diff --git a/fairseq/__init__.py b/fairseq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..080c988b2da326c2fe356630d5641d367b37a546 --- /dev/null +++ b/fairseq/__init__.py @@ -0,0 +1,45 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import os +import sys + +try: + from .version import __version__ # noqa +except ImportError: + version_txt = os.path.join(os.path.dirname(__file__), "version.txt") + with open(version_txt) as f: + __version__ = f.read().strip() + +__all__ = ["pdb"] + +# backwards compatibility to support `from fairseq.X import Y` +from fairseq.distributed import utils as distributed_utils +from fairseq.logging import meters, metrics, progress_bar # noqa + +sys.modules["fairseq.distributed_utils"] = distributed_utils +sys.modules["fairseq.meters"] = meters +sys.modules["fairseq.metrics"] = metrics +sys.modules["fairseq.progress_bar"] = progress_bar + +# initialize hydra +from fairseq.dataclass.initialize import hydra_init + +hydra_init() + +import fairseq.criterions # noqa +import fairseq.distributed # noqa +import fairseq.models # noqa +import fairseq.modules # noqa +import fairseq.optim # noqa +import fairseq.optim.lr_scheduler # noqa +import fairseq.pdb # noqa +import fairseq.scoring # noqa +import fairseq.tasks # noqa +import fairseq.token_generation_constraints # noqa + +import fairseq.benchmark # noqa +import fairseq.model_parallel # noqa diff --git a/fairseq/__pycache__/__init__.cpython-310.pyc b/fairseq/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc300748279c892437ebe72564a22e3a249b9a1d Binary files /dev/null and b/fairseq/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/__pycache__/__init__.cpython-311.pyc b/fairseq/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a93408f4f267e2d92c3862f166c94fc91975c8a Binary files /dev/null and b/fairseq/__pycache__/__init__.cpython-311.pyc differ diff --git a/fairseq/__pycache__/checkpoint_utils.cpython-310.pyc b/fairseq/__pycache__/checkpoint_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f04776f79434ad04f562b21a7de07154d493abd0 Binary files /dev/null and b/fairseq/__pycache__/checkpoint_utils.cpython-310.pyc differ diff --git a/fairseq/__pycache__/file_chunker_utils.cpython-310.pyc b/fairseq/__pycache__/file_chunker_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6335ed6abf8110680ab7d6a97bff83ecaadba241 Binary files /dev/null and b/fairseq/__pycache__/file_chunker_utils.cpython-310.pyc differ diff --git a/fairseq/__pycache__/file_io.cpython-310.pyc b/fairseq/__pycache__/file_io.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..daabf488516403913095a816c0cf4569235ec717 Binary files /dev/null and b/fairseq/__pycache__/file_io.cpython-310.pyc differ diff --git a/fairseq/__pycache__/file_utils.cpython-310.pyc b/fairseq/__pycache__/file_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc3e58538a3660e3cc37b0413402ad5cfc6817ad Binary files /dev/null and b/fairseq/__pycache__/file_utils.cpython-310.pyc differ diff --git a/fairseq/__pycache__/hub_utils.cpython-310.pyc b/fairseq/__pycache__/hub_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a528b1bf60e5b71f3c1b0befde2adc5b8d36aa4 Binary files /dev/null and b/fairseq/__pycache__/hub_utils.cpython-310.pyc differ diff --git a/fairseq/__pycache__/incremental_decoding_utils.cpython-310.pyc b/fairseq/__pycache__/incremental_decoding_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5271d9f907028b2ec47a9f3eef4888fe92576472 Binary files /dev/null and b/fairseq/__pycache__/incremental_decoding_utils.cpython-310.pyc differ diff --git a/fairseq/__pycache__/iterative_refinement_generator.cpython-310.pyc b/fairseq/__pycache__/iterative_refinement_generator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f5c530f4144676f829f7fb7ffa180d75bdeec08 Binary files /dev/null and b/fairseq/__pycache__/iterative_refinement_generator.cpython-310.pyc differ diff --git a/fairseq/__pycache__/ngram_repeat_block.cpython-310.pyc b/fairseq/__pycache__/ngram_repeat_block.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8defc5d4f226acdb6b0785d5ad60abd093ec2d95 Binary files /dev/null and b/fairseq/__pycache__/ngram_repeat_block.cpython-310.pyc differ diff --git a/fairseq/__pycache__/options.cpython-310.pyc b/fairseq/__pycache__/options.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb43ec2a5cb571d16c6ff18ffd9be53f1b8cac6e Binary files /dev/null and b/fairseq/__pycache__/options.cpython-310.pyc differ diff --git a/fairseq/__pycache__/pdb.cpython-310.pyc b/fairseq/__pycache__/pdb.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c46de80e1afc4508e5aa48d8a9711995063c38c Binary files /dev/null and b/fairseq/__pycache__/pdb.cpython-310.pyc differ diff --git a/fairseq/__pycache__/quantization_utils.cpython-310.pyc b/fairseq/__pycache__/quantization_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d33fa9b3884c55b5e2772fd86acb6322cb2f49bb Binary files /dev/null and b/fairseq/__pycache__/quantization_utils.cpython-310.pyc differ diff --git a/fairseq/__pycache__/registry.cpython-310.pyc b/fairseq/__pycache__/registry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11e20c29982aba89dabceafdb8ad9caa05521dbc Binary files /dev/null and b/fairseq/__pycache__/registry.cpython-310.pyc differ diff --git a/fairseq/__pycache__/search.cpython-310.pyc b/fairseq/__pycache__/search.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fbf26c0ba4b7d42132784c40a46f0c5e9ae6a23 Binary files /dev/null and b/fairseq/__pycache__/search.cpython-310.pyc differ diff --git a/fairseq/__pycache__/sequence_generator.cpython-310.pyc b/fairseq/__pycache__/sequence_generator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90320b6657fa62e1cb78e7be5f8147e9ebd36548 Binary files /dev/null and b/fairseq/__pycache__/sequence_generator.cpython-310.pyc differ diff --git a/fairseq/__pycache__/speech_generator.cpython-310.pyc b/fairseq/__pycache__/speech_generator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35e6b6e2171be17d8cd2a91ee5ae066f0031bb20 Binary files /dev/null and b/fairseq/__pycache__/speech_generator.cpython-310.pyc differ diff --git a/fairseq/__pycache__/token_generation_constraints.cpython-310.pyc b/fairseq/__pycache__/token_generation_constraints.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa11dd07fda0da92e575f4cf8eb3bb4668089255 Binary files /dev/null and b/fairseq/__pycache__/token_generation_constraints.cpython-310.pyc differ diff --git a/fairseq/__pycache__/tokenizer.cpython-310.pyc b/fairseq/__pycache__/tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77cc4e59e8ba2db1a9344db7b9ca2761b421c8ab Binary files /dev/null and b/fairseq/__pycache__/tokenizer.cpython-310.pyc differ diff --git a/fairseq/__pycache__/utils.cpython-310.pyc b/fairseq/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d3799f830a33a363d1037019acb15f2f24eb6c2 Binary files /dev/null and b/fairseq/__pycache__/utils.cpython-310.pyc differ diff --git a/fairseq/__pycache__/version.cpython-310.pyc b/fairseq/__pycache__/version.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b5fe9d10594fb2063df6003affda0812f4d21bc Binary files /dev/null and b/fairseq/__pycache__/version.cpython-310.pyc differ diff --git a/fairseq/__pycache__/version.cpython-311.pyc b/fairseq/__pycache__/version.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f39ee51391bc3040c0717420f71cbb0805c023db Binary files /dev/null and b/fairseq/__pycache__/version.cpython-311.pyc differ diff --git a/fairseq/benchmark/__init__.py b/fairseq/benchmark/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0317d5c623778fe40b7bf07b77769cd10c243244 --- /dev/null +++ b/fairseq/benchmark/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# import models/tasks to register them +from . import dummy_dataset, dummy_lm, dummy_masked_lm, dummy_model, dummy_mt # noqa diff --git a/fairseq/benchmark/__pycache__/__init__.cpython-310.pyc b/fairseq/benchmark/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..779db67dab5c1409fa5c68a7cf7f2c996df6c829 Binary files /dev/null and b/fairseq/benchmark/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/benchmark/__pycache__/dummy_dataset.cpython-310.pyc b/fairseq/benchmark/__pycache__/dummy_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50328f267a3db8885afc62597f143d36db867bb6 Binary files /dev/null and b/fairseq/benchmark/__pycache__/dummy_dataset.cpython-310.pyc differ diff --git a/fairseq/benchmark/__pycache__/dummy_lm.cpython-310.pyc b/fairseq/benchmark/__pycache__/dummy_lm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c16576cd107788a4aa715d1840197bf017ff76fa Binary files /dev/null and b/fairseq/benchmark/__pycache__/dummy_lm.cpython-310.pyc differ diff --git a/fairseq/benchmark/__pycache__/dummy_masked_lm.cpython-310.pyc b/fairseq/benchmark/__pycache__/dummy_masked_lm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34b88fbdd96ad1251190fca39750f990b5114b2a Binary files /dev/null and b/fairseq/benchmark/__pycache__/dummy_masked_lm.cpython-310.pyc differ diff --git a/fairseq/benchmark/__pycache__/dummy_model.cpython-310.pyc b/fairseq/benchmark/__pycache__/dummy_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eabfd933f60282c334ed1f4c7968a0bd4a897d58 Binary files /dev/null and b/fairseq/benchmark/__pycache__/dummy_model.cpython-310.pyc differ diff --git a/fairseq/benchmark/__pycache__/dummy_mt.cpython-310.pyc b/fairseq/benchmark/__pycache__/dummy_mt.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c74ef0cbe2f1c2ea396625cda86b8c625e7e93d7 Binary files /dev/null and b/fairseq/benchmark/__pycache__/dummy_mt.cpython-310.pyc differ diff --git a/fairseq/benchmark/benchmark_multihead_attention.py b/fairseq/benchmark/benchmark_multihead_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..a44847f25031ff2e4490ca47d560167af786f64d --- /dev/null +++ b/fairseq/benchmark/benchmark_multihead_attention.py @@ -0,0 +1,172 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import random + +import torch +from torch.utils import benchmark + +from fairseq.modules.multihead_attention import MultiheadAttention + +BATCH = [20, 41, 97] +SEQ = 64 +EMB = 48 +HEADS = 4 +DROP = 0.1 +DEVICE = torch.device("cuda") +ATTN_MASK_DTYPE = [torch.uint8, torch.bool, torch.float] +KEY_PADDING_MASK_DTYPE = [torch.uint8, torch.bool] + + +def _reset_seeds(): + torch.manual_seed(0) + random.seed(0) + + +def _get_mask(to_dtype: torch.dtype, dim0: int, dim1: int): + if to_dtype == torch.float: + mask = torch.randint(0, 2, (dim0, dim1)).to(dtype=torch.bool) + return mask.to(dtype=to_dtype).masked_fill(mask, -float("inf")) + return torch.randint(0, 2, (dim0, dim1)).to(dtype=to_dtype) + + +def benchmark_multihead_attention( + label="", + attn_dtype=torch.uint8, + key_padding_dtype=torch.uint8, + add_bias_kv=False, + add_zero_attn=False, + static_kv=False, + batch_size=20, + embedding=EMB, + seq_len=SEQ, + num_heads=HEADS, +): + + results = [] + # device = torch.device("cuda") + + xformers_att_config = '{"name": "scaled_dot_product"}' + + attn_mask = _get_mask(to_dtype=attn_dtype, dim0=seq_len, dim1=seq_len) + key_padding_mask = _get_mask( + to_dtype=key_padding_dtype, dim0=batch_size, dim1=seq_len + ) + + q = torch.rand(seq_len, batch_size, embedding, requires_grad=True) + k = torch.rand(seq_len, batch_size, embedding, requires_grad=True) + v = torch.rand(seq_len, batch_size, embedding, requires_grad=True) + + _reset_seeds() + + original_mha = MultiheadAttention( + embedding, + num_heads, + dropout=0.0, + xformers_att_config=None, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + ) + + xformers_mha = MultiheadAttention( + embedding, + num_heads, + dropout=0.0, + xformers_att_config=xformers_att_config, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + ) + + def original_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv): + original_mha( + query=q, + key=k, + value=v, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + static_kv=static_kv, + ) + + def xformers_bench_fw(q, k, v, key_padding_mask, attn_mask, static_kv): + xformers_mha( + query=q, + key=k, + value=v, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + static_kv=static_kv, + ) + + def original_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv): + output, _ = original_mha( + query=q, + key=k, + value=v, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + static_kv=static_kv, + ) + loss = torch.norm(output) + loss.backward() + + def xformers_bench_fw_bw(q, k, v, key_padding_mask, attn_mask, static_kv): + output, _ = xformers_mha( + query=q, + key=k, + value=v, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + static_kv=static_kv, + ) + loss = torch.norm(output) + loss.backward() + + fns = [ + original_bench_fw, + xformers_bench_fw, + original_bench_fw_bw, + xformers_bench_fw_bw, + ] + + for fn in fns: + results.append( + benchmark.Timer( + stmt="fn(q, k, v, key_padding_mask, attn_mask, static_kv)", + globals={ + "q": q, + "k": k, + "v": v, + "key_padding_mask": key_padding_mask, + "attn_mask": attn_mask, + "static_kv": static_kv, + "fn": fn, + }, + label="multihead fw + bw", + sub_label=f"{fn.__name__}", + description=label, + ).blocked_autorange(min_run_time=1) + ) + + compare = benchmark.Compare(results) + compare.print() + + +def run_benchmarks(): + for attn_dtype, key_padding_dtype, add_bias_kv, add_zero_attn in itertools.product( + ATTN_MASK_DTYPE, KEY_PADDING_MASK_DTYPE, [True, False], [True, False] + ): + label = f"attn_dtype {attn_dtype}, key_padding_dtype {key_padding_dtype}, \ + add_bias_kv {add_bias_kv}, add_zero_attn {add_zero_attn}" + benchmark_multihead_attention( + label=label, + attn_dtype=attn_dtype, + key_padding_dtype=key_padding_dtype, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + ) + + +run_benchmarks() diff --git a/fairseq/benchmark/dummy_dataset.py b/fairseq/benchmark/dummy_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2f051754af55966e26850e94c121e0ff439bfd28 --- /dev/null +++ b/fairseq/benchmark/dummy_dataset.py @@ -0,0 +1,36 @@ +import numpy as np +from fairseq.data import FairseqDataset + + +class DummyDataset(FairseqDataset): + def __init__(self, batch, num_items, item_size): + super().__init__() + self.batch = batch + self.num_items = num_items + self.item_size = item_size + + def __getitem__(self, index): + return index + + def __len__(self): + return self.num_items + + def collater(self, samples): + return self.batch + + @property + def sizes(self): + return np.array([self.item_size] * self.num_items) + + def num_tokens(self, index): + return self.item_size + + def size(self, index): + return self.item_size + + def ordered_indices(self): + return np.arange(self.num_items) + + @property + def supports_prefetch(self): + return False diff --git a/fairseq/benchmark/dummy_lm.py b/fairseq/benchmark/dummy_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..c6246a0c0e338fa36244b3aa4fb57f189fbffcb6 --- /dev/null +++ b/fairseq/benchmark/dummy_lm.py @@ -0,0 +1,83 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from dataclasses import dataclass, field +from typing import Optional + +import torch +from .dummy_dataset import DummyDataset +from fairseq.data import Dictionary +from fairseq.dataclass import FairseqDataclass +from fairseq.tasks import FairseqTask, register_task +from omegaconf import II + + +logger = logging.getLogger(__name__) + + +@dataclass +class DummyLMConfig(FairseqDataclass): + dict_size: int = 49996 + dataset_size: int = 100000 + tokens_per_sample: int = field( + default=512, metadata={"help": "max sequence length"} + ) + add_bos_token: bool = False + batch_size: Optional[int] = II("dataset.batch_size") + max_tokens: Optional[int] = II("dataset.max_tokens") + max_target_positions: int = II("task.tokens_per_sample") + + +@register_task("dummy_lm", dataclass=DummyLMConfig) +class DummyLMTask(FairseqTask): + def __init__(self, cfg: DummyLMConfig): + super().__init__(cfg) + + # load dictionary + self.dictionary = Dictionary() + for i in range(cfg.dict_size): + self.dictionary.add_symbol("word{}".format(i)) + self.dictionary.pad_to_multiple_(8) # often faster if divisible by 8 + logger.info("dictionary: {} types".format(len(self.dictionary))) + + seq = torch.arange(cfg.tokens_per_sample + 1) + self.dictionary.pad() + 1 + + self.dummy_src = seq[:-1] + self.dummy_tgt = seq[1:] + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + Args: + split (str): name of the split (e.g., train, valid, test) + """ + if self.cfg.batch_size is not None: + bsz = self.cfg.batch_size + else: + bsz = max(1, self.cfg.max_tokens // self.cfg.tokens_per_sample) + self.datasets[split] = DummyDataset( + { + "id": 1, + "net_input": { + "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]), + "src_lengths": torch.full( + (bsz,), self.cfg.tokens_per_sample, dtype=torch.long + ), + }, + "target": torch.stack([self.dummy_tgt for _ in range(bsz)]), + "nsentences": bsz, + "ntokens": bsz * self.cfg.tokens_per_sample, + }, + num_items=self.cfg.dataset_size, + item_size=self.cfg.tokens_per_sample, + ) + + @property + def source_dictionary(self): + return self.dictionary + + @property + def target_dictionary(self): + return self.dictionary diff --git a/fairseq/benchmark/dummy_masked_lm.py b/fairseq/benchmark/dummy_masked_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..12b9c5d0f55993bf8750564882a351fc3f8055f0 --- /dev/null +++ b/fairseq/benchmark/dummy_masked_lm.py @@ -0,0 +1,94 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from dataclasses import dataclass, field +from typing import Optional + +import torch +from omegaconf import II + +from .dummy_dataset import DummyDataset +from fairseq.data import Dictionary +from fairseq.dataclass import FairseqDataclass +from fairseq.tasks import FairseqTask, register_task + +logger = logging.getLogger(__name__) + + +@dataclass +class DummyMaskedLMConfig(FairseqDataclass): + dict_size: int = 49996 + dataset_size: int = 100000 + tokens_per_sample: int = field( + default=512, + metadata={ + "help": "max number of total tokens over all" + " segments per sample for BERT dataset" + }, + ) + batch_size: Optional[int] = II("dataset.batch_size") + max_tokens: Optional[int] = II("dataset.max_tokens") + max_target_positions: int = II("task.tokens_per_sample") + + +@register_task("dummy_masked_lm", dataclass=DummyMaskedLMConfig) +class DummyMaskedLMTask(FairseqTask): + def __init__(self, cfg: DummyMaskedLMConfig): + super().__init__(cfg) + + self.dictionary = Dictionary() + for i in range(cfg.dict_size): + self.dictionary.add_symbol("word{}".format(i)) + logger.info("dictionary: {} types".format(len(self.dictionary))) + # add mask token + self.mask_idx = self.dictionary.add_symbol("") + self.dictionary.pad_to_multiple_(8) # often faster if divisible by 8 + + mask_idx = 0 + pad_idx = 1 + seq = torch.arange(cfg.tokens_per_sample) + pad_idx + 1 + mask = torch.arange(2, cfg.tokens_per_sample, 7) # ~15% + src = seq.clone() + src[mask] = mask_idx + tgt = torch.full_like(seq, pad_idx) + tgt[mask] = seq[mask] + + self.dummy_src = src + self.dummy_tgt = tgt + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + Args: + split (str): name of the split (e.g., train, valid, test) + """ + if self.cfg.batch_size is not None: + bsz = self.cfg.batch_size + else: + bsz = max(1, self.cfg.max_tokens // self.cfg.tokens_per_sample) + self.datasets[split] = DummyDataset( + { + "id": 1, + "net_input": { + "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]), + "src_lengths": torch.full( + (bsz,), self.cfg.tokens_per_sample, dtype=torch.long + ), + }, + "target": torch.stack([self.dummy_tgt for _ in range(bsz)]), + "nsentences": bsz, + "ntokens": bsz * self.cfg.tokens_per_sample, + }, + num_items=self.cfg.dataset_size, + item_size=self.cfg.tokens_per_sample, + ) + + @property + def source_dictionary(self): + return self.dictionary + + @property + def target_dictionary(self): + return self.dictionary diff --git a/fairseq/benchmark/dummy_model.py b/fairseq/benchmark/dummy_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ff26e4fe655d8e8d7f9942c4bd3df7cd267405fb --- /dev/null +++ b/fairseq/benchmark/dummy_model.py @@ -0,0 +1,96 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn +import torch.nn.functional as F +from fairseq.data import Dictionary +from fairseq.models import ( + FairseqDecoder, + FairseqLanguageModel, + register_model, + register_model_architecture, +) + + +@register_model("dummy_model") +class DummyModel(FairseqLanguageModel): + def __init__(self, args, encoder): + super().__init__(encoder) + self.args = args + + @staticmethod + def add_args(parser): + parser.add_argument("--num-layers", type=int, default=24) + parser.add_argument("--embed-dim", type=int, default=1024) + + @classmethod + def build_model(cls, args, task): + encoder = DummyEncoder( + num_embed=len(task.target_dictionary), + embed_dim=args.embed_dim, + num_layers=args.num_layers, + ) + return cls(args, encoder) + + def forward(self, src_tokens, masked_tokens=None, **kwargs): + return self.decoder(src_tokens, masked_tokens=masked_tokens) + + +class DummyEncoder(FairseqDecoder): + def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24): + super().__init__(Dictionary()) + self.embed = nn.Embedding( + num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0 + ) + self.layers_a = nn.ModuleList( + [ + nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, 3 * embed_dim), # q, k, v input projection + nn.Linear(3 * embed_dim, embed_dim), # skip self-attention + nn.Linear(embed_dim, embed_dim), # output projection + nn.Dropout(), + ) + for i in range(num_layers) + ] + ) + self.layers_b = nn.ModuleList( + [ + nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, 4 * embed_dim), # FFN + nn.ReLU(), + nn.Linear(4 * embed_dim, embed_dim), # FFN + nn.Dropout(0.1), + ) + for i in range(num_layers) + ] + ) + self.out_proj = nn.Linear(embed_dim, num_embed) + + def forward(self, tokens, masked_tokens=None): + x = self.embed(tokens) + for layer_a, layer_b in zip(self.layers_a, self.layers_b): + x = x + layer_a(x) + x = x + layer_b(x) + x = self.out_proj(x) + if masked_tokens is not None: + x = x[masked_tokens] + return (x,) + + def max_positions(self): + return 1024 + + def get_normalized_probs(self, net_output, log_probs, sample=None): + logits = net_output[0].float() + if log_probs: + return F.log_softmax(logits, dim=-1) + else: + return F.softmax(logits, dim=-1) + + +@register_model_architecture("dummy_model", "dummy_model") +def base_architecture(args): + pass diff --git a/fairseq/benchmark/dummy_mt.py b/fairseq/benchmark/dummy_mt.py new file mode 100644 index 0000000000000000000000000000000000000000..28d78cffdbf8c2bcee69b454a79891cb34def200 --- /dev/null +++ b/fairseq/benchmark/dummy_mt.py @@ -0,0 +1,119 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import numpy as np +import torch + +from fairseq.data import Dictionary, FairseqDataset +from fairseq.tasks import LegacyFairseqTask, register_task + +logger = logging.getLogger(__name__) + + +@register_task("dummy_mt") +class DummyMTTask(LegacyFairseqTask): + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument("--dict-size", default=49996, type=int) + parser.add_argument("--dataset-size", default=100000, type=int) + parser.add_argument("--src-len", default=30, type=int) + parser.add_argument("--tgt-len", default=30, type=int) + + def __init__(self, args, dictionary): + super().__init__(args) + self.dictionary = dictionary + self.seed = args.seed + + dictionary.pad_to_multiple_(8) # often faster if divisible by 8 + + self.dummy_src = torch.arange(args.src_len + 1) + dictionary.pad() + 1 + self.dummy_tgt = torch.arange(args.tgt_len + 1) + dictionary.pad() + 1 + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task.""" + dictionary = Dictionary() + for i in range(args.dict_size): + dictionary.add_symbol("word{}".format(i)) + logger.info("dictionary: {} types".format(len(dictionary))) + + args.max_source_positions = args.src_len + dictionary.pad() + 2 + args.max_target_positions = args.tgt_len + dictionary.pad() + 2 + + return cls(args, dictionary) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + Args: + split (str): name of the split (e.g., train, valid, test) + """ + item_size = max(self.args.src_len, self.args.tgt_len) + if self.args.batch_size is not None: + bsz = self.args.batch_size + else: + bsz = max(1, self.args.max_tokens // item_size) + tgt = torch.stack([self.dummy_tgt for _ in range(bsz)]) + self.datasets[split] = DummyDataset( + { + "id": 1, + "net_input": { + "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]), + "src_lengths": torch.full( + (bsz,), self.args.src_len, dtype=torch.long + ), + "prev_output_tokens": tgt.clone(), + }, + "target": tgt, + "nsentences": bsz, + "ntokens": bsz * self.args.tgt_len, + }, + num_items=self.args.dataset_size, + item_size=item_size, + ) + + @property + def source_dictionary(self): + return self.dictionary + + @property + def target_dictionary(self): + return self.dictionary + + +class DummyDataset(FairseqDataset): + def __init__(self, batch, num_items, item_size): + super().__init__() + self.batch = batch + self.num_items = num_items + self.item_size = item_size + + def __getitem__(self, index): + return index + + def __len__(self): + return self.num_items + + def collater(self, samples): + return self.batch + + @property + def sizes(self): + return np.array([self.item_size] * self.num_items) + + def num_tokens(self, index): + return self.item_size + + def size(self, index): + return self.item_size + + def ordered_indices(self): + return np.arange(self.num_items) + + @property + def supports_prefetch(self): + return False diff --git a/fairseq/binarizer.py b/fairseq/binarizer.py new file mode 100644 index 0000000000000000000000000000000000000000..6f03d7a2cbb16db6aa218713211c1323adbc7d45 --- /dev/null +++ b/fairseq/binarizer.py @@ -0,0 +1,381 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import typing as tp +from abc import ABC, abstractmethod +from collections import Counter +from dataclasses import dataclass +from multiprocessing import Pool + +import torch + +from fairseq.data import Dictionary, indexed_dataset +from fairseq.file_chunker_utils import Chunker, find_offsets +from fairseq.file_io import PathManager +from fairseq.tokenizer import tokenize_line + +logger = logging.getLogger("binarizer") + + +@dataclass +class BinarizeSummary: + """ + Keep track of what's going on in the binarizer + """ + + num_seq: int = 0 + replaced: tp.Optional[Counter] = None + num_tok: int = 0 + + @property + def num_replaced(self) -> int: + if self.replaced is None: + return 0 + return sum(self.replaced.values()) + + @property + def replaced_percent(self) -> float: + return 100 * self.num_replaced / self.num_tok + + def __str__(self) -> str: + base = f"{self.num_seq} sents, {self.num_tok} tokens" + if self.replaced is None: + return base + + return f"{base}, {self.replaced_percent:.3}% replaced" + + def merge(self, other: "BinarizeSummary"): + replaced = None + if self.replaced is not None: + replaced = self.replaced + if other.replaced is not None: + if replaced is None: + replaced = other.replaced + else: + replaced += other.replaced + self.replaced = replaced + self.num_seq += other.num_seq + self.num_tok += other.num_tok + + +class Binarizer(ABC): + """ + a binarizer describes how to take a string and build a tensor out of it + """ + + @abstractmethod + def binarize_line( + self, + line: str, + summary: BinarizeSummary, + ) -> torch.IntTensor: + ... + + +def _worker_prefix(output_prefix: str, worker_id: int): + return f"{output_prefix}.pt{worker_id}" + + +class FileBinarizer: + """ + An file binarizer can take a file, tokenize it, and binarize each line to a tensor + """ + + @classmethod + def multiprocess_dataset( + cls, + input_file: str, + dataset_impl: str, + binarizer: Binarizer, + output_prefix: str, + vocab_size=None, + num_workers=1, + ) -> BinarizeSummary: + final_summary = BinarizeSummary() + + offsets = find_offsets(input_file, num_workers) + # find_offsets returns a list of position [pos1, pos2, pos3, pos4] but we would want pairs: + # [(pos1, pos2), (pos2, pos3), (pos3, pos4)] to process the chunks with start/end info + # we zip the list with itself shifted by one to get all the pairs. + (first_chunk, *more_chunks) = zip(offsets, offsets[1:]) + pool = None + if num_workers > 1: + pool = Pool(processes=num_workers - 1) + worker_results = [ + pool.apply_async( + cls._binarize_chunk_and_finalize, + args=( + binarizer, + input_file, + start_offset, + end_offset, + _worker_prefix( + output_prefix, + worker_id, + ), + dataset_impl, + ), + kwds={ + "vocab_size": vocab_size, + } + if vocab_size is not None + else {}, + ) + for worker_id, (start_offset, end_offset) in enumerate( + more_chunks, start=1 + ) + ] + + pool.close() + pool.join() + for r in worker_results: + summ = r.get() + final_summary.merge(summ) + + # do not close the bin file as we need to merge the worker results in + final_ds, summ = cls._binarize_file_chunk( + binarizer, + input_file, + offset_start=first_chunk[0], + offset_end=first_chunk[1], + output_prefix=output_prefix, + dataset_impl=dataset_impl, + vocab_size=vocab_size if vocab_size is not None else None, + ) + final_summary.merge(summ) + + if num_workers > 1: + for worker_id in range(1, num_workers): + # merge the worker outputs + worker_output_prefix = _worker_prefix( + output_prefix, + worker_id, + ) + final_ds.merge_file_(worker_output_prefix) + try: + os.remove(indexed_dataset.data_file_path(worker_output_prefix)) + os.remove(indexed_dataset.index_file_path(worker_output_prefix)) + except Exception as e: + logger.error( + f"couldn't remove {worker_output_prefix}.*", exc_info=e + ) + + # now we can close the file + idx_file = indexed_dataset.index_file_path(output_prefix) + final_ds.finalize(idx_file) + return final_summary + + @staticmethod + def _binarize_file_chunk( + binarizer: Binarizer, + filename: str, + offset_start: int, + offset_end: int, + output_prefix: str, + dataset_impl: str, + vocab_size=None, + ) -> tp.Tuple[tp.Any, BinarizeSummary]: # (dataset builder, BinarizeSummary) + """ + creates a dataset builder and append binarized items to it. This function does not + finalize the builder, this is useful if you want to do other things with your bin file + like appending/merging other files + """ + bin_file = indexed_dataset.data_file_path(output_prefix) + ds = indexed_dataset.make_builder( + bin_file, + impl=dataset_impl, + vocab_size=vocab_size, + ) + summary = BinarizeSummary() + + with Chunker( + PathManager.get_local_path(filename), offset_start, offset_end + ) as line_iterator: + for line in line_iterator: + ds.add_item(binarizer.binarize_line(line, summary)) + + return ds, summary + + @classmethod + def _binarize_chunk_and_finalize( + cls, + binarizer: Binarizer, + filename: str, + offset_start: int, + offset_end: int, + output_prefix: str, + dataset_impl: str, + vocab_size=None, + ): + """ + same as above, but also finalizes the builder + """ + ds, summ = cls._binarize_file_chunk( + binarizer, + filename, + offset_start, + offset_end, + output_prefix, + dataset_impl, + vocab_size=vocab_size, + ) + + idx_file = indexed_dataset.index_file_path(output_prefix) + ds.finalize(idx_file) + + return summ + + +class VocabularyDatasetBinarizer(Binarizer): + """ + Takes a Dictionary/Vocabulary, assign ids to each + token using the dictionary encode_line function. + """ + + def __init__( + self, + dict: Dictionary, + tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line, + append_eos: bool = True, + reverse_order: bool = False, + already_numberized: bool = False, + ) -> None: + self.dict = dict + self.tokenize = tokenize + self.append_eos = append_eos + self.reverse_order = reverse_order + self.already_numberized = already_numberized + super().__init__() + + def binarize_line( + self, + line: str, + summary: BinarizeSummary, + ): + if summary.replaced is None: + summary.replaced = Counter() + + def replaced_consumer(word, idx): + if idx == self.dict.unk_index and word != self.dict.unk_word: + summary.replaced.update([word]) + + if self.already_numberized: + id_strings = line.strip().split() + id_list = [int(id_string) for id_string in id_strings] + if self.reverse_order: + id_list.reverse() + if self.append_eos: + id_list.append(self.dict.eos()) + ids = torch.IntTensor(id_list) + else: + ids = self.dict.encode_line( + line=line, + line_tokenizer=self.tokenize, + add_if_not_exist=False, + consumer=replaced_consumer, + append_eos=self.append_eos, + reverse_order=self.reverse_order, + ) + + summary.num_seq += 1 + summary.num_tok += len(ids) + return ids + + +class AlignmentDatasetBinarizer(Binarizer): + """ + binarize by parsing a set of alignments and packing + them in a tensor (see utils.parse_alignment) + """ + + def __init__( + self, + alignment_parser: tp.Callable[[str], torch.IntTensor], + ) -> None: + super().__init__() + self.alignment_parser = alignment_parser + + def binarize_line( + self, + line: str, + summary: BinarizeSummary, + ): + ids = self.alignment_parser(line) + summary.num_seq += 1 + summary.num_tok += len(ids) + return ids + + +class LegacyBinarizer: + @classmethod + def binarize( + cls, + filename: str, + dico: Dictionary, + consumer: tp.Callable[[torch.IntTensor], None], + tokenize: tp.Callable[[str], tp.List[str]] = tokenize_line, + append_eos: bool = True, + reverse_order: bool = False, + offset: int = 0, + end: int = -1, + already_numberized: bool = False, + ) -> tp.Dict[str, int]: + binarizer = VocabularyDatasetBinarizer( + dict=dico, + tokenize=tokenize, + append_eos=append_eos, + reverse_order=reverse_order, + already_numberized=already_numberized, + ) + return cls._consume_file( + filename, + binarizer, + consumer, + offset_start=offset, + offset_end=end, + ) + + @classmethod + def binarize_alignments( + cls, + filename: str, + alignment_parser: tp.Callable[[str], torch.IntTensor], + consumer: tp.Callable[[torch.IntTensor], None], + offset: int = 0, + end: int = -1, + ) -> tp.Dict[str, int]: + binarizer = AlignmentDatasetBinarizer(alignment_parser) + return cls._consume_file( + filename, + binarizer, + consumer, + offset_start=offset, + offset_end=end, + ) + + @staticmethod + def _consume_file( + filename: str, + binarizer: Binarizer, + consumer: tp.Callable[[torch.IntTensor], None], + offset_start: int, + offset_end: int, + ) -> tp.Dict[str, int]: + summary = BinarizeSummary() + + with Chunker( + PathManager.get_local_path(filename), offset_start, offset_end + ) as line_iterator: + for line in line_iterator: + consumer(binarizer.binarize_line(line, summary)) + + return { + "nseq": summary.num_seq, + "nunk": summary.num_replaced, + "ntok": summary.num_tok, + "replaced": summary.replaced, + } diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4447f74507763d7b4843b5f52eda8cfe2954ea63 --- /dev/null +++ b/fairseq/checkpoint_utils.py @@ -0,0 +1,936 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import ast +import collections +import contextlib +import inspect +import logging +import os +import re +import time +import traceback +from collections import OrderedDict +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch +from fairseq.data import data_utils +from fairseq.dataclass.configs import CheckpointConfig +from fairseq.dataclass.utils import ( + convert_namespace_to_omegaconf, + overwrite_args_by_name, +) +from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP +from fairseq.file_io import PathManager +from fairseq.models import FairseqDecoder, FairseqEncoder +from omegaconf import DictConfig, OmegaConf, open_dict + +logger = logging.getLogger(__name__) + + +def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): + from fairseq import meters + + # only one worker should attempt to create the required dir + if trainer.data_parallel_rank == 0: + os.makedirs(cfg.save_dir, exist_ok=True) + + prev_best = getattr(save_checkpoint, "best", val_loss) + if val_loss is not None: + best_function = max if cfg.maximize_best_checkpoint_metric else min + save_checkpoint.best = best_function(val_loss, prev_best) + + if cfg.no_save: + return None + + trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state + + if not trainer.should_save_checkpoint_on_current_rank: + if trainer.always_call_state_dict_during_save_checkpoint: + trainer.state_dict() + return None + + write_timer = meters.StopwatchMeter() + write_timer.start() + + epoch = epoch_itr.epoch + end_of_epoch = epoch_itr.end_of_epoch() + updates = trainer.get_num_updates() + + logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates") + + def is_better(a, b): + return a >= b if cfg.maximize_best_checkpoint_metric else a <= b + + suffix = trainer.checkpoint_suffix + checkpoint_conds = collections.OrderedDict() + checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = ( + end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0 + ) + checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = ( + not end_of_epoch + and cfg.save_interval_updates > 0 + and updates % cfg.save_interval_updates == 0 + ) + checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and ( + not hasattr(save_checkpoint, "best") + or is_better(val_loss, save_checkpoint.best) + ) + if val_loss is not None and cfg.keep_best_checkpoints > 0: + worst_best = getattr(save_checkpoint, "best", None) + chkpts = checkpoint_paths( + cfg.save_dir, + pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format( + cfg.best_checkpoint_metric, suffix + ), + ) + if len(chkpts) > 0: + p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0] + worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), "")) + # add random digits to resolve ties + with data_utils.numpy_seed(epoch, updates, val_loss): + rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints) + + checkpoint_conds[ + "checkpoint.best_{}_{:.3f}{}{}.pt".format( + cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix + ) + ] = worst_best is None or is_better(val_loss, worst_best) + checkpoint_conds[ + "checkpoint_last{}.pt".format(suffix) + ] = not cfg.no_last_checkpoints + + extra_state = { + "train_iterator": epoch_itr.state_dict(), + "val_loss": val_loss, + } + + # Going forward, different tasks could expose an API like this to dump all + # the checkpoint worthy attributes in a dictionary which then will be + # merged with the parent dictionary to create the "extra_state". This + # allows for an extensible yet simple design to checkpoint task level + # attributes + if hasattr(trainer.task, "get_checkpoint_dict"): + extra_state = {**extra_state, **trainer.task.get_checkpoint_dict()} + logger.info(f"State of {trainer.task.__class__.__name__} is ready to be persisted with the checkpoint") + + if hasattr(save_checkpoint, "best"): + extra_state.update({"best": save_checkpoint.best}) + + checkpoints = [ + os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond + ] + saved_cp = None + if len(checkpoints) > 0 and trainer.should_save_checkpoint_on_current_rank: + saved_cp = trainer.save_checkpoint(checkpoints[0], extra_state) + for cp in checkpoints[1:]: + if cfg.write_checkpoints_asynchronously: + # TODO[ioPath]: Need to implement a delayed asynchronous + # file copying/moving feature. + logger.warning( + f"ioPath is not copying {checkpoints[0]} to {cp} " + "since async write mode is on." + ) + else: + assert PathManager.copy( + checkpoints[0], cp, overwrite=True + ), f"Failed to copy {checkpoints[0]} to {cp}" + + write_timer.stop() + logger.info( + "Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format( + checkpoints[0], epoch, updates, val_loss, write_timer.sum + ) + ) + + if ( + not end_of_epoch + and cfg.keep_interval_updates > 0 + and trainer.should_save_checkpoint_on_current_rank + ): + # remove old checkpoints; checkpoints are sorted in descending order + if cfg.keep_interval_updates_pattern == -1: + checkpoints = checkpoint_paths( + cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix) + ) + else: + checkpoints = checkpoint_paths( + cfg.save_dir, + pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix), + keep_match=True, + ) + checkpoints = [ + x[0] + for x in checkpoints + if x[1] % cfg.keep_interval_updates_pattern != 0 + ] + + for old_chk in checkpoints[cfg.keep_interval_updates :]: + if os.path.lexists(old_chk): + os.remove(old_chk) + elif PathManager.exists(old_chk): + PathManager.rm(old_chk) + + if cfg.keep_last_epochs > 0 and trainer.should_save_checkpoint_on_current_rank: + # remove old epoch checkpoints; checkpoints are sorted in descending order + checkpoints = checkpoint_paths( + cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix) + ) + for old_chk in checkpoints[cfg.keep_last_epochs :]: + if os.path.lexists(old_chk): + os.remove(old_chk) + elif PathManager.exists(old_chk): + PathManager.rm(old_chk) + + if cfg.keep_best_checkpoints > 0 and trainer.should_save_checkpoint_on_current_rank: + # only keep the best N checkpoints according to validation metric + checkpoints = checkpoint_paths( + cfg.save_dir, + pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format( + cfg.best_checkpoint_metric, suffix + ), + ) + if not cfg.maximize_best_checkpoint_metric: + checkpoints = checkpoints[::-1] + for old_chk in checkpoints[cfg.keep_best_checkpoints :]: + if os.path.lexists(old_chk): + os.remove(old_chk) + elif PathManager.exists(old_chk): + PathManager.rm(old_chk) + + return saved_cp + + +def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): + """ + Load a checkpoint and restore the training iterator. + + *passthrough_args* will be passed through to + ``trainer.get_train_iterator``. + """ + + reset_optimizer = cfg.reset_optimizer + reset_lr_scheduler = cfg.reset_lr_scheduler + optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides) + reset_meters = cfg.reset_meters + reset_dataloader = cfg.reset_dataloader + + if cfg.finetune_from_model is not None and ( + reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader + ): + raise ValueError( + "--finetune-from-model can not be set together with either --reset-optimizer" + " or reset_lr_scheduler or reset_meters or reset_dataloader" + ) + + suffix = trainer.checkpoint_suffix + if ( + cfg.restore_file == "checkpoint_last.pt" + ): # default value of restore_file is 'checkpoint_last.pt' + checkpoint_path = os.path.join( + cfg.save_dir, "checkpoint_last{}.pt".format(suffix) + ) + first_launch = not PathManager.exists(checkpoint_path) + if first_launch and getattr(cfg, "continue_once", None) is not None: + checkpoint_path = cfg.continue_once + elif cfg.finetune_from_model is not None and first_launch: + # if there is no last checkpoint to restore, start the finetune from pretrained model + # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc. + if PathManager.exists(cfg.finetune_from_model): + checkpoint_path = cfg.finetune_from_model + reset_optimizer = True + reset_lr_scheduler = True + reset_meters = True + reset_dataloader = True + logger.info( + f"loading pretrained model from {checkpoint_path}: " + "optimizer, lr scheduler, meters, dataloader will be reset" + ) + else: + raise ValueError( + f"--finetune-from-model {cfg.finetune_from_model} does not exist" + ) + elif suffix is not None: + checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt") + else: + checkpoint_path = cfg.restore_file + + if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model: + raise ValueError( + "--finetune-from-model and --restore-file (non-default value) " + "can not be specified together: " + str(cfg) + ) + + extra_state = trainer.load_checkpoint( + checkpoint_path, + reset_optimizer, + reset_lr_scheduler, + optimizer_overrides, + reset_meters=reset_meters, + ) + + if ( + extra_state is not None + and "best" in extra_state + and not reset_optimizer + and not reset_meters + ): + save_checkpoint.best = extra_state["best"] + + if extra_state is not None and not reset_dataloader: + # restore iterator from checkpoint + itr_state = extra_state["train_iterator"] + epoch_itr = trainer.get_train_iterator( + epoch=itr_state["epoch"], load_dataset=True, **passthrough_args + ) + epoch_itr.load_state_dict(itr_state) + + # Preload the checkpoint for the task + task_cp_dict = extra_state.get(trainer.task.__class__.__name__, {}) + if task_cp_dict and hasattr(trainer.task, "set_checkpoint_dict"): + trainer.task.set_checkpoint_dict(task_cp_dict) + else: + epoch_itr = trainer.get_train_iterator( + epoch=1, load_dataset=True, **passthrough_args + ) + + trainer.lr_step(epoch_itr.epoch) + + return extra_state, epoch_itr + + +def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False): + """Loads a checkpoint to CPU (with upgrading for backward compatibility). + + If doing single-GPU training or if the checkpoint is only being loaded by at + most one process on each node (current default behavior is for only rank 0 + to read the checkpoint from disk), load_on_all_ranks should be False to + avoid errors from torch.distributed not having been initialized or + torch.distributed.barrier() hanging. + + If all processes on each node may be loading the checkpoint + simultaneously, load_on_all_ranks should be set to True to avoid I/O + conflicts. + + There's currently no support for > 1 but < all processes loading the + checkpoint on each node. + """ + local_path = PathManager.get_local_path(path) + # The locally cached file returned by get_local_path() may be stale for + # remote files that are periodically updated/overwritten (ex: + # checkpoint_last.pt) - so we remove the local copy, sync across processes + # (if needed), and then download a fresh copy. + if local_path != path and PathManager.path_requires_pathmanager(path): + try: + os.remove(local_path) + except FileNotFoundError: + # With potentially multiple processes removing the same file, the + # file being missing is benign (missing_ok isn't available until + # Python 3.8). + pass + if load_on_all_ranks: + torch.distributed.barrier() + local_path = PathManager.get_local_path(path) + + with open(local_path, "rb") as f: + state = torch.load(f, map_location=torch.device("cpu"), weights_only=False) + + if "args" in state and state["args"] is not None and arg_overrides is not None: + args = state["args"] + for arg_name, arg_val in arg_overrides.items(): + setattr(args, arg_name, arg_val) + + if "cfg" in state and state["cfg"] is not None: + + # hack to be able to set Namespace in dict config. this should be removed when we update to newer + # omegaconf version that supports object flags, or when we migrate all existing models + from omegaconf import __version__ as oc_version + from omegaconf import _utils + + if oc_version < "2.2": + old_primitive = _utils.is_primitive_type + _utils.is_primitive_type = lambda _: True + + state["cfg"] = OmegaConf.create(state["cfg"]) + + _utils.is_primitive_type = old_primitive + OmegaConf.set_struct(state["cfg"], True) + else: + state["cfg"] = OmegaConf.create(state["cfg"], flags={"allow_objects": True}) + + if arg_overrides is not None: + overwrite_args_by_name(state["cfg"], arg_overrides) + + state = _upgrade_state_dict(state) + return state + + +def load_model_ensemble( + filenames, + arg_overrides: Optional[Dict[str, Any]] = None, + task=None, + strict=True, + suffix="", + num_shards=1, + state=None, +): + """Loads an ensemble of models. + + Args: + filenames (List[str]): checkpoint files to load + arg_overrides (Dict[str,Any], optional): override model args that + were used during model training + task (fairseq.tasks.FairseqTask, optional): task to use for loading + """ + assert not ( + strict and num_shards > 1 + ), "Cannot load state dict with strict=True and checkpoint shards > 1" + ensemble, args, _task = load_model_ensemble_and_task( + filenames, + arg_overrides, + task, + strict, + suffix, + num_shards, + state, + ) + return ensemble, args + + +def get_maybe_sharded_checkpoint_filename( + filename: str, suffix: str, shard_idx: int, num_shards: int +) -> str: + orig_filename = filename + filename = filename.replace(".pt", suffix + ".pt") + fsdp_filename = filename[:-3] + f"-shard{shard_idx}.pt" + model_parallel_filename = orig_filename[:-3] + f"_part{shard_idx}.pt" + if PathManager.exists(fsdp_filename): + return fsdp_filename + elif num_shards > 1: + return model_parallel_filename + else: + return filename + + +def load_model_ensemble_and_task( + filenames, + arg_overrides: Optional[Dict[str, Any]] = None, + task=None, + strict=True, + suffix="", + num_shards=1, + state=None, +): + assert state is None or len(filenames) == 1 + + from fairseq import tasks + + assert not ( + strict and num_shards > 1 + ), "Cannot load state dict with strict=True and checkpoint shards > 1" + ensemble = [] + cfg = None + for filename in filenames: + orig_filename = filename + model_shard_state = {"shard_weights": [], "shard_metadata": []} + assert num_shards > 0 + st = time.time() + for shard_idx in range(num_shards): + filename = get_maybe_sharded_checkpoint_filename( + orig_filename, suffix, shard_idx, num_shards + ) + + if not PathManager.exists(filename): + raise IOError("Model file not found: {}".format(filename)) + if state is None: + state = load_checkpoint_to_cpu(filename, arg_overrides) + if "args" in state and state["args"] is not None: + cfg = convert_namespace_to_omegaconf(state["args"]) + elif "cfg" in state and state["cfg"] is not None: + cfg = state["cfg"] + else: + raise RuntimeError( + f"Neither args nor cfg exist in state keys = {state.keys()}" + ) + + if task is None: + task = tasks.setup_task(cfg.task, from_checkpoint=True) + + if "task_state" in state: + task.load_state_dict(state["task_state"]) + + argspec = inspect.getfullargspec(task.build_model) + + if "fsdp_metadata" in state and num_shards > 1: + model_shard_state["shard_weights"].append(state["model"]) + model_shard_state["shard_metadata"].append(state["fsdp_metadata"]) + # check FSDP import before the code goes too far + if not has_FSDP: + raise ImportError( + "Cannot find FullyShardedDataParallel. " + "Please install fairscale with: pip install fairscale" + ) + if shard_idx == num_shards - 1: + consolidated_model_state = FSDP.consolidate_shard_weights( + shard_weights=model_shard_state["shard_weights"], + shard_metadata=model_shard_state["shard_metadata"], + ) + if "from_checkpoint" in argspec.args: + model = task.build_model(cfg.model, from_checkpoint=True) + else: + model = task.build_model(cfg.model) + if ( + "optimizer_history" in state + and len(state["optimizer_history"]) > 0 + and "num_updates" in state["optimizer_history"][-1] + ): + model.set_num_updates( + state["optimizer_history"][-1]["num_updates"] + ) + model.load_state_dict( + consolidated_model_state, strict=strict, model_cfg=cfg.model + ) + else: + # model parallel checkpoint or unsharded checkpoint + # support old external tasks + + if "from_checkpoint" in argspec.args: + model = task.build_model(cfg.model, from_checkpoint=True) + else: + model = task.build_model(cfg.model) + if ( + "optimizer_history" in state + and len(state["optimizer_history"]) > 0 + and "num_updates" in state["optimizer_history"][-1] + ): + model.set_num_updates(state["optimizer_history"][-1]["num_updates"]) + model.load_state_dict( + state["model"], strict=strict, model_cfg=cfg.model + ) + + # reset state so it gets loaded for the next model in ensemble + state = None + if shard_idx % 10 == 0 and shard_idx > 0: + elapsed = time.time() - st + logger.info( + f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard" + ) + + # build model for ensemble + ensemble.append(model) + return ensemble, cfg, task + + +def load_model_ensemble_and_task_from_hf_hub( + model_id, + cache_dir: Optional[str] = None, + arg_overrides: Optional[Dict[str, Any]] = None, + **kwargs: Any, +): + try: + from huggingface_hub import snapshot_download + except ImportError: + raise ImportError( + "You need to install huggingface_hub to use `load_from_hf_hub`. " + "See https://pypi.org/project/huggingface-hub/ for installation." + ) + + library_name = "fairseq" + cache_dir = cache_dir or (Path.home() / ".cache" / library_name).as_posix() + cache_dir = snapshot_download( + model_id, cache_dir=cache_dir, library_name=library_name, **kwargs + ) + + _arg_overrides = arg_overrides or {} + _arg_overrides["data"] = cache_dir + return load_model_ensemble_and_task( + [p.as_posix() for p in Path(cache_dir).glob("*.pt")], + arg_overrides=_arg_overrides, + ) + + +def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False): + """Retrieves all checkpoints found in `path` directory. + + Checkpoints are identified by matching filename to the specified pattern. If + the pattern contains groups, the result will be sorted by the first group in + descending order. + """ + pt_regexp = re.compile(pattern) + files = PathManager.ls(path) + + entries = [] + for i, f in enumerate(files): + m = pt_regexp.fullmatch(f) + if m is not None: + idx = float(m.group(1)) if len(m.groups()) > 0 else i + entries.append((idx, m.group(0))) + if keep_match: + return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)] + else: + return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] + + +def torch_persistent_save(obj, filename, async_write: bool = False): + if async_write: + with PathManager.opena(filename, "wb") as f: + _torch_persistent_save(obj, f) + else: + if PathManager.supports_rename(filename): + # do atomic save + with PathManager.open(filename + ".tmp", "wb") as f: + _torch_persistent_save(obj, f) + PathManager.rename(filename + ".tmp", filename) + else: + # fallback to non-atomic save + with PathManager.open(filename, "wb") as f: + _torch_persistent_save(obj, f) + + +def _torch_persistent_save(obj, f): + if isinstance(f, str): + with PathManager.open(f, "wb") as h: + torch_persistent_save(obj, h) + return + for i in range(3): + try: + return torch.save(obj, f) + except Exception: + if i == 2: + logger.error(traceback.format_exc()) + raise + else: + time.sleep(2.5) + + +def _upgrade_state_dict(state): + """Helper for upgrading old model checkpoints.""" + + # add optimizer_history + if "optimizer_history" not in state: + state["optimizer_history"] = [ + {"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]} + ] + state["last_optimizer_state"] = state["optimizer"] + del state["optimizer"] + del state["best_loss"] + # move extra_state into sub-dictionary + if "epoch" in state and "extra_state" not in state: + state["extra_state"] = { + "epoch": state["epoch"], + "batch_offset": state["batch_offset"], + "val_loss": state["val_loss"], + } + del state["epoch"] + del state["batch_offset"] + del state["val_loss"] + # reduce optimizer history's memory usage (only keep the last state) + if "optimizer" in state["optimizer_history"][-1]: + state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"] + for optim_hist in state["optimizer_history"]: + del optim_hist["optimizer"] + # record the optimizer class name + if "optimizer_name" not in state["optimizer_history"][-1]: + state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG" + # move best_loss into lr_scheduler_state + if "lr_scheduler_state" not in state["optimizer_history"][-1]: + state["optimizer_history"][-1]["lr_scheduler_state"] = { + "best": state["optimizer_history"][-1]["best_loss"] + } + del state["optimizer_history"][-1]["best_loss"] + # keep track of number of updates + if "num_updates" not in state["optimizer_history"][-1]: + state["optimizer_history"][-1]["num_updates"] = 0 + # use stateful training data iterator + if "train_iterator" not in state["extra_state"]: + state["extra_state"]["train_iterator"] = { + "epoch": state["extra_state"].get("epoch", 0), + "iterations_in_epoch": state["extra_state"].get("batch_offset", 0), + } + + # backward compatibility, cfg updates + if "args" in state and state["args"] is not None: + # old model checkpoints may not have separate source/target positions + if hasattr(state["args"], "max_positions") and not hasattr( + state["args"], "max_source_positions" + ): + state["args"].max_source_positions = state["args"].max_positions + state["args"].max_target_positions = state["args"].max_positions + # default to translation task + if not hasattr(state["args"], "task"): + state["args"].task = "translation" + # --raw-text and --lazy-load are deprecated + if getattr(state["args"], "raw_text", False): + state["args"].dataset_impl = "raw" + elif getattr(state["args"], "lazy_load", False): + state["args"].dataset_impl = "lazy" + # epochs start at 1 + if state["extra_state"]["train_iterator"] is not None: + state["extra_state"]["train_iterator"]["epoch"] = max( + state["extra_state"]["train_iterator"].get("epoch", 1), 1 + ) + # --remove-bpe ==> --postprocess + if hasattr(state["args"], "remove_bpe"): + state["args"].post_process = state["args"].remove_bpe + # --min-lr ==> --stop-min-lr + if hasattr(state["args"], "min_lr"): + state["args"].stop_min_lr = state["args"].min_lr + del state["args"].min_lr + # binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion + if hasattr(state["args"], "criterion") and state["args"].criterion in [ + "binary_cross_entropy", + "kd_binary_cross_entropy", + ]: + state["args"].criterion = "wav2vec" + # remove log_keys if it's None (criteria will supply a default value of []) + if hasattr(state["args"], "log_keys") and state["args"].log_keys is None: + delattr(state["args"], "log_keys") + # speech_pretraining => audio pretraining + if ( + hasattr(state["args"], "task") + and state["args"].task == "speech_pretraining" + ): + state["args"].task = "audio_pretraining" + # audio_cpc => wav2vec + if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc": + state["args"].arch = "wav2vec" + # convert legacy float learning rate to List[float] + if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float): + state["args"].lr = [state["args"].lr] + # convert task data arg to a string instead of List[string] + if ( + hasattr(state["args"], "data") + and isinstance(state["args"].data, list) + and len(state["args"].data) > 0 + ): + state["args"].data = state["args"].data[0] + + state["cfg"] = convert_namespace_to_omegaconf(state["args"]) + + if "cfg" in state and state["cfg"] is not None: + cfg = state["cfg"] + with open_dict(cfg): + # any upgrades for Hydra-based configs + if ( + "task" in cfg + and "eval_wer_config" in cfg.task + and isinstance(cfg.task.eval_wer_config.print_alignment, bool) + ): + cfg.task.eval_wer_config.print_alignment = "hard" + if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool): + cfg.generation.print_alignment = ( + "hard" if cfg.generation.print_alignment else None + ) + if ( + "model" in cfg + and "w2v_args" in cfg.model + and cfg.model.w2v_args is not None + and ( + hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args + ) + and hasattr(cfg.model.w2v_args.task, "eval_wer_config") + and cfg.model.w2v_args.task.eval_wer_config is not None + and isinstance( + cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool + ) + ): + cfg.model.w2v_args.task.eval_wer_config.print_alignment = "hard" + + return state + + +def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]): + """Prune the given state_dict if desired for LayerDrop + (https://arxiv.org/abs/1909.11556). + + Training with LayerDrop allows models to be robust to pruning at inference + time. This function prunes state_dict to allow smaller models to be loaded + from a larger model and re-maps the existing state_dict for this to occur. + + It's called by functions that load models from checkpoints and does not + need to be called directly. + """ + arch = None + if model_cfg is not None: + arch = ( + model_cfg._name + if isinstance(model_cfg, DictConfig) + else getattr(model_cfg, "arch", None) + ) + + if not model_cfg or arch is None or arch == "ptt_transformer": + # args should not be none, but don't crash if it is. + return state_dict + + encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None) + decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None) + + if not encoder_layers_to_keep and not decoder_layers_to_keep: + return state_dict + + # apply pruning + logger.info( + "Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop" + ) + + def create_pruning_pass(layers_to_keep, layer_name): + keep_layers = sorted( + int(layer_string) for layer_string in layers_to_keep.split(",") + ) + mapping_dict = {} + for i in range(len(keep_layers)): + mapping_dict[str(keep_layers[i])] = str(i) + + regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name)) + return {"substitution_regex": regex, "mapping_dict": mapping_dict} + + pruning_passes = [] + if encoder_layers_to_keep: + pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder")) + if decoder_layers_to_keep: + pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder")) + + new_state_dict = {} + for layer_name in state_dict.keys(): + match = re.search(r"\.layers\.(\d+)\.", layer_name) + # if layer has no number in it, it is a supporting layer, such as an + # embedding + if not match: + new_state_dict[layer_name] = state_dict[layer_name] + continue + + # otherwise, layer should be pruned. + original_layer_number = match.group(1) + # figure out which mapping dict to replace from + for pruning_pass in pruning_passes: + if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[ + "substitution_regex" + ].search(layer_name): + new_layer_number = pruning_pass["mapping_dict"][original_layer_number] + substitution_match = pruning_pass["substitution_regex"].search( + layer_name + ) + new_state_key = ( + layer_name[: substitution_match.start(1)] + + new_layer_number + + layer_name[substitution_match.end(1) :] + ) + new_state_dict[new_state_key] = state_dict[layer_name] + + # Since layers are now pruned, *_layers_to_keep are no longer needed. + # This is more of "It would make it work fix" rather than a proper fix. + if isinstance(model_cfg, DictConfig): + context = open_dict(model_cfg) + else: + context = contextlib.ExitStack() + with context: + if hasattr(model_cfg, "encoder_layers_to_keep"): + model_cfg.encoder_layers_to_keep = None + if hasattr(model_cfg, "decoder_layers_to_keep"): + model_cfg.decoder_layers_to_keep = None + + return new_state_dict + + +def load_pretrained_component_from_model( + component: Union[FairseqEncoder, FairseqDecoder], + checkpoint: str, + strict: bool = True, +): + """ + Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the + provided `component` object. If state_dict fails to load, there may be a + mismatch in the architecture of the corresponding `component` found in the + `checkpoint` file. + """ + if not PathManager.exists(checkpoint): + raise IOError("Model file not found: {}".format(checkpoint)) + state = load_checkpoint_to_cpu(checkpoint) + if isinstance(component, FairseqEncoder): + component_type = "encoder" + elif isinstance(component, FairseqDecoder): + component_type = "decoder" + else: + raise ValueError( + "component to load must be either a FairseqEncoder or " + "FairseqDecoder. Loading other component types are not supported." + ) + component_state_dict = OrderedDict() + for key in state["model"].keys(): + if key.startswith(component_type): + # encoder.input_layers.0.0.weight --> input_layers.0.0.weight + component_subkey = key[len(component_type) + 1 :] + component_state_dict[component_subkey] = state["model"][key] + component.load_state_dict(component_state_dict, strict=strict) + return component + + +def verify_checkpoint_directory(save_dir: str) -> None: + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + temp_file_path = os.path.join(save_dir, "dummy") + try: + with open(temp_file_path, "w"): + pass + except OSError as e: + logger.warning( + "Unable to access checkpoint save directory: {}".format(save_dir) + ) + raise e + else: + os.remove(temp_file_path) + + +def save_ema_as_checkpoint(src_path, dst_path): + state = load_ema_from_checkpoint(src_path) + torch_persistent_save(state, dst_path) + + +def load_ema_from_checkpoint(fpath): + """Loads exponential moving averaged (EMA) checkpoint from input and + returns a model with ema weights. + + Args: + fpath: A string path of checkpoint to load from. + + Returns: + A dict of string keys mapping to various values. The 'model' key + from the returned dict should correspond to an OrderedDict mapping + string parameter names to torch Tensors. + """ + params_dict = collections.OrderedDict() + new_state = None + + with PathManager.open(fpath, "rb") as f: + new_state = torch.load( + f, + map_location=( + lambda s, _: torch.serialization.default_restore_location(s, "cpu") + ), + ) + + # EMA model is stored in a separate "extra state" + model_params = new_state["extra_state"]["ema"] + + for key in list(model_params.keys()): + p = model_params[key] + if isinstance(p, torch.HalfTensor): + p = p.float() + if key not in params_dict: + params_dict[key] = p.clone() + # NOTE: clone() is needed in case of p is a shared parameter + else: + raise ValueError("Key {} is repeated in EMA model params.".format(key)) + + if len(params_dict) == 0: + raise ValueError( + f"Input checkpoint path '{fpath}' does not contain " + "ema model weights, is this model trained with EMA?" + ) + + new_state["model"] = params_dict + return new_state diff --git a/fairseq/clib/cuda/ngram_repeat_block_cuda.cpp b/fairseq/clib/cuda/ngram_repeat_block_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..707219105a17a691e43de1296a72bbaffa0c7fe9 --- /dev/null +++ b/fairseq/clib/cuda/ngram_repeat_block_cuda.cpp @@ -0,0 +1,55 @@ +/* +Copyright (c) Microsoft Corporation. +Licensed under the MIT License. +*/ + +#include +#include + +/* +CPP Binding for CUDA OP +*/ + +// CUDA forward declarations +torch::Tensor ngram_repeat_block_cuda_forward( + torch::Tensor tokens, + torch::Tensor lprobs, + int bsz, + int step, + int beam_size, + int no_repeat_ngram_size); + +#define CHECK_CUDA(x) \ + TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +// Input check and call to CUDA OP +// Backward method not required +torch::Tensor ngram_repeat_block_forward( + torch::Tensor tokens, + torch::Tensor lprobs, + int bsz, + int step, + int beam_size, + int no_repeat_ngram_size) { + CHECK_INPUT(tokens); + CHECK_INPUT(lprobs); + assert(bsz > 0); + assert(step >= 0); + assert(beam_size > 0); + assert(no_repeat_ngram_size > 0); + + return ngram_repeat_block_cuda_forward( + tokens, lprobs, bsz, step, beam_size, no_repeat_ngram_size); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "forward", + &ngram_repeat_block_forward, + "No Repeat Ngram Block forward (CUDA)"); +} diff --git a/fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu b/fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..bd6106cba0672c3ff29c925b0f5cea557ab3eced --- /dev/null +++ b/fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu @@ -0,0 +1,82 @@ +/* +Copyright (c) Microsoft Corporation. +Licensed under the MIT License. +*/ + +/* +Kernel implementation for blocking repeated n-grams. +*/ + +#include +#include +#include +#include +#include + +// Ban repeated ngrams of length = 'no_repeat_ngram_size' +__global__ void banRepeatedTokens( + long* __restrict__ tokens, + float* __restrict__ lprobs, + int max_predict_len, + int vocab_size, + int no_repeat_ngram_size) { + auto row = blockIdx.x; + auto col = threadIdx.x; + auto start = row * (max_predict_len) + col; + // Each thread compares ngram starting from + // thread index with final ngram starting from + // step - no_repeat_ngram_size +2 + auto check_start_pos = blockDim.x; + auto lprob_start = row * vocab_size; + bool is_banned = true; + extern __shared__ long tokens_shm[]; + tokens_shm[col] = tokens[start]; + if (col == blockDim.x - 1) { + for (int i = 1; i < no_repeat_ngram_size; i++) { + if (col + i < max_predict_len) { + tokens_shm[col + i] = tokens[start + i]; + } + } + } + __syncthreads(); + + for (int k = 0; k < no_repeat_ngram_size - 1; k++) { + if (tokens_shm[col + k] != tokens_shm[check_start_pos + k]) { + is_banned = false; + } + } + if (is_banned == true) { + auto token_to_be_banned = tokens_shm[col + no_repeat_ngram_size - 1]; + lprobs[lprob_start + token_to_be_banned] = -INFINITY; + } +} + +// Allocate blocks and threads based on +// batch size and sequence length and launch +// kernel +torch::Tensor ngram_repeat_block_cuda_forward( + const torch::Tensor tokens, + torch::Tensor lprobs, + int bsz, + int step, + int beam_size, + int no_repeat_ngram_size) { + int threads = step - no_repeat_ngram_size + 2; + if (threads <= 0) + return lprobs; + int max_predict_len = tokens.size(1); + int vocab_size = lprobs.size(1); + auto token_ptr = tokens.data_ptr(); + auto lprob_ptr = lprobs.data_ptr(); + int blocks = bsz * beam_size; + int shared_mem_size = (step + 1) * sizeof(long); + + // Launching N blocks where N is number of samples in a batch (beams*bsz) + // Launching T threads where T is number of previous ngrams in a sample + // Allocating shared mem per block for fastser access of input tokens since + // each token will be accessed N times to compare with current Ngram where + // N is Ngram size. + banRepeatedTokens<<>>( + token_ptr, lprob_ptr, max_predict_len, vocab_size, no_repeat_ngram_size); + return lprobs; +} diff --git a/fairseq/clib/libbase/balanced_assignment.cpp b/fairseq/clib/libbase/balanced_assignment.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1a5a1061f3892be5a17e49192f744c39e0d395e8 --- /dev/null +++ b/fairseq/clib/libbase/balanced_assignment.cpp @@ -0,0 +1,109 @@ +/** + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* +C++ code for solving the linear assignment problem. +Based on the Auction Algorithm from +https://dspace.mit.edu/bitstream/handle/1721.1/3265/P-2108-26912652.pdf and the +implementation from: https://github.com/bkj/auction-lap Adapted to be more +efficient when each worker is looking for k jobs instead of 1. +*/ +#include +#include +using namespace torch::indexing; +torch::Tensor balanced_assignment(torch::Tensor job_and_worker_to_score) { + int max_iterations = 100; + torch::Tensor epsilon = + (job_and_worker_to_score.max() - job_and_worker_to_score.min()) / 50; + epsilon.clamp_min_(1e-04); + torch::Tensor worker_and_job_to_score = + job_and_worker_to_score.detach().transpose(0, 1).contiguous(); + int num_workers = worker_and_job_to_score.size(0); + int num_jobs = worker_and_job_to_score.size(1); + auto device = worker_and_job_to_score.device(); + int jobs_per_worker = num_jobs / num_workers; + torch::Tensor value = worker_and_job_to_score.clone(); + int counter = 0; + torch::Tensor max_value = worker_and_job_to_score.max(); + + torch::Tensor bid_indices; + torch::Tensor cost = worker_and_job_to_score.new_zeros({1, num_jobs}); + torch::Tensor bids = + worker_and_job_to_score.new_empty({num_workers, num_jobs}); + torch::Tensor bid_increments = + worker_and_job_to_score.new_empty({num_workers, jobs_per_worker}); + torch::Tensor top_values = + worker_and_job_to_score.new_empty({num_workers, jobs_per_worker + 1}); + torch::Tensor high_bids = worker_and_job_to_score.new_empty({num_jobs}); + + torch::Tensor top_index = top_values.to(torch::kLong); + torch::Tensor high_bidders = top_index.new_empty({num_jobs}); + torch::Tensor have_bids = high_bidders.to(torch::kBool); + torch::Tensor jobs_indices = + torch::arange({num_jobs}, torch::dtype(torch::kLong).device(device)); + torch::Tensor true_tensor = + torch::ones({1}, torch::dtype(torch::kBool).device(device)); + + while (true) { + bids.zero_(); + torch::topk_out(top_values, top_index, value, jobs_per_worker + 1, 1); + + // Each worker bids the difference in value between that job and the k+1th + // job + torch::sub_out( + bid_increments, + top_values.index({Slice(None, None), Slice(0, jobs_per_worker)}), + top_values.index({Slice(None, None), jobs_per_worker}).unsqueeze(1)); + + bid_increments.add_(epsilon); + bids.scatter_( + 1, + top_index.index({Slice(None, None), Slice(0, jobs_per_worker)}), + bid_increments); + + if (counter < max_iterations && counter > 0) { + // Put in a minimal bid to retain items from the last round if no-one else + // bids for them this round + bids.view(-1).index_put_({bid_indices}, epsilon); + } + + // Find the highest bidding worker per job + torch::max_out(high_bids, high_bidders, bids, 0); + torch::gt_out(have_bids, high_bids, 0); + + if (have_bids.all().item()) { + // All jobs were bid for + break; + } + + // Make popular items more expensive + cost.add_(high_bids); + torch::sub_out(value, worker_and_job_to_score, cost); + + bid_indices = ((high_bidders * num_jobs) + jobs_indices).index({have_bids}); + + if (counter < max_iterations) { + // Make sure that this item will be in the winning worker's top-k next + // time. + value.view(-1).index_put_({bid_indices}, max_value); + } else { + // Suboptimal approximation that converges quickly from current solution + value.view(-1).index_put_( + {bid_indices}, worker_and_job_to_score.view(-1).index({bid_indices})); + } + + counter += 1; + } + + return top_index.index({Slice(None, None), Slice(0, jobs_per_worker)}) + .reshape(-1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("balanced_assignment", &balanced_assignment, "Balanced Assignment"); +} diff --git a/fairseq/clib/libbleu/libbleu.cpp b/fairseq/clib/libbleu/libbleu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..939d9e1174e398fa48c840009b592c753a67939a --- /dev/null +++ b/fairseq/clib/libbleu/libbleu.cpp @@ -0,0 +1,157 @@ +/** + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +// NOLINTNEXTLINE +typedef struct { + size_t reflen; + size_t predlen; + size_t match1; + size_t count1; + size_t match2; + size_t count2; + size_t match3; + size_t count3; + size_t match4; + size_t count4; +} bleu_stat; + +// left trim (remove pad) +void bleu_ltrim(size_t* len, int** sent, int pad) { + size_t start = 0; + while (start < *len) { + if (*(*sent + start) != pad) { + break; + } + start++; + } + *sent += start; + *len -= start; +} + +// right trim remove (eos) +void bleu_rtrim(size_t* len, int** sent, int pad, int eos) { + size_t end = *len - 1; + while (end > 0) { + if (*(*sent + end) != eos && *(*sent + end) != pad) { + break; + } + end--; + } + *len = end + 1; +} + +// left and right trim +void bleu_trim(size_t* len, int** sent, int pad, int eos) { + bleu_ltrim(len, sent, pad); + bleu_rtrim(len, sent, pad, eos); +} + +size_t bleu_hash(int len, int* data) { + size_t h = 14695981039346656037ul; + size_t prime = 0x100000001b3; + char* b = (char*)data; + size_t blen = sizeof(int) * len; + + while (blen-- > 0) { + h ^= *b++; + h *= prime; + } + + return h; +} + +void bleu_addngram( + size_t* ntotal, + size_t* nmatch, + size_t n, + size_t reflen, + int* ref, + size_t predlen, + int* pred) { + if (predlen < n) { + return; + } + + predlen = predlen - n + 1; + (*ntotal) += predlen; + + if (reflen < n) { + return; + } + + reflen = reflen - n + 1; + + std::map count; + while (predlen > 0) { + size_t w = bleu_hash(n, pred++); + count[w]++; + predlen--; + } + + while (reflen > 0) { + size_t w = bleu_hash(n, ref++); + if (count[w] > 0) { + (*nmatch)++; + count[w] -= 1; + } + reflen--; + } +} + +extern "C" { + +#ifdef _WIN64 +__declspec(dllexport) +#endif + void bleu_zero_init(bleu_stat* stat) { + std::memset(stat, 0, sizeof(bleu_stat)); +} + +#ifdef _WIN64 +__declspec(dllexport) +#endif + void bleu_one_init(bleu_stat* stat) { + bleu_zero_init(stat); + stat->count1 = 0; + stat->count2 = 1; + stat->count3 = 1; + stat->count4 = 1; + stat->match1 = 0; + stat->match2 = 1; + stat->match3 = 1; + stat->match4 = 1; +} + +#ifdef _WIN64 +__declspec(dllexport) +#endif + void bleu_add( + bleu_stat* stat, + size_t reflen, + int* ref, + size_t predlen, + int* pred, + int pad, + int eos) { + + bleu_trim(&reflen, &ref, pad, eos); + bleu_trim(&predlen, &pred, pad, eos); + stat->reflen += reflen; + stat->predlen += predlen; + + bleu_addngram(&stat->count1, &stat->match1, 1, reflen, ref, predlen, pred); + bleu_addngram(&stat->count2, &stat->match2, 2, reflen, ref, predlen, pred); + bleu_addngram(&stat->count3, &stat->match3, 3, reflen, ref, predlen, pred); + bleu_addngram(&stat->count4, &stat->match4, 4, reflen, ref, predlen, pred); +} +} diff --git a/fairseq/clib/libbleu/module.cpp b/fairseq/clib/libbleu/module.cpp new file mode 100644 index 0000000000000000000000000000000000000000..35288b3177185670135f7bdc1f1589c5bb992304 --- /dev/null +++ b/fairseq/clib/libbleu/module.cpp @@ -0,0 +1,33 @@ +/** + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +static PyMethodDef method_def[] = {{NULL, NULL, 0, NULL}}; // NOLINT + +static struct PyModuleDef module_def = { + PyModuleDef_HEAD_INIT, + "libbleu", /* name of module */ + // NOLINTNEXTLINE + NULL, /* module documentation, may be NULL */ + -1, /* size of per-interpreter state of the module, + or -1 if the module keeps state in global variables. */ + method_def}; // NOLINT + +#if PY_MAJOR_VERSION == 2 +PyMODINIT_FUNC init_libbleu() +#else +PyMODINIT_FUNC PyInit_libbleu() +#endif +{ + PyObject* m = PyModule_Create(&module_def); + if (!m) { + return NULL; + } + return m; +} diff --git a/fairseq/clib/libnat/edit_dist.cpp b/fairseq/clib/libnat/edit_dist.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9ffb60569d74d2868ed8113b7c787ef870e9da20 --- /dev/null +++ b/fairseq/clib/libnat/edit_dist.cpp @@ -0,0 +1,231 @@ +/** + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include // @manual=//caffe2:torch_extension +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace ::std; + +vector> edit_distance2_with_dp( + vector& x, + vector& y) { + uint32_t lx = x.size(); + uint32_t ly = y.size(); + vector> d(lx + 1, vector(ly + 1)); + for (uint32_t i = 0; i < lx + 1; i++) { + d[i][0] = i; + } + for (uint32_t j = 0; j < ly + 1; j++) { + d[0][j] = j; + } + for (uint32_t i = 1; i < lx + 1; i++) { + for (uint32_t j = 1; j < ly + 1; j++) { + d[i][j] = + min(min(d[i - 1][j], d[i][j - 1]) + 1, + d[i - 1][j - 1] + 2 * (x.at(i - 1) == y.at(j - 1) ? 0 : 1)); + } + } + return d; +} + +vector> edit_distance2_backtracking( + vector>& d, + vector& x, + vector& y, + uint32_t terminal_symbol) { + vector seq; + vector> edit_seqs(x.size() + 2, vector()); + /* + edit_seqs: + 0~x.size() cell is the insertion sequences + last cell is the delete sequence + */ + + if (x.size() == 0) { + edit_seqs.at(0) = y; + return edit_seqs; + } + + uint32_t i = d.size() - 1; + uint32_t j = d.at(0).size() - 1; + + while ((i >= 0) && (j >= 0)) { + if ((i == 0) && (j == 0)) { + break; + } + + if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) { + seq.push_back(1); // insert + seq.push_back(y.at(j - 1)); + j--; + } else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) { + seq.push_back(2); // delete + seq.push_back(x.at(i - 1)); + i--; + } else { + seq.push_back(3); // keep + seq.push_back(x.at(i - 1)); + i--; + j--; + } + } + + uint32_t prev_op, op, s, word; + prev_op = 0, s = 0; + for (uint32_t k = 0; k < seq.size() / 2; k++) { + op = seq.at(seq.size() - 2 * k - 2); + word = seq.at(seq.size() - 2 * k - 1); + if (prev_op != 1) { + s++; + } + if (op == 1) // insert + { + edit_seqs.at(s - 1).push_back(word); + } else if (op == 2) // delete + { + edit_seqs.at(x.size() + 1).push_back(1); + } else { + edit_seqs.at(x.size() + 1).push_back(0); + } + + prev_op = op; + } + + for (uint32_t k = 0; k < edit_seqs.size(); k++) { + if (edit_seqs[k].size() == 0) { + edit_seqs[k].push_back(terminal_symbol); + } + } + return edit_seqs; +} + +vector> edit_distance2_backtracking_with_delete( + vector>& d, + vector& x, + vector& y, + uint32_t terminal_symbol, + uint32_t deletion_symbol) { + vector seq; + vector> edit_seqs(x.size() + 1, vector()); + /* + edit_seqs: + 0~x.size() cell is the insertion sequences + last cell is the delete sequence + */ + + if (x.size() == 0) { + edit_seqs.at(0) = y; + return edit_seqs; + } + + uint32_t i = d.size() - 1; + uint32_t j = d.at(0).size() - 1; + + while ((i >= 0) && (j >= 0)) { + if ((i == 0) && (j == 0)) { + break; + } + + if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) { + seq.push_back(1); // insert + seq.push_back(y.at(j - 1)); + j--; + } else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) { + seq.push_back(2); // delete + seq.push_back(x.at(i - 1)); + i--; + } else { + seq.push_back(3); // keep + seq.push_back(x.at(i - 1)); + i--; + j--; + } + } + + uint32_t prev_op, op, s, word; + prev_op = 0, s = 0; + for (uint32_t k = 0; k < seq.size() / 2; k++) { + op = seq.at(seq.size() - 2 * k - 2); + word = seq.at(seq.size() - 2 * k - 1); + if (prev_op != 1) { + s++; + } + if (op == 1) // insert + { + edit_seqs.at(s - 1).push_back(word); + } else if (op == 2) // delete + { + edit_seqs.at(s - 1).push_back(deletion_symbol); + } + + prev_op = op; + } + + for (uint32_t k = 0; k < edit_seqs.size(); k++) { + if (edit_seqs.at(k).size() == 0) { + edit_seqs.at(k).push_back(terminal_symbol); + } + } + return edit_seqs; +} + +vector compute_ed2( + vector>& xs, + vector>& ys) { + vector distances(xs.size()); + for (uint32_t i = 0; i < xs.size(); i++) { + vector> d = edit_distance2_with_dp(xs.at(i), ys.at(i)); + distances.at(i) = d.at(xs.at(i).size()).at(ys.at(i).size()); + } + return distances; +} + +vector>> suggested_ed2_path( + vector>& xs, + vector>& ys, + uint32_t terminal_symbol) { + vector>> seq(xs.size()); + for (uint32_t i = 0; i < xs.size(); i++) { + vector> d = edit_distance2_with_dp(xs.at(i), ys.at(i)); + seq.at(i) = + edit_distance2_backtracking(d, xs.at(i), ys.at(i), terminal_symbol); + } + return seq; +} + +vector>> suggested_ed2_path_with_delete( + vector>& xs, + vector>& ys, + uint32_t terminal_symbol, + uint32_t deletion_symbol) { + vector>> seq(xs.size()); + for (uint32_t i = 0; i < xs.size(); i++) { + vector> d = edit_distance2_with_dp(xs.at(i), ys.at(i)); + seq.at(i) = edit_distance2_backtracking_with_delete( + d, xs.at(i), ys.at(i), terminal_symbol, deletion_symbol); + } + return seq; +} + +PYBIND11_MODULE(libnat, m) { + m.def("compute_ed2", &compute_ed2, "compute_ed2"); + m.def("suggested_ed2_path", &suggested_ed2_path, "suggested_ed2_path"); + m.def( + "suggested_ed2_path_with_delete", + &suggested_ed2_path_with_delete, + "suggested_ed2_path_with_delete"); +} diff --git a/fairseq/clib/libnat_cuda/binding.cpp b/fairseq/clib/libnat_cuda/binding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ced91c0d0afab9071842911d9876e6360d90284a --- /dev/null +++ b/fairseq/clib/libnat_cuda/binding.cpp @@ -0,0 +1,67 @@ +/** + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + This code is partially adpoted from + https://github.com/1ytic/pytorch-edit-distance + */ + +#include +#include "edit_dist.h" + +#ifndef TORCH_CHECK +#define TORCH_CHECK AT_CHECK +#endif + +#define CHECK_CUDA(x) \ + TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +torch::Tensor LevenshteinDistance( + torch::Tensor source, + torch::Tensor target, + torch::Tensor source_length, + torch::Tensor target_length) { + CHECK_INPUT(source); + CHECK_INPUT(target); + CHECK_INPUT(source_length); + CHECK_INPUT(target_length); + return LevenshteinDistanceCuda(source, target, source_length, target_length); +} + +torch::Tensor GenerateDeletionLabel( + torch::Tensor source, + torch::Tensor operations) { + CHECK_INPUT(source); + CHECK_INPUT(operations); + return GenerateDeletionLabelCuda(source, operations); +} + +std::pair GenerateInsertionLabel( + torch::Tensor target, + torch::Tensor operations) { + CHECK_INPUT(target); + CHECK_INPUT(operations); + return GenerateInsertionLabelCuda(target, operations); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("levenshtein_distance", &LevenshteinDistance, "Levenshtein distance"); + m.def( + "generate_deletion_labels", + &GenerateDeletionLabel, + "Generate Deletion Label"); + m.def( + "generate_insertion_labels", + &GenerateInsertionLabel, + "Generate Insertion Label"); +} diff --git a/fairseq/clib/libnat_cuda/edit_dist.cu b/fairseq/clib/libnat_cuda/edit_dist.cu new file mode 100644 index 0000000000000000000000000000000000000000..1ea5ec7e3cb31557fde20bc457f986bbcecc9cb2 --- /dev/null +++ b/fairseq/clib/libnat_cuda/edit_dist.cu @@ -0,0 +1,344 @@ +/** + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "edit_dist.h" + +#include +#include +#include +#include +#include // std::pair + +template +__global__ void generate_deletion_label_kernel( + const scalar_t* __restrict__ source, + const size_t source_size, + const size_t operation_size, + int* __restrict__ operations, + int* __restrict__ labels) { + const int index = blockIdx.x; + const int offset = index * operation_size; + const int offset_label = index * source_size; + + for (int i = 0; i < source_size; i++) { + labels[offset_label + i] = 0; + } + + int k = 0; + for (int i = 0; i < operation_size; i++) { + if (operations[offset + i] == 0) { + break; + } else if (operations[offset + i] == 1) { + continue; + } else { + labels[offset_label + k] = 3 - operations[offset + i]; + k++; + } + } +} + +template +__global__ void generate_insertion_label_kernel( + const scalar_t* __restrict__ target, + const size_t target_size, + const size_t operation_size, + int* __restrict__ operations, + int* __restrict__ labels, + int* __restrict__ masks) { + const int index = blockIdx.x; + const int offset = index * operation_size; + const int offset_label = index * target_size; + + int k = 0; + int u = 0; + int m = 0; + + for (int i = 0; i < target_size; i++) { + labels[offset_label + i] = 0; + masks[offset_label + i] = 0; + } + + for (int i = 0; i < operation_size - 1; i++) { + if (operations[offset + i] == 0) { + break; + } else if (operations[offset + i] == 2) { + continue; + } else if (operations[offset + i] == 1) { + masks[offset_label + m] = 1; + u++; + m++; + } else { + labels[offset_label + k] = u; + masks[offset_label + m] = 0; + k++; + m++; + u = 0; + } + } +} + +template +__global__ void levenshtein_distance_kernel( + const scalar_t* __restrict__ source, + const scalar_t* __restrict__ target, + const int* __restrict__ source_length, + const int* __restrict__ target_length, + const size_t source_size, + const size_t target_size, + int* __restrict__ operations, + int* __restrict__ errors_curr) { + const int index = blockIdx.x; + const int offset = index * (source_size + target_size); + const int d = index * (source_size + 1) * (target_size + 1); + const int t = target_size + 1; + + auto err_idx = [d, t](int i, int j) { return d + i * t + j; }; + auto opt_idx = [offset](int k) { return offset + k; }; + + const int hyp_len = source_length[index]; + const int ref_len = target_length[index]; + const scalar_t* hyp_begin = source + index * source_size; + const scalar_t* ref_begin = target + index * target_size; + + // dynamic programming + for (int i = 0; i <= hyp_len; i++) { + errors_curr[err_idx(i, 0)] = i; + } + for (int j = 0; j <= ref_len; j++) { + errors_curr[err_idx(0, j)] = j; + } + for (int i = 1; i <= hyp_len; i++) { + for (int j = 1; j <= ref_len; j++) { + errors_curr[err_idx(i, j)] = min( + min(errors_curr[err_idx(i - 1, j)], errors_curr[err_idx(i, j - 1)]) + + 1, + errors_curr[err_idx(i - 1, j - 1)] + + 2 * (*(hyp_begin + i - 1) == *(ref_begin + j - 1) ? 0 : 1)); + } + } + + // back-tracing + int i = hyp_len; + int j = ref_len; + int o = hyp_len + ref_len; + + for (int k = 0; k < source_size + target_size; k++) { + operations[opt_idx(k)] = 0; + } + + while ((i >= 0) && (j >= 0)) { + if ((i == 0) && (j == 0)) { + break; + } + + if ((j > 0) && + (errors_curr[err_idx(i, j - 1)] < errors_curr[err_idx(i, j)])) { + o--; + operations[opt_idx(o)] = 1; + j--; // insertion + } else if ( + (i > 0) && + (errors_curr[err_idx(i - 1, j)] < errors_curr[err_idx(i, j)])) { + o--; + operations[opt_idx(o)] = 2; + i--; // deletion + } else { + o--; + operations[opt_idx(o)] = 3; + i--; + j--; // do nothing + } + } + + // moving to the left + for (int k = 0; k < hyp_len + ref_len; k++) { + if (k + o < hyp_len + ref_len) { + operations[opt_idx(k)] = operations[opt_idx(k + o)]; + } else { + operations[opt_idx(k)] = 0; // padding + } + } +} + +template +__global__ void faster_levenshtein_distance_kernel( + const scalar_t* __restrict__ source, + const scalar_t* __restrict__ target, + const int* __restrict__ source_length, + const int* __restrict__ target_length, + const size_t source_size, + const size_t target_size, + int* __restrict__ operations) { + extern __shared__ short errors[]; + auto errors_curr = errors; + + const int index = blockIdx.x; + const int offset = index * (source_size + target_size); + const int t = target_size + 1; + + auto err_idx = [t](int i, int j) { return i * t + j; }; + auto opt_idx = [offset](int k) { return offset + k; }; + + const int hyp_len = source_length[index]; + const int ref_len = target_length[index]; + const scalar_t* hyp_begin = source + index * source_size; + const scalar_t* ref_begin = target + index * target_size; + + // dynamic programming + for (int i = 0; i <= hyp_len; i++) { + errors_curr[err_idx(i, 0)] = i; + } + for (int j = 0; j <= ref_len; j++) { + errors_curr[err_idx(0, j)] = j; + } + for (int i = 1; i <= hyp_len; i++) { + for (int j = 1; j <= ref_len; j++) { + errors_curr[err_idx(i, j)] = min( + min(errors_curr[err_idx(i - 1, j)], errors_curr[err_idx(i, j - 1)]) + + 1, + errors_curr[err_idx(i - 1, j - 1)] + + 2 * (*(hyp_begin + i - 1) == *(ref_begin + j - 1) ? 0 : 1)); + } + } + + // back-tracing + int i = hyp_len; + int j = ref_len; + int o = hyp_len + ref_len; + + for (int k = 0; k < source_size + target_size; k++) { + operations[opt_idx(k)] = 0; + } + + while ((i >= 0) && (j >= 0)) { + if ((i == 0) && (j == 0)) { + break; + } + + if ((j > 0) && + (errors_curr[err_idx(i, j - 1)] < errors_curr[err_idx(i, j)])) { + o--; + operations[opt_idx(o)] = 1; + j--; // insertion + } else if ( + (i > 0) && + (errors_curr[err_idx(i - 1, j)] < errors_curr[err_idx(i, j)])) { + o--; + operations[opt_idx(o)] = 2; + i--; // deletion + } else { + o--; + operations[opt_idx(o)] = 3; + i--; + j--; // do nothing + } + } + + // moving to the left + for (int k = 0; k < hyp_len + ref_len; k++) { + if (k + o < hyp_len + ref_len) { + operations[opt_idx(k)] = operations[opt_idx(k + o)]; + } else { + operations[opt_idx(k)] = 0; // padding + } + } +} + +torch::Tensor GenerateDeletionLabelCuda( + torch::Tensor source, + torch::Tensor operations) { + const auto batch_size = source.size(0); + at::TensorOptions options(source.device()); + options = options.dtype(at::ScalarType::Int); + auto labels = torch::empty({batch_size, source.size(1)}, options); + auto stream = at::cuda::getCurrentCUDAStream(source.device().index()); + + AT_DISPATCH_ALL_TYPES(source.scalar_type(), "generate_deletion_labels", ([&] { + generate_deletion_label_kernel + <<>>( + source.data_ptr(), + source.size(1), + operations.size(1), + operations.data_ptr(), + labels.data_ptr()); + })); + + return labels; +} + +std::pair GenerateInsertionLabelCuda( + torch::Tensor target, + torch::Tensor operations) { + const auto batch_size = target.size(0); + at::TensorOptions options(target.device()); + options = options.dtype(at::ScalarType::Int); + auto labels = torch::empty({batch_size, target.size(1)}, options); + auto masks = torch::empty({batch_size, target.size(1)}, options); + auto stream = at::cuda::getCurrentCUDAStream(target.device().index()); + + AT_DISPATCH_ALL_TYPES( + target.scalar_type(), "generate_insertion_labels", ([&] { + generate_insertion_label_kernel<<>>( + target.data_ptr(), + target.size(1), + operations.size(1), + operations.data_ptr(), + labels.data_ptr(), + masks.data_ptr()); + })); + + return std::make_pair(labels, masks); +} + +torch::Tensor LevenshteinDistanceCuda( + torch::Tensor source, + torch::Tensor target, + torch::Tensor source_length, + torch::Tensor target_length) { + const auto batch_size = source.size(0); + const auto shared_size = + (source.size(1) + 1) * (target.size(1) + 1) * sizeof(short); + + at::TensorOptions options(source.device()); + options = options.dtype(at::ScalarType::Int); + auto operations = + torch::empty({batch_size, source.size(1) + target.size(1)}, options); + auto stream = at::cuda::getCurrentCUDAStream(source.device().index()); + + if (shared_size > 40000) { + auto distances = torch::empty( + {batch_size, (source.size(1) + 1) * (target.size(1) + 1)}, options); + AT_DISPATCH_ALL_TYPES(source.scalar_type(), "levenshtein_distance", ([&] { + levenshtein_distance_kernel + <<>>( + source.data_ptr(), + target.data_ptr(), + source_length.data_ptr(), + target_length.data_ptr(), + source.size(1), + target.size(1), + operations.data_ptr(), + distances.data_ptr()); + })); + } else { + AT_DISPATCH_ALL_TYPES( + source.scalar_type(), "faster_levenshtein_distance", ([&] { + faster_levenshtein_distance_kernel + <<>>( + source.data_ptr(), + target.data_ptr(), + source_length.data_ptr(), + target_length.data_ptr(), + source.size(1), + target.size(1), + operations.data_ptr()); + })); + } + + return operations; +} diff --git a/fairseq/clib/libnat_cuda/edit_dist.h b/fairseq/clib/libnat_cuda/edit_dist.h new file mode 100644 index 0000000000000000000000000000000000000000..5220c52fd80529b90a67ba74e9ca73c668dab099 --- /dev/null +++ b/fairseq/clib/libnat_cuda/edit_dist.h @@ -0,0 +1,25 @@ +/** + * Copyright 2017-present, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +torch::Tensor LevenshteinDistanceCuda( + torch::Tensor source, + torch::Tensor target, + torch::Tensor source_length, + torch::Tensor target_length); + +torch::Tensor GenerateDeletionLabelCuda( + torch::Tensor source, + torch::Tensor operations); + +std::pair GenerateInsertionLabelCuda( + torch::Tensor source, + torch::Tensor operations); diff --git a/fairseq/config/__init__.py b/fairseq/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6264236915a7269a4d920ee8213004374dd86a9a --- /dev/null +++ b/fairseq/config/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/fairseq/config/config.yaml b/fairseq/config/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2ed7168cb7f7473c43d864478c5c6ce51639e030 --- /dev/null +++ b/fairseq/config/config.yaml @@ -0,0 +1,19 @@ +# @package _group_ + +hydra: + run: + dir: . + +defaults: + - _self_ + - task: null + - model: null + - criterion: cross_entropy + - optimizer: null + - lr_scheduler: fixed + - bpe: null + - tokenizer: null + - scoring: null + - generation: null + - common_eval: null + - eval_lm: null diff --git a/fairseq/config/fb_run_config/slurm.yaml b/fairseq/config/fb_run_config/slurm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..20cf8f52016ddc23d5bcb09ef94d900a035b81ca --- /dev/null +++ b/fairseq/config/fb_run_config/slurm.yaml @@ -0,0 +1,29 @@ +# @package _global_ + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - fb_run_config + - distributed_training.distributed_port + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} + launcher: + cpus_per_task: 60 + gpus_per_node: ??? + tasks_per_node: 1 + nodes: 1 + partition: learnfair + mem_gb: 400 + timeout_min: 4320 + max_num_timeout: 10 + name: ${env:PREFIX}_${hydra.job.config_name} + submitit_folder: ${hydra.sweep.dir} + +distributed_training: + ddp_backend: c10d + distributed_world_size: ??? + distributed_port: ??? diff --git a/fairseq/config/model/transformer_lm/transformer_lm_baevski_gbw.yaml b/fairseq/config/model/transformer_lm/transformer_lm_baevski_gbw.yaml new file mode 100644 index 0000000000000000000000000000000000000000..30b1a4f1e0f5e7f7c2671ff8ec995cc32363f10f --- /dev/null +++ b/fairseq/config/model/transformer_lm/transformer_lm_baevski_gbw.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "relu" +dropout: 0.1 +attention_dropout: 0.1 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 512 +decoder_output_dim: 512 +decoder_input_dim: 512 +decoder_ffn_embed_dim: 4096 +decoder_layers: 12 +decoder_attention_heads: 16 +decoder_normalize_before: true +no_decoder_final_norm: true +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/fairseq/config/model/transformer_lm/transformer_lm_baevski_wiki103.yaml b/fairseq/config/model/transformer_lm/transformer_lm_baevski_wiki103.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1154cfa660ee5ce6a272cd1a0049eead1e92c117 --- /dev/null +++ b/fairseq/config/model/transformer_lm/transformer_lm_baevski_wiki103.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "relu" +dropout: 0.3 +attention_dropout: 0.1 +activation_dropout: 0.1 +relu_dropout: 0.1 +decoder_embed_dim: 1024 +decoder_output_dim: 1024 +decoder_input_dim: 1024 +decoder_ffn_embed_dim: 4096 +decoder_layers: 16 +decoder_attention_heads: 8 +decoder_normalize_before: true +no_decoder_final_norm: true +adaptive_softmax_cutoff: "20000,60000" +adaptive_softmax_dropout: 0.2 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: true +adaptive_input_factor: 4 +adaptive_input_cutoff: "20000,60000" +tie_adaptive_weights: true +tie_adaptive_proj: true +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/fairseq/config/model/transformer_lm/transformer_lm_big.yaml b/fairseq/config/model/transformer_lm/transformer_lm_big.yaml new file mode 100644 index 0000000000000000000000000000000000000000..309575310bfc5d9c5cde31563073bef18abc646e --- /dev/null +++ b/fairseq/config/model/transformer_lm/transformer_lm_big.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "relu" +dropout: 0.1 +attention_dropout: 0.0 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 1024 +decoder_output_dim: 1024 +decoder_input_dim: 1024 +decoder_ffn_embed_dim: 4096 +decoder_layers: 12 +decoder_attention_heads: 16 +decoder_normalize_before: true +no_decoder_final_norm: false +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/fairseq/config/model/transformer_lm/transformer_lm_gbw.yaml b/fairseq/config/model/transformer_lm/transformer_lm_gbw.yaml new file mode 100644 index 0000000000000000000000000000000000000000..30b1a4f1e0f5e7f7c2671ff8ec995cc32363f10f --- /dev/null +++ b/fairseq/config/model/transformer_lm/transformer_lm_gbw.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "relu" +dropout: 0.1 +attention_dropout: 0.1 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 512 +decoder_output_dim: 512 +decoder_input_dim: 512 +decoder_ffn_embed_dim: 4096 +decoder_layers: 12 +decoder_attention_heads: 16 +decoder_normalize_before: true +no_decoder_final_norm: true +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml b/fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2c6cb7be3801115371566932ffc78651c9ac6c0f --- /dev/null +++ b/fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "gelu" +dropout: 0.1 +attention_dropout: 0.1 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 768 +decoder_output_dim: 768 +decoder_input_dim: 768 +decoder_ffn_embed_dim: 3072 +decoder_layers: 12 +decoder_attention_heads: 12 +decoder_normalize_before: true +no_decoder_final_norm: false +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/fairseq/config/model/transformer_lm/transformer_lm_gpt2_big.yaml b/fairseq/config/model/transformer_lm/transformer_lm_gpt2_big.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a08769a1781abdb13302bf57bf1338bcaf68a0ec --- /dev/null +++ b/fairseq/config/model/transformer_lm/transformer_lm_gpt2_big.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "gelu" +dropout: 0.1 +attention_dropout: 0.1 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 1600 +decoder_output_dim: 1600 +decoder_input_dim: 1600 +decoder_ffn_embed_dim: 6400 +decoder_layers: 48 +decoder_attention_heads: 25 +decoder_normalize_before: true +no_decoder_final_norm: false +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/fairseq/config/model/transformer_lm/transformer_lm_gpt2_medium.yaml b/fairseq/config/model/transformer_lm/transformer_lm_gpt2_medium.yaml new file mode 100644 index 0000000000000000000000000000000000000000..64261d793c0f1ae091c9bf5c8c77093a07326137 --- /dev/null +++ b/fairseq/config/model/transformer_lm/transformer_lm_gpt2_medium.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "gelu" +dropout: 0.1 +attention_dropout: 0.1 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 1280 +decoder_output_dim: 1280 +decoder_input_dim: 1280 +decoder_ffn_embed_dim: 5120 +decoder_layers: 36 +decoder_attention_heads: 20 +decoder_normalize_before: true +no_decoder_final_norm: false +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/fairseq/config/model/transformer_lm/transformer_lm_gpt2_small.yaml b/fairseq/config/model/transformer_lm/transformer_lm_gpt2_small.yaml new file mode 100644 index 0000000000000000000000000000000000000000..702e81f466c82edf40433589d389edbe0a7b96db --- /dev/null +++ b/fairseq/config/model/transformer_lm/transformer_lm_gpt2_small.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "gelu" +dropout: 0.1 +attention_dropout: 0.1 +activation_dropout: 0.0 +relu_dropout: 0.0 +decoder_embed_dim: 1024 +decoder_output_dim: 1024 +decoder_input_dim: 1024 +decoder_ffn_embed_dim: 4096 +decoder_layers: 24 +decoder_attention_heads: 16 +decoder_normalize_before: true +no_decoder_final_norm: false +adaptive_softmax_cutoff: null +adaptive_softmax_dropout: 0 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: false +adaptive_input_factor: 4 +adaptive_input_cutoff: null +tie_adaptive_weights: false +tie_adaptive_proj: false +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/fairseq/config/model/transformer_lm/transformer_lm_wiki103.yaml b/fairseq/config/model/transformer_lm/transformer_lm_wiki103.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1154cfa660ee5ce6a272cd1a0049eead1e92c117 --- /dev/null +++ b/fairseq/config/model/transformer_lm/transformer_lm_wiki103.yaml @@ -0,0 +1,36 @@ +# @package _group_ +activation_fn: "relu" +dropout: 0.3 +attention_dropout: 0.1 +activation_dropout: 0.1 +relu_dropout: 0.1 +decoder_embed_dim: 1024 +decoder_output_dim: 1024 +decoder_input_dim: 1024 +decoder_ffn_embed_dim: 4096 +decoder_layers: 16 +decoder_attention_heads: 8 +decoder_normalize_before: true +no_decoder_final_norm: true +adaptive_softmax_cutoff: "20000,60000" +adaptive_softmax_dropout: 0.2 +adaptive_softmax_factor: 4 +no_token_positional_embeddings: false +share_decoder_input_output_embed: false +character_embeddings: false +character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]" +character_embedding_dim: 4 +char_embedder_highway_layers: 2 +adaptive_input: true +adaptive_input_factor: 4 +adaptive_input_cutoff: "20000,60000" +tie_adaptive_weights: true +tie_adaptive_proj: true +decoder_learned_pos: false +decoder_layerdrop: 0 +decoder_layers_to_keep: null +layernorm_embedding: false +no_scale_embedding: false +quant_noise_pq: 0 +quant_noise_pq_block_size: 8 +quant_noise_scalar: 0 diff --git a/fairseq/config/model/wav2vec/vq_wav2vec_gumbel.yaml b/fairseq/config/model/wav2vec/vq_wav2vec_gumbel.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ee1329bf4612d8bb295c6cc3d8bc0a3bcef1777d --- /dev/null +++ b/fairseq/config/model/wav2vec/vq_wav2vec_gumbel.yaml @@ -0,0 +1,5 @@ +# @package _group_ +activation: gelu +vq_type: gumbel +vq_depth: 2 +combine_groups: true diff --git a/fairseq/config/model/wav2vec2/wav2vec2_base.yaml b/fairseq/config/model/wav2vec2/wav2vec2_base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ce65499b808b9a3821cee4ca87c36e84d09005a1 --- /dev/null +++ b/fairseq/config/model/wav2vec2/wav2vec2_base.yaml @@ -0,0 +1,8 @@ +# @package _group_ + +quantize_targets: true +final_dim: 256 +encoder_layerdrop: 0.05 +dropout_input: 0.1 +dropout_features: 0.1 +feature_grad_mult: 0.1 diff --git a/fairseq/config/model/wav2vec2/wav2vec2_large.yaml b/fairseq/config/model/wav2vec2/wav2vec2_large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5846f75243f27f201c85bfe6820815c015971275 --- /dev/null +++ b/fairseq/config/model/wav2vec2/wav2vec2_large.yaml @@ -0,0 +1,20 @@ +# @package _group_ + +quantize_targets: true +extractor_mode: layer_norm +layer_norm_first: true +final_dim: 768 +latent_temp: [2.0,0.1,0.999995] +encoder_layerdrop: 0.0 +dropout_input: 0.0 +dropout_features: 0.0 +dropout: 0.0 +attention_dropout: 0.0 +conv_bias: true + +encoder_layers: 24 +encoder_embed_dim: 1024 +encoder_ffn_embed_dim: 4096 +encoder_attention_heads: 16 + +feature_grad_mult: 1.0 diff --git a/fairseq/criterions/__init__.py b/fairseq/criterions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ecd65d34adec4222ac8781106560ebc5dc2622f5 --- /dev/null +++ b/fairseq/criterions/__init__.py @@ -0,0 +1,36 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import importlib +import os + +from fairseq import registry +from fairseq.criterions.fairseq_criterion import ( # noqa + FairseqCriterion, + LegacyFairseqCriterion, +) +from omegaconf import DictConfig + + +( + build_criterion_, + register_criterion, + CRITERION_REGISTRY, + CRITERION_DATACLASS_REGISTRY, +) = registry.setup_registry( + "--criterion", base_class=FairseqCriterion, default="cross_entropy" +) + + +def build_criterion(cfg: DictConfig, task, from_checkpoint=False): + return build_criterion_(cfg, task, from_checkpoint=from_checkpoint) + + +# automatically import any Python files in the criterions/ directory +for file in sorted(os.listdir(os.path.dirname(__file__))): + if file.endswith(".py") and not file.startswith("_"): + file_name = file[: file.find(".py")] + importlib.import_module("fairseq.criterions." + file_name) diff --git a/fairseq/criterions/__pycache__/__init__.cpython-310.pyc b/fairseq/criterions/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e6da576ed5a301b1ad20fa4b1eb9ca4fd45ab08 Binary files /dev/null and b/fairseq/criterions/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/criterions/__pycache__/adaptive_loss.cpython-310.pyc b/fairseq/criterions/__pycache__/adaptive_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e37a0231659af5e95855d42b7dd0d7313fcd799a Binary files /dev/null and b/fairseq/criterions/__pycache__/adaptive_loss.cpython-310.pyc differ diff --git a/fairseq/criterions/__pycache__/composite_loss.cpython-310.pyc b/fairseq/criterions/__pycache__/composite_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ac23473716899107366de04b398bf9cf3b31d14 Binary files /dev/null and b/fairseq/criterions/__pycache__/composite_loss.cpython-310.pyc differ diff --git a/fairseq/criterions/__pycache__/cross_entropy.cpython-310.pyc b/fairseq/criterions/__pycache__/cross_entropy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1e626ab57c8ca85d0c3accb630172a8193d21b3 Binary files /dev/null and b/fairseq/criterions/__pycache__/cross_entropy.cpython-310.pyc differ diff --git a/fairseq/criterions/__pycache__/ctc.cpython-310.pyc b/fairseq/criterions/__pycache__/ctc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca2a4a20b678f57c9eefe1cf43a542b5087db9d2 Binary files /dev/null and b/fairseq/criterions/__pycache__/ctc.cpython-310.pyc differ diff --git a/fairseq/criterions/__pycache__/fairseq_criterion.cpython-310.pyc b/fairseq/criterions/__pycache__/fairseq_criterion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9b78982d36742e53430096833868dff426feed8 Binary files /dev/null and b/fairseq/criterions/__pycache__/fairseq_criterion.cpython-310.pyc differ diff --git a/fairseq/criterions/__pycache__/fastspeech2_loss.cpython-310.pyc b/fairseq/criterions/__pycache__/fastspeech2_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80839212137cda59de2cee486c568898643ffb34 Binary files /dev/null and b/fairseq/criterions/__pycache__/fastspeech2_loss.cpython-310.pyc differ diff --git a/fairseq/criterions/__pycache__/hubert_criterion.cpython-310.pyc b/fairseq/criterions/__pycache__/hubert_criterion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..940d12da3f230f42ab5911ea6d355bfc46a13846 Binary files /dev/null and b/fairseq/criterions/__pycache__/hubert_criterion.cpython-310.pyc differ diff --git a/fairseq/criterions/__pycache__/label_smoothed_cross_entropy.cpython-310.pyc b/fairseq/criterions/__pycache__/label_smoothed_cross_entropy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fae79baa33f19e12b759dbc9acbbeb0b9862c8c Binary files /dev/null and b/fairseq/criterions/__pycache__/label_smoothed_cross_entropy.cpython-310.pyc differ diff --git a/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_latency_augmented.cpython-310.pyc b/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_latency_augmented.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..679d695e925c8122dc85fdb4e1050b84456e2ba0 Binary files /dev/null and b/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_latency_augmented.cpython-310.pyc differ diff --git a/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_alignment.cpython-310.pyc b/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_alignment.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca019141dfbd438854fff8b139727ea3e1b0cabe Binary files /dev/null and b/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_alignment.cpython-310.pyc differ diff --git a/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_ctc.cpython-310.pyc b/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_ctc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea28421604c25b4ca0c122662db78c0883cbd1f9 Binary files /dev/null and b/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_ctc.cpython-310.pyc differ diff --git a/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_rdrop.cpython-310.pyc b/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_rdrop.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e91d191ac2d502aa988f5c1ed922d13cf91f302b Binary files /dev/null and b/fairseq/criterions/__pycache__/label_smoothed_cross_entropy_with_rdrop.cpython-310.pyc differ diff --git a/fairseq/criterions/__pycache__/legacy_masked_lm.cpython-310.pyc b/fairseq/criterions/__pycache__/legacy_masked_lm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ea909fd7c6ce9764879fb9ab98ae43bb598b114 Binary files /dev/null and b/fairseq/criterions/__pycache__/legacy_masked_lm.cpython-310.pyc differ diff --git a/fairseq/criterions/__pycache__/masked_lm.cpython-310.pyc b/fairseq/criterions/__pycache__/masked_lm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e193b0903873b007d20d58bf02d232522074d465 Binary files /dev/null and b/fairseq/criterions/__pycache__/masked_lm.cpython-310.pyc differ diff --git a/fairseq/criterions/__pycache__/model_criterion.cpython-310.pyc b/fairseq/criterions/__pycache__/model_criterion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4cf6b90972650076b58e91ef569019ad907c80b Binary files /dev/null and b/fairseq/criterions/__pycache__/model_criterion.cpython-310.pyc differ diff --git a/fairseq/criterions/__pycache__/nat_loss.cpython-310.pyc b/fairseq/criterions/__pycache__/nat_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e147faebbe0b59bf570bacb7389f44871a3dab8 Binary files /dev/null and b/fairseq/criterions/__pycache__/nat_loss.cpython-310.pyc differ diff --git a/fairseq/criterions/__pycache__/sentence_prediction.cpython-310.pyc b/fairseq/criterions/__pycache__/sentence_prediction.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd13a235458033831f8c2a290579ac991fa99d7a Binary files /dev/null and b/fairseq/criterions/__pycache__/sentence_prediction.cpython-310.pyc differ diff --git a/fairseq/criterions/__pycache__/sentence_prediction_adapters.cpython-310.pyc b/fairseq/criterions/__pycache__/sentence_prediction_adapters.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..048decea193f6878351b5c7913e973f332d3f265 Binary files /dev/null and b/fairseq/criterions/__pycache__/sentence_prediction_adapters.cpython-310.pyc differ diff --git a/fairseq/criterions/__pycache__/sentence_ranking.cpython-310.pyc b/fairseq/criterions/__pycache__/sentence_ranking.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1eb773c085434bf5d8232e07873413c0ac21d40 Binary files /dev/null and b/fairseq/criterions/__pycache__/sentence_ranking.cpython-310.pyc differ diff --git a/fairseq/criterions/__pycache__/speech_dlm_criterion.cpython-310.pyc b/fairseq/criterions/__pycache__/speech_dlm_criterion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9168a22f2aa0be1eaba9f9b2287f2198b5606cc Binary files /dev/null and b/fairseq/criterions/__pycache__/speech_dlm_criterion.cpython-310.pyc differ diff --git a/fairseq/criterions/__pycache__/speech_to_speech_criterion.cpython-310.pyc b/fairseq/criterions/__pycache__/speech_to_speech_criterion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..935bc362dd2f192773d81cc9b835216c918da136 Binary files /dev/null and b/fairseq/criterions/__pycache__/speech_to_speech_criterion.cpython-310.pyc differ diff --git a/fairseq/criterions/__pycache__/speech_ulm_criterion.cpython-310.pyc b/fairseq/criterions/__pycache__/speech_ulm_criterion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..900f3e118be9a151390a4f2bf594af8c7cf981c5 Binary files /dev/null and b/fairseq/criterions/__pycache__/speech_ulm_criterion.cpython-310.pyc differ diff --git a/fairseq/criterions/__pycache__/tacotron2_loss.cpython-310.pyc b/fairseq/criterions/__pycache__/tacotron2_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c759aaef7958485ee5e0e1830b0162b9b96c148e Binary files /dev/null and b/fairseq/criterions/__pycache__/tacotron2_loss.cpython-310.pyc differ diff --git a/fairseq/criterions/__pycache__/wav2vec_criterion.cpython-310.pyc b/fairseq/criterions/__pycache__/wav2vec_criterion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd211abd4b456ae38094752266b7ceb2acc89f2a Binary files /dev/null and b/fairseq/criterions/__pycache__/wav2vec_criterion.cpython-310.pyc differ diff --git a/fairseq/criterions/adaptive_loss.py b/fairseq/criterions/adaptive_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..fc1ac8540461ab402e6ba6b4b40afe363774fffc --- /dev/null +++ b/fairseq/criterions/adaptive_loss.py @@ -0,0 +1,124 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass + +import torch.nn.functional as F +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.constants import DDP_BACKEND_CHOICES +from omegaconf import II + + +@dataclass +class AdaptiveLossConfig(FairseqDataclass): + sentence_avg: bool = II("optimization.sentence_avg") + ddp_backend: DDP_BACKEND_CHOICES = II("distributed_training.ddp_backend") + + +@register_criterion("adaptive_loss", dataclass=AdaptiveLossConfig) +class AdaptiveLoss(FairseqCriterion): + """This is an implementation of the loss function accompanying the adaptive softmax approximation for + graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs" + (http://arxiv.org/abs/1609.04309).""" + + def __init__(self, task, sentence_avg): + super().__init__(task) + self.sentence_avg = sentence_avg + + @classmethod + def build_criterion(cls, cfg: AdaptiveLossConfig, task): + if cfg.ddp_backend in {"c10d", "pytorch_ddp"}: + raise Exception( + "AdaptiveLoss is not compatible with the PyTorch " + "version of DistributedDataParallel. Please use " + "`--ddp-backend=legacy_ddp` instead." + ) + return cls(task, cfg.sentence_avg) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + + assert ( + hasattr(model.decoder, "adaptive_softmax") + and model.decoder.adaptive_softmax is not None + ) + adaptive_softmax = model.decoder.adaptive_softmax + + net_output = model(**sample["net_input"]) + orig_target = model.get_targets(sample, net_output) + + nsentences = orig_target.size(0) + orig_target = orig_target.view(-1) + + bsz = orig_target.size(0) + + logits, target = adaptive_softmax(net_output[0], orig_target) + assert len(target) == len(logits) + + loss = net_output[0].new(1 if reduce else bsz).zero_() + + for i in range(len(target)): + if target[i] is not None: + assert target[i].min() >= 0 and target[i].max() <= logits[i].size(1) + loss += F.cross_entropy( + logits[i], + target[i], + ignore_index=self.padding_idx, + reduction="sum" if reduce else "none", + ) + + orig = utils.strip_pad(orig_target, self.padding_idx) + ntokens = orig.numel() + sample_size = sample["target"].size(0) if self.sentence_avg else ntokens + logging_output = { + "loss": loss.data, + "ntokens": ntokens, + "nsentences": nsentences, + "sample_size": sample_size, + } + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) + ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) + sample_size = utils.item( + sum(log.get("sample_size", 0) for log in logging_outputs) + ) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + if sample_size != ntokens: + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) + else: + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/criterions/composite_loss.py b/fairseq/criterions/composite_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..98e835fa6e4c0bcad062df9c519701bf795c98be --- /dev/null +++ b/fairseq/criterions/composite_loss.py @@ -0,0 +1,100 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq import utils +from fairseq.criterions import LegacyFairseqCriterion, register_criterion +from torch import nn + + +@register_criterion("composite_loss") +class CompositeLoss(LegacyFairseqCriterion): + """This is a composite loss that, given a list of model outputs and a list of targets, + computes an average of losses for each output-target pair""" + + def __init__(self, args, task): + super().__init__(args, task) + self.underlying_criterion = args.underlying_criterion + + @staticmethod + def add_args(parser): + """Add criterion-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--underlying-criterion', type=str, metavar='VAL', required=True, + help='underlying criterion to use for the composite loss') + # fmt: on + + @staticmethod + def build_underlying_criterion(args, task): + saved_criterion = args.criterion + args.criterion = args.underlying_criterion + assert saved_criterion != args.underlying_criterion + underlying_criterion = task.build_criterion(args) + args.criterion = saved_criterion + return underlying_criterion + + @classmethod + def build_criterion(cls, args, task): + underlying_criterion = CompositeLoss.build_underlying_criterion(args, task) + + class FakeModel(nn.Module): + def __init__(self, model, net_out, target): + super().__init__() + self.model = model + self.net_out = net_out + self.target = target + + def forward(self, **unused): + return self.net_out + + def get_normalized_probs(self, net_output, log_probs, sample=None): + return self.model.get_normalized_probs( + net_output, log_probs, sample=sample + ) + + def get_targets(self, *unused): + return self.target + + @property + def decoder(self): + return self.model.decoder + + class _CompositeLoss(LegacyFairseqCriterion): + def __init__(self, args, task, underlying_criterion): + super().__init__(args, task) + self.underlying_criterion = underlying_criterion + + def forward(self, model, sample, reduce=True): + net_outputs = model(**sample["net_input"]) + targets = sample["target"] + + bsz = targets[0].size(0) + loss = net_outputs[0][0].new(1 if reduce else bsz).float().zero_() + + sample_size = 0 + logging_output = {} + for o, t in zip(net_outputs[0], targets): + m = FakeModel(model, (o, net_outputs[1]), t) + sample["target"] = t + l, ss, logging_output = self.underlying_criterion(m, sample, reduce) + loss += l + sample_size += ss + + loss.div_(len(targets)) + sample_size /= len(targets) + + logging_output["loss"] = utils.item(loss.data) if reduce else loss.data + return loss, sample_size, logging_output + + @staticmethod + def aggregate_logging_outputs(logging_outputs): + return underlying_criterion.__class__.aggregate_logging_outputs( + logging_outputs + ) + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + underlying_criterion.__class__.reduce_metrics(logging_outputs) + + return _CompositeLoss(args, task, underlying_criterion) diff --git a/fairseq/criterions/cross_entropy.py b/fairseq/criterions/cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..24d6bcd6128f9bfb7667adcdb3e7f001cc57a523 --- /dev/null +++ b/fairseq/criterions/cross_entropy.py @@ -0,0 +1,91 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass + +import torch.nn.functional as F +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from omegaconf import II + + +@dataclass +class CrossEntropyCriterionConfig(FairseqDataclass): + sentence_avg: bool = II("optimization.sentence_avg") + + +@register_criterion("cross_entropy", dataclass=CrossEntropyCriterionConfig) +class CrossEntropyCriterion(FairseqCriterion): + def __init__(self, task, sentence_avg): + super().__init__(task) + self.sentence_avg = sentence_avg + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample["net_input"]) + loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce) + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + logging_output = { + "loss": loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + return loss, sample_size, logging_output + + def compute_loss(self, model, net_output, sample, reduce=True): + lprobs = model.get_normalized_probs(net_output, log_probs=True) + lprobs = lprobs.view(-1, lprobs.size(-1)) + target = model.get_targets(sample, net_output).view(-1) + loss = F.nll_loss( + lprobs, + target, + ignore_index=self.padding_idx, + reduction="sum" if reduce else "none", + ) + return loss, loss + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + # we divide by log(2) to convert the loss from base e to base 2 + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + if sample_size != ntokens: + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) + else: + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/criterions/ctc.py b/fairseq/criterions/ctc.py new file mode 100644 index 0000000000000000000000000000000000000000..368213cb2b05bcfcec3ae84aef68c82bd792492b --- /dev/null +++ b/fairseq/criterions/ctc.py @@ -0,0 +1,325 @@ +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import math +from argparse import Namespace +from dataclasses import dataclass, field +from omegaconf import II +from typing import Optional + +import torch +import torch.nn.functional as F + +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from fairseq.data.data_utils import post_process +from fairseq.tasks import FairseqTask +from fairseq.logging.meters import safe_round + + +@dataclass +class CtcCriterionConfig(FairseqDataclass): + zero_infinity: bool = field( + default=False, + metadata={"help": "zero inf loss when source length <= target length"}, + ) + sentence_avg: bool = II("optimization.sentence_avg") + post_process: str = field( + default="letter", + metadata={ + "help": "how to post process predictions into words. can be letter, " + "wordpiece, BPE symbols, etc. " + "See fairseq.data.data_utils.post_process() for full list of options" + }, + ) + wer_kenlm_model: Optional[str] = field( + default=None, + metadata={ + "help": "if this is provided, use kenlm to compute wer (along with other wer_* args)" + }, + ) + wer_lexicon: Optional[str] = field( + default=None, + metadata={"help": "lexicon to use with wer_kenlm_model"}, + ) + wer_lm_weight: float = field( + default=2.0, + metadata={"help": "lm weight to use with wer_kenlm_model"}, + ) + wer_word_score: float = field( + default=-1.0, + metadata={"help": "lm word score to use with wer_kenlm_model"}, + ) + wer_sil_weight: float = field( + default=0, + metadata={"help": "lm word score to use with wer_kenlm_model"}, + ) + + wer_args: Optional[str] = field( + default=None, + metadata={ + "help": "DEPRECATED: tuple of (wer_kenlm_model, wer_lexicon, wer_lm_weight, wer_word_score)" + }, + ) + + +@register_criterion("ctc", dataclass=CtcCriterionConfig) +class CtcCriterion(FairseqCriterion): + def __init__( + self, cfg: CtcCriterionConfig, task: FairseqTask, rdrop_alpha: int = 0.0 + ): + super().__init__(task) + self.blank_idx = ( + task.target_dictionary.index(task.blank_symbol) + if hasattr(task, "blank_symbol") + else 0 + ) + self.pad_idx = task.target_dictionary.pad() + self.eos_idx = task.target_dictionary.eos() + self.post_process = cfg.post_process + + self.rdrop_alpha = rdrop_alpha + + if cfg.wer_args is not None: + ( + cfg.wer_kenlm_model, + cfg.wer_lexicon, + cfg.wer_lm_weight, + cfg.wer_word_score, + ) = eval(cfg.wer_args) + + if cfg.wer_kenlm_model is not None and cfg.wer_kenlm_model != "": + from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder + + dec_args = Namespace() + dec_args.nbest = 1 + dec_args.criterion = "ctc" + dec_args.kenlm_model = cfg.wer_kenlm_model + dec_args.lexicon = cfg.wer_lexicon + dec_args.beam = 50 + dec_args.beam_size_token = min(50, len(task.target_dictionary)) + dec_args.beam_threshold = min(50, len(task.target_dictionary)) + dec_args.lm_weight = cfg.wer_lm_weight + dec_args.word_score = cfg.wer_word_score + dec_args.sil_weight = cfg.wer_sil_weight + dec_args.unk_weight = -math.inf + dec_args.sil_weight = 0 + + self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary) + else: + self.w2l_decoder = None + + self.zero_infinity = cfg.zero_infinity + self.sentence_avg = cfg.sentence_avg + + def forward(self, model, sample, reduce=True, **kwargs): + net_output = model(**sample["net_input"]) + lprobs = model.get_normalized_probs( + net_output, log_probs=True + ).contiguous() # (T, B, C) from the encoder + + # CTC loss is calculated over duplicated inputs + # sample is already duplicated for R-Drop + if self.rdrop_alpha > 0: + for k, v in sample.items(): + if k in ["target", "target_lengths"]: + sample[k] = torch.cat([v, v.clone()], dim=0) + elif k == "net_input": + if sample[k]["src_tokens"].size(1) != sample[k]["src_lengths"].size( + 0 + ): + # for decoder CTC loss + sample[k]["src_lengths"] = torch.cat( + [ + sample[k]["src_lengths"], + sample[k]["src_lengths"].clone(), + ], + dim=0, + ) + + if "src_lengths" in sample["net_input"]: + input_lengths = sample["net_input"]["src_lengths"] + else: + if net_output["padding_mask"] is not None: + non_padding_mask = ~net_output["padding_mask"] + input_lengths = non_padding_mask.long().sum(-1) + else: + input_lengths = lprobs.new_full( + (lprobs.size(1),), lprobs.size(0), dtype=torch.long + ) + + pad_mask = (sample["target"] != self.pad_idx) & ( + sample["target"] != self.eos_idx + ) + targets_flat = sample["target"].masked_select(pad_mask) + if "target_lengths" in sample: + target_lengths = sample["target_lengths"] + else: + target_lengths = pad_mask.sum(-1) + + with torch.backends.cudnn.flags(enabled=False): + loss = F.ctc_loss( + lprobs, + targets_flat, + input_lengths, + target_lengths, + blank=self.blank_idx, + reduction="sum", + zero_infinity=self.zero_infinity, + ) + + ntokens = ( + sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item() + ) + + sample_size = sample["target"].size(0) if self.sentence_avg else ntokens + logging_output = { + "loss": utils.item(loss.data), # * sample['ntokens'], + "ntokens": ntokens, + "nsentences": sample["id"].numel(), + "sample_size": sample_size, + } + + if not model.training: + import editdistance + + with torch.no_grad(): + lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu() + + c_err = 0 + c_len = 0 + w_errs = 0 + w_len = 0 + wv_errs = 0 + for lp, t, inp_l in zip( + lprobs_t, + sample["target_label"] + if "target_label" in sample + else sample["target"], + input_lengths, + ): + lp = lp[:inp_l].unsqueeze(0) + + decoded = None + if self.w2l_decoder is not None: + decoded = self.w2l_decoder.decode(lp) + if len(decoded) < 1: + decoded = None + else: + decoded = decoded[0] + if len(decoded) < 1: + decoded = None + else: + decoded = decoded[0] + + p = (t != self.task.target_dictionary.pad()) & ( + t != self.task.target_dictionary.eos() + ) + targ = t[p] + targ_units = self.task.target_dictionary.string(targ) + targ_units_arr = targ.tolist() + + toks = lp.argmax(dim=-1).unique_consecutive() + pred_units_arr = toks[toks != self.blank_idx].tolist() + + c_err += editdistance.eval(pred_units_arr, targ_units_arr) + c_len += len(targ_units_arr) + + targ_words = post_process(targ_units, self.post_process).split() + + pred_units = self.task.target_dictionary.string(pred_units_arr) + pred_words_raw = post_process(pred_units, self.post_process).split() + + if decoded is not None and "words" in decoded: + pred_words = decoded["words"] + w_errs += editdistance.eval(pred_words, targ_words) + wv_errs += editdistance.eval(pred_words_raw, targ_words) + else: + dist = editdistance.eval(pred_words_raw, targ_words) + w_errs += dist + wv_errs += dist + + w_len += len(targ_words) + + logging_output["wv_errors"] = wv_errs + logging_output["w_errors"] = w_errs + logging_output["w_total"] = w_len + logging_output["c_errors"] = c_err + logging_output["c_total"] = c_len + + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + + loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) + ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) + nsentences = utils.item( + sum(log.get("nsentences", 0) for log in logging_outputs) + ) + sample_size = utils.item( + sum(log.get("sample_size", 0) for log in logging_outputs) + ) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar("ntokens", ntokens) + metrics.log_scalar("nsentences", nsentences) + if sample_size != ntokens: + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + + c_errors = sum(log.get("c_errors", 0) for log in logging_outputs) + metrics.log_scalar("_c_errors", c_errors) + c_total = sum(log.get("c_total", 0) for log in logging_outputs) + metrics.log_scalar("_c_total", c_total) + w_errors = sum(log.get("w_errors", 0) for log in logging_outputs) + metrics.log_scalar("_w_errors", w_errors) + wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs) + metrics.log_scalar("_wv_errors", wv_errors) + w_total = sum(log.get("w_total", 0) for log in logging_outputs) + metrics.log_scalar("_w_total", w_total) + + if c_total > 0: + metrics.log_derived( + "uer", + lambda meters: safe_round( + meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3 + ) + if meters["_c_total"].sum > 0 + else float("nan"), + ) + if w_total > 0: + metrics.log_derived( + "wer", + lambda meters: safe_round( + meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3 + ) + if meters["_w_total"].sum > 0 + else float("nan"), + ) + metrics.log_derived( + "raw_wer", + lambda meters: safe_round( + meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3 + ) + if meters["_w_total"].sum > 0 + else float("nan"), + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/criterions/fairseq_criterion.py b/fairseq/criterions/fairseq_criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..0b1e64a8e36c55be90d7e3f854effd99ed5bcc44 --- /dev/null +++ b/fairseq/criterions/fairseq_criterion.py @@ -0,0 +1,121 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import inspect +from typing import Any, Dict, List + +from fairseq import utils +from fairseq.logging import metrics +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import gen_parser_from_dataclass +from torch.nn.modules.loss import _Loss + + +class FairseqCriterion(_Loss): + def __init__(self, task): + super().__init__() + self.task = task + if hasattr(task, "target_dictionary"): + tgt_dict = task.target_dictionary + self.padding_idx = tgt_dict.pad() if tgt_dict is not None else -100 + + @classmethod + def add_args(cls, parser): + """Add criterion-specific arguments to the parser.""" + dc = getattr(cls, "__dataclass", None) + if dc is not None: + gen_parser_from_dataclass(parser, dc()) + + @classmethod + def build_criterion(cls, cfg: FairseqDataclass, task): + """Construct a criterion from command-line args.""" + # arguments in the __init__. + init_args = {} + for p in inspect.signature(cls).parameters.values(): + if ( + p.kind == p.POSITIONAL_ONLY + or p.kind == p.VAR_POSITIONAL + or p.kind == p.VAR_KEYWORD + ): + # we haven't implemented inference for these argument types, + # but PRs welcome :) + raise NotImplementedError("{} not supported".format(p.kind)) + + assert p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY} + + if p.name == "task": + init_args["task"] = task + elif p.name == "cfg": + init_args["cfg"] = cfg + elif hasattr(cfg, p.name): + init_args[p.name] = getattr(cfg, p.name) + elif p.default != p.empty: + pass # we'll use the default value + else: + raise NotImplementedError( + "Unable to infer Criterion arguments, please implement " + "{}.build_criterion".format(cls.__name__) + ) + return cls(**init_args) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + raise NotImplementedError + + @staticmethod + def aggregate_logging_outputs( + logging_outputs: List[Dict[str, Any]] + ) -> Dict[str, Any]: + """Aggregate logging outputs from data parallel training.""" + utils.deprecation_warning( + "The aggregate_logging_outputs API is deprecated. " + "Please use the reduce_metrics API instead." + ) + raise NotImplementedError + + @classmethod + def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None: + """Aggregate logging outputs from data parallel training.""" + utils.deprecation_warning( + "Criterions should implement the reduce_metrics API. " + "Falling back to deprecated aggregate_logging_outputs API." + ) + agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs) + for k, v in agg_logging_outputs.items(): + if k in {"nsentences", "ntokens", "sample_size"}: + continue + metrics.log_scalar(k, v) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return False + + +class LegacyFairseqCriterion(FairseqCriterion): + def __init__(self, args, task): + super().__init__(task=task) + self.args = args + + utils.deprecation_warning( + "Criterions should take explicit arguments instead of an " + "argparse.Namespace object, please update your criterion by " + "extending FairseqCriterion instead of LegacyFairseqCriterion." + ) + + @classmethod + def build_criterion(cls, args, task): + """Construct a criterion from command-line args.""" + return cls(args, task) diff --git a/fairseq/criterions/fastspeech2_loss.py b/fairseq/criterions/fastspeech2_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..ab7cd08e3bd9c8d5c4be017095034b18362d77e0 --- /dev/null +++ b/fairseq/criterions/fastspeech2_loss.py @@ -0,0 +1,137 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +from typing import List, Dict, Any +from dataclasses import dataclass, field + +import torch +import torch.nn.functional as F + +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from fairseq.data.data_utils import lengths_to_mask +from fairseq.models.fairseq_model import FairseqEncoderModel + + +@dataclass +class FastSpeech2CriterionConfig(FairseqDataclass): + ctc_weight: float = field(default=0.0, metadata={"help": "weight for CTC loss"}) + + +@register_criterion("fastspeech2", dataclass=FastSpeech2CriterionConfig) +class FastSpeech2Loss(FairseqCriterion): + def __init__(self, task, ctc_weight): + super().__init__(task) + self.ctc_weight = ctc_weight + + def forward(self, model: FairseqEncoderModel, sample, reduction="mean"): + src_tokens = sample["net_input"]["src_tokens"] + src_lens = sample["net_input"]["src_lengths"] + tgt_lens = sample["target_lengths"] + _feat_out, _feat_out_post, _, log_dur_out, pitch_out, energy_out = model( + src_tokens=src_tokens, + src_lengths=src_lens, + prev_output_tokens=sample["net_input"]["prev_output_tokens"], + incremental_state=None, + target_lengths=tgt_lens, + speaker=sample["speaker"], + durations=sample["durations"], + pitches=sample["pitches"], + energies=sample["energies"], + ) + + src_mask = lengths_to_mask(sample["net_input"]["src_lengths"]) + tgt_mask = lengths_to_mask(sample["target_lengths"]) + + pitches, energies = sample["pitches"], sample["energies"] + pitch_out, pitches = pitch_out[src_mask], pitches[src_mask] + energy_out, energies = energy_out[src_mask], energies[src_mask] + + feat_out, feat = _feat_out[tgt_mask], sample["target"][tgt_mask] + l1_loss = F.l1_loss(feat_out, feat, reduction=reduction) + if _feat_out_post is not None: + l1_loss += F.l1_loss(_feat_out_post[tgt_mask], feat, reduction=reduction) + + pitch_loss = F.mse_loss(pitch_out, pitches, reduction=reduction) + energy_loss = F.mse_loss(energy_out, energies, reduction=reduction) + + log_dur_out = log_dur_out[src_mask] + dur = sample["durations"].float() + dur = dur.half() if log_dur_out.type().endswith(".HalfTensor") else dur + log_dur = torch.log(dur + 1)[src_mask] + dur_loss = F.mse_loss(log_dur_out, log_dur, reduction=reduction) + + ctc_loss = torch.tensor(0.0).type_as(l1_loss) + if self.ctc_weight > 0.0: + lprobs = model.get_normalized_probs((_feat_out,), log_probs=True) + lprobs = lprobs.transpose(0, 1) # T x B x C + src_mask = lengths_to_mask(src_lens) + src_tokens_flat = src_tokens.masked_select(src_mask) + ctc_loss = ( + F.ctc_loss( + lprobs, + src_tokens_flat, + tgt_lens, + src_lens, + reduction=reduction, + zero_infinity=True, + ) + * self.ctc_weight + ) + + loss = l1_loss + dur_loss + pitch_loss + energy_loss + ctc_loss + + sample_size = sample["nsentences"] + logging_output = { + "loss": utils.item(loss.data), + "ntokens": sample["ntokens"], + "nsentences": sample["nsentences"], + "sample_size": sample_size, + "l1_loss": utils.item(l1_loss.data), + "dur_loss": utils.item(dur_loss.data), + "pitch_loss": utils.item(pitch_loss.data), + "energy_loss": utils.item(energy_loss.data), + "ctc_loss": utils.item(ctc_loss.data), + } + return loss, sample_size, logging_output + + @classmethod + def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None: + ns = [log.get("sample_size", 0) for log in logging_outputs] + ntot = sum(ns) + ws = [n / (ntot + 1e-8) for n in ns] + for key in [ + "loss", + "l1_loss", + "dur_loss", + "pitch_loss", + "energy_loss", + "ctc_loss", + ]: + vals = [log.get(key, 0) for log in logging_outputs] + val = sum(val * w for val, w in zip(vals, ws)) + metrics.log_scalar(key, val, ntot, round=3) + metrics.log_scalar("sample_size", ntot, len(logging_outputs)) + + # inference metrics + if "targ_frames" not in logging_outputs[0]: + return + n = sum(log.get("targ_frames", 0) for log in logging_outputs) + for key, new_key in [ + ("mcd_loss", "mcd_loss"), + ("pred_frames", "pred_ratio"), + ("nins", "ins_rate"), + ("ndel", "del_rate"), + ]: + val = sum(log.get(key, 0) for log in logging_outputs) + metrics.log_scalar(new_key, val / n, n, round=3) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + return False diff --git a/fairseq/criterions/hubert_criterion.py b/fairseq/criterions/hubert_criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..262874b582aa4765981fd9cc958c7221596d681e --- /dev/null +++ b/fairseq/criterions/hubert_criterion.py @@ -0,0 +1,195 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import re +from dataclasses import dataclass, field +from typing import List, Optional + +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass + + +@dataclass +class HubertCriterionConfig(FairseqDataclass): + pred_masked_weight: float = field( + default=1.0, + metadata={"help": "weight for predictive loss for masked frames"}, + ) + pred_nomask_weight: float = field( + default=0.0, + metadata={"help": "weight for predictive loss for unmasked frames"}, + ) + loss_weights: Optional[List[float]] = field( + default=None, + metadata={"help": "weights for additional loss terms (not first one)"}, + ) + log_keys: List[str] = field( + default_factory=lambda: [], + metadata={"help": "output keys to log"}, + ) + + +@register_criterion("hubert", dataclass=HubertCriterionConfig) +class HubertCriterion(FairseqCriterion): + def __init__( + self, + task, + pred_masked_weight, + pred_nomask_weight, + loss_weights=None, + log_keys=None, + ): + super().__init__(task) + self.pred_masked_weight = pred_masked_weight + self.pred_nomask_weight = pred_nomask_weight + self.loss_weights = loss_weights + self.log_keys = [] if log_keys is None else log_keys + + def forward(self, model, sample, reduce=True, log_pred=False): + """Compute the loss for the given sample. + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(target_list=sample["target_list"], **sample["net_input"]) + loss = 0.0 + sample_size = 0 + logging_output = {} + reduction = "sum" if reduce else "none" + + loss_m_list = [] + logp_m_list = model.get_logits(net_output, True) + targ_m_list = model.get_targets(net_output, True) + assert self.pred_masked_weight == 0 or len(logp_m_list) > 0 + for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)): + loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction) + loss_m_list.append(loss_m) + logging_output[f"loss_m_{i}"] = loss_m.detach().item() + if self.pred_masked_weight > 0: + loss += self.pred_masked_weight * sum(loss_m_list) + sample_size += targ_m_list[0].numel() + + loss_u_list = [] + logp_u_list = model.get_logits(net_output, False) + targ_u_list = model.get_targets(net_output, False) + assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0 + for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)): + loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction) + loss_u_list.append(loss_u) + logging_output[f"loss_u_{i}"] = loss_u.detach().item() + if self.pred_nomask_weight > 0: + loss += self.pred_nomask_weight * sum(loss_u_list) + sample_size += targ_u_list[0].numel() + + if self.loss_weights is not None: + assert hasattr(model, "get_extra_losses") + extra_losses, names = model.get_extra_losses(net_output) + if torch.is_tensor(extra_losses): + extra_losses = [extra_losses] + names = [names] + if len(self.loss_weights) == 1 and len(extra_losses) != 1: + self.loss_weights = [self.loss_weights[0]] * len(extra_losses) + assert len(extra_losses) == len( + self.loss_weights + ), f"{len(extra_losses)}, {len(self.loss_weights)}" + for p, n, coef in zip(extra_losses, names, self.loss_weights): + if coef != 0 and p is not None: + p = coef * p.float() * sample_size + loss += p + logging_output[f"loss_{n}"] = p.item() + + logging_output = { + "loss": loss.item() if reduce else loss, + "ntokens": sample_size, + "nsentences": sample["id"].numel(), + "sample_size": sample_size, + **logging_output, + } + + for lk in self.log_keys: + if lk in net_output: + logging_output[lk] = float((net_output[lk])) + + def compute_correct(logits): + if logits.numel() == 0: + return 0, 0 + else: + assert logits.dim() > 1, logits.shape + max = logits.argmax(-1) == 0 + min = logits.argmin(-1) == 0 + both = max & min + corr = max.long().sum().item() - both.long().sum().item() + count = max.numel() + return corr, count + + with torch.no_grad(): + for i, logp_m in enumerate(logp_m_list): + corr_m, count_m = compute_correct(logp_m) + logging_output[f"correct_m_{i}"] = corr_m + logging_output[f"count_m_{i}"] = count_m + + for i, logp_u in enumerate(logp_u_list): + corr_u, count_u = compute_correct(logp_u) + logging_output[f"correct_u_{i}"] = corr_u + logging_output[f"count_u_{i}"] = count_u + + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training (copied from normal cross entropy).""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + if sample_size != ntokens: + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) + else: + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) + ) + + counts = {} + for lk in logging_outputs[0].keys(): + if lk.startswith("count_"): + val = sum(log[lk] for log in logging_outputs) + metrics.log_scalar(lk, val) + counts[lk] = val + + for lk in logging_outputs[0].keys(): + if lk.startswith("loss_"): + val = sum(log[lk] for log in logging_outputs) + metrics.log_scalar(lk, val / sample_size / math.log(2), round=3) + elif lk.startswith("correct_"): + val = sum(log[lk] for log in logging_outputs) + metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)]) + + @staticmethod + def aggregate_logging_outputs(logging_outputs): + """Aggregate logging outputs from data parallel training.""" + raise NotImplementedError() + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return False diff --git a/fairseq/criterions/label_smoothed_cross_entropy.py b/fairseq/criterions/label_smoothed_cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..325679bb1678928b9fe644293b39f00115300a15 --- /dev/null +++ b/fairseq/criterions/label_smoothed_cross_entropy.py @@ -0,0 +1,168 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass, field + +import torch +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from omegaconf import II + + +@dataclass +class LabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass): + label_smoothing: float = field( + default=0.0, + metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"}, + ) + report_accuracy: bool = field( + default=False, + metadata={"help": "report accuracy metric"}, + ) + ignore_prefix_size: int = field( + default=0, + metadata={"help": "Ignore first N tokens"}, + ) + sentence_avg: bool = II("optimization.sentence_avg") + + +def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True): + if target.dim() == lprobs.dim() - 1: + target = target.unsqueeze(-1) + nll_loss = -lprobs.gather(dim=-1, index=target) + smooth_loss = -lprobs.sum(dim=-1, keepdim=True) + if ignore_index is not None: + pad_mask = target.eq(ignore_index) + nll_loss.masked_fill_(pad_mask, 0.0) + smooth_loss.masked_fill_(pad_mask, 0.0) + else: + nll_loss = nll_loss.squeeze(-1) + smooth_loss = smooth_loss.squeeze(-1) + if reduce: + nll_loss = nll_loss.sum() + smooth_loss = smooth_loss.sum() + eps_i = epsilon / (lprobs.size(-1) - 1) + loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss + return loss, nll_loss + + +@register_criterion( + "label_smoothed_cross_entropy", dataclass=LabelSmoothedCrossEntropyCriterionConfig +) +class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): + def __init__( + self, + task, + sentence_avg, + label_smoothing, + ignore_prefix_size=0, + report_accuracy=False, + ): + super().__init__(task) + self.sentence_avg = sentence_avg + self.eps = label_smoothing + self.ignore_prefix_size = ignore_prefix_size + self.report_accuracy = report_accuracy + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample["net_input"]) + loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + logging_output = { + "loss": loss.data, + "nll_loss": nll_loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + if self.report_accuracy: + n_correct, total = self.compute_accuracy(model, net_output, sample) + logging_output["n_correct"] = utils.item(n_correct.data) + logging_output["total"] = utils.item(total.data) + return loss, sample_size, logging_output + + def get_lprobs_and_target(self, model, net_output, sample): + lprobs = model.get_normalized_probs(net_output, log_probs=True) + target = model.get_targets(sample, net_output) + if self.ignore_prefix_size > 0: + # lprobs: B x T x C + lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous() + target = target[:, self.ignore_prefix_size :].contiguous() + return lprobs.view(-1, lprobs.size(-1)), target.view(-1) + + def compute_loss(self, model, net_output, sample, reduce=True): + lprobs, target = self.get_lprobs_and_target(model, net_output, sample) + loss, nll_loss = label_smoothed_nll_loss( + lprobs, + target, + self.eps, + ignore_index=self.padding_idx, + reduce=reduce, + ) + return loss, nll_loss + + def compute_accuracy(self, model, net_output, sample): + lprobs, target = self.get_lprobs_and_target(model, net_output, sample) + mask = target.ne(self.padding_idx) + n_correct = torch.sum( + lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask)) + ) + total = torch.sum(mask) + return n_correct, total + + @classmethod + def reduce_metrics(cls, logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar( + "nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) + + total = utils.item(sum(log.get("total", 0) for log in logging_outputs)) + if total > 0: + metrics.log_scalar("total", total) + n_correct = utils.item( + sum(log.get("n_correct", 0) for log in logging_outputs) + ) + metrics.log_scalar("n_correct", n_correct) + metrics.log_derived( + "accuracy", + lambda meters: round( + meters["n_correct"].sum * 100.0 / meters["total"].sum, 3 + ) + if meters["total"].sum > 0 + else float("nan"), + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py b/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py new file mode 100644 index 0000000000000000000000000000000000000000..6eaedab9cf35dd6636e3463fdab1f0d4f9dda7e4 --- /dev/null +++ b/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py @@ -0,0 +1,221 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +import torch +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import register_criterion +from fairseq.criterions.label_smoothed_cross_entropy import ( + LabelSmoothedCrossEntropyCriterion, + LabelSmoothedCrossEntropyCriterionConfig, +) + +try: + from simuleval.metrics.latency import ( + AverageLagging, + AverageProportion, + DifferentiableAverageLagging, + ) + + LATENCY_METRICS = { + "average_lagging": AverageLagging, + "average_proportion": AverageProportion, + "differentiable_average_lagging": DifferentiableAverageLagging, + } +except ImportError: + LATENCY_METRICS = None + + +@dataclass +class LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig( + LabelSmoothedCrossEntropyCriterionConfig +): + latency_avg_weight: float = field( + default=0.0, + metadata={"help": "weight fot average latency loss."}, + ) + latency_var_weight: float = field( + default=0.0, + metadata={"help": "weight fot variance latency loss."}, + ) + latency_avg_type: str = field( + default="differentiable_average_lagging", + metadata={"help": "latency type for average loss"}, + ) + latency_var_type: str = field( + default="variance_delay", + metadata={"help": "latency typ for variance loss"}, + ) + latency_gather_method: str = field( + default="weighted_average", + metadata={"help": "method to gather latency loss for all heads"}, + ) + latency_update_after: int = field( + default=0, + metadata={"help": "Add latency loss after certain steps"}, + ) + + +@register_criterion( + "latency_augmented_label_smoothed_cross_entropy", + dataclass=LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig, +) +class LatencyAugmentedLabelSmoothedCrossEntropyCriterion( + LabelSmoothedCrossEntropyCriterion +): + def __init__( + self, + task, + sentence_avg, + label_smoothing, + ignore_prefix_size, + report_accuracy, + latency_avg_weight, + latency_var_weight, + latency_avg_type, + latency_var_type, + latency_gather_method, + latency_update_after, + ): + super().__init__( + task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy + ) + assert LATENCY_METRICS is not None, "Please make sure SimulEval is installed." + + self.latency_avg_weight = latency_avg_weight + self.latency_var_weight = latency_var_weight + self.latency_avg_type = latency_avg_type + self.latency_var_type = latency_var_type + self.latency_gather_method = latency_gather_method + self.latency_update_after = latency_update_after + + def forward(self, model, sample, reduce=True): + net_output = model(**sample["net_input"]) + # 1. Compute cross entropy loss + loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) + + # 2. Compute cross latency loss + latency_loss, expected_latency, expected_delays_var = self.compute_latency_loss( + model, sample, net_output + ) + + if self.latency_update_after > 0: + num_updates = getattr(model.decoder, "num_updates", None) + assert ( + num_updates is not None + ), "model.decoder doesn't have attribute 'num_updates'" + if num_updates <= self.latency_update_after: + latency_loss = 0 + + loss += latency_loss + + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + + logging_output = { + "loss": loss.data, + "nll_loss": nll_loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + "latency": expected_latency, + "delays_var": expected_delays_var, + "latency_loss": latency_loss, + } + + if self.report_accuracy: + n_correct, total = self.compute_accuracy(model, net_output, sample) + logging_output["n_correct"] = utils.item(n_correct.data) + logging_output["total"] = utils.item(total.data) + return loss, sample_size, logging_output + + def compute_latency_loss(self, model, sample, net_output): + assert ( + net_output[-1].encoder_padding_mask is None + or not net_output[-1].encoder_padding_mask[:, 0].any() + ), "Only right padding on source is supported." + # 1. Obtain the expected alignment + alpha_list = [item["alpha"] for item in net_output[1].attn_list] + num_layers = len(alpha_list) + bsz, num_heads, tgt_len, src_len = alpha_list[0].size() + + # bsz * num_layers * num_heads, tgt_len, src_len + alpha_all = torch.cat(alpha_list, dim=1).view(-1, tgt_len, src_len) + + # 2 compute expected delays + # bsz * num_heads * num_layers, tgt_len, src_len for MMA + steps = ( + torch.arange(1, 1 + src_len) + .unsqueeze(0) + .unsqueeze(1) + .expand_as(alpha_all) + .type_as(alpha_all) + ) + + expected_delays = torch.sum(steps * alpha_all, dim=-1) + + target_padding_mask = ( + model.get_targets(sample, net_output) + .eq(self.padding_idx) + .unsqueeze(1) + .expand(bsz, num_layers * num_heads, tgt_len) + .contiguous() + .view(-1, tgt_len) + ) + + src_lengths = ( + sample["net_input"]["src_lengths"] + .unsqueeze(1) + .expand(bsz, num_layers * num_heads) + .contiguous() + .view(-1) + ) + expected_latency = LATENCY_METRICS[self.latency_avg_type]( + expected_delays, src_lengths, None, target_padding_mask=target_padding_mask + ) + + # 2.1 average expected latency of heads + # bsz, num_layers * num_heads + expected_latency = expected_latency.view(bsz, -1) + if self.latency_gather_method == "average": + # bsz * tgt_len + expected_latency = expected_delays.mean(dim=1) + elif self.latency_gather_method == "weighted_average": + weights = torch.nn.functional.softmax(expected_latency, dim=1) + expected_latency = torch.sum(expected_latency * weights, dim=1) + elif self.latency_gather_method == "max": + expected_latency = expected_latency.max(dim=1)[0] + else: + raise NotImplementedError + + expected_latency = expected_latency.sum() + avg_loss = self.latency_avg_weight * expected_latency + + # 2.2 variance of expected delays + expected_delays_var = ( + expected_delays.view(bsz, -1, tgt_len).var(dim=1).mean(dim=1) + ) + expected_delays_var = expected_delays_var.sum() + var_loss = self.latency_avg_weight * expected_delays_var + + # 3. Final loss + latency_loss = avg_loss + var_loss + + return latency_loss, expected_latency, expected_delays_var + + @classmethod + def reduce_metrics(cls, logging_outputs) -> None: + super().reduce_metrics(logging_outputs) + latency = sum(log.get("latency", 0) for log in logging_outputs) + delays_var = sum(log.get("delays_var", 0) for log in logging_outputs) + latency_loss = sum(log.get("latency_loss", 0) for log in logging_outputs) + nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) + metrics.log_scalar("latency", latency.float() / nsentences, nsentences, round=3) + metrics.log_scalar("delays_var", delays_var / nsentences, nsentences, round=3) + metrics.log_scalar( + "latency_loss", latency_loss / nsentences, nsentences, round=3 + ) diff --git a/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py b/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..b55f65e5cc7e9e949208786a4974b4ba09b0de66 --- /dev/null +++ b/fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py @@ -0,0 +1,131 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import register_criterion + +from .label_smoothed_cross_entropy import ( + LabelSmoothedCrossEntropyCriterion, + LabelSmoothedCrossEntropyCriterionConfig, +) + +from dataclasses import dataclass, field + + +@dataclass +class LabelSmoothedCrossEntropyCriterionWithAlignmentConfig( + LabelSmoothedCrossEntropyCriterionConfig +): + alignment_lambda: float = field( + default=0.05, metadata={"help": "weight for the alignment loss"} + ) + + +@register_criterion( + "label_smoothed_cross_entropy_with_alignment", + dataclass=LabelSmoothedCrossEntropyCriterionWithAlignmentConfig, +) +class LabelSmoothedCrossEntropyCriterionWithAlignment( + LabelSmoothedCrossEntropyCriterion +): + def __init__(self, task, sentence_avg, label_smoothing, alignment_lambda): + super().__init__(task, sentence_avg, label_smoothing) + self.alignment_lambda = alignment_lambda + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample["net_input"]) + loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + logging_output = { + "loss": utils.item(loss.data) if reduce else loss.data, + "nll_loss": utils.item(nll_loss.data) if reduce else nll_loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + + alignment_loss = None + + # Compute alignment loss only for training set and non dummy batches. + if "alignments" in sample and sample["alignments"] is not None: + alignment_loss = self.compute_alignment_loss(sample, net_output) + + if alignment_loss is not None: + logging_output["alignment_loss"] = utils.item(alignment_loss.data) + loss += self.alignment_lambda * alignment_loss + + return loss, sample_size, logging_output + + def compute_alignment_loss(self, sample, net_output): + attn_prob = net_output[1]["attn"][0] + bsz, tgt_sz, src_sz = attn_prob.shape + attn = attn_prob.view(bsz * tgt_sz, src_sz) + + align = sample["alignments"] + align_weights = sample["align_weights"].float() + + if len(align) > 0: + # Alignment loss computation. align (shape [:, 2]) contains the src-tgt index pairs corresponding to + # the alignments. align_weights (shape [:]) contains the 1 / frequency of a tgt index for normalizing. + loss = -( + (attn[align[:, 1][:, None], align[:, 0][:, None]]).log() + * align_weights[:, None] + ).sum() + else: + return None + + return loss + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) + nll_loss_sum = utils.item( + sum(log.get("nll_loss", 0) for log in logging_outputs) + ) + alignment_loss_sum = utils.item( + sum(log.get("alignment_loss", 0) for log in logging_outputs) + ) + ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) + sample_size = utils.item( + sum(log.get("sample_size", 0) for log in logging_outputs) + ) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar( + "nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_scalar( + "alignment_loss", + alignment_loss_sum / sample_size / math.log(2), + sample_size, + round=3, + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/criterions/label_smoothed_cross_entropy_with_ctc.py b/fairseq/criterions/label_smoothed_cross_entropy_with_ctc.py new file mode 100644 index 0000000000000000000000000000000000000000..f2e8cdf3bfe0caea99125c6f9607dff9495891cf --- /dev/null +++ b/fairseq/criterions/label_smoothed_cross_entropy_with_ctc.py @@ -0,0 +1,97 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass, field + +import torch +import torch.nn.functional as F + +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import register_criterion +from fairseq.criterions.label_smoothed_cross_entropy import ( + LabelSmoothedCrossEntropyCriterion, + LabelSmoothedCrossEntropyCriterionConfig, +) +from fairseq.data.data_utils import lengths_to_mask + + +@dataclass +class LabelSmoothedCrossEntropyWithCtcCriterionConfig( + LabelSmoothedCrossEntropyCriterionConfig +): + ctc_weight: float = field(default=1.0, metadata={"help": "weight for CTC loss"}) + + +@register_criterion( + "label_smoothed_cross_entropy_with_ctc", + dataclass=LabelSmoothedCrossEntropyWithCtcCriterionConfig, +) +class LabelSmoothedCrossEntropyWithCtcCriterion(LabelSmoothedCrossEntropyCriterion): + def __init__( + self, + task, + sentence_avg, + label_smoothing, + ignore_prefix_size, + report_accuracy, + ctc_weight, + ): + super().__init__( + task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy + ) + self.ctc_weight = ctc_weight + + def forward(self, model, sample, reduce=True): + net_output = model(**sample["net_input"]) + loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) + + ctc_loss = torch.tensor(0.0).type_as(loss) + if self.ctc_weight > 0.0: + ctc_lprobs, ctc_lens = model.get_ctc_output(net_output, sample) + ctc_tgt, ctc_tgt_lens = model.get_ctc_target(sample) + ctc_tgt_mask = lengths_to_mask(ctc_tgt_lens) + ctc_tgt_flat = ctc_tgt.masked_select(ctc_tgt_mask) + reduction = "sum" if reduce else "none" + ctc_loss = ( + F.ctc_loss( + ctc_lprobs, + ctc_tgt_flat, + ctc_lens, + ctc_tgt_lens, + reduction=reduction, + zero_infinity=True, + ) + * self.ctc_weight + ) + loss += ctc_loss + + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + logging_output = { + "loss": utils.item(loss.data), + "nll_loss": utils.item(nll_loss.data), + "ctc_loss": utils.item(ctc_loss.data), + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + if self.report_accuracy: + n_correct, total = self.compute_accuracy(model, net_output, sample) + logging_output["n_correct"] = utils.item(n_correct.data) + logging_output["total"] = utils.item(total.data) + return loss, sample_size, logging_output + + @classmethod + def reduce_metrics(cls, logging_outputs) -> None: + super().reduce_metrics(logging_outputs) + loss_sum = sum(log.get("ctc_loss", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + metrics.log_scalar( + "ctc_loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) diff --git a/fairseq/criterions/label_smoothed_cross_entropy_with_rdrop.py b/fairseq/criterions/label_smoothed_cross_entropy_with_rdrop.py new file mode 100644 index 0000000000000000000000000000000000000000..47ee263a8de63261e4c8838ba44fe269553f5f3b --- /dev/null +++ b/fairseq/criterions/label_smoothed_cross_entropy_with_rdrop.py @@ -0,0 +1,177 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass, field + +import torch + +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import register_criterion +from fairseq.criterions.label_smoothed_cross_entropy import ( + LabelSmoothedCrossEntropyCriterion, + LabelSmoothedCrossEntropyCriterionConfig, + label_smoothed_nll_loss, +) + + +@dataclass +class RdropLabelSmoothedCrossEntropyCriterionConfig( + LabelSmoothedCrossEntropyCriterionConfig +): + rdrop_alpha: float = field( + default=0.0, + metadata={"help": "alpha for r-drop, 0 means no r-drop"}, + ) + + +@register_criterion( + "label_smoothed_cross_entropy_with_rdrop", + dataclass=RdropLabelSmoothedCrossEntropyCriterionConfig, +) +class RdropLabelSmoothedCrossEntropyCriterion(LabelSmoothedCrossEntropyCriterion): + def __init__( + self, + task, + sentence_avg, + label_smoothing, + ignore_prefix_size=0, + report_accuracy=False, + rdrop_alpha=0.0, + ): + super().__init__( + task, + sentence_avg, + label_smoothing, + ignore_prefix_size=ignore_prefix_size, + report_accuracy=report_accuracy, + ) + self.sentence_avg = sentence_avg + self.eps = label_smoothing + self.ignore_prefix_size = ignore_prefix_size + self.report_accuracy = report_accuracy + self.rdrop_alpha = rdrop_alpha + + def forward(self, model, sample, reduce=True, net_output=None): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + if net_output is None: + if self.rdrop_alpha > 0 and sample["net_input"]["src_tokens"].size( + 0 + ) == sample["target"].size(0): + sample = duplicate_input(sample) + net_output = model(**sample["net_input"]) + loss, nll_loss, rdrop_kl_loss = self.compute_loss( + model, net_output, sample, reduce=reduce + ) + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + logging_output = { + "loss": loss.data, + "nll_loss": nll_loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + if self.report_accuracy: + n_correct, total = self.compute_accuracy(model, net_output, sample) + logging_output["n_correct"] = utils.item(n_correct.data) + logging_output["total"] = utils.item(total.data) + if self.rdrop_alpha > 0: + logging_output["rdrop_kl_loss"] = utils.item(rdrop_kl_loss.data) + return loss, sample_size, logging_output + + def get_lprobs_and_target(self, model, net_output, sample): + lprobs = model.get_normalized_probs(net_output, log_probs=True) + target = model.get_targets(sample, net_output) + if self.rdrop_alpha > 0 or target.size(0) != lprobs.size(0): + target = torch.cat([target, target.clone()], dim=0) + + if self.ignore_prefix_size > 0: + # lprobs: B x T x C + lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous() + target = target[:, self.ignore_prefix_size :].contiguous() + return lprobs.view(-1, lprobs.size(-1)), target.view(-1) + + def compute_loss(self, model, net_output, sample, reduce=True): + lprobs, target = self.get_lprobs_and_target(model, net_output, sample) + loss, nll_loss = label_smoothed_nll_loss( + lprobs, + target, + self.eps, + ignore_index=self.padding_idx, + reduce=reduce, + ) + + if self.rdrop_alpha > 0: + pad_mask = target[: target.size(0) // 2].unsqueeze(-1).eq(self.padding_idx) + rdrop_kl_loss = compute_kl_loss(model, net_output, pad_mask) + loss += self.rdrop_alpha * rdrop_kl_loss + else: + rdrop_kl_loss = loss.new_zeros(1) + return loss, nll_loss, rdrop_kl_loss + + @classmethod + def reduce_metrics(cls, logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + super().reduce_metrics(logging_outputs) + + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + rdrop_kl_loss = utils.item( + sum(log.get("rdrop_kl_loss", 0) for log in logging_outputs) + / sample_size + / math.log(2) + ) + if rdrop_kl_loss > 0: + metrics.log_scalar("rdrop_kl_loss", rdrop_kl_loss) + + +def duplicate_input(sample): + if "net_input" in sample.keys(): + sample_input = sample["net_input"] + else: + sample_input = sample + + for k, v in sample_input.items(): + if isinstance(v, torch.Tensor): + sample_input[k] = torch.cat([v, v.clone()], dim=0) + if "net_input" in sample.keys(): + sample["net_input"] = sample_input + else: + sample = sample_input + return sample + + +def compute_kl_loss(model, net_output, pad_mask=None, reduce=True): + net_prob = model.get_normalized_probs(net_output, log_probs=True) + net_prob_tec = model.get_normalized_probs(net_output, log_probs=False) + + net_prob = net_prob.view(-1, net_prob.size(-1)) + net_prob_tec = net_prob_tec.view(-1, net_prob_tec.size(-1)) + + p, q = torch.split(net_prob, net_prob.size(0) // 2, dim=0) + p_tec, q_tec = torch.split(net_prob_tec, net_prob_tec.size(0) // 2, dim=0) + + p_loss = torch.nn.functional.kl_div(p, q_tec, reduction="none") + q_loss = torch.nn.functional.kl_div(q, p_tec, reduction="none") + + if pad_mask is not None: + p_loss.masked_fill_(pad_mask, 0.0) + q_loss.masked_fill_(pad_mask, 0.0) + + if reduce: + p_loss = p_loss.sum() + q_loss = q_loss.sum() + + loss = (p_loss + q_loss) / 2 + return loss diff --git a/fairseq/criterions/legacy_masked_lm.py b/fairseq/criterions/legacy_masked_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf70df2ab97eef1ec454ddc8ccaf5a86cc3c153 --- /dev/null +++ b/fairseq/criterions/legacy_masked_lm.py @@ -0,0 +1,178 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion + + +def compute_cross_entropy_loss(logits, targets, ignore_index=-100): + """ + Function to compute the cross entropy loss. The default value of + ignore_index is the same as the default value for F.cross_entropy in + pytorch. + """ + assert logits.size(0) == targets.size( + -1 + ), "Logits and Targets tensor shapes don't match up" + + loss = F.nll_loss( + F.log_softmax(logits, -1, dtype=torch.float32), + targets, + reduction="sum", + ignore_index=ignore_index, + ) + return loss + + +@register_criterion("legacy_masked_lm_loss") +class LegacyMaskedLmLoss(FairseqCriterion): + """ + Implementation for the loss used in masked language model (MLM) training. + This optionally also computes the next sentence prediction (NSP) loss and + adds it to the overall loss based on the specified args. There are three + cases to consider: + 1) Generic MLM training without NSP loss. In this case sentence_targets + and sentence_logits are both None. + 2) BERT training without NSP loss. In this case sentence_targets is + not None but sentence_logits is None and we should not be computing + a sentence level loss. + 3) BERT training with NSP loss. In this case both sentence_targets and + sentence_logits are not None and we should be computing a sentence + level loss. The weight of the sentence level loss is specified as + an argument. + """ + + def __init__(self, task, masked_lm_only, nsp_loss_weight): + super().__init__(task) + self.masked_lm_only = masked_lm_only + self.nsp_loss_weight = nsp_loss_weight + + @staticmethod + def add_args(parser): + """Args for MaskedLM Loss""" + # Default for masked_lm_only is False so as to not break BERT training + parser.add_argument( + "--masked-lm-only", + default=False, + action="store_true", + help="compute MLM loss only", + ) + parser.add_argument( + "--nsp-loss-weight", + default=1.0, + type=float, + help="weight for next sentence prediction" " loss (default 1)", + ) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + lm_logits, output_metadata = model(**sample["net_input"]) + + # reshape lm_logits from (N,T,C) to (N*T,C) + lm_logits = lm_logits.view(-1, lm_logits.size(-1)) + lm_targets = sample["lm_target"].view(-1) + lm_loss = compute_cross_entropy_loss(lm_logits, lm_targets, self.padding_idx) + + # compute the number of tokens for which loss is computed. This is used + # to normalize the loss + ntokens = utils.strip_pad(lm_targets, self.padding_idx).numel() + loss = lm_loss / ntokens + nsentences = sample["nsentences"] + # nsentences = 0 + + # Compute sentence loss if masked_lm_only is False + sentence_loss = None + if not self.masked_lm_only: + sentence_logits = output_metadata["sentence_logits"] + sentence_targets = sample["sentence_target"].view(-1) + # This needs to be recomputed due to some differences between + # TokenBlock and BlockPair dataset. This can be resolved with a + # refactor of BERTModel which we will do in the future. + # TODO: Remove this after refactor of BERTModel + nsentences = sentence_targets.size(0) + + # Check for logits being none which can happen when remove_heads + # is set to true in the BERT model. Ideally we should set + # masked_lm_only to true in this case, but that requires some + # refactor in the BERT model. + if sentence_logits is not None: + sentence_loss = compute_cross_entropy_loss( + sentence_logits, sentence_targets + ) + + loss += self.nsp_loss_weight * (sentence_loss / nsentences) + + # NOTE: as we are summing up per token mlm loss and per sentence nsp loss + # we don't need to use sample_size as denominator for the gradient + # here sample_size is just used for logging + sample_size = 1 + logging_output = { + "loss": utils.item(loss.data) if reduce else loss.data, + "lm_loss": utils.item(lm_loss.data) if reduce else lm_loss.data, + # sentence loss is not always computed + "sentence_loss": ( + (utils.item(sentence_loss.data) if reduce else sentence_loss.data) + if sentence_loss is not None + else 0.0 + ), + "ntokens": ntokens, + "nsentences": nsentences, + "sample_size": sample_size, + } + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + lm_loss_sum = sum(log.get("lm_loss", 0) for log in logging_outputs) + sentence_loss_sum = sum(log.get("sentence_loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + agg_loss = sum(log.get("loss", 0) for log in logging_outputs) + + metrics.log_scalar( + "loss", + agg_loss / sample_size / math.log(2) if sample_size > 0 else 0.0, + sample_size, + round=3, + ) + metrics.log_scalar( + "lm_loss", + lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.0, + ntokens, + round=3, + ) + metrics.log_scalar( + "sentence_loss", + sentence_loss_sum / nsentences / math.log(2) if nsentences > 0 else 0.0, + nsentences, + round=3, + ) + metrics.log_scalar( + "nll_loss", + lm_loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.0, + ntokens, + round=3, + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/criterions/masked_lm.py b/fairseq/criterions/masked_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..09ddd9f3e6d41c4771521bb187ad981223e09e95 --- /dev/null +++ b/fairseq/criterions/masked_lm.py @@ -0,0 +1,99 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +import math +from omegaconf import II + +import torch +from fairseq import modules, utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass + + +@dataclass +class MaskedLmConfig(FairseqDataclass): + tpu: bool = II("common.tpu") + + +@register_criterion("masked_lm", dataclass=MaskedLmConfig) +class MaskedLmLoss(FairseqCriterion): + """ + Implementation for the loss used in masked language model (MLM) training. + """ + + def __init__(self, cfg: MaskedLmConfig, task): + super().__init__(task) + self.tpu = cfg.tpu + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + masked_tokens = sample["target"].ne(self.padding_idx) + sample_size = masked_tokens.int().sum() + + # Rare: when all tokens are masked, project all tokens. + # We use torch.where to avoid device-to-host transfers, + # except on CPU where torch.where is not well supported + # (see github.com/pytorch/pytorch/issues/26247). + if self.tpu: + masked_tokens = None # always project all tokens on TPU + elif masked_tokens.device == torch.device("cpu"): + if not masked_tokens.any(): + masked_tokens = None + else: + masked_tokens = torch.where( + masked_tokens.any(), + masked_tokens, + masked_tokens.new([True]), + ) + + logits = model(**sample["net_input"], masked_tokens=masked_tokens)[0] + targets = model.get_targets(sample, [logits]) + if masked_tokens is not None: + targets = targets[masked_tokens] + + loss = modules.cross_entropy( + logits.view(-1, logits.size(-1)), + targets.view(-1), + reduction="sum", + ignore_index=self.padding_idx, + ) + + logging_output = { + "loss": loss if self.tpu else loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["nsentences"], + "sample_size": sample_size, + } + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/criterions/model_criterion.py b/fairseq/criterions/model_criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..4c020ddbd2fde9a879d94fe63e745d3e6b9a627e --- /dev/null +++ b/fairseq/criterions/model_criterion.py @@ -0,0 +1,177 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from dataclasses import dataclass, field +from typing import Dict, List + +import torch + +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from fairseq.logging.meters import safe_round + + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelCriterionConfig(FairseqDataclass): + loss_weights: Dict[str, float] = field( + default_factory=dict, + metadata={"help": "weights for the loss terms"}, + ) + log_keys: List[str] = field( + default_factory=list, + metadata={"help": "additional output keys to log"}, + ) + can_sum: bool = True + + +@register_criterion("model", dataclass=ModelCriterionConfig) +class ModelCriterion(FairseqCriterion): + """ + This criterion relies on the model to supply losses. + The losses should be a dictionary of name -> scalar returned by + the model either by including it in the net_output dict or by + implementing a get_losses(net_output, sample) method. The final loss is + a scaled sum of all losses according to weights in loss_weights. + If no weights are provided, then all losses are scaled by 1.0. + + The losses will be automatically logged. Additional keys from + net_output dict can be logged via the log_keys parameter. + """ + + def __init__(self, task, loss_weights=None, log_keys=None, can_sum=True): + super().__init__(task) + self.loss_weights = loss_weights + self.log_keys = log_keys + self.can_sum = can_sum + + def forward(self, model, sample, reduce=True): + net_output = model(**sample["net_input"]) + + scaled_losses = {} + + if hasattr(model, "get_losses"): + losses = model.get_losses(net_output, sample) + elif isinstance(net_output, dict) and "losses" in net_output: + losses = net_output["losses"] + else: + raise Exception("Could not retrieve losses") + + for lk, p in losses.items(): + try: + coef = 1.0 if len(self.loss_weights) == 0 else self.loss_weights[lk] + except KeyError: + logger.error( + f"weight for loss {lk} is not in loss_weights ({self.loss_weights})" + ) + raise + if coef != 0 and p is not None: + scaled_losses[lk] = coef * p.float().sum() + + loss = sum(scaled_losses.values()) + + if "sample_size" in net_output: + sample_size = net_output["sample_size"] + else: + sample_size = loss.numel() + + if reduce and loss.numel() > 1: + loss = loss.sum() + + logging_output = { + "loss": loss.data, + "ntokens": sample_size, + "nsentences": sample["id"].numel(), + "sample_size": sample_size, + "_world_size": 1, + } + + for lk in self.log_keys: + if lk in net_output and net_output[lk] is not None: + if not torch.is_tensor(net_output[lk]) or net_output[lk].numel() == 1: + logging_output[lk] = float(net_output[lk]) + elif lk.startswith("_"): + logging_output[lk] = net_output[lk] + else: + for i, v in enumerate(net_output[lk]): + logging_output[f"{lk}_{i}"] = float(v) + + if len(scaled_losses) > 1: + for lk, l in scaled_losses.items(): + if l.numel() > 1: + l = l.sum() + logging_output[f"loss_{lk}"] = l.item() + + if "logs" in net_output: + for lgw in net_output["logs"]: + logging_output[lgw] = net_output["logs"][lgw] + + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) + ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) + nsentences = utils.item( + sum(log.get("nsentences", 0) for log in logging_outputs) + ) + sample_size = utils.item( + sum(log.get("sample_size", 0) for log in logging_outputs) + ) + + metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3) + metrics.log_scalar("ntokens", ntokens) + metrics.log_scalar("nsentences", nsentences) + metrics.log_scalar("sample_size", sample_size) + + builtin_keys = { + "loss", + "ntokens", + "nsentences", + "sample_size", + "_world_size", + } + + world_size = utils.item( + sum(log.get("_world_size", 0) for log in logging_outputs) + ) + + for k in logging_outputs[0]: + if k not in builtin_keys and not k.startswith("_"): + val = sum(log.get(k, 0) for log in logging_outputs) + if k.startswith("loss_"): + metrics.log_scalar(k, val / sample_size, sample_size, round=3) + else: + metrics.log_scalar(k, val / world_size, round=3) + + correct = sum(log.get("correct", 0) for log in logging_outputs) + total = sum(log.get("count", 0) for log in logging_outputs) + + if total > 0: + metrics.log_scalar("_correct", correct) + metrics.log_scalar("_total", total) + + metrics.log_derived( + "accuracy", + lambda meters: safe_round( + meters["_correct"].sum / meters["_total"].sum, 5 + ) + if meters["_total"].sum > 0 + else float("nan"), + ) + + def logging_outputs_can_be_summed(self) -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return self.can_sum diff --git a/fairseq/criterions/nat_loss.py b/fairseq/criterions/nat_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..fc0bdaf8510d8feae20779aad49b53a4d84d37db --- /dev/null +++ b/fairseq/criterions/nat_loss.py @@ -0,0 +1,181 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from torch import Tensor + +from dataclasses import dataclass, field + + +@dataclass +class LabelSmoothedDualImitationCriterionConfig(FairseqDataclass): + label_smoothing: float = field( + default=0.0, + metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"}, + ) + + +@register_criterion("nat_loss", dataclass=LabelSmoothedDualImitationCriterionConfig) +class LabelSmoothedDualImitationCriterion(FairseqCriterion): + def __init__(self, task, label_smoothing): + super().__init__(task) + self.label_smoothing = label_smoothing + + def _compute_loss( + self, outputs, targets, masks=None, label_smoothing=0.0, name="loss", factor=1.0 + ): + """ + outputs: batch x len x d_model + targets: batch x len + masks: batch x len + + policy_logprob: if there is some policy + depends on the likelihood score as rewards. + """ + + def mean_ds(x: Tensor, dim=None) -> Tensor: + return ( + x.float().mean().type_as(x) + if dim is None + else x.float().mean(dim).type_as(x) + ) + + if masks is not None: + outputs, targets = outputs[masks], targets[masks] + + if masks is not None and not masks.any(): + nll_loss = torch.tensor(0) + loss = nll_loss + else: + logits = F.log_softmax(outputs, dim=-1) + if targets.dim() == 1: + losses = F.nll_loss(logits, targets.to(logits.device), reduction="none") + + else: # soft-labels + losses = F.kl_div(logits, targets.to(logits.device), reduction="none") + losses = losses.sum(-1) + + nll_loss = mean_ds(losses) + if label_smoothing > 0: + loss = ( + nll_loss * (1 - label_smoothing) - mean_ds(logits) * label_smoothing + ) + else: + loss = nll_loss + + loss = loss * factor + return {"name": name, "loss": loss, "nll_loss": nll_loss, "factor": factor} + + def _custom_loss(self, loss, name="loss", factor=1.0): + return {"name": name, "loss": loss, "factor": factor} + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + nsentences, ntokens = sample["nsentences"], sample["ntokens"] + + # B x T + src_tokens, src_lengths = ( + sample["net_input"]["src_tokens"], + sample["net_input"]["src_lengths"], + ) + tgt_tokens, prev_output_tokens = sample["target"], sample["prev_target"] + + outputs = model(src_tokens, src_lengths, prev_output_tokens, tgt_tokens) + losses, nll_loss = [], [] + + for obj in outputs: + if outputs[obj].get("loss", None) is None: + _losses = self._compute_loss( + outputs[obj].get("out"), + outputs[obj].get("tgt"), + outputs[obj].get("mask", None), + outputs[obj].get("ls", 0.0), + name=obj + "-loss", + factor=outputs[obj].get("factor", 1.0), + ) + else: + _losses = self._custom_loss( + outputs[obj].get("loss"), + name=obj + "-loss", + factor=outputs[obj].get("factor", 1.0), + ) + + losses += [_losses] + if outputs[obj].get("nll_loss", False): + nll_loss += [_losses.get("nll_loss", 0.0)] + + loss = sum(l["loss"] for l in losses) + nll_loss = sum(l for l in nll_loss) if len(nll_loss) > 0 else loss.new_tensor(0) + + # NOTE: + # we don't need to use sample_size as denominator for the gradient + # here sample_size is just used for logging + sample_size = 1 + logging_output = { + "loss": loss.data, + "nll_loss": nll_loss.data, + "ntokens": ntokens, + "nsentences": nsentences, + "sample_size": sample_size, + } + + for l in losses: + logging_output[l["name"]] = ( + utils.item(l["loss"].data / l["factor"]) + if reduce + else l[["loss"]].data / l["factor"] + ) + + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + sample_size = utils.item( + sum(log.get("sample_size", 0) for log in logging_outputs) + ) + loss = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) + nll_loss = utils.item(sum(log.get("nll_loss", 0) for log in logging_outputs)) + + metrics.log_scalar( + "loss", loss / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar( + "nll_loss", nll_loss / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) + ) + + for key in logging_outputs[0]: + if key[-5:] == "-loss": + val = sum(log.get(key, 0) for log in logging_outputs) + metrics.log_scalar( + key[:-5], + val / sample_size / math.log(2) if sample_size > 0 else 0.0, + sample_size, + round=3, + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/criterions/sentence_prediction.py b/fairseq/criterions/sentence_prediction.py new file mode 100644 index 0000000000000000000000000000000000000000..298b80576814d785e8722ae4726698a0a928a20f --- /dev/null +++ b/fairseq/criterions/sentence_prediction.py @@ -0,0 +1,288 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass, field +from itertools import chain + +import numpy as np +import torch +import torch.nn.functional as F +from sklearn.metrics import f1_score +from sklearn.metrics import matthews_corrcoef as _matthews_corrcoef +from scipy.stats import pearsonr, spearmanr + +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from fairseq.logging.meters import safe_round + + +def simple_accuracy(preds, labels): + return (preds == labels).mean() + + +def acc_and_f1(preds, labels): + acc = simple_accuracy(preds, labels) + f1 = f1_score(y_true=labels, y_pred=preds) + return { + "acc": acc, + "f1": f1, + "acc_and_f1": (acc + f1) / 2, + } + + +def pearson_and_spearman(preds, labels): + pearson_corr = pearsonr(preds, labels)[0] + spearman_corr = spearmanr(preds, labels)[0] + return { + "pearson": pearson_corr, + "spearmanr": spearman_corr, + "corr": (pearson_corr + spearman_corr) / 2, + } + + +def matthews_corrcoef(preds, labels): + # make it consistent with other metrics taking (preds, labels) as input + mcc = _matthews_corrcoef(labels, preds) + return mcc + + +@dataclass +class SentencePredictionConfig(FairseqDataclass): + classification_head_name: str = field( + default="sentence_classification_head", + metadata={"help": "name of the classification head to use"}, + ) + regression_target: bool = field( + default=False, + ) + report_mcc: bool = False + report_acc_and_f1: bool = False + report_pearson_and_spearman: bool = False + + +@register_criterion("sentence_prediction", dataclass=SentencePredictionConfig) +class SentencePredictionCriterion(FairseqCriterion): + def __init__(self, cfg: SentencePredictionConfig, task): + super().__init__(task) + self.classification_head_name = cfg.classification_head_name + self.regression_target = cfg.regression_target + self.keep_pred_and_targ = ( + cfg.report_mcc or cfg.report_acc_and_f1 or cfg.report_pearson_and_spearman + ) + self.report_mcc = cfg.report_mcc + self.report_acc_and_f1 = cfg.report_acc_and_f1 + self.report_pearson_and_spearman = cfg.report_pearson_and_spearman + self.label_dict = task.label_dictionary + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + assert ( + hasattr(model, "classification_heads") + and self.classification_head_name in model.classification_heads + ), "model must provide sentence classification head for --criterion=sentence_prediction" + + logits, _ = model( + **sample["net_input"], + features_only=True, + classification_head_name=self.classification_head_name, + ) + targets = model.get_targets(sample, [logits]).view(-1) + sample_size = targets.numel() + + if not self.regression_target: + lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) + task_loss = F.nll_loss(lprobs, targets, reduction="sum") + else: + logits = logits.view(-1).float() + targets = targets.float() + task_loss = F.mse_loss(logits, targets, reduction="sum") + + logging_output = {} + loss = task_loss + # mha & ffn regularization update + if ( + hasattr(model, "args") + and hasattr(model.args, "mha_reg_scale_factor") + and model.args.mha_reg_scale_factor != 0.0 + ): + mha_reg_loss = model._get_adaptive_head_loss() + loss += mha_reg_loss + logging_output.update({"mha_reg_loss": mha_reg_loss}) + if ( + hasattr(model, "args") + and hasattr(model.args, "ffn_reg_scale_factor") + and model.args.ffn_reg_scale_factor != 0.0 + ): + ffn_reg_loss = model._get_adaptive_ffn_loss() + loss += ffn_reg_loss + logging_output.update({"ffn_reg_loss": ffn_reg_loss}) + + logging_output.update( + { + "loss": loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample_size, + "sample_size": sample_size, + } + ) + if not self.regression_target: + preds = logits.argmax(dim=1) + logging_output["ncorrect"] = (preds == targets).sum() + if self.keep_pred_and_targ and not model.training: + if self.regression_target: + logging_output["pred"] = logits.detach().cpu().tolist() + logging_output["targ"] = targets.detach().cpu().tolist() + else: + # remove offset `self.label_dict.nspecial` from OffsetTokensDataset + preds = self.label_dict.string(preds + self.label_dict.nspecial).split() + targets = self.label_dict.string( + targets + self.label_dict.nspecial + ).split() + logging_output["pred"] = list(map(int, preds)) + logging_output["targ"] = list(map(int, targets)) + + if self.report_mcc: + logging_output["report_mcc"] = True + if self.report_acc_and_f1: + logging_output["report_acc_and_f1"] = True + if self.report_pearson_and_spearman: + logging_output["report_pearson_and_spearman"] = True + + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + mha_reg_loss_sum = sum(log.get("mha_reg_loss", 0) for log in logging_outputs) + ffn_reg_loss_sum = sum(log.get("ffn_reg_loss", 0) for log in logging_outputs) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + if mha_reg_loss_sum: + metrics.log_scalar( + "mha_reg_loss", + mha_reg_loss_sum / sample_size / math.log(2), + sample_size, + round=3, + ) + if ffn_reg_loss_sum: + metrics.log_scalar( + "ffn_reg_loss", + ffn_reg_loss_sum / sample_size / math.log(2), + sample_size, + round=3, + ) + if sample_size != ntokens: + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + + if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]: + ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs) + metrics.log_scalar( + "accuracy", 100.0 * ncorrect / nsentences, nsentences, round=1 + ) + + # Metrics used by GLUE + pred = np.array( + list(chain.from_iterable(log.get("pred", []) for log in logging_outputs)) + ) + targ = np.array( + list(chain.from_iterable(log.get("targ", []) for log in logging_outputs)) + ) + if len(pred): + metrics.log_concat_tensor("pred", torch.from_numpy(pred), dim=0) + metrics.log_concat_tensor("targ", torch.from_numpy(targ), dim=0) + if any("report_mcc" in log for log in logging_outputs): + metrics.log_derived( + "mcc", + lambda meters: safe_round( + matthews_corrcoef( + meters["pred"].tensor.numpy(), + meters["targ"].tensor.numpy(), + ) + * 100, + 1, + ), + ) + if any("report_acc_and_f1" in log for log in logging_outputs): + metrics.log_derived( + "acc_and_f1", + lambda meters: safe_round( + acc_and_f1( + meters["pred"].tensor.numpy(), + meters["targ"].tensor.numpy(), + )["acc_and_f1"] + * 100, + 1, + ), + ) + metrics.log_derived( + "f1", + lambda meters: safe_round( + acc_and_f1( + meters["pred"].tensor.numpy(), + meters["targ"].tensor.numpy(), + )["f1"] + * 100, + 1, + ), + ) + if any("report_pearson_and_spearman" in log for log in logging_outputs): + metrics.log_derived( + "pearson_and_spearman", + lambda meters: safe_round( + pearson_and_spearman( + meters["pred"].tensor.numpy(), + meters["targ"].tensor.numpy(), + )["corr"] + * 100, + 1, + ), + ) + metrics.log_derived( + "pearson", + lambda meters: safe_round( + pearson_and_spearman( + meters["pred"].tensor.numpy(), + meters["targ"].tensor.numpy(), + )["pearson"] + * 100, + 1, + ), + ) + metrics.log_derived( + "spearman", + lambda meters: safe_round( + pearson_and_spearman( + meters["pred"].tensor.numpy(), + meters["targ"].tensor.numpy(), + )["spearmanr"] + * 100, + 1, + ), + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/criterions/sentence_prediction_adapters.py b/fairseq/criterions/sentence_prediction_adapters.py new file mode 100644 index 0000000000000000000000000000000000000000..8a873a45b3b121730c8a27d64facfa2922f8eb88 --- /dev/null +++ b/fairseq/criterions/sentence_prediction_adapters.py @@ -0,0 +1,63 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +from fairseq.criterions import register_criterion +from fairseq.criterions.sentence_prediction import ( + SentencePredictionCriterion, + SentencePredictionConfig, +) + + +@register_criterion("sentence_prediction_adapters", dataclass=SentencePredictionConfig) +class SentencePredictionCriterionAdapters(SentencePredictionCriterion): + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + assert ( + hasattr(model, "classification_heads") + and self.classification_head_name in model.classification_heads + ), "model must provide sentence classification head for --criterion=sentence_prediction" + + if not hasattr(sample, "lang_id"): + # If no language ID is given, we fall back to English + lang_id = ["en_XX"] * sample["nsentences"] + else: + lang_id = sample["lang_id"] + + logits, _ = model( + **sample["net_input"], + features_only=True, + classification_head_name=self.classification_head_name, + lang_id=lang_id, + ) + targets = model.get_targets(sample, [logits]).view(-1) + sample_size = targets.numel() + + if not self.regression_target: + lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) + loss = F.nll_loss(lprobs, targets, reduction="sum") + else: + logits = logits.view(-1).float() + targets = targets.float() + loss = F.mse_loss(logits, targets, reduction="sum") + + logging_output = { + "loss": loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample_size, + "sample_size": sample_size, + } + if not self.regression_target: + preds = logits.argmax(dim=1) + logging_output["ncorrect"] = (preds == targets).sum() + + return loss, sample_size, logging_output diff --git a/fairseq/criterions/sentence_ranking.py b/fairseq/criterions/sentence_ranking.py new file mode 100644 index 0000000000000000000000000000000000000000..bfb9f058f9208a1eb1218715df0f6f2183085dd9 --- /dev/null +++ b/fairseq/criterions/sentence_ranking.py @@ -0,0 +1,121 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion + + +@register_criterion("sentence_ranking") +class SentenceRankingCriterion(FairseqCriterion): + def __init__(self, task, ranking_head_name, save_predictions, num_classes): + super().__init__(task) + self.ranking_head_name = ranking_head_name + if save_predictions is not None: + self.prediction_h = open(save_predictions, "w") + else: + self.prediction_h = None + self.num_classes = num_classes + + def __del__(self): + if self.prediction_h is not None: + self.prediction_h.close() + + @staticmethod + def add_args(parser): + # fmt: off + parser.add_argument('--save-predictions', metavar='FILE', + help='file to save predictions to') + parser.add_argument('--ranking-head-name', + default='sentence_classification_head', + help='name of the ranking head to use') + # fmt: on + + def forward(self, model, sample, reduce=True): + """Compute ranking loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + assert ( + hasattr(model, "classification_heads") + and self.ranking_head_name in model.classification_heads + ), "model must provide sentence ranking head for --criterion=sentence_ranking" + + scores = [] + for idx in range(self.num_classes): + score, _ = model( + **sample["net_input{idx}".format(idx=idx + 1)], + classification_head_name=self.ranking_head_name, + ) + scores.append(score) + + logits = torch.cat(scores, dim=1) + sample_size = logits.size(0) + + if "target" in sample: + targets = model.get_targets(sample, [logits]).view(-1) + lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) + loss = F.nll_loss(lprobs, targets, reduction="sum") + else: + targets = None + loss = torch.tensor(0.0, requires_grad=True) + + if self.prediction_h is not None: + preds = logits.argmax(dim=1) + for i, (id, pred) in enumerate(zip(sample["id"].tolist(), preds.tolist())): + if targets is not None: + label = targets[i].item() + print("{}\t{}\t{}".format(id, pred, label), file=self.prediction_h) + else: + print("{}\t{}".format(id, pred), file=self.prediction_h) + + logging_output = { + "loss": loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample_size, + "sample_size": sample_size, + } + if targets is not None: + logging_output["ncorrect"] = (logits.argmax(dim=1) == targets).sum() + + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + if sample_size != ntokens: + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + + if len(logging_outputs) > 0 and "ncorrect" in logging_outputs[0]: + ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs) + metrics.log_scalar( + "accuracy", 100.0 * ncorrect / nsentences, nsentences, round=1 + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/criterions/speech_dlm_criterion.py b/fairseq/criterions/speech_dlm_criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..888818011408fd81e5cd3e3c9b074e5082702c79 --- /dev/null +++ b/fairseq/criterions/speech_dlm_criterion.py @@ -0,0 +1,335 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass, field +from typing import Optional + +import torch.nn.functional as F +from fairseq import metrics, utils +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from omegaconf import II + + +@dataclass +class SpeechDLMCriterionConfig(FairseqDataclass): + sentence_avg: bool = II("optimization.sentence_avg") + main_and_cross_weights: Optional[str] = field( + default="1,0", + metadata={ + "help": "Comma-separated list of weights of Main-channel vs Cross-channel Prediction Losses" + "(default: 1,0)" + }, + ) + general_unit_loss_weight: float = field( + default=0, + metadata={ + "help": "The weight of the General Prediction Loss (Next-step Unit Prediction Loss)" + "(default: 0)" + }, + ) + edge_unit_loss_weight: float = field( + default=1, + metadata={"help": "The weight of the Edge Unit Prediction Loss" "(default: 1)"}, + ) + duration_loss_weight: float = field( + default=1, + metadata={ + "help": "The weight of the Edge Unit Duration Prediction Loss" + "(default: 1)" + }, + ) + + +@register_criterion("speech_dlm_criterion", dataclass=SpeechDLMCriterionConfig) +class SpeechDLMCriterion(FairseqCriterion): + """Criteron for the SpeechDLM model as described in the paper: + https://arxiv.org/pdf/2203.16502.pdf + + There are 3 possible losses depending on the targets of the model: + - general_unit_loss : The next unit prediction loss, corresponding to + 'next' target + - edge_unit_loss : The edge unit prediction loss, corresponding to + 'edge' target + - duration_loss : The duration prediction loss, corresponding to + 'duration' target + """ + + def __init__( + self, + task, + sentence_avg, + main_and_cross_weights, + general_unit_loss_weight, + edge_unit_loss_weight, + duration_loss_weight, + ): + super().__init__(task) + self.sentence_avg = sentence_avg + + self.channels = task.channels + self.targets = task.targets + self.delayed_duration_target = task.delayed_duration_target + + self.main_channel_weight = float(main_and_cross_weights.split(",")[0]) + self.cross_channel_weight = float(main_and_cross_weights.split(",")[1]) + assert self.main_channel_weight >= 0 and self.cross_channel_weight >= 0 + + self.channel_weights = { + channel: weight + for channel, weight in zip(self.channels, task.channel_weights) + } + + self.target_weights = {} + for t in self.targets: + if t == "next": + self.target_weights[t] = general_unit_loss_weight + assert ( + general_unit_loss_weight > 0 + ), "Expect a positive --general-unit-loss-weight for next unit prediction" + elif t == "edge": + self.target_weights[t] = edge_unit_loss_weight + assert ( + edge_unit_loss_weight > 0 + ), "Expect a positive --edge-unit-loss-weight for edge unit prediction" + elif t == "duration": + self.target_weights[t] = duration_loss_weight + assert ( + duration_loss_weight > 0 + ), "Expect a positive --duration-loss-weight for duration prediction" + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample["net_input"]) + loss_dict, stats_dict = self.compute_loss( + model, net_output, sample, reduce=reduce + ) + nsentences = sample["net_input"]["src_tokens"][self.channels[0]].size(0) + + logging_output = { + "nsentences": nsentences, + } + logging_output["nsentences"] = nsentences + + loss_all = {t: 0 for t in self.targets} + correct_all = {t: 0 for t in self.targets} + count_all = {t: 0 for t in self.targets} + ntokens_all = 0 + sample_size_all = 0 + for channel in loss_dict: + for pred_channel in loss_dict[channel]: + # Get ntokens & sample_size + ntokens = sample["net_input"]["src_tokens"][channel].numel() + sample_size = nsentences if self.sentence_avg else ntokens + prefix = "[{}-{}]".format(channel, pred_channel) + log_keys = { + "next": "general_token", + "edge": "edge_token", + "duration": "edge_duration", + } + + # Log & Update the sizes + logging_output["{}ntokens".format(prefix)] = ntokens + logging_output["{}sample_size".format(prefix)] = sample_size + ntokens_all += ntokens + sample_size_all += sample_size + + for t in self.targets: + log_key = log_keys[t] + loss = loss_dict[channel][pred_channel][t] + correct, count = stats_dict[channel][pred_channel][t] + + # Log the statistics + logging_output["{}{}_loss".format(prefix, log_key)] = loss.data + logging_output["{}{}_correct".format(prefix, log_key)] = correct + logging_output["{}{}_count".format(prefix, log_key)] = count + + # Scale the training loss by weights + target_loss = loss * self.channel_weights[channel] + if pred_channel == channel: + target_loss = target_loss * self.main_channel_weight + else: + target_loss = target_loss * self.cross_channel_weight + # Normalize the losses in the training by the number of edges + if t in ["edge", "duration"]: + target_loss = target_loss / count * sample_size + + # Update the statistics + loss_all[t] += target_loss + correct_all[t] += correct + count_all[t] += count + + # Logging the average statistics + logging_output["ntokens"] = ntokens_all + logging_output["sample_size"] = sample_size_all + for t in self.targets: + log_key = { + "next": "general_token", + "edge": "edge_token", + "duration": "edge_duration", + }[t] + logging_output["{}_loss".format(log_key)] = loss_all[t].data + logging_output["{}_correct".format(log_key)] = correct_all[t] + logging_output["{}_count".format(log_key)] = count_all[t] + + # Define the training loss + training_loss = 0 + for t in self.targets: + training_loss += loss_all[t] * self.target_weights[t] + logging_output["loss"] = training_loss.data + + return training_loss, sample_size_all, logging_output + + def compute_loss(self, model, net_output, sample, reduce=True): + # Get the model outputs and target + lprobs_dict = model.get_normalized_probs(net_output, log_probs=True) + target_dict = model.get_targets(sample, net_output) + + # Init the dictionaries + loss_dict, stats_dict = {}, {} + + for channel in lprobs_dict: + # Init the dictionaries + loss_dict[channel], stats_dict[channel] = {}, {} + + for pred_channel in lprobs_dict[channel]: + # Init the dictionaries + loss_dict[channel][pred_channel] = {} + stats_dict[channel][pred_channel] = {} + + # Get token & duration predictions + outputs = lprobs_dict[channel][pred_channel] + if not isinstance(outputs, dict): + token_lprobs = outputs + else: + token_lprobs = outputs["pred_token"] + dur_preds = outputs["pred_duration"] + dur_preds = dur_preds.view(-1) + token_lprobs = token_lprobs.view(-1, token_lprobs.size(-1)) + token_preds = token_lprobs.argmax(dim=-1) + + # Get edge indices + if "edge" in self.targets or "duration" in self.targets: + edge_indices = target_dict["edge_indices"][pred_channel] + + # Compute loss and statistics + for t in self.targets: + if t in ["next", "edge"]: + if t == "next": + target = target_dict["next"][pred_channel].view(-1) + lprobs = token_lprobs + preds = token_preds + elif t == "edge": + target = target_dict["edge"][pred_channel] + lprobs = token_lprobs[edge_indices] + preds = token_preds[edge_indices] + + loss = F.nll_loss( + lprobs, + target, + ignore_index=self.padding_idx, + reduction="sum" if reduce else "none", + ) + elif t == "duration": + target = target_dict["duration"][pred_channel] + if self.delayed_duration_target: + duration_indices = edge_indices + 1 + if duration_indices[-1] == len(dur_preds): + duration_indices = duration_indices[:-1] + target = target[:-1] + else: + duration_indices = edge_indices + preds = dur_preds[duration_indices] + + loss = F.l1_loss( + preds, + target, + reduction="sum" if reduce else "none", + ) + preds = preds.round() + + correct = (preds == target).sum().float().cpu().item() + count = float(target.size(0)) + + loss_dict[channel][pred_channel][t] = loss + stats_dict[channel][pred_channel][t] = (correct, count) + + return loss_dict, stats_dict + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + logging_keys = next(iter(logging_outputs)).keys() + channels = [item[:-7] for item in logging_keys if item.endswith("ntokens")] + target_prefixes = set( + [ + item[:-5].split("]")[-1] + for item in logging_keys + if item.endswith("_loss") + ] + ) + for channel_prefix in channels: + for target_prefix in target_prefixes: + prefix = "{}{}".format(channel_prefix, target_prefix) + count_sum = sum( + log.get("{}_count".format(prefix), 0) for log in logging_outputs + ) + correct_sum = sum( + log.get("{}_correct".format(prefix), 0) for log in logging_outputs + ) + loss_sum = sum( + log.get("{}_loss".format(prefix), 0) for log in logging_outputs + ) + + if "duration" not in target_prefix: + # we divide by log(2) to convert the loss from base e to base 2 + metrics.log_scalar( + "{}_loss".format(prefix), + loss_sum / count_sum / math.log(2), + count_sum, + round=3, + ) + metrics.log_derived( + "{}_ppl".format(prefix), + lambda meters, prefix=prefix: utils.get_perplexity( + meters["{}_loss".format(prefix)].avg + ), + ) + else: + # for duration we don't need to divide by log(2) + metrics.log_scalar( + "{}_loss".format(prefix), + loss_sum / count_sum, + count_sum, + round=3, + ) + + accuracy = 100 * correct_sum / count_sum + metrics.log_scalar("{}_pred_acc".format(prefix), accuracy, round=3) + + # Logging training loss + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + + # we divide by log(2) to convert the loss from base e to base 2 + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/criterions/speech_to_speech_criterion.py b/fairseq/criterions/speech_to_speech_criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..06a825214013bf9d7d39d683895d90166efbae3f --- /dev/null +++ b/fairseq/criterions/speech_to_speech_criterion.py @@ -0,0 +1,517 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +from collections import OrderedDict + +import torch + +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import register_criterion +from fairseq.criterions.ctc import CtcCriterion +from fairseq.criterions.label_smoothed_cross_entropy_with_rdrop import ( + RdropLabelSmoothedCrossEntropyCriterion, + RdropLabelSmoothedCrossEntropyCriterionConfig, + duplicate_input, +) +from fairseq.criterions.tacotron2_loss import ( + Tacotron2Criterion, + Tacotron2CriterionConfig, +) + +logger = logging.getLogger(__name__) + + +class MultitaskCriterion: + def __init__(self, multitask_tasks, rdrop_alpha=0.0): + self.rdrop_alpha = rdrop_alpha + self.rdrop_alpha_mtl = rdrop_alpha + + self.multitask_criterion = OrderedDict() + self.multitask_loss_weight = OrderedDict() + for task_name, task_obj in multitask_tasks.items(): + if task_obj.args.get_loss_weight(0) == 0: + logger.info(f"Skip {task_name} loss criterion") + continue + + rdrop_alpha_task = task_obj.args.rdrop_alpha + if rdrop_alpha_task is None: + rdrop_alpha_task = rdrop_alpha + self.rdrop_alpha_mtl = rdrop_alpha_task + logger.info(f"rdrop_alpha is set to {rdrop_alpha_task} for {task_name}") + + if task_obj.args.decoder_type == "ctc": + self.multitask_criterion[task_name] = CtcCriterion( + task_obj.args.criterion_cfg, + task_obj, + rdrop_alpha=rdrop_alpha_task, + ) + else: + self.multitask_criterion[ + task_name + ] = RdropLabelSmoothedCrossEntropyCriterion( + task_obj, + task_obj.args.criterion_cfg.sentence_avg, + label_smoothing=task_obj.args.criterion_cfg.label_smoothing, + rdrop_alpha=rdrop_alpha_task, + ) + + def set_multitask_loss_weight(self, task_name, weight=0.0): + self.multitask_loss_weight[task_name] = weight + + def get_multitask_loss(self, model, sample, model_out): + logging_output = {} + loss = 0.0 + for task_name, task_criterion in self.multitask_criterion.items(): + layer_id = task_criterion.task.args.input_layer + if isinstance(task_criterion, CtcCriterion): + if task_criterion.task.args.input_from == "encoder": + if len(model_out["encoder_padding_mask"]) > 0: + non_padding_mask = ~model_out["encoder_padding_mask"][0] + input_lengths = non_padding_mask.long().sum(-1) + else: + out = model_out["encoder_states"][layer_id] + input_lengths = out.new_full( + (out.shape[1],), out.shape[0] + ).long() + + task_sample = { + "net_input": { + "src_tokens": model_out["encoder_states"][ + layer_id + ], # check batch idx + "src_lengths": input_lengths, + }, + "id": sample["id"], + } + else: + task_sample = { + "net_input": { + "src_tokens": model_out["inner_states"][layer_id], + "src_lengths": sample["target_lengths"], + }, + "id": sample["id"], + } + else: + task_sample = { + "net_input": { + "src_tokens": sample["multitask"][task_name]["net_input"][ + "prev_output_tokens" + ], + "encoder_out": { + "encoder_out": [model_out["encoder_states"][layer_id]], + "encoder_padding_mask": model_out["encoder_padding_mask"], + }, + } + } + + for key in ["target", "target_lengths", "ntokens"]: + task_sample[key] = sample["multitask"][task_name][key] + + if task_name == getattr(model, "mt_task_name", None): + decoder_out = model_out["mt_decoder_out"] + else: + decoder_out = None + task_loss, task_sample_size, task_logging_output = task_criterion( + model.multitask_decoders[task_name], task_sample, net_output=decoder_out + ) + + loss = loss + self.multitask_loss_weight[task_name] * task_loss + task_logging_output["loss_weight"] = self.multitask_loss_weight[task_name] + logging_output[task_name] = task_logging_output + return loss, logging_output + + @classmethod + def reduce_metrics(cls, logging_outputs) -> None: + for task_name in logging_outputs[0]["multitask"].keys(): + # different criterion may return different logging + # currently only reduce on loss, the most common one + # ideally the way that losses are reduced should also depend on the task type + loss_sum = sum( + log["multitask"][task_name].get("loss", 0) for log in logging_outputs + ) + sample_size = sum( + log["multitask"][task_name].get("sample_size", 0) + for log in logging_outputs + ) + + metrics.log_scalar( + f"multitask_{task_name}_loss", + loss_sum / sample_size / math.log(2), + sample_size, + round=3, + ) + + loss_weight = logging_outputs[0]["multitask"][task_name].get( + "loss_weight", 0 + ) + metrics.log_scalar( + f"multitask_{task_name}_loss_weight", + loss_weight, + weight=0, + priority=250, + ) + + +@register_criterion( + "speech_to_unit", dataclass=RdropLabelSmoothedCrossEntropyCriterionConfig +) +class SpeechToUnitMultitaskTaskCriterion( + RdropLabelSmoothedCrossEntropyCriterion, MultitaskCriterion +): + def __init__( + self, + task, + sentence_avg, + label_smoothing, + ignore_prefix_size=0, + report_accuracy=False, + rdrop_alpha=0.0, + ): + super().__init__( + task, + sentence_avg, + label_smoothing, + ignore_prefix_size, + report_accuracy, + rdrop_alpha, + ) + MultitaskCriterion.__init__(self, task.multitask_tasks, rdrop_alpha) + + def forward(self, model, sample, reduce=True): + net_input_concat = { + "src_tokens": sample["net_input"]["src_tokens"], + "src_lengths": sample["net_input"]["src_lengths"], + "prev_output_tokens": sample["net_input"]["prev_output_tokens"], + "tgt_speaker": sample["net_input"].get("tgt_speaker", None), + "return_all_hiddens": True, + } + + if self.rdrop_alpha > 0 or self.rdrop_alpha_mtl > 0: + net_input_concat = duplicate_input(net_input_concat) + + net_output, extra = model(**net_input_concat) + loss, nll_loss, rdrop_kl_loss = self.compute_loss( + model, [net_output], sample, reduce=reduce + ) + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + logging_output = { + "loss": loss.data, + "nll_loss": nll_loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + if self.report_accuracy: + n_correct, total = self.compute_accuracy(model, [net_output], sample) + logging_output["n_correct"] = utils.item(n_correct.data) + logging_output["total"] = utils.item(total.data) + if self.rdrop_alpha > 0: + logging_output["rdrop_kl_loss"] = utils.item(rdrop_kl_loss.data) + + if len(self.multitask_criterion) == 0: + return loss, sample_size, logging_output + + # multitask + multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra) + loss += multitask_loss + logging_output["multitask"] = multitask_log + + return loss, sample_size, logging_output + + @classmethod + def reduce_metrics(cls, logging_outputs) -> None: + super().reduce_metrics(logging_outputs) + + # inference metrics + if "targ_frames" in logging_outputs[0]: + n = sum(log.get("norm_frames", 0) for log in logging_outputs) + for key, new_key in [ + ("mcd_loss", "mcd_loss"), + ("pred_frames", "pred_ratio"), + ("nins", "ins_rate"), + ("ndel", "del_rate"), + ]: + val = sum(log.get(key, 0) for log in logging_outputs) + metrics.log_scalar(new_key, val / n, n, round=3) + + if "multitask" not in logging_outputs[0]: + return + + MultitaskCriterion.reduce_metrics(logging_outputs) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return False + + +@register_criterion( + "speech_to_unit_2pass", dataclass=RdropLabelSmoothedCrossEntropyCriterionConfig +) +class SpeechToUnit2passMultitaskTaskCriterion(SpeechToUnitMultitaskTaskCriterion): + def __init__( + self, + task, + sentence_avg, + label_smoothing, + ignore_prefix_size=0, + report_accuracy=False, + rdrop_alpha=0.0, + ): + super().__init__( + task, + sentence_avg, + label_smoothing, + ignore_prefix_size, + report_accuracy, + rdrop_alpha, + ) + + def forward(self, model, sample, reduce=True): + net_input_concat = { + "src_tokens": sample["net_input"]["src_tokens"], + "src_lengths": sample["net_input"]["src_lengths"], + "prev_output_tokens": sample["net_input"]["prev_output_tokens"], + "prev_output_tokens_mt": sample["multitask"][model.mt_task_name][ + "net_input" + ]["prev_output_tokens"], + "tgt_speaker": sample["net_input"].get("tgt_speaker", None), + "return_all_hiddens": True, + } + if getattr(model, "asr_task_name", None) is not None: + net_input_concat["prev_output_tokens_asr"] = sample["multitask"][ + model.asr_task_name + ]["net_input"]["prev_output_tokens"] + + if self.rdrop_alpha > 0 or self.rdrop_alpha_mtl > 0: + net_input_concat = duplicate_input(net_input_concat) + + net_output, extra = model(**net_input_concat) + loss, nll_loss, rdrop_kl_loss = self.compute_loss( + model, [net_output], sample, reduce=reduce + ) + + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + logging_output = { + "loss": loss.data, + "nll_loss": nll_loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + if self.report_accuracy: + n_correct, total = self.compute_accuracy(model, [net_output], sample) + logging_output["n_correct"] = utils.item(n_correct.data) + logging_output["total"] = utils.item(total.data) + if self.rdrop_alpha > 0: + logging_output["rdrop_kl_loss"] = utils.item(rdrop_kl_loss.data) + + if len(self.multitask_criterion) == 0: + return loss, sample_size, logging_output + + # multitask + multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra) + loss += multitask_loss + logging_output["multitask"] = multitask_log + + return loss, sample_size, logging_output + + +@register_criterion("speech_to_spectrogram", dataclass=Tacotron2CriterionConfig) +class SpeechToSpectrogramMultitaskTaskCriterion(Tacotron2Criterion, MultitaskCriterion): + def __init__( + self, + task, + sentence_avg, + use_guided_attention_loss, + guided_attention_loss_sigma, + bce_pos_weight, + ctc_weight, + ): + super().__init__( + task, + sentence_avg, + use_guided_attention_loss, + guided_attention_loss_sigma, + bce_pos_weight, + ctc_weight, + ) + MultitaskCriterion.__init__(self, task.multitask_tasks) + + def forward(self, model, sample, reduction="mean"): + bsz, max_len, _ = sample["target"].size() + feat_tgt = sample["target"] + feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len) + eos_tgt = torch.arange(max_len).to(sample["target"].device) + eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1) + eos_tgt = (eos_tgt == (feat_len - 1)).float() + + feat_out, eos_out, extra = model( + src_tokens=sample["net_input"]["src_tokens"], + src_lengths=sample["net_input"]["src_lengths"], + prev_output_tokens=sample["net_input"]["prev_output_tokens"], + tgt_speaker=sample["net_input"]["tgt_speaker"], + target_lengths=sample["target_lengths"], + return_all_hiddens=True, + ) + + l1_loss, mse_loss, eos_loss = self.compute_loss( + extra["feature_out"], + feat_out, + eos_out, + feat_tgt, + eos_tgt, + sample["target_lengths"], + reduction, + ) + attn_loss = torch.tensor(0.0).type_as(l1_loss) + if self.guided_attn is not None: + attn_loss = self.guided_attn( + extra["attn"], + sample["net_input"]["src_lengths"], + sample["target_lengths"], + reduction, + ) + loss = ( + l1_loss + mse_loss + eos_loss + attn_loss + ) # do not include ctc loss as there's no text target + + sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"] + logging_output = { + "loss": utils.item(loss.data), + "ntokens": sample["ntokens"], + "nsentences": sample["nsentences"], + "sample_size": sample_size, + "l1_loss": utils.item(l1_loss.data), + "mse_loss": utils.item(mse_loss.data), + "eos_loss": utils.item(eos_loss.data), + "attn_loss": utils.item(attn_loss.data), + } + + if len(self.multitask_criterion) == 0: + return loss, sample_size, logging_output + + # multitask + multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra) + loss += multitask_loss + logging_output["multitask"] = multitask_log + return loss, sample_size, logging_output + + @classmethod + def reduce_metrics(cls, logging_outputs) -> None: + super().reduce_metrics(logging_outputs) + + # inference metrics + if "targ_frames" in logging_outputs[0]: + n = sum(log.get("norm_frames", 0) for log in logging_outputs) + for key, new_key in [ + ("mcd_loss", "mcd_loss"), + ("pred_frames", "pred_ratio"), + ("nins", "ins_rate"), + ("ndel", "del_rate"), + ]: + val = sum(log.get(key, 0) for log in logging_outputs) + metrics.log_scalar(new_key, val / n, n, round=3) + + if "multitask" not in logging_outputs[0]: + return + + MultitaskCriterion.reduce_metrics(logging_outputs) + + +@register_criterion("speech_to_spectrogram_2pass", dataclass=Tacotron2CriterionConfig) +class SpeechToSpectrogram2passMultitaskTaskCriterion( + SpeechToSpectrogramMultitaskTaskCriterion +): + def __init__( + self, + task, + sentence_avg, + use_guided_attention_loss, + guided_attention_loss_sigma, + bce_pos_weight, + ctc_weight, + ): + super().__init__( + task, + sentence_avg, + use_guided_attention_loss, + guided_attention_loss_sigma, + bce_pos_weight, + ctc_weight, + ) + + def forward(self, model, sample, reduction="mean"): + bsz, max_len, _ = sample["target"].size() + feat_tgt = sample["target"] + feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len) + eos_tgt = torch.arange(max_len).to(sample["target"].device) + eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1) + eos_tgt = (eos_tgt == (feat_len - 1)).float() + + feat_out, eos_out, extra = model( + src_tokens=sample["net_input"]["src_tokens"], + src_lengths=sample["net_input"]["src_lengths"], + prev_output_tokens=sample["net_input"]["prev_output_tokens"], + prev_output_tokens_mt=sample["multitask"][model.mt_task_name]["net_input"][ + "prev_output_tokens" + ], + tgt_speaker=sample["net_input"]["tgt_speaker"], + target_lengths=sample["target_lengths"], + return_all_hiddens=True, + ) + + l1_loss, mse_loss, eos_loss = self.compute_loss( + extra["feature_out"], + feat_out, + eos_out, + feat_tgt, + eos_tgt, + sample["target_lengths"], + reduction, + ) + attn_loss = torch.tensor(0.0).type_as(l1_loss) + if self.guided_attn is not None: + attn_loss = self.guided_attn( + extra["attn"], + sample["net_input"]["src_lengths"], + sample["target_lengths"], + reduction, + ) + loss = ( + l1_loss + mse_loss + eos_loss + attn_loss + ) # do not include ctc loss as there's no text target + + sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"] + logging_output = { + "loss": utils.item(loss.data), + "ntokens": sample["ntokens"], + "nsentences": sample["nsentences"], + "sample_size": sample_size, + "l1_loss": utils.item(l1_loss.data), + "mse_loss": utils.item(mse_loss.data), + "eos_loss": utils.item(eos_loss.data), + "attn_loss": utils.item(attn_loss.data), + } + + if len(self.multitask_criterion) == 0: + return loss, sample_size, logging_output + + # multitask + multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra) + loss += multitask_loss + logging_output["multitask"] = multitask_log + return loss, sample_size, logging_output diff --git a/fairseq/criterions/speech_ulm_criterion.py b/fairseq/criterions/speech_ulm_criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..eea74bae2641e285c56e86feb6e9866464c9673f --- /dev/null +++ b/fairseq/criterions/speech_ulm_criterion.py @@ -0,0 +1,126 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from dataclasses import dataclass, field + +import torch.nn.functional as F +from fairseq.logging import metrics +from fairseq.tasks import FairseqTask +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from omegaconf import II + + +@dataclass +class SpeechUnitLmCriterionConfig(FairseqDataclass): + sentence_avg: bool = II("optimization.sentence_avg") + loss_weights: str = field( + default="1.;0.0;0.0", + metadata={ + "help": "Weights of the losses that correspond to token, duration, and F0 streams" + }, + ) + discrete_duration: bool = II("task.discrete_duration") + discrete_f0: bool = II("task.discrete_f0") + + +def mae_loss(pred, targ, mask, reduce=True): + if pred.ndim == 3: + pred = pred.squeeze(2) + else: + assert pred.ndim == 2 + loss = (pred.float() - targ.float()).abs() * (~mask).float() + loss = loss.sum() if reduce else loss.view(-1) + return loss + + +def nll_loss(pred, targ, mask, reduce=True): + lprob = F.log_softmax(pred, dim=-1) + loss = F.nll_loss(lprob.view(-1, lprob.size(-1)), targ.view(-1), reduction="none") + loss = loss * (~mask).float().view(-1) + loss = loss.sum() if reduce else loss.view(-1) + return loss + + +@register_criterion("speech_unit_lm_criterion", dataclass=SpeechUnitLmCriterionConfig) +class SpeechUnitLmCriterion(FairseqCriterion): + def __init__(self, cfg: SpeechUnitLmCriterionConfig, task: FairseqTask): + super().__init__(task) + self.sentence_avg = cfg.sentence_avg + self.weights = torch.tensor([float(w) for w in cfg.loss_weights.split(";")]) + assert self.weights.size(0) == 3 + assert (self.weights >= 0.0).all() + + self.dur_loss_fn = nll_loss if cfg.discrete_duration else mae_loss + self.f0_loss_fn = nll_loss if cfg.discrete_f0 else mae_loss + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample["net_input"]) + + token_loss = nll_loss( + net_output["token"], sample["target"], sample["mask"], reduce + ) + dur_loss = self.dur_loss_fn( + net_output["duration"], + sample["dur_target"], + sample["dur_mask"], + reduce, + ) + f0_loss = self.f0_loss_fn( + net_output["f0"], + sample["f0_target"], + sample["f0_mask"], + reduce, + ) + loss = self.weights.to(token_loss.device) * torch.stack( + [token_loss, dur_loss, f0_loss], dim=-1 + ) + loss = loss.sum() if reduce else loss.sum(-1) + + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + logging_output = { + "loss": loss.detach().sum().item(), + "token_loss": token_loss.detach().sum().item(), + "dur_loss": dur_loss.detach().sum().item(), + "f0_loss": f0_loss.detach().sum().item(), + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + token_loss_sum = sum(log.get("token_loss", 0) for log in logging_outputs) + dur_loss_sum = sum(log.get("dur_loss", 0) for log in logging_outputs) + f0_loss_sum = sum(log.get("f0_loss", 0) for log in logging_outputs) + + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3) + + metrics.log_scalar( + "token_loss", token_loss_sum / sample_size, sample_size, round=3 + ) + + metrics.log_scalar("dur_loss", dur_loss_sum / sample_size, sample_size, round=3) + + metrics.log_scalar("f0_loss", f0_loss_sum / sample_size, sample_size, round=3) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + return True diff --git a/fairseq/criterions/tacotron2_loss.py b/fairseq/criterions/tacotron2_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..4113fdc5489f1e0c787b60735086ae9d073c8e17 --- /dev/null +++ b/fairseq/criterions/tacotron2_loss.py @@ -0,0 +1,227 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import logging +from dataclasses import dataclass, field +from functools import lru_cache +from typing import Any, Dict, List + +import torch +import torch.nn.functional as F +from omegaconf import II + +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.data.data_utils import lengths_to_mask +from fairseq.dataclass import FairseqDataclass + +logger = logging.getLogger(__name__) + + +@dataclass +class Tacotron2CriterionConfig(FairseqDataclass): + bce_pos_weight: float = field( + default=1.0, + metadata={"help": "weight of positive examples for BCE loss"}, + ) + use_guided_attention_loss: bool = field( + default=False, + metadata={"help": "use guided attention loss"}, + ) + guided_attention_loss_sigma: float = field( + default=0.4, + metadata={"help": "weight of positive examples for BCE loss"}, + ) + ctc_weight: float = field(default=0.0, metadata={"help": "weight for CTC loss"}) + sentence_avg: bool = II("optimization.sentence_avg") + + +class GuidedAttentionLoss(torch.nn.Module): + """ + Efficiently Trainable Text-to-Speech System Based on Deep Convolutional + Networks with Guided Attention (https://arxiv.org/abs/1710.08969) + """ + + def __init__(self, sigma): + super().__init__() + self.sigma = sigma + + @staticmethod + @lru_cache(maxsize=8) + def _get_weight(s_len, t_len, sigma): + grid_x, grid_y = torch.meshgrid(torch.arange(t_len), torch.arange(s_len)) + grid_x = grid_x.to(s_len.device) + grid_y = grid_y.to(s_len.device) + w = (grid_y.float() / s_len - grid_x.float() / t_len) ** 2 + return 1.0 - torch.exp(-w / (2 * (sigma**2))) + + def _get_weights(self, src_lens, tgt_lens): + bsz, max_s_len, max_t_len = len(src_lens), max(src_lens), max(tgt_lens) + weights = torch.zeros((bsz, max_t_len, max_s_len)) + for i, (s_len, t_len) in enumerate(zip(src_lens, tgt_lens)): + weights[i, :t_len, :s_len] = self._get_weight(s_len, t_len, self.sigma) + return weights + + @staticmethod + def _get_masks(src_lens, tgt_lens): + in_masks = lengths_to_mask(src_lens) + out_masks = lengths_to_mask(tgt_lens) + return out_masks.unsqueeze(2) & in_masks.unsqueeze(1) + + def forward(self, attn, src_lens, tgt_lens, reduction="mean"): + weights = self._get_weights(src_lens, tgt_lens).to(attn.device) + masks = self._get_masks(src_lens, tgt_lens).to(attn.device) + loss = (weights * attn.transpose(1, 2)).masked_select(masks) + loss = torch.sum(loss) if reduction == "sum" else torch.mean(loss) + return loss + + +@register_criterion("tacotron2", dataclass=Tacotron2CriterionConfig) +class Tacotron2Criterion(FairseqCriterion): + def __init__( + self, + task, + sentence_avg, + use_guided_attention_loss, + guided_attention_loss_sigma, + bce_pos_weight, + ctc_weight, + ): + super().__init__(task) + self.sentence_avg = sentence_avg + self.bce_pos_weight = bce_pos_weight + + self.guided_attn = None + if use_guided_attention_loss: + self.guided_attn = GuidedAttentionLoss(guided_attention_loss_sigma) + self.ctc_weight = ctc_weight + + def forward(self, model, sample, reduction="mean"): + bsz, max_len, _ = sample["target"].size() + feat_tgt = sample["target"] + feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len) + eos_tgt = torch.arange(max_len).to(sample["target"].device) + eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1) + eos_tgt = (eos_tgt == (feat_len - 1)).float() + src_tokens = sample["net_input"]["src_tokens"] + src_lens = sample["net_input"]["src_lengths"] + tgt_lens = sample["target_lengths"] + + feat_out, eos_out, extra = model( + src_tokens=src_tokens, + src_lengths=src_lens, + prev_output_tokens=sample["net_input"]["prev_output_tokens"], + incremental_state=None, + target_lengths=tgt_lens, + speaker=sample["speaker"], + ) + + l1_loss, mse_loss, eos_loss = self.compute_loss( + extra["feature_out"], + feat_out, + eos_out, + feat_tgt, + eos_tgt, + tgt_lens, + reduction, + ) + attn_loss = torch.tensor(0.0).type_as(l1_loss) + if self.guided_attn is not None: + attn_loss = self.guided_attn(extra["attn"], src_lens, tgt_lens, reduction) + ctc_loss = torch.tensor(0.0).type_as(l1_loss) + if self.ctc_weight > 0.0: + net_output = (feat_out, eos_out, extra) + lprobs = model.get_normalized_probs(net_output, log_probs=True) + lprobs = lprobs.transpose(0, 1) # T x B x C + src_mask = lengths_to_mask(src_lens) + src_tokens_flat = src_tokens.masked_select(src_mask) + ctc_loss = ( + F.ctc_loss( + lprobs, + src_tokens_flat, + tgt_lens, + src_lens, + reduction=reduction, + zero_infinity=True, + ) + * self.ctc_weight + ) + loss = l1_loss + mse_loss + eos_loss + attn_loss + ctc_loss + + sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"] + logging_output = { + "loss": utils.item(loss.data), + "ntokens": sample["ntokens"], + "nsentences": sample["nsentences"], + "sample_size": sample_size, + "l1_loss": utils.item(l1_loss.data), + "mse_loss": utils.item(mse_loss.data), + "eos_loss": utils.item(eos_loss.data), + "attn_loss": utils.item(attn_loss.data), + "ctc_loss": utils.item(ctc_loss.data), + } + return loss, sample_size, logging_output + + def compute_loss( + self, + feat_out, + feat_out_post, + eos_out, + feat_tgt, + eos_tgt, + tgt_lens, + reduction="mean", + ): + mask = lengths_to_mask(tgt_lens) + _eos_out = eos_out[mask].squeeze() + _eos_tgt = eos_tgt[mask] + _feat_tgt = feat_tgt[mask] + _feat_out = feat_out[mask] + _feat_out_post = feat_out_post[mask] + + l1_loss = F.l1_loss(_feat_out, _feat_tgt, reduction=reduction) + F.l1_loss( + _feat_out_post, _feat_tgt, reduction=reduction + ) + mse_loss = F.mse_loss(_feat_out, _feat_tgt, reduction=reduction) + F.mse_loss( + _feat_out_post, _feat_tgt, reduction=reduction + ) + eos_loss = F.binary_cross_entropy_with_logits( + _eos_out, + _eos_tgt, + pos_weight=torch.tensor(self.bce_pos_weight), + reduction=reduction, + ) + return l1_loss, mse_loss, eos_loss + + @classmethod + def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None: + ns = [log.get("sample_size", 0) for log in logging_outputs] + ntot = sum(ns) + ws = [n / (ntot + 1e-8) for n in ns] + for key in ["loss", "l1_loss", "mse_loss", "eos_loss", "attn_loss", "ctc_loss"]: + vals = [log.get(key, 0) for log in logging_outputs] + val = sum(val * w for val, w in zip(vals, ws)) + metrics.log_scalar(key, val, ntot, round=3) + metrics.log_scalar("sample_size", ntot, len(logging_outputs)) + + # inference metrics + if "targ_frames" not in logging_outputs[0]: + return + n = sum(log.get("targ_frames", 0) for log in logging_outputs) + for key, new_key in [ + ("mcd_loss", "mcd_loss"), + ("pred_frames", "pred_ratio"), + ("nins", "ins_rate"), + ("ndel", "del_rate"), + ]: + val = sum(log.get(key, 0) for log in logging_outputs) + metrics.log_scalar(new_key, val / n, n, round=3) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + return False diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..3975468487704e053fe7634257f443ee2c396616 --- /dev/null +++ b/fairseq/criterions/wav2vec_criterion.py @@ -0,0 +1,231 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass, field +from typing import List, Optional + +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from fairseq.logging.meters import safe_round +from fairseq.utils import is_xla_tensor + + +@dataclass +class Wav2VecCriterionConfig(FairseqDataclass): + infonce: bool = field( + default=False, + metadata={ + "help": "if set, uses cross entropy instead of binary cross entropy (i.e. InfoNCE loss)" + }, + ) + loss_weights: Optional[List[float]] = field( + default=None, + metadata={"help": "weights for additional loss terms (not first one)"}, + ) + log_keys: List[str] = field( + default_factory=lambda: [], + metadata={"help": "output keys to log"}, + ) + + +@register_criterion("wav2vec", dataclass=Wav2VecCriterionConfig) +class Wav2vecCriterion(FairseqCriterion): + def __init__(self, task, infonce=False, loss_weights=None, log_keys=None): + super().__init__(task) + self.infonce = infonce + self.loss_weights = loss_weights + self.log_keys = [] if log_keys is None else log_keys + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample["net_input"]) + logits = model.get_logits(net_output).float() + target = model.get_targets(sample, net_output) + self.xla = is_xla_tensor(logits) + + # XXX: handle weights on xla. + weights = None + if hasattr(model, "get_target_weights") and not self.infonce: + weights = model.get_target_weights(target, net_output) + if torch.is_tensor(weights): + weights = weights.float() + + losses = [] + + reduction = "none" if ((not reduce) or self.xla) else "sum" + if self.infonce: + loss = F.cross_entropy(logits, target, reduction=reduction) + else: + loss = F.binary_cross_entropy_with_logits( + logits, target.float(), weights, reduction=reduction + ) + + if self.xla: + # tpu-comment: since dynamic shapes lead to recompilations on xla, + # we don't shrink tensors using mask_indices. + # Instead, we use mask indices to adjust loss. + mi = ( + sample["net_input"]["mask_indices"] + .transpose(0, 1) # logits are transposed in `model.get_logits` + .reshape(logits.size(0)) + ) + loss = (loss * mi).sum() if reduce else (loss * mi) + + if "sample_size" in sample: + sample_size = sample["sample_size"] + elif "mask_indices" in sample["net_input"]: + sample_size = sample["net_input"]["mask_indices"].sum() + else: + sample_size = target.numel() if self.infonce else target.long().sum().item() + losses.append(loss.detach().clone()) + + if self.loss_weights is not None: + assert hasattr(model, "get_extra_losses") + extra_losses = model.get_extra_losses(net_output) + if torch.is_tensor(extra_losses): + extra_losses = [extra_losses] + if len(self.loss_weights) == 1 and len(extra_losses) != 1: + self.loss_weights = [self.loss_weights[0]] * len(extra_losses) + assert len(extra_losses) == len( + self.loss_weights + ), f"{len(extra_losses)}, {len(self.loss_weights)}" + for p, coef in zip(extra_losses, self.loss_weights): + if coef != 0 and p is not None: + p = coef * p.float() * sample_size + loss += p + losses.append(p) + + logging_output = { + "loss": loss.item() if (reduce and not self.xla) else loss.detach(), + "ntokens": sample_size, + "nsentences": sample["id"].numel(), + "sample_size": sample_size, + } + + for lk in self.log_keys: + # Only store "logits" and "target" for computing MAP and MAUC + # during validation + if lk == "logits": + if not self.training: + logging_output["logits"] = logits.cpu().numpy() + elif lk == "target": + if not self.training: + # If the targets have been mixed with the predictions of + # teacher models, find the original targets + if hasattr(model, "get_original_targets"): + original_target = model.get_original_targets(sample, net_output) + else: + original_target = target + logging_output["target"] = original_target.cpu().numpy() + elif lk in net_output: + value = net_output[lk] + if not is_xla_tensor(value): + value = float(value) + logging_output[lk] = value + + if len(losses) > 1: + for i, l in enumerate(losses): + logging_output[f"loss_{i}"] = l.item() if not self.xla else l.detach() + + if self.infonce: + with torch.no_grad(): + if logits.numel() == 0: + corr = 0 + count = 0 + else: + assert logits.dim() > 1, logits.shape + max = logits.argmax(-1) == 0 + min = logits.argmin(-1) == 0 + if is_xla_tensor(logits): + max, min = max * mi, min * mi + both = max & min + corr = max.long().sum() - both.long().sum() + count = mi.sum() + else: + both = max & min + corr = max.long().sum().item() - both.long().sum().item() + count = float(max.numel()) + + logging_output["correct"] = corr + logging_output["count"] = count + + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) + ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) + nsentences = utils.item( + sum(log.get("nsentences", 0) for log in logging_outputs) + ) + sample_size = utils.item( + sum(log.get("sample_size", 0) for log in logging_outputs) + ) + + metrics.log_scalar( + "loss", loss_sum / (sample_size or 1) / math.log(2), sample_size, round=3 + ) + metrics.log_scalar("ntokens", ntokens) + metrics.log_scalar("nsentences", nsentences) + + correct = sum(log.get("correct", 0) for log in logging_outputs) + metrics.log_scalar("_correct", correct) + + total = sum(log.get("count", 0) for log in logging_outputs) + metrics.log_scalar("_total", total) + + if total > 0: + metrics.log_derived( + "accuracy", + lambda meters: safe_round( + meters["_correct"].sum / meters["_total"].sum, 5 + ) + if meters["_total"].sum > 0 + else float("nan"), + ) + + builtin_keys = { + "loss", + "ntokens", + "nsentences", + "sample_size", + "correct", + "count", + } + + for k in logging_outputs[0]: + if k not in builtin_keys: + val = sum(log.get(k, 0) for log in logging_outputs) + if k.startswith("loss"): + metrics.log_scalar( + k, val / (sample_size or 1) / math.log(2), sample_size, round=3 + ) + else: + metrics.log_scalar(k, val / len(logging_outputs), round=3) + + # FIXME: revert when gather based xla reduction is implemented + # @staticmethod + # def logging_outputs_can_be_summed() -> bool: + def logging_outputs_can_be_summed(self) -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + # XXX: Gather based reduction not implemented for xla yet. + # So we fall to sum based reduction for xla. + return self.xla diff --git a/fairseq/data/__init__.py b/fairseq/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eeaae2b2547dc830c95a6a4313a02d469d4f63cd --- /dev/null +++ b/fairseq/data/__init__.py @@ -0,0 +1,137 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +from .dictionary import Dictionary, TruncatedDictionary + +from .fairseq_dataset import FairseqDataset, FairseqIterableDataset + +from .base_wrapper_dataset import BaseWrapperDataset + +from .add_target_dataset import AddTargetDataset +from .append_token_dataset import AppendTokenDataset +from .audio.raw_audio_dataset import BinarizedAudioDataset, FileAudioDataset +from .audio.hubert_dataset import HubertDataset +from .backtranslation_dataset import BacktranslationDataset +from .bucket_pad_length_dataset import BucketPadLengthDataset +from .colorize_dataset import ColorizeDataset +from .concat_dataset import ConcatDataset +from .concat_sentences_dataset import ConcatSentencesDataset +from .denoising_dataset import DenoisingDataset +from .id_dataset import IdDataset +from .indexed_dataset import ( + IndexedCachedDataset, + IndexedDataset, + IndexedRawTextDataset, + MMapIndexedDataset, +) +from .language_pair_dataset import LanguagePairDataset +from .list_dataset import ListDataset +from .lm_context_window_dataset import LMContextWindowDataset +from .lru_cache_dataset import LRUCacheDataset +from .mask_tokens_dataset import MaskTokensDataset +from .monolingual_dataset import MonolingualDataset +from .multi_corpus_sampled_dataset import MultiCorpusSampledDataset +from .nested_dictionary_dataset import NestedDictionaryDataset +from .noising import NoisingDataset +from .numel_dataset import NumelDataset +from .num_samples_dataset import NumSamplesDataset +from .offset_tokens_dataset import OffsetTokensDataset +from .padding_mask_dataset import ( + LeftPaddingMaskDataset, + PaddingMaskDataset, + RightPaddingMaskDataset, +) +from .pad_dataset import LeftPadDataset, PadDataset, RightPadDataset +from .prepend_dataset import PrependDataset +from .prepend_token_dataset import PrependTokenDataset +from .raw_label_dataset import RawLabelDataset +from .replace_dataset import ReplaceDataset +from .resampling_dataset import ResamplingDataset +from .roll_dataset import RollDataset +from .round_robin_zip_datasets import RoundRobinZipDatasets +from .sort_dataset import SortDataset +from .speech_dlm_dataset import SpeechDLMDataset +from .strip_token_dataset import StripTokenDataset +from .subsample_dataset import SubsampleDataset +from .token_block_dataset import TokenBlockDataset +from .transform_eos_dataset import TransformEosDataset +from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset +from .shorten_dataset import TruncateDataset, RandomCropDataset +from .multilingual.sampled_multi_dataset import SampledMultiDataset +from .multilingual.sampled_multi_epoch_dataset import SampledMultiEpochDataset +from .fasta_dataset import FastaDataset, EncodedFastaDataset +from .transform_eos_concat_langpair_dataset import TransformEosConcatLangPairDataset + +from .iterators import ( + CountingIterator, + EpochBatchIterator, + GroupedIterator, + ShardedIterator, +) + +__all__ = [ + "AddTargetDataset", + "AppendTokenDataset", + "BacktranslationDataset", + "BaseWrapperDataset", + "BinarizedAudioDataset", + "BucketPadLengthDataset", + "ColorizeDataset", + "ConcatDataset", + "ConcatSentencesDataset", + "CountingIterator", + "DenoisingDataset", + "Dictionary", + "EncodedFastaDataset", + "EpochBatchIterator", + "FairseqDataset", + "FairseqIterableDataset", + "FastaDataset", + "FileAudioDataset", + "GroupedIterator", + "HubertDataset", + "IdDataset", + "IndexedCachedDataset", + "IndexedDataset", + "IndexedRawTextDataset", + "LanguagePairDataset", + "LeftPadDataset", + "ListDataset", + "LMContextWindowDataset", + "LRUCacheDataset", + "MaskTokensDataset", + "MMapIndexedDataset", + "MonolingualDataset", + "MultiCorpusSampledDataset", + "NestedDictionaryDataset", + "NoisingDataset", + "NumelDataset", + "NumSamplesDataset", + "OffsetTokensDataset", + "PadDataset", + "PrependDataset", + "PrependTokenDataset", + "RandomCropDataset", + "RawLabelDataset", + "ResamplingDataset", + "ReplaceDataset", + "RightPadDataset", + "RollDataset", + "RoundRobinZipDatasets", + "SampledMultiDataset", + "SampledMultiEpochDataset", + "ShardedIterator", + "SortDataset", + "SpeechDLMDataset", + "StripTokenDataset", + "SubsampleDataset", + "TokenBlockDataset", + "TransformEosDataset", + "TransformEosLangPairDataset", + "TransformEosConcatLangPairDataset", + "TruncateDataset", + "TruncatedDictionary", +] diff --git a/fairseq/data/__pycache__/__init__.cpython-310.pyc b/fairseq/data/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c937349a38af563d5b8e5be2599dbab791b366c4 Binary files /dev/null and b/fairseq/data/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/add_target_dataset.cpython-310.pyc b/fairseq/data/__pycache__/add_target_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa93b0ab9e80d3fead21753da496d02bad012ea7 Binary files /dev/null and b/fairseq/data/__pycache__/add_target_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/append_token_dataset.cpython-310.pyc b/fairseq/data/__pycache__/append_token_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7c95b0c2c523ae9955b4e41e14899229daab406 Binary files /dev/null and b/fairseq/data/__pycache__/append_token_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/backtranslation_dataset.cpython-310.pyc b/fairseq/data/__pycache__/backtranslation_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b7478269993f7683c72a643e78d235a01da156f Binary files /dev/null and b/fairseq/data/__pycache__/backtranslation_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/base_wrapper_dataset.cpython-310.pyc b/fairseq/data/__pycache__/base_wrapper_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0051b0ff4c324462de1b02dcad40ce412381926 Binary files /dev/null and b/fairseq/data/__pycache__/base_wrapper_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/bucket_pad_length_dataset.cpython-310.pyc b/fairseq/data/__pycache__/bucket_pad_length_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1945fb2c661e26ea5c08ae9a924a7fc7653e66c6 Binary files /dev/null and b/fairseq/data/__pycache__/bucket_pad_length_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/codedataset.cpython-310.pyc b/fairseq/data/__pycache__/codedataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8dd59068f3187e4567f367e0ee1e3543e7b913f8 Binary files /dev/null and b/fairseq/data/__pycache__/codedataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/colorize_dataset.cpython-310.pyc b/fairseq/data/__pycache__/colorize_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de13f2bc0af3f876b2e2647549db970fa20da73d Binary files /dev/null and b/fairseq/data/__pycache__/colorize_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/concat_dataset.cpython-310.pyc b/fairseq/data/__pycache__/concat_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f307323414cecc289b1bd5d1218c67ac6b3b09ee Binary files /dev/null and b/fairseq/data/__pycache__/concat_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/concat_sentences_dataset.cpython-310.pyc b/fairseq/data/__pycache__/concat_sentences_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdf32050978051e147fb347778d3226af470a3c3 Binary files /dev/null and b/fairseq/data/__pycache__/concat_sentences_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/data_utils.cpython-310.pyc b/fairseq/data/__pycache__/data_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e1c5a8e7bdaac2e07a1edc4e04a60ba3f3d5405 Binary files /dev/null and b/fairseq/data/__pycache__/data_utils.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/denoising_dataset.cpython-310.pyc b/fairseq/data/__pycache__/denoising_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bac81c47f3268a3fcdfabeeb5f4ae063f906655 Binary files /dev/null and b/fairseq/data/__pycache__/denoising_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/dictionary.cpython-310.pyc b/fairseq/data/__pycache__/dictionary.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a97ae08f9c35911db3ae7b85e7b5bdf52795803f Binary files /dev/null and b/fairseq/data/__pycache__/dictionary.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/fairseq_dataset.cpython-310.pyc b/fairseq/data/__pycache__/fairseq_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb6e2271c56a2cf70c04008e012145da409c24d3 Binary files /dev/null and b/fairseq/data/__pycache__/fairseq_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/fasta_dataset.cpython-310.pyc b/fairseq/data/__pycache__/fasta_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09f6f14ead85283c6f204d37323a321cb4cecb34 Binary files /dev/null and b/fairseq/data/__pycache__/fasta_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/id_dataset.cpython-310.pyc b/fairseq/data/__pycache__/id_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd749d8180911def1c9ad95c6e9016f8732b82f8 Binary files /dev/null and b/fairseq/data/__pycache__/id_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/indexed_dataset.cpython-310.pyc b/fairseq/data/__pycache__/indexed_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb8412856ec98f02d42fc696ff4299a5afd478cf Binary files /dev/null and b/fairseq/data/__pycache__/indexed_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/iterators.cpython-310.pyc b/fairseq/data/__pycache__/iterators.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad9badb1155fc2bc873dfdb9eacff0001a3ecdb3 Binary files /dev/null and b/fairseq/data/__pycache__/iterators.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/language_pair_dataset.cpython-310.pyc b/fairseq/data/__pycache__/language_pair_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6d3d63eda3337a44117e27a59db11cc73d165f2 Binary files /dev/null and b/fairseq/data/__pycache__/language_pair_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/list_dataset.cpython-310.pyc b/fairseq/data/__pycache__/list_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39a12df45be736a27acacba5a1571310dcfc10d2 Binary files /dev/null and b/fairseq/data/__pycache__/list_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/lm_context_window_dataset.cpython-310.pyc b/fairseq/data/__pycache__/lm_context_window_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab2d784dda135912f0db2249435d132fbb7ec312 Binary files /dev/null and b/fairseq/data/__pycache__/lm_context_window_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/lru_cache_dataset.cpython-310.pyc b/fairseq/data/__pycache__/lru_cache_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f76fc913ec5c694cb5552bf4e2c46591926babc Binary files /dev/null and b/fairseq/data/__pycache__/lru_cache_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/mask_tokens_dataset.cpython-310.pyc b/fairseq/data/__pycache__/mask_tokens_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca6b5d2c03252acc302b3e1af0f2914eaf397dbd Binary files /dev/null and b/fairseq/data/__pycache__/mask_tokens_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/monolingual_dataset.cpython-310.pyc b/fairseq/data/__pycache__/monolingual_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..590a546be31a157a0ce495e2bc2c9c8006ca0228 Binary files /dev/null and b/fairseq/data/__pycache__/monolingual_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/multi_corpus_dataset.cpython-310.pyc b/fairseq/data/__pycache__/multi_corpus_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdaaefc53e3859a9fdf9f5df9d2d869911a7d2a6 Binary files /dev/null and b/fairseq/data/__pycache__/multi_corpus_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/multi_corpus_sampled_dataset.cpython-310.pyc b/fairseq/data/__pycache__/multi_corpus_sampled_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef4ffbb3a22c2873a799eb467ddb64bc2012e9e7 Binary files /dev/null and b/fairseq/data/__pycache__/multi_corpus_sampled_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/nested_dictionary_dataset.cpython-310.pyc b/fairseq/data/__pycache__/nested_dictionary_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3c644c3ffff34c22e435369b79caa17fc244444 Binary files /dev/null and b/fairseq/data/__pycache__/nested_dictionary_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/noising.cpython-310.pyc b/fairseq/data/__pycache__/noising.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1067732221a9fe97336e6ea82364a394bf5d3cea Binary files /dev/null and b/fairseq/data/__pycache__/noising.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/num_samples_dataset.cpython-310.pyc b/fairseq/data/__pycache__/num_samples_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8793cbae91b92f22c4c095f8d32147a29295bdf Binary files /dev/null and b/fairseq/data/__pycache__/num_samples_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/numel_dataset.cpython-310.pyc b/fairseq/data/__pycache__/numel_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0c3be5d68f885f96bd9a73fe1c2b450cf5c605b Binary files /dev/null and b/fairseq/data/__pycache__/numel_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/offset_tokens_dataset.cpython-310.pyc b/fairseq/data/__pycache__/offset_tokens_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6653f9b4e83831ca3308f4d803f2eff41f1c95ca Binary files /dev/null and b/fairseq/data/__pycache__/offset_tokens_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/pad_dataset.cpython-310.pyc b/fairseq/data/__pycache__/pad_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a765eb805d96c73960916d80041c1023ca36374 Binary files /dev/null and b/fairseq/data/__pycache__/pad_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/padding_mask_dataset.cpython-310.pyc b/fairseq/data/__pycache__/padding_mask_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c3497e728fac6b7ce15d9614ae3f2dff470e12a Binary files /dev/null and b/fairseq/data/__pycache__/padding_mask_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/plasma_utils.cpython-310.pyc b/fairseq/data/__pycache__/plasma_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10812dd01718ba6360480337b44d4382c2860f2c Binary files /dev/null and b/fairseq/data/__pycache__/plasma_utils.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/prepend_dataset.cpython-310.pyc b/fairseq/data/__pycache__/prepend_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a2e404df46038216332c7cf70e583bb8dad0e11 Binary files /dev/null and b/fairseq/data/__pycache__/prepend_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/prepend_token_dataset.cpython-310.pyc b/fairseq/data/__pycache__/prepend_token_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f052b07749996529aae83623975fd57fbab2902 Binary files /dev/null and b/fairseq/data/__pycache__/prepend_token_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/raw_label_dataset.cpython-310.pyc b/fairseq/data/__pycache__/raw_label_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d722b6839faa97788bc9dab418a59601b641b7f Binary files /dev/null and b/fairseq/data/__pycache__/raw_label_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/replace_dataset.cpython-310.pyc b/fairseq/data/__pycache__/replace_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f7cd2c3ab5cb848591970167736f0a522930d9f Binary files /dev/null and b/fairseq/data/__pycache__/replace_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/resampling_dataset.cpython-310.pyc b/fairseq/data/__pycache__/resampling_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7eea16503da7fc388fa8885d4c872fe19e23cd8f Binary files /dev/null and b/fairseq/data/__pycache__/resampling_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/roll_dataset.cpython-310.pyc b/fairseq/data/__pycache__/roll_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1cbd5234ad7cd84b658fa1eca80f95b35abb272 Binary files /dev/null and b/fairseq/data/__pycache__/roll_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/round_robin_zip_datasets.cpython-310.pyc b/fairseq/data/__pycache__/round_robin_zip_datasets.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..712e6f487b2d36050869175aa175f24724405867 Binary files /dev/null and b/fairseq/data/__pycache__/round_robin_zip_datasets.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/shorten_dataset.cpython-310.pyc b/fairseq/data/__pycache__/shorten_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c299899c97e2ab2fc276ff468c51c726b77bc19 Binary files /dev/null and b/fairseq/data/__pycache__/shorten_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/sort_dataset.cpython-310.pyc b/fairseq/data/__pycache__/sort_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fb23f80e11816c0222f011aeb56bdc67c359792 Binary files /dev/null and b/fairseq/data/__pycache__/sort_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/span_mask_tokens_dataset.cpython-310.pyc b/fairseq/data/__pycache__/span_mask_tokens_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e4bee5c76b4d0d725ad07b6cdec701a321ecf58 Binary files /dev/null and b/fairseq/data/__pycache__/span_mask_tokens_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/speech_dlm_dataset.cpython-310.pyc b/fairseq/data/__pycache__/speech_dlm_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02398a6337569d2b2b266349586332e3e648ae7c Binary files /dev/null and b/fairseq/data/__pycache__/speech_dlm_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/strip_token_dataset.cpython-310.pyc b/fairseq/data/__pycache__/strip_token_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37435df057d9fa584f99f60e3498e2019418e787 Binary files /dev/null and b/fairseq/data/__pycache__/strip_token_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/subsample_dataset.cpython-310.pyc b/fairseq/data/__pycache__/subsample_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90074999a1955aba77ec88c987b376553a4bd17c Binary files /dev/null and b/fairseq/data/__pycache__/subsample_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/text_compressor.cpython-310.pyc b/fairseq/data/__pycache__/text_compressor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab7b20dce0e3d1566c9cfee668cabd926dcc9eee Binary files /dev/null and b/fairseq/data/__pycache__/text_compressor.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/token_block_dataset.cpython-310.pyc b/fairseq/data/__pycache__/token_block_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..104f57c5fadec66f2c43c6968621e737f78c3a15 Binary files /dev/null and b/fairseq/data/__pycache__/token_block_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/transform_eos_concat_langpair_dataset.cpython-310.pyc b/fairseq/data/__pycache__/transform_eos_concat_langpair_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef54fb6ca82c5794ddf334b30acd9f8c92651c01 Binary files /dev/null and b/fairseq/data/__pycache__/transform_eos_concat_langpair_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/transform_eos_dataset.cpython-310.pyc b/fairseq/data/__pycache__/transform_eos_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f416533a994bcb0967b40bfd423bab009c3453b6 Binary files /dev/null and b/fairseq/data/__pycache__/transform_eos_dataset.cpython-310.pyc differ diff --git a/fairseq/data/__pycache__/transform_eos_lang_pair_dataset.cpython-310.pyc b/fairseq/data/__pycache__/transform_eos_lang_pair_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13039fd6ecd1af75d5a1fc14a53380bfa1595cf8 Binary files /dev/null and b/fairseq/data/__pycache__/transform_eos_lang_pair_dataset.cpython-310.pyc differ diff --git a/fairseq/data/add_class_target_dataset.py b/fairseq/data/add_class_target_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..bf89f2565662e42adeb7455fb07f5c81b209b93c --- /dev/null +++ b/fairseq/data/add_class_target_dataset.py @@ -0,0 +1,79 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from . import BaseWrapperDataset, data_utils +from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel + + +class AddTargetDataset(BaseWrapperDataset): + def __init__( + self, + dataset, + labels, + pad, + eos, + batch_targets, + process_label=None, + label_len_fn=None, + add_to_input=False, + text_compression_level=TextCompressionLevel.none, + ): + super().__init__(dataset) + self.labels = labels + self.batch_targets = batch_targets + self.pad = pad + self.eos = eos + self.process_label = process_label + self.label_len_fn = label_len_fn + self.add_to_input = add_to_input + self.text_compressor = TextCompressor(level=text_compression_level) + + def get_label(self, index, process_fn=None): + lbl = self.labels[index] + lbl = self.text_compressor.decompress(lbl) + return lbl if process_fn is None else process_fn(lbl) + + def __getitem__(self, index): + item = self.dataset[index] + item["label"] = self.get_label(index, process_fn=self.process_label) + return item + + def size(self, index): + sz = self.dataset.size(index) + own_sz = self.label_len_fn(self.get_label(index)) + return sz, own_sz + + def collater(self, samples): + collated = self.dataset.collater(samples) + if len(collated) == 0: + return collated + indices = set(collated["id"].tolist()) + target = [s["label"] for s in samples if s["id"] in indices] + + if self.batch_targets: + collated["target_lengths"] = torch.LongTensor([len(t) for t in target]) + target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False) + collated["ntokens"] = collated["target_lengths"].sum().item() + else: + collated["ntokens"] = sum([len(t) for t in target]) + + collated["target"] = target + + if self.add_to_input: + eos = target.new_full((target.size(0), 1), self.eos) + collated["target"] = torch.cat([target, eos], dim=-1).long() + collated["net_input"]["prev_output_tokens"] = torch.cat( + [eos, target], dim=-1 + ).long() + collated["ntokens"] += target.size(0) + return collated + + def filter_indices_by_size(self, indices, max_sizes): + indices, ignored = data_utils._filter_by_size_dynamic( + indices, self.size, max_sizes + ) + return indices, ignored diff --git a/fairseq/data/add_target_dataset.py b/fairseq/data/add_target_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..978a5b1903cf51683fa03fcbb68b44cf23f93a08 --- /dev/null +++ b/fairseq/data/add_target_dataset.py @@ -0,0 +1,83 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from . import BaseWrapperDataset, data_utils +from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel + + +class AddTargetDataset(BaseWrapperDataset): + def __init__( + self, + dataset, + labels, + pad, + eos, + batch_targets, + process_label=None, + label_len_fn=None, + add_to_input=False, + text_compression_level=TextCompressionLevel.none, + ): + super().__init__(dataset) + self.labels = labels + self.batch_targets = batch_targets + self.pad = pad + self.eos = eos + self.process_label = process_label + self.label_len_fn = label_len_fn + self.add_to_input = add_to_input + self.text_compressor = TextCompressor(level=text_compression_level) + + def get_label(self, index, process_fn=None): + lbl = self.labels[index] + lbl = self.text_compressor.decompress(lbl) + return lbl if process_fn is None else process_fn(lbl) + + def __getitem__(self, index): + item = self.dataset[index] + item["label"] = self.get_label(index, process_fn=self.process_label) + return item + + def size(self, index): + sz = self.dataset.size(index) + own_sz = self.label_len_fn(self.get_label(index)) + return sz, own_sz + + def collater(self, samples): + collated = self.dataset.collater(samples) + if len(collated) == 0: + return collated + indices = set(collated["id"].tolist()) + target = [s["label"] for s in samples if s["id"] in indices] + + if self.add_to_input: + eos = torch.LongTensor([self.eos]) + prev_output_tokens = [torch.cat([eos, t], axis=-1) for t in target] + target = [torch.cat([t, eos], axis=-1) for t in target] + collated["net_input"]["prev_output_tokens"] = prev_output_tokens + + if self.batch_targets: + collated["target_lengths"] = torch.LongTensor([len(t) for t in target]) + target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False) + collated["ntokens"] = collated["target_lengths"].sum().item() + if getattr(collated["net_input"], "prev_output_tokens", None): + collated["net_input"]["prev_output_tokens"] = data_utils.collate_tokens( + collated["net_input"]["prev_output_tokens"], + pad_idx=self.pad, + left_pad=False, + ) + else: + collated["ntokens"] = sum([len(t) for t in target]) + + collated["target"] = target + return collated + + def filter_indices_by_size(self, indices, max_sizes): + indices, ignored = data_utils._filter_by_size_dynamic( + indices, self.size, max_sizes + ) + return indices, ignored diff --git a/fairseq/data/append_token_dataset.py b/fairseq/data/append_token_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..87695bd0f5fcb6b10247e3b743340623e6438cc1 --- /dev/null +++ b/fairseq/data/append_token_dataset.py @@ -0,0 +1,41 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from . import BaseWrapperDataset + + +class AppendTokenDataset(BaseWrapperDataset): + def __init__(self, dataset, token=None): + super().__init__(dataset) + self.token = token + if token is not None: + self._sizes = np.array(dataset.sizes) + 1 + else: + self._sizes = dataset.sizes + + def __getitem__(self, idx): + item = self.dataset[idx] + if self.token is not None: + item = torch.cat([item, item.new([self.token])]) + return item + + @property + def sizes(self): + return self._sizes + + def num_tokens(self, index): + n = self.dataset.num_tokens(index) + if self.token is not None: + n += 1 + return n + + def size(self, index): + n = self.dataset.size(index) + if self.token is not None: + n += 1 + return n diff --git a/fairseq/data/audio/__init__.py b/fairseq/data/audio/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dff90fadfc3248e3ca6d5771002cea8e0e767c1a --- /dev/null +++ b/fairseq/data/audio/__init__.py @@ -0,0 +1,93 @@ +from abc import ABC, abstractmethod +from typing import Dict, Optional +import importlib +import os +import numpy as np + + +class AudioTransform(ABC): + @classmethod + @abstractmethod + def from_config_dict(cls, config: Optional[Dict] = None): + pass + + +class CompositeAudioTransform(AudioTransform): + def _from_config_dict( + cls, + transform_type, + get_audio_transform, + composite_cls, + config=None, + return_empty=False, + ): + _config = {} if config is None else config + _transforms = _config.get(f"{transform_type}_transforms") + + if _transforms is None: + if return_empty: + _transforms = [] + else: + return None + + transforms = [ + get_audio_transform(_t).from_config_dict(_config.get(_t)) + for _t in _transforms + ] + return composite_cls(transforms) + + def __init__(self, transforms): + self.transforms = [t for t in transforms if t is not None] + + def __call__(self, x): + for t in self.transforms: + x = t(x) + return x + + def __repr__(self): + format_string = ( + [self.__class__.__name__ + "("] + + [f" {t.__repr__()}" for t in self.transforms] + + [")"] + ) + return "\n".join(format_string) + + +def register_audio_transform(name, cls_type, registry, class_names): + def register_audio_transform_cls(cls): + if name in registry: + raise ValueError(f"Cannot register duplicate transform ({name})") + if not issubclass(cls, cls_type): + raise ValueError( + f"Transform ({name}: {cls.__name__}) must extend " + f"{cls_type.__name__}" + ) + if cls.__name__ in class_names: + raise ValueError( + f"Cannot register audio transform with duplicate " + f"class name ({cls.__name__})" + ) + registry[name] = cls + class_names.add(cls.__name__) + return cls + + return register_audio_transform_cls + + +def import_transforms(transforms_dir, transform_type): + for file in os.listdir(transforms_dir): + path = os.path.join(transforms_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + name = file[: file.find(".py")] if file.endswith(".py") else file + importlib.import_module( + f"fairseq.data.audio.{transform_type}_transforms." + name + ) + + +# Utility fn for uniform numbers in transforms +def rand_uniform(a, b): + return np.random.uniform() * (b - a) + a diff --git a/fairseq/data/audio/__pycache__/__init__.cpython-310.pyc b/fairseq/data/audio/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f5bad027c75792d17d8d872e392c4dd045879be Binary files /dev/null and b/fairseq/data/audio/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/data/audio/__pycache__/audio_utils.cpython-310.pyc b/fairseq/data/audio/__pycache__/audio_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..598c0ade356f513e5317b4286b431a99f600cf29 Binary files /dev/null and b/fairseq/data/audio/__pycache__/audio_utils.cpython-310.pyc differ diff --git a/fairseq/data/audio/__pycache__/data_cfg.cpython-310.pyc b/fairseq/data/audio/__pycache__/data_cfg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15d2f5ebd00138a307af100aed2b73f65eaa645b Binary files /dev/null and b/fairseq/data/audio/__pycache__/data_cfg.cpython-310.pyc differ diff --git a/fairseq/data/audio/__pycache__/frm_text_to_speech_dataset.cpython-310.pyc b/fairseq/data/audio/__pycache__/frm_text_to_speech_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e25603c3babe2cbc096ee6ddbfc5f17a87397bb Binary files /dev/null and b/fairseq/data/audio/__pycache__/frm_text_to_speech_dataset.cpython-310.pyc differ diff --git a/fairseq/data/audio/__pycache__/hubert_dataset.cpython-310.pyc b/fairseq/data/audio/__pycache__/hubert_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54443e49d56dab6870620e503d877b9547d17d87 Binary files /dev/null and b/fairseq/data/audio/__pycache__/hubert_dataset.cpython-310.pyc differ diff --git a/fairseq/data/audio/__pycache__/raw_audio_dataset.cpython-310.pyc b/fairseq/data/audio/__pycache__/raw_audio_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..969fa3df6ba9870f64b08921cdf8bb507733482c Binary files /dev/null and b/fairseq/data/audio/__pycache__/raw_audio_dataset.cpython-310.pyc differ diff --git a/fairseq/data/audio/__pycache__/speech_to_speech_dataset.cpython-310.pyc b/fairseq/data/audio/__pycache__/speech_to_speech_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27c3c4f6653aba7f519567c82e6c7a6309c86eb8 Binary files /dev/null and b/fairseq/data/audio/__pycache__/speech_to_speech_dataset.cpython-310.pyc differ diff --git a/fairseq/data/audio/__pycache__/speech_to_text_dataset.cpython-310.pyc b/fairseq/data/audio/__pycache__/speech_to_text_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..060b50bf9accf079c18e373ce00e5ae8f8d02cdb Binary files /dev/null and b/fairseq/data/audio/__pycache__/speech_to_text_dataset.cpython-310.pyc differ diff --git a/fairseq/data/audio/__pycache__/text_to_speech_dataset.cpython-310.pyc b/fairseq/data/audio/__pycache__/text_to_speech_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..905abcf4e1c45cd9a71d741118e54c80ec8596be Binary files /dev/null and b/fairseq/data/audio/__pycache__/text_to_speech_dataset.cpython-310.pyc differ diff --git a/fairseq/data/audio/audio_utils.py b/fairseq/data/audio/audio_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..590a7493aec1a88a84c173cc6d4c57828705da5f --- /dev/null +++ b/fairseq/data/audio/audio_utils.py @@ -0,0 +1,389 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import mmap +from pathlib import Path +import io +from typing import BinaryIO, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F + +from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform + +SF_AUDIO_FILE_EXTENSIONS = {".wav", ".flac", ".ogg"} +FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS = {".npy", ".wav", ".flac", ".ogg"} + + +def convert_waveform( + waveform: Union[np.ndarray, torch.Tensor], + sample_rate: int, + normalize_volume: bool = False, + to_mono: bool = False, + to_sample_rate: Optional[int] = None, +) -> Tuple[Union[np.ndarray, torch.Tensor], int]: + """convert a waveform: + - to a target sample rate + - from multi-channel to mono channel + - volume normalization + + Args: + waveform (numpy.ndarray or torch.Tensor): 2D original waveform + (channels x length) + sample_rate (int): original sample rate + normalize_volume (bool): perform volume normalization + to_mono (bool): convert to mono channel if having multiple channels + to_sample_rate (Optional[int]): target sample rate + Returns: + waveform (numpy.ndarray): converted 2D waveform (channels x length) + sample_rate (float): target sample rate + """ + try: + import torchaudio.sox_effects as ta_sox + except ImportError: + raise ImportError("Please install torchaudio: pip install torchaudio") + + effects = [] + if normalize_volume: + effects.append(["gain", "-n"]) + if to_sample_rate is not None and to_sample_rate != sample_rate: + effects.append(["rate", f"{to_sample_rate}"]) + if to_mono and waveform.shape[0] > 1: + effects.append(["channels", "1"]) + if len(effects) > 0: + is_np_input = isinstance(waveform, np.ndarray) + _waveform = torch.from_numpy(waveform) if is_np_input else waveform + converted, converted_sample_rate = ta_sox.apply_effects_tensor( + _waveform, sample_rate, effects + ) + if is_np_input: + converted = converted.numpy() + return converted, converted_sample_rate + return waveform, sample_rate + + +def get_waveform( + path_or_fp: Union[str, BinaryIO], + normalization: bool = True, + mono: bool = True, + frames: int = -1, + start: int = 0, + always_2d: bool = True, + output_sample_rate: Optional[int] = None, + normalize_volume: bool = False, + waveform_transforms: Optional[CompositeAudioWaveformTransform] = None, +) -> Tuple[np.ndarray, int]: + """Get the waveform and sample rate of a 16-bit WAV/FLAC/OGG Vorbis audio. + + Args: + path_or_fp (str or BinaryIO): the path or file-like object + normalization (bool): normalize values to [-1, 1] (Default: True) + mono (bool): convert multi-channel audio to mono-channel one + frames (int): the number of frames to read. (-1 for reading all) + start (int): Where to start reading. A negative value counts from the end. + always_2d (bool): always return 2D array even for mono-channel audios + output_sample_rate (Optional[int]): output sample rate + normalize_volume (bool): normalize volume + Returns: + waveform (numpy.ndarray): 1D or 2D waveform (channels x length) + sample_rate (float): sample rate + """ + if isinstance(path_or_fp, str): + ext = Path(path_or_fp).suffix + if ext not in SF_AUDIO_FILE_EXTENSIONS: + raise ValueError(f"Unsupported audio format: {ext}") + + try: + import soundfile as sf + except ImportError: + raise ImportError("Please install soundfile: pip install soundfile") + + waveform, sample_rate = sf.read( + path_or_fp, dtype="float32", always_2d=True, frames=frames, start=start + ) + waveform = waveform.T # T x C -> C x T + waveform, sample_rate = convert_waveform( + waveform, + sample_rate, + normalize_volume=normalize_volume, + to_mono=mono, + to_sample_rate=output_sample_rate, + ) + + if not normalization: + waveform *= 2**15 # denormalized to 16-bit signed integers + + if waveform_transforms is not None: + waveform, sample_rate = waveform_transforms(waveform, sample_rate) + + if not always_2d: + waveform = waveform.squeeze(axis=0) + + return waveform, sample_rate + + +def get_features_from_npy_or_audio(path, waveform_transforms=None): + ext = Path(path).suffix + if ext not in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS: + raise ValueError(f'Unsupported file format for "{path}"') + return ( + np.load(path) + if ext == ".npy" + else get_fbank(path, waveform_transforms=waveform_transforms) + ) + + +def get_features_or_waveform_from_stored_zip( + path, + byte_offset, + byte_size, + need_waveform=False, + use_sample_rate=None, + waveform_transforms=None, +): + assert path.endswith(".zip") + data = read_from_stored_zip(path, byte_offset, byte_size) + f = io.BytesIO(data) + if is_npy_data(data): + features_or_waveform = np.load(f) + elif is_sf_audio_data(data): + features_or_waveform = ( + get_waveform( + f, + always_2d=False, + output_sample_rate=use_sample_rate, + waveform_transforms=waveform_transforms, + )[0] + if need_waveform + else get_fbank(f, waveform_transforms=waveform_transforms) + ) + else: + raise ValueError(f'Unknown file format for "{path}"') + return features_or_waveform + + +def get_features_or_waveform( + path: str, need_waveform=False, use_sample_rate=None, waveform_transforms=None +): + """Get speech features from .npy file or waveform from .wav/.flac file. + The file may be inside an uncompressed ZIP file and is accessed via byte + offset and length. + + Args: + path (str): File path in the format of "<.npy/.wav/.flac path>" or + "::". + need_waveform (bool): return waveform instead of features. + use_sample_rate (int): change sample rate for the input wave file + + Returns: + features_or_waveform (numpy.ndarray): speech features or waveform. + """ + _path, slice_ptr = parse_path(path) + if len(slice_ptr) == 0: + if need_waveform: + return get_waveform( + _path, + always_2d=False, + output_sample_rate=use_sample_rate, + waveform_transforms=waveform_transforms, + )[0] + return get_features_from_npy_or_audio( + _path, waveform_transforms=waveform_transforms + ) + elif len(slice_ptr) == 2: + features_or_waveform = get_features_or_waveform_from_stored_zip( + _path, + slice_ptr[0], + slice_ptr[1], + need_waveform=need_waveform, + use_sample_rate=use_sample_rate, + waveform_transforms=waveform_transforms, + ) + else: + raise ValueError(f"Invalid path: {path}") + + return features_or_waveform + + +def _get_kaldi_fbank( + waveform: np.ndarray, sample_rate: int, n_bins=80 +) -> Optional[np.ndarray]: + """Get mel-filter bank features via PyKaldi.""" + try: + from kaldi.feat.fbank import Fbank, FbankOptions + from kaldi.feat.mel import MelBanksOptions + from kaldi.feat.window import FrameExtractionOptions + from kaldi.matrix import Vector + + mel_opts = MelBanksOptions() + mel_opts.num_bins = n_bins + frame_opts = FrameExtractionOptions() + frame_opts.samp_freq = sample_rate + opts = FbankOptions() + opts.mel_opts = mel_opts + opts.frame_opts = frame_opts + fbank = Fbank(opts=opts) + features = fbank.compute(Vector(waveform.squeeze()), 1.0).numpy() + return features + except ImportError: + return None + + +def _get_torchaudio_fbank( + waveform: np.ndarray, sample_rate, n_bins=80 +) -> Optional[np.ndarray]: + """Get mel-filter bank features via TorchAudio.""" + try: + import torchaudio.compliance.kaldi as ta_kaldi + + waveform = torch.from_numpy(waveform) + features = ta_kaldi.fbank( + waveform, num_mel_bins=n_bins, sample_frequency=sample_rate + ) + return features.numpy() + except ImportError: + return None + + +def get_fbank( + path_or_fp: Union[str, BinaryIO], n_bins=80, waveform_transforms=None +) -> np.ndarray: + """Get mel-filter bank features via PyKaldi or TorchAudio. Prefer PyKaldi + (faster CPP implementation) to TorchAudio (Python implementation). Note that + Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the + waveform should not be normalized.""" + waveform, sample_rate = get_waveform( + path_or_fp, normalization=False, waveform_transforms=waveform_transforms + ) + + features = _get_kaldi_fbank(waveform, sample_rate, n_bins) + if features is None: + features = _get_torchaudio_fbank(waveform, sample_rate, n_bins) + if features is None: + raise ImportError( + "Please install pyKaldi or torchaudio to enable " + "online filterbank feature extraction" + ) + + return features + + +def is_npy_data(data: bytes) -> bool: + return data[0] == 147 and data[1] == 78 + + +def is_sf_audio_data(data: bytes) -> bool: + is_wav = data[0] == 82 and data[1] == 73 and data[2] == 70 + is_flac = data[0] == 102 and data[1] == 76 and data[2] == 97 + is_ogg = data[0] == 79 and data[1] == 103 and data[2] == 103 + return is_wav or is_flac or is_ogg + + +def mmap_read(path: str, offset: int, length: int) -> bytes: + with open(path, "rb") as f: + with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_o: + data = mmap_o[offset : offset + length] + return data + + +def read_from_stored_zip(zip_path: str, offset: int, length: int) -> bytes: + return mmap_read(zip_path, offset, length) + + +def parse_path(path: str) -> Tuple[str, List[int]]: + """Parse data path which is either a path to + 1. a .npy/.wav/.flac/.ogg file + 2. a stored ZIP file with slicing info: "[zip_path]:[offset]:[length]" + + Args: + path (str): the data path to parse + + Returns: + file_path (str): the file path + slice_ptr (list of int): empty in case 1; + byte offset and length for the slice in case 2 + """ + + if Path(path).suffix in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS: + _path, slice_ptr = path, [] + else: + _path, *slice_ptr = path.split(":") + if not Path(_path).is_file(): + raise FileNotFoundError(f"File not found: {_path}") + assert len(slice_ptr) in {0, 2}, f"Invalid path: {path}" + slice_ptr = [int(i) for i in slice_ptr] + return _path, slice_ptr + + +def get_window(window_fn: callable, n_fft: int, win_length: int) -> torch.Tensor: + padding = n_fft - win_length + assert padding >= 0 + return F.pad(window_fn(win_length), (padding // 2, padding - padding // 2)) + + +def get_fourier_basis(n_fft: int) -> torch.Tensor: + basis = np.fft.fft(np.eye(n_fft)) + basis = np.vstack( + [np.real(basis[: n_fft // 2 + 1, :]), np.imag(basis[: n_fft // 2 + 1, :])] + ) + return torch.from_numpy(basis).float() + + +def get_mel_filters( + sample_rate: int, n_fft: int, n_mels: int, f_min: float, f_max: float +) -> torch.Tensor: + try: + import librosa + except ImportError: + raise ImportError("Please install librosa: pip install librosa") + basis = librosa.filters.mel(sample_rate, n_fft, n_mels, f_min, f_max) + return torch.from_numpy(basis).float() + + +class TTSSpectrogram(torch.nn.Module): + def __init__( + self, + n_fft: int, + win_length: int, + hop_length: int, + window_fn: callable = torch.hann_window, + return_phase: bool = False, + ) -> None: + super(TTSSpectrogram, self).__init__() + self.n_fft = n_fft + self.hop_length = hop_length + self.return_phase = return_phase + + basis = get_fourier_basis(n_fft).unsqueeze(1) + basis *= get_window(window_fn, n_fft, win_length) + self.register_buffer("basis", basis) + + def forward( + self, waveform: torch.Tensor + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + padding = (self.n_fft // 2, self.n_fft // 2) + x = F.pad(waveform.unsqueeze(1), padding, mode="reflect") + x = F.conv1d(x, self.basis, stride=self.hop_length) + real_part = x[:, : self.n_fft // 2 + 1, :] + imag_part = x[:, self.n_fft // 2 + 1 :, :] + magnitude = torch.sqrt(real_part**2 + imag_part**2) + if self.return_phase: + phase = torch.atan2(imag_part, real_part) + return magnitude, phase + return magnitude + + +class TTSMelScale(torch.nn.Module): + def __init__( + self, n_mels: int, sample_rate: int, f_min: float, f_max: float, n_stft: int + ) -> None: + super(TTSMelScale, self).__init__() + basis = get_mel_filters(sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max) + self.register_buffer("basis", basis) + + def forward(self, specgram: torch.Tensor) -> torch.Tensor: + return torch.matmul(self.basis, specgram) diff --git a/fairseq/data/audio/data_cfg.py b/fairseq/data/audio/data_cfg.py new file mode 100644 index 0000000000000000000000000000000000000000..6be6f6521c96a5d33466bcce0ecfeaf6f899bf59 --- /dev/null +++ b/fairseq/data/audio/data_cfg.py @@ -0,0 +1,387 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from argparse import Namespace +from copy import deepcopy +from pathlib import Path +from typing import Dict, Optional + +from fairseq.data import Dictionary + +logger = logging.getLogger(__name__) + + +def get_config_from_yaml(yaml_path: Path): + try: + import yaml + except ImportError: + print("Please install PyYAML: pip install PyYAML") + config = {} + if yaml_path.is_file(): + try: + with open(yaml_path) as f: + config = yaml.load(f, Loader=yaml.FullLoader) + except Exception as e: + raise Exception(f"Failed to load config from {yaml_path.as_posix()}: {e}") + else: + raise FileNotFoundError(f"{yaml_path.as_posix()} not found") + + return config + + +class S2TDataConfig(object): + """Wrapper class for data config YAML""" + + def __init__(self, yaml_path: Path): + self.config = get_config_from_yaml(yaml_path) + self.root = yaml_path.parent + + def _auto_convert_to_abs_path(self, x): + if isinstance(x, str): + if not Path(x).exists() and (self.root / x).exists(): + return (self.root / x).as_posix() + elif isinstance(x, dict): + return {k: self._auto_convert_to_abs_path(v) for k, v in x.items()} + return x + + @property + def vocab_filename(self): + """fairseq vocabulary file under data root""" + return self.config.get("vocab_filename", "dict.txt") + + @property + def speaker_set_filename(self): + """speaker set file under data root""" + return self.config.get("speaker_set_filename", None) + + @property + def shuffle(self) -> bool: + """Shuffle dataset samples before batching""" + return self.config.get("shuffle", False) + + @property + def pre_tokenizer(self) -> Dict: + """Pre-tokenizer to apply before subword tokenization. Returning + a dictionary with `tokenizer` providing the tokenizer name and + the other items providing the tokenizer-specific arguments. + Tokenizers are defined in `fairseq.data.encoders.*`""" + tokenizer = self.config.get("pre_tokenizer", {"tokenizer": None}) + return self._auto_convert_to_abs_path(tokenizer) + + @property + def bpe_tokenizer(self) -> Dict: + """Subword tokenizer to apply after pre-tokenization. Returning + a dictionary with `bpe` providing the tokenizer name and + the other items providing the tokenizer-specific arguments. + Tokenizers are defined in `fairseq.data.encoders.*`""" + tokenizer = self.config.get("bpe_tokenizer", {"bpe": None}) + return self._auto_convert_to_abs_path(tokenizer) + + @property + def prepend_tgt_lang_tag(self) -> bool: + """Prepend target lang ID token as the target BOS (e.g. for to-many + multilingual setting). During inference, this requires `--prefix-size 1` + to force BOS to be lang ID token.""" + return self.config.get("prepend_tgt_lang_tag", False) + + @property + def prepend_bos_and_append_tgt_lang_tag(self) -> bool: + """Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining).""" + return self.config.get("prepend_bos_and_append_tgt_lang_tag", False) + + @property + def input_feat_per_channel(self): + """The dimension of input features (per audio channel)""" + return self.config.get("input_feat_per_channel", 80) + + @property + def input_channels(self): + """The number of channels in the input audio""" + return self.config.get("input_channels", 1) + + @property + def sample_rate(self): + return self.config.get("sample_rate", 16_000) + + @property + def sampling_alpha(self): + """Hyper-parameter alpha = 1/T for temperature-based resampling. + (alpha = 1 for no resampling)""" + return self.config.get("sampling_alpha", 1.0) + + @property + def use_audio_input(self): + """Needed by the dataset loader to see if the model requires + raw audio as inputs.""" + return self.config.get("use_audio_input", False) + + def standardize_audio(self) -> bool: + return self.use_audio_input and self.config.get("standardize_audio", False) + + @property + def use_sample_rate(self): + """Needed by the dataset loader to see if the model requires + raw audio with specific sample rate as inputs.""" + return self.config.get("use_sample_rate", 16000) + + @property + def audio_root(self): + """Audio paths in the manifest TSV can be relative and this provides + the root path. Set this to empty string when using absolute paths.""" + return self.config.get("audio_root", "") + + def get_transforms(self, transform_type, split, is_train): + """Split-specific feature transforms. Allowing train set + wildcard `_train`, evaluation set wildcard `_eval` and general + wildcard `*` for matching.""" + from copy import deepcopy + + cfg = deepcopy(self.config) + _cur = cfg.get(f"{transform_type}transforms", {}) + cur = _cur.get(split) + cur = _cur.get("_train") if cur is None and is_train else cur + cur = _cur.get("_eval") if cur is None and not is_train else cur + cur = _cur.get("*") if cur is None else cur + return cur + + def get_feature_transforms(self, split, is_train): + cfg = deepcopy(self.config) + # TODO: deprecate transforms + cur = self.get_transforms("", split, is_train) + if cur is not None: + logger.warning( + "Auto converting transforms into feature_transforms, " + "but transforms will be deprecated in the future. Please " + "update this in the config." + ) + ft_transforms = self.get_transforms("feature_", split, is_train) + if ft_transforms: + cur.extend(ft_transforms) + else: + cur = self.get_transforms("feature_", split, is_train) + cfg["feature_transforms"] = cur + return cfg + + def get_waveform_transforms(self, split, is_train): + cfg = deepcopy(self.config) + cfg["waveform_transforms"] = self.get_transforms("waveform_", split, is_train) + return cfg + + def get_dataset_transforms(self, split, is_train): + cfg = deepcopy(self.config) + cfg["dataset_transforms"] = self.get_transforms("dataset_", split, is_train) + return cfg + + @property + def global_cmvn_stats_npz(self) -> Optional[str]: + path = self.config.get("global_cmvn", {}).get("stats_npz_path", None) + return self._auto_convert_to_abs_path(path) + + @property + def vocoder(self) -> Dict[str, str]: + vocoder = self.config.get("vocoder", {"type": "griffin_lim"}) + return self._auto_convert_to_abs_path(vocoder) + + @property + def hub(self) -> Dict[str, str]: + return self.config.get("hub", {}) + + +class S2SDataConfig(S2TDataConfig): + """Wrapper class for data config YAML""" + + @property + def vocab_filename(self): + """fairseq vocabulary file under data root""" + return self.config.get("vocab_filename", None) + + @property + def pre_tokenizer(self) -> Dict: + return None + + @property + def bpe_tokenizer(self) -> Dict: + return None + + @property + def input_transformed_channels(self): + """The number of channels in the audio after feature transforms""" + # TODO: move this into individual transforms + # TODO: deprecate transforms + _cur = self.config.get("transforms", {}) + ft_transforms = self.config.get("feature_transforms", {}) + if _cur and ft_transforms: + _cur.update(ft_transforms) + else: + _cur = self.config.get("feature_transforms", {}) + cur = _cur.get("_train", []) + + _channels = self.input_channels + if "delta_deltas" in cur: + _channels *= 3 + + return _channels + + @property + def output_sample_rate(self): + """The audio sample rate of output target speech""" + return self.config.get("output_sample_rate", 22050) + + @property + def target_speaker_embed(self): + """Target speaker embedding file (one line per target audio sample)""" + return self.config.get("target_speaker_embed", None) + + @property + def prepend_tgt_lang_tag_as_bos(self) -> bool: + """Prepend target lang ID token as the target BOS.""" + return self.config.get("prepend_tgt_lang_tag_as_bos", False) + + +class MultitaskConfig(object): + """Wrapper class for data config YAML""" + + def __init__(self, yaml_path: Path): + config = get_config_from_yaml(yaml_path) + self.config = {} + for k, v in config.items(): + self.config[k] = SingleTaskConfig(k, v) + + def get_all_tasks(self): + return self.config + + def get_single_task(self, name): + assert name in self.config, f"multitask '{name}' does not exist!" + return self.config[name] + + @property + def first_pass_decoder_task_index(self): + """Return the task index of the first-pass text decoder. + If there are multiple 'is_first_pass_decoder: True' in the config file, + the last task is used for the first-pass decoder. + If there is no 'is_first_pass_decoder: True' in the config file, + the last task whose task_name includes 'target' and decoder_type is not ctc. + """ + idx = -1 + for i, (k, v) in enumerate(self.config.items()): + if v.is_first_pass_decoder: + idx = i + if idx < 0: + for i, (k, v) in enumerate(self.config.items()): + if k.startswith("target") and v.decoder_type == "transformer": + idx = i + return idx + + +class SingleTaskConfig(object): + def __init__(self, name, config): + self.task_name = name + self.config = config + dict_path = config.get("dict", "") + self.tgt_dict = Dictionary.load(dict_path) if Path(dict_path).exists() else None + + @property + def data(self): + return self.config.get("data", "") + + @property + def decoder_type(self): + return self.config.get("decoder_type", "transformer") + + @property + def decoder_args(self): + """Decoder arch related args""" + args = self.config.get("decoder_args", {}) + return Namespace(**args) + + @property + def criterion_cfg(self): + """cfg for the multitask criterion""" + if self.decoder_type == "ctc": + from fairseq.criterions.ctc import CtcCriterionConfig + + cfg = CtcCriterionConfig + cfg.zero_infinity = self.config.get("zero_infinity", True) + else: + from fairseq.criterions.label_smoothed_cross_entropy import ( + LabelSmoothedCrossEntropyCriterionConfig, + ) + + cfg = LabelSmoothedCrossEntropyCriterionConfig + cfg.label_smoothing = self.config.get("label_smoothing", 0.2) + return cfg + + @property + def input_from(self): + """Condition on encoder/decoder of the main model""" + return "decoder" if "decoder_layer" in self.config else "encoder" + + @property + def input_layer(self): + if self.input_from == "decoder": + return self.config["decoder_layer"] - 1 + else: + # default using the output from the last encoder layer (-1) + return self.config.get("encoder_layer", 0) - 1 + + @property + def loss_weight_schedule(self): + return ( + "decay" + if "loss_weight_max" in self.config + and "loss_weight_decay_steps" in self.config + else "fixed" + ) + + def get_loss_weight(self, num_updates): + if self.loss_weight_schedule == "fixed": + weight = self.config.get("loss_weight", 1.0) + else: # "decay" + assert ( + self.config.get("loss_weight_decay_steps", 0) > 0 + ), "loss_weight_decay_steps must be greater than 0 for a decay schedule" + loss_weight_min = self.config.get("loss_weight_min", 0.0001) + loss_weight_decay_stepsize = ( + self.config["loss_weight_max"] - loss_weight_min + ) / self.config["loss_weight_decay_steps"] + weight = max( + self.config["loss_weight_max"] + - loss_weight_decay_stepsize * num_updates, + loss_weight_min, + ) + return weight + + @property + def prepend_bos_and_append_tgt_lang_tag(self) -> bool: + """Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining).""" + return self.config.get("prepend_bos_and_append_tgt_lang_tag", False) + + @property + def eos_token(self): + """EOS token during generation""" + return self.config.get("eos_token", "") + + @property + def rdrop_alpha(self): + return self.config.get("rdrop_alpha", 0.0) + + @property + def is_first_pass_decoder(self): + flag = self.config.get("is_first_pass_decoder", False) + if flag: + if self.decoder_type == "ctc": + raise ValueError( + "First-pass decoder in the multi-decoder model must not be CTC." + ) + if "target" not in self.task_name: + raise Warning( + 'The name of the first-pass decoder does not include "target".' + ) + return flag + + @property + def get_lang_tag_mapping(self): + return self.config.get("lang_tag_mapping", {}) diff --git a/fairseq/data/audio/dataset_transforms/__init__.py b/fairseq/data/audio/dataset_transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b24c6f731f29875bb7570c72bf140ade9a52e19c --- /dev/null +++ b/fairseq/data/audio/dataset_transforms/__init__.py @@ -0,0 +1,53 @@ +import os +from fairseq.data.audio import ( + AudioTransform, + CompositeAudioTransform, + import_transforms, + register_audio_transform, +) + + +class AudioDatasetTransform(AudioTransform): + pass + + +AUDIO_DATASET_TRANSFORM_REGISTRY = {} +AUDIO_DATASET_TRANSFORM_CLASS_NAMES = set() + + +def get_audio_dataset_transform(name): + return AUDIO_DATASET_TRANSFORM_REGISTRY[name] + + +def register_audio_dataset_transform(name): + return register_audio_transform( + name, + AudioDatasetTransform, + AUDIO_DATASET_TRANSFORM_REGISTRY, + AUDIO_DATASET_TRANSFORM_CLASS_NAMES, + ) + + +import_transforms(os.path.dirname(__file__), "dataset") + + +class CompositeAudioDatasetTransform(CompositeAudioTransform): + @classmethod + def from_config_dict(cls, config=None): + return super()._from_config_dict( + cls, + "dataset", + get_audio_dataset_transform, + CompositeAudioDatasetTransform, + config, + return_empty=True, + ) + + def get_transform(self, cls): + for t in self.transforms: + if isinstance(t, cls): + return t + return None + + def has_transform(self, cls): + return self.get_transform(cls) is not None diff --git a/fairseq/data/audio/dataset_transforms/__pycache__/__init__.cpython-310.pyc b/fairseq/data/audio/dataset_transforms/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47f9721aff0b5df8efa87cf31bf6d716994a81ad Binary files /dev/null and b/fairseq/data/audio/dataset_transforms/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/data/audio/dataset_transforms/__pycache__/concataugment.cpython-310.pyc b/fairseq/data/audio/dataset_transforms/__pycache__/concataugment.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ced0fea59fd55a8e23f34336066112cd1a5723a8 Binary files /dev/null and b/fairseq/data/audio/dataset_transforms/__pycache__/concataugment.cpython-310.pyc differ diff --git a/fairseq/data/audio/dataset_transforms/__pycache__/noisyoverlapaugment.cpython-310.pyc b/fairseq/data/audio/dataset_transforms/__pycache__/noisyoverlapaugment.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aecc48b99aa6dc23cae2fa389f0bc1d8968f4c18 Binary files /dev/null and b/fairseq/data/audio/dataset_transforms/__pycache__/noisyoverlapaugment.cpython-310.pyc differ diff --git a/fairseq/data/audio/dataset_transforms/concataugment.py b/fairseq/data/audio/dataset_transforms/concataugment.py new file mode 100644 index 0000000000000000000000000000000000000000..0b632ccf2b7adaeaacb784ff41df641f46303fac --- /dev/null +++ b/fairseq/data/audio/dataset_transforms/concataugment.py @@ -0,0 +1,61 @@ +from typing import List +import numpy as np + +from fairseq.data.audio.dataset_transforms import ( + AudioDatasetTransform, + register_audio_dataset_transform, +) + +_DEFAULTS = {"rate": 0.25, "max_tokens": 3000, "attempts": 5} + + +@register_audio_dataset_transform("concataugment") +class ConcatAugment(AudioDatasetTransform): + @classmethod + def from_config_dict(cls, config=None): + _config = {} if config is None else config + return ConcatAugment( + _config.get("rate", _DEFAULTS["rate"]), + _config.get("max_tokens", _DEFAULTS["max_tokens"]), + _config.get("attempts", _DEFAULTS["attempts"]), + ) + + def __init__( + self, + rate=_DEFAULTS["rate"], + max_tokens=_DEFAULTS["max_tokens"], + attempts=_DEFAULTS["attempts"], + ): + self.rate, self.max_tokens, self.attempts = rate, max_tokens, attempts + + def __repr__(self): + return ( + self.__class__.__name__ + + "(" + + ", ".join( + [ + f"rate={self.rate}", + f"max_tokens={self.max_tokens}", + f"attempts={self.attempts}", + ] + ) + + ")" + ) + + def find_indices(self, index: int, n_frames: List[int], n_samples: int): + # skip conditions: application rate, max_tokens limit exceeded + if np.random.random() > self.rate: + return [index] + if self.max_tokens and n_frames[index] > self.max_tokens: + return [index] + + # pick second sample to concatenate + for _ in range(self.attempts): + index2 = np.random.randint(0, n_samples) + if index2 != index and ( + not self.max_tokens + or n_frames[index] + n_frames[index2] < self.max_tokens + ): + return [index, index2] + + return [index] diff --git a/fairseq/data/audio/dataset_transforms/noisyoverlapaugment.py b/fairseq/data/audio/dataset_transforms/noisyoverlapaugment.py new file mode 100644 index 0000000000000000000000000000000000000000..e9ebec2388985abc66c2ce0952d3c5684169450a --- /dev/null +++ b/fairseq/data/audio/dataset_transforms/noisyoverlapaugment.py @@ -0,0 +1,105 @@ +import numpy as np +import torch + +from fairseq.data.audio import rand_uniform +from fairseq.data.audio.dataset_transforms import ( + AudioDatasetTransform, + register_audio_dataset_transform, +) +from fairseq.data.audio.waveform_transforms.noiseaugment import ( + NoiseAugmentTransform, +) + +_DEFAULTS = { + "rate": 0.25, + "mixing_noise_rate": 0.1, + "noise_path": "", + "noise_snr_min": -5, + "noise_snr_max": 5, + "utterance_snr_min": -5, + "utterance_snr_max": 5, +} + + +@register_audio_dataset_transform("noisyoverlapaugment") +class NoisyOverlapAugment(AudioDatasetTransform): + @classmethod + def from_config_dict(cls, config=None): + _config = {} if config is None else config + return NoisyOverlapAugment( + _config.get("rate", _DEFAULTS["rate"]), + _config.get("mixing_noise_rate", _DEFAULTS["mixing_noise_rate"]), + _config.get("noise_path", _DEFAULTS["noise_path"]), + _config.get("noise_snr_min", _DEFAULTS["noise_snr_min"]), + _config.get("noise_snr_max", _DEFAULTS["noise_snr_max"]), + _config.get("utterance_snr_min", _DEFAULTS["utterance_snr_min"]), + _config.get("utterance_snr_max", _DEFAULTS["utterance_snr_max"]), + ) + + def __init__( + self, + rate=_DEFAULTS["rate"], + mixing_noise_rate=_DEFAULTS["mixing_noise_rate"], + noise_path=_DEFAULTS["noise_path"], + noise_snr_min=_DEFAULTS["noise_snr_min"], + noise_snr_max=_DEFAULTS["noise_snr_max"], + utterance_snr_min=_DEFAULTS["utterance_snr_min"], + utterance_snr_max=_DEFAULTS["utterance_snr_max"], + ): + self.rate = rate + self.mixing_noise_rate = mixing_noise_rate + self.noise_shaper = NoiseAugmentTransform(noise_path) + self.noise_snr_min = noise_snr_min + self.noise_snr_max = noise_snr_max + self.utterance_snr_min = utterance_snr_min + self.utterance_snr_max = utterance_snr_max + + def __repr__(self): + return ( + self.__class__.__name__ + + "(" + + ", ".join( + [ + f"rate={self.rate}", + f"mixing_noise_rate={self.mixing_noise_rate}", + f"noise_snr_min={self.noise_snr_min}", + f"noise_snr_max={self.noise_snr_max}", + f"utterance_snr_min={self.utterance_snr_min}", + f"utterance_snr_max={self.utterance_snr_max}", + ] + ) + + ")" + ) + + def __call__(self, sources): + for i, source in enumerate(sources): + if np.random.random() > self.rate: + continue + + pri = source.numpy() + + if np.random.random() > self.mixing_noise_rate: + sec = sources[np.random.randint(0, len(sources))].numpy() + snr = rand_uniform(self.utterance_snr_min, self.utterance_snr_max) + else: + sec = self.noise_shaper.pick_sample(source.shape) + snr = rand_uniform(self.noise_snr_min, self.noise_snr_max) + + L1 = pri.shape[-1] + L2 = sec.shape[-1] + l = np.random.randint(0, min(round(L1 / 2), L2)) # mix len + s_source = np.random.randint(0, L1 - l) + s_sec = np.random.randint(0, L2 - l) + + get_power = lambda x: np.mean(x**2) + if get_power(sec) == 0: + continue + + scl = np.sqrt(get_power(pri) / (np.power(10, snr / 10) * get_power(sec))) + + pri[s_source : s_source + l] = np.add( + pri[s_source : s_source + l], np.multiply(scl, sec[s_sec : s_sec + l]) + ) + sources[i] = torch.from_numpy(pri).float() + + return sources diff --git a/fairseq/data/audio/feature_transforms/__init__.py b/fairseq/data/audio/feature_transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d295013b9013b7cfb2082261418e7fd730b915e6 --- /dev/null +++ b/fairseq/data/audio/feature_transforms/__init__.py @@ -0,0 +1,43 @@ +import os +from fairseq.data.audio import ( + AudioTransform, + CompositeAudioTransform, + import_transforms, + register_audio_transform, +) + + +class AudioFeatureTransform(AudioTransform): + pass + + +AUDIO_FEATURE_TRANSFORM_REGISTRY = {} +AUDIO_FEATURE_TRANSFORM_CLASS_NAMES = set() + + +def get_audio_feature_transform(name): + return AUDIO_FEATURE_TRANSFORM_REGISTRY[name] + + +def register_audio_feature_transform(name): + return register_audio_transform( + name, + AudioFeatureTransform, + AUDIO_FEATURE_TRANSFORM_REGISTRY, + AUDIO_FEATURE_TRANSFORM_CLASS_NAMES, + ) + + +import_transforms(os.path.dirname(__file__), "feature") + + +class CompositeAudioFeatureTransform(CompositeAudioTransform): + @classmethod + def from_config_dict(cls, config=None): + return super()._from_config_dict( + cls, + "feature", + get_audio_feature_transform, + CompositeAudioFeatureTransform, + config, + ) diff --git a/fairseq/data/audio/feature_transforms/__pycache__/__init__.cpython-310.pyc b/fairseq/data/audio/feature_transforms/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7543b4c7ef45c67809f2acfc51571b12dad5917 Binary files /dev/null and b/fairseq/data/audio/feature_transforms/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/data/audio/feature_transforms/__pycache__/delta_deltas.cpython-310.pyc b/fairseq/data/audio/feature_transforms/__pycache__/delta_deltas.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a55ea068afc9962e322f714ac447218ed7512795 Binary files /dev/null and b/fairseq/data/audio/feature_transforms/__pycache__/delta_deltas.cpython-310.pyc differ diff --git a/fairseq/data/audio/feature_transforms/__pycache__/global_cmvn.cpython-310.pyc b/fairseq/data/audio/feature_transforms/__pycache__/global_cmvn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4f0a9d1f8d162b777e6a22f24d25c1e9e668676 Binary files /dev/null and b/fairseq/data/audio/feature_transforms/__pycache__/global_cmvn.cpython-310.pyc differ diff --git a/fairseq/data/audio/feature_transforms/__pycache__/specaugment.cpython-310.pyc b/fairseq/data/audio/feature_transforms/__pycache__/specaugment.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5c2d6ec466a6a64072a1a7e403eb691a1971786 Binary files /dev/null and b/fairseq/data/audio/feature_transforms/__pycache__/specaugment.cpython-310.pyc differ diff --git a/fairseq/data/audio/feature_transforms/__pycache__/utterance_cmvn.cpython-310.pyc b/fairseq/data/audio/feature_transforms/__pycache__/utterance_cmvn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19509bb6e82660033407ec06c8833e368296c875 Binary files /dev/null and b/fairseq/data/audio/feature_transforms/__pycache__/utterance_cmvn.cpython-310.pyc differ diff --git a/fairseq/data/audio/feature_transforms/delta_deltas.py b/fairseq/data/audio/feature_transforms/delta_deltas.py new file mode 100644 index 0000000000000000000000000000000000000000..49d090b11e5b31562e0aedc9b4e2b8d0d510eeda --- /dev/null +++ b/fairseq/data/audio/feature_transforms/delta_deltas.py @@ -0,0 +1,37 @@ +import numpy as np +import torch +from fairseq.data.audio.feature_transforms import ( + AudioFeatureTransform, + register_audio_feature_transform, +) + + +@register_audio_feature_transform("delta_deltas") +class DeltaDeltas(AudioFeatureTransform): + """Expand delta-deltas features from spectrum.""" + + @classmethod + def from_config_dict(cls, config=None): + _config = {} if config is None else config + return DeltaDeltas(_config.get("win_length", 5)) + + def __init__(self, win_length=5): + self.win_length = win_length + + def __repr__(self): + return self.__class__.__name__ + + def __call__(self, spectrogram): + from torchaudio.functional import compute_deltas + + assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor." + # spectrogram is T x F, while compute_deltas takes (…, F, T) + spectrogram = torch.from_numpy(spectrogram).transpose(0, 1) + delta = compute_deltas(spectrogram) + delta_delta = compute_deltas(delta) + + out_feat = np.concatenate( + [spectrogram, delta.numpy(), delta_delta.numpy()], axis=0 + ) + out_feat = np.transpose(out_feat) + return out_feat diff --git a/fairseq/data/audio/feature_transforms/global_cmvn.py b/fairseq/data/audio/feature_transforms/global_cmvn.py new file mode 100644 index 0000000000000000000000000000000000000000..e457ff176fee3b996da11f47e7dc61b81c445ba3 --- /dev/null +++ b/fairseq/data/audio/feature_transforms/global_cmvn.py @@ -0,0 +1,29 @@ +import numpy as np +from fairseq.data.audio.feature_transforms import ( + AudioFeatureTransform, + register_audio_feature_transform, +) + + +@register_audio_feature_transform("global_cmvn") +class GlobalCMVN(AudioFeatureTransform): + """Global CMVN (cepstral mean and variance normalization). The global mean + and variance need to be pre-computed and stored in NumPy format (.npz).""" + + @classmethod + def from_config_dict(cls, config=None): + _config = {} if config is None else config + return GlobalCMVN(_config.get("stats_npz_path")) + + def __init__(self, stats_npz_path): + self.stats_npz_path = stats_npz_path + stats = np.load(stats_npz_path) + self.mean, self.std = stats["mean"], stats["std"] + + def __repr__(self): + return self.__class__.__name__ + f'(stats_npz_path="{self.stats_npz_path}")' + + def __call__(self, x): + x = np.subtract(x, self.mean) + x = np.divide(x, self.std) + return x diff --git a/fairseq/data/audio/feature_transforms/specaugment.py b/fairseq/data/audio/feature_transforms/specaugment.py new file mode 100644 index 0000000000000000000000000000000000000000..ce5802b41a903ea8f3e3e8a169d5048b4e908f99 --- /dev/null +++ b/fairseq/data/audio/feature_transforms/specaugment.py @@ -0,0 +1,131 @@ +import math +import numbers +from typing import Optional + +import numpy as np +from fairseq.data.audio.feature_transforms import ( + AudioFeatureTransform, + register_audio_feature_transform, +) + + +@register_audio_feature_transform("specaugment") +class SpecAugmentTransform(AudioFeatureTransform): + """SpecAugment (https://arxiv.org/abs/1904.08779)""" + + @classmethod + def from_config_dict(cls, config=None): + _config = {} if config is None else config + return SpecAugmentTransform( + _config.get("time_warp_W", 0), + _config.get("freq_mask_N", 0), + _config.get("freq_mask_F", 0), + _config.get("time_mask_N", 0), + _config.get("time_mask_T", 0), + _config.get("time_mask_p", 0.0), + _config.get("mask_value", None), + ) + + def __init__( + self, + time_warp_w: int = 0, + freq_mask_n: int = 0, + freq_mask_f: int = 0, + time_mask_n: int = 0, + time_mask_t: int = 0, + time_mask_p: float = 0.0, + mask_value: Optional[float] = 0.0, + ): + # Sanity checks + assert mask_value is None or isinstance( + mask_value, numbers.Number + ), f"mask_value (type: {type(mask_value)}) must be None or a number" + if freq_mask_n > 0: + assert freq_mask_f > 0, ( + f"freq_mask_F ({freq_mask_f}) " + f"must be larger than 0 when doing freq masking." + ) + if time_mask_n > 0: + assert time_mask_t > 0, ( + f"time_mask_T ({time_mask_t}) must be larger than 0 when " + f"doing time masking." + ) + + self.time_warp_w = time_warp_w + self.freq_mask_n = freq_mask_n + self.freq_mask_f = freq_mask_f + self.time_mask_n = time_mask_n + self.time_mask_t = time_mask_t + self.time_mask_p = time_mask_p + self.mask_value = mask_value + + def __repr__(self): + return ( + self.__class__.__name__ + + "(" + + ", ".join( + [ + f"time_warp_w={self.time_warp_w}", + f"freq_mask_n={self.freq_mask_n}", + f"freq_mask_f={self.freq_mask_f}", + f"time_mask_n={self.time_mask_n}", + f"time_mask_t={self.time_mask_t}", + f"time_mask_p={self.time_mask_p}", + ] + ) + + ")" + ) + + def __call__(self, spectrogram): + assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor." + + distorted = spectrogram.copy() # make a copy of input spectrogram. + num_frames = spectrogram.shape[0] # or 'tau' in the paper. + num_freqs = spectrogram.shape[1] # or 'miu' in the paper. + mask_value = self.mask_value + + if mask_value is None: # if no value was specified, use local mean. + mask_value = spectrogram.mean() + + if num_frames == 0: + return spectrogram + + if num_freqs < self.freq_mask_f: + return spectrogram + + if self.time_warp_w > 0: + if 2 * self.time_warp_w < num_frames: + import cv2 + + w0 = np.random.randint(self.time_warp_w, num_frames - self.time_warp_w) + w = np.random.randint(-self.time_warp_w + 1, self.time_warp_w) + upper, lower = distorted[:w0, :], distorted[w0:, :] + upper = cv2.resize( + upper, dsize=(num_freqs, w0 + w), interpolation=cv2.INTER_LINEAR + ) + lower = cv2.resize( + lower, + dsize=(num_freqs, num_frames - w0 - w), + interpolation=cv2.INTER_LINEAR, + ) + distorted = np.concatenate((upper, lower), axis=0) + + for _i in range(self.freq_mask_n): + f = np.random.randint(0, self.freq_mask_f) + f0 = np.random.randint(0, num_freqs - f) + if f != 0: + distorted[:, f0 : f0 + f] = mask_value + + max_time_mask_t = min( + self.time_mask_t, math.floor(num_frames * self.time_mask_p) + ) + if max_time_mask_t < 1: + return distorted + + for _i in range(self.time_mask_n): + t = np.random.randint(0, max_time_mask_t) + t0 = np.random.randint(0, num_frames - t) + if t != 0: + distorted[t0 : t0 + t, :] = mask_value + + return distorted diff --git a/fairseq/data/audio/feature_transforms/utterance_cmvn.py b/fairseq/data/audio/feature_transforms/utterance_cmvn.py new file mode 100644 index 0000000000000000000000000000000000000000..37637bc09a893789665d57a3675aa2362413e13b --- /dev/null +++ b/fairseq/data/audio/feature_transforms/utterance_cmvn.py @@ -0,0 +1,41 @@ +import numpy as np + +from fairseq.data.audio.feature_transforms import ( + AudioFeatureTransform, + register_audio_feature_transform, +) + + +@register_audio_feature_transform("utterance_cmvn") +class UtteranceCMVN(AudioFeatureTransform): + """Utterance-level CMVN (cepstral mean and variance normalization)""" + + @classmethod + def from_config_dict(cls, config=None): + _config = {} if config is None else config + return UtteranceCMVN( + _config.get("norm_means", True), + _config.get("norm_vars", True), + ) + + def __init__(self, norm_means=True, norm_vars=True): + self.norm_means, self.norm_vars = norm_means, norm_vars + + def __repr__(self): + return ( + self.__class__.__name__ + + f"(norm_means={self.norm_means}, norm_vars={self.norm_vars})" + ) + + def __call__(self, x): + mean = x.mean(axis=0) + square_sums = (x**2).sum(axis=0) + + if self.norm_means: + x = np.subtract(x, mean) + if self.norm_vars: + var = square_sums / x.shape[0] - mean**2 + std = np.sqrt(np.maximum(var, 1e-10)) + x = np.divide(x, std) + + return x diff --git a/fairseq/data/audio/frm_text_to_speech_dataset.py b/fairseq/data/audio/frm_text_to_speech_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b54654d49250605255b375e094457062e53fa55c --- /dev/null +++ b/fairseq/data/audio/frm_text_to_speech_dataset.py @@ -0,0 +1,205 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory.abs + +import csv +import logging +import os.path as op +from typing import List, Optional + +import numpy as np +import torch +from fairseq.data import Dictionary +from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig +from fairseq.data.audio.text_to_speech_dataset import ( + TextToSpeechDataset, + TextToSpeechDatasetCreator, +) + +logger = logging.getLogger(__name__) + + +class FrmTextToSpeechDataset(TextToSpeechDataset): + def __init__( + self, + split: str, + is_train_split: bool, + data_cfg: S2TDataConfig, + audio_paths: List[str], + n_frames: List[int], + src_texts: Optional[List[str]] = None, + tgt_texts: Optional[List[str]] = None, + speakers: Optional[List[str]] = None, + src_langs: Optional[List[str]] = None, + tgt_langs: Optional[List[str]] = None, + ids: Optional[List[str]] = None, + tgt_dict: Optional[Dictionary] = None, + pre_tokenizer=None, + bpe_tokenizer=None, + n_frames_per_step=1, + speaker_to_id=None, + do_chunk=False, + chunk_bound=-1, + chunk_init=50, + chunk_incr=5, + add_eos=True, + dedup=True, + ref_fpu=-1, + ): + # It assumes texts are encoded at a fixed frame-rate + super().__init__( + split=split, + is_train_split=is_train_split, + data_cfg=data_cfg, + audio_paths=audio_paths, + n_frames=n_frames, + src_texts=src_texts, + tgt_texts=tgt_texts, + speakers=speakers, + src_langs=src_langs, + tgt_langs=tgt_langs, + ids=ids, + tgt_dict=tgt_dict, + pre_tokenizer=pre_tokenizer, + bpe_tokenizer=bpe_tokenizer, + n_frames_per_step=n_frames_per_step, + speaker_to_id=speaker_to_id, + ) + + self.do_chunk = do_chunk + self.chunk_bound = chunk_bound + self.chunk_init = chunk_init + self.chunk_incr = chunk_incr + self.add_eos = add_eos + self.dedup = dedup + self.ref_fpu = ref_fpu + + self.chunk_size = -1 + + if do_chunk: + assert self.chunk_incr >= 0 + assert self.pre_tokenizer is None + + def __getitem__(self, index): + index, source, target, speaker_id, _, _, _ = super().__getitem__(index) + if target[-1].item() == self.tgt_dict.eos_index: + target = target[:-1] + + fpu = source.size(0) / target.size(0) # frame-per-unit + fps = self.n_frames_per_step + assert ( + self.ref_fpu == -1 or abs((fpu * fps - self.ref_fpu) / self.ref_fpu) < 0.1 + ), f"{fpu*fps} != {self.ref_fpu}" + + # only chunk training split + if self.is_train_split and self.do_chunk and self.chunk_size > 0: + lang = target[: int(self.data_cfg.prepend_tgt_lang_tag)] + text = target[int(self.data_cfg.prepend_tgt_lang_tag) :] + size = len(text) + chunk_size = min(self.chunk_size, size) + chunk_start = np.random.randint(size - chunk_size + 1) + text = text[chunk_start : chunk_start + chunk_size] + target = torch.cat((lang, text), 0) + + f_size = int(np.floor(chunk_size * fpu)) + f_start = int(np.floor(chunk_start * fpu)) + assert f_size > 0 + source = source[f_start : f_start + f_size, :] + + if self.dedup: + target = torch.unique_consecutive(target) + + if self.add_eos: + eos_idx = self.tgt_dict.eos_index + target = torch.cat((target, torch.LongTensor([eos_idx])), 0) + + return index, source, target, speaker_id + + def set_epoch(self, epoch): + if self.is_train_split and self.do_chunk: + old = self.chunk_size + self.chunk_size = self.chunk_init + epoch * self.chunk_incr + if self.chunk_bound > 0: + self.chunk_size = min(self.chunk_size, self.chunk_bound) + logger.info( + ( + f"{self.split}: setting chunk size " + f"from {old} to {self.chunk_size}" + ) + ) + + +class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator): + # inherit for key names + @classmethod + def from_tsv( + cls, + root: str, + data_cfg: S2TDataConfig, + split: str, + tgt_dict, + pre_tokenizer, + bpe_tokenizer, + is_train_split: bool, + n_frames_per_step: int, + speaker_to_id, + do_chunk: bool = False, + chunk_bound: int = -1, + chunk_init: int = 50, + chunk_incr: int = 5, + add_eos: bool = True, + dedup: bool = True, + ref_fpu: float = -1, + ) -> FrmTextToSpeechDataset: + tsv_path = op.join(root, f"{split}.tsv") + if not op.isfile(tsv_path): + raise FileNotFoundError(f"Dataset not found: {tsv_path}") + with open(tsv_path) as f: + reader = csv.DictReader( + f, + delimiter="\t", + quotechar=None, + doublequote=False, + lineterminator="\n", + quoting=csv.QUOTE_NONE, + ) + s = [dict(e) for e in reader] + assert len(s) > 0 + + ids = [ss[cls.KEY_ID] for ss in s] + audio_paths = [op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s] + n_frames = [int(ss[cls.KEY_N_FRAMES]) for ss in s] + tgt_texts = [ss[cls.KEY_TGT_TEXT] for ss in s] + src_texts = [ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for ss in s] + speakers = [ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for ss in s] + src_langs = [ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for ss in s] + tgt_langs = [ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for ss in s] + + return FrmTextToSpeechDataset( + split=split, + is_train_split=is_train_split, + data_cfg=data_cfg, + audio_paths=audio_paths, + n_frames=n_frames, + src_texts=src_texts, + tgt_texts=tgt_texts, + speakers=speakers, + src_langs=src_langs, + tgt_langs=tgt_langs, + ids=ids, + tgt_dict=tgt_dict, + pre_tokenizer=pre_tokenizer, + bpe_tokenizer=bpe_tokenizer, + n_frames_per_step=n_frames_per_step, + speaker_to_id=speaker_to_id, + do_chunk=do_chunk, + chunk_bound=chunk_bound, + chunk_init=chunk_init, + chunk_incr=chunk_incr, + add_eos=add_eos, + dedup=dedup, + ref_fpu=ref_fpu, + ) diff --git a/fairseq/data/audio/hubert_dataset.py b/fairseq/data/audio/hubert_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f09b065fdc9f3b0ff9586c13ca378626f1d4b604 --- /dev/null +++ b/fairseq/data/audio/hubert_dataset.py @@ -0,0 +1,356 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import logging +import os +import sys +from typing import Any, List, Optional, Union + +import numpy as np + +import torch +import torch.nn.functional as F +from fairseq.data import data_utils +from fairseq.data.fairseq_dataset import FairseqDataset +from fairseq.data.audio.audio_utils import ( + parse_path, + read_from_stored_zip, +) +import io + +logger = logging.getLogger(__name__) + + +def load_audio(manifest_path, max_keep, min_keep): + n_long, n_short = 0, 0 + names, inds, sizes = [], [], [] + with open(manifest_path) as f: + root = f.readline().strip() + for ind, line in enumerate(f): + items = line.strip().split("\t") + assert len(items) == 2, line + sz = int(items[1]) + if min_keep is not None and sz < min_keep: + n_short += 1 + elif max_keep is not None and sz > max_keep: + n_long += 1 + else: + names.append(items[0]) + inds.append(ind) + sizes.append(sz) + tot = ind + 1 + logger.info( + ( + f"max_keep={max_keep}, min_keep={min_keep}, " + f"loaded {len(names)}, skipped {n_short} short and {n_long} long, " + f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}" + ) + ) + return root, names, inds, tot, sizes + + +def load_label(label_path, inds, tot): + with open(label_path) as f: + labels = [line.rstrip() for line in f] + assert ( + len(labels) == tot + ), f"number of labels does not match ({len(labels)} != {tot})" + labels = [labels[i] for i in inds] + return labels + + +def load_label_offset(label_path, inds, tot): + with open(label_path) as f: + code_lengths = [len(line.encode("utf-8")) for line in f] + assert ( + len(code_lengths) == tot + ), f"number of labels does not match ({len(code_lengths)} != {tot})" + offsets = list(itertools.accumulate([0] + code_lengths)) + offsets = [(offsets[i], offsets[i + 1]) for i in inds] + return offsets + + +def verify_label_lengths( + audio_sizes, + audio_rate, + label_path, + label_rate, + inds, + tot, + tol=0.1, # tolerance in seconds +): + if label_rate < 0: + logger.info(f"{label_path} is sequence label. skipped") + return + + with open(label_path) as f: + lengths = [len(line.rstrip().split()) for line in f] + assert len(lengths) == tot + lengths = [lengths[i] for i in inds] + num_invalid = 0 + for i, ind in enumerate(inds): + dur_from_audio = audio_sizes[i] / audio_rate + dur_from_label = lengths[i] / label_rate + if abs(dur_from_audio - dur_from_label) > tol: + logger.warning( + ( + f"audio and label duration differ too much " + f"(|{dur_from_audio} - {dur_from_label}| > {tol}) " + f"in line {ind+1} of {label_path}. Check if `label_rate` " + f"is correctly set (currently {label_rate}). " + f"num. of samples = {audio_sizes[i]}; " + f"label length = {lengths[i]}" + ) + ) + num_invalid += 1 + if num_invalid > 0: + logger.warning( + f"total {num_invalid} (audio, label) pairs with mismatched lengths" + ) + + +class HubertDataset(FairseqDataset): + def __init__( + self, + manifest_path: str, + sample_rate: float, + label_paths: List[str], + label_rates: Union[List[float], float], # -1 for sequence labels + pad_list: List[str], + eos_list: List[str], + label_processors: Optional[List[Any]] = None, + max_keep_sample_size: Optional[int] = None, + min_keep_sample_size: Optional[int] = None, + max_sample_size: Optional[int] = None, + shuffle: bool = True, + pad_audio: bool = False, + normalize: bool = False, + store_labels: bool = True, + random_crop: bool = False, + single_target: bool = False, + ): + self.audio_root, self.audio_names, inds, tot, self.sizes = load_audio( + manifest_path, max_keep_sample_size, min_keep_sample_size + ) + self.sample_rate = sample_rate + self.shuffle = shuffle + self.random_crop = random_crop + + self.num_labels = len(label_paths) + self.pad_list = pad_list + self.eos_list = eos_list + self.label_processors = label_processors + self.single_target = single_target + self.label_rates = ( + [label_rates for _ in range(len(label_paths))] + if isinstance(label_rates, float) + else label_rates + ) + self.store_labels = store_labels + if store_labels: + self.label_list = [load_label(p, inds, tot) for p in label_paths] + else: + self.label_paths = label_paths + self.label_offsets_list = [ + load_label_offset(p, inds, tot) for p in label_paths + ] + assert label_processors is None or len(label_processors) == self.num_labels + for label_path, label_rate in zip(label_paths, self.label_rates): + verify_label_lengths( + self.sizes, sample_rate, label_path, label_rate, inds, tot + ) + + self.max_sample_size = ( + max_sample_size if max_sample_size is not None else sys.maxsize + ) + self.pad_audio = pad_audio + self.normalize = normalize + logger.info( + f"pad_audio={pad_audio}, random_crop={random_crop}, " + f"normalize={normalize}, max_sample_size={self.max_sample_size}" + ) + + def get_audio(self, index): + import soundfile as sf + + wav_path = os.path.join(self.audio_root, self.audio_names[index]) + _path, slice_ptr = parse_path(wav_path) + if len(slice_ptr) == 0: + wav, cur_sample_rate = sf.read(_path) + else: + assert _path.endswith(".zip") + data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1]) + f = io.BytesIO(data) + wav, cur_sample_rate = sf.read(f) + wav = torch.from_numpy(wav).float() + wav = self.postprocess(wav, cur_sample_rate) + return wav + + def get_label(self, index, label_idx): + if self.store_labels: + label = self.label_list[label_idx][index] + else: + with open(self.label_paths[label_idx]) as f: + offset_s, offset_e = self.label_offsets_list[label_idx][index] + f.seek(offset_s) + label = f.read(offset_e - offset_s) + + if self.label_processors is not None: + label = self.label_processors[label_idx](label) + return label + + def get_labels(self, index): + return [self.get_label(index, i) for i in range(self.num_labels)] + + def __getitem__(self, index): + wav = self.get_audio(index) + labels = self.get_labels(index) + return {"id": index, "source": wav, "label_list": labels} + + def __len__(self): + return len(self.sizes) + + def crop_to_max_size(self, wav, target_size): + size = len(wav) + diff = size - target_size + if diff <= 0: + return wav, 0 + + start, end = 0, target_size + if self.random_crop: + start = np.random.randint(0, diff + 1) + end = size - diff + start + return wav[start:end], start + + def collater(self, samples): + # target = max(sizes) -> random_crop not used + # target = max_sample_size -> random_crop used for long + samples = [s for s in samples if s["source"] is not None] + if len(samples) == 0: + return {} + + audios = [s["source"] for s in samples] + audio_sizes = [len(s) for s in audios] + if self.pad_audio: + audio_size = min(max(audio_sizes), self.max_sample_size) + else: + audio_size = min(min(audio_sizes), self.max_sample_size) + collated_audios, padding_mask, audio_starts = self.collater_audio( + audios, audio_size + ) + + targets_by_label = [ + [s["label_list"][i] for s in samples] for i in range(self.num_labels) + ] + targets_list, lengths_list, ntokens_list = self.collater_label( + targets_by_label, audio_size, audio_starts + ) + + net_input = {"source": collated_audios, "padding_mask": padding_mask} + batch = { + "id": torch.LongTensor([s["id"] for s in samples]), + "net_input": net_input, + } + + if self.single_target: + batch["target_lengths"] = lengths_list[0] + batch["ntokens"] = ntokens_list[0] + batch["target"] = targets_list[0] + else: + batch["target_lengths_list"] = lengths_list + batch["ntokens_list"] = ntokens_list + batch["target_list"] = targets_list + return batch + + def collater_audio(self, audios, audio_size): + collated_audios = audios[0].new_zeros(len(audios), audio_size) + padding_mask = ( + torch.BoolTensor(collated_audios.shape).fill_(False) + # if self.pad_audio else None + ) + audio_starts = [0 for _ in audios] + for i, audio in enumerate(audios): + diff = len(audio) - audio_size + if diff == 0: + collated_audios[i] = audio + elif diff < 0: + assert self.pad_audio + collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)]) + padding_mask[i, diff:] = True + else: + collated_audios[i], audio_starts[i] = self.crop_to_max_size( + audio, audio_size + ) + return collated_audios, padding_mask, audio_starts + + def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad): + assert label_rate > 0 + s2f = label_rate / self.sample_rate + frm_starts = [int(round(s * s2f)) for s in audio_starts] + frm_size = int(round(audio_size * s2f)) + if not self.pad_audio: + rem_size = [len(t) - s for t, s in zip(targets, frm_starts)] + frm_size = min(frm_size, *rem_size) + targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)] + logger.debug(f"audio_starts={audio_starts}") + logger.debug(f"frame_starts={frm_starts}") + logger.debug(f"frame_size={frm_size}") + + lengths = torch.LongTensor([len(t) for t in targets]) + ntokens = lengths.sum().item() + targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False) + return targets, lengths, ntokens + + def collater_seq_label(self, targets, pad): + lengths = torch.LongTensor([len(t) for t in targets]) + ntokens = lengths.sum().item() + targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False) + return targets, lengths, ntokens + + def collater_label(self, targets_by_label, audio_size, audio_starts): + targets_list, lengths_list, ntokens_list = [], [], [] + itr = zip(targets_by_label, self.label_rates, self.pad_list) + for targets, label_rate, pad in itr: + if label_rate == -1.0: + targets, lengths, ntokens = self.collater_seq_label(targets, pad) + else: + targets, lengths, ntokens = self.collater_frm_label( + targets, audio_size, audio_starts, label_rate, pad + ) + targets_list.append(targets) + lengths_list.append(lengths) + ntokens_list.append(ntokens) + return targets_list, lengths_list, ntokens_list + + def num_tokens(self, index): + return self.size(index) + + def size(self, index): + if self.pad_audio: + return self.sizes[index] + return min(self.sizes[index], self.max_sample_size) + + def ordered_indices(self): + if self.shuffle: + order = [np.random.permutation(len(self))] + else: + order = [np.arange(len(self))] + + order.append(self.sizes) + return np.lexsort(order)[::-1] + + def postprocess(self, wav, cur_sample_rate): + if wav.dim() == 2: + wav = wav.mean(-1) + assert wav.dim() == 1, wav.dim() + + if cur_sample_rate != self.sample_rate: + raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}") + + if self.normalize: + with torch.no_grad(): + wav = F.layer_norm(wav, wav.shape) + return wav diff --git a/fairseq/data/audio/multi_modality_dataset.py b/fairseq/data/audio/multi_modality_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0a42c10611000ee33590690de123729300fac18e --- /dev/null +++ b/fairseq/data/audio/multi_modality_dataset.py @@ -0,0 +1,267 @@ +# Copyright (c) 2021-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import logging +import math +from typing import List, Optional, NamedTuple + +import numpy as np +import torch +from fairseq.data import ( + ConcatDataset, + LanguagePairDataset, + FileAudioDataset, + data_utils, +) +from fairseq.data import FairseqDataset + +logger = logging.getLogger(__name__) + + +class ModalityDatasetItem(NamedTuple): + datasetname: str + dataset: any + max_positions: List[int] + max_tokens: Optional[int] = None + max_sentences: Optional[int] = None + + +# MultiModalityDataset: it concate multiple datasets with different modalities. +# Compared with ConcatDataset it can 1) sample data given the ratios for different datasets +# 2) it adds mode to indicate what type of the data samples come from. +# It will be used with GroupedEpochBatchIterator together to generate mini-batch with samples +# from the same type of dataset +# If only one dataset is used, it will perform like the original dataset with mode added +class MultiModalityDataset(ConcatDataset): + def __init__(self, datasets: List[ModalityDatasetItem]): + id_to_mode = [] + dsets = [] + max_tokens = [] + max_sentences = [] + max_positions = [] + for dset in datasets: + id_to_mode.append(dset.datasetname) + dsets.append(dset.dataset) + max_tokens.append(dset.max_tokens) + max_positions.append(dset.max_positions) + max_sentences.append(dset.max_sentences) + weights = [1.0 for s in dsets] + super().__init__(dsets, weights) + self.max_tokens = max_tokens + self.max_positions = max_positions + self.max_sentences = max_sentences + self.id_to_mode = id_to_mode + self.raw_sub_batch_samplers = [] + self._cur_epoch = 0 + + def set_epoch(self, epoch): + super().set_epoch(epoch) + self._cur_epoch = epoch + + def __getitem__(self, idx): + dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) + sample = self.datasets[dataset_idx][sample_idx] + return (dataset_idx, sample) + + def collater(self, samples): + if len(samples) == 0: + return {} + dataset_idx = samples[0][0] + # make sure all samples in samples are from same dataset + assert sum([0 if dataset_idx == s[0] else 1 for s in samples]) == 0 + samples = self.datasets[dataset_idx].collater([x[1] for x in samples]) + # add mode + samples["net_input"]["mode"] = self.id_to_mode[dataset_idx] + + return samples + + def size(self, index: int): + if len(self.datasets) == 1: + return self.datasets[0].size(index) + return super().size(index) + + @property + def sizes(self): + if len(self.datasets) == 1: + return self.datasets[0].sizes + return super().sizes + + def ordered_indices(self): + """ + Returns indices sorted by length. So less padding is needed. + """ + if len(self.datasets) == 1: + return [self.datasets[0].ordered_indices()] + indices_group = [] + for d_idx, ds in enumerate(self.datasets): + sample_num = self.cumulative_sizes[d_idx] + if d_idx > 0: + sample_num = sample_num - self.cumulative_sizes[d_idx - 1] + assert sample_num == len(ds) + indices_group.append(ds.ordered_indices()) + return indices_group + + def get_raw_batch_samplers(self, required_batch_size_multiple, seed): + if len(self.raw_sub_batch_samplers) > 0: + logger.info(" raw_sub_batch_samplers exists. No action is taken") + return + with data_utils.numpy_seed(seed): + indices = self.ordered_indices() + + for i, ds in enumerate(self.datasets): + indices[i] = ds.filter_indices_by_size( + indices[i], + self.max_positions[i], + )[0] + sub_batch_sampler = ds.batch_by_size( + indices[i], + max_tokens=self.max_tokens[i], + max_sentences=self.max_sentences[i], + required_batch_size_multiple=required_batch_size_multiple, + ) + self.raw_sub_batch_samplers.append(sub_batch_sampler) + + def get_batch_samplers(self, mult_ratios, required_batch_size_multiple, seed): + self.get_raw_batch_samplers(required_batch_size_multiple, seed) + batch_samplers = [] + for i, _ in enumerate(self.datasets): + if i > 0: + sub_batch_sampler = [ + [y + self.cumulative_sizes[i - 1] for y in x] + for x in self.raw_sub_batch_samplers[i] + ] + else: + sub_batch_sampler = list(self.raw_sub_batch_samplers[i]) + smp_r = mult_ratios[i] + if smp_r != 1: + is_increase = "increased" if smp_r > 1 else "decreased" + logger.info( + "number of batch for the dataset {} is {} from {} to {}".format( + self.id_to_mode[i], + is_increase, + len(sub_batch_sampler), + int(len(sub_batch_sampler) * smp_r), + ) + ) + mul_samplers = [] + for _ in range(math.floor(smp_r)): + mul_samplers = mul_samplers + sub_batch_sampler + if math.floor(smp_r) != smp_r: + with data_utils.numpy_seed(seed + self._cur_epoch): + np.random.shuffle(sub_batch_sampler) + smp_num = int( + (smp_r - math.floor(smp_r)) * len(sub_batch_sampler) + ) + mul_samplers = mul_samplers + sub_batch_sampler[:smp_num] + sub_batch_sampler = mul_samplers + else: + logger.info( + "dataset {} batch number is {} ".format( + self.id_to_mode[i], len(sub_batch_sampler) + ) + ) + batch_samplers.append(sub_batch_sampler) + + return batch_samplers + + +class LangPairMaskDataset(FairseqDataset): + def __init__( + self, + dataset: LanguagePairDataset, + src_eos: int, + src_bos: Optional[int] = None, + noise_id: Optional[int] = -1, + mask_ratio: Optional[float] = 0, + mask_type: Optional[str] = "random", + ): + self.dataset = dataset + self.src_eos = src_eos + self.src_bos = src_bos + self.noise_id = noise_id + self.mask_ratio = mask_ratio + self.mask_type = mask_type + assert mask_type in ("random", "tail") + + @property + def src_sizes(self): + return self.dataset.src_sizes + + @property + def tgt_sizes(self): + return self.dataset.tgt_sizes + + @property + def sizes(self): + # dataset.sizes can be a dynamically computed sizes: + return self.dataset.sizes + + def get_batch_shapes(self): + if hasattr(self.dataset, "get_batch_shapes"): + return self.dataset.get_batch_shapes() + return self.dataset.buckets + + def num_tokens_vec(self, indices): + return self.dataset.num_tokens_vec(indices) + + def __len__(self): + return len(self.dataset) + + def num_tokens(self, index): + return self.dataset.num_tokens(index) + + def size(self, index): + return self.dataset.size(index) + + def ordered_indices(self): + return self.dataset.ordered_indices() + + @property + def supports_prefetch(self): + return getattr(self.dataset, "supports_prefetch", False) + + def prefetch(self, indices): + return self.dataset.prefetch(indices) + + def mask_src_tokens(self, sample): + src_item = sample["source"] + mask = None + if self.mask_type == "random": + mask = torch.rand(len(src_item)).le(self.mask_ratio) + else: + mask = torch.ones(len(src_item)) + mask[: int(len(src_item) * (1 - self.mask_ratio))] = 0 + mask = mask.eq(1) + if src_item[0] == self.src_bos: + mask[0] = False + if src_item[-1] == self.src_eos: + mask[-1] = False + mask_src_item = src_item.masked_fill(mask, self.noise_id) + smp = {"id": sample["id"], "source": mask_src_item, "target": sample["target"]} + return smp + + def __getitem__(self, index): + sample = self.dataset[index] + if self.mask_ratio > 0: + sample = self.mask_src_tokens(sample) + return sample + + def collater(self, samples, pad_to_length=None): + return self.dataset.collater(samples, pad_to_length) + + +class FileAudioDatasetWrapper(FileAudioDataset): + def collater(self, samples): + samples = super().collater(samples) + if len(samples) == 0: + return {} + samples["net_input"]["src_tokens"] = samples["net_input"]["source"] + samples["net_input"]["prev_output_tokens"] = None + del samples["net_input"]["source"] + samples["net_input"]["src_lengths"] = None + samples["net_input"]["alignment"] = None + return samples diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ec202d5574e15ba5608e33e96f0a2b8de08331f1 --- /dev/null +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -0,0 +1,431 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +import os +import sys +import time +import io + +import numpy as np +import torch +import torch.nn.functional as F + +from .. import FairseqDataset +from ..data_utils import compute_block_mask_1d, get_buckets, get_bucketed_sizes +from fairseq.data.audio.audio_utils import ( + parse_path, + read_from_stored_zip, + is_sf_audio_data, +) +from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel + + +logger = logging.getLogger(__name__) + + +class RawAudioDataset(FairseqDataset): + def __init__( + self, + sample_rate, + max_sample_size=None, + min_sample_size=0, + shuffle=True, + pad=False, + normalize=False, + compute_mask=False, + feature_encoder_spec: str = "None", + mask_prob: float = 0.75, + mask_prob_adjust: float = 0, + mask_length: int = 1, + inverse_mask: bool = False, + require_same_masks: bool = True, + clone_batch: int = 1, + expand_adjacent: bool = False, + mask_dropout: float = 0, + non_overlapping: bool = False, + corpus_key=None, + ): + super().__init__() + + self.sample_rate = sample_rate + self.sizes = [] + self.max_sample_size = ( + max_sample_size if max_sample_size is not None else sys.maxsize + ) + self.min_sample_size = min_sample_size + self.pad = pad + self.shuffle = shuffle + self.normalize = normalize + + self.is_compute_mask = compute_mask + self.feature_encoder_spec = eval(feature_encoder_spec) + self._features_size_map = {} + self.mask_prob = mask_prob + self.mask_prob_adjust = mask_prob_adjust + self.mask_length = mask_length + self.inverse_mask = inverse_mask + self.require_same_masks = require_same_masks + self.clone_batch = clone_batch + self.expand_adjacent = expand_adjacent + self.mask_dropout = mask_dropout + self.non_overlapping = non_overlapping + self.corpus_key = corpus_key + + def __getitem__(self, index): + raise NotImplementedError() + + def __len__(self): + return len(self.sizes) + + def postprocess(self, feats, curr_sample_rate): + if feats.dim() == 2: + feats = feats.mean(-1) + + if curr_sample_rate != self.sample_rate: + raise Exception(f"sample rate: {curr_sample_rate}, need {self.sample_rate}") + + assert feats.dim() == 1, feats.dim() + + if self.normalize: + with torch.no_grad(): + feats = F.layer_norm(feats, feats.shape) + return feats + + def crop_to_max_size(self, t, target_size, dim=0): + size = t.size(dim) + diff = size - target_size + if diff <= 0: + return t + + start = np.random.randint(0, diff + 1) + end = size - diff + start + + slices = [] + for d in range(dim): + slices.append(slice(None)) + slices.append(slice(start, end)) + + return t[slices] + + @staticmethod + def _bucket_tensor(tensor, num_pad, value): + return F.pad(tensor, (0, num_pad), value=value) + + def collater(self, samples): + samples = [s for s in samples if s["source"] is not None] + if len(samples) == 0: + return {} + + sources = [s["source"] for s in samples] + sizes = [len(s) for s in sources] + + if self.pad: + target_size = min(max(sizes), self.max_sample_size) + else: + target_size = min(min(sizes), self.max_sample_size) + + collated_sources = sources[0].new_zeros(len(sources), target_size) + padding_mask = ( + torch.BoolTensor(collated_sources.shape).fill_(False) if self.pad else None + ) + for i, (source, size) in enumerate(zip(sources, sizes)): + diff = size - target_size + if diff == 0: + collated_sources[i] = source + elif diff < 0: + assert self.pad + collated_sources[i] = torch.cat( + [source, source.new_full((-diff,), 0.0)] + ) + padding_mask[i, diff:] = True + else: + collated_sources[i] = self.crop_to_max_size(source, target_size) + + input = {"source": collated_sources} + if self.corpus_key is not None: + input["corpus_key"] = [self.corpus_key] * len(sources) + out = {"id": torch.LongTensor([s["id"] for s in samples])} + if self.pad: + input["padding_mask"] = padding_mask + + if hasattr(self, "num_buckets") and self.num_buckets > 0: + assert self.pad, "Cannot bucket without padding first." + bucket = max(self._bucketed_sizes[s["id"]] for s in samples) + num_pad = bucket - collated_sources.size(-1) + if num_pad: + input["source"] = self._bucket_tensor(collated_sources, num_pad, 0) + input["padding_mask"] = self._bucket_tensor(padding_mask, num_pad, True) + + if "precomputed_mask" in samples[0]: + target_size = self._get_mask_indices_dims(target_size) + collated_mask = torch.cat( + [ + self.crop_to_max_size(s["precomputed_mask"], target_size, dim=1) + for s in samples + ], + dim=0, + ) + input["precomputed_mask"] = collated_mask + + out["net_input"] = input + return out + + def _get_mask_indices_dims(self, size, padding=0, dilation=1): + if size not in self.feature_encoder_spec: + L_in = size + for (_, kernel_size, stride) in self.feature_encoder_spec: + L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1 + L_out = 1 + L_out // stride + L_in = L_out + self._features_size_map[size] = L_out + return self._features_size_map[size] + + def num_tokens(self, index): + return self.size(index) + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + if self.pad: + return self.sizes[index] + return min(self.sizes[index], self.max_sample_size) + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + + if self.shuffle: + order = [np.random.permutation(len(self))] + order.append( + np.minimum( + np.array(self.sizes), + self.max_sample_size, + ) + ) + return np.lexsort(order)[::-1] + else: + return np.arange(len(self)) + + def set_bucket_info(self, num_buckets): + self.num_buckets = num_buckets + if self.num_buckets > 0: + self._collated_sizes = np.minimum( + np.array(self.sizes), + self.max_sample_size, + ) + self.buckets = get_buckets( + self._collated_sizes, + self.num_buckets, + ) + self._bucketed_sizes = get_bucketed_sizes( + self._collated_sizes, self.buckets + ) + logger.info( + f"{len(self.buckets)} bucket(s) for the audio dataset: " + f"{self.buckets}" + ) + + def filter_indices_by_size(self, indices, max_sizes): + return indices, [] + + +class FileAudioDataset(RawAudioDataset): + def __init__( + self, + manifest_path, + sample_rate, + max_sample_size=None, + min_sample_size=0, + shuffle=True, + pad=False, + normalize=False, + num_buckets=0, + compute_mask=False, + text_compression_level=TextCompressionLevel.none, + **mask_compute_kwargs, + ): + super().__init__( + sample_rate=sample_rate, + max_sample_size=max_sample_size, + min_sample_size=min_sample_size, + shuffle=shuffle, + pad=pad, + normalize=normalize, + compute_mask=compute_mask, + **mask_compute_kwargs, + ) + + self.text_compressor = TextCompressor(level=text_compression_level) + + skipped = 0 + self.fnames = [] + sizes = [] + self.skipped_indices = set() + + with open(manifest_path, "r") as f: + self.root_dir = f.readline().strip() + for i, line in enumerate(f): + items = line.strip().split("\t") + assert len(items) == 2, line + sz = int(items[1]) + if min_sample_size is not None and sz < min_sample_size: + skipped += 1 + self.skipped_indices.add(i) + continue + self.fnames.append(self.text_compressor.compress(items[0])) + sizes.append(sz) + logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples") + + self.sizes = np.array(sizes, dtype=np.int64) + + try: + import pyarrow + + self.fnames = pyarrow.array(self.fnames) + except: + logger.debug( + "Could not create a pyarrow array. Please install pyarrow for better performance" + ) + pass + + self.set_bucket_info(num_buckets) + + def __getitem__(self, index): + import soundfile as sf + + fn = self.fnames[index] + fn = fn if isinstance(self.fnames, list) else fn.as_py() + fn = self.text_compressor.decompress(fn) + path_or_fp = os.path.join(self.root_dir, fn) + _path, slice_ptr = parse_path(path_or_fp) + if len(slice_ptr) == 2: + byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1]) + assert is_sf_audio_data(byte_data) + path_or_fp = io.BytesIO(byte_data) + + retry = 3 + wav = None + for i in range(retry): + try: + wav, curr_sample_rate = sf.read(path_or_fp, dtype="float32") + break + except Exception as e: + logger.warning( + f"Failed to read {path_or_fp}: {e}. Sleeping for {1 * i}" + ) + time.sleep(1 * i) + + if wav is None: + raise Exception(f"Failed to load {path_or_fp}") + + feats = torch.from_numpy(wav).float() + feats = self.postprocess(feats, curr_sample_rate) + + v = {"id": index, "source": feats} + + if self.is_compute_mask: + T = self._get_mask_indices_dims(feats.size(-1)) + mask = compute_block_mask_1d( + shape=(self.clone_batch, T), + mask_prob=self.mask_prob, + mask_length=self.mask_length, + mask_prob_adjust=self.mask_prob_adjust, + inverse_mask=self.inverse_mask, + require_same_masks=True, + expand_adjcent=self.expand_adjacent, + mask_dropout=self.mask_dropout, + non_overlapping=self.non_overlapping, + ) + + v["precomputed_mask"] = mask + + return v + + +class BinarizedAudioDataset(RawAudioDataset): + def __init__( + self, + data_dir, + split, + sample_rate, + max_sample_size=None, + min_sample_size=0, + shuffle=True, + pad=False, + normalize=False, + num_buckets=0, + compute_mask=False, + **mask_compute_kwargs, + ): + super().__init__( + sample_rate=sample_rate, + max_sample_size=max_sample_size, + min_sample_size=min_sample_size, + shuffle=shuffle, + pad=pad, + normalize=normalize, + compute_mask=compute_mask, + **mask_compute_kwargs, + ) + + from fairseq.data import data_utils, Dictionary + + self.fnames_dict = Dictionary.load(os.path.join(data_dir, "dict.txt")) + + root_path = os.path.join(data_dir, f"{split}.root") + if os.path.exists(root_path): + with open(root_path, "r") as f: + self.root_dir = next(f).strip() + else: + self.root_dir = None + + fnames_path = os.path.join(data_dir, split) + self.fnames = data_utils.load_indexed_dataset(fnames_path, self.fnames_dict) + lengths_path = os.path.join(data_dir, f"{split}.lengths") + + with open(lengths_path, "r") as f: + for line in f: + sz = int(line.rstrip()) + assert ( + sz >= min_sample_size + ), f"Min sample size is not supported for binarized dataset, but found a sample with size {sz}" + self.sizes.append(sz) + + self.sizes = np.array(self.sizes, dtype=np.int64) + + self.set_bucket_info(num_buckets) + logger.info(f"loaded {len(self.fnames)} samples") + + def __getitem__(self, index): + import soundfile as sf + + fname = self.fnames_dict.string(self.fnames[index], separator="") + if self.root_dir: + fname = os.path.join(self.root_dir, fname) + + wav, curr_sample_rate = sf.read(fname) + feats = torch.from_numpy(wav).float() + feats = self.postprocess(feats, curr_sample_rate) + v = {"id": index, "source": feats} + + if self.is_compute_mask: + T = self._get_mask_indices_dims(feats.size(-1)) + mask = compute_block_mask_1d( + shape=(self.clone_batch, T), + mask_prob=self.mask_prob, + mask_length=self.mask_length, + mask_prob_adjust=self.mask_prob_adjust, + inverse_mask=self.inverse_mask, + require_same_masks=True, + expand_adjcent=self.expand_adjacent, + mask_dropout=self.mask_dropout, + non_overlapping=self.non_overlapping, + ) + + v["precomputed_mask"] = mask + + return v diff --git a/fairseq/data/audio/speech_to_speech_dataset.py b/fairseq/data/audio/speech_to_speech_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fe4b61f83112f8cdb8f8b30830f1db875488996c --- /dev/null +++ b/fairseq/data/audio/speech_to_speech_dataset.py @@ -0,0 +1,379 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import torch + +from fairseq.data import ConcatDataset, Dictionary +from fairseq.data import data_utils as fairseq_data_utils +from fairseq.data.audio.audio_utils import get_features_or_waveform +from fairseq.data.audio.data_cfg import S2SDataConfig +from fairseq.data.audio.speech_to_text_dataset import ( + SpeechToTextDataset, + SpeechToTextDatasetCreator, + TextTargetMultitaskData, + _collate_frames, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class SpeechToSpeechDatasetItem(object): + index: int + source: torch.Tensor + target: Optional[torch.Tensor] = None + target_speaker: Optional[torch.Tensor] = None + tgt_lang_tag: Optional[int] = None + + +class SpeechToSpeechDataset(SpeechToTextDataset): + def __init__( + self, + split: str, + is_train_split: bool, + data_cfg: S2SDataConfig, + src_audio_paths: List[str], + src_n_frames: List[int], + tgt_audio_paths: List[str], + tgt_n_frames: List[int], + src_langs: Optional[List[str]] = None, + tgt_langs: Optional[List[str]] = None, + ids: Optional[List[str]] = None, + target_is_code: bool = False, + tgt_dict: Dictionary = None, + n_frames_per_step: int = 1, + ): + tgt_texts = tgt_audio_paths if target_is_code else None + super().__init__( + split=split, + is_train_split=is_train_split, + cfg=data_cfg, + audio_paths=src_audio_paths, + n_frames=src_n_frames, + ids=ids, + tgt_dict=tgt_dict, + tgt_texts=tgt_texts, + src_langs=src_langs, + tgt_langs=tgt_langs, + n_frames_per_step=n_frames_per_step, + ) + + self.tgt_audio_paths = tgt_audio_paths + self.tgt_lens = [t // self.n_frames_per_step for t in tgt_n_frames] + + assert not target_is_code or tgt_dict is not None + self.target_is_code = target_is_code + + assert len(tgt_audio_paths) == self.n_samples + assert len(tgt_n_frames) == self.n_samples + + self.tgt_speakers = None + if self.cfg.target_speaker_embed: + samples = SpeechToTextDatasetCreator._load_samples_from_tsv( + self.cfg.target_speaker_embed, split + ) + spk_emb_dict = {s["id"]: s["speaker_embed"] for s in samples} + self.tgt_speakers = [spk_emb_dict[id] for id in self.ids] + assert len(self.tgt_speakers) == self.n_samples + + logger.info(self.__repr__()) + + def pack_units(self, input: torch.Tensor) -> torch.Tensor: + if self.n_frames_per_step <= 1: + return input + + offset = 4 + vocab_size = ( + len(self.tgt_dict) - offset + ) # remove offset from , , , , which is specific to fairseq dictionary + + assert input.dim() == 1 + stacked_input = ( + input[:-1].view(-1, self.n_frames_per_step) - offset + ) # remove + scale = [ + pow(vocab_size, self.n_frames_per_step - 1 - i) + for i in range(self.n_frames_per_step) + ] + scale = torch.LongTensor(scale).squeeze(0) + res = input.new((len(input) - 1) // self.n_frames_per_step + 1).fill_(input[-1]) + res[:-1] = (stacked_input * scale).sum(dim=1) + offset + + return res + + def __getitem__(self, index: int) -> SpeechToSpeechDatasetItem: + source = self._get_source_audio(index) + + tgt_lang_tag = None + if self.cfg.prepend_tgt_lang_tag_as_bos: + # prepend_tgt_lang_tag_as_bos: put tgt_lang_tag as bos of target + tgt_lang_tag = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict) + + if not self.target_is_code: + target = get_features_or_waveform(self.tgt_audio_paths[index]) + target = torch.from_numpy(target).float() + target = self.pack_frames(target) + else: + target = self.tgt_dict.encode_line( + self.tgt_audio_paths[index], + add_if_not_exist=False, + append_eos=True, + ).long() + if self.n_frames_per_step > 1: + n_tgt_frame = target.size(0) - 1 # exclude + keep_n_tgt_frame = n_tgt_frame - n_tgt_frame % self.n_frames_per_step + target = torch.cat( + ( + target[:keep_n_tgt_frame], + target.new_full((1,), self.tgt_dict.eos()), + ), + dim=0, + ) + + if self.tgt_speakers: + tgt_spk = get_features_or_waveform(self.tgt_speakers[index]) + tgt_spk = torch.from_numpy(tgt_spk).float() + else: + tgt_spk = torch.FloatTensor([]) + + return SpeechToSpeechDatasetItem( + index=index, + source=source, + target=target, + target_speaker=tgt_spk, + tgt_lang_tag=tgt_lang_tag, + ) + + def _collate_target(self, samples: List[SpeechToSpeechDatasetItem]) -> torch.Tensor: + if self.target_is_code: + target = fairseq_data_utils.collate_tokens( + [x.target for x in samples], + self.tgt_dict.pad(), + self.tgt_dict.eos(), + left_pad=False, + move_eos_to_beginning=False, + ) + # convert stacked units to a single id + pack_targets = [self.pack_units(x.target) for x in samples] + prev_output_tokens = fairseq_data_utils.collate_tokens( + pack_targets, + self.tgt_dict.pad(), + self.tgt_dict.eos(), + left_pad=False, + move_eos_to_beginning=True, + ) + target_lengths = torch.tensor( + [x.size(0) for x in pack_targets], dtype=torch.long + ) + else: + target = _collate_frames([x.target for x in samples], is_audio_input=False) + bsz, _, d = target.size() + prev_output_tokens = torch.cat( + (target.new_full((bsz, 1, d), 0.0), target[:, :-1, :]), dim=1 + ) + target_lengths = torch.tensor( + [x.target.size(0) for x in samples], dtype=torch.long + ) + + return target, prev_output_tokens, target_lengths + + def collater( + self, samples: List[SpeechToSpeechDatasetItem], return_order: bool = False + ) -> Dict: + if len(samples) == 0: + return {} + indices = torch.tensor([x.index for x in samples], dtype=torch.long) + frames = _collate_frames([x.source for x in samples], self.cfg.use_audio_input) + # sort samples by descending number of frames + n_frames = torch.tensor([x.source.size(0) for x in samples], dtype=torch.long) + n_frames, order = n_frames.sort(descending=True) + indices = indices.index_select(0, order) + frames = frames.index_select(0, order) + + target, prev_output_tokens, target_lengths = self._collate_target(samples) + target = target.index_select(0, order) + target_lengths = target_lengths.index_select(0, order) + prev_output_tokens = prev_output_tokens.index_select(0, order) + ntokens = sum(x.target.size(0) for x in samples) + + tgt_speakers = None + if self.cfg.target_speaker_embed: + tgt_speakers = _collate_frames( + [x.target_speaker for x in samples], is_audio_input=True + ).index_select(0, order) + + net_input = { + "src_tokens": frames, + "src_lengths": n_frames, + "prev_output_tokens": prev_output_tokens, + "tgt_speaker": tgt_speakers, # TODO: unify "speaker" and "tgt_speaker" + } + if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None: + for i in range(len(samples)): + net_input["prev_output_tokens"][i][0] = samples[order[i]].tgt_lang_tag + out = { + "id": indices, + "net_input": net_input, + "speaker": tgt_speakers, # to support Tacotron2 loss for speech-to-spectrogram model + "target": target, + "target_lengths": target_lengths, + "ntokens": ntokens, + "nsentences": len(samples), + } + if return_order: + out["order"] = order + return out + + +class SpeechToSpeechMultitaskDataset(SpeechToSpeechDataset): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.multitask_data = {} + + def add_multitask_dataset(self, task_name, task_data): + self.multitask_data[task_name] = task_data + + def __getitem__( + self, index: int + ) -> Tuple[SpeechToSpeechDatasetItem, Dict[str, torch.Tensor]]: + s2s_data = super().__getitem__(index) + + multitask_target = {} + sample_id = self.ids[index] + tgt_lang = self.tgt_langs[index] + for task_name, task_dataset in self.multitask_data.items(): + multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang) + + return s2s_data, multitask_target + + def collater( + self, samples: List[Tuple[SpeechToSpeechDatasetItem, Dict[str, torch.Tensor]]] + ) -> Dict: + if len(samples) == 0: + return {} + + out = super().collater([s for s, _ in samples], return_order=True) + order = out["order"] + del out["order"] + + for task_name, task_dataset in self.multitask_data.items(): + if "multitask" not in out: + out["multitask"] = {} + d = [s[task_name] for _, s in samples] + task_target = task_dataset.collater(d) + out["multitask"][task_name] = { + "target": task_target["target"].index_select(0, order), + "target_lengths": task_target["target_lengths"].index_select(0, order), + "ntokens": task_target["ntokens"], + } + out["multitask"][task_name]["net_input"] = { + "prev_output_tokens": task_target["prev_output_tokens"].index_select( + 0, order + ), + } + + return out + + +class SpeechToSpeechDatasetCreator(object): + # mandatory columns + KEY_ID, KEY_SRC_AUDIO, KEY_SRC_N_FRAMES = "id", "src_audio", "src_n_frames" + KEY_TGT_AUDIO, KEY_TGT_N_FRAMES = "tgt_audio", "tgt_n_frames" + # optional columns + KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang" + # default values + DEFAULT_LANG = "" + + @classmethod + def _from_list( + cls, + split_name: str, + is_train_split, + samples: List[Dict], + data_cfg: S2SDataConfig, + target_is_code: bool = False, + tgt_dict: Dictionary = None, + n_frames_per_step: int = 1, + multitask: Optional[Dict] = None, + ) -> SpeechToSpeechDataset: + audio_root = Path(data_cfg.audio_root) + ids = [s[cls.KEY_ID] for s in samples] + src_audio_paths = [ + (audio_root / s[cls.KEY_SRC_AUDIO]).as_posix() for s in samples + ] + tgt_audio_paths = [ + s[cls.KEY_TGT_AUDIO] + if target_is_code + else (audio_root / s[cls.KEY_TGT_AUDIO]).as_posix() + for s in samples + ] + src_n_frames = [int(s[cls.KEY_SRC_N_FRAMES]) for s in samples] + tgt_n_frames = [int(s[cls.KEY_TGT_N_FRAMES]) for s in samples] + src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples] + tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples] + + has_multitask = multitask is not None and len(multitask.keys()) > 0 + dataset_cls = ( + SpeechToSpeechMultitaskDataset if has_multitask else SpeechToSpeechDataset + ) + + ds = dataset_cls( + split=split_name, + is_train_split=is_train_split, + data_cfg=data_cfg, + src_audio_paths=src_audio_paths, + src_n_frames=src_n_frames, + tgt_audio_paths=tgt_audio_paths, + tgt_n_frames=tgt_n_frames, + src_langs=src_langs, + tgt_langs=tgt_langs, + ids=ids, + target_is_code=target_is_code, + tgt_dict=tgt_dict, + n_frames_per_step=n_frames_per_step, + ) + + if has_multitask: + for task_name, task_obj in multitask.items(): + task_data = TextTargetMultitaskData( + task_obj.args, split_name, task_obj.target_dictionary + ) + ds.add_multitask_dataset(task_name, task_data) + return ds + + @classmethod + def from_tsv( + cls, + root: str, + data_cfg: S2SDataConfig, + splits: str, + is_train_split: bool, + epoch: int, + seed: int, + target_is_code: bool = False, + tgt_dict: Dictionary = None, + n_frames_per_step: int = 1, + multitask: Optional[Dict] = None, + ) -> SpeechToSpeechDataset: + datasets = [] + for split in splits.split(","): + samples = SpeechToTextDatasetCreator._load_samples_from_tsv(root, split) + ds = cls._from_list( + split_name=split, + is_train_split=is_train_split, + samples=samples, + data_cfg=data_cfg, + target_is_code=target_is_code, + tgt_dict=tgt_dict, + n_frames_per_step=n_frames_per_step, + multitask=multitask, + ) + datasets.append(ds) + return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0] diff --git a/fairseq/data/audio/speech_to_text_dataset.py b/fairseq/data/audio/speech_to_text_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cdf71558fdf4c9e79c3a5b271363d96b4244216d --- /dev/null +++ b/fairseq/data/audio/speech_to_text_dataset.py @@ -0,0 +1,733 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import csv +import logging +import re +from argparse import Namespace +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F + +from fairseq.data import ConcatDataset, Dictionary, FairseqDataset, ResamplingDataset +from fairseq.data import data_utils as fairseq_data_utils +from fairseq.data import encoders +from fairseq.data.audio.audio_utils import get_features_or_waveform +from fairseq.data.audio.data_cfg import S2TDataConfig +from fairseq.data.audio.dataset_transforms import CompositeAudioDatasetTransform +from fairseq.data.audio.dataset_transforms.concataugment import ConcatAugment +from fairseq.data.audio.dataset_transforms.noisyoverlapaugment import ( + NoisyOverlapAugment, +) +from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform +from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform + +logger = logging.getLogger(__name__) + + +def _collate_frames( + frames: List[torch.Tensor], is_audio_input: bool = False +) -> torch.Tensor: + """ + Convert a list of 2D frames into a padded 3D tensor + Args: + frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is + length of i-th frame and f_dim is static dimension of features + Returns: + 3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i] + """ + max_len = max(frame.size(0) for frame in frames) + if is_audio_input: + out = frames[0].new_zeros((len(frames), max_len)) + else: + out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1))) + for i, v in enumerate(frames): + out[i, : v.size(0)] = v + return out + + +def _is_int_or_np_int(n): + return isinstance(n, int) or ( + isinstance(n, np.generic) and isinstance(n.item(), int) + ) + + +@dataclass +class SpeechToTextDatasetItem(object): + index: int + source: torch.Tensor + target: Optional[torch.Tensor] = None + speaker_id: Optional[int] = None + + +class SpeechToTextDataset(FairseqDataset): + LANG_TAG_TEMPLATE = "" + + def __init__( + self, + split: str, + is_train_split: bool, + cfg: S2TDataConfig, + audio_paths: List[str], + n_frames: List[int], + src_texts: Optional[List[str]] = None, + tgt_texts: Optional[List[str]] = None, + speakers: Optional[List[str]] = None, + src_langs: Optional[List[str]] = None, + tgt_langs: Optional[List[str]] = None, + ids: Optional[List[str]] = None, + tgt_dict: Optional[Dictionary] = None, + pre_tokenizer=None, + bpe_tokenizer=None, + n_frames_per_step=1, + speaker_to_id=None, + append_eos=True, + ): + self.split, self.is_train_split = split, is_train_split + self.cfg = cfg + self.audio_paths, self.n_frames = audio_paths, n_frames + self.n_samples = len(audio_paths) + assert len(n_frames) == self.n_samples > 0 + assert src_texts is None or len(src_texts) == self.n_samples + assert tgt_texts is None or len(tgt_texts) == self.n_samples + assert speakers is None or len(speakers) == self.n_samples + assert src_langs is None or len(src_langs) == self.n_samples + assert tgt_langs is None or len(tgt_langs) == self.n_samples + assert ids is None or len(ids) == self.n_samples + assert (tgt_dict is None and tgt_texts is None) or ( + tgt_dict is not None and tgt_texts is not None + ) + self.src_texts, self.tgt_texts = src_texts, tgt_texts + self.src_langs, self.tgt_langs = src_langs, tgt_langs + self.speakers = speakers + self.tgt_dict = tgt_dict + self.check_tgt_lang_tag() + self.ids = ids + self.shuffle = cfg.shuffle if is_train_split else False + + self.feature_transforms = CompositeAudioFeatureTransform.from_config_dict( + self.cfg.get_feature_transforms(split, is_train_split) + ) + self.waveform_transforms = CompositeAudioWaveformTransform.from_config_dict( + self.cfg.get_waveform_transforms(split, is_train_split) + ) + # TODO: add these to data_cfg.py + self.dataset_transforms = CompositeAudioDatasetTransform.from_config_dict( + self.cfg.get_dataset_transforms(split, is_train_split) + ) + + # check proper usage of transforms + if self.feature_transforms and self.cfg.use_audio_input: + logger.warning( + "Feature transforms will not be applied. To use feature transforms, " + "set use_audio_input as False in config." + ) + + self.pre_tokenizer = pre_tokenizer + self.bpe_tokenizer = bpe_tokenizer + self.n_frames_per_step = n_frames_per_step + self.speaker_to_id = speaker_to_id + + self.tgt_lens = self.get_tgt_lens_and_check_oov() + self.append_eos = append_eos + + logger.info(self.__repr__()) + + def get_tgt_lens_and_check_oov(self): + if self.tgt_texts is None: + return [0 for _ in range(self.n_samples)] + tgt_lens = [] + n_tokens, n_oov_tokens = 0, 0 + for i in range(self.n_samples): + tokenized = self.get_tokenized_tgt_text(i).split(" ") + oov_tokens = [ + t + for t in tokenized + if self.tgt_dict.index(t) == self.tgt_dict.unk_index + ] + n_tokens += len(tokenized) + n_oov_tokens += len(oov_tokens) + tgt_lens.append(len(tokenized)) + logger.info(f"'{self.split}' has {n_oov_tokens / n_tokens * 100:.2f}% OOV") + return tgt_lens + + def __repr__(self): + return ( + self.__class__.__name__ + + f'(split="{self.split}", n_samples={self.n_samples:_}, ' + f"prepend_tgt_lang_tag={self.cfg.prepend_tgt_lang_tag}, " + f"n_frames_per_step={self.n_frames_per_step}, " + f"shuffle={self.shuffle}, " + f"feature_transforms={self.feature_transforms}, " + f"waveform_transforms={self.waveform_transforms}, " + f"dataset_transforms={self.dataset_transforms})" + ) + + @classmethod + def is_lang_tag(cls, token): + pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)") + return re.match(pattern, token) + + def check_tgt_lang_tag(self): + if self.cfg.prepend_tgt_lang_tag: + assert self.tgt_langs is not None and self.tgt_dict is not None + tgt_lang_tags = [ + self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs) + ] + assert all(t in self.tgt_dict for t in tgt_lang_tags) + + @classmethod + def tokenize(cls, tokenizer, text: str): + return text if tokenizer is None else tokenizer.encode(text) + + def get_tokenized_tgt_text(self, index: Union[int, List[int]]): + if _is_int_or_np_int(index): + text = self.tgt_texts[index] + else: + text = " ".join([self.tgt_texts[i] for i in index]) + + text = self.tokenize(self.pre_tokenizer, text) + text = self.tokenize(self.bpe_tokenizer, text) + return text + + def pack_frames(self, feature: torch.Tensor): + if self.n_frames_per_step == 1: + return feature + n_packed_frames = feature.shape[0] // self.n_frames_per_step + feature = feature[: self.n_frames_per_step * n_packed_frames] + return feature.reshape(n_packed_frames, -1) + + @classmethod + def get_lang_tag_idx(cls, lang: str, dictionary: Dictionary): + lang_tag_idx = dictionary.index(cls.LANG_TAG_TEMPLATE.format(lang)) + assert lang_tag_idx != dictionary.unk() + return lang_tag_idx + + def _get_source_audio(self, index: Union[int, List[int]]) -> torch.Tensor: + """ + Gives source audio for given index with any relevant transforms + applied. For ConcatAug, source audios for given indices are + concatenated in given order. + Args: + index (int or List[int]): index—or in the case of ConcatAug, + indices—to pull the source audio for + Returns: + source audios concatenated for given indices with + relevant transforms appplied + """ + if _is_int_or_np_int(index): + source = get_features_or_waveform( + self.audio_paths[index], + need_waveform=self.cfg.use_audio_input, + use_sample_rate=self.cfg.use_sample_rate, + waveform_transforms=self.waveform_transforms, + ) + else: + source = np.concatenate( + [ + get_features_or_waveform( + self.audio_paths[i], + need_waveform=self.cfg.use_audio_input, + use_sample_rate=self.cfg.use_sample_rate, + waveform_transforms=self.waveform_transforms, + ) + for i in index + ] + ) + if self.cfg.use_audio_input: + source = torch.from_numpy(source).float() + if self.cfg.standardize_audio: + with torch.no_grad(): + source = F.layer_norm(source, source.shape) + else: + if self.feature_transforms is not None: + source = self.feature_transforms(source) + source = torch.from_numpy(source).float() + return source + + def __getitem__(self, index: int) -> SpeechToTextDatasetItem: + has_concat = self.dataset_transforms.has_transform(ConcatAugment) + if has_concat: + concat = self.dataset_transforms.get_transform(ConcatAugment) + indices = concat.find_indices(index, self.n_frames, self.n_samples) + + source = self._get_source_audio(indices if has_concat else index) + source = self.pack_frames(source) + + target = None + if self.tgt_texts is not None: + tokenized = self.get_tokenized_tgt_text(indices if has_concat else index) + target = self.tgt_dict.encode_line( + tokenized, add_if_not_exist=False, append_eos=self.append_eos + ).long() + if self.cfg.prepend_tgt_lang_tag: + lang_tag_idx = self.get_lang_tag_idx( + self.tgt_langs[index], self.tgt_dict + ) + target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0) + + if self.cfg.prepend_bos_and_append_tgt_lang_tag: + bos = torch.LongTensor([self.tgt_dict.bos()]) + lang_tag_idx = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict) + assert lang_tag_idx != self.tgt_dict.unk() + lang_tag_idx = torch.LongTensor([lang_tag_idx]) + target = torch.cat((bos, target, lang_tag_idx), 0) + + speaker_id = None + if self.speaker_to_id is not None: + speaker_id = self.speaker_to_id[self.speakers[index]] + return SpeechToTextDatasetItem( + index=index, source=source, target=target, speaker_id=speaker_id + ) + + def __len__(self): + return self.n_samples + + def collater( + self, samples: List[SpeechToTextDatasetItem], return_order: bool = False + ) -> Dict: + if len(samples) == 0: + return {} + indices = torch.tensor([x.index for x in samples], dtype=torch.long) + + sources = [x.source for x in samples] + has_NOAug = self.dataset_transforms.has_transform(NoisyOverlapAugment) + if has_NOAug and self.cfg.use_audio_input: + NOAug = self.dataset_transforms.get_transform(NoisyOverlapAugment) + sources = NOAug(sources) + + frames = _collate_frames(sources, self.cfg.use_audio_input) + # sort samples by descending number of frames + n_frames = torch.tensor([x.size(0) for x in sources], dtype=torch.long) + n_frames, order = n_frames.sort(descending=True) + indices = indices.index_select(0, order) + frames = frames.index_select(0, order) + + target, target_lengths = None, None + prev_output_tokens = None + ntokens = None + if self.tgt_texts is not None: + target = fairseq_data_utils.collate_tokens( + [x.target for x in samples], + self.tgt_dict.pad(), + self.tgt_dict.eos(), + left_pad=False, + move_eos_to_beginning=False, + ) + target = target.index_select(0, order) + target_lengths = torch.tensor( + [x.target.size(0) for x in samples], dtype=torch.long + ).index_select(0, order) + prev_output_tokens = fairseq_data_utils.collate_tokens( + [x.target for x in samples], + self.tgt_dict.pad(), + eos_idx=None, + left_pad=False, + move_eos_to_beginning=True, + ) + prev_output_tokens = prev_output_tokens.index_select(0, order) + ntokens = sum(x.target.size(0) for x in samples) + + speaker = None + if self.speaker_to_id is not None: + speaker = ( + torch.tensor([s.speaker_id for s in samples], dtype=torch.long) + .index_select(0, order) + .view(-1, 1) + ) + + net_input = { + "src_tokens": frames, + "src_lengths": n_frames, + "prev_output_tokens": prev_output_tokens, + } + out = { + "id": indices, + "net_input": net_input, + "speaker": speaker, + "target": target, + "target_lengths": target_lengths, + "ntokens": ntokens, + "nsentences": len(samples), + } + if return_order: + out["order"] = order + return out + + def num_tokens(self, index): + return self.n_frames[index] + + def size(self, index): + return self.n_frames[index], self.tgt_lens[index] + + @property + def sizes(self): + return np.array(self.n_frames) + + @property + def can_reuse_epoch_itr_across_epochs(self): + return True + + def ordered_indices(self): + if self.shuffle: + order = [np.random.permutation(len(self))] + else: + order = [np.arange(len(self))] + # first by descending order of # of frames then by original/random order + order.append([-n for n in self.n_frames]) + return np.lexsort(order) + + def prefetch(self, indices): + raise False + + +class TextTargetMultitaskData(object): + # mandatory columns + KEY_ID, KEY_TEXT = "id", "tgt_text" + LANG_TAG_TEMPLATE = "" + + def __init__(self, args, split, tgt_dict): + samples = SpeechToTextDatasetCreator._load_samples_from_tsv(args.data, split) + self.data = {s[self.KEY_ID]: s[self.KEY_TEXT] for s in samples} + self.dict = tgt_dict + self.append_eos = args.decoder_type != "ctc" + self.pre_tokenizer = self.build_tokenizer(args) + self.bpe_tokenizer = self.build_bpe(args) + self.prepend_bos_and_append_tgt_lang_tag = ( + args.prepend_bos_and_append_tgt_lang_tag + ) + self.eos_token = args.eos_token + self.lang_tag_mapping = args.get_lang_tag_mapping + + @classmethod + def is_lang_tag(cls, token): + pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)") + return re.match(pattern, token) + + @classmethod + def tokenize(cls, tokenizer, text: str): + return text if tokenizer is None else tokenizer.encode(text) + + def get_tokenized_tgt_text(self, index: int): + text = self.tokenize(self.pre_tokenizer, self.data[index]) + text = self.tokenize(self.bpe_tokenizer, text) + return text + + def get_lang_tag_idx(self, lang: str, dictionary: Dictionary): + lang_tag = self.LANG_TAG_TEMPLATE.format(lang) + lang_tag = self.lang_tag_mapping.get(lang_tag, lang_tag) + lang_tag_idx = dictionary.index(lang_tag) + assert lang_tag_idx != dictionary.unk(), (lang, lang_tag) + return lang_tag_idx + + def build_tokenizer(self, args): + pre_tokenizer = args.config.get("pre_tokenizer") + if pre_tokenizer is not None: + logger.info(f"pre-tokenizer: {pre_tokenizer}") + return encoders.build_tokenizer(Namespace(**pre_tokenizer)) + else: + return None + + def build_bpe(self, args): + bpe_tokenizer = args.config.get("bpe_tokenizer") + if bpe_tokenizer is not None: + logger.info(f"tokenizer: {bpe_tokenizer}") + return encoders.build_bpe(Namespace(**bpe_tokenizer)) + else: + return None + + def get(self, sample_id, tgt_lang=None): + if sample_id in self.data: + tokenized = self.get_tokenized_tgt_text(sample_id) + target = self.dict.encode_line( + tokenized, + add_if_not_exist=False, + append_eos=self.append_eos, + ) + if self.prepend_bos_and_append_tgt_lang_tag: + bos = torch.LongTensor([self.dict.bos()]) + lang_tag_idx = self.get_lang_tag_idx(tgt_lang, self.dict) + assert lang_tag_idx != self.dict.unk() + lang_tag_idx = torch.LongTensor([lang_tag_idx]) + target = torch.cat((bos, target, lang_tag_idx), 0) + return target + else: + logger.warning(f"no target for {sample_id}") + return torch.IntTensor([]) + + def collater(self, samples: List[torch.Tensor]) -> torch.Tensor: + out = fairseq_data_utils.collate_tokens( + samples, + self.dict.pad(), + eos_idx=None, + left_pad=False, + move_eos_to_beginning=False, + ).long() + + prev_out = fairseq_data_utils.collate_tokens( + samples, + self.dict.pad(), + eos_idx=None, + left_pad=False, + move_eos_to_beginning=True, + ).long() + + target_lengths = torch.tensor([t.size(0) for t in samples], dtype=torch.long) + ntokens = sum(t.size(0) for t in samples) + + output = { + "prev_output_tokens": prev_out, + "target": out, + "target_lengths": target_lengths, + "ntokens": ntokens, + } + + return output + + +class SpeechToTextMultitaskDataset(SpeechToTextDataset): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.multitask_data = {} + + def add_multitask_dataset(self, task_name, task_data): + self.multitask_data[task_name] = task_data + + def __getitem__( + self, index: int + ) -> Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]: + s2t_data = super().__getitem__(index) + + multitask_target = {} + sample_id = self.ids[index] + tgt_lang = self.tgt_langs[index] + for task_name, task_dataset in self.multitask_data.items(): + multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang) + + return s2t_data, multitask_target + + def collater( + self, samples: List[Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]] + ) -> Dict: + if len(samples) == 0: + return {} + + out = super().collater([s for s, _ in samples], return_order=True) + order = out["order"] + del out["order"] + + for task_name, task_dataset in self.multitask_data.items(): + if "multitask" not in out: + out["multitask"] = {} + d = [s[task_name] for _, s in samples] + task_target = task_dataset.collater(d) + out["multitask"][task_name] = { + "target": task_target["target"].index_select(0, order), + "target_lengths": task_target["target_lengths"].index_select(0, order), + "ntokens": task_target["ntokens"], + } + out["multitask"][task_name]["net_input"] = { + "prev_output_tokens": task_target["prev_output_tokens"].index_select( + 0, order + ), + } + + return out + + +class SpeechToTextDatasetCreator(object): + # mandatory columns + KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames" + KEY_TGT_TEXT = "tgt_text" + # optional columns + KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text" + KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang" + # default values + DEFAULT_SPEAKER = DEFAULT_SRC_TEXT = DEFAULT_LANG = "" + + @classmethod + def _from_list( + cls, + split_name: str, + is_train_split, + samples: List[Dict], + cfg: S2TDataConfig, + tgt_dict, + pre_tokenizer, + bpe_tokenizer, + n_frames_per_step, + speaker_to_id, + multitask: Optional[Dict] = None, + ) -> SpeechToTextDataset: + audio_root = Path(cfg.audio_root) + ids = [s[cls.KEY_ID] for s in samples] + audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples] + n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples] + tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples] + src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples] + speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples] + src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples] + tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples] + + has_multitask = multitask is not None and len(multitask.keys()) > 0 + dataset_cls = ( + SpeechToTextMultitaskDataset if has_multitask else SpeechToTextDataset + ) + + ds = dataset_cls( + split=split_name, + is_train_split=is_train_split, + cfg=cfg, + audio_paths=audio_paths, + n_frames=n_frames, + src_texts=src_texts, + tgt_texts=tgt_texts, + speakers=speakers, + src_langs=src_langs, + tgt_langs=tgt_langs, + ids=ids, + tgt_dict=tgt_dict, + pre_tokenizer=pre_tokenizer, + bpe_tokenizer=bpe_tokenizer, + n_frames_per_step=n_frames_per_step, + speaker_to_id=speaker_to_id, + ) + + if has_multitask: + for task_name, task_obj in multitask.items(): + task_data = TextTargetMultitaskData( + task_obj.args, split_name, task_obj.target_dictionary + ) + ds.add_multitask_dataset(task_name, task_data) + return ds + + @classmethod + def get_size_ratios( + cls, datasets: List[SpeechToTextDataset], alpha: float = 1.0 + ) -> List[float]: + """Size ratios for temperature-based sampling + (https://arxiv.org/abs/1907.05019)""" + + id_to_lp, lp_to_sz = {}, defaultdict(int) + for ds in datasets: + lang_pairs = {f"{s}->{t}" for s, t in zip(ds.src_langs, ds.tgt_langs)} + assert len(lang_pairs) == 1 + lang_pair = list(lang_pairs)[0] + id_to_lp[ds.split] = lang_pair + lp_to_sz[lang_pair] += sum(ds.n_frames) + + sz_sum = sum(v for v in lp_to_sz.values()) + lp_to_prob = {k: v / sz_sum for k, v in lp_to_sz.items()} + lp_to_tgt_prob = {k: v**alpha for k, v in lp_to_prob.items()} + prob_sum = sum(v for v in lp_to_tgt_prob.values()) + lp_to_tgt_prob = {k: v / prob_sum for k, v in lp_to_tgt_prob.items()} + lp_to_sz_ratio = { + k: (lp_to_tgt_prob[k] * sz_sum) / v for k, v in lp_to_sz.items() + } + size_ratio = [lp_to_sz_ratio[id_to_lp[ds.split]] for ds in datasets] + + p_formatted = { + k: f"{lp_to_prob[k]:.3f}->{lp_to_tgt_prob[k]:.3f}" for k in lp_to_sz + } + logger.info(f"sampling probability balancing: {p_formatted}") + sr_formatted = {ds.split: f"{r:.3f}" for ds, r in zip(datasets, size_ratio)} + logger.info(f"balanced sampling size ratio: {sr_formatted}") + return size_ratio + + @classmethod + def _load_samples_from_tsv(cls, root: str, split: str): + tsv_path = Path(root) / f"{split}.tsv" + if not tsv_path.is_file(): + raise FileNotFoundError(f"Dataset not found: {tsv_path}") + with open(tsv_path) as f: + reader = csv.DictReader( + f, + delimiter="\t", + quotechar=None, + doublequote=False, + lineterminator="\n", + quoting=csv.QUOTE_NONE, + ) + samples = [dict(e) for e in reader] + if len(samples) == 0: + raise ValueError(f"Empty manifest: {tsv_path}") + return samples + + @classmethod + def _from_tsv( + cls, + root: str, + cfg: S2TDataConfig, + split: str, + tgt_dict, + is_train_split: bool, + pre_tokenizer, + bpe_tokenizer, + n_frames_per_step, + speaker_to_id, + multitask: Optional[Dict] = None, + ) -> SpeechToTextDataset: + samples = cls._load_samples_from_tsv(root, split) + return cls._from_list( + split, + is_train_split, + samples, + cfg, + tgt_dict, + pre_tokenizer, + bpe_tokenizer, + n_frames_per_step, + speaker_to_id, + multitask, + ) + + @classmethod + def from_tsv( + cls, + root: str, + cfg: S2TDataConfig, + splits: str, + tgt_dict, + pre_tokenizer, + bpe_tokenizer, + is_train_split: bool, + epoch: int, + seed: int, + n_frames_per_step: int = 1, + speaker_to_id=None, + multitask: Optional[Dict] = None, + ) -> SpeechToTextDataset: + datasets = [ + cls._from_tsv( + root=root, + cfg=cfg, + split=split, + tgt_dict=tgt_dict, + is_train_split=is_train_split, + pre_tokenizer=pre_tokenizer, + bpe_tokenizer=bpe_tokenizer, + n_frames_per_step=n_frames_per_step, + speaker_to_id=speaker_to_id, + multitask=multitask, + ) + for split in splits.split(",") + ] + + if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0: + # temperature-based sampling + size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha) + datasets = [ + ResamplingDataset( + d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0) + ) + for r, d in zip(size_ratios, datasets) + ] + + return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0] diff --git a/fairseq/data/audio/speech_to_text_joint_dataset.py b/fairseq/data/audio/speech_to_text_joint_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..06922ea083714aa7b9b7b7a284db5eaf0a79a37a --- /dev/null +++ b/fairseq/data/audio/speech_to_text_joint_dataset.py @@ -0,0 +1,359 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from pathlib import Path +from typing import Dict, List, NamedTuple, Optional + +import torch + +from fairseq.data import ConcatDataset, Dictionary, ResamplingDataset +from fairseq.data import data_utils as fairseq_data_utils +from fairseq.data.audio.speech_to_text_dataset import ( + S2TDataConfig, + SpeechToTextDataset, + SpeechToTextDatasetCreator, +) + +logger = logging.getLogger(__name__) + + +class S2TJointDataConfig(S2TDataConfig): + """Wrapper class for data config YAML""" + + @property + def src_vocab_filename(self): + """fairseq vocabulary file under data root""" + return self.config.get("src_vocab_filename", "src_dict.txt") + + @property + def src_pre_tokenizer(self) -> Dict: + """Pre-tokenizer to apply before subword tokenization. Returning + a dictionary with `tokenizer` providing the tokenizer name and + the other items providing the tokenizer-specific arguments. + Tokenizers are defined in `fairseq.data.encoders.*`""" + return self.config.get("src_pre_tokenizer", {"tokenizer": None}) + + @property + def src_bpe_tokenizer(self) -> Dict: + """Subword tokenizer to apply on source text after pre-tokenization. + Returning a dictionary with `bpe` providing the tokenizer name and + the other items providing the tokenizer-specific arguments. + Tokenizers are defined in `fairseq.data.encoders.*`""" + return self.config.get("src_bpe_tokenizer", {"bpe": None}) + + @property + def prepend_tgt_lang_tag_no_change(self) -> bool: + """Prepend target lang ID token as the prev_output_tokens BOS (e.g. for + to-many multilingual setting). No change needed during inference. + This option is deprecated and replaced by prepend_tgt_lang_tag_as_bos. + """ + value = self.config.get("prepend_tgt_lang_tag_no_change", None) + if value is None: + return self.config.get("prepend_tgt_lang_tag_as_bos", False) + return value + + @property + def sampling_text_alpha(self): + """Hyper-parameter alpha = 1/T for temperature-based resampling. (text + input only) (alpha = 1 for no resampling)""" + return self.config.get("sampling_text_alpha", 1.0) + + +class SpeechToTextJointDatasetItem(NamedTuple): + index: int + source: torch.Tensor + target: Optional[torch.Tensor] = None + src_txt_tokens: Optional[torch.Tensor] = None + tgt_lang_tag: Optional[int] = None + src_lang_tag: Optional[int] = None + tgt_alignment: Optional[torch.Tensor] = None + + +# use_src_lang_id: +# 0: don't use src_lang_id +# 1: attach src_lang_id to the src_txt_tokens as eos +class SpeechToTextJointDataset(SpeechToTextDataset): + def __init__( + self, + split: str, + is_train_split: bool, + cfg: S2TJointDataConfig, + audio_paths: List[str], + n_frames: List[int], + src_texts: Optional[List[str]] = None, + tgt_texts: Optional[List[str]] = None, + speakers: Optional[List[str]] = None, + src_langs: Optional[List[str]] = None, + tgt_langs: Optional[List[str]] = None, + ids: Optional[List[str]] = None, + tgt_dict: Optional[Dictionary] = None, + src_dict: Optional[Dictionary] = None, + pre_tokenizer=None, + bpe_tokenizer=None, + src_pre_tokenizer=None, + src_bpe_tokenizer=None, + append_eos: Optional[bool] = True, + alignment: Optional[List[str]] = None, + use_src_lang_id: Optional[int] = 0, + ): + super().__init__( + split, + is_train_split, + cfg, + audio_paths, + n_frames, + src_texts=src_texts, + tgt_texts=tgt_texts, + speakers=speakers, + src_langs=src_langs, + tgt_langs=tgt_langs, + ids=ids, + tgt_dict=tgt_dict, + pre_tokenizer=pre_tokenizer, + bpe_tokenizer=bpe_tokenizer, + append_eos=append_eos, + ) + + self.src_dict = src_dict + self.src_pre_tokenizer = src_pre_tokenizer + self.src_bpe_tokenizer = src_bpe_tokenizer + self.alignment = None + self.use_src_lang_id = use_src_lang_id + if alignment is not None: + self.alignment = [ + [float(s) for s in sample.split()] for sample in alignment + ] + + def get_tokenized_src_text(self, index: int): + text = self.tokenize(self.src_pre_tokenizer, self.src_texts[index]) + text = self.tokenize(self.src_bpe_tokenizer, text) + return text + + def __getitem__(self, index: int) -> SpeechToTextJointDatasetItem: + s2t_dataset_item = super().__getitem__(index) + src_tokens = None + src_lang_tag = None + if self.src_texts is not None and self.src_dict is not None: + src_tokens = self.get_tokenized_src_text(index) + src_tokens = self.src_dict.encode_line( + src_tokens, add_if_not_exist=False, append_eos=True + ).long() + if self.use_src_lang_id > 0: + src_lang_tag = self.get_lang_tag_idx( + self.src_langs[index], self.src_dict + ) + tgt_lang_tag = None + if self.cfg.prepend_tgt_lang_tag_no_change: + # prepend_tgt_lang_tag_no_change: modify prev_output_tokens instead + tgt_lang_tag = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict) + ali = None + if self.alignment is not None: + ali = torch.Tensor(self.alignment[index]).float() + + return SpeechToTextJointDatasetItem( + index=index, + source=s2t_dataset_item.source, + target=s2t_dataset_item.target, + src_txt_tokens=src_tokens, + tgt_lang_tag=tgt_lang_tag, + src_lang_tag=src_lang_tag, + tgt_alignment=ali, + ) + + def __len__(self): + return self.n_samples + + def collater(self, samples: List[SpeechToTextJointDatasetItem]) -> Dict: + s2t_out = super().collater(samples, return_order=True) + if s2t_out == {}: + return s2t_out + net_input, order = s2t_out["net_input"], s2t_out["order"] + + if self.src_texts is not None and self.src_dict is not None: + src_txt_tokens = fairseq_data_utils.collate_tokens( + [x.src_txt_tokens for x in samples], + self.src_dict.pad(), + self.src_dict.eos(), + left_pad=False, + move_eos_to_beginning=False, + ) + src_txt_lengths = torch.tensor( + [x.src_txt_tokens.size()[0] for x in samples], dtype=torch.long + ) + if self.use_src_lang_id > 0: + src_lang_idxs = torch.tensor( + [s.src_lang_tag for s in samples], dtype=src_txt_tokens.dtype + ) + if self.use_src_lang_id == 1: # replace eos with lang_id + eos_idx = src_txt_lengths - 1 + src_txt_tokens.scatter_( + 1, eos_idx.view(-1, 1), src_lang_idxs.view(-1, 1) + ) + else: + raise NotImplementedError("Implementation is required") + + src_txt_tokens = src_txt_tokens.index_select(0, order) + src_txt_lengths = src_txt_lengths.index_select(0, order) + net_input["src_txt_tokens"] = src_txt_tokens + net_input["src_txt_lengths"] = src_txt_lengths + + net_input["alignment"] = None + if self.alignment is not None: + max_len = max([s.tgt_alignment.size(0) for s in samples]) + alignment = torch.ones(len(samples), max_len).float() + for i, s in enumerate(samples): + cur_len = s.tgt_alignment.size(0) + alignment[i][:cur_len].copy_(s.tgt_alignment) + net_input["alignment"] = alignment.index_select(0, order) + + if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None: + for i in range(len(samples)): + net_input["prev_output_tokens"][i][0] = samples[order[i]].tgt_lang_tag + + out = { + "id": s2t_out["id"], + "net_input": net_input, + "target": s2t_out["target"], + "target_lengths": s2t_out["target_lengths"], + "ntokens": s2t_out["ntokens"], + "nsentences": len(samples), + } + return out + + +class SpeechToTextJointDatasetCreator(SpeechToTextDatasetCreator): + KEY_ALIGN = "align" + + @classmethod + def _from_list( + cls, + split_name: str, + is_train_split, + samples: List[Dict], + cfg: S2TJointDataConfig, + tgt_dict, + src_dict, + pre_tokenizer, + bpe_tokenizer, + src_pre_tokenizer, + src_bpe_tokenizer, + append_eos, + use_src_lang_id, + ) -> SpeechToTextJointDataset: + audio_root = Path(cfg.audio_root) + ids = [s[cls.KEY_ID] for s in samples] + audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples] + n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples] + tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples] + src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples] + speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples] + src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples] + tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples] + tgt_alignment = None + if cls.KEY_ALIGN in samples[0].keys(): + tgt_alignment = [s[cls.KEY_ALIGN] for s in samples] + return SpeechToTextJointDataset( + split_name, + is_train_split, + cfg, + audio_paths, + n_frames, + src_texts=src_texts, + tgt_texts=tgt_texts, + speakers=speakers, + src_langs=src_langs, + tgt_langs=tgt_langs, + ids=ids, + tgt_dict=tgt_dict, + src_dict=src_dict, + pre_tokenizer=pre_tokenizer, + bpe_tokenizer=bpe_tokenizer, + src_pre_tokenizer=src_pre_tokenizer, + src_bpe_tokenizer=src_bpe_tokenizer, + append_eos=append_eos, + alignment=tgt_alignment, + use_src_lang_id=use_src_lang_id, + ) + + @classmethod + def _from_tsv( + cls, + root: str, + cfg: S2TJointDataConfig, + split: str, + tgt_dict, + src_dict, + is_train_split: bool, + pre_tokenizer, + bpe_tokenizer, + src_pre_tokenizer, + src_bpe_tokenizer, + append_eos: bool, + use_src_lang_id: int, + ) -> SpeechToTextJointDataset: + samples = cls._load_samples_from_tsv(root, split) + return cls._from_list( + split, + is_train_split, + samples, + cfg, + tgt_dict, + src_dict, + pre_tokenizer, + bpe_tokenizer, + src_pre_tokenizer, + src_bpe_tokenizer, + append_eos, + use_src_lang_id, + ) + + @classmethod + def from_tsv( + cls, + root: str, + cfg: S2TJointDataConfig, + splits: str, + tgt_dict, + src_dict, + pre_tokenizer, + bpe_tokenizer, + src_pre_tokenizer, + src_bpe_tokenizer, + is_train_split: bool, + epoch: int, + seed: int, + append_eos: Optional[bool] = True, + use_src_lang_id: Optional[int] = 0, + ) -> SpeechToTextJointDataset: + datasets = [ + cls._from_tsv( + root, + cfg, + split, + tgt_dict, + src_dict, + is_train_split, + pre_tokenizer, + bpe_tokenizer, + src_pre_tokenizer, + src_bpe_tokenizer, + append_eos=append_eos, + use_src_lang_id=use_src_lang_id, + ) + for split in splits.split(",") + ] + + if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0: + # temperature-based sampling + size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha) + datasets = [ + ResamplingDataset( + d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0) + ) + for r, d in zip(size_ratios, datasets) + ] + + return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0] diff --git a/fairseq/data/audio/text_to_speech_dataset.py b/fairseq/data/audio/text_to_speech_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..13612b458b3635551ba2bf7e99160675a1ff4f9b --- /dev/null +++ b/fairseq/data/audio/text_to_speech_dataset.py @@ -0,0 +1,250 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory.abs + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional + +import numpy as np +import torch + +from fairseq.data import Dictionary +from fairseq.data import data_utils as fairseq_data_utils +from fairseq.data.audio.audio_utils import get_features_or_waveform +from fairseq.data.audio.speech_to_text_dataset import ( + S2TDataConfig, + SpeechToTextDataset, + SpeechToTextDatasetCreator, + _collate_frames, +) + + +@dataclass +class TextToSpeechDatasetItem(object): + index: int + source: torch.Tensor + target: Optional[torch.Tensor] = None + speaker_id: Optional[int] = None + duration: Optional[torch.Tensor] = None + pitch: Optional[torch.Tensor] = None + energy: Optional[torch.Tensor] = None + + +class TextToSpeechDataset(SpeechToTextDataset): + def __init__( + self, + split: str, + is_train_split: bool, + cfg: S2TDataConfig, + audio_paths: List[str], + n_frames: List[int], + src_texts: Optional[List[str]] = None, + tgt_texts: Optional[List[str]] = None, + speakers: Optional[List[str]] = None, + src_langs: Optional[List[str]] = None, + tgt_langs: Optional[List[str]] = None, + ids: Optional[List[str]] = None, + tgt_dict: Optional[Dictionary] = None, + pre_tokenizer=None, + bpe_tokenizer=None, + n_frames_per_step=1, + speaker_to_id=None, + durations: Optional[List[List[int]]] = None, + pitches: Optional[List[str]] = None, + energies: Optional[List[str]] = None, + ): + super(TextToSpeechDataset, self).__init__( + split, + is_train_split, + cfg, + audio_paths, + n_frames, + src_texts=src_texts, + tgt_texts=tgt_texts, + speakers=speakers, + src_langs=src_langs, + tgt_langs=tgt_langs, + ids=ids, + tgt_dict=tgt_dict, + pre_tokenizer=pre_tokenizer, + bpe_tokenizer=bpe_tokenizer, + n_frames_per_step=n_frames_per_step, + speaker_to_id=speaker_to_id, + ) + self.durations = durations + self.pitches = pitches + self.energies = energies + + def __getitem__(self, index: int) -> TextToSpeechDatasetItem: + s2t_item = super().__getitem__(index) + + duration, pitch, energy = None, None, None + if self.durations is not None: + duration = torch.tensor( + self.durations[index] + [0], dtype=torch.long # pad 0 for EOS + ) + if self.pitches is not None: + pitch = get_features_or_waveform(self.pitches[index]) + pitch = torch.from_numpy( + np.concatenate((pitch, [0])) # pad 0 for EOS + ).float() + if self.energies is not None: + energy = get_features_or_waveform(self.energies[index]) + energy = torch.from_numpy( + np.concatenate((energy, [0])) # pad 0 for EOS + ).float() + return TextToSpeechDatasetItem( + index=index, + source=s2t_item.source, + target=s2t_item.target, + speaker_id=s2t_item.speaker_id, + duration=duration, + pitch=pitch, + energy=energy, + ) + + def collater(self, samples: List[TextToSpeechDatasetItem]) -> Dict[str, Any]: + if len(samples) == 0: + return {} + + src_lengths, order = torch.tensor( + [s.target.shape[0] for s in samples], dtype=torch.long + ).sort(descending=True) + id_ = torch.tensor([s.index for s in samples], dtype=torch.long).index_select( + 0, order + ) + feat = _collate_frames( + [s.source for s in samples], self.cfg.use_audio_input + ).index_select(0, order) + target_lengths = torch.tensor( + [s.source.shape[0] for s in samples], dtype=torch.long + ).index_select(0, order) + + src_tokens = fairseq_data_utils.collate_tokens( + [s.target for s in samples], + self.tgt_dict.pad(), + self.tgt_dict.eos(), + left_pad=False, + move_eos_to_beginning=False, + ).index_select(0, order) + + speaker = None + if self.speaker_to_id is not None: + speaker = ( + torch.tensor([s.speaker_id for s in samples], dtype=torch.long) + .index_select(0, order) + .view(-1, 1) + ) + + bsz, _, d = feat.size() + prev_output_tokens = torch.cat( + (feat.new_zeros((bsz, 1, d)), feat[:, :-1, :]), dim=1 + ) + + durations, pitches, energies = None, None, None + if self.durations is not None: + durations = fairseq_data_utils.collate_tokens( + [s.duration for s in samples], 0 + ).index_select(0, order) + assert src_tokens.shape[1] == durations.shape[1] + if self.pitches is not None: + pitches = _collate_frames([s.pitch for s in samples], True) + pitches = pitches.index_select(0, order) + assert src_tokens.shape[1] == pitches.shape[1] + if self.energies is not None: + energies = _collate_frames([s.energy for s in samples], True) + energies = energies.index_select(0, order) + assert src_tokens.shape[1] == energies.shape[1] + src_texts = [self.tgt_dict.string(samples[i].target) for i in order] + + return { + "id": id_, + "net_input": { + "src_tokens": src_tokens, + "src_lengths": src_lengths, + "prev_output_tokens": prev_output_tokens, + }, + "speaker": speaker, + "target": feat, + "durations": durations, + "pitches": pitches, + "energies": energies, + "target_lengths": target_lengths, + "ntokens": sum(target_lengths).item(), + "nsentences": len(samples), + "src_texts": src_texts, + } + + +class TextToSpeechDatasetCreator(SpeechToTextDatasetCreator): + KEY_DURATION = "duration" + KEY_PITCH = "pitch" + KEY_ENERGY = "energy" + + @classmethod + def _from_list( + cls, + split_name: str, + is_train_split, + samples: List[Dict], + cfg: S2TDataConfig, + tgt_dict, + pre_tokenizer, + bpe_tokenizer, + n_frames_per_step, + speaker_to_id, + multitask=None, + ) -> TextToSpeechDataset: + audio_root = Path(cfg.audio_root) + ids = [s[cls.KEY_ID] for s in samples] + audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples] + n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples] + tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples] + src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples] + speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples] + src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples] + tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples] + + durations = [s.get(cls.KEY_DURATION, None) for s in samples] + durations = [ + None if dd is None else [int(d) for d in dd.split(" ")] for dd in durations + ] + durations = None if any(dd is None for dd in durations) else durations + + pitches = [s.get(cls.KEY_PITCH, None) for s in samples] + pitches = [ + None if pp is None else (audio_root / pp).as_posix() for pp in pitches + ] + pitches = None if any(pp is None for pp in pitches) else pitches + + energies = [s.get(cls.KEY_ENERGY, None) for s in samples] + energies = [ + None if ee is None else (audio_root / ee).as_posix() for ee in energies + ] + energies = None if any(ee is None for ee in energies) else energies + + return TextToSpeechDataset( + split_name, + is_train_split, + cfg, + audio_paths, + n_frames, + src_texts, + tgt_texts, + speakers, + src_langs, + tgt_langs, + ids, + tgt_dict, + pre_tokenizer, + bpe_tokenizer, + n_frames_per_step, + speaker_to_id, + durations, + pitches, + energies, + ) diff --git a/fairseq/data/audio/waveform_transforms/__init__.py b/fairseq/data/audio/waveform_transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..57f8bb571b9a5be072b49a8e8f86e8b5de6b0fb9 --- /dev/null +++ b/fairseq/data/audio/waveform_transforms/__init__.py @@ -0,0 +1,48 @@ +import os +from fairseq.data.audio import ( + AudioTransform, + CompositeAudioTransform, + import_transforms, + register_audio_transform, +) + + +class AudioWaveformTransform(AudioTransform): + pass + + +AUDIO_WAVEFORM_TRANSFORM_REGISTRY = {} +AUDIO_WAVEFORM_TRANSFORM_CLASS_NAMES = set() + + +def get_audio_waveform_transform(name): + return AUDIO_WAVEFORM_TRANSFORM_REGISTRY[name] + + +def register_audio_waveform_transform(name): + return register_audio_transform( + name, + AudioWaveformTransform, + AUDIO_WAVEFORM_TRANSFORM_REGISTRY, + AUDIO_WAVEFORM_TRANSFORM_CLASS_NAMES, + ) + + +import_transforms(os.path.dirname(__file__), "waveform") + + +class CompositeAudioWaveformTransform(CompositeAudioTransform): + @classmethod + def from_config_dict(cls, config=None): + return super()._from_config_dict( + cls, + "waveform", + get_audio_waveform_transform, + CompositeAudioWaveformTransform, + config, + ) + + def __call__(self, x, sample_rate): + for t in self.transforms: + x, sample_rate = t(x, sample_rate) + return x, sample_rate diff --git a/fairseq/data/audio/waveform_transforms/__pycache__/__init__.cpython-310.pyc b/fairseq/data/audio/waveform_transforms/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50a06c45b964bbac947a6e835221f6b013d5bc23 Binary files /dev/null and b/fairseq/data/audio/waveform_transforms/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/data/audio/waveform_transforms/__pycache__/noiseaugment.cpython-310.pyc b/fairseq/data/audio/waveform_transforms/__pycache__/noiseaugment.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c0466ba9fd98fd40e36a89b2f13322e4cd30dad Binary files /dev/null and b/fairseq/data/audio/waveform_transforms/__pycache__/noiseaugment.cpython-310.pyc differ diff --git a/fairseq/data/audio/waveform_transforms/noiseaugment.py b/fairseq/data/audio/waveform_transforms/noiseaugment.py new file mode 100644 index 0000000000000000000000000000000000000000..401ce30943ca89f0b35e12e1db7504361bf8b329 --- /dev/null +++ b/fairseq/data/audio/waveform_transforms/noiseaugment.py @@ -0,0 +1,201 @@ +from pathlib import Path +import numpy as np +from math import ceil + +from fairseq.data.audio import rand_uniform +from fairseq.data.audio.waveform_transforms import ( + AudioWaveformTransform, + register_audio_waveform_transform, +) + +SNR_MIN = 5.0 +SNR_MAX = 15.0 +RATE = 0.25 + +NOISE_RATE = 1.0 +NOISE_LEN_MEAN = 0.2 +NOISE_LEN_STD = 0.05 + + +class NoiseAugmentTransform(AudioWaveformTransform): + @classmethod + def from_config_dict(cls, config=None): + _config = {} if config is None else config + return cls( + _config.get("samples_path", None), + _config.get("snr_min", SNR_MIN), + _config.get("snr_max", SNR_MAX), + _config.get("rate", RATE), + ) + + def __init__( + self, + samples_path: str, + snr_min: float = SNR_MIN, + snr_max: float = SNR_MAX, + rate: float = RATE, + ): + # Sanity checks + assert ( + samples_path + ), "need to provide path to audio samples for noise augmentation" + assert snr_max >= snr_min, f"empty signal-to-noise range ({snr_min}, {snr_max})" + assert rate >= 0 and rate <= 1, "rate should be a float between 0 to 1" + + self.paths = list(Path(samples_path).glob("**/*.wav")) # load music + self.n_samples = len(self.paths) + assert self.n_samples > 0, f"no audio files found in {samples_path}" + + self.snr_min = snr_min + self.snr_max = snr_max + self.rate = rate + + def __repr__(self): + return ( + self.__class__.__name__ + + "(" + + ", ".join( + [ + f"n_samples={self.n_samples}", + f"snr={self.snr_min}-{self.snr_max}dB", + f"rate={self.rate}", + ] + ) + + ")" + ) + + def pick_sample(self, goal_shape, always_2d=False, use_sample_rate=None): + from fairseq.data.audio.audio_utils import get_waveform + + path = self.paths[np.random.randint(0, self.n_samples)] + sample = get_waveform( + path, always_2d=always_2d, output_sample_rate=use_sample_rate + )[0] + + # Check dimensions match, else silently skip adding noise to sample + # NOTE: SHOULD THIS QUIT WITH AN ERROR? + is_2d = len(goal_shape) == 2 + if len(goal_shape) != sample.ndim or ( + is_2d and goal_shape[0] != sample.shape[0] + ): + return np.zeros(goal_shape) + + # Cut/repeat sample to size + len_dim = len(goal_shape) - 1 + n_repeat = ceil(goal_shape[len_dim] / sample.shape[len_dim]) + repeated = np.tile(sample, [1, n_repeat] if is_2d else n_repeat) + start = np.random.randint(0, repeated.shape[len_dim] - goal_shape[len_dim] + 1) + return ( + repeated[:, start : start + goal_shape[len_dim]] + if is_2d + else repeated[start : start + goal_shape[len_dim]] + ) + + def _mix(self, source, noise, snr): + get_power = lambda x: np.mean(x**2) + if get_power(noise): + scl = np.sqrt( + get_power(source) / (np.power(10, snr / 10) * get_power(noise)) + ) + else: + scl = 0 + return 1 * source + scl * noise + + def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None): + return self.pick_sample(goal_shape, always_2d, use_sample_rate) + + def __call__(self, source, sample_rate): + if np.random.random() > self.rate: + return source, sample_rate + + noise = self._get_noise( + source.shape, always_2d=True, use_sample_rate=sample_rate + ) + + return ( + self._mix(source, noise, rand_uniform(self.snr_min, self.snr_max)), + sample_rate, + ) + + +@register_audio_waveform_transform("musicaugment") +class MusicAugmentTransform(NoiseAugmentTransform): + pass + + +@register_audio_waveform_transform("backgroundnoiseaugment") +class BackgroundNoiseAugmentTransform(NoiseAugmentTransform): + pass + + +@register_audio_waveform_transform("babbleaugment") +class BabbleAugmentTransform(NoiseAugmentTransform): + def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None): + for i in range(np.random.randint(3, 8)): + speech = self.pick_sample(goal_shape, always_2d, use_sample_rate) + if i == 0: + agg_noise = speech + else: # SNR scaled by i (how many noise signals already in agg_noise) + agg_noise = self._mix(agg_noise, speech, i) + return agg_noise + + +@register_audio_waveform_transform("sporadicnoiseaugment") +class SporadicNoiseAugmentTransform(NoiseAugmentTransform): + @classmethod + def from_config_dict(cls, config=None): + _config = {} if config is None else config + return cls( + _config.get("samples_path", None), + _config.get("snr_min", SNR_MIN), + _config.get("snr_max", SNR_MAX), + _config.get("rate", RATE), + _config.get("noise_rate", NOISE_RATE), + _config.get("noise_len_mean", NOISE_LEN_MEAN), + _config.get("noise_len_std", NOISE_LEN_STD), + ) + + def __init__( + self, + samples_path: str, + snr_min: float = SNR_MIN, + snr_max: float = SNR_MAX, + rate: float = RATE, + noise_rate: float = NOISE_RATE, # noises per second + noise_len_mean: float = NOISE_LEN_MEAN, # length of noises in seconds + noise_len_std: float = NOISE_LEN_STD, + ): + super().__init__(samples_path, snr_min, snr_max, rate) + self.noise_rate = noise_rate + self.noise_len_mean = noise_len_mean + self.noise_len_std = noise_len_std + + def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None): + agg_noise = np.zeros(goal_shape) + len_dim = len(goal_shape) - 1 + is_2d = len(goal_shape) == 2 + + n_noises = round(self.noise_rate * goal_shape[len_dim] / use_sample_rate) + start_pointers = [ + round(rand_uniform(0, goal_shape[len_dim])) for _ in range(n_noises) + ] + + for start_pointer in start_pointers: + noise_shape = list(goal_shape) + len_seconds = np.random.normal(self.noise_len_mean, self.noise_len_std) + noise_shape[len_dim] = round(max(0, len_seconds) * use_sample_rate) + end_pointer = start_pointer + noise_shape[len_dim] + if end_pointer >= goal_shape[len_dim]: + continue + + noise = self.pick_sample(noise_shape, always_2d, use_sample_rate) + if is_2d: + agg_noise[:, start_pointer:end_pointer] = ( + agg_noise[:, start_pointer:end_pointer] + noise + ) + else: + agg_noise[start_pointer:end_pointer] = ( + agg_noise[start_pointer:end_pointer] + noise + ) + + return agg_noise diff --git a/fairseq/data/backtranslation_dataset.py b/fairseq/data/backtranslation_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8f70c90df3d237077537993e125d366c95292f1a --- /dev/null +++ b/fairseq/data/backtranslation_dataset.py @@ -0,0 +1,165 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from fairseq import utils + +from . import FairseqDataset + + +def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True): + """Backtranslate a list of samples. + + Given an input (*samples*) of the form: + + [{'id': 1, 'source': 'hallo welt'}] + + this will return: + + [{'id': 1, 'source': 'hello world', 'target': 'hallo welt'}] + + Args: + samples (List[dict]): samples to backtranslate. Individual samples are + expected to have a 'source' key, which will become the 'target' + after backtranslation. + collate_fn (callable): function to collate samples into a mini-batch + generate_fn (callable): function to generate backtranslations + cuda (bool): use GPU for generation (default: ``True``) + + Returns: + List[dict]: an updated list of samples with a backtranslated source + """ + collated_samples = collate_fn(samples) + s = utils.move_to_cuda(collated_samples) if cuda else collated_samples + generated_sources = generate_fn(s) + + id_to_src = {sample["id"]: sample["source"] for sample in samples} + + # Go through each tgt sentence in batch and its corresponding best + # generated hypothesis and create a backtranslation data pair + # {id: id, source: generated backtranslation, target: original tgt} + return [ + { + "id": id.item(), + "target": id_to_src[id.item()], + "source": hypos[0]["tokens"].cpu(), + } + for id, hypos in zip(collated_samples["id"], generated_sources) + ] + + +class BacktranslationDataset(FairseqDataset): + """ + Sets up a backtranslation dataset which takes a tgt batch, generates + a src using a tgt-src backtranslation function (*backtranslation_fn*), + and returns the corresponding `{generated src, input tgt}` batch. + + Args: + tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be + backtranslated. Only the source side of this dataset will be used. + After backtranslation, the source sentences in this dataset will be + returned as the targets. + src_dict (~fairseq.data.Dictionary): the dictionary of backtranslated + sentences. + tgt_dict (~fairseq.data.Dictionary, optional): the dictionary of + sentences to be backtranslated. + backtranslation_fn (callable, optional): function to call to generate + backtranslations. This is typically the `generate` method of a + :class:`~fairseq.sequence_generator.SequenceGenerator` object. + Pass in None when it is not available at initialization time, and + use set_backtranslation_fn function to set it when available. + output_collater (callable, optional): function to call on the + backtranslated samples to create the final batch + (default: ``tgt_dataset.collater``). + cuda: use GPU for generation + """ + + def __init__( + self, + tgt_dataset, + src_dict, + tgt_dict=None, + backtranslation_fn=None, + output_collater=None, + cuda=True, + **kwargs + ): + self.tgt_dataset = tgt_dataset + self.backtranslation_fn = backtranslation_fn + self.output_collater = ( + output_collater if output_collater is not None else tgt_dataset.collater + ) + self.cuda = cuda if torch.cuda.is_available() else False + self.src_dict = src_dict + self.tgt_dict = tgt_dict + + def __getitem__(self, index): + """ + Returns a single sample from *tgt_dataset*. Note that backtranslation is + not applied in this step; use :func:`collater` instead to backtranslate + a batch of samples. + """ + return self.tgt_dataset[index] + + def __len__(self): + return len(self.tgt_dataset) + + def set_backtranslation_fn(self, backtranslation_fn): + self.backtranslation_fn = backtranslation_fn + + def collater(self, samples): + """Merge and backtranslate a list of samples to form a mini-batch. + + Using the samples from *tgt_dataset*, load a collated target sample to + feed to the backtranslation model. Then take the backtranslation with + the best score as the source and the original input as the target. + + Note: we expect *tgt_dataset* to provide a function `collater()` that + will collate samples into the format expected by *backtranslation_fn*. + After backtranslation, we will feed the new list of samples (i.e., the + `(backtranslated source, original source)` pairs) to *output_collater* + and return the result. + + Args: + samples (List[dict]): samples to backtranslate and collate + + Returns: + dict: a mini-batch with keys coming from *output_collater* + """ + if samples[0].get("is_dummy", False): + return samples + samples = backtranslate_samples( + samples=samples, + collate_fn=self.tgt_dataset.collater, + generate_fn=(lambda net_input: self.backtranslation_fn(net_input)), + cuda=self.cuda, + ) + return self.output_collater(samples) + + def num_tokens(self, index): + """Just use the tgt dataset num_tokens""" + return self.tgt_dataset.num_tokens(index) + + def ordered_indices(self): + """Just use the tgt dataset ordered_indices""" + return self.tgt_dataset.ordered_indices() + + def size(self, index): + """Return an example's size as a float or tuple. This value is used + when filtering a dataset with ``--max-positions``. + + Note: we use *tgt_dataset* to approximate the length of the source + sentence, since we do not know the actual length until after + backtranslation. + """ + tgt_size = self.tgt_dataset.size(index)[0] + return (tgt_size, tgt_size) + + @property + def supports_prefetch(self): + return getattr(self.tgt_dataset, "supports_prefetch", False) + + def prefetch(self, indices): + return self.tgt_dataset.prefetch(indices) diff --git a/fairseq/data/base_wrapper_dataset.py b/fairseq/data/base_wrapper_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..134d398b47dc73c8807759188504aee205b3b34d --- /dev/null +++ b/fairseq/data/base_wrapper_dataset.py @@ -0,0 +1,78 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from torch.utils.data.dataloader import default_collate + +from . import FairseqDataset + + +class BaseWrapperDataset(FairseqDataset): + def __init__(self, dataset): + super().__init__() + self.dataset = dataset + + def __getitem__(self, index): + return self.dataset[index] + + def __len__(self): + return len(self.dataset) + + def collater(self, samples): + if hasattr(self.dataset, "collater"): + return self.dataset.collater(samples) + else: + return default_collate(samples) + + @property + def sizes(self): + return self.dataset.sizes + + def num_tokens(self, index): + return self.dataset.num_tokens(index) + + def size(self, index): + return self.dataset.size(index) + + def ordered_indices(self): + return self.dataset.ordered_indices() + + @property + def supports_prefetch(self): + return getattr(self.dataset, "supports_prefetch", False) + + def attr(self, attr: str, index: int): + return self.dataset.attr(attr, index) + + def prefetch(self, indices): + self.dataset.prefetch(indices) + + def get_batch_shapes(self): + return self.dataset.get_batch_shapes() + + def batch_by_size( + self, + indices, + max_tokens=None, + max_sentences=None, + required_batch_size_multiple=1, + ): + return self.dataset.batch_by_size( + indices, + max_tokens=max_tokens, + max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + ) + + def filter_indices_by_size(self, indices, max_sizes): + return self.dataset.filter_indices_by_size(indices, max_sizes) + + @property + def can_reuse_epoch_itr_across_epochs(self): + return self.dataset.can_reuse_epoch_itr_across_epochs + + def set_epoch(self, epoch): + super().set_epoch(epoch) + if hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(epoch) diff --git a/fairseq/data/bucket_pad_length_dataset.py b/fairseq/data/bucket_pad_length_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0f9410014845873bb0344fca6478c231c88e9dea --- /dev/null +++ b/fairseq/data/bucket_pad_length_dataset.py @@ -0,0 +1,78 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch.nn.functional as F +from fairseq.data import BaseWrapperDataset +from fairseq.data.data_utils import get_buckets, get_bucketed_sizes + + +class BucketPadLengthDataset(BaseWrapperDataset): + """ + Bucket and pad item lengths to the nearest bucket size. This can be used to + reduce the number of unique batch shapes, which is important on TPUs since + each new batch shape requires a recompilation. + + Args: + dataset (FairseqDatset): dataset to bucket + sizes (List[int]): all item sizes + num_buckets (int): number of buckets to create + pad_idx (int): padding symbol + left_pad (bool): if True, pad on the left; otherwise right pad + """ + + def __init__( + self, + dataset, + sizes, + num_buckets, + pad_idx, + left_pad, + tensor_key=None, + ): + super().__init__(dataset) + self.pad_idx = pad_idx + self.left_pad = left_pad + + assert num_buckets > 0 + self.buckets = get_buckets(sizes, num_buckets) + self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets) + self._tensor_key = tensor_key + + def _set_tensor(self, item, val): + if self._tensor_key is None: + return val + item[self._tensor_key] = val + return item + + def _get_tensor(self, item): + if self._tensor_key is None: + return item + return item[self._tensor_key] + + def _pad(self, tensor, bucket_size, dim=-1): + num_pad = bucket_size - tensor.size(dim) + return F.pad( + tensor, + (num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad), + value=self.pad_idx, + ) + + def __getitem__(self, index): + item = self.dataset[index] + bucket_size = self._bucketed_sizes[index] + tensor = self._get_tensor(item) + padded = self._pad(tensor, bucket_size) + return self._set_tensor(item, padded) + + @property + def sizes(self): + return self._bucketed_sizes + + def num_tokens(self, index): + return self._bucketed_sizes[index] + + def size(self, index): + return self._bucketed_sizes[index] diff --git a/fairseq/data/codedataset.py b/fairseq/data/codedataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a433091956ed2449f628028ea8da7be4c4895307 --- /dev/null +++ b/fairseq/data/codedataset.py @@ -0,0 +1,576 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import json +import logging +import os +import random +from pathlib import Path + +import numpy as np +import torch +import torch.utils.data + +from . import data_utils +from fairseq.data.fairseq_dataset import FairseqDataset + +F0_FRAME_SPACE = 0.005 # sec + + +logger = logging.getLogger(__name__) + + +class ExpressiveCodeDataConfig(object): + def __init__(self, json_path): + with open(json_path, "r") as f: + self.config = json.load(f) + self._manifests = self.config["manifests"] + + @property + def manifests(self): + return self._manifests + + @property + def n_units(self): + return self.config["n_units"] + + @property + def sampling_rate(self): + return self.config["sampling_rate"] + + @property + def code_hop_size(self): + return self.config["code_hop_size"] + + @property + def f0_stats(self): + """pre-computed f0 statistics path""" + return self.config.get("f0_stats", None) + + @property + def f0_vq_type(self): + """naive or precomp""" + return self.config["f0_vq_type"] + + @property + def f0_vq_name(self): + return self.config["f0_vq_name"] + + def get_f0_vq_naive_quantizer(self, log, norm_mean, norm_std): + key = "log" if log else "linear" + if norm_mean and norm_std: + key += "_mean_std_norm" + elif norm_mean: + key += "_mean_norm" + else: + key += "_none_norm" + return self.config["f0_vq_naive_quantizer"][key] + + @property + def f0_vq_n_units(self): + return self.config["f0_vq_n_units"] + + @property + def multispkr(self): + """how to parse speaker label from audio path""" + return self.config.get("multispkr", None) + + +def get_f0(audio, rate=16000): + try: + import amfm_decompy.basic_tools as basic + import amfm_decompy.pYAAPT as pYAAPT + from librosa.util import normalize + except ImportError: + raise "Please install amfm_decompy (`pip install AMFM-decompy`) and librosa (`pip install librosa`)." + + assert audio.ndim == 1 + frame_length = 20.0 # ms + to_pad = int(frame_length / 1000 * rate) // 2 + + audio = normalize(audio) * 0.95 + audio = np.pad(audio, (to_pad, to_pad), "constant", constant_values=0) + audio = basic.SignalObj(audio, rate) + pitch = pYAAPT.yaapt( + audio, + frame_length=frame_length, + frame_space=F0_FRAME_SPACE * 1000, + nccf_thresh1=0.25, + tda_frame_length=25.0, + ) + f0 = pitch.samp_values + return f0 + + +def interpolate_f0(f0): + try: + from scipy.interpolate import interp1d + except ImportError: + raise "Please install scipy (`pip install scipy`)" + + orig_t = np.arange(f0.shape[0]) + f0_interp = f0[:] + ii = f0_interp != 0 + if ii.sum() > 1: + f0_interp = interp1d( + orig_t[ii], f0_interp[ii], bounds_error=False, kind="linear", fill_value=0 + )(orig_t) + f0_interp = torch.Tensor(f0_interp).type_as(f0).to(f0.device) + return f0_interp + + +def naive_quantize(x, edges): + bin_idx = (x.view(-1, 1) > edges.view(1, -1)).long().sum(dim=1) + return bin_idx + + +def load_wav(full_path): + try: + import soundfile as sf + except ImportError: + raise "Please install soundfile (`pip install SoundFile`)" + data, sampling_rate = sf.read(full_path) + return data, sampling_rate + + +def parse_code(code_str, dictionary, append_eos): + code, duration = torch.unique_consecutive( + torch.ShortTensor(list(map(int, code_str.split()))), return_counts=True + ) + code = " ".join(map(str, code.tolist())) + code = dictionary.encode_line(code, append_eos).short() + + if append_eos: + duration = torch.cat((duration, duration.new_zeros((1,))), dim=0) # eos + duration = duration.short() + return code, duration + + +def parse_manifest(manifest, dictionary): + audio_files = [] + codes = [] + durations = [] + speakers = [] + + with open(manifest) as info: + for line in info.readlines(): + sample = eval(line.strip()) + if "cpc_km100" in sample: + k = "cpc_km100" + elif "hubert_km100" in sample: + k = "hubert_km100" + elif "phone" in sample: + k = "phone" + else: + assert False, "unknown format" + code = sample[k] + code, duration = parse_code(code, dictionary, append_eos=True) + + codes.append(code) + durations.append(duration) + audio_files.append(sample["audio"]) + speakers.append(sample.get("speaker", None)) + + return audio_files, codes, durations, speakers + + +def parse_speaker(path, method): + if type(path) == str: + path = Path(path) + + if method == "parent_name": + return path.parent.name + elif method == "parent_parent_name": + return path.parent.parent.name + elif method == "_": + return path.name.split("_")[0] + elif method == "single": + return "A" + elif callable(method): + return method(path) + else: + raise NotImplementedError() + + +def get_f0_by_filename(filename, tgt_sampling_rate): + audio, sampling_rate = load_wav(filename) + if sampling_rate != tgt_sampling_rate: + raise ValueError( + "{} SR doesn't match target {} SR".format(sampling_rate, tgt_sampling_rate) + ) + + # compute un-interpolated f0, and use Ann's interp in __getitem__ if set + f0 = get_f0(audio, rate=tgt_sampling_rate) + f0 = torch.from_numpy(f0.astype(np.float32)) + return f0 + + +def align_f0_to_durations(f0, durations, f0_code_ratio, tol=1): + code_len = durations.sum() + targ_len = int(f0_code_ratio * code_len) + diff = f0.size(0) - targ_len + assert abs(diff) <= tol, ( + f"Cannot subsample F0: |{f0.size(0)} - {f0_code_ratio}*{code_len}|" + f" > {tol} (dur=\n{durations})" + ) + if diff > 0: + f0 = f0[:targ_len] + elif diff < 0: + f0 = torch.cat((f0, f0.new_full((-diff,), f0[-1])), 0) + + f0_offset = 0.0 + seg_f0s = [] + for dur in durations: + f0_dur = dur.item() * f0_code_ratio + seg_f0 = f0[int(f0_offset) : int(f0_offset + f0_dur)] + seg_f0 = seg_f0[seg_f0 != 0] + if len(seg_f0) == 0: + seg_f0 = torch.tensor(0).type(seg_f0.type()) + else: + seg_f0 = seg_f0.mean() + seg_f0s.append(seg_f0) + f0_offset += f0_dur + + assert int(f0_offset) == f0.size(0), f"{f0_offset} {f0.size()} {durations.sum()}" + return torch.tensor(seg_f0s) + + +class Paddings(object): + def __init__(self, code_val, dur_val=0, f0_val=-2.0): + self.code = code_val + self.dur = dur_val + self.f0 = f0_val + + +class Shifts(object): + def __init__(self, shifts_str, pads): + self._shifts = list(map(int, shifts_str.split(","))) + assert len(self._shifts) == 2, self._shifts + assert all(s >= 0 for s in self._shifts) + self.extra_length = max(s for s in self._shifts) + self.pads = pads + + @property + def dur(self): + return self._shifts[0] + + @property + def f0(self): + return self._shifts[1] + + @staticmethod + def shift_one(seq, left_pad_num, right_pad_num, pad): + assert seq.ndim == 1 + bos = seq.new_full((left_pad_num,), pad) + eos = seq.new_full((right_pad_num,), pad) + seq = torch.cat([bos, seq, eos]) + mask = torch.ones_like(seq).bool() + mask[left_pad_num : len(seq) - right_pad_num] = 0 + return seq, mask + + def __call__(self, code, dur, f0): + if self.extra_length == 0: + code_mask = torch.zeros_like(code).bool() + dur_mask = torch.zeros_like(dur).bool() + f0_mask = torch.zeros_like(f0).bool() + return code, code_mask, dur, dur_mask, f0, f0_mask + + code, code_mask = self.shift_one(code, 0, self.extra_length, self.pads.code) + dur, dur_mask = self.shift_one( + dur, self.dur, self.extra_length - self.dur, self.pads.dur + ) + f0, f0_mask = self.shift_one( + f0, self.f0, self.extra_length - self.f0, self.pads.f0 + ) + return code, code_mask, dur, dur_mask, f0, f0_mask + + +class CodeDataset(FairseqDataset): + def __init__( + self, + manifest, + dictionary, + dur_dictionary, + f0_dictionary, + config, + discrete_dur, + discrete_f0, + log_f0, + normalize_f0_mean, + normalize_f0_std, + interpolate_f0, + return_filename=False, + strip_filename=True, + shifts="0,0", + return_continuous_f0=False, + ): + random.seed(1234) + self.dictionary = dictionary + self.dur_dictionary = dur_dictionary + self.f0_dictionary = f0_dictionary + self.config = config + + # duration config + self.discrete_dur = discrete_dur + + # pitch config + self.discrete_f0 = discrete_f0 + self.log_f0 = log_f0 + self.normalize_f0_mean = normalize_f0_mean + self.normalize_f0_std = normalize_f0_std + self.interpolate_f0 = interpolate_f0 + + self.return_filename = return_filename + self.strip_filename = strip_filename + self.f0_code_ratio = config.code_hop_size / ( + config.sampling_rate * F0_FRAME_SPACE + ) + + # use lazy loading to avoid sharing file handlers across workers + self.manifest = manifest + self._codes = None + self._durs = None + self._f0s = None + with open(f"{manifest}.leng.txt", "r") as f: + lengs = [int(line.rstrip()) for line in f] + edges = np.cumsum([0] + lengs) + self.starts, self.ends = edges[:-1], edges[1:] + with open(f"{manifest}.path.txt", "r") as f: + self.file_names = [line.rstrip() for line in f] + logger.info(f"num entries: {len(self.starts)}") + + if os.path.exists(f"{manifest}.f0_stat.pt"): + self.f0_stats = torch.load(f"{manifest}.f0_stat.pt") + elif config.f0_stats: + self.f0_stats = torch.load(config.f0_stats) + + self.multispkr = config.multispkr + if config.multispkr: + with open(f"{manifest}.speaker.txt", "r") as f: + self.spkrs = [line.rstrip() for line in f] + self.id_to_spkr = sorted(self.spkrs) + self.spkr_to_id = {k: v for v, k in enumerate(self.id_to_spkr)} + + self.pads = Paddings( + dictionary.pad(), + 0, # use 0 for duration padding + f0_dictionary.pad() if discrete_f0 else -5.0, + ) + self.shifts = Shifts(shifts, pads=self.pads) + self.return_continuous_f0 = return_continuous_f0 + + def get_data_handlers(self): + logging.info(f"loading data for {self.manifest}") + self._codes = np.load(f"{self.manifest}.code.npy", mmap_mode="r") + self._durs = np.load(f"{self.manifest}.dur.npy", mmap_mode="r") + + if self.discrete_f0: + if self.config.f0_vq_type == "precomp": + self._f0s = np.load( + f"{self.manifest}.{self.config.f0_vq_name}.npy", mmap_mode="r" + ) + elif self.config.f0_vq_type == "naive": + self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r") + quantizers_path = self.config.get_f0_vq_naive_quantizer( + self.log_f0, self.normalize_f0_mean, self.normalize_f0_std + ) + quantizers = torch.load(quantizers_path) + n_units = self.config.f0_vq_n_units + self._f0_quantizer = torch.from_numpy(quantizers[n_units]) + else: + raise ValueError(f"f0_vq_type {self.config.f0_vq_type} not supported") + else: + self._f0s = np.load(f"{self.manifest}.f0.npy", mmap_mode="r") + + def preprocess_f0(self, f0, stats): + """ + 1. interpolate + 2. log transform (keep unvoiced frame 0) + """ + # TODO: change this to be dependent on config for naive quantizer + f0 = f0.clone() + if self.interpolate_f0: + f0 = interpolate_f0(f0) + + mask = f0 != 0 # only process voiced frames + if self.log_f0: + f0[mask] = f0[mask].log() + if self.normalize_f0_mean: + mean = stats["logf0_mean"] if self.log_f0 else stats["f0_mean"] + f0[mask] = f0[mask] - mean + if self.normalize_f0_std: + std = stats["logf0_std"] if self.log_f0 else stats["f0_std"] + f0[mask] = f0[mask] / std + return f0 + + def _get_raw_item(self, index): + start, end = self.starts[index], self.ends[index] + if self._codes is None: + self.get_data_handlers() + code = torch.from_numpy(np.array(self._codes[start:end])).long() + dur = torch.from_numpy(np.array(self._durs[start:end])) + f0 = torch.from_numpy(np.array(self._f0s[start:end])) + return code, dur, f0 + + def __getitem__(self, index): + code, dur, f0 = self._get_raw_item(index) + code = torch.cat([code.new([self.dictionary.bos()]), code]) + + # use 0 for eos and bos + dur = torch.cat([dur.new([0]), dur]) + if self.discrete_dur: + dur = self.dur_dictionary.encode_line( + " ".join(map(str, dur.tolist())), append_eos=False + ).long() + else: + dur = dur.float() + + # TODO: find a more elegant approach + raw_f0 = None + if self.discrete_f0: + if self.config.f0_vq_type == "precomp": + f0 = self.f0_dictionary.encode_line( + " ".join(map(str, f0.tolist())), append_eos=False + ).long() + else: + f0 = f0.float() + f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]]) + if self.return_continuous_f0: + raw_f0 = f0 + raw_f0 = torch.cat([raw_f0.new([self.f0_dictionary.bos()]), raw_f0]) + f0 = naive_quantize(f0, self._f0_quantizer) + f0 = torch.cat([f0.new([self.f0_dictionary.bos()]), f0]) + else: + f0 = f0.float() + if self.multispkr: + f0 = self.preprocess_f0(f0, self.f0_stats[self.spkrs[index]]) + else: + f0 = self.preprocess_f0(f0, self.f0_stats) + f0 = torch.cat([f0.new([0]), f0]) + + if raw_f0 is not None: + *_, raw_f0, raw_f0_mask = self.shifts(code, dur, raw_f0) + else: + raw_f0_mask = None + + code, code_mask, dur, dur_mask, f0, f0_mask = self.shifts(code, dur, f0) + if raw_f0_mask is not None: + assert (raw_f0_mask == f0_mask).all() + + # is a padded frame if either input or output is padded + feats = { + "source": code[:-1], + "target": code[1:], + "mask": code_mask[1:].logical_or(code_mask[:-1]), + "dur_source": dur[:-1], + "dur_target": dur[1:], + "dur_mask": dur_mask[1:].logical_or(dur_mask[:-1]), + "f0_source": f0[:-1], + "f0_target": f0[1:], + "f0_mask": f0_mask[1:].logical_or(f0_mask[:-1]), + } + + if raw_f0 is not None: + feats["raw_f0"] = raw_f0[1:] + + if self.return_filename: + fname = self.file_names[index] + feats["filename"] = ( + fname if not self.strip_filename else Path(fname).with_suffix("").name + ) + return feats + + def __len__(self): + return len(self.starts) + + def size(self, index): + return self.ends[index] - self.starts[index] + self.shifts.extra_length + + def num_tokens(self, index): + return self.size(index) + + def collater(self, samples): + pad_idx, eos_idx = self.dictionary.pad(), self.dictionary.eos() + if len(samples) == 0: + return {} + + src_tokens = data_utils.collate_tokens( + [s["source"] for s in samples], pad_idx, eos_idx, left_pad=False + ) + + tgt_tokens = data_utils.collate_tokens( + [s["target"] for s in samples], + pad_idx=pad_idx, + eos_idx=pad_idx, # appending padding, eos is there already + left_pad=False, + ) + + src_durs, tgt_durs = [ + data_utils.collate_tokens( + [s[k] for s in samples], + pad_idx=self.pads.dur, + eos_idx=self.pads.dur, + left_pad=False, + ) + for k in ["dur_source", "dur_target"] + ] + + src_f0s, tgt_f0s = [ + data_utils.collate_tokens( + [s[k] for s in samples], + pad_idx=self.pads.f0, + eos_idx=self.pads.f0, + left_pad=False, + ) + for k in ["f0_source", "f0_target"] + ] + + mask, dur_mask, f0_mask = [ + data_utils.collate_tokens( + [s[k] for s in samples], + pad_idx=1, + eos_idx=1, + left_pad=False, + ) + for k in ["mask", "dur_mask", "f0_mask"] + ] + + src_lengths = torch.LongTensor([s["source"].numel() for s in samples]) + n_tokens = sum(len(s["source"]) for s in samples) + + result = { + "nsentences": len(samples), + "ntokens": n_tokens, + "net_input": { + "src_tokens": src_tokens, + "src_lengths": src_lengths, + "dur_src": src_durs, + "f0_src": src_f0s, + }, + "target": tgt_tokens, + "dur_target": tgt_durs, + "f0_target": tgt_f0s, + "mask": mask, + "dur_mask": dur_mask, + "f0_mask": f0_mask, + } + + if "filename" in samples[0]: + result["filename"] = [s["filename"] for s in samples] + + # TODO: remove this hack into the inference dataset + if "prefix" in samples[0]: + result["prefix"] = [s["prefix"] for s in samples] + + if "raw_f0" in samples[0]: + raw_f0s = data_utils.collate_tokens( + [s["raw_f0"] for s in samples], + pad_idx=self.pads.f0, + eos_idx=self.pads.f0, + left_pad=False, + ) + result["raw_f0"] = raw_f0s + return result diff --git a/fairseq/data/colorize_dataset.py b/fairseq/data/colorize_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7a6d2713791b1e80a6f5b982a4bf4ba93f6f561e --- /dev/null +++ b/fairseq/data/colorize_dataset.py @@ -0,0 +1,25 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from . import BaseWrapperDataset + + +class ColorizeDataset(BaseWrapperDataset): + """Adds 'colors' property to net input that is obtained from the provided color getter for use by models""" + + def __init__(self, dataset, color_getter): + super().__init__(dataset) + self.color_getter = color_getter + + def collater(self, samples): + base_collate = super().collater(samples) + if len(base_collate) > 0: + base_collate["net_input"]["colors"] = torch.tensor( + list(self.color_getter(self.dataset, s["id"]) for s in samples), + dtype=torch.long, + ) + return base_collate diff --git a/fairseq/data/concat_dataset.py b/fairseq/data/concat_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..01a4078bb159fa44b2d1062b9a971fe7f1abd1c2 --- /dev/null +++ b/fairseq/data/concat_dataset.py @@ -0,0 +1,124 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import bisect + +import numpy as np +from torch.utils.data.dataloader import default_collate + +from . import FairseqDataset + + +class ConcatDataset(FairseqDataset): + @staticmethod + def cumsum(sequence, sample_ratios): + r, s = [], 0 + for e, ratio in zip(sequence, sample_ratios): + curr_len = int(ratio * len(e)) + r.append(curr_len + s) + s += curr_len + return r + + def __init__(self, datasets, sample_ratios=1): + super(ConcatDataset, self).__init__() + assert len(datasets) > 0, "datasets should not be an empty iterable" + self.datasets = list(datasets) + if isinstance(sample_ratios, int): + sample_ratios = [sample_ratios] * len(self.datasets) + self.sample_ratios = sample_ratios + self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios) + self.real_sizes = [len(d) for d in self.datasets] + + def __len__(self): + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) + return self.datasets[dataset_idx][sample_idx] + + def _get_dataset_and_sample_index(self, idx: int): + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + sample_idx = sample_idx % self.real_sizes[dataset_idx] + return dataset_idx, sample_idx + + def collater(self, samples, **extra_args): + # For now only supports datasets with same underlying collater implementations + if hasattr(self.datasets[0], "collater"): + return self.datasets[0].collater(samples, **extra_args) + else: + return default_collate(samples, **extra_args) + + def size(self, idx: int): + """ + Return an example's size as a float or tuple. + """ + dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) + return self.datasets[dataset_idx].size(sample_idx) + + def num_tokens(self, index: int): + return np.max(self.size(index)) + + def attr(self, attr: str, index: int): + dataset_idx = bisect.bisect_right(self.cumulative_sizes, index) + return getattr(self.datasets[dataset_idx], attr, None) + + @property + def sizes(self): + _dataset_sizes = [] + for ds, sr in zip(self.datasets, self.sample_ratios): + if isinstance(ds.sizes, np.ndarray): + _dataset_sizes.append(np.tile(ds.sizes, sr)) + else: + # Only support underlying dataset with single size array. + assert isinstance(ds.sizes, list) + _dataset_sizes.append(np.tile(ds.sizes[0], sr)) + return np.concatenate(_dataset_sizes) + + @property + def supports_prefetch(self): + return all(d.supports_prefetch for d in self.datasets) + + def ordered_indices(self): + """ + Returns indices sorted by length. So less padding is needed. + """ + if isinstance(self.sizes, np.ndarray) and len(self.sizes.shape) > 1: + # special handling for concatenating lang_pair_datasets + indices = np.arange(len(self)) + sizes = self.sizes + tgt_sizes = ( + sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None + ) + src_sizes = ( + sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes + ) + # sort by target length, then source length + if tgt_sizes is not None: + indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")] + return indices[np.argsort(src_sizes[indices], kind="mergesort")] + else: + return np.argsort(self.sizes) + + def prefetch(self, indices): + frm = 0 + for to, ds in zip(self.cumulative_sizes, self.datasets): + real_size = len(ds) + if getattr(ds, "supports_prefetch", False): + ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to]) + frm = to + + @property + def can_reuse_epoch_itr_across_epochs(self): + return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets) + + def set_epoch(self, epoch): + super().set_epoch(epoch) + for ds in self.datasets: + if hasattr(ds, "set_epoch"): + ds.set_epoch(epoch) diff --git a/fairseq/data/concat_sentences_dataset.py b/fairseq/data/concat_sentences_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..625a29370e90f9d1d7274024afb902ed83a22325 --- /dev/null +++ b/fairseq/data/concat_sentences_dataset.py @@ -0,0 +1,54 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from . import FairseqDataset + + +class ConcatSentencesDataset(FairseqDataset): + def __init__(self, *datasets): + super().__init__() + self.datasets = datasets + assert all( + len(ds) == len(datasets[0]) for ds in datasets + ), "datasets must have the same length" + + def __getitem__(self, index): + return torch.cat([ds[index] for ds in self.datasets]) + + def __len__(self): + return len(self.datasets[0]) + + def collater(self, samples): + return self.datasets[0].collater(samples) + + @property + def sizes(self): + return sum(ds.sizes for ds in self.datasets) + + def num_tokens(self, index): + return sum(ds.num_tokens(index) for ds in self.datasets) + + def size(self, index): + return sum(ds.size(index) for ds in self.datasets) + + def ordered_indices(self): + return self.datasets[0].ordered_indices() + + @property + def supports_prefetch(self): + return any(getattr(ds, "supports_prefetch", False) for ds in self.datasets) + + def prefetch(self, indices): + for ds in self.datasets: + if getattr(ds, "supports_prefetch", False): + ds.prefetch(indices) + + def set_epoch(self, epoch): + super().set_epoch(epoch) + for ds in self.datasets: + if hasattr(ds, "set_epoch"): + ds.set_epoch(epoch) diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a19cc3c1827387a6fba571dfdbbbddfcce38eeb --- /dev/null +++ b/fairseq/data/data_utils.py @@ -0,0 +1,1144 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +try: + from collections.abc import Iterable +except ImportError: + from collections import Iterable +import contextlib +import itertools +import logging +import re +import warnings +from typing import Optional, Tuple + +import math +import numpy as np +import torch + +from fairseq.file_io import PathManager +from fairseq import utils +import os + +logger = logging.getLogger(__name__) + + +def infer_language_pair(path): + """Infer language pair from filename: .-.(...).idx""" + src, dst = None, None + for filename in PathManager.ls(path): + parts = filename.split(".") + if len(parts) >= 3 and len(parts[1].split("-")) == 2: + return parts[1].split("-") + return src, dst + + +def collate_tokens( + values, + pad_idx, + eos_idx=None, + left_pad=False, + move_eos_to_beginning=False, + pad_to_length=None, + pad_to_multiple=1, + pad_to_bsz=None, +): + """Convert a list of 1d tensors into a padded 2d tensor.""" + size = max(v.size(0) for v in values) + size = size if pad_to_length is None else max(size, pad_to_length) + if pad_to_multiple != 1 and size % pad_to_multiple != 0: + size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple) + + batch_size = len(values) if pad_to_bsz is None else max(len(values), pad_to_bsz) + res = values[0].new(batch_size, size).fill_(pad_idx) + + def copy_tensor(src, dst): + assert dst.numel() == src.numel() + if move_eos_to_beginning: + if eos_idx is None: + # if no eos_idx is specified, then use the last token in src + dst[0] = src[-1] + else: + dst[0] = eos_idx + dst[1:] = src[:-1] + else: + dst.copy_(src) + + for i, v in enumerate(values): + copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)]) + return res + + +def load_indexed_dataset( + path, dictionary=None, dataset_impl=None, combine=False, default="cached" +): + """A helper function for loading indexed datasets. + + Args: + path (str): path to indexed dataset (e.g., 'data-bin/train') + dictionary (~fairseq.data.Dictionary): data dictionary + dataset_impl (str, optional): which dataset implementation to use. If + not provided, it will be inferred automatically. For legacy indexed + data we use the 'cached' implementation by default. + combine (bool, optional): automatically load and combine multiple + datasets. For example, if *path* is 'data-bin/train', then we will + combine 'data-bin/train', 'data-bin/train1', ... and return a + single ConcatDataset instance. + """ + import fairseq.data.indexed_dataset as indexed_dataset + from fairseq.data.concat_dataset import ConcatDataset + + datasets = [] + for k in itertools.count(): + path_k = path + (str(k) if k > 0 else "") + try: + path_k = indexed_dataset.get_indexed_dataset_to_local(path_k) + except Exception as e: + if "StorageException: [404] Path not found" in str(e): + logger.warning(f"path_k: {e} not found") + else: + raise e + + dataset_impl_k = dataset_impl + if dataset_impl_k is None: + dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k) + dataset = indexed_dataset.make_dataset( + path_k, + impl=dataset_impl_k or default, + fix_lua_indexing=True, + dictionary=dictionary, + ) + if dataset is None: + break + logger.info("loaded {:,} examples from: {}".format(len(dataset), path_k)) + datasets.append(dataset) + if not combine: + break + if len(datasets) == 0: + return None + elif len(datasets) == 1: + return datasets[0] + else: + return ConcatDataset(datasets) + + +@contextlib.contextmanager +def numpy_seed(seed, *addl_seeds): + """Context manager which seeds the NumPy PRNG with the specified seed and + restores the state afterward""" + if seed is None: + yield + return + if len(addl_seeds) > 0: + seed = int(hash((seed, *addl_seeds)) % 1e6) + state = np.random.get_state() + np.random.seed(seed) + try: + yield + finally: + np.random.set_state(state) + + +def collect_filtered(function, iterable, filtered): + """ + Similar to :func:`filter` but collects filtered elements in ``filtered``. + + Args: + function (callable): function that returns ``False`` for elements that + should be filtered + iterable (iterable): iterable to filter + filtered (list): list to store filtered elements + """ + for el in iterable: + if function(el): + yield el + else: + filtered.append(el) + + +def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False): + def compare_leq(a, b): + return a <= b if not isinstance(a, tuple) else max(a) <= b + + def check_size(idx): + if isinstance(max_positions, float) or isinstance(max_positions, int): + return size_fn(idx) <= max_positions + elif isinstance(max_positions, dict): + idx_size = size_fn(idx) + assert isinstance(idx_size, dict) + intersect_keys = set(max_positions.keys()) & set(idx_size.keys()) + return all( + all( + a is None or b is None or a <= b + for a, b in zip(idx_size[key], max_positions[key]) + ) + for key in intersect_keys + ) + else: + # For MultiCorpusSampledDataset, will generalize it later + if not isinstance(size_fn(idx), Iterable): + return all(size_fn(idx) <= b for b in max_positions) + return all( + a is None or b is None or a <= b + for a, b in zip(size_fn(idx), max_positions) + ) + + ignored = [] + itr = collect_filtered(check_size, indices, ignored) + indices = np.fromiter(itr, dtype=np.int64, count=-1) + return indices, ignored + + +def filter_by_size(indices, dataset, max_positions, raise_exception=False): + """ + [deprecated] Filter indices based on their size. + Use `FairseqDataset::filter_indices_by_size` instead. + + Args: + indices (List[int]): ordered list of dataset indices + dataset (FairseqDataset): fairseq dataset instance + max_positions (tuple): filter elements larger than this size. + Comparisons are done component-wise. + raise_exception (bool, optional): if ``True``, raise an exception if + any elements are filtered (default: False). + """ + warnings.warn( + "data_utils.filter_by_size is deprecated. " + "Use `FairseqDataset::filter_indices_by_size` instead.", + stacklevel=2, + ) + if isinstance(max_positions, float) or isinstance(max_positions, int): + if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray): + ignored = indices[dataset.sizes[indices] > max_positions].tolist() + indices = indices[dataset.sizes[indices] <= max_positions] + elif ( + hasattr(dataset, "sizes") + and isinstance(dataset.sizes, list) + and len(dataset.sizes) == 1 + ): + ignored = indices[dataset.sizes[0][indices] > max_positions].tolist() + indices = indices[dataset.sizes[0][indices] <= max_positions] + else: + indices, ignored = _filter_by_size_dynamic( + indices, dataset.size, max_positions + ) + else: + indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions) + + if len(ignored) > 0 and raise_exception: + raise Exception( + ( + "Size of sample #{} is invalid (={}) since max_positions={}, " + "skip this example with --skip-invalid-size-inputs-valid-test" + ).format(ignored[0], dataset.size(ignored[0]), max_positions) + ) + if len(ignored) > 0: + logger.warning( + ( + "{} samples have invalid sizes and will be skipped, " + "max_positions={}, first few sample ids={}" + ).format(len(ignored), max_positions, ignored[:10]) + ) + return indices + + +def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes): + """Filter a list of sample indices. Remove those that are longer + than specified in max_sizes. + + Args: + indices (np.array): original array of sample indices + max_sizes (int or list[int] or tuple[int]): max sample size, + can be defined separately for src and tgt (then list or tuple) + + Returns: + np.array: filtered sample array + list: list of removed indices + """ + if max_sizes is None: + return indices, [] + if type(max_sizes) in (int, float): + max_src_size, max_tgt_size = max_sizes, max_sizes + else: + max_src_size, max_tgt_size = max_sizes + if tgt_sizes is None: + ignored = indices[src_sizes[indices] > max_src_size] + else: + ignored = indices[ + (src_sizes[indices] > max_src_size) | (tgt_sizes[indices] > max_tgt_size) + ] + if len(ignored) > 0: + if tgt_sizes is None: + indices = indices[src_sizes[indices] <= max_src_size] + else: + indices = indices[ + (src_sizes[indices] <= max_src_size) + & (tgt_sizes[indices] <= max_tgt_size) + ] + return indices, ignored.tolist() + + +def batch_by_size( + indices, + num_tokens_fn, + num_tokens_vec=None, + max_tokens=None, + max_sentences=None, + required_batch_size_multiple=1, + fixed_shapes=None, +): + """ + Yield mini-batches of indices bucketed by size. Batches may contain + sequences of different lengths. + + Args: + indices (List[int]): ordered list of dataset indices + num_tokens_fn (callable): function that returns the number of tokens at + a given index + num_tokens_vec (List[int], optional): precomputed vector of the number + of tokens for each index in indices (to enable faster batch generation) + max_tokens (int, optional): max number of tokens in each batch + (default: None). + max_sentences (int, optional): max number of sentences in each + batch (default: None). + required_batch_size_multiple (int, optional): require batch size to + be less than N or a multiple of N (default: 1). + fixed_shapes (List[Tuple[int, int]], optional): if given, batches will + only be created with the given shapes. *max_sentences* and + *required_batch_size_multiple* will be ignored (default: None). + """ + try: + from fairseq.data.data_utils_fast import ( + batch_by_size_fn, + batch_by_size_vec, + batch_fixed_shapes_fast, + ) + except ImportError: + raise ImportError( + "Please build Cython components with: " + "`python setup.py build_ext --inplace`" + ) + except ValueError: + raise ValueError( + "Please build (or rebuild) Cython components with `python setup.py build_ext --inplace`." + ) + + # added int() to avoid TypeError: an integer is required + max_tokens = int(max_tokens) if max_tokens is not None else -1 + max_sentences = max_sentences if max_sentences is not None else -1 + bsz_mult = required_batch_size_multiple + + if not isinstance(indices, np.ndarray): + indices = np.fromiter(indices, dtype=np.int64, count=-1) + + if num_tokens_vec is not None and not isinstance(num_tokens_vec, np.ndarray): + num_tokens_vec = np.fromiter(num_tokens_vec, dtype=np.int64, count=-1) + + if fixed_shapes is None: + if num_tokens_vec is None: + b = batch_by_size_fn( + indices, + num_tokens_fn, + max_tokens, + max_sentences, + bsz_mult, + ) + else: + b = batch_by_size_vec( + indices, + num_tokens_vec, + max_tokens, + max_sentences, + bsz_mult, + ) + + if bsz_mult > 1 and len(b[-1]) % bsz_mult != 0: + b = b[:-1] + + return b + + else: + fixed_shapes = np.array(fixed_shapes, dtype=np.int64) + sort_order = np.lexsort( + [ + fixed_shapes[:, 1].argsort(), # length + fixed_shapes[:, 0].argsort(), # bsz + ] + ) + fixed_shapes_sorted = fixed_shapes[sort_order] + return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted) + + +def post_process(sentence: str, symbol: str): + if symbol == "sentencepiece": + sentence = sentence.replace(" ", "").replace("\u2581", " ").strip() + elif symbol == "wordpiece": + sentence = sentence.replace(" ", "").replace("_", " ").strip() + elif symbol == "letter": + sentence = sentence.replace(" ", "").replace("|", " ").strip() + elif symbol == "silence": + import re + + sentence = sentence.replace("", "") + sentence = re.sub(" +", " ", sentence).strip() + elif symbol == "_EOW": + sentence = sentence.replace(" ", "").replace("_EOW", " ").strip() + elif symbol in {"subword_nmt", "@@ ", "@@"}: + if symbol == "subword_nmt": + symbol = "@@ " + sentence = (sentence + " ").replace(symbol, "").rstrip() + elif symbol == "none": + pass + elif symbol is not None: + raise NotImplementedError(f"Unknown post_process option: {symbol}") + return sentence + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, + require_same_masks: bool = True, + mask_dropout: float = 0.0, + add_masks: bool = False, + seed: Optional[int] = None, + epoch: Optional[int] = None, + indices: Optional[torch.Tensor] = None, + idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset + num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample + mask_dropout: randomly dropout this percentage of masks in each example + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + if num_mask_ver == 1: + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if seed is not None and epoch is not None and indices is not None: + seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6) + else: + seed_i = None + + rng = np.random.default_rng(seed_i) + + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + assert sz >= 0, sz + else: + sz = all_sz + + if num_mask_ver == 1: + if padding_mask is not None: + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + num_mask = all_num_mask + elif num_mask_ver == 2: + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + rng.random() + ) + num_mask = max(min_masks, num_mask) + else: + raise ValueError() + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = rng.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = rng.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + if mask_type == "static": + raise ValueError(f"this should never happens") + else: + lengths = [min(mask_length, sz - 1)] + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = rng.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = rng.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + if idc_select_ver == 1: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + mask_idc = rng.choice(sz - min_len, num_mask, replace=False) + elif idc_select_ver == 2: + mask_idc = rng.choice(sz, num_mask, replace=False) + else: + raise ValueError() + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idc = np.unique(mask_idc[mask_idc < sz]) + if len(mask_idc) >= sz: + raise ValueError( + ( + f"the entire sequence is masked. " + f"sz={sz}; mask_idc[mask_idc]; " + f"index={indices[i] if indices is not None else None}" + ) + ) + mask_idcs.append(mask_idc) + + target_len = None + if require_same_masks: + if add_masks: + target_len = max([len(m) for m in mask_idcs]) + else: + target_len = min([len(m) for m in mask_idcs]) + + for i, mask_idc in enumerate(mask_idcs): + if target_len is not None and len(mask_idc) > target_len: + mask_idc = rng.choice(mask_idc, target_len, replace=False) + + mask[i, mask_idc] = True + + if target_len is not None and len(mask_idc) < target_len: + unmasked = np.flatnonzero(~mask[i]) + to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False) + mask[i, to_mask] = True + + if mask_dropout > 0: + masked = np.flatnonzero(mask[i]) + num_holes = np.rint(len(masked) * mask_dropout).astype(int) + to_drop = rng.choice(masked, num_holes, replace=False) + mask[i, to_drop] = False + + return mask + + +def compute_block_mask_2d( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + mask_prob_adjust: float = 0, + inverse_mask: bool = False, + require_same_masks: bool = True, + expand_adjcent: bool = False, + mask_dropout: float = 0, + non_overlapping: bool = False, +) -> torch.Tensor: + + assert mask_length > 1 + + B, L = shape + + d = int(L**0.5) + + if inverse_mask: + mask_prob = 1 - mask_prob + + if non_overlapping: + sz = math.ceil(d / mask_length) + inp_len = sz * sz + + inp = torch.zeros((B, 1, sz, sz)) + w = torch.ones((1, 1, mask_length, mask_length)) + + mask_inds = torch.multinomial( + 1 - inp.view(B, -1), + int(inp_len * (mask_prob + mask_prob_adjust) * (1 + mask_dropout)), + replacement=False, + ) + inp.view(B, -1).scatter_(1, mask_inds, 1) + + mask = torch.nn.functional.conv_transpose2d(inp, w, stride=mask_length).squeeze( + 1 + ) + if mask.size(-1) > d: + mask = mask[..., :d, :d] + else: + mask = torch.zeros((B, d, d)) + mask_inds = torch.randint( + 0, + L, + size=( + B, + int( + L + * ((mask_prob + mask_prob_adjust) / mask_length**2) + * (1 + mask_dropout) + ), + ), + ) + mask.view(B, -1).scatter_(1, mask_inds, 1) + centers = mask.nonzero(as_tuple=True) + + inds = ([], [], []) + + offset = mask_length // 2 + for i in range(mask_length): + for j in range(mask_length): + k1 = i - offset + k2 = j - offset + inds[0].append(centers[0]) + inds[1].append(centers[1] + k1) + inds[2].append(centers[2] + k2) + + i0 = torch.cat(inds[0]) + i1 = torch.cat(inds[1]).clamp_(min=0, max=d - 1) + i2 = torch.cat(inds[2]).clamp_(min=0, max=d - 1) + + mask[(i0, i1, i2)] = 1 + + def get_nbs(b, m, w): + all_nbs = torch.nn.functional.conv2d(m.unsqueeze(1), w, padding="same") + all_nbs = all_nbs.clamp_max_(1).view(b, -1) + return all_nbs + + if require_same_masks and expand_adjcent: + w = torch.zeros((1, 1, 3, 3)) + w[..., 0, 1] = 1 + w[..., 2, 1] = 1 + w[..., 1, 0] = 1 + w[..., 1, 2] = 1 + + all_nbs = get_nbs(B, mask, w) + + mask = mask.reshape(B, -1) + + if require_same_masks: + n_masks = mask.sum(dim=-1) + final_target_len = int(L * (mask_prob)) + target_len = int(final_target_len * (1 + mask_dropout)) + + for i in range(len(mask)): + n = n_masks[i] + m = mask[i] + r = 0 + while expand_adjcent and n < target_len: + if r == 0: + nbs = all_nbs[i] + else: + nbs = get_nbs(1, m.view(1, d, d), w).flatten() + + cands = (1 - m + nbs) > 1 + cand_sz = int(cands.sum().item()) + + assert cand_sz > 0, f"{nbs} {cand_sz}" + + to_mask = torch.multinomial( + cands.float(), min(cand_sz, int(target_len - n)), replacement=False + ) + m[to_mask] = 1 + assert to_mask.numel() > 0 + n += to_mask.numel() + r += 1 + + if n > final_target_len: + to_unmask = torch.multinomial( + m, int(n - final_target_len), replacement=False + ) + m[to_unmask] = 0 + elif n < final_target_len: + to_mask = torch.multinomial( + (1 - m), int(final_target_len - n), replacement=False + ) + m[to_mask] = 1 + + if inverse_mask: + mask = 1 - mask + + return mask + + +def compute_block_mask_1d( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + mask_prob_adjust: float = 0, + inverse_mask: bool = False, + require_same_masks: bool = True, + expand_adjcent: bool = False, + mask_dropout: float = 0, + non_overlapping: bool = False, +) -> torch.Tensor: + + B, L = shape + + if inverse_mask: + mask_prob = 1 - mask_prob + + if non_overlapping: + sz = math.ceil(L / mask_length) + + inp = torch.zeros((B, 1, sz)) + w = torch.ones((1, 1, mask_length)) + + mask_inds = torch.multinomial( + 1 - inp.view(B, -1), + int(sz * (mask_prob + mask_prob_adjust) * (1 + mask_dropout)), + replacement=False, + ) + inp.view(B, -1).scatter_(1, mask_inds, 1) + + mask = torch.nn.functional.conv_transpose1d(inp, w, stride=mask_length).squeeze( + 1 + ) + if mask.size(-1) > L: + mask = mask[..., :L] + + else: + mask = torch.zeros((B, L)) + mask_inds = torch.randint( + 0, + L, + size=( + B, + int( + L + * ((mask_prob + mask_prob_adjust) / mask_length) + * (1 + mask_dropout) + ), + ), + ) + + mask.view(B, -1).scatter_(1, mask_inds, 1) + centers = mask.nonzero(as_tuple=True) + + inds = ([], []) + + offset = mask_length // 2 + for i in range(mask_length): + k1 = i - offset + inds[0].append(centers[0]) + inds[1].append(centers[1] + k1) + + i0 = torch.cat(inds[0]) + i1 = torch.cat(inds[1]).clamp_(min=0, max=L - 1) + + mask[(i0, i1)] = 1 + + def get_nbs(b, m, w): + all_nbs = torch.nn.functional.conv1d(m.unsqueeze(1), w, padding="same") + all_nbs = all_nbs.clamp_max_(1).view(b, -1) + return all_nbs + + if require_same_masks and expand_adjcent: + w = torch.ones((1, 1, 3)) + w[..., 1] = 0 + all_nbs = get_nbs(B, mask, w) + + mask = mask.view(B, -1) + + if require_same_masks: + n_masks = mask.sum(dim=-1) + final_target_len = int(L * (mask_prob)) + target_len = int(final_target_len * (1 + mask_dropout)) + + for i in range(len(mask)): + n = n_masks[i] + m = mask[i] + r = 0 + while expand_adjcent and n < target_len: + if r == 0: + nbs = all_nbs[i] + else: + nbs = get_nbs(1, m.unsqueeze(0), w).squeeze(0) + + cands = (1 - m + nbs) > 1 + cand_sz = int(cands.sum().item()) + + assert cand_sz > 0, f"{nbs} {cand_sz}" + + to_mask = torch.multinomial( + cands.float(), min(cand_sz, int(target_len - n)), replacement=False + ) + m[to_mask] = 1 + assert to_mask.numel() > 0 + n += to_mask.numel() + r += 1 + + if n > final_target_len: + to_unmask = torch.multinomial( + m, int(n - final_target_len), replacement=False + ) + m[to_unmask] = 0 + elif n < final_target_len: + to_mask = torch.multinomial( + (1 - m), int(final_target_len - n), replacement=False + ) + m[to_mask] = 1 + + if inverse_mask: + mask = 1 - mask + + return mask + + +def get_mem_usage(): + try: + import psutil + + mb = 1024 * 1024 + return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb" + except ImportError: + return "N/A" + + +# lens: torch.LongTensor +# returns: torch.BoolTensor +def lengths_to_padding_mask(lens): + bsz, max_lens = lens.size(0), torch.max(lens).item() + mask = torch.arange(max_lens).to(lens.device).view(1, max_lens) + mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens) + return mask + + +# lens: torch.LongTensor +# returns: torch.BoolTensor +def lengths_to_mask(lens): + return ~lengths_to_padding_mask(lens) + + +def get_buckets(sizes, num_buckets): + buckets = np.unique( + np.percentile( + sizes, + np.linspace(0, 100, num_buckets + 1), + interpolation="lower", + )[1:] + ) + return buckets + + +def get_bucketed_sizes(orig_sizes, buckets): + sizes = np.copy(orig_sizes) + assert np.min(sizes) >= 0 + start_val = -1 + for end_val in buckets: + mask = (sizes > start_val) & (sizes <= end_val) + sizes[mask] = end_val + start_val = end_val + return sizes + + +def _find_extra_valid_paths(dataset_path: str) -> set: + paths = utils.split_paths(dataset_path) + all_valid_paths = set() + for sub_dir in paths: + contents = PathManager.ls(sub_dir) + valid_paths = [c for c in contents if re.match("valid*[0-9].*", c) is not None] + all_valid_paths |= {os.path.basename(p) for p in valid_paths} + # Remove .bin, .idx etc + roots = {os.path.splitext(p)[0] for p in all_valid_paths} + return roots + + +def raise_if_valid_subsets_unintentionally_ignored(train_cfg) -> None: + """Raises if there are paths matching 'valid*[0-9].*' which are not combined or ignored.""" + if ( + train_cfg.dataset.ignore_unused_valid_subsets + or train_cfg.dataset.combine_valid_subsets + or train_cfg.dataset.disable_validation + or not hasattr(train_cfg.task, "data") + ): + return + other_paths = _find_extra_valid_paths(train_cfg.task.data) + specified_subsets = train_cfg.dataset.valid_subset.split(",") + ignored_paths = [p for p in other_paths if p not in specified_subsets] + if ignored_paths: + advice = "Set --combine-val to combine them or --ignore-unused-valid-subsets to ignore them." + msg = f"Valid paths {ignored_paths} will be ignored. {advice}" + raise ValueError(msg) + + +def compute_mask_indices_for_one( + sz, + mask_prob: float, + mask_length: int, + seed=None, + epoch=None, + index=None, + min_masks=0, +): + """ + set seed, epoch, index for deterministic masking + """ + seed = int(hash((seed, epoch, index)) % 1e6) if seed else None + rng = np.random.default_rng(seed) + + # decide elements to mask + mask = np.full(sz, False) + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + rng.random() + ) + num_mask = max(min_masks, num_mask) + + # multiple masking as described in the vq-wav2vec paper (https://arxiv.org/abs/1910.05453) + mask_idc = rng.choice(sz, num_mask, replace=False) + mask_idc = np.concatenate([mask_idc + i for i in range(mask_length)]) + mask_idc = mask_idc[mask_idc < len(mask)] + try: + mask[mask_idc] = True + except: # something wrong + print(f"Assigning mask indexes {mask_idc} to mask {mask} failed!") + raise + + return mask + + +def compute_mask_indices_v2( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + min_masks: int = 0, + require_same_masks: bool = True, + seed: Optional[int] = None, + epoch: Optional[int] = None, + indices: Optional[torch.Tensor] = None, +) -> np.ndarray: + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + else: + sz = all_sz + index = indices[i].item() if indices is not None else None + mask_for_one = compute_mask_indices_for_one( + sz, mask_prob, mask_length, seed, epoch, index, min_masks + ) + mask[i, :sz] = mask_for_one + + if require_same_masks: + index_sum = indices.sum().item() if indices is not None else None + seed = int(hash((seed, epoch, index_sum)) % 1e6) if seed else None + rng = np.random.default_rng(seed) + + num_mask = mask.sum(-1).min() + for i in range(bsz): + extra = mask[i].sum() - num_mask + if extra > 0: + to_unmask = rng.choice(np.nonzero(mask[i])[0], extra, replace=False) + mask[i, to_unmask] = False + + return mask + + +# TODO: a copy of the original compute_mask_indices +def compute_mask_indices_v3( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, + require_same_masks: bool = True, + mask_dropout: float = 0.0, + seed: Optional[int] = None, + epoch: Optional[int] = None, + indices: Optional[torch.Tensor] = None, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample + mask_dropout: randomly dropout this percentage of masks in each example + """ + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if seed is not None and epoch is not None and indices is not None: + seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6) + else: + seed_i = None + rng = np.random.default_rng(seed_i) + + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + rng.random() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = rng.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = rng.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = rng.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = rng.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = rng.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len and require_same_masks: + mask_idc = rng.choice(mask_idc, min_len, replace=False) + if mask_dropout > 0: + num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int) + mask_idc = rng.choice(mask_idc, len(mask_idc) - num_holes, replace=False) + + mask[i, mask_idc] = True + + return mask diff --git a/fairseq/data/data_utils_fast.cpp b/fairseq/data/data_utils_fast.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b6e9c7bafac31bf5efc24020a2535c5f2fa60231 --- /dev/null +++ b/fairseq/data/data_utils_fast.cpp @@ -0,0 +1,32427 @@ +/* Generated by Cython 3.0.8 */ + +/* BEGIN: Cython Metadata +{ + "distutils": { + "depends": [ + "/tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/core/include/numpy/arrayobject.h", + "/tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/core/include/numpy/arrayscalars.h", + "/tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/core/include/numpy/ndarrayobject.h", + "/tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/core/include/numpy/ndarraytypes.h", + "/tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/core/include/numpy/ufuncobject.h" + ], + "extra_compile_args": [ + "-std=c++11", + "-O3", + "-DTORCH_API_INCLUDE_EXTENSION_H", + "-DPYBIND11_COMPILER_TYPE=\"_gcc\"", + "-DPYBIND11_STDLIB=\"_libstdcpp\"", + "-DPYBIND11_BUILD_ABI=\"_cxxabi1011\"", + "-DTORCH_EXTENSION_NAME=data_utils_fast", + "-D_GLIBCXX_USE_CXX11_ABI=0" + ], + "include_dirs": [ + "/tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/core/include", + "/tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/core/include" + ], + "language": "c++", + "name": "fairseq.data.data_utils_fast", + "sources": [ + "fairseq/data/data_utils_fast.pyx" + ] + }, + "module_name": "fairseq.data.data_utils_fast" +} +END: Cython Metadata */ + +#ifndef PY_SSIZE_T_CLEAN +#define PY_SSIZE_T_CLEAN +#endif /* PY_SSIZE_T_CLEAN */ +#if defined(CYTHON_LIMITED_API) && 0 + #ifndef Py_LIMITED_API + #if CYTHON_LIMITED_API+0 > 0x03030000 + #define Py_LIMITED_API CYTHON_LIMITED_API + #else + #define Py_LIMITED_API 0x03030000 + #endif + #endif +#endif + +#include "Python.h" +#ifndef Py_PYTHON_H + #error Python headers needed to compile C extensions, please install development version of Python. +#elif PY_VERSION_HEX < 0x02070000 || (0x03000000 <= PY_VERSION_HEX && PY_VERSION_HEX < 0x03030000) + #error Cython requires Python 2.7+ or Python 3.3+. +#else +#if defined(CYTHON_LIMITED_API) && CYTHON_LIMITED_API +#define __PYX_EXTRA_ABI_MODULE_NAME "limited" +#else +#define __PYX_EXTRA_ABI_MODULE_NAME "" +#endif +#define CYTHON_ABI "3_0_8" __PYX_EXTRA_ABI_MODULE_NAME +#define __PYX_ABI_MODULE_NAME "_cython_" CYTHON_ABI +#define __PYX_TYPE_MODULE_PREFIX __PYX_ABI_MODULE_NAME "." +#define CYTHON_HEX_VERSION 0x030008F0 +#define CYTHON_FUTURE_DIVISION 1 +#include +#ifndef offsetof + #define offsetof(type, member) ( (size_t) & ((type*)0) -> member ) +#endif +#if !defined(_WIN32) && !defined(WIN32) && !defined(MS_WINDOWS) + #ifndef __stdcall + #define __stdcall + #endif + #ifndef __cdecl + #define __cdecl + #endif + #ifndef __fastcall + #define __fastcall + #endif +#endif +#ifndef DL_IMPORT + #define DL_IMPORT(t) t +#endif +#ifndef DL_EXPORT + #define DL_EXPORT(t) t +#endif +#define __PYX_COMMA , +#ifndef HAVE_LONG_LONG + #define HAVE_LONG_LONG +#endif +#ifndef PY_LONG_LONG + #define PY_LONG_LONG LONG_LONG +#endif +#ifndef Py_HUGE_VAL + #define Py_HUGE_VAL HUGE_VAL +#endif +#define __PYX_LIMITED_VERSION_HEX PY_VERSION_HEX +#if defined(GRAALVM_PYTHON) + /* For very preliminary testing purposes. Most variables are set the same as PyPy. + The existence of this section does not imply that anything works or is even tested */ + #define CYTHON_COMPILING_IN_PYPY 0 + #define CYTHON_COMPILING_IN_CPYTHON 0 + #define CYTHON_COMPILING_IN_LIMITED_API 0 + #define CYTHON_COMPILING_IN_GRAAL 1 + #define CYTHON_COMPILING_IN_NOGIL 0 + #undef CYTHON_USE_TYPE_SLOTS + #define CYTHON_USE_TYPE_SLOTS 0 + #undef CYTHON_USE_TYPE_SPECS + #define CYTHON_USE_TYPE_SPECS 0 + #undef CYTHON_USE_PYTYPE_LOOKUP + #define CYTHON_USE_PYTYPE_LOOKUP 0 + #if PY_VERSION_HEX < 0x03050000 + #undef CYTHON_USE_ASYNC_SLOTS + #define CYTHON_USE_ASYNC_SLOTS 0 + #elif !defined(CYTHON_USE_ASYNC_SLOTS) + #define CYTHON_USE_ASYNC_SLOTS 1 + #endif + #undef CYTHON_USE_PYLIST_INTERNALS + #define CYTHON_USE_PYLIST_INTERNALS 0 + #undef CYTHON_USE_UNICODE_INTERNALS + #define CYTHON_USE_UNICODE_INTERNALS 0 + #undef CYTHON_USE_UNICODE_WRITER + #define CYTHON_USE_UNICODE_WRITER 0 + #undef CYTHON_USE_PYLONG_INTERNALS + #define CYTHON_USE_PYLONG_INTERNALS 0 + #undef CYTHON_AVOID_BORROWED_REFS + #define CYTHON_AVOID_BORROWED_REFS 1 + #undef CYTHON_ASSUME_SAFE_MACROS + #define CYTHON_ASSUME_SAFE_MACROS 0 + #undef CYTHON_UNPACK_METHODS + #define CYTHON_UNPACK_METHODS 0 + #undef CYTHON_FAST_THREAD_STATE + #define CYTHON_FAST_THREAD_STATE 0 + #undef CYTHON_FAST_GIL + #define CYTHON_FAST_GIL 0 + #undef CYTHON_METH_FASTCALL + #define CYTHON_METH_FASTCALL 0 + #undef CYTHON_FAST_PYCALL + #define CYTHON_FAST_PYCALL 0 + #ifndef CYTHON_PEP487_INIT_SUBCLASS + #define CYTHON_PEP487_INIT_SUBCLASS (PY_MAJOR_VERSION >= 3) + #endif + #undef CYTHON_PEP489_MULTI_PHASE_INIT + #define CYTHON_PEP489_MULTI_PHASE_INIT 1 + #undef CYTHON_USE_MODULE_STATE + #define CYTHON_USE_MODULE_STATE 0 + #undef CYTHON_USE_TP_FINALIZE + #define CYTHON_USE_TP_FINALIZE 0 + #undef CYTHON_USE_DICT_VERSIONS + #define CYTHON_USE_DICT_VERSIONS 0 + #undef CYTHON_USE_EXC_INFO_STACK + #define CYTHON_USE_EXC_INFO_STACK 0 + #ifndef CYTHON_UPDATE_DESCRIPTOR_DOC + #define CYTHON_UPDATE_DESCRIPTOR_DOC 0 + #endif +#elif defined(PYPY_VERSION) + #define CYTHON_COMPILING_IN_PYPY 1 + #define CYTHON_COMPILING_IN_CPYTHON 0 + #define CYTHON_COMPILING_IN_LIMITED_API 0 + #define CYTHON_COMPILING_IN_GRAAL 0 + #define CYTHON_COMPILING_IN_NOGIL 0 + #undef CYTHON_USE_TYPE_SLOTS + #define CYTHON_USE_TYPE_SLOTS 0 + #ifndef CYTHON_USE_TYPE_SPECS + #define CYTHON_USE_TYPE_SPECS 0 + #endif + #undef CYTHON_USE_PYTYPE_LOOKUP + #define CYTHON_USE_PYTYPE_LOOKUP 0 + #if PY_VERSION_HEX < 0x03050000 + #undef CYTHON_USE_ASYNC_SLOTS + #define CYTHON_USE_ASYNC_SLOTS 0 + #elif !defined(CYTHON_USE_ASYNC_SLOTS) + #define CYTHON_USE_ASYNC_SLOTS 1 + #endif + #undef CYTHON_USE_PYLIST_INTERNALS + #define CYTHON_USE_PYLIST_INTERNALS 0 + #undef CYTHON_USE_UNICODE_INTERNALS + #define CYTHON_USE_UNICODE_INTERNALS 0 + #undef CYTHON_USE_UNICODE_WRITER + #define CYTHON_USE_UNICODE_WRITER 0 + #undef CYTHON_USE_PYLONG_INTERNALS + #define CYTHON_USE_PYLONG_INTERNALS 0 + #undef CYTHON_AVOID_BORROWED_REFS + #define CYTHON_AVOID_BORROWED_REFS 1 + #undef CYTHON_ASSUME_SAFE_MACROS + #define CYTHON_ASSUME_SAFE_MACROS 0 + #undef CYTHON_UNPACK_METHODS + #define CYTHON_UNPACK_METHODS 0 + #undef CYTHON_FAST_THREAD_STATE + #define CYTHON_FAST_THREAD_STATE 0 + #undef CYTHON_FAST_GIL + #define CYTHON_FAST_GIL 0 + #undef CYTHON_METH_FASTCALL + #define CYTHON_METH_FASTCALL 0 + #undef CYTHON_FAST_PYCALL + #define CYTHON_FAST_PYCALL 0 + #ifndef CYTHON_PEP487_INIT_SUBCLASS + #define CYTHON_PEP487_INIT_SUBCLASS (PY_MAJOR_VERSION >= 3) + #endif + #if PY_VERSION_HEX < 0x03090000 + #undef CYTHON_PEP489_MULTI_PHASE_INIT + #define CYTHON_PEP489_MULTI_PHASE_INIT 0 + #elif !defined(CYTHON_PEP489_MULTI_PHASE_INIT) + #define CYTHON_PEP489_MULTI_PHASE_INIT 1 + #endif + #undef CYTHON_USE_MODULE_STATE + #define CYTHON_USE_MODULE_STATE 0 + #undef CYTHON_USE_TP_FINALIZE + #define CYTHON_USE_TP_FINALIZE (PY_VERSION_HEX >= 0x030400a1 && PYPY_VERSION_NUM >= 0x07030C00) + #undef CYTHON_USE_DICT_VERSIONS + #define CYTHON_USE_DICT_VERSIONS 0 + #undef CYTHON_USE_EXC_INFO_STACK + #define CYTHON_USE_EXC_INFO_STACK 0 + #ifndef CYTHON_UPDATE_DESCRIPTOR_DOC + #define CYTHON_UPDATE_DESCRIPTOR_DOC 0 + #endif +#elif defined(CYTHON_LIMITED_API) + #ifdef Py_LIMITED_API + #undef __PYX_LIMITED_VERSION_HEX + #define __PYX_LIMITED_VERSION_HEX Py_LIMITED_API + #endif + #define CYTHON_COMPILING_IN_PYPY 0 + #define CYTHON_COMPILING_IN_CPYTHON 0 + #define CYTHON_COMPILING_IN_LIMITED_API 1 + #define CYTHON_COMPILING_IN_GRAAL 0 + #define CYTHON_COMPILING_IN_NOGIL 0 + #undef CYTHON_CLINE_IN_TRACEBACK + #define CYTHON_CLINE_IN_TRACEBACK 0 + #undef CYTHON_USE_TYPE_SLOTS + #define CYTHON_USE_TYPE_SLOTS 0 + #undef CYTHON_USE_TYPE_SPECS + #define CYTHON_USE_TYPE_SPECS 1 + #undef CYTHON_USE_PYTYPE_LOOKUP + #define CYTHON_USE_PYTYPE_LOOKUP 0 + #undef CYTHON_USE_ASYNC_SLOTS + #define CYTHON_USE_ASYNC_SLOTS 0 + #undef CYTHON_USE_PYLIST_INTERNALS + #define CYTHON_USE_PYLIST_INTERNALS 0 + #undef CYTHON_USE_UNICODE_INTERNALS + #define CYTHON_USE_UNICODE_INTERNALS 0 + #ifndef CYTHON_USE_UNICODE_WRITER + #define CYTHON_USE_UNICODE_WRITER 0 + #endif + #undef CYTHON_USE_PYLONG_INTERNALS + #define CYTHON_USE_PYLONG_INTERNALS 0 + #ifndef CYTHON_AVOID_BORROWED_REFS + #define CYTHON_AVOID_BORROWED_REFS 0 + #endif + #undef CYTHON_ASSUME_SAFE_MACROS + #define CYTHON_ASSUME_SAFE_MACROS 0 + #undef CYTHON_UNPACK_METHODS + #define CYTHON_UNPACK_METHODS 0 + #undef CYTHON_FAST_THREAD_STATE + #define CYTHON_FAST_THREAD_STATE 0 + #undef CYTHON_FAST_GIL + #define CYTHON_FAST_GIL 0 + #undef CYTHON_METH_FASTCALL + #define CYTHON_METH_FASTCALL 0 + #undef CYTHON_FAST_PYCALL + #define CYTHON_FAST_PYCALL 0 + #ifndef CYTHON_PEP487_INIT_SUBCLASS + #define CYTHON_PEP487_INIT_SUBCLASS 1 + #endif + #undef CYTHON_PEP489_MULTI_PHASE_INIT + #define CYTHON_PEP489_MULTI_PHASE_INIT 0 + #undef CYTHON_USE_MODULE_STATE + #define CYTHON_USE_MODULE_STATE 1 + #ifndef CYTHON_USE_TP_FINALIZE + #define CYTHON_USE_TP_FINALIZE 0 + #endif + #undef CYTHON_USE_DICT_VERSIONS + #define CYTHON_USE_DICT_VERSIONS 0 + #undef CYTHON_USE_EXC_INFO_STACK + #define CYTHON_USE_EXC_INFO_STACK 0 + #ifndef CYTHON_UPDATE_DESCRIPTOR_DOC + #define CYTHON_UPDATE_DESCRIPTOR_DOC 0 + #endif +#elif defined(Py_GIL_DISABLED) || defined(Py_NOGIL) + #define CYTHON_COMPILING_IN_PYPY 0 + #define CYTHON_COMPILING_IN_CPYTHON 0 + #define CYTHON_COMPILING_IN_LIMITED_API 0 + #define CYTHON_COMPILING_IN_GRAAL 0 + #define CYTHON_COMPILING_IN_NOGIL 1 + #ifndef CYTHON_USE_TYPE_SLOTS + #define CYTHON_USE_TYPE_SLOTS 1 + #endif + #undef CYTHON_USE_PYTYPE_LOOKUP + #define CYTHON_USE_PYTYPE_LOOKUP 0 + #ifndef CYTHON_USE_ASYNC_SLOTS + #define CYTHON_USE_ASYNC_SLOTS 1 + #endif + #undef CYTHON_USE_PYLIST_INTERNALS + #define CYTHON_USE_PYLIST_INTERNALS 0 + #ifndef CYTHON_USE_UNICODE_INTERNALS + #define CYTHON_USE_UNICODE_INTERNALS 1 + #endif + #undef CYTHON_USE_UNICODE_WRITER + #define CYTHON_USE_UNICODE_WRITER 0 + #undef CYTHON_USE_PYLONG_INTERNALS + #define CYTHON_USE_PYLONG_INTERNALS 0 + #ifndef CYTHON_AVOID_BORROWED_REFS + #define CYTHON_AVOID_BORROWED_REFS 0 + #endif + #ifndef CYTHON_ASSUME_SAFE_MACROS + #define CYTHON_ASSUME_SAFE_MACROS 1 + #endif + #ifndef CYTHON_UNPACK_METHODS + #define CYTHON_UNPACK_METHODS 1 + #endif + #undef CYTHON_FAST_THREAD_STATE + #define CYTHON_FAST_THREAD_STATE 0 + #undef CYTHON_FAST_PYCALL + #define CYTHON_FAST_PYCALL 0 + #ifndef CYTHON_PEP489_MULTI_PHASE_INIT + #define CYTHON_PEP489_MULTI_PHASE_INIT 1 + #endif + #ifndef CYTHON_USE_TP_FINALIZE + #define CYTHON_USE_TP_FINALIZE 1 + #endif + #undef CYTHON_USE_DICT_VERSIONS + #define CYTHON_USE_DICT_VERSIONS 0 + #undef CYTHON_USE_EXC_INFO_STACK + #define CYTHON_USE_EXC_INFO_STACK 0 +#else + #define CYTHON_COMPILING_IN_PYPY 0 + #define CYTHON_COMPILING_IN_CPYTHON 1 + #define CYTHON_COMPILING_IN_LIMITED_API 0 + #define CYTHON_COMPILING_IN_GRAAL 0 + #define CYTHON_COMPILING_IN_NOGIL 0 + #ifndef CYTHON_USE_TYPE_SLOTS + #define CYTHON_USE_TYPE_SLOTS 1 + #endif + #ifndef CYTHON_USE_TYPE_SPECS + #define CYTHON_USE_TYPE_SPECS 0 + #endif + #ifndef CYTHON_USE_PYTYPE_LOOKUP + #define CYTHON_USE_PYTYPE_LOOKUP 1 + #endif + #if PY_MAJOR_VERSION < 3 + #undef CYTHON_USE_ASYNC_SLOTS + #define CYTHON_USE_ASYNC_SLOTS 0 + #elif !defined(CYTHON_USE_ASYNC_SLOTS) + #define CYTHON_USE_ASYNC_SLOTS 1 + #endif + #ifndef CYTHON_USE_PYLONG_INTERNALS + #define CYTHON_USE_PYLONG_INTERNALS 1 + #endif + #ifndef CYTHON_USE_PYLIST_INTERNALS + #define CYTHON_USE_PYLIST_INTERNALS 1 + #endif + #ifndef CYTHON_USE_UNICODE_INTERNALS + #define CYTHON_USE_UNICODE_INTERNALS 1 + #endif + #if PY_VERSION_HEX < 0x030300F0 || PY_VERSION_HEX >= 0x030B00A2 + #undef CYTHON_USE_UNICODE_WRITER + #define CYTHON_USE_UNICODE_WRITER 0 + #elif !defined(CYTHON_USE_UNICODE_WRITER) + #define CYTHON_USE_UNICODE_WRITER 1 + #endif + #ifndef CYTHON_AVOID_BORROWED_REFS + #define CYTHON_AVOID_BORROWED_REFS 0 + #endif + #ifndef CYTHON_ASSUME_SAFE_MACROS + #define CYTHON_ASSUME_SAFE_MACROS 1 + #endif + #ifndef CYTHON_UNPACK_METHODS + #define CYTHON_UNPACK_METHODS 1 + #endif + #ifndef CYTHON_FAST_THREAD_STATE + #define CYTHON_FAST_THREAD_STATE 1 + #endif + #ifndef CYTHON_FAST_GIL + #define CYTHON_FAST_GIL (PY_MAJOR_VERSION < 3 || PY_VERSION_HEX >= 0x03060000 && PY_VERSION_HEX < 0x030C00A6) + #endif + #ifndef CYTHON_METH_FASTCALL + #define CYTHON_METH_FASTCALL (PY_VERSION_HEX >= 0x030700A1) + #endif + #ifndef CYTHON_FAST_PYCALL + #define CYTHON_FAST_PYCALL 1 + #endif + #ifndef CYTHON_PEP487_INIT_SUBCLASS + #define CYTHON_PEP487_INIT_SUBCLASS 1 + #endif + #if PY_VERSION_HEX < 0x03050000 + #undef CYTHON_PEP489_MULTI_PHASE_INIT + #define CYTHON_PEP489_MULTI_PHASE_INIT 0 + #elif !defined(CYTHON_PEP489_MULTI_PHASE_INIT) + #define CYTHON_PEP489_MULTI_PHASE_INIT 1 + #endif + #ifndef CYTHON_USE_MODULE_STATE + #define CYTHON_USE_MODULE_STATE 0 + #endif + #if PY_VERSION_HEX < 0x030400a1 + #undef CYTHON_USE_TP_FINALIZE + #define CYTHON_USE_TP_FINALIZE 0 + #elif !defined(CYTHON_USE_TP_FINALIZE) + #define CYTHON_USE_TP_FINALIZE 1 + #endif + #if PY_VERSION_HEX < 0x030600B1 + #undef CYTHON_USE_DICT_VERSIONS + #define CYTHON_USE_DICT_VERSIONS 0 + #elif !defined(CYTHON_USE_DICT_VERSIONS) + #define CYTHON_USE_DICT_VERSIONS (PY_VERSION_HEX < 0x030C00A5) + #endif + #if PY_VERSION_HEX < 0x030700A3 + #undef CYTHON_USE_EXC_INFO_STACK + #define CYTHON_USE_EXC_INFO_STACK 0 + #elif !defined(CYTHON_USE_EXC_INFO_STACK) + #define CYTHON_USE_EXC_INFO_STACK 1 + #endif + #ifndef CYTHON_UPDATE_DESCRIPTOR_DOC + #define CYTHON_UPDATE_DESCRIPTOR_DOC 1 + #endif +#endif +#if !defined(CYTHON_FAST_PYCCALL) +#define CYTHON_FAST_PYCCALL (CYTHON_FAST_PYCALL && PY_VERSION_HEX >= 0x030600B1) +#endif +#if !defined(CYTHON_VECTORCALL) +#define CYTHON_VECTORCALL (CYTHON_FAST_PYCCALL && PY_VERSION_HEX >= 0x030800B1) +#endif +#define CYTHON_BACKPORT_VECTORCALL (CYTHON_METH_FASTCALL && PY_VERSION_HEX < 0x030800B1) +#if CYTHON_USE_PYLONG_INTERNALS + #if PY_MAJOR_VERSION < 3 + #include "longintrepr.h" + #endif + #undef SHIFT + #undef BASE + #undef MASK + #ifdef SIZEOF_VOID_P + enum { __pyx_check_sizeof_voidp = 1 / (int)(SIZEOF_VOID_P == sizeof(void*)) }; + #endif +#endif +#ifndef __has_attribute + #define __has_attribute(x) 0 +#endif +#ifndef __has_cpp_attribute + #define __has_cpp_attribute(x) 0 +#endif +#ifndef CYTHON_RESTRICT + #if defined(__GNUC__) + #define CYTHON_RESTRICT __restrict__ + #elif defined(_MSC_VER) && _MSC_VER >= 1400 + #define CYTHON_RESTRICT __restrict + #elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L + #define CYTHON_RESTRICT restrict + #else + #define CYTHON_RESTRICT + #endif +#endif +#ifndef CYTHON_UNUSED + #if defined(__cplusplus) + /* for clang __has_cpp_attribute(maybe_unused) is true even before C++17 + * but leads to warnings with -pedantic, since it is a C++17 feature */ + #if ((defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) || __cplusplus >= 201703L) + #if __has_cpp_attribute(maybe_unused) + #define CYTHON_UNUSED [[maybe_unused]] + #endif + #endif + #endif +#endif +#ifndef CYTHON_UNUSED +# if defined(__GNUC__) +# if !(defined(__cplusplus)) || (__GNUC__ > 3 || (__GNUC__ == 3 && __GNUC_MINOR__ >= 4)) +# define CYTHON_UNUSED __attribute__ ((__unused__)) +# else +# define CYTHON_UNUSED +# endif +# elif defined(__ICC) || (defined(__INTEL_COMPILER) && !defined(_MSC_VER)) +# define CYTHON_UNUSED __attribute__ ((__unused__)) +# else +# define CYTHON_UNUSED +# endif +#endif +#ifndef CYTHON_UNUSED_VAR +# if defined(__cplusplus) + template void CYTHON_UNUSED_VAR( const T& ) { } +# else +# define CYTHON_UNUSED_VAR(x) (void)(x) +# endif +#endif +#ifndef CYTHON_MAYBE_UNUSED_VAR + #define CYTHON_MAYBE_UNUSED_VAR(x) CYTHON_UNUSED_VAR(x) +#endif +#ifndef CYTHON_NCP_UNUSED +# if CYTHON_COMPILING_IN_CPYTHON +# define CYTHON_NCP_UNUSED +# else +# define CYTHON_NCP_UNUSED CYTHON_UNUSED +# endif +#endif +#ifndef CYTHON_USE_CPP_STD_MOVE + #if defined(__cplusplus) && (\ + __cplusplus >= 201103L || (defined(_MSC_VER) && _MSC_VER >= 1600)) + #define CYTHON_USE_CPP_STD_MOVE 1 + #else + #define CYTHON_USE_CPP_STD_MOVE 0 + #endif +#endif +#define __Pyx_void_to_None(void_result) ((void)(void_result), Py_INCREF(Py_None), Py_None) +#ifdef _MSC_VER + #ifndef _MSC_STDINT_H_ + #if _MSC_VER < 1300 + typedef unsigned char uint8_t; + typedef unsigned short uint16_t; + typedef unsigned int uint32_t; + #else + typedef unsigned __int8 uint8_t; + typedef unsigned __int16 uint16_t; + typedef unsigned __int32 uint32_t; + #endif + #endif + #if _MSC_VER < 1300 + #ifdef _WIN64 + typedef unsigned long long __pyx_uintptr_t; + #else + typedef unsigned int __pyx_uintptr_t; + #endif + #else + #ifdef _WIN64 + typedef unsigned __int64 __pyx_uintptr_t; + #else + typedef unsigned __int32 __pyx_uintptr_t; + #endif + #endif +#else + #include + typedef uintptr_t __pyx_uintptr_t; +#endif +#ifndef CYTHON_FALLTHROUGH + #if defined(__cplusplus) + /* for clang __has_cpp_attribute(fallthrough) is true even before C++17 + * but leads to warnings with -pedantic, since it is a C++17 feature */ + #if ((defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) || __cplusplus >= 201703L) + #if __has_cpp_attribute(fallthrough) + #define CYTHON_FALLTHROUGH [[fallthrough]] + #endif + #endif + #ifndef CYTHON_FALLTHROUGH + #if __has_cpp_attribute(clang::fallthrough) + #define CYTHON_FALLTHROUGH [[clang::fallthrough]] + #elif __has_cpp_attribute(gnu::fallthrough) + #define CYTHON_FALLTHROUGH [[gnu::fallthrough]] + #endif + #endif + #endif + #ifndef CYTHON_FALLTHROUGH + #if __has_attribute(fallthrough) + #define CYTHON_FALLTHROUGH __attribute__((fallthrough)) + #else + #define CYTHON_FALLTHROUGH + #endif + #endif + #if defined(__clang__) && defined(__apple_build_version__) + #if __apple_build_version__ < 7000000 + #undef CYTHON_FALLTHROUGH + #define CYTHON_FALLTHROUGH + #endif + #endif +#endif +#ifdef __cplusplus + template + struct __PYX_IS_UNSIGNED_IMPL {static const bool value = T(0) < T(-1);}; + #define __PYX_IS_UNSIGNED(type) (__PYX_IS_UNSIGNED_IMPL::value) +#else + #define __PYX_IS_UNSIGNED(type) (((type)-1) > 0) +#endif +#if CYTHON_COMPILING_IN_PYPY == 1 + #define __PYX_NEED_TP_PRINT_SLOT (PY_VERSION_HEX >= 0x030800b4 && PY_VERSION_HEX < 0x030A0000) +#else + #define __PYX_NEED_TP_PRINT_SLOT (PY_VERSION_HEX >= 0x030800b4 && PY_VERSION_HEX < 0x03090000) +#endif +#define __PYX_REINTERPRET_FUNCION(func_pointer, other_pointer) ((func_pointer)(void(*)(void))(other_pointer)) + +#ifndef __cplusplus + #error "Cython files generated with the C++ option must be compiled with a C++ compiler." +#endif +#ifndef CYTHON_INLINE + #if defined(__clang__) + #define CYTHON_INLINE __inline__ __attribute__ ((__unused__)) + #else + #define CYTHON_INLINE inline + #endif +#endif +template +void __Pyx_call_destructor(T& x) { + x.~T(); +} +template +class __Pyx_FakeReference { + public: + __Pyx_FakeReference() : ptr(NULL) { } + __Pyx_FakeReference(const T& ref) : ptr(const_cast(&ref)) { } + T *operator->() { return ptr; } + T *operator&() { return ptr; } + operator T&() { return *ptr; } + template bool operator ==(const U& other) const { return *ptr == other; } + template bool operator !=(const U& other) const { return *ptr != other; } + template bool operator==(const __Pyx_FakeReference& other) const { return *ptr == *other.ptr; } + template bool operator!=(const __Pyx_FakeReference& other) const { return *ptr != *other.ptr; } + private: + T *ptr; +}; + +#define __PYX_BUILD_PY_SSIZE_T "n" +#define CYTHON_FORMAT_SSIZE_T "z" +#if PY_MAJOR_VERSION < 3 + #define __Pyx_BUILTIN_MODULE_NAME "__builtin__" + #define __Pyx_DefaultClassType PyClass_Type + #define __Pyx_PyCode_New(a, p, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos)\ + PyCode_New(a+k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos) +#else + #define __Pyx_BUILTIN_MODULE_NAME "builtins" + #define __Pyx_DefaultClassType PyType_Type +#if CYTHON_COMPILING_IN_LIMITED_API + static CYTHON_INLINE PyObject* __Pyx_PyCode_New(int a, int p, int k, int l, int s, int f, + PyObject *code, PyObject *c, PyObject* n, PyObject *v, + PyObject *fv, PyObject *cell, PyObject* fn, + PyObject *name, int fline, PyObject *lnos) { + PyObject *exception_table = NULL; + PyObject *types_module=NULL, *code_type=NULL, *result=NULL; + #if __PYX_LIMITED_VERSION_HEX < 0x030B0000 + PyObject *version_info; + PyObject *py_minor_version = NULL; + #endif + long minor_version = 0; + PyObject *type, *value, *traceback; + PyErr_Fetch(&type, &value, &traceback); + #if __PYX_LIMITED_VERSION_HEX >= 0x030B0000 + minor_version = 11; + #else + if (!(version_info = PySys_GetObject("version_info"))) goto end; + if (!(py_minor_version = PySequence_GetItem(version_info, 1))) goto end; + minor_version = PyLong_AsLong(py_minor_version); + Py_DECREF(py_minor_version); + if (minor_version == -1 && PyErr_Occurred()) goto end; + #endif + if (!(types_module = PyImport_ImportModule("types"))) goto end; + if (!(code_type = PyObject_GetAttrString(types_module, "CodeType"))) goto end; + if (minor_version <= 7) { + (void)p; + result = PyObject_CallFunction(code_type, "iiiiiOOOOOOiOO", a, k, l, s, f, code, + c, n, v, fn, name, fline, lnos, fv, cell); + } else if (minor_version <= 10) { + result = PyObject_CallFunction(code_type, "iiiiiiOOOOOOiOO", a,p, k, l, s, f, code, + c, n, v, fn, name, fline, lnos, fv, cell); + } else { + if (!(exception_table = PyBytes_FromStringAndSize(NULL, 0))) goto end; + result = PyObject_CallFunction(code_type, "iiiiiiOOOOOOOiOO", a,p, k, l, s, f, code, + c, n, v, fn, name, name, fline, lnos, exception_table, fv, cell); + } + end: + Py_XDECREF(code_type); + Py_XDECREF(exception_table); + Py_XDECREF(types_module); + if (type) { + PyErr_Restore(type, value, traceback); + } + return result; + } + #ifndef CO_OPTIMIZED + #define CO_OPTIMIZED 0x0001 + #endif + #ifndef CO_NEWLOCALS + #define CO_NEWLOCALS 0x0002 + #endif + #ifndef CO_VARARGS + #define CO_VARARGS 0x0004 + #endif + #ifndef CO_VARKEYWORDS + #define CO_VARKEYWORDS 0x0008 + #endif + #ifndef CO_ASYNC_GENERATOR + #define CO_ASYNC_GENERATOR 0x0200 + #endif + #ifndef CO_GENERATOR + #define CO_GENERATOR 0x0020 + #endif + #ifndef CO_COROUTINE + #define CO_COROUTINE 0x0080 + #endif +#elif PY_VERSION_HEX >= 0x030B0000 + static CYTHON_INLINE PyCodeObject* __Pyx_PyCode_New(int a, int p, int k, int l, int s, int f, + PyObject *code, PyObject *c, PyObject* n, PyObject *v, + PyObject *fv, PyObject *cell, PyObject* fn, + PyObject *name, int fline, PyObject *lnos) { + PyCodeObject *result; + PyObject *empty_bytes = PyBytes_FromStringAndSize("", 0); + if (!empty_bytes) return NULL; + result = + #if PY_VERSION_HEX >= 0x030C0000 + PyUnstable_Code_NewWithPosOnlyArgs + #else + PyCode_NewWithPosOnlyArgs + #endif + (a, p, k, l, s, f, code, c, n, v, fv, cell, fn, name, name, fline, lnos, empty_bytes); + Py_DECREF(empty_bytes); + return result; + } +#elif PY_VERSION_HEX >= 0x030800B2 && !CYTHON_COMPILING_IN_PYPY + #define __Pyx_PyCode_New(a, p, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos)\ + PyCode_NewWithPosOnlyArgs(a, p, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos) +#else + #define __Pyx_PyCode_New(a, p, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos)\ + PyCode_New(a, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos) +#endif +#endif +#if PY_VERSION_HEX >= 0x030900A4 || defined(Py_IS_TYPE) + #define __Pyx_IS_TYPE(ob, type) Py_IS_TYPE(ob, type) +#else + #define __Pyx_IS_TYPE(ob, type) (((const PyObject*)ob)->ob_type == (type)) +#endif +#if PY_VERSION_HEX >= 0x030A00B1 || defined(Py_Is) + #define __Pyx_Py_Is(x, y) Py_Is(x, y) +#else + #define __Pyx_Py_Is(x, y) ((x) == (y)) +#endif +#if PY_VERSION_HEX >= 0x030A00B1 || defined(Py_IsNone) + #define __Pyx_Py_IsNone(ob) Py_IsNone(ob) +#else + #define __Pyx_Py_IsNone(ob) __Pyx_Py_Is((ob), Py_None) +#endif +#if PY_VERSION_HEX >= 0x030A00B1 || defined(Py_IsTrue) + #define __Pyx_Py_IsTrue(ob) Py_IsTrue(ob) +#else + #define __Pyx_Py_IsTrue(ob) __Pyx_Py_Is((ob), Py_True) +#endif +#if PY_VERSION_HEX >= 0x030A00B1 || defined(Py_IsFalse) + #define __Pyx_Py_IsFalse(ob) Py_IsFalse(ob) +#else + #define __Pyx_Py_IsFalse(ob) __Pyx_Py_Is((ob), Py_False) +#endif +#define __Pyx_NoneAsNull(obj) (__Pyx_Py_IsNone(obj) ? NULL : (obj)) +#if PY_VERSION_HEX >= 0x030900F0 && !CYTHON_COMPILING_IN_PYPY + #define __Pyx_PyObject_GC_IsFinalized(o) PyObject_GC_IsFinalized(o) +#else + #define __Pyx_PyObject_GC_IsFinalized(o) _PyGC_FINALIZED(o) +#endif +#ifndef CO_COROUTINE + #define CO_COROUTINE 0x80 +#endif +#ifndef CO_ASYNC_GENERATOR + #define CO_ASYNC_GENERATOR 0x200 +#endif +#ifndef Py_TPFLAGS_CHECKTYPES + #define Py_TPFLAGS_CHECKTYPES 0 +#endif +#ifndef Py_TPFLAGS_HAVE_INDEX + #define Py_TPFLAGS_HAVE_INDEX 0 +#endif +#ifndef Py_TPFLAGS_HAVE_NEWBUFFER + #define Py_TPFLAGS_HAVE_NEWBUFFER 0 +#endif +#ifndef Py_TPFLAGS_HAVE_FINALIZE + #define Py_TPFLAGS_HAVE_FINALIZE 0 +#endif +#ifndef Py_TPFLAGS_SEQUENCE + #define Py_TPFLAGS_SEQUENCE 0 +#endif +#ifndef Py_TPFLAGS_MAPPING + #define Py_TPFLAGS_MAPPING 0 +#endif +#ifndef METH_STACKLESS + #define METH_STACKLESS 0 +#endif +#if PY_VERSION_HEX <= 0x030700A3 || !defined(METH_FASTCALL) + #ifndef METH_FASTCALL + #define METH_FASTCALL 0x80 + #endif + typedef PyObject *(*__Pyx_PyCFunctionFast) (PyObject *self, PyObject *const *args, Py_ssize_t nargs); + typedef PyObject *(*__Pyx_PyCFunctionFastWithKeywords) (PyObject *self, PyObject *const *args, + Py_ssize_t nargs, PyObject *kwnames); +#else + #define __Pyx_PyCFunctionFast _PyCFunctionFast + #define __Pyx_PyCFunctionFastWithKeywords _PyCFunctionFastWithKeywords +#endif +#if CYTHON_METH_FASTCALL + #define __Pyx_METH_FASTCALL METH_FASTCALL + #define __Pyx_PyCFunction_FastCall __Pyx_PyCFunctionFast + #define __Pyx_PyCFunction_FastCallWithKeywords __Pyx_PyCFunctionFastWithKeywords +#else + #define __Pyx_METH_FASTCALL METH_VARARGS + #define __Pyx_PyCFunction_FastCall PyCFunction + #define __Pyx_PyCFunction_FastCallWithKeywords PyCFunctionWithKeywords +#endif +#if CYTHON_VECTORCALL + #define __pyx_vectorcallfunc vectorcallfunc + #define __Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET PY_VECTORCALL_ARGUMENTS_OFFSET + #define __Pyx_PyVectorcall_NARGS(n) PyVectorcall_NARGS((size_t)(n)) +#elif CYTHON_BACKPORT_VECTORCALL + typedef PyObject *(*__pyx_vectorcallfunc)(PyObject *callable, PyObject *const *args, + size_t nargsf, PyObject *kwnames); + #define __Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET ((size_t)1 << (8 * sizeof(size_t) - 1)) + #define __Pyx_PyVectorcall_NARGS(n) ((Py_ssize_t)(((size_t)(n)) & ~__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET)) +#else + #define __Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET 0 + #define __Pyx_PyVectorcall_NARGS(n) ((Py_ssize_t)(n)) +#endif +#if PY_MAJOR_VERSION >= 0x030900B1 +#define __Pyx_PyCFunction_CheckExact(func) PyCFunction_CheckExact(func) +#else +#define __Pyx_PyCFunction_CheckExact(func) PyCFunction_Check(func) +#endif +#define __Pyx_CyOrPyCFunction_Check(func) PyCFunction_Check(func) +#if CYTHON_COMPILING_IN_CPYTHON +#define __Pyx_CyOrPyCFunction_GET_FUNCTION(func) (((PyCFunctionObject*)(func))->m_ml->ml_meth) +#elif !CYTHON_COMPILING_IN_LIMITED_API +#define __Pyx_CyOrPyCFunction_GET_FUNCTION(func) PyCFunction_GET_FUNCTION(func) +#endif +#if CYTHON_COMPILING_IN_CPYTHON +#define __Pyx_CyOrPyCFunction_GET_FLAGS(func) (((PyCFunctionObject*)(func))->m_ml->ml_flags) +static CYTHON_INLINE PyObject* __Pyx_CyOrPyCFunction_GET_SELF(PyObject *func) { + return (__Pyx_CyOrPyCFunction_GET_FLAGS(func) & METH_STATIC) ? NULL : ((PyCFunctionObject*)func)->m_self; +} +#endif +static CYTHON_INLINE int __Pyx__IsSameCFunction(PyObject *func, void *cfunc) { +#if CYTHON_COMPILING_IN_LIMITED_API + return PyCFunction_Check(func) && PyCFunction_GetFunction(func) == (PyCFunction) cfunc; +#else + return PyCFunction_Check(func) && PyCFunction_GET_FUNCTION(func) == (PyCFunction) cfunc; +#endif +} +#define __Pyx_IsSameCFunction(func, cfunc) __Pyx__IsSameCFunction(func, cfunc) +#if __PYX_LIMITED_VERSION_HEX < 0x030900B1 + #define __Pyx_PyType_FromModuleAndSpec(m, s, b) ((void)m, PyType_FromSpecWithBases(s, b)) + typedef PyObject *(*__Pyx_PyCMethod)(PyObject *, PyTypeObject *, PyObject *const *, size_t, PyObject *); +#else + #define __Pyx_PyType_FromModuleAndSpec(m, s, b) PyType_FromModuleAndSpec(m, s, b) + #define __Pyx_PyCMethod PyCMethod +#endif +#ifndef METH_METHOD + #define METH_METHOD 0x200 +#endif +#if CYTHON_COMPILING_IN_PYPY && !defined(PyObject_Malloc) + #define PyObject_Malloc(s) PyMem_Malloc(s) + #define PyObject_Free(p) PyMem_Free(p) + #define PyObject_Realloc(p) PyMem_Realloc(p) +#endif +#if CYTHON_COMPILING_IN_LIMITED_API + #define __Pyx_PyCode_HasFreeVars(co) (PyCode_GetNumFree(co) > 0) + #define __Pyx_PyFrame_SetLineNumber(frame, lineno) +#else + #define __Pyx_PyCode_HasFreeVars(co) (PyCode_GetNumFree(co) > 0) + #define __Pyx_PyFrame_SetLineNumber(frame, lineno) (frame)->f_lineno = (lineno) +#endif +#if CYTHON_COMPILING_IN_LIMITED_API + #define __Pyx_PyThreadState_Current PyThreadState_Get() +#elif !CYTHON_FAST_THREAD_STATE + #define __Pyx_PyThreadState_Current PyThreadState_GET() +#elif PY_VERSION_HEX >= 0x030d00A1 + #define __Pyx_PyThreadState_Current PyThreadState_GetUnchecked() +#elif PY_VERSION_HEX >= 0x03060000 + #define __Pyx_PyThreadState_Current _PyThreadState_UncheckedGet() +#elif PY_VERSION_HEX >= 0x03000000 + #define __Pyx_PyThreadState_Current PyThreadState_GET() +#else + #define __Pyx_PyThreadState_Current _PyThreadState_Current +#endif +#if CYTHON_COMPILING_IN_LIMITED_API +static CYTHON_INLINE void *__Pyx_PyModule_GetState(PyObject *op) +{ + void *result; + result = PyModule_GetState(op); + if (!result) + Py_FatalError("Couldn't find the module state"); + return result; +} +#endif +#define __Pyx_PyObject_GetSlot(obj, name, func_ctype) __Pyx_PyType_GetSlot(Py_TYPE(obj), name, func_ctype) +#if CYTHON_COMPILING_IN_LIMITED_API + #define __Pyx_PyType_GetSlot(type, name, func_ctype) ((func_ctype) PyType_GetSlot((type), Py_##name)) +#else + #define __Pyx_PyType_GetSlot(type, name, func_ctype) ((type)->name) +#endif +#if PY_VERSION_HEX < 0x030700A2 && !defined(PyThread_tss_create) && !defined(Py_tss_NEEDS_INIT) +#include "pythread.h" +#define Py_tss_NEEDS_INIT 0 +typedef int Py_tss_t; +static CYTHON_INLINE int PyThread_tss_create(Py_tss_t *key) { + *key = PyThread_create_key(); + return 0; +} +static CYTHON_INLINE Py_tss_t * PyThread_tss_alloc(void) { + Py_tss_t *key = (Py_tss_t *)PyObject_Malloc(sizeof(Py_tss_t)); + *key = Py_tss_NEEDS_INIT; + return key; +} +static CYTHON_INLINE void PyThread_tss_free(Py_tss_t *key) { + PyObject_Free(key); +} +static CYTHON_INLINE int PyThread_tss_is_created(Py_tss_t *key) { + return *key != Py_tss_NEEDS_INIT; +} +static CYTHON_INLINE void PyThread_tss_delete(Py_tss_t *key) { + PyThread_delete_key(*key); + *key = Py_tss_NEEDS_INIT; +} +static CYTHON_INLINE int PyThread_tss_set(Py_tss_t *key, void *value) { + return PyThread_set_key_value(*key, value); +} +static CYTHON_INLINE void * PyThread_tss_get(Py_tss_t *key) { + return PyThread_get_key_value(*key); +} +#endif +#if PY_MAJOR_VERSION < 3 + #if CYTHON_COMPILING_IN_PYPY + #if PYPY_VERSION_NUM < 0x07030600 + #if defined(__cplusplus) && __cplusplus >= 201402L + [[deprecated("`with nogil:` inside a nogil function will not release the GIL in PyPy2 < 7.3.6")]] + #elif defined(__GNUC__) || defined(__clang__) + __attribute__ ((__deprecated__("`with nogil:` inside a nogil function will not release the GIL in PyPy2 < 7.3.6"))) + #elif defined(_MSC_VER) + __declspec(deprecated("`with nogil:` inside a nogil function will not release the GIL in PyPy2 < 7.3.6")) + #endif + static CYTHON_INLINE int PyGILState_Check(void) { + return 0; + } + #else // PYPY_VERSION_NUM < 0x07030600 + #endif // PYPY_VERSION_NUM < 0x07030600 + #else + static CYTHON_INLINE int PyGILState_Check(void) { + PyThreadState * tstate = _PyThreadState_Current; + return tstate && (tstate == PyGILState_GetThisThreadState()); + } + #endif +#endif +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030d0000 || defined(_PyDict_NewPresized) +#define __Pyx_PyDict_NewPresized(n) ((n <= 8) ? PyDict_New() : _PyDict_NewPresized(n)) +#else +#define __Pyx_PyDict_NewPresized(n) PyDict_New() +#endif +#if PY_MAJOR_VERSION >= 3 || CYTHON_FUTURE_DIVISION + #define __Pyx_PyNumber_Divide(x,y) PyNumber_TrueDivide(x,y) + #define __Pyx_PyNumber_InPlaceDivide(x,y) PyNumber_InPlaceTrueDivide(x,y) +#else + #define __Pyx_PyNumber_Divide(x,y) PyNumber_Divide(x,y) + #define __Pyx_PyNumber_InPlaceDivide(x,y) PyNumber_InPlaceDivide(x,y) +#endif +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX > 0x030600B4 && PY_VERSION_HEX < 0x030d0000 && CYTHON_USE_UNICODE_INTERNALS +#define __Pyx_PyDict_GetItemStrWithError(dict, name) _PyDict_GetItem_KnownHash(dict, name, ((PyASCIIObject *) name)->hash) +static CYTHON_INLINE PyObject * __Pyx_PyDict_GetItemStr(PyObject *dict, PyObject *name) { + PyObject *res = __Pyx_PyDict_GetItemStrWithError(dict, name); + if (res == NULL) PyErr_Clear(); + return res; +} +#elif PY_MAJOR_VERSION >= 3 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07020000) +#define __Pyx_PyDict_GetItemStrWithError PyDict_GetItemWithError +#define __Pyx_PyDict_GetItemStr PyDict_GetItem +#else +static CYTHON_INLINE PyObject * __Pyx_PyDict_GetItemStrWithError(PyObject *dict, PyObject *name) { +#if CYTHON_COMPILING_IN_PYPY + return PyDict_GetItem(dict, name); +#else + PyDictEntry *ep; + PyDictObject *mp = (PyDictObject*) dict; + long hash = ((PyStringObject *) name)->ob_shash; + assert(hash != -1); + ep = (mp->ma_lookup)(mp, name, hash); + if (ep == NULL) { + return NULL; + } + return ep->me_value; +#endif +} +#define __Pyx_PyDict_GetItemStr PyDict_GetItem +#endif +#if CYTHON_USE_TYPE_SLOTS + #define __Pyx_PyType_GetFlags(tp) (((PyTypeObject *)tp)->tp_flags) + #define __Pyx_PyType_HasFeature(type, feature) ((__Pyx_PyType_GetFlags(type) & (feature)) != 0) + #define __Pyx_PyObject_GetIterNextFunc(obj) (Py_TYPE(obj)->tp_iternext) +#else + #define __Pyx_PyType_GetFlags(tp) (PyType_GetFlags((PyTypeObject *)tp)) + #define __Pyx_PyType_HasFeature(type, feature) PyType_HasFeature(type, feature) + #define __Pyx_PyObject_GetIterNextFunc(obj) PyIter_Next +#endif +#if CYTHON_COMPILING_IN_LIMITED_API + #define __Pyx_SetItemOnTypeDict(tp, k, v) PyObject_GenericSetAttr((PyObject*)tp, k, v) +#else + #define __Pyx_SetItemOnTypeDict(tp, k, v) PyDict_SetItem(tp->tp_dict, k, v) +#endif +#if CYTHON_USE_TYPE_SPECS && PY_VERSION_HEX >= 0x03080000 +#define __Pyx_PyHeapTypeObject_GC_Del(obj) {\ + PyTypeObject *type = Py_TYPE((PyObject*)obj);\ + assert(__Pyx_PyType_HasFeature(type, Py_TPFLAGS_HEAPTYPE));\ + PyObject_GC_Del(obj);\ + Py_DECREF(type);\ +} +#else +#define __Pyx_PyHeapTypeObject_GC_Del(obj) PyObject_GC_Del(obj) +#endif +#if CYTHON_COMPILING_IN_LIMITED_API + #define CYTHON_PEP393_ENABLED 1 + #define __Pyx_PyUnicode_READY(op) (0) + #define __Pyx_PyUnicode_GET_LENGTH(u) PyUnicode_GetLength(u) + #define __Pyx_PyUnicode_READ_CHAR(u, i) PyUnicode_ReadChar(u, i) + #define __Pyx_PyUnicode_MAX_CHAR_VALUE(u) ((void)u, 1114111U) + #define __Pyx_PyUnicode_KIND(u) ((void)u, (0)) + #define __Pyx_PyUnicode_DATA(u) ((void*)u) + #define __Pyx_PyUnicode_READ(k, d, i) ((void)k, PyUnicode_ReadChar((PyObject*)(d), i)) + #define __Pyx_PyUnicode_IS_TRUE(u) (0 != PyUnicode_GetLength(u)) +#elif PY_VERSION_HEX > 0x03030000 && defined(PyUnicode_KIND) + #define CYTHON_PEP393_ENABLED 1 + #if PY_VERSION_HEX >= 0x030C0000 + #define __Pyx_PyUnicode_READY(op) (0) + #else + #define __Pyx_PyUnicode_READY(op) (likely(PyUnicode_IS_READY(op)) ?\ + 0 : _PyUnicode_Ready((PyObject *)(op))) + #endif + #define __Pyx_PyUnicode_GET_LENGTH(u) PyUnicode_GET_LENGTH(u) + #define __Pyx_PyUnicode_READ_CHAR(u, i) PyUnicode_READ_CHAR(u, i) + #define __Pyx_PyUnicode_MAX_CHAR_VALUE(u) PyUnicode_MAX_CHAR_VALUE(u) + #define __Pyx_PyUnicode_KIND(u) ((int)PyUnicode_KIND(u)) + #define __Pyx_PyUnicode_DATA(u) PyUnicode_DATA(u) + #define __Pyx_PyUnicode_READ(k, d, i) PyUnicode_READ(k, d, i) + #define __Pyx_PyUnicode_WRITE(k, d, i, ch) PyUnicode_WRITE(k, d, i, (Py_UCS4) ch) + #if PY_VERSION_HEX >= 0x030C0000 + #define __Pyx_PyUnicode_IS_TRUE(u) (0 != PyUnicode_GET_LENGTH(u)) + #else + #if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x03090000 + #define __Pyx_PyUnicode_IS_TRUE(u) (0 != (likely(PyUnicode_IS_READY(u)) ? PyUnicode_GET_LENGTH(u) : ((PyCompactUnicodeObject *)(u))->wstr_length)) + #else + #define __Pyx_PyUnicode_IS_TRUE(u) (0 != (likely(PyUnicode_IS_READY(u)) ? PyUnicode_GET_LENGTH(u) : PyUnicode_GET_SIZE(u))) + #endif + #endif +#else + #define CYTHON_PEP393_ENABLED 0 + #define PyUnicode_1BYTE_KIND 1 + #define PyUnicode_2BYTE_KIND 2 + #define PyUnicode_4BYTE_KIND 4 + #define __Pyx_PyUnicode_READY(op) (0) + #define __Pyx_PyUnicode_GET_LENGTH(u) PyUnicode_GET_SIZE(u) + #define __Pyx_PyUnicode_READ_CHAR(u, i) ((Py_UCS4)(PyUnicode_AS_UNICODE(u)[i])) + #define __Pyx_PyUnicode_MAX_CHAR_VALUE(u) ((sizeof(Py_UNICODE) == 2) ? 65535U : 1114111U) + #define __Pyx_PyUnicode_KIND(u) ((int)sizeof(Py_UNICODE)) + #define __Pyx_PyUnicode_DATA(u) ((void*)PyUnicode_AS_UNICODE(u)) + #define __Pyx_PyUnicode_READ(k, d, i) ((void)(k), (Py_UCS4)(((Py_UNICODE*)d)[i])) + #define __Pyx_PyUnicode_WRITE(k, d, i, ch) (((void)(k)), ((Py_UNICODE*)d)[i] = (Py_UNICODE) ch) + #define __Pyx_PyUnicode_IS_TRUE(u) (0 != PyUnicode_GET_SIZE(u)) +#endif +#if CYTHON_COMPILING_IN_PYPY + #define __Pyx_PyUnicode_Concat(a, b) PyNumber_Add(a, b) + #define __Pyx_PyUnicode_ConcatSafe(a, b) PyNumber_Add(a, b) +#else + #define __Pyx_PyUnicode_Concat(a, b) PyUnicode_Concat(a, b) + #define __Pyx_PyUnicode_ConcatSafe(a, b) ((unlikely((a) == Py_None) || unlikely((b) == Py_None)) ?\ + PyNumber_Add(a, b) : __Pyx_PyUnicode_Concat(a, b)) +#endif +#if CYTHON_COMPILING_IN_PYPY + #if !defined(PyUnicode_DecodeUnicodeEscape) + #define PyUnicode_DecodeUnicodeEscape(s, size, errors) PyUnicode_Decode(s, size, "unicode_escape", errors) + #endif + #if !defined(PyUnicode_Contains) || (PY_MAJOR_VERSION == 2 && PYPY_VERSION_NUM < 0x07030500) + #undef PyUnicode_Contains + #define PyUnicode_Contains(u, s) PySequence_Contains(u, s) + #endif + #if !defined(PyByteArray_Check) + #define PyByteArray_Check(obj) PyObject_TypeCheck(obj, &PyByteArray_Type) + #endif + #if !defined(PyObject_Format) + #define PyObject_Format(obj, fmt) PyObject_CallMethod(obj, "__format__", "O", fmt) + #endif +#endif +#define __Pyx_PyString_FormatSafe(a, b) ((unlikely((a) == Py_None || (PyString_Check(b) && !PyString_CheckExact(b)))) ? PyNumber_Remainder(a, b) : __Pyx_PyString_Format(a, b)) +#define __Pyx_PyUnicode_FormatSafe(a, b) ((unlikely((a) == Py_None || (PyUnicode_Check(b) && !PyUnicode_CheckExact(b)))) ? PyNumber_Remainder(a, b) : PyUnicode_Format(a, b)) +#if PY_MAJOR_VERSION >= 3 + #define __Pyx_PyString_Format(a, b) PyUnicode_Format(a, b) +#else + #define __Pyx_PyString_Format(a, b) PyString_Format(a, b) +#endif +#if PY_MAJOR_VERSION < 3 && !defined(PyObject_ASCII) + #define PyObject_ASCII(o) PyObject_Repr(o) +#endif +#if PY_MAJOR_VERSION >= 3 + #define PyBaseString_Type PyUnicode_Type + #define PyStringObject PyUnicodeObject + #define PyString_Type PyUnicode_Type + #define PyString_Check PyUnicode_Check + #define PyString_CheckExact PyUnicode_CheckExact +#ifndef PyObject_Unicode + #define PyObject_Unicode PyObject_Str +#endif +#endif +#if PY_MAJOR_VERSION >= 3 + #define __Pyx_PyBaseString_Check(obj) PyUnicode_Check(obj) + #define __Pyx_PyBaseString_CheckExact(obj) PyUnicode_CheckExact(obj) +#else + #define __Pyx_PyBaseString_Check(obj) (PyString_Check(obj) || PyUnicode_Check(obj)) + #define __Pyx_PyBaseString_CheckExact(obj) (PyString_CheckExact(obj) || PyUnicode_CheckExact(obj)) +#endif +#if CYTHON_COMPILING_IN_CPYTHON + #define __Pyx_PySequence_ListKeepNew(obj)\ + (likely(PyList_CheckExact(obj) && Py_REFCNT(obj) == 1) ? __Pyx_NewRef(obj) : PySequence_List(obj)) +#else + #define __Pyx_PySequence_ListKeepNew(obj) PySequence_List(obj) +#endif +#ifndef PySet_CheckExact + #define PySet_CheckExact(obj) __Pyx_IS_TYPE(obj, &PySet_Type) +#endif +#if PY_VERSION_HEX >= 0x030900A4 + #define __Pyx_SET_REFCNT(obj, refcnt) Py_SET_REFCNT(obj, refcnt) + #define __Pyx_SET_SIZE(obj, size) Py_SET_SIZE(obj, size) +#else + #define __Pyx_SET_REFCNT(obj, refcnt) Py_REFCNT(obj) = (refcnt) + #define __Pyx_SET_SIZE(obj, size) Py_SIZE(obj) = (size) +#endif +#if CYTHON_ASSUME_SAFE_MACROS + #define __Pyx_PySequence_ITEM(o, i) PySequence_ITEM(o, i) + #define __Pyx_PySequence_SIZE(seq) Py_SIZE(seq) + #define __Pyx_PyTuple_SET_ITEM(o, i, v) (PyTuple_SET_ITEM(o, i, v), (0)) + #define __Pyx_PyList_SET_ITEM(o, i, v) (PyList_SET_ITEM(o, i, v), (0)) + #define __Pyx_PyTuple_GET_SIZE(o) PyTuple_GET_SIZE(o) + #define __Pyx_PyList_GET_SIZE(o) PyList_GET_SIZE(o) + #define __Pyx_PySet_GET_SIZE(o) PySet_GET_SIZE(o) + #define __Pyx_PyBytes_GET_SIZE(o) PyBytes_GET_SIZE(o) + #define __Pyx_PyByteArray_GET_SIZE(o) PyByteArray_GET_SIZE(o) +#else + #define __Pyx_PySequence_ITEM(o, i) PySequence_GetItem(o, i) + #define __Pyx_PySequence_SIZE(seq) PySequence_Size(seq) + #define __Pyx_PyTuple_SET_ITEM(o, i, v) PyTuple_SetItem(o, i, v) + #define __Pyx_PyList_SET_ITEM(o, i, v) PyList_SetItem(o, i, v) + #define __Pyx_PyTuple_GET_SIZE(o) PyTuple_Size(o) + #define __Pyx_PyList_GET_SIZE(o) PyList_Size(o) + #define __Pyx_PySet_GET_SIZE(o) PySet_Size(o) + #define __Pyx_PyBytes_GET_SIZE(o) PyBytes_Size(o) + #define __Pyx_PyByteArray_GET_SIZE(o) PyByteArray_Size(o) +#endif +#if PY_VERSION_HEX >= 0x030d00A1 + #define __Pyx_PyImport_AddModuleRef(name) PyImport_AddModuleRef(name) +#else + static CYTHON_INLINE PyObject *__Pyx_PyImport_AddModuleRef(const char *name) { + PyObject *module = PyImport_AddModule(name); + Py_XINCREF(module); + return module; + } +#endif +#if PY_MAJOR_VERSION >= 3 + #define PyIntObject PyLongObject + #define PyInt_Type PyLong_Type + #define PyInt_Check(op) PyLong_Check(op) + #define PyInt_CheckExact(op) PyLong_CheckExact(op) + #define __Pyx_Py3Int_Check(op) PyLong_Check(op) + #define __Pyx_Py3Int_CheckExact(op) PyLong_CheckExact(op) + #define PyInt_FromString PyLong_FromString + #define PyInt_FromUnicode PyLong_FromUnicode + #define PyInt_FromLong PyLong_FromLong + #define PyInt_FromSize_t PyLong_FromSize_t + #define PyInt_FromSsize_t PyLong_FromSsize_t + #define PyInt_AsLong PyLong_AsLong + #define PyInt_AS_LONG PyLong_AS_LONG + #define PyInt_AsSsize_t PyLong_AsSsize_t + #define PyInt_AsUnsignedLongMask PyLong_AsUnsignedLongMask + #define PyInt_AsUnsignedLongLongMask PyLong_AsUnsignedLongLongMask + #define PyNumber_Int PyNumber_Long +#else + #define __Pyx_Py3Int_Check(op) (PyLong_Check(op) || PyInt_Check(op)) + #define __Pyx_Py3Int_CheckExact(op) (PyLong_CheckExact(op) || PyInt_CheckExact(op)) +#endif +#if PY_MAJOR_VERSION >= 3 + #define PyBoolObject PyLongObject +#endif +#if PY_MAJOR_VERSION >= 3 && CYTHON_COMPILING_IN_PYPY + #ifndef PyUnicode_InternFromString + #define PyUnicode_InternFromString(s) PyUnicode_FromString(s) + #endif +#endif +#if PY_VERSION_HEX < 0x030200A4 + typedef long Py_hash_t; + #define __Pyx_PyInt_FromHash_t PyInt_FromLong + #define __Pyx_PyInt_AsHash_t __Pyx_PyIndex_AsHash_t +#else + #define __Pyx_PyInt_FromHash_t PyInt_FromSsize_t + #define __Pyx_PyInt_AsHash_t __Pyx_PyIndex_AsSsize_t +#endif +#if CYTHON_USE_ASYNC_SLOTS + #if PY_VERSION_HEX >= 0x030500B1 + #define __Pyx_PyAsyncMethodsStruct PyAsyncMethods + #define __Pyx_PyType_AsAsync(obj) (Py_TYPE(obj)->tp_as_async) + #else + #define __Pyx_PyType_AsAsync(obj) ((__Pyx_PyAsyncMethodsStruct*) (Py_TYPE(obj)->tp_reserved)) + #endif +#else + #define __Pyx_PyType_AsAsync(obj) NULL +#endif +#ifndef __Pyx_PyAsyncMethodsStruct + typedef struct { + unaryfunc am_await; + unaryfunc am_aiter; + unaryfunc am_anext; + } __Pyx_PyAsyncMethodsStruct; +#endif + +#if defined(_WIN32) || defined(WIN32) || defined(MS_WINDOWS) + #if !defined(_USE_MATH_DEFINES) + #define _USE_MATH_DEFINES + #endif +#endif +#include +#ifdef NAN +#define __PYX_NAN() ((float) NAN) +#else +static CYTHON_INLINE float __PYX_NAN() { + float value; + memset(&value, 0xFF, sizeof(value)); + return value; +} +#endif +#if defined(__CYGWIN__) && defined(_LDBL_EQ_DBL) +#define __Pyx_truncl trunc +#else +#define __Pyx_truncl truncl +#endif + +#define __PYX_MARK_ERR_POS(f_index, lineno) \ + { __pyx_filename = __pyx_f[f_index]; (void)__pyx_filename; __pyx_lineno = lineno; (void)__pyx_lineno; __pyx_clineno = __LINE__; (void)__pyx_clineno; } +#define __PYX_ERR(f_index, lineno, Ln_error) \ + { __PYX_MARK_ERR_POS(f_index, lineno) goto Ln_error; } + +#ifdef CYTHON_EXTERN_C + #undef __PYX_EXTERN_C + #define __PYX_EXTERN_C CYTHON_EXTERN_C +#elif defined(__PYX_EXTERN_C) + #ifdef _MSC_VER + #pragma message ("Please do not define the '__PYX_EXTERN_C' macro externally. Use 'CYTHON_EXTERN_C' instead.") + #else + #warning Please do not define the '__PYX_EXTERN_C' macro externally. Use 'CYTHON_EXTERN_C' instead. + #endif +#else + #define __PYX_EXTERN_C extern "C++" +#endif + +#define __PYX_HAVE__fairseq__data__data_utils_fast +#define __PYX_HAVE_API__fairseq__data__data_utils_fast +/* Early includes */ +#include +#include + + /* Using NumPy API declarations from "numpy/__init__.cython-30.pxd" */ + +#include "numpy/arrayobject.h" +#include "numpy/ndarrayobject.h" +#include "numpy/ndarraytypes.h" +#include "numpy/arrayscalars.h" +#include "numpy/ufuncobject.h" +#include +#include "pythread.h" +#include +#ifdef _OPENMP +#include +#endif /* _OPENMP */ + +#if defined(PYREX_WITHOUT_ASSERTIONS) && !defined(CYTHON_WITHOUT_ASSERTIONS) +#define CYTHON_WITHOUT_ASSERTIONS +#endif + +typedef struct {PyObject **p; const char *s; const Py_ssize_t n; const char* encoding; + const char is_unicode; const char is_str; const char intern; } __Pyx_StringTabEntry; + +#define __PYX_DEFAULT_STRING_ENCODING_IS_ASCII 0 +#define __PYX_DEFAULT_STRING_ENCODING_IS_UTF8 0 +#define __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT (PY_MAJOR_VERSION >= 3 && __PYX_DEFAULT_STRING_ENCODING_IS_UTF8) +#define __PYX_DEFAULT_STRING_ENCODING "" +#define __Pyx_PyObject_FromString __Pyx_PyBytes_FromString +#define __Pyx_PyObject_FromStringAndSize __Pyx_PyBytes_FromStringAndSize +#define __Pyx_uchar_cast(c) ((unsigned char)c) +#define __Pyx_long_cast(x) ((long)x) +#define __Pyx_fits_Py_ssize_t(v, type, is_signed) (\ + (sizeof(type) < sizeof(Py_ssize_t)) ||\ + (sizeof(type) > sizeof(Py_ssize_t) &&\ + likely(v < (type)PY_SSIZE_T_MAX ||\ + v == (type)PY_SSIZE_T_MAX) &&\ + (!is_signed || likely(v > (type)PY_SSIZE_T_MIN ||\ + v == (type)PY_SSIZE_T_MIN))) ||\ + (sizeof(type) == sizeof(Py_ssize_t) &&\ + (is_signed || likely(v < (type)PY_SSIZE_T_MAX ||\ + v == (type)PY_SSIZE_T_MAX))) ) +static CYTHON_INLINE int __Pyx_is_valid_index(Py_ssize_t i, Py_ssize_t limit) { + return (size_t) i < (size_t) limit; +} +#if defined (__cplusplus) && __cplusplus >= 201103L + #include + #define __Pyx_sst_abs(value) std::abs(value) +#elif SIZEOF_INT >= SIZEOF_SIZE_T + #define __Pyx_sst_abs(value) abs(value) +#elif SIZEOF_LONG >= SIZEOF_SIZE_T + #define __Pyx_sst_abs(value) labs(value) +#elif defined (_MSC_VER) + #define __Pyx_sst_abs(value) ((Py_ssize_t)_abs64(value)) +#elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L + #define __Pyx_sst_abs(value) llabs(value) +#elif defined (__GNUC__) + #define __Pyx_sst_abs(value) __builtin_llabs(value) +#else + #define __Pyx_sst_abs(value) ((value<0) ? -value : value) +#endif +static CYTHON_INLINE Py_ssize_t __Pyx_ssize_strlen(const char *s); +static CYTHON_INLINE const char* __Pyx_PyObject_AsString(PyObject*); +static CYTHON_INLINE const char* __Pyx_PyObject_AsStringAndSize(PyObject*, Py_ssize_t* length); +static CYTHON_INLINE PyObject* __Pyx_PyByteArray_FromString(const char*); +#define __Pyx_PyByteArray_FromStringAndSize(s, l) PyByteArray_FromStringAndSize((const char*)s, l) +#define __Pyx_PyBytes_FromString PyBytes_FromString +#define __Pyx_PyBytes_FromStringAndSize PyBytes_FromStringAndSize +static CYTHON_INLINE PyObject* __Pyx_PyUnicode_FromString(const char*); +#if PY_MAJOR_VERSION < 3 + #define __Pyx_PyStr_FromString __Pyx_PyBytes_FromString + #define __Pyx_PyStr_FromStringAndSize __Pyx_PyBytes_FromStringAndSize +#else + #define __Pyx_PyStr_FromString __Pyx_PyUnicode_FromString + #define __Pyx_PyStr_FromStringAndSize __Pyx_PyUnicode_FromStringAndSize +#endif +#define __Pyx_PyBytes_AsWritableString(s) ((char*) PyBytes_AS_STRING(s)) +#define __Pyx_PyBytes_AsWritableSString(s) ((signed char*) PyBytes_AS_STRING(s)) +#define __Pyx_PyBytes_AsWritableUString(s) ((unsigned char*) PyBytes_AS_STRING(s)) +#define __Pyx_PyBytes_AsString(s) ((const char*) PyBytes_AS_STRING(s)) +#define __Pyx_PyBytes_AsSString(s) ((const signed char*) PyBytes_AS_STRING(s)) +#define __Pyx_PyBytes_AsUString(s) ((const unsigned char*) PyBytes_AS_STRING(s)) +#define __Pyx_PyObject_AsWritableString(s) ((char*)(__pyx_uintptr_t) __Pyx_PyObject_AsString(s)) +#define __Pyx_PyObject_AsWritableSString(s) ((signed char*)(__pyx_uintptr_t) __Pyx_PyObject_AsString(s)) +#define __Pyx_PyObject_AsWritableUString(s) ((unsigned char*)(__pyx_uintptr_t) __Pyx_PyObject_AsString(s)) +#define __Pyx_PyObject_AsSString(s) ((const signed char*) __Pyx_PyObject_AsString(s)) +#define __Pyx_PyObject_AsUString(s) ((const unsigned char*) __Pyx_PyObject_AsString(s)) +#define __Pyx_PyObject_FromCString(s) __Pyx_PyObject_FromString((const char*)s) +#define __Pyx_PyBytes_FromCString(s) __Pyx_PyBytes_FromString((const char*)s) +#define __Pyx_PyByteArray_FromCString(s) __Pyx_PyByteArray_FromString((const char*)s) +#define __Pyx_PyStr_FromCString(s) __Pyx_PyStr_FromString((const char*)s) +#define __Pyx_PyUnicode_FromCString(s) __Pyx_PyUnicode_FromString((const char*)s) +#if CYTHON_COMPILING_IN_LIMITED_API +static CYTHON_INLINE size_t __Pyx_Py_UNICODE_strlen(const wchar_t *u) +{ + const wchar_t *u_end = u; + while (*u_end++) ; + return (size_t)(u_end - u - 1); +} +#else +static CYTHON_INLINE size_t __Pyx_Py_UNICODE_strlen(const Py_UNICODE *u) +{ + const Py_UNICODE *u_end = u; + while (*u_end++) ; + return (size_t)(u_end - u - 1); +} +#endif +#define __Pyx_PyUnicode_FromOrdinal(o) PyUnicode_FromOrdinal((int)o) +#define __Pyx_PyUnicode_FromUnicode(u) PyUnicode_FromUnicode(u, __Pyx_Py_UNICODE_strlen(u)) +#define __Pyx_PyUnicode_FromUnicodeAndLength PyUnicode_FromUnicode +#define __Pyx_PyUnicode_AsUnicode PyUnicode_AsUnicode +#define __Pyx_NewRef(obj) (Py_INCREF(obj), obj) +#define __Pyx_Owned_Py_None(b) __Pyx_NewRef(Py_None) +static CYTHON_INLINE PyObject * __Pyx_PyBool_FromLong(long b); +static CYTHON_INLINE int __Pyx_PyObject_IsTrue(PyObject*); +static CYTHON_INLINE int __Pyx_PyObject_IsTrueAndDecref(PyObject*); +static CYTHON_INLINE PyObject* __Pyx_PyNumber_IntOrLong(PyObject* x); +#define __Pyx_PySequence_Tuple(obj)\ + (likely(PyTuple_CheckExact(obj)) ? __Pyx_NewRef(obj) : PySequence_Tuple(obj)) +static CYTHON_INLINE Py_ssize_t __Pyx_PyIndex_AsSsize_t(PyObject*); +static CYTHON_INLINE PyObject * __Pyx_PyInt_FromSize_t(size_t); +static CYTHON_INLINE Py_hash_t __Pyx_PyIndex_AsHash_t(PyObject*); +#if CYTHON_ASSUME_SAFE_MACROS +#define __pyx_PyFloat_AsDouble(x) (PyFloat_CheckExact(x) ? PyFloat_AS_DOUBLE(x) : PyFloat_AsDouble(x)) +#else +#define __pyx_PyFloat_AsDouble(x) PyFloat_AsDouble(x) +#endif +#define __pyx_PyFloat_AsFloat(x) ((float) __pyx_PyFloat_AsDouble(x)) +#if PY_MAJOR_VERSION >= 3 +#define __Pyx_PyNumber_Int(x) (PyLong_CheckExact(x) ? __Pyx_NewRef(x) : PyNumber_Long(x)) +#else +#define __Pyx_PyNumber_Int(x) (PyInt_CheckExact(x) ? __Pyx_NewRef(x) : PyNumber_Int(x)) +#endif +#if CYTHON_USE_PYLONG_INTERNALS + #if PY_VERSION_HEX >= 0x030C00A7 + #ifndef _PyLong_SIGN_MASK + #define _PyLong_SIGN_MASK 3 + #endif + #ifndef _PyLong_NON_SIZE_BITS + #define _PyLong_NON_SIZE_BITS 3 + #endif + #define __Pyx_PyLong_Sign(x) (((PyLongObject*)x)->long_value.lv_tag & _PyLong_SIGN_MASK) + #define __Pyx_PyLong_IsNeg(x) ((__Pyx_PyLong_Sign(x) & 2) != 0) + #define __Pyx_PyLong_IsNonNeg(x) (!__Pyx_PyLong_IsNeg(x)) + #define __Pyx_PyLong_IsZero(x) (__Pyx_PyLong_Sign(x) & 1) + #define __Pyx_PyLong_IsPos(x) (__Pyx_PyLong_Sign(x) == 0) + #define __Pyx_PyLong_CompactValueUnsigned(x) (__Pyx_PyLong_Digits(x)[0]) + #define __Pyx_PyLong_DigitCount(x) ((Py_ssize_t) (((PyLongObject*)x)->long_value.lv_tag >> _PyLong_NON_SIZE_BITS)) + #define __Pyx_PyLong_SignedDigitCount(x)\ + ((1 - (Py_ssize_t) __Pyx_PyLong_Sign(x)) * __Pyx_PyLong_DigitCount(x)) + #if defined(PyUnstable_Long_IsCompact) && defined(PyUnstable_Long_CompactValue) + #define __Pyx_PyLong_IsCompact(x) PyUnstable_Long_IsCompact((PyLongObject*) x) + #define __Pyx_PyLong_CompactValue(x) PyUnstable_Long_CompactValue((PyLongObject*) x) + #else + #define __Pyx_PyLong_IsCompact(x) (((PyLongObject*)x)->long_value.lv_tag < (2 << _PyLong_NON_SIZE_BITS)) + #define __Pyx_PyLong_CompactValue(x) ((1 - (Py_ssize_t) __Pyx_PyLong_Sign(x)) * (Py_ssize_t) __Pyx_PyLong_Digits(x)[0]) + #endif + typedef Py_ssize_t __Pyx_compact_pylong; + typedef size_t __Pyx_compact_upylong; + #else + #define __Pyx_PyLong_IsNeg(x) (Py_SIZE(x) < 0) + #define __Pyx_PyLong_IsNonNeg(x) (Py_SIZE(x) >= 0) + #define __Pyx_PyLong_IsZero(x) (Py_SIZE(x) == 0) + #define __Pyx_PyLong_IsPos(x) (Py_SIZE(x) > 0) + #define __Pyx_PyLong_CompactValueUnsigned(x) ((Py_SIZE(x) == 0) ? 0 : __Pyx_PyLong_Digits(x)[0]) + #define __Pyx_PyLong_DigitCount(x) __Pyx_sst_abs(Py_SIZE(x)) + #define __Pyx_PyLong_SignedDigitCount(x) Py_SIZE(x) + #define __Pyx_PyLong_IsCompact(x) (Py_SIZE(x) == 0 || Py_SIZE(x) == 1 || Py_SIZE(x) == -1) + #define __Pyx_PyLong_CompactValue(x)\ + ((Py_SIZE(x) == 0) ? (sdigit) 0 : ((Py_SIZE(x) < 0) ? -(sdigit)__Pyx_PyLong_Digits(x)[0] : (sdigit)__Pyx_PyLong_Digits(x)[0])) + typedef sdigit __Pyx_compact_pylong; + typedef digit __Pyx_compact_upylong; + #endif + #if PY_VERSION_HEX >= 0x030C00A5 + #define __Pyx_PyLong_Digits(x) (((PyLongObject*)x)->long_value.ob_digit) + #else + #define __Pyx_PyLong_Digits(x) (((PyLongObject*)x)->ob_digit) + #endif +#endif +#if PY_MAJOR_VERSION < 3 && __PYX_DEFAULT_STRING_ENCODING_IS_ASCII +#include +static int __Pyx_sys_getdefaultencoding_not_ascii; +static int __Pyx_init_sys_getdefaultencoding_params(void) { + PyObject* sys; + PyObject* default_encoding = NULL; + PyObject* ascii_chars_u = NULL; + PyObject* ascii_chars_b = NULL; + const char* default_encoding_c; + sys = PyImport_ImportModule("sys"); + if (!sys) goto bad; + default_encoding = PyObject_CallMethod(sys, (char*) "getdefaultencoding", NULL); + Py_DECREF(sys); + if (!default_encoding) goto bad; + default_encoding_c = PyBytes_AsString(default_encoding); + if (!default_encoding_c) goto bad; + if (strcmp(default_encoding_c, "ascii") == 0) { + __Pyx_sys_getdefaultencoding_not_ascii = 0; + } else { + char ascii_chars[128]; + int c; + for (c = 0; c < 128; c++) { + ascii_chars[c] = (char) c; + } + __Pyx_sys_getdefaultencoding_not_ascii = 1; + ascii_chars_u = PyUnicode_DecodeASCII(ascii_chars, 128, NULL); + if (!ascii_chars_u) goto bad; + ascii_chars_b = PyUnicode_AsEncodedString(ascii_chars_u, default_encoding_c, NULL); + if (!ascii_chars_b || !PyBytes_Check(ascii_chars_b) || memcmp(ascii_chars, PyBytes_AS_STRING(ascii_chars_b), 128) != 0) { + PyErr_Format( + PyExc_ValueError, + "This module compiled with c_string_encoding=ascii, but default encoding '%.200s' is not a superset of ascii.", + default_encoding_c); + goto bad; + } + Py_DECREF(ascii_chars_u); + Py_DECREF(ascii_chars_b); + } + Py_DECREF(default_encoding); + return 0; +bad: + Py_XDECREF(default_encoding); + Py_XDECREF(ascii_chars_u); + Py_XDECREF(ascii_chars_b); + return -1; +} +#endif +#if __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT && PY_MAJOR_VERSION >= 3 +#define __Pyx_PyUnicode_FromStringAndSize(c_str, size) PyUnicode_DecodeUTF8(c_str, size, NULL) +#else +#define __Pyx_PyUnicode_FromStringAndSize(c_str, size) PyUnicode_Decode(c_str, size, __PYX_DEFAULT_STRING_ENCODING, NULL) +#if __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT +#include +static char* __PYX_DEFAULT_STRING_ENCODING; +static int __Pyx_init_sys_getdefaultencoding_params(void) { + PyObject* sys; + PyObject* default_encoding = NULL; + char* default_encoding_c; + sys = PyImport_ImportModule("sys"); + if (!sys) goto bad; + default_encoding = PyObject_CallMethod(sys, (char*) (const char*) "getdefaultencoding", NULL); + Py_DECREF(sys); + if (!default_encoding) goto bad; + default_encoding_c = PyBytes_AsString(default_encoding); + if (!default_encoding_c) goto bad; + __PYX_DEFAULT_STRING_ENCODING = (char*) malloc(strlen(default_encoding_c) + 1); + if (!__PYX_DEFAULT_STRING_ENCODING) goto bad; + strcpy(__PYX_DEFAULT_STRING_ENCODING, default_encoding_c); + Py_DECREF(default_encoding); + return 0; +bad: + Py_XDECREF(default_encoding); + return -1; +} +#endif +#endif + + +/* Test for GCC > 2.95 */ +#if defined(__GNUC__) && (__GNUC__ > 2 || (__GNUC__ == 2 && (__GNUC_MINOR__ > 95))) + #define likely(x) __builtin_expect(!!(x), 1) + #define unlikely(x) __builtin_expect(!!(x), 0) +#else /* !__GNUC__ or GCC < 2.95 */ + #define likely(x) (x) + #define unlikely(x) (x) +#endif /* __GNUC__ */ +static CYTHON_INLINE void __Pyx_pretend_to_initialize(void* ptr) { (void)ptr; } + +#if !CYTHON_USE_MODULE_STATE +static PyObject *__pyx_m = NULL; +#endif +static int __pyx_lineno; +static int __pyx_clineno = 0; +static const char * __pyx_cfilenm = __FILE__; +static const char *__pyx_filename; + +/* Header.proto */ +#if !defined(CYTHON_CCOMPLEX) + #if defined(__cplusplus) + #define CYTHON_CCOMPLEX 1 + #elif (defined(_Complex_I) && !defined(_MSC_VER)) || ((defined (__STDC_VERSION__) && __STDC_VERSION__ >= 201112L) && !defined(__STDC_NO_COMPLEX__) && !defined(_MSC_VER)) + #define CYTHON_CCOMPLEX 1 + #else + #define CYTHON_CCOMPLEX 0 + #endif +#endif +#if CYTHON_CCOMPLEX + #ifdef __cplusplus + #include + #else + #include + #endif +#endif +#if CYTHON_CCOMPLEX && !defined(__cplusplus) && defined(__sun__) && defined(__GNUC__) + #undef _Complex_I + #define _Complex_I 1.0fj +#endif + +/* #### Code section: filename_table ### */ + +static const char *__pyx_f[] = { + "fairseq/data/data_utils_fast.pyx", + "", + "__init__.cython-30.pxd", + "type.pxd", +}; +/* #### Code section: utility_code_proto_before_types ### */ +/* ForceInitThreads.proto */ +#ifndef __PYX_FORCE_INIT_THREADS + #define __PYX_FORCE_INIT_THREADS 0 +#endif + +/* NoFastGil.proto */ +#define __Pyx_PyGILState_Ensure PyGILState_Ensure +#define __Pyx_PyGILState_Release PyGILState_Release +#define __Pyx_FastGIL_Remember() +#define __Pyx_FastGIL_Forget() +#define __Pyx_FastGilFuncInit() + +/* BufferFormatStructs.proto */ +struct __Pyx_StructField_; +#define __PYX_BUF_FLAGS_PACKED_STRUCT (1 << 0) +typedef struct { + const char* name; + struct __Pyx_StructField_* fields; + size_t size; + size_t arraysize[8]; + int ndim; + char typegroup; + char is_unsigned; + int flags; +} __Pyx_TypeInfo; +typedef struct __Pyx_StructField_ { + __Pyx_TypeInfo* type; + const char* name; + size_t offset; +} __Pyx_StructField; +typedef struct { + __Pyx_StructField* field; + size_t parent_offset; +} __Pyx_BufFmt_StackElem; +typedef struct { + __Pyx_StructField root; + __Pyx_BufFmt_StackElem* head; + size_t fmt_offset; + size_t new_count, enc_count; + size_t struct_alignment; + int is_complex; + char enc_type; + char new_packmode; + char enc_packmode; + char is_valid_array; +} __Pyx_BufFmt_Context; + +/* Atomics.proto */ +#include +#ifndef CYTHON_ATOMICS + #define CYTHON_ATOMICS 1 +#endif +#define __PYX_CYTHON_ATOMICS_ENABLED() CYTHON_ATOMICS +#define __pyx_atomic_int_type int +#define __pyx_nonatomic_int_type int +#if CYTHON_ATOMICS && (defined(__STDC_VERSION__) &&\ + (__STDC_VERSION__ >= 201112L) &&\ + !defined(__STDC_NO_ATOMICS__)) + #include +#elif CYTHON_ATOMICS && (defined(__cplusplus) && (\ + (__cplusplus >= 201103L) ||\ + (defined(_MSC_VER) && _MSC_VER >= 1700))) + #include +#endif +#if CYTHON_ATOMICS && (defined(__STDC_VERSION__) &&\ + (__STDC_VERSION__ >= 201112L) &&\ + !defined(__STDC_NO_ATOMICS__) &&\ + ATOMIC_INT_LOCK_FREE == 2) + #undef __pyx_atomic_int_type + #define __pyx_atomic_int_type atomic_int + #define __pyx_atomic_incr_aligned(value) atomic_fetch_add_explicit(value, 1, memory_order_relaxed) + #define __pyx_atomic_decr_aligned(value) atomic_fetch_sub_explicit(value, 1, memory_order_acq_rel) + #if defined(__PYX_DEBUG_ATOMICS) && defined(_MSC_VER) + #pragma message ("Using standard C atomics") + #elif defined(__PYX_DEBUG_ATOMICS) + #warning "Using standard C atomics" + #endif +#elif CYTHON_ATOMICS && (defined(__cplusplus) && (\ + (__cplusplus >= 201103L) ||\ +\ + (defined(_MSC_VER) && _MSC_VER >= 1700)) &&\ + ATOMIC_INT_LOCK_FREE == 2) + #undef __pyx_atomic_int_type + #define __pyx_atomic_int_type std::atomic_int + #define __pyx_atomic_incr_aligned(value) std::atomic_fetch_add_explicit(value, 1, std::memory_order_relaxed) + #define __pyx_atomic_decr_aligned(value) std::atomic_fetch_sub_explicit(value, 1, std::memory_order_acq_rel) + #if defined(__PYX_DEBUG_ATOMICS) && defined(_MSC_VER) + #pragma message ("Using standard C++ atomics") + #elif defined(__PYX_DEBUG_ATOMICS) + #warning "Using standard C++ atomics" + #endif +#elif CYTHON_ATOMICS && (__GNUC__ >= 5 || (__GNUC__ == 4 &&\ + (__GNUC_MINOR__ > 1 ||\ + (__GNUC_MINOR__ == 1 && __GNUC_PATCHLEVEL__ >= 2)))) + #define __pyx_atomic_incr_aligned(value) __sync_fetch_and_add(value, 1) + #define __pyx_atomic_decr_aligned(value) __sync_fetch_and_sub(value, 1) + #ifdef __PYX_DEBUG_ATOMICS + #warning "Using GNU atomics" + #endif +#elif CYTHON_ATOMICS && defined(_MSC_VER) + #include + #undef __pyx_atomic_int_type + #define __pyx_atomic_int_type long + #undef __pyx_nonatomic_int_type + #define __pyx_nonatomic_int_type long + #pragma intrinsic (_InterlockedExchangeAdd) + #define __pyx_atomic_incr_aligned(value) _InterlockedExchangeAdd(value, 1) + #define __pyx_atomic_decr_aligned(value) _InterlockedExchangeAdd(value, -1) + #ifdef __PYX_DEBUG_ATOMICS + #pragma message ("Using MSVC atomics") + #endif +#else + #undef CYTHON_ATOMICS + #define CYTHON_ATOMICS 0 + #ifdef __PYX_DEBUG_ATOMICS + #warning "Not using atomics" + #endif +#endif +#if CYTHON_ATOMICS + #define __pyx_add_acquisition_count(memview)\ + __pyx_atomic_incr_aligned(__pyx_get_slice_count_pointer(memview)) + #define __pyx_sub_acquisition_count(memview)\ + __pyx_atomic_decr_aligned(__pyx_get_slice_count_pointer(memview)) +#else + #define __pyx_add_acquisition_count(memview)\ + __pyx_add_acquisition_count_locked(__pyx_get_slice_count_pointer(memview), memview->lock) + #define __pyx_sub_acquisition_count(memview)\ + __pyx_sub_acquisition_count_locked(__pyx_get_slice_count_pointer(memview), memview->lock) +#endif + +/* MemviewSliceStruct.proto */ +struct __pyx_memoryview_obj; +typedef struct { + struct __pyx_memoryview_obj *memview; + char *data; + Py_ssize_t shape[8]; + Py_ssize_t strides[8]; + Py_ssize_t suboffsets[8]; +} __Pyx_memviewslice; +#define __Pyx_MemoryView_Len(m) (m.shape[0]) + +/* #### Code section: numeric_typedefs ### */ + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":730 + * # in Cython to enable them only on the right systems. + * + * ctypedef npy_int8 int8_t # <<<<<<<<<<<<<< + * ctypedef npy_int16 int16_t + * ctypedef npy_int32 int32_t + */ +typedef npy_int8 __pyx_t_5numpy_int8_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":731 + * + * ctypedef npy_int8 int8_t + * ctypedef npy_int16 int16_t # <<<<<<<<<<<<<< + * ctypedef npy_int32 int32_t + * ctypedef npy_int64 int64_t + */ +typedef npy_int16 __pyx_t_5numpy_int16_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":732 + * ctypedef npy_int8 int8_t + * ctypedef npy_int16 int16_t + * ctypedef npy_int32 int32_t # <<<<<<<<<<<<<< + * ctypedef npy_int64 int64_t + * #ctypedef npy_int96 int96_t + */ +typedef npy_int32 __pyx_t_5numpy_int32_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":733 + * ctypedef npy_int16 int16_t + * ctypedef npy_int32 int32_t + * ctypedef npy_int64 int64_t # <<<<<<<<<<<<<< + * #ctypedef npy_int96 int96_t + * #ctypedef npy_int128 int128_t + */ +typedef npy_int64 __pyx_t_5numpy_int64_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":737 + * #ctypedef npy_int128 int128_t + * + * ctypedef npy_uint8 uint8_t # <<<<<<<<<<<<<< + * ctypedef npy_uint16 uint16_t + * ctypedef npy_uint32 uint32_t + */ +typedef npy_uint8 __pyx_t_5numpy_uint8_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":738 + * + * ctypedef npy_uint8 uint8_t + * ctypedef npy_uint16 uint16_t # <<<<<<<<<<<<<< + * ctypedef npy_uint32 uint32_t + * ctypedef npy_uint64 uint64_t + */ +typedef npy_uint16 __pyx_t_5numpy_uint16_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":739 + * ctypedef npy_uint8 uint8_t + * ctypedef npy_uint16 uint16_t + * ctypedef npy_uint32 uint32_t # <<<<<<<<<<<<<< + * ctypedef npy_uint64 uint64_t + * #ctypedef npy_uint96 uint96_t + */ +typedef npy_uint32 __pyx_t_5numpy_uint32_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":740 + * ctypedef npy_uint16 uint16_t + * ctypedef npy_uint32 uint32_t + * ctypedef npy_uint64 uint64_t # <<<<<<<<<<<<<< + * #ctypedef npy_uint96 uint96_t + * #ctypedef npy_uint128 uint128_t + */ +typedef npy_uint64 __pyx_t_5numpy_uint64_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":744 + * #ctypedef npy_uint128 uint128_t + * + * ctypedef npy_float32 float32_t # <<<<<<<<<<<<<< + * ctypedef npy_float64 float64_t + * #ctypedef npy_float80 float80_t + */ +typedef npy_float32 __pyx_t_5numpy_float32_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":745 + * + * ctypedef npy_float32 float32_t + * ctypedef npy_float64 float64_t # <<<<<<<<<<<<<< + * #ctypedef npy_float80 float80_t + * #ctypedef npy_float128 float128_t + */ +typedef npy_float64 __pyx_t_5numpy_float64_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":754 + * # The int types are mapped a bit surprising -- + * # numpy.int corresponds to 'l' and numpy.long to 'q' + * ctypedef npy_long int_t # <<<<<<<<<<<<<< + * ctypedef npy_longlong longlong_t + * + */ +typedef npy_long __pyx_t_5numpy_int_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":755 + * # numpy.int corresponds to 'l' and numpy.long to 'q' + * ctypedef npy_long int_t + * ctypedef npy_longlong longlong_t # <<<<<<<<<<<<<< + * + * ctypedef npy_ulong uint_t + */ +typedef npy_longlong __pyx_t_5numpy_longlong_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":757 + * ctypedef npy_longlong longlong_t + * + * ctypedef npy_ulong uint_t # <<<<<<<<<<<<<< + * ctypedef npy_ulonglong ulonglong_t + * + */ +typedef npy_ulong __pyx_t_5numpy_uint_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":758 + * + * ctypedef npy_ulong uint_t + * ctypedef npy_ulonglong ulonglong_t # <<<<<<<<<<<<<< + * + * ctypedef npy_intp intp_t + */ +typedef npy_ulonglong __pyx_t_5numpy_ulonglong_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":760 + * ctypedef npy_ulonglong ulonglong_t + * + * ctypedef npy_intp intp_t # <<<<<<<<<<<<<< + * ctypedef npy_uintp uintp_t + * + */ +typedef npy_intp __pyx_t_5numpy_intp_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":761 + * + * ctypedef npy_intp intp_t + * ctypedef npy_uintp uintp_t # <<<<<<<<<<<<<< + * + * ctypedef npy_double float_t + */ +typedef npy_uintp __pyx_t_5numpy_uintp_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":763 + * ctypedef npy_uintp uintp_t + * + * ctypedef npy_double float_t # <<<<<<<<<<<<<< + * ctypedef npy_double double_t + * ctypedef npy_longdouble longdouble_t + */ +typedef npy_double __pyx_t_5numpy_float_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":764 + * + * ctypedef npy_double float_t + * ctypedef npy_double double_t # <<<<<<<<<<<<<< + * ctypedef npy_longdouble longdouble_t + * + */ +typedef npy_double __pyx_t_5numpy_double_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":765 + * ctypedef npy_double float_t + * ctypedef npy_double double_t + * ctypedef npy_longdouble longdouble_t # <<<<<<<<<<<<<< + * + * ctypedef npy_cfloat cfloat_t + */ +typedef npy_longdouble __pyx_t_5numpy_longdouble_t; + +/* "fairseq/data/data_utils_fast.pyx":15 + * from libcpp cimport bool as bool_t + * + * ctypedef int64_t DTYPE_t # <<<<<<<<<<<<<< + * + * @cython.cdivision(True) + */ +typedef int64_t __pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t; +/* #### Code section: complex_type_declarations ### */ +/* Declarations.proto */ +#if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) + #ifdef __cplusplus + typedef ::std::complex< float > __pyx_t_float_complex; + #else + typedef float _Complex __pyx_t_float_complex; + #endif +#else + typedef struct { float real, imag; } __pyx_t_float_complex; +#endif +static CYTHON_INLINE __pyx_t_float_complex __pyx_t_float_complex_from_parts(float, float); + +/* Declarations.proto */ +#if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) + #ifdef __cplusplus + typedef ::std::complex< double > __pyx_t_double_complex; + #else + typedef double _Complex __pyx_t_double_complex; + #endif +#else + typedef struct { double real, imag; } __pyx_t_double_complex; +#endif +static CYTHON_INLINE __pyx_t_double_complex __pyx_t_double_complex_from_parts(double, double); + +/* #### Code section: type_declarations ### */ + +/*--- Type declarations ---*/ +struct __pyx_array_obj; +struct __pyx_MemviewEnum_obj; +struct __pyx_memoryview_obj; +struct __pyx_memoryviewslice_obj; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":767 + * ctypedef npy_longdouble longdouble_t + * + * ctypedef npy_cfloat cfloat_t # <<<<<<<<<<<<<< + * ctypedef npy_cdouble cdouble_t + * ctypedef npy_clongdouble clongdouble_t + */ +typedef npy_cfloat __pyx_t_5numpy_cfloat_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":768 + * + * ctypedef npy_cfloat cfloat_t + * ctypedef npy_cdouble cdouble_t # <<<<<<<<<<<<<< + * ctypedef npy_clongdouble clongdouble_t + * + */ +typedef npy_cdouble __pyx_t_5numpy_cdouble_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":769 + * ctypedef npy_cfloat cfloat_t + * ctypedef npy_cdouble cdouble_t + * ctypedef npy_clongdouble clongdouble_t # <<<<<<<<<<<<<< + * + * ctypedef npy_cdouble complex_t + */ +typedef npy_clongdouble __pyx_t_5numpy_clongdouble_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":771 + * ctypedef npy_clongdouble clongdouble_t + * + * ctypedef npy_cdouble complex_t # <<<<<<<<<<<<<< + * + * cdef inline object PyArray_MultiIterNew1(a): + */ +typedef npy_cdouble __pyx_t_5numpy_complex_t; + +/* "View.MemoryView":114 + * @cython.collection_type("sequence") + * @cname("__pyx_array") + * cdef class array: # <<<<<<<<<<<<<< + * + * cdef: + */ +struct __pyx_array_obj { + PyObject_HEAD + struct __pyx_vtabstruct_array *__pyx_vtab; + char *data; + Py_ssize_t len; + char *format; + int ndim; + Py_ssize_t *_shape; + Py_ssize_t *_strides; + Py_ssize_t itemsize; + PyObject *mode; + PyObject *_format; + void (*callback_free_data)(void *); + int free_data; + int dtype_is_object; +}; + + +/* "View.MemoryView":302 + * + * @cname('__pyx_MemviewEnum') + * cdef class Enum(object): # <<<<<<<<<<<<<< + * cdef object name + * def __init__(self, name): + */ +struct __pyx_MemviewEnum_obj { + PyObject_HEAD + PyObject *name; +}; + + +/* "View.MemoryView":337 + * + * @cname('__pyx_memoryview') + * cdef class memoryview: # <<<<<<<<<<<<<< + * + * cdef object obj + */ +struct __pyx_memoryview_obj { + PyObject_HEAD + struct __pyx_vtabstruct_memoryview *__pyx_vtab; + PyObject *obj; + PyObject *_size; + PyObject *_array_interface; + PyThread_type_lock lock; + __pyx_atomic_int_type acquisition_count; + Py_buffer view; + int flags; + int dtype_is_object; + __Pyx_TypeInfo *typeinfo; +}; + + +/* "View.MemoryView":952 + * @cython.collection_type("sequence") + * @cname('__pyx_memoryviewslice') + * cdef class _memoryviewslice(memoryview): # <<<<<<<<<<<<<< + * "Internal class for passing memoryview slices to Python" + * + */ +struct __pyx_memoryviewslice_obj { + struct __pyx_memoryview_obj __pyx_base; + __Pyx_memviewslice from_slice; + PyObject *from_object; + PyObject *(*to_object_func)(char *); + int (*to_dtype_func)(char *, PyObject *); +}; + + + +/* "View.MemoryView":114 + * @cython.collection_type("sequence") + * @cname("__pyx_array") + * cdef class array: # <<<<<<<<<<<<<< + * + * cdef: + */ + +struct __pyx_vtabstruct_array { + PyObject *(*get_memview)(struct __pyx_array_obj *); +}; +static struct __pyx_vtabstruct_array *__pyx_vtabptr_array; + + +/* "View.MemoryView":337 + * + * @cname('__pyx_memoryview') + * cdef class memoryview: # <<<<<<<<<<<<<< + * + * cdef object obj + */ + +struct __pyx_vtabstruct_memoryview { + char *(*get_item_pointer)(struct __pyx_memoryview_obj *, PyObject *); + PyObject *(*is_slice)(struct __pyx_memoryview_obj *, PyObject *); + PyObject *(*setitem_slice_assignment)(struct __pyx_memoryview_obj *, PyObject *, PyObject *); + PyObject *(*setitem_slice_assign_scalar)(struct __pyx_memoryview_obj *, struct __pyx_memoryview_obj *, PyObject *); + PyObject *(*setitem_indexed)(struct __pyx_memoryview_obj *, PyObject *, PyObject *); + PyObject *(*convert_item_to_object)(struct __pyx_memoryview_obj *, char *); + PyObject *(*assign_item_from_object)(struct __pyx_memoryview_obj *, char *, PyObject *); + PyObject *(*_get_base)(struct __pyx_memoryview_obj *); +}; +static struct __pyx_vtabstruct_memoryview *__pyx_vtabptr_memoryview; + + +/* "View.MemoryView":952 + * @cython.collection_type("sequence") + * @cname('__pyx_memoryviewslice') + * cdef class _memoryviewslice(memoryview): # <<<<<<<<<<<<<< + * "Internal class for passing memoryview slices to Python" + * + */ + +struct __pyx_vtabstruct__memoryviewslice { + struct __pyx_vtabstruct_memoryview __pyx_base; +}; +static struct __pyx_vtabstruct__memoryviewslice *__pyx_vtabptr__memoryviewslice; +/* #### Code section: utility_code_proto ### */ + +/* --- Runtime support code (head) --- */ +/* Refnanny.proto */ +#ifndef CYTHON_REFNANNY + #define CYTHON_REFNANNY 0 +#endif +#if CYTHON_REFNANNY + typedef struct { + void (*INCREF)(void*, PyObject*, Py_ssize_t); + void (*DECREF)(void*, PyObject*, Py_ssize_t); + void (*GOTREF)(void*, PyObject*, Py_ssize_t); + void (*GIVEREF)(void*, PyObject*, Py_ssize_t); + void* (*SetupContext)(const char*, Py_ssize_t, const char*); + void (*FinishContext)(void**); + } __Pyx_RefNannyAPIStruct; + static __Pyx_RefNannyAPIStruct *__Pyx_RefNanny = NULL; + static __Pyx_RefNannyAPIStruct *__Pyx_RefNannyImportAPI(const char *modname); + #define __Pyx_RefNannyDeclarations void *__pyx_refnanny = NULL; +#ifdef WITH_THREAD + #define __Pyx_RefNannySetupContext(name, acquire_gil)\ + if (acquire_gil) {\ + PyGILState_STATE __pyx_gilstate_save = PyGILState_Ensure();\ + __pyx_refnanny = __Pyx_RefNanny->SetupContext((name), (__LINE__), (__FILE__));\ + PyGILState_Release(__pyx_gilstate_save);\ + } else {\ + __pyx_refnanny = __Pyx_RefNanny->SetupContext((name), (__LINE__), (__FILE__));\ + } + #define __Pyx_RefNannyFinishContextNogil() {\ + PyGILState_STATE __pyx_gilstate_save = PyGILState_Ensure();\ + __Pyx_RefNannyFinishContext();\ + PyGILState_Release(__pyx_gilstate_save);\ + } +#else + #define __Pyx_RefNannySetupContext(name, acquire_gil)\ + __pyx_refnanny = __Pyx_RefNanny->SetupContext((name), (__LINE__), (__FILE__)) + #define __Pyx_RefNannyFinishContextNogil() __Pyx_RefNannyFinishContext() +#endif + #define __Pyx_RefNannyFinishContextNogil() {\ + PyGILState_STATE __pyx_gilstate_save = PyGILState_Ensure();\ + __Pyx_RefNannyFinishContext();\ + PyGILState_Release(__pyx_gilstate_save);\ + } + #define __Pyx_RefNannyFinishContext()\ + __Pyx_RefNanny->FinishContext(&__pyx_refnanny) + #define __Pyx_INCREF(r) __Pyx_RefNanny->INCREF(__pyx_refnanny, (PyObject *)(r), (__LINE__)) + #define __Pyx_DECREF(r) __Pyx_RefNanny->DECREF(__pyx_refnanny, (PyObject *)(r), (__LINE__)) + #define __Pyx_GOTREF(r) __Pyx_RefNanny->GOTREF(__pyx_refnanny, (PyObject *)(r), (__LINE__)) + #define __Pyx_GIVEREF(r) __Pyx_RefNanny->GIVEREF(__pyx_refnanny, (PyObject *)(r), (__LINE__)) + #define __Pyx_XINCREF(r) do { if((r) == NULL); else {__Pyx_INCREF(r); }} while(0) + #define __Pyx_XDECREF(r) do { if((r) == NULL); else {__Pyx_DECREF(r); }} while(0) + #define __Pyx_XGOTREF(r) do { if((r) == NULL); else {__Pyx_GOTREF(r); }} while(0) + #define __Pyx_XGIVEREF(r) do { if((r) == NULL); else {__Pyx_GIVEREF(r);}} while(0) +#else + #define __Pyx_RefNannyDeclarations + #define __Pyx_RefNannySetupContext(name, acquire_gil) + #define __Pyx_RefNannyFinishContextNogil() + #define __Pyx_RefNannyFinishContext() + #define __Pyx_INCREF(r) Py_INCREF(r) + #define __Pyx_DECREF(r) Py_DECREF(r) + #define __Pyx_GOTREF(r) + #define __Pyx_GIVEREF(r) + #define __Pyx_XINCREF(r) Py_XINCREF(r) + #define __Pyx_XDECREF(r) Py_XDECREF(r) + #define __Pyx_XGOTREF(r) + #define __Pyx_XGIVEREF(r) +#endif +#define __Pyx_Py_XDECREF_SET(r, v) do {\ + PyObject *tmp = (PyObject *) r;\ + r = v; Py_XDECREF(tmp);\ + } while (0) +#define __Pyx_XDECREF_SET(r, v) do {\ + PyObject *tmp = (PyObject *) r;\ + r = v; __Pyx_XDECREF(tmp);\ + } while (0) +#define __Pyx_DECREF_SET(r, v) do {\ + PyObject *tmp = (PyObject *) r;\ + r = v; __Pyx_DECREF(tmp);\ + } while (0) +#define __Pyx_CLEAR(r) do { PyObject* tmp = ((PyObject*)(r)); r = NULL; __Pyx_DECREF(tmp);} while(0) +#define __Pyx_XCLEAR(r) do { if((r) != NULL) {PyObject* tmp = ((PyObject*)(r)); r = NULL; __Pyx_DECREF(tmp);}} while(0) + +/* PyErrExceptionMatches.proto */ +#if CYTHON_FAST_THREAD_STATE +#define __Pyx_PyErr_ExceptionMatches(err) __Pyx_PyErr_ExceptionMatchesInState(__pyx_tstate, err) +static CYTHON_INLINE int __Pyx_PyErr_ExceptionMatchesInState(PyThreadState* tstate, PyObject* err); +#else +#define __Pyx_PyErr_ExceptionMatches(err) PyErr_ExceptionMatches(err) +#endif + +/* PyThreadStateGet.proto */ +#if CYTHON_FAST_THREAD_STATE +#define __Pyx_PyThreadState_declare PyThreadState *__pyx_tstate; +#define __Pyx_PyThreadState_assign __pyx_tstate = __Pyx_PyThreadState_Current; +#if PY_VERSION_HEX >= 0x030C00A6 +#define __Pyx_PyErr_Occurred() (__pyx_tstate->current_exception != NULL) +#define __Pyx_PyErr_CurrentExceptionType() (__pyx_tstate->current_exception ? (PyObject*) Py_TYPE(__pyx_tstate->current_exception) : (PyObject*) NULL) +#else +#define __Pyx_PyErr_Occurred() (__pyx_tstate->curexc_type != NULL) +#define __Pyx_PyErr_CurrentExceptionType() (__pyx_tstate->curexc_type) +#endif +#else +#define __Pyx_PyThreadState_declare +#define __Pyx_PyThreadState_assign +#define __Pyx_PyErr_Occurred() (PyErr_Occurred() != NULL) +#define __Pyx_PyErr_CurrentExceptionType() PyErr_Occurred() +#endif + +/* PyErrFetchRestore.proto */ +#if CYTHON_FAST_THREAD_STATE +#define __Pyx_PyErr_Clear() __Pyx_ErrRestore(NULL, NULL, NULL) +#define __Pyx_ErrRestoreWithState(type, value, tb) __Pyx_ErrRestoreInState(PyThreadState_GET(), type, value, tb) +#define __Pyx_ErrFetchWithState(type, value, tb) __Pyx_ErrFetchInState(PyThreadState_GET(), type, value, tb) +#define __Pyx_ErrRestore(type, value, tb) __Pyx_ErrRestoreInState(__pyx_tstate, type, value, tb) +#define __Pyx_ErrFetch(type, value, tb) __Pyx_ErrFetchInState(__pyx_tstate, type, value, tb) +static CYTHON_INLINE void __Pyx_ErrRestoreInState(PyThreadState *tstate, PyObject *type, PyObject *value, PyObject *tb); +static CYTHON_INLINE void __Pyx_ErrFetchInState(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb); +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030C00A6 +#define __Pyx_PyErr_SetNone(exc) (Py_INCREF(exc), __Pyx_ErrRestore((exc), NULL, NULL)) +#else +#define __Pyx_PyErr_SetNone(exc) PyErr_SetNone(exc) +#endif +#else +#define __Pyx_PyErr_Clear() PyErr_Clear() +#define __Pyx_PyErr_SetNone(exc) PyErr_SetNone(exc) +#define __Pyx_ErrRestoreWithState(type, value, tb) PyErr_Restore(type, value, tb) +#define __Pyx_ErrFetchWithState(type, value, tb) PyErr_Fetch(type, value, tb) +#define __Pyx_ErrRestoreInState(tstate, type, value, tb) PyErr_Restore(type, value, tb) +#define __Pyx_ErrFetchInState(tstate, type, value, tb) PyErr_Fetch(type, value, tb) +#define __Pyx_ErrRestore(type, value, tb) PyErr_Restore(type, value, tb) +#define __Pyx_ErrFetch(type, value, tb) PyErr_Fetch(type, value, tb) +#endif + +/* PyObjectGetAttrStr.proto */ +#if CYTHON_USE_TYPE_SLOTS +static CYTHON_INLINE PyObject* __Pyx_PyObject_GetAttrStr(PyObject* obj, PyObject* attr_name); +#else +#define __Pyx_PyObject_GetAttrStr(o,n) PyObject_GetAttr(o,n) +#endif + +/* PyObjectGetAttrStrNoError.proto */ +static CYTHON_INLINE PyObject* __Pyx_PyObject_GetAttrStrNoError(PyObject* obj, PyObject* attr_name); + +/* GetBuiltinName.proto */ +static PyObject *__Pyx_GetBuiltinName(PyObject *name); + +/* TupleAndListFromArray.proto */ +#if CYTHON_COMPILING_IN_CPYTHON +static CYTHON_INLINE PyObject* __Pyx_PyList_FromArray(PyObject *const *src, Py_ssize_t n); +static CYTHON_INLINE PyObject* __Pyx_PyTuple_FromArray(PyObject *const *src, Py_ssize_t n); +#endif + +/* IncludeStringH.proto */ +#include + +/* BytesEquals.proto */ +static CYTHON_INLINE int __Pyx_PyBytes_Equals(PyObject* s1, PyObject* s2, int equals); + +/* UnicodeEquals.proto */ +static CYTHON_INLINE int __Pyx_PyUnicode_Equals(PyObject* s1, PyObject* s2, int equals); + +/* fastcall.proto */ +#if CYTHON_AVOID_BORROWED_REFS + #define __Pyx_Arg_VARARGS(args, i) PySequence_GetItem(args, i) +#elif CYTHON_ASSUME_SAFE_MACROS + #define __Pyx_Arg_VARARGS(args, i) PyTuple_GET_ITEM(args, i) +#else + #define __Pyx_Arg_VARARGS(args, i) PyTuple_GetItem(args, i) +#endif +#if CYTHON_AVOID_BORROWED_REFS + #define __Pyx_Arg_NewRef_VARARGS(arg) __Pyx_NewRef(arg) + #define __Pyx_Arg_XDECREF_VARARGS(arg) Py_XDECREF(arg) +#else + #define __Pyx_Arg_NewRef_VARARGS(arg) arg + #define __Pyx_Arg_XDECREF_VARARGS(arg) +#endif +#define __Pyx_NumKwargs_VARARGS(kwds) PyDict_Size(kwds) +#define __Pyx_KwValues_VARARGS(args, nargs) NULL +#define __Pyx_GetKwValue_VARARGS(kw, kwvalues, s) __Pyx_PyDict_GetItemStrWithError(kw, s) +#define __Pyx_KwargsAsDict_VARARGS(kw, kwvalues) PyDict_Copy(kw) +#if CYTHON_METH_FASTCALL + #define __Pyx_Arg_FASTCALL(args, i) args[i] + #define __Pyx_NumKwargs_FASTCALL(kwds) PyTuple_GET_SIZE(kwds) + #define __Pyx_KwValues_FASTCALL(args, nargs) ((args) + (nargs)) + static CYTHON_INLINE PyObject * __Pyx_GetKwValue_FASTCALL(PyObject *kwnames, PyObject *const *kwvalues, PyObject *s); +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030d0000 + CYTHON_UNUSED static PyObject *__Pyx_KwargsAsDict_FASTCALL(PyObject *kwnames, PyObject *const *kwvalues); + #else + #define __Pyx_KwargsAsDict_FASTCALL(kw, kwvalues) _PyStack_AsDict(kwvalues, kw) + #endif + #define __Pyx_Arg_NewRef_FASTCALL(arg) arg /* no-op, __Pyx_Arg_FASTCALL is direct and this needs + to have the same reference counting */ + #define __Pyx_Arg_XDECREF_FASTCALL(arg) +#else + #define __Pyx_Arg_FASTCALL __Pyx_Arg_VARARGS + #define __Pyx_NumKwargs_FASTCALL __Pyx_NumKwargs_VARARGS + #define __Pyx_KwValues_FASTCALL __Pyx_KwValues_VARARGS + #define __Pyx_GetKwValue_FASTCALL __Pyx_GetKwValue_VARARGS + #define __Pyx_KwargsAsDict_FASTCALL __Pyx_KwargsAsDict_VARARGS + #define __Pyx_Arg_NewRef_FASTCALL(arg) __Pyx_Arg_NewRef_VARARGS(arg) + #define __Pyx_Arg_XDECREF_FASTCALL(arg) __Pyx_Arg_XDECREF_VARARGS(arg) +#endif +#if CYTHON_COMPILING_IN_CPYTHON && CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS +#define __Pyx_ArgsSlice_VARARGS(args, start, stop) __Pyx_PyTuple_FromArray(&__Pyx_Arg_VARARGS(args, start), stop - start) +#define __Pyx_ArgsSlice_FASTCALL(args, start, stop) __Pyx_PyTuple_FromArray(&__Pyx_Arg_FASTCALL(args, start), stop - start) +#else +#define __Pyx_ArgsSlice_VARARGS(args, start, stop) PyTuple_GetSlice(args, start, stop) +#define __Pyx_ArgsSlice_FASTCALL(args, start, stop) PyTuple_GetSlice(args, start, stop) +#endif + +/* RaiseArgTupleInvalid.proto */ +static void __Pyx_RaiseArgtupleInvalid(const char* func_name, int exact, + Py_ssize_t num_min, Py_ssize_t num_max, Py_ssize_t num_found); + +/* RaiseDoubleKeywords.proto */ +static void __Pyx_RaiseDoubleKeywordsError(const char* func_name, PyObject* kw_name); + +/* ParseKeywords.proto */ +static int __Pyx_ParseOptionalKeywords(PyObject *kwds, PyObject *const *kwvalues, + PyObject **argnames[], + PyObject *kwds2, PyObject *values[], Py_ssize_t num_pos_args, + const char* function_name); + +/* ArgTypeTest.proto */ +#define __Pyx_ArgTypeTest(obj, type, none_allowed, name, exact)\ + ((likely(__Pyx_IS_TYPE(obj, type) | (none_allowed && (obj == Py_None)))) ? 1 :\ + __Pyx__ArgTypeTest(obj, type, name, exact)) +static int __Pyx__ArgTypeTest(PyObject *obj, PyTypeObject *type, const char *name, int exact); + +/* RaiseException.proto */ +static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause); + +/* PyFunctionFastCall.proto */ +#if CYTHON_FAST_PYCALL +#if !CYTHON_VECTORCALL +#define __Pyx_PyFunction_FastCall(func, args, nargs)\ + __Pyx_PyFunction_FastCallDict((func), (args), (nargs), NULL) +static PyObject *__Pyx_PyFunction_FastCallDict(PyObject *func, PyObject **args, Py_ssize_t nargs, PyObject *kwargs); +#endif +#define __Pyx_BUILD_ASSERT_EXPR(cond)\ + (sizeof(char [1 - 2*!(cond)]) - 1) +#ifndef Py_MEMBER_SIZE +#define Py_MEMBER_SIZE(type, member) sizeof(((type *)0)->member) +#endif +#if !CYTHON_VECTORCALL +#if PY_VERSION_HEX >= 0x03080000 + #include "frameobject.h" +#if PY_VERSION_HEX >= 0x030b00a6 && !CYTHON_COMPILING_IN_LIMITED_API + #ifndef Py_BUILD_CORE + #define Py_BUILD_CORE 1 + #endif + #include "internal/pycore_frame.h" +#endif + #define __Pxy_PyFrame_Initialize_Offsets() + #define __Pyx_PyFrame_GetLocalsplus(frame) ((frame)->f_localsplus) +#else + static size_t __pyx_pyframe_localsplus_offset = 0; + #include "frameobject.h" + #define __Pxy_PyFrame_Initialize_Offsets()\ + ((void)__Pyx_BUILD_ASSERT_EXPR(sizeof(PyFrameObject) == offsetof(PyFrameObject, f_localsplus) + Py_MEMBER_SIZE(PyFrameObject, f_localsplus)),\ + (void)(__pyx_pyframe_localsplus_offset = ((size_t)PyFrame_Type.tp_basicsize) - Py_MEMBER_SIZE(PyFrameObject, f_localsplus))) + #define __Pyx_PyFrame_GetLocalsplus(frame)\ + (assert(__pyx_pyframe_localsplus_offset), (PyObject **)(((char *)(frame)) + __pyx_pyframe_localsplus_offset)) +#endif +#endif +#endif + +/* PyObjectCall.proto */ +#if CYTHON_COMPILING_IN_CPYTHON +static CYTHON_INLINE PyObject* __Pyx_PyObject_Call(PyObject *func, PyObject *arg, PyObject *kw); +#else +#define __Pyx_PyObject_Call(func, arg, kw) PyObject_Call(func, arg, kw) +#endif + +/* PyObjectCallMethO.proto */ +#if CYTHON_COMPILING_IN_CPYTHON +static CYTHON_INLINE PyObject* __Pyx_PyObject_CallMethO(PyObject *func, PyObject *arg); +#endif + +/* PyObjectFastCall.proto */ +#define __Pyx_PyObject_FastCall(func, args, nargs) __Pyx_PyObject_FastCallDict(func, args, (size_t)(nargs), NULL) +static CYTHON_INLINE PyObject* __Pyx_PyObject_FastCallDict(PyObject *func, PyObject **args, size_t nargs, PyObject *kwargs); + +/* RaiseUnexpectedTypeError.proto */ +static int __Pyx_RaiseUnexpectedTypeError(const char *expected, PyObject *obj); + +/* GCCDiagnostics.proto */ +#if !defined(__INTEL_COMPILER) && defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)) +#define __Pyx_HAS_GCC_DIAGNOSTIC +#endif + +/* BuildPyUnicode.proto */ +static PyObject* __Pyx_PyUnicode_BuildFromAscii(Py_ssize_t ulength, char* chars, int clength, + int prepend_sign, char padding_char); + +/* CIntToPyUnicode.proto */ +static CYTHON_INLINE PyObject* __Pyx_PyUnicode_From_int(int value, Py_ssize_t width, char padding_char, char format_char); + +/* CIntToPyUnicode.proto */ +static CYTHON_INLINE PyObject* __Pyx_PyUnicode_From_Py_ssize_t(Py_ssize_t value, Py_ssize_t width, char padding_char, char format_char); + +/* JoinPyUnicode.proto */ +static PyObject* __Pyx_PyUnicode_Join(PyObject* value_tuple, Py_ssize_t value_count, Py_ssize_t result_ulength, + Py_UCS4 max_char); + +/* StrEquals.proto */ +#if PY_MAJOR_VERSION >= 3 +#define __Pyx_PyString_Equals __Pyx_PyUnicode_Equals +#else +#define __Pyx_PyString_Equals __Pyx_PyBytes_Equals +#endif + +/* PyObjectFormatSimple.proto */ +#if CYTHON_COMPILING_IN_PYPY + #define __Pyx_PyObject_FormatSimple(s, f) (\ + likely(PyUnicode_CheckExact(s)) ? (Py_INCREF(s), s) :\ + PyObject_Format(s, f)) +#elif PY_MAJOR_VERSION < 3 + #define __Pyx_PyObject_FormatSimple(s, f) (\ + likely(PyUnicode_CheckExact(s)) ? (Py_INCREF(s), s) :\ + likely(PyString_CheckExact(s)) ? PyUnicode_FromEncodedObject(s, NULL, "strict") :\ + PyObject_Format(s, f)) +#elif CYTHON_USE_TYPE_SLOTS + #define __Pyx_PyObject_FormatSimple(s, f) (\ + likely(PyUnicode_CheckExact(s)) ? (Py_INCREF(s), s) :\ + likely(PyLong_CheckExact(s)) ? PyLong_Type.tp_repr(s) :\ + likely(PyFloat_CheckExact(s)) ? PyFloat_Type.tp_repr(s) :\ + PyObject_Format(s, f)) +#else + #define __Pyx_PyObject_FormatSimple(s, f) (\ + likely(PyUnicode_CheckExact(s)) ? (Py_INCREF(s), s) :\ + PyObject_Format(s, f)) +#endif + +CYTHON_UNUSED static int __pyx_array_getbuffer(PyObject *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags); /*proto*/ +static PyObject *__pyx_array_get_memview(struct __pyx_array_obj *); /*proto*/ +/* GetAttr.proto */ +static CYTHON_INLINE PyObject *__Pyx_GetAttr(PyObject *, PyObject *); + +/* GetItemInt.proto */ +#define __Pyx_GetItemInt(o, i, type, is_signed, to_py_func, is_list, wraparound, boundscheck)\ + (__Pyx_fits_Py_ssize_t(i, type, is_signed) ?\ + __Pyx_GetItemInt_Fast(o, (Py_ssize_t)i, is_list, wraparound, boundscheck) :\ + (is_list ? (PyErr_SetString(PyExc_IndexError, "list index out of range"), (PyObject*)NULL) :\ + __Pyx_GetItemInt_Generic(o, to_py_func(i)))) +#define __Pyx_GetItemInt_List(o, i, type, is_signed, to_py_func, is_list, wraparound, boundscheck)\ + (__Pyx_fits_Py_ssize_t(i, type, is_signed) ?\ + __Pyx_GetItemInt_List_Fast(o, (Py_ssize_t)i, wraparound, boundscheck) :\ + (PyErr_SetString(PyExc_IndexError, "list index out of range"), (PyObject*)NULL)) +static CYTHON_INLINE PyObject *__Pyx_GetItemInt_List_Fast(PyObject *o, Py_ssize_t i, + int wraparound, int boundscheck); +#define __Pyx_GetItemInt_Tuple(o, i, type, is_signed, to_py_func, is_list, wraparound, boundscheck)\ + (__Pyx_fits_Py_ssize_t(i, type, is_signed) ?\ + __Pyx_GetItemInt_Tuple_Fast(o, (Py_ssize_t)i, wraparound, boundscheck) :\ + (PyErr_SetString(PyExc_IndexError, "tuple index out of range"), (PyObject*)NULL)) +static CYTHON_INLINE PyObject *__Pyx_GetItemInt_Tuple_Fast(PyObject *o, Py_ssize_t i, + int wraparound, int boundscheck); +static PyObject *__Pyx_GetItemInt_Generic(PyObject *o, PyObject* j); +static CYTHON_INLINE PyObject *__Pyx_GetItemInt_Fast(PyObject *o, Py_ssize_t i, + int is_list, int wraparound, int boundscheck); + +/* PyObjectCallOneArg.proto */ +static CYTHON_INLINE PyObject* __Pyx_PyObject_CallOneArg(PyObject *func, PyObject *arg); + +/* ObjectGetItem.proto */ +#if CYTHON_USE_TYPE_SLOTS +static CYTHON_INLINE PyObject *__Pyx_PyObject_GetItem(PyObject *obj, PyObject *key); +#else +#define __Pyx_PyObject_GetItem(obj, key) PyObject_GetItem(obj, key) +#endif + +/* KeywordStringCheck.proto */ +static int __Pyx_CheckKeywordStrings(PyObject *kw, const char* function_name, int kw_allowed); + +/* DivInt[Py_ssize_t].proto */ +static CYTHON_INLINE Py_ssize_t __Pyx_div_Py_ssize_t(Py_ssize_t, Py_ssize_t); + +/* UnaryNegOverflows.proto */ +#define __Pyx_UNARY_NEG_WOULD_OVERFLOW(x)\ + (((x) < 0) & ((unsigned long)(x) == 0-(unsigned long)(x))) + +/* GetAttr3.proto */ +static CYTHON_INLINE PyObject *__Pyx_GetAttr3(PyObject *, PyObject *, PyObject *); + +/* PyDictVersioning.proto */ +#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_TYPE_SLOTS +#define __PYX_DICT_VERSION_INIT ((PY_UINT64_T) -1) +#define __PYX_GET_DICT_VERSION(dict) (((PyDictObject*)(dict))->ma_version_tag) +#define __PYX_UPDATE_DICT_CACHE(dict, value, cache_var, version_var)\ + (version_var) = __PYX_GET_DICT_VERSION(dict);\ + (cache_var) = (value); +#define __PYX_PY_DICT_LOOKUP_IF_MODIFIED(VAR, DICT, LOOKUP) {\ + static PY_UINT64_T __pyx_dict_version = 0;\ + static PyObject *__pyx_dict_cached_value = NULL;\ + if (likely(__PYX_GET_DICT_VERSION(DICT) == __pyx_dict_version)) {\ + (VAR) = __pyx_dict_cached_value;\ + } else {\ + (VAR) = __pyx_dict_cached_value = (LOOKUP);\ + __pyx_dict_version = __PYX_GET_DICT_VERSION(DICT);\ + }\ +} +static CYTHON_INLINE PY_UINT64_T __Pyx_get_tp_dict_version(PyObject *obj); +static CYTHON_INLINE PY_UINT64_T __Pyx_get_object_dict_version(PyObject *obj); +static CYTHON_INLINE int __Pyx_object_dict_version_matches(PyObject* obj, PY_UINT64_T tp_dict_version, PY_UINT64_T obj_dict_version); +#else +#define __PYX_GET_DICT_VERSION(dict) (0) +#define __PYX_UPDATE_DICT_CACHE(dict, value, cache_var, version_var) +#define __PYX_PY_DICT_LOOKUP_IF_MODIFIED(VAR, DICT, LOOKUP) (VAR) = (LOOKUP); +#endif + +/* GetModuleGlobalName.proto */ +#if CYTHON_USE_DICT_VERSIONS +#define __Pyx_GetModuleGlobalName(var, name) do {\ + static PY_UINT64_T __pyx_dict_version = 0;\ + static PyObject *__pyx_dict_cached_value = NULL;\ + (var) = (likely(__pyx_dict_version == __PYX_GET_DICT_VERSION(__pyx_d))) ?\ + (likely(__pyx_dict_cached_value) ? __Pyx_NewRef(__pyx_dict_cached_value) : __Pyx_GetBuiltinName(name)) :\ + __Pyx__GetModuleGlobalName(name, &__pyx_dict_version, &__pyx_dict_cached_value);\ +} while(0) +#define __Pyx_GetModuleGlobalNameUncached(var, name) do {\ + PY_UINT64_T __pyx_dict_version;\ + PyObject *__pyx_dict_cached_value;\ + (var) = __Pyx__GetModuleGlobalName(name, &__pyx_dict_version, &__pyx_dict_cached_value);\ +} while(0) +static PyObject *__Pyx__GetModuleGlobalName(PyObject *name, PY_UINT64_T *dict_version, PyObject **dict_cached_value); +#else +#define __Pyx_GetModuleGlobalName(var, name) (var) = __Pyx__GetModuleGlobalName(name) +#define __Pyx_GetModuleGlobalNameUncached(var, name) (var) = __Pyx__GetModuleGlobalName(name) +static CYTHON_INLINE PyObject *__Pyx__GetModuleGlobalName(PyObject *name); +#endif + +/* AssertionsEnabled.proto */ +#if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX < 0x02070600 && !defined(Py_OptimizeFlag) + #define __Pyx_init_assertions_enabled() (0) + #define __pyx_assertions_enabled() (1) +#elif CYTHON_COMPILING_IN_LIMITED_API || (CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030C0000) + static int __pyx_assertions_enabled_flag; + #define __pyx_assertions_enabled() (__pyx_assertions_enabled_flag) + static int __Pyx_init_assertions_enabled(void) { + PyObject *builtins, *debug, *debug_str; + int flag; + builtins = PyEval_GetBuiltins(); + if (!builtins) goto bad; + debug_str = PyUnicode_FromStringAndSize("__debug__", 9); + if (!debug_str) goto bad; + debug = PyObject_GetItem(builtins, debug_str); + Py_DECREF(debug_str); + if (!debug) goto bad; + flag = PyObject_IsTrue(debug); + Py_DECREF(debug); + if (flag == -1) goto bad; + __pyx_assertions_enabled_flag = flag; + return 0; + bad: + __pyx_assertions_enabled_flag = 1; + return -1; + } +#else + #define __Pyx_init_assertions_enabled() (0) + #define __pyx_assertions_enabled() (!Py_OptimizeFlag) +#endif + +/* RaiseTooManyValuesToUnpack.proto */ +static CYTHON_INLINE void __Pyx_RaiseTooManyValuesError(Py_ssize_t expected); + +/* RaiseNeedMoreValuesToUnpack.proto */ +static CYTHON_INLINE void __Pyx_RaiseNeedMoreValuesError(Py_ssize_t index); + +/* RaiseNoneIterError.proto */ +static CYTHON_INLINE void __Pyx_RaiseNoneNotIterableError(void); + +/* ExtTypeTest.proto */ +static CYTHON_INLINE int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type); + +/* GetTopmostException.proto */ +#if CYTHON_USE_EXC_INFO_STACK && CYTHON_FAST_THREAD_STATE +static _PyErr_StackItem * __Pyx_PyErr_GetTopmostException(PyThreadState *tstate); +#endif + +/* SaveResetException.proto */ +#if CYTHON_FAST_THREAD_STATE +#define __Pyx_ExceptionSave(type, value, tb) __Pyx__ExceptionSave(__pyx_tstate, type, value, tb) +static CYTHON_INLINE void __Pyx__ExceptionSave(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb); +#define __Pyx_ExceptionReset(type, value, tb) __Pyx__ExceptionReset(__pyx_tstate, type, value, tb) +static CYTHON_INLINE void __Pyx__ExceptionReset(PyThreadState *tstate, PyObject *type, PyObject *value, PyObject *tb); +#else +#define __Pyx_ExceptionSave(type, value, tb) PyErr_GetExcInfo(type, value, tb) +#define __Pyx_ExceptionReset(type, value, tb) PyErr_SetExcInfo(type, value, tb) +#endif + +/* GetException.proto */ +#if CYTHON_FAST_THREAD_STATE +#define __Pyx_GetException(type, value, tb) __Pyx__GetException(__pyx_tstate, type, value, tb) +static int __Pyx__GetException(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb); +#else +static int __Pyx_GetException(PyObject **type, PyObject **value, PyObject **tb); +#endif + +/* SwapException.proto */ +#if CYTHON_FAST_THREAD_STATE +#define __Pyx_ExceptionSwap(type, value, tb) __Pyx__ExceptionSwap(__pyx_tstate, type, value, tb) +static CYTHON_INLINE void __Pyx__ExceptionSwap(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb); +#else +static CYTHON_INLINE void __Pyx_ExceptionSwap(PyObject **type, PyObject **value, PyObject **tb); +#endif + +/* Import.proto */ +static PyObject *__Pyx_Import(PyObject *name, PyObject *from_list, int level); + +/* ImportDottedModule.proto */ +static PyObject *__Pyx_ImportDottedModule(PyObject *name, PyObject *parts_tuple); +#if PY_MAJOR_VERSION >= 3 +static PyObject *__Pyx_ImportDottedModule_WalkParts(PyObject *module, PyObject *name, PyObject *parts_tuple); +#endif + +/* FastTypeChecks.proto */ +#if CYTHON_COMPILING_IN_CPYTHON +#define __Pyx_TypeCheck(obj, type) __Pyx_IsSubtype(Py_TYPE(obj), (PyTypeObject *)type) +#define __Pyx_TypeCheck2(obj, type1, type2) __Pyx_IsAnySubtype2(Py_TYPE(obj), (PyTypeObject *)type1, (PyTypeObject *)type2) +static CYTHON_INLINE int __Pyx_IsSubtype(PyTypeObject *a, PyTypeObject *b); +static CYTHON_INLINE int __Pyx_IsAnySubtype2(PyTypeObject *cls, PyTypeObject *a, PyTypeObject *b); +static CYTHON_INLINE int __Pyx_PyErr_GivenExceptionMatches(PyObject *err, PyObject *type); +static CYTHON_INLINE int __Pyx_PyErr_GivenExceptionMatches2(PyObject *err, PyObject *type1, PyObject *type2); +#else +#define __Pyx_TypeCheck(obj, type) PyObject_TypeCheck(obj, (PyTypeObject *)type) +#define __Pyx_TypeCheck2(obj, type1, type2) (PyObject_TypeCheck(obj, (PyTypeObject *)type1) || PyObject_TypeCheck(obj, (PyTypeObject *)type2)) +#define __Pyx_PyErr_GivenExceptionMatches(err, type) PyErr_GivenExceptionMatches(err, type) +#define __Pyx_PyErr_GivenExceptionMatches2(err, type1, type2) (PyErr_GivenExceptionMatches(err, type1) || PyErr_GivenExceptionMatches(err, type2)) +#endif +#define __Pyx_PyErr_ExceptionMatches2(err1, err2) __Pyx_PyErr_GivenExceptionMatches2(__Pyx_PyErr_CurrentExceptionType(), err1, err2) +#define __Pyx_PyException_Check(obj) __Pyx_TypeCheck(obj, PyExc_Exception) + +CYTHON_UNUSED static int __pyx_memoryview_getbuffer(PyObject *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags); /*proto*/ +/* ListCompAppend.proto */ +#if CYTHON_USE_PYLIST_INTERNALS && CYTHON_ASSUME_SAFE_MACROS +static CYTHON_INLINE int __Pyx_ListComp_Append(PyObject* list, PyObject* x) { + PyListObject* L = (PyListObject*) list; + Py_ssize_t len = Py_SIZE(list); + if (likely(L->allocated > len)) { + Py_INCREF(x); + #if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030d0000 + L->ob_item[len] = x; + #else + PyList_SET_ITEM(list, len, x); + #endif + __Pyx_SET_SIZE(list, len + 1); + return 0; + } + return PyList_Append(list, x); +} +#else +#define __Pyx_ListComp_Append(L,x) PyList_Append(L,x) +#endif + +/* PySequenceMultiply.proto */ +#define __Pyx_PySequence_Multiply_Left(mul, seq) __Pyx_PySequence_Multiply(seq, mul) +static CYTHON_INLINE PyObject* __Pyx_PySequence_Multiply(PyObject *seq, Py_ssize_t mul); + +/* SetItemInt.proto */ +#define __Pyx_SetItemInt(o, i, v, type, is_signed, to_py_func, is_list, wraparound, boundscheck)\ + (__Pyx_fits_Py_ssize_t(i, type, is_signed) ?\ + __Pyx_SetItemInt_Fast(o, (Py_ssize_t)i, v, is_list, wraparound, boundscheck) :\ + (is_list ? (PyErr_SetString(PyExc_IndexError, "list assignment index out of range"), -1) :\ + __Pyx_SetItemInt_Generic(o, to_py_func(i), v))) +static int __Pyx_SetItemInt_Generic(PyObject *o, PyObject *j, PyObject *v); +static CYTHON_INLINE int __Pyx_SetItemInt_Fast(PyObject *o, Py_ssize_t i, PyObject *v, + int is_list, int wraparound, int boundscheck); + +/* RaiseUnboundLocalError.proto */ +static CYTHON_INLINE void __Pyx_RaiseUnboundLocalError(const char *varname); + +/* DivInt[long].proto */ +static CYTHON_INLINE long __Pyx_div_long(long, long); + +/* PySequenceContains.proto */ +static CYTHON_INLINE int __Pyx_PySequence_ContainsTF(PyObject* item, PyObject* seq, int eq) { + int result = PySequence_Contains(seq, item); + return unlikely(result < 0) ? result : (result == (eq == Py_EQ)); +} + +/* ImportFrom.proto */ +static PyObject* __Pyx_ImportFrom(PyObject* module, PyObject* name); + +/* HasAttr.proto */ +#if __PYX_LIMITED_VERSION_HEX >= 0x030d00A1 +#define __Pyx_HasAttr(o, n) PyObject_HasAttrWithError(o, n) +#else +static CYTHON_INLINE int __Pyx_HasAttr(PyObject *, PyObject *); +#endif + +/* IsLittleEndian.proto */ +static CYTHON_INLINE int __Pyx_Is_Little_Endian(void); + +/* BufferFormatCheck.proto */ +static const char* __Pyx_BufFmt_CheckString(__Pyx_BufFmt_Context* ctx, const char* ts); +static void __Pyx_BufFmt_Init(__Pyx_BufFmt_Context* ctx, + __Pyx_BufFmt_StackElem* stack, + __Pyx_TypeInfo* type); + +/* BufferGetAndValidate.proto */ +#define __Pyx_GetBufferAndValidate(buf, obj, dtype, flags, nd, cast, stack)\ + ((obj == Py_None || obj == NULL) ?\ + (__Pyx_ZeroBuffer(buf), 0) :\ + __Pyx__GetBufferAndValidate(buf, obj, dtype, flags, nd, cast, stack)) +static int __Pyx__GetBufferAndValidate(Py_buffer* buf, PyObject* obj, + __Pyx_TypeInfo* dtype, int flags, int nd, int cast, __Pyx_BufFmt_StackElem* stack); +static void __Pyx_ZeroBuffer(Py_buffer* buf); +static CYTHON_INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info); +static Py_ssize_t __Pyx_minusones[] = { -1, -1, -1, -1, -1, -1, -1, -1 }; +static Py_ssize_t __Pyx_zeros[] = { 0, 0, 0, 0, 0, 0, 0, 0 }; + +#define __Pyx_BufPtrStrided1d(type, buf, i0, s0) (type)((char*)buf + i0 * s0) +/* BufferIndexError.proto */ +static void __Pyx_RaiseBufferIndexError(int axis); + +/* ListAppend.proto */ +#if CYTHON_USE_PYLIST_INTERNALS && CYTHON_ASSUME_SAFE_MACROS +static CYTHON_INLINE int __Pyx_PyList_Append(PyObject* list, PyObject* x) { + PyListObject* L = (PyListObject*) list; + Py_ssize_t len = Py_SIZE(list); + if (likely(L->allocated > len) & likely(len > (L->allocated >> 1))) { + Py_INCREF(x); + #if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030d0000 + L->ob_item[len] = x; + #else + PyList_SET_ITEM(list, len, x); + #endif + __Pyx_SET_SIZE(list, len + 1); + return 0; + } + return PyList_Append(list, x); +} +#else +#define __Pyx_PyList_Append(L,x) PyList_Append(L,x) +#endif + +/* PyIntCompare.proto */ +static CYTHON_INLINE int __Pyx_PyInt_BoolEqObjC(PyObject *op1, PyObject *op2, long intval, long inplace); + +/* PyObject_GenericGetAttrNoDict.proto */ +#if CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP && PY_VERSION_HEX < 0x03070000 +static CYTHON_INLINE PyObject* __Pyx_PyObject_GenericGetAttrNoDict(PyObject* obj, PyObject* attr_name); +#else +#define __Pyx_PyObject_GenericGetAttrNoDict PyObject_GenericGetAttr +#endif + +/* PyObject_GenericGetAttr.proto */ +#if CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP && PY_VERSION_HEX < 0x03070000 +static PyObject* __Pyx_PyObject_GenericGetAttr(PyObject* obj, PyObject* attr_name); +#else +#define __Pyx_PyObject_GenericGetAttr PyObject_GenericGetAttr +#endif + +/* IncludeStructmemberH.proto */ +#include + +/* FixUpExtensionType.proto */ +#if CYTHON_USE_TYPE_SPECS +static int __Pyx_fix_up_extension_type_from_spec(PyType_Spec *spec, PyTypeObject *type); +#endif + +/* PyObjectCallNoArg.proto */ +static CYTHON_INLINE PyObject* __Pyx_PyObject_CallNoArg(PyObject *func); + +/* PyObjectGetMethod.proto */ +static int __Pyx_PyObject_GetMethod(PyObject *obj, PyObject *name, PyObject **method); + +/* PyObjectCallMethod0.proto */ +static PyObject* __Pyx_PyObject_CallMethod0(PyObject* obj, PyObject* method_name); + +/* ValidateBasesTuple.proto */ +#if CYTHON_COMPILING_IN_CPYTHON || CYTHON_COMPILING_IN_LIMITED_API || CYTHON_USE_TYPE_SPECS +static int __Pyx_validate_bases_tuple(const char *type_name, Py_ssize_t dictoffset, PyObject *bases); +#endif + +/* PyType_Ready.proto */ +CYTHON_UNUSED static int __Pyx_PyType_Ready(PyTypeObject *t); + +/* SetVTable.proto */ +static int __Pyx_SetVtable(PyTypeObject* typeptr , void* vtable); + +/* GetVTable.proto */ +static void* __Pyx_GetVtable(PyTypeObject *type); + +/* MergeVTables.proto */ +#if !CYTHON_COMPILING_IN_LIMITED_API +static int __Pyx_MergeVtables(PyTypeObject *type); +#endif + +/* SetupReduce.proto */ +#if !CYTHON_COMPILING_IN_LIMITED_API +static int __Pyx_setup_reduce(PyObject* type_obj); +#endif + +/* TypeImport.proto */ +#ifndef __PYX_HAVE_RT_ImportType_proto_3_0_8 +#define __PYX_HAVE_RT_ImportType_proto_3_0_8 +#if defined (__STDC_VERSION__) && __STDC_VERSION__ >= 201112L +#include +#endif +#if (defined (__STDC_VERSION__) && __STDC_VERSION__ >= 201112L) || __cplusplus >= 201103L +#define __PYX_GET_STRUCT_ALIGNMENT_3_0_8(s) alignof(s) +#else +#define __PYX_GET_STRUCT_ALIGNMENT_3_0_8(s) sizeof(void*) +#endif +enum __Pyx_ImportType_CheckSize_3_0_8 { + __Pyx_ImportType_CheckSize_Error_3_0_8 = 0, + __Pyx_ImportType_CheckSize_Warn_3_0_8 = 1, + __Pyx_ImportType_CheckSize_Ignore_3_0_8 = 2 +}; +static PyTypeObject *__Pyx_ImportType_3_0_8(PyObject* module, const char *module_name, const char *class_name, size_t size, size_t alignment, enum __Pyx_ImportType_CheckSize_3_0_8 check_size); +#endif + +/* FetchSharedCythonModule.proto */ +static PyObject *__Pyx_FetchSharedCythonABIModule(void); + +/* FetchCommonType.proto */ +#if !CYTHON_USE_TYPE_SPECS +static PyTypeObject* __Pyx_FetchCommonType(PyTypeObject* type); +#else +static PyTypeObject* __Pyx_FetchCommonTypeFromSpec(PyObject *module, PyType_Spec *spec, PyObject *bases); +#endif + +/* PyMethodNew.proto */ +#if CYTHON_COMPILING_IN_LIMITED_API +static PyObject *__Pyx_PyMethod_New(PyObject *func, PyObject *self, PyObject *typ) { + PyObject *typesModule=NULL, *methodType=NULL, *result=NULL; + CYTHON_UNUSED_VAR(typ); + if (!self) + return __Pyx_NewRef(func); + typesModule = PyImport_ImportModule("types"); + if (!typesModule) return NULL; + methodType = PyObject_GetAttrString(typesModule, "MethodType"); + Py_DECREF(typesModule); + if (!methodType) return NULL; + result = PyObject_CallFunctionObjArgs(methodType, func, self, NULL); + Py_DECREF(methodType); + return result; +} +#elif PY_MAJOR_VERSION >= 3 +static PyObject *__Pyx_PyMethod_New(PyObject *func, PyObject *self, PyObject *typ) { + CYTHON_UNUSED_VAR(typ); + if (!self) + return __Pyx_NewRef(func); + return PyMethod_New(func, self); +} +#else + #define __Pyx_PyMethod_New PyMethod_New +#endif + +/* PyVectorcallFastCallDict.proto */ +#if CYTHON_METH_FASTCALL +static CYTHON_INLINE PyObject *__Pyx_PyVectorcall_FastCallDict(PyObject *func, __pyx_vectorcallfunc vc, PyObject *const *args, size_t nargs, PyObject *kw); +#endif + +/* CythonFunctionShared.proto */ +#define __Pyx_CyFunction_USED +#define __Pyx_CYFUNCTION_STATICMETHOD 0x01 +#define __Pyx_CYFUNCTION_CLASSMETHOD 0x02 +#define __Pyx_CYFUNCTION_CCLASS 0x04 +#define __Pyx_CYFUNCTION_COROUTINE 0x08 +#define __Pyx_CyFunction_GetClosure(f)\ + (((__pyx_CyFunctionObject *) (f))->func_closure) +#if PY_VERSION_HEX < 0x030900B1 || CYTHON_COMPILING_IN_LIMITED_API + #define __Pyx_CyFunction_GetClassObj(f)\ + (((__pyx_CyFunctionObject *) (f))->func_classobj) +#else + #define __Pyx_CyFunction_GetClassObj(f)\ + ((PyObject*) ((PyCMethodObject *) (f))->mm_class) +#endif +#define __Pyx_CyFunction_SetClassObj(f, classobj)\ + __Pyx__CyFunction_SetClassObj((__pyx_CyFunctionObject *) (f), (classobj)) +#define __Pyx_CyFunction_Defaults(type, f)\ + ((type *)(((__pyx_CyFunctionObject *) (f))->defaults)) +#define __Pyx_CyFunction_SetDefaultsGetter(f, g)\ + ((__pyx_CyFunctionObject *) (f))->defaults_getter = (g) +typedef struct { +#if CYTHON_COMPILING_IN_LIMITED_API + PyObject_HEAD + PyObject *func; +#elif PY_VERSION_HEX < 0x030900B1 + PyCFunctionObject func; +#else + PyCMethodObject func; +#endif +#if CYTHON_BACKPORT_VECTORCALL + __pyx_vectorcallfunc func_vectorcall; +#endif +#if PY_VERSION_HEX < 0x030500A0 || CYTHON_COMPILING_IN_LIMITED_API + PyObject *func_weakreflist; +#endif + PyObject *func_dict; + PyObject *func_name; + PyObject *func_qualname; + PyObject *func_doc; + PyObject *func_globals; + PyObject *func_code; + PyObject *func_closure; +#if PY_VERSION_HEX < 0x030900B1 || CYTHON_COMPILING_IN_LIMITED_API + PyObject *func_classobj; +#endif + void *defaults; + int defaults_pyobjects; + size_t defaults_size; + int flags; + PyObject *defaults_tuple; + PyObject *defaults_kwdict; + PyObject *(*defaults_getter)(PyObject *); + PyObject *func_annotations; + PyObject *func_is_coroutine; +} __pyx_CyFunctionObject; +#undef __Pyx_CyOrPyCFunction_Check +#define __Pyx_CyFunction_Check(obj) __Pyx_TypeCheck(obj, __pyx_CyFunctionType) +#define __Pyx_CyOrPyCFunction_Check(obj) __Pyx_TypeCheck2(obj, __pyx_CyFunctionType, &PyCFunction_Type) +#define __Pyx_CyFunction_CheckExact(obj) __Pyx_IS_TYPE(obj, __pyx_CyFunctionType) +static CYTHON_INLINE int __Pyx__IsSameCyOrCFunction(PyObject *func, void *cfunc); +#undef __Pyx_IsSameCFunction +#define __Pyx_IsSameCFunction(func, cfunc) __Pyx__IsSameCyOrCFunction(func, cfunc) +static PyObject *__Pyx_CyFunction_Init(__pyx_CyFunctionObject* op, PyMethodDef *ml, + int flags, PyObject* qualname, + PyObject *closure, + PyObject *module, PyObject *globals, + PyObject* code); +static CYTHON_INLINE void __Pyx__CyFunction_SetClassObj(__pyx_CyFunctionObject* f, PyObject* classobj); +static CYTHON_INLINE void *__Pyx_CyFunction_InitDefaults(PyObject *m, + size_t size, + int pyobjects); +static CYTHON_INLINE void __Pyx_CyFunction_SetDefaultsTuple(PyObject *m, + PyObject *tuple); +static CYTHON_INLINE void __Pyx_CyFunction_SetDefaultsKwDict(PyObject *m, + PyObject *dict); +static CYTHON_INLINE void __Pyx_CyFunction_SetAnnotationsDict(PyObject *m, + PyObject *dict); +static int __pyx_CyFunction_init(PyObject *module); +#if CYTHON_METH_FASTCALL +static PyObject * __Pyx_CyFunction_Vectorcall_NOARGS(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames); +static PyObject * __Pyx_CyFunction_Vectorcall_O(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames); +static PyObject * __Pyx_CyFunction_Vectorcall_FASTCALL_KEYWORDS(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames); +static PyObject * __Pyx_CyFunction_Vectorcall_FASTCALL_KEYWORDS_METHOD(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames); +#if CYTHON_BACKPORT_VECTORCALL +#define __Pyx_CyFunction_func_vectorcall(f) (((__pyx_CyFunctionObject*)f)->func_vectorcall) +#else +#define __Pyx_CyFunction_func_vectorcall(f) (((PyCFunctionObject*)f)->vectorcall) +#endif +#endif + +/* CythonFunction.proto */ +static PyObject *__Pyx_CyFunction_New(PyMethodDef *ml, + int flags, PyObject* qualname, + PyObject *closure, + PyObject *module, PyObject *globals, + PyObject* code); + +/* CLineInTraceback.proto */ +#ifdef CYTHON_CLINE_IN_TRACEBACK +#define __Pyx_CLineForTraceback(tstate, c_line) (((CYTHON_CLINE_IN_TRACEBACK)) ? c_line : 0) +#else +static int __Pyx_CLineForTraceback(PyThreadState *tstate, int c_line); +#endif + +/* CodeObjectCache.proto */ +#if !CYTHON_COMPILING_IN_LIMITED_API +typedef struct { + PyCodeObject* code_object; + int code_line; +} __Pyx_CodeObjectCacheEntry; +struct __Pyx_CodeObjectCache { + int count; + int max_count; + __Pyx_CodeObjectCacheEntry* entries; +}; +static struct __Pyx_CodeObjectCache __pyx_code_cache = {0,0,NULL}; +static int __pyx_bisect_code_objects(__Pyx_CodeObjectCacheEntry* entries, int count, int code_line); +static PyCodeObject *__pyx_find_code_object(int code_line); +static void __pyx_insert_code_object(int code_line, PyCodeObject* code_object); +#endif + +/* AddTraceback.proto */ +static void __Pyx_AddTraceback(const char *funcname, int c_line, + int py_line, const char *filename); + +#if PY_MAJOR_VERSION < 3 + static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags); + static void __Pyx_ReleaseBuffer(Py_buffer *view); +#else + #define __Pyx_GetBuffer PyObject_GetBuffer + #define __Pyx_ReleaseBuffer PyBuffer_Release +#endif + + +/* BufferStructDeclare.proto */ +typedef struct { + Py_ssize_t shape, strides, suboffsets; +} __Pyx_Buf_DimInfo; +typedef struct { + size_t refcount; + Py_buffer pybuffer; +} __Pyx_Buffer; +typedef struct { + __Pyx_Buffer *rcbuffer; + char *data; + __Pyx_Buf_DimInfo diminfo[8]; +} __Pyx_LocalBuf_ND; + +/* MemviewSliceIsContig.proto */ +static int __pyx_memviewslice_is_contig(const __Pyx_memviewslice mvs, char order, int ndim); + +/* OverlappingSlices.proto */ +static int __pyx_slices_overlap(__Pyx_memviewslice *slice1, + __Pyx_memviewslice *slice2, + int ndim, size_t itemsize); + +/* TypeInfoCompare.proto */ +static int __pyx_typeinfo_cmp(__Pyx_TypeInfo *a, __Pyx_TypeInfo *b); + +/* MemviewSliceValidateAndInit.proto */ +static int __Pyx_ValidateAndInit_memviewslice( + int *axes_specs, + int c_or_f_flag, + int buf_flags, + int ndim, + __Pyx_TypeInfo *dtype, + __Pyx_BufFmt_StackElem stack[], + __Pyx_memviewslice *memviewslice, + PyObject *original_obj); + +/* ObjectToMemviewSlice.proto */ +static CYTHON_INLINE __Pyx_memviewslice __Pyx_PyObject_to_MemoryviewSlice_ds_nn_int32_t(PyObject *, int writable_flag); + +/* ObjectToMemviewSlice.proto */ +static CYTHON_INLINE __Pyx_memviewslice __Pyx_PyObject_to_MemoryviewSlice_ds_nn_int64_t(PyObject *, int writable_flag); + +/* ObjectToMemviewSlice.proto */ +static CYTHON_INLINE __Pyx_memviewslice __Pyx_PyObject_to_MemoryviewSlice_ds_nn___pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t(PyObject *, int writable_flag); + +/* ObjectToMemviewSlice.proto */ +static CYTHON_INLINE __Pyx_memviewslice __Pyx_PyObject_to_MemoryviewSlice_dsds_nn___pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t(PyObject *, int writable_flag); + +/* MemviewDtypeToObject.proto */ +static CYTHON_INLINE PyObject *__pyx_memview_get_nn___pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t(const char *itemp); +static CYTHON_INLINE int __pyx_memview_set_nn___pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t(const char *itemp, PyObject *obj); + +/* RealImag.proto */ +#if CYTHON_CCOMPLEX + #ifdef __cplusplus + #define __Pyx_CREAL(z) ((z).real()) + #define __Pyx_CIMAG(z) ((z).imag()) + #else + #define __Pyx_CREAL(z) (__real__(z)) + #define __Pyx_CIMAG(z) (__imag__(z)) + #endif +#else + #define __Pyx_CREAL(z) ((z).real) + #define __Pyx_CIMAG(z) ((z).imag) +#endif +#if defined(__cplusplus) && CYTHON_CCOMPLEX\ + && (defined(_WIN32) || defined(__clang__) || (defined(__GNUC__) && (__GNUC__ >= 5 || __GNUC__ == 4 && __GNUC_MINOR__ >= 4 )) || __cplusplus >= 201103) + #define __Pyx_SET_CREAL(z,x) ((z).real(x)) + #define __Pyx_SET_CIMAG(z,y) ((z).imag(y)) +#else + #define __Pyx_SET_CREAL(z,x) __Pyx_CREAL(z) = (x) + #define __Pyx_SET_CIMAG(z,y) __Pyx_CIMAG(z) = (y) +#endif + +/* Arithmetic.proto */ +#if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) + #define __Pyx_c_eq_float(a, b) ((a)==(b)) + #define __Pyx_c_sum_float(a, b) ((a)+(b)) + #define __Pyx_c_diff_float(a, b) ((a)-(b)) + #define __Pyx_c_prod_float(a, b) ((a)*(b)) + #define __Pyx_c_quot_float(a, b) ((a)/(b)) + #define __Pyx_c_neg_float(a) (-(a)) + #ifdef __cplusplus + #define __Pyx_c_is_zero_float(z) ((z)==(float)0) + #define __Pyx_c_conj_float(z) (::std::conj(z)) + #if 1 + #define __Pyx_c_abs_float(z) (::std::abs(z)) + #define __Pyx_c_pow_float(a, b) (::std::pow(a, b)) + #endif + #else + #define __Pyx_c_is_zero_float(z) ((z)==0) + #define __Pyx_c_conj_float(z) (conjf(z)) + #if 1 + #define __Pyx_c_abs_float(z) (cabsf(z)) + #define __Pyx_c_pow_float(a, b) (cpowf(a, b)) + #endif + #endif +#else + static CYTHON_INLINE int __Pyx_c_eq_float(__pyx_t_float_complex, __pyx_t_float_complex); + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_sum_float(__pyx_t_float_complex, __pyx_t_float_complex); + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_diff_float(__pyx_t_float_complex, __pyx_t_float_complex); + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_prod_float(__pyx_t_float_complex, __pyx_t_float_complex); + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_quot_float(__pyx_t_float_complex, __pyx_t_float_complex); + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_neg_float(__pyx_t_float_complex); + static CYTHON_INLINE int __Pyx_c_is_zero_float(__pyx_t_float_complex); + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_conj_float(__pyx_t_float_complex); + #if 1 + static CYTHON_INLINE float __Pyx_c_abs_float(__pyx_t_float_complex); + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_pow_float(__pyx_t_float_complex, __pyx_t_float_complex); + #endif +#endif + +/* Arithmetic.proto */ +#if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) + #define __Pyx_c_eq_double(a, b) ((a)==(b)) + #define __Pyx_c_sum_double(a, b) ((a)+(b)) + #define __Pyx_c_diff_double(a, b) ((a)-(b)) + #define __Pyx_c_prod_double(a, b) ((a)*(b)) + #define __Pyx_c_quot_double(a, b) ((a)/(b)) + #define __Pyx_c_neg_double(a) (-(a)) + #ifdef __cplusplus + #define __Pyx_c_is_zero_double(z) ((z)==(double)0) + #define __Pyx_c_conj_double(z) (::std::conj(z)) + #if 1 + #define __Pyx_c_abs_double(z) (::std::abs(z)) + #define __Pyx_c_pow_double(a, b) (::std::pow(a, b)) + #endif + #else + #define __Pyx_c_is_zero_double(z) ((z)==0) + #define __Pyx_c_conj_double(z) (conj(z)) + #if 1 + #define __Pyx_c_abs_double(z) (cabs(z)) + #define __Pyx_c_pow_double(a, b) (cpow(a, b)) + #endif + #endif +#else + static CYTHON_INLINE int __Pyx_c_eq_double(__pyx_t_double_complex, __pyx_t_double_complex); + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_sum_double(__pyx_t_double_complex, __pyx_t_double_complex); + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_diff_double(__pyx_t_double_complex, __pyx_t_double_complex); + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_prod_double(__pyx_t_double_complex, __pyx_t_double_complex); + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_quot_double(__pyx_t_double_complex, __pyx_t_double_complex); + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_neg_double(__pyx_t_double_complex); + static CYTHON_INLINE int __Pyx_c_is_zero_double(__pyx_t_double_complex); + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_conj_double(__pyx_t_double_complex); + #if 1 + static CYTHON_INLINE double __Pyx_c_abs_double(__pyx_t_double_complex); + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_pow_double(__pyx_t_double_complex, __pyx_t_double_complex); + #endif +#endif + +/* MemviewSliceCopyTemplate.proto */ +static __Pyx_memviewslice +__pyx_memoryview_copy_new_contig(const __Pyx_memviewslice *from_mvs, + const char *mode, int ndim, + size_t sizeof_dtype, int contig_flag, + int dtype_is_object); + +/* MemviewSliceInit.proto */ +#define __Pyx_BUF_MAX_NDIMS %(BUF_MAX_NDIMS)d +#define __Pyx_MEMVIEW_DIRECT 1 +#define __Pyx_MEMVIEW_PTR 2 +#define __Pyx_MEMVIEW_FULL 4 +#define __Pyx_MEMVIEW_CONTIG 8 +#define __Pyx_MEMVIEW_STRIDED 16 +#define __Pyx_MEMVIEW_FOLLOW 32 +#define __Pyx_IS_C_CONTIG 1 +#define __Pyx_IS_F_CONTIG 2 +static int __Pyx_init_memviewslice( + struct __pyx_memoryview_obj *memview, + int ndim, + __Pyx_memviewslice *memviewslice, + int memview_is_new_reference); +static CYTHON_INLINE int __pyx_add_acquisition_count_locked( + __pyx_atomic_int_type *acquisition_count, PyThread_type_lock lock); +static CYTHON_INLINE int __pyx_sub_acquisition_count_locked( + __pyx_atomic_int_type *acquisition_count, PyThread_type_lock lock); +#define __pyx_get_slice_count_pointer(memview) (&memview->acquisition_count) +#define __PYX_INC_MEMVIEW(slice, have_gil) __Pyx_INC_MEMVIEW(slice, have_gil, __LINE__) +#define __PYX_XCLEAR_MEMVIEW(slice, have_gil) __Pyx_XCLEAR_MEMVIEW(slice, have_gil, __LINE__) +static CYTHON_INLINE void __Pyx_INC_MEMVIEW(__Pyx_memviewslice *, int, int); +static CYTHON_INLINE void __Pyx_XCLEAR_MEMVIEW(__Pyx_memviewslice *, int, int); + +/* CIntFromPy.proto */ +static CYTHON_INLINE int64_t __Pyx_PyInt_As_int64_t(PyObject *); + +/* CIntFromPy.proto */ +static CYTHON_INLINE int32_t __Pyx_PyInt_As_int32_t(PyObject *); + +/* CIntToPy.proto */ +static CYTHON_INLINE PyObject* __Pyx_PyInt_From_int64_t(int64_t value); + +/* CIntToPy.proto */ +static CYTHON_INLINE PyObject* __Pyx_PyInt_From_int32_t(int32_t value); + +/* None.proto */ +#include + +/* CIntFromPy.proto */ +static CYTHON_INLINE int __Pyx_PyInt_As_int(PyObject *); + +/* CIntFromPy.proto */ +static CYTHON_INLINE long __Pyx_PyInt_As_long(PyObject *); + +/* CIntToPy.proto */ +static CYTHON_INLINE PyObject* __Pyx_PyInt_From_int(int value); + +/* CIntToPy.proto */ +static CYTHON_INLINE PyObject* __Pyx_PyInt_From_long(long value); + +/* CIntFromPy.proto */ +static CYTHON_INLINE char __Pyx_PyInt_As_char(PyObject *); + +/* FormatTypeName.proto */ +#if CYTHON_COMPILING_IN_LIMITED_API +typedef PyObject *__Pyx_TypeName; +#define __Pyx_FMT_TYPENAME "%U" +static __Pyx_TypeName __Pyx_PyType_GetName(PyTypeObject* tp); +#define __Pyx_DECREF_TypeName(obj) Py_XDECREF(obj) +#else +typedef const char *__Pyx_TypeName; +#define __Pyx_FMT_TYPENAME "%.200s" +#define __Pyx_PyType_GetName(tp) ((tp)->tp_name) +#define __Pyx_DECREF_TypeName(obj) +#endif + +/* CheckBinaryVersion.proto */ +static unsigned long __Pyx_get_runtime_version(void); +static int __Pyx_check_binary_version(unsigned long ct_version, unsigned long rt_version, int allow_newer); + +/* InitStrings.proto */ +static int __Pyx_InitStrings(__Pyx_StringTabEntry *t); + +/* #### Code section: module_declarations ### */ +static PyObject *__pyx_array_get_memview(struct __pyx_array_obj *__pyx_v_self); /* proto*/ +static char *__pyx_memoryview_get_item_pointer(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index); /* proto*/ +static PyObject *__pyx_memoryview_is_slice(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_obj); /* proto*/ +static PyObject *__pyx_memoryview_setitem_slice_assignment(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_dst, PyObject *__pyx_v_src); /* proto*/ +static PyObject *__pyx_memoryview_setitem_slice_assign_scalar(struct __pyx_memoryview_obj *__pyx_v_self, struct __pyx_memoryview_obj *__pyx_v_dst, PyObject *__pyx_v_value); /* proto*/ +static PyObject *__pyx_memoryview_setitem_indexed(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index, PyObject *__pyx_v_value); /* proto*/ +static PyObject *__pyx_memoryview_convert_item_to_object(struct __pyx_memoryview_obj *__pyx_v_self, char *__pyx_v_itemp); /* proto*/ +static PyObject *__pyx_memoryview_assign_item_from_object(struct __pyx_memoryview_obj *__pyx_v_self, char *__pyx_v_itemp, PyObject *__pyx_v_value); /* proto*/ +static PyObject *__pyx_memoryview__get_base(struct __pyx_memoryview_obj *__pyx_v_self); /* proto*/ +static PyObject *__pyx_memoryviewslice_convert_item_to_object(struct __pyx_memoryviewslice_obj *__pyx_v_self, char *__pyx_v_itemp); /* proto*/ +static PyObject *__pyx_memoryviewslice_assign_item_from_object(struct __pyx_memoryviewslice_obj *__pyx_v_self, char *__pyx_v_itemp, PyObject *__pyx_v_value); /* proto*/ +static PyObject *__pyx_memoryviewslice__get_base(struct __pyx_memoryviewslice_obj *__pyx_v_self); /* proto*/ +static CYTHON_INLINE PyObject *__pyx_f_5numpy_7ndarray_4base_base(PyArrayObject *__pyx_v_self); /* proto*/ +static CYTHON_INLINE PyArray_Descr *__pyx_f_5numpy_7ndarray_5descr_descr(PyArrayObject *__pyx_v_self); /* proto*/ +static CYTHON_INLINE int __pyx_f_5numpy_7ndarray_4ndim_ndim(PyArrayObject *__pyx_v_self); /* proto*/ +static CYTHON_INLINE npy_intp *__pyx_f_5numpy_7ndarray_5shape_shape(PyArrayObject *__pyx_v_self); /* proto*/ +static CYTHON_INLINE npy_intp *__pyx_f_5numpy_7ndarray_7strides_strides(PyArrayObject *__pyx_v_self); /* proto*/ +static CYTHON_INLINE npy_intp __pyx_f_5numpy_7ndarray_4size_size(PyArrayObject *__pyx_v_self); /* proto*/ +static CYTHON_INLINE char *__pyx_f_5numpy_7ndarray_4data_data(PyArrayObject *__pyx_v_self); /* proto*/ + +/* Module declarations from "cython.view" */ + +/* Module declarations from "cython.dataclasses" */ + +/* Module declarations from "cython" */ + +/* Module declarations from "libc.string" */ + +/* Module declarations from "libc.stdio" */ + +/* Module declarations from "__builtin__" */ + +/* Module declarations from "cpython.type" */ + +/* Module declarations from "cpython" */ + +/* Module declarations from "cpython.object" */ + +/* Module declarations from "cpython.ref" */ + +/* Module declarations from "numpy" */ + +/* Module declarations from "numpy" */ + +/* Module declarations from "libc.stdint" */ + +/* Module declarations from "libcpp" */ + +/* Module declarations from "fairseq.data.data_utils_fast" */ +static PyObject *__pyx_collections_abc_Sequence = 0; +static PyObject *generic = 0; +static PyObject *strided = 0; +static PyObject *indirect = 0; +static PyObject *contiguous = 0; +static PyObject *indirect_contiguous = 0; +static int __pyx_memoryview_thread_locks_used; +static PyThread_type_lock __pyx_memoryview_thread_locks[8]; +static PyObject *__pyx_f_7fairseq_4data_15data_utils_fast_batch_by_size_vec(PyArrayObject *, PyArrayObject *, int64_t, int64_t, int32_t, int __pyx_skip_dispatch); /*proto*/ +static PyObject *__pyx_f_7fairseq_4data_15data_utils_fast_batch_by_size_fn(PyArrayObject *, PyObject *, int64_t, int64_t, int32_t, int __pyx_skip_dispatch); /*proto*/ +static PyObject *__pyx_f_7fairseq_4data_15data_utils_fast__find_valid_shape(__Pyx_memviewslice, int64_t, int64_t); /*proto*/ +static PyObject *__pyx_f_7fairseq_4data_15data_utils_fast_batch_fixed_shapes_fast(PyArrayObject *, PyObject *, PyArrayObject *, int __pyx_skip_dispatch); /*proto*/ +static int __pyx_array_allocate_buffer(struct __pyx_array_obj *); /*proto*/ +static struct __pyx_array_obj *__pyx_array_new(PyObject *, Py_ssize_t, char *, char *, char *); /*proto*/ +static PyObject *__pyx_memoryview_new(PyObject *, int, int, __Pyx_TypeInfo *); /*proto*/ +static CYTHON_INLINE int __pyx_memoryview_check(PyObject *); /*proto*/ +static PyObject *_unellipsify(PyObject *, int); /*proto*/ +static int assert_direct_dimensions(Py_ssize_t *, int); /*proto*/ +static struct __pyx_memoryview_obj *__pyx_memview_slice(struct __pyx_memoryview_obj *, PyObject *); /*proto*/ +static int __pyx_memoryview_slice_memviewslice(__Pyx_memviewslice *, Py_ssize_t, Py_ssize_t, Py_ssize_t, int, int, int *, Py_ssize_t, Py_ssize_t, Py_ssize_t, int, int, int, int); /*proto*/ +static char *__pyx_pybuffer_index(Py_buffer *, char *, Py_ssize_t, Py_ssize_t); /*proto*/ +static int __pyx_memslice_transpose(__Pyx_memviewslice *); /*proto*/ +static PyObject *__pyx_memoryview_fromslice(__Pyx_memviewslice, int, PyObject *(*)(char *), int (*)(char *, PyObject *), int); /*proto*/ +static __Pyx_memviewslice *__pyx_memoryview_get_slice_from_memoryview(struct __pyx_memoryview_obj *, __Pyx_memviewslice *); /*proto*/ +static void __pyx_memoryview_slice_copy(struct __pyx_memoryview_obj *, __Pyx_memviewslice *); /*proto*/ +static PyObject *__pyx_memoryview_copy_object(struct __pyx_memoryview_obj *); /*proto*/ +static PyObject *__pyx_memoryview_copy_object_from_slice(struct __pyx_memoryview_obj *, __Pyx_memviewslice *); /*proto*/ +static Py_ssize_t abs_py_ssize_t(Py_ssize_t); /*proto*/ +static char __pyx_get_best_slice_order(__Pyx_memviewslice *, int); /*proto*/ +static void _copy_strided_to_strided(char *, Py_ssize_t *, char *, Py_ssize_t *, Py_ssize_t *, Py_ssize_t *, int, size_t); /*proto*/ +static void copy_strided_to_strided(__Pyx_memviewslice *, __Pyx_memviewslice *, int, size_t); /*proto*/ +static Py_ssize_t __pyx_memoryview_slice_get_size(__Pyx_memviewslice *, int); /*proto*/ +static Py_ssize_t __pyx_fill_contig_strides_array(Py_ssize_t *, Py_ssize_t *, Py_ssize_t, int, char); /*proto*/ +static void *__pyx_memoryview_copy_data_to_temp(__Pyx_memviewslice *, __Pyx_memviewslice *, char, int); /*proto*/ +static int __pyx_memoryview_err_extents(int, Py_ssize_t, Py_ssize_t); /*proto*/ +static int __pyx_memoryview_err_dim(PyObject *, PyObject *, int); /*proto*/ +static int __pyx_memoryview_err(PyObject *, PyObject *); /*proto*/ +static int __pyx_memoryview_err_no_memory(void); /*proto*/ +static int __pyx_memoryview_copy_contents(__Pyx_memviewslice, __Pyx_memviewslice, int, int, int); /*proto*/ +static void __pyx_memoryview_broadcast_leading(__Pyx_memviewslice *, int, int); /*proto*/ +static void __pyx_memoryview_refcount_copying(__Pyx_memviewslice *, int, int, int); /*proto*/ +static void __pyx_memoryview_refcount_objects_in_slice_with_gil(char *, Py_ssize_t *, Py_ssize_t *, int, int); /*proto*/ +static void __pyx_memoryview_refcount_objects_in_slice(char *, Py_ssize_t *, Py_ssize_t *, int, int); /*proto*/ +static void __pyx_memoryview_slice_assign_scalar(__Pyx_memviewslice *, int, size_t, void *, int); /*proto*/ +static void __pyx_memoryview__slice_assign_scalar(char *, Py_ssize_t *, Py_ssize_t *, int, size_t, void *); /*proto*/ +static PyObject *__pyx_unpickle_Enum__set_state(struct __pyx_MemviewEnum_obj *, PyObject *); /*proto*/ +/* #### Code section: typeinfo ### */ +static __Pyx_TypeInfo __Pyx_TypeInfo_nn_int64_t = { "int64_t", NULL, sizeof(int64_t), { 0 }, 0, __PYX_IS_UNSIGNED(int64_t) ? 'U' : 'I', __PYX_IS_UNSIGNED(int64_t), 0 }; +static __Pyx_TypeInfo __Pyx_TypeInfo_nn_int32_t = { "int32_t", NULL, sizeof(int32_t), { 0 }, 0, __PYX_IS_UNSIGNED(int32_t) ? 'U' : 'I', __PYX_IS_UNSIGNED(int32_t), 0 }; +static __Pyx_TypeInfo __Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t = { "DTYPE_t", NULL, sizeof(__pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t), { 0 }, 0, __PYX_IS_UNSIGNED(__pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t) ? 'U' : 'I', __PYX_IS_UNSIGNED(__pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t), 0 }; +/* #### Code section: before_global_var ### */ +#define __Pyx_MODULE_NAME "fairseq.data.data_utils_fast" +extern int __pyx_module_is_main_fairseq__data__data_utils_fast; +int __pyx_module_is_main_fairseq__data__data_utils_fast = 0; + +/* Implementation of "fairseq.data.data_utils_fast" */ +/* #### Code section: global_var ### */ +static PyObject *__pyx_builtin_AssertionError; +static PyObject *__pyx_builtin_range; +static PyObject *__pyx_builtin___import__; +static PyObject *__pyx_builtin_ValueError; +static PyObject *__pyx_builtin_MemoryError; +static PyObject *__pyx_builtin_enumerate; +static PyObject *__pyx_builtin_TypeError; +static PyObject *__pyx_builtin_Ellipsis; +static PyObject *__pyx_builtin_id; +static PyObject *__pyx_builtin_IndexError; +static PyObject *__pyx_builtin_ImportError; +/* #### Code section: string_decls ### */ +static const char __pyx_k_[] = ": "; +static const char __pyx_k_O[] = "O"; +static const char __pyx_k_c[] = "c"; +static const char __pyx_k__2[] = "."; +static const char __pyx_k__3[] = "*"; +static const char __pyx_k__6[] = "'"; +static const char __pyx_k__7[] = ")"; +static const char __pyx_k_gc[] = "gc"; +static const char __pyx_k_id[] = "id"; +static const char __pyx_k_np[] = "np"; +static const char __pyx_k__28[] = "?"; +static const char __pyx_k_abc[] = "abc"; +static const char __pyx_k_and[] = " and "; +static const char __pyx_k_got[] = " (got "; +static const char __pyx_k_max[] = "max"; +static const char __pyx_k_new[] = "__new__"; +static const char __pyx_k_obj[] = "obj"; +static const char __pyx_k_sys[] = "sys"; +static const char __pyx_k_base[] = "base"; +static const char __pyx_k_dict[] = "__dict__"; +static const char __pyx_k_main[] = "__main__"; +static const char __pyx_k_mode[] = "mode"; +static const char __pyx_k_name[] = "name"; +static const char __pyx_k_ndim[] = "ndim"; +static const char __pyx_k_pack[] = "pack"; +static const char __pyx_k_size[] = "size"; +static const char __pyx_k_spec[] = "__spec__"; +static const char __pyx_k_step[] = "step"; +static const char __pyx_k_stop[] = "stop"; +static const char __pyx_k_test[] = "__test__"; +static const char __pyx_k_ASCII[] = "ASCII"; +static const char __pyx_k_class[] = "__class__"; +static const char __pyx_k_count[] = "count"; +static const char __pyx_k_dtype[] = "dtype"; +static const char __pyx_k_error[] = "error"; +static const char __pyx_k_flags[] = "flags"; +static const char __pyx_k_index[] = "index"; +static const char __pyx_k_int32[] = "int32"; +static const char __pyx_k_int64[] = "int64"; +static const char __pyx_k_numpy[] = "numpy"; +static const char __pyx_k_range[] = "range"; +static const char __pyx_k_shape[] = "shape"; +static const char __pyx_k_split[] = "split"; +static const char __pyx_k_start[] = "start"; +static const char __pyx_k_zeros[] = "zeros"; +static const char __pyx_k_enable[] = "enable"; +static const char __pyx_k_encode[] = "encode"; +static const char __pyx_k_format[] = "format"; +static const char __pyx_k_import[] = "__import__"; +static const char __pyx_k_name_2[] = "__name__"; +static const char __pyx_k_pickle[] = "pickle"; +static const char __pyx_k_reduce[] = "__reduce__"; +static const char __pyx_k_struct[] = "struct"; +static const char __pyx_k_unpack[] = "unpack"; +static const char __pyx_k_update[] = "update"; +static const char __pyx_k_disable[] = "disable"; +static const char __pyx_k_fortran[] = "fortran"; +static const char __pyx_k_indices[] = "indices"; +static const char __pyx_k_memview[] = "memview"; +static const char __pyx_k_Ellipsis[] = "Ellipsis"; +static const char __pyx_k_Sequence[] = "Sequence"; +static const char __pyx_k_bsz_mult[] = "bsz_mult"; +static const char __pyx_k_getstate[] = "__getstate__"; +static const char __pyx_k_itemsize[] = "itemsize"; +static const char __pyx_k_pyx_type[] = "__pyx_type"; +static const char __pyx_k_register[] = "register"; +static const char __pyx_k_setstate[] = "__setstate__"; +static const char __pyx_k_TypeError[] = "TypeError"; +static const char __pyx_k_enumerate[] = "enumerate"; +static const char __pyx_k_isenabled[] = "isenabled"; +static const char __pyx_k_pyx_state[] = "__pyx_state"; +static const char __pyx_k_reduce_ex[] = "__reduce_ex__"; +static const char __pyx_k_IndexError[] = "IndexError"; +static const char __pyx_k_ValueError[] = "ValueError"; +static const char __pyx_k_max_tokens[] = "max_tokens"; +static const char __pyx_k_pyx_result[] = "__pyx_result"; +static const char __pyx_k_pyx_vtable[] = "__pyx_vtable__"; +static const char __pyx_k_ImportError[] = "ImportError"; +static const char __pyx_k_MemoryError[] = "MemoryError"; +static const char __pyx_k_PickleError[] = "PickleError"; +static const char __pyx_k_collections[] = "collections"; +static const char __pyx_k_initializing[] = "_initializing"; +static const char __pyx_k_is_coroutine[] = "_is_coroutine"; +static const char __pyx_k_pyx_checksum[] = "__pyx_checksum"; +static const char __pyx_k_stringsource[] = ""; +static const char __pyx_k_version_info[] = "version_info"; +static const char __pyx_k_class_getitem[] = "__class_getitem__"; +static const char __pyx_k_max_sentences[] = "max_sentences"; +static const char __pyx_k_num_tokens_fn[] = "num_tokens_fn"; +static const char __pyx_k_reduce_cython[] = "__reduce_cython__"; +static const char __pyx_k_AssertionError[] = "AssertionError"; +static const char __pyx_k_num_tokens_vec[] = "num_tokens_vec"; +static const char __pyx_k_View_MemoryView[] = "View.MemoryView"; +static const char __pyx_k_allocate_buffer[] = "allocate_buffer"; +static const char __pyx_k_collections_abc[] = "collections.abc"; +static const char __pyx_k_dtype_is_object[] = "dtype_is_object"; +static const char __pyx_k_pyx_PickleError[] = "__pyx_PickleError"; +static const char __pyx_k_setstate_cython[] = "__setstate_cython__"; +static const char __pyx_k_batch_by_size_fn[] = "batch_by_size_fn"; +static const char __pyx_k_batch_by_size_vec[] = "batch_by_size_vec"; +static const char __pyx_k_pyx_unpickle_Enum[] = "__pyx_unpickle_Enum"; +static const char __pyx_k_asyncio_coroutines[] = "asyncio.coroutines"; +static const char __pyx_k_cline_in_traceback[] = "cline_in_traceback"; +static const char __pyx_k_strided_and_direct[] = ""; +static const char __pyx_k_fixed_shapes_sorted[] = "fixed_shapes_sorted"; +static const char __pyx_k_strided_and_indirect[] = ""; +static const char __pyx_k_Invalid_shape_in_axis[] = "Invalid shape in axis "; +static const char __pyx_k_contiguous_and_direct[] = ""; +static const char __pyx_k_Cannot_index_with_type[] = "Cannot index with type '"; +static const char __pyx_k_MemoryView_of_r_object[] = ""; +static const char __pyx_k_MemoryView_of_r_at_0x_x[] = ""; +static const char __pyx_k_batch_fixed_shapes_fast[] = "batch_fixed_shapes_fast"; +static const char __pyx_k_contiguous_and_indirect[] = ""; +static const char __pyx_k_Dimension_d_is_not_direct[] = "Dimension %d is not direct"; +static const char __pyx_k_Index_out_of_bounds_axis_d[] = "Index out of bounds (axis %d)"; +static const char __pyx_k_Step_may_not_be_zero_axis_d[] = "Step may not be zero (axis %d)"; +static const char __pyx_k_itemsize_0_for_cython_array[] = "itemsize <= 0 for cython.array"; +static const char __pyx_k_fairseq_data_data_utils_fast[] = "fairseq.data.data_utils_fast"; +static const char __pyx_k_unable_to_allocate_array_data[] = "unable to allocate array data."; +static const char __pyx_k_strided_and_direct_or_indirect[] = ""; +static const char __pyx_k_numpy_core_multiarray_failed_to[] = "numpy.core.multiarray failed to import"; +static const char __pyx_k_All_dimensions_preceding_dimensi[] = "All dimensions preceding dimension %d must be indexed and not sliced"; +static const char __pyx_k_Buffer_view_does_not_expose_stri[] = "Buffer view does not expose strides"; +static const char __pyx_k_Can_only_create_a_buffer_that_is[] = "Can only create a buffer that is contiguous in memory."; +static const char __pyx_k_Cannot_assign_to_read_only_memor[] = "Cannot assign to read-only memoryview"; +static const char __pyx_k_Cannot_create_writable_memory_vi[] = "Cannot create writable memory view from read-only memoryview"; +static const char __pyx_k_Cannot_transpose_memoryview_with[] = "Cannot transpose memoryview with indirect dimensions"; +static const char __pyx_k_Empty_shape_tuple_for_cython_arr[] = "Empty shape tuple for cython.array"; +static const char __pyx_k_Incompatible_checksums_0x_x_vs_0[] = "Incompatible checksums (0x%x vs (0x82a3537, 0x6ae9995, 0xb068931) = (name))"; +static const char __pyx_k_Indirect_dimensions_not_supporte[] = "Indirect dimensions not supported"; +static const char __pyx_k_Invalid_mode_expected_c_or_fortr[] = "Invalid mode, expected 'c' or 'fortran', got "; +static const char __pyx_k_Out_of_bounds_on_buffer_access_a[] = "Out of bounds on buffer access (axis "; +static const char __pyx_k_Sentences_lengths_should_not_exc[] = "Sentences lengths should not exceed max_tokens="; +static const char __pyx_k_Unable_to_convert_item_to_object[] = "Unable to convert item to object"; +static const char __pyx_k_fairseq_data_data_utils_fast_pyx[] = "fairseq/data/data_utils_fast.pyx"; +static const char __pyx_k_got_differing_extents_in_dimensi[] = "got differing extents in dimension "; +static const char __pyx_k_no_default___reduce___due_to_non[] = "no default __reduce__ due to non-trivial __cinit__"; +static const char __pyx_k_numpy_core_umath_failed_to_impor[] = "numpy.core.umath failed to import"; +static const char __pyx_k_unable_to_allocate_shape_and_str[] = "unable to allocate shape and strides."; +/* #### Code section: decls ### */ +static int __pyx_array___pyx_pf_15View_dot_MemoryView_5array___cinit__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_shape, Py_ssize_t __pyx_v_itemsize, PyObject *__pyx_v_format, PyObject *__pyx_v_mode, int __pyx_v_allocate_buffer); /* proto */ +static int __pyx_array___pyx_pf_15View_dot_MemoryView_5array_2__getbuffer__(struct __pyx_array_obj *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags); /* proto */ +static void __pyx_array___pyx_pf_15View_dot_MemoryView_5array_4__dealloc__(struct __pyx_array_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_5array_7memview___get__(struct __pyx_array_obj *__pyx_v_self); /* proto */ +static Py_ssize_t __pyx_array___pyx_pf_15View_dot_MemoryView_5array_6__len__(struct __pyx_array_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_array___pyx_pf_15View_dot_MemoryView_5array_8__getattr__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_attr); /* proto */ +static PyObject *__pyx_array___pyx_pf_15View_dot_MemoryView_5array_10__getitem__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_item); /* proto */ +static int __pyx_array___pyx_pf_15View_dot_MemoryView_5array_12__setitem__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_item, PyObject *__pyx_v_value); /* proto */ +static PyObject *__pyx_pf___pyx_array___reduce_cython__(CYTHON_UNUSED struct __pyx_array_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf___pyx_array_2__setstate_cython__(CYTHON_UNUSED struct __pyx_array_obj *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state); /* proto */ +static int __pyx_MemviewEnum___pyx_pf_15View_dot_MemoryView_4Enum___init__(struct __pyx_MemviewEnum_obj *__pyx_v_self, PyObject *__pyx_v_name); /* proto */ +static PyObject *__pyx_MemviewEnum___pyx_pf_15View_dot_MemoryView_4Enum_2__repr__(struct __pyx_MemviewEnum_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf___pyx_MemviewEnum___reduce_cython__(struct __pyx_MemviewEnum_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf___pyx_MemviewEnum_2__setstate_cython__(struct __pyx_MemviewEnum_obj *__pyx_v_self, PyObject *__pyx_v___pyx_state); /* proto */ +static int __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview___cinit__(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_obj, int __pyx_v_flags, int __pyx_v_dtype_is_object); /* proto */ +static void __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_2__dealloc__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_4__getitem__(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index); /* proto */ +static int __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_6__setitem__(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index, PyObject *__pyx_v_value); /* proto */ +static int __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_8__getbuffer__(struct __pyx_memoryview_obj *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_1T___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_4base___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_5shape___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_7strides___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_10suboffsets___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_4ndim___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_8itemsize___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_6nbytes___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_4size___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static Py_ssize_t __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_10__len__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_12__repr__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_14__str__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_16is_c_contig(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_18is_f_contig(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_20copy(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_22copy_fortran(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf___pyx_memoryview___reduce_cython__(CYTHON_UNUSED struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf___pyx_memoryview_2__setstate_cython__(CYTHON_UNUSED struct __pyx_memoryview_obj *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state); /* proto */ +static void __pyx_memoryviewslice___pyx_pf_15View_dot_MemoryView_16_memoryviewslice___dealloc__(struct __pyx_memoryviewslice_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf___pyx_memoryviewslice___reduce_cython__(CYTHON_UNUSED struct __pyx_memoryviewslice_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf___pyx_memoryviewslice_2__setstate_cython__(CYTHON_UNUSED struct __pyx_memoryviewslice_obj *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView___pyx_unpickle_Enum(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v___pyx_type, long __pyx_v___pyx_checksum, PyObject *__pyx_v___pyx_state); /* proto */ +static PyObject *__pyx_pf_7fairseq_4data_15data_utils_fast_batch_by_size_vec(CYTHON_UNUSED PyObject *__pyx_self, PyArrayObject *__pyx_v_indices, PyArrayObject *__pyx_v_num_tokens_vec, int64_t __pyx_v_max_tokens, int64_t __pyx_v_max_sentences, int32_t __pyx_v_bsz_mult); /* proto */ +static PyObject *__pyx_pf_7fairseq_4data_15data_utils_fast_2batch_by_size_fn(CYTHON_UNUSED PyObject *__pyx_self, PyArrayObject *__pyx_v_indices, PyObject *__pyx_v_num_tokens_fn, int64_t __pyx_v_max_tokens, int64_t __pyx_v_max_sentences, int32_t __pyx_v_bsz_mult); /* proto */ +static PyObject *__pyx_pf_7fairseq_4data_15data_utils_fast_4batch_fixed_shapes_fast(CYTHON_UNUSED PyObject *__pyx_self, PyArrayObject *__pyx_v_indices, PyObject *__pyx_v_num_tokens_fn, PyArrayObject *__pyx_v_fixed_shapes_sorted); /* proto */ +static PyObject *__pyx_tp_new_array(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/ +static PyObject *__pyx_tp_new_Enum(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/ +static PyObject *__pyx_tp_new_memoryview(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/ +static PyObject *__pyx_tp_new__memoryviewslice(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/ +/* #### Code section: late_includes ### */ +/* #### Code section: module_state ### */ +typedef struct { + PyObject *__pyx_d; + PyObject *__pyx_b; + PyObject *__pyx_cython_runtime; + PyObject *__pyx_empty_tuple; + PyObject *__pyx_empty_bytes; + PyObject *__pyx_empty_unicode; + #ifdef __Pyx_CyFunction_USED + PyTypeObject *__pyx_CyFunctionType; + #endif + #ifdef __Pyx_FusedFunction_USED + PyTypeObject *__pyx_FusedFunctionType; + #endif + #ifdef __Pyx_Generator_USED + PyTypeObject *__pyx_GeneratorType; + #endif + #ifdef __Pyx_IterableCoroutine_USED + PyTypeObject *__pyx_IterableCoroutineType; + #endif + #ifdef __Pyx_Coroutine_USED + PyTypeObject *__pyx_CoroutineAwaitType; + #endif + #ifdef __Pyx_Coroutine_USED + PyTypeObject *__pyx_CoroutineType; + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + PyTypeObject *__pyx_ptype_7cpython_4type_type; + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + PyTypeObject *__pyx_ptype_5numpy_dtype; + PyTypeObject *__pyx_ptype_5numpy_flatiter; + PyTypeObject *__pyx_ptype_5numpy_broadcast; + PyTypeObject *__pyx_ptype_5numpy_ndarray; + PyTypeObject *__pyx_ptype_5numpy_generic; + PyTypeObject *__pyx_ptype_5numpy_number; + PyTypeObject *__pyx_ptype_5numpy_integer; + PyTypeObject *__pyx_ptype_5numpy_signedinteger; + PyTypeObject *__pyx_ptype_5numpy_unsignedinteger; + PyTypeObject *__pyx_ptype_5numpy_inexact; + PyTypeObject *__pyx_ptype_5numpy_floating; + PyTypeObject *__pyx_ptype_5numpy_complexfloating; + PyTypeObject *__pyx_ptype_5numpy_flexible; + PyTypeObject *__pyx_ptype_5numpy_character; + PyTypeObject *__pyx_ptype_5numpy_ufunc; + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + PyObject *__pyx_type___pyx_array; + PyObject *__pyx_type___pyx_MemviewEnum; + PyObject *__pyx_type___pyx_memoryview; + PyObject *__pyx_type___pyx_memoryviewslice; + #endif + PyTypeObject *__pyx_array_type; + PyTypeObject *__pyx_MemviewEnum_type; + PyTypeObject *__pyx_memoryview_type; + PyTypeObject *__pyx_memoryviewslice_type; + PyObject *__pyx_kp_u_; + PyObject *__pyx_n_s_ASCII; + PyObject *__pyx_kp_s_All_dimensions_preceding_dimensi; + PyObject *__pyx_n_s_AssertionError; + PyObject *__pyx_kp_s_Buffer_view_does_not_expose_stri; + PyObject *__pyx_kp_s_Can_only_create_a_buffer_that_is; + PyObject *__pyx_kp_s_Cannot_assign_to_read_only_memor; + PyObject *__pyx_kp_s_Cannot_create_writable_memory_vi; + PyObject *__pyx_kp_u_Cannot_index_with_type; + PyObject *__pyx_kp_s_Cannot_transpose_memoryview_with; + PyObject *__pyx_kp_s_Dimension_d_is_not_direct; + PyObject *__pyx_n_s_Ellipsis; + PyObject *__pyx_kp_s_Empty_shape_tuple_for_cython_arr; + PyObject *__pyx_n_s_ImportError; + PyObject *__pyx_kp_s_Incompatible_checksums_0x_x_vs_0; + PyObject *__pyx_n_s_IndexError; + PyObject *__pyx_kp_s_Index_out_of_bounds_axis_d; + PyObject *__pyx_kp_s_Indirect_dimensions_not_supporte; + PyObject *__pyx_kp_u_Invalid_mode_expected_c_or_fortr; + PyObject *__pyx_kp_u_Invalid_shape_in_axis; + PyObject *__pyx_n_s_MemoryError; + PyObject *__pyx_kp_s_MemoryView_of_r_at_0x_x; + PyObject *__pyx_kp_s_MemoryView_of_r_object; + PyObject *__pyx_n_b_O; + PyObject *__pyx_kp_u_Out_of_bounds_on_buffer_access_a; + PyObject *__pyx_n_s_PickleError; + PyObject *__pyx_kp_u_Sentences_lengths_should_not_exc; + PyObject *__pyx_n_s_Sequence; + PyObject *__pyx_kp_s_Step_may_not_be_zero_axis_d; + PyObject *__pyx_n_s_TypeError; + PyObject *__pyx_kp_s_Unable_to_convert_item_to_object; + PyObject *__pyx_n_s_ValueError; + PyObject *__pyx_n_s_View_MemoryView; + PyObject *__pyx_kp_u__2; + PyObject *__pyx_n_s__28; + PyObject *__pyx_n_s__3; + PyObject *__pyx_kp_u__6; + PyObject *__pyx_kp_u__7; + PyObject *__pyx_n_s_abc; + PyObject *__pyx_n_s_allocate_buffer; + PyObject *__pyx_kp_u_and; + PyObject *__pyx_n_s_asyncio_coroutines; + PyObject *__pyx_n_s_base; + PyObject *__pyx_n_s_batch_by_size_fn; + PyObject *__pyx_n_s_batch_by_size_vec; + PyObject *__pyx_n_s_batch_fixed_shapes_fast; + PyObject *__pyx_n_s_bsz_mult; + PyObject *__pyx_n_s_c; + PyObject *__pyx_n_u_c; + PyObject *__pyx_n_s_class; + PyObject *__pyx_n_s_class_getitem; + PyObject *__pyx_n_s_cline_in_traceback; + PyObject *__pyx_n_s_collections; + PyObject *__pyx_kp_s_collections_abc; + PyObject *__pyx_kp_s_contiguous_and_direct; + PyObject *__pyx_kp_s_contiguous_and_indirect; + PyObject *__pyx_n_s_count; + PyObject *__pyx_n_s_dict; + PyObject *__pyx_kp_u_disable; + PyObject *__pyx_n_s_dtype; + PyObject *__pyx_n_s_dtype_is_object; + PyObject *__pyx_kp_u_enable; + PyObject *__pyx_n_s_encode; + PyObject *__pyx_n_s_enumerate; + PyObject *__pyx_n_s_error; + PyObject *__pyx_n_s_fairseq_data_data_utils_fast; + PyObject *__pyx_kp_s_fairseq_data_data_utils_fast_pyx; + PyObject *__pyx_n_s_fixed_shapes_sorted; + PyObject *__pyx_n_s_flags; + PyObject *__pyx_n_s_format; + PyObject *__pyx_n_s_fortran; + PyObject *__pyx_n_u_fortran; + PyObject *__pyx_kp_u_gc; + PyObject *__pyx_n_s_getstate; + PyObject *__pyx_kp_u_got; + PyObject *__pyx_kp_u_got_differing_extents_in_dimensi; + PyObject *__pyx_n_s_id; + PyObject *__pyx_n_s_import; + PyObject *__pyx_n_s_index; + PyObject *__pyx_n_s_indices; + PyObject *__pyx_n_s_initializing; + PyObject *__pyx_n_s_int32; + PyObject *__pyx_n_s_int64; + PyObject *__pyx_n_s_is_coroutine; + PyObject *__pyx_kp_u_isenabled; + PyObject *__pyx_n_s_itemsize; + PyObject *__pyx_kp_s_itemsize_0_for_cython_array; + PyObject *__pyx_n_s_main; + PyObject *__pyx_n_s_max; + PyObject *__pyx_n_s_max_sentences; + PyObject *__pyx_n_s_max_tokens; + PyObject *__pyx_n_s_memview; + PyObject *__pyx_n_s_mode; + PyObject *__pyx_n_s_name; + PyObject *__pyx_n_s_name_2; + PyObject *__pyx_n_s_ndim; + PyObject *__pyx_n_s_new; + PyObject *__pyx_kp_s_no_default___reduce___due_to_non; + PyObject *__pyx_n_s_np; + PyObject *__pyx_n_s_num_tokens_fn; + PyObject *__pyx_n_s_num_tokens_vec; + PyObject *__pyx_n_s_numpy; + PyObject *__pyx_kp_u_numpy_core_multiarray_failed_to; + PyObject *__pyx_kp_u_numpy_core_umath_failed_to_impor; + PyObject *__pyx_n_s_obj; + PyObject *__pyx_n_s_pack; + PyObject *__pyx_n_s_pickle; + PyObject *__pyx_n_s_pyx_PickleError; + PyObject *__pyx_n_s_pyx_checksum; + PyObject *__pyx_n_s_pyx_result; + PyObject *__pyx_n_s_pyx_state; + PyObject *__pyx_n_s_pyx_type; + PyObject *__pyx_n_s_pyx_unpickle_Enum; + PyObject *__pyx_n_s_pyx_vtable; + PyObject *__pyx_n_s_range; + PyObject *__pyx_n_s_reduce; + PyObject *__pyx_n_s_reduce_cython; + PyObject *__pyx_n_s_reduce_ex; + PyObject *__pyx_n_s_register; + PyObject *__pyx_n_s_setstate; + PyObject *__pyx_n_s_setstate_cython; + PyObject *__pyx_n_s_shape; + PyObject *__pyx_n_s_size; + PyObject *__pyx_n_s_spec; + PyObject *__pyx_n_s_split; + PyObject *__pyx_n_s_start; + PyObject *__pyx_n_s_step; + PyObject *__pyx_n_s_stop; + PyObject *__pyx_kp_s_strided_and_direct; + PyObject *__pyx_kp_s_strided_and_direct_or_indirect; + PyObject *__pyx_kp_s_strided_and_indirect; + PyObject *__pyx_kp_s_stringsource; + PyObject *__pyx_n_s_struct; + PyObject *__pyx_n_s_sys; + PyObject *__pyx_n_s_test; + PyObject *__pyx_kp_s_unable_to_allocate_array_data; + PyObject *__pyx_kp_s_unable_to_allocate_shape_and_str; + PyObject *__pyx_n_s_unpack; + PyObject *__pyx_n_s_update; + PyObject *__pyx_n_s_version_info; + PyObject *__pyx_n_s_zeros; + PyObject *__pyx_int_0; + PyObject *__pyx_int_1; + PyObject *__pyx_int_3; + PyObject *__pyx_int_112105877; + PyObject *__pyx_int_136983863; + PyObject *__pyx_int_184977713; + PyObject *__pyx_int_neg_1; + PyObject *__pyx_slice__5; + PyObject *__pyx_tuple__4; + PyObject *__pyx_tuple__8; + PyObject *__pyx_tuple__9; + PyObject *__pyx_tuple__10; + PyObject *__pyx_tuple__11; + PyObject *__pyx_tuple__12; + PyObject *__pyx_tuple__13; + PyObject *__pyx_tuple__14; + PyObject *__pyx_tuple__15; + PyObject *__pyx_tuple__16; + PyObject *__pyx_tuple__17; + PyObject *__pyx_tuple__18; + PyObject *__pyx_tuple__19; + PyObject *__pyx_tuple__20; + PyObject *__pyx_tuple__22; + PyObject *__pyx_tuple__24; + PyObject *__pyx_tuple__26; + PyObject *__pyx_codeobj__21; + PyObject *__pyx_codeobj__23; + PyObject *__pyx_codeobj__25; + PyObject *__pyx_codeobj__27; +} __pyx_mstate; + +#if CYTHON_USE_MODULE_STATE +#ifdef __cplusplus +namespace { + extern struct PyModuleDef __pyx_moduledef; +} /* anonymous namespace */ +#else +static struct PyModuleDef __pyx_moduledef; +#endif + +#define __pyx_mstate(o) ((__pyx_mstate *)__Pyx_PyModule_GetState(o)) + +#define __pyx_mstate_global (__pyx_mstate(PyState_FindModule(&__pyx_moduledef))) + +#define __pyx_m (PyState_FindModule(&__pyx_moduledef)) +#else +static __pyx_mstate __pyx_mstate_global_static = +#ifdef __cplusplus + {}; +#else + {0}; +#endif +static __pyx_mstate *__pyx_mstate_global = &__pyx_mstate_global_static; +#endif +/* #### Code section: module_state_clear ### */ +#if CYTHON_USE_MODULE_STATE +static int __pyx_m_clear(PyObject *m) { + __pyx_mstate *clear_module_state = __pyx_mstate(m); + if (!clear_module_state) return 0; + Py_CLEAR(clear_module_state->__pyx_d); + Py_CLEAR(clear_module_state->__pyx_b); + Py_CLEAR(clear_module_state->__pyx_cython_runtime); + Py_CLEAR(clear_module_state->__pyx_empty_tuple); + Py_CLEAR(clear_module_state->__pyx_empty_bytes); + Py_CLEAR(clear_module_state->__pyx_empty_unicode); + #ifdef __Pyx_CyFunction_USED + Py_CLEAR(clear_module_state->__pyx_CyFunctionType); + #endif + #ifdef __Pyx_FusedFunction_USED + Py_CLEAR(clear_module_state->__pyx_FusedFunctionType); + #endif + Py_CLEAR(clear_module_state->__pyx_ptype_7cpython_4type_type); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_dtype); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_flatiter); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_broadcast); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_ndarray); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_generic); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_number); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_integer); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_signedinteger); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_unsignedinteger); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_inexact); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_floating); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_complexfloating); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_flexible); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_character); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_ufunc); + Py_CLEAR(clear_module_state->__pyx_array_type); + Py_CLEAR(clear_module_state->__pyx_type___pyx_array); + Py_CLEAR(clear_module_state->__pyx_MemviewEnum_type); + Py_CLEAR(clear_module_state->__pyx_type___pyx_MemviewEnum); + Py_CLEAR(clear_module_state->__pyx_memoryview_type); + Py_CLEAR(clear_module_state->__pyx_type___pyx_memoryview); + Py_CLEAR(clear_module_state->__pyx_memoryviewslice_type); + Py_CLEAR(clear_module_state->__pyx_type___pyx_memoryviewslice); + Py_CLEAR(clear_module_state->__pyx_kp_u_); + Py_CLEAR(clear_module_state->__pyx_n_s_ASCII); + Py_CLEAR(clear_module_state->__pyx_kp_s_All_dimensions_preceding_dimensi); + Py_CLEAR(clear_module_state->__pyx_n_s_AssertionError); + Py_CLEAR(clear_module_state->__pyx_kp_s_Buffer_view_does_not_expose_stri); + Py_CLEAR(clear_module_state->__pyx_kp_s_Can_only_create_a_buffer_that_is); + Py_CLEAR(clear_module_state->__pyx_kp_s_Cannot_assign_to_read_only_memor); + Py_CLEAR(clear_module_state->__pyx_kp_s_Cannot_create_writable_memory_vi); + Py_CLEAR(clear_module_state->__pyx_kp_u_Cannot_index_with_type); + Py_CLEAR(clear_module_state->__pyx_kp_s_Cannot_transpose_memoryview_with); + Py_CLEAR(clear_module_state->__pyx_kp_s_Dimension_d_is_not_direct); + Py_CLEAR(clear_module_state->__pyx_n_s_Ellipsis); + Py_CLEAR(clear_module_state->__pyx_kp_s_Empty_shape_tuple_for_cython_arr); + Py_CLEAR(clear_module_state->__pyx_n_s_ImportError); + Py_CLEAR(clear_module_state->__pyx_kp_s_Incompatible_checksums_0x_x_vs_0); + Py_CLEAR(clear_module_state->__pyx_n_s_IndexError); + Py_CLEAR(clear_module_state->__pyx_kp_s_Index_out_of_bounds_axis_d); + Py_CLEAR(clear_module_state->__pyx_kp_s_Indirect_dimensions_not_supporte); + Py_CLEAR(clear_module_state->__pyx_kp_u_Invalid_mode_expected_c_or_fortr); + Py_CLEAR(clear_module_state->__pyx_kp_u_Invalid_shape_in_axis); + Py_CLEAR(clear_module_state->__pyx_n_s_MemoryError); + Py_CLEAR(clear_module_state->__pyx_kp_s_MemoryView_of_r_at_0x_x); + Py_CLEAR(clear_module_state->__pyx_kp_s_MemoryView_of_r_object); + Py_CLEAR(clear_module_state->__pyx_n_b_O); + Py_CLEAR(clear_module_state->__pyx_kp_u_Out_of_bounds_on_buffer_access_a); + Py_CLEAR(clear_module_state->__pyx_n_s_PickleError); + Py_CLEAR(clear_module_state->__pyx_kp_u_Sentences_lengths_should_not_exc); + Py_CLEAR(clear_module_state->__pyx_n_s_Sequence); + Py_CLEAR(clear_module_state->__pyx_kp_s_Step_may_not_be_zero_axis_d); + Py_CLEAR(clear_module_state->__pyx_n_s_TypeError); + Py_CLEAR(clear_module_state->__pyx_kp_s_Unable_to_convert_item_to_object); + Py_CLEAR(clear_module_state->__pyx_n_s_ValueError); + Py_CLEAR(clear_module_state->__pyx_n_s_View_MemoryView); + Py_CLEAR(clear_module_state->__pyx_kp_u__2); + Py_CLEAR(clear_module_state->__pyx_n_s__28); + Py_CLEAR(clear_module_state->__pyx_n_s__3); + Py_CLEAR(clear_module_state->__pyx_kp_u__6); + Py_CLEAR(clear_module_state->__pyx_kp_u__7); + Py_CLEAR(clear_module_state->__pyx_n_s_abc); + Py_CLEAR(clear_module_state->__pyx_n_s_allocate_buffer); + Py_CLEAR(clear_module_state->__pyx_kp_u_and); + Py_CLEAR(clear_module_state->__pyx_n_s_asyncio_coroutines); + Py_CLEAR(clear_module_state->__pyx_n_s_base); + Py_CLEAR(clear_module_state->__pyx_n_s_batch_by_size_fn); + Py_CLEAR(clear_module_state->__pyx_n_s_batch_by_size_vec); + Py_CLEAR(clear_module_state->__pyx_n_s_batch_fixed_shapes_fast); + Py_CLEAR(clear_module_state->__pyx_n_s_bsz_mult); + Py_CLEAR(clear_module_state->__pyx_n_s_c); + Py_CLEAR(clear_module_state->__pyx_n_u_c); + Py_CLEAR(clear_module_state->__pyx_n_s_class); + Py_CLEAR(clear_module_state->__pyx_n_s_class_getitem); + Py_CLEAR(clear_module_state->__pyx_n_s_cline_in_traceback); + Py_CLEAR(clear_module_state->__pyx_n_s_collections); + Py_CLEAR(clear_module_state->__pyx_kp_s_collections_abc); + Py_CLEAR(clear_module_state->__pyx_kp_s_contiguous_and_direct); + Py_CLEAR(clear_module_state->__pyx_kp_s_contiguous_and_indirect); + Py_CLEAR(clear_module_state->__pyx_n_s_count); + Py_CLEAR(clear_module_state->__pyx_n_s_dict); + Py_CLEAR(clear_module_state->__pyx_kp_u_disable); + Py_CLEAR(clear_module_state->__pyx_n_s_dtype); + Py_CLEAR(clear_module_state->__pyx_n_s_dtype_is_object); + Py_CLEAR(clear_module_state->__pyx_kp_u_enable); + Py_CLEAR(clear_module_state->__pyx_n_s_encode); + Py_CLEAR(clear_module_state->__pyx_n_s_enumerate); + Py_CLEAR(clear_module_state->__pyx_n_s_error); + Py_CLEAR(clear_module_state->__pyx_n_s_fairseq_data_data_utils_fast); + Py_CLEAR(clear_module_state->__pyx_kp_s_fairseq_data_data_utils_fast_pyx); + Py_CLEAR(clear_module_state->__pyx_n_s_fixed_shapes_sorted); + Py_CLEAR(clear_module_state->__pyx_n_s_flags); + Py_CLEAR(clear_module_state->__pyx_n_s_format); + Py_CLEAR(clear_module_state->__pyx_n_s_fortran); + Py_CLEAR(clear_module_state->__pyx_n_u_fortran); + Py_CLEAR(clear_module_state->__pyx_kp_u_gc); + Py_CLEAR(clear_module_state->__pyx_n_s_getstate); + Py_CLEAR(clear_module_state->__pyx_kp_u_got); + Py_CLEAR(clear_module_state->__pyx_kp_u_got_differing_extents_in_dimensi); + Py_CLEAR(clear_module_state->__pyx_n_s_id); + Py_CLEAR(clear_module_state->__pyx_n_s_import); + Py_CLEAR(clear_module_state->__pyx_n_s_index); + Py_CLEAR(clear_module_state->__pyx_n_s_indices); + Py_CLEAR(clear_module_state->__pyx_n_s_initializing); + Py_CLEAR(clear_module_state->__pyx_n_s_int32); + Py_CLEAR(clear_module_state->__pyx_n_s_int64); + Py_CLEAR(clear_module_state->__pyx_n_s_is_coroutine); + Py_CLEAR(clear_module_state->__pyx_kp_u_isenabled); + Py_CLEAR(clear_module_state->__pyx_n_s_itemsize); + Py_CLEAR(clear_module_state->__pyx_kp_s_itemsize_0_for_cython_array); + Py_CLEAR(clear_module_state->__pyx_n_s_main); + Py_CLEAR(clear_module_state->__pyx_n_s_max); + Py_CLEAR(clear_module_state->__pyx_n_s_max_sentences); + Py_CLEAR(clear_module_state->__pyx_n_s_max_tokens); + Py_CLEAR(clear_module_state->__pyx_n_s_memview); + Py_CLEAR(clear_module_state->__pyx_n_s_mode); + Py_CLEAR(clear_module_state->__pyx_n_s_name); + Py_CLEAR(clear_module_state->__pyx_n_s_name_2); + Py_CLEAR(clear_module_state->__pyx_n_s_ndim); + Py_CLEAR(clear_module_state->__pyx_n_s_new); + Py_CLEAR(clear_module_state->__pyx_kp_s_no_default___reduce___due_to_non); + Py_CLEAR(clear_module_state->__pyx_n_s_np); + Py_CLEAR(clear_module_state->__pyx_n_s_num_tokens_fn); + Py_CLEAR(clear_module_state->__pyx_n_s_num_tokens_vec); + Py_CLEAR(clear_module_state->__pyx_n_s_numpy); + Py_CLEAR(clear_module_state->__pyx_kp_u_numpy_core_multiarray_failed_to); + Py_CLEAR(clear_module_state->__pyx_kp_u_numpy_core_umath_failed_to_impor); + Py_CLEAR(clear_module_state->__pyx_n_s_obj); + Py_CLEAR(clear_module_state->__pyx_n_s_pack); + Py_CLEAR(clear_module_state->__pyx_n_s_pickle); + Py_CLEAR(clear_module_state->__pyx_n_s_pyx_PickleError); + Py_CLEAR(clear_module_state->__pyx_n_s_pyx_checksum); + Py_CLEAR(clear_module_state->__pyx_n_s_pyx_result); + Py_CLEAR(clear_module_state->__pyx_n_s_pyx_state); + Py_CLEAR(clear_module_state->__pyx_n_s_pyx_type); + Py_CLEAR(clear_module_state->__pyx_n_s_pyx_unpickle_Enum); + Py_CLEAR(clear_module_state->__pyx_n_s_pyx_vtable); + Py_CLEAR(clear_module_state->__pyx_n_s_range); + Py_CLEAR(clear_module_state->__pyx_n_s_reduce); + Py_CLEAR(clear_module_state->__pyx_n_s_reduce_cython); + Py_CLEAR(clear_module_state->__pyx_n_s_reduce_ex); + Py_CLEAR(clear_module_state->__pyx_n_s_register); + Py_CLEAR(clear_module_state->__pyx_n_s_setstate); + Py_CLEAR(clear_module_state->__pyx_n_s_setstate_cython); + Py_CLEAR(clear_module_state->__pyx_n_s_shape); + Py_CLEAR(clear_module_state->__pyx_n_s_size); + Py_CLEAR(clear_module_state->__pyx_n_s_spec); + Py_CLEAR(clear_module_state->__pyx_n_s_split); + Py_CLEAR(clear_module_state->__pyx_n_s_start); + Py_CLEAR(clear_module_state->__pyx_n_s_step); + Py_CLEAR(clear_module_state->__pyx_n_s_stop); + Py_CLEAR(clear_module_state->__pyx_kp_s_strided_and_direct); + Py_CLEAR(clear_module_state->__pyx_kp_s_strided_and_direct_or_indirect); + Py_CLEAR(clear_module_state->__pyx_kp_s_strided_and_indirect); + Py_CLEAR(clear_module_state->__pyx_kp_s_stringsource); + Py_CLEAR(clear_module_state->__pyx_n_s_struct); + Py_CLEAR(clear_module_state->__pyx_n_s_sys); + Py_CLEAR(clear_module_state->__pyx_n_s_test); + Py_CLEAR(clear_module_state->__pyx_kp_s_unable_to_allocate_array_data); + Py_CLEAR(clear_module_state->__pyx_kp_s_unable_to_allocate_shape_and_str); + Py_CLEAR(clear_module_state->__pyx_n_s_unpack); + Py_CLEAR(clear_module_state->__pyx_n_s_update); + Py_CLEAR(clear_module_state->__pyx_n_s_version_info); + Py_CLEAR(clear_module_state->__pyx_n_s_zeros); + Py_CLEAR(clear_module_state->__pyx_int_0); + Py_CLEAR(clear_module_state->__pyx_int_1); + Py_CLEAR(clear_module_state->__pyx_int_3); + Py_CLEAR(clear_module_state->__pyx_int_112105877); + Py_CLEAR(clear_module_state->__pyx_int_136983863); + Py_CLEAR(clear_module_state->__pyx_int_184977713); + Py_CLEAR(clear_module_state->__pyx_int_neg_1); + Py_CLEAR(clear_module_state->__pyx_slice__5); + Py_CLEAR(clear_module_state->__pyx_tuple__4); + Py_CLEAR(clear_module_state->__pyx_tuple__8); + Py_CLEAR(clear_module_state->__pyx_tuple__9); + Py_CLEAR(clear_module_state->__pyx_tuple__10); + Py_CLEAR(clear_module_state->__pyx_tuple__11); + Py_CLEAR(clear_module_state->__pyx_tuple__12); + Py_CLEAR(clear_module_state->__pyx_tuple__13); + Py_CLEAR(clear_module_state->__pyx_tuple__14); + Py_CLEAR(clear_module_state->__pyx_tuple__15); + Py_CLEAR(clear_module_state->__pyx_tuple__16); + Py_CLEAR(clear_module_state->__pyx_tuple__17); + Py_CLEAR(clear_module_state->__pyx_tuple__18); + Py_CLEAR(clear_module_state->__pyx_tuple__19); + Py_CLEAR(clear_module_state->__pyx_tuple__20); + Py_CLEAR(clear_module_state->__pyx_tuple__22); + Py_CLEAR(clear_module_state->__pyx_tuple__24); + Py_CLEAR(clear_module_state->__pyx_tuple__26); + Py_CLEAR(clear_module_state->__pyx_codeobj__21); + Py_CLEAR(clear_module_state->__pyx_codeobj__23); + Py_CLEAR(clear_module_state->__pyx_codeobj__25); + Py_CLEAR(clear_module_state->__pyx_codeobj__27); + return 0; +} +#endif +/* #### Code section: module_state_traverse ### */ +#if CYTHON_USE_MODULE_STATE +static int __pyx_m_traverse(PyObject *m, visitproc visit, void *arg) { + __pyx_mstate *traverse_module_state = __pyx_mstate(m); + if (!traverse_module_state) return 0; + Py_VISIT(traverse_module_state->__pyx_d); + Py_VISIT(traverse_module_state->__pyx_b); + Py_VISIT(traverse_module_state->__pyx_cython_runtime); + Py_VISIT(traverse_module_state->__pyx_empty_tuple); + Py_VISIT(traverse_module_state->__pyx_empty_bytes); + Py_VISIT(traverse_module_state->__pyx_empty_unicode); + #ifdef __Pyx_CyFunction_USED + Py_VISIT(traverse_module_state->__pyx_CyFunctionType); + #endif + #ifdef __Pyx_FusedFunction_USED + Py_VISIT(traverse_module_state->__pyx_FusedFunctionType); + #endif + Py_VISIT(traverse_module_state->__pyx_ptype_7cpython_4type_type); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_dtype); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_flatiter); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_broadcast); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_ndarray); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_generic); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_number); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_integer); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_signedinteger); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_unsignedinteger); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_inexact); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_floating); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_complexfloating); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_flexible); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_character); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_ufunc); + Py_VISIT(traverse_module_state->__pyx_array_type); + Py_VISIT(traverse_module_state->__pyx_type___pyx_array); + Py_VISIT(traverse_module_state->__pyx_MemviewEnum_type); + Py_VISIT(traverse_module_state->__pyx_type___pyx_MemviewEnum); + Py_VISIT(traverse_module_state->__pyx_memoryview_type); + Py_VISIT(traverse_module_state->__pyx_type___pyx_memoryview); + Py_VISIT(traverse_module_state->__pyx_memoryviewslice_type); + Py_VISIT(traverse_module_state->__pyx_type___pyx_memoryviewslice); + Py_VISIT(traverse_module_state->__pyx_kp_u_); + Py_VISIT(traverse_module_state->__pyx_n_s_ASCII); + Py_VISIT(traverse_module_state->__pyx_kp_s_All_dimensions_preceding_dimensi); + Py_VISIT(traverse_module_state->__pyx_n_s_AssertionError); + Py_VISIT(traverse_module_state->__pyx_kp_s_Buffer_view_does_not_expose_stri); + Py_VISIT(traverse_module_state->__pyx_kp_s_Can_only_create_a_buffer_that_is); + Py_VISIT(traverse_module_state->__pyx_kp_s_Cannot_assign_to_read_only_memor); + Py_VISIT(traverse_module_state->__pyx_kp_s_Cannot_create_writable_memory_vi); + Py_VISIT(traverse_module_state->__pyx_kp_u_Cannot_index_with_type); + Py_VISIT(traverse_module_state->__pyx_kp_s_Cannot_transpose_memoryview_with); + Py_VISIT(traverse_module_state->__pyx_kp_s_Dimension_d_is_not_direct); + Py_VISIT(traverse_module_state->__pyx_n_s_Ellipsis); + Py_VISIT(traverse_module_state->__pyx_kp_s_Empty_shape_tuple_for_cython_arr); + Py_VISIT(traverse_module_state->__pyx_n_s_ImportError); + Py_VISIT(traverse_module_state->__pyx_kp_s_Incompatible_checksums_0x_x_vs_0); + Py_VISIT(traverse_module_state->__pyx_n_s_IndexError); + Py_VISIT(traverse_module_state->__pyx_kp_s_Index_out_of_bounds_axis_d); + Py_VISIT(traverse_module_state->__pyx_kp_s_Indirect_dimensions_not_supporte); + Py_VISIT(traverse_module_state->__pyx_kp_u_Invalid_mode_expected_c_or_fortr); + Py_VISIT(traverse_module_state->__pyx_kp_u_Invalid_shape_in_axis); + Py_VISIT(traverse_module_state->__pyx_n_s_MemoryError); + Py_VISIT(traverse_module_state->__pyx_kp_s_MemoryView_of_r_at_0x_x); + Py_VISIT(traverse_module_state->__pyx_kp_s_MemoryView_of_r_object); + Py_VISIT(traverse_module_state->__pyx_n_b_O); + Py_VISIT(traverse_module_state->__pyx_kp_u_Out_of_bounds_on_buffer_access_a); + Py_VISIT(traverse_module_state->__pyx_n_s_PickleError); + Py_VISIT(traverse_module_state->__pyx_kp_u_Sentences_lengths_should_not_exc); + Py_VISIT(traverse_module_state->__pyx_n_s_Sequence); + Py_VISIT(traverse_module_state->__pyx_kp_s_Step_may_not_be_zero_axis_d); + Py_VISIT(traverse_module_state->__pyx_n_s_TypeError); + Py_VISIT(traverse_module_state->__pyx_kp_s_Unable_to_convert_item_to_object); + Py_VISIT(traverse_module_state->__pyx_n_s_ValueError); + Py_VISIT(traverse_module_state->__pyx_n_s_View_MemoryView); + Py_VISIT(traverse_module_state->__pyx_kp_u__2); + Py_VISIT(traverse_module_state->__pyx_n_s__28); + Py_VISIT(traverse_module_state->__pyx_n_s__3); + Py_VISIT(traverse_module_state->__pyx_kp_u__6); + Py_VISIT(traverse_module_state->__pyx_kp_u__7); + Py_VISIT(traverse_module_state->__pyx_n_s_abc); + Py_VISIT(traverse_module_state->__pyx_n_s_allocate_buffer); + Py_VISIT(traverse_module_state->__pyx_kp_u_and); + Py_VISIT(traverse_module_state->__pyx_n_s_asyncio_coroutines); + Py_VISIT(traverse_module_state->__pyx_n_s_base); + Py_VISIT(traverse_module_state->__pyx_n_s_batch_by_size_fn); + Py_VISIT(traverse_module_state->__pyx_n_s_batch_by_size_vec); + Py_VISIT(traverse_module_state->__pyx_n_s_batch_fixed_shapes_fast); + Py_VISIT(traverse_module_state->__pyx_n_s_bsz_mult); + Py_VISIT(traverse_module_state->__pyx_n_s_c); + Py_VISIT(traverse_module_state->__pyx_n_u_c); + Py_VISIT(traverse_module_state->__pyx_n_s_class); + Py_VISIT(traverse_module_state->__pyx_n_s_class_getitem); + Py_VISIT(traverse_module_state->__pyx_n_s_cline_in_traceback); + Py_VISIT(traverse_module_state->__pyx_n_s_collections); + Py_VISIT(traverse_module_state->__pyx_kp_s_collections_abc); + Py_VISIT(traverse_module_state->__pyx_kp_s_contiguous_and_direct); + Py_VISIT(traverse_module_state->__pyx_kp_s_contiguous_and_indirect); + Py_VISIT(traverse_module_state->__pyx_n_s_count); + Py_VISIT(traverse_module_state->__pyx_n_s_dict); + Py_VISIT(traverse_module_state->__pyx_kp_u_disable); + Py_VISIT(traverse_module_state->__pyx_n_s_dtype); + Py_VISIT(traverse_module_state->__pyx_n_s_dtype_is_object); + Py_VISIT(traverse_module_state->__pyx_kp_u_enable); + Py_VISIT(traverse_module_state->__pyx_n_s_encode); + Py_VISIT(traverse_module_state->__pyx_n_s_enumerate); + Py_VISIT(traverse_module_state->__pyx_n_s_error); + Py_VISIT(traverse_module_state->__pyx_n_s_fairseq_data_data_utils_fast); + Py_VISIT(traverse_module_state->__pyx_kp_s_fairseq_data_data_utils_fast_pyx); + Py_VISIT(traverse_module_state->__pyx_n_s_fixed_shapes_sorted); + Py_VISIT(traverse_module_state->__pyx_n_s_flags); + Py_VISIT(traverse_module_state->__pyx_n_s_format); + Py_VISIT(traverse_module_state->__pyx_n_s_fortran); + Py_VISIT(traverse_module_state->__pyx_n_u_fortran); + Py_VISIT(traverse_module_state->__pyx_kp_u_gc); + Py_VISIT(traverse_module_state->__pyx_n_s_getstate); + Py_VISIT(traverse_module_state->__pyx_kp_u_got); + Py_VISIT(traverse_module_state->__pyx_kp_u_got_differing_extents_in_dimensi); + Py_VISIT(traverse_module_state->__pyx_n_s_id); + Py_VISIT(traverse_module_state->__pyx_n_s_import); + Py_VISIT(traverse_module_state->__pyx_n_s_index); + Py_VISIT(traverse_module_state->__pyx_n_s_indices); + Py_VISIT(traverse_module_state->__pyx_n_s_initializing); + Py_VISIT(traverse_module_state->__pyx_n_s_int32); + Py_VISIT(traverse_module_state->__pyx_n_s_int64); + Py_VISIT(traverse_module_state->__pyx_n_s_is_coroutine); + Py_VISIT(traverse_module_state->__pyx_kp_u_isenabled); + Py_VISIT(traverse_module_state->__pyx_n_s_itemsize); + Py_VISIT(traverse_module_state->__pyx_kp_s_itemsize_0_for_cython_array); + Py_VISIT(traverse_module_state->__pyx_n_s_main); + Py_VISIT(traverse_module_state->__pyx_n_s_max); + Py_VISIT(traverse_module_state->__pyx_n_s_max_sentences); + Py_VISIT(traverse_module_state->__pyx_n_s_max_tokens); + Py_VISIT(traverse_module_state->__pyx_n_s_memview); + Py_VISIT(traverse_module_state->__pyx_n_s_mode); + Py_VISIT(traverse_module_state->__pyx_n_s_name); + Py_VISIT(traverse_module_state->__pyx_n_s_name_2); + Py_VISIT(traverse_module_state->__pyx_n_s_ndim); + Py_VISIT(traverse_module_state->__pyx_n_s_new); + Py_VISIT(traverse_module_state->__pyx_kp_s_no_default___reduce___due_to_non); + Py_VISIT(traverse_module_state->__pyx_n_s_np); + Py_VISIT(traverse_module_state->__pyx_n_s_num_tokens_fn); + Py_VISIT(traverse_module_state->__pyx_n_s_num_tokens_vec); + Py_VISIT(traverse_module_state->__pyx_n_s_numpy); + Py_VISIT(traverse_module_state->__pyx_kp_u_numpy_core_multiarray_failed_to); + Py_VISIT(traverse_module_state->__pyx_kp_u_numpy_core_umath_failed_to_impor); + Py_VISIT(traverse_module_state->__pyx_n_s_obj); + Py_VISIT(traverse_module_state->__pyx_n_s_pack); + Py_VISIT(traverse_module_state->__pyx_n_s_pickle); + Py_VISIT(traverse_module_state->__pyx_n_s_pyx_PickleError); + Py_VISIT(traverse_module_state->__pyx_n_s_pyx_checksum); + Py_VISIT(traverse_module_state->__pyx_n_s_pyx_result); + Py_VISIT(traverse_module_state->__pyx_n_s_pyx_state); + Py_VISIT(traverse_module_state->__pyx_n_s_pyx_type); + Py_VISIT(traverse_module_state->__pyx_n_s_pyx_unpickle_Enum); + Py_VISIT(traverse_module_state->__pyx_n_s_pyx_vtable); + Py_VISIT(traverse_module_state->__pyx_n_s_range); + Py_VISIT(traverse_module_state->__pyx_n_s_reduce); + Py_VISIT(traverse_module_state->__pyx_n_s_reduce_cython); + Py_VISIT(traverse_module_state->__pyx_n_s_reduce_ex); + Py_VISIT(traverse_module_state->__pyx_n_s_register); + Py_VISIT(traverse_module_state->__pyx_n_s_setstate); + Py_VISIT(traverse_module_state->__pyx_n_s_setstate_cython); + Py_VISIT(traverse_module_state->__pyx_n_s_shape); + Py_VISIT(traverse_module_state->__pyx_n_s_size); + Py_VISIT(traverse_module_state->__pyx_n_s_spec); + Py_VISIT(traverse_module_state->__pyx_n_s_split); + Py_VISIT(traverse_module_state->__pyx_n_s_start); + Py_VISIT(traverse_module_state->__pyx_n_s_step); + Py_VISIT(traverse_module_state->__pyx_n_s_stop); + Py_VISIT(traverse_module_state->__pyx_kp_s_strided_and_direct); + Py_VISIT(traverse_module_state->__pyx_kp_s_strided_and_direct_or_indirect); + Py_VISIT(traverse_module_state->__pyx_kp_s_strided_and_indirect); + Py_VISIT(traverse_module_state->__pyx_kp_s_stringsource); + Py_VISIT(traverse_module_state->__pyx_n_s_struct); + Py_VISIT(traverse_module_state->__pyx_n_s_sys); + Py_VISIT(traverse_module_state->__pyx_n_s_test); + Py_VISIT(traverse_module_state->__pyx_kp_s_unable_to_allocate_array_data); + Py_VISIT(traverse_module_state->__pyx_kp_s_unable_to_allocate_shape_and_str); + Py_VISIT(traverse_module_state->__pyx_n_s_unpack); + Py_VISIT(traverse_module_state->__pyx_n_s_update); + Py_VISIT(traverse_module_state->__pyx_n_s_version_info); + Py_VISIT(traverse_module_state->__pyx_n_s_zeros); + Py_VISIT(traverse_module_state->__pyx_int_0); + Py_VISIT(traverse_module_state->__pyx_int_1); + Py_VISIT(traverse_module_state->__pyx_int_3); + Py_VISIT(traverse_module_state->__pyx_int_112105877); + Py_VISIT(traverse_module_state->__pyx_int_136983863); + Py_VISIT(traverse_module_state->__pyx_int_184977713); + Py_VISIT(traverse_module_state->__pyx_int_neg_1); + Py_VISIT(traverse_module_state->__pyx_slice__5); + Py_VISIT(traverse_module_state->__pyx_tuple__4); + Py_VISIT(traverse_module_state->__pyx_tuple__8); + Py_VISIT(traverse_module_state->__pyx_tuple__9); + Py_VISIT(traverse_module_state->__pyx_tuple__10); + Py_VISIT(traverse_module_state->__pyx_tuple__11); + Py_VISIT(traverse_module_state->__pyx_tuple__12); + Py_VISIT(traverse_module_state->__pyx_tuple__13); + Py_VISIT(traverse_module_state->__pyx_tuple__14); + Py_VISIT(traverse_module_state->__pyx_tuple__15); + Py_VISIT(traverse_module_state->__pyx_tuple__16); + Py_VISIT(traverse_module_state->__pyx_tuple__17); + Py_VISIT(traverse_module_state->__pyx_tuple__18); + Py_VISIT(traverse_module_state->__pyx_tuple__19); + Py_VISIT(traverse_module_state->__pyx_tuple__20); + Py_VISIT(traverse_module_state->__pyx_tuple__22); + Py_VISIT(traverse_module_state->__pyx_tuple__24); + Py_VISIT(traverse_module_state->__pyx_tuple__26); + Py_VISIT(traverse_module_state->__pyx_codeobj__21); + Py_VISIT(traverse_module_state->__pyx_codeobj__23); + Py_VISIT(traverse_module_state->__pyx_codeobj__25); + Py_VISIT(traverse_module_state->__pyx_codeobj__27); + return 0; +} +#endif +/* #### Code section: module_state_defines ### */ +#define __pyx_d __pyx_mstate_global->__pyx_d +#define __pyx_b __pyx_mstate_global->__pyx_b +#define __pyx_cython_runtime __pyx_mstate_global->__pyx_cython_runtime +#define __pyx_empty_tuple __pyx_mstate_global->__pyx_empty_tuple +#define __pyx_empty_bytes __pyx_mstate_global->__pyx_empty_bytes +#define __pyx_empty_unicode __pyx_mstate_global->__pyx_empty_unicode +#ifdef __Pyx_CyFunction_USED +#define __pyx_CyFunctionType __pyx_mstate_global->__pyx_CyFunctionType +#endif +#ifdef __Pyx_FusedFunction_USED +#define __pyx_FusedFunctionType __pyx_mstate_global->__pyx_FusedFunctionType +#endif +#ifdef __Pyx_Generator_USED +#define __pyx_GeneratorType __pyx_mstate_global->__pyx_GeneratorType +#endif +#ifdef __Pyx_IterableCoroutine_USED +#define __pyx_IterableCoroutineType __pyx_mstate_global->__pyx_IterableCoroutineType +#endif +#ifdef __Pyx_Coroutine_USED +#define __pyx_CoroutineAwaitType __pyx_mstate_global->__pyx_CoroutineAwaitType +#endif +#ifdef __Pyx_Coroutine_USED +#define __pyx_CoroutineType __pyx_mstate_global->__pyx_CoroutineType +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#define __pyx_ptype_7cpython_4type_type __pyx_mstate_global->__pyx_ptype_7cpython_4type_type +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#define __pyx_ptype_5numpy_dtype __pyx_mstate_global->__pyx_ptype_5numpy_dtype +#define __pyx_ptype_5numpy_flatiter __pyx_mstate_global->__pyx_ptype_5numpy_flatiter +#define __pyx_ptype_5numpy_broadcast __pyx_mstate_global->__pyx_ptype_5numpy_broadcast +#define __pyx_ptype_5numpy_ndarray __pyx_mstate_global->__pyx_ptype_5numpy_ndarray +#define __pyx_ptype_5numpy_generic __pyx_mstate_global->__pyx_ptype_5numpy_generic +#define __pyx_ptype_5numpy_number __pyx_mstate_global->__pyx_ptype_5numpy_number +#define __pyx_ptype_5numpy_integer __pyx_mstate_global->__pyx_ptype_5numpy_integer +#define __pyx_ptype_5numpy_signedinteger __pyx_mstate_global->__pyx_ptype_5numpy_signedinteger +#define __pyx_ptype_5numpy_unsignedinteger __pyx_mstate_global->__pyx_ptype_5numpy_unsignedinteger +#define __pyx_ptype_5numpy_inexact __pyx_mstate_global->__pyx_ptype_5numpy_inexact +#define __pyx_ptype_5numpy_floating __pyx_mstate_global->__pyx_ptype_5numpy_floating +#define __pyx_ptype_5numpy_complexfloating __pyx_mstate_global->__pyx_ptype_5numpy_complexfloating +#define __pyx_ptype_5numpy_flexible __pyx_mstate_global->__pyx_ptype_5numpy_flexible +#define __pyx_ptype_5numpy_character __pyx_mstate_global->__pyx_ptype_5numpy_character +#define __pyx_ptype_5numpy_ufunc __pyx_mstate_global->__pyx_ptype_5numpy_ufunc +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#define __pyx_type___pyx_array __pyx_mstate_global->__pyx_type___pyx_array +#define __pyx_type___pyx_MemviewEnum __pyx_mstate_global->__pyx_type___pyx_MemviewEnum +#define __pyx_type___pyx_memoryview __pyx_mstate_global->__pyx_type___pyx_memoryview +#define __pyx_type___pyx_memoryviewslice __pyx_mstate_global->__pyx_type___pyx_memoryviewslice +#endif +#define __pyx_array_type __pyx_mstate_global->__pyx_array_type +#define __pyx_MemviewEnum_type __pyx_mstate_global->__pyx_MemviewEnum_type +#define __pyx_memoryview_type __pyx_mstate_global->__pyx_memoryview_type +#define __pyx_memoryviewslice_type __pyx_mstate_global->__pyx_memoryviewslice_type +#define __pyx_kp_u_ __pyx_mstate_global->__pyx_kp_u_ +#define __pyx_n_s_ASCII __pyx_mstate_global->__pyx_n_s_ASCII +#define __pyx_kp_s_All_dimensions_preceding_dimensi __pyx_mstate_global->__pyx_kp_s_All_dimensions_preceding_dimensi +#define __pyx_n_s_AssertionError __pyx_mstate_global->__pyx_n_s_AssertionError +#define __pyx_kp_s_Buffer_view_does_not_expose_stri __pyx_mstate_global->__pyx_kp_s_Buffer_view_does_not_expose_stri +#define __pyx_kp_s_Can_only_create_a_buffer_that_is __pyx_mstate_global->__pyx_kp_s_Can_only_create_a_buffer_that_is +#define __pyx_kp_s_Cannot_assign_to_read_only_memor __pyx_mstate_global->__pyx_kp_s_Cannot_assign_to_read_only_memor +#define __pyx_kp_s_Cannot_create_writable_memory_vi __pyx_mstate_global->__pyx_kp_s_Cannot_create_writable_memory_vi +#define __pyx_kp_u_Cannot_index_with_type __pyx_mstate_global->__pyx_kp_u_Cannot_index_with_type +#define __pyx_kp_s_Cannot_transpose_memoryview_with __pyx_mstate_global->__pyx_kp_s_Cannot_transpose_memoryview_with +#define __pyx_kp_s_Dimension_d_is_not_direct __pyx_mstate_global->__pyx_kp_s_Dimension_d_is_not_direct +#define __pyx_n_s_Ellipsis __pyx_mstate_global->__pyx_n_s_Ellipsis +#define __pyx_kp_s_Empty_shape_tuple_for_cython_arr __pyx_mstate_global->__pyx_kp_s_Empty_shape_tuple_for_cython_arr +#define __pyx_n_s_ImportError __pyx_mstate_global->__pyx_n_s_ImportError +#define __pyx_kp_s_Incompatible_checksums_0x_x_vs_0 __pyx_mstate_global->__pyx_kp_s_Incompatible_checksums_0x_x_vs_0 +#define __pyx_n_s_IndexError __pyx_mstate_global->__pyx_n_s_IndexError +#define __pyx_kp_s_Index_out_of_bounds_axis_d __pyx_mstate_global->__pyx_kp_s_Index_out_of_bounds_axis_d +#define __pyx_kp_s_Indirect_dimensions_not_supporte __pyx_mstate_global->__pyx_kp_s_Indirect_dimensions_not_supporte +#define __pyx_kp_u_Invalid_mode_expected_c_or_fortr __pyx_mstate_global->__pyx_kp_u_Invalid_mode_expected_c_or_fortr +#define __pyx_kp_u_Invalid_shape_in_axis __pyx_mstate_global->__pyx_kp_u_Invalid_shape_in_axis +#define __pyx_n_s_MemoryError __pyx_mstate_global->__pyx_n_s_MemoryError +#define __pyx_kp_s_MemoryView_of_r_at_0x_x __pyx_mstate_global->__pyx_kp_s_MemoryView_of_r_at_0x_x +#define __pyx_kp_s_MemoryView_of_r_object __pyx_mstate_global->__pyx_kp_s_MemoryView_of_r_object +#define __pyx_n_b_O __pyx_mstate_global->__pyx_n_b_O +#define __pyx_kp_u_Out_of_bounds_on_buffer_access_a __pyx_mstate_global->__pyx_kp_u_Out_of_bounds_on_buffer_access_a +#define __pyx_n_s_PickleError __pyx_mstate_global->__pyx_n_s_PickleError +#define __pyx_kp_u_Sentences_lengths_should_not_exc __pyx_mstate_global->__pyx_kp_u_Sentences_lengths_should_not_exc +#define __pyx_n_s_Sequence __pyx_mstate_global->__pyx_n_s_Sequence +#define __pyx_kp_s_Step_may_not_be_zero_axis_d __pyx_mstate_global->__pyx_kp_s_Step_may_not_be_zero_axis_d +#define __pyx_n_s_TypeError __pyx_mstate_global->__pyx_n_s_TypeError +#define __pyx_kp_s_Unable_to_convert_item_to_object __pyx_mstate_global->__pyx_kp_s_Unable_to_convert_item_to_object +#define __pyx_n_s_ValueError __pyx_mstate_global->__pyx_n_s_ValueError +#define __pyx_n_s_View_MemoryView __pyx_mstate_global->__pyx_n_s_View_MemoryView +#define __pyx_kp_u__2 __pyx_mstate_global->__pyx_kp_u__2 +#define __pyx_n_s__28 __pyx_mstate_global->__pyx_n_s__28 +#define __pyx_n_s__3 __pyx_mstate_global->__pyx_n_s__3 +#define __pyx_kp_u__6 __pyx_mstate_global->__pyx_kp_u__6 +#define __pyx_kp_u__7 __pyx_mstate_global->__pyx_kp_u__7 +#define __pyx_n_s_abc __pyx_mstate_global->__pyx_n_s_abc +#define __pyx_n_s_allocate_buffer __pyx_mstate_global->__pyx_n_s_allocate_buffer +#define __pyx_kp_u_and __pyx_mstate_global->__pyx_kp_u_and +#define __pyx_n_s_asyncio_coroutines __pyx_mstate_global->__pyx_n_s_asyncio_coroutines +#define __pyx_n_s_base __pyx_mstate_global->__pyx_n_s_base +#define __pyx_n_s_batch_by_size_fn __pyx_mstate_global->__pyx_n_s_batch_by_size_fn +#define __pyx_n_s_batch_by_size_vec __pyx_mstate_global->__pyx_n_s_batch_by_size_vec +#define __pyx_n_s_batch_fixed_shapes_fast __pyx_mstate_global->__pyx_n_s_batch_fixed_shapes_fast +#define __pyx_n_s_bsz_mult __pyx_mstate_global->__pyx_n_s_bsz_mult +#define __pyx_n_s_c __pyx_mstate_global->__pyx_n_s_c +#define __pyx_n_u_c __pyx_mstate_global->__pyx_n_u_c +#define __pyx_n_s_class __pyx_mstate_global->__pyx_n_s_class +#define __pyx_n_s_class_getitem __pyx_mstate_global->__pyx_n_s_class_getitem +#define __pyx_n_s_cline_in_traceback __pyx_mstate_global->__pyx_n_s_cline_in_traceback +#define __pyx_n_s_collections __pyx_mstate_global->__pyx_n_s_collections +#define __pyx_kp_s_collections_abc __pyx_mstate_global->__pyx_kp_s_collections_abc +#define __pyx_kp_s_contiguous_and_direct __pyx_mstate_global->__pyx_kp_s_contiguous_and_direct +#define __pyx_kp_s_contiguous_and_indirect __pyx_mstate_global->__pyx_kp_s_contiguous_and_indirect +#define __pyx_n_s_count __pyx_mstate_global->__pyx_n_s_count +#define __pyx_n_s_dict __pyx_mstate_global->__pyx_n_s_dict +#define __pyx_kp_u_disable __pyx_mstate_global->__pyx_kp_u_disable +#define __pyx_n_s_dtype __pyx_mstate_global->__pyx_n_s_dtype +#define __pyx_n_s_dtype_is_object __pyx_mstate_global->__pyx_n_s_dtype_is_object +#define __pyx_kp_u_enable __pyx_mstate_global->__pyx_kp_u_enable +#define __pyx_n_s_encode __pyx_mstate_global->__pyx_n_s_encode +#define __pyx_n_s_enumerate __pyx_mstate_global->__pyx_n_s_enumerate +#define __pyx_n_s_error __pyx_mstate_global->__pyx_n_s_error +#define __pyx_n_s_fairseq_data_data_utils_fast __pyx_mstate_global->__pyx_n_s_fairseq_data_data_utils_fast +#define __pyx_kp_s_fairseq_data_data_utils_fast_pyx __pyx_mstate_global->__pyx_kp_s_fairseq_data_data_utils_fast_pyx +#define __pyx_n_s_fixed_shapes_sorted __pyx_mstate_global->__pyx_n_s_fixed_shapes_sorted +#define __pyx_n_s_flags __pyx_mstate_global->__pyx_n_s_flags +#define __pyx_n_s_format __pyx_mstate_global->__pyx_n_s_format +#define __pyx_n_s_fortran __pyx_mstate_global->__pyx_n_s_fortran +#define __pyx_n_u_fortran __pyx_mstate_global->__pyx_n_u_fortran +#define __pyx_kp_u_gc __pyx_mstate_global->__pyx_kp_u_gc +#define __pyx_n_s_getstate __pyx_mstate_global->__pyx_n_s_getstate +#define __pyx_kp_u_got __pyx_mstate_global->__pyx_kp_u_got +#define __pyx_kp_u_got_differing_extents_in_dimensi __pyx_mstate_global->__pyx_kp_u_got_differing_extents_in_dimensi +#define __pyx_n_s_id __pyx_mstate_global->__pyx_n_s_id +#define __pyx_n_s_import __pyx_mstate_global->__pyx_n_s_import +#define __pyx_n_s_index __pyx_mstate_global->__pyx_n_s_index +#define __pyx_n_s_indices __pyx_mstate_global->__pyx_n_s_indices +#define __pyx_n_s_initializing __pyx_mstate_global->__pyx_n_s_initializing +#define __pyx_n_s_int32 __pyx_mstate_global->__pyx_n_s_int32 +#define __pyx_n_s_int64 __pyx_mstate_global->__pyx_n_s_int64 +#define __pyx_n_s_is_coroutine __pyx_mstate_global->__pyx_n_s_is_coroutine +#define __pyx_kp_u_isenabled __pyx_mstate_global->__pyx_kp_u_isenabled +#define __pyx_n_s_itemsize __pyx_mstate_global->__pyx_n_s_itemsize +#define __pyx_kp_s_itemsize_0_for_cython_array __pyx_mstate_global->__pyx_kp_s_itemsize_0_for_cython_array +#define __pyx_n_s_main __pyx_mstate_global->__pyx_n_s_main +#define __pyx_n_s_max __pyx_mstate_global->__pyx_n_s_max +#define __pyx_n_s_max_sentences __pyx_mstate_global->__pyx_n_s_max_sentences +#define __pyx_n_s_max_tokens __pyx_mstate_global->__pyx_n_s_max_tokens +#define __pyx_n_s_memview __pyx_mstate_global->__pyx_n_s_memview +#define __pyx_n_s_mode __pyx_mstate_global->__pyx_n_s_mode +#define __pyx_n_s_name __pyx_mstate_global->__pyx_n_s_name +#define __pyx_n_s_name_2 __pyx_mstate_global->__pyx_n_s_name_2 +#define __pyx_n_s_ndim __pyx_mstate_global->__pyx_n_s_ndim +#define __pyx_n_s_new __pyx_mstate_global->__pyx_n_s_new +#define __pyx_kp_s_no_default___reduce___due_to_non __pyx_mstate_global->__pyx_kp_s_no_default___reduce___due_to_non +#define __pyx_n_s_np __pyx_mstate_global->__pyx_n_s_np +#define __pyx_n_s_num_tokens_fn __pyx_mstate_global->__pyx_n_s_num_tokens_fn +#define __pyx_n_s_num_tokens_vec __pyx_mstate_global->__pyx_n_s_num_tokens_vec +#define __pyx_n_s_numpy __pyx_mstate_global->__pyx_n_s_numpy +#define __pyx_kp_u_numpy_core_multiarray_failed_to __pyx_mstate_global->__pyx_kp_u_numpy_core_multiarray_failed_to +#define __pyx_kp_u_numpy_core_umath_failed_to_impor __pyx_mstate_global->__pyx_kp_u_numpy_core_umath_failed_to_impor +#define __pyx_n_s_obj __pyx_mstate_global->__pyx_n_s_obj +#define __pyx_n_s_pack __pyx_mstate_global->__pyx_n_s_pack +#define __pyx_n_s_pickle __pyx_mstate_global->__pyx_n_s_pickle +#define __pyx_n_s_pyx_PickleError __pyx_mstate_global->__pyx_n_s_pyx_PickleError +#define __pyx_n_s_pyx_checksum __pyx_mstate_global->__pyx_n_s_pyx_checksum +#define __pyx_n_s_pyx_result __pyx_mstate_global->__pyx_n_s_pyx_result +#define __pyx_n_s_pyx_state __pyx_mstate_global->__pyx_n_s_pyx_state +#define __pyx_n_s_pyx_type __pyx_mstate_global->__pyx_n_s_pyx_type +#define __pyx_n_s_pyx_unpickle_Enum __pyx_mstate_global->__pyx_n_s_pyx_unpickle_Enum +#define __pyx_n_s_pyx_vtable __pyx_mstate_global->__pyx_n_s_pyx_vtable +#define __pyx_n_s_range __pyx_mstate_global->__pyx_n_s_range +#define __pyx_n_s_reduce __pyx_mstate_global->__pyx_n_s_reduce +#define __pyx_n_s_reduce_cython __pyx_mstate_global->__pyx_n_s_reduce_cython +#define __pyx_n_s_reduce_ex __pyx_mstate_global->__pyx_n_s_reduce_ex +#define __pyx_n_s_register __pyx_mstate_global->__pyx_n_s_register +#define __pyx_n_s_setstate __pyx_mstate_global->__pyx_n_s_setstate +#define __pyx_n_s_setstate_cython __pyx_mstate_global->__pyx_n_s_setstate_cython +#define __pyx_n_s_shape __pyx_mstate_global->__pyx_n_s_shape +#define __pyx_n_s_size __pyx_mstate_global->__pyx_n_s_size +#define __pyx_n_s_spec __pyx_mstate_global->__pyx_n_s_spec +#define __pyx_n_s_split __pyx_mstate_global->__pyx_n_s_split +#define __pyx_n_s_start __pyx_mstate_global->__pyx_n_s_start +#define __pyx_n_s_step __pyx_mstate_global->__pyx_n_s_step +#define __pyx_n_s_stop __pyx_mstate_global->__pyx_n_s_stop +#define __pyx_kp_s_strided_and_direct __pyx_mstate_global->__pyx_kp_s_strided_and_direct +#define __pyx_kp_s_strided_and_direct_or_indirect __pyx_mstate_global->__pyx_kp_s_strided_and_direct_or_indirect +#define __pyx_kp_s_strided_and_indirect __pyx_mstate_global->__pyx_kp_s_strided_and_indirect +#define __pyx_kp_s_stringsource __pyx_mstate_global->__pyx_kp_s_stringsource +#define __pyx_n_s_struct __pyx_mstate_global->__pyx_n_s_struct +#define __pyx_n_s_sys __pyx_mstate_global->__pyx_n_s_sys +#define __pyx_n_s_test __pyx_mstate_global->__pyx_n_s_test +#define __pyx_kp_s_unable_to_allocate_array_data __pyx_mstate_global->__pyx_kp_s_unable_to_allocate_array_data +#define __pyx_kp_s_unable_to_allocate_shape_and_str __pyx_mstate_global->__pyx_kp_s_unable_to_allocate_shape_and_str +#define __pyx_n_s_unpack __pyx_mstate_global->__pyx_n_s_unpack +#define __pyx_n_s_update __pyx_mstate_global->__pyx_n_s_update +#define __pyx_n_s_version_info __pyx_mstate_global->__pyx_n_s_version_info +#define __pyx_n_s_zeros __pyx_mstate_global->__pyx_n_s_zeros +#define __pyx_int_0 __pyx_mstate_global->__pyx_int_0 +#define __pyx_int_1 __pyx_mstate_global->__pyx_int_1 +#define __pyx_int_3 __pyx_mstate_global->__pyx_int_3 +#define __pyx_int_112105877 __pyx_mstate_global->__pyx_int_112105877 +#define __pyx_int_136983863 __pyx_mstate_global->__pyx_int_136983863 +#define __pyx_int_184977713 __pyx_mstate_global->__pyx_int_184977713 +#define __pyx_int_neg_1 __pyx_mstate_global->__pyx_int_neg_1 +#define __pyx_slice__5 __pyx_mstate_global->__pyx_slice__5 +#define __pyx_tuple__4 __pyx_mstate_global->__pyx_tuple__4 +#define __pyx_tuple__8 __pyx_mstate_global->__pyx_tuple__8 +#define __pyx_tuple__9 __pyx_mstate_global->__pyx_tuple__9 +#define __pyx_tuple__10 __pyx_mstate_global->__pyx_tuple__10 +#define __pyx_tuple__11 __pyx_mstate_global->__pyx_tuple__11 +#define __pyx_tuple__12 __pyx_mstate_global->__pyx_tuple__12 +#define __pyx_tuple__13 __pyx_mstate_global->__pyx_tuple__13 +#define __pyx_tuple__14 __pyx_mstate_global->__pyx_tuple__14 +#define __pyx_tuple__15 __pyx_mstate_global->__pyx_tuple__15 +#define __pyx_tuple__16 __pyx_mstate_global->__pyx_tuple__16 +#define __pyx_tuple__17 __pyx_mstate_global->__pyx_tuple__17 +#define __pyx_tuple__18 __pyx_mstate_global->__pyx_tuple__18 +#define __pyx_tuple__19 __pyx_mstate_global->__pyx_tuple__19 +#define __pyx_tuple__20 __pyx_mstate_global->__pyx_tuple__20 +#define __pyx_tuple__22 __pyx_mstate_global->__pyx_tuple__22 +#define __pyx_tuple__24 __pyx_mstate_global->__pyx_tuple__24 +#define __pyx_tuple__26 __pyx_mstate_global->__pyx_tuple__26 +#define __pyx_codeobj__21 __pyx_mstate_global->__pyx_codeobj__21 +#define __pyx_codeobj__23 __pyx_mstate_global->__pyx_codeobj__23 +#define __pyx_codeobj__25 __pyx_mstate_global->__pyx_codeobj__25 +#define __pyx_codeobj__27 __pyx_mstate_global->__pyx_codeobj__27 +/* #### Code section: module_code ### */ + +/* "View.MemoryView":131 + * cdef bint dtype_is_object + * + * def __cinit__(array self, tuple shape, Py_ssize_t itemsize, format not None, # <<<<<<<<<<<<<< + * mode="c", bint allocate_buffer=True): + * + */ + +/* Python wrapper */ +static int __pyx_array___cinit__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ +static int __pyx_array___cinit__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { + PyObject *__pyx_v_shape = 0; + Py_ssize_t __pyx_v_itemsize; + PyObject *__pyx_v_format = 0; + PyObject *__pyx_v_mode = 0; + int __pyx_v_allocate_buffer; + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[5] = {0,0,0,0,0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + int __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__cinit__ (wrapper)", 0); + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return -1; + #endif + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_shape,&__pyx_n_s_itemsize,&__pyx_n_s_format,&__pyx_n_s_mode,&__pyx_n_s_allocate_buffer,0}; + values[3] = __Pyx_Arg_NewRef_VARARGS(((PyObject *)__pyx_n_s_c)); + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 5: values[4] = __Pyx_Arg_VARARGS(__pyx_args, 4); + CYTHON_FALLTHROUGH; + case 4: values[3] = __Pyx_Arg_VARARGS(__pyx_args, 3); + CYTHON_FALLTHROUGH; + case 3: values[2] = __Pyx_Arg_VARARGS(__pyx_args, 2); + CYTHON_FALLTHROUGH; + case 2: values[1] = __Pyx_Arg_VARARGS(__pyx_args, 1); + CYTHON_FALLTHROUGH; + case 1: values[0] = __Pyx_Arg_VARARGS(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_VARARGS(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_shape)) != 0)) { + (void)__Pyx_Arg_NewRef_VARARGS(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 131, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + CYTHON_FALLTHROUGH; + case 1: + if (likely((values[1] = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_itemsize)) != 0)) { + (void)__Pyx_Arg_NewRef_VARARGS(values[1]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 131, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("__cinit__", 0, 3, 5, 1); __PYX_ERR(1, 131, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 2: + if (likely((values[2] = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_format)) != 0)) { + (void)__Pyx_Arg_NewRef_VARARGS(values[2]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 131, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("__cinit__", 0, 3, 5, 2); __PYX_ERR(1, 131, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 3: + if (kw_args > 0) { + PyObject* value = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_mode); + if (value) { values[3] = __Pyx_Arg_NewRef_VARARGS(value); kw_args--; } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 131, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 4: + if (kw_args > 0) { + PyObject* value = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_allocate_buffer); + if (value) { values[4] = __Pyx_Arg_NewRef_VARARGS(value); kw_args--; } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 131, __pyx_L3_error) + } + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "__cinit__") < 0)) __PYX_ERR(1, 131, __pyx_L3_error) + } + } else { + switch (__pyx_nargs) { + case 5: values[4] = __Pyx_Arg_VARARGS(__pyx_args, 4); + CYTHON_FALLTHROUGH; + case 4: values[3] = __Pyx_Arg_VARARGS(__pyx_args, 3); + CYTHON_FALLTHROUGH; + case 3: values[2] = __Pyx_Arg_VARARGS(__pyx_args, 2); + values[1] = __Pyx_Arg_VARARGS(__pyx_args, 1); + values[0] = __Pyx_Arg_VARARGS(__pyx_args, 0); + break; + default: goto __pyx_L5_argtuple_error; + } + } + __pyx_v_shape = ((PyObject*)values[0]); + __pyx_v_itemsize = __Pyx_PyIndex_AsSsize_t(values[1]); if (unlikely((__pyx_v_itemsize == (Py_ssize_t)-1) && PyErr_Occurred())) __PYX_ERR(1, 131, __pyx_L3_error) + __pyx_v_format = values[2]; + __pyx_v_mode = values[3]; + if (values[4]) { + __pyx_v_allocate_buffer = __Pyx_PyObject_IsTrue(values[4]); if (unlikely((__pyx_v_allocate_buffer == (int)-1) && PyErr_Occurred())) __PYX_ERR(1, 132, __pyx_L3_error) + } else { + + /* "View.MemoryView":132 + * + * def __cinit__(array self, tuple shape, Py_ssize_t itemsize, format not None, + * mode="c", bint allocate_buffer=True): # <<<<<<<<<<<<<< + * + * cdef int idx + */ + __pyx_v_allocate_buffer = ((int)1); + } + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("__cinit__", 0, 3, 5, __pyx_nargs); __PYX_ERR(1, 131, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_VARARGS(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("View.MemoryView.array.__cinit__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return -1; + __pyx_L4_argument_unpacking_done:; + if (unlikely(!__Pyx_ArgTypeTest(((PyObject *)__pyx_v_shape), (&PyTuple_Type), 1, "shape", 1))) __PYX_ERR(1, 131, __pyx_L1_error) + if (unlikely(((PyObject *)__pyx_v_format) == Py_None)) { + PyErr_Format(PyExc_TypeError, "Argument '%.200s' must not be None", "format"); __PYX_ERR(1, 131, __pyx_L1_error) + } + __pyx_r = __pyx_array___pyx_pf_15View_dot_MemoryView_5array___cinit__(((struct __pyx_array_obj *)__pyx_v_self), __pyx_v_shape, __pyx_v_itemsize, __pyx_v_format, __pyx_v_mode, __pyx_v_allocate_buffer); + + /* "View.MemoryView":131 + * cdef bint dtype_is_object + * + * def __cinit__(array self, tuple shape, Py_ssize_t itemsize, format not None, # <<<<<<<<<<<<<< + * mode="c", bint allocate_buffer=True): + * + */ + + /* function exit code */ + goto __pyx_L0; + __pyx_L1_error:; + __pyx_r = -1; + __pyx_L0:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_VARARGS(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static int __pyx_array___pyx_pf_15View_dot_MemoryView_5array___cinit__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_shape, Py_ssize_t __pyx_v_itemsize, PyObject *__pyx_v_format, PyObject *__pyx_v_mode, int __pyx_v_allocate_buffer) { + int __pyx_v_idx; + Py_ssize_t __pyx_v_dim; + char __pyx_v_order; + int __pyx_r; + __Pyx_RefNannyDeclarations + Py_ssize_t __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + PyObject *__pyx_t_4 = NULL; + PyObject *__pyx_t_5 = NULL; + PyObject *__pyx_t_6 = NULL; + int __pyx_t_7; + char *__pyx_t_8; + Py_ssize_t __pyx_t_9; + Py_UCS4 __pyx_t_10; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__cinit__", 0); + __Pyx_INCREF(__pyx_v_format); + + /* "View.MemoryView":137 + * cdef Py_ssize_t dim + * + * self.ndim = len(shape) # <<<<<<<<<<<<<< + * self.itemsize = itemsize + * + */ + if (unlikely(__pyx_v_shape == Py_None)) { + PyErr_SetString(PyExc_TypeError, "object of type 'NoneType' has no len()"); + __PYX_ERR(1, 137, __pyx_L1_error) + } + __pyx_t_1 = __Pyx_PyTuple_GET_SIZE(__pyx_v_shape); if (unlikely(__pyx_t_1 == ((Py_ssize_t)-1))) __PYX_ERR(1, 137, __pyx_L1_error) + __pyx_v_self->ndim = ((int)__pyx_t_1); + + /* "View.MemoryView":138 + * + * self.ndim = len(shape) + * self.itemsize = itemsize # <<<<<<<<<<<<<< + * + * if not self.ndim: + */ + __pyx_v_self->itemsize = __pyx_v_itemsize; + + /* "View.MemoryView":140 + * self.itemsize = itemsize + * + * if not self.ndim: # <<<<<<<<<<<<<< + * raise ValueError, "Empty shape tuple for cython.array" + * + */ + __pyx_t_2 = (!(__pyx_v_self->ndim != 0)); + if (unlikely(__pyx_t_2)) { + + /* "View.MemoryView":141 + * + * if not self.ndim: + * raise ValueError, "Empty shape tuple for cython.array" # <<<<<<<<<<<<<< + * + * if itemsize <= 0: + */ + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_kp_s_Empty_shape_tuple_for_cython_arr, 0, 0); + __PYX_ERR(1, 141, __pyx_L1_error) + + /* "View.MemoryView":140 + * self.itemsize = itemsize + * + * if not self.ndim: # <<<<<<<<<<<<<< + * raise ValueError, "Empty shape tuple for cython.array" + * + */ + } + + /* "View.MemoryView":143 + * raise ValueError, "Empty shape tuple for cython.array" + * + * if itemsize <= 0: # <<<<<<<<<<<<<< + * raise ValueError, "itemsize <= 0 for cython.array" + * + */ + __pyx_t_2 = (__pyx_v_itemsize <= 0); + if (unlikely(__pyx_t_2)) { + + /* "View.MemoryView":144 + * + * if itemsize <= 0: + * raise ValueError, "itemsize <= 0 for cython.array" # <<<<<<<<<<<<<< + * + * if not isinstance(format, bytes): + */ + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_kp_s_itemsize_0_for_cython_array, 0, 0); + __PYX_ERR(1, 144, __pyx_L1_error) + + /* "View.MemoryView":143 + * raise ValueError, "Empty shape tuple for cython.array" + * + * if itemsize <= 0: # <<<<<<<<<<<<<< + * raise ValueError, "itemsize <= 0 for cython.array" + * + */ + } + + /* "View.MemoryView":146 + * raise ValueError, "itemsize <= 0 for cython.array" + * + * if not isinstance(format, bytes): # <<<<<<<<<<<<<< + * format = format.encode('ASCII') + * self._format = format # keep a reference to the byte string + */ + __pyx_t_2 = PyBytes_Check(__pyx_v_format); + __pyx_t_3 = (!__pyx_t_2); + if (__pyx_t_3) { + + /* "View.MemoryView":147 + * + * if not isinstance(format, bytes): + * format = format.encode('ASCII') # <<<<<<<<<<<<<< + * self._format = format # keep a reference to the byte string + * self.format = self._format + */ + __pyx_t_5 = __Pyx_PyObject_GetAttrStr(__pyx_v_format, __pyx_n_s_encode); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 147, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_6 = NULL; + __pyx_t_7 = 0; + #if CYTHON_UNPACK_METHODS + if (likely(PyMethod_Check(__pyx_t_5))) { + __pyx_t_6 = PyMethod_GET_SELF(__pyx_t_5); + if (likely(__pyx_t_6)) { + PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_5); + __Pyx_INCREF(__pyx_t_6); + __Pyx_INCREF(function); + __Pyx_DECREF_SET(__pyx_t_5, function); + __pyx_t_7 = 1; + } + } + #endif + { + PyObject *__pyx_callargs[2] = {__pyx_t_6, __pyx_n_s_ASCII}; + __pyx_t_4 = __Pyx_PyObject_FastCall(__pyx_t_5, __pyx_callargs+1-__pyx_t_7, 1+__pyx_t_7); + __Pyx_XDECREF(__pyx_t_6); __pyx_t_6 = 0; + if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 147, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + } + __Pyx_DECREF_SET(__pyx_v_format, __pyx_t_4); + __pyx_t_4 = 0; + + /* "View.MemoryView":146 + * raise ValueError, "itemsize <= 0 for cython.array" + * + * if not isinstance(format, bytes): # <<<<<<<<<<<<<< + * format = format.encode('ASCII') + * self._format = format # keep a reference to the byte string + */ + } + + /* "View.MemoryView":148 + * if not isinstance(format, bytes): + * format = format.encode('ASCII') + * self._format = format # keep a reference to the byte string # <<<<<<<<<<<<<< + * self.format = self._format + * + */ + if (!(likely(PyBytes_CheckExact(__pyx_v_format))||((__pyx_v_format) == Py_None) || __Pyx_RaiseUnexpectedTypeError("bytes", __pyx_v_format))) __PYX_ERR(1, 148, __pyx_L1_error) + __pyx_t_4 = __pyx_v_format; + __Pyx_INCREF(__pyx_t_4); + __Pyx_GIVEREF(__pyx_t_4); + __Pyx_GOTREF(__pyx_v_self->_format); + __Pyx_DECREF(__pyx_v_self->_format); + __pyx_v_self->_format = ((PyObject*)__pyx_t_4); + __pyx_t_4 = 0; + + /* "View.MemoryView":149 + * format = format.encode('ASCII') + * self._format = format # keep a reference to the byte string + * self.format = self._format # <<<<<<<<<<<<<< + * + * + */ + if (unlikely(__pyx_v_self->_format == Py_None)) { + PyErr_SetString(PyExc_TypeError, "expected bytes, NoneType found"); + __PYX_ERR(1, 149, __pyx_L1_error) + } + __pyx_t_8 = __Pyx_PyBytes_AsWritableString(__pyx_v_self->_format); if (unlikely((!__pyx_t_8) && PyErr_Occurred())) __PYX_ERR(1, 149, __pyx_L1_error) + __pyx_v_self->format = __pyx_t_8; + + /* "View.MemoryView":152 + * + * + * self._shape = PyObject_Malloc(sizeof(Py_ssize_t)*self.ndim*2) # <<<<<<<<<<<<<< + * self._strides = self._shape + self.ndim + * + */ + __pyx_v_self->_shape = ((Py_ssize_t *)PyObject_Malloc((((sizeof(Py_ssize_t)) * __pyx_v_self->ndim) * 2))); + + /* "View.MemoryView":153 + * + * self._shape = PyObject_Malloc(sizeof(Py_ssize_t)*self.ndim*2) + * self._strides = self._shape + self.ndim # <<<<<<<<<<<<<< + * + * if not self._shape: + */ + __pyx_v_self->_strides = (__pyx_v_self->_shape + __pyx_v_self->ndim); + + /* "View.MemoryView":155 + * self._strides = self._shape + self.ndim + * + * if not self._shape: # <<<<<<<<<<<<<< + * raise MemoryError, "unable to allocate shape and strides." + * + */ + __pyx_t_3 = (!(__pyx_v_self->_shape != 0)); + if (unlikely(__pyx_t_3)) { + + /* "View.MemoryView":156 + * + * if not self._shape: + * raise MemoryError, "unable to allocate shape and strides." # <<<<<<<<<<<<<< + * + * + */ + __Pyx_Raise(__pyx_builtin_MemoryError, __pyx_kp_s_unable_to_allocate_shape_and_str, 0, 0); + __PYX_ERR(1, 156, __pyx_L1_error) + + /* "View.MemoryView":155 + * self._strides = self._shape + self.ndim + * + * if not self._shape: # <<<<<<<<<<<<<< + * raise MemoryError, "unable to allocate shape and strides." + * + */ + } + + /* "View.MemoryView":159 + * + * + * for idx, dim in enumerate(shape): # <<<<<<<<<<<<<< + * if dim <= 0: + * raise ValueError, f"Invalid shape in axis {idx}: {dim}." + */ + __pyx_t_7 = 0; + __pyx_t_4 = __pyx_v_shape; __Pyx_INCREF(__pyx_t_4); + __pyx_t_1 = 0; + for (;;) { + { + Py_ssize_t __pyx_temp = __Pyx_PyTuple_GET_SIZE(__pyx_t_4); + #if !CYTHON_ASSUME_SAFE_MACROS + if (unlikely((__pyx_temp < 0))) __PYX_ERR(1, 159, __pyx_L1_error) + #endif + if (__pyx_t_1 >= __pyx_temp) break; + } + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + __pyx_t_5 = PyTuple_GET_ITEM(__pyx_t_4, __pyx_t_1); __Pyx_INCREF(__pyx_t_5); __pyx_t_1++; if (unlikely((0 < 0))) __PYX_ERR(1, 159, __pyx_L1_error) + #else + __pyx_t_5 = __Pyx_PySequence_ITEM(__pyx_t_4, __pyx_t_1); __pyx_t_1++; if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 159, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + #endif + __pyx_t_9 = __Pyx_PyIndex_AsSsize_t(__pyx_t_5); if (unlikely((__pyx_t_9 == (Py_ssize_t)-1) && PyErr_Occurred())) __PYX_ERR(1, 159, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __pyx_v_dim = __pyx_t_9; + __pyx_v_idx = __pyx_t_7; + __pyx_t_7 = (__pyx_t_7 + 1); + + /* "View.MemoryView":160 + * + * for idx, dim in enumerate(shape): + * if dim <= 0: # <<<<<<<<<<<<<< + * raise ValueError, f"Invalid shape in axis {idx}: {dim}." + * self._shape[idx] = dim + */ + __pyx_t_3 = (__pyx_v_dim <= 0); + if (unlikely(__pyx_t_3)) { + + /* "View.MemoryView":161 + * for idx, dim in enumerate(shape): + * if dim <= 0: + * raise ValueError, f"Invalid shape in axis {idx}: {dim}." # <<<<<<<<<<<<<< + * self._shape[idx] = dim + * + */ + __pyx_t_5 = PyTuple_New(5); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 161, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_9 = 0; + __pyx_t_10 = 127; + __Pyx_INCREF(__pyx_kp_u_Invalid_shape_in_axis); + __pyx_t_9 += 22; + __Pyx_GIVEREF(__pyx_kp_u_Invalid_shape_in_axis); + PyTuple_SET_ITEM(__pyx_t_5, 0, __pyx_kp_u_Invalid_shape_in_axis); + __pyx_t_6 = __Pyx_PyUnicode_From_int(__pyx_v_idx, 0, ' ', 'd'); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 161, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __pyx_t_9 += __Pyx_PyUnicode_GET_LENGTH(__pyx_t_6); + __Pyx_GIVEREF(__pyx_t_6); + PyTuple_SET_ITEM(__pyx_t_5, 1, __pyx_t_6); + __pyx_t_6 = 0; + __Pyx_INCREF(__pyx_kp_u_); + __pyx_t_9 += 2; + __Pyx_GIVEREF(__pyx_kp_u_); + PyTuple_SET_ITEM(__pyx_t_5, 2, __pyx_kp_u_); + __pyx_t_6 = __Pyx_PyUnicode_From_Py_ssize_t(__pyx_v_dim, 0, ' ', 'd'); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 161, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __pyx_t_9 += __Pyx_PyUnicode_GET_LENGTH(__pyx_t_6); + __Pyx_GIVEREF(__pyx_t_6); + PyTuple_SET_ITEM(__pyx_t_5, 3, __pyx_t_6); + __pyx_t_6 = 0; + __Pyx_INCREF(__pyx_kp_u__2); + __pyx_t_9 += 1; + __Pyx_GIVEREF(__pyx_kp_u__2); + PyTuple_SET_ITEM(__pyx_t_5, 4, __pyx_kp_u__2); + __pyx_t_6 = __Pyx_PyUnicode_Join(__pyx_t_5, 5, __pyx_t_9, __pyx_t_10); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 161, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_t_6, 0, 0); + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + __PYX_ERR(1, 161, __pyx_L1_error) + + /* "View.MemoryView":160 + * + * for idx, dim in enumerate(shape): + * if dim <= 0: # <<<<<<<<<<<<<< + * raise ValueError, f"Invalid shape in axis {idx}: {dim}." + * self._shape[idx] = dim + */ + } + + /* "View.MemoryView":162 + * if dim <= 0: + * raise ValueError, f"Invalid shape in axis {idx}: {dim}." + * self._shape[idx] = dim # <<<<<<<<<<<<<< + * + * cdef char order + */ + (__pyx_v_self->_shape[__pyx_v_idx]) = __pyx_v_dim; + + /* "View.MemoryView":159 + * + * + * for idx, dim in enumerate(shape): # <<<<<<<<<<<<<< + * if dim <= 0: + * raise ValueError, f"Invalid shape in axis {idx}: {dim}." + */ + } + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + + /* "View.MemoryView":165 + * + * cdef char order + * if mode == 'c': # <<<<<<<<<<<<<< + * order = b'C' + * self.mode = u'c' + */ + __pyx_t_3 = (__Pyx_PyString_Equals(__pyx_v_mode, __pyx_n_s_c, Py_EQ)); if (unlikely((__pyx_t_3 < 0))) __PYX_ERR(1, 165, __pyx_L1_error) + if (__pyx_t_3) { + + /* "View.MemoryView":166 + * cdef char order + * if mode == 'c': + * order = b'C' # <<<<<<<<<<<<<< + * self.mode = u'c' + * elif mode == 'fortran': + */ + __pyx_v_order = 'C'; + + /* "View.MemoryView":167 + * if mode == 'c': + * order = b'C' + * self.mode = u'c' # <<<<<<<<<<<<<< + * elif mode == 'fortran': + * order = b'F' + */ + __Pyx_INCREF(__pyx_n_u_c); + __Pyx_GIVEREF(__pyx_n_u_c); + __Pyx_GOTREF(__pyx_v_self->mode); + __Pyx_DECREF(__pyx_v_self->mode); + __pyx_v_self->mode = __pyx_n_u_c; + + /* "View.MemoryView":165 + * + * cdef char order + * if mode == 'c': # <<<<<<<<<<<<<< + * order = b'C' + * self.mode = u'c' + */ + goto __pyx_L11; + } + + /* "View.MemoryView":168 + * order = b'C' + * self.mode = u'c' + * elif mode == 'fortran': # <<<<<<<<<<<<<< + * order = b'F' + * self.mode = u'fortran' + */ + __pyx_t_3 = (__Pyx_PyString_Equals(__pyx_v_mode, __pyx_n_s_fortran, Py_EQ)); if (unlikely((__pyx_t_3 < 0))) __PYX_ERR(1, 168, __pyx_L1_error) + if (likely(__pyx_t_3)) { + + /* "View.MemoryView":169 + * self.mode = u'c' + * elif mode == 'fortran': + * order = b'F' # <<<<<<<<<<<<<< + * self.mode = u'fortran' + * else: + */ + __pyx_v_order = 'F'; + + /* "View.MemoryView":170 + * elif mode == 'fortran': + * order = b'F' + * self.mode = u'fortran' # <<<<<<<<<<<<<< + * else: + * raise ValueError, f"Invalid mode, expected 'c' or 'fortran', got {mode}" + */ + __Pyx_INCREF(__pyx_n_u_fortran); + __Pyx_GIVEREF(__pyx_n_u_fortran); + __Pyx_GOTREF(__pyx_v_self->mode); + __Pyx_DECREF(__pyx_v_self->mode); + __pyx_v_self->mode = __pyx_n_u_fortran; + + /* "View.MemoryView":168 + * order = b'C' + * self.mode = u'c' + * elif mode == 'fortran': # <<<<<<<<<<<<<< + * order = b'F' + * self.mode = u'fortran' + */ + goto __pyx_L11; + } + + /* "View.MemoryView":172 + * self.mode = u'fortran' + * else: + * raise ValueError, f"Invalid mode, expected 'c' or 'fortran', got {mode}" # <<<<<<<<<<<<<< + * + * self.len = fill_contig_strides_array(self._shape, self._strides, itemsize, self.ndim, order) + */ + /*else*/ { + __pyx_t_4 = __Pyx_PyObject_FormatSimple(__pyx_v_mode, __pyx_empty_unicode); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 172, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_6 = __Pyx_PyUnicode_Concat(__pyx_kp_u_Invalid_mode_expected_c_or_fortr, __pyx_t_4); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 172, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_t_6, 0, 0); + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + __PYX_ERR(1, 172, __pyx_L1_error) + } + __pyx_L11:; + + /* "View.MemoryView":174 + * raise ValueError, f"Invalid mode, expected 'c' or 'fortran', got {mode}" + * + * self.len = fill_contig_strides_array(self._shape, self._strides, itemsize, self.ndim, order) # <<<<<<<<<<<<<< + * + * self.free_data = allocate_buffer + */ + __pyx_v_self->len = __pyx_fill_contig_strides_array(__pyx_v_self->_shape, __pyx_v_self->_strides, __pyx_v_itemsize, __pyx_v_self->ndim, __pyx_v_order); + + /* "View.MemoryView":176 + * self.len = fill_contig_strides_array(self._shape, self._strides, itemsize, self.ndim, order) + * + * self.free_data = allocate_buffer # <<<<<<<<<<<<<< + * self.dtype_is_object = format == b'O' + * + */ + __pyx_v_self->free_data = __pyx_v_allocate_buffer; + + /* "View.MemoryView":177 + * + * self.free_data = allocate_buffer + * self.dtype_is_object = format == b'O' # <<<<<<<<<<<<<< + * + * if allocate_buffer: + */ + __pyx_t_6 = PyObject_RichCompare(__pyx_v_format, __pyx_n_b_O, Py_EQ); __Pyx_XGOTREF(__pyx_t_6); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 177, __pyx_L1_error) + __pyx_t_3 = __Pyx_PyObject_IsTrue(__pyx_t_6); if (unlikely((__pyx_t_3 == (int)-1) && PyErr_Occurred())) __PYX_ERR(1, 177, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + __pyx_v_self->dtype_is_object = __pyx_t_3; + + /* "View.MemoryView":179 + * self.dtype_is_object = format == b'O' + * + * if allocate_buffer: # <<<<<<<<<<<<<< + * _allocate_buffer(self) + * + */ + if (__pyx_v_allocate_buffer) { + + /* "View.MemoryView":180 + * + * if allocate_buffer: + * _allocate_buffer(self) # <<<<<<<<<<<<<< + * + * @cname('getbuffer') + */ + __pyx_t_7 = __pyx_array_allocate_buffer(__pyx_v_self); if (unlikely(__pyx_t_7 == ((int)-1))) __PYX_ERR(1, 180, __pyx_L1_error) + + /* "View.MemoryView":179 + * self.dtype_is_object = format == b'O' + * + * if allocate_buffer: # <<<<<<<<<<<<<< + * _allocate_buffer(self) + * + */ + } + + /* "View.MemoryView":131 + * cdef bint dtype_is_object + * + * def __cinit__(array self, tuple shape, Py_ssize_t itemsize, format not None, # <<<<<<<<<<<<<< + * mode="c", bint allocate_buffer=True): + * + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_4); + __Pyx_XDECREF(__pyx_t_5); + __Pyx_XDECREF(__pyx_t_6); + __Pyx_AddTraceback("View.MemoryView.array.__cinit__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_format); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":182 + * _allocate_buffer(self) + * + * @cname('getbuffer') # <<<<<<<<<<<<<< + * def __getbuffer__(self, Py_buffer *info, int flags): + * cdef int bufmode = -1 + */ + +/* Python wrapper */ +CYTHON_UNUSED static int __pyx_array_getbuffer(PyObject *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags); /*proto*/ +CYTHON_UNUSED static int __pyx_array_getbuffer(PyObject *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + int __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__getbuffer__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_array___pyx_pf_15View_dot_MemoryView_5array_2__getbuffer__(((struct __pyx_array_obj *)__pyx_v_self), ((Py_buffer *)__pyx_v_info), ((int)__pyx_v_flags)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static int __pyx_array___pyx_pf_15View_dot_MemoryView_5array_2__getbuffer__(struct __pyx_array_obj *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags) { + int __pyx_v_bufmode; + int __pyx_r; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + char *__pyx_t_2; + Py_ssize_t __pyx_t_3; + int __pyx_t_4; + Py_ssize_t *__pyx_t_5; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + if (unlikely(__pyx_v_info == NULL)) { + PyErr_SetString(PyExc_BufferError, "PyObject_GetBuffer: view==NULL argument is obsolete"); + return -1; + } + __Pyx_RefNannySetupContext("__getbuffer__", 0); + __pyx_v_info->obj = Py_None; __Pyx_INCREF(Py_None); + __Pyx_GIVEREF(__pyx_v_info->obj); + + /* "View.MemoryView":184 + * @cname('getbuffer') + * def __getbuffer__(self, Py_buffer *info, int flags): + * cdef int bufmode = -1 # <<<<<<<<<<<<<< + * if flags & (PyBUF_C_CONTIGUOUS | PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS): + * if self.mode == u"c": + */ + __pyx_v_bufmode = -1; + + /* "View.MemoryView":185 + * def __getbuffer__(self, Py_buffer *info, int flags): + * cdef int bufmode = -1 + * if flags & (PyBUF_C_CONTIGUOUS | PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS): # <<<<<<<<<<<<<< + * if self.mode == u"c": + * bufmode = PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + */ + __pyx_t_1 = ((__pyx_v_flags & ((PyBUF_C_CONTIGUOUS | PyBUF_F_CONTIGUOUS) | PyBUF_ANY_CONTIGUOUS)) != 0); + if (__pyx_t_1) { + + /* "View.MemoryView":186 + * cdef int bufmode = -1 + * if flags & (PyBUF_C_CONTIGUOUS | PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS): + * if self.mode == u"c": # <<<<<<<<<<<<<< + * bufmode = PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * elif self.mode == u"fortran": + */ + __pyx_t_1 = (__Pyx_PyUnicode_Equals(__pyx_v_self->mode, __pyx_n_u_c, Py_EQ)); if (unlikely((__pyx_t_1 < 0))) __PYX_ERR(1, 186, __pyx_L1_error) + if (__pyx_t_1) { + + /* "View.MemoryView":187 + * if flags & (PyBUF_C_CONTIGUOUS | PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS): + * if self.mode == u"c": + * bufmode = PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS # <<<<<<<<<<<<<< + * elif self.mode == u"fortran": + * bufmode = PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + */ + __pyx_v_bufmode = (PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS); + + /* "View.MemoryView":186 + * cdef int bufmode = -1 + * if flags & (PyBUF_C_CONTIGUOUS | PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS): + * if self.mode == u"c": # <<<<<<<<<<<<<< + * bufmode = PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * elif self.mode == u"fortran": + */ + goto __pyx_L4; + } + + /* "View.MemoryView":188 + * if self.mode == u"c": + * bufmode = PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * elif self.mode == u"fortran": # <<<<<<<<<<<<<< + * bufmode = PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * if not (flags & bufmode): + */ + __pyx_t_1 = (__Pyx_PyUnicode_Equals(__pyx_v_self->mode, __pyx_n_u_fortran, Py_EQ)); if (unlikely((__pyx_t_1 < 0))) __PYX_ERR(1, 188, __pyx_L1_error) + if (__pyx_t_1) { + + /* "View.MemoryView":189 + * bufmode = PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * elif self.mode == u"fortran": + * bufmode = PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS # <<<<<<<<<<<<<< + * if not (flags & bufmode): + * raise ValueError, "Can only create a buffer that is contiguous in memory." + */ + __pyx_v_bufmode = (PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS); + + /* "View.MemoryView":188 + * if self.mode == u"c": + * bufmode = PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * elif self.mode == u"fortran": # <<<<<<<<<<<<<< + * bufmode = PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * if not (flags & bufmode): + */ + } + __pyx_L4:; + + /* "View.MemoryView":190 + * elif self.mode == u"fortran": + * bufmode = PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * if not (flags & bufmode): # <<<<<<<<<<<<<< + * raise ValueError, "Can only create a buffer that is contiguous in memory." + * info.buf = self.data + */ + __pyx_t_1 = (!((__pyx_v_flags & __pyx_v_bufmode) != 0)); + if (unlikely(__pyx_t_1)) { + + /* "View.MemoryView":191 + * bufmode = PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * if not (flags & bufmode): + * raise ValueError, "Can only create a buffer that is contiguous in memory." # <<<<<<<<<<<<<< + * info.buf = self.data + * info.len = self.len + */ + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_kp_s_Can_only_create_a_buffer_that_is, 0, 0); + __PYX_ERR(1, 191, __pyx_L1_error) + + /* "View.MemoryView":190 + * elif self.mode == u"fortran": + * bufmode = PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * if not (flags & bufmode): # <<<<<<<<<<<<<< + * raise ValueError, "Can only create a buffer that is contiguous in memory." + * info.buf = self.data + */ + } + + /* "View.MemoryView":185 + * def __getbuffer__(self, Py_buffer *info, int flags): + * cdef int bufmode = -1 + * if flags & (PyBUF_C_CONTIGUOUS | PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS): # <<<<<<<<<<<<<< + * if self.mode == u"c": + * bufmode = PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + */ + } + + /* "View.MemoryView":192 + * if not (flags & bufmode): + * raise ValueError, "Can only create a buffer that is contiguous in memory." + * info.buf = self.data # <<<<<<<<<<<<<< + * info.len = self.len + * + */ + __pyx_t_2 = __pyx_v_self->data; + __pyx_v_info->buf = __pyx_t_2; + + /* "View.MemoryView":193 + * raise ValueError, "Can only create a buffer that is contiguous in memory." + * info.buf = self.data + * info.len = self.len # <<<<<<<<<<<<<< + * + * if flags & PyBUF_STRIDES: + */ + __pyx_t_3 = __pyx_v_self->len; + __pyx_v_info->len = __pyx_t_3; + + /* "View.MemoryView":195 + * info.len = self.len + * + * if flags & PyBUF_STRIDES: # <<<<<<<<<<<<<< + * info.ndim = self.ndim + * info.shape = self._shape + */ + __pyx_t_1 = ((__pyx_v_flags & PyBUF_STRIDES) != 0); + if (__pyx_t_1) { + + /* "View.MemoryView":196 + * + * if flags & PyBUF_STRIDES: + * info.ndim = self.ndim # <<<<<<<<<<<<<< + * info.shape = self._shape + * info.strides = self._strides + */ + __pyx_t_4 = __pyx_v_self->ndim; + __pyx_v_info->ndim = __pyx_t_4; + + /* "View.MemoryView":197 + * if flags & PyBUF_STRIDES: + * info.ndim = self.ndim + * info.shape = self._shape # <<<<<<<<<<<<<< + * info.strides = self._strides + * else: + */ + __pyx_t_5 = __pyx_v_self->_shape; + __pyx_v_info->shape = __pyx_t_5; + + /* "View.MemoryView":198 + * info.ndim = self.ndim + * info.shape = self._shape + * info.strides = self._strides # <<<<<<<<<<<<<< + * else: + * info.ndim = 1 + */ + __pyx_t_5 = __pyx_v_self->_strides; + __pyx_v_info->strides = __pyx_t_5; + + /* "View.MemoryView":195 + * info.len = self.len + * + * if flags & PyBUF_STRIDES: # <<<<<<<<<<<<<< + * info.ndim = self.ndim + * info.shape = self._shape + */ + goto __pyx_L6; + } + + /* "View.MemoryView":200 + * info.strides = self._strides + * else: + * info.ndim = 1 # <<<<<<<<<<<<<< + * info.shape = &self.len if flags & PyBUF_ND else NULL + * info.strides = NULL + */ + /*else*/ { + __pyx_v_info->ndim = 1; + + /* "View.MemoryView":201 + * else: + * info.ndim = 1 + * info.shape = &self.len if flags & PyBUF_ND else NULL # <<<<<<<<<<<<<< + * info.strides = NULL + * + */ + __pyx_t_1 = ((__pyx_v_flags & PyBUF_ND) != 0); + if (__pyx_t_1) { + __pyx_t_5 = (&__pyx_v_self->len); + } else { + __pyx_t_5 = NULL; + } + __pyx_v_info->shape = __pyx_t_5; + + /* "View.MemoryView":202 + * info.ndim = 1 + * info.shape = &self.len if flags & PyBUF_ND else NULL + * info.strides = NULL # <<<<<<<<<<<<<< + * + * info.suboffsets = NULL + */ + __pyx_v_info->strides = NULL; + } + __pyx_L6:; + + /* "View.MemoryView":204 + * info.strides = NULL + * + * info.suboffsets = NULL # <<<<<<<<<<<<<< + * info.itemsize = self.itemsize + * info.readonly = 0 + */ + __pyx_v_info->suboffsets = NULL; + + /* "View.MemoryView":205 + * + * info.suboffsets = NULL + * info.itemsize = self.itemsize # <<<<<<<<<<<<<< + * info.readonly = 0 + * info.format = self.format if flags & PyBUF_FORMAT else NULL + */ + __pyx_t_3 = __pyx_v_self->itemsize; + __pyx_v_info->itemsize = __pyx_t_3; + + /* "View.MemoryView":206 + * info.suboffsets = NULL + * info.itemsize = self.itemsize + * info.readonly = 0 # <<<<<<<<<<<<<< + * info.format = self.format if flags & PyBUF_FORMAT else NULL + * info.obj = self + */ + __pyx_v_info->readonly = 0; + + /* "View.MemoryView":207 + * info.itemsize = self.itemsize + * info.readonly = 0 + * info.format = self.format if flags & PyBUF_FORMAT else NULL # <<<<<<<<<<<<<< + * info.obj = self + * + */ + __pyx_t_1 = ((__pyx_v_flags & PyBUF_FORMAT) != 0); + if (__pyx_t_1) { + __pyx_t_2 = __pyx_v_self->format; + } else { + __pyx_t_2 = NULL; + } + __pyx_v_info->format = __pyx_t_2; + + /* "View.MemoryView":208 + * info.readonly = 0 + * info.format = self.format if flags & PyBUF_FORMAT else NULL + * info.obj = self # <<<<<<<<<<<<<< + * + * def __dealloc__(array self): + */ + __Pyx_INCREF((PyObject *)__pyx_v_self); + __Pyx_GIVEREF((PyObject *)__pyx_v_self); + __Pyx_GOTREF(__pyx_v_info->obj); + __Pyx_DECREF(__pyx_v_info->obj); + __pyx_v_info->obj = ((PyObject *)__pyx_v_self); + + /* "View.MemoryView":182 + * _allocate_buffer(self) + * + * @cname('getbuffer') # <<<<<<<<<<<<<< + * def __getbuffer__(self, Py_buffer *info, int flags): + * cdef int bufmode = -1 + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView.array.__getbuffer__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + if (__pyx_v_info->obj != NULL) { + __Pyx_GOTREF(__pyx_v_info->obj); + __Pyx_DECREF(__pyx_v_info->obj); __pyx_v_info->obj = 0; + } + goto __pyx_L2; + __pyx_L0:; + if (__pyx_v_info->obj == Py_None) { + __Pyx_GOTREF(__pyx_v_info->obj); + __Pyx_DECREF(__pyx_v_info->obj); __pyx_v_info->obj = 0; + } + __pyx_L2:; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":210 + * info.obj = self + * + * def __dealloc__(array self): # <<<<<<<<<<<<<< + * if self.callback_free_data != NULL: + * self.callback_free_data(self.data) + */ + +/* Python wrapper */ +static void __pyx_array___dealloc__(PyObject *__pyx_v_self); /*proto*/ +static void __pyx_array___dealloc__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__dealloc__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_array___pyx_pf_15View_dot_MemoryView_5array_4__dealloc__(((struct __pyx_array_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); +} + +static void __pyx_array___pyx_pf_15View_dot_MemoryView_5array_4__dealloc__(struct __pyx_array_obj *__pyx_v_self) { + int __pyx_t_1; + int __pyx_t_2; + + /* "View.MemoryView":211 + * + * def __dealloc__(array self): + * if self.callback_free_data != NULL: # <<<<<<<<<<<<<< + * self.callback_free_data(self.data) + * elif self.free_data and self.data is not NULL: + */ + __pyx_t_1 = (__pyx_v_self->callback_free_data != NULL); + if (__pyx_t_1) { + + /* "View.MemoryView":212 + * def __dealloc__(array self): + * if self.callback_free_data != NULL: + * self.callback_free_data(self.data) # <<<<<<<<<<<<<< + * elif self.free_data and self.data is not NULL: + * if self.dtype_is_object: + */ + __pyx_v_self->callback_free_data(__pyx_v_self->data); + + /* "View.MemoryView":211 + * + * def __dealloc__(array self): + * if self.callback_free_data != NULL: # <<<<<<<<<<<<<< + * self.callback_free_data(self.data) + * elif self.free_data and self.data is not NULL: + */ + goto __pyx_L3; + } + + /* "View.MemoryView":213 + * if self.callback_free_data != NULL: + * self.callback_free_data(self.data) + * elif self.free_data and self.data is not NULL: # <<<<<<<<<<<<<< + * if self.dtype_is_object: + * refcount_objects_in_slice(self.data, self._shape, self._strides, self.ndim, inc=False) + */ + if (__pyx_v_self->free_data) { + } else { + __pyx_t_1 = __pyx_v_self->free_data; + goto __pyx_L4_bool_binop_done; + } + __pyx_t_2 = (__pyx_v_self->data != NULL); + __pyx_t_1 = __pyx_t_2; + __pyx_L4_bool_binop_done:; + if (__pyx_t_1) { + + /* "View.MemoryView":214 + * self.callback_free_data(self.data) + * elif self.free_data and self.data is not NULL: + * if self.dtype_is_object: # <<<<<<<<<<<<<< + * refcount_objects_in_slice(self.data, self._shape, self._strides, self.ndim, inc=False) + * free(self.data) + */ + if (__pyx_v_self->dtype_is_object) { + + /* "View.MemoryView":215 + * elif self.free_data and self.data is not NULL: + * if self.dtype_is_object: + * refcount_objects_in_slice(self.data, self._shape, self._strides, self.ndim, inc=False) # <<<<<<<<<<<<<< + * free(self.data) + * PyObject_Free(self._shape) + */ + __pyx_memoryview_refcount_objects_in_slice(__pyx_v_self->data, __pyx_v_self->_shape, __pyx_v_self->_strides, __pyx_v_self->ndim, 0); + + /* "View.MemoryView":214 + * self.callback_free_data(self.data) + * elif self.free_data and self.data is not NULL: + * if self.dtype_is_object: # <<<<<<<<<<<<<< + * refcount_objects_in_slice(self.data, self._shape, self._strides, self.ndim, inc=False) + * free(self.data) + */ + } + + /* "View.MemoryView":216 + * if self.dtype_is_object: + * refcount_objects_in_slice(self.data, self._shape, self._strides, self.ndim, inc=False) + * free(self.data) # <<<<<<<<<<<<<< + * PyObject_Free(self._shape) + * + */ + free(__pyx_v_self->data); + + /* "View.MemoryView":213 + * if self.callback_free_data != NULL: + * self.callback_free_data(self.data) + * elif self.free_data and self.data is not NULL: # <<<<<<<<<<<<<< + * if self.dtype_is_object: + * refcount_objects_in_slice(self.data, self._shape, self._strides, self.ndim, inc=False) + */ + } + __pyx_L3:; + + /* "View.MemoryView":217 + * refcount_objects_in_slice(self.data, self._shape, self._strides, self.ndim, inc=False) + * free(self.data) + * PyObject_Free(self._shape) # <<<<<<<<<<<<<< + * + * @property + */ + PyObject_Free(__pyx_v_self->_shape); + + /* "View.MemoryView":210 + * info.obj = self + * + * def __dealloc__(array self): # <<<<<<<<<<<<<< + * if self.callback_free_data != NULL: + * self.callback_free_data(self.data) + */ + + /* function exit code */ +} + +/* "View.MemoryView":219 + * PyObject_Free(self._shape) + * + * @property # <<<<<<<<<<<<<< + * def memview(self): + * return self.get_memview() + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_5array_7memview_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_5array_7memview_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_5array_7memview___get__(((struct __pyx_array_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_5array_7memview___get__(struct __pyx_array_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":221 + * @property + * def memview(self): + * return self.get_memview() # <<<<<<<<<<<<<< + * + * @cname('get_memview') + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = ((struct __pyx_vtabstruct_array *)__pyx_v_self->__pyx_vtab)->get_memview(__pyx_v_self); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 221, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "View.MemoryView":219 + * PyObject_Free(self._shape) + * + * @property # <<<<<<<<<<<<<< + * def memview(self): + * return self.get_memview() + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("View.MemoryView.array.memview.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":224 + * + * @cname('get_memview') + * cdef get_memview(self): # <<<<<<<<<<<<<< + * flags = PyBUF_ANY_CONTIGUOUS|PyBUF_FORMAT|PyBUF_WRITABLE + * return memoryview(self, flags, self.dtype_is_object) + */ + +static PyObject *__pyx_array_get_memview(struct __pyx_array_obj *__pyx_v_self) { + int __pyx_v_flags; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("get_memview", 1); + + /* "View.MemoryView":225 + * @cname('get_memview') + * cdef get_memview(self): + * flags = PyBUF_ANY_CONTIGUOUS|PyBUF_FORMAT|PyBUF_WRITABLE # <<<<<<<<<<<<<< + * return memoryview(self, flags, self.dtype_is_object) + * + */ + __pyx_v_flags = ((PyBUF_ANY_CONTIGUOUS | PyBUF_FORMAT) | PyBUF_WRITABLE); + + /* "View.MemoryView":226 + * cdef get_memview(self): + * flags = PyBUF_ANY_CONTIGUOUS|PyBUF_FORMAT|PyBUF_WRITABLE + * return memoryview(self, flags, self.dtype_is_object) # <<<<<<<<<<<<<< + * + * def __len__(self): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __Pyx_PyInt_From_int(__pyx_v_flags); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 226, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = __Pyx_PyBool_FromLong(__pyx_v_self->dtype_is_object); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 226, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_3 = PyTuple_New(3); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 226, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_INCREF((PyObject *)__pyx_v_self); + __Pyx_GIVEREF((PyObject *)__pyx_v_self); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, ((PyObject *)__pyx_v_self))) __PYX_ERR(1, 226, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_t_1)) __PYX_ERR(1, 226, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_2); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 2, __pyx_t_2)) __PYX_ERR(1, 226, __pyx_L1_error); + __pyx_t_1 = 0; + __pyx_t_2 = 0; + __pyx_t_2 = __Pyx_PyObject_Call(((PyObject *)__pyx_memoryview_type), __pyx_t_3, NULL); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 226, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":224 + * + * @cname('get_memview') + * cdef get_memview(self): # <<<<<<<<<<<<<< + * flags = PyBUF_ANY_CONTIGUOUS|PyBUF_FORMAT|PyBUF_WRITABLE + * return memoryview(self, flags, self.dtype_is_object) + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_AddTraceback("View.MemoryView.array.get_memview", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":228 + * return memoryview(self, flags, self.dtype_is_object) + * + * def __len__(self): # <<<<<<<<<<<<<< + * return self._shape[0] + * + */ + +/* Python wrapper */ +static Py_ssize_t __pyx_array___len__(PyObject *__pyx_v_self); /*proto*/ +static Py_ssize_t __pyx_array___len__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + Py_ssize_t __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__len__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_array___pyx_pf_15View_dot_MemoryView_5array_6__len__(((struct __pyx_array_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static Py_ssize_t __pyx_array___pyx_pf_15View_dot_MemoryView_5array_6__len__(struct __pyx_array_obj *__pyx_v_self) { + Py_ssize_t __pyx_r; + + /* "View.MemoryView":229 + * + * def __len__(self): + * return self._shape[0] # <<<<<<<<<<<<<< + * + * def __getattr__(self, attr): + */ + __pyx_r = (__pyx_v_self->_shape[0]); + goto __pyx_L0; + + /* "View.MemoryView":228 + * return memoryview(self, flags, self.dtype_is_object) + * + * def __len__(self): # <<<<<<<<<<<<<< + * return self._shape[0] + * + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":231 + * return self._shape[0] + * + * def __getattr__(self, attr): # <<<<<<<<<<<<<< + * return getattr(self.memview, attr) + * + */ + +/* Python wrapper */ +static PyObject *__pyx_array___getattr__(PyObject *__pyx_v_self, PyObject *__pyx_v_attr); /*proto*/ +static PyObject *__pyx_array___getattr__(PyObject *__pyx_v_self, PyObject *__pyx_v_attr) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__getattr__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_array___pyx_pf_15View_dot_MemoryView_5array_8__getattr__(((struct __pyx_array_obj *)__pyx_v_self), ((PyObject *)__pyx_v_attr)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_array___pyx_pf_15View_dot_MemoryView_5array_8__getattr__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_attr) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__getattr__", 1); + + /* "View.MemoryView":232 + * + * def __getattr__(self, attr): + * return getattr(self.memview, attr) # <<<<<<<<<<<<<< + * + * def __getitem__(self, item): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self), __pyx_n_s_memview); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 232, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = __Pyx_GetAttr(__pyx_t_1, __pyx_v_attr); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 232, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":231 + * return self._shape[0] + * + * def __getattr__(self, attr): # <<<<<<<<<<<<<< + * return getattr(self.memview, attr) + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.array.__getattr__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":234 + * return getattr(self.memview, attr) + * + * def __getitem__(self, item): # <<<<<<<<<<<<<< + * return self.memview[item] + * + */ + +/* Python wrapper */ +static PyObject *__pyx_array___getitem__(PyObject *__pyx_v_self, PyObject *__pyx_v_item); /*proto*/ +static PyObject *__pyx_array___getitem__(PyObject *__pyx_v_self, PyObject *__pyx_v_item) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__getitem__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_array___pyx_pf_15View_dot_MemoryView_5array_10__getitem__(((struct __pyx_array_obj *)__pyx_v_self), ((PyObject *)__pyx_v_item)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_array___pyx_pf_15View_dot_MemoryView_5array_10__getitem__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_item) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__getitem__", 1); + + /* "View.MemoryView":235 + * + * def __getitem__(self, item): + * return self.memview[item] # <<<<<<<<<<<<<< + * + * def __setitem__(self, item, value): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self), __pyx_n_s_memview); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 235, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = __Pyx_PyObject_GetItem(__pyx_t_1, __pyx_v_item); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 235, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":234 + * return getattr(self.memview, attr) + * + * def __getitem__(self, item): # <<<<<<<<<<<<<< + * return self.memview[item] + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.array.__getitem__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":237 + * return self.memview[item] + * + * def __setitem__(self, item, value): # <<<<<<<<<<<<<< + * self.memview[item] = value + * + */ + +/* Python wrapper */ +static int __pyx_array___setitem__(PyObject *__pyx_v_self, PyObject *__pyx_v_item, PyObject *__pyx_v_value); /*proto*/ +static int __pyx_array___setitem__(PyObject *__pyx_v_self, PyObject *__pyx_v_item, PyObject *__pyx_v_value) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + int __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__setitem__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_array___pyx_pf_15View_dot_MemoryView_5array_12__setitem__(((struct __pyx_array_obj *)__pyx_v_self), ((PyObject *)__pyx_v_item), ((PyObject *)__pyx_v_value)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static int __pyx_array___pyx_pf_15View_dot_MemoryView_5array_12__setitem__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_item, PyObject *__pyx_v_value) { + int __pyx_r; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__setitem__", 1); + + /* "View.MemoryView":238 + * + * def __setitem__(self, item, value): + * self.memview[item] = value # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self), __pyx_n_s_memview); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 238, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + if (unlikely((PyObject_SetItem(__pyx_t_1, __pyx_v_item, __pyx_v_value) < 0))) __PYX_ERR(1, 238, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + + /* "View.MemoryView":237 + * return self.memview[item] + * + * def __setitem__(self, item, value): # <<<<<<<<<<<<<< + * self.memview[item] = value + * + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("View.MemoryView.array.__setitem__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): + */ + +/* Python wrapper */ +static PyObject *__pyx_pw___pyx_array_1__reduce_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_pw___pyx_array_1__reduce_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__reduce_cython__ (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + if (unlikely(__pyx_nargs > 0)) { + __Pyx_RaiseArgtupleInvalid("__reduce_cython__", 1, 0, 0, __pyx_nargs); return NULL;} + if (unlikely(__pyx_kwds) && __Pyx_NumKwargs_FASTCALL(__pyx_kwds) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "__reduce_cython__", 0))) return NULL; + __pyx_r = __pyx_pf___pyx_array___reduce_cython__(((struct __pyx_array_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf___pyx_array___reduce_cython__(CYTHON_UNUSED struct __pyx_array_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__reduce_cython__", 1); + + /* "(tree fragment)":2 + * def __reduce_cython__(self): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" # <<<<<<<<<<<<<< + * def __setstate_cython__(self, __pyx_state): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + */ + __Pyx_Raise(__pyx_builtin_TypeError, __pyx_kp_s_no_default___reduce___due_to_non, 0, 0); + __PYX_ERR(1, 2, __pyx_L1_error) + + /* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView.array.__reduce_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":3 + * def __reduce_cython__(self): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + */ + +/* Python wrapper */ +static PyObject *__pyx_pw___pyx_array_3__setstate_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_pw___pyx_array_3__setstate_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + CYTHON_UNUSED PyObject *__pyx_v___pyx_state = 0; + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[1] = {0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__setstate_cython__ (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_pyx_state,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 1: values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_FASTCALL(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_pyx_state)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 3, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "__setstate_cython__") < 0)) __PYX_ERR(1, 3, __pyx_L3_error) + } + } else if (unlikely(__pyx_nargs != 1)) { + goto __pyx_L5_argtuple_error; + } else { + values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + } + __pyx_v___pyx_state = values[0]; + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("__setstate_cython__", 1, 1, 1, __pyx_nargs); __PYX_ERR(1, 3, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("View.MemoryView.array.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return NULL; + __pyx_L4_argument_unpacking_done:; + __pyx_r = __pyx_pf___pyx_array_2__setstate_cython__(((struct __pyx_array_obj *)__pyx_v_self), __pyx_v___pyx_state); + + /* function exit code */ + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf___pyx_array_2__setstate_cython__(CYTHON_UNUSED struct __pyx_array_obj *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__setstate_cython__", 1); + + /* "(tree fragment)":4 + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" # <<<<<<<<<<<<<< + */ + __Pyx_Raise(__pyx_builtin_TypeError, __pyx_kp_s_no_default___reduce___due_to_non, 0, 0); + __PYX_ERR(1, 4, __pyx_L1_error) + + /* "(tree fragment)":3 + * def __reduce_cython__(self): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView.array.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":248 + * + * @cname("__pyx_array_allocate_buffer") + * cdef int _allocate_buffer(array self) except -1: # <<<<<<<<<<<<<< + * + * + */ + +static int __pyx_array_allocate_buffer(struct __pyx_array_obj *__pyx_v_self) { + Py_ssize_t __pyx_v_i; + PyObject **__pyx_v_p; + int __pyx_r; + int __pyx_t_1; + Py_ssize_t __pyx_t_2; + Py_ssize_t __pyx_t_3; + Py_ssize_t __pyx_t_4; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + + /* "View.MemoryView":254 + * cdef PyObject **p + * + * self.free_data = True # <<<<<<<<<<<<<< + * self.data = malloc(self.len) + * if not self.data: + */ + __pyx_v_self->free_data = 1; + + /* "View.MemoryView":255 + * + * self.free_data = True + * self.data = malloc(self.len) # <<<<<<<<<<<<<< + * if not self.data: + * raise MemoryError, "unable to allocate array data." + */ + __pyx_v_self->data = ((char *)malloc(__pyx_v_self->len)); + + /* "View.MemoryView":256 + * self.free_data = True + * self.data = malloc(self.len) + * if not self.data: # <<<<<<<<<<<<<< + * raise MemoryError, "unable to allocate array data." + * + */ + __pyx_t_1 = (!(__pyx_v_self->data != 0)); + if (unlikely(__pyx_t_1)) { + + /* "View.MemoryView":257 + * self.data = malloc(self.len) + * if not self.data: + * raise MemoryError, "unable to allocate array data." # <<<<<<<<<<<<<< + * + * if self.dtype_is_object: + */ + __Pyx_Raise(__pyx_builtin_MemoryError, __pyx_kp_s_unable_to_allocate_array_data, 0, 0); + __PYX_ERR(1, 257, __pyx_L1_error) + + /* "View.MemoryView":256 + * self.free_data = True + * self.data = malloc(self.len) + * if not self.data: # <<<<<<<<<<<<<< + * raise MemoryError, "unable to allocate array data." + * + */ + } + + /* "View.MemoryView":259 + * raise MemoryError, "unable to allocate array data." + * + * if self.dtype_is_object: # <<<<<<<<<<<<<< + * p = self.data + * for i in range(self.len // self.itemsize): + */ + if (__pyx_v_self->dtype_is_object) { + + /* "View.MemoryView":260 + * + * if self.dtype_is_object: + * p = self.data # <<<<<<<<<<<<<< + * for i in range(self.len // self.itemsize): + * p[i] = Py_None + */ + __pyx_v_p = ((PyObject **)__pyx_v_self->data); + + /* "View.MemoryView":261 + * if self.dtype_is_object: + * p = self.data + * for i in range(self.len // self.itemsize): # <<<<<<<<<<<<<< + * p[i] = Py_None + * Py_INCREF(Py_None) + */ + if (unlikely(__pyx_v_self->itemsize == 0)) { + PyErr_SetString(PyExc_ZeroDivisionError, "integer division or modulo by zero"); + __PYX_ERR(1, 261, __pyx_L1_error) + } + else if (sizeof(Py_ssize_t) == sizeof(long) && (!(((Py_ssize_t)-1) > 0)) && unlikely(__pyx_v_self->itemsize == (Py_ssize_t)-1) && unlikely(__Pyx_UNARY_NEG_WOULD_OVERFLOW(__pyx_v_self->len))) { + PyErr_SetString(PyExc_OverflowError, "value too large to perform division"); + __PYX_ERR(1, 261, __pyx_L1_error) + } + __pyx_t_2 = __Pyx_div_Py_ssize_t(__pyx_v_self->len, __pyx_v_self->itemsize); + __pyx_t_3 = __pyx_t_2; + for (__pyx_t_4 = 0; __pyx_t_4 < __pyx_t_3; __pyx_t_4+=1) { + __pyx_v_i = __pyx_t_4; + + /* "View.MemoryView":262 + * p = self.data + * for i in range(self.len // self.itemsize): + * p[i] = Py_None # <<<<<<<<<<<<<< + * Py_INCREF(Py_None) + * return 0 + */ + (__pyx_v_p[__pyx_v_i]) = Py_None; + + /* "View.MemoryView":263 + * for i in range(self.len // self.itemsize): + * p[i] = Py_None + * Py_INCREF(Py_None) # <<<<<<<<<<<<<< + * return 0 + * + */ + Py_INCREF(Py_None); + } + + /* "View.MemoryView":259 + * raise MemoryError, "unable to allocate array data." + * + * if self.dtype_is_object: # <<<<<<<<<<<<<< + * p = self.data + * for i in range(self.len // self.itemsize): + */ + } + + /* "View.MemoryView":264 + * p[i] = Py_None + * Py_INCREF(Py_None) + * return 0 # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = 0; + goto __pyx_L0; + + /* "View.MemoryView":248 + * + * @cname("__pyx_array_allocate_buffer") + * cdef int _allocate_buffer(array self) except -1: # <<<<<<<<<<<<<< + * + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView._allocate_buffer", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":268 + * + * @cname("__pyx_array_new") + * cdef array array_cwrapper(tuple shape, Py_ssize_t itemsize, char *format, char *c_mode, char *buf): # <<<<<<<<<<<<<< + * cdef array result + * cdef str mode = "fortran" if c_mode[0] == b'f' else "c" # this often comes from a constant C string. + */ + +static struct __pyx_array_obj *__pyx_array_new(PyObject *__pyx_v_shape, Py_ssize_t __pyx_v_itemsize, char *__pyx_v_format, char *__pyx_v_c_mode, char *__pyx_v_buf) { + struct __pyx_array_obj *__pyx_v_result = 0; + PyObject *__pyx_v_mode = 0; + struct __pyx_array_obj *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("array_cwrapper", 1); + + /* "View.MemoryView":270 + * cdef array array_cwrapper(tuple shape, Py_ssize_t itemsize, char *format, char *c_mode, char *buf): + * cdef array result + * cdef str mode = "fortran" if c_mode[0] == b'f' else "c" # this often comes from a constant C string. # <<<<<<<<<<<<<< + * + * if buf is NULL: + */ + __pyx_t_2 = ((__pyx_v_c_mode[0]) == 'f'); + if (__pyx_t_2) { + __Pyx_INCREF(__pyx_n_s_fortran); + __pyx_t_1 = __pyx_n_s_fortran; + } else { + __Pyx_INCREF(__pyx_n_s_c); + __pyx_t_1 = __pyx_n_s_c; + } + __pyx_v_mode = ((PyObject*)__pyx_t_1); + __pyx_t_1 = 0; + + /* "View.MemoryView":272 + * cdef str mode = "fortran" if c_mode[0] == b'f' else "c" # this often comes from a constant C string. + * + * if buf is NULL: # <<<<<<<<<<<<<< + * result = array.__new__(array, shape, itemsize, format, mode) + * else: + */ + __pyx_t_2 = (__pyx_v_buf == NULL); + if (__pyx_t_2) { + + /* "View.MemoryView":273 + * + * if buf is NULL: + * result = array.__new__(array, shape, itemsize, format, mode) # <<<<<<<<<<<<<< + * else: + * result = array.__new__(array, shape, itemsize, format, mode, allocate_buffer=False) + */ + __pyx_t_1 = PyInt_FromSsize_t(__pyx_v_itemsize); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 273, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_3 = __Pyx_PyBytes_FromString(__pyx_v_format); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 273, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_4 = PyTuple_New(4); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 273, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_INCREF(__pyx_v_shape); + __Pyx_GIVEREF(__pyx_v_shape); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 0, __pyx_v_shape)) __PYX_ERR(1, 273, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 1, __pyx_t_1)) __PYX_ERR(1, 273, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_3); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 2, __pyx_t_3)) __PYX_ERR(1, 273, __pyx_L1_error); + __Pyx_INCREF(__pyx_v_mode); + __Pyx_GIVEREF(__pyx_v_mode); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 3, __pyx_v_mode)) __PYX_ERR(1, 273, __pyx_L1_error); + __pyx_t_1 = 0; + __pyx_t_3 = 0; + __pyx_t_3 = ((PyObject *)__pyx_tp_new_array(((PyTypeObject *)__pyx_array_type), __pyx_t_4, NULL)); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 273, __pyx_L1_error) + __Pyx_GOTREF((PyObject *)__pyx_t_3); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __pyx_v_result = ((struct __pyx_array_obj *)__pyx_t_3); + __pyx_t_3 = 0; + + /* "View.MemoryView":272 + * cdef str mode = "fortran" if c_mode[0] == b'f' else "c" # this often comes from a constant C string. + * + * if buf is NULL: # <<<<<<<<<<<<<< + * result = array.__new__(array, shape, itemsize, format, mode) + * else: + */ + goto __pyx_L3; + } + + /* "View.MemoryView":275 + * result = array.__new__(array, shape, itemsize, format, mode) + * else: + * result = array.__new__(array, shape, itemsize, format, mode, allocate_buffer=False) # <<<<<<<<<<<<<< + * result.data = buf + * + */ + /*else*/ { + __pyx_t_3 = PyInt_FromSsize_t(__pyx_v_itemsize); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 275, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_4 = __Pyx_PyBytes_FromString(__pyx_v_format); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 275, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_1 = PyTuple_New(4); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 275, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_INCREF(__pyx_v_shape); + __Pyx_GIVEREF(__pyx_v_shape); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 0, __pyx_v_shape)) __PYX_ERR(1, 275, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_3); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 1, __pyx_t_3)) __PYX_ERR(1, 275, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_4); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 2, __pyx_t_4)) __PYX_ERR(1, 275, __pyx_L1_error); + __Pyx_INCREF(__pyx_v_mode); + __Pyx_GIVEREF(__pyx_v_mode); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 3, __pyx_v_mode)) __PYX_ERR(1, 275, __pyx_L1_error); + __pyx_t_3 = 0; + __pyx_t_4 = 0; + __pyx_t_4 = __Pyx_PyDict_NewPresized(1); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 275, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + if (PyDict_SetItem(__pyx_t_4, __pyx_n_s_allocate_buffer, Py_False) < 0) __PYX_ERR(1, 275, __pyx_L1_error) + __pyx_t_3 = ((PyObject *)__pyx_tp_new_array(((PyTypeObject *)__pyx_array_type), __pyx_t_1, __pyx_t_4)); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 275, __pyx_L1_error) + __Pyx_GOTREF((PyObject *)__pyx_t_3); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __pyx_v_result = ((struct __pyx_array_obj *)__pyx_t_3); + __pyx_t_3 = 0; + + /* "View.MemoryView":276 + * else: + * result = array.__new__(array, shape, itemsize, format, mode, allocate_buffer=False) + * result.data = buf # <<<<<<<<<<<<<< + * + * return result + */ + __pyx_v_result->data = __pyx_v_buf; + } + __pyx_L3:; + + /* "View.MemoryView":278 + * result.data = buf + * + * return result # <<<<<<<<<<<<<< + * + * + */ + __Pyx_XDECREF((PyObject *)__pyx_r); + __Pyx_INCREF((PyObject *)__pyx_v_result); + __pyx_r = __pyx_v_result; + goto __pyx_L0; + + /* "View.MemoryView":268 + * + * @cname("__pyx_array_new") + * cdef array array_cwrapper(tuple shape, Py_ssize_t itemsize, char *format, char *c_mode, char *buf): # <<<<<<<<<<<<<< + * cdef array result + * cdef str mode = "fortran" if c_mode[0] == b'f' else "c" # this often comes from a constant C string. + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_XDECREF(__pyx_t_4); + __Pyx_AddTraceback("View.MemoryView.array_cwrapper", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XDECREF((PyObject *)__pyx_v_result); + __Pyx_XDECREF(__pyx_v_mode); + __Pyx_XGIVEREF((PyObject *)__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":304 + * cdef class Enum(object): + * cdef object name + * def __init__(self, name): # <<<<<<<<<<<<<< + * self.name = name + * def __repr__(self): + */ + +/* Python wrapper */ +static int __pyx_MemviewEnum___init__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ +static int __pyx_MemviewEnum___init__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { + PyObject *__pyx_v_name = 0; + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[1] = {0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + int __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__init__ (wrapper)", 0); + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return -1; + #endif + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_name,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 1: values[0] = __Pyx_Arg_VARARGS(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_VARARGS(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_name)) != 0)) { + (void)__Pyx_Arg_NewRef_VARARGS(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 304, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "__init__") < 0)) __PYX_ERR(1, 304, __pyx_L3_error) + } + } else if (unlikely(__pyx_nargs != 1)) { + goto __pyx_L5_argtuple_error; + } else { + values[0] = __Pyx_Arg_VARARGS(__pyx_args, 0); + } + __pyx_v_name = values[0]; + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("__init__", 1, 1, 1, __pyx_nargs); __PYX_ERR(1, 304, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_VARARGS(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("View.MemoryView.Enum.__init__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return -1; + __pyx_L4_argument_unpacking_done:; + __pyx_r = __pyx_MemviewEnum___pyx_pf_15View_dot_MemoryView_4Enum___init__(((struct __pyx_MemviewEnum_obj *)__pyx_v_self), __pyx_v_name); + + /* function exit code */ + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_VARARGS(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static int __pyx_MemviewEnum___pyx_pf_15View_dot_MemoryView_4Enum___init__(struct __pyx_MemviewEnum_obj *__pyx_v_self, PyObject *__pyx_v_name) { + int __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__init__", 1); + + /* "View.MemoryView":305 + * cdef object name + * def __init__(self, name): + * self.name = name # <<<<<<<<<<<<<< + * def __repr__(self): + * return self.name + */ + __Pyx_INCREF(__pyx_v_name); + __Pyx_GIVEREF(__pyx_v_name); + __Pyx_GOTREF(__pyx_v_self->name); + __Pyx_DECREF(__pyx_v_self->name); + __pyx_v_self->name = __pyx_v_name; + + /* "View.MemoryView":304 + * cdef class Enum(object): + * cdef object name + * def __init__(self, name): # <<<<<<<<<<<<<< + * self.name = name + * def __repr__(self): + */ + + /* function exit code */ + __pyx_r = 0; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":306 + * def __init__(self, name): + * self.name = name + * def __repr__(self): # <<<<<<<<<<<<<< + * return self.name + * + */ + +/* Python wrapper */ +static PyObject *__pyx_MemviewEnum___repr__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_MemviewEnum___repr__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__repr__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_MemviewEnum___pyx_pf_15View_dot_MemoryView_4Enum_2__repr__(((struct __pyx_MemviewEnum_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_MemviewEnum___pyx_pf_15View_dot_MemoryView_4Enum_2__repr__(struct __pyx_MemviewEnum_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__repr__", 1); + + /* "View.MemoryView":307 + * self.name = name + * def __repr__(self): + * return self.name # <<<<<<<<<<<<<< + * + * cdef generic = Enum("") + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(__pyx_v_self->name); + __pyx_r = __pyx_v_self->name; + goto __pyx_L0; + + /* "View.MemoryView":306 + * def __init__(self, name): + * self.name = name + * def __repr__(self): # <<<<<<<<<<<<<< + * return self.name + * + */ + + /* function exit code */ + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * cdef tuple state + * cdef object _dict + */ + +/* Python wrapper */ +static PyObject *__pyx_pw___pyx_MemviewEnum_1__reduce_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_pw___pyx_MemviewEnum_1__reduce_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__reduce_cython__ (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + if (unlikely(__pyx_nargs > 0)) { + __Pyx_RaiseArgtupleInvalid("__reduce_cython__", 1, 0, 0, __pyx_nargs); return NULL;} + if (unlikely(__pyx_kwds) && __Pyx_NumKwargs_FASTCALL(__pyx_kwds) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "__reduce_cython__", 0))) return NULL; + __pyx_r = __pyx_pf___pyx_MemviewEnum___reduce_cython__(((struct __pyx_MemviewEnum_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf___pyx_MemviewEnum___reduce_cython__(struct __pyx_MemviewEnum_obj *__pyx_v_self) { + PyObject *__pyx_v_state = 0; + PyObject *__pyx_v__dict = 0; + int __pyx_v_use_setstate; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__reduce_cython__", 1); + + /* "(tree fragment)":5 + * cdef object _dict + * cdef bint use_setstate + * state = (self.name,) # <<<<<<<<<<<<<< + * _dict = getattr(self, '__dict__', None) + * if _dict is not None: + */ + __pyx_t_1 = PyTuple_New(1); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 5, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_INCREF(__pyx_v_self->name); + __Pyx_GIVEREF(__pyx_v_self->name); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 0, __pyx_v_self->name)) __PYX_ERR(1, 5, __pyx_L1_error); + __pyx_v_state = ((PyObject*)__pyx_t_1); + __pyx_t_1 = 0; + + /* "(tree fragment)":6 + * cdef bint use_setstate + * state = (self.name,) + * _dict = getattr(self, '__dict__', None) # <<<<<<<<<<<<<< + * if _dict is not None: + * state += (_dict,) + */ + __pyx_t_1 = __Pyx_GetAttr3(((PyObject *)__pyx_v_self), __pyx_n_s_dict, Py_None); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 6, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_v__dict = __pyx_t_1; + __pyx_t_1 = 0; + + /* "(tree fragment)":7 + * state = (self.name,) + * _dict = getattr(self, '__dict__', None) + * if _dict is not None: # <<<<<<<<<<<<<< + * state += (_dict,) + * use_setstate = True + */ + __pyx_t_2 = (__pyx_v__dict != Py_None); + if (__pyx_t_2) { + + /* "(tree fragment)":8 + * _dict = getattr(self, '__dict__', None) + * if _dict is not None: + * state += (_dict,) # <<<<<<<<<<<<<< + * use_setstate = True + * else: + */ + __pyx_t_1 = PyTuple_New(1); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 8, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_INCREF(__pyx_v__dict); + __Pyx_GIVEREF(__pyx_v__dict); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 0, __pyx_v__dict)) __PYX_ERR(1, 8, __pyx_L1_error); + __pyx_t_3 = PyNumber_InPlaceAdd(__pyx_v_state, __pyx_t_1); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 8, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_DECREF_SET(__pyx_v_state, ((PyObject*)__pyx_t_3)); + __pyx_t_3 = 0; + + /* "(tree fragment)":9 + * if _dict is not None: + * state += (_dict,) + * use_setstate = True # <<<<<<<<<<<<<< + * else: + * use_setstate = self.name is not None + */ + __pyx_v_use_setstate = 1; + + /* "(tree fragment)":7 + * state = (self.name,) + * _dict = getattr(self, '__dict__', None) + * if _dict is not None: # <<<<<<<<<<<<<< + * state += (_dict,) + * use_setstate = True + */ + goto __pyx_L3; + } + + /* "(tree fragment)":11 + * use_setstate = True + * else: + * use_setstate = self.name is not None # <<<<<<<<<<<<<< + * if use_setstate: + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, None), state + */ + /*else*/ { + __pyx_t_2 = (__pyx_v_self->name != Py_None); + __pyx_v_use_setstate = __pyx_t_2; + } + __pyx_L3:; + + /* "(tree fragment)":12 + * else: + * use_setstate = self.name is not None + * if use_setstate: # <<<<<<<<<<<<<< + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, None), state + * else: + */ + if (__pyx_v_use_setstate) { + + /* "(tree fragment)":13 + * use_setstate = self.name is not None + * if use_setstate: + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, None), state # <<<<<<<<<<<<<< + * else: + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, state) + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_GetModuleGlobalName(__pyx_t_3, __pyx_n_s_pyx_unpickle_Enum); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 13, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_1 = PyTuple_New(3); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 13, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_INCREF(((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); + __Pyx_GIVEREF(((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 0, ((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self))))) __PYX_ERR(1, 13, __pyx_L1_error); + __Pyx_INCREF(__pyx_int_136983863); + __Pyx_GIVEREF(__pyx_int_136983863); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 1, __pyx_int_136983863)) __PYX_ERR(1, 13, __pyx_L1_error); + __Pyx_INCREF(Py_None); + __Pyx_GIVEREF(Py_None); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 2, Py_None)) __PYX_ERR(1, 13, __pyx_L1_error); + __pyx_t_4 = PyTuple_New(3); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 13, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_GIVEREF(__pyx_t_3); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 0, __pyx_t_3)) __PYX_ERR(1, 13, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 1, __pyx_t_1)) __PYX_ERR(1, 13, __pyx_L1_error); + __Pyx_INCREF(__pyx_v_state); + __Pyx_GIVEREF(__pyx_v_state); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 2, __pyx_v_state)) __PYX_ERR(1, 13, __pyx_L1_error); + __pyx_t_3 = 0; + __pyx_t_1 = 0; + __pyx_r = __pyx_t_4; + __pyx_t_4 = 0; + goto __pyx_L0; + + /* "(tree fragment)":12 + * else: + * use_setstate = self.name is not None + * if use_setstate: # <<<<<<<<<<<<<< + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, None), state + * else: + */ + } + + /* "(tree fragment)":15 + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, None), state + * else: + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, state) # <<<<<<<<<<<<<< + * def __setstate_cython__(self, __pyx_state): + * __pyx_unpickle_Enum__set_state(self, __pyx_state) + */ + /*else*/ { + __Pyx_XDECREF(__pyx_r); + __Pyx_GetModuleGlobalName(__pyx_t_4, __pyx_n_s_pyx_unpickle_Enum); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 15, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_1 = PyTuple_New(3); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 15, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_INCREF(((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); + __Pyx_GIVEREF(((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 0, ((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self))))) __PYX_ERR(1, 15, __pyx_L1_error); + __Pyx_INCREF(__pyx_int_136983863); + __Pyx_GIVEREF(__pyx_int_136983863); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 1, __pyx_int_136983863)) __PYX_ERR(1, 15, __pyx_L1_error); + __Pyx_INCREF(__pyx_v_state); + __Pyx_GIVEREF(__pyx_v_state); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 2, __pyx_v_state)) __PYX_ERR(1, 15, __pyx_L1_error); + __pyx_t_3 = PyTuple_New(2); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 15, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_GIVEREF(__pyx_t_4); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_t_4)) __PYX_ERR(1, 15, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_t_1)) __PYX_ERR(1, 15, __pyx_L1_error); + __pyx_t_4 = 0; + __pyx_t_1 = 0; + __pyx_r = __pyx_t_3; + __pyx_t_3 = 0; + goto __pyx_L0; + } + + /* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * cdef tuple state + * cdef object _dict + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_XDECREF(__pyx_t_4); + __Pyx_AddTraceback("View.MemoryView.Enum.__reduce_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_state); + __Pyx_XDECREF(__pyx_v__dict); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":16 + * else: + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, state) + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * __pyx_unpickle_Enum__set_state(self, __pyx_state) + */ + +/* Python wrapper */ +static PyObject *__pyx_pw___pyx_MemviewEnum_3__setstate_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_pw___pyx_MemviewEnum_3__setstate_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + PyObject *__pyx_v___pyx_state = 0; + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[1] = {0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__setstate_cython__ (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_pyx_state,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 1: values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_FASTCALL(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_pyx_state)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 16, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "__setstate_cython__") < 0)) __PYX_ERR(1, 16, __pyx_L3_error) + } + } else if (unlikely(__pyx_nargs != 1)) { + goto __pyx_L5_argtuple_error; + } else { + values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + } + __pyx_v___pyx_state = values[0]; + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("__setstate_cython__", 1, 1, 1, __pyx_nargs); __PYX_ERR(1, 16, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("View.MemoryView.Enum.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return NULL; + __pyx_L4_argument_unpacking_done:; + __pyx_r = __pyx_pf___pyx_MemviewEnum_2__setstate_cython__(((struct __pyx_MemviewEnum_obj *)__pyx_v_self), __pyx_v___pyx_state); + + /* function exit code */ + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf___pyx_MemviewEnum_2__setstate_cython__(struct __pyx_MemviewEnum_obj *__pyx_v_self, PyObject *__pyx_v___pyx_state) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__setstate_cython__", 1); + + /* "(tree fragment)":17 + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, state) + * def __setstate_cython__(self, __pyx_state): + * __pyx_unpickle_Enum__set_state(self, __pyx_state) # <<<<<<<<<<<<<< + */ + if (!(likely(PyTuple_CheckExact(__pyx_v___pyx_state))||((__pyx_v___pyx_state) == Py_None) || __Pyx_RaiseUnexpectedTypeError("tuple", __pyx_v___pyx_state))) __PYX_ERR(1, 17, __pyx_L1_error) + __pyx_t_1 = __pyx_unpickle_Enum__set_state(__pyx_v_self, ((PyObject*)__pyx_v___pyx_state)); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 17, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + + /* "(tree fragment)":16 + * else: + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, state) + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * __pyx_unpickle_Enum__set_state(self, __pyx_state) + */ + + /* function exit code */ + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("View.MemoryView.Enum.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":349 + * cdef __Pyx_TypeInfo *typeinfo + * + * def __cinit__(memoryview self, object obj, int flags, bint dtype_is_object=False): # <<<<<<<<<<<<<< + * self.obj = obj + * self.flags = flags + */ + +/* Python wrapper */ +static int __pyx_memoryview___cinit__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ +static int __pyx_memoryview___cinit__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { + PyObject *__pyx_v_obj = 0; + int __pyx_v_flags; + int __pyx_v_dtype_is_object; + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[3] = {0,0,0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + int __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__cinit__ (wrapper)", 0); + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return -1; + #endif + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_obj,&__pyx_n_s_flags,&__pyx_n_s_dtype_is_object,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 3: values[2] = __Pyx_Arg_VARARGS(__pyx_args, 2); + CYTHON_FALLTHROUGH; + case 2: values[1] = __Pyx_Arg_VARARGS(__pyx_args, 1); + CYTHON_FALLTHROUGH; + case 1: values[0] = __Pyx_Arg_VARARGS(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_VARARGS(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_obj)) != 0)) { + (void)__Pyx_Arg_NewRef_VARARGS(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 349, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + CYTHON_FALLTHROUGH; + case 1: + if (likely((values[1] = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_flags)) != 0)) { + (void)__Pyx_Arg_NewRef_VARARGS(values[1]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 349, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("__cinit__", 0, 2, 3, 1); __PYX_ERR(1, 349, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 2: + if (kw_args > 0) { + PyObject* value = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_dtype_is_object); + if (value) { values[2] = __Pyx_Arg_NewRef_VARARGS(value); kw_args--; } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 349, __pyx_L3_error) + } + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "__cinit__") < 0)) __PYX_ERR(1, 349, __pyx_L3_error) + } + } else { + switch (__pyx_nargs) { + case 3: values[2] = __Pyx_Arg_VARARGS(__pyx_args, 2); + CYTHON_FALLTHROUGH; + case 2: values[1] = __Pyx_Arg_VARARGS(__pyx_args, 1); + values[0] = __Pyx_Arg_VARARGS(__pyx_args, 0); + break; + default: goto __pyx_L5_argtuple_error; + } + } + __pyx_v_obj = values[0]; + __pyx_v_flags = __Pyx_PyInt_As_int(values[1]); if (unlikely((__pyx_v_flags == (int)-1) && PyErr_Occurred())) __PYX_ERR(1, 349, __pyx_L3_error) + if (values[2]) { + __pyx_v_dtype_is_object = __Pyx_PyObject_IsTrue(values[2]); if (unlikely((__pyx_v_dtype_is_object == (int)-1) && PyErr_Occurred())) __PYX_ERR(1, 349, __pyx_L3_error) + } else { + __pyx_v_dtype_is_object = ((int)0); + } + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("__cinit__", 0, 2, 3, __pyx_nargs); __PYX_ERR(1, 349, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_VARARGS(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("View.MemoryView.memoryview.__cinit__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return -1; + __pyx_L4_argument_unpacking_done:; + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview___cinit__(((struct __pyx_memoryview_obj *)__pyx_v_self), __pyx_v_obj, __pyx_v_flags, __pyx_v_dtype_is_object); + + /* function exit code */ + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_VARARGS(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static int __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview___cinit__(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_obj, int __pyx_v_flags, int __pyx_v_dtype_is_object) { + int __pyx_r; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + Py_intptr_t __pyx_t_4; + size_t __pyx_t_5; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__cinit__", 1); + + /* "View.MemoryView":350 + * + * def __cinit__(memoryview self, object obj, int flags, bint dtype_is_object=False): + * self.obj = obj # <<<<<<<<<<<<<< + * self.flags = flags + * if type(self) is memoryview or obj is not None: + */ + __Pyx_INCREF(__pyx_v_obj); + __Pyx_GIVEREF(__pyx_v_obj); + __Pyx_GOTREF(__pyx_v_self->obj); + __Pyx_DECREF(__pyx_v_self->obj); + __pyx_v_self->obj = __pyx_v_obj; + + /* "View.MemoryView":351 + * def __cinit__(memoryview self, object obj, int flags, bint dtype_is_object=False): + * self.obj = obj + * self.flags = flags # <<<<<<<<<<<<<< + * if type(self) is memoryview or obj is not None: + * __Pyx_GetBuffer(obj, &self.view, flags) + */ + __pyx_v_self->flags = __pyx_v_flags; + + /* "View.MemoryView":352 + * self.obj = obj + * self.flags = flags + * if type(self) is memoryview or obj is not None: # <<<<<<<<<<<<<< + * __Pyx_GetBuffer(obj, &self.view, flags) + * if self.view.obj == NULL: + */ + __pyx_t_2 = (((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self))) == ((PyObject *)__pyx_memoryview_type)); + if (!__pyx_t_2) { + } else { + __pyx_t_1 = __pyx_t_2; + goto __pyx_L4_bool_binop_done; + } + __pyx_t_2 = (__pyx_v_obj != Py_None); + __pyx_t_1 = __pyx_t_2; + __pyx_L4_bool_binop_done:; + if (__pyx_t_1) { + + /* "View.MemoryView":353 + * self.flags = flags + * if type(self) is memoryview or obj is not None: + * __Pyx_GetBuffer(obj, &self.view, flags) # <<<<<<<<<<<<<< + * if self.view.obj == NULL: + * (<__pyx_buffer *> &self.view).obj = Py_None + */ + __pyx_t_3 = __Pyx_GetBuffer(__pyx_v_obj, (&__pyx_v_self->view), __pyx_v_flags); if (unlikely(__pyx_t_3 == ((int)-1))) __PYX_ERR(1, 353, __pyx_L1_error) + + /* "View.MemoryView":354 + * if type(self) is memoryview or obj is not None: + * __Pyx_GetBuffer(obj, &self.view, flags) + * if self.view.obj == NULL: # <<<<<<<<<<<<<< + * (<__pyx_buffer *> &self.view).obj = Py_None + * Py_INCREF(Py_None) + */ + __pyx_t_1 = (((PyObject *)__pyx_v_self->view.obj) == NULL); + if (__pyx_t_1) { + + /* "View.MemoryView":355 + * __Pyx_GetBuffer(obj, &self.view, flags) + * if self.view.obj == NULL: + * (<__pyx_buffer *> &self.view).obj = Py_None # <<<<<<<<<<<<<< + * Py_INCREF(Py_None) + * + */ + ((Py_buffer *)(&__pyx_v_self->view))->obj = Py_None; + + /* "View.MemoryView":356 + * if self.view.obj == NULL: + * (<__pyx_buffer *> &self.view).obj = Py_None + * Py_INCREF(Py_None) # <<<<<<<<<<<<<< + * + * if not __PYX_CYTHON_ATOMICS_ENABLED(): + */ + Py_INCREF(Py_None); + + /* "View.MemoryView":354 + * if type(self) is memoryview or obj is not None: + * __Pyx_GetBuffer(obj, &self.view, flags) + * if self.view.obj == NULL: # <<<<<<<<<<<<<< + * (<__pyx_buffer *> &self.view).obj = Py_None + * Py_INCREF(Py_None) + */ + } + + /* "View.MemoryView":352 + * self.obj = obj + * self.flags = flags + * if type(self) is memoryview or obj is not None: # <<<<<<<<<<<<<< + * __Pyx_GetBuffer(obj, &self.view, flags) + * if self.view.obj == NULL: + */ + } + + /* "View.MemoryView":358 + * Py_INCREF(Py_None) + * + * if not __PYX_CYTHON_ATOMICS_ENABLED(): # <<<<<<<<<<<<<< + * global __pyx_memoryview_thread_locks_used + * if __pyx_memoryview_thread_locks_used < 8: + */ + __pyx_t_1 = (!__PYX_CYTHON_ATOMICS_ENABLED()); + if (__pyx_t_1) { + + /* "View.MemoryView":360 + * if not __PYX_CYTHON_ATOMICS_ENABLED(): + * global __pyx_memoryview_thread_locks_used + * if __pyx_memoryview_thread_locks_used < 8: # <<<<<<<<<<<<<< + * self.lock = __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] + * __pyx_memoryview_thread_locks_used += 1 + */ + __pyx_t_1 = (__pyx_memoryview_thread_locks_used < 8); + if (__pyx_t_1) { + + /* "View.MemoryView":361 + * global __pyx_memoryview_thread_locks_used + * if __pyx_memoryview_thread_locks_used < 8: + * self.lock = __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] # <<<<<<<<<<<<<< + * __pyx_memoryview_thread_locks_used += 1 + * if self.lock is NULL: + */ + __pyx_v_self->lock = (__pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used]); + + /* "View.MemoryView":362 + * if __pyx_memoryview_thread_locks_used < 8: + * self.lock = __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] + * __pyx_memoryview_thread_locks_used += 1 # <<<<<<<<<<<<<< + * if self.lock is NULL: + * self.lock = PyThread_allocate_lock() + */ + __pyx_memoryview_thread_locks_used = (__pyx_memoryview_thread_locks_used + 1); + + /* "View.MemoryView":360 + * if not __PYX_CYTHON_ATOMICS_ENABLED(): + * global __pyx_memoryview_thread_locks_used + * if __pyx_memoryview_thread_locks_used < 8: # <<<<<<<<<<<<<< + * self.lock = __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] + * __pyx_memoryview_thread_locks_used += 1 + */ + } + + /* "View.MemoryView":363 + * self.lock = __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] + * __pyx_memoryview_thread_locks_used += 1 + * if self.lock is NULL: # <<<<<<<<<<<<<< + * self.lock = PyThread_allocate_lock() + * if self.lock is NULL: + */ + __pyx_t_1 = (__pyx_v_self->lock == NULL); + if (__pyx_t_1) { + + /* "View.MemoryView":364 + * __pyx_memoryview_thread_locks_used += 1 + * if self.lock is NULL: + * self.lock = PyThread_allocate_lock() # <<<<<<<<<<<<<< + * if self.lock is NULL: + * raise MemoryError + */ + __pyx_v_self->lock = PyThread_allocate_lock(); + + /* "View.MemoryView":365 + * if self.lock is NULL: + * self.lock = PyThread_allocate_lock() + * if self.lock is NULL: # <<<<<<<<<<<<<< + * raise MemoryError + * + */ + __pyx_t_1 = (__pyx_v_self->lock == NULL); + if (unlikely(__pyx_t_1)) { + + /* "View.MemoryView":366 + * self.lock = PyThread_allocate_lock() + * if self.lock is NULL: + * raise MemoryError # <<<<<<<<<<<<<< + * + * if flags & PyBUF_FORMAT: + */ + PyErr_NoMemory(); __PYX_ERR(1, 366, __pyx_L1_error) + + /* "View.MemoryView":365 + * if self.lock is NULL: + * self.lock = PyThread_allocate_lock() + * if self.lock is NULL: # <<<<<<<<<<<<<< + * raise MemoryError + * + */ + } + + /* "View.MemoryView":363 + * self.lock = __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] + * __pyx_memoryview_thread_locks_used += 1 + * if self.lock is NULL: # <<<<<<<<<<<<<< + * self.lock = PyThread_allocate_lock() + * if self.lock is NULL: + */ + } + + /* "View.MemoryView":358 + * Py_INCREF(Py_None) + * + * if not __PYX_CYTHON_ATOMICS_ENABLED(): # <<<<<<<<<<<<<< + * global __pyx_memoryview_thread_locks_used + * if __pyx_memoryview_thread_locks_used < 8: + */ + } + + /* "View.MemoryView":368 + * raise MemoryError + * + * if flags & PyBUF_FORMAT: # <<<<<<<<<<<<<< + * self.dtype_is_object = (self.view.format[0] == b'O' and self.view.format[1] == b'\0') + * else: + */ + __pyx_t_1 = ((__pyx_v_flags & PyBUF_FORMAT) != 0); + if (__pyx_t_1) { + + /* "View.MemoryView":369 + * + * if flags & PyBUF_FORMAT: + * self.dtype_is_object = (self.view.format[0] == b'O' and self.view.format[1] == b'\0') # <<<<<<<<<<<<<< + * else: + * self.dtype_is_object = dtype_is_object + */ + __pyx_t_2 = ((__pyx_v_self->view.format[0]) == 'O'); + if (__pyx_t_2) { + } else { + __pyx_t_1 = __pyx_t_2; + goto __pyx_L12_bool_binop_done; + } + __pyx_t_2 = ((__pyx_v_self->view.format[1]) == '\x00'); + __pyx_t_1 = __pyx_t_2; + __pyx_L12_bool_binop_done:; + __pyx_v_self->dtype_is_object = __pyx_t_1; + + /* "View.MemoryView":368 + * raise MemoryError + * + * if flags & PyBUF_FORMAT: # <<<<<<<<<<<<<< + * self.dtype_is_object = (self.view.format[0] == b'O' and self.view.format[1] == b'\0') + * else: + */ + goto __pyx_L11; + } + + /* "View.MemoryView":371 + * self.dtype_is_object = (self.view.format[0] == b'O' and self.view.format[1] == b'\0') + * else: + * self.dtype_is_object = dtype_is_object # <<<<<<<<<<<<<< + * + * assert (&self.acquisition_count) % sizeof(__pyx_atomic_int_type) == 0 + */ + /*else*/ { + __pyx_v_self->dtype_is_object = __pyx_v_dtype_is_object; + } + __pyx_L11:; + + /* "View.MemoryView":373 + * self.dtype_is_object = dtype_is_object + * + * assert (&self.acquisition_count) % sizeof(__pyx_atomic_int_type) == 0 # <<<<<<<<<<<<<< + * self.typeinfo = NULL + * + */ + #ifndef CYTHON_WITHOUT_ASSERTIONS + if (unlikely(__pyx_assertions_enabled())) { + __pyx_t_4 = ((Py_intptr_t)((void *)(&__pyx_v_self->acquisition_count))); + __pyx_t_5 = (sizeof(__pyx_atomic_int_type)); + if (unlikely(__pyx_t_5 == 0)) { + PyErr_SetString(PyExc_ZeroDivisionError, "integer division or modulo by zero"); + __PYX_ERR(1, 373, __pyx_L1_error) + } + __pyx_t_1 = ((__pyx_t_4 % __pyx_t_5) == 0); + if (unlikely(!__pyx_t_1)) { + __Pyx_Raise(__pyx_builtin_AssertionError, 0, 0, 0); + __PYX_ERR(1, 373, __pyx_L1_error) + } + } + #else + if ((1)); else __PYX_ERR(1, 373, __pyx_L1_error) + #endif + + /* "View.MemoryView":374 + * + * assert (&self.acquisition_count) % sizeof(__pyx_atomic_int_type) == 0 + * self.typeinfo = NULL # <<<<<<<<<<<<<< + * + * def __dealloc__(memoryview self): + */ + __pyx_v_self->typeinfo = NULL; + + /* "View.MemoryView":349 + * cdef __Pyx_TypeInfo *typeinfo + * + * def __cinit__(memoryview self, object obj, int flags, bint dtype_is_object=False): # <<<<<<<<<<<<<< + * self.obj = obj + * self.flags = flags + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView.memoryview.__cinit__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":376 + * self.typeinfo = NULL + * + * def __dealloc__(memoryview self): # <<<<<<<<<<<<<< + * if self.obj is not None: + * __Pyx_ReleaseBuffer(&self.view) + */ + +/* Python wrapper */ +static void __pyx_memoryview___dealloc__(PyObject *__pyx_v_self); /*proto*/ +static void __pyx_memoryview___dealloc__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__dealloc__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_2__dealloc__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); +} + +static void __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_2__dealloc__(struct __pyx_memoryview_obj *__pyx_v_self) { + int __pyx_v_i; + int __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + int __pyx_t_4; + PyThread_type_lock __pyx_t_5; + PyThread_type_lock __pyx_t_6; + + /* "View.MemoryView":377 + * + * def __dealloc__(memoryview self): + * if self.obj is not None: # <<<<<<<<<<<<<< + * __Pyx_ReleaseBuffer(&self.view) + * elif (<__pyx_buffer *> &self.view).obj == Py_None: + */ + __pyx_t_1 = (__pyx_v_self->obj != Py_None); + if (__pyx_t_1) { + + /* "View.MemoryView":378 + * def __dealloc__(memoryview self): + * if self.obj is not None: + * __Pyx_ReleaseBuffer(&self.view) # <<<<<<<<<<<<<< + * elif (<__pyx_buffer *> &self.view).obj == Py_None: + * + */ + __Pyx_ReleaseBuffer((&__pyx_v_self->view)); + + /* "View.MemoryView":377 + * + * def __dealloc__(memoryview self): + * if self.obj is not None: # <<<<<<<<<<<<<< + * __Pyx_ReleaseBuffer(&self.view) + * elif (<__pyx_buffer *> &self.view).obj == Py_None: + */ + goto __pyx_L3; + } + + /* "View.MemoryView":379 + * if self.obj is not None: + * __Pyx_ReleaseBuffer(&self.view) + * elif (<__pyx_buffer *> &self.view).obj == Py_None: # <<<<<<<<<<<<<< + * + * (<__pyx_buffer *> &self.view).obj = NULL + */ + __pyx_t_1 = (((Py_buffer *)(&__pyx_v_self->view))->obj == Py_None); + if (__pyx_t_1) { + + /* "View.MemoryView":381 + * elif (<__pyx_buffer *> &self.view).obj == Py_None: + * + * (<__pyx_buffer *> &self.view).obj = NULL # <<<<<<<<<<<<<< + * Py_DECREF(Py_None) + * + */ + ((Py_buffer *)(&__pyx_v_self->view))->obj = NULL; + + /* "View.MemoryView":382 + * + * (<__pyx_buffer *> &self.view).obj = NULL + * Py_DECREF(Py_None) # <<<<<<<<<<<<<< + * + * cdef int i + */ + Py_DECREF(Py_None); + + /* "View.MemoryView":379 + * if self.obj is not None: + * __Pyx_ReleaseBuffer(&self.view) + * elif (<__pyx_buffer *> &self.view).obj == Py_None: # <<<<<<<<<<<<<< + * + * (<__pyx_buffer *> &self.view).obj = NULL + */ + } + __pyx_L3:; + + /* "View.MemoryView":386 + * cdef int i + * global __pyx_memoryview_thread_locks_used + * if self.lock != NULL: # <<<<<<<<<<<<<< + * for i in range(__pyx_memoryview_thread_locks_used): + * if __pyx_memoryview_thread_locks[i] is self.lock: + */ + __pyx_t_1 = (__pyx_v_self->lock != NULL); + if (__pyx_t_1) { + + /* "View.MemoryView":387 + * global __pyx_memoryview_thread_locks_used + * if self.lock != NULL: + * for i in range(__pyx_memoryview_thread_locks_used): # <<<<<<<<<<<<<< + * if __pyx_memoryview_thread_locks[i] is self.lock: + * __pyx_memoryview_thread_locks_used -= 1 + */ + __pyx_t_2 = __pyx_memoryview_thread_locks_used; + __pyx_t_3 = __pyx_t_2; + for (__pyx_t_4 = 0; __pyx_t_4 < __pyx_t_3; __pyx_t_4+=1) { + __pyx_v_i = __pyx_t_4; + + /* "View.MemoryView":388 + * if self.lock != NULL: + * for i in range(__pyx_memoryview_thread_locks_used): + * if __pyx_memoryview_thread_locks[i] is self.lock: # <<<<<<<<<<<<<< + * __pyx_memoryview_thread_locks_used -= 1 + * if i != __pyx_memoryview_thread_locks_used: + */ + __pyx_t_1 = ((__pyx_memoryview_thread_locks[__pyx_v_i]) == __pyx_v_self->lock); + if (__pyx_t_1) { + + /* "View.MemoryView":389 + * for i in range(__pyx_memoryview_thread_locks_used): + * if __pyx_memoryview_thread_locks[i] is self.lock: + * __pyx_memoryview_thread_locks_used -= 1 # <<<<<<<<<<<<<< + * if i != __pyx_memoryview_thread_locks_used: + * __pyx_memoryview_thread_locks[i], __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] = ( + */ + __pyx_memoryview_thread_locks_used = (__pyx_memoryview_thread_locks_used - 1); + + /* "View.MemoryView":390 + * if __pyx_memoryview_thread_locks[i] is self.lock: + * __pyx_memoryview_thread_locks_used -= 1 + * if i != __pyx_memoryview_thread_locks_used: # <<<<<<<<<<<<<< + * __pyx_memoryview_thread_locks[i], __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] = ( + * __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used], __pyx_memoryview_thread_locks[i]) + */ + __pyx_t_1 = (__pyx_v_i != __pyx_memoryview_thread_locks_used); + if (__pyx_t_1) { + + /* "View.MemoryView":392 + * if i != __pyx_memoryview_thread_locks_used: + * __pyx_memoryview_thread_locks[i], __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] = ( + * __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used], __pyx_memoryview_thread_locks[i]) # <<<<<<<<<<<<<< + * break + * else: + */ + __pyx_t_5 = (__pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used]); + __pyx_t_6 = (__pyx_memoryview_thread_locks[__pyx_v_i]); + + /* "View.MemoryView":391 + * __pyx_memoryview_thread_locks_used -= 1 + * if i != __pyx_memoryview_thread_locks_used: + * __pyx_memoryview_thread_locks[i], __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] = ( # <<<<<<<<<<<<<< + * __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used], __pyx_memoryview_thread_locks[i]) + * break + */ + (__pyx_memoryview_thread_locks[__pyx_v_i]) = __pyx_t_5; + (__pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used]) = __pyx_t_6; + + /* "View.MemoryView":390 + * if __pyx_memoryview_thread_locks[i] is self.lock: + * __pyx_memoryview_thread_locks_used -= 1 + * if i != __pyx_memoryview_thread_locks_used: # <<<<<<<<<<<<<< + * __pyx_memoryview_thread_locks[i], __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] = ( + * __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used], __pyx_memoryview_thread_locks[i]) + */ + } + + /* "View.MemoryView":393 + * __pyx_memoryview_thread_locks[i], __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] = ( + * __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used], __pyx_memoryview_thread_locks[i]) + * break # <<<<<<<<<<<<<< + * else: + * PyThread_free_lock(self.lock) + */ + goto __pyx_L6_break; + + /* "View.MemoryView":388 + * if self.lock != NULL: + * for i in range(__pyx_memoryview_thread_locks_used): + * if __pyx_memoryview_thread_locks[i] is self.lock: # <<<<<<<<<<<<<< + * __pyx_memoryview_thread_locks_used -= 1 + * if i != __pyx_memoryview_thread_locks_used: + */ + } + } + /*else*/ { + + /* "View.MemoryView":395 + * break + * else: + * PyThread_free_lock(self.lock) # <<<<<<<<<<<<<< + * + * cdef char *get_item_pointer(memoryview self, object index) except NULL: + */ + PyThread_free_lock(__pyx_v_self->lock); + } + __pyx_L6_break:; + + /* "View.MemoryView":386 + * cdef int i + * global __pyx_memoryview_thread_locks_used + * if self.lock != NULL: # <<<<<<<<<<<<<< + * for i in range(__pyx_memoryview_thread_locks_used): + * if __pyx_memoryview_thread_locks[i] is self.lock: + */ + } + + /* "View.MemoryView":376 + * self.typeinfo = NULL + * + * def __dealloc__(memoryview self): # <<<<<<<<<<<<<< + * if self.obj is not None: + * __Pyx_ReleaseBuffer(&self.view) + */ + + /* function exit code */ +} + +/* "View.MemoryView":397 + * PyThread_free_lock(self.lock) + * + * cdef char *get_item_pointer(memoryview self, object index) except NULL: # <<<<<<<<<<<<<< + * cdef Py_ssize_t dim + * cdef char *itemp = self.view.buf + */ + +static char *__pyx_memoryview_get_item_pointer(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index) { + Py_ssize_t __pyx_v_dim; + char *__pyx_v_itemp; + PyObject *__pyx_v_idx = NULL; + char *__pyx_r; + __Pyx_RefNannyDeclarations + Py_ssize_t __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + Py_ssize_t __pyx_t_3; + PyObject *(*__pyx_t_4)(PyObject *); + PyObject *__pyx_t_5 = NULL; + Py_ssize_t __pyx_t_6; + char *__pyx_t_7; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("get_item_pointer", 1); + + /* "View.MemoryView":399 + * cdef char *get_item_pointer(memoryview self, object index) except NULL: + * cdef Py_ssize_t dim + * cdef char *itemp = self.view.buf # <<<<<<<<<<<<<< + * + * for dim, idx in enumerate(index): + */ + __pyx_v_itemp = ((char *)__pyx_v_self->view.buf); + + /* "View.MemoryView":401 + * cdef char *itemp = self.view.buf + * + * for dim, idx in enumerate(index): # <<<<<<<<<<<<<< + * itemp = pybuffer_index(&self.view, itemp, idx, dim) + * + */ + __pyx_t_1 = 0; + if (likely(PyList_CheckExact(__pyx_v_index)) || PyTuple_CheckExact(__pyx_v_index)) { + __pyx_t_2 = __pyx_v_index; __Pyx_INCREF(__pyx_t_2); + __pyx_t_3 = 0; + __pyx_t_4 = NULL; + } else { + __pyx_t_3 = -1; __pyx_t_2 = PyObject_GetIter(__pyx_v_index); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 401, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_4 = __Pyx_PyObject_GetIterNextFunc(__pyx_t_2); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 401, __pyx_L1_error) + } + for (;;) { + if (likely(!__pyx_t_4)) { + if (likely(PyList_CheckExact(__pyx_t_2))) { + { + Py_ssize_t __pyx_temp = __Pyx_PyList_GET_SIZE(__pyx_t_2); + #if !CYTHON_ASSUME_SAFE_MACROS + if (unlikely((__pyx_temp < 0))) __PYX_ERR(1, 401, __pyx_L1_error) + #endif + if (__pyx_t_3 >= __pyx_temp) break; + } + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + __pyx_t_5 = PyList_GET_ITEM(__pyx_t_2, __pyx_t_3); __Pyx_INCREF(__pyx_t_5); __pyx_t_3++; if (unlikely((0 < 0))) __PYX_ERR(1, 401, __pyx_L1_error) + #else + __pyx_t_5 = __Pyx_PySequence_ITEM(__pyx_t_2, __pyx_t_3); __pyx_t_3++; if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 401, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + #endif + } else { + { + Py_ssize_t __pyx_temp = __Pyx_PyTuple_GET_SIZE(__pyx_t_2); + #if !CYTHON_ASSUME_SAFE_MACROS + if (unlikely((__pyx_temp < 0))) __PYX_ERR(1, 401, __pyx_L1_error) + #endif + if (__pyx_t_3 >= __pyx_temp) break; + } + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + __pyx_t_5 = PyTuple_GET_ITEM(__pyx_t_2, __pyx_t_3); __Pyx_INCREF(__pyx_t_5); __pyx_t_3++; if (unlikely((0 < 0))) __PYX_ERR(1, 401, __pyx_L1_error) + #else + __pyx_t_5 = __Pyx_PySequence_ITEM(__pyx_t_2, __pyx_t_3); __pyx_t_3++; if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 401, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + #endif + } + } else { + __pyx_t_5 = __pyx_t_4(__pyx_t_2); + if (unlikely(!__pyx_t_5)) { + PyObject* exc_type = PyErr_Occurred(); + if (exc_type) { + if (likely(__Pyx_PyErr_GivenExceptionMatches(exc_type, PyExc_StopIteration))) PyErr_Clear(); + else __PYX_ERR(1, 401, __pyx_L1_error) + } + break; + } + __Pyx_GOTREF(__pyx_t_5); + } + __Pyx_XDECREF_SET(__pyx_v_idx, __pyx_t_5); + __pyx_t_5 = 0; + __pyx_v_dim = __pyx_t_1; + __pyx_t_1 = (__pyx_t_1 + 1); + + /* "View.MemoryView":402 + * + * for dim, idx in enumerate(index): + * itemp = pybuffer_index(&self.view, itemp, idx, dim) # <<<<<<<<<<<<<< + * + * return itemp + */ + __pyx_t_6 = __Pyx_PyIndex_AsSsize_t(__pyx_v_idx); if (unlikely((__pyx_t_6 == (Py_ssize_t)-1) && PyErr_Occurred())) __PYX_ERR(1, 402, __pyx_L1_error) + __pyx_t_7 = __pyx_pybuffer_index((&__pyx_v_self->view), __pyx_v_itemp, __pyx_t_6, __pyx_v_dim); if (unlikely(__pyx_t_7 == ((char *)NULL))) __PYX_ERR(1, 402, __pyx_L1_error) + __pyx_v_itemp = __pyx_t_7; + + /* "View.MemoryView":401 + * cdef char *itemp = self.view.buf + * + * for dim, idx in enumerate(index): # <<<<<<<<<<<<<< + * itemp = pybuffer_index(&self.view, itemp, idx, dim) + * + */ + } + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + + /* "View.MemoryView":404 + * itemp = pybuffer_index(&self.view, itemp, idx, dim) + * + * return itemp # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = __pyx_v_itemp; + goto __pyx_L0; + + /* "View.MemoryView":397 + * PyThread_free_lock(self.lock) + * + * cdef char *get_item_pointer(memoryview self, object index) except NULL: # <<<<<<<<<<<<<< + * cdef Py_ssize_t dim + * cdef char *itemp = self.view.buf + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_5); + __Pyx_AddTraceback("View.MemoryView.memoryview.get_item_pointer", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_idx); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":407 + * + * + * def __getitem__(memoryview self, object index): # <<<<<<<<<<<<<< + * if index is Ellipsis: + * return self + */ + +/* Python wrapper */ +static PyObject *__pyx_memoryview___getitem__(PyObject *__pyx_v_self, PyObject *__pyx_v_index); /*proto*/ +static PyObject *__pyx_memoryview___getitem__(PyObject *__pyx_v_self, PyObject *__pyx_v_index) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__getitem__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_4__getitem__(((struct __pyx_memoryview_obj *)__pyx_v_self), ((PyObject *)__pyx_v_index)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_4__getitem__(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index) { + PyObject *__pyx_v_have_slices = NULL; + PyObject *__pyx_v_indices = NULL; + char *__pyx_v_itemp; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + char *__pyx_t_5; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__getitem__", 1); + + /* "View.MemoryView":408 + * + * def __getitem__(memoryview self, object index): + * if index is Ellipsis: # <<<<<<<<<<<<<< + * return self + * + */ + __pyx_t_1 = (__pyx_v_index == __pyx_builtin_Ellipsis); + if (__pyx_t_1) { + + /* "View.MemoryView":409 + * def __getitem__(memoryview self, object index): + * if index is Ellipsis: + * return self # <<<<<<<<<<<<<< + * + * have_slices, indices = _unellipsify(index, self.view.ndim) + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF((PyObject *)__pyx_v_self); + __pyx_r = ((PyObject *)__pyx_v_self); + goto __pyx_L0; + + /* "View.MemoryView":408 + * + * def __getitem__(memoryview self, object index): + * if index is Ellipsis: # <<<<<<<<<<<<<< + * return self + * + */ + } + + /* "View.MemoryView":411 + * return self + * + * have_slices, indices = _unellipsify(index, self.view.ndim) # <<<<<<<<<<<<<< + * + * cdef char *itemp + */ + __pyx_t_2 = _unellipsify(__pyx_v_index, __pyx_v_self->view.ndim); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 411, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + if (likely(__pyx_t_2 != Py_None)) { + PyObject* sequence = __pyx_t_2; + Py_ssize_t size = __Pyx_PySequence_SIZE(sequence); + if (unlikely(size != 2)) { + if (size > 2) __Pyx_RaiseTooManyValuesError(2); + else if (size >= 0) __Pyx_RaiseNeedMoreValuesError(size); + __PYX_ERR(1, 411, __pyx_L1_error) + } + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + __pyx_t_3 = PyTuple_GET_ITEM(sequence, 0); + __pyx_t_4 = PyTuple_GET_ITEM(sequence, 1); + __Pyx_INCREF(__pyx_t_3); + __Pyx_INCREF(__pyx_t_4); + #else + __pyx_t_3 = PySequence_ITEM(sequence, 0); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 411, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_4 = PySequence_ITEM(sequence, 1); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 411, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + #endif + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + } else { + __Pyx_RaiseNoneNotIterableError(); __PYX_ERR(1, 411, __pyx_L1_error) + } + __pyx_v_have_slices = __pyx_t_3; + __pyx_t_3 = 0; + __pyx_v_indices = __pyx_t_4; + __pyx_t_4 = 0; + + /* "View.MemoryView":414 + * + * cdef char *itemp + * if have_slices: # <<<<<<<<<<<<<< + * return memview_slice(self, indices) + * else: + */ + __pyx_t_1 = __Pyx_PyObject_IsTrue(__pyx_v_have_slices); if (unlikely((__pyx_t_1 < 0))) __PYX_ERR(1, 414, __pyx_L1_error) + if (__pyx_t_1) { + + /* "View.MemoryView":415 + * cdef char *itemp + * if have_slices: + * return memview_slice(self, indices) # <<<<<<<<<<<<<< + * else: + * itemp = self.get_item_pointer(indices) + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = ((PyObject *)__pyx_memview_slice(__pyx_v_self, __pyx_v_indices)); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 415, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":414 + * + * cdef char *itemp + * if have_slices: # <<<<<<<<<<<<<< + * return memview_slice(self, indices) + * else: + */ + } + + /* "View.MemoryView":417 + * return memview_slice(self, indices) + * else: + * itemp = self.get_item_pointer(indices) # <<<<<<<<<<<<<< + * return self.convert_item_to_object(itemp) + * + */ + /*else*/ { + __pyx_t_5 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->get_item_pointer(__pyx_v_self, __pyx_v_indices); if (unlikely(__pyx_t_5 == ((char *)NULL))) __PYX_ERR(1, 417, __pyx_L1_error) + __pyx_v_itemp = __pyx_t_5; + + /* "View.MemoryView":418 + * else: + * itemp = self.get_item_pointer(indices) + * return self.convert_item_to_object(itemp) # <<<<<<<<<<<<<< + * + * def __setitem__(memoryview self, object index, object value): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->convert_item_to_object(__pyx_v_self, __pyx_v_itemp); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 418, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + } + + /* "View.MemoryView":407 + * + * + * def __getitem__(memoryview self, object index): # <<<<<<<<<<<<<< + * if index is Ellipsis: + * return self + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_XDECREF(__pyx_t_4); + __Pyx_AddTraceback("View.MemoryView.memoryview.__getitem__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_have_slices); + __Pyx_XDECREF(__pyx_v_indices); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":420 + * return self.convert_item_to_object(itemp) + * + * def __setitem__(memoryview self, object index, object value): # <<<<<<<<<<<<<< + * if self.view.readonly: + * raise TypeError, "Cannot assign to read-only memoryview" + */ + +/* Python wrapper */ +static int __pyx_memoryview___setitem__(PyObject *__pyx_v_self, PyObject *__pyx_v_index, PyObject *__pyx_v_value); /*proto*/ +static int __pyx_memoryview___setitem__(PyObject *__pyx_v_self, PyObject *__pyx_v_index, PyObject *__pyx_v_value) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + int __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__setitem__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_6__setitem__(((struct __pyx_memoryview_obj *)__pyx_v_self), ((PyObject *)__pyx_v_index), ((PyObject *)__pyx_v_value)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static int __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_6__setitem__(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index, PyObject *__pyx_v_value) { + PyObject *__pyx_v_have_slices = NULL; + PyObject *__pyx_v_obj = NULL; + int __pyx_r; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + int __pyx_t_4; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__setitem__", 0); + __Pyx_INCREF(__pyx_v_index); + + /* "View.MemoryView":421 + * + * def __setitem__(memoryview self, object index, object value): + * if self.view.readonly: # <<<<<<<<<<<<<< + * raise TypeError, "Cannot assign to read-only memoryview" + * + */ + if (unlikely(__pyx_v_self->view.readonly)) { + + /* "View.MemoryView":422 + * def __setitem__(memoryview self, object index, object value): + * if self.view.readonly: + * raise TypeError, "Cannot assign to read-only memoryview" # <<<<<<<<<<<<<< + * + * have_slices, index = _unellipsify(index, self.view.ndim) + */ + __Pyx_Raise(__pyx_builtin_TypeError, __pyx_kp_s_Cannot_assign_to_read_only_memor, 0, 0); + __PYX_ERR(1, 422, __pyx_L1_error) + + /* "View.MemoryView":421 + * + * def __setitem__(memoryview self, object index, object value): + * if self.view.readonly: # <<<<<<<<<<<<<< + * raise TypeError, "Cannot assign to read-only memoryview" + * + */ + } + + /* "View.MemoryView":424 + * raise TypeError, "Cannot assign to read-only memoryview" + * + * have_slices, index = _unellipsify(index, self.view.ndim) # <<<<<<<<<<<<<< + * + * if have_slices: + */ + __pyx_t_1 = _unellipsify(__pyx_v_index, __pyx_v_self->view.ndim); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 424, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + if (likely(__pyx_t_1 != Py_None)) { + PyObject* sequence = __pyx_t_1; + Py_ssize_t size = __Pyx_PySequence_SIZE(sequence); + if (unlikely(size != 2)) { + if (size > 2) __Pyx_RaiseTooManyValuesError(2); + else if (size >= 0) __Pyx_RaiseNeedMoreValuesError(size); + __PYX_ERR(1, 424, __pyx_L1_error) + } + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + __pyx_t_2 = PyTuple_GET_ITEM(sequence, 0); + __pyx_t_3 = PyTuple_GET_ITEM(sequence, 1); + __Pyx_INCREF(__pyx_t_2); + __Pyx_INCREF(__pyx_t_3); + #else + __pyx_t_2 = PySequence_ITEM(sequence, 0); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 424, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_3 = PySequence_ITEM(sequence, 1); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 424, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + #endif + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + } else { + __Pyx_RaiseNoneNotIterableError(); __PYX_ERR(1, 424, __pyx_L1_error) + } + __pyx_v_have_slices = __pyx_t_2; + __pyx_t_2 = 0; + __Pyx_DECREF_SET(__pyx_v_index, __pyx_t_3); + __pyx_t_3 = 0; + + /* "View.MemoryView":426 + * have_slices, index = _unellipsify(index, self.view.ndim) + * + * if have_slices: # <<<<<<<<<<<<<< + * obj = self.is_slice(value) + * if obj: + */ + __pyx_t_4 = __Pyx_PyObject_IsTrue(__pyx_v_have_slices); if (unlikely((__pyx_t_4 < 0))) __PYX_ERR(1, 426, __pyx_L1_error) + if (__pyx_t_4) { + + /* "View.MemoryView":427 + * + * if have_slices: + * obj = self.is_slice(value) # <<<<<<<<<<<<<< + * if obj: + * self.setitem_slice_assignment(self[index], obj) + */ + __pyx_t_1 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->is_slice(__pyx_v_self, __pyx_v_value); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 427, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_v_obj = __pyx_t_1; + __pyx_t_1 = 0; + + /* "View.MemoryView":428 + * if have_slices: + * obj = self.is_slice(value) + * if obj: # <<<<<<<<<<<<<< + * self.setitem_slice_assignment(self[index], obj) + * else: + */ + __pyx_t_4 = __Pyx_PyObject_IsTrue(__pyx_v_obj); if (unlikely((__pyx_t_4 < 0))) __PYX_ERR(1, 428, __pyx_L1_error) + if (__pyx_t_4) { + + /* "View.MemoryView":429 + * obj = self.is_slice(value) + * if obj: + * self.setitem_slice_assignment(self[index], obj) # <<<<<<<<<<<<<< + * else: + * self.setitem_slice_assign_scalar(self[index], value) + */ + __pyx_t_1 = __Pyx_PyObject_GetItem(((PyObject *)__pyx_v_self), __pyx_v_index); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 429, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_3 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->setitem_slice_assignment(__pyx_v_self, __pyx_t_1, __pyx_v_obj); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 429, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + + /* "View.MemoryView":428 + * if have_slices: + * obj = self.is_slice(value) + * if obj: # <<<<<<<<<<<<<< + * self.setitem_slice_assignment(self[index], obj) + * else: + */ + goto __pyx_L5; + } + + /* "View.MemoryView":431 + * self.setitem_slice_assignment(self[index], obj) + * else: + * self.setitem_slice_assign_scalar(self[index], value) # <<<<<<<<<<<<<< + * else: + * self.setitem_indexed(index, value) + */ + /*else*/ { + __pyx_t_3 = __Pyx_PyObject_GetItem(((PyObject *)__pyx_v_self), __pyx_v_index); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 431, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + if (!(likely(((__pyx_t_3) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_3, __pyx_memoryview_type))))) __PYX_ERR(1, 431, __pyx_L1_error) + __pyx_t_1 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->setitem_slice_assign_scalar(__pyx_v_self, ((struct __pyx_memoryview_obj *)__pyx_t_3), __pyx_v_value); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 431, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + } + __pyx_L5:; + + /* "View.MemoryView":426 + * have_slices, index = _unellipsify(index, self.view.ndim) + * + * if have_slices: # <<<<<<<<<<<<<< + * obj = self.is_slice(value) + * if obj: + */ + goto __pyx_L4; + } + + /* "View.MemoryView":433 + * self.setitem_slice_assign_scalar(self[index], value) + * else: + * self.setitem_indexed(index, value) # <<<<<<<<<<<<<< + * + * cdef is_slice(self, obj): + */ + /*else*/ { + __pyx_t_1 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->setitem_indexed(__pyx_v_self, __pyx_v_index, __pyx_v_value); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 433, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + } + __pyx_L4:; + + /* "View.MemoryView":420 + * return self.convert_item_to_object(itemp) + * + * def __setitem__(memoryview self, object index, object value): # <<<<<<<<<<<<<< + * if self.view.readonly: + * raise TypeError, "Cannot assign to read-only memoryview" + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_AddTraceback("View.MemoryView.memoryview.__setitem__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_have_slices); + __Pyx_XDECREF(__pyx_v_obj); + __Pyx_XDECREF(__pyx_v_index); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":435 + * self.setitem_indexed(index, value) + * + * cdef is_slice(self, obj): # <<<<<<<<<<<<<< + * if not isinstance(obj, memoryview): + * try: + */ + +static PyObject *__pyx_memoryview_is_slice(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_obj) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + PyObject *__pyx_t_5 = NULL; + PyObject *__pyx_t_6 = NULL; + PyObject *__pyx_t_7 = NULL; + PyObject *__pyx_t_8 = NULL; + int __pyx_t_9; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("is_slice", 0); + __Pyx_INCREF(__pyx_v_obj); + + /* "View.MemoryView":436 + * + * cdef is_slice(self, obj): + * if not isinstance(obj, memoryview): # <<<<<<<<<<<<<< + * try: + * obj = memoryview(obj, self.flags & ~PyBUF_WRITABLE | PyBUF_ANY_CONTIGUOUS, + */ + __pyx_t_1 = __Pyx_TypeCheck(__pyx_v_obj, __pyx_memoryview_type); + __pyx_t_2 = (!__pyx_t_1); + if (__pyx_t_2) { + + /* "View.MemoryView":437 + * cdef is_slice(self, obj): + * if not isinstance(obj, memoryview): + * try: # <<<<<<<<<<<<<< + * obj = memoryview(obj, self.flags & ~PyBUF_WRITABLE | PyBUF_ANY_CONTIGUOUS, + * self.dtype_is_object) + */ + { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ExceptionSave(&__pyx_t_3, &__pyx_t_4, &__pyx_t_5); + __Pyx_XGOTREF(__pyx_t_3); + __Pyx_XGOTREF(__pyx_t_4); + __Pyx_XGOTREF(__pyx_t_5); + /*try:*/ { + + /* "View.MemoryView":438 + * if not isinstance(obj, memoryview): + * try: + * obj = memoryview(obj, self.flags & ~PyBUF_WRITABLE | PyBUF_ANY_CONTIGUOUS, # <<<<<<<<<<<<<< + * self.dtype_is_object) + * except TypeError: + */ + __pyx_t_6 = __Pyx_PyInt_From_int(((__pyx_v_self->flags & (~PyBUF_WRITABLE)) | PyBUF_ANY_CONTIGUOUS)); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 438, __pyx_L4_error) + __Pyx_GOTREF(__pyx_t_6); + + /* "View.MemoryView":439 + * try: + * obj = memoryview(obj, self.flags & ~PyBUF_WRITABLE | PyBUF_ANY_CONTIGUOUS, + * self.dtype_is_object) # <<<<<<<<<<<<<< + * except TypeError: + * return None + */ + __pyx_t_7 = __Pyx_PyBool_FromLong(__pyx_v_self->dtype_is_object); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 439, __pyx_L4_error) + __Pyx_GOTREF(__pyx_t_7); + + /* "View.MemoryView":438 + * if not isinstance(obj, memoryview): + * try: + * obj = memoryview(obj, self.flags & ~PyBUF_WRITABLE | PyBUF_ANY_CONTIGUOUS, # <<<<<<<<<<<<<< + * self.dtype_is_object) + * except TypeError: + */ + __pyx_t_8 = PyTuple_New(3); if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 438, __pyx_L4_error) + __Pyx_GOTREF(__pyx_t_8); + __Pyx_INCREF(__pyx_v_obj); + __Pyx_GIVEREF(__pyx_v_obj); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_8, 0, __pyx_v_obj)) __PYX_ERR(1, 438, __pyx_L4_error); + __Pyx_GIVEREF(__pyx_t_6); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_8, 1, __pyx_t_6)) __PYX_ERR(1, 438, __pyx_L4_error); + __Pyx_GIVEREF(__pyx_t_7); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_8, 2, __pyx_t_7)) __PYX_ERR(1, 438, __pyx_L4_error); + __pyx_t_6 = 0; + __pyx_t_7 = 0; + __pyx_t_7 = __Pyx_PyObject_Call(((PyObject *)__pyx_memoryview_type), __pyx_t_8, NULL); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 438, __pyx_L4_error) + __Pyx_GOTREF(__pyx_t_7); + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + __Pyx_DECREF_SET(__pyx_v_obj, __pyx_t_7); + __pyx_t_7 = 0; + + /* "View.MemoryView":437 + * cdef is_slice(self, obj): + * if not isinstance(obj, memoryview): + * try: # <<<<<<<<<<<<<< + * obj = memoryview(obj, self.flags & ~PyBUF_WRITABLE | PyBUF_ANY_CONTIGUOUS, + * self.dtype_is_object) + */ + } + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; + goto __pyx_L9_try_end; + __pyx_L4_error:; + __Pyx_XDECREF(__pyx_t_6); __pyx_t_6 = 0; + __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; + __Pyx_XDECREF(__pyx_t_8); __pyx_t_8 = 0; + + /* "View.MemoryView":440 + * obj = memoryview(obj, self.flags & ~PyBUF_WRITABLE | PyBUF_ANY_CONTIGUOUS, + * self.dtype_is_object) + * except TypeError: # <<<<<<<<<<<<<< + * return None + * + */ + __pyx_t_9 = __Pyx_PyErr_ExceptionMatches(__pyx_builtin_TypeError); + if (__pyx_t_9) { + __Pyx_AddTraceback("View.MemoryView.memoryview.is_slice", __pyx_clineno, __pyx_lineno, __pyx_filename); + if (__Pyx_GetException(&__pyx_t_7, &__pyx_t_8, &__pyx_t_6) < 0) __PYX_ERR(1, 440, __pyx_L6_except_error) + __Pyx_XGOTREF(__pyx_t_7); + __Pyx_XGOTREF(__pyx_t_8); + __Pyx_XGOTREF(__pyx_t_6); + + /* "View.MemoryView":441 + * self.dtype_is_object) + * except TypeError: + * return None # <<<<<<<<<<<<<< + * + * return obj + */ + __Pyx_XDECREF(__pyx_r); + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + goto __pyx_L7_except_return; + } + goto __pyx_L6_except_error; + + /* "View.MemoryView":437 + * cdef is_slice(self, obj): + * if not isinstance(obj, memoryview): + * try: # <<<<<<<<<<<<<< + * obj = memoryview(obj, self.flags & ~PyBUF_WRITABLE | PyBUF_ANY_CONTIGUOUS, + * self.dtype_is_object) + */ + __pyx_L6_except_error:; + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_XGIVEREF(__pyx_t_4); + __Pyx_XGIVEREF(__pyx_t_5); + __Pyx_ExceptionReset(__pyx_t_3, __pyx_t_4, __pyx_t_5); + goto __pyx_L1_error; + __pyx_L7_except_return:; + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_XGIVEREF(__pyx_t_4); + __Pyx_XGIVEREF(__pyx_t_5); + __Pyx_ExceptionReset(__pyx_t_3, __pyx_t_4, __pyx_t_5); + goto __pyx_L0; + __pyx_L9_try_end:; + } + + /* "View.MemoryView":436 + * + * cdef is_slice(self, obj): + * if not isinstance(obj, memoryview): # <<<<<<<<<<<<<< + * try: + * obj = memoryview(obj, self.flags & ~PyBUF_WRITABLE | PyBUF_ANY_CONTIGUOUS, + */ + } + + /* "View.MemoryView":443 + * return None + * + * return obj # <<<<<<<<<<<<<< + * + * cdef setitem_slice_assignment(self, dst, src): + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(__pyx_v_obj); + __pyx_r = __pyx_v_obj; + goto __pyx_L0; + + /* "View.MemoryView":435 + * self.setitem_indexed(index, value) + * + * cdef is_slice(self, obj): # <<<<<<<<<<<<<< + * if not isinstance(obj, memoryview): + * try: + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_6); + __Pyx_XDECREF(__pyx_t_7); + __Pyx_XDECREF(__pyx_t_8); + __Pyx_AddTraceback("View.MemoryView.memoryview.is_slice", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_obj); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":445 + * return obj + * + * cdef setitem_slice_assignment(self, dst, src): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice dst_slice + * cdef __Pyx_memviewslice src_slice + */ + +static PyObject *__pyx_memoryview_setitem_slice_assignment(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_dst, PyObject *__pyx_v_src) { + __Pyx_memviewslice __pyx_v_dst_slice; + __Pyx_memviewslice __pyx_v_src_slice; + __Pyx_memviewslice __pyx_v_msrc; + __Pyx_memviewslice __pyx_v_mdst; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_memviewslice *__pyx_t_1; + PyObject *__pyx_t_2 = NULL; + int __pyx_t_3; + int __pyx_t_4; + int __pyx_t_5; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("setitem_slice_assignment", 1); + + /* "View.MemoryView":448 + * cdef __Pyx_memviewslice dst_slice + * cdef __Pyx_memviewslice src_slice + * cdef __Pyx_memviewslice msrc = get_slice_from_memview(src, &src_slice)[0] # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice mdst = get_slice_from_memview(dst, &dst_slice)[0] + * + */ + if (!(likely(((__pyx_v_src) == Py_None) || likely(__Pyx_TypeTest(__pyx_v_src, __pyx_memoryview_type))))) __PYX_ERR(1, 448, __pyx_L1_error) + __pyx_t_1 = __pyx_memoryview_get_slice_from_memoryview(((struct __pyx_memoryview_obj *)__pyx_v_src), (&__pyx_v_src_slice)); if (unlikely(__pyx_t_1 == ((__Pyx_memviewslice *)NULL))) __PYX_ERR(1, 448, __pyx_L1_error) + __pyx_v_msrc = (__pyx_t_1[0]); + + /* "View.MemoryView":449 + * cdef __Pyx_memviewslice src_slice + * cdef __Pyx_memviewslice msrc = get_slice_from_memview(src, &src_slice)[0] + * cdef __Pyx_memviewslice mdst = get_slice_from_memview(dst, &dst_slice)[0] # <<<<<<<<<<<<<< + * + * memoryview_copy_contents(msrc, mdst, src.ndim, dst.ndim, self.dtype_is_object) + */ + if (!(likely(((__pyx_v_dst) == Py_None) || likely(__Pyx_TypeTest(__pyx_v_dst, __pyx_memoryview_type))))) __PYX_ERR(1, 449, __pyx_L1_error) + __pyx_t_1 = __pyx_memoryview_get_slice_from_memoryview(((struct __pyx_memoryview_obj *)__pyx_v_dst), (&__pyx_v_dst_slice)); if (unlikely(__pyx_t_1 == ((__Pyx_memviewslice *)NULL))) __PYX_ERR(1, 449, __pyx_L1_error) + __pyx_v_mdst = (__pyx_t_1[0]); + + /* "View.MemoryView":451 + * cdef __Pyx_memviewslice mdst = get_slice_from_memview(dst, &dst_slice)[0] + * + * memoryview_copy_contents(msrc, mdst, src.ndim, dst.ndim, self.dtype_is_object) # <<<<<<<<<<<<<< + * + * cdef setitem_slice_assign_scalar(self, memoryview dst, value): + */ + __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_v_src, __pyx_n_s_ndim); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 451, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_3 = __Pyx_PyInt_As_int(__pyx_t_2); if (unlikely((__pyx_t_3 == (int)-1) && PyErr_Occurred())) __PYX_ERR(1, 451, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_v_dst, __pyx_n_s_ndim); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 451, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_4 = __Pyx_PyInt_As_int(__pyx_t_2); if (unlikely((__pyx_t_4 == (int)-1) && PyErr_Occurred())) __PYX_ERR(1, 451, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_t_5 = __pyx_memoryview_copy_contents(__pyx_v_msrc, __pyx_v_mdst, __pyx_t_3, __pyx_t_4, __pyx_v_self->dtype_is_object); if (unlikely(__pyx_t_5 == ((int)-1))) __PYX_ERR(1, 451, __pyx_L1_error) + + /* "View.MemoryView":445 + * return obj + * + * cdef setitem_slice_assignment(self, dst, src): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice dst_slice + * cdef __Pyx_memviewslice src_slice + */ + + /* function exit code */ + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.memoryview.setitem_slice_assignment", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":453 + * memoryview_copy_contents(msrc, mdst, src.ndim, dst.ndim, self.dtype_is_object) + * + * cdef setitem_slice_assign_scalar(self, memoryview dst, value): # <<<<<<<<<<<<<< + * cdef int array[128] + * cdef void *tmp = NULL + */ + +static PyObject *__pyx_memoryview_setitem_slice_assign_scalar(struct __pyx_memoryview_obj *__pyx_v_self, struct __pyx_memoryview_obj *__pyx_v_dst, PyObject *__pyx_v_value) { + int __pyx_v_array[0x80]; + void *__pyx_v_tmp; + void *__pyx_v_item; + __Pyx_memviewslice *__pyx_v_dst_slice; + __Pyx_memviewslice __pyx_v_tmp_slice; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_memviewslice *__pyx_t_1; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + int __pyx_t_4; + int __pyx_t_5; + char const *__pyx_t_6; + PyObject *__pyx_t_7 = NULL; + PyObject *__pyx_t_8 = NULL; + PyObject *__pyx_t_9 = NULL; + PyObject *__pyx_t_10 = NULL; + PyObject *__pyx_t_11 = NULL; + PyObject *__pyx_t_12 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("setitem_slice_assign_scalar", 1); + + /* "View.MemoryView":455 + * cdef setitem_slice_assign_scalar(self, memoryview dst, value): + * cdef int array[128] + * cdef void *tmp = NULL # <<<<<<<<<<<<<< + * cdef void *item + * + */ + __pyx_v_tmp = NULL; + + /* "View.MemoryView":460 + * cdef __Pyx_memviewslice *dst_slice + * cdef __Pyx_memviewslice tmp_slice + * dst_slice = get_slice_from_memview(dst, &tmp_slice) # <<<<<<<<<<<<<< + * + * if self.view.itemsize > sizeof(array): + */ + __pyx_t_1 = __pyx_memoryview_get_slice_from_memoryview(__pyx_v_dst, (&__pyx_v_tmp_slice)); if (unlikely(__pyx_t_1 == ((__Pyx_memviewslice *)NULL))) __PYX_ERR(1, 460, __pyx_L1_error) + __pyx_v_dst_slice = __pyx_t_1; + + /* "View.MemoryView":462 + * dst_slice = get_slice_from_memview(dst, &tmp_slice) + * + * if self.view.itemsize > sizeof(array): # <<<<<<<<<<<<<< + * tmp = PyMem_Malloc(self.view.itemsize) + * if tmp == NULL: + */ + __pyx_t_2 = (((size_t)__pyx_v_self->view.itemsize) > (sizeof(__pyx_v_array))); + if (__pyx_t_2) { + + /* "View.MemoryView":463 + * + * if self.view.itemsize > sizeof(array): + * tmp = PyMem_Malloc(self.view.itemsize) # <<<<<<<<<<<<<< + * if tmp == NULL: + * raise MemoryError + */ + __pyx_v_tmp = PyMem_Malloc(__pyx_v_self->view.itemsize); + + /* "View.MemoryView":464 + * if self.view.itemsize > sizeof(array): + * tmp = PyMem_Malloc(self.view.itemsize) + * if tmp == NULL: # <<<<<<<<<<<<<< + * raise MemoryError + * item = tmp + */ + __pyx_t_2 = (__pyx_v_tmp == NULL); + if (unlikely(__pyx_t_2)) { + + /* "View.MemoryView":465 + * tmp = PyMem_Malloc(self.view.itemsize) + * if tmp == NULL: + * raise MemoryError # <<<<<<<<<<<<<< + * item = tmp + * else: + */ + PyErr_NoMemory(); __PYX_ERR(1, 465, __pyx_L1_error) + + /* "View.MemoryView":464 + * if self.view.itemsize > sizeof(array): + * tmp = PyMem_Malloc(self.view.itemsize) + * if tmp == NULL: # <<<<<<<<<<<<<< + * raise MemoryError + * item = tmp + */ + } + + /* "View.MemoryView":466 + * if tmp == NULL: + * raise MemoryError + * item = tmp # <<<<<<<<<<<<<< + * else: + * item = array + */ + __pyx_v_item = __pyx_v_tmp; + + /* "View.MemoryView":462 + * dst_slice = get_slice_from_memview(dst, &tmp_slice) + * + * if self.view.itemsize > sizeof(array): # <<<<<<<<<<<<<< + * tmp = PyMem_Malloc(self.view.itemsize) + * if tmp == NULL: + */ + goto __pyx_L3; + } + + /* "View.MemoryView":468 + * item = tmp + * else: + * item = array # <<<<<<<<<<<<<< + * + * try: + */ + /*else*/ { + __pyx_v_item = ((void *)__pyx_v_array); + } + __pyx_L3:; + + /* "View.MemoryView":470 + * item = array + * + * try: # <<<<<<<<<<<<<< + * if self.dtype_is_object: + * ( item)[0] = value + */ + /*try:*/ { + + /* "View.MemoryView":471 + * + * try: + * if self.dtype_is_object: # <<<<<<<<<<<<<< + * ( item)[0] = value + * else: + */ + if (__pyx_v_self->dtype_is_object) { + + /* "View.MemoryView":472 + * try: + * if self.dtype_is_object: + * ( item)[0] = value # <<<<<<<<<<<<<< + * else: + * self.assign_item_from_object( item, value) + */ + (((PyObject **)__pyx_v_item)[0]) = ((PyObject *)__pyx_v_value); + + /* "View.MemoryView":471 + * + * try: + * if self.dtype_is_object: # <<<<<<<<<<<<<< + * ( item)[0] = value + * else: + */ + goto __pyx_L8; + } + + /* "View.MemoryView":474 + * ( item)[0] = value + * else: + * self.assign_item_from_object( item, value) # <<<<<<<<<<<<<< + * + * + */ + /*else*/ { + __pyx_t_3 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->assign_item_from_object(__pyx_v_self, ((char *)__pyx_v_item), __pyx_v_value); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 474, __pyx_L6_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + } + __pyx_L8:; + + /* "View.MemoryView":478 + * + * + * if self.view.suboffsets != NULL: # <<<<<<<<<<<<<< + * assert_direct_dimensions(self.view.suboffsets, self.view.ndim) + * slice_assign_scalar(dst_slice, dst.view.ndim, self.view.itemsize, + */ + __pyx_t_2 = (__pyx_v_self->view.suboffsets != NULL); + if (__pyx_t_2) { + + /* "View.MemoryView":479 + * + * if self.view.suboffsets != NULL: + * assert_direct_dimensions(self.view.suboffsets, self.view.ndim) # <<<<<<<<<<<<<< + * slice_assign_scalar(dst_slice, dst.view.ndim, self.view.itemsize, + * item, self.dtype_is_object) + */ + __pyx_t_4 = assert_direct_dimensions(__pyx_v_self->view.suboffsets, __pyx_v_self->view.ndim); if (unlikely(__pyx_t_4 == ((int)-1))) __PYX_ERR(1, 479, __pyx_L6_error) + + /* "View.MemoryView":478 + * + * + * if self.view.suboffsets != NULL: # <<<<<<<<<<<<<< + * assert_direct_dimensions(self.view.suboffsets, self.view.ndim) + * slice_assign_scalar(dst_slice, dst.view.ndim, self.view.itemsize, + */ + } + + /* "View.MemoryView":480 + * if self.view.suboffsets != NULL: + * assert_direct_dimensions(self.view.suboffsets, self.view.ndim) + * slice_assign_scalar(dst_slice, dst.view.ndim, self.view.itemsize, # <<<<<<<<<<<<<< + * item, self.dtype_is_object) + * finally: + */ + __pyx_memoryview_slice_assign_scalar(__pyx_v_dst_slice, __pyx_v_dst->view.ndim, __pyx_v_self->view.itemsize, __pyx_v_item, __pyx_v_self->dtype_is_object); + } + + /* "View.MemoryView":483 + * item, self.dtype_is_object) + * finally: + * PyMem_Free(tmp) # <<<<<<<<<<<<<< + * + * cdef setitem_indexed(self, index, value): + */ + /*finally:*/ { + /*normal exit:*/{ + PyMem_Free(__pyx_v_tmp); + goto __pyx_L7; + } + __pyx_L6_error:; + /*exception exit:*/{ + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __pyx_t_7 = 0; __pyx_t_8 = 0; __pyx_t_9 = 0; __pyx_t_10 = 0; __pyx_t_11 = 0; __pyx_t_12 = 0; + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + if (PY_MAJOR_VERSION >= 3) __Pyx_ExceptionSwap(&__pyx_t_10, &__pyx_t_11, &__pyx_t_12); + if ((PY_MAJOR_VERSION < 3) || unlikely(__Pyx_GetException(&__pyx_t_7, &__pyx_t_8, &__pyx_t_9) < 0)) __Pyx_ErrFetch(&__pyx_t_7, &__pyx_t_8, &__pyx_t_9); + __Pyx_XGOTREF(__pyx_t_7); + __Pyx_XGOTREF(__pyx_t_8); + __Pyx_XGOTREF(__pyx_t_9); + __Pyx_XGOTREF(__pyx_t_10); + __Pyx_XGOTREF(__pyx_t_11); + __Pyx_XGOTREF(__pyx_t_12); + __pyx_t_4 = __pyx_lineno; __pyx_t_5 = __pyx_clineno; __pyx_t_6 = __pyx_filename; + { + PyMem_Free(__pyx_v_tmp); + } + if (PY_MAJOR_VERSION >= 3) { + __Pyx_XGIVEREF(__pyx_t_10); + __Pyx_XGIVEREF(__pyx_t_11); + __Pyx_XGIVEREF(__pyx_t_12); + __Pyx_ExceptionReset(__pyx_t_10, __pyx_t_11, __pyx_t_12); + } + __Pyx_XGIVEREF(__pyx_t_7); + __Pyx_XGIVEREF(__pyx_t_8); + __Pyx_XGIVEREF(__pyx_t_9); + __Pyx_ErrRestore(__pyx_t_7, __pyx_t_8, __pyx_t_9); + __pyx_t_7 = 0; __pyx_t_8 = 0; __pyx_t_9 = 0; __pyx_t_10 = 0; __pyx_t_11 = 0; __pyx_t_12 = 0; + __pyx_lineno = __pyx_t_4; __pyx_clineno = __pyx_t_5; __pyx_filename = __pyx_t_6; + goto __pyx_L1_error; + } + __pyx_L7:; + } + + /* "View.MemoryView":453 + * memoryview_copy_contents(msrc, mdst, src.ndim, dst.ndim, self.dtype_is_object) + * + * cdef setitem_slice_assign_scalar(self, memoryview dst, value): # <<<<<<<<<<<<<< + * cdef int array[128] + * cdef void *tmp = NULL + */ + + /* function exit code */ + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_3); + __Pyx_AddTraceback("View.MemoryView.memoryview.setitem_slice_assign_scalar", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":485 + * PyMem_Free(tmp) + * + * cdef setitem_indexed(self, index, value): # <<<<<<<<<<<<<< + * cdef char *itemp = self.get_item_pointer(index) + * self.assign_item_from_object(itemp, value) + */ + +static PyObject *__pyx_memoryview_setitem_indexed(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index, PyObject *__pyx_v_value) { + char *__pyx_v_itemp; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + char *__pyx_t_1; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("setitem_indexed", 1); + + /* "View.MemoryView":486 + * + * cdef setitem_indexed(self, index, value): + * cdef char *itemp = self.get_item_pointer(index) # <<<<<<<<<<<<<< + * self.assign_item_from_object(itemp, value) + * + */ + __pyx_t_1 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->get_item_pointer(__pyx_v_self, __pyx_v_index); if (unlikely(__pyx_t_1 == ((char *)NULL))) __PYX_ERR(1, 486, __pyx_L1_error) + __pyx_v_itemp = __pyx_t_1; + + /* "View.MemoryView":487 + * cdef setitem_indexed(self, index, value): + * cdef char *itemp = self.get_item_pointer(index) + * self.assign_item_from_object(itemp, value) # <<<<<<<<<<<<<< + * + * cdef convert_item_to_object(self, char *itemp): + */ + __pyx_t_2 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->assign_item_from_object(__pyx_v_self, __pyx_v_itemp, __pyx_v_value); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 487, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + + /* "View.MemoryView":485 + * PyMem_Free(tmp) + * + * cdef setitem_indexed(self, index, value): # <<<<<<<<<<<<<< + * cdef char *itemp = self.get_item_pointer(index) + * self.assign_item_from_object(itemp, value) + */ + + /* function exit code */ + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.memoryview.setitem_indexed", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":489 + * self.assign_item_from_object(itemp, value) + * + * cdef convert_item_to_object(self, char *itemp): # <<<<<<<<<<<<<< + * """Only used if instantiated manually by the user, or if Cython doesn't + * know how to convert the type""" + */ + +static PyObject *__pyx_memoryview_convert_item_to_object(struct __pyx_memoryview_obj *__pyx_v_self, char *__pyx_v_itemp) { + PyObject *__pyx_v_struct = NULL; + PyObject *__pyx_v_bytesitem = 0; + PyObject *__pyx_v_result = NULL; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + PyObject *__pyx_t_5 = NULL; + PyObject *__pyx_t_6 = NULL; + PyObject *__pyx_t_7 = NULL; + int __pyx_t_8; + Py_ssize_t __pyx_t_9; + int __pyx_t_10; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("convert_item_to_object", 1); + + /* "View.MemoryView":492 + * """Only used if instantiated manually by the user, or if Cython doesn't + * know how to convert the type""" + * import struct # <<<<<<<<<<<<<< + * cdef bytes bytesitem + * + */ + __pyx_t_1 = __Pyx_ImportDottedModule(__pyx_n_s_struct, NULL); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 492, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_v_struct = __pyx_t_1; + __pyx_t_1 = 0; + + /* "View.MemoryView":495 + * cdef bytes bytesitem + * + * bytesitem = itemp[:self.view.itemsize] # <<<<<<<<<<<<<< + * try: + * result = struct.unpack(self.view.format, bytesitem) + */ + __pyx_t_1 = __Pyx_PyBytes_FromStringAndSize(__pyx_v_itemp + 0, __pyx_v_self->view.itemsize - 0); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 495, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_v_bytesitem = ((PyObject*)__pyx_t_1); + __pyx_t_1 = 0; + + /* "View.MemoryView":496 + * + * bytesitem = itemp[:self.view.itemsize] + * try: # <<<<<<<<<<<<<< + * result = struct.unpack(self.view.format, bytesitem) + * except struct.error: + */ + { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ExceptionSave(&__pyx_t_2, &__pyx_t_3, &__pyx_t_4); + __Pyx_XGOTREF(__pyx_t_2); + __Pyx_XGOTREF(__pyx_t_3); + __Pyx_XGOTREF(__pyx_t_4); + /*try:*/ { + + /* "View.MemoryView":497 + * bytesitem = itemp[:self.view.itemsize] + * try: + * result = struct.unpack(self.view.format, bytesitem) # <<<<<<<<<<<<<< + * except struct.error: + * raise ValueError, "Unable to convert item to object" + */ + __pyx_t_5 = __Pyx_PyObject_GetAttrStr(__pyx_v_struct, __pyx_n_s_unpack); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 497, __pyx_L3_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_6 = __Pyx_PyBytes_FromString(__pyx_v_self->view.format); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 497, __pyx_L3_error) + __Pyx_GOTREF(__pyx_t_6); + __pyx_t_7 = NULL; + __pyx_t_8 = 0; + #if CYTHON_UNPACK_METHODS + if (likely(PyMethod_Check(__pyx_t_5))) { + __pyx_t_7 = PyMethod_GET_SELF(__pyx_t_5); + if (likely(__pyx_t_7)) { + PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_5); + __Pyx_INCREF(__pyx_t_7); + __Pyx_INCREF(function); + __Pyx_DECREF_SET(__pyx_t_5, function); + __pyx_t_8 = 1; + } + } + #endif + { + PyObject *__pyx_callargs[3] = {__pyx_t_7, __pyx_t_6, __pyx_v_bytesitem}; + __pyx_t_1 = __Pyx_PyObject_FastCall(__pyx_t_5, __pyx_callargs+1-__pyx_t_8, 2+__pyx_t_8); + __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 497, __pyx_L3_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + } + __pyx_v_result = __pyx_t_1; + __pyx_t_1 = 0; + + /* "View.MemoryView":496 + * + * bytesitem = itemp[:self.view.itemsize] + * try: # <<<<<<<<<<<<<< + * result = struct.unpack(self.view.format, bytesitem) + * except struct.error: + */ + } + + /* "View.MemoryView":501 + * raise ValueError, "Unable to convert item to object" + * else: + * if len(self.view.format) == 1: # <<<<<<<<<<<<<< + * return result[0] + * return result + */ + /*else:*/ { + __pyx_t_9 = __Pyx_ssize_strlen(__pyx_v_self->view.format); if (unlikely(__pyx_t_9 == ((Py_ssize_t)-1))) __PYX_ERR(1, 501, __pyx_L5_except_error) + __pyx_t_10 = (__pyx_t_9 == 1); + if (__pyx_t_10) { + + /* "View.MemoryView":502 + * else: + * if len(self.view.format) == 1: + * return result[0] # <<<<<<<<<<<<<< + * return result + * + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __Pyx_GetItemInt(__pyx_v_result, 0, long, 1, __Pyx_PyInt_From_long, 0, 0, 1); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 502, __pyx_L5_except_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L6_except_return; + + /* "View.MemoryView":501 + * raise ValueError, "Unable to convert item to object" + * else: + * if len(self.view.format) == 1: # <<<<<<<<<<<<<< + * return result[0] + * return result + */ + } + + /* "View.MemoryView":503 + * if len(self.view.format) == 1: + * return result[0] + * return result # <<<<<<<<<<<<<< + * + * cdef assign_item_from_object(self, char *itemp, object value): + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(__pyx_v_result); + __pyx_r = __pyx_v_result; + goto __pyx_L6_except_return; + } + __pyx_L3_error:; + __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_XDECREF(__pyx_t_6); __pyx_t_6 = 0; + __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "View.MemoryView":498 + * try: + * result = struct.unpack(self.view.format, bytesitem) + * except struct.error: # <<<<<<<<<<<<<< + * raise ValueError, "Unable to convert item to object" + * else: + */ + __Pyx_ErrFetch(&__pyx_t_1, &__pyx_t_5, &__pyx_t_6); + __pyx_t_7 = __Pyx_PyObject_GetAttrStr(__pyx_v_struct, __pyx_n_s_error); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 498, __pyx_L5_except_error) + __Pyx_GOTREF(__pyx_t_7); + __pyx_t_8 = __Pyx_PyErr_GivenExceptionMatches(__pyx_t_1, __pyx_t_7); + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + __Pyx_ErrRestore(__pyx_t_1, __pyx_t_5, __pyx_t_6); + __pyx_t_1 = 0; __pyx_t_5 = 0; __pyx_t_6 = 0; + if (__pyx_t_8) { + __Pyx_AddTraceback("View.MemoryView.memoryview.convert_item_to_object", __pyx_clineno, __pyx_lineno, __pyx_filename); + if (__Pyx_GetException(&__pyx_t_6, &__pyx_t_5, &__pyx_t_1) < 0) __PYX_ERR(1, 498, __pyx_L5_except_error) + __Pyx_XGOTREF(__pyx_t_6); + __Pyx_XGOTREF(__pyx_t_5); + __Pyx_XGOTREF(__pyx_t_1); + + /* "View.MemoryView":499 + * result = struct.unpack(self.view.format, bytesitem) + * except struct.error: + * raise ValueError, "Unable to convert item to object" # <<<<<<<<<<<<<< + * else: + * if len(self.view.format) == 1: + */ + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_kp_s_Unable_to_convert_item_to_object, 0, 0); + __PYX_ERR(1, 499, __pyx_L5_except_error) + } + goto __pyx_L5_except_error; + + /* "View.MemoryView":496 + * + * bytesitem = itemp[:self.view.itemsize] + * try: # <<<<<<<<<<<<<< + * result = struct.unpack(self.view.format, bytesitem) + * except struct.error: + */ + __pyx_L5_except_error:; + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_XGIVEREF(__pyx_t_4); + __Pyx_ExceptionReset(__pyx_t_2, __pyx_t_3, __pyx_t_4); + goto __pyx_L1_error; + __pyx_L6_except_return:; + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_XGIVEREF(__pyx_t_4); + __Pyx_ExceptionReset(__pyx_t_2, __pyx_t_3, __pyx_t_4); + goto __pyx_L0; + } + + /* "View.MemoryView":489 + * self.assign_item_from_object(itemp, value) + * + * cdef convert_item_to_object(self, char *itemp): # <<<<<<<<<<<<<< + * """Only used if instantiated manually by the user, or if Cython doesn't + * know how to convert the type""" + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_5); + __Pyx_XDECREF(__pyx_t_6); + __Pyx_XDECREF(__pyx_t_7); + __Pyx_AddTraceback("View.MemoryView.memoryview.convert_item_to_object", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_struct); + __Pyx_XDECREF(__pyx_v_bytesitem); + __Pyx_XDECREF(__pyx_v_result); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":505 + * return result + * + * cdef assign_item_from_object(self, char *itemp, object value): # <<<<<<<<<<<<<< + * """Only used if instantiated manually by the user, or if Cython doesn't + * know how to convert the type""" + */ + +static PyObject *__pyx_memoryview_assign_item_from_object(struct __pyx_memoryview_obj *__pyx_v_self, char *__pyx_v_itemp, PyObject *__pyx_v_value) { + PyObject *__pyx_v_struct = NULL; + char __pyx_v_c; + PyObject *__pyx_v_bytesvalue = 0; + Py_ssize_t __pyx_v_i; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + PyObject *__pyx_t_5 = NULL; + int __pyx_t_6; + Py_ssize_t __pyx_t_7; + PyObject *__pyx_t_8 = NULL; + char *__pyx_t_9; + char *__pyx_t_10; + char *__pyx_t_11; + char *__pyx_t_12; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("assign_item_from_object", 1); + + /* "View.MemoryView":508 + * """Only used if instantiated manually by the user, or if Cython doesn't + * know how to convert the type""" + * import struct # <<<<<<<<<<<<<< + * cdef char c + * cdef bytes bytesvalue + */ + __pyx_t_1 = __Pyx_ImportDottedModule(__pyx_n_s_struct, NULL); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 508, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_v_struct = __pyx_t_1; + __pyx_t_1 = 0; + + /* "View.MemoryView":513 + * cdef Py_ssize_t i + * + * if isinstance(value, tuple): # <<<<<<<<<<<<<< + * bytesvalue = struct.pack(self.view.format, *value) + * else: + */ + __pyx_t_2 = PyTuple_Check(__pyx_v_value); + if (__pyx_t_2) { + + /* "View.MemoryView":514 + * + * if isinstance(value, tuple): + * bytesvalue = struct.pack(self.view.format, *value) # <<<<<<<<<<<<<< + * else: + * bytesvalue = struct.pack(self.view.format, value) + */ + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(__pyx_v_struct, __pyx_n_s_pack); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 514, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_3 = __Pyx_PyBytes_FromString(__pyx_v_self->view.format); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 514, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_4 = PyTuple_New(1); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 514, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_GIVEREF(__pyx_t_3); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 0, __pyx_t_3)) __PYX_ERR(1, 514, __pyx_L1_error); + __pyx_t_3 = 0; + __pyx_t_3 = __Pyx_PySequence_Tuple(__pyx_v_value); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 514, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_5 = PyNumber_Add(__pyx_t_4, __pyx_t_3); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 514, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __pyx_t_3 = __Pyx_PyObject_Call(__pyx_t_1, __pyx_t_5, NULL); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 514, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + if (!(likely(PyBytes_CheckExact(__pyx_t_3))||((__pyx_t_3) == Py_None) || __Pyx_RaiseUnexpectedTypeError("bytes", __pyx_t_3))) __PYX_ERR(1, 514, __pyx_L1_error) + __pyx_v_bytesvalue = ((PyObject*)__pyx_t_3); + __pyx_t_3 = 0; + + /* "View.MemoryView":513 + * cdef Py_ssize_t i + * + * if isinstance(value, tuple): # <<<<<<<<<<<<<< + * bytesvalue = struct.pack(self.view.format, *value) + * else: + */ + goto __pyx_L3; + } + + /* "View.MemoryView":516 + * bytesvalue = struct.pack(self.view.format, *value) + * else: + * bytesvalue = struct.pack(self.view.format, value) # <<<<<<<<<<<<<< + * + * for i, c in enumerate(bytesvalue): + */ + /*else*/ { + __pyx_t_5 = __Pyx_PyObject_GetAttrStr(__pyx_v_struct, __pyx_n_s_pack); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 516, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_1 = __Pyx_PyBytes_FromString(__pyx_v_self->view.format); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 516, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_4 = NULL; + __pyx_t_6 = 0; + #if CYTHON_UNPACK_METHODS + if (likely(PyMethod_Check(__pyx_t_5))) { + __pyx_t_4 = PyMethod_GET_SELF(__pyx_t_5); + if (likely(__pyx_t_4)) { + PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_5); + __Pyx_INCREF(__pyx_t_4); + __Pyx_INCREF(function); + __Pyx_DECREF_SET(__pyx_t_5, function); + __pyx_t_6 = 1; + } + } + #endif + { + PyObject *__pyx_callargs[3] = {__pyx_t_4, __pyx_t_1, __pyx_v_value}; + __pyx_t_3 = __Pyx_PyObject_FastCall(__pyx_t_5, __pyx_callargs+1-__pyx_t_6, 2+__pyx_t_6); + __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 516, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + } + if (!(likely(PyBytes_CheckExact(__pyx_t_3))||((__pyx_t_3) == Py_None) || __Pyx_RaiseUnexpectedTypeError("bytes", __pyx_t_3))) __PYX_ERR(1, 516, __pyx_L1_error) + __pyx_v_bytesvalue = ((PyObject*)__pyx_t_3); + __pyx_t_3 = 0; + } + __pyx_L3:; + + /* "View.MemoryView":518 + * bytesvalue = struct.pack(self.view.format, value) + * + * for i, c in enumerate(bytesvalue): # <<<<<<<<<<<<<< + * itemp[i] = c + * + */ + __pyx_t_7 = 0; + if (unlikely(__pyx_v_bytesvalue == Py_None)) { + PyErr_SetString(PyExc_TypeError, "'NoneType' is not iterable"); + __PYX_ERR(1, 518, __pyx_L1_error) + } + __Pyx_INCREF(__pyx_v_bytesvalue); + __pyx_t_8 = __pyx_v_bytesvalue; + __pyx_t_10 = PyBytes_AS_STRING(__pyx_t_8); + __pyx_t_11 = (__pyx_t_10 + PyBytes_GET_SIZE(__pyx_t_8)); + for (__pyx_t_12 = __pyx_t_10; __pyx_t_12 < __pyx_t_11; __pyx_t_12++) { + __pyx_t_9 = __pyx_t_12; + __pyx_v_c = (__pyx_t_9[0]); + + /* "View.MemoryView":519 + * + * for i, c in enumerate(bytesvalue): + * itemp[i] = c # <<<<<<<<<<<<<< + * + * @cname('getbuffer') + */ + __pyx_v_i = __pyx_t_7; + + /* "View.MemoryView":518 + * bytesvalue = struct.pack(self.view.format, value) + * + * for i, c in enumerate(bytesvalue): # <<<<<<<<<<<<<< + * itemp[i] = c + * + */ + __pyx_t_7 = (__pyx_t_7 + 1); + + /* "View.MemoryView":519 + * + * for i, c in enumerate(bytesvalue): + * itemp[i] = c # <<<<<<<<<<<<<< + * + * @cname('getbuffer') + */ + (__pyx_v_itemp[__pyx_v_i]) = __pyx_v_c; + } + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + + /* "View.MemoryView":505 + * return result + * + * cdef assign_item_from_object(self, char *itemp, object value): # <<<<<<<<<<<<<< + * """Only used if instantiated manually by the user, or if Cython doesn't + * know how to convert the type""" + */ + + /* function exit code */ + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_XDECREF(__pyx_t_4); + __Pyx_XDECREF(__pyx_t_5); + __Pyx_XDECREF(__pyx_t_8); + __Pyx_AddTraceback("View.MemoryView.memoryview.assign_item_from_object", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_struct); + __Pyx_XDECREF(__pyx_v_bytesvalue); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":521 + * itemp[i] = c + * + * @cname('getbuffer') # <<<<<<<<<<<<<< + * def __getbuffer__(self, Py_buffer *info, int flags): + * if flags & PyBUF_WRITABLE and self.view.readonly: + */ + +/* Python wrapper */ +CYTHON_UNUSED static int __pyx_memoryview_getbuffer(PyObject *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags); /*proto*/ +CYTHON_UNUSED static int __pyx_memoryview_getbuffer(PyObject *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + int __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__getbuffer__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_8__getbuffer__(((struct __pyx_memoryview_obj *)__pyx_v_self), ((Py_buffer *)__pyx_v_info), ((int)__pyx_v_flags)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static int __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_8__getbuffer__(struct __pyx_memoryview_obj *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags) { + int __pyx_r; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + int __pyx_t_2; + Py_ssize_t *__pyx_t_3; + char *__pyx_t_4; + void *__pyx_t_5; + int __pyx_t_6; + Py_ssize_t __pyx_t_7; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + if (unlikely(__pyx_v_info == NULL)) { + PyErr_SetString(PyExc_BufferError, "PyObject_GetBuffer: view==NULL argument is obsolete"); + return -1; + } + __Pyx_RefNannySetupContext("__getbuffer__", 0); + __pyx_v_info->obj = Py_None; __Pyx_INCREF(Py_None); + __Pyx_GIVEREF(__pyx_v_info->obj); + + /* "View.MemoryView":523 + * @cname('getbuffer') + * def __getbuffer__(self, Py_buffer *info, int flags): + * if flags & PyBUF_WRITABLE and self.view.readonly: # <<<<<<<<<<<<<< + * raise ValueError, "Cannot create writable memory view from read-only memoryview" + * + */ + __pyx_t_2 = ((__pyx_v_flags & PyBUF_WRITABLE) != 0); + if (__pyx_t_2) { + } else { + __pyx_t_1 = __pyx_t_2; + goto __pyx_L4_bool_binop_done; + } + __pyx_t_1 = __pyx_v_self->view.readonly; + __pyx_L4_bool_binop_done:; + if (unlikely(__pyx_t_1)) { + + /* "View.MemoryView":524 + * def __getbuffer__(self, Py_buffer *info, int flags): + * if flags & PyBUF_WRITABLE and self.view.readonly: + * raise ValueError, "Cannot create writable memory view from read-only memoryview" # <<<<<<<<<<<<<< + * + * if flags & PyBUF_ND: + */ + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_kp_s_Cannot_create_writable_memory_vi, 0, 0); + __PYX_ERR(1, 524, __pyx_L1_error) + + /* "View.MemoryView":523 + * @cname('getbuffer') + * def __getbuffer__(self, Py_buffer *info, int flags): + * if flags & PyBUF_WRITABLE and self.view.readonly: # <<<<<<<<<<<<<< + * raise ValueError, "Cannot create writable memory view from read-only memoryview" + * + */ + } + + /* "View.MemoryView":526 + * raise ValueError, "Cannot create writable memory view from read-only memoryview" + * + * if flags & PyBUF_ND: # <<<<<<<<<<<<<< + * info.shape = self.view.shape + * else: + */ + __pyx_t_1 = ((__pyx_v_flags & PyBUF_ND) != 0); + if (__pyx_t_1) { + + /* "View.MemoryView":527 + * + * if flags & PyBUF_ND: + * info.shape = self.view.shape # <<<<<<<<<<<<<< + * else: + * info.shape = NULL + */ + __pyx_t_3 = __pyx_v_self->view.shape; + __pyx_v_info->shape = __pyx_t_3; + + /* "View.MemoryView":526 + * raise ValueError, "Cannot create writable memory view from read-only memoryview" + * + * if flags & PyBUF_ND: # <<<<<<<<<<<<<< + * info.shape = self.view.shape + * else: + */ + goto __pyx_L6; + } + + /* "View.MemoryView":529 + * info.shape = self.view.shape + * else: + * info.shape = NULL # <<<<<<<<<<<<<< + * + * if flags & PyBUF_STRIDES: + */ + /*else*/ { + __pyx_v_info->shape = NULL; + } + __pyx_L6:; + + /* "View.MemoryView":531 + * info.shape = NULL + * + * if flags & PyBUF_STRIDES: # <<<<<<<<<<<<<< + * info.strides = self.view.strides + * else: + */ + __pyx_t_1 = ((__pyx_v_flags & PyBUF_STRIDES) != 0); + if (__pyx_t_1) { + + /* "View.MemoryView":532 + * + * if flags & PyBUF_STRIDES: + * info.strides = self.view.strides # <<<<<<<<<<<<<< + * else: + * info.strides = NULL + */ + __pyx_t_3 = __pyx_v_self->view.strides; + __pyx_v_info->strides = __pyx_t_3; + + /* "View.MemoryView":531 + * info.shape = NULL + * + * if flags & PyBUF_STRIDES: # <<<<<<<<<<<<<< + * info.strides = self.view.strides + * else: + */ + goto __pyx_L7; + } + + /* "View.MemoryView":534 + * info.strides = self.view.strides + * else: + * info.strides = NULL # <<<<<<<<<<<<<< + * + * if flags & PyBUF_INDIRECT: + */ + /*else*/ { + __pyx_v_info->strides = NULL; + } + __pyx_L7:; + + /* "View.MemoryView":536 + * info.strides = NULL + * + * if flags & PyBUF_INDIRECT: # <<<<<<<<<<<<<< + * info.suboffsets = self.view.suboffsets + * else: + */ + __pyx_t_1 = ((__pyx_v_flags & PyBUF_INDIRECT) != 0); + if (__pyx_t_1) { + + /* "View.MemoryView":537 + * + * if flags & PyBUF_INDIRECT: + * info.suboffsets = self.view.suboffsets # <<<<<<<<<<<<<< + * else: + * info.suboffsets = NULL + */ + __pyx_t_3 = __pyx_v_self->view.suboffsets; + __pyx_v_info->suboffsets = __pyx_t_3; + + /* "View.MemoryView":536 + * info.strides = NULL + * + * if flags & PyBUF_INDIRECT: # <<<<<<<<<<<<<< + * info.suboffsets = self.view.suboffsets + * else: + */ + goto __pyx_L8; + } + + /* "View.MemoryView":539 + * info.suboffsets = self.view.suboffsets + * else: + * info.suboffsets = NULL # <<<<<<<<<<<<<< + * + * if flags & PyBUF_FORMAT: + */ + /*else*/ { + __pyx_v_info->suboffsets = NULL; + } + __pyx_L8:; + + /* "View.MemoryView":541 + * info.suboffsets = NULL + * + * if flags & PyBUF_FORMAT: # <<<<<<<<<<<<<< + * info.format = self.view.format + * else: + */ + __pyx_t_1 = ((__pyx_v_flags & PyBUF_FORMAT) != 0); + if (__pyx_t_1) { + + /* "View.MemoryView":542 + * + * if flags & PyBUF_FORMAT: + * info.format = self.view.format # <<<<<<<<<<<<<< + * else: + * info.format = NULL + */ + __pyx_t_4 = __pyx_v_self->view.format; + __pyx_v_info->format = __pyx_t_4; + + /* "View.MemoryView":541 + * info.suboffsets = NULL + * + * if flags & PyBUF_FORMAT: # <<<<<<<<<<<<<< + * info.format = self.view.format + * else: + */ + goto __pyx_L9; + } + + /* "View.MemoryView":544 + * info.format = self.view.format + * else: + * info.format = NULL # <<<<<<<<<<<<<< + * + * info.buf = self.view.buf + */ + /*else*/ { + __pyx_v_info->format = NULL; + } + __pyx_L9:; + + /* "View.MemoryView":546 + * info.format = NULL + * + * info.buf = self.view.buf # <<<<<<<<<<<<<< + * info.ndim = self.view.ndim + * info.itemsize = self.view.itemsize + */ + __pyx_t_5 = __pyx_v_self->view.buf; + __pyx_v_info->buf = __pyx_t_5; + + /* "View.MemoryView":547 + * + * info.buf = self.view.buf + * info.ndim = self.view.ndim # <<<<<<<<<<<<<< + * info.itemsize = self.view.itemsize + * info.len = self.view.len + */ + __pyx_t_6 = __pyx_v_self->view.ndim; + __pyx_v_info->ndim = __pyx_t_6; + + /* "View.MemoryView":548 + * info.buf = self.view.buf + * info.ndim = self.view.ndim + * info.itemsize = self.view.itemsize # <<<<<<<<<<<<<< + * info.len = self.view.len + * info.readonly = self.view.readonly + */ + __pyx_t_7 = __pyx_v_self->view.itemsize; + __pyx_v_info->itemsize = __pyx_t_7; + + /* "View.MemoryView":549 + * info.ndim = self.view.ndim + * info.itemsize = self.view.itemsize + * info.len = self.view.len # <<<<<<<<<<<<<< + * info.readonly = self.view.readonly + * info.obj = self + */ + __pyx_t_7 = __pyx_v_self->view.len; + __pyx_v_info->len = __pyx_t_7; + + /* "View.MemoryView":550 + * info.itemsize = self.view.itemsize + * info.len = self.view.len + * info.readonly = self.view.readonly # <<<<<<<<<<<<<< + * info.obj = self + * + */ + __pyx_t_1 = __pyx_v_self->view.readonly; + __pyx_v_info->readonly = __pyx_t_1; + + /* "View.MemoryView":551 + * info.len = self.view.len + * info.readonly = self.view.readonly + * info.obj = self # <<<<<<<<<<<<<< + * + * + */ + __Pyx_INCREF((PyObject *)__pyx_v_self); + __Pyx_GIVEREF((PyObject *)__pyx_v_self); + __Pyx_GOTREF(__pyx_v_info->obj); + __Pyx_DECREF(__pyx_v_info->obj); + __pyx_v_info->obj = ((PyObject *)__pyx_v_self); + + /* "View.MemoryView":521 + * itemp[i] = c + * + * @cname('getbuffer') # <<<<<<<<<<<<<< + * def __getbuffer__(self, Py_buffer *info, int flags): + * if flags & PyBUF_WRITABLE and self.view.readonly: + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView.memoryview.__getbuffer__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + if (__pyx_v_info->obj != NULL) { + __Pyx_GOTREF(__pyx_v_info->obj); + __Pyx_DECREF(__pyx_v_info->obj); __pyx_v_info->obj = 0; + } + goto __pyx_L2; + __pyx_L0:; + if (__pyx_v_info->obj == Py_None) { + __Pyx_GOTREF(__pyx_v_info->obj); + __Pyx_DECREF(__pyx_v_info->obj); __pyx_v_info->obj = 0; + } + __pyx_L2:; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":554 + * + * + * @property # <<<<<<<<<<<<<< + * def T(self): + * cdef _memoryviewslice result = memoryview_copy(self) + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_1T_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_1T_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_10memoryview_1T___get__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_1T___get__(struct __pyx_memoryview_obj *__pyx_v_self) { + struct __pyx_memoryviewslice_obj *__pyx_v_result = 0; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_t_2; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":556 + * @property + * def T(self): + * cdef _memoryviewslice result = memoryview_copy(self) # <<<<<<<<<<<<<< + * transpose_memslice(&result.from_slice) + * return result + */ + __pyx_t_1 = __pyx_memoryview_copy_object(__pyx_v_self); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 556, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + if (!(likely(((__pyx_t_1) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_1, __pyx_memoryviewslice_type))))) __PYX_ERR(1, 556, __pyx_L1_error) + __pyx_v_result = ((struct __pyx_memoryviewslice_obj *)__pyx_t_1); + __pyx_t_1 = 0; + + /* "View.MemoryView":557 + * def T(self): + * cdef _memoryviewslice result = memoryview_copy(self) + * transpose_memslice(&result.from_slice) # <<<<<<<<<<<<<< + * return result + * + */ + __pyx_t_2 = __pyx_memslice_transpose((&__pyx_v_result->from_slice)); if (unlikely(__pyx_t_2 == ((int)-1))) __PYX_ERR(1, 557, __pyx_L1_error) + + /* "View.MemoryView":558 + * cdef _memoryviewslice result = memoryview_copy(self) + * transpose_memslice(&result.from_slice) + * return result # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF((PyObject *)__pyx_v_result); + __pyx_r = ((PyObject *)__pyx_v_result); + goto __pyx_L0; + + /* "View.MemoryView":554 + * + * + * @property # <<<<<<<<<<<<<< + * def T(self): + * cdef _memoryviewslice result = memoryview_copy(self) + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("View.MemoryView.memoryview.T.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XDECREF((PyObject *)__pyx_v_result); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":560 + * return result + * + * @property # <<<<<<<<<<<<<< + * def base(self): + * return self._get_base() + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_4base_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_4base_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_10memoryview_4base___get__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_4base___get__(struct __pyx_memoryview_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":562 + * @property + * def base(self): + * return self._get_base() # <<<<<<<<<<<<<< + * + * cdef _get_base(self): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->_get_base(__pyx_v_self); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 562, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "View.MemoryView":560 + * return result + * + * @property # <<<<<<<<<<<<<< + * def base(self): + * return self._get_base() + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("View.MemoryView.memoryview.base.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":564 + * return self._get_base() + * + * cdef _get_base(self): # <<<<<<<<<<<<<< + * return self.obj + * + */ + +static PyObject *__pyx_memoryview__get_base(struct __pyx_memoryview_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("_get_base", 1); + + /* "View.MemoryView":565 + * + * cdef _get_base(self): + * return self.obj # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(__pyx_v_self->obj); + __pyx_r = __pyx_v_self->obj; + goto __pyx_L0; + + /* "View.MemoryView":564 + * return self._get_base() + * + * cdef _get_base(self): # <<<<<<<<<<<<<< + * return self.obj + * + */ + + /* function exit code */ + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":567 + * return self.obj + * + * @property # <<<<<<<<<<<<<< + * def shape(self): + * return tuple([length for length in self.view.shape[:self.view.ndim]]) + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_5shape_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_5shape_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_10memoryview_5shape___get__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_5shape___get__(struct __pyx_memoryview_obj *__pyx_v_self) { + Py_ssize_t __pyx_7genexpr__pyx_v_length; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + Py_ssize_t *__pyx_t_2; + Py_ssize_t *__pyx_t_3; + Py_ssize_t *__pyx_t_4; + PyObject *__pyx_t_5 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":569 + * @property + * def shape(self): + * return tuple([length for length in self.view.shape[:self.view.ndim]]) # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF(__pyx_r); + { /* enter inner scope */ + __pyx_t_1 = PyList_New(0); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 569, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_3 = (__pyx_v_self->view.shape + __pyx_v_self->view.ndim); + for (__pyx_t_4 = __pyx_v_self->view.shape; __pyx_t_4 < __pyx_t_3; __pyx_t_4++) { + __pyx_t_2 = __pyx_t_4; + __pyx_7genexpr__pyx_v_length = (__pyx_t_2[0]); + __pyx_t_5 = PyInt_FromSsize_t(__pyx_7genexpr__pyx_v_length); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 569, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + if (unlikely(__Pyx_ListComp_Append(__pyx_t_1, (PyObject*)__pyx_t_5))) __PYX_ERR(1, 569, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + } + } /* exit inner scope */ + __pyx_t_5 = PyList_AsTuple(((PyObject*)__pyx_t_1)); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 569, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_r = __pyx_t_5; + __pyx_t_5 = 0; + goto __pyx_L0; + + /* "View.MemoryView":567 + * return self.obj + * + * @property # <<<<<<<<<<<<<< + * def shape(self): + * return tuple([length for length in self.view.shape[:self.view.ndim]]) + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_5); + __Pyx_AddTraceback("View.MemoryView.memoryview.shape.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":571 + * return tuple([length for length in self.view.shape[:self.view.ndim]]) + * + * @property # <<<<<<<<<<<<<< + * def strides(self): + * if self.view.strides == NULL: + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_7strides_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_7strides_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_10memoryview_7strides___get__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_7strides___get__(struct __pyx_memoryview_obj *__pyx_v_self) { + Py_ssize_t __pyx_8genexpr1__pyx_v_stride; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + Py_ssize_t *__pyx_t_3; + Py_ssize_t *__pyx_t_4; + Py_ssize_t *__pyx_t_5; + PyObject *__pyx_t_6 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":573 + * @property + * def strides(self): + * if self.view.strides == NULL: # <<<<<<<<<<<<<< + * + * raise ValueError, "Buffer view does not expose strides" + */ + __pyx_t_1 = (__pyx_v_self->view.strides == NULL); + if (unlikely(__pyx_t_1)) { + + /* "View.MemoryView":575 + * if self.view.strides == NULL: + * + * raise ValueError, "Buffer view does not expose strides" # <<<<<<<<<<<<<< + * + * return tuple([stride for stride in self.view.strides[:self.view.ndim]]) + */ + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_kp_s_Buffer_view_does_not_expose_stri, 0, 0); + __PYX_ERR(1, 575, __pyx_L1_error) + + /* "View.MemoryView":573 + * @property + * def strides(self): + * if self.view.strides == NULL: # <<<<<<<<<<<<<< + * + * raise ValueError, "Buffer view does not expose strides" + */ + } + + /* "View.MemoryView":577 + * raise ValueError, "Buffer view does not expose strides" + * + * return tuple([stride for stride in self.view.strides[:self.view.ndim]]) # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF(__pyx_r); + { /* enter inner scope */ + __pyx_t_2 = PyList_New(0); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 577, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_4 = (__pyx_v_self->view.strides + __pyx_v_self->view.ndim); + for (__pyx_t_5 = __pyx_v_self->view.strides; __pyx_t_5 < __pyx_t_4; __pyx_t_5++) { + __pyx_t_3 = __pyx_t_5; + __pyx_8genexpr1__pyx_v_stride = (__pyx_t_3[0]); + __pyx_t_6 = PyInt_FromSsize_t(__pyx_8genexpr1__pyx_v_stride); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 577, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + if (unlikely(__Pyx_ListComp_Append(__pyx_t_2, (PyObject*)__pyx_t_6))) __PYX_ERR(1, 577, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + } + } /* exit inner scope */ + __pyx_t_6 = PyList_AsTuple(((PyObject*)__pyx_t_2)); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 577, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_r = __pyx_t_6; + __pyx_t_6 = 0; + goto __pyx_L0; + + /* "View.MemoryView":571 + * return tuple([length for length in self.view.shape[:self.view.ndim]]) + * + * @property # <<<<<<<<<<<<<< + * def strides(self): + * if self.view.strides == NULL: + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_6); + __Pyx_AddTraceback("View.MemoryView.memoryview.strides.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":579 + * return tuple([stride for stride in self.view.strides[:self.view.ndim]]) + * + * @property # <<<<<<<<<<<<<< + * def suboffsets(self): + * if self.view.suboffsets == NULL: + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_10suboffsets_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_10suboffsets_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_10memoryview_10suboffsets___get__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_10suboffsets___get__(struct __pyx_memoryview_obj *__pyx_v_self) { + Py_ssize_t __pyx_8genexpr2__pyx_v_suboffset; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + Py_ssize_t *__pyx_t_3; + Py_ssize_t *__pyx_t_4; + Py_ssize_t *__pyx_t_5; + PyObject *__pyx_t_6 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":581 + * @property + * def suboffsets(self): + * if self.view.suboffsets == NULL: # <<<<<<<<<<<<<< + * return (-1,) * self.view.ndim + * + */ + __pyx_t_1 = (__pyx_v_self->view.suboffsets == NULL); + if (__pyx_t_1) { + + /* "View.MemoryView":582 + * def suboffsets(self): + * if self.view.suboffsets == NULL: + * return (-1,) * self.view.ndim # <<<<<<<<<<<<<< + * + * return tuple([suboffset for suboffset in self.view.suboffsets[:self.view.ndim]]) + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = __Pyx_PySequence_Multiply(__pyx_tuple__4, __pyx_v_self->view.ndim); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 582, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":581 + * @property + * def suboffsets(self): + * if self.view.suboffsets == NULL: # <<<<<<<<<<<<<< + * return (-1,) * self.view.ndim + * + */ + } + + /* "View.MemoryView":584 + * return (-1,) * self.view.ndim + * + * return tuple([suboffset for suboffset in self.view.suboffsets[:self.view.ndim]]) # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF(__pyx_r); + { /* enter inner scope */ + __pyx_t_2 = PyList_New(0); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 584, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_4 = (__pyx_v_self->view.suboffsets + __pyx_v_self->view.ndim); + for (__pyx_t_5 = __pyx_v_self->view.suboffsets; __pyx_t_5 < __pyx_t_4; __pyx_t_5++) { + __pyx_t_3 = __pyx_t_5; + __pyx_8genexpr2__pyx_v_suboffset = (__pyx_t_3[0]); + __pyx_t_6 = PyInt_FromSsize_t(__pyx_8genexpr2__pyx_v_suboffset); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 584, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + if (unlikely(__Pyx_ListComp_Append(__pyx_t_2, (PyObject*)__pyx_t_6))) __PYX_ERR(1, 584, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + } + } /* exit inner scope */ + __pyx_t_6 = PyList_AsTuple(((PyObject*)__pyx_t_2)); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 584, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_r = __pyx_t_6; + __pyx_t_6 = 0; + goto __pyx_L0; + + /* "View.MemoryView":579 + * return tuple([stride for stride in self.view.strides[:self.view.ndim]]) + * + * @property # <<<<<<<<<<<<<< + * def suboffsets(self): + * if self.view.suboffsets == NULL: + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_6); + __Pyx_AddTraceback("View.MemoryView.memoryview.suboffsets.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":586 + * return tuple([suboffset for suboffset in self.view.suboffsets[:self.view.ndim]]) + * + * @property # <<<<<<<<<<<<<< + * def ndim(self): + * return self.view.ndim + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_4ndim_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_4ndim_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_10memoryview_4ndim___get__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_4ndim___get__(struct __pyx_memoryview_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":588 + * @property + * def ndim(self): + * return self.view.ndim # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __Pyx_PyInt_From_int(__pyx_v_self->view.ndim); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 588, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "View.MemoryView":586 + * return tuple([suboffset for suboffset in self.view.suboffsets[:self.view.ndim]]) + * + * @property # <<<<<<<<<<<<<< + * def ndim(self): + * return self.view.ndim + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("View.MemoryView.memoryview.ndim.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":590 + * return self.view.ndim + * + * @property # <<<<<<<<<<<<<< + * def itemsize(self): + * return self.view.itemsize + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_8itemsize_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_8itemsize_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_10memoryview_8itemsize___get__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_8itemsize___get__(struct __pyx_memoryview_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":592 + * @property + * def itemsize(self): + * return self.view.itemsize # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = PyInt_FromSsize_t(__pyx_v_self->view.itemsize); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 592, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "View.MemoryView":590 + * return self.view.ndim + * + * @property # <<<<<<<<<<<<<< + * def itemsize(self): + * return self.view.itemsize + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("View.MemoryView.memoryview.itemsize.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":594 + * return self.view.itemsize + * + * @property # <<<<<<<<<<<<<< + * def nbytes(self): + * return self.size * self.view.itemsize + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_6nbytes_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_6nbytes_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_10memoryview_6nbytes___get__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_6nbytes___get__(struct __pyx_memoryview_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":596 + * @property + * def nbytes(self): + * return self.size * self.view.itemsize # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self), __pyx_n_s_size); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 596, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = PyInt_FromSsize_t(__pyx_v_self->view.itemsize); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 596, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_3 = PyNumber_Multiply(__pyx_t_1, __pyx_t_2); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 596, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_r = __pyx_t_3; + __pyx_t_3 = 0; + goto __pyx_L0; + + /* "View.MemoryView":594 + * return self.view.itemsize + * + * @property # <<<<<<<<<<<<<< + * def nbytes(self): + * return self.size * self.view.itemsize + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_AddTraceback("View.MemoryView.memoryview.nbytes.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":598 + * return self.size * self.view.itemsize + * + * @property # <<<<<<<<<<<<<< + * def size(self): + * if self._size is None: + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_4size_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_4size_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_10memoryview_4size___get__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_4size___get__(struct __pyx_memoryview_obj *__pyx_v_self) { + PyObject *__pyx_v_result = NULL; + PyObject *__pyx_v_length = NULL; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + Py_ssize_t *__pyx_t_2; + Py_ssize_t *__pyx_t_3; + Py_ssize_t *__pyx_t_4; + PyObject *__pyx_t_5 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":600 + * @property + * def size(self): + * if self._size is None: # <<<<<<<<<<<<<< + * result = 1 + * + */ + __pyx_t_1 = (__pyx_v_self->_size == Py_None); + if (__pyx_t_1) { + + /* "View.MemoryView":601 + * def size(self): + * if self._size is None: + * result = 1 # <<<<<<<<<<<<<< + * + * for length in self.view.shape[:self.view.ndim]: + */ + __Pyx_INCREF(__pyx_int_1); + __pyx_v_result = __pyx_int_1; + + /* "View.MemoryView":603 + * result = 1 + * + * for length in self.view.shape[:self.view.ndim]: # <<<<<<<<<<<<<< + * result *= length + * + */ + __pyx_t_3 = (__pyx_v_self->view.shape + __pyx_v_self->view.ndim); + for (__pyx_t_4 = __pyx_v_self->view.shape; __pyx_t_4 < __pyx_t_3; __pyx_t_4++) { + __pyx_t_2 = __pyx_t_4; + __pyx_t_5 = PyInt_FromSsize_t((__pyx_t_2[0])); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 603, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_XDECREF_SET(__pyx_v_length, __pyx_t_5); + __pyx_t_5 = 0; + + /* "View.MemoryView":604 + * + * for length in self.view.shape[:self.view.ndim]: + * result *= length # <<<<<<<<<<<<<< + * + * self._size = result + */ + __pyx_t_5 = PyNumber_InPlaceMultiply(__pyx_v_result, __pyx_v_length); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 604, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_DECREF_SET(__pyx_v_result, __pyx_t_5); + __pyx_t_5 = 0; + } + + /* "View.MemoryView":606 + * result *= length + * + * self._size = result # <<<<<<<<<<<<<< + * + * return self._size + */ + __Pyx_INCREF(__pyx_v_result); + __Pyx_GIVEREF(__pyx_v_result); + __Pyx_GOTREF(__pyx_v_self->_size); + __Pyx_DECREF(__pyx_v_self->_size); + __pyx_v_self->_size = __pyx_v_result; + + /* "View.MemoryView":600 + * @property + * def size(self): + * if self._size is None: # <<<<<<<<<<<<<< + * result = 1 + * + */ + } + + /* "View.MemoryView":608 + * self._size = result + * + * return self._size # <<<<<<<<<<<<<< + * + * def __len__(self): + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(__pyx_v_self->_size); + __pyx_r = __pyx_v_self->_size; + goto __pyx_L0; + + /* "View.MemoryView":598 + * return self.size * self.view.itemsize + * + * @property # <<<<<<<<<<<<<< + * def size(self): + * if self._size is None: + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_5); + __Pyx_AddTraceback("View.MemoryView.memoryview.size.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_result); + __Pyx_XDECREF(__pyx_v_length); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":610 + * return self._size + * + * def __len__(self): # <<<<<<<<<<<<<< + * if self.view.ndim >= 1: + * return self.view.shape[0] + */ + +/* Python wrapper */ +static Py_ssize_t __pyx_memoryview___len__(PyObject *__pyx_v_self); /*proto*/ +static Py_ssize_t __pyx_memoryview___len__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + Py_ssize_t __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__len__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_10__len__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static Py_ssize_t __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_10__len__(struct __pyx_memoryview_obj *__pyx_v_self) { + Py_ssize_t __pyx_r; + int __pyx_t_1; + + /* "View.MemoryView":611 + * + * def __len__(self): + * if self.view.ndim >= 1: # <<<<<<<<<<<<<< + * return self.view.shape[0] + * + */ + __pyx_t_1 = (__pyx_v_self->view.ndim >= 1); + if (__pyx_t_1) { + + /* "View.MemoryView":612 + * def __len__(self): + * if self.view.ndim >= 1: + * return self.view.shape[0] # <<<<<<<<<<<<<< + * + * return 0 + */ + __pyx_r = (__pyx_v_self->view.shape[0]); + goto __pyx_L0; + + /* "View.MemoryView":611 + * + * def __len__(self): + * if self.view.ndim >= 1: # <<<<<<<<<<<<<< + * return self.view.shape[0] + * + */ + } + + /* "View.MemoryView":614 + * return self.view.shape[0] + * + * return 0 # <<<<<<<<<<<<<< + * + * def __repr__(self): + */ + __pyx_r = 0; + goto __pyx_L0; + + /* "View.MemoryView":610 + * return self._size + * + * def __len__(self): # <<<<<<<<<<<<<< + * if self.view.ndim >= 1: + * return self.view.shape[0] + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":616 + * return 0 + * + * def __repr__(self): # <<<<<<<<<<<<<< + * return "" % (self.base.__class__.__name__, + * id(self)) + */ + +/* Python wrapper */ +static PyObject *__pyx_memoryview___repr__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_memoryview___repr__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__repr__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_12__repr__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_12__repr__(struct __pyx_memoryview_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__repr__", 1); + + /* "View.MemoryView":617 + * + * def __repr__(self): + * return "" % (self.base.__class__.__name__, # <<<<<<<<<<<<<< + * id(self)) + * + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self), __pyx_n_s_base); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 617, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_t_1, __pyx_n_s_class); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 617, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(__pyx_t_2, __pyx_n_s_name_2); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 617, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + + /* "View.MemoryView":618 + * def __repr__(self): + * return "" % (self.base.__class__.__name__, + * id(self)) # <<<<<<<<<<<<<< + * + * def __str__(self): + */ + __pyx_t_2 = __Pyx_PyObject_CallOneArg(__pyx_builtin_id, ((PyObject *)__pyx_v_self)); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 618, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + + /* "View.MemoryView":617 + * + * def __repr__(self): + * return "" % (self.base.__class__.__name__, # <<<<<<<<<<<<<< + * id(self)) + * + */ + __pyx_t_3 = PyTuple_New(2); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 617, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_GIVEREF(__pyx_t_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_t_1)) __PYX_ERR(1, 617, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_2); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_t_2)) __PYX_ERR(1, 617, __pyx_L1_error); + __pyx_t_1 = 0; + __pyx_t_2 = 0; + __pyx_t_2 = __Pyx_PyString_Format(__pyx_kp_s_MemoryView_of_r_at_0x_x, __pyx_t_3); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 617, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":616 + * return 0 + * + * def __repr__(self): # <<<<<<<<<<<<<< + * return "" % (self.base.__class__.__name__, + * id(self)) + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_AddTraceback("View.MemoryView.memoryview.__repr__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":620 + * id(self)) + * + * def __str__(self): # <<<<<<<<<<<<<< + * return "" % (self.base.__class__.__name__,) + * + */ + +/* Python wrapper */ +static PyObject *__pyx_memoryview___str__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_memoryview___str__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__str__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_14__str__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_14__str__(struct __pyx_memoryview_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__str__", 1); + + /* "View.MemoryView":621 + * + * def __str__(self): + * return "" % (self.base.__class__.__name__,) # <<<<<<<<<<<<<< + * + * + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self), __pyx_n_s_base); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 621, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_t_1, __pyx_n_s_class); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 621, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(__pyx_t_2, __pyx_n_s_name_2); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 621, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_t_2 = PyTuple_New(1); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 621, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_GIVEREF(__pyx_t_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_2, 0, __pyx_t_1)) __PYX_ERR(1, 621, __pyx_L1_error); + __pyx_t_1 = 0; + __pyx_t_1 = __Pyx_PyString_Format(__pyx_kp_s_MemoryView_of_r_object, __pyx_t_2); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 621, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "View.MemoryView":620 + * id(self)) + * + * def __str__(self): # <<<<<<<<<<<<<< + * return "" % (self.base.__class__.__name__,) + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.memoryview.__str__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":624 + * + * + * def is_c_contig(self): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice *mslice + * cdef __Pyx_memviewslice tmp + */ + +/* Python wrapper */ +static PyObject *__pyx_memoryview_is_c_contig(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_memoryview_is_c_contig(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("is_c_contig (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + if (unlikely(__pyx_nargs > 0)) { + __Pyx_RaiseArgtupleInvalid("is_c_contig", 1, 0, 0, __pyx_nargs); return NULL;} + if (unlikely(__pyx_kwds) && __Pyx_NumKwargs_FASTCALL(__pyx_kwds) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "is_c_contig", 0))) return NULL; + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_16is_c_contig(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_16is_c_contig(struct __pyx_memoryview_obj *__pyx_v_self) { + __Pyx_memviewslice *__pyx_v_mslice; + __Pyx_memviewslice __pyx_v_tmp; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_memviewslice *__pyx_t_1; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("is_c_contig", 1); + + /* "View.MemoryView":627 + * cdef __Pyx_memviewslice *mslice + * cdef __Pyx_memviewslice tmp + * mslice = get_slice_from_memview(self, &tmp) # <<<<<<<<<<<<<< + * return slice_is_contig(mslice[0], 'C', self.view.ndim) + * + */ + __pyx_t_1 = __pyx_memoryview_get_slice_from_memoryview(__pyx_v_self, (&__pyx_v_tmp)); if (unlikely(__pyx_t_1 == ((__Pyx_memviewslice *)NULL))) __PYX_ERR(1, 627, __pyx_L1_error) + __pyx_v_mslice = __pyx_t_1; + + /* "View.MemoryView":628 + * cdef __Pyx_memviewslice tmp + * mslice = get_slice_from_memview(self, &tmp) + * return slice_is_contig(mslice[0], 'C', self.view.ndim) # <<<<<<<<<<<<<< + * + * def is_f_contig(self): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = __Pyx_PyBool_FromLong(__pyx_memviewslice_is_contig((__pyx_v_mslice[0]), 'C', __pyx_v_self->view.ndim)); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 628, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":624 + * + * + * def is_c_contig(self): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice *mslice + * cdef __Pyx_memviewslice tmp + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.memoryview.is_c_contig", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":630 + * return slice_is_contig(mslice[0], 'C', self.view.ndim) + * + * def is_f_contig(self): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice *mslice + * cdef __Pyx_memviewslice tmp + */ + +/* Python wrapper */ +static PyObject *__pyx_memoryview_is_f_contig(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_memoryview_is_f_contig(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("is_f_contig (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + if (unlikely(__pyx_nargs > 0)) { + __Pyx_RaiseArgtupleInvalid("is_f_contig", 1, 0, 0, __pyx_nargs); return NULL;} + if (unlikely(__pyx_kwds) && __Pyx_NumKwargs_FASTCALL(__pyx_kwds) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "is_f_contig", 0))) return NULL; + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_18is_f_contig(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_18is_f_contig(struct __pyx_memoryview_obj *__pyx_v_self) { + __Pyx_memviewslice *__pyx_v_mslice; + __Pyx_memviewslice __pyx_v_tmp; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_memviewslice *__pyx_t_1; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("is_f_contig", 1); + + /* "View.MemoryView":633 + * cdef __Pyx_memviewslice *mslice + * cdef __Pyx_memviewslice tmp + * mslice = get_slice_from_memview(self, &tmp) # <<<<<<<<<<<<<< + * return slice_is_contig(mslice[0], 'F', self.view.ndim) + * + */ + __pyx_t_1 = __pyx_memoryview_get_slice_from_memoryview(__pyx_v_self, (&__pyx_v_tmp)); if (unlikely(__pyx_t_1 == ((__Pyx_memviewslice *)NULL))) __PYX_ERR(1, 633, __pyx_L1_error) + __pyx_v_mslice = __pyx_t_1; + + /* "View.MemoryView":634 + * cdef __Pyx_memviewslice tmp + * mslice = get_slice_from_memview(self, &tmp) + * return slice_is_contig(mslice[0], 'F', self.view.ndim) # <<<<<<<<<<<<<< + * + * def copy(self): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = __Pyx_PyBool_FromLong(__pyx_memviewslice_is_contig((__pyx_v_mslice[0]), 'F', __pyx_v_self->view.ndim)); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 634, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":630 + * return slice_is_contig(mslice[0], 'C', self.view.ndim) + * + * def is_f_contig(self): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice *mslice + * cdef __Pyx_memviewslice tmp + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.memoryview.is_f_contig", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":636 + * return slice_is_contig(mslice[0], 'F', self.view.ndim) + * + * def copy(self): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice mslice + * cdef int flags = self.flags & ~PyBUF_F_CONTIGUOUS + */ + +/* Python wrapper */ +static PyObject *__pyx_memoryview_copy(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_memoryview_copy(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("copy (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + if (unlikely(__pyx_nargs > 0)) { + __Pyx_RaiseArgtupleInvalid("copy", 1, 0, 0, __pyx_nargs); return NULL;} + if (unlikely(__pyx_kwds) && __Pyx_NumKwargs_FASTCALL(__pyx_kwds) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "copy", 0))) return NULL; + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_20copy(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_20copy(struct __pyx_memoryview_obj *__pyx_v_self) { + __Pyx_memviewslice __pyx_v_mslice; + int __pyx_v_flags; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_memviewslice __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("copy", 1); + + /* "View.MemoryView":638 + * def copy(self): + * cdef __Pyx_memviewslice mslice + * cdef int flags = self.flags & ~PyBUF_F_CONTIGUOUS # <<<<<<<<<<<<<< + * + * slice_copy(self, &mslice) + */ + __pyx_v_flags = (__pyx_v_self->flags & (~PyBUF_F_CONTIGUOUS)); + + /* "View.MemoryView":640 + * cdef int flags = self.flags & ~PyBUF_F_CONTIGUOUS + * + * slice_copy(self, &mslice) # <<<<<<<<<<<<<< + * mslice = slice_copy_contig(&mslice, "c", self.view.ndim, + * self.view.itemsize, + */ + __pyx_memoryview_slice_copy(__pyx_v_self, (&__pyx_v_mslice)); + + /* "View.MemoryView":641 + * + * slice_copy(self, &mslice) + * mslice = slice_copy_contig(&mslice, "c", self.view.ndim, # <<<<<<<<<<<<<< + * self.view.itemsize, + * flags|PyBUF_C_CONTIGUOUS, + */ + __pyx_t_1 = __pyx_memoryview_copy_new_contig((&__pyx_v_mslice), ((char *)"c"), __pyx_v_self->view.ndim, __pyx_v_self->view.itemsize, (__pyx_v_flags | PyBUF_C_CONTIGUOUS), __pyx_v_self->dtype_is_object); if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 641, __pyx_L1_error) + __pyx_v_mslice = __pyx_t_1; + + /* "View.MemoryView":646 + * self.dtype_is_object) + * + * return memoryview_copy_from_slice(self, &mslice) # <<<<<<<<<<<<<< + * + * def copy_fortran(self): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = __pyx_memoryview_copy_object_from_slice(__pyx_v_self, (&__pyx_v_mslice)); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 646, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":636 + * return slice_is_contig(mslice[0], 'F', self.view.ndim) + * + * def copy(self): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice mslice + * cdef int flags = self.flags & ~PyBUF_F_CONTIGUOUS + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.memoryview.copy", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":648 + * return memoryview_copy_from_slice(self, &mslice) + * + * def copy_fortran(self): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice src, dst + * cdef int flags = self.flags & ~PyBUF_C_CONTIGUOUS + */ + +/* Python wrapper */ +static PyObject *__pyx_memoryview_copy_fortran(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_memoryview_copy_fortran(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("copy_fortran (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + if (unlikely(__pyx_nargs > 0)) { + __Pyx_RaiseArgtupleInvalid("copy_fortran", 1, 0, 0, __pyx_nargs); return NULL;} + if (unlikely(__pyx_kwds) && __Pyx_NumKwargs_FASTCALL(__pyx_kwds) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "copy_fortran", 0))) return NULL; + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_22copy_fortran(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_22copy_fortran(struct __pyx_memoryview_obj *__pyx_v_self) { + __Pyx_memviewslice __pyx_v_src; + __Pyx_memviewslice __pyx_v_dst; + int __pyx_v_flags; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_memviewslice __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("copy_fortran", 1); + + /* "View.MemoryView":650 + * def copy_fortran(self): + * cdef __Pyx_memviewslice src, dst + * cdef int flags = self.flags & ~PyBUF_C_CONTIGUOUS # <<<<<<<<<<<<<< + * + * slice_copy(self, &src) + */ + __pyx_v_flags = (__pyx_v_self->flags & (~PyBUF_C_CONTIGUOUS)); + + /* "View.MemoryView":652 + * cdef int flags = self.flags & ~PyBUF_C_CONTIGUOUS + * + * slice_copy(self, &src) # <<<<<<<<<<<<<< + * dst = slice_copy_contig(&src, "fortran", self.view.ndim, + * self.view.itemsize, + */ + __pyx_memoryview_slice_copy(__pyx_v_self, (&__pyx_v_src)); + + /* "View.MemoryView":653 + * + * slice_copy(self, &src) + * dst = slice_copy_contig(&src, "fortran", self.view.ndim, # <<<<<<<<<<<<<< + * self.view.itemsize, + * flags|PyBUF_F_CONTIGUOUS, + */ + __pyx_t_1 = __pyx_memoryview_copy_new_contig((&__pyx_v_src), ((char *)"fortran"), __pyx_v_self->view.ndim, __pyx_v_self->view.itemsize, (__pyx_v_flags | PyBUF_F_CONTIGUOUS), __pyx_v_self->dtype_is_object); if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 653, __pyx_L1_error) + __pyx_v_dst = __pyx_t_1; + + /* "View.MemoryView":658 + * self.dtype_is_object) + * + * return memoryview_copy_from_slice(self, &dst) # <<<<<<<<<<<<<< + * + * + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = __pyx_memoryview_copy_object_from_slice(__pyx_v_self, (&__pyx_v_dst)); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 658, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":648 + * return memoryview_copy_from_slice(self, &mslice) + * + * def copy_fortran(self): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice src, dst + * cdef int flags = self.flags & ~PyBUF_C_CONTIGUOUS + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.memoryview.copy_fortran", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): + */ + +/* Python wrapper */ +static PyObject *__pyx_pw___pyx_memoryview_1__reduce_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_pw___pyx_memoryview_1__reduce_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__reduce_cython__ (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + if (unlikely(__pyx_nargs > 0)) { + __Pyx_RaiseArgtupleInvalid("__reduce_cython__", 1, 0, 0, __pyx_nargs); return NULL;} + if (unlikely(__pyx_kwds) && __Pyx_NumKwargs_FASTCALL(__pyx_kwds) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "__reduce_cython__", 0))) return NULL; + __pyx_r = __pyx_pf___pyx_memoryview___reduce_cython__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf___pyx_memoryview___reduce_cython__(CYTHON_UNUSED struct __pyx_memoryview_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__reduce_cython__", 1); + + /* "(tree fragment)":2 + * def __reduce_cython__(self): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" # <<<<<<<<<<<<<< + * def __setstate_cython__(self, __pyx_state): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + */ + __Pyx_Raise(__pyx_builtin_TypeError, __pyx_kp_s_no_default___reduce___due_to_non, 0, 0); + __PYX_ERR(1, 2, __pyx_L1_error) + + /* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView.memoryview.__reduce_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":3 + * def __reduce_cython__(self): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + */ + +/* Python wrapper */ +static PyObject *__pyx_pw___pyx_memoryview_3__setstate_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_pw___pyx_memoryview_3__setstate_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + CYTHON_UNUSED PyObject *__pyx_v___pyx_state = 0; + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[1] = {0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__setstate_cython__ (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_pyx_state,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 1: values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_FASTCALL(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_pyx_state)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 3, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "__setstate_cython__") < 0)) __PYX_ERR(1, 3, __pyx_L3_error) + } + } else if (unlikely(__pyx_nargs != 1)) { + goto __pyx_L5_argtuple_error; + } else { + values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + } + __pyx_v___pyx_state = values[0]; + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("__setstate_cython__", 1, 1, 1, __pyx_nargs); __PYX_ERR(1, 3, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("View.MemoryView.memoryview.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return NULL; + __pyx_L4_argument_unpacking_done:; + __pyx_r = __pyx_pf___pyx_memoryview_2__setstate_cython__(((struct __pyx_memoryview_obj *)__pyx_v_self), __pyx_v___pyx_state); + + /* function exit code */ + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf___pyx_memoryview_2__setstate_cython__(CYTHON_UNUSED struct __pyx_memoryview_obj *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__setstate_cython__", 1); + + /* "(tree fragment)":4 + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" # <<<<<<<<<<<<<< + */ + __Pyx_Raise(__pyx_builtin_TypeError, __pyx_kp_s_no_default___reduce___due_to_non, 0, 0); + __PYX_ERR(1, 4, __pyx_L1_error) + + /* "(tree fragment)":3 + * def __reduce_cython__(self): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView.memoryview.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":662 + * + * @cname('__pyx_memoryview_new') + * cdef memoryview_cwrapper(object o, int flags, bint dtype_is_object, __Pyx_TypeInfo *typeinfo): # <<<<<<<<<<<<<< + * cdef memoryview result = memoryview(o, flags, dtype_is_object) + * result.typeinfo = typeinfo + */ + +static PyObject *__pyx_memoryview_new(PyObject *__pyx_v_o, int __pyx_v_flags, int __pyx_v_dtype_is_object, __Pyx_TypeInfo *__pyx_v_typeinfo) { + struct __pyx_memoryview_obj *__pyx_v_result = 0; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("memoryview_cwrapper", 1); + + /* "View.MemoryView":663 + * @cname('__pyx_memoryview_new') + * cdef memoryview_cwrapper(object o, int flags, bint dtype_is_object, __Pyx_TypeInfo *typeinfo): + * cdef memoryview result = memoryview(o, flags, dtype_is_object) # <<<<<<<<<<<<<< + * result.typeinfo = typeinfo + * return result + */ + __pyx_t_1 = __Pyx_PyInt_From_int(__pyx_v_flags); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 663, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = __Pyx_PyBool_FromLong(__pyx_v_dtype_is_object); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 663, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_3 = PyTuple_New(3); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 663, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_INCREF(__pyx_v_o); + __Pyx_GIVEREF(__pyx_v_o); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_v_o)) __PYX_ERR(1, 663, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_t_1)) __PYX_ERR(1, 663, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_2); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 2, __pyx_t_2)) __PYX_ERR(1, 663, __pyx_L1_error); + __pyx_t_1 = 0; + __pyx_t_2 = 0; + __pyx_t_2 = __Pyx_PyObject_Call(((PyObject *)__pyx_memoryview_type), __pyx_t_3, NULL); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 663, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __pyx_v_result = ((struct __pyx_memoryview_obj *)__pyx_t_2); + __pyx_t_2 = 0; + + /* "View.MemoryView":664 + * cdef memoryview_cwrapper(object o, int flags, bint dtype_is_object, __Pyx_TypeInfo *typeinfo): + * cdef memoryview result = memoryview(o, flags, dtype_is_object) + * result.typeinfo = typeinfo # <<<<<<<<<<<<<< + * return result + * + */ + __pyx_v_result->typeinfo = __pyx_v_typeinfo; + + /* "View.MemoryView":665 + * cdef memoryview result = memoryview(o, flags, dtype_is_object) + * result.typeinfo = typeinfo + * return result # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_check') + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF((PyObject *)__pyx_v_result); + __pyx_r = ((PyObject *)__pyx_v_result); + goto __pyx_L0; + + /* "View.MemoryView":662 + * + * @cname('__pyx_memoryview_new') + * cdef memoryview_cwrapper(object o, int flags, bint dtype_is_object, __Pyx_TypeInfo *typeinfo): # <<<<<<<<<<<<<< + * cdef memoryview result = memoryview(o, flags, dtype_is_object) + * result.typeinfo = typeinfo + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_AddTraceback("View.MemoryView.memoryview_cwrapper", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XDECREF((PyObject *)__pyx_v_result); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":668 + * + * @cname('__pyx_memoryview_check') + * cdef inline bint memoryview_check(object o) noexcept: # <<<<<<<<<<<<<< + * return isinstance(o, memoryview) + * + */ + +static CYTHON_INLINE int __pyx_memoryview_check(PyObject *__pyx_v_o) { + int __pyx_r; + int __pyx_t_1; + + /* "View.MemoryView":669 + * @cname('__pyx_memoryview_check') + * cdef inline bint memoryview_check(object o) noexcept: + * return isinstance(o, memoryview) # <<<<<<<<<<<<<< + * + * cdef tuple _unellipsify(object index, int ndim): + */ + __pyx_t_1 = __Pyx_TypeCheck(__pyx_v_o, __pyx_memoryview_type); + __pyx_r = __pyx_t_1; + goto __pyx_L0; + + /* "View.MemoryView":668 + * + * @cname('__pyx_memoryview_check') + * cdef inline bint memoryview_check(object o) noexcept: # <<<<<<<<<<<<<< + * return isinstance(o, memoryview) + * + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":671 + * return isinstance(o, memoryview) + * + * cdef tuple _unellipsify(object index, int ndim): # <<<<<<<<<<<<<< + * """ + * Replace all ellipses with full slices and fill incomplete indices with + */ + +static PyObject *_unellipsify(PyObject *__pyx_v_index, int __pyx_v_ndim) { + Py_ssize_t __pyx_v_idx; + PyObject *__pyx_v_tup = NULL; + PyObject *__pyx_v_result = NULL; + int __pyx_v_have_slices; + int __pyx_v_seen_ellipsis; + PyObject *__pyx_v_item = NULL; + Py_ssize_t __pyx_v_nslices; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + Py_ssize_t __pyx_t_4; + Py_ssize_t __pyx_t_5; + Py_UCS4 __pyx_t_6; + PyObject *__pyx_t_7 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("_unellipsify", 1); + + /* "View.MemoryView":677 + * """ + * cdef Py_ssize_t idx + * tup = index if isinstance(index, tuple) else (index,) # <<<<<<<<<<<<<< + * + * result = [slice(None)] * ndim + */ + __pyx_t_2 = PyTuple_Check(__pyx_v_index); + if (__pyx_t_2) { + __Pyx_INCREF(((PyObject*)__pyx_v_index)); + __pyx_t_1 = __pyx_v_index; + } else { + __pyx_t_3 = PyTuple_New(1); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 677, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_INCREF(__pyx_v_index); + __Pyx_GIVEREF(__pyx_v_index); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_v_index)) __PYX_ERR(1, 677, __pyx_L1_error); + __pyx_t_1 = __pyx_t_3; + __pyx_t_3 = 0; + } + __pyx_v_tup = ((PyObject*)__pyx_t_1); + __pyx_t_1 = 0; + + /* "View.MemoryView":679 + * tup = index if isinstance(index, tuple) else (index,) + * + * result = [slice(None)] * ndim # <<<<<<<<<<<<<< + * have_slices = False + * seen_ellipsis = False + */ + __pyx_t_1 = PyList_New(1 * ((__pyx_v_ndim<0) ? 0:__pyx_v_ndim)); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 679, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + { Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < __pyx_v_ndim; __pyx_temp++) { + __Pyx_INCREF(__pyx_slice__5); + __Pyx_GIVEREF(__pyx_slice__5); + if (__Pyx_PyList_SET_ITEM(__pyx_t_1, __pyx_temp, __pyx_slice__5)) __PYX_ERR(1, 679, __pyx_L1_error); + } + } + __pyx_v_result = ((PyObject*)__pyx_t_1); + __pyx_t_1 = 0; + + /* "View.MemoryView":680 + * + * result = [slice(None)] * ndim + * have_slices = False # <<<<<<<<<<<<<< + * seen_ellipsis = False + * idx = 0 + */ + __pyx_v_have_slices = 0; + + /* "View.MemoryView":681 + * result = [slice(None)] * ndim + * have_slices = False + * seen_ellipsis = False # <<<<<<<<<<<<<< + * idx = 0 + * for item in tup: + */ + __pyx_v_seen_ellipsis = 0; + + /* "View.MemoryView":682 + * have_slices = False + * seen_ellipsis = False + * idx = 0 # <<<<<<<<<<<<<< + * for item in tup: + * if item is Ellipsis: + */ + __pyx_v_idx = 0; + + /* "View.MemoryView":683 + * seen_ellipsis = False + * idx = 0 + * for item in tup: # <<<<<<<<<<<<<< + * if item is Ellipsis: + * if not seen_ellipsis: + */ + if (unlikely(__pyx_v_tup == Py_None)) { + PyErr_SetString(PyExc_TypeError, "'NoneType' object is not iterable"); + __PYX_ERR(1, 683, __pyx_L1_error) + } + __pyx_t_1 = __pyx_v_tup; __Pyx_INCREF(__pyx_t_1); + __pyx_t_4 = 0; + for (;;) { + { + Py_ssize_t __pyx_temp = __Pyx_PyTuple_GET_SIZE(__pyx_t_1); + #if !CYTHON_ASSUME_SAFE_MACROS + if (unlikely((__pyx_temp < 0))) __PYX_ERR(1, 683, __pyx_L1_error) + #endif + if (__pyx_t_4 >= __pyx_temp) break; + } + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + __pyx_t_3 = PyTuple_GET_ITEM(__pyx_t_1, __pyx_t_4); __Pyx_INCREF(__pyx_t_3); __pyx_t_4++; if (unlikely((0 < 0))) __PYX_ERR(1, 683, __pyx_L1_error) + #else + __pyx_t_3 = __Pyx_PySequence_ITEM(__pyx_t_1, __pyx_t_4); __pyx_t_4++; if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 683, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + #endif + __Pyx_XDECREF_SET(__pyx_v_item, __pyx_t_3); + __pyx_t_3 = 0; + + /* "View.MemoryView":684 + * idx = 0 + * for item in tup: + * if item is Ellipsis: # <<<<<<<<<<<<<< + * if not seen_ellipsis: + * idx += ndim - len(tup) + */ + __pyx_t_2 = (__pyx_v_item == __pyx_builtin_Ellipsis); + if (__pyx_t_2) { + + /* "View.MemoryView":685 + * for item in tup: + * if item is Ellipsis: + * if not seen_ellipsis: # <<<<<<<<<<<<<< + * idx += ndim - len(tup) + * seen_ellipsis = True + */ + __pyx_t_2 = (!__pyx_v_seen_ellipsis); + if (__pyx_t_2) { + + /* "View.MemoryView":686 + * if item is Ellipsis: + * if not seen_ellipsis: + * idx += ndim - len(tup) # <<<<<<<<<<<<<< + * seen_ellipsis = True + * have_slices = True + */ + if (unlikely(__pyx_v_tup == Py_None)) { + PyErr_SetString(PyExc_TypeError, "object of type 'NoneType' has no len()"); + __PYX_ERR(1, 686, __pyx_L1_error) + } + __pyx_t_5 = __Pyx_PyTuple_GET_SIZE(__pyx_v_tup); if (unlikely(__pyx_t_5 == ((Py_ssize_t)-1))) __PYX_ERR(1, 686, __pyx_L1_error) + __pyx_v_idx = (__pyx_v_idx + (__pyx_v_ndim - __pyx_t_5)); + + /* "View.MemoryView":687 + * if not seen_ellipsis: + * idx += ndim - len(tup) + * seen_ellipsis = True # <<<<<<<<<<<<<< + * have_slices = True + * else: + */ + __pyx_v_seen_ellipsis = 1; + + /* "View.MemoryView":685 + * for item in tup: + * if item is Ellipsis: + * if not seen_ellipsis: # <<<<<<<<<<<<<< + * idx += ndim - len(tup) + * seen_ellipsis = True + */ + } + + /* "View.MemoryView":688 + * idx += ndim - len(tup) + * seen_ellipsis = True + * have_slices = True # <<<<<<<<<<<<<< + * else: + * if isinstance(item, slice): + */ + __pyx_v_have_slices = 1; + + /* "View.MemoryView":684 + * idx = 0 + * for item in tup: + * if item is Ellipsis: # <<<<<<<<<<<<<< + * if not seen_ellipsis: + * idx += ndim - len(tup) + */ + goto __pyx_L5; + } + + /* "View.MemoryView":690 + * have_slices = True + * else: + * if isinstance(item, slice): # <<<<<<<<<<<<<< + * have_slices = True + * elif not PyIndex_Check(item): + */ + /*else*/ { + __pyx_t_2 = PySlice_Check(__pyx_v_item); + if (__pyx_t_2) { + + /* "View.MemoryView":691 + * else: + * if isinstance(item, slice): + * have_slices = True # <<<<<<<<<<<<<< + * elif not PyIndex_Check(item): + * raise TypeError, f"Cannot index with type '{type(item)}'" + */ + __pyx_v_have_slices = 1; + + /* "View.MemoryView":690 + * have_slices = True + * else: + * if isinstance(item, slice): # <<<<<<<<<<<<<< + * have_slices = True + * elif not PyIndex_Check(item): + */ + goto __pyx_L7; + } + + /* "View.MemoryView":692 + * if isinstance(item, slice): + * have_slices = True + * elif not PyIndex_Check(item): # <<<<<<<<<<<<<< + * raise TypeError, f"Cannot index with type '{type(item)}'" + * result[idx] = item + */ + __pyx_t_2 = (!(PyIndex_Check(__pyx_v_item) != 0)); + if (unlikely(__pyx_t_2)) { + + /* "View.MemoryView":693 + * have_slices = True + * elif not PyIndex_Check(item): + * raise TypeError, f"Cannot index with type '{type(item)}'" # <<<<<<<<<<<<<< + * result[idx] = item + * idx += 1 + */ + __pyx_t_3 = PyTuple_New(3); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 693, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_5 = 0; + __pyx_t_6 = 127; + __Pyx_INCREF(__pyx_kp_u_Cannot_index_with_type); + __pyx_t_5 += 24; + __Pyx_GIVEREF(__pyx_kp_u_Cannot_index_with_type); + PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_kp_u_Cannot_index_with_type); + __pyx_t_7 = __Pyx_PyObject_FormatSimple(((PyObject *)Py_TYPE(__pyx_v_item)), __pyx_empty_unicode); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 693, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __pyx_t_6 = (__Pyx_PyUnicode_MAX_CHAR_VALUE(__pyx_t_7) > __pyx_t_6) ? __Pyx_PyUnicode_MAX_CHAR_VALUE(__pyx_t_7) : __pyx_t_6; + __pyx_t_5 += __Pyx_PyUnicode_GET_LENGTH(__pyx_t_7); + __Pyx_GIVEREF(__pyx_t_7); + PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_t_7); + __pyx_t_7 = 0; + __Pyx_INCREF(__pyx_kp_u__6); + __pyx_t_5 += 1; + __Pyx_GIVEREF(__pyx_kp_u__6); + PyTuple_SET_ITEM(__pyx_t_3, 2, __pyx_kp_u__6); + __pyx_t_7 = __Pyx_PyUnicode_Join(__pyx_t_3, 3, __pyx_t_5, __pyx_t_6); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 693, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __Pyx_Raise(__pyx_builtin_TypeError, __pyx_t_7, 0, 0); + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + __PYX_ERR(1, 693, __pyx_L1_error) + + /* "View.MemoryView":692 + * if isinstance(item, slice): + * have_slices = True + * elif not PyIndex_Check(item): # <<<<<<<<<<<<<< + * raise TypeError, f"Cannot index with type '{type(item)}'" + * result[idx] = item + */ + } + __pyx_L7:; + + /* "View.MemoryView":694 + * elif not PyIndex_Check(item): + * raise TypeError, f"Cannot index with type '{type(item)}'" + * result[idx] = item # <<<<<<<<<<<<<< + * idx += 1 + * + */ + if (unlikely((__Pyx_SetItemInt(__pyx_v_result, __pyx_v_idx, __pyx_v_item, Py_ssize_t, 1, PyInt_FromSsize_t, 1, 1, 1) < 0))) __PYX_ERR(1, 694, __pyx_L1_error) + } + __pyx_L5:; + + /* "View.MemoryView":695 + * raise TypeError, f"Cannot index with type '{type(item)}'" + * result[idx] = item + * idx += 1 # <<<<<<<<<<<<<< + * + * nslices = ndim - idx + */ + __pyx_v_idx = (__pyx_v_idx + 1); + + /* "View.MemoryView":683 + * seen_ellipsis = False + * idx = 0 + * for item in tup: # <<<<<<<<<<<<<< + * if item is Ellipsis: + * if not seen_ellipsis: + */ + } + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + + /* "View.MemoryView":697 + * idx += 1 + * + * nslices = ndim - idx # <<<<<<<<<<<<<< + * return have_slices or nslices, tuple(result) + * + */ + __pyx_v_nslices = (__pyx_v_ndim - __pyx_v_idx); + + /* "View.MemoryView":698 + * + * nslices = ndim - idx + * return have_slices or nslices, tuple(result) # <<<<<<<<<<<<<< + * + * cdef int assert_direct_dimensions(Py_ssize_t *suboffsets, int ndim) except -1: + */ + __Pyx_XDECREF(__pyx_r); + if (!__pyx_v_have_slices) { + } else { + __pyx_t_7 = __Pyx_PyBool_FromLong(__pyx_v_have_slices); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 698, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __pyx_t_1 = __pyx_t_7; + __pyx_t_7 = 0; + goto __pyx_L9_bool_binop_done; + } + __pyx_t_7 = PyInt_FromSsize_t(__pyx_v_nslices); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 698, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __pyx_t_1 = __pyx_t_7; + __pyx_t_7 = 0; + __pyx_L9_bool_binop_done:; + __pyx_t_7 = PyList_AsTuple(__pyx_v_result); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 698, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __pyx_t_3 = PyTuple_New(2); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 698, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_GIVEREF(__pyx_t_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_t_1)) __PYX_ERR(1, 698, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_7); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_t_7)) __PYX_ERR(1, 698, __pyx_L1_error); + __pyx_t_1 = 0; + __pyx_t_7 = 0; + __pyx_r = ((PyObject*)__pyx_t_3); + __pyx_t_3 = 0; + goto __pyx_L0; + + /* "View.MemoryView":671 + * return isinstance(o, memoryview) + * + * cdef tuple _unellipsify(object index, int ndim): # <<<<<<<<<<<<<< + * """ + * Replace all ellipses with full slices and fill incomplete indices with + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_XDECREF(__pyx_t_7); + __Pyx_AddTraceback("View.MemoryView._unellipsify", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_tup); + __Pyx_XDECREF(__pyx_v_result); + __Pyx_XDECREF(__pyx_v_item); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":700 + * return have_slices or nslices, tuple(result) + * + * cdef int assert_direct_dimensions(Py_ssize_t *suboffsets, int ndim) except -1: # <<<<<<<<<<<<<< + * for suboffset in suboffsets[:ndim]: + * if suboffset >= 0: + */ + +static int assert_direct_dimensions(Py_ssize_t *__pyx_v_suboffsets, int __pyx_v_ndim) { + Py_ssize_t __pyx_v_suboffset; + int __pyx_r; + Py_ssize_t *__pyx_t_1; + Py_ssize_t *__pyx_t_2; + Py_ssize_t *__pyx_t_3; + int __pyx_t_4; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + + /* "View.MemoryView":701 + * + * cdef int assert_direct_dimensions(Py_ssize_t *suboffsets, int ndim) except -1: + * for suboffset in suboffsets[:ndim]: # <<<<<<<<<<<<<< + * if suboffset >= 0: + * raise ValueError, "Indirect dimensions not supported" + */ + __pyx_t_2 = (__pyx_v_suboffsets + __pyx_v_ndim); + for (__pyx_t_3 = __pyx_v_suboffsets; __pyx_t_3 < __pyx_t_2; __pyx_t_3++) { + __pyx_t_1 = __pyx_t_3; + __pyx_v_suboffset = (__pyx_t_1[0]); + + /* "View.MemoryView":702 + * cdef int assert_direct_dimensions(Py_ssize_t *suboffsets, int ndim) except -1: + * for suboffset in suboffsets[:ndim]: + * if suboffset >= 0: # <<<<<<<<<<<<<< + * raise ValueError, "Indirect dimensions not supported" + * return 0 # return type just used as an error flag + */ + __pyx_t_4 = (__pyx_v_suboffset >= 0); + if (unlikely(__pyx_t_4)) { + + /* "View.MemoryView":703 + * for suboffset in suboffsets[:ndim]: + * if suboffset >= 0: + * raise ValueError, "Indirect dimensions not supported" # <<<<<<<<<<<<<< + * return 0 # return type just used as an error flag + * + */ + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_kp_s_Indirect_dimensions_not_supporte, 0, 0); + __PYX_ERR(1, 703, __pyx_L1_error) + + /* "View.MemoryView":702 + * cdef int assert_direct_dimensions(Py_ssize_t *suboffsets, int ndim) except -1: + * for suboffset in suboffsets[:ndim]: + * if suboffset >= 0: # <<<<<<<<<<<<<< + * raise ValueError, "Indirect dimensions not supported" + * return 0 # return type just used as an error flag + */ + } + } + + /* "View.MemoryView":704 + * if suboffset >= 0: + * raise ValueError, "Indirect dimensions not supported" + * return 0 # return type just used as an error flag # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = 0; + goto __pyx_L0; + + /* "View.MemoryView":700 + * return have_slices or nslices, tuple(result) + * + * cdef int assert_direct_dimensions(Py_ssize_t *suboffsets, int ndim) except -1: # <<<<<<<<<<<<<< + * for suboffset in suboffsets[:ndim]: + * if suboffset >= 0: + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView.assert_direct_dimensions", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":711 + * + * @cname('__pyx_memview_slice') + * cdef memoryview memview_slice(memoryview memview, object indices): # <<<<<<<<<<<<<< + * cdef int new_ndim = 0, suboffset_dim = -1, dim + * cdef bint negative_step + */ + +static struct __pyx_memoryview_obj *__pyx_memview_slice(struct __pyx_memoryview_obj *__pyx_v_memview, PyObject *__pyx_v_indices) { + int __pyx_v_new_ndim; + int __pyx_v_suboffset_dim; + int __pyx_v_dim; + __Pyx_memviewslice __pyx_v_src; + __Pyx_memviewslice __pyx_v_dst; + __Pyx_memviewslice *__pyx_v_p_src; + struct __pyx_memoryviewslice_obj *__pyx_v_memviewsliceobj = 0; + __Pyx_memviewslice *__pyx_v_p_dst; + int *__pyx_v_p_suboffset_dim; + Py_ssize_t __pyx_v_start; + Py_ssize_t __pyx_v_stop; + Py_ssize_t __pyx_v_step; + Py_ssize_t __pyx_v_cindex; + int __pyx_v_have_start; + int __pyx_v_have_stop; + int __pyx_v_have_step; + PyObject *__pyx_v_index = NULL; + struct __pyx_memoryview_obj *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + struct __pyx_memoryview_obj *__pyx_t_3; + char *__pyx_t_4; + int __pyx_t_5; + Py_ssize_t __pyx_t_6; + PyObject *(*__pyx_t_7)(PyObject *); + PyObject *__pyx_t_8 = NULL; + Py_ssize_t __pyx_t_9; + int __pyx_t_10; + Py_ssize_t __pyx_t_11; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("memview_slice", 1); + + /* "View.MemoryView":712 + * @cname('__pyx_memview_slice') + * cdef memoryview memview_slice(memoryview memview, object indices): + * cdef int new_ndim = 0, suboffset_dim = -1, dim # <<<<<<<<<<<<<< + * cdef bint negative_step + * cdef __Pyx_memviewslice src, dst + */ + __pyx_v_new_ndim = 0; + __pyx_v_suboffset_dim = -1; + + /* "View.MemoryView":719 + * + * + * memset(&dst, 0, sizeof(dst)) # <<<<<<<<<<<<<< + * + * cdef _memoryviewslice memviewsliceobj + */ + (void)(memset((&__pyx_v_dst), 0, (sizeof(__pyx_v_dst)))); + + /* "View.MemoryView":723 + * cdef _memoryviewslice memviewsliceobj + * + * assert memview.view.ndim > 0 # <<<<<<<<<<<<<< + * + * if isinstance(memview, _memoryviewslice): + */ + #ifndef CYTHON_WITHOUT_ASSERTIONS + if (unlikely(__pyx_assertions_enabled())) { + __pyx_t_1 = (__pyx_v_memview->view.ndim > 0); + if (unlikely(!__pyx_t_1)) { + __Pyx_Raise(__pyx_builtin_AssertionError, 0, 0, 0); + __PYX_ERR(1, 723, __pyx_L1_error) + } + } + #else + if ((1)); else __PYX_ERR(1, 723, __pyx_L1_error) + #endif + + /* "View.MemoryView":725 + * assert memview.view.ndim > 0 + * + * if isinstance(memview, _memoryviewslice): # <<<<<<<<<<<<<< + * memviewsliceobj = memview + * p_src = &memviewsliceobj.from_slice + */ + __pyx_t_1 = __Pyx_TypeCheck(((PyObject *)__pyx_v_memview), __pyx_memoryviewslice_type); + if (__pyx_t_1) { + + /* "View.MemoryView":726 + * + * if isinstance(memview, _memoryviewslice): + * memviewsliceobj = memview # <<<<<<<<<<<<<< + * p_src = &memviewsliceobj.from_slice + * else: + */ + if (!(likely(((((PyObject *)__pyx_v_memview)) == Py_None) || likely(__Pyx_TypeTest(((PyObject *)__pyx_v_memview), __pyx_memoryviewslice_type))))) __PYX_ERR(1, 726, __pyx_L1_error) + __pyx_t_2 = ((PyObject *)__pyx_v_memview); + __Pyx_INCREF(__pyx_t_2); + __pyx_v_memviewsliceobj = ((struct __pyx_memoryviewslice_obj *)__pyx_t_2); + __pyx_t_2 = 0; + + /* "View.MemoryView":727 + * if isinstance(memview, _memoryviewslice): + * memviewsliceobj = memview + * p_src = &memviewsliceobj.from_slice # <<<<<<<<<<<<<< + * else: + * slice_copy(memview, &src) + */ + __pyx_v_p_src = (&__pyx_v_memviewsliceobj->from_slice); + + /* "View.MemoryView":725 + * assert memview.view.ndim > 0 + * + * if isinstance(memview, _memoryviewslice): # <<<<<<<<<<<<<< + * memviewsliceobj = memview + * p_src = &memviewsliceobj.from_slice + */ + goto __pyx_L3; + } + + /* "View.MemoryView":729 + * p_src = &memviewsliceobj.from_slice + * else: + * slice_copy(memview, &src) # <<<<<<<<<<<<<< + * p_src = &src + * + */ + /*else*/ { + __pyx_memoryview_slice_copy(__pyx_v_memview, (&__pyx_v_src)); + + /* "View.MemoryView":730 + * else: + * slice_copy(memview, &src) + * p_src = &src # <<<<<<<<<<<<<< + * + * + */ + __pyx_v_p_src = (&__pyx_v_src); + } + __pyx_L3:; + + /* "View.MemoryView":736 + * + * + * dst.memview = p_src.memview # <<<<<<<<<<<<<< + * dst.data = p_src.data + * + */ + __pyx_t_3 = __pyx_v_p_src->memview; + __pyx_v_dst.memview = __pyx_t_3; + + /* "View.MemoryView":737 + * + * dst.memview = p_src.memview + * dst.data = p_src.data # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_4 = __pyx_v_p_src->data; + __pyx_v_dst.data = __pyx_t_4; + + /* "View.MemoryView":742 + * + * + * cdef __Pyx_memviewslice *p_dst = &dst # <<<<<<<<<<<<<< + * cdef int *p_suboffset_dim = &suboffset_dim + * cdef Py_ssize_t start, stop, step, cindex + */ + __pyx_v_p_dst = (&__pyx_v_dst); + + /* "View.MemoryView":743 + * + * cdef __Pyx_memviewslice *p_dst = &dst + * cdef int *p_suboffset_dim = &suboffset_dim # <<<<<<<<<<<<<< + * cdef Py_ssize_t start, stop, step, cindex + * cdef bint have_start, have_stop, have_step + */ + __pyx_v_p_suboffset_dim = (&__pyx_v_suboffset_dim); + + /* "View.MemoryView":747 + * cdef bint have_start, have_stop, have_step + * + * for dim, index in enumerate(indices): # <<<<<<<<<<<<<< + * if PyIndex_Check(index): + * cindex = index + */ + __pyx_t_5 = 0; + if (likely(PyList_CheckExact(__pyx_v_indices)) || PyTuple_CheckExact(__pyx_v_indices)) { + __pyx_t_2 = __pyx_v_indices; __Pyx_INCREF(__pyx_t_2); + __pyx_t_6 = 0; + __pyx_t_7 = NULL; + } else { + __pyx_t_6 = -1; __pyx_t_2 = PyObject_GetIter(__pyx_v_indices); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 747, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_7 = __Pyx_PyObject_GetIterNextFunc(__pyx_t_2); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 747, __pyx_L1_error) + } + for (;;) { + if (likely(!__pyx_t_7)) { + if (likely(PyList_CheckExact(__pyx_t_2))) { + { + Py_ssize_t __pyx_temp = __Pyx_PyList_GET_SIZE(__pyx_t_2); + #if !CYTHON_ASSUME_SAFE_MACROS + if (unlikely((__pyx_temp < 0))) __PYX_ERR(1, 747, __pyx_L1_error) + #endif + if (__pyx_t_6 >= __pyx_temp) break; + } + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + __pyx_t_8 = PyList_GET_ITEM(__pyx_t_2, __pyx_t_6); __Pyx_INCREF(__pyx_t_8); __pyx_t_6++; if (unlikely((0 < 0))) __PYX_ERR(1, 747, __pyx_L1_error) + #else + __pyx_t_8 = __Pyx_PySequence_ITEM(__pyx_t_2, __pyx_t_6); __pyx_t_6++; if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 747, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_8); + #endif + } else { + { + Py_ssize_t __pyx_temp = __Pyx_PyTuple_GET_SIZE(__pyx_t_2); + #if !CYTHON_ASSUME_SAFE_MACROS + if (unlikely((__pyx_temp < 0))) __PYX_ERR(1, 747, __pyx_L1_error) + #endif + if (__pyx_t_6 >= __pyx_temp) break; + } + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + __pyx_t_8 = PyTuple_GET_ITEM(__pyx_t_2, __pyx_t_6); __Pyx_INCREF(__pyx_t_8); __pyx_t_6++; if (unlikely((0 < 0))) __PYX_ERR(1, 747, __pyx_L1_error) + #else + __pyx_t_8 = __Pyx_PySequence_ITEM(__pyx_t_2, __pyx_t_6); __pyx_t_6++; if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 747, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_8); + #endif + } + } else { + __pyx_t_8 = __pyx_t_7(__pyx_t_2); + if (unlikely(!__pyx_t_8)) { + PyObject* exc_type = PyErr_Occurred(); + if (exc_type) { + if (likely(__Pyx_PyErr_GivenExceptionMatches(exc_type, PyExc_StopIteration))) PyErr_Clear(); + else __PYX_ERR(1, 747, __pyx_L1_error) + } + break; + } + __Pyx_GOTREF(__pyx_t_8); + } + __Pyx_XDECREF_SET(__pyx_v_index, __pyx_t_8); + __pyx_t_8 = 0; + __pyx_v_dim = __pyx_t_5; + __pyx_t_5 = (__pyx_t_5 + 1); + + /* "View.MemoryView":748 + * + * for dim, index in enumerate(indices): + * if PyIndex_Check(index): # <<<<<<<<<<<<<< + * cindex = index + * slice_memviewslice( + */ + __pyx_t_1 = (PyIndex_Check(__pyx_v_index) != 0); + if (__pyx_t_1) { + + /* "View.MemoryView":749 + * for dim, index in enumerate(indices): + * if PyIndex_Check(index): + * cindex = index # <<<<<<<<<<<<<< + * slice_memviewslice( + * p_dst, p_src.shape[dim], p_src.strides[dim], p_src.suboffsets[dim], + */ + __pyx_t_9 = __Pyx_PyIndex_AsSsize_t(__pyx_v_index); if (unlikely((__pyx_t_9 == (Py_ssize_t)-1) && PyErr_Occurred())) __PYX_ERR(1, 749, __pyx_L1_error) + __pyx_v_cindex = __pyx_t_9; + + /* "View.MemoryView":750 + * if PyIndex_Check(index): + * cindex = index + * slice_memviewslice( # <<<<<<<<<<<<<< + * p_dst, p_src.shape[dim], p_src.strides[dim], p_src.suboffsets[dim], + * dim, new_ndim, p_suboffset_dim, + */ + __pyx_t_10 = __pyx_memoryview_slice_memviewslice(__pyx_v_p_dst, (__pyx_v_p_src->shape[__pyx_v_dim]), (__pyx_v_p_src->strides[__pyx_v_dim]), (__pyx_v_p_src->suboffsets[__pyx_v_dim]), __pyx_v_dim, __pyx_v_new_ndim, __pyx_v_p_suboffset_dim, __pyx_v_cindex, 0, 0, 0, 0, 0, 0); if (unlikely(__pyx_t_10 == ((int)-1))) __PYX_ERR(1, 750, __pyx_L1_error) + + /* "View.MemoryView":748 + * + * for dim, index in enumerate(indices): + * if PyIndex_Check(index): # <<<<<<<<<<<<<< + * cindex = index + * slice_memviewslice( + */ + goto __pyx_L6; + } + + /* "View.MemoryView":756 + * 0, 0, 0, # have_{start,stop,step} + * False) + * elif index is None: # <<<<<<<<<<<<<< + * p_dst.shape[new_ndim] = 1 + * p_dst.strides[new_ndim] = 0 + */ + __pyx_t_1 = (__pyx_v_index == Py_None); + if (__pyx_t_1) { + + /* "View.MemoryView":757 + * False) + * elif index is None: + * p_dst.shape[new_ndim] = 1 # <<<<<<<<<<<<<< + * p_dst.strides[new_ndim] = 0 + * p_dst.suboffsets[new_ndim] = -1 + */ + (__pyx_v_p_dst->shape[__pyx_v_new_ndim]) = 1; + + /* "View.MemoryView":758 + * elif index is None: + * p_dst.shape[new_ndim] = 1 + * p_dst.strides[new_ndim] = 0 # <<<<<<<<<<<<<< + * p_dst.suboffsets[new_ndim] = -1 + * new_ndim += 1 + */ + (__pyx_v_p_dst->strides[__pyx_v_new_ndim]) = 0; + + /* "View.MemoryView":759 + * p_dst.shape[new_ndim] = 1 + * p_dst.strides[new_ndim] = 0 + * p_dst.suboffsets[new_ndim] = -1 # <<<<<<<<<<<<<< + * new_ndim += 1 + * else: + */ + (__pyx_v_p_dst->suboffsets[__pyx_v_new_ndim]) = -1L; + + /* "View.MemoryView":760 + * p_dst.strides[new_ndim] = 0 + * p_dst.suboffsets[new_ndim] = -1 + * new_ndim += 1 # <<<<<<<<<<<<<< + * else: + * start = index.start or 0 + */ + __pyx_v_new_ndim = (__pyx_v_new_ndim + 1); + + /* "View.MemoryView":756 + * 0, 0, 0, # have_{start,stop,step} + * False) + * elif index is None: # <<<<<<<<<<<<<< + * p_dst.shape[new_ndim] = 1 + * p_dst.strides[new_ndim] = 0 + */ + goto __pyx_L6; + } + + /* "View.MemoryView":762 + * new_ndim += 1 + * else: + * start = index.start or 0 # <<<<<<<<<<<<<< + * stop = index.stop or 0 + * step = index.step or 0 + */ + /*else*/ { + __pyx_t_8 = __Pyx_PyObject_GetAttrStr(__pyx_v_index, __pyx_n_s_start); if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 762, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_8); + __pyx_t_1 = __Pyx_PyObject_IsTrue(__pyx_t_8); if (unlikely((__pyx_t_1 < 0))) __PYX_ERR(1, 762, __pyx_L1_error) + if (!__pyx_t_1) { + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + } else { + __pyx_t_11 = __Pyx_PyIndex_AsSsize_t(__pyx_t_8); if (unlikely((__pyx_t_11 == (Py_ssize_t)-1) && PyErr_Occurred())) __PYX_ERR(1, 762, __pyx_L1_error) + __pyx_t_9 = __pyx_t_11; + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + goto __pyx_L7_bool_binop_done; + } + __pyx_t_9 = 0; + __pyx_L7_bool_binop_done:; + __pyx_v_start = __pyx_t_9; + + /* "View.MemoryView":763 + * else: + * start = index.start or 0 + * stop = index.stop or 0 # <<<<<<<<<<<<<< + * step = index.step or 0 + * + */ + __pyx_t_8 = __Pyx_PyObject_GetAttrStr(__pyx_v_index, __pyx_n_s_stop); if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 763, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_8); + __pyx_t_1 = __Pyx_PyObject_IsTrue(__pyx_t_8); if (unlikely((__pyx_t_1 < 0))) __PYX_ERR(1, 763, __pyx_L1_error) + if (!__pyx_t_1) { + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + } else { + __pyx_t_11 = __Pyx_PyIndex_AsSsize_t(__pyx_t_8); if (unlikely((__pyx_t_11 == (Py_ssize_t)-1) && PyErr_Occurred())) __PYX_ERR(1, 763, __pyx_L1_error) + __pyx_t_9 = __pyx_t_11; + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + goto __pyx_L9_bool_binop_done; + } + __pyx_t_9 = 0; + __pyx_L9_bool_binop_done:; + __pyx_v_stop = __pyx_t_9; + + /* "View.MemoryView":764 + * start = index.start or 0 + * stop = index.stop or 0 + * step = index.step or 0 # <<<<<<<<<<<<<< + * + * have_start = index.start is not None + */ + __pyx_t_8 = __Pyx_PyObject_GetAttrStr(__pyx_v_index, __pyx_n_s_step); if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 764, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_8); + __pyx_t_1 = __Pyx_PyObject_IsTrue(__pyx_t_8); if (unlikely((__pyx_t_1 < 0))) __PYX_ERR(1, 764, __pyx_L1_error) + if (!__pyx_t_1) { + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + } else { + __pyx_t_11 = __Pyx_PyIndex_AsSsize_t(__pyx_t_8); if (unlikely((__pyx_t_11 == (Py_ssize_t)-1) && PyErr_Occurred())) __PYX_ERR(1, 764, __pyx_L1_error) + __pyx_t_9 = __pyx_t_11; + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + goto __pyx_L11_bool_binop_done; + } + __pyx_t_9 = 0; + __pyx_L11_bool_binop_done:; + __pyx_v_step = __pyx_t_9; + + /* "View.MemoryView":766 + * step = index.step or 0 + * + * have_start = index.start is not None # <<<<<<<<<<<<<< + * have_stop = index.stop is not None + * have_step = index.step is not None + */ + __pyx_t_8 = __Pyx_PyObject_GetAttrStr(__pyx_v_index, __pyx_n_s_start); if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 766, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_8); + __pyx_t_1 = (__pyx_t_8 != Py_None); + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + __pyx_v_have_start = __pyx_t_1; + + /* "View.MemoryView":767 + * + * have_start = index.start is not None + * have_stop = index.stop is not None # <<<<<<<<<<<<<< + * have_step = index.step is not None + * + */ + __pyx_t_8 = __Pyx_PyObject_GetAttrStr(__pyx_v_index, __pyx_n_s_stop); if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 767, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_8); + __pyx_t_1 = (__pyx_t_8 != Py_None); + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + __pyx_v_have_stop = __pyx_t_1; + + /* "View.MemoryView":768 + * have_start = index.start is not None + * have_stop = index.stop is not None + * have_step = index.step is not None # <<<<<<<<<<<<<< + * + * slice_memviewslice( + */ + __pyx_t_8 = __Pyx_PyObject_GetAttrStr(__pyx_v_index, __pyx_n_s_step); if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 768, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_8); + __pyx_t_1 = (__pyx_t_8 != Py_None); + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + __pyx_v_have_step = __pyx_t_1; + + /* "View.MemoryView":770 + * have_step = index.step is not None + * + * slice_memviewslice( # <<<<<<<<<<<<<< + * p_dst, p_src.shape[dim], p_src.strides[dim], p_src.suboffsets[dim], + * dim, new_ndim, p_suboffset_dim, + */ + __pyx_t_10 = __pyx_memoryview_slice_memviewslice(__pyx_v_p_dst, (__pyx_v_p_src->shape[__pyx_v_dim]), (__pyx_v_p_src->strides[__pyx_v_dim]), (__pyx_v_p_src->suboffsets[__pyx_v_dim]), __pyx_v_dim, __pyx_v_new_ndim, __pyx_v_p_suboffset_dim, __pyx_v_start, __pyx_v_stop, __pyx_v_step, __pyx_v_have_start, __pyx_v_have_stop, __pyx_v_have_step, 1); if (unlikely(__pyx_t_10 == ((int)-1))) __PYX_ERR(1, 770, __pyx_L1_error) + + /* "View.MemoryView":776 + * have_start, have_stop, have_step, + * True) + * new_ndim += 1 # <<<<<<<<<<<<<< + * + * if isinstance(memview, _memoryviewslice): + */ + __pyx_v_new_ndim = (__pyx_v_new_ndim + 1); + } + __pyx_L6:; + + /* "View.MemoryView":747 + * cdef bint have_start, have_stop, have_step + * + * for dim, index in enumerate(indices): # <<<<<<<<<<<<<< + * if PyIndex_Check(index): + * cindex = index + */ + } + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + + /* "View.MemoryView":778 + * new_ndim += 1 + * + * if isinstance(memview, _memoryviewslice): # <<<<<<<<<<<<<< + * return memoryview_fromslice(dst, new_ndim, + * memviewsliceobj.to_object_func, + */ + __pyx_t_1 = __Pyx_TypeCheck(((PyObject *)__pyx_v_memview), __pyx_memoryviewslice_type); + if (__pyx_t_1) { + + /* "View.MemoryView":779 + * + * if isinstance(memview, _memoryviewslice): + * return memoryview_fromslice(dst, new_ndim, # <<<<<<<<<<<<<< + * memviewsliceobj.to_object_func, + * memviewsliceobj.to_dtype_func, + */ + __Pyx_XDECREF((PyObject *)__pyx_r); + + /* "View.MemoryView":780 + * if isinstance(memview, _memoryviewslice): + * return memoryview_fromslice(dst, new_ndim, + * memviewsliceobj.to_object_func, # <<<<<<<<<<<<<< + * memviewsliceobj.to_dtype_func, + * memview.dtype_is_object) + */ + if (unlikely(!__pyx_v_memviewsliceobj)) { __Pyx_RaiseUnboundLocalError("memviewsliceobj"); __PYX_ERR(1, 780, __pyx_L1_error) } + + /* "View.MemoryView":781 + * return memoryview_fromslice(dst, new_ndim, + * memviewsliceobj.to_object_func, + * memviewsliceobj.to_dtype_func, # <<<<<<<<<<<<<< + * memview.dtype_is_object) + * else: + */ + if (unlikely(!__pyx_v_memviewsliceobj)) { __Pyx_RaiseUnboundLocalError("memviewsliceobj"); __PYX_ERR(1, 781, __pyx_L1_error) } + + /* "View.MemoryView":779 + * + * if isinstance(memview, _memoryviewslice): + * return memoryview_fromslice(dst, new_ndim, # <<<<<<<<<<<<<< + * memviewsliceobj.to_object_func, + * memviewsliceobj.to_dtype_func, + */ + __pyx_t_2 = __pyx_memoryview_fromslice(__pyx_v_dst, __pyx_v_new_ndim, __pyx_v_memviewsliceobj->to_object_func, __pyx_v_memviewsliceobj->to_dtype_func, __pyx_v_memview->dtype_is_object); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 779, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + if (!(likely(((__pyx_t_2) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_2, __pyx_memoryview_type))))) __PYX_ERR(1, 779, __pyx_L1_error) + __pyx_r = ((struct __pyx_memoryview_obj *)__pyx_t_2); + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":778 + * new_ndim += 1 + * + * if isinstance(memview, _memoryviewslice): # <<<<<<<<<<<<<< + * return memoryview_fromslice(dst, new_ndim, + * memviewsliceobj.to_object_func, + */ + } + + /* "View.MemoryView":784 + * memview.dtype_is_object) + * else: + * return memoryview_fromslice(dst, new_ndim, NULL, NULL, # <<<<<<<<<<<<<< + * memview.dtype_is_object) + * + */ + /*else*/ { + __Pyx_XDECREF((PyObject *)__pyx_r); + + /* "View.MemoryView":785 + * else: + * return memoryview_fromslice(dst, new_ndim, NULL, NULL, + * memview.dtype_is_object) # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_2 = __pyx_memoryview_fromslice(__pyx_v_dst, __pyx_v_new_ndim, NULL, NULL, __pyx_v_memview->dtype_is_object); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 784, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + + /* "View.MemoryView":784 + * memview.dtype_is_object) + * else: + * return memoryview_fromslice(dst, new_ndim, NULL, NULL, # <<<<<<<<<<<<<< + * memview.dtype_is_object) + * + */ + if (!(likely(((__pyx_t_2) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_2, __pyx_memoryview_type))))) __PYX_ERR(1, 784, __pyx_L1_error) + __pyx_r = ((struct __pyx_memoryview_obj *)__pyx_t_2); + __pyx_t_2 = 0; + goto __pyx_L0; + } + + /* "View.MemoryView":711 + * + * @cname('__pyx_memview_slice') + * cdef memoryview memview_slice(memoryview memview, object indices): # <<<<<<<<<<<<<< + * cdef int new_ndim = 0, suboffset_dim = -1, dim + * cdef bint negative_step + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_8); + __Pyx_AddTraceback("View.MemoryView.memview_slice", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XDECREF((PyObject *)__pyx_v_memviewsliceobj); + __Pyx_XDECREF(__pyx_v_index); + __Pyx_XGIVEREF((PyObject *)__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":793 + * + * @cname('__pyx_memoryview_slice_memviewslice') + * cdef int slice_memviewslice( # <<<<<<<<<<<<<< + * __Pyx_memviewslice *dst, + * Py_ssize_t shape, Py_ssize_t stride, Py_ssize_t suboffset, + */ + +static int __pyx_memoryview_slice_memviewslice(__Pyx_memviewslice *__pyx_v_dst, Py_ssize_t __pyx_v_shape, Py_ssize_t __pyx_v_stride, Py_ssize_t __pyx_v_suboffset, int __pyx_v_dim, int __pyx_v_new_ndim, int *__pyx_v_suboffset_dim, Py_ssize_t __pyx_v_start, Py_ssize_t __pyx_v_stop, Py_ssize_t __pyx_v_step, int __pyx_v_have_start, int __pyx_v_have_stop, int __pyx_v_have_step, int __pyx_v_is_slice) { + Py_ssize_t __pyx_v_new_shape; + int __pyx_v_negative_step; + int __pyx_r; + int __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save; + #endif + + /* "View.MemoryView":813 + * cdef bint negative_step + * + * if not is_slice: # <<<<<<<<<<<<<< + * + * if start < 0: + */ + __pyx_t_1 = (!__pyx_v_is_slice); + if (__pyx_t_1) { + + /* "View.MemoryView":815 + * if not is_slice: + * + * if start < 0: # <<<<<<<<<<<<<< + * start += shape + * if not 0 <= start < shape: + */ + __pyx_t_1 = (__pyx_v_start < 0); + if (__pyx_t_1) { + + /* "View.MemoryView":816 + * + * if start < 0: + * start += shape # <<<<<<<<<<<<<< + * if not 0 <= start < shape: + * _err_dim(PyExc_IndexError, "Index out of bounds (axis %d)", dim) + */ + __pyx_v_start = (__pyx_v_start + __pyx_v_shape); + + /* "View.MemoryView":815 + * if not is_slice: + * + * if start < 0: # <<<<<<<<<<<<<< + * start += shape + * if not 0 <= start < shape: + */ + } + + /* "View.MemoryView":817 + * if start < 0: + * start += shape + * if not 0 <= start < shape: # <<<<<<<<<<<<<< + * _err_dim(PyExc_IndexError, "Index out of bounds (axis %d)", dim) + * else: + */ + __pyx_t_1 = (0 <= __pyx_v_start); + if (__pyx_t_1) { + __pyx_t_1 = (__pyx_v_start < __pyx_v_shape); + } + __pyx_t_2 = (!__pyx_t_1); + if (__pyx_t_2) { + + /* "View.MemoryView":818 + * start += shape + * if not 0 <= start < shape: + * _err_dim(PyExc_IndexError, "Index out of bounds (axis %d)", dim) # <<<<<<<<<<<<<< + * else: + * + */ + __pyx_t_3 = __pyx_memoryview_err_dim(PyExc_IndexError, __pyx_kp_s_Index_out_of_bounds_axis_d, __pyx_v_dim); if (unlikely(__pyx_t_3 == ((int)-1))) __PYX_ERR(1, 818, __pyx_L1_error) + + /* "View.MemoryView":817 + * if start < 0: + * start += shape + * if not 0 <= start < shape: # <<<<<<<<<<<<<< + * _err_dim(PyExc_IndexError, "Index out of bounds (axis %d)", dim) + * else: + */ + } + + /* "View.MemoryView":813 + * cdef bint negative_step + * + * if not is_slice: # <<<<<<<<<<<<<< + * + * if start < 0: + */ + goto __pyx_L3; + } + + /* "View.MemoryView":821 + * else: + * + * if have_step: # <<<<<<<<<<<<<< + * negative_step = step < 0 + * if step == 0: + */ + /*else*/ { + __pyx_t_2 = (__pyx_v_have_step != 0); + if (__pyx_t_2) { + + /* "View.MemoryView":822 + * + * if have_step: + * negative_step = step < 0 # <<<<<<<<<<<<<< + * if step == 0: + * _err_dim(PyExc_ValueError, "Step may not be zero (axis %d)", dim) + */ + __pyx_v_negative_step = (__pyx_v_step < 0); + + /* "View.MemoryView":823 + * if have_step: + * negative_step = step < 0 + * if step == 0: # <<<<<<<<<<<<<< + * _err_dim(PyExc_ValueError, "Step may not be zero (axis %d)", dim) + * else: + */ + __pyx_t_2 = (__pyx_v_step == 0); + if (__pyx_t_2) { + + /* "View.MemoryView":824 + * negative_step = step < 0 + * if step == 0: + * _err_dim(PyExc_ValueError, "Step may not be zero (axis %d)", dim) # <<<<<<<<<<<<<< + * else: + * negative_step = False + */ + __pyx_t_3 = __pyx_memoryview_err_dim(PyExc_ValueError, __pyx_kp_s_Step_may_not_be_zero_axis_d, __pyx_v_dim); if (unlikely(__pyx_t_3 == ((int)-1))) __PYX_ERR(1, 824, __pyx_L1_error) + + /* "View.MemoryView":823 + * if have_step: + * negative_step = step < 0 + * if step == 0: # <<<<<<<<<<<<<< + * _err_dim(PyExc_ValueError, "Step may not be zero (axis %d)", dim) + * else: + */ + } + + /* "View.MemoryView":821 + * else: + * + * if have_step: # <<<<<<<<<<<<<< + * negative_step = step < 0 + * if step == 0: + */ + goto __pyx_L6; + } + + /* "View.MemoryView":826 + * _err_dim(PyExc_ValueError, "Step may not be zero (axis %d)", dim) + * else: + * negative_step = False # <<<<<<<<<<<<<< + * step = 1 + * + */ + /*else*/ { + __pyx_v_negative_step = 0; + + /* "View.MemoryView":827 + * else: + * negative_step = False + * step = 1 # <<<<<<<<<<<<<< + * + * + */ + __pyx_v_step = 1; + } + __pyx_L6:; + + /* "View.MemoryView":830 + * + * + * if have_start: # <<<<<<<<<<<<<< + * if start < 0: + * start += shape + */ + __pyx_t_2 = (__pyx_v_have_start != 0); + if (__pyx_t_2) { + + /* "View.MemoryView":831 + * + * if have_start: + * if start < 0: # <<<<<<<<<<<<<< + * start += shape + * if start < 0: + */ + __pyx_t_2 = (__pyx_v_start < 0); + if (__pyx_t_2) { + + /* "View.MemoryView":832 + * if have_start: + * if start < 0: + * start += shape # <<<<<<<<<<<<<< + * if start < 0: + * start = 0 + */ + __pyx_v_start = (__pyx_v_start + __pyx_v_shape); + + /* "View.MemoryView":833 + * if start < 0: + * start += shape + * if start < 0: # <<<<<<<<<<<<<< + * start = 0 + * elif start >= shape: + */ + __pyx_t_2 = (__pyx_v_start < 0); + if (__pyx_t_2) { + + /* "View.MemoryView":834 + * start += shape + * if start < 0: + * start = 0 # <<<<<<<<<<<<<< + * elif start >= shape: + * if negative_step: + */ + __pyx_v_start = 0; + + /* "View.MemoryView":833 + * if start < 0: + * start += shape + * if start < 0: # <<<<<<<<<<<<<< + * start = 0 + * elif start >= shape: + */ + } + + /* "View.MemoryView":831 + * + * if have_start: + * if start < 0: # <<<<<<<<<<<<<< + * start += shape + * if start < 0: + */ + goto __pyx_L9; + } + + /* "View.MemoryView":835 + * if start < 0: + * start = 0 + * elif start >= shape: # <<<<<<<<<<<<<< + * if negative_step: + * start = shape - 1 + */ + __pyx_t_2 = (__pyx_v_start >= __pyx_v_shape); + if (__pyx_t_2) { + + /* "View.MemoryView":836 + * start = 0 + * elif start >= shape: + * if negative_step: # <<<<<<<<<<<<<< + * start = shape - 1 + * else: + */ + if (__pyx_v_negative_step) { + + /* "View.MemoryView":837 + * elif start >= shape: + * if negative_step: + * start = shape - 1 # <<<<<<<<<<<<<< + * else: + * start = shape + */ + __pyx_v_start = (__pyx_v_shape - 1); + + /* "View.MemoryView":836 + * start = 0 + * elif start >= shape: + * if negative_step: # <<<<<<<<<<<<<< + * start = shape - 1 + * else: + */ + goto __pyx_L11; + } + + /* "View.MemoryView":839 + * start = shape - 1 + * else: + * start = shape # <<<<<<<<<<<<<< + * else: + * if negative_step: + */ + /*else*/ { + __pyx_v_start = __pyx_v_shape; + } + __pyx_L11:; + + /* "View.MemoryView":835 + * if start < 0: + * start = 0 + * elif start >= shape: # <<<<<<<<<<<<<< + * if negative_step: + * start = shape - 1 + */ + } + __pyx_L9:; + + /* "View.MemoryView":830 + * + * + * if have_start: # <<<<<<<<<<<<<< + * if start < 0: + * start += shape + */ + goto __pyx_L8; + } + + /* "View.MemoryView":841 + * start = shape + * else: + * if negative_step: # <<<<<<<<<<<<<< + * start = shape - 1 + * else: + */ + /*else*/ { + if (__pyx_v_negative_step) { + + /* "View.MemoryView":842 + * else: + * if negative_step: + * start = shape - 1 # <<<<<<<<<<<<<< + * else: + * start = 0 + */ + __pyx_v_start = (__pyx_v_shape - 1); + + /* "View.MemoryView":841 + * start = shape + * else: + * if negative_step: # <<<<<<<<<<<<<< + * start = shape - 1 + * else: + */ + goto __pyx_L12; + } + + /* "View.MemoryView":844 + * start = shape - 1 + * else: + * start = 0 # <<<<<<<<<<<<<< + * + * if have_stop: + */ + /*else*/ { + __pyx_v_start = 0; + } + __pyx_L12:; + } + __pyx_L8:; + + /* "View.MemoryView":846 + * start = 0 + * + * if have_stop: # <<<<<<<<<<<<<< + * if stop < 0: + * stop += shape + */ + __pyx_t_2 = (__pyx_v_have_stop != 0); + if (__pyx_t_2) { + + /* "View.MemoryView":847 + * + * if have_stop: + * if stop < 0: # <<<<<<<<<<<<<< + * stop += shape + * if stop < 0: + */ + __pyx_t_2 = (__pyx_v_stop < 0); + if (__pyx_t_2) { + + /* "View.MemoryView":848 + * if have_stop: + * if stop < 0: + * stop += shape # <<<<<<<<<<<<<< + * if stop < 0: + * stop = 0 + */ + __pyx_v_stop = (__pyx_v_stop + __pyx_v_shape); + + /* "View.MemoryView":849 + * if stop < 0: + * stop += shape + * if stop < 0: # <<<<<<<<<<<<<< + * stop = 0 + * elif stop > shape: + */ + __pyx_t_2 = (__pyx_v_stop < 0); + if (__pyx_t_2) { + + /* "View.MemoryView":850 + * stop += shape + * if stop < 0: + * stop = 0 # <<<<<<<<<<<<<< + * elif stop > shape: + * stop = shape + */ + __pyx_v_stop = 0; + + /* "View.MemoryView":849 + * if stop < 0: + * stop += shape + * if stop < 0: # <<<<<<<<<<<<<< + * stop = 0 + * elif stop > shape: + */ + } + + /* "View.MemoryView":847 + * + * if have_stop: + * if stop < 0: # <<<<<<<<<<<<<< + * stop += shape + * if stop < 0: + */ + goto __pyx_L14; + } + + /* "View.MemoryView":851 + * if stop < 0: + * stop = 0 + * elif stop > shape: # <<<<<<<<<<<<<< + * stop = shape + * else: + */ + __pyx_t_2 = (__pyx_v_stop > __pyx_v_shape); + if (__pyx_t_2) { + + /* "View.MemoryView":852 + * stop = 0 + * elif stop > shape: + * stop = shape # <<<<<<<<<<<<<< + * else: + * if negative_step: + */ + __pyx_v_stop = __pyx_v_shape; + + /* "View.MemoryView":851 + * if stop < 0: + * stop = 0 + * elif stop > shape: # <<<<<<<<<<<<<< + * stop = shape + * else: + */ + } + __pyx_L14:; + + /* "View.MemoryView":846 + * start = 0 + * + * if have_stop: # <<<<<<<<<<<<<< + * if stop < 0: + * stop += shape + */ + goto __pyx_L13; + } + + /* "View.MemoryView":854 + * stop = shape + * else: + * if negative_step: # <<<<<<<<<<<<<< + * stop = -1 + * else: + */ + /*else*/ { + if (__pyx_v_negative_step) { + + /* "View.MemoryView":855 + * else: + * if negative_step: + * stop = -1 # <<<<<<<<<<<<<< + * else: + * stop = shape + */ + __pyx_v_stop = -1L; + + /* "View.MemoryView":854 + * stop = shape + * else: + * if negative_step: # <<<<<<<<<<<<<< + * stop = -1 + * else: + */ + goto __pyx_L16; + } + + /* "View.MemoryView":857 + * stop = -1 + * else: + * stop = shape # <<<<<<<<<<<<<< + * + * + */ + /*else*/ { + __pyx_v_stop = __pyx_v_shape; + } + __pyx_L16:; + } + __pyx_L13:; + + /* "View.MemoryView":861 + * + * with cython.cdivision(True): + * new_shape = (stop - start) // step # <<<<<<<<<<<<<< + * + * if (stop - start) - step * new_shape: + */ + __pyx_v_new_shape = ((__pyx_v_stop - __pyx_v_start) / __pyx_v_step); + + /* "View.MemoryView":863 + * new_shape = (stop - start) // step + * + * if (stop - start) - step * new_shape: # <<<<<<<<<<<<<< + * new_shape += 1 + * + */ + __pyx_t_2 = (((__pyx_v_stop - __pyx_v_start) - (__pyx_v_step * __pyx_v_new_shape)) != 0); + if (__pyx_t_2) { + + /* "View.MemoryView":864 + * + * if (stop - start) - step * new_shape: + * new_shape += 1 # <<<<<<<<<<<<<< + * + * if new_shape < 0: + */ + __pyx_v_new_shape = (__pyx_v_new_shape + 1); + + /* "View.MemoryView":863 + * new_shape = (stop - start) // step + * + * if (stop - start) - step * new_shape: # <<<<<<<<<<<<<< + * new_shape += 1 + * + */ + } + + /* "View.MemoryView":866 + * new_shape += 1 + * + * if new_shape < 0: # <<<<<<<<<<<<<< + * new_shape = 0 + * + */ + __pyx_t_2 = (__pyx_v_new_shape < 0); + if (__pyx_t_2) { + + /* "View.MemoryView":867 + * + * if new_shape < 0: + * new_shape = 0 # <<<<<<<<<<<<<< + * + * + */ + __pyx_v_new_shape = 0; + + /* "View.MemoryView":866 + * new_shape += 1 + * + * if new_shape < 0: # <<<<<<<<<<<<<< + * new_shape = 0 + * + */ + } + + /* "View.MemoryView":870 + * + * + * dst.strides[new_ndim] = stride * step # <<<<<<<<<<<<<< + * dst.shape[new_ndim] = new_shape + * dst.suboffsets[new_ndim] = suboffset + */ + (__pyx_v_dst->strides[__pyx_v_new_ndim]) = (__pyx_v_stride * __pyx_v_step); + + /* "View.MemoryView":871 + * + * dst.strides[new_ndim] = stride * step + * dst.shape[new_ndim] = new_shape # <<<<<<<<<<<<<< + * dst.suboffsets[new_ndim] = suboffset + * + */ + (__pyx_v_dst->shape[__pyx_v_new_ndim]) = __pyx_v_new_shape; + + /* "View.MemoryView":872 + * dst.strides[new_ndim] = stride * step + * dst.shape[new_ndim] = new_shape + * dst.suboffsets[new_ndim] = suboffset # <<<<<<<<<<<<<< + * + * + */ + (__pyx_v_dst->suboffsets[__pyx_v_new_ndim]) = __pyx_v_suboffset; + } + __pyx_L3:; + + /* "View.MemoryView":875 + * + * + * if suboffset_dim[0] < 0: # <<<<<<<<<<<<<< + * dst.data += start * stride + * else: + */ + __pyx_t_2 = ((__pyx_v_suboffset_dim[0]) < 0); + if (__pyx_t_2) { + + /* "View.MemoryView":876 + * + * if suboffset_dim[0] < 0: + * dst.data += start * stride # <<<<<<<<<<<<<< + * else: + * dst.suboffsets[suboffset_dim[0]] += start * stride + */ + __pyx_v_dst->data = (__pyx_v_dst->data + (__pyx_v_start * __pyx_v_stride)); + + /* "View.MemoryView":875 + * + * + * if suboffset_dim[0] < 0: # <<<<<<<<<<<<<< + * dst.data += start * stride + * else: + */ + goto __pyx_L19; + } + + /* "View.MemoryView":878 + * dst.data += start * stride + * else: + * dst.suboffsets[suboffset_dim[0]] += start * stride # <<<<<<<<<<<<<< + * + * if suboffset >= 0: + */ + /*else*/ { + __pyx_t_3 = (__pyx_v_suboffset_dim[0]); + (__pyx_v_dst->suboffsets[__pyx_t_3]) = ((__pyx_v_dst->suboffsets[__pyx_t_3]) + (__pyx_v_start * __pyx_v_stride)); + } + __pyx_L19:; + + /* "View.MemoryView":880 + * dst.suboffsets[suboffset_dim[0]] += start * stride + * + * if suboffset >= 0: # <<<<<<<<<<<<<< + * if not is_slice: + * if new_ndim == 0: + */ + __pyx_t_2 = (__pyx_v_suboffset >= 0); + if (__pyx_t_2) { + + /* "View.MemoryView":881 + * + * if suboffset >= 0: + * if not is_slice: # <<<<<<<<<<<<<< + * if new_ndim == 0: + * dst.data = ( dst.data)[0] + suboffset + */ + __pyx_t_2 = (!__pyx_v_is_slice); + if (__pyx_t_2) { + + /* "View.MemoryView":882 + * if suboffset >= 0: + * if not is_slice: + * if new_ndim == 0: # <<<<<<<<<<<<<< + * dst.data = ( dst.data)[0] + suboffset + * else: + */ + __pyx_t_2 = (__pyx_v_new_ndim == 0); + if (__pyx_t_2) { + + /* "View.MemoryView":883 + * if not is_slice: + * if new_ndim == 0: + * dst.data = ( dst.data)[0] + suboffset # <<<<<<<<<<<<<< + * else: + * _err_dim(PyExc_IndexError, "All dimensions preceding dimension %d " + */ + __pyx_v_dst->data = ((((char **)__pyx_v_dst->data)[0]) + __pyx_v_suboffset); + + /* "View.MemoryView":882 + * if suboffset >= 0: + * if not is_slice: + * if new_ndim == 0: # <<<<<<<<<<<<<< + * dst.data = ( dst.data)[0] + suboffset + * else: + */ + goto __pyx_L22; + } + + /* "View.MemoryView":885 + * dst.data = ( dst.data)[0] + suboffset + * else: + * _err_dim(PyExc_IndexError, "All dimensions preceding dimension %d " # <<<<<<<<<<<<<< + * "must be indexed and not sliced", dim) + * else: + */ + /*else*/ { + + /* "View.MemoryView":886 + * else: + * _err_dim(PyExc_IndexError, "All dimensions preceding dimension %d " + * "must be indexed and not sliced", dim) # <<<<<<<<<<<<<< + * else: + * suboffset_dim[0] = new_ndim + */ + __pyx_t_3 = __pyx_memoryview_err_dim(PyExc_IndexError, __pyx_kp_s_All_dimensions_preceding_dimensi, __pyx_v_dim); if (unlikely(__pyx_t_3 == ((int)-1))) __PYX_ERR(1, 885, __pyx_L1_error) + } + __pyx_L22:; + + /* "View.MemoryView":881 + * + * if suboffset >= 0: + * if not is_slice: # <<<<<<<<<<<<<< + * if new_ndim == 0: + * dst.data = ( dst.data)[0] + suboffset + */ + goto __pyx_L21; + } + + /* "View.MemoryView":888 + * "must be indexed and not sliced", dim) + * else: + * suboffset_dim[0] = new_ndim # <<<<<<<<<<<<<< + * + * return 0 + */ + /*else*/ { + (__pyx_v_suboffset_dim[0]) = __pyx_v_new_ndim; + } + __pyx_L21:; + + /* "View.MemoryView":880 + * dst.suboffsets[suboffset_dim[0]] += start * stride + * + * if suboffset >= 0: # <<<<<<<<<<<<<< + * if not is_slice: + * if new_ndim == 0: + */ + } + + /* "View.MemoryView":890 + * suboffset_dim[0] = new_ndim + * + * return 0 # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = 0; + goto __pyx_L0; + + /* "View.MemoryView":793 + * + * @cname('__pyx_memoryview_slice_memviewslice') + * cdef int slice_memviewslice( # <<<<<<<<<<<<<< + * __Pyx_memviewslice *dst, + * Py_ssize_t shape, Py_ssize_t stride, Py_ssize_t suboffset, + */ + + /* function exit code */ + __pyx_L1_error:; + #ifdef WITH_THREAD + __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + __Pyx_AddTraceback("View.MemoryView.slice_memviewslice", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":896 + * + * @cname('__pyx_pybuffer_index') + * cdef char *pybuffer_index(Py_buffer *view, char *bufp, Py_ssize_t index, # <<<<<<<<<<<<<< + * Py_ssize_t dim) except NULL: + * cdef Py_ssize_t shape, stride, suboffset = -1 + */ + +static char *__pyx_pybuffer_index(Py_buffer *__pyx_v_view, char *__pyx_v_bufp, Py_ssize_t __pyx_v_index, Py_ssize_t __pyx_v_dim) { + Py_ssize_t __pyx_v_shape; + Py_ssize_t __pyx_v_stride; + Py_ssize_t __pyx_v_suboffset; + Py_ssize_t __pyx_v_itemsize; + char *__pyx_v_resultp; + char *__pyx_r; + __Pyx_RefNannyDeclarations + Py_ssize_t __pyx_t_1; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + Py_UCS4 __pyx_t_4; + PyObject *__pyx_t_5 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("pybuffer_index", 1); + + /* "View.MemoryView":898 + * cdef char *pybuffer_index(Py_buffer *view, char *bufp, Py_ssize_t index, + * Py_ssize_t dim) except NULL: + * cdef Py_ssize_t shape, stride, suboffset = -1 # <<<<<<<<<<<<<< + * cdef Py_ssize_t itemsize = view.itemsize + * cdef char *resultp + */ + __pyx_v_suboffset = -1L; + + /* "View.MemoryView":899 + * Py_ssize_t dim) except NULL: + * cdef Py_ssize_t shape, stride, suboffset = -1 + * cdef Py_ssize_t itemsize = view.itemsize # <<<<<<<<<<<<<< + * cdef char *resultp + * + */ + __pyx_t_1 = __pyx_v_view->itemsize; + __pyx_v_itemsize = __pyx_t_1; + + /* "View.MemoryView":902 + * cdef char *resultp + * + * if view.ndim == 0: # <<<<<<<<<<<<<< + * shape = view.len // itemsize + * stride = itemsize + */ + __pyx_t_2 = (__pyx_v_view->ndim == 0); + if (__pyx_t_2) { + + /* "View.MemoryView":903 + * + * if view.ndim == 0: + * shape = view.len // itemsize # <<<<<<<<<<<<<< + * stride = itemsize + * else: + */ + if (unlikely(__pyx_v_itemsize == 0)) { + PyErr_SetString(PyExc_ZeroDivisionError, "integer division or modulo by zero"); + __PYX_ERR(1, 903, __pyx_L1_error) + } + else if (sizeof(Py_ssize_t) == sizeof(long) && (!(((Py_ssize_t)-1) > 0)) && unlikely(__pyx_v_itemsize == (Py_ssize_t)-1) && unlikely(__Pyx_UNARY_NEG_WOULD_OVERFLOW(__pyx_v_view->len))) { + PyErr_SetString(PyExc_OverflowError, "value too large to perform division"); + __PYX_ERR(1, 903, __pyx_L1_error) + } + __pyx_v_shape = __Pyx_div_Py_ssize_t(__pyx_v_view->len, __pyx_v_itemsize); + + /* "View.MemoryView":904 + * if view.ndim == 0: + * shape = view.len // itemsize + * stride = itemsize # <<<<<<<<<<<<<< + * else: + * shape = view.shape[dim] + */ + __pyx_v_stride = __pyx_v_itemsize; + + /* "View.MemoryView":902 + * cdef char *resultp + * + * if view.ndim == 0: # <<<<<<<<<<<<<< + * shape = view.len // itemsize + * stride = itemsize + */ + goto __pyx_L3; + } + + /* "View.MemoryView":906 + * stride = itemsize + * else: + * shape = view.shape[dim] # <<<<<<<<<<<<<< + * stride = view.strides[dim] + * if view.suboffsets != NULL: + */ + /*else*/ { + __pyx_v_shape = (__pyx_v_view->shape[__pyx_v_dim]); + + /* "View.MemoryView":907 + * else: + * shape = view.shape[dim] + * stride = view.strides[dim] # <<<<<<<<<<<<<< + * if view.suboffsets != NULL: + * suboffset = view.suboffsets[dim] + */ + __pyx_v_stride = (__pyx_v_view->strides[__pyx_v_dim]); + + /* "View.MemoryView":908 + * shape = view.shape[dim] + * stride = view.strides[dim] + * if view.suboffsets != NULL: # <<<<<<<<<<<<<< + * suboffset = view.suboffsets[dim] + * + */ + __pyx_t_2 = (__pyx_v_view->suboffsets != NULL); + if (__pyx_t_2) { + + /* "View.MemoryView":909 + * stride = view.strides[dim] + * if view.suboffsets != NULL: + * suboffset = view.suboffsets[dim] # <<<<<<<<<<<<<< + * + * if index < 0: + */ + __pyx_v_suboffset = (__pyx_v_view->suboffsets[__pyx_v_dim]); + + /* "View.MemoryView":908 + * shape = view.shape[dim] + * stride = view.strides[dim] + * if view.suboffsets != NULL: # <<<<<<<<<<<<<< + * suboffset = view.suboffsets[dim] + * + */ + } + } + __pyx_L3:; + + /* "View.MemoryView":911 + * suboffset = view.suboffsets[dim] + * + * if index < 0: # <<<<<<<<<<<<<< + * index += view.shape[dim] + * if index < 0: + */ + __pyx_t_2 = (__pyx_v_index < 0); + if (__pyx_t_2) { + + /* "View.MemoryView":912 + * + * if index < 0: + * index += view.shape[dim] # <<<<<<<<<<<<<< + * if index < 0: + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" + */ + __pyx_v_index = (__pyx_v_index + (__pyx_v_view->shape[__pyx_v_dim])); + + /* "View.MemoryView":913 + * if index < 0: + * index += view.shape[dim] + * if index < 0: # <<<<<<<<<<<<<< + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" + * + */ + __pyx_t_2 = (__pyx_v_index < 0); + if (unlikely(__pyx_t_2)) { + + /* "View.MemoryView":914 + * index += view.shape[dim] + * if index < 0: + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" # <<<<<<<<<<<<<< + * + * if index >= shape: + */ + __pyx_t_3 = PyTuple_New(3); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 914, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_1 = 0; + __pyx_t_4 = 127; + __Pyx_INCREF(__pyx_kp_u_Out_of_bounds_on_buffer_access_a); + __pyx_t_1 += 37; + __Pyx_GIVEREF(__pyx_kp_u_Out_of_bounds_on_buffer_access_a); + PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_kp_u_Out_of_bounds_on_buffer_access_a); + __pyx_t_5 = __Pyx_PyUnicode_From_Py_ssize_t(__pyx_v_dim, 0, ' ', 'd'); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 914, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_1 += __Pyx_PyUnicode_GET_LENGTH(__pyx_t_5); + __Pyx_GIVEREF(__pyx_t_5); + PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_t_5); + __pyx_t_5 = 0; + __Pyx_INCREF(__pyx_kp_u__7); + __pyx_t_1 += 1; + __Pyx_GIVEREF(__pyx_kp_u__7); + PyTuple_SET_ITEM(__pyx_t_3, 2, __pyx_kp_u__7); + __pyx_t_5 = __Pyx_PyUnicode_Join(__pyx_t_3, 3, __pyx_t_1, __pyx_t_4); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 914, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __Pyx_Raise(__pyx_builtin_IndexError, __pyx_t_5, 0, 0); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __PYX_ERR(1, 914, __pyx_L1_error) + + /* "View.MemoryView":913 + * if index < 0: + * index += view.shape[dim] + * if index < 0: # <<<<<<<<<<<<<< + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" + * + */ + } + + /* "View.MemoryView":911 + * suboffset = view.suboffsets[dim] + * + * if index < 0: # <<<<<<<<<<<<<< + * index += view.shape[dim] + * if index < 0: + */ + } + + /* "View.MemoryView":916 + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" + * + * if index >= shape: # <<<<<<<<<<<<<< + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" + * + */ + __pyx_t_2 = (__pyx_v_index >= __pyx_v_shape); + if (unlikely(__pyx_t_2)) { + + /* "View.MemoryView":917 + * + * if index >= shape: + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" # <<<<<<<<<<<<<< + * + * resultp = bufp + index * stride + */ + __pyx_t_5 = PyTuple_New(3); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 917, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_1 = 0; + __pyx_t_4 = 127; + __Pyx_INCREF(__pyx_kp_u_Out_of_bounds_on_buffer_access_a); + __pyx_t_1 += 37; + __Pyx_GIVEREF(__pyx_kp_u_Out_of_bounds_on_buffer_access_a); + PyTuple_SET_ITEM(__pyx_t_5, 0, __pyx_kp_u_Out_of_bounds_on_buffer_access_a); + __pyx_t_3 = __Pyx_PyUnicode_From_Py_ssize_t(__pyx_v_dim, 0, ' ', 'd'); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 917, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_1 += __Pyx_PyUnicode_GET_LENGTH(__pyx_t_3); + __Pyx_GIVEREF(__pyx_t_3); + PyTuple_SET_ITEM(__pyx_t_5, 1, __pyx_t_3); + __pyx_t_3 = 0; + __Pyx_INCREF(__pyx_kp_u__7); + __pyx_t_1 += 1; + __Pyx_GIVEREF(__pyx_kp_u__7); + PyTuple_SET_ITEM(__pyx_t_5, 2, __pyx_kp_u__7); + __pyx_t_3 = __Pyx_PyUnicode_Join(__pyx_t_5, 3, __pyx_t_1, __pyx_t_4); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 917, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_Raise(__pyx_builtin_IndexError, __pyx_t_3, 0, 0); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __PYX_ERR(1, 917, __pyx_L1_error) + + /* "View.MemoryView":916 + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" + * + * if index >= shape: # <<<<<<<<<<<<<< + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" + * + */ + } + + /* "View.MemoryView":919 + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" + * + * resultp = bufp + index * stride # <<<<<<<<<<<<<< + * if suboffset >= 0: + * resultp = ( resultp)[0] + suboffset + */ + __pyx_v_resultp = (__pyx_v_bufp + (__pyx_v_index * __pyx_v_stride)); + + /* "View.MemoryView":920 + * + * resultp = bufp + index * stride + * if suboffset >= 0: # <<<<<<<<<<<<<< + * resultp = ( resultp)[0] + suboffset + * + */ + __pyx_t_2 = (__pyx_v_suboffset >= 0); + if (__pyx_t_2) { + + /* "View.MemoryView":921 + * resultp = bufp + index * stride + * if suboffset >= 0: + * resultp = ( resultp)[0] + suboffset # <<<<<<<<<<<<<< + * + * return resultp + */ + __pyx_v_resultp = ((((char **)__pyx_v_resultp)[0]) + __pyx_v_suboffset); + + /* "View.MemoryView":920 + * + * resultp = bufp + index * stride + * if suboffset >= 0: # <<<<<<<<<<<<<< + * resultp = ( resultp)[0] + suboffset + * + */ + } + + /* "View.MemoryView":923 + * resultp = ( resultp)[0] + suboffset + * + * return resultp # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = __pyx_v_resultp; + goto __pyx_L0; + + /* "View.MemoryView":896 + * + * @cname('__pyx_pybuffer_index') + * cdef char *pybuffer_index(Py_buffer *view, char *bufp, Py_ssize_t index, # <<<<<<<<<<<<<< + * Py_ssize_t dim) except NULL: + * cdef Py_ssize_t shape, stride, suboffset = -1 + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_3); + __Pyx_XDECREF(__pyx_t_5); + __Pyx_AddTraceback("View.MemoryView.pybuffer_index", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":929 + * + * @cname('__pyx_memslice_transpose') + * cdef int transpose_memslice(__Pyx_memviewslice *memslice) except -1 nogil: # <<<<<<<<<<<<<< + * cdef int ndim = memslice.memview.view.ndim + * + */ + +static int __pyx_memslice_transpose(__Pyx_memviewslice *__pyx_v_memslice) { + int __pyx_v_ndim; + Py_ssize_t *__pyx_v_shape; + Py_ssize_t *__pyx_v_strides; + int __pyx_v_i; + int __pyx_v_j; + int __pyx_r; + int __pyx_t_1; + Py_ssize_t *__pyx_t_2; + long __pyx_t_3; + long __pyx_t_4; + Py_ssize_t __pyx_t_5; + Py_ssize_t __pyx_t_6; + int __pyx_t_7; + int __pyx_t_8; + int __pyx_t_9; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save; + #endif + + /* "View.MemoryView":930 + * @cname('__pyx_memslice_transpose') + * cdef int transpose_memslice(__Pyx_memviewslice *memslice) except -1 nogil: + * cdef int ndim = memslice.memview.view.ndim # <<<<<<<<<<<<<< + * + * cdef Py_ssize_t *shape = memslice.shape + */ + __pyx_t_1 = __pyx_v_memslice->memview->view.ndim; + __pyx_v_ndim = __pyx_t_1; + + /* "View.MemoryView":932 + * cdef int ndim = memslice.memview.view.ndim + * + * cdef Py_ssize_t *shape = memslice.shape # <<<<<<<<<<<<<< + * cdef Py_ssize_t *strides = memslice.strides + * + */ + __pyx_t_2 = __pyx_v_memslice->shape; + __pyx_v_shape = __pyx_t_2; + + /* "View.MemoryView":933 + * + * cdef Py_ssize_t *shape = memslice.shape + * cdef Py_ssize_t *strides = memslice.strides # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_2 = __pyx_v_memslice->strides; + __pyx_v_strides = __pyx_t_2; + + /* "View.MemoryView":937 + * + * cdef int i, j + * for i in range(ndim // 2): # <<<<<<<<<<<<<< + * j = ndim - 1 - i + * strides[i], strides[j] = strides[j], strides[i] + */ + __pyx_t_3 = __Pyx_div_long(__pyx_v_ndim, 2); + __pyx_t_4 = __pyx_t_3; + for (__pyx_t_1 = 0; __pyx_t_1 < __pyx_t_4; __pyx_t_1+=1) { + __pyx_v_i = __pyx_t_1; + + /* "View.MemoryView":938 + * cdef int i, j + * for i in range(ndim // 2): + * j = ndim - 1 - i # <<<<<<<<<<<<<< + * strides[i], strides[j] = strides[j], strides[i] + * shape[i], shape[j] = shape[j], shape[i] + */ + __pyx_v_j = ((__pyx_v_ndim - 1) - __pyx_v_i); + + /* "View.MemoryView":939 + * for i in range(ndim // 2): + * j = ndim - 1 - i + * strides[i], strides[j] = strides[j], strides[i] # <<<<<<<<<<<<<< + * shape[i], shape[j] = shape[j], shape[i] + * + */ + __pyx_t_5 = (__pyx_v_strides[__pyx_v_j]); + __pyx_t_6 = (__pyx_v_strides[__pyx_v_i]); + (__pyx_v_strides[__pyx_v_i]) = __pyx_t_5; + (__pyx_v_strides[__pyx_v_j]) = __pyx_t_6; + + /* "View.MemoryView":940 + * j = ndim - 1 - i + * strides[i], strides[j] = strides[j], strides[i] + * shape[i], shape[j] = shape[j], shape[i] # <<<<<<<<<<<<<< + * + * if memslice.suboffsets[i] >= 0 or memslice.suboffsets[j] >= 0: + */ + __pyx_t_6 = (__pyx_v_shape[__pyx_v_j]); + __pyx_t_5 = (__pyx_v_shape[__pyx_v_i]); + (__pyx_v_shape[__pyx_v_i]) = __pyx_t_6; + (__pyx_v_shape[__pyx_v_j]) = __pyx_t_5; + + /* "View.MemoryView":942 + * shape[i], shape[j] = shape[j], shape[i] + * + * if memslice.suboffsets[i] >= 0 or memslice.suboffsets[j] >= 0: # <<<<<<<<<<<<<< + * _err(PyExc_ValueError, "Cannot transpose memoryview with indirect dimensions") + * + */ + __pyx_t_8 = ((__pyx_v_memslice->suboffsets[__pyx_v_i]) >= 0); + if (!__pyx_t_8) { + } else { + __pyx_t_7 = __pyx_t_8; + goto __pyx_L6_bool_binop_done; + } + __pyx_t_8 = ((__pyx_v_memslice->suboffsets[__pyx_v_j]) >= 0); + __pyx_t_7 = __pyx_t_8; + __pyx_L6_bool_binop_done:; + if (__pyx_t_7) { + + /* "View.MemoryView":943 + * + * if memslice.suboffsets[i] >= 0 or memslice.suboffsets[j] >= 0: + * _err(PyExc_ValueError, "Cannot transpose memoryview with indirect dimensions") # <<<<<<<<<<<<<< + * + * return 0 + */ + __pyx_t_9 = __pyx_memoryview_err(PyExc_ValueError, __pyx_kp_s_Cannot_transpose_memoryview_with); if (unlikely(__pyx_t_9 == ((int)-1))) __PYX_ERR(1, 943, __pyx_L1_error) + + /* "View.MemoryView":942 + * shape[i], shape[j] = shape[j], shape[i] + * + * if memslice.suboffsets[i] >= 0 or memslice.suboffsets[j] >= 0: # <<<<<<<<<<<<<< + * _err(PyExc_ValueError, "Cannot transpose memoryview with indirect dimensions") + * + */ + } + } + + /* "View.MemoryView":945 + * _err(PyExc_ValueError, "Cannot transpose memoryview with indirect dimensions") + * + * return 0 # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = 0; + goto __pyx_L0; + + /* "View.MemoryView":929 + * + * @cname('__pyx_memslice_transpose') + * cdef int transpose_memslice(__Pyx_memviewslice *memslice) except -1 nogil: # <<<<<<<<<<<<<< + * cdef int ndim = memslice.memview.view.ndim + * + */ + + /* function exit code */ + __pyx_L1_error:; + #ifdef WITH_THREAD + __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + __Pyx_AddTraceback("View.MemoryView.transpose_memslice", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":963 + * cdef int (*to_dtype_func)(char *, object) except 0 + * + * def __dealloc__(self): # <<<<<<<<<<<<<< + * __PYX_XCLEAR_MEMVIEW(&self.from_slice, 1) + * + */ + +/* Python wrapper */ +static void __pyx_memoryviewslice___dealloc__(PyObject *__pyx_v_self); /*proto*/ +static void __pyx_memoryviewslice___dealloc__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__dealloc__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_memoryviewslice___pyx_pf_15View_dot_MemoryView_16_memoryviewslice___dealloc__(((struct __pyx_memoryviewslice_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); +} + +static void __pyx_memoryviewslice___pyx_pf_15View_dot_MemoryView_16_memoryviewslice___dealloc__(struct __pyx_memoryviewslice_obj *__pyx_v_self) { + + /* "View.MemoryView":964 + * + * def __dealloc__(self): + * __PYX_XCLEAR_MEMVIEW(&self.from_slice, 1) # <<<<<<<<<<<<<< + * + * cdef convert_item_to_object(self, char *itemp): + */ + __PYX_XCLEAR_MEMVIEW((&__pyx_v_self->from_slice), 1); + + /* "View.MemoryView":963 + * cdef int (*to_dtype_func)(char *, object) except 0 + * + * def __dealloc__(self): # <<<<<<<<<<<<<< + * __PYX_XCLEAR_MEMVIEW(&self.from_slice, 1) + * + */ + + /* function exit code */ +} + +/* "View.MemoryView":966 + * __PYX_XCLEAR_MEMVIEW(&self.from_slice, 1) + * + * cdef convert_item_to_object(self, char *itemp): # <<<<<<<<<<<<<< + * if self.to_object_func != NULL: + * return self.to_object_func(itemp) + */ + +static PyObject *__pyx_memoryviewslice_convert_item_to_object(struct __pyx_memoryviewslice_obj *__pyx_v_self, char *__pyx_v_itemp) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("convert_item_to_object", 1); + + /* "View.MemoryView":967 + * + * cdef convert_item_to_object(self, char *itemp): + * if self.to_object_func != NULL: # <<<<<<<<<<<<<< + * return self.to_object_func(itemp) + * else: + */ + __pyx_t_1 = (__pyx_v_self->to_object_func != NULL); + if (__pyx_t_1) { + + /* "View.MemoryView":968 + * cdef convert_item_to_object(self, char *itemp): + * if self.to_object_func != NULL: + * return self.to_object_func(itemp) # <<<<<<<<<<<<<< + * else: + * return memoryview.convert_item_to_object(self, itemp) + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = __pyx_v_self->to_object_func(__pyx_v_itemp); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 968, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":967 + * + * cdef convert_item_to_object(self, char *itemp): + * if self.to_object_func != NULL: # <<<<<<<<<<<<<< + * return self.to_object_func(itemp) + * else: + */ + } + + /* "View.MemoryView":970 + * return self.to_object_func(itemp) + * else: + * return memoryview.convert_item_to_object(self, itemp) # <<<<<<<<<<<<<< + * + * cdef assign_item_from_object(self, char *itemp, object value): + */ + /*else*/ { + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = __pyx_memoryview_convert_item_to_object(((struct __pyx_memoryview_obj *)__pyx_v_self), __pyx_v_itemp); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 970, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + } + + /* "View.MemoryView":966 + * __PYX_XCLEAR_MEMVIEW(&self.from_slice, 1) + * + * cdef convert_item_to_object(self, char *itemp): # <<<<<<<<<<<<<< + * if self.to_object_func != NULL: + * return self.to_object_func(itemp) + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView._memoryviewslice.convert_item_to_object", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":972 + * return memoryview.convert_item_to_object(self, itemp) + * + * cdef assign_item_from_object(self, char *itemp, object value): # <<<<<<<<<<<<<< + * if self.to_dtype_func != NULL: + * self.to_dtype_func(itemp, value) + */ + +static PyObject *__pyx_memoryviewslice_assign_item_from_object(struct __pyx_memoryviewslice_obj *__pyx_v_self, char *__pyx_v_itemp, PyObject *__pyx_v_value) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("assign_item_from_object", 1); + + /* "View.MemoryView":973 + * + * cdef assign_item_from_object(self, char *itemp, object value): + * if self.to_dtype_func != NULL: # <<<<<<<<<<<<<< + * self.to_dtype_func(itemp, value) + * else: + */ + __pyx_t_1 = (__pyx_v_self->to_dtype_func != NULL); + if (__pyx_t_1) { + + /* "View.MemoryView":974 + * cdef assign_item_from_object(self, char *itemp, object value): + * if self.to_dtype_func != NULL: + * self.to_dtype_func(itemp, value) # <<<<<<<<<<<<<< + * else: + * memoryview.assign_item_from_object(self, itemp, value) + */ + __pyx_t_2 = __pyx_v_self->to_dtype_func(__pyx_v_itemp, __pyx_v_value); if (unlikely(__pyx_t_2 == ((int)0))) __PYX_ERR(1, 974, __pyx_L1_error) + + /* "View.MemoryView":973 + * + * cdef assign_item_from_object(self, char *itemp, object value): + * if self.to_dtype_func != NULL: # <<<<<<<<<<<<<< + * self.to_dtype_func(itemp, value) + * else: + */ + goto __pyx_L3; + } + + /* "View.MemoryView":976 + * self.to_dtype_func(itemp, value) + * else: + * memoryview.assign_item_from_object(self, itemp, value) # <<<<<<<<<<<<<< + * + * cdef _get_base(self): + */ + /*else*/ { + __pyx_t_3 = __pyx_memoryview_assign_item_from_object(((struct __pyx_memoryview_obj *)__pyx_v_self), __pyx_v_itemp, __pyx_v_value); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 976, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + } + __pyx_L3:; + + /* "View.MemoryView":972 + * return memoryview.convert_item_to_object(self, itemp) + * + * cdef assign_item_from_object(self, char *itemp, object value): # <<<<<<<<<<<<<< + * if self.to_dtype_func != NULL: + * self.to_dtype_func(itemp, value) + */ + + /* function exit code */ + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_3); + __Pyx_AddTraceback("View.MemoryView._memoryviewslice.assign_item_from_object", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":978 + * memoryview.assign_item_from_object(self, itemp, value) + * + * cdef _get_base(self): # <<<<<<<<<<<<<< + * return self.from_object + * + */ + +static PyObject *__pyx_memoryviewslice__get_base(struct __pyx_memoryviewslice_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("_get_base", 1); + + /* "View.MemoryView":979 + * + * cdef _get_base(self): + * return self.from_object # <<<<<<<<<<<<<< + * + * + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(__pyx_v_self->from_object); + __pyx_r = __pyx_v_self->from_object; + goto __pyx_L0; + + /* "View.MemoryView":978 + * memoryview.assign_item_from_object(self, itemp, value) + * + * cdef _get_base(self): # <<<<<<<<<<<<<< + * return self.from_object + * + */ + + /* function exit code */ + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): + */ + +/* Python wrapper */ +static PyObject *__pyx_pw___pyx_memoryviewslice_1__reduce_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_pw___pyx_memoryviewslice_1__reduce_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__reduce_cython__ (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + if (unlikely(__pyx_nargs > 0)) { + __Pyx_RaiseArgtupleInvalid("__reduce_cython__", 1, 0, 0, __pyx_nargs); return NULL;} + if (unlikely(__pyx_kwds) && __Pyx_NumKwargs_FASTCALL(__pyx_kwds) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "__reduce_cython__", 0))) return NULL; + __pyx_r = __pyx_pf___pyx_memoryviewslice___reduce_cython__(((struct __pyx_memoryviewslice_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf___pyx_memoryviewslice___reduce_cython__(CYTHON_UNUSED struct __pyx_memoryviewslice_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__reduce_cython__", 1); + + /* "(tree fragment)":2 + * def __reduce_cython__(self): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" # <<<<<<<<<<<<<< + * def __setstate_cython__(self, __pyx_state): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + */ + __Pyx_Raise(__pyx_builtin_TypeError, __pyx_kp_s_no_default___reduce___due_to_non, 0, 0); + __PYX_ERR(1, 2, __pyx_L1_error) + + /* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView._memoryviewslice.__reduce_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":3 + * def __reduce_cython__(self): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + */ + +/* Python wrapper */ +static PyObject *__pyx_pw___pyx_memoryviewslice_3__setstate_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_pw___pyx_memoryviewslice_3__setstate_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + CYTHON_UNUSED PyObject *__pyx_v___pyx_state = 0; + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[1] = {0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__setstate_cython__ (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_pyx_state,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 1: values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_FASTCALL(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_pyx_state)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 3, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "__setstate_cython__") < 0)) __PYX_ERR(1, 3, __pyx_L3_error) + } + } else if (unlikely(__pyx_nargs != 1)) { + goto __pyx_L5_argtuple_error; + } else { + values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + } + __pyx_v___pyx_state = values[0]; + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("__setstate_cython__", 1, 1, 1, __pyx_nargs); __PYX_ERR(1, 3, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("View.MemoryView._memoryviewslice.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return NULL; + __pyx_L4_argument_unpacking_done:; + __pyx_r = __pyx_pf___pyx_memoryviewslice_2__setstate_cython__(((struct __pyx_memoryviewslice_obj *)__pyx_v_self), __pyx_v___pyx_state); + + /* function exit code */ + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf___pyx_memoryviewslice_2__setstate_cython__(CYTHON_UNUSED struct __pyx_memoryviewslice_obj *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__setstate_cython__", 1); + + /* "(tree fragment)":4 + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" # <<<<<<<<<<<<<< + */ + __Pyx_Raise(__pyx_builtin_TypeError, __pyx_kp_s_no_default___reduce___due_to_non, 0, 0); + __PYX_ERR(1, 4, __pyx_L1_error) + + /* "(tree fragment)":3 + * def __reduce_cython__(self): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView._memoryviewslice.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":999 + * + * @cname('__pyx_memoryview_fromslice') + * cdef memoryview_fromslice(__Pyx_memviewslice memviewslice, # <<<<<<<<<<<<<< + * int ndim, + * object (*to_object_func)(char *), + */ + +static PyObject *__pyx_memoryview_fromslice(__Pyx_memviewslice __pyx_v_memviewslice, int __pyx_v_ndim, PyObject *(*__pyx_v_to_object_func)(char *), int (*__pyx_v_to_dtype_func)(char *, PyObject *), int __pyx_v_dtype_is_object) { + struct __pyx_memoryviewslice_obj *__pyx_v_result = 0; + Py_ssize_t __pyx_v_suboffset; + PyObject *__pyx_v_length = NULL; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + __Pyx_TypeInfo *__pyx_t_4; + Py_buffer __pyx_t_5; + Py_ssize_t *__pyx_t_6; + Py_ssize_t *__pyx_t_7; + Py_ssize_t *__pyx_t_8; + Py_ssize_t __pyx_t_9; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("memoryview_fromslice", 1); + + /* "View.MemoryView":1007 + * cdef _memoryviewslice result + * + * if memviewslice.memview == Py_None: # <<<<<<<<<<<<<< + * return None + * + */ + __pyx_t_1 = (((PyObject *)__pyx_v_memviewslice.memview) == Py_None); + if (__pyx_t_1) { + + /* "View.MemoryView":1008 + * + * if memviewslice.memview == Py_None: + * return None # <<<<<<<<<<<<<< + * + * + */ + __Pyx_XDECREF(__pyx_r); + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + + /* "View.MemoryView":1007 + * cdef _memoryviewslice result + * + * if memviewslice.memview == Py_None: # <<<<<<<<<<<<<< + * return None + * + */ + } + + /* "View.MemoryView":1013 + * + * + * result = _memoryviewslice.__new__(_memoryviewslice, None, 0, dtype_is_object) # <<<<<<<<<<<<<< + * + * result.from_slice = memviewslice + */ + __pyx_t_2 = __Pyx_PyBool_FromLong(__pyx_v_dtype_is_object); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 1013, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_3 = PyTuple_New(3); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 1013, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_INCREF(Py_None); + __Pyx_GIVEREF(Py_None); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, Py_None)) __PYX_ERR(1, 1013, __pyx_L1_error); + __Pyx_INCREF(__pyx_int_0); + __Pyx_GIVEREF(__pyx_int_0); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_int_0)) __PYX_ERR(1, 1013, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_2); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 2, __pyx_t_2)) __PYX_ERR(1, 1013, __pyx_L1_error); + __pyx_t_2 = 0; + __pyx_t_2 = ((PyObject *)__pyx_tp_new__memoryviewslice(((PyTypeObject *)__pyx_memoryviewslice_type), __pyx_t_3, NULL)); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 1013, __pyx_L1_error) + __Pyx_GOTREF((PyObject *)__pyx_t_2); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __pyx_v_result = ((struct __pyx_memoryviewslice_obj *)__pyx_t_2); + __pyx_t_2 = 0; + + /* "View.MemoryView":1015 + * result = _memoryviewslice.__new__(_memoryviewslice, None, 0, dtype_is_object) + * + * result.from_slice = memviewslice # <<<<<<<<<<<<<< + * __PYX_INC_MEMVIEW(&memviewslice, 1) + * + */ + __pyx_v_result->from_slice = __pyx_v_memviewslice; + + /* "View.MemoryView":1016 + * + * result.from_slice = memviewslice + * __PYX_INC_MEMVIEW(&memviewslice, 1) # <<<<<<<<<<<<<< + * + * result.from_object = ( memviewslice.memview)._get_base() + */ + __PYX_INC_MEMVIEW((&__pyx_v_memviewslice), 1); + + /* "View.MemoryView":1018 + * __PYX_INC_MEMVIEW(&memviewslice, 1) + * + * result.from_object = ( memviewslice.memview)._get_base() # <<<<<<<<<<<<<< + * result.typeinfo = memviewslice.memview.typeinfo + * + */ + __pyx_t_2 = ((struct __pyx_vtabstruct_memoryview *)((struct __pyx_memoryview_obj *)__pyx_v_memviewslice.memview)->__pyx_vtab)->_get_base(((struct __pyx_memoryview_obj *)__pyx_v_memviewslice.memview)); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 1018, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_GIVEREF(__pyx_t_2); + __Pyx_GOTREF(__pyx_v_result->from_object); + __Pyx_DECREF(__pyx_v_result->from_object); + __pyx_v_result->from_object = __pyx_t_2; + __pyx_t_2 = 0; + + /* "View.MemoryView":1019 + * + * result.from_object = ( memviewslice.memview)._get_base() + * result.typeinfo = memviewslice.memview.typeinfo # <<<<<<<<<<<<<< + * + * result.view = memviewslice.memview.view + */ + __pyx_t_4 = __pyx_v_memviewslice.memview->typeinfo; + __pyx_v_result->__pyx_base.typeinfo = __pyx_t_4; + + /* "View.MemoryView":1021 + * result.typeinfo = memviewslice.memview.typeinfo + * + * result.view = memviewslice.memview.view # <<<<<<<<<<<<<< + * result.view.buf = memviewslice.data + * result.view.ndim = ndim + */ + __pyx_t_5 = __pyx_v_memviewslice.memview->view; + __pyx_v_result->__pyx_base.view = __pyx_t_5; + + /* "View.MemoryView":1022 + * + * result.view = memviewslice.memview.view + * result.view.buf = memviewslice.data # <<<<<<<<<<<<<< + * result.view.ndim = ndim + * (<__pyx_buffer *> &result.view).obj = Py_None + */ + __pyx_v_result->__pyx_base.view.buf = ((void *)__pyx_v_memviewslice.data); + + /* "View.MemoryView":1023 + * result.view = memviewslice.memview.view + * result.view.buf = memviewslice.data + * result.view.ndim = ndim # <<<<<<<<<<<<<< + * (<__pyx_buffer *> &result.view).obj = Py_None + * Py_INCREF(Py_None) + */ + __pyx_v_result->__pyx_base.view.ndim = __pyx_v_ndim; + + /* "View.MemoryView":1024 + * result.view.buf = memviewslice.data + * result.view.ndim = ndim + * (<__pyx_buffer *> &result.view).obj = Py_None # <<<<<<<<<<<<<< + * Py_INCREF(Py_None) + * + */ + ((Py_buffer *)(&__pyx_v_result->__pyx_base.view))->obj = Py_None; + + /* "View.MemoryView":1025 + * result.view.ndim = ndim + * (<__pyx_buffer *> &result.view).obj = Py_None + * Py_INCREF(Py_None) # <<<<<<<<<<<<<< + * + * if (memviewslice.memview).flags & PyBUF_WRITABLE: + */ + Py_INCREF(Py_None); + + /* "View.MemoryView":1027 + * Py_INCREF(Py_None) + * + * if (memviewslice.memview).flags & PyBUF_WRITABLE: # <<<<<<<<<<<<<< + * result.flags = PyBUF_RECORDS + * else: + */ + __pyx_t_1 = ((((struct __pyx_memoryview_obj *)__pyx_v_memviewslice.memview)->flags & PyBUF_WRITABLE) != 0); + if (__pyx_t_1) { + + /* "View.MemoryView":1028 + * + * if (memviewslice.memview).flags & PyBUF_WRITABLE: + * result.flags = PyBUF_RECORDS # <<<<<<<<<<<<<< + * else: + * result.flags = PyBUF_RECORDS_RO + */ + __pyx_v_result->__pyx_base.flags = PyBUF_RECORDS; + + /* "View.MemoryView":1027 + * Py_INCREF(Py_None) + * + * if (memviewslice.memview).flags & PyBUF_WRITABLE: # <<<<<<<<<<<<<< + * result.flags = PyBUF_RECORDS + * else: + */ + goto __pyx_L4; + } + + /* "View.MemoryView":1030 + * result.flags = PyBUF_RECORDS + * else: + * result.flags = PyBUF_RECORDS_RO # <<<<<<<<<<<<<< + * + * result.view.shape = result.from_slice.shape + */ + /*else*/ { + __pyx_v_result->__pyx_base.flags = PyBUF_RECORDS_RO; + } + __pyx_L4:; + + /* "View.MemoryView":1032 + * result.flags = PyBUF_RECORDS_RO + * + * result.view.shape = result.from_slice.shape # <<<<<<<<<<<<<< + * result.view.strides = result.from_slice.strides + * + */ + __pyx_v_result->__pyx_base.view.shape = ((Py_ssize_t *)__pyx_v_result->from_slice.shape); + + /* "View.MemoryView":1033 + * + * result.view.shape = result.from_slice.shape + * result.view.strides = result.from_slice.strides # <<<<<<<<<<<<<< + * + * + */ + __pyx_v_result->__pyx_base.view.strides = ((Py_ssize_t *)__pyx_v_result->from_slice.strides); + + /* "View.MemoryView":1036 + * + * + * result.view.suboffsets = NULL # <<<<<<<<<<<<<< + * for suboffset in result.from_slice.suboffsets[:ndim]: + * if suboffset >= 0: + */ + __pyx_v_result->__pyx_base.view.suboffsets = NULL; + + /* "View.MemoryView":1037 + * + * result.view.suboffsets = NULL + * for suboffset in result.from_slice.suboffsets[:ndim]: # <<<<<<<<<<<<<< + * if suboffset >= 0: + * result.view.suboffsets = result.from_slice.suboffsets + */ + __pyx_t_7 = (__pyx_v_result->from_slice.suboffsets + __pyx_v_ndim); + for (__pyx_t_8 = __pyx_v_result->from_slice.suboffsets; __pyx_t_8 < __pyx_t_7; __pyx_t_8++) { + __pyx_t_6 = __pyx_t_8; + __pyx_v_suboffset = (__pyx_t_6[0]); + + /* "View.MemoryView":1038 + * result.view.suboffsets = NULL + * for suboffset in result.from_slice.suboffsets[:ndim]: + * if suboffset >= 0: # <<<<<<<<<<<<<< + * result.view.suboffsets = result.from_slice.suboffsets + * break + */ + __pyx_t_1 = (__pyx_v_suboffset >= 0); + if (__pyx_t_1) { + + /* "View.MemoryView":1039 + * for suboffset in result.from_slice.suboffsets[:ndim]: + * if suboffset >= 0: + * result.view.suboffsets = result.from_slice.suboffsets # <<<<<<<<<<<<<< + * break + * + */ + __pyx_v_result->__pyx_base.view.suboffsets = ((Py_ssize_t *)__pyx_v_result->from_slice.suboffsets); + + /* "View.MemoryView":1040 + * if suboffset >= 0: + * result.view.suboffsets = result.from_slice.suboffsets + * break # <<<<<<<<<<<<<< + * + * result.view.len = result.view.itemsize + */ + goto __pyx_L6_break; + + /* "View.MemoryView":1038 + * result.view.suboffsets = NULL + * for suboffset in result.from_slice.suboffsets[:ndim]: + * if suboffset >= 0: # <<<<<<<<<<<<<< + * result.view.suboffsets = result.from_slice.suboffsets + * break + */ + } + } + __pyx_L6_break:; + + /* "View.MemoryView":1042 + * break + * + * result.view.len = result.view.itemsize # <<<<<<<<<<<<<< + * for length in result.view.shape[:ndim]: + * result.view.len *= length + */ + __pyx_t_9 = __pyx_v_result->__pyx_base.view.itemsize; + __pyx_v_result->__pyx_base.view.len = __pyx_t_9; + + /* "View.MemoryView":1043 + * + * result.view.len = result.view.itemsize + * for length in result.view.shape[:ndim]: # <<<<<<<<<<<<<< + * result.view.len *= length + * + */ + __pyx_t_7 = (__pyx_v_result->__pyx_base.view.shape + __pyx_v_ndim); + for (__pyx_t_8 = __pyx_v_result->__pyx_base.view.shape; __pyx_t_8 < __pyx_t_7; __pyx_t_8++) { + __pyx_t_6 = __pyx_t_8; + __pyx_t_2 = PyInt_FromSsize_t((__pyx_t_6[0])); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 1043, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_XDECREF_SET(__pyx_v_length, __pyx_t_2); + __pyx_t_2 = 0; + + /* "View.MemoryView":1044 + * result.view.len = result.view.itemsize + * for length in result.view.shape[:ndim]: + * result.view.len *= length # <<<<<<<<<<<<<< + * + * result.to_object_func = to_object_func + */ + __pyx_t_2 = PyInt_FromSsize_t(__pyx_v_result->__pyx_base.view.len); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 1044, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_3 = PyNumber_InPlaceMultiply(__pyx_t_2, __pyx_v_length); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 1044, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_t_9 = __Pyx_PyIndex_AsSsize_t(__pyx_t_3); if (unlikely((__pyx_t_9 == (Py_ssize_t)-1) && PyErr_Occurred())) __PYX_ERR(1, 1044, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __pyx_v_result->__pyx_base.view.len = __pyx_t_9; + } + + /* "View.MemoryView":1046 + * result.view.len *= length + * + * result.to_object_func = to_object_func # <<<<<<<<<<<<<< + * result.to_dtype_func = to_dtype_func + * + */ + __pyx_v_result->to_object_func = __pyx_v_to_object_func; + + /* "View.MemoryView":1047 + * + * result.to_object_func = to_object_func + * result.to_dtype_func = to_dtype_func # <<<<<<<<<<<<<< + * + * return result + */ + __pyx_v_result->to_dtype_func = __pyx_v_to_dtype_func; + + /* "View.MemoryView":1049 + * result.to_dtype_func = to_dtype_func + * + * return result # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_get_slice_from_memoryview') + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF((PyObject *)__pyx_v_result); + __pyx_r = ((PyObject *)__pyx_v_result); + goto __pyx_L0; + + /* "View.MemoryView":999 + * + * @cname('__pyx_memoryview_fromslice') + * cdef memoryview_fromslice(__Pyx_memviewslice memviewslice, # <<<<<<<<<<<<<< + * int ndim, + * object (*to_object_func)(char *), + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_AddTraceback("View.MemoryView.memoryview_fromslice", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XDECREF((PyObject *)__pyx_v_result); + __Pyx_XDECREF(__pyx_v_length); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":1052 + * + * @cname('__pyx_memoryview_get_slice_from_memoryview') + * cdef __Pyx_memviewslice *get_slice_from_memview(memoryview memview, # <<<<<<<<<<<<<< + * __Pyx_memviewslice *mslice) except NULL: + * cdef _memoryviewslice obj + */ + +static __Pyx_memviewslice *__pyx_memoryview_get_slice_from_memoryview(struct __pyx_memoryview_obj *__pyx_v_memview, __Pyx_memviewslice *__pyx_v_mslice) { + struct __pyx_memoryviewslice_obj *__pyx_v_obj = 0; + __Pyx_memviewslice *__pyx_r; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("get_slice_from_memview", 1); + + /* "View.MemoryView":1055 + * __Pyx_memviewslice *mslice) except NULL: + * cdef _memoryviewslice obj + * if isinstance(memview, _memoryviewslice): # <<<<<<<<<<<<<< + * obj = memview + * return &obj.from_slice + */ + __pyx_t_1 = __Pyx_TypeCheck(((PyObject *)__pyx_v_memview), __pyx_memoryviewslice_type); + if (__pyx_t_1) { + + /* "View.MemoryView":1056 + * cdef _memoryviewslice obj + * if isinstance(memview, _memoryviewslice): + * obj = memview # <<<<<<<<<<<<<< + * return &obj.from_slice + * else: + */ + if (!(likely(((((PyObject *)__pyx_v_memview)) == Py_None) || likely(__Pyx_TypeTest(((PyObject *)__pyx_v_memview), __pyx_memoryviewslice_type))))) __PYX_ERR(1, 1056, __pyx_L1_error) + __pyx_t_2 = ((PyObject *)__pyx_v_memview); + __Pyx_INCREF(__pyx_t_2); + __pyx_v_obj = ((struct __pyx_memoryviewslice_obj *)__pyx_t_2); + __pyx_t_2 = 0; + + /* "View.MemoryView":1057 + * if isinstance(memview, _memoryviewslice): + * obj = memview + * return &obj.from_slice # <<<<<<<<<<<<<< + * else: + * slice_copy(memview, mslice) + */ + __pyx_r = (&__pyx_v_obj->from_slice); + goto __pyx_L0; + + /* "View.MemoryView":1055 + * __Pyx_memviewslice *mslice) except NULL: + * cdef _memoryviewslice obj + * if isinstance(memview, _memoryviewslice): # <<<<<<<<<<<<<< + * obj = memview + * return &obj.from_slice + */ + } + + /* "View.MemoryView":1059 + * return &obj.from_slice + * else: + * slice_copy(memview, mslice) # <<<<<<<<<<<<<< + * return mslice + * + */ + /*else*/ { + __pyx_memoryview_slice_copy(__pyx_v_memview, __pyx_v_mslice); + + /* "View.MemoryView":1060 + * else: + * slice_copy(memview, mslice) + * return mslice # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_slice_copy') + */ + __pyx_r = __pyx_v_mslice; + goto __pyx_L0; + } + + /* "View.MemoryView":1052 + * + * @cname('__pyx_memoryview_get_slice_from_memoryview') + * cdef __Pyx_memviewslice *get_slice_from_memview(memoryview memview, # <<<<<<<<<<<<<< + * __Pyx_memviewslice *mslice) except NULL: + * cdef _memoryviewslice obj + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.get_slice_from_memview", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XDECREF((PyObject *)__pyx_v_obj); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":1063 + * + * @cname('__pyx_memoryview_slice_copy') + * cdef void slice_copy(memoryview memview, __Pyx_memviewslice *dst) noexcept: # <<<<<<<<<<<<<< + * cdef int dim + * cdef (Py_ssize_t*) shape, strides, suboffsets + */ + +static void __pyx_memoryview_slice_copy(struct __pyx_memoryview_obj *__pyx_v_memview, __Pyx_memviewslice *__pyx_v_dst) { + int __pyx_v_dim; + Py_ssize_t *__pyx_v_shape; + Py_ssize_t *__pyx_v_strides; + Py_ssize_t *__pyx_v_suboffsets; + Py_ssize_t *__pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + int __pyx_t_4; + Py_ssize_t __pyx_t_5; + int __pyx_t_6; + + /* "View.MemoryView":1067 + * cdef (Py_ssize_t*) shape, strides, suboffsets + * + * shape = memview.view.shape # <<<<<<<<<<<<<< + * strides = memview.view.strides + * suboffsets = memview.view.suboffsets + */ + __pyx_t_1 = __pyx_v_memview->view.shape; + __pyx_v_shape = __pyx_t_1; + + /* "View.MemoryView":1068 + * + * shape = memview.view.shape + * strides = memview.view.strides # <<<<<<<<<<<<<< + * suboffsets = memview.view.suboffsets + * + */ + __pyx_t_1 = __pyx_v_memview->view.strides; + __pyx_v_strides = __pyx_t_1; + + /* "View.MemoryView":1069 + * shape = memview.view.shape + * strides = memview.view.strides + * suboffsets = memview.view.suboffsets # <<<<<<<<<<<<<< + * + * dst.memview = <__pyx_memoryview *> memview + */ + __pyx_t_1 = __pyx_v_memview->view.suboffsets; + __pyx_v_suboffsets = __pyx_t_1; + + /* "View.MemoryView":1071 + * suboffsets = memview.view.suboffsets + * + * dst.memview = <__pyx_memoryview *> memview # <<<<<<<<<<<<<< + * dst.data = memview.view.buf + * + */ + __pyx_v_dst->memview = ((struct __pyx_memoryview_obj *)__pyx_v_memview); + + /* "View.MemoryView":1072 + * + * dst.memview = <__pyx_memoryview *> memview + * dst.data = memview.view.buf # <<<<<<<<<<<<<< + * + * for dim in range(memview.view.ndim): + */ + __pyx_v_dst->data = ((char *)__pyx_v_memview->view.buf); + + /* "View.MemoryView":1074 + * dst.data = memview.view.buf + * + * for dim in range(memview.view.ndim): # <<<<<<<<<<<<<< + * dst.shape[dim] = shape[dim] + * dst.strides[dim] = strides[dim] + */ + __pyx_t_2 = __pyx_v_memview->view.ndim; + __pyx_t_3 = __pyx_t_2; + for (__pyx_t_4 = 0; __pyx_t_4 < __pyx_t_3; __pyx_t_4+=1) { + __pyx_v_dim = __pyx_t_4; + + /* "View.MemoryView":1075 + * + * for dim in range(memview.view.ndim): + * dst.shape[dim] = shape[dim] # <<<<<<<<<<<<<< + * dst.strides[dim] = strides[dim] + * dst.suboffsets[dim] = suboffsets[dim] if suboffsets else -1 + */ + (__pyx_v_dst->shape[__pyx_v_dim]) = (__pyx_v_shape[__pyx_v_dim]); + + /* "View.MemoryView":1076 + * for dim in range(memview.view.ndim): + * dst.shape[dim] = shape[dim] + * dst.strides[dim] = strides[dim] # <<<<<<<<<<<<<< + * dst.suboffsets[dim] = suboffsets[dim] if suboffsets else -1 + * + */ + (__pyx_v_dst->strides[__pyx_v_dim]) = (__pyx_v_strides[__pyx_v_dim]); + + /* "View.MemoryView":1077 + * dst.shape[dim] = shape[dim] + * dst.strides[dim] = strides[dim] + * dst.suboffsets[dim] = suboffsets[dim] if suboffsets else -1 # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_copy_object') + */ + __pyx_t_6 = (__pyx_v_suboffsets != 0); + if (__pyx_t_6) { + __pyx_t_5 = (__pyx_v_suboffsets[__pyx_v_dim]); + } else { + __pyx_t_5 = -1L; + } + (__pyx_v_dst->suboffsets[__pyx_v_dim]) = __pyx_t_5; + } + + /* "View.MemoryView":1063 + * + * @cname('__pyx_memoryview_slice_copy') + * cdef void slice_copy(memoryview memview, __Pyx_memviewslice *dst) noexcept: # <<<<<<<<<<<<<< + * cdef int dim + * cdef (Py_ssize_t*) shape, strides, suboffsets + */ + + /* function exit code */ +} + +/* "View.MemoryView":1080 + * + * @cname('__pyx_memoryview_copy_object') + * cdef memoryview_copy(memoryview memview): # <<<<<<<<<<<<<< + * "Create a new memoryview object" + * cdef __Pyx_memviewslice memviewslice + */ + +static PyObject *__pyx_memoryview_copy_object(struct __pyx_memoryview_obj *__pyx_v_memview) { + __Pyx_memviewslice __pyx_v_memviewslice; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("memoryview_copy", 1); + + /* "View.MemoryView":1083 + * "Create a new memoryview object" + * cdef __Pyx_memviewslice memviewslice + * slice_copy(memview, &memviewslice) # <<<<<<<<<<<<<< + * return memoryview_copy_from_slice(memview, &memviewslice) + * + */ + __pyx_memoryview_slice_copy(__pyx_v_memview, (&__pyx_v_memviewslice)); + + /* "View.MemoryView":1084 + * cdef __Pyx_memviewslice memviewslice + * slice_copy(memview, &memviewslice) + * return memoryview_copy_from_slice(memview, &memviewslice) # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_copy_object_from_slice') + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __pyx_memoryview_copy_object_from_slice(__pyx_v_memview, (&__pyx_v_memviewslice)); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 1084, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "View.MemoryView":1080 + * + * @cname('__pyx_memoryview_copy_object') + * cdef memoryview_copy(memoryview memview): # <<<<<<<<<<<<<< + * "Create a new memoryview object" + * cdef __Pyx_memviewslice memviewslice + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("View.MemoryView.memoryview_copy", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":1087 + * + * @cname('__pyx_memoryview_copy_object_from_slice') + * cdef memoryview_copy_from_slice(memoryview memview, __Pyx_memviewslice *memviewslice): # <<<<<<<<<<<<<< + * """ + * Create a new memoryview object from a given memoryview object and slice. + */ + +static PyObject *__pyx_memoryview_copy_object_from_slice(struct __pyx_memoryview_obj *__pyx_v_memview, __Pyx_memviewslice *__pyx_v_memviewslice) { + PyObject *(*__pyx_v_to_object_func)(char *); + int (*__pyx_v_to_dtype_func)(char *, PyObject *); + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + PyObject *(*__pyx_t_2)(char *); + int (*__pyx_t_3)(char *, PyObject *); + PyObject *__pyx_t_4 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("memoryview_copy_from_slice", 1); + + /* "View.MemoryView":1094 + * cdef int (*to_dtype_func)(char *, object) except 0 + * + * if isinstance(memview, _memoryviewslice): # <<<<<<<<<<<<<< + * to_object_func = (<_memoryviewslice> memview).to_object_func + * to_dtype_func = (<_memoryviewslice> memview).to_dtype_func + */ + __pyx_t_1 = __Pyx_TypeCheck(((PyObject *)__pyx_v_memview), __pyx_memoryviewslice_type); + if (__pyx_t_1) { + + /* "View.MemoryView":1095 + * + * if isinstance(memview, _memoryviewslice): + * to_object_func = (<_memoryviewslice> memview).to_object_func # <<<<<<<<<<<<<< + * to_dtype_func = (<_memoryviewslice> memview).to_dtype_func + * else: + */ + __pyx_t_2 = ((struct __pyx_memoryviewslice_obj *)__pyx_v_memview)->to_object_func; + __pyx_v_to_object_func = __pyx_t_2; + + /* "View.MemoryView":1096 + * if isinstance(memview, _memoryviewslice): + * to_object_func = (<_memoryviewslice> memview).to_object_func + * to_dtype_func = (<_memoryviewslice> memview).to_dtype_func # <<<<<<<<<<<<<< + * else: + * to_object_func = NULL + */ + __pyx_t_3 = ((struct __pyx_memoryviewslice_obj *)__pyx_v_memview)->to_dtype_func; + __pyx_v_to_dtype_func = __pyx_t_3; + + /* "View.MemoryView":1094 + * cdef int (*to_dtype_func)(char *, object) except 0 + * + * if isinstance(memview, _memoryviewslice): # <<<<<<<<<<<<<< + * to_object_func = (<_memoryviewslice> memview).to_object_func + * to_dtype_func = (<_memoryviewslice> memview).to_dtype_func + */ + goto __pyx_L3; + } + + /* "View.MemoryView":1098 + * to_dtype_func = (<_memoryviewslice> memview).to_dtype_func + * else: + * to_object_func = NULL # <<<<<<<<<<<<<< + * to_dtype_func = NULL + * + */ + /*else*/ { + __pyx_v_to_object_func = NULL; + + /* "View.MemoryView":1099 + * else: + * to_object_func = NULL + * to_dtype_func = NULL # <<<<<<<<<<<<<< + * + * return memoryview_fromslice(memviewslice[0], memview.view.ndim, + */ + __pyx_v_to_dtype_func = NULL; + } + __pyx_L3:; + + /* "View.MemoryView":1101 + * to_dtype_func = NULL + * + * return memoryview_fromslice(memviewslice[0], memview.view.ndim, # <<<<<<<<<<<<<< + * to_object_func, to_dtype_func, + * memview.dtype_is_object) + */ + __Pyx_XDECREF(__pyx_r); + + /* "View.MemoryView":1103 + * return memoryview_fromslice(memviewslice[0], memview.view.ndim, + * to_object_func, to_dtype_func, + * memview.dtype_is_object) # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_4 = __pyx_memoryview_fromslice((__pyx_v_memviewslice[0]), __pyx_v_memview->view.ndim, __pyx_v_to_object_func, __pyx_v_to_dtype_func, __pyx_v_memview->dtype_is_object); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 1101, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_r = __pyx_t_4; + __pyx_t_4 = 0; + goto __pyx_L0; + + /* "View.MemoryView":1087 + * + * @cname('__pyx_memoryview_copy_object_from_slice') + * cdef memoryview_copy_from_slice(memoryview memview, __Pyx_memviewslice *memviewslice): # <<<<<<<<<<<<<< + * """ + * Create a new memoryview object from a given memoryview object and slice. + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_4); + __Pyx_AddTraceback("View.MemoryView.memoryview_copy_from_slice", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":1109 + * + * + * cdef Py_ssize_t abs_py_ssize_t(Py_ssize_t arg) noexcept nogil: # <<<<<<<<<<<<<< + * return -arg if arg < 0 else arg + * + */ + +static Py_ssize_t abs_py_ssize_t(Py_ssize_t __pyx_v_arg) { + Py_ssize_t __pyx_r; + Py_ssize_t __pyx_t_1; + int __pyx_t_2; + + /* "View.MemoryView":1110 + * + * cdef Py_ssize_t abs_py_ssize_t(Py_ssize_t arg) noexcept nogil: + * return -arg if arg < 0 else arg # <<<<<<<<<<<<<< + * + * @cname('__pyx_get_best_slice_order') + */ + __pyx_t_2 = (__pyx_v_arg < 0); + if (__pyx_t_2) { + __pyx_t_1 = (-__pyx_v_arg); + } else { + __pyx_t_1 = __pyx_v_arg; + } + __pyx_r = __pyx_t_1; + goto __pyx_L0; + + /* "View.MemoryView":1109 + * + * + * cdef Py_ssize_t abs_py_ssize_t(Py_ssize_t arg) noexcept nogil: # <<<<<<<<<<<<<< + * return -arg if arg < 0 else arg + * + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":1113 + * + * @cname('__pyx_get_best_slice_order') + * cdef char get_best_order(__Pyx_memviewslice *mslice, int ndim) noexcept nogil: # <<<<<<<<<<<<<< + * """ + * Figure out the best memory access order for a given slice. + */ + +static char __pyx_get_best_slice_order(__Pyx_memviewslice *__pyx_v_mslice, int __pyx_v_ndim) { + int __pyx_v_i; + Py_ssize_t __pyx_v_c_stride; + Py_ssize_t __pyx_v_f_stride; + char __pyx_r; + int __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + int __pyx_t_4; + + /* "View.MemoryView":1118 + * """ + * cdef int i + * cdef Py_ssize_t c_stride = 0 # <<<<<<<<<<<<<< + * cdef Py_ssize_t f_stride = 0 + * + */ + __pyx_v_c_stride = 0; + + /* "View.MemoryView":1119 + * cdef int i + * cdef Py_ssize_t c_stride = 0 + * cdef Py_ssize_t f_stride = 0 # <<<<<<<<<<<<<< + * + * for i in range(ndim - 1, -1, -1): + */ + __pyx_v_f_stride = 0; + + /* "View.MemoryView":1121 + * cdef Py_ssize_t f_stride = 0 + * + * for i in range(ndim - 1, -1, -1): # <<<<<<<<<<<<<< + * if mslice.shape[i] > 1: + * c_stride = mslice.strides[i] + */ + for (__pyx_t_1 = (__pyx_v_ndim - 1); __pyx_t_1 > -1; __pyx_t_1-=1) { + __pyx_v_i = __pyx_t_1; + + /* "View.MemoryView":1122 + * + * for i in range(ndim - 1, -1, -1): + * if mslice.shape[i] > 1: # <<<<<<<<<<<<<< + * c_stride = mslice.strides[i] + * break + */ + __pyx_t_2 = ((__pyx_v_mslice->shape[__pyx_v_i]) > 1); + if (__pyx_t_2) { + + /* "View.MemoryView":1123 + * for i in range(ndim - 1, -1, -1): + * if mslice.shape[i] > 1: + * c_stride = mslice.strides[i] # <<<<<<<<<<<<<< + * break + * + */ + __pyx_v_c_stride = (__pyx_v_mslice->strides[__pyx_v_i]); + + /* "View.MemoryView":1124 + * if mslice.shape[i] > 1: + * c_stride = mslice.strides[i] + * break # <<<<<<<<<<<<<< + * + * for i in range(ndim): + */ + goto __pyx_L4_break; + + /* "View.MemoryView":1122 + * + * for i in range(ndim - 1, -1, -1): + * if mslice.shape[i] > 1: # <<<<<<<<<<<<<< + * c_stride = mslice.strides[i] + * break + */ + } + } + __pyx_L4_break:; + + /* "View.MemoryView":1126 + * break + * + * for i in range(ndim): # <<<<<<<<<<<<<< + * if mslice.shape[i] > 1: + * f_stride = mslice.strides[i] + */ + __pyx_t_1 = __pyx_v_ndim; + __pyx_t_3 = __pyx_t_1; + for (__pyx_t_4 = 0; __pyx_t_4 < __pyx_t_3; __pyx_t_4+=1) { + __pyx_v_i = __pyx_t_4; + + /* "View.MemoryView":1127 + * + * for i in range(ndim): + * if mslice.shape[i] > 1: # <<<<<<<<<<<<<< + * f_stride = mslice.strides[i] + * break + */ + __pyx_t_2 = ((__pyx_v_mslice->shape[__pyx_v_i]) > 1); + if (__pyx_t_2) { + + /* "View.MemoryView":1128 + * for i in range(ndim): + * if mslice.shape[i] > 1: + * f_stride = mslice.strides[i] # <<<<<<<<<<<<<< + * break + * + */ + __pyx_v_f_stride = (__pyx_v_mslice->strides[__pyx_v_i]); + + /* "View.MemoryView":1129 + * if mslice.shape[i] > 1: + * f_stride = mslice.strides[i] + * break # <<<<<<<<<<<<<< + * + * if abs_py_ssize_t(c_stride) <= abs_py_ssize_t(f_stride): + */ + goto __pyx_L7_break; + + /* "View.MemoryView":1127 + * + * for i in range(ndim): + * if mslice.shape[i] > 1: # <<<<<<<<<<<<<< + * f_stride = mslice.strides[i] + * break + */ + } + } + __pyx_L7_break:; + + /* "View.MemoryView":1131 + * break + * + * if abs_py_ssize_t(c_stride) <= abs_py_ssize_t(f_stride): # <<<<<<<<<<<<<< + * return 'C' + * else: + */ + __pyx_t_2 = (abs_py_ssize_t(__pyx_v_c_stride) <= abs_py_ssize_t(__pyx_v_f_stride)); + if (__pyx_t_2) { + + /* "View.MemoryView":1132 + * + * if abs_py_ssize_t(c_stride) <= abs_py_ssize_t(f_stride): + * return 'C' # <<<<<<<<<<<<<< + * else: + * return 'F' + */ + __pyx_r = 'C'; + goto __pyx_L0; + + /* "View.MemoryView":1131 + * break + * + * if abs_py_ssize_t(c_stride) <= abs_py_ssize_t(f_stride): # <<<<<<<<<<<<<< + * return 'C' + * else: + */ + } + + /* "View.MemoryView":1134 + * return 'C' + * else: + * return 'F' # <<<<<<<<<<<<<< + * + * @cython.cdivision(True) + */ + /*else*/ { + __pyx_r = 'F'; + goto __pyx_L0; + } + + /* "View.MemoryView":1113 + * + * @cname('__pyx_get_best_slice_order') + * cdef char get_best_order(__Pyx_memviewslice *mslice, int ndim) noexcept nogil: # <<<<<<<<<<<<<< + * """ + * Figure out the best memory access order for a given slice. + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":1137 + * + * @cython.cdivision(True) + * cdef void _copy_strided_to_strided(char *src_data, Py_ssize_t *src_strides, # <<<<<<<<<<<<<< + * char *dst_data, Py_ssize_t *dst_strides, + * Py_ssize_t *src_shape, Py_ssize_t *dst_shape, + */ + +static void _copy_strided_to_strided(char *__pyx_v_src_data, Py_ssize_t *__pyx_v_src_strides, char *__pyx_v_dst_data, Py_ssize_t *__pyx_v_dst_strides, Py_ssize_t *__pyx_v_src_shape, Py_ssize_t *__pyx_v_dst_shape, int __pyx_v_ndim, size_t __pyx_v_itemsize) { + CYTHON_UNUSED Py_ssize_t __pyx_v_i; + CYTHON_UNUSED Py_ssize_t __pyx_v_src_extent; + Py_ssize_t __pyx_v_dst_extent; + Py_ssize_t __pyx_v_src_stride; + Py_ssize_t __pyx_v_dst_stride; + int __pyx_t_1; + int __pyx_t_2; + Py_ssize_t __pyx_t_3; + Py_ssize_t __pyx_t_4; + Py_ssize_t __pyx_t_5; + + /* "View.MemoryView":1144 + * + * cdef Py_ssize_t i + * cdef Py_ssize_t src_extent = src_shape[0] # <<<<<<<<<<<<<< + * cdef Py_ssize_t dst_extent = dst_shape[0] + * cdef Py_ssize_t src_stride = src_strides[0] + */ + __pyx_v_src_extent = (__pyx_v_src_shape[0]); + + /* "View.MemoryView":1145 + * cdef Py_ssize_t i + * cdef Py_ssize_t src_extent = src_shape[0] + * cdef Py_ssize_t dst_extent = dst_shape[0] # <<<<<<<<<<<<<< + * cdef Py_ssize_t src_stride = src_strides[0] + * cdef Py_ssize_t dst_stride = dst_strides[0] + */ + __pyx_v_dst_extent = (__pyx_v_dst_shape[0]); + + /* "View.MemoryView":1146 + * cdef Py_ssize_t src_extent = src_shape[0] + * cdef Py_ssize_t dst_extent = dst_shape[0] + * cdef Py_ssize_t src_stride = src_strides[0] # <<<<<<<<<<<<<< + * cdef Py_ssize_t dst_stride = dst_strides[0] + * + */ + __pyx_v_src_stride = (__pyx_v_src_strides[0]); + + /* "View.MemoryView":1147 + * cdef Py_ssize_t dst_extent = dst_shape[0] + * cdef Py_ssize_t src_stride = src_strides[0] + * cdef Py_ssize_t dst_stride = dst_strides[0] # <<<<<<<<<<<<<< + * + * if ndim == 1: + */ + __pyx_v_dst_stride = (__pyx_v_dst_strides[0]); + + /* "View.MemoryView":1149 + * cdef Py_ssize_t dst_stride = dst_strides[0] + * + * if ndim == 1: # <<<<<<<<<<<<<< + * if (src_stride > 0 and dst_stride > 0 and + * src_stride == itemsize == dst_stride): + */ + __pyx_t_1 = (__pyx_v_ndim == 1); + if (__pyx_t_1) { + + /* "View.MemoryView":1150 + * + * if ndim == 1: + * if (src_stride > 0 and dst_stride > 0 and # <<<<<<<<<<<<<< + * src_stride == itemsize == dst_stride): + * memcpy(dst_data, src_data, itemsize * dst_extent) + */ + __pyx_t_2 = (__pyx_v_src_stride > 0); + if (__pyx_t_2) { + } else { + __pyx_t_1 = __pyx_t_2; + goto __pyx_L5_bool_binop_done; + } + __pyx_t_2 = (__pyx_v_dst_stride > 0); + if (__pyx_t_2) { + } else { + __pyx_t_1 = __pyx_t_2; + goto __pyx_L5_bool_binop_done; + } + + /* "View.MemoryView":1151 + * if ndim == 1: + * if (src_stride > 0 and dst_stride > 0 and + * src_stride == itemsize == dst_stride): # <<<<<<<<<<<<<< + * memcpy(dst_data, src_data, itemsize * dst_extent) + * else: + */ + __pyx_t_2 = (((size_t)__pyx_v_src_stride) == __pyx_v_itemsize); + if (__pyx_t_2) { + __pyx_t_2 = (__pyx_v_itemsize == ((size_t)__pyx_v_dst_stride)); + } + __pyx_t_1 = __pyx_t_2; + __pyx_L5_bool_binop_done:; + + /* "View.MemoryView":1150 + * + * if ndim == 1: + * if (src_stride > 0 and dst_stride > 0 and # <<<<<<<<<<<<<< + * src_stride == itemsize == dst_stride): + * memcpy(dst_data, src_data, itemsize * dst_extent) + */ + if (__pyx_t_1) { + + /* "View.MemoryView":1152 + * if (src_stride > 0 and dst_stride > 0 and + * src_stride == itemsize == dst_stride): + * memcpy(dst_data, src_data, itemsize * dst_extent) # <<<<<<<<<<<<<< + * else: + * for i in range(dst_extent): + */ + (void)(memcpy(__pyx_v_dst_data, __pyx_v_src_data, (__pyx_v_itemsize * __pyx_v_dst_extent))); + + /* "View.MemoryView":1150 + * + * if ndim == 1: + * if (src_stride > 0 and dst_stride > 0 and # <<<<<<<<<<<<<< + * src_stride == itemsize == dst_stride): + * memcpy(dst_data, src_data, itemsize * dst_extent) + */ + goto __pyx_L4; + } + + /* "View.MemoryView":1154 + * memcpy(dst_data, src_data, itemsize * dst_extent) + * else: + * for i in range(dst_extent): # <<<<<<<<<<<<<< + * memcpy(dst_data, src_data, itemsize) + * src_data += src_stride + */ + /*else*/ { + __pyx_t_3 = __pyx_v_dst_extent; + __pyx_t_4 = __pyx_t_3; + for (__pyx_t_5 = 0; __pyx_t_5 < __pyx_t_4; __pyx_t_5+=1) { + __pyx_v_i = __pyx_t_5; + + /* "View.MemoryView":1155 + * else: + * for i in range(dst_extent): + * memcpy(dst_data, src_data, itemsize) # <<<<<<<<<<<<<< + * src_data += src_stride + * dst_data += dst_stride + */ + (void)(memcpy(__pyx_v_dst_data, __pyx_v_src_data, __pyx_v_itemsize)); + + /* "View.MemoryView":1156 + * for i in range(dst_extent): + * memcpy(dst_data, src_data, itemsize) + * src_data += src_stride # <<<<<<<<<<<<<< + * dst_data += dst_stride + * else: + */ + __pyx_v_src_data = (__pyx_v_src_data + __pyx_v_src_stride); + + /* "View.MemoryView":1157 + * memcpy(dst_data, src_data, itemsize) + * src_data += src_stride + * dst_data += dst_stride # <<<<<<<<<<<<<< + * else: + * for i in range(dst_extent): + */ + __pyx_v_dst_data = (__pyx_v_dst_data + __pyx_v_dst_stride); + } + } + __pyx_L4:; + + /* "View.MemoryView":1149 + * cdef Py_ssize_t dst_stride = dst_strides[0] + * + * if ndim == 1: # <<<<<<<<<<<<<< + * if (src_stride > 0 and dst_stride > 0 and + * src_stride == itemsize == dst_stride): + */ + goto __pyx_L3; + } + + /* "View.MemoryView":1159 + * dst_data += dst_stride + * else: + * for i in range(dst_extent): # <<<<<<<<<<<<<< + * _copy_strided_to_strided(src_data, src_strides + 1, + * dst_data, dst_strides + 1, + */ + /*else*/ { + __pyx_t_3 = __pyx_v_dst_extent; + __pyx_t_4 = __pyx_t_3; + for (__pyx_t_5 = 0; __pyx_t_5 < __pyx_t_4; __pyx_t_5+=1) { + __pyx_v_i = __pyx_t_5; + + /* "View.MemoryView":1160 + * else: + * for i in range(dst_extent): + * _copy_strided_to_strided(src_data, src_strides + 1, # <<<<<<<<<<<<<< + * dst_data, dst_strides + 1, + * src_shape + 1, dst_shape + 1, + */ + _copy_strided_to_strided(__pyx_v_src_data, (__pyx_v_src_strides + 1), __pyx_v_dst_data, (__pyx_v_dst_strides + 1), (__pyx_v_src_shape + 1), (__pyx_v_dst_shape + 1), (__pyx_v_ndim - 1), __pyx_v_itemsize); + + /* "View.MemoryView":1164 + * src_shape + 1, dst_shape + 1, + * ndim - 1, itemsize) + * src_data += src_stride # <<<<<<<<<<<<<< + * dst_data += dst_stride + * + */ + __pyx_v_src_data = (__pyx_v_src_data + __pyx_v_src_stride); + + /* "View.MemoryView":1165 + * ndim - 1, itemsize) + * src_data += src_stride + * dst_data += dst_stride # <<<<<<<<<<<<<< + * + * cdef void copy_strided_to_strided(__Pyx_memviewslice *src, + */ + __pyx_v_dst_data = (__pyx_v_dst_data + __pyx_v_dst_stride); + } + } + __pyx_L3:; + + /* "View.MemoryView":1137 + * + * @cython.cdivision(True) + * cdef void _copy_strided_to_strided(char *src_data, Py_ssize_t *src_strides, # <<<<<<<<<<<<<< + * char *dst_data, Py_ssize_t *dst_strides, + * Py_ssize_t *src_shape, Py_ssize_t *dst_shape, + */ + + /* function exit code */ +} + +/* "View.MemoryView":1167 + * dst_data += dst_stride + * + * cdef void copy_strided_to_strided(__Pyx_memviewslice *src, # <<<<<<<<<<<<<< + * __Pyx_memviewslice *dst, + * int ndim, size_t itemsize) noexcept nogil: + */ + +static void copy_strided_to_strided(__Pyx_memviewslice *__pyx_v_src, __Pyx_memviewslice *__pyx_v_dst, int __pyx_v_ndim, size_t __pyx_v_itemsize) { + + /* "View.MemoryView":1170 + * __Pyx_memviewslice *dst, + * int ndim, size_t itemsize) noexcept nogil: + * _copy_strided_to_strided(src.data, src.strides, dst.data, dst.strides, # <<<<<<<<<<<<<< + * src.shape, dst.shape, ndim, itemsize) + * + */ + _copy_strided_to_strided(__pyx_v_src->data, __pyx_v_src->strides, __pyx_v_dst->data, __pyx_v_dst->strides, __pyx_v_src->shape, __pyx_v_dst->shape, __pyx_v_ndim, __pyx_v_itemsize); + + /* "View.MemoryView":1167 + * dst_data += dst_stride + * + * cdef void copy_strided_to_strided(__Pyx_memviewslice *src, # <<<<<<<<<<<<<< + * __Pyx_memviewslice *dst, + * int ndim, size_t itemsize) noexcept nogil: + */ + + /* function exit code */ +} + +/* "View.MemoryView":1174 + * + * @cname('__pyx_memoryview_slice_get_size') + * cdef Py_ssize_t slice_get_size(__Pyx_memviewslice *src, int ndim) noexcept nogil: # <<<<<<<<<<<<<< + * "Return the size of the memory occupied by the slice in number of bytes" + * cdef Py_ssize_t shape, size = src.memview.view.itemsize + */ + +static Py_ssize_t __pyx_memoryview_slice_get_size(__Pyx_memviewslice *__pyx_v_src, int __pyx_v_ndim) { + Py_ssize_t __pyx_v_shape; + Py_ssize_t __pyx_v_size; + Py_ssize_t __pyx_r; + Py_ssize_t __pyx_t_1; + Py_ssize_t *__pyx_t_2; + Py_ssize_t *__pyx_t_3; + Py_ssize_t *__pyx_t_4; + + /* "View.MemoryView":1176 + * cdef Py_ssize_t slice_get_size(__Pyx_memviewslice *src, int ndim) noexcept nogil: + * "Return the size of the memory occupied by the slice in number of bytes" + * cdef Py_ssize_t shape, size = src.memview.view.itemsize # <<<<<<<<<<<<<< + * + * for shape in src.shape[:ndim]: + */ + __pyx_t_1 = __pyx_v_src->memview->view.itemsize; + __pyx_v_size = __pyx_t_1; + + /* "View.MemoryView":1178 + * cdef Py_ssize_t shape, size = src.memview.view.itemsize + * + * for shape in src.shape[:ndim]: # <<<<<<<<<<<<<< + * size *= shape + * + */ + __pyx_t_3 = (__pyx_v_src->shape + __pyx_v_ndim); + for (__pyx_t_4 = __pyx_v_src->shape; __pyx_t_4 < __pyx_t_3; __pyx_t_4++) { + __pyx_t_2 = __pyx_t_4; + __pyx_v_shape = (__pyx_t_2[0]); + + /* "View.MemoryView":1179 + * + * for shape in src.shape[:ndim]: + * size *= shape # <<<<<<<<<<<<<< + * + * return size + */ + __pyx_v_size = (__pyx_v_size * __pyx_v_shape); + } + + /* "View.MemoryView":1181 + * size *= shape + * + * return size # <<<<<<<<<<<<<< + * + * @cname('__pyx_fill_contig_strides_array') + */ + __pyx_r = __pyx_v_size; + goto __pyx_L0; + + /* "View.MemoryView":1174 + * + * @cname('__pyx_memoryview_slice_get_size') + * cdef Py_ssize_t slice_get_size(__Pyx_memviewslice *src, int ndim) noexcept nogil: # <<<<<<<<<<<<<< + * "Return the size of the memory occupied by the slice in number of bytes" + * cdef Py_ssize_t shape, size = src.memview.view.itemsize + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":1184 + * + * @cname('__pyx_fill_contig_strides_array') + * cdef Py_ssize_t fill_contig_strides_array( # <<<<<<<<<<<<<< + * Py_ssize_t *shape, Py_ssize_t *strides, Py_ssize_t stride, + * int ndim, char order) noexcept nogil: + */ + +static Py_ssize_t __pyx_fill_contig_strides_array(Py_ssize_t *__pyx_v_shape, Py_ssize_t *__pyx_v_strides, Py_ssize_t __pyx_v_stride, int __pyx_v_ndim, char __pyx_v_order) { + int __pyx_v_idx; + Py_ssize_t __pyx_r; + int __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + int __pyx_t_4; + + /* "View.MemoryView":1193 + * cdef int idx + * + * if order == 'F': # <<<<<<<<<<<<<< + * for idx in range(ndim): + * strides[idx] = stride + */ + __pyx_t_1 = (__pyx_v_order == 'F'); + if (__pyx_t_1) { + + /* "View.MemoryView":1194 + * + * if order == 'F': + * for idx in range(ndim): # <<<<<<<<<<<<<< + * strides[idx] = stride + * stride *= shape[idx] + */ + __pyx_t_2 = __pyx_v_ndim; + __pyx_t_3 = __pyx_t_2; + for (__pyx_t_4 = 0; __pyx_t_4 < __pyx_t_3; __pyx_t_4+=1) { + __pyx_v_idx = __pyx_t_4; + + /* "View.MemoryView":1195 + * if order == 'F': + * for idx in range(ndim): + * strides[idx] = stride # <<<<<<<<<<<<<< + * stride *= shape[idx] + * else: + */ + (__pyx_v_strides[__pyx_v_idx]) = __pyx_v_stride; + + /* "View.MemoryView":1196 + * for idx in range(ndim): + * strides[idx] = stride + * stride *= shape[idx] # <<<<<<<<<<<<<< + * else: + * for idx in range(ndim - 1, -1, -1): + */ + __pyx_v_stride = (__pyx_v_stride * (__pyx_v_shape[__pyx_v_idx])); + } + + /* "View.MemoryView":1193 + * cdef int idx + * + * if order == 'F': # <<<<<<<<<<<<<< + * for idx in range(ndim): + * strides[idx] = stride + */ + goto __pyx_L3; + } + + /* "View.MemoryView":1198 + * stride *= shape[idx] + * else: + * for idx in range(ndim - 1, -1, -1): # <<<<<<<<<<<<<< + * strides[idx] = stride + * stride *= shape[idx] + */ + /*else*/ { + for (__pyx_t_2 = (__pyx_v_ndim - 1); __pyx_t_2 > -1; __pyx_t_2-=1) { + __pyx_v_idx = __pyx_t_2; + + /* "View.MemoryView":1199 + * else: + * for idx in range(ndim - 1, -1, -1): + * strides[idx] = stride # <<<<<<<<<<<<<< + * stride *= shape[idx] + * + */ + (__pyx_v_strides[__pyx_v_idx]) = __pyx_v_stride; + + /* "View.MemoryView":1200 + * for idx in range(ndim - 1, -1, -1): + * strides[idx] = stride + * stride *= shape[idx] # <<<<<<<<<<<<<< + * + * return stride + */ + __pyx_v_stride = (__pyx_v_stride * (__pyx_v_shape[__pyx_v_idx])); + } + } + __pyx_L3:; + + /* "View.MemoryView":1202 + * stride *= shape[idx] + * + * return stride # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_copy_data_to_temp') + */ + __pyx_r = __pyx_v_stride; + goto __pyx_L0; + + /* "View.MemoryView":1184 + * + * @cname('__pyx_fill_contig_strides_array') + * cdef Py_ssize_t fill_contig_strides_array( # <<<<<<<<<<<<<< + * Py_ssize_t *shape, Py_ssize_t *strides, Py_ssize_t stride, + * int ndim, char order) noexcept nogil: + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":1205 + * + * @cname('__pyx_memoryview_copy_data_to_temp') + * cdef void *copy_data_to_temp(__Pyx_memviewslice *src, # <<<<<<<<<<<<<< + * __Pyx_memviewslice *tmpslice, + * char order, + */ + +static void *__pyx_memoryview_copy_data_to_temp(__Pyx_memviewslice *__pyx_v_src, __Pyx_memviewslice *__pyx_v_tmpslice, char __pyx_v_order, int __pyx_v_ndim) { + int __pyx_v_i; + void *__pyx_v_result; + size_t __pyx_v_itemsize; + size_t __pyx_v_size; + void *__pyx_r; + Py_ssize_t __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + struct __pyx_memoryview_obj *__pyx_t_4; + int __pyx_t_5; + int __pyx_t_6; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save; + #endif + + /* "View.MemoryView":1216 + * cdef void *result + * + * cdef size_t itemsize = src.memview.view.itemsize # <<<<<<<<<<<<<< + * cdef size_t size = slice_get_size(src, ndim) + * + */ + __pyx_t_1 = __pyx_v_src->memview->view.itemsize; + __pyx_v_itemsize = __pyx_t_1; + + /* "View.MemoryView":1217 + * + * cdef size_t itemsize = src.memview.view.itemsize + * cdef size_t size = slice_get_size(src, ndim) # <<<<<<<<<<<<<< + * + * result = malloc(size) + */ + __pyx_v_size = __pyx_memoryview_slice_get_size(__pyx_v_src, __pyx_v_ndim); + + /* "View.MemoryView":1219 + * cdef size_t size = slice_get_size(src, ndim) + * + * result = malloc(size) # <<<<<<<<<<<<<< + * if not result: + * _err_no_memory() + */ + __pyx_v_result = malloc(__pyx_v_size); + + /* "View.MemoryView":1220 + * + * result = malloc(size) + * if not result: # <<<<<<<<<<<<<< + * _err_no_memory() + * + */ + __pyx_t_2 = (!(__pyx_v_result != 0)); + if (__pyx_t_2) { + + /* "View.MemoryView":1221 + * result = malloc(size) + * if not result: + * _err_no_memory() # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_3 = __pyx_memoryview_err_no_memory(); if (unlikely(__pyx_t_3 == ((int)-1))) __PYX_ERR(1, 1221, __pyx_L1_error) + + /* "View.MemoryView":1220 + * + * result = malloc(size) + * if not result: # <<<<<<<<<<<<<< + * _err_no_memory() + * + */ + } + + /* "View.MemoryView":1224 + * + * + * tmpslice.data = result # <<<<<<<<<<<<<< + * tmpslice.memview = src.memview + * for i in range(ndim): + */ + __pyx_v_tmpslice->data = ((char *)__pyx_v_result); + + /* "View.MemoryView":1225 + * + * tmpslice.data = result + * tmpslice.memview = src.memview # <<<<<<<<<<<<<< + * for i in range(ndim): + * tmpslice.shape[i] = src.shape[i] + */ + __pyx_t_4 = __pyx_v_src->memview; + __pyx_v_tmpslice->memview = __pyx_t_4; + + /* "View.MemoryView":1226 + * tmpslice.data = result + * tmpslice.memview = src.memview + * for i in range(ndim): # <<<<<<<<<<<<<< + * tmpslice.shape[i] = src.shape[i] + * tmpslice.suboffsets[i] = -1 + */ + __pyx_t_3 = __pyx_v_ndim; + __pyx_t_5 = __pyx_t_3; + for (__pyx_t_6 = 0; __pyx_t_6 < __pyx_t_5; __pyx_t_6+=1) { + __pyx_v_i = __pyx_t_6; + + /* "View.MemoryView":1227 + * tmpslice.memview = src.memview + * for i in range(ndim): + * tmpslice.shape[i] = src.shape[i] # <<<<<<<<<<<<<< + * tmpslice.suboffsets[i] = -1 + * + */ + (__pyx_v_tmpslice->shape[__pyx_v_i]) = (__pyx_v_src->shape[__pyx_v_i]); + + /* "View.MemoryView":1228 + * for i in range(ndim): + * tmpslice.shape[i] = src.shape[i] + * tmpslice.suboffsets[i] = -1 # <<<<<<<<<<<<<< + * + * fill_contig_strides_array(&tmpslice.shape[0], &tmpslice.strides[0], itemsize, ndim, order) + */ + (__pyx_v_tmpslice->suboffsets[__pyx_v_i]) = -1L; + } + + /* "View.MemoryView":1230 + * tmpslice.suboffsets[i] = -1 + * + * fill_contig_strides_array(&tmpslice.shape[0], &tmpslice.strides[0], itemsize, ndim, order) # <<<<<<<<<<<<<< + * + * + */ + (void)(__pyx_fill_contig_strides_array((&(__pyx_v_tmpslice->shape[0])), (&(__pyx_v_tmpslice->strides[0])), __pyx_v_itemsize, __pyx_v_ndim, __pyx_v_order)); + + /* "View.MemoryView":1233 + * + * + * for i in range(ndim): # <<<<<<<<<<<<<< + * if tmpslice.shape[i] == 1: + * tmpslice.strides[i] = 0 + */ + __pyx_t_3 = __pyx_v_ndim; + __pyx_t_5 = __pyx_t_3; + for (__pyx_t_6 = 0; __pyx_t_6 < __pyx_t_5; __pyx_t_6+=1) { + __pyx_v_i = __pyx_t_6; + + /* "View.MemoryView":1234 + * + * for i in range(ndim): + * if tmpslice.shape[i] == 1: # <<<<<<<<<<<<<< + * tmpslice.strides[i] = 0 + * + */ + __pyx_t_2 = ((__pyx_v_tmpslice->shape[__pyx_v_i]) == 1); + if (__pyx_t_2) { + + /* "View.MemoryView":1235 + * for i in range(ndim): + * if tmpslice.shape[i] == 1: + * tmpslice.strides[i] = 0 # <<<<<<<<<<<<<< + * + * if slice_is_contig(src[0], order, ndim): + */ + (__pyx_v_tmpslice->strides[__pyx_v_i]) = 0; + + /* "View.MemoryView":1234 + * + * for i in range(ndim): + * if tmpslice.shape[i] == 1: # <<<<<<<<<<<<<< + * tmpslice.strides[i] = 0 + * + */ + } + } + + /* "View.MemoryView":1237 + * tmpslice.strides[i] = 0 + * + * if slice_is_contig(src[0], order, ndim): # <<<<<<<<<<<<<< + * memcpy(result, src.data, size) + * else: + */ + __pyx_t_2 = __pyx_memviewslice_is_contig((__pyx_v_src[0]), __pyx_v_order, __pyx_v_ndim); + if (__pyx_t_2) { + + /* "View.MemoryView":1238 + * + * if slice_is_contig(src[0], order, ndim): + * memcpy(result, src.data, size) # <<<<<<<<<<<<<< + * else: + * copy_strided_to_strided(src, tmpslice, ndim, itemsize) + */ + (void)(memcpy(__pyx_v_result, __pyx_v_src->data, __pyx_v_size)); + + /* "View.MemoryView":1237 + * tmpslice.strides[i] = 0 + * + * if slice_is_contig(src[0], order, ndim): # <<<<<<<<<<<<<< + * memcpy(result, src.data, size) + * else: + */ + goto __pyx_L9; + } + + /* "View.MemoryView":1240 + * memcpy(result, src.data, size) + * else: + * copy_strided_to_strided(src, tmpslice, ndim, itemsize) # <<<<<<<<<<<<<< + * + * return result + */ + /*else*/ { + copy_strided_to_strided(__pyx_v_src, __pyx_v_tmpslice, __pyx_v_ndim, __pyx_v_itemsize); + } + __pyx_L9:; + + /* "View.MemoryView":1242 + * copy_strided_to_strided(src, tmpslice, ndim, itemsize) + * + * return result # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = __pyx_v_result; + goto __pyx_L0; + + /* "View.MemoryView":1205 + * + * @cname('__pyx_memoryview_copy_data_to_temp') + * cdef void *copy_data_to_temp(__Pyx_memviewslice *src, # <<<<<<<<<<<<<< + * __Pyx_memviewslice *tmpslice, + * char order, + */ + + /* function exit code */ + __pyx_L1_error:; + #ifdef WITH_THREAD + __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + __Pyx_AddTraceback("View.MemoryView.copy_data_to_temp", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":1247 + * + * @cname('__pyx_memoryview_err_extents') + * cdef int _err_extents(int i, Py_ssize_t extent1, # <<<<<<<<<<<<<< + * Py_ssize_t extent2) except -1 with gil: + * raise ValueError, f"got differing extents in dimension {i} (got {extent1} and {extent2})" + */ + +static int __pyx_memoryview_err_extents(int __pyx_v_i, Py_ssize_t __pyx_v_extent1, Py_ssize_t __pyx_v_extent2) { + int __pyx_r; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + Py_ssize_t __pyx_t_2; + Py_UCS4 __pyx_t_3; + PyObject *__pyx_t_4 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + __Pyx_RefNannySetupContext("_err_extents", 0); + + /* "View.MemoryView":1249 + * cdef int _err_extents(int i, Py_ssize_t extent1, + * Py_ssize_t extent2) except -1 with gil: + * raise ValueError, f"got differing extents in dimension {i} (got {extent1} and {extent2})" # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_err_dim') + */ + __pyx_t_1 = PyTuple_New(7); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 1249, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = 0; + __pyx_t_3 = 127; + __Pyx_INCREF(__pyx_kp_u_got_differing_extents_in_dimensi); + __pyx_t_2 += 35; + __Pyx_GIVEREF(__pyx_kp_u_got_differing_extents_in_dimensi); + PyTuple_SET_ITEM(__pyx_t_1, 0, __pyx_kp_u_got_differing_extents_in_dimensi); + __pyx_t_4 = __Pyx_PyUnicode_From_int(__pyx_v_i, 0, ' ', 'd'); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 1249, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_2 += __Pyx_PyUnicode_GET_LENGTH(__pyx_t_4); + __Pyx_GIVEREF(__pyx_t_4); + PyTuple_SET_ITEM(__pyx_t_1, 1, __pyx_t_4); + __pyx_t_4 = 0; + __Pyx_INCREF(__pyx_kp_u_got); + __pyx_t_2 += 6; + __Pyx_GIVEREF(__pyx_kp_u_got); + PyTuple_SET_ITEM(__pyx_t_1, 2, __pyx_kp_u_got); + __pyx_t_4 = __Pyx_PyUnicode_From_Py_ssize_t(__pyx_v_extent1, 0, ' ', 'd'); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 1249, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_2 += __Pyx_PyUnicode_GET_LENGTH(__pyx_t_4); + __Pyx_GIVEREF(__pyx_t_4); + PyTuple_SET_ITEM(__pyx_t_1, 3, __pyx_t_4); + __pyx_t_4 = 0; + __Pyx_INCREF(__pyx_kp_u_and); + __pyx_t_2 += 5; + __Pyx_GIVEREF(__pyx_kp_u_and); + PyTuple_SET_ITEM(__pyx_t_1, 4, __pyx_kp_u_and); + __pyx_t_4 = __Pyx_PyUnicode_From_Py_ssize_t(__pyx_v_extent2, 0, ' ', 'd'); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 1249, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_2 += __Pyx_PyUnicode_GET_LENGTH(__pyx_t_4); + __Pyx_GIVEREF(__pyx_t_4); + PyTuple_SET_ITEM(__pyx_t_1, 5, __pyx_t_4); + __pyx_t_4 = 0; + __Pyx_INCREF(__pyx_kp_u__7); + __pyx_t_2 += 1; + __Pyx_GIVEREF(__pyx_kp_u__7); + PyTuple_SET_ITEM(__pyx_t_1, 6, __pyx_kp_u__7); + __pyx_t_4 = __Pyx_PyUnicode_Join(__pyx_t_1, 7, __pyx_t_2, __pyx_t_3); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 1249, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_t_4, 0, 0); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __PYX_ERR(1, 1249, __pyx_L1_error) + + /* "View.MemoryView":1247 + * + * @cname('__pyx_memoryview_err_extents') + * cdef int _err_extents(int i, Py_ssize_t extent1, # <<<<<<<<<<<<<< + * Py_ssize_t extent2) except -1 with gil: + * raise ValueError, f"got differing extents in dimension {i} (got {extent1} and {extent2})" + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_4); + __Pyx_AddTraceback("View.MemoryView._err_extents", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __Pyx_RefNannyFinishContext(); + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + return __pyx_r; +} + +/* "View.MemoryView":1252 + * + * @cname('__pyx_memoryview_err_dim') + * cdef int _err_dim(PyObject *error, str msg, int dim) except -1 with gil: # <<<<<<<<<<<<<< + * raise error, msg % dim + * + */ + +static int __pyx_memoryview_err_dim(PyObject *__pyx_v_error, PyObject *__pyx_v_msg, int __pyx_v_dim) { + int __pyx_r; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + __Pyx_RefNannySetupContext("_err_dim", 0); + __Pyx_INCREF(__pyx_v_msg); + + /* "View.MemoryView":1253 + * @cname('__pyx_memoryview_err_dim') + * cdef int _err_dim(PyObject *error, str msg, int dim) except -1 with gil: + * raise error, msg % dim # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_err') + */ + __pyx_t_1 = __Pyx_PyInt_From_int(__pyx_v_dim); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 1253, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = __Pyx_PyString_FormatSafe(__pyx_v_msg, __pyx_t_1); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 1253, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_Raise(((PyObject *)__pyx_v_error), __pyx_t_2, 0, 0); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __PYX_ERR(1, 1253, __pyx_L1_error) + + /* "View.MemoryView":1252 + * + * @cname('__pyx_memoryview_err_dim') + * cdef int _err_dim(PyObject *error, str msg, int dim) except -1 with gil: # <<<<<<<<<<<<<< + * raise error, msg % dim + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView._err_dim", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __Pyx_XDECREF(__pyx_v_msg); + __Pyx_RefNannyFinishContext(); + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + return __pyx_r; +} + +/* "View.MemoryView":1256 + * + * @cname('__pyx_memoryview_err') + * cdef int _err(PyObject *error, str msg) except -1 with gil: # <<<<<<<<<<<<<< + * raise error, msg + * + */ + +static int __pyx_memoryview_err(PyObject *__pyx_v_error, PyObject *__pyx_v_msg) { + int __pyx_r; + __Pyx_RefNannyDeclarations + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + __Pyx_RefNannySetupContext("_err", 0); + __Pyx_INCREF(__pyx_v_msg); + + /* "View.MemoryView":1257 + * @cname('__pyx_memoryview_err') + * cdef int _err(PyObject *error, str msg) except -1 with gil: + * raise error, msg # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_err_no_memory') + */ + __Pyx_Raise(((PyObject *)__pyx_v_error), __pyx_v_msg, 0, 0); + __PYX_ERR(1, 1257, __pyx_L1_error) + + /* "View.MemoryView":1256 + * + * @cname('__pyx_memoryview_err') + * cdef int _err(PyObject *error, str msg) except -1 with gil: # <<<<<<<<<<<<<< + * raise error, msg + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView._err", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __Pyx_XDECREF(__pyx_v_msg); + __Pyx_RefNannyFinishContext(); + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + return __pyx_r; +} + +/* "View.MemoryView":1260 + * + * @cname('__pyx_memoryview_err_no_memory') + * cdef int _err_no_memory() except -1 with gil: # <<<<<<<<<<<<<< + * raise MemoryError + * + */ + +static int __pyx_memoryview_err_no_memory(void) { + int __pyx_r; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + + /* "View.MemoryView":1261 + * @cname('__pyx_memoryview_err_no_memory') + * cdef int _err_no_memory() except -1 with gil: + * raise MemoryError # <<<<<<<<<<<<<< + * + * + */ + PyErr_NoMemory(); __PYX_ERR(1, 1261, __pyx_L1_error) + + /* "View.MemoryView":1260 + * + * @cname('__pyx_memoryview_err_no_memory') + * cdef int _err_no_memory() except -1 with gil: # <<<<<<<<<<<<<< + * raise MemoryError + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView._err_no_memory", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + return __pyx_r; +} + +/* "View.MemoryView":1265 + * + * @cname('__pyx_memoryview_copy_contents') + * cdef int memoryview_copy_contents(__Pyx_memviewslice src, # <<<<<<<<<<<<<< + * __Pyx_memviewslice dst, + * int src_ndim, int dst_ndim, + */ + +static int __pyx_memoryview_copy_contents(__Pyx_memviewslice __pyx_v_src, __Pyx_memviewslice __pyx_v_dst, int __pyx_v_src_ndim, int __pyx_v_dst_ndim, int __pyx_v_dtype_is_object) { + void *__pyx_v_tmpdata; + size_t __pyx_v_itemsize; + int __pyx_v_i; + char __pyx_v_order; + int __pyx_v_broadcasting; + int __pyx_v_direct_copy; + __Pyx_memviewslice __pyx_v_tmp; + int __pyx_v_ndim; + int __pyx_r; + Py_ssize_t __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + int __pyx_t_4; + int __pyx_t_5; + int __pyx_t_6; + void *__pyx_t_7; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save; + #endif + + /* "View.MemoryView":1273 + * Check for overlapping memory and verify the shapes. + * """ + * cdef void *tmpdata = NULL # <<<<<<<<<<<<<< + * cdef size_t itemsize = src.memview.view.itemsize + * cdef int i + */ + __pyx_v_tmpdata = NULL; + + /* "View.MemoryView":1274 + * """ + * cdef void *tmpdata = NULL + * cdef size_t itemsize = src.memview.view.itemsize # <<<<<<<<<<<<<< + * cdef int i + * cdef char order = get_best_order(&src, src_ndim) + */ + __pyx_t_1 = __pyx_v_src.memview->view.itemsize; + __pyx_v_itemsize = __pyx_t_1; + + /* "View.MemoryView":1276 + * cdef size_t itemsize = src.memview.view.itemsize + * cdef int i + * cdef char order = get_best_order(&src, src_ndim) # <<<<<<<<<<<<<< + * cdef bint broadcasting = False + * cdef bint direct_copy = False + */ + __pyx_v_order = __pyx_get_best_slice_order((&__pyx_v_src), __pyx_v_src_ndim); + + /* "View.MemoryView":1277 + * cdef int i + * cdef char order = get_best_order(&src, src_ndim) + * cdef bint broadcasting = False # <<<<<<<<<<<<<< + * cdef bint direct_copy = False + * cdef __Pyx_memviewslice tmp + */ + __pyx_v_broadcasting = 0; + + /* "View.MemoryView":1278 + * cdef char order = get_best_order(&src, src_ndim) + * cdef bint broadcasting = False + * cdef bint direct_copy = False # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice tmp + * + */ + __pyx_v_direct_copy = 0; + + /* "View.MemoryView":1281 + * cdef __Pyx_memviewslice tmp + * + * if src_ndim < dst_ndim: # <<<<<<<<<<<<<< + * broadcast_leading(&src, src_ndim, dst_ndim) + * elif dst_ndim < src_ndim: + */ + __pyx_t_2 = (__pyx_v_src_ndim < __pyx_v_dst_ndim); + if (__pyx_t_2) { + + /* "View.MemoryView":1282 + * + * if src_ndim < dst_ndim: + * broadcast_leading(&src, src_ndim, dst_ndim) # <<<<<<<<<<<<<< + * elif dst_ndim < src_ndim: + * broadcast_leading(&dst, dst_ndim, src_ndim) + */ + __pyx_memoryview_broadcast_leading((&__pyx_v_src), __pyx_v_src_ndim, __pyx_v_dst_ndim); + + /* "View.MemoryView":1281 + * cdef __Pyx_memviewslice tmp + * + * if src_ndim < dst_ndim: # <<<<<<<<<<<<<< + * broadcast_leading(&src, src_ndim, dst_ndim) + * elif dst_ndim < src_ndim: + */ + goto __pyx_L3; + } + + /* "View.MemoryView":1283 + * if src_ndim < dst_ndim: + * broadcast_leading(&src, src_ndim, dst_ndim) + * elif dst_ndim < src_ndim: # <<<<<<<<<<<<<< + * broadcast_leading(&dst, dst_ndim, src_ndim) + * + */ + __pyx_t_2 = (__pyx_v_dst_ndim < __pyx_v_src_ndim); + if (__pyx_t_2) { + + /* "View.MemoryView":1284 + * broadcast_leading(&src, src_ndim, dst_ndim) + * elif dst_ndim < src_ndim: + * broadcast_leading(&dst, dst_ndim, src_ndim) # <<<<<<<<<<<<<< + * + * cdef int ndim = max(src_ndim, dst_ndim) + */ + __pyx_memoryview_broadcast_leading((&__pyx_v_dst), __pyx_v_dst_ndim, __pyx_v_src_ndim); + + /* "View.MemoryView":1283 + * if src_ndim < dst_ndim: + * broadcast_leading(&src, src_ndim, dst_ndim) + * elif dst_ndim < src_ndim: # <<<<<<<<<<<<<< + * broadcast_leading(&dst, dst_ndim, src_ndim) + * + */ + } + __pyx_L3:; + + /* "View.MemoryView":1286 + * broadcast_leading(&dst, dst_ndim, src_ndim) + * + * cdef int ndim = max(src_ndim, dst_ndim) # <<<<<<<<<<<<<< + * + * for i in range(ndim): + */ + __pyx_t_3 = __pyx_v_dst_ndim; + __pyx_t_4 = __pyx_v_src_ndim; + __pyx_t_2 = (__pyx_t_3 > __pyx_t_4); + if (__pyx_t_2) { + __pyx_t_5 = __pyx_t_3; + } else { + __pyx_t_5 = __pyx_t_4; + } + __pyx_v_ndim = __pyx_t_5; + + /* "View.MemoryView":1288 + * cdef int ndim = max(src_ndim, dst_ndim) + * + * for i in range(ndim): # <<<<<<<<<<<<<< + * if src.shape[i] != dst.shape[i]: + * if src.shape[i] == 1: + */ + __pyx_t_5 = __pyx_v_ndim; + __pyx_t_3 = __pyx_t_5; + for (__pyx_t_4 = 0; __pyx_t_4 < __pyx_t_3; __pyx_t_4+=1) { + __pyx_v_i = __pyx_t_4; + + /* "View.MemoryView":1289 + * + * for i in range(ndim): + * if src.shape[i] != dst.shape[i]: # <<<<<<<<<<<<<< + * if src.shape[i] == 1: + * broadcasting = True + */ + __pyx_t_2 = ((__pyx_v_src.shape[__pyx_v_i]) != (__pyx_v_dst.shape[__pyx_v_i])); + if (__pyx_t_2) { + + /* "View.MemoryView":1290 + * for i in range(ndim): + * if src.shape[i] != dst.shape[i]: + * if src.shape[i] == 1: # <<<<<<<<<<<<<< + * broadcasting = True + * src.strides[i] = 0 + */ + __pyx_t_2 = ((__pyx_v_src.shape[__pyx_v_i]) == 1); + if (__pyx_t_2) { + + /* "View.MemoryView":1291 + * if src.shape[i] != dst.shape[i]: + * if src.shape[i] == 1: + * broadcasting = True # <<<<<<<<<<<<<< + * src.strides[i] = 0 + * else: + */ + __pyx_v_broadcasting = 1; + + /* "View.MemoryView":1292 + * if src.shape[i] == 1: + * broadcasting = True + * src.strides[i] = 0 # <<<<<<<<<<<<<< + * else: + * _err_extents(i, dst.shape[i], src.shape[i]) + */ + (__pyx_v_src.strides[__pyx_v_i]) = 0; + + /* "View.MemoryView":1290 + * for i in range(ndim): + * if src.shape[i] != dst.shape[i]: + * if src.shape[i] == 1: # <<<<<<<<<<<<<< + * broadcasting = True + * src.strides[i] = 0 + */ + goto __pyx_L7; + } + + /* "View.MemoryView":1294 + * src.strides[i] = 0 + * else: + * _err_extents(i, dst.shape[i], src.shape[i]) # <<<<<<<<<<<<<< + * + * if src.suboffsets[i] >= 0: + */ + /*else*/ { + __pyx_t_6 = __pyx_memoryview_err_extents(__pyx_v_i, (__pyx_v_dst.shape[__pyx_v_i]), (__pyx_v_src.shape[__pyx_v_i])); if (unlikely(__pyx_t_6 == ((int)-1))) __PYX_ERR(1, 1294, __pyx_L1_error) + } + __pyx_L7:; + + /* "View.MemoryView":1289 + * + * for i in range(ndim): + * if src.shape[i] != dst.shape[i]: # <<<<<<<<<<<<<< + * if src.shape[i] == 1: + * broadcasting = True + */ + } + + /* "View.MemoryView":1296 + * _err_extents(i, dst.shape[i], src.shape[i]) + * + * if src.suboffsets[i] >= 0: # <<<<<<<<<<<<<< + * _err_dim(PyExc_ValueError, "Dimension %d is not direct", i) + * + */ + __pyx_t_2 = ((__pyx_v_src.suboffsets[__pyx_v_i]) >= 0); + if (__pyx_t_2) { + + /* "View.MemoryView":1297 + * + * if src.suboffsets[i] >= 0: + * _err_dim(PyExc_ValueError, "Dimension %d is not direct", i) # <<<<<<<<<<<<<< + * + * if slices_overlap(&src, &dst, ndim, itemsize): + */ + __pyx_t_6 = __pyx_memoryview_err_dim(PyExc_ValueError, __pyx_kp_s_Dimension_d_is_not_direct, __pyx_v_i); if (unlikely(__pyx_t_6 == ((int)-1))) __PYX_ERR(1, 1297, __pyx_L1_error) + + /* "View.MemoryView":1296 + * _err_extents(i, dst.shape[i], src.shape[i]) + * + * if src.suboffsets[i] >= 0: # <<<<<<<<<<<<<< + * _err_dim(PyExc_ValueError, "Dimension %d is not direct", i) + * + */ + } + } + + /* "View.MemoryView":1299 + * _err_dim(PyExc_ValueError, "Dimension %d is not direct", i) + * + * if slices_overlap(&src, &dst, ndim, itemsize): # <<<<<<<<<<<<<< + * + * if not slice_is_contig(src, order, ndim): + */ + __pyx_t_2 = __pyx_slices_overlap((&__pyx_v_src), (&__pyx_v_dst), __pyx_v_ndim, __pyx_v_itemsize); + if (__pyx_t_2) { + + /* "View.MemoryView":1301 + * if slices_overlap(&src, &dst, ndim, itemsize): + * + * if not slice_is_contig(src, order, ndim): # <<<<<<<<<<<<<< + * order = get_best_order(&dst, ndim) + * + */ + __pyx_t_2 = (!__pyx_memviewslice_is_contig(__pyx_v_src, __pyx_v_order, __pyx_v_ndim)); + if (__pyx_t_2) { + + /* "View.MemoryView":1302 + * + * if not slice_is_contig(src, order, ndim): + * order = get_best_order(&dst, ndim) # <<<<<<<<<<<<<< + * + * tmpdata = copy_data_to_temp(&src, &tmp, order, ndim) + */ + __pyx_v_order = __pyx_get_best_slice_order((&__pyx_v_dst), __pyx_v_ndim); + + /* "View.MemoryView":1301 + * if slices_overlap(&src, &dst, ndim, itemsize): + * + * if not slice_is_contig(src, order, ndim): # <<<<<<<<<<<<<< + * order = get_best_order(&dst, ndim) + * + */ + } + + /* "View.MemoryView":1304 + * order = get_best_order(&dst, ndim) + * + * tmpdata = copy_data_to_temp(&src, &tmp, order, ndim) # <<<<<<<<<<<<<< + * src = tmp + * + */ + __pyx_t_7 = __pyx_memoryview_copy_data_to_temp((&__pyx_v_src), (&__pyx_v_tmp), __pyx_v_order, __pyx_v_ndim); if (unlikely(__pyx_t_7 == ((void *)NULL))) __PYX_ERR(1, 1304, __pyx_L1_error) + __pyx_v_tmpdata = __pyx_t_7; + + /* "View.MemoryView":1305 + * + * tmpdata = copy_data_to_temp(&src, &tmp, order, ndim) + * src = tmp # <<<<<<<<<<<<<< + * + * if not broadcasting: + */ + __pyx_v_src = __pyx_v_tmp; + + /* "View.MemoryView":1299 + * _err_dim(PyExc_ValueError, "Dimension %d is not direct", i) + * + * if slices_overlap(&src, &dst, ndim, itemsize): # <<<<<<<<<<<<<< + * + * if not slice_is_contig(src, order, ndim): + */ + } + + /* "View.MemoryView":1307 + * src = tmp + * + * if not broadcasting: # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_2 = (!__pyx_v_broadcasting); + if (__pyx_t_2) { + + /* "View.MemoryView":1310 + * + * + * if slice_is_contig(src, 'C', ndim): # <<<<<<<<<<<<<< + * direct_copy = slice_is_contig(dst, 'C', ndim) + * elif slice_is_contig(src, 'F', ndim): + */ + __pyx_t_2 = __pyx_memviewslice_is_contig(__pyx_v_src, 'C', __pyx_v_ndim); + if (__pyx_t_2) { + + /* "View.MemoryView":1311 + * + * if slice_is_contig(src, 'C', ndim): + * direct_copy = slice_is_contig(dst, 'C', ndim) # <<<<<<<<<<<<<< + * elif slice_is_contig(src, 'F', ndim): + * direct_copy = slice_is_contig(dst, 'F', ndim) + */ + __pyx_v_direct_copy = __pyx_memviewslice_is_contig(__pyx_v_dst, 'C', __pyx_v_ndim); + + /* "View.MemoryView":1310 + * + * + * if slice_is_contig(src, 'C', ndim): # <<<<<<<<<<<<<< + * direct_copy = slice_is_contig(dst, 'C', ndim) + * elif slice_is_contig(src, 'F', ndim): + */ + goto __pyx_L12; + } + + /* "View.MemoryView":1312 + * if slice_is_contig(src, 'C', ndim): + * direct_copy = slice_is_contig(dst, 'C', ndim) + * elif slice_is_contig(src, 'F', ndim): # <<<<<<<<<<<<<< + * direct_copy = slice_is_contig(dst, 'F', ndim) + * + */ + __pyx_t_2 = __pyx_memviewslice_is_contig(__pyx_v_src, 'F', __pyx_v_ndim); + if (__pyx_t_2) { + + /* "View.MemoryView":1313 + * direct_copy = slice_is_contig(dst, 'C', ndim) + * elif slice_is_contig(src, 'F', ndim): + * direct_copy = slice_is_contig(dst, 'F', ndim) # <<<<<<<<<<<<<< + * + * if direct_copy: + */ + __pyx_v_direct_copy = __pyx_memviewslice_is_contig(__pyx_v_dst, 'F', __pyx_v_ndim); + + /* "View.MemoryView":1312 + * if slice_is_contig(src, 'C', ndim): + * direct_copy = slice_is_contig(dst, 'C', ndim) + * elif slice_is_contig(src, 'F', ndim): # <<<<<<<<<<<<<< + * direct_copy = slice_is_contig(dst, 'F', ndim) + * + */ + } + __pyx_L12:; + + /* "View.MemoryView":1315 + * direct_copy = slice_is_contig(dst, 'F', ndim) + * + * if direct_copy: # <<<<<<<<<<<<<< + * + * refcount_copying(&dst, dtype_is_object, ndim, inc=False) + */ + if (__pyx_v_direct_copy) { + + /* "View.MemoryView":1317 + * if direct_copy: + * + * refcount_copying(&dst, dtype_is_object, ndim, inc=False) # <<<<<<<<<<<<<< + * memcpy(dst.data, src.data, slice_get_size(&src, ndim)) + * refcount_copying(&dst, dtype_is_object, ndim, inc=True) + */ + __pyx_memoryview_refcount_copying((&__pyx_v_dst), __pyx_v_dtype_is_object, __pyx_v_ndim, 0); + + /* "View.MemoryView":1318 + * + * refcount_copying(&dst, dtype_is_object, ndim, inc=False) + * memcpy(dst.data, src.data, slice_get_size(&src, ndim)) # <<<<<<<<<<<<<< + * refcount_copying(&dst, dtype_is_object, ndim, inc=True) + * free(tmpdata) + */ + (void)(memcpy(__pyx_v_dst.data, __pyx_v_src.data, __pyx_memoryview_slice_get_size((&__pyx_v_src), __pyx_v_ndim))); + + /* "View.MemoryView":1319 + * refcount_copying(&dst, dtype_is_object, ndim, inc=False) + * memcpy(dst.data, src.data, slice_get_size(&src, ndim)) + * refcount_copying(&dst, dtype_is_object, ndim, inc=True) # <<<<<<<<<<<<<< + * free(tmpdata) + * return 0 + */ + __pyx_memoryview_refcount_copying((&__pyx_v_dst), __pyx_v_dtype_is_object, __pyx_v_ndim, 1); + + /* "View.MemoryView":1320 + * memcpy(dst.data, src.data, slice_get_size(&src, ndim)) + * refcount_copying(&dst, dtype_is_object, ndim, inc=True) + * free(tmpdata) # <<<<<<<<<<<<<< + * return 0 + * + */ + free(__pyx_v_tmpdata); + + /* "View.MemoryView":1321 + * refcount_copying(&dst, dtype_is_object, ndim, inc=True) + * free(tmpdata) + * return 0 # <<<<<<<<<<<<<< + * + * if order == 'F' == get_best_order(&dst, ndim): + */ + __pyx_r = 0; + goto __pyx_L0; + + /* "View.MemoryView":1315 + * direct_copy = slice_is_contig(dst, 'F', ndim) + * + * if direct_copy: # <<<<<<<<<<<<<< + * + * refcount_copying(&dst, dtype_is_object, ndim, inc=False) + */ + } + + /* "View.MemoryView":1307 + * src = tmp + * + * if not broadcasting: # <<<<<<<<<<<<<< + * + * + */ + } + + /* "View.MemoryView":1323 + * return 0 + * + * if order == 'F' == get_best_order(&dst, ndim): # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_2 = (__pyx_v_order == 'F'); + if (__pyx_t_2) { + __pyx_t_2 = ('F' == __pyx_get_best_slice_order((&__pyx_v_dst), __pyx_v_ndim)); + } + if (__pyx_t_2) { + + /* "View.MemoryView":1326 + * + * + * transpose_memslice(&src) # <<<<<<<<<<<<<< + * transpose_memslice(&dst) + * + */ + __pyx_t_5 = __pyx_memslice_transpose((&__pyx_v_src)); if (unlikely(__pyx_t_5 == ((int)-1))) __PYX_ERR(1, 1326, __pyx_L1_error) + + /* "View.MemoryView":1327 + * + * transpose_memslice(&src) + * transpose_memslice(&dst) # <<<<<<<<<<<<<< + * + * refcount_copying(&dst, dtype_is_object, ndim, inc=False) + */ + __pyx_t_5 = __pyx_memslice_transpose((&__pyx_v_dst)); if (unlikely(__pyx_t_5 == ((int)-1))) __PYX_ERR(1, 1327, __pyx_L1_error) + + /* "View.MemoryView":1323 + * return 0 + * + * if order == 'F' == get_best_order(&dst, ndim): # <<<<<<<<<<<<<< + * + * + */ + } + + /* "View.MemoryView":1329 + * transpose_memslice(&dst) + * + * refcount_copying(&dst, dtype_is_object, ndim, inc=False) # <<<<<<<<<<<<<< + * copy_strided_to_strided(&src, &dst, ndim, itemsize) + * refcount_copying(&dst, dtype_is_object, ndim, inc=True) + */ + __pyx_memoryview_refcount_copying((&__pyx_v_dst), __pyx_v_dtype_is_object, __pyx_v_ndim, 0); + + /* "View.MemoryView":1330 + * + * refcount_copying(&dst, dtype_is_object, ndim, inc=False) + * copy_strided_to_strided(&src, &dst, ndim, itemsize) # <<<<<<<<<<<<<< + * refcount_copying(&dst, dtype_is_object, ndim, inc=True) + * + */ + copy_strided_to_strided((&__pyx_v_src), (&__pyx_v_dst), __pyx_v_ndim, __pyx_v_itemsize); + + /* "View.MemoryView":1331 + * refcount_copying(&dst, dtype_is_object, ndim, inc=False) + * copy_strided_to_strided(&src, &dst, ndim, itemsize) + * refcount_copying(&dst, dtype_is_object, ndim, inc=True) # <<<<<<<<<<<<<< + * + * free(tmpdata) + */ + __pyx_memoryview_refcount_copying((&__pyx_v_dst), __pyx_v_dtype_is_object, __pyx_v_ndim, 1); + + /* "View.MemoryView":1333 + * refcount_copying(&dst, dtype_is_object, ndim, inc=True) + * + * free(tmpdata) # <<<<<<<<<<<<<< + * return 0 + * + */ + free(__pyx_v_tmpdata); + + /* "View.MemoryView":1334 + * + * free(tmpdata) + * return 0 # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_broadcast_leading') + */ + __pyx_r = 0; + goto __pyx_L0; + + /* "View.MemoryView":1265 + * + * @cname('__pyx_memoryview_copy_contents') + * cdef int memoryview_copy_contents(__Pyx_memviewslice src, # <<<<<<<<<<<<<< + * __Pyx_memviewslice dst, + * int src_ndim, int dst_ndim, + */ + + /* function exit code */ + __pyx_L1_error:; + #ifdef WITH_THREAD + __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + __Pyx_AddTraceback("View.MemoryView.memoryview_copy_contents", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":1337 + * + * @cname('__pyx_memoryview_broadcast_leading') + * cdef void broadcast_leading(__Pyx_memviewslice *mslice, # <<<<<<<<<<<<<< + * int ndim, + * int ndim_other) noexcept nogil: + */ + +static void __pyx_memoryview_broadcast_leading(__Pyx_memviewslice *__pyx_v_mslice, int __pyx_v_ndim, int __pyx_v_ndim_other) { + int __pyx_v_i; + int __pyx_v_offset; + int __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + + /* "View.MemoryView":1341 + * int ndim_other) noexcept nogil: + * cdef int i + * cdef int offset = ndim_other - ndim # <<<<<<<<<<<<<< + * + * for i in range(ndim - 1, -1, -1): + */ + __pyx_v_offset = (__pyx_v_ndim_other - __pyx_v_ndim); + + /* "View.MemoryView":1343 + * cdef int offset = ndim_other - ndim + * + * for i in range(ndim - 1, -1, -1): # <<<<<<<<<<<<<< + * mslice.shape[i + offset] = mslice.shape[i] + * mslice.strides[i + offset] = mslice.strides[i] + */ + for (__pyx_t_1 = (__pyx_v_ndim - 1); __pyx_t_1 > -1; __pyx_t_1-=1) { + __pyx_v_i = __pyx_t_1; + + /* "View.MemoryView":1344 + * + * for i in range(ndim - 1, -1, -1): + * mslice.shape[i + offset] = mslice.shape[i] # <<<<<<<<<<<<<< + * mslice.strides[i + offset] = mslice.strides[i] + * mslice.suboffsets[i + offset] = mslice.suboffsets[i] + */ + (__pyx_v_mslice->shape[(__pyx_v_i + __pyx_v_offset)]) = (__pyx_v_mslice->shape[__pyx_v_i]); + + /* "View.MemoryView":1345 + * for i in range(ndim - 1, -1, -1): + * mslice.shape[i + offset] = mslice.shape[i] + * mslice.strides[i + offset] = mslice.strides[i] # <<<<<<<<<<<<<< + * mslice.suboffsets[i + offset] = mslice.suboffsets[i] + * + */ + (__pyx_v_mslice->strides[(__pyx_v_i + __pyx_v_offset)]) = (__pyx_v_mslice->strides[__pyx_v_i]); + + /* "View.MemoryView":1346 + * mslice.shape[i + offset] = mslice.shape[i] + * mslice.strides[i + offset] = mslice.strides[i] + * mslice.suboffsets[i + offset] = mslice.suboffsets[i] # <<<<<<<<<<<<<< + * + * for i in range(offset): + */ + (__pyx_v_mslice->suboffsets[(__pyx_v_i + __pyx_v_offset)]) = (__pyx_v_mslice->suboffsets[__pyx_v_i]); + } + + /* "View.MemoryView":1348 + * mslice.suboffsets[i + offset] = mslice.suboffsets[i] + * + * for i in range(offset): # <<<<<<<<<<<<<< + * mslice.shape[i] = 1 + * mslice.strides[i] = mslice.strides[0] + */ + __pyx_t_1 = __pyx_v_offset; + __pyx_t_2 = __pyx_t_1; + for (__pyx_t_3 = 0; __pyx_t_3 < __pyx_t_2; __pyx_t_3+=1) { + __pyx_v_i = __pyx_t_3; + + /* "View.MemoryView":1349 + * + * for i in range(offset): + * mslice.shape[i] = 1 # <<<<<<<<<<<<<< + * mslice.strides[i] = mslice.strides[0] + * mslice.suboffsets[i] = -1 + */ + (__pyx_v_mslice->shape[__pyx_v_i]) = 1; + + /* "View.MemoryView":1350 + * for i in range(offset): + * mslice.shape[i] = 1 + * mslice.strides[i] = mslice.strides[0] # <<<<<<<<<<<<<< + * mslice.suboffsets[i] = -1 + * + */ + (__pyx_v_mslice->strides[__pyx_v_i]) = (__pyx_v_mslice->strides[0]); + + /* "View.MemoryView":1351 + * mslice.shape[i] = 1 + * mslice.strides[i] = mslice.strides[0] + * mslice.suboffsets[i] = -1 # <<<<<<<<<<<<<< + * + * + */ + (__pyx_v_mslice->suboffsets[__pyx_v_i]) = -1L; + } + + /* "View.MemoryView":1337 + * + * @cname('__pyx_memoryview_broadcast_leading') + * cdef void broadcast_leading(__Pyx_memviewslice *mslice, # <<<<<<<<<<<<<< + * int ndim, + * int ndim_other) noexcept nogil: + */ + + /* function exit code */ +} + +/* "View.MemoryView":1359 + * + * @cname('__pyx_memoryview_refcount_copying') + * cdef void refcount_copying(__Pyx_memviewslice *dst, bint dtype_is_object, int ndim, bint inc) noexcept nogil: # <<<<<<<<<<<<<< + * + * if dtype_is_object: + */ + +static void __pyx_memoryview_refcount_copying(__Pyx_memviewslice *__pyx_v_dst, int __pyx_v_dtype_is_object, int __pyx_v_ndim, int __pyx_v_inc) { + + /* "View.MemoryView":1361 + * cdef void refcount_copying(__Pyx_memviewslice *dst, bint dtype_is_object, int ndim, bint inc) noexcept nogil: + * + * if dtype_is_object: # <<<<<<<<<<<<<< + * refcount_objects_in_slice_with_gil(dst.data, dst.shape, dst.strides, ndim, inc) + * + */ + if (__pyx_v_dtype_is_object) { + + /* "View.MemoryView":1362 + * + * if dtype_is_object: + * refcount_objects_in_slice_with_gil(dst.data, dst.shape, dst.strides, ndim, inc) # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_refcount_objects_in_slice_with_gil') + */ + __pyx_memoryview_refcount_objects_in_slice_with_gil(__pyx_v_dst->data, __pyx_v_dst->shape, __pyx_v_dst->strides, __pyx_v_ndim, __pyx_v_inc); + + /* "View.MemoryView":1361 + * cdef void refcount_copying(__Pyx_memviewslice *dst, bint dtype_is_object, int ndim, bint inc) noexcept nogil: + * + * if dtype_is_object: # <<<<<<<<<<<<<< + * refcount_objects_in_slice_with_gil(dst.data, dst.shape, dst.strides, ndim, inc) + * + */ + } + + /* "View.MemoryView":1359 + * + * @cname('__pyx_memoryview_refcount_copying') + * cdef void refcount_copying(__Pyx_memviewslice *dst, bint dtype_is_object, int ndim, bint inc) noexcept nogil: # <<<<<<<<<<<<<< + * + * if dtype_is_object: + */ + + /* function exit code */ +} + +/* "View.MemoryView":1365 + * + * @cname('__pyx_memoryview_refcount_objects_in_slice_with_gil') + * cdef void refcount_objects_in_slice_with_gil(char *data, Py_ssize_t *shape, # <<<<<<<<<<<<<< + * Py_ssize_t *strides, int ndim, + * bint inc) noexcept with gil: + */ + +static void __pyx_memoryview_refcount_objects_in_slice_with_gil(char *__pyx_v_data, Py_ssize_t *__pyx_v_shape, Py_ssize_t *__pyx_v_strides, int __pyx_v_ndim, int __pyx_v_inc) { + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + + /* "View.MemoryView":1368 + * Py_ssize_t *strides, int ndim, + * bint inc) noexcept with gil: + * refcount_objects_in_slice(data, shape, strides, ndim, inc) # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_refcount_objects_in_slice') + */ + __pyx_memoryview_refcount_objects_in_slice(__pyx_v_data, __pyx_v_shape, __pyx_v_strides, __pyx_v_ndim, __pyx_v_inc); + + /* "View.MemoryView":1365 + * + * @cname('__pyx_memoryview_refcount_objects_in_slice_with_gil') + * cdef void refcount_objects_in_slice_with_gil(char *data, Py_ssize_t *shape, # <<<<<<<<<<<<<< + * Py_ssize_t *strides, int ndim, + * bint inc) noexcept with gil: + */ + + /* function exit code */ + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif +} + +/* "View.MemoryView":1371 + * + * @cname('__pyx_memoryview_refcount_objects_in_slice') + * cdef void refcount_objects_in_slice(char *data, Py_ssize_t *shape, # <<<<<<<<<<<<<< + * Py_ssize_t *strides, int ndim, bint inc) noexcept: + * cdef Py_ssize_t i + */ + +static void __pyx_memoryview_refcount_objects_in_slice(char *__pyx_v_data, Py_ssize_t *__pyx_v_shape, Py_ssize_t *__pyx_v_strides, int __pyx_v_ndim, int __pyx_v_inc) { + CYTHON_UNUSED Py_ssize_t __pyx_v_i; + Py_ssize_t __pyx_v_stride; + Py_ssize_t __pyx_t_1; + Py_ssize_t __pyx_t_2; + Py_ssize_t __pyx_t_3; + int __pyx_t_4; + + /* "View.MemoryView":1374 + * Py_ssize_t *strides, int ndim, bint inc) noexcept: + * cdef Py_ssize_t i + * cdef Py_ssize_t stride = strides[0] # <<<<<<<<<<<<<< + * + * for i in range(shape[0]): + */ + __pyx_v_stride = (__pyx_v_strides[0]); + + /* "View.MemoryView":1376 + * cdef Py_ssize_t stride = strides[0] + * + * for i in range(shape[0]): # <<<<<<<<<<<<<< + * if ndim == 1: + * if inc: + */ + __pyx_t_1 = (__pyx_v_shape[0]); + __pyx_t_2 = __pyx_t_1; + for (__pyx_t_3 = 0; __pyx_t_3 < __pyx_t_2; __pyx_t_3+=1) { + __pyx_v_i = __pyx_t_3; + + /* "View.MemoryView":1377 + * + * for i in range(shape[0]): + * if ndim == 1: # <<<<<<<<<<<<<< + * if inc: + * Py_INCREF(( data)[0]) + */ + __pyx_t_4 = (__pyx_v_ndim == 1); + if (__pyx_t_4) { + + /* "View.MemoryView":1378 + * for i in range(shape[0]): + * if ndim == 1: + * if inc: # <<<<<<<<<<<<<< + * Py_INCREF(( data)[0]) + * else: + */ + if (__pyx_v_inc) { + + /* "View.MemoryView":1379 + * if ndim == 1: + * if inc: + * Py_INCREF(( data)[0]) # <<<<<<<<<<<<<< + * else: + * Py_DECREF(( data)[0]) + */ + Py_INCREF((((PyObject **)__pyx_v_data)[0])); + + /* "View.MemoryView":1378 + * for i in range(shape[0]): + * if ndim == 1: + * if inc: # <<<<<<<<<<<<<< + * Py_INCREF(( data)[0]) + * else: + */ + goto __pyx_L6; + } + + /* "View.MemoryView":1381 + * Py_INCREF(( data)[0]) + * else: + * Py_DECREF(( data)[0]) # <<<<<<<<<<<<<< + * else: + * refcount_objects_in_slice(data, shape + 1, strides + 1, ndim - 1, inc) + */ + /*else*/ { + Py_DECREF((((PyObject **)__pyx_v_data)[0])); + } + __pyx_L6:; + + /* "View.MemoryView":1377 + * + * for i in range(shape[0]): + * if ndim == 1: # <<<<<<<<<<<<<< + * if inc: + * Py_INCREF(( data)[0]) + */ + goto __pyx_L5; + } + + /* "View.MemoryView":1383 + * Py_DECREF(( data)[0]) + * else: + * refcount_objects_in_slice(data, shape + 1, strides + 1, ndim - 1, inc) # <<<<<<<<<<<<<< + * + * data += stride + */ + /*else*/ { + __pyx_memoryview_refcount_objects_in_slice(__pyx_v_data, (__pyx_v_shape + 1), (__pyx_v_strides + 1), (__pyx_v_ndim - 1), __pyx_v_inc); + } + __pyx_L5:; + + /* "View.MemoryView":1385 + * refcount_objects_in_slice(data, shape + 1, strides + 1, ndim - 1, inc) + * + * data += stride # <<<<<<<<<<<<<< + * + * + */ + __pyx_v_data = (__pyx_v_data + __pyx_v_stride); + } + + /* "View.MemoryView":1371 + * + * @cname('__pyx_memoryview_refcount_objects_in_slice') + * cdef void refcount_objects_in_slice(char *data, Py_ssize_t *shape, # <<<<<<<<<<<<<< + * Py_ssize_t *strides, int ndim, bint inc) noexcept: + * cdef Py_ssize_t i + */ + + /* function exit code */ +} + +/* "View.MemoryView":1391 + * + * @cname('__pyx_memoryview_slice_assign_scalar') + * cdef void slice_assign_scalar(__Pyx_memviewslice *dst, int ndim, # <<<<<<<<<<<<<< + * size_t itemsize, void *item, + * bint dtype_is_object) noexcept nogil: + */ + +static void __pyx_memoryview_slice_assign_scalar(__Pyx_memviewslice *__pyx_v_dst, int __pyx_v_ndim, size_t __pyx_v_itemsize, void *__pyx_v_item, int __pyx_v_dtype_is_object) { + + /* "View.MemoryView":1394 + * size_t itemsize, void *item, + * bint dtype_is_object) noexcept nogil: + * refcount_copying(dst, dtype_is_object, ndim, inc=False) # <<<<<<<<<<<<<< + * _slice_assign_scalar(dst.data, dst.shape, dst.strides, ndim, itemsize, item) + * refcount_copying(dst, dtype_is_object, ndim, inc=True) + */ + __pyx_memoryview_refcount_copying(__pyx_v_dst, __pyx_v_dtype_is_object, __pyx_v_ndim, 0); + + /* "View.MemoryView":1395 + * bint dtype_is_object) noexcept nogil: + * refcount_copying(dst, dtype_is_object, ndim, inc=False) + * _slice_assign_scalar(dst.data, dst.shape, dst.strides, ndim, itemsize, item) # <<<<<<<<<<<<<< + * refcount_copying(dst, dtype_is_object, ndim, inc=True) + * + */ + __pyx_memoryview__slice_assign_scalar(__pyx_v_dst->data, __pyx_v_dst->shape, __pyx_v_dst->strides, __pyx_v_ndim, __pyx_v_itemsize, __pyx_v_item); + + /* "View.MemoryView":1396 + * refcount_copying(dst, dtype_is_object, ndim, inc=False) + * _slice_assign_scalar(dst.data, dst.shape, dst.strides, ndim, itemsize, item) + * refcount_copying(dst, dtype_is_object, ndim, inc=True) # <<<<<<<<<<<<<< + * + * + */ + __pyx_memoryview_refcount_copying(__pyx_v_dst, __pyx_v_dtype_is_object, __pyx_v_ndim, 1); + + /* "View.MemoryView":1391 + * + * @cname('__pyx_memoryview_slice_assign_scalar') + * cdef void slice_assign_scalar(__Pyx_memviewslice *dst, int ndim, # <<<<<<<<<<<<<< + * size_t itemsize, void *item, + * bint dtype_is_object) noexcept nogil: + */ + + /* function exit code */ +} + +/* "View.MemoryView":1400 + * + * @cname('__pyx_memoryview__slice_assign_scalar') + * cdef void _slice_assign_scalar(char *data, Py_ssize_t *shape, # <<<<<<<<<<<<<< + * Py_ssize_t *strides, int ndim, + * size_t itemsize, void *item) noexcept nogil: + */ + +static void __pyx_memoryview__slice_assign_scalar(char *__pyx_v_data, Py_ssize_t *__pyx_v_shape, Py_ssize_t *__pyx_v_strides, int __pyx_v_ndim, size_t __pyx_v_itemsize, void *__pyx_v_item) { + CYTHON_UNUSED Py_ssize_t __pyx_v_i; + Py_ssize_t __pyx_v_stride; + Py_ssize_t __pyx_v_extent; + int __pyx_t_1; + Py_ssize_t __pyx_t_2; + Py_ssize_t __pyx_t_3; + Py_ssize_t __pyx_t_4; + + /* "View.MemoryView":1404 + * size_t itemsize, void *item) noexcept nogil: + * cdef Py_ssize_t i + * cdef Py_ssize_t stride = strides[0] # <<<<<<<<<<<<<< + * cdef Py_ssize_t extent = shape[0] + * + */ + __pyx_v_stride = (__pyx_v_strides[0]); + + /* "View.MemoryView":1405 + * cdef Py_ssize_t i + * cdef Py_ssize_t stride = strides[0] + * cdef Py_ssize_t extent = shape[0] # <<<<<<<<<<<<<< + * + * if ndim == 1: + */ + __pyx_v_extent = (__pyx_v_shape[0]); + + /* "View.MemoryView":1407 + * cdef Py_ssize_t extent = shape[0] + * + * if ndim == 1: # <<<<<<<<<<<<<< + * for i in range(extent): + * memcpy(data, item, itemsize) + */ + __pyx_t_1 = (__pyx_v_ndim == 1); + if (__pyx_t_1) { + + /* "View.MemoryView":1408 + * + * if ndim == 1: + * for i in range(extent): # <<<<<<<<<<<<<< + * memcpy(data, item, itemsize) + * data += stride + */ + __pyx_t_2 = __pyx_v_extent; + __pyx_t_3 = __pyx_t_2; + for (__pyx_t_4 = 0; __pyx_t_4 < __pyx_t_3; __pyx_t_4+=1) { + __pyx_v_i = __pyx_t_4; + + /* "View.MemoryView":1409 + * if ndim == 1: + * for i in range(extent): + * memcpy(data, item, itemsize) # <<<<<<<<<<<<<< + * data += stride + * else: + */ + (void)(memcpy(__pyx_v_data, __pyx_v_item, __pyx_v_itemsize)); + + /* "View.MemoryView":1410 + * for i in range(extent): + * memcpy(data, item, itemsize) + * data += stride # <<<<<<<<<<<<<< + * else: + * for i in range(extent): + */ + __pyx_v_data = (__pyx_v_data + __pyx_v_stride); + } + + /* "View.MemoryView":1407 + * cdef Py_ssize_t extent = shape[0] + * + * if ndim == 1: # <<<<<<<<<<<<<< + * for i in range(extent): + * memcpy(data, item, itemsize) + */ + goto __pyx_L3; + } + + /* "View.MemoryView":1412 + * data += stride + * else: + * for i in range(extent): # <<<<<<<<<<<<<< + * _slice_assign_scalar(data, shape + 1, strides + 1, ndim - 1, itemsize, item) + * data += stride + */ + /*else*/ { + __pyx_t_2 = __pyx_v_extent; + __pyx_t_3 = __pyx_t_2; + for (__pyx_t_4 = 0; __pyx_t_4 < __pyx_t_3; __pyx_t_4+=1) { + __pyx_v_i = __pyx_t_4; + + /* "View.MemoryView":1413 + * else: + * for i in range(extent): + * _slice_assign_scalar(data, shape + 1, strides + 1, ndim - 1, itemsize, item) # <<<<<<<<<<<<<< + * data += stride + * + */ + __pyx_memoryview__slice_assign_scalar(__pyx_v_data, (__pyx_v_shape + 1), (__pyx_v_strides + 1), (__pyx_v_ndim - 1), __pyx_v_itemsize, __pyx_v_item); + + /* "View.MemoryView":1414 + * for i in range(extent): + * _slice_assign_scalar(data, shape + 1, strides + 1, ndim - 1, itemsize, item) + * data += stride # <<<<<<<<<<<<<< + * + * + */ + __pyx_v_data = (__pyx_v_data + __pyx_v_stride); + } + } + __pyx_L3:; + + /* "View.MemoryView":1400 + * + * @cname('__pyx_memoryview__slice_assign_scalar') + * cdef void _slice_assign_scalar(char *data, Py_ssize_t *shape, # <<<<<<<<<<<<<< + * Py_ssize_t *strides, int ndim, + * size_t itemsize, void *item) noexcept nogil: + */ + + /* function exit code */ +} + +/* "(tree fragment)":1 + * def __pyx_unpickle_Enum(__pyx_type, long __pyx_checksum, __pyx_state): # <<<<<<<<<<<<<< + * cdef object __pyx_PickleError + * cdef object __pyx_result + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_1__pyx_unpickle_Enum(PyObject *__pyx_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyMethodDef __pyx_mdef_15View_dot_MemoryView_1__pyx_unpickle_Enum = {"__pyx_unpickle_Enum", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw_15View_dot_MemoryView_1__pyx_unpickle_Enum, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}; +static PyObject *__pyx_pw_15View_dot_MemoryView_1__pyx_unpickle_Enum(PyObject *__pyx_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + PyObject *__pyx_v___pyx_type = 0; + long __pyx_v___pyx_checksum; + PyObject *__pyx_v___pyx_state = 0; + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[3] = {0,0,0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__pyx_unpickle_Enum (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_pyx_type,&__pyx_n_s_pyx_checksum,&__pyx_n_s_pyx_state,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 3: values[2] = __Pyx_Arg_FASTCALL(__pyx_args, 2); + CYTHON_FALLTHROUGH; + case 2: values[1] = __Pyx_Arg_FASTCALL(__pyx_args, 1); + CYTHON_FALLTHROUGH; + case 1: values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_FASTCALL(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_pyx_type)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 1, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + CYTHON_FALLTHROUGH; + case 1: + if (likely((values[1] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_pyx_checksum)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[1]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 1, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("__pyx_unpickle_Enum", 1, 3, 3, 1); __PYX_ERR(1, 1, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 2: + if (likely((values[2] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_pyx_state)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[2]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 1, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("__pyx_unpickle_Enum", 1, 3, 3, 2); __PYX_ERR(1, 1, __pyx_L3_error) + } + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "__pyx_unpickle_Enum") < 0)) __PYX_ERR(1, 1, __pyx_L3_error) + } + } else if (unlikely(__pyx_nargs != 3)) { + goto __pyx_L5_argtuple_error; + } else { + values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + values[1] = __Pyx_Arg_FASTCALL(__pyx_args, 1); + values[2] = __Pyx_Arg_FASTCALL(__pyx_args, 2); + } + __pyx_v___pyx_type = values[0]; + __pyx_v___pyx_checksum = __Pyx_PyInt_As_long(values[1]); if (unlikely((__pyx_v___pyx_checksum == (long)-1) && PyErr_Occurred())) __PYX_ERR(1, 1, __pyx_L3_error) + __pyx_v___pyx_state = values[2]; + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("__pyx_unpickle_Enum", 1, 3, 3, __pyx_nargs); __PYX_ERR(1, 1, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("View.MemoryView.__pyx_unpickle_Enum", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return NULL; + __pyx_L4_argument_unpacking_done:; + __pyx_r = __pyx_pf_15View_dot_MemoryView___pyx_unpickle_Enum(__pyx_self, __pyx_v___pyx_type, __pyx_v___pyx_checksum, __pyx_v___pyx_state); + + /* function exit code */ + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView___pyx_unpickle_Enum(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v___pyx_type, long __pyx_v___pyx_checksum, PyObject *__pyx_v___pyx_state) { + PyObject *__pyx_v___pyx_PickleError = 0; + PyObject *__pyx_v___pyx_result = 0; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + int __pyx_t_5; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__pyx_unpickle_Enum", 1); + + /* "(tree fragment)":4 + * cdef object __pyx_PickleError + * cdef object __pyx_result + * if __pyx_checksum not in (0x82a3537, 0x6ae9995, 0xb068931): # <<<<<<<<<<<<<< + * from pickle import PickleError as __pyx_PickleError + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x82a3537, 0x6ae9995, 0xb068931) = (name))" % __pyx_checksum + */ + __pyx_t_1 = __Pyx_PyInt_From_long(__pyx_v___pyx_checksum); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 4, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = (__Pyx_PySequence_ContainsTF(__pyx_t_1, __pyx_tuple__8, Py_NE)); if (unlikely((__pyx_t_2 < 0))) __PYX_ERR(1, 4, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + if (__pyx_t_2) { + + /* "(tree fragment)":5 + * cdef object __pyx_result + * if __pyx_checksum not in (0x82a3537, 0x6ae9995, 0xb068931): + * from pickle import PickleError as __pyx_PickleError # <<<<<<<<<<<<<< + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x82a3537, 0x6ae9995, 0xb068931) = (name))" % __pyx_checksum + * __pyx_result = Enum.__new__(__pyx_type) + */ + __pyx_t_1 = PyList_New(1); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 5, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_INCREF(__pyx_n_s_PickleError); + __Pyx_GIVEREF(__pyx_n_s_PickleError); + if (__Pyx_PyList_SET_ITEM(__pyx_t_1, 0, __pyx_n_s_PickleError)) __PYX_ERR(1, 5, __pyx_L1_error); + __pyx_t_3 = __Pyx_Import(__pyx_n_s_pickle, __pyx_t_1, 0); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 5, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_t_1 = __Pyx_ImportFrom(__pyx_t_3, __pyx_n_s_PickleError); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 5, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_INCREF(__pyx_t_1); + __pyx_v___pyx_PickleError = __pyx_t_1; + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + + /* "(tree fragment)":6 + * if __pyx_checksum not in (0x82a3537, 0x6ae9995, 0xb068931): + * from pickle import PickleError as __pyx_PickleError + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x82a3537, 0x6ae9995, 0xb068931) = (name))" % __pyx_checksum # <<<<<<<<<<<<<< + * __pyx_result = Enum.__new__(__pyx_type) + * if __pyx_state is not None: + */ + __pyx_t_3 = __Pyx_PyInt_From_long(__pyx_v___pyx_checksum); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 6, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_1 = __Pyx_PyString_Format(__pyx_kp_s_Incompatible_checksums_0x_x_vs_0, __pyx_t_3); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 6, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __Pyx_Raise(__pyx_v___pyx_PickleError, __pyx_t_1, 0, 0); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __PYX_ERR(1, 6, __pyx_L1_error) + + /* "(tree fragment)":4 + * cdef object __pyx_PickleError + * cdef object __pyx_result + * if __pyx_checksum not in (0x82a3537, 0x6ae9995, 0xb068931): # <<<<<<<<<<<<<< + * from pickle import PickleError as __pyx_PickleError + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x82a3537, 0x6ae9995, 0xb068931) = (name))" % __pyx_checksum + */ + } + + /* "(tree fragment)":7 + * from pickle import PickleError as __pyx_PickleError + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x82a3537, 0x6ae9995, 0xb068931) = (name))" % __pyx_checksum + * __pyx_result = Enum.__new__(__pyx_type) # <<<<<<<<<<<<<< + * if __pyx_state is not None: + * __pyx_unpickle_Enum__set_state( __pyx_result, __pyx_state) + */ + __pyx_t_3 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_MemviewEnum_type), __pyx_n_s_new); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 7, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_4 = NULL; + __pyx_t_5 = 0; + #if CYTHON_UNPACK_METHODS + if (likely(PyMethod_Check(__pyx_t_3))) { + __pyx_t_4 = PyMethod_GET_SELF(__pyx_t_3); + if (likely(__pyx_t_4)) { + PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_3); + __Pyx_INCREF(__pyx_t_4); + __Pyx_INCREF(function); + __Pyx_DECREF_SET(__pyx_t_3, function); + __pyx_t_5 = 1; + } + } + #endif + { + PyObject *__pyx_callargs[2] = {__pyx_t_4, __pyx_v___pyx_type}; + __pyx_t_1 = __Pyx_PyObject_FastCall(__pyx_t_3, __pyx_callargs+1-__pyx_t_5, 1+__pyx_t_5); + __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; + if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 7, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + } + __pyx_v___pyx_result = __pyx_t_1; + __pyx_t_1 = 0; + + /* "(tree fragment)":8 + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x82a3537, 0x6ae9995, 0xb068931) = (name))" % __pyx_checksum + * __pyx_result = Enum.__new__(__pyx_type) + * if __pyx_state is not None: # <<<<<<<<<<<<<< + * __pyx_unpickle_Enum__set_state( __pyx_result, __pyx_state) + * return __pyx_result + */ + __pyx_t_2 = (__pyx_v___pyx_state != Py_None); + if (__pyx_t_2) { + + /* "(tree fragment)":9 + * __pyx_result = Enum.__new__(__pyx_type) + * if __pyx_state is not None: + * __pyx_unpickle_Enum__set_state( __pyx_result, __pyx_state) # <<<<<<<<<<<<<< + * return __pyx_result + * cdef __pyx_unpickle_Enum__set_state(Enum __pyx_result, tuple __pyx_state): + */ + if (!(likely(PyTuple_CheckExact(__pyx_v___pyx_state))||((__pyx_v___pyx_state) == Py_None) || __Pyx_RaiseUnexpectedTypeError("tuple", __pyx_v___pyx_state))) __PYX_ERR(1, 9, __pyx_L1_error) + __pyx_t_1 = __pyx_unpickle_Enum__set_state(((struct __pyx_MemviewEnum_obj *)__pyx_v___pyx_result), ((PyObject*)__pyx_v___pyx_state)); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 9, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + + /* "(tree fragment)":8 + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x82a3537, 0x6ae9995, 0xb068931) = (name))" % __pyx_checksum + * __pyx_result = Enum.__new__(__pyx_type) + * if __pyx_state is not None: # <<<<<<<<<<<<<< + * __pyx_unpickle_Enum__set_state( __pyx_result, __pyx_state) + * return __pyx_result + */ + } + + /* "(tree fragment)":10 + * if __pyx_state is not None: + * __pyx_unpickle_Enum__set_state( __pyx_result, __pyx_state) + * return __pyx_result # <<<<<<<<<<<<<< + * cdef __pyx_unpickle_Enum__set_state(Enum __pyx_result, tuple __pyx_state): + * __pyx_result.name = __pyx_state[0] + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(__pyx_v___pyx_result); + __pyx_r = __pyx_v___pyx_result; + goto __pyx_L0; + + /* "(tree fragment)":1 + * def __pyx_unpickle_Enum(__pyx_type, long __pyx_checksum, __pyx_state): # <<<<<<<<<<<<<< + * cdef object __pyx_PickleError + * cdef object __pyx_result + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_XDECREF(__pyx_t_4); + __Pyx_AddTraceback("View.MemoryView.__pyx_unpickle_Enum", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v___pyx_PickleError); + __Pyx_XDECREF(__pyx_v___pyx_result); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":11 + * __pyx_unpickle_Enum__set_state( __pyx_result, __pyx_state) + * return __pyx_result + * cdef __pyx_unpickle_Enum__set_state(Enum __pyx_result, tuple __pyx_state): # <<<<<<<<<<<<<< + * __pyx_result.name = __pyx_state[0] + * if len(__pyx_state) > 1 and hasattr(__pyx_result, '__dict__'): + */ + +static PyObject *__pyx_unpickle_Enum__set_state(struct __pyx_MemviewEnum_obj *__pyx_v___pyx_result, PyObject *__pyx_v___pyx_state) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_t_2; + Py_ssize_t __pyx_t_3; + int __pyx_t_4; + PyObject *__pyx_t_5 = NULL; + PyObject *__pyx_t_6 = NULL; + PyObject *__pyx_t_7 = NULL; + int __pyx_t_8; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__pyx_unpickle_Enum__set_state", 1); + + /* "(tree fragment)":12 + * return __pyx_result + * cdef __pyx_unpickle_Enum__set_state(Enum __pyx_result, tuple __pyx_state): + * __pyx_result.name = __pyx_state[0] # <<<<<<<<<<<<<< + * if len(__pyx_state) > 1 and hasattr(__pyx_result, '__dict__'): + * __pyx_result.__dict__.update(__pyx_state[1]) + */ + if (unlikely(__pyx_v___pyx_state == Py_None)) { + PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); + __PYX_ERR(1, 12, __pyx_L1_error) + } + __pyx_t_1 = __Pyx_GetItemInt_Tuple(__pyx_v___pyx_state, 0, long, 1, __Pyx_PyInt_From_long, 0, 0, 1); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 12, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_GIVEREF(__pyx_t_1); + __Pyx_GOTREF(__pyx_v___pyx_result->name); + __Pyx_DECREF(__pyx_v___pyx_result->name); + __pyx_v___pyx_result->name = __pyx_t_1; + __pyx_t_1 = 0; + + /* "(tree fragment)":13 + * cdef __pyx_unpickle_Enum__set_state(Enum __pyx_result, tuple __pyx_state): + * __pyx_result.name = __pyx_state[0] + * if len(__pyx_state) > 1 and hasattr(__pyx_result, '__dict__'): # <<<<<<<<<<<<<< + * __pyx_result.__dict__.update(__pyx_state[1]) + */ + if (unlikely(__pyx_v___pyx_state == Py_None)) { + PyErr_SetString(PyExc_TypeError, "object of type 'NoneType' has no len()"); + __PYX_ERR(1, 13, __pyx_L1_error) + } + __pyx_t_3 = __Pyx_PyTuple_GET_SIZE(__pyx_v___pyx_state); if (unlikely(__pyx_t_3 == ((Py_ssize_t)-1))) __PYX_ERR(1, 13, __pyx_L1_error) + __pyx_t_4 = (__pyx_t_3 > 1); + if (__pyx_t_4) { + } else { + __pyx_t_2 = __pyx_t_4; + goto __pyx_L4_bool_binop_done; + } + __pyx_t_4 = __Pyx_HasAttr(((PyObject *)__pyx_v___pyx_result), __pyx_n_s_dict); if (unlikely(__pyx_t_4 == ((int)-1))) __PYX_ERR(1, 13, __pyx_L1_error) + __pyx_t_2 = __pyx_t_4; + __pyx_L4_bool_binop_done:; + if (__pyx_t_2) { + + /* "(tree fragment)":14 + * __pyx_result.name = __pyx_state[0] + * if len(__pyx_state) > 1 and hasattr(__pyx_result, '__dict__'): + * __pyx_result.__dict__.update(__pyx_state[1]) # <<<<<<<<<<<<<< + */ + __pyx_t_5 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v___pyx_result), __pyx_n_s_dict); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 14, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_6 = __Pyx_PyObject_GetAttrStr(__pyx_t_5, __pyx_n_s_update); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 14, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + if (unlikely(__pyx_v___pyx_state == Py_None)) { + PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); + __PYX_ERR(1, 14, __pyx_L1_error) + } + __pyx_t_5 = __Pyx_GetItemInt_Tuple(__pyx_v___pyx_state, 1, long, 1, __Pyx_PyInt_From_long, 0, 0, 1); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 14, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_7 = NULL; + __pyx_t_8 = 0; + #if CYTHON_UNPACK_METHODS + if (likely(PyMethod_Check(__pyx_t_6))) { + __pyx_t_7 = PyMethod_GET_SELF(__pyx_t_6); + if (likely(__pyx_t_7)) { + PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_6); + __Pyx_INCREF(__pyx_t_7); + __Pyx_INCREF(function); + __Pyx_DECREF_SET(__pyx_t_6, function); + __pyx_t_8 = 1; + } + } + #endif + { + PyObject *__pyx_callargs[2] = {__pyx_t_7, __pyx_t_5}; + __pyx_t_1 = __Pyx_PyObject_FastCall(__pyx_t_6, __pyx_callargs+1-__pyx_t_8, 1+__pyx_t_8); + __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 14, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + } + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + + /* "(tree fragment)":13 + * cdef __pyx_unpickle_Enum__set_state(Enum __pyx_result, tuple __pyx_state): + * __pyx_result.name = __pyx_state[0] + * if len(__pyx_state) > 1 and hasattr(__pyx_result, '__dict__'): # <<<<<<<<<<<<<< + * __pyx_result.__dict__.update(__pyx_state[1]) + */ + } + + /* "(tree fragment)":11 + * __pyx_unpickle_Enum__set_state( __pyx_result, __pyx_state) + * return __pyx_result + * cdef __pyx_unpickle_Enum__set_state(Enum __pyx_result, tuple __pyx_state): # <<<<<<<<<<<<<< + * __pyx_result.name = __pyx_state[0] + * if len(__pyx_state) > 1 and hasattr(__pyx_result, '__dict__'): + */ + + /* function exit code */ + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_5); + __Pyx_XDECREF(__pyx_t_6); + __Pyx_XDECREF(__pyx_t_7); + __Pyx_AddTraceback("View.MemoryView.__pyx_unpickle_Enum__set_state", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":245 + * + * @property + * cdef inline PyObject* base(self) nogil: # <<<<<<<<<<<<<< + * """Returns a borrowed reference to the object owning the data/memory. + * """ + */ + +static CYTHON_INLINE PyObject *__pyx_f_5numpy_7ndarray_4base_base(PyArrayObject *__pyx_v_self) { + PyObject *__pyx_r; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":248 + * """Returns a borrowed reference to the object owning the data/memory. + * """ + * return PyArray_BASE(self) # <<<<<<<<<<<<<< + * + * @property + */ + __pyx_r = PyArray_BASE(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":245 + * + * @property + * cdef inline PyObject* base(self) nogil: # <<<<<<<<<<<<<< + * """Returns a borrowed reference to the object owning the data/memory. + * """ + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":251 + * + * @property + * cdef inline dtype descr(self): # <<<<<<<<<<<<<< + * """Returns an owned reference to the dtype of the array. + * """ + */ + +static CYTHON_INLINE PyArray_Descr *__pyx_f_5numpy_7ndarray_5descr_descr(PyArrayObject *__pyx_v_self) { + PyArray_Descr *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyArray_Descr *__pyx_t_1; + __Pyx_RefNannySetupContext("descr", 1); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":254 + * """Returns an owned reference to the dtype of the array. + * """ + * return PyArray_DESCR(self) # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF((PyObject *)__pyx_r); + __pyx_t_1 = PyArray_DESCR(__pyx_v_self); + __Pyx_INCREF((PyObject *)((PyArray_Descr *)__pyx_t_1)); + __pyx_r = ((PyArray_Descr *)__pyx_t_1); + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":251 + * + * @property + * cdef inline dtype descr(self): # <<<<<<<<<<<<<< + * """Returns an owned reference to the dtype of the array. + * """ + */ + + /* function exit code */ + __pyx_L0:; + __Pyx_XGIVEREF((PyObject *)__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":257 + * + * @property + * cdef inline int ndim(self) nogil: # <<<<<<<<<<<<<< + * """Returns the number of dimensions in the array. + * """ + */ + +static CYTHON_INLINE int __pyx_f_5numpy_7ndarray_4ndim_ndim(PyArrayObject *__pyx_v_self) { + int __pyx_r; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":260 + * """Returns the number of dimensions in the array. + * """ + * return PyArray_NDIM(self) # <<<<<<<<<<<<<< + * + * @property + */ + __pyx_r = PyArray_NDIM(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":257 + * + * @property + * cdef inline int ndim(self) nogil: # <<<<<<<<<<<<<< + * """Returns the number of dimensions in the array. + * """ + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":263 + * + * @property + * cdef inline npy_intp *shape(self) nogil: # <<<<<<<<<<<<<< + * """Returns a pointer to the dimensions/shape of the array. + * The number of elements matches the number of dimensions of the array (ndim). + */ + +static CYTHON_INLINE npy_intp *__pyx_f_5numpy_7ndarray_5shape_shape(PyArrayObject *__pyx_v_self) { + npy_intp *__pyx_r; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":268 + * Can return NULL for 0-dimensional arrays. + * """ + * return PyArray_DIMS(self) # <<<<<<<<<<<<<< + * + * @property + */ + __pyx_r = PyArray_DIMS(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":263 + * + * @property + * cdef inline npy_intp *shape(self) nogil: # <<<<<<<<<<<<<< + * """Returns a pointer to the dimensions/shape of the array. + * The number of elements matches the number of dimensions of the array (ndim). + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":271 + * + * @property + * cdef inline npy_intp *strides(self) nogil: # <<<<<<<<<<<<<< + * """Returns a pointer to the strides of the array. + * The number of elements matches the number of dimensions of the array (ndim). + */ + +static CYTHON_INLINE npy_intp *__pyx_f_5numpy_7ndarray_7strides_strides(PyArrayObject *__pyx_v_self) { + npy_intp *__pyx_r; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":275 + * The number of elements matches the number of dimensions of the array (ndim). + * """ + * return PyArray_STRIDES(self) # <<<<<<<<<<<<<< + * + * @property + */ + __pyx_r = PyArray_STRIDES(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":271 + * + * @property + * cdef inline npy_intp *strides(self) nogil: # <<<<<<<<<<<<<< + * """Returns a pointer to the strides of the array. + * The number of elements matches the number of dimensions of the array (ndim). + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":278 + * + * @property + * cdef inline npy_intp size(self) nogil: # <<<<<<<<<<<<<< + * """Returns the total size (in number of elements) of the array. + * """ + */ + +static CYTHON_INLINE npy_intp __pyx_f_5numpy_7ndarray_4size_size(PyArrayObject *__pyx_v_self) { + npy_intp __pyx_r; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":281 + * """Returns the total size (in number of elements) of the array. + * """ + * return PyArray_SIZE(self) # <<<<<<<<<<<<<< + * + * @property + */ + __pyx_r = PyArray_SIZE(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":278 + * + * @property + * cdef inline npy_intp size(self) nogil: # <<<<<<<<<<<<<< + * """Returns the total size (in number of elements) of the array. + * """ + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":284 + * + * @property + * cdef inline char* data(self) nogil: # <<<<<<<<<<<<<< + * """The pointer to the data buffer as a char*. + * This is provided for legacy reasons to avoid direct struct field access. + */ + +static CYTHON_INLINE char *__pyx_f_5numpy_7ndarray_4data_data(PyArrayObject *__pyx_v_self) { + char *__pyx_r; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":290 + * of `PyArray_DATA()` instead, which returns a 'void*'. + * """ + * return PyArray_BYTES(self) # <<<<<<<<<<<<<< + * + * ctypedef unsigned char npy_bool + */ + __pyx_r = PyArray_BYTES(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":284 + * + * @property + * cdef inline char* data(self) nogil: # <<<<<<<<<<<<<< + * """The pointer to the data buffer as a char*. + * This is provided for legacy reasons to avoid direct struct field access. + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":773 + * ctypedef npy_cdouble complex_t + * + * cdef inline object PyArray_MultiIterNew1(a): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(1, a) + * + */ + +static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyArray_MultiIterNew1(PyObject *__pyx_v_a) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("PyArray_MultiIterNew1", 1); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":774 + * + * cdef inline object PyArray_MultiIterNew1(a): + * return PyArray_MultiIterNew(1, a) # <<<<<<<<<<<<<< + * + * cdef inline object PyArray_MultiIterNew2(a, b): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = PyArray_MultiIterNew(1, ((void *)__pyx_v_a)); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 774, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":773 + * ctypedef npy_cdouble complex_t + * + * cdef inline object PyArray_MultiIterNew1(a): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(1, a) + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("numpy.PyArray_MultiIterNew1", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":776 + * return PyArray_MultiIterNew(1, a) + * + * cdef inline object PyArray_MultiIterNew2(a, b): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(2, a, b) + * + */ + +static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyArray_MultiIterNew2(PyObject *__pyx_v_a, PyObject *__pyx_v_b) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("PyArray_MultiIterNew2", 1); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":777 + * + * cdef inline object PyArray_MultiIterNew2(a, b): + * return PyArray_MultiIterNew(2, a, b) # <<<<<<<<<<<<<< + * + * cdef inline object PyArray_MultiIterNew3(a, b, c): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = PyArray_MultiIterNew(2, ((void *)__pyx_v_a), ((void *)__pyx_v_b)); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 777, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":776 + * return PyArray_MultiIterNew(1, a) + * + * cdef inline object PyArray_MultiIterNew2(a, b): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(2, a, b) + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("numpy.PyArray_MultiIterNew2", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":779 + * return PyArray_MultiIterNew(2, a, b) + * + * cdef inline object PyArray_MultiIterNew3(a, b, c): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(3, a, b, c) + * + */ + +static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyArray_MultiIterNew3(PyObject *__pyx_v_a, PyObject *__pyx_v_b, PyObject *__pyx_v_c) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("PyArray_MultiIterNew3", 1); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":780 + * + * cdef inline object PyArray_MultiIterNew3(a, b, c): + * return PyArray_MultiIterNew(3, a, b, c) # <<<<<<<<<<<<<< + * + * cdef inline object PyArray_MultiIterNew4(a, b, c, d): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = PyArray_MultiIterNew(3, ((void *)__pyx_v_a), ((void *)__pyx_v_b), ((void *)__pyx_v_c)); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 780, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":779 + * return PyArray_MultiIterNew(2, a, b) + * + * cdef inline object PyArray_MultiIterNew3(a, b, c): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(3, a, b, c) + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("numpy.PyArray_MultiIterNew3", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":782 + * return PyArray_MultiIterNew(3, a, b, c) + * + * cdef inline object PyArray_MultiIterNew4(a, b, c, d): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(4, a, b, c, d) + * + */ + +static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyArray_MultiIterNew4(PyObject *__pyx_v_a, PyObject *__pyx_v_b, PyObject *__pyx_v_c, PyObject *__pyx_v_d) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("PyArray_MultiIterNew4", 1); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":783 + * + * cdef inline object PyArray_MultiIterNew4(a, b, c, d): + * return PyArray_MultiIterNew(4, a, b, c, d) # <<<<<<<<<<<<<< + * + * cdef inline object PyArray_MultiIterNew5(a, b, c, d, e): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = PyArray_MultiIterNew(4, ((void *)__pyx_v_a), ((void *)__pyx_v_b), ((void *)__pyx_v_c), ((void *)__pyx_v_d)); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 783, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":782 + * return PyArray_MultiIterNew(3, a, b, c) + * + * cdef inline object PyArray_MultiIterNew4(a, b, c, d): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(4, a, b, c, d) + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("numpy.PyArray_MultiIterNew4", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":785 + * return PyArray_MultiIterNew(4, a, b, c, d) + * + * cdef inline object PyArray_MultiIterNew5(a, b, c, d, e): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(5, a, b, c, d, e) + * + */ + +static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyArray_MultiIterNew5(PyObject *__pyx_v_a, PyObject *__pyx_v_b, PyObject *__pyx_v_c, PyObject *__pyx_v_d, PyObject *__pyx_v_e) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("PyArray_MultiIterNew5", 1); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":786 + * + * cdef inline object PyArray_MultiIterNew5(a, b, c, d, e): + * return PyArray_MultiIterNew(5, a, b, c, d, e) # <<<<<<<<<<<<<< + * + * cdef inline tuple PyDataType_SHAPE(dtype d): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = PyArray_MultiIterNew(5, ((void *)__pyx_v_a), ((void *)__pyx_v_b), ((void *)__pyx_v_c), ((void *)__pyx_v_d), ((void *)__pyx_v_e)); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 786, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":785 + * return PyArray_MultiIterNew(4, a, b, c, d) + * + * cdef inline object PyArray_MultiIterNew5(a, b, c, d, e): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(5, a, b, c, d, e) + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("numpy.PyArray_MultiIterNew5", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":788 + * return PyArray_MultiIterNew(5, a, b, c, d, e) + * + * cdef inline tuple PyDataType_SHAPE(dtype d): # <<<<<<<<<<<<<< + * if PyDataType_HASSUBARRAY(d): + * return d.subarray.shape + */ + +static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyDataType_SHAPE(PyArray_Descr *__pyx_v_d) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + __Pyx_RefNannySetupContext("PyDataType_SHAPE", 1); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":789 + * + * cdef inline tuple PyDataType_SHAPE(dtype d): + * if PyDataType_HASSUBARRAY(d): # <<<<<<<<<<<<<< + * return d.subarray.shape + * else: + */ + __pyx_t_1 = PyDataType_HASSUBARRAY(__pyx_v_d); + if (__pyx_t_1) { + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":790 + * cdef inline tuple PyDataType_SHAPE(dtype d): + * if PyDataType_HASSUBARRAY(d): + * return d.subarray.shape # <<<<<<<<<<<<<< + * else: + * return () + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(((PyObject*)__pyx_v_d->subarray->shape)); + __pyx_r = ((PyObject*)__pyx_v_d->subarray->shape); + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":789 + * + * cdef inline tuple PyDataType_SHAPE(dtype d): + * if PyDataType_HASSUBARRAY(d): # <<<<<<<<<<<<<< + * return d.subarray.shape + * else: + */ + } + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":792 + * return d.subarray.shape + * else: + * return () # <<<<<<<<<<<<<< + * + * + */ + /*else*/ { + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(__pyx_empty_tuple); + __pyx_r = __pyx_empty_tuple; + goto __pyx_L0; + } + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":788 + * return PyArray_MultiIterNew(5, a, b, c, d, e) + * + * cdef inline tuple PyDataType_SHAPE(dtype d): # <<<<<<<<<<<<<< + * if PyDataType_HASSUBARRAY(d): + * return d.subarray.shape + */ + + /* function exit code */ + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":968 + * int _import_umath() except -1 + * + * cdef inline void set_array_base(ndarray arr, object base): # <<<<<<<<<<<<<< + * Py_INCREF(base) # important to do this before stealing the reference below! + * PyArray_SetBaseObject(arr, base) + */ + +static CYTHON_INLINE void __pyx_f_5numpy_set_array_base(PyArrayObject *__pyx_v_arr, PyObject *__pyx_v_base) { + int __pyx_t_1; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":969 + * + * cdef inline void set_array_base(ndarray arr, object base): + * Py_INCREF(base) # important to do this before stealing the reference below! # <<<<<<<<<<<<<< + * PyArray_SetBaseObject(arr, base) + * + */ + Py_INCREF(__pyx_v_base); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":970 + * cdef inline void set_array_base(ndarray arr, object base): + * Py_INCREF(base) # important to do this before stealing the reference below! + * PyArray_SetBaseObject(arr, base) # <<<<<<<<<<<<<< + * + * cdef inline object get_array_base(ndarray arr): + */ + __pyx_t_1 = PyArray_SetBaseObject(__pyx_v_arr, __pyx_v_base); if (unlikely(__pyx_t_1 == ((int)-1))) __PYX_ERR(2, 970, __pyx_L1_error) + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":968 + * int _import_umath() except -1 + * + * cdef inline void set_array_base(ndarray arr, object base): # <<<<<<<<<<<<<< + * Py_INCREF(base) # important to do this before stealing the reference below! + * PyArray_SetBaseObject(arr, base) + */ + + /* function exit code */ + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_AddTraceback("numpy.set_array_base", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_L0:; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":972 + * PyArray_SetBaseObject(arr, base) + * + * cdef inline object get_array_base(ndarray arr): # <<<<<<<<<<<<<< + * base = PyArray_BASE(arr) + * if base is NULL: + */ + +static CYTHON_INLINE PyObject *__pyx_f_5numpy_get_array_base(PyArrayObject *__pyx_v_arr) { + PyObject *__pyx_v_base; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + __Pyx_RefNannySetupContext("get_array_base", 1); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":973 + * + * cdef inline object get_array_base(ndarray arr): + * base = PyArray_BASE(arr) # <<<<<<<<<<<<<< + * if base is NULL: + * return None + */ + __pyx_v_base = PyArray_BASE(__pyx_v_arr); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":974 + * cdef inline object get_array_base(ndarray arr): + * base = PyArray_BASE(arr) + * if base is NULL: # <<<<<<<<<<<<<< + * return None + * return base + */ + __pyx_t_1 = (__pyx_v_base == NULL); + if (__pyx_t_1) { + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":975 + * base = PyArray_BASE(arr) + * if base is NULL: + * return None # <<<<<<<<<<<<<< + * return base + * + */ + __Pyx_XDECREF(__pyx_r); + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":974 + * cdef inline object get_array_base(ndarray arr): + * base = PyArray_BASE(arr) + * if base is NULL: # <<<<<<<<<<<<<< + * return None + * return base + */ + } + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":976 + * if base is NULL: + * return None + * return base # <<<<<<<<<<<<<< + * + * # Versions of the import_* functions which are more suitable for + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(((PyObject *)__pyx_v_base)); + __pyx_r = ((PyObject *)__pyx_v_base); + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":972 + * PyArray_SetBaseObject(arr, base) + * + * cdef inline object get_array_base(ndarray arr): # <<<<<<<<<<<<<< + * base = PyArray_BASE(arr) + * if base is NULL: + */ + + /* function exit code */ + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":980 + * # Versions of the import_* functions which are more suitable for + * # Cython code. + * cdef inline int import_array() except -1: # <<<<<<<<<<<<<< + * try: + * __pyx_import_array() + */ + +static CYTHON_INLINE int __pyx_f_5numpy_import_array(void) { + int __pyx_r; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + int __pyx_t_4; + PyObject *__pyx_t_5 = NULL; + PyObject *__pyx_t_6 = NULL; + PyObject *__pyx_t_7 = NULL; + PyObject *__pyx_t_8 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("import_array", 1); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":981 + * # Cython code. + * cdef inline int import_array() except -1: + * try: # <<<<<<<<<<<<<< + * __pyx_import_array() + * except Exception: + */ + { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ExceptionSave(&__pyx_t_1, &__pyx_t_2, &__pyx_t_3); + __Pyx_XGOTREF(__pyx_t_1); + __Pyx_XGOTREF(__pyx_t_2); + __Pyx_XGOTREF(__pyx_t_3); + /*try:*/ { + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":982 + * cdef inline int import_array() except -1: + * try: + * __pyx_import_array() # <<<<<<<<<<<<<< + * except Exception: + * raise ImportError("numpy.core.multiarray failed to import") + */ + __pyx_t_4 = _import_array(); if (unlikely(__pyx_t_4 == ((int)-1))) __PYX_ERR(2, 982, __pyx_L3_error) + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":981 + * # Cython code. + * cdef inline int import_array() except -1: + * try: # <<<<<<<<<<<<<< + * __pyx_import_array() + * except Exception: + */ + } + __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0; + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + goto __pyx_L8_try_end; + __pyx_L3_error:; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":983 + * try: + * __pyx_import_array() + * except Exception: # <<<<<<<<<<<<<< + * raise ImportError("numpy.core.multiarray failed to import") + * + */ + __pyx_t_4 = __Pyx_PyErr_ExceptionMatches(((PyObject *)(&((PyTypeObject*)PyExc_Exception)[0]))); + if (__pyx_t_4) { + __Pyx_AddTraceback("numpy.import_array", __pyx_clineno, __pyx_lineno, __pyx_filename); + if (__Pyx_GetException(&__pyx_t_5, &__pyx_t_6, &__pyx_t_7) < 0) __PYX_ERR(2, 983, __pyx_L5_except_error) + __Pyx_XGOTREF(__pyx_t_5); + __Pyx_XGOTREF(__pyx_t_6); + __Pyx_XGOTREF(__pyx_t_7); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":984 + * __pyx_import_array() + * except Exception: + * raise ImportError("numpy.core.multiarray failed to import") # <<<<<<<<<<<<<< + * + * cdef inline int import_umath() except -1: + */ + __pyx_t_8 = __Pyx_PyObject_Call(__pyx_builtin_ImportError, __pyx_tuple__9, NULL); if (unlikely(!__pyx_t_8)) __PYX_ERR(2, 984, __pyx_L5_except_error) + __Pyx_GOTREF(__pyx_t_8); + __Pyx_Raise(__pyx_t_8, 0, 0, 0); + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + __PYX_ERR(2, 984, __pyx_L5_except_error) + } + goto __pyx_L5_except_error; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":981 + * # Cython code. + * cdef inline int import_array() except -1: + * try: # <<<<<<<<<<<<<< + * __pyx_import_array() + * except Exception: + */ + __pyx_L5_except_error:; + __Pyx_XGIVEREF(__pyx_t_1); + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_ExceptionReset(__pyx_t_1, __pyx_t_2, __pyx_t_3); + goto __pyx_L1_error; + __pyx_L8_try_end:; + } + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":980 + * # Versions of the import_* functions which are more suitable for + * # Cython code. + * cdef inline int import_array() except -1: # <<<<<<<<<<<<<< + * try: + * __pyx_import_array() + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_5); + __Pyx_XDECREF(__pyx_t_6); + __Pyx_XDECREF(__pyx_t_7); + __Pyx_XDECREF(__pyx_t_8); + __Pyx_AddTraceback("numpy.import_array", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":986 + * raise ImportError("numpy.core.multiarray failed to import") + * + * cdef inline int import_umath() except -1: # <<<<<<<<<<<<<< + * try: + * _import_umath() + */ + +static CYTHON_INLINE int __pyx_f_5numpy_import_umath(void) { + int __pyx_r; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + int __pyx_t_4; + PyObject *__pyx_t_5 = NULL; + PyObject *__pyx_t_6 = NULL; + PyObject *__pyx_t_7 = NULL; + PyObject *__pyx_t_8 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("import_umath", 1); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":987 + * + * cdef inline int import_umath() except -1: + * try: # <<<<<<<<<<<<<< + * _import_umath() + * except Exception: + */ + { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ExceptionSave(&__pyx_t_1, &__pyx_t_2, &__pyx_t_3); + __Pyx_XGOTREF(__pyx_t_1); + __Pyx_XGOTREF(__pyx_t_2); + __Pyx_XGOTREF(__pyx_t_3); + /*try:*/ { + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":988 + * cdef inline int import_umath() except -1: + * try: + * _import_umath() # <<<<<<<<<<<<<< + * except Exception: + * raise ImportError("numpy.core.umath failed to import") + */ + __pyx_t_4 = _import_umath(); if (unlikely(__pyx_t_4 == ((int)-1))) __PYX_ERR(2, 988, __pyx_L3_error) + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":987 + * + * cdef inline int import_umath() except -1: + * try: # <<<<<<<<<<<<<< + * _import_umath() + * except Exception: + */ + } + __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0; + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + goto __pyx_L8_try_end; + __pyx_L3_error:; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":989 + * try: + * _import_umath() + * except Exception: # <<<<<<<<<<<<<< + * raise ImportError("numpy.core.umath failed to import") + * + */ + __pyx_t_4 = __Pyx_PyErr_ExceptionMatches(((PyObject *)(&((PyTypeObject*)PyExc_Exception)[0]))); + if (__pyx_t_4) { + __Pyx_AddTraceback("numpy.import_umath", __pyx_clineno, __pyx_lineno, __pyx_filename); + if (__Pyx_GetException(&__pyx_t_5, &__pyx_t_6, &__pyx_t_7) < 0) __PYX_ERR(2, 989, __pyx_L5_except_error) + __Pyx_XGOTREF(__pyx_t_5); + __Pyx_XGOTREF(__pyx_t_6); + __Pyx_XGOTREF(__pyx_t_7); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":990 + * _import_umath() + * except Exception: + * raise ImportError("numpy.core.umath failed to import") # <<<<<<<<<<<<<< + * + * cdef inline int import_ufunc() except -1: + */ + __pyx_t_8 = __Pyx_PyObject_Call(__pyx_builtin_ImportError, __pyx_tuple__10, NULL); if (unlikely(!__pyx_t_8)) __PYX_ERR(2, 990, __pyx_L5_except_error) + __Pyx_GOTREF(__pyx_t_8); + __Pyx_Raise(__pyx_t_8, 0, 0, 0); + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + __PYX_ERR(2, 990, __pyx_L5_except_error) + } + goto __pyx_L5_except_error; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":987 + * + * cdef inline int import_umath() except -1: + * try: # <<<<<<<<<<<<<< + * _import_umath() + * except Exception: + */ + __pyx_L5_except_error:; + __Pyx_XGIVEREF(__pyx_t_1); + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_ExceptionReset(__pyx_t_1, __pyx_t_2, __pyx_t_3); + goto __pyx_L1_error; + __pyx_L8_try_end:; + } + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":986 + * raise ImportError("numpy.core.multiarray failed to import") + * + * cdef inline int import_umath() except -1: # <<<<<<<<<<<<<< + * try: + * _import_umath() + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_5); + __Pyx_XDECREF(__pyx_t_6); + __Pyx_XDECREF(__pyx_t_7); + __Pyx_XDECREF(__pyx_t_8); + __Pyx_AddTraceback("numpy.import_umath", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":992 + * raise ImportError("numpy.core.umath failed to import") + * + * cdef inline int import_ufunc() except -1: # <<<<<<<<<<<<<< + * try: + * _import_umath() + */ + +static CYTHON_INLINE int __pyx_f_5numpy_import_ufunc(void) { + int __pyx_r; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + int __pyx_t_4; + PyObject *__pyx_t_5 = NULL; + PyObject *__pyx_t_6 = NULL; + PyObject *__pyx_t_7 = NULL; + PyObject *__pyx_t_8 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("import_ufunc", 1); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":993 + * + * cdef inline int import_ufunc() except -1: + * try: # <<<<<<<<<<<<<< + * _import_umath() + * except Exception: + */ + { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ExceptionSave(&__pyx_t_1, &__pyx_t_2, &__pyx_t_3); + __Pyx_XGOTREF(__pyx_t_1); + __Pyx_XGOTREF(__pyx_t_2); + __Pyx_XGOTREF(__pyx_t_3); + /*try:*/ { + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":994 + * cdef inline int import_ufunc() except -1: + * try: + * _import_umath() # <<<<<<<<<<<<<< + * except Exception: + * raise ImportError("numpy.core.umath failed to import") + */ + __pyx_t_4 = _import_umath(); if (unlikely(__pyx_t_4 == ((int)-1))) __PYX_ERR(2, 994, __pyx_L3_error) + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":993 + * + * cdef inline int import_ufunc() except -1: + * try: # <<<<<<<<<<<<<< + * _import_umath() + * except Exception: + */ + } + __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0; + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + goto __pyx_L8_try_end; + __pyx_L3_error:; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":995 + * try: + * _import_umath() + * except Exception: # <<<<<<<<<<<<<< + * raise ImportError("numpy.core.umath failed to import") + * + */ + __pyx_t_4 = __Pyx_PyErr_ExceptionMatches(((PyObject *)(&((PyTypeObject*)PyExc_Exception)[0]))); + if (__pyx_t_4) { + __Pyx_AddTraceback("numpy.import_ufunc", __pyx_clineno, __pyx_lineno, __pyx_filename); + if (__Pyx_GetException(&__pyx_t_5, &__pyx_t_6, &__pyx_t_7) < 0) __PYX_ERR(2, 995, __pyx_L5_except_error) + __Pyx_XGOTREF(__pyx_t_5); + __Pyx_XGOTREF(__pyx_t_6); + __Pyx_XGOTREF(__pyx_t_7); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":996 + * _import_umath() + * except Exception: + * raise ImportError("numpy.core.umath failed to import") # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_8 = __Pyx_PyObject_Call(__pyx_builtin_ImportError, __pyx_tuple__10, NULL); if (unlikely(!__pyx_t_8)) __PYX_ERR(2, 996, __pyx_L5_except_error) + __Pyx_GOTREF(__pyx_t_8); + __Pyx_Raise(__pyx_t_8, 0, 0, 0); + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + __PYX_ERR(2, 996, __pyx_L5_except_error) + } + goto __pyx_L5_except_error; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":993 + * + * cdef inline int import_ufunc() except -1: + * try: # <<<<<<<<<<<<<< + * _import_umath() + * except Exception: + */ + __pyx_L5_except_error:; + __Pyx_XGIVEREF(__pyx_t_1); + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_ExceptionReset(__pyx_t_1, __pyx_t_2, __pyx_t_3); + goto __pyx_L1_error; + __pyx_L8_try_end:; + } + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":992 + * raise ImportError("numpy.core.umath failed to import") + * + * cdef inline int import_ufunc() except -1: # <<<<<<<<<<<<<< + * try: + * _import_umath() + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_5); + __Pyx_XDECREF(__pyx_t_6); + __Pyx_XDECREF(__pyx_t_7); + __Pyx_XDECREF(__pyx_t_8); + __Pyx_AddTraceback("numpy.import_ufunc", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":999 + * + * + * cdef inline bint is_timedelta64_object(object obj): # <<<<<<<<<<<<<< + * """ + * Cython equivalent of `isinstance(obj, np.timedelta64)` + */ + +static CYTHON_INLINE int __pyx_f_5numpy_is_timedelta64_object(PyObject *__pyx_v_obj) { + int __pyx_r; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1011 + * bool + * """ + * return PyObject_TypeCheck(obj, &PyTimedeltaArrType_Type) # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = PyObject_TypeCheck(__pyx_v_obj, (&PyTimedeltaArrType_Type)); + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":999 + * + * + * cdef inline bint is_timedelta64_object(object obj): # <<<<<<<<<<<<<< + * """ + * Cython equivalent of `isinstance(obj, np.timedelta64)` + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1014 + * + * + * cdef inline bint is_datetime64_object(object obj): # <<<<<<<<<<<<<< + * """ + * Cython equivalent of `isinstance(obj, np.datetime64)` + */ + +static CYTHON_INLINE int __pyx_f_5numpy_is_datetime64_object(PyObject *__pyx_v_obj) { + int __pyx_r; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1026 + * bool + * """ + * return PyObject_TypeCheck(obj, &PyDatetimeArrType_Type) # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = PyObject_TypeCheck(__pyx_v_obj, (&PyDatetimeArrType_Type)); + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1014 + * + * + * cdef inline bint is_datetime64_object(object obj): # <<<<<<<<<<<<<< + * """ + * Cython equivalent of `isinstance(obj, np.datetime64)` + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1029 + * + * + * cdef inline npy_datetime get_datetime64_value(object obj) nogil: # <<<<<<<<<<<<<< + * """ + * returns the int64 value underlying scalar numpy datetime64 object + */ + +static CYTHON_INLINE npy_datetime __pyx_f_5numpy_get_datetime64_value(PyObject *__pyx_v_obj) { + npy_datetime __pyx_r; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1036 + * also needed. That can be found using `get_datetime64_unit`. + * """ + * return (obj).obval # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = ((PyDatetimeScalarObject *)__pyx_v_obj)->obval; + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1029 + * + * + * cdef inline npy_datetime get_datetime64_value(object obj) nogil: # <<<<<<<<<<<<<< + * """ + * returns the int64 value underlying scalar numpy datetime64 object + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1039 + * + * + * cdef inline npy_timedelta get_timedelta64_value(object obj) nogil: # <<<<<<<<<<<<<< + * """ + * returns the int64 value underlying scalar numpy timedelta64 object + */ + +static CYTHON_INLINE npy_timedelta __pyx_f_5numpy_get_timedelta64_value(PyObject *__pyx_v_obj) { + npy_timedelta __pyx_r; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1043 + * returns the int64 value underlying scalar numpy timedelta64 object + * """ + * return (obj).obval # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = ((PyTimedeltaScalarObject *)__pyx_v_obj)->obval; + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1039 + * + * + * cdef inline npy_timedelta get_timedelta64_value(object obj) nogil: # <<<<<<<<<<<<<< + * """ + * returns the int64 value underlying scalar numpy timedelta64 object + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1046 + * + * + * cdef inline NPY_DATETIMEUNIT get_datetime64_unit(object obj) nogil: # <<<<<<<<<<<<<< + * """ + * returns the unit part of the dtype for a numpy datetime64 object. + */ + +static CYTHON_INLINE NPY_DATETIMEUNIT __pyx_f_5numpy_get_datetime64_unit(PyObject *__pyx_v_obj) { + NPY_DATETIMEUNIT __pyx_r; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1050 + * returns the unit part of the dtype for a numpy datetime64 object. + * """ + * return (obj).obmeta.base # <<<<<<<<<<<<<< + */ + __pyx_r = ((NPY_DATETIMEUNIT)((PyDatetimeScalarObject *)__pyx_v_obj)->obmeta.base); + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1046 + * + * + * cdef inline NPY_DATETIMEUNIT get_datetime64_unit(object obj) nogil: # <<<<<<<<<<<<<< + * """ + * returns the unit part of the dtype for a numpy datetime64 object. + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "fairseq/data/data_utils_fast.pyx":20 + * @cython.boundscheck(False) + * @cython.wraparound(False) + * cpdef list batch_by_size_vec( # <<<<<<<<<<<<<< + * np.ndarray[int64_t, ndim=1] indices, + * np.ndarray[int64_t, ndim=1] num_tokens_vec, + */ + +static PyObject *__pyx_pw_7fairseq_4data_15data_utils_fast_1batch_by_size_vec(PyObject *__pyx_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_f_7fairseq_4data_15data_utils_fast_batch_by_size_vec(PyArrayObject *__pyx_v_indices, PyArrayObject *__pyx_v_num_tokens_vec, int64_t __pyx_v_max_tokens, int64_t __pyx_v_max_sentences, int32_t __pyx_v_bsz_mult, CYTHON_UNUSED int __pyx_skip_dispatch) { + int32_t __pyx_v_indices_len; + PyArrayObject *__pyx_v_batches_ends = 0; + __Pyx_memviewslice __pyx_v_batches_ends_view = { 0, 0, { 0 }, { 0 }, { 0 } }; + __Pyx_memviewslice __pyx_v_num_tokens_view = { 0, 0, { 0 }, { 0 }, { 0 } }; + int32_t __pyx_v_pos; + int32_t __pyx_v_new_batch_end; + int64_t __pyx_v_new_batch_max_tokens; + int32_t __pyx_v_new_batch_sentences; + int64_t __pyx_v_new_batch_num_tokens; + bool __pyx_v_overflow; + bool __pyx_v_size_matches_with_bsz_mult; + int32_t __pyx_v_batches_count; + int32_t __pyx_v_batch_start; + int64_t __pyx_v_tail_max_tokens; + int64_t __pyx_v_batch_max_tokens; + int64_t __pyx_v_tail_num_tokens; + PyObject *__pyx_v_tail_overflow = NULL; + __Pyx_LocalBuf_ND __pyx_pybuffernd_batches_ends; + __Pyx_Buffer __pyx_pybuffer_batches_ends; + __Pyx_LocalBuf_ND __pyx_pybuffernd_indices; + __Pyx_Buffer __pyx_pybuffer_indices; + __Pyx_LocalBuf_ND __pyx_pybuffernd_num_tokens_vec; + __Pyx_Buffer __pyx_pybuffer_num_tokens_vec; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + npy_intp *__pyx_t_1; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + int __pyx_t_4; + PyObject *__pyx_t_5 = NULL; + PyObject *__pyx_t_6 = NULL; + int __pyx_t_7; + PyObject *__pyx_t_8 = NULL; + PyObject *__pyx_t_9 = NULL; + PyArrayObject *__pyx_t_10 = NULL; + __Pyx_memviewslice __pyx_t_11 = { 0, 0, { 0 }, { 0 }, { 0 } }; + __Pyx_memviewslice __pyx_t_12 = { 0, 0, { 0 }, { 0 }, { 0 } }; + int32_t __pyx_t_13; + int32_t __pyx_t_14; + int32_t __pyx_t_15; + int64_t __pyx_t_16; + Py_ssize_t __pyx_t_17; + bool __pyx_t_18; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("batch_by_size_vec", 1); + __pyx_pybuffer_batches_ends.pybuffer.buf = NULL; + __pyx_pybuffer_batches_ends.refcount = 0; + __pyx_pybuffernd_batches_ends.data = NULL; + __pyx_pybuffernd_batches_ends.rcbuffer = &__pyx_pybuffer_batches_ends; + __pyx_pybuffer_indices.pybuffer.buf = NULL; + __pyx_pybuffer_indices.refcount = 0; + __pyx_pybuffernd_indices.data = NULL; + __pyx_pybuffernd_indices.rcbuffer = &__pyx_pybuffer_indices; + __pyx_pybuffer_num_tokens_vec.pybuffer.buf = NULL; + __pyx_pybuffer_num_tokens_vec.refcount = 0; + __pyx_pybuffernd_num_tokens_vec.data = NULL; + __pyx_pybuffernd_num_tokens_vec.rcbuffer = &__pyx_pybuffer_num_tokens_vec; + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_indices.rcbuffer->pybuffer, (PyObject*)__pyx_v_indices, &__Pyx_TypeInfo_nn_int64_t, PyBUF_FORMAT| PyBUF_STRIDES, 1, 0, __pyx_stack) == -1)) __PYX_ERR(0, 20, __pyx_L1_error) + } + __pyx_pybuffernd_indices.diminfo[0].strides = __pyx_pybuffernd_indices.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_indices.diminfo[0].shape = __pyx_pybuffernd_indices.rcbuffer->pybuffer.shape[0]; + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_num_tokens_vec.rcbuffer->pybuffer, (PyObject*)__pyx_v_num_tokens_vec, &__Pyx_TypeInfo_nn_int64_t, PyBUF_FORMAT| PyBUF_STRIDES, 1, 0, __pyx_stack) == -1)) __PYX_ERR(0, 20, __pyx_L1_error) + } + __pyx_pybuffernd_num_tokens_vec.diminfo[0].strides = __pyx_pybuffernd_num_tokens_vec.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_num_tokens_vec.diminfo[0].shape = __pyx_pybuffernd_num_tokens_vec.rcbuffer->pybuffer.shape[0]; + + /* "fairseq/data/data_utils_fast.pyx":27 + * int32_t bsz_mult, + * ): + * if indices.shape[0] == 0: # <<<<<<<<<<<<<< + * return [] + * + */ + __pyx_t_1 = __pyx_f_5numpy_7ndarray_5shape_shape(((PyArrayObject *)__pyx_v_indices)); if (unlikely(__pyx_t_1 == ((npy_intp *)NULL) && PyErr_Occurred())) __PYX_ERR(0, 27, __pyx_L1_error) + __pyx_t_2 = ((__pyx_t_1[0]) == 0); + if (__pyx_t_2) { + + /* "fairseq/data/data_utils_fast.pyx":28 + * ): + * if indices.shape[0] == 0: + * return [] # <<<<<<<<<<<<<< + * + * assert max_tokens <= 0 or np.max(num_tokens_vec) <= max_tokens, ( + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_3 = PyList_New(0); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 28, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_r = ((PyObject*)__pyx_t_3); + __pyx_t_3 = 0; + goto __pyx_L0; + + /* "fairseq/data/data_utils_fast.pyx":27 + * int32_t bsz_mult, + * ): + * if indices.shape[0] == 0: # <<<<<<<<<<<<<< + * return [] + * + */ + } + + /* "fairseq/data/data_utils_fast.pyx":30 + * return [] + * + * assert max_tokens <= 0 or np.max(num_tokens_vec) <= max_tokens, ( # <<<<<<<<<<<<<< + * f"Sentences lengths should not exceed max_tokens={max_tokens}" + * ) + */ + #ifndef CYTHON_WITHOUT_ASSERTIONS + if (unlikely(__pyx_assertions_enabled())) { + __pyx_t_4 = (__pyx_v_max_tokens <= 0); + if (!__pyx_t_4) { + } else { + __pyx_t_2 = __pyx_t_4; + goto __pyx_L4_bool_binop_done; + } + __Pyx_GetModuleGlobalName(__pyx_t_5, __pyx_n_s_np); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 30, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_6 = __Pyx_PyObject_GetAttrStr(__pyx_t_5, __pyx_n_s_max); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 30, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __pyx_t_5 = NULL; + __pyx_t_7 = 0; + #if CYTHON_UNPACK_METHODS + if (unlikely(PyMethod_Check(__pyx_t_6))) { + __pyx_t_5 = PyMethod_GET_SELF(__pyx_t_6); + if (likely(__pyx_t_5)) { + PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_6); + __Pyx_INCREF(__pyx_t_5); + __Pyx_INCREF(function); + __Pyx_DECREF_SET(__pyx_t_6, function); + __pyx_t_7 = 1; + } + } + #endif + { + PyObject *__pyx_callargs[2] = {__pyx_t_5, ((PyObject *)__pyx_v_num_tokens_vec)}; + __pyx_t_3 = __Pyx_PyObject_FastCall(__pyx_t_6, __pyx_callargs+1-__pyx_t_7, 1+__pyx_t_7); + __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; + if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 30, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + } + __pyx_t_6 = __Pyx_PyInt_From_int64_t(__pyx_v_max_tokens); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 30, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __pyx_t_5 = PyObject_RichCompare(__pyx_t_3, __pyx_t_6, Py_LE); __Pyx_XGOTREF(__pyx_t_5); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 30, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + __pyx_t_4 = __Pyx_PyObject_IsTrue(__pyx_t_5); if (unlikely((__pyx_t_4 < 0))) __PYX_ERR(0, 30, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __pyx_t_2 = __pyx_t_4; + __pyx_L4_bool_binop_done:; + if (unlikely(!__pyx_t_2)) { + + /* "fairseq/data/data_utils_fast.pyx":31 + * + * assert max_tokens <= 0 or np.max(num_tokens_vec) <= max_tokens, ( + * f"Sentences lengths should not exceed max_tokens={max_tokens}" # <<<<<<<<<<<<<< + * ) + * + */ + __pyx_t_5 = __Pyx_PyInt_From_int64_t(__pyx_v_max_tokens); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 31, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_6 = __Pyx_PyObject_FormatSimple(__pyx_t_5, __pyx_empty_unicode); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 31, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __pyx_t_5 = __Pyx_PyUnicode_Concat(__pyx_kp_u_Sentences_lengths_should_not_exc, __pyx_t_6); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 31, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + __Pyx_Raise(__pyx_builtin_AssertionError, __pyx_t_5, 0, 0); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __PYX_ERR(0, 30, __pyx_L1_error) + } + } + #else + if ((1)); else __PYX_ERR(0, 30, __pyx_L1_error) + #endif + + /* "fairseq/data/data_utils_fast.pyx":34 + * ) + * + * cdef int32_t indices_len = indices.shape[0] # <<<<<<<<<<<<<< + * cdef np.ndarray[int32_t, ndim=1] batches_ends = \ + * np.zeros(indices_len, dtype=np.int32) + */ + __pyx_t_1 = __pyx_f_5numpy_7ndarray_5shape_shape(((PyArrayObject *)__pyx_v_indices)); if (unlikely(__pyx_t_1 == ((npy_intp *)NULL) && PyErr_Occurred())) __PYX_ERR(0, 34, __pyx_L1_error) + __pyx_v_indices_len = (__pyx_t_1[0]); + + /* "fairseq/data/data_utils_fast.pyx":36 + * cdef int32_t indices_len = indices.shape[0] + * cdef np.ndarray[int32_t, ndim=1] batches_ends = \ + * np.zeros(indices_len, dtype=np.int32) # <<<<<<<<<<<<<< + * cdef int32_t[:] batches_ends_view = batches_ends + * cdef int64_t[:] num_tokens_view = num_tokens_vec + */ + __Pyx_GetModuleGlobalName(__pyx_t_5, __pyx_n_s_np); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 36, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_6 = __Pyx_PyObject_GetAttrStr(__pyx_t_5, __pyx_n_s_zeros); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 36, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __pyx_t_5 = __Pyx_PyInt_From_int32_t(__pyx_v_indices_len); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 36, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_3 = PyTuple_New(1); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 36, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_GIVEREF(__pyx_t_5); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_t_5)) __PYX_ERR(0, 36, __pyx_L1_error); + __pyx_t_5 = 0; + __pyx_t_5 = __Pyx_PyDict_NewPresized(1); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 36, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_GetModuleGlobalName(__pyx_t_8, __pyx_n_s_np); if (unlikely(!__pyx_t_8)) __PYX_ERR(0, 36, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_8); + __pyx_t_9 = __Pyx_PyObject_GetAttrStr(__pyx_t_8, __pyx_n_s_int32); if (unlikely(!__pyx_t_9)) __PYX_ERR(0, 36, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_9); + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + if (PyDict_SetItem(__pyx_t_5, __pyx_n_s_dtype, __pyx_t_9) < 0) __PYX_ERR(0, 36, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_9); __pyx_t_9 = 0; + __pyx_t_9 = __Pyx_PyObject_Call(__pyx_t_6, __pyx_t_3, __pyx_t_5); if (unlikely(!__pyx_t_9)) __PYX_ERR(0, 36, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_9); + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + if (!(likely(((__pyx_t_9) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_9, __pyx_ptype_5numpy_ndarray))))) __PYX_ERR(0, 36, __pyx_L1_error) + __pyx_t_10 = ((PyArrayObject *)__pyx_t_9); + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_batches_ends.rcbuffer->pybuffer, (PyObject*)__pyx_t_10, &__Pyx_TypeInfo_nn_int32_t, PyBUF_FORMAT| PyBUF_STRIDES, 1, 0, __pyx_stack) == -1)) { + __pyx_v_batches_ends = ((PyArrayObject *)Py_None); __Pyx_INCREF(Py_None); __pyx_pybuffernd_batches_ends.rcbuffer->pybuffer.buf = NULL; + __PYX_ERR(0, 35, __pyx_L1_error) + } else {__pyx_pybuffernd_batches_ends.diminfo[0].strides = __pyx_pybuffernd_batches_ends.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_batches_ends.diminfo[0].shape = __pyx_pybuffernd_batches_ends.rcbuffer->pybuffer.shape[0]; + } + } + __pyx_t_10 = 0; + __pyx_v_batches_ends = ((PyArrayObject *)__pyx_t_9); + __pyx_t_9 = 0; + + /* "fairseq/data/data_utils_fast.pyx":37 + * cdef np.ndarray[int32_t, ndim=1] batches_ends = \ + * np.zeros(indices_len, dtype=np.int32) + * cdef int32_t[:] batches_ends_view = batches_ends # <<<<<<<<<<<<<< + * cdef int64_t[:] num_tokens_view = num_tokens_vec + * + */ + __pyx_t_11 = __Pyx_PyObject_to_MemoryviewSlice_ds_nn_int32_t(((PyObject *)__pyx_v_batches_ends), PyBUF_WRITABLE); if (unlikely(!__pyx_t_11.memview)) __PYX_ERR(0, 37, __pyx_L1_error) + __pyx_v_batches_ends_view = __pyx_t_11; + __pyx_t_11.memview = NULL; + __pyx_t_11.data = NULL; + + /* "fairseq/data/data_utils_fast.pyx":38 + * np.zeros(indices_len, dtype=np.int32) + * cdef int32_t[:] batches_ends_view = batches_ends + * cdef int64_t[:] num_tokens_view = num_tokens_vec # <<<<<<<<<<<<<< + * + * cdef int32_t pos = 0 + */ + __pyx_t_12 = __Pyx_PyObject_to_MemoryviewSlice_ds_nn_int64_t(((PyObject *)__pyx_v_num_tokens_vec), PyBUF_WRITABLE); if (unlikely(!__pyx_t_12.memview)) __PYX_ERR(0, 38, __pyx_L1_error) + __pyx_v_num_tokens_view = __pyx_t_12; + __pyx_t_12.memview = NULL; + __pyx_t_12.data = NULL; + + /* "fairseq/data/data_utils_fast.pyx":40 + * cdef int64_t[:] num_tokens_view = num_tokens_vec + * + * cdef int32_t pos = 0 # <<<<<<<<<<<<<< + * cdef int32_t new_batch_end = 0 + * + */ + __pyx_v_pos = 0; + + /* "fairseq/data/data_utils_fast.pyx":41 + * + * cdef int32_t pos = 0 + * cdef int32_t new_batch_end = 0 # <<<<<<<<<<<<<< + * + * cdef int64_t new_batch_max_tokens = 0 + */ + __pyx_v_new_batch_end = 0; + + /* "fairseq/data/data_utils_fast.pyx":43 + * cdef int32_t new_batch_end = 0 + * + * cdef int64_t new_batch_max_tokens = 0 # <<<<<<<<<<<<<< + * cdef int32_t new_batch_sentences = 0 + * cdef int64_t new_batch_num_tokens = 0 + */ + __pyx_v_new_batch_max_tokens = 0; + + /* "fairseq/data/data_utils_fast.pyx":44 + * + * cdef int64_t new_batch_max_tokens = 0 + * cdef int32_t new_batch_sentences = 0 # <<<<<<<<<<<<<< + * cdef int64_t new_batch_num_tokens = 0 + * + */ + __pyx_v_new_batch_sentences = 0; + + /* "fairseq/data/data_utils_fast.pyx":45 + * cdef int64_t new_batch_max_tokens = 0 + * cdef int32_t new_batch_sentences = 0 + * cdef int64_t new_batch_num_tokens = 0 # <<<<<<<<<<<<<< + * + * cdef bool_t overflow = False + */ + __pyx_v_new_batch_num_tokens = 0; + + /* "fairseq/data/data_utils_fast.pyx":47 + * cdef int64_t new_batch_num_tokens = 0 + * + * cdef bool_t overflow = False # <<<<<<<<<<<<<< + * cdef bool_t size_matches_with_bsz_mult = False + * + */ + __pyx_v_overflow = 0; + + /* "fairseq/data/data_utils_fast.pyx":48 + * + * cdef bool_t overflow = False + * cdef bool_t size_matches_with_bsz_mult = False # <<<<<<<<<<<<<< + * + * cdef int32_t batches_count = 0 + */ + __pyx_v_size_matches_with_bsz_mult = 0; + + /* "fairseq/data/data_utils_fast.pyx":50 + * cdef bool_t size_matches_with_bsz_mult = False + * + * cdef int32_t batches_count = 0 # <<<<<<<<<<<<<< + * cdef int32_t batch_start = 0 + * cdef int64_t tail_max_tokens = 0 + */ + __pyx_v_batches_count = 0; + + /* "fairseq/data/data_utils_fast.pyx":51 + * + * cdef int32_t batches_count = 0 + * cdef int32_t batch_start = 0 # <<<<<<<<<<<<<< + * cdef int64_t tail_max_tokens = 0 + * cdef int64_t batch_max_tokens = 0 + */ + __pyx_v_batch_start = 0; + + /* "fairseq/data/data_utils_fast.pyx":52 + * cdef int32_t batches_count = 0 + * cdef int32_t batch_start = 0 + * cdef int64_t tail_max_tokens = 0 # <<<<<<<<<<<<<< + * cdef int64_t batch_max_tokens = 0 + * + */ + __pyx_v_tail_max_tokens = 0; + + /* "fairseq/data/data_utils_fast.pyx":53 + * cdef int32_t batch_start = 0 + * cdef int64_t tail_max_tokens = 0 + * cdef int64_t batch_max_tokens = 0 # <<<<<<<<<<<<<< + * + * for pos in range(indices_len): + */ + __pyx_v_batch_max_tokens = 0; + + /* "fairseq/data/data_utils_fast.pyx":55 + * cdef int64_t batch_max_tokens = 0 + * + * for pos in range(indices_len): # <<<<<<<<<<<<<< + * # At every pos we keep stats about the last complete batch [batch_start:batch_end), + * # and tail [batch_end:pos]. + */ + __pyx_t_13 = __pyx_v_indices_len; + __pyx_t_14 = __pyx_t_13; + for (__pyx_t_15 = 0; __pyx_t_15 < __pyx_t_14; __pyx_t_15+=1) { + __pyx_v_pos = __pyx_t_15; + + /* "fairseq/data/data_utils_fast.pyx":69 + * + * tail_max_tokens = tail_max_tokens \ + * if tail_max_tokens > num_tokens_view[pos] \ # <<<<<<<<<<<<<< + * else num_tokens_view[pos] + * new_batch_end = pos + 1 + */ + __pyx_t_17 = __pyx_v_pos; + __pyx_t_2 = (__pyx_v_tail_max_tokens > (*((int64_t *) ( /* dim=0 */ (__pyx_v_num_tokens_view.data + __pyx_t_17 * __pyx_v_num_tokens_view.strides[0]) )))); + if (__pyx_t_2) { + + /* "fairseq/data/data_utils_fast.pyx":68 + * # Important: For the sake of performance try to avoid using function calls within this loop. + * + * tail_max_tokens = tail_max_tokens \ # <<<<<<<<<<<<<< + * if tail_max_tokens > num_tokens_view[pos] \ + * else num_tokens_view[pos] + */ + __pyx_t_16 = __pyx_v_tail_max_tokens; + } else { + + /* "fairseq/data/data_utils_fast.pyx":70 + * tail_max_tokens = tail_max_tokens \ + * if tail_max_tokens > num_tokens_view[pos] \ + * else num_tokens_view[pos] # <<<<<<<<<<<<<< + * new_batch_end = pos + 1 + * new_batch_max_tokens = batch_max_tokens \ + */ + __pyx_t_17 = __pyx_v_pos; + __pyx_t_16 = (*((int64_t *) ( /* dim=0 */ (__pyx_v_num_tokens_view.data + __pyx_t_17 * __pyx_v_num_tokens_view.strides[0]) ))); + } + __pyx_v_tail_max_tokens = __pyx_t_16; + + /* "fairseq/data/data_utils_fast.pyx":71 + * if tail_max_tokens > num_tokens_view[pos] \ + * else num_tokens_view[pos] + * new_batch_end = pos + 1 # <<<<<<<<<<<<<< + * new_batch_max_tokens = batch_max_tokens \ + * if batch_max_tokens > tail_max_tokens \ + */ + __pyx_v_new_batch_end = (__pyx_v_pos + 1); + + /* "fairseq/data/data_utils_fast.pyx":73 + * new_batch_end = pos + 1 + * new_batch_max_tokens = batch_max_tokens \ + * if batch_max_tokens > tail_max_tokens \ # <<<<<<<<<<<<<< + * else tail_max_tokens + * new_batch_sentences = new_batch_end - batch_start + */ + __pyx_t_2 = (__pyx_v_batch_max_tokens > __pyx_v_tail_max_tokens); + if (__pyx_t_2) { + + /* "fairseq/data/data_utils_fast.pyx":72 + * else num_tokens_view[pos] + * new_batch_end = pos + 1 + * new_batch_max_tokens = batch_max_tokens \ # <<<<<<<<<<<<<< + * if batch_max_tokens > tail_max_tokens \ + * else tail_max_tokens + */ + __pyx_t_16 = __pyx_v_batch_max_tokens; + } else { + + /* "fairseq/data/data_utils_fast.pyx":74 + * new_batch_max_tokens = batch_max_tokens \ + * if batch_max_tokens > tail_max_tokens \ + * else tail_max_tokens # <<<<<<<<<<<<<< + * new_batch_sentences = new_batch_end - batch_start + * new_batch_num_tokens = new_batch_sentences * new_batch_max_tokens + */ + __pyx_t_16 = __pyx_v_tail_max_tokens; + } + __pyx_v_new_batch_max_tokens = __pyx_t_16; + + /* "fairseq/data/data_utils_fast.pyx":75 + * if batch_max_tokens > tail_max_tokens \ + * else tail_max_tokens + * new_batch_sentences = new_batch_end - batch_start # <<<<<<<<<<<<<< + * new_batch_num_tokens = new_batch_sentences * new_batch_max_tokens + * + */ + __pyx_v_new_batch_sentences = (__pyx_v_new_batch_end - __pyx_v_batch_start); + + /* "fairseq/data/data_utils_fast.pyx":76 + * else tail_max_tokens + * new_batch_sentences = new_batch_end - batch_start + * new_batch_num_tokens = new_batch_sentences * new_batch_max_tokens # <<<<<<<<<<<<<< + * + * overflow = (new_batch_sentences > max_sentences > 0 or + */ + __pyx_v_new_batch_num_tokens = (__pyx_v_new_batch_sentences * __pyx_v_new_batch_max_tokens); + + /* "fairseq/data/data_utils_fast.pyx":78 + * new_batch_num_tokens = new_batch_sentences * new_batch_max_tokens + * + * overflow = (new_batch_sentences > max_sentences > 0 or # <<<<<<<<<<<<<< + * new_batch_num_tokens > max_tokens > 0) + * size_matches_with_bsz_mult = (new_batch_sentences < bsz_mult or + */ + __pyx_t_2 = (__pyx_v_new_batch_sentences > __pyx_v_max_sentences); + if (__pyx_t_2) { + __pyx_t_2 = (__pyx_v_max_sentences > 0); + } + if (!__pyx_t_2) { + } else { + __pyx_t_18 = __pyx_t_2; + goto __pyx_L8_bool_binop_done; + } + + /* "fairseq/data/data_utils_fast.pyx":79 + * + * overflow = (new_batch_sentences > max_sentences > 0 or + * new_batch_num_tokens > max_tokens > 0) # <<<<<<<<<<<<<< + * size_matches_with_bsz_mult = (new_batch_sentences < bsz_mult or + * new_batch_sentences % bsz_mult == 0) + */ + __pyx_t_2 = (__pyx_v_new_batch_num_tokens > __pyx_v_max_tokens); + if (__pyx_t_2) { + __pyx_t_2 = (__pyx_v_max_tokens > 0); + } + __pyx_t_18 = __pyx_t_2; + __pyx_L8_bool_binop_done:; + __pyx_v_overflow = __pyx_t_18; + + /* "fairseq/data/data_utils_fast.pyx":80 + * overflow = (new_batch_sentences > max_sentences > 0 or + * new_batch_num_tokens > max_tokens > 0) + * size_matches_with_bsz_mult = (new_batch_sentences < bsz_mult or # <<<<<<<<<<<<<< + * new_batch_sentences % bsz_mult == 0) + * + */ + __pyx_t_2 = (__pyx_v_new_batch_sentences < __pyx_v_bsz_mult); + if (!__pyx_t_2) { + } else { + __pyx_t_18 = __pyx_t_2; + goto __pyx_L10_bool_binop_done; + } + + /* "fairseq/data/data_utils_fast.pyx":81 + * new_batch_num_tokens > max_tokens > 0) + * size_matches_with_bsz_mult = (new_batch_sentences < bsz_mult or + * new_batch_sentences % bsz_mult == 0) # <<<<<<<<<<<<<< + * + * if overflow: + */ + __pyx_t_2 = ((__pyx_v_new_batch_sentences % __pyx_v_bsz_mult) == 0); + __pyx_t_18 = __pyx_t_2; + __pyx_L10_bool_binop_done:; + __pyx_v_size_matches_with_bsz_mult = __pyx_t_18; + + /* "fairseq/data/data_utils_fast.pyx":83 + * new_batch_sentences % bsz_mult == 0) + * + * if overflow: # <<<<<<<<<<<<<< + * tail_num_tokens = tail_max_tokens * \ + * (new_batch_end - batches_ends_view[batches_count]) + */ + __pyx_t_2 = (__pyx_v_overflow != 0); + if (__pyx_t_2) { + + /* "fairseq/data/data_utils_fast.pyx":85 + * if overflow: + * tail_num_tokens = tail_max_tokens * \ + * (new_batch_end - batches_ends_view[batches_count]) # <<<<<<<<<<<<<< + * tail_overflow = tail_num_tokens > max_tokens > 0 + * # In case of a tail overflow finalize two batches + */ + __pyx_t_17 = __pyx_v_batches_count; + + /* "fairseq/data/data_utils_fast.pyx":84 + * + * if overflow: + * tail_num_tokens = tail_max_tokens * \ # <<<<<<<<<<<<<< + * (new_batch_end - batches_ends_view[batches_count]) + * tail_overflow = tail_num_tokens > max_tokens > 0 + */ + __pyx_v_tail_num_tokens = (__pyx_v_tail_max_tokens * (__pyx_v_new_batch_end - (*((int32_t *) ( /* dim=0 */ (__pyx_v_batches_ends_view.data + __pyx_t_17 * __pyx_v_batches_ends_view.strides[0]) ))))); + + /* "fairseq/data/data_utils_fast.pyx":86 + * tail_num_tokens = tail_max_tokens * \ + * (new_batch_end - batches_ends_view[batches_count]) + * tail_overflow = tail_num_tokens > max_tokens > 0 # <<<<<<<<<<<<<< + * # In case of a tail overflow finalize two batches + * if tail_overflow: + */ + __pyx_t_2 = (__pyx_v_tail_num_tokens > __pyx_v_max_tokens); + if (__pyx_t_2) { + __pyx_t_2 = (__pyx_v_max_tokens > 0); + } + __pyx_t_9 = __Pyx_PyBool_FromLong(__pyx_t_2); if (unlikely(!__pyx_t_9)) __PYX_ERR(0, 86, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_9); + __Pyx_XDECREF_SET(__pyx_v_tail_overflow, __pyx_t_9); + __pyx_t_9 = 0; + + /* "fairseq/data/data_utils_fast.pyx":88 + * tail_overflow = tail_num_tokens > max_tokens > 0 + * # In case of a tail overflow finalize two batches + * if tail_overflow: # <<<<<<<<<<<<<< + * batches_count += 1 + * batches_ends_view[batches_count] = pos + */ + __pyx_t_2 = __Pyx_PyObject_IsTrue(__pyx_v_tail_overflow); if (unlikely((__pyx_t_2 < 0))) __PYX_ERR(0, 88, __pyx_L1_error) + if (__pyx_t_2) { + + /* "fairseq/data/data_utils_fast.pyx":89 + * # In case of a tail overflow finalize two batches + * if tail_overflow: + * batches_count += 1 # <<<<<<<<<<<<<< + * batches_ends_view[batches_count] = pos + * tail_max_tokens = num_tokens_view[pos] + */ + __pyx_v_batches_count = (__pyx_v_batches_count + 1); + + /* "fairseq/data/data_utils_fast.pyx":90 + * if tail_overflow: + * batches_count += 1 + * batches_ends_view[batches_count] = pos # <<<<<<<<<<<<<< + * tail_max_tokens = num_tokens_view[pos] + * batch_start = batches_ends_view[batches_count] + */ + __pyx_t_17 = __pyx_v_batches_count; + *((int32_t *) ( /* dim=0 */ (__pyx_v_batches_ends_view.data + __pyx_t_17 * __pyx_v_batches_ends_view.strides[0]) )) = __pyx_v_pos; + + /* "fairseq/data/data_utils_fast.pyx":91 + * batches_count += 1 + * batches_ends_view[batches_count] = pos + * tail_max_tokens = num_tokens_view[pos] # <<<<<<<<<<<<<< + * batch_start = batches_ends_view[batches_count] + * batches_count += 1 + */ + __pyx_t_17 = __pyx_v_pos; + __pyx_v_tail_max_tokens = (*((int64_t *) ( /* dim=0 */ (__pyx_v_num_tokens_view.data + __pyx_t_17 * __pyx_v_num_tokens_view.strides[0]) ))); + + /* "fairseq/data/data_utils_fast.pyx":88 + * tail_overflow = tail_num_tokens > max_tokens > 0 + * # In case of a tail overflow finalize two batches + * if tail_overflow: # <<<<<<<<<<<<<< + * batches_count += 1 + * batches_ends_view[batches_count] = pos + */ + } + + /* "fairseq/data/data_utils_fast.pyx":92 + * batches_ends_view[batches_count] = pos + * tail_max_tokens = num_tokens_view[pos] + * batch_start = batches_ends_view[batches_count] # <<<<<<<<<<<<<< + * batches_count += 1 + * new_batch_max_tokens = tail_max_tokens + */ + __pyx_t_17 = __pyx_v_batches_count; + __pyx_v_batch_start = (*((int32_t *) ( /* dim=0 */ (__pyx_v_batches_ends_view.data + __pyx_t_17 * __pyx_v_batches_ends_view.strides[0]) ))); + + /* "fairseq/data/data_utils_fast.pyx":93 + * tail_max_tokens = num_tokens_view[pos] + * batch_start = batches_ends_view[batches_count] + * batches_count += 1 # <<<<<<<<<<<<<< + * new_batch_max_tokens = tail_max_tokens + * + */ + __pyx_v_batches_count = (__pyx_v_batches_count + 1); + + /* "fairseq/data/data_utils_fast.pyx":94 + * batch_start = batches_ends_view[batches_count] + * batches_count += 1 + * new_batch_max_tokens = tail_max_tokens # <<<<<<<<<<<<<< + * + * if overflow or size_matches_with_bsz_mult: + */ + __pyx_v_new_batch_max_tokens = __pyx_v_tail_max_tokens; + + /* "fairseq/data/data_utils_fast.pyx":83 + * new_batch_sentences % bsz_mult == 0) + * + * if overflow: # <<<<<<<<<<<<<< + * tail_num_tokens = tail_max_tokens * \ + * (new_batch_end - batches_ends_view[batches_count]) + */ + } + + /* "fairseq/data/data_utils_fast.pyx":96 + * new_batch_max_tokens = tail_max_tokens + * + * if overflow or size_matches_with_bsz_mult: # <<<<<<<<<<<<<< + * batches_ends_view[batches_count] = new_batch_end + * batch_max_tokens = new_batch_max_tokens + */ + __pyx_t_4 = (__pyx_v_overflow != 0); + if (!__pyx_t_4) { + } else { + __pyx_t_2 = __pyx_t_4; + goto __pyx_L15_bool_binop_done; + } + __pyx_t_4 = (__pyx_v_size_matches_with_bsz_mult != 0); + __pyx_t_2 = __pyx_t_4; + __pyx_L15_bool_binop_done:; + if (__pyx_t_2) { + + /* "fairseq/data/data_utils_fast.pyx":97 + * + * if overflow or size_matches_with_bsz_mult: + * batches_ends_view[batches_count] = new_batch_end # <<<<<<<<<<<<<< + * batch_max_tokens = new_batch_max_tokens + * tail_max_tokens = 0 + */ + __pyx_t_17 = __pyx_v_batches_count; + *((int32_t *) ( /* dim=0 */ (__pyx_v_batches_ends_view.data + __pyx_t_17 * __pyx_v_batches_ends_view.strides[0]) )) = __pyx_v_new_batch_end; + + /* "fairseq/data/data_utils_fast.pyx":98 + * if overflow or size_matches_with_bsz_mult: + * batches_ends_view[batches_count] = new_batch_end + * batch_max_tokens = new_batch_max_tokens # <<<<<<<<<<<<<< + * tail_max_tokens = 0 + * if batches_ends_view[batches_count] != indices_len: + */ + __pyx_v_batch_max_tokens = __pyx_v_new_batch_max_tokens; + + /* "fairseq/data/data_utils_fast.pyx":99 + * batches_ends_view[batches_count] = new_batch_end + * batch_max_tokens = new_batch_max_tokens + * tail_max_tokens = 0 # <<<<<<<<<<<<<< + * if batches_ends_view[batches_count] != indices_len: + * batches_count += 1 + */ + __pyx_v_tail_max_tokens = 0; + + /* "fairseq/data/data_utils_fast.pyx":96 + * new_batch_max_tokens = tail_max_tokens + * + * if overflow or size_matches_with_bsz_mult: # <<<<<<<<<<<<<< + * batches_ends_view[batches_count] = new_batch_end + * batch_max_tokens = new_batch_max_tokens + */ + } + } + + /* "fairseq/data/data_utils_fast.pyx":100 + * batch_max_tokens = new_batch_max_tokens + * tail_max_tokens = 0 + * if batches_ends_view[batches_count] != indices_len: # <<<<<<<<<<<<<< + * batches_count += 1 + * # Memory and time-efficient split + */ + __pyx_t_17 = __pyx_v_batches_count; + __pyx_t_2 = ((*((int32_t *) ( /* dim=0 */ (__pyx_v_batches_ends_view.data + __pyx_t_17 * __pyx_v_batches_ends_view.strides[0]) ))) != __pyx_v_indices_len); + if (__pyx_t_2) { + + /* "fairseq/data/data_utils_fast.pyx":101 + * tail_max_tokens = 0 + * if batches_ends_view[batches_count] != indices_len: + * batches_count += 1 # <<<<<<<<<<<<<< + * # Memory and time-efficient split + * return np.split(indices, batches_ends[:batches_count]) + */ + __pyx_v_batches_count = (__pyx_v_batches_count + 1); + + /* "fairseq/data/data_utils_fast.pyx":100 + * batch_max_tokens = new_batch_max_tokens + * tail_max_tokens = 0 + * if batches_ends_view[batches_count] != indices_len: # <<<<<<<<<<<<<< + * batches_count += 1 + * # Memory and time-efficient split + */ + } + + /* "fairseq/data/data_utils_fast.pyx":103 + * batches_count += 1 + * # Memory and time-efficient split + * return np.split(indices, batches_ends[:batches_count]) # <<<<<<<<<<<<<< + * + * + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_GetModuleGlobalName(__pyx_t_5, __pyx_n_s_np); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 103, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_3 = __Pyx_PyObject_GetAttrStr(__pyx_t_5, __pyx_n_s_split); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 103, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __pyx_t_5 = __Pyx_PyInt_From_int32_t(__pyx_v_batches_count); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 103, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_6 = PySlice_New(Py_None, __pyx_t_5, Py_None); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 103, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __pyx_t_5 = __Pyx_PyObject_GetItem(((PyObject *)__pyx_v_batches_ends), __pyx_t_6); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 103, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + __pyx_t_6 = NULL; + __pyx_t_7 = 0; + #if CYTHON_UNPACK_METHODS + if (unlikely(PyMethod_Check(__pyx_t_3))) { + __pyx_t_6 = PyMethod_GET_SELF(__pyx_t_3); + if (likely(__pyx_t_6)) { + PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_3); + __Pyx_INCREF(__pyx_t_6); + __Pyx_INCREF(function); + __Pyx_DECREF_SET(__pyx_t_3, function); + __pyx_t_7 = 1; + } + } + #endif + { + PyObject *__pyx_callargs[3] = {__pyx_t_6, ((PyObject *)__pyx_v_indices), __pyx_t_5}; + __pyx_t_9 = __Pyx_PyObject_FastCall(__pyx_t_3, __pyx_callargs+1-__pyx_t_7, 2+__pyx_t_7); + __Pyx_XDECREF(__pyx_t_6); __pyx_t_6 = 0; + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + if (unlikely(!__pyx_t_9)) __PYX_ERR(0, 103, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_9); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + } + if (!(likely(PyList_CheckExact(__pyx_t_9))||((__pyx_t_9) == Py_None) || __Pyx_RaiseUnexpectedTypeError("list", __pyx_t_9))) __PYX_ERR(0, 103, __pyx_L1_error) + __pyx_r = ((PyObject*)__pyx_t_9); + __pyx_t_9 = 0; + goto __pyx_L0; + + /* "fairseq/data/data_utils_fast.pyx":20 + * @cython.boundscheck(False) + * @cython.wraparound(False) + * cpdef list batch_by_size_vec( # <<<<<<<<<<<<<< + * np.ndarray[int64_t, ndim=1] indices, + * np.ndarray[int64_t, ndim=1] num_tokens_vec, + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_3); + __Pyx_XDECREF(__pyx_t_5); + __Pyx_XDECREF(__pyx_t_6); + __Pyx_XDECREF(__pyx_t_8); + __Pyx_XDECREF(__pyx_t_9); + __PYX_XCLEAR_MEMVIEW(&__pyx_t_11, 1); + __PYX_XCLEAR_MEMVIEW(&__pyx_t_12, 1); + { PyObject *__pyx_type, *__pyx_value, *__pyx_tb; + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ErrFetch(&__pyx_type, &__pyx_value, &__pyx_tb); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_batches_ends.rcbuffer->pybuffer); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_indices.rcbuffer->pybuffer); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_num_tokens_vec.rcbuffer->pybuffer); + __Pyx_ErrRestore(__pyx_type, __pyx_value, __pyx_tb);} + __Pyx_AddTraceback("fairseq.data.data_utils_fast.batch_by_size_vec", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + goto __pyx_L2; + __pyx_L0:; + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_batches_ends.rcbuffer->pybuffer); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_indices.rcbuffer->pybuffer); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_num_tokens_vec.rcbuffer->pybuffer); + __pyx_L2:; + __Pyx_XDECREF((PyObject *)__pyx_v_batches_ends); + __PYX_XCLEAR_MEMVIEW(&__pyx_v_batches_ends_view, 1); + __PYX_XCLEAR_MEMVIEW(&__pyx_v_num_tokens_view, 1); + __Pyx_XDECREF(__pyx_v_tail_overflow); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* Python wrapper */ +static PyObject *__pyx_pw_7fairseq_4data_15data_utils_fast_1batch_by_size_vec(PyObject *__pyx_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyMethodDef __pyx_mdef_7fairseq_4data_15data_utils_fast_1batch_by_size_vec = {"batch_by_size_vec", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw_7fairseq_4data_15data_utils_fast_1batch_by_size_vec, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}; +static PyObject *__pyx_pw_7fairseq_4data_15data_utils_fast_1batch_by_size_vec(PyObject *__pyx_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + PyArrayObject *__pyx_v_indices = 0; + PyArrayObject *__pyx_v_num_tokens_vec = 0; + int64_t __pyx_v_max_tokens; + int64_t __pyx_v_max_sentences; + int32_t __pyx_v_bsz_mult; + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[5] = {0,0,0,0,0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("batch_by_size_vec (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_indices,&__pyx_n_s_num_tokens_vec,&__pyx_n_s_max_tokens,&__pyx_n_s_max_sentences,&__pyx_n_s_bsz_mult,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 5: values[4] = __Pyx_Arg_FASTCALL(__pyx_args, 4); + CYTHON_FALLTHROUGH; + case 4: values[3] = __Pyx_Arg_FASTCALL(__pyx_args, 3); + CYTHON_FALLTHROUGH; + case 3: values[2] = __Pyx_Arg_FASTCALL(__pyx_args, 2); + CYTHON_FALLTHROUGH; + case 2: values[1] = __Pyx_Arg_FASTCALL(__pyx_args, 1); + CYTHON_FALLTHROUGH; + case 1: values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_FASTCALL(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_indices)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 20, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + CYTHON_FALLTHROUGH; + case 1: + if (likely((values[1] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_num_tokens_vec)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[1]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 20, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("batch_by_size_vec", 1, 5, 5, 1); __PYX_ERR(0, 20, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 2: + if (likely((values[2] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_max_tokens)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[2]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 20, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("batch_by_size_vec", 1, 5, 5, 2); __PYX_ERR(0, 20, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 3: + if (likely((values[3] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_max_sentences)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[3]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 20, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("batch_by_size_vec", 1, 5, 5, 3); __PYX_ERR(0, 20, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 4: + if (likely((values[4] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_bsz_mult)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[4]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 20, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("batch_by_size_vec", 1, 5, 5, 4); __PYX_ERR(0, 20, __pyx_L3_error) + } + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "batch_by_size_vec") < 0)) __PYX_ERR(0, 20, __pyx_L3_error) + } + } else if (unlikely(__pyx_nargs != 5)) { + goto __pyx_L5_argtuple_error; + } else { + values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + values[1] = __Pyx_Arg_FASTCALL(__pyx_args, 1); + values[2] = __Pyx_Arg_FASTCALL(__pyx_args, 2); + values[3] = __Pyx_Arg_FASTCALL(__pyx_args, 3); + values[4] = __Pyx_Arg_FASTCALL(__pyx_args, 4); + } + __pyx_v_indices = ((PyArrayObject *)values[0]); + __pyx_v_num_tokens_vec = ((PyArrayObject *)values[1]); + __pyx_v_max_tokens = __Pyx_PyInt_As_int64_t(values[2]); if (unlikely((__pyx_v_max_tokens == ((int64_t)-1)) && PyErr_Occurred())) __PYX_ERR(0, 23, __pyx_L3_error) + __pyx_v_max_sentences = __Pyx_PyInt_As_int64_t(values[3]); if (unlikely((__pyx_v_max_sentences == ((int64_t)-1)) && PyErr_Occurred())) __PYX_ERR(0, 24, __pyx_L3_error) + __pyx_v_bsz_mult = __Pyx_PyInt_As_int32_t(values[4]); if (unlikely((__pyx_v_bsz_mult == ((int32_t)-1)) && PyErr_Occurred())) __PYX_ERR(0, 25, __pyx_L3_error) + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("batch_by_size_vec", 1, 5, 5, __pyx_nargs); __PYX_ERR(0, 20, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("fairseq.data.data_utils_fast.batch_by_size_vec", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return NULL; + __pyx_L4_argument_unpacking_done:; + if (unlikely(!__Pyx_ArgTypeTest(((PyObject *)__pyx_v_indices), __pyx_ptype_5numpy_ndarray, 1, "indices", 0))) __PYX_ERR(0, 21, __pyx_L1_error) + if (unlikely(!__Pyx_ArgTypeTest(((PyObject *)__pyx_v_num_tokens_vec), __pyx_ptype_5numpy_ndarray, 1, "num_tokens_vec", 0))) __PYX_ERR(0, 22, __pyx_L1_error) + __pyx_r = __pyx_pf_7fairseq_4data_15data_utils_fast_batch_by_size_vec(__pyx_self, __pyx_v_indices, __pyx_v_num_tokens_vec, __pyx_v_max_tokens, __pyx_v_max_sentences, __pyx_v_bsz_mult); + + /* function exit code */ + goto __pyx_L0; + __pyx_L1_error:; + __pyx_r = NULL; + __pyx_L0:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_7fairseq_4data_15data_utils_fast_batch_by_size_vec(CYTHON_UNUSED PyObject *__pyx_self, PyArrayObject *__pyx_v_indices, PyArrayObject *__pyx_v_num_tokens_vec, int64_t __pyx_v_max_tokens, int64_t __pyx_v_max_sentences, int32_t __pyx_v_bsz_mult) { + __Pyx_LocalBuf_ND __pyx_pybuffernd_indices; + __Pyx_Buffer __pyx_pybuffer_indices; + __Pyx_LocalBuf_ND __pyx_pybuffernd_num_tokens_vec; + __Pyx_Buffer __pyx_pybuffer_num_tokens_vec; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("batch_by_size_vec", 1); + __pyx_pybuffer_indices.pybuffer.buf = NULL; + __pyx_pybuffer_indices.refcount = 0; + __pyx_pybuffernd_indices.data = NULL; + __pyx_pybuffernd_indices.rcbuffer = &__pyx_pybuffer_indices; + __pyx_pybuffer_num_tokens_vec.pybuffer.buf = NULL; + __pyx_pybuffer_num_tokens_vec.refcount = 0; + __pyx_pybuffernd_num_tokens_vec.data = NULL; + __pyx_pybuffernd_num_tokens_vec.rcbuffer = &__pyx_pybuffer_num_tokens_vec; + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_indices.rcbuffer->pybuffer, (PyObject*)__pyx_v_indices, &__Pyx_TypeInfo_nn_int64_t, PyBUF_FORMAT| PyBUF_STRIDES, 1, 0, __pyx_stack) == -1)) __PYX_ERR(0, 20, __pyx_L1_error) + } + __pyx_pybuffernd_indices.diminfo[0].strides = __pyx_pybuffernd_indices.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_indices.diminfo[0].shape = __pyx_pybuffernd_indices.rcbuffer->pybuffer.shape[0]; + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_num_tokens_vec.rcbuffer->pybuffer, (PyObject*)__pyx_v_num_tokens_vec, &__Pyx_TypeInfo_nn_int64_t, PyBUF_FORMAT| PyBUF_STRIDES, 1, 0, __pyx_stack) == -1)) __PYX_ERR(0, 20, __pyx_L1_error) + } + __pyx_pybuffernd_num_tokens_vec.diminfo[0].strides = __pyx_pybuffernd_num_tokens_vec.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_num_tokens_vec.diminfo[0].shape = __pyx_pybuffernd_num_tokens_vec.rcbuffer->pybuffer.shape[0]; + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __pyx_f_7fairseq_4data_15data_utils_fast_batch_by_size_vec(__pyx_v_indices, __pyx_v_num_tokens_vec, __pyx_v_max_tokens, __pyx_v_max_sentences, __pyx_v_bsz_mult, 0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 20, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + { PyObject *__pyx_type, *__pyx_value, *__pyx_tb; + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ErrFetch(&__pyx_type, &__pyx_value, &__pyx_tb); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_indices.rcbuffer->pybuffer); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_num_tokens_vec.rcbuffer->pybuffer); + __Pyx_ErrRestore(__pyx_type, __pyx_value, __pyx_tb);} + __Pyx_AddTraceback("fairseq.data.data_utils_fast.batch_by_size_vec", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + goto __pyx_L2; + __pyx_L0:; + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_indices.rcbuffer->pybuffer); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_num_tokens_vec.rcbuffer->pybuffer); + __pyx_L2:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "fairseq/data/data_utils_fast.pyx":108 + * @cython.boundscheck(False) + * @cython.wraparound(False) + * cpdef list batch_by_size_fn( # <<<<<<<<<<<<<< + * np.ndarray[DTYPE_t, ndim=1] indices, + * num_tokens_fn, + */ + +static PyObject *__pyx_pw_7fairseq_4data_15data_utils_fast_3batch_by_size_fn(PyObject *__pyx_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_f_7fairseq_4data_15data_utils_fast_batch_by_size_fn(PyArrayObject *__pyx_v_indices, PyObject *__pyx_v_num_tokens_fn, int64_t __pyx_v_max_tokens, int64_t __pyx_v_max_sentences, int32_t __pyx_v_bsz_mult, CYTHON_UNUSED int __pyx_skip_dispatch) { + int32_t __pyx_v_indices_len; + PyArrayObject *__pyx_v_num_tokens_vec = 0; + __Pyx_memviewslice __pyx_v_indices_view = { 0, 0, { 0 }, { 0 }, { 0 } }; + CYTHON_UNUSED __Pyx_memviewslice __pyx_v_num_tokens_vec_view = { 0, 0, { 0 }, { 0 }, { 0 } }; + int64_t __pyx_v_pos; + __Pyx_LocalBuf_ND __pyx_pybuffernd_indices; + __Pyx_Buffer __pyx_pybuffer_indices; + __Pyx_LocalBuf_ND __pyx_pybuffernd_num_tokens_vec; + __Pyx_Buffer __pyx_pybuffer_num_tokens_vec; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + npy_intp *__pyx_t_1; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + PyObject *__pyx_t_5 = NULL; + PyObject *__pyx_t_6 = NULL; + PyArrayObject *__pyx_t_7 = NULL; + __Pyx_memviewslice __pyx_t_8 = { 0, 0, { 0 }, { 0 }, { 0 } }; + int32_t __pyx_t_9; + int32_t __pyx_t_10; + int64_t __pyx_t_11; + Py_ssize_t __pyx_t_12; + int __pyx_t_13; + int64_t __pyx_t_14; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("batch_by_size_fn", 1); + __pyx_pybuffer_num_tokens_vec.pybuffer.buf = NULL; + __pyx_pybuffer_num_tokens_vec.refcount = 0; + __pyx_pybuffernd_num_tokens_vec.data = NULL; + __pyx_pybuffernd_num_tokens_vec.rcbuffer = &__pyx_pybuffer_num_tokens_vec; + __pyx_pybuffer_indices.pybuffer.buf = NULL; + __pyx_pybuffer_indices.refcount = 0; + __pyx_pybuffernd_indices.data = NULL; + __pyx_pybuffernd_indices.rcbuffer = &__pyx_pybuffer_indices; + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_indices.rcbuffer->pybuffer, (PyObject*)__pyx_v_indices, &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t, PyBUF_FORMAT| PyBUF_STRIDES, 1, 0, __pyx_stack) == -1)) __PYX_ERR(0, 108, __pyx_L1_error) + } + __pyx_pybuffernd_indices.diminfo[0].strides = __pyx_pybuffernd_indices.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_indices.diminfo[0].shape = __pyx_pybuffernd_indices.rcbuffer->pybuffer.shape[0]; + + /* "fairseq/data/data_utils_fast.pyx":115 + * int32_t bsz_mult, + * ): + * cdef int32_t indices_len = indices.shape[0] # <<<<<<<<<<<<<< + * cdef np.ndarray[int64_t, ndim=1] num_tokens_vec = np.zeros(indices_len, + * dtype=np.int64) + */ + __pyx_t_1 = __pyx_f_5numpy_7ndarray_5shape_shape(((PyArrayObject *)__pyx_v_indices)); if (unlikely(__pyx_t_1 == ((npy_intp *)NULL) && PyErr_Occurred())) __PYX_ERR(0, 115, __pyx_L1_error) + __pyx_v_indices_len = (__pyx_t_1[0]); + + /* "fairseq/data/data_utils_fast.pyx":116 + * ): + * cdef int32_t indices_len = indices.shape[0] + * cdef np.ndarray[int64_t, ndim=1] num_tokens_vec = np.zeros(indices_len, # <<<<<<<<<<<<<< + * dtype=np.int64) + * cdef DTYPE_t[:] indices_view = indices + */ + __Pyx_GetModuleGlobalName(__pyx_t_2, __pyx_n_s_np); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 116, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_3 = __Pyx_PyObject_GetAttrStr(__pyx_t_2, __pyx_n_s_zeros); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 116, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_t_2 = __Pyx_PyInt_From_int32_t(__pyx_v_indices_len); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 116, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_4 = PyTuple_New(1); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 116, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_GIVEREF(__pyx_t_2); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 0, __pyx_t_2)) __PYX_ERR(0, 116, __pyx_L1_error); + __pyx_t_2 = 0; + + /* "fairseq/data/data_utils_fast.pyx":117 + * cdef int32_t indices_len = indices.shape[0] + * cdef np.ndarray[int64_t, ndim=1] num_tokens_vec = np.zeros(indices_len, + * dtype=np.int64) # <<<<<<<<<<<<<< + * cdef DTYPE_t[:] indices_view = indices + * cdef DTYPE_t[:] num_tokens_vec_view = num_tokens_vec + */ + __pyx_t_2 = __Pyx_PyDict_NewPresized(1); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 117, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_GetModuleGlobalName(__pyx_t_5, __pyx_n_s_np); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 117, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_6 = __Pyx_PyObject_GetAttrStr(__pyx_t_5, __pyx_n_s_int64); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 117, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + if (PyDict_SetItem(__pyx_t_2, __pyx_n_s_dtype, __pyx_t_6) < 0) __PYX_ERR(0, 117, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + + /* "fairseq/data/data_utils_fast.pyx":116 + * ): + * cdef int32_t indices_len = indices.shape[0] + * cdef np.ndarray[int64_t, ndim=1] num_tokens_vec = np.zeros(indices_len, # <<<<<<<<<<<<<< + * dtype=np.int64) + * cdef DTYPE_t[:] indices_view = indices + */ + __pyx_t_6 = __Pyx_PyObject_Call(__pyx_t_3, __pyx_t_4, __pyx_t_2); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 116, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + if (!(likely(((__pyx_t_6) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_6, __pyx_ptype_5numpy_ndarray))))) __PYX_ERR(0, 116, __pyx_L1_error) + __pyx_t_7 = ((PyArrayObject *)__pyx_t_6); + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_num_tokens_vec.rcbuffer->pybuffer, (PyObject*)__pyx_t_7, &__Pyx_TypeInfo_nn_int64_t, PyBUF_FORMAT| PyBUF_STRIDES| PyBUF_WRITABLE, 1, 0, __pyx_stack) == -1)) { + __pyx_v_num_tokens_vec = ((PyArrayObject *)Py_None); __Pyx_INCREF(Py_None); __pyx_pybuffernd_num_tokens_vec.rcbuffer->pybuffer.buf = NULL; + __PYX_ERR(0, 116, __pyx_L1_error) + } else {__pyx_pybuffernd_num_tokens_vec.diminfo[0].strides = __pyx_pybuffernd_num_tokens_vec.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_num_tokens_vec.diminfo[0].shape = __pyx_pybuffernd_num_tokens_vec.rcbuffer->pybuffer.shape[0]; + } + } + __pyx_t_7 = 0; + __pyx_v_num_tokens_vec = ((PyArrayObject *)__pyx_t_6); + __pyx_t_6 = 0; + + /* "fairseq/data/data_utils_fast.pyx":118 + * cdef np.ndarray[int64_t, ndim=1] num_tokens_vec = np.zeros(indices_len, + * dtype=np.int64) + * cdef DTYPE_t[:] indices_view = indices # <<<<<<<<<<<<<< + * cdef DTYPE_t[:] num_tokens_vec_view = num_tokens_vec + * cdef int64_t pos + */ + __pyx_t_8 = __Pyx_PyObject_to_MemoryviewSlice_ds_nn___pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t(((PyObject *)__pyx_v_indices), PyBUF_WRITABLE); if (unlikely(!__pyx_t_8.memview)) __PYX_ERR(0, 118, __pyx_L1_error) + __pyx_v_indices_view = __pyx_t_8; + __pyx_t_8.memview = NULL; + __pyx_t_8.data = NULL; + + /* "fairseq/data/data_utils_fast.pyx":119 + * dtype=np.int64) + * cdef DTYPE_t[:] indices_view = indices + * cdef DTYPE_t[:] num_tokens_vec_view = num_tokens_vec # <<<<<<<<<<<<<< + * cdef int64_t pos + * for pos in range(indices_len): + */ + __pyx_t_8 = __Pyx_PyObject_to_MemoryviewSlice_ds_nn___pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t(((PyObject *)__pyx_v_num_tokens_vec), PyBUF_WRITABLE); if (unlikely(!__pyx_t_8.memview)) __PYX_ERR(0, 119, __pyx_L1_error) + __pyx_v_num_tokens_vec_view = __pyx_t_8; + __pyx_t_8.memview = NULL; + __pyx_t_8.data = NULL; + + /* "fairseq/data/data_utils_fast.pyx":121 + * cdef DTYPE_t[:] num_tokens_vec_view = num_tokens_vec + * cdef int64_t pos + * for pos in range(indices_len): # <<<<<<<<<<<<<< + * num_tokens_vec[pos] = num_tokens_fn(indices_view[pos]) + * return batch_by_size_vec(indices, num_tokens_vec, max_tokens, + */ + __pyx_t_9 = __pyx_v_indices_len; + __pyx_t_10 = __pyx_t_9; + for (__pyx_t_11 = 0; __pyx_t_11 < __pyx_t_10; __pyx_t_11+=1) { + __pyx_v_pos = __pyx_t_11; + + /* "fairseq/data/data_utils_fast.pyx":122 + * cdef int64_t pos + * for pos in range(indices_len): + * num_tokens_vec[pos] = num_tokens_fn(indices_view[pos]) # <<<<<<<<<<<<<< + * return batch_by_size_vec(indices, num_tokens_vec, max_tokens, + * max_sentences, bsz_mult,) + */ + __pyx_t_12 = __pyx_v_pos; + __pyx_t_2 = __Pyx_PyInt_From_int64_t((*((__pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t *) ( /* dim=0 */ (__pyx_v_indices_view.data + __pyx_t_12 * __pyx_v_indices_view.strides[0]) )))); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 122, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_INCREF(__pyx_v_num_tokens_fn); + __pyx_t_4 = __pyx_v_num_tokens_fn; __pyx_t_3 = NULL; + __pyx_t_13 = 0; + #if CYTHON_UNPACK_METHODS + if (unlikely(PyMethod_Check(__pyx_t_4))) { + __pyx_t_3 = PyMethod_GET_SELF(__pyx_t_4); + if (likely(__pyx_t_3)) { + PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_4); + __Pyx_INCREF(__pyx_t_3); + __Pyx_INCREF(function); + __Pyx_DECREF_SET(__pyx_t_4, function); + __pyx_t_13 = 1; + } + } + #endif + { + PyObject *__pyx_callargs[2] = {__pyx_t_3, __pyx_t_2}; + __pyx_t_6 = __Pyx_PyObject_FastCall(__pyx_t_4, __pyx_callargs+1-__pyx_t_13, 1+__pyx_t_13); + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 122, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + } + __pyx_t_14 = __Pyx_PyInt_As_int64_t(__pyx_t_6); if (unlikely((__pyx_t_14 == ((int64_t)-1)) && PyErr_Occurred())) __PYX_ERR(0, 122, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + __pyx_t_12 = __pyx_v_pos; + *__Pyx_BufPtrStrided1d(int64_t *, __pyx_pybuffernd_num_tokens_vec.rcbuffer->pybuffer.buf, __pyx_t_12, __pyx_pybuffernd_num_tokens_vec.diminfo[0].strides) = __pyx_t_14; + } + + /* "fairseq/data/data_utils_fast.pyx":123 + * for pos in range(indices_len): + * num_tokens_vec[pos] = num_tokens_fn(indices_view[pos]) + * return batch_by_size_vec(indices, num_tokens_vec, max_tokens, # <<<<<<<<<<<<<< + * max_sentences, bsz_mult,) + * + */ + __Pyx_XDECREF(__pyx_r); + + /* "fairseq/data/data_utils_fast.pyx":124 + * num_tokens_vec[pos] = num_tokens_fn(indices_view[pos]) + * return batch_by_size_vec(indices, num_tokens_vec, max_tokens, + * max_sentences, bsz_mult,) # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_6 = __pyx_f_7fairseq_4data_15data_utils_fast_batch_by_size_vec(((PyArrayObject *)__pyx_v_indices), ((PyArrayObject *)__pyx_v_num_tokens_vec), __pyx_v_max_tokens, __pyx_v_max_sentences, __pyx_v_bsz_mult, 0); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 123, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __pyx_r = ((PyObject*)__pyx_t_6); + __pyx_t_6 = 0; + goto __pyx_L0; + + /* "fairseq/data/data_utils_fast.pyx":108 + * @cython.boundscheck(False) + * @cython.wraparound(False) + * cpdef list batch_by_size_fn( # <<<<<<<<<<<<<< + * np.ndarray[DTYPE_t, ndim=1] indices, + * num_tokens_fn, + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_XDECREF(__pyx_t_4); + __Pyx_XDECREF(__pyx_t_5); + __Pyx_XDECREF(__pyx_t_6); + __PYX_XCLEAR_MEMVIEW(&__pyx_t_8, 1); + { PyObject *__pyx_type, *__pyx_value, *__pyx_tb; + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ErrFetch(&__pyx_type, &__pyx_value, &__pyx_tb); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_indices.rcbuffer->pybuffer); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_num_tokens_vec.rcbuffer->pybuffer); + __Pyx_ErrRestore(__pyx_type, __pyx_value, __pyx_tb);} + __Pyx_AddTraceback("fairseq.data.data_utils_fast.batch_by_size_fn", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + goto __pyx_L2; + __pyx_L0:; + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_indices.rcbuffer->pybuffer); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_num_tokens_vec.rcbuffer->pybuffer); + __pyx_L2:; + __Pyx_XDECREF((PyObject *)__pyx_v_num_tokens_vec); + __PYX_XCLEAR_MEMVIEW(&__pyx_v_indices_view, 1); + __PYX_XCLEAR_MEMVIEW(&__pyx_v_num_tokens_vec_view, 1); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* Python wrapper */ +static PyObject *__pyx_pw_7fairseq_4data_15data_utils_fast_3batch_by_size_fn(PyObject *__pyx_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyMethodDef __pyx_mdef_7fairseq_4data_15data_utils_fast_3batch_by_size_fn = {"batch_by_size_fn", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw_7fairseq_4data_15data_utils_fast_3batch_by_size_fn, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}; +static PyObject *__pyx_pw_7fairseq_4data_15data_utils_fast_3batch_by_size_fn(PyObject *__pyx_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + PyArrayObject *__pyx_v_indices = 0; + PyObject *__pyx_v_num_tokens_fn = 0; + int64_t __pyx_v_max_tokens; + int64_t __pyx_v_max_sentences; + int32_t __pyx_v_bsz_mult; + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[5] = {0,0,0,0,0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("batch_by_size_fn (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_indices,&__pyx_n_s_num_tokens_fn,&__pyx_n_s_max_tokens,&__pyx_n_s_max_sentences,&__pyx_n_s_bsz_mult,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 5: values[4] = __Pyx_Arg_FASTCALL(__pyx_args, 4); + CYTHON_FALLTHROUGH; + case 4: values[3] = __Pyx_Arg_FASTCALL(__pyx_args, 3); + CYTHON_FALLTHROUGH; + case 3: values[2] = __Pyx_Arg_FASTCALL(__pyx_args, 2); + CYTHON_FALLTHROUGH; + case 2: values[1] = __Pyx_Arg_FASTCALL(__pyx_args, 1); + CYTHON_FALLTHROUGH; + case 1: values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_FASTCALL(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_indices)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 108, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + CYTHON_FALLTHROUGH; + case 1: + if (likely((values[1] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_num_tokens_fn)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[1]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 108, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("batch_by_size_fn", 1, 5, 5, 1); __PYX_ERR(0, 108, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 2: + if (likely((values[2] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_max_tokens)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[2]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 108, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("batch_by_size_fn", 1, 5, 5, 2); __PYX_ERR(0, 108, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 3: + if (likely((values[3] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_max_sentences)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[3]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 108, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("batch_by_size_fn", 1, 5, 5, 3); __PYX_ERR(0, 108, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 4: + if (likely((values[4] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_bsz_mult)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[4]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 108, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("batch_by_size_fn", 1, 5, 5, 4); __PYX_ERR(0, 108, __pyx_L3_error) + } + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "batch_by_size_fn") < 0)) __PYX_ERR(0, 108, __pyx_L3_error) + } + } else if (unlikely(__pyx_nargs != 5)) { + goto __pyx_L5_argtuple_error; + } else { + values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + values[1] = __Pyx_Arg_FASTCALL(__pyx_args, 1); + values[2] = __Pyx_Arg_FASTCALL(__pyx_args, 2); + values[3] = __Pyx_Arg_FASTCALL(__pyx_args, 3); + values[4] = __Pyx_Arg_FASTCALL(__pyx_args, 4); + } + __pyx_v_indices = ((PyArrayObject *)values[0]); + __pyx_v_num_tokens_fn = values[1]; + __pyx_v_max_tokens = __Pyx_PyInt_As_int64_t(values[2]); if (unlikely((__pyx_v_max_tokens == ((int64_t)-1)) && PyErr_Occurred())) __PYX_ERR(0, 111, __pyx_L3_error) + __pyx_v_max_sentences = __Pyx_PyInt_As_int64_t(values[3]); if (unlikely((__pyx_v_max_sentences == ((int64_t)-1)) && PyErr_Occurred())) __PYX_ERR(0, 112, __pyx_L3_error) + __pyx_v_bsz_mult = __Pyx_PyInt_As_int32_t(values[4]); if (unlikely((__pyx_v_bsz_mult == ((int32_t)-1)) && PyErr_Occurred())) __PYX_ERR(0, 113, __pyx_L3_error) + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("batch_by_size_fn", 1, 5, 5, __pyx_nargs); __PYX_ERR(0, 108, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("fairseq.data.data_utils_fast.batch_by_size_fn", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return NULL; + __pyx_L4_argument_unpacking_done:; + if (unlikely(!__Pyx_ArgTypeTest(((PyObject *)__pyx_v_indices), __pyx_ptype_5numpy_ndarray, 1, "indices", 0))) __PYX_ERR(0, 109, __pyx_L1_error) + __pyx_r = __pyx_pf_7fairseq_4data_15data_utils_fast_2batch_by_size_fn(__pyx_self, __pyx_v_indices, __pyx_v_num_tokens_fn, __pyx_v_max_tokens, __pyx_v_max_sentences, __pyx_v_bsz_mult); + + /* function exit code */ + goto __pyx_L0; + __pyx_L1_error:; + __pyx_r = NULL; + __pyx_L0:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_7fairseq_4data_15data_utils_fast_2batch_by_size_fn(CYTHON_UNUSED PyObject *__pyx_self, PyArrayObject *__pyx_v_indices, PyObject *__pyx_v_num_tokens_fn, int64_t __pyx_v_max_tokens, int64_t __pyx_v_max_sentences, int32_t __pyx_v_bsz_mult) { + __Pyx_LocalBuf_ND __pyx_pybuffernd_indices; + __Pyx_Buffer __pyx_pybuffer_indices; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("batch_by_size_fn", 1); + __pyx_pybuffer_indices.pybuffer.buf = NULL; + __pyx_pybuffer_indices.refcount = 0; + __pyx_pybuffernd_indices.data = NULL; + __pyx_pybuffernd_indices.rcbuffer = &__pyx_pybuffer_indices; + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_indices.rcbuffer->pybuffer, (PyObject*)__pyx_v_indices, &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t, PyBUF_FORMAT| PyBUF_STRIDES, 1, 0, __pyx_stack) == -1)) __PYX_ERR(0, 108, __pyx_L1_error) + } + __pyx_pybuffernd_indices.diminfo[0].strides = __pyx_pybuffernd_indices.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_indices.diminfo[0].shape = __pyx_pybuffernd_indices.rcbuffer->pybuffer.shape[0]; + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __pyx_f_7fairseq_4data_15data_utils_fast_batch_by_size_fn(__pyx_v_indices, __pyx_v_num_tokens_fn, __pyx_v_max_tokens, __pyx_v_max_sentences, __pyx_v_bsz_mult, 0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 108, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + { PyObject *__pyx_type, *__pyx_value, *__pyx_tb; + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ErrFetch(&__pyx_type, &__pyx_value, &__pyx_tb); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_indices.rcbuffer->pybuffer); + __Pyx_ErrRestore(__pyx_type, __pyx_value, __pyx_tb);} + __Pyx_AddTraceback("fairseq.data.data_utils_fast.batch_by_size_fn", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + goto __pyx_L2; + __pyx_L0:; + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_indices.rcbuffer->pybuffer); + __pyx_L2:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "fairseq/data/data_utils_fast.pyx":127 + * + * + * cdef _find_valid_shape( # <<<<<<<<<<<<<< + * DTYPE_t[:, :] shapes_view, + * int64_t num_sentences, + */ + +static PyObject *__pyx_f_7fairseq_4data_15data_utils_fast__find_valid_shape(__Pyx_memviewslice __pyx_v_shapes_view, int64_t __pyx_v_num_sentences, int64_t __pyx_v_num_tokens) { + Py_ssize_t __pyx_v_i; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + Py_ssize_t __pyx_t_1; + Py_ssize_t __pyx_t_2; + Py_ssize_t __pyx_t_3; + int __pyx_t_4; + Py_ssize_t __pyx_t_5; + Py_ssize_t __pyx_t_6; + int __pyx_t_7; + int __pyx_t_8; + PyObject *__pyx_t_9 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("_find_valid_shape", 1); + + /* "fairseq/data/data_utils_fast.pyx":133 + * ): + * """Return index of first valid shape of -1 if none is found.""" + * for i in range(shapes_view.shape[0]): # <<<<<<<<<<<<<< + * if num_sentences <= shapes_view[i][0] and num_tokens <= shapes_view[i][1]: + * return i + */ + __pyx_t_1 = (__pyx_v_shapes_view.shape[0]); + __pyx_t_2 = __pyx_t_1; + for (__pyx_t_3 = 0; __pyx_t_3 < __pyx_t_2; __pyx_t_3+=1) { + __pyx_v_i = __pyx_t_3; + + /* "fairseq/data/data_utils_fast.pyx":134 + * """Return index of first valid shape of -1 if none is found.""" + * for i in range(shapes_view.shape[0]): + * if num_sentences <= shapes_view[i][0] and num_tokens <= shapes_view[i][1]: # <<<<<<<<<<<<<< + * return i + * return -1 + */ + __pyx_t_5 = __pyx_v_i; + __pyx_t_6 = 0; + __pyx_t_7 = -1; + if (__pyx_t_5 < 0) { + __pyx_t_5 += __pyx_v_shapes_view.shape[0]; + if (unlikely(__pyx_t_5 < 0)) __pyx_t_7 = 0; + } else if (unlikely(__pyx_t_5 >= __pyx_v_shapes_view.shape[0])) __pyx_t_7 = 0; + if (__pyx_t_6 < 0) { + __pyx_t_6 += __pyx_v_shapes_view.shape[1]; + if (unlikely(__pyx_t_6 < 0)) __pyx_t_7 = 1; + } else if (unlikely(__pyx_t_6 >= __pyx_v_shapes_view.shape[1])) __pyx_t_7 = 1; + if (unlikely(__pyx_t_7 != -1)) { + __Pyx_RaiseBufferIndexError(__pyx_t_7); + __PYX_ERR(0, 134, __pyx_L1_error) + } + __pyx_t_8 = (__pyx_v_num_sentences <= (*((__pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_shapes_view.data + __pyx_t_5 * __pyx_v_shapes_view.strides[0]) ) + __pyx_t_6 * __pyx_v_shapes_view.strides[1]) )))); + if (__pyx_t_8) { + } else { + __pyx_t_4 = __pyx_t_8; + goto __pyx_L6_bool_binop_done; + } + __pyx_t_6 = __pyx_v_i; + __pyx_t_5 = 1; + __pyx_t_7 = -1; + if (__pyx_t_6 < 0) { + __pyx_t_6 += __pyx_v_shapes_view.shape[0]; + if (unlikely(__pyx_t_6 < 0)) __pyx_t_7 = 0; + } else if (unlikely(__pyx_t_6 >= __pyx_v_shapes_view.shape[0])) __pyx_t_7 = 0; + if (__pyx_t_5 < 0) { + __pyx_t_5 += __pyx_v_shapes_view.shape[1]; + if (unlikely(__pyx_t_5 < 0)) __pyx_t_7 = 1; + } else if (unlikely(__pyx_t_5 >= __pyx_v_shapes_view.shape[1])) __pyx_t_7 = 1; + if (unlikely(__pyx_t_7 != -1)) { + __Pyx_RaiseBufferIndexError(__pyx_t_7); + __PYX_ERR(0, 134, __pyx_L1_error) + } + __pyx_t_8 = (__pyx_v_num_tokens <= (*((__pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_shapes_view.data + __pyx_t_6 * __pyx_v_shapes_view.strides[0]) ) + __pyx_t_5 * __pyx_v_shapes_view.strides[1]) )))); + __pyx_t_4 = __pyx_t_8; + __pyx_L6_bool_binop_done:; + if (__pyx_t_4) { + + /* "fairseq/data/data_utils_fast.pyx":135 + * for i in range(shapes_view.shape[0]): + * if num_sentences <= shapes_view[i][0] and num_tokens <= shapes_view[i][1]: + * return i # <<<<<<<<<<<<<< + * return -1 + * + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_9 = PyInt_FromSsize_t(__pyx_v_i); if (unlikely(!__pyx_t_9)) __PYX_ERR(0, 135, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_9); + __pyx_r = __pyx_t_9; + __pyx_t_9 = 0; + goto __pyx_L0; + + /* "fairseq/data/data_utils_fast.pyx":134 + * """Return index of first valid shape of -1 if none is found.""" + * for i in range(shapes_view.shape[0]): + * if num_sentences <= shapes_view[i][0] and num_tokens <= shapes_view[i][1]: # <<<<<<<<<<<<<< + * return i + * return -1 + */ + } + } + + /* "fairseq/data/data_utils_fast.pyx":136 + * if num_sentences <= shapes_view[i][0] and num_tokens <= shapes_view[i][1]: + * return i + * return -1 # <<<<<<<<<<<<<< + * + * + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(__pyx_int_neg_1); + __pyx_r = __pyx_int_neg_1; + goto __pyx_L0; + + /* "fairseq/data/data_utils_fast.pyx":127 + * + * + * cdef _find_valid_shape( # <<<<<<<<<<<<<< + * DTYPE_t[:, :] shapes_view, + * int64_t num_sentences, + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_9); + __Pyx_AddTraceback("fairseq.data.data_utils_fast._find_valid_shape", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "fairseq/data/data_utils_fast.pyx":140 + * + * @cython.cdivision(True) + * cpdef list batch_fixed_shapes_fast( # <<<<<<<<<<<<<< + * np.ndarray[DTYPE_t, ndim=1] indices, + * num_tokens_fn, + */ + +static PyObject *__pyx_pw_7fairseq_4data_15data_utils_fast_5batch_fixed_shapes_fast(PyObject *__pyx_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_f_7fairseq_4data_15data_utils_fast_batch_fixed_shapes_fast(PyArrayObject *__pyx_v_indices, PyObject *__pyx_v_num_tokens_fn, PyArrayObject *__pyx_v_fixed_shapes_sorted, CYTHON_UNUSED int __pyx_skip_dispatch) { + int64_t __pyx_v_sample_len; + PyObject *__pyx_v_sample_lens = 0; + PyObject *__pyx_v_batch = 0; + PyObject *__pyx_v_batches = 0; + int64_t __pyx_v_i; + int64_t __pyx_v_idx; + int64_t __pyx_v_num_tokens; + __Pyx_memviewslice __pyx_v_indices_view = { 0, 0, { 0 }, { 0 }, { 0 } }; + __Pyx_memviewslice __pyx_v_shapes_view = { 0, 0, { 0 }, { 0 }, { 0 } }; + PyObject *__pyx_v_shape_idx = NULL; + __Pyx_LocalBuf_ND __pyx_pybuffernd_fixed_shapes_sorted; + __Pyx_Buffer __pyx_pybuffer_fixed_shapes_sorted; + __Pyx_LocalBuf_ND __pyx_pybuffernd_indices; + __Pyx_Buffer __pyx_pybuffer_indices; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + __Pyx_memviewslice __pyx_t_2 = { 0, 0, { 0 }, { 0 }, { 0 } }; + __Pyx_memviewslice __pyx_t_3 = { 0, 0, { 0 }, { 0 }, { 0 } }; + Py_ssize_t __pyx_t_4; + Py_ssize_t __pyx_t_5; + int64_t __pyx_t_6; + Py_ssize_t __pyx_t_7; + int __pyx_t_8; + PyObject *__pyx_t_9 = NULL; + PyObject *__pyx_t_10 = NULL; + PyObject *__pyx_t_11 = NULL; + int64_t __pyx_t_12; + int __pyx_t_13; + int64_t __pyx_t_14; + int64_t __pyx_t_15; + int __pyx_t_16; + Py_ssize_t __pyx_t_17; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("batch_fixed_shapes_fast", 1); + __pyx_pybuffer_indices.pybuffer.buf = NULL; + __pyx_pybuffer_indices.refcount = 0; + __pyx_pybuffernd_indices.data = NULL; + __pyx_pybuffernd_indices.rcbuffer = &__pyx_pybuffer_indices; + __pyx_pybuffer_fixed_shapes_sorted.pybuffer.buf = NULL; + __pyx_pybuffer_fixed_shapes_sorted.refcount = 0; + __pyx_pybuffernd_fixed_shapes_sorted.data = NULL; + __pyx_pybuffernd_fixed_shapes_sorted.rcbuffer = &__pyx_pybuffer_fixed_shapes_sorted; + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_indices.rcbuffer->pybuffer, (PyObject*)__pyx_v_indices, &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t, PyBUF_FORMAT| PyBUF_STRIDES, 1, 0, __pyx_stack) == -1)) __PYX_ERR(0, 140, __pyx_L1_error) + } + __pyx_pybuffernd_indices.diminfo[0].strides = __pyx_pybuffernd_indices.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_indices.diminfo[0].shape = __pyx_pybuffernd_indices.rcbuffer->pybuffer.shape[0]; + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_fixed_shapes_sorted.rcbuffer->pybuffer, (PyObject*)__pyx_v_fixed_shapes_sorted, &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t, PyBUF_FORMAT| PyBUF_STRIDES, 2, 0, __pyx_stack) == -1)) __PYX_ERR(0, 140, __pyx_L1_error) + } + __pyx_pybuffernd_fixed_shapes_sorted.diminfo[0].strides = __pyx_pybuffernd_fixed_shapes_sorted.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_fixed_shapes_sorted.diminfo[0].shape = __pyx_pybuffernd_fixed_shapes_sorted.rcbuffer->pybuffer.shape[0]; __pyx_pybuffernd_fixed_shapes_sorted.diminfo[1].strides = __pyx_pybuffernd_fixed_shapes_sorted.rcbuffer->pybuffer.strides[1]; __pyx_pybuffernd_fixed_shapes_sorted.diminfo[1].shape = __pyx_pybuffernd_fixed_shapes_sorted.rcbuffer->pybuffer.shape[1]; + + /* "fairseq/data/data_utils_fast.pyx":145 + * np.ndarray[DTYPE_t, ndim=2] fixed_shapes_sorted, + * ): + * cdef int64_t sample_len = 0 # <<<<<<<<<<<<<< + * cdef list sample_lens = [] + * cdef list batch = [] + */ + __pyx_v_sample_len = 0; + + /* "fairseq/data/data_utils_fast.pyx":146 + * ): + * cdef int64_t sample_len = 0 + * cdef list sample_lens = [] # <<<<<<<<<<<<<< + * cdef list batch = [] + * cdef list batches = [] + */ + __pyx_t_1 = PyList_New(0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 146, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_v_sample_lens = ((PyObject*)__pyx_t_1); + __pyx_t_1 = 0; + + /* "fairseq/data/data_utils_fast.pyx":147 + * cdef int64_t sample_len = 0 + * cdef list sample_lens = [] + * cdef list batch = [] # <<<<<<<<<<<<<< + * cdef list batches = [] + * cdef int64_t mod_len + */ + __pyx_t_1 = PyList_New(0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 147, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_v_batch = ((PyObject*)__pyx_t_1); + __pyx_t_1 = 0; + + /* "fairseq/data/data_utils_fast.pyx":148 + * cdef list sample_lens = [] + * cdef list batch = [] + * cdef list batches = [] # <<<<<<<<<<<<<< + * cdef int64_t mod_len + * cdef int64_t i + */ + __pyx_t_1 = PyList_New(0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 148, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_v_batches = ((PyObject*)__pyx_t_1); + __pyx_t_1 = 0; + + /* "fairseq/data/data_utils_fast.pyx":153 + * cdef int64_t idx + * cdef int64_t num_tokens + * cdef DTYPE_t[:] indices_view = indices # <<<<<<<<<<<<<< + * cdef DTYPE_t[:, :] shapes_view = fixed_shapes_sorted + * + */ + __pyx_t_2 = __Pyx_PyObject_to_MemoryviewSlice_ds_nn___pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t(((PyObject *)__pyx_v_indices), PyBUF_WRITABLE); if (unlikely(!__pyx_t_2.memview)) __PYX_ERR(0, 153, __pyx_L1_error) + __pyx_v_indices_view = __pyx_t_2; + __pyx_t_2.memview = NULL; + __pyx_t_2.data = NULL; + + /* "fairseq/data/data_utils_fast.pyx":154 + * cdef int64_t num_tokens + * cdef DTYPE_t[:] indices_view = indices + * cdef DTYPE_t[:, :] shapes_view = fixed_shapes_sorted # <<<<<<<<<<<<<< + * + * for i in range(len(indices_view)): + */ + __pyx_t_3 = __Pyx_PyObject_to_MemoryviewSlice_dsds_nn___pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t(((PyObject *)__pyx_v_fixed_shapes_sorted), PyBUF_WRITABLE); if (unlikely(!__pyx_t_3.memview)) __PYX_ERR(0, 154, __pyx_L1_error) + __pyx_v_shapes_view = __pyx_t_3; + __pyx_t_3.memview = NULL; + __pyx_t_3.data = NULL; + + /* "fairseq/data/data_utils_fast.pyx":156 + * cdef DTYPE_t[:, :] shapes_view = fixed_shapes_sorted + * + * for i in range(len(indices_view)): # <<<<<<<<<<<<<< + * idx = indices_view[i] + * num_tokens = num_tokens_fn(idx) + */ + __pyx_t_4 = __Pyx_MemoryView_Len(__pyx_v_indices_view); + __pyx_t_5 = __pyx_t_4; + for (__pyx_t_6 = 0; __pyx_t_6 < __pyx_t_5; __pyx_t_6+=1) { + __pyx_v_i = __pyx_t_6; + + /* "fairseq/data/data_utils_fast.pyx":157 + * + * for i in range(len(indices_view)): + * idx = indices_view[i] # <<<<<<<<<<<<<< + * num_tokens = num_tokens_fn(idx) + * sample_lens.append(num_tokens) + */ + __pyx_t_7 = __pyx_v_i; + __pyx_t_8 = -1; + if (__pyx_t_7 < 0) { + __pyx_t_7 += __pyx_v_indices_view.shape[0]; + if (unlikely(__pyx_t_7 < 0)) __pyx_t_8 = 0; + } else if (unlikely(__pyx_t_7 >= __pyx_v_indices_view.shape[0])) __pyx_t_8 = 0; + if (unlikely(__pyx_t_8 != -1)) { + __Pyx_RaiseBufferIndexError(__pyx_t_8); + __PYX_ERR(0, 157, __pyx_L1_error) + } + __pyx_v_idx = (*((__pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t *) ( /* dim=0 */ (__pyx_v_indices_view.data + __pyx_t_7 * __pyx_v_indices_view.strides[0]) ))); + + /* "fairseq/data/data_utils_fast.pyx":158 + * for i in range(len(indices_view)): + * idx = indices_view[i] + * num_tokens = num_tokens_fn(idx) # <<<<<<<<<<<<<< + * sample_lens.append(num_tokens) + * sample_len = max(sample_len, num_tokens) + */ + __pyx_t_9 = __Pyx_PyInt_From_int64_t(__pyx_v_idx); if (unlikely(!__pyx_t_9)) __PYX_ERR(0, 158, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_9); + __Pyx_INCREF(__pyx_v_num_tokens_fn); + __pyx_t_10 = __pyx_v_num_tokens_fn; __pyx_t_11 = NULL; + __pyx_t_8 = 0; + #if CYTHON_UNPACK_METHODS + if (unlikely(PyMethod_Check(__pyx_t_10))) { + __pyx_t_11 = PyMethod_GET_SELF(__pyx_t_10); + if (likely(__pyx_t_11)) { + PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_10); + __Pyx_INCREF(__pyx_t_11); + __Pyx_INCREF(function); + __Pyx_DECREF_SET(__pyx_t_10, function); + __pyx_t_8 = 1; + } + } + #endif + { + PyObject *__pyx_callargs[2] = {__pyx_t_11, __pyx_t_9}; + __pyx_t_1 = __Pyx_PyObject_FastCall(__pyx_t_10, __pyx_callargs+1-__pyx_t_8, 1+__pyx_t_8); + __Pyx_XDECREF(__pyx_t_11); __pyx_t_11 = 0; + __Pyx_DECREF(__pyx_t_9); __pyx_t_9 = 0; + if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 158, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_10); __pyx_t_10 = 0; + } + __pyx_t_12 = __Pyx_PyInt_As_int64_t(__pyx_t_1); if (unlikely((__pyx_t_12 == ((int64_t)-1)) && PyErr_Occurred())) __PYX_ERR(0, 158, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_v_num_tokens = __pyx_t_12; + + /* "fairseq/data/data_utils_fast.pyx":159 + * idx = indices_view[i] + * num_tokens = num_tokens_fn(idx) + * sample_lens.append(num_tokens) # <<<<<<<<<<<<<< + * sample_len = max(sample_len, num_tokens) + * + */ + __pyx_t_1 = __Pyx_PyInt_From_int64_t(__pyx_v_num_tokens); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 159, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_13 = __Pyx_PyList_Append(__pyx_v_sample_lens, __pyx_t_1); if (unlikely(__pyx_t_13 == ((int)-1))) __PYX_ERR(0, 159, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + + /* "fairseq/data/data_utils_fast.pyx":160 + * num_tokens = num_tokens_fn(idx) + * sample_lens.append(num_tokens) + * sample_len = max(sample_len, num_tokens) # <<<<<<<<<<<<<< + * + * shape_idx = _find_valid_shape(shapes_view, len(batch) + 1, sample_len) + */ + __pyx_t_12 = __pyx_v_num_tokens; + __pyx_t_14 = __pyx_v_sample_len; + __pyx_t_16 = (__pyx_t_12 > __pyx_t_14); + if (__pyx_t_16) { + __pyx_t_15 = __pyx_t_12; + } else { + __pyx_t_15 = __pyx_t_14; + } + __pyx_v_sample_len = __pyx_t_15; + + /* "fairseq/data/data_utils_fast.pyx":162 + * sample_len = max(sample_len, num_tokens) + * + * shape_idx = _find_valid_shape(shapes_view, len(batch) + 1, sample_len) # <<<<<<<<<<<<<< + * if shape_idx == -1: + * batches.append(batch) + */ + __pyx_t_17 = __Pyx_PyList_GET_SIZE(__pyx_v_batch); if (unlikely(__pyx_t_17 == ((Py_ssize_t)-1))) __PYX_ERR(0, 162, __pyx_L1_error) + __pyx_t_1 = __pyx_f_7fairseq_4data_15data_utils_fast__find_valid_shape(__pyx_v_shapes_view, (__pyx_t_17 + 1), __pyx_v_sample_len); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 162, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_XDECREF_SET(__pyx_v_shape_idx, __pyx_t_1); + __pyx_t_1 = 0; + + /* "fairseq/data/data_utils_fast.pyx":163 + * + * shape_idx = _find_valid_shape(shapes_view, len(batch) + 1, sample_len) + * if shape_idx == -1: # <<<<<<<<<<<<<< + * batches.append(batch) + * batch = [] + */ + __pyx_t_16 = (__Pyx_PyInt_BoolEqObjC(__pyx_v_shape_idx, __pyx_int_neg_1, -1L, 0)); if (unlikely((__pyx_t_16 < 0))) __PYX_ERR(0, 163, __pyx_L1_error) + if (__pyx_t_16) { + + /* "fairseq/data/data_utils_fast.pyx":164 + * shape_idx = _find_valid_shape(shapes_view, len(batch) + 1, sample_len) + * if shape_idx == -1: + * batches.append(batch) # <<<<<<<<<<<<<< + * batch = [] + * sample_lens = [] + */ + __pyx_t_13 = __Pyx_PyList_Append(__pyx_v_batches, __pyx_v_batch); if (unlikely(__pyx_t_13 == ((int)-1))) __PYX_ERR(0, 164, __pyx_L1_error) + + /* "fairseq/data/data_utils_fast.pyx":165 + * if shape_idx == -1: + * batches.append(batch) + * batch = [] # <<<<<<<<<<<<<< + * sample_lens = [] + * sample_len = 0 + */ + __pyx_t_1 = PyList_New(0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 165, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF_SET(__pyx_v_batch, ((PyObject*)__pyx_t_1)); + __pyx_t_1 = 0; + + /* "fairseq/data/data_utils_fast.pyx":166 + * batches.append(batch) + * batch = [] + * sample_lens = [] # <<<<<<<<<<<<<< + * sample_len = 0 + * shapes_view = fixed_shapes_sorted + */ + __pyx_t_1 = PyList_New(0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 166, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF_SET(__pyx_v_sample_lens, ((PyObject*)__pyx_t_1)); + __pyx_t_1 = 0; + + /* "fairseq/data/data_utils_fast.pyx":167 + * batch = [] + * sample_lens = [] + * sample_len = 0 # <<<<<<<<<<<<<< + * shapes_view = fixed_shapes_sorted + * elif shape_idx > 0: + */ + __pyx_v_sample_len = 0; + + /* "fairseq/data/data_utils_fast.pyx":168 + * sample_lens = [] + * sample_len = 0 + * shapes_view = fixed_shapes_sorted # <<<<<<<<<<<<<< + * elif shape_idx > 0: + * # small optimization for the next call to _find_valid_shape + */ + __pyx_t_3 = __Pyx_PyObject_to_MemoryviewSlice_dsds_nn___pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t(((PyObject *)__pyx_v_fixed_shapes_sorted), PyBUF_WRITABLE); if (unlikely(!__pyx_t_3.memview)) __PYX_ERR(0, 168, __pyx_L1_error) + __PYX_XCLEAR_MEMVIEW(&__pyx_v_shapes_view, 1); + __pyx_v_shapes_view = __pyx_t_3; + __pyx_t_3.memview = NULL; + __pyx_t_3.data = NULL; + + /* "fairseq/data/data_utils_fast.pyx":163 + * + * shape_idx = _find_valid_shape(shapes_view, len(batch) + 1, sample_len) + * if shape_idx == -1: # <<<<<<<<<<<<<< + * batches.append(batch) + * batch = [] + */ + goto __pyx_L5; + } + + /* "fairseq/data/data_utils_fast.pyx":169 + * sample_len = 0 + * shapes_view = fixed_shapes_sorted + * elif shape_idx > 0: # <<<<<<<<<<<<<< + * # small optimization for the next call to _find_valid_shape + * shapes_view = shapes_view[shape_idx:] + */ + __pyx_t_1 = PyObject_RichCompare(__pyx_v_shape_idx, __pyx_int_0, Py_GT); __Pyx_XGOTREF(__pyx_t_1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 169, __pyx_L1_error) + __pyx_t_16 = __Pyx_PyObject_IsTrue(__pyx_t_1); if (unlikely((__pyx_t_16 < 0))) __PYX_ERR(0, 169, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + if (__pyx_t_16) { + + /* "fairseq/data/data_utils_fast.pyx":171 + * elif shape_idx > 0: + * # small optimization for the next call to _find_valid_shape + * shapes_view = shapes_view[shape_idx:] # <<<<<<<<<<<<<< + * + * batch.append(idx) + */ + __pyx_t_17 = __Pyx_PyIndex_AsSsize_t(__pyx_v_shape_idx); if (unlikely((__pyx_t_17 == (Py_ssize_t)-1) && PyErr_Occurred())) __PYX_ERR(0, 171, __pyx_L1_error) + __pyx_t_3.data = __pyx_v_shapes_view.data; + __pyx_t_3.memview = __pyx_v_shapes_view.memview; + __PYX_INC_MEMVIEW(&__pyx_t_3, 1); + __pyx_t_8 = -1; + if (unlikely(__pyx_memoryview_slice_memviewslice( + &__pyx_t_3, + __pyx_v_shapes_view.shape[0], __pyx_v_shapes_view.strides[0], __pyx_v_shapes_view.suboffsets[0], + 0, + 0, + &__pyx_t_8, + __pyx_t_17, + 0, + 0, + 1, + 0, + 0, + 1) < 0)) +{ + __PYX_ERR(0, 171, __pyx_L1_error) +} + +__pyx_t_3.shape[1] = __pyx_v_shapes_view.shape[1]; +__pyx_t_3.strides[1] = __pyx_v_shapes_view.strides[1]; + __pyx_t_3.suboffsets[1] = -1; + +__PYX_XCLEAR_MEMVIEW(&__pyx_v_shapes_view, 1); + __pyx_v_shapes_view = __pyx_t_3; + __pyx_t_3.memview = NULL; + __pyx_t_3.data = NULL; + + /* "fairseq/data/data_utils_fast.pyx":169 + * sample_len = 0 + * shapes_view = fixed_shapes_sorted + * elif shape_idx > 0: # <<<<<<<<<<<<<< + * # small optimization for the next call to _find_valid_shape + * shapes_view = shapes_view[shape_idx:] + */ + } + __pyx_L5:; + + /* "fairseq/data/data_utils_fast.pyx":173 + * shapes_view = shapes_view[shape_idx:] + * + * batch.append(idx) # <<<<<<<<<<<<<< + * + * if len(batch) > 0: + */ + __pyx_t_1 = __Pyx_PyInt_From_int64_t(__pyx_v_idx); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 173, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_13 = __Pyx_PyList_Append(__pyx_v_batch, __pyx_t_1); if (unlikely(__pyx_t_13 == ((int)-1))) __PYX_ERR(0, 173, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + } + + /* "fairseq/data/data_utils_fast.pyx":175 + * batch.append(idx) + * + * if len(batch) > 0: # <<<<<<<<<<<<<< + * batches.append(batch) + * + */ + __pyx_t_4 = __Pyx_PyList_GET_SIZE(__pyx_v_batch); if (unlikely(__pyx_t_4 == ((Py_ssize_t)-1))) __PYX_ERR(0, 175, __pyx_L1_error) + __pyx_t_16 = (__pyx_t_4 > 0); + if (__pyx_t_16) { + + /* "fairseq/data/data_utils_fast.pyx":176 + * + * if len(batch) > 0: + * batches.append(batch) # <<<<<<<<<<<<<< + * + * return batches + */ + __pyx_t_13 = __Pyx_PyList_Append(__pyx_v_batches, __pyx_v_batch); if (unlikely(__pyx_t_13 == ((int)-1))) __PYX_ERR(0, 176, __pyx_L1_error) + + /* "fairseq/data/data_utils_fast.pyx":175 + * batch.append(idx) + * + * if len(batch) > 0: # <<<<<<<<<<<<<< + * batches.append(batch) + * + */ + } + + /* "fairseq/data/data_utils_fast.pyx":178 + * batches.append(batch) + * + * return batches # <<<<<<<<<<<<<< + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(__pyx_v_batches); + __pyx_r = __pyx_v_batches; + goto __pyx_L0; + + /* "fairseq/data/data_utils_fast.pyx":140 + * + * @cython.cdivision(True) + * cpdef list batch_fixed_shapes_fast( # <<<<<<<<<<<<<< + * np.ndarray[DTYPE_t, ndim=1] indices, + * num_tokens_fn, + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __PYX_XCLEAR_MEMVIEW(&__pyx_t_2, 1); + __PYX_XCLEAR_MEMVIEW(&__pyx_t_3, 1); + __Pyx_XDECREF(__pyx_t_9); + __Pyx_XDECREF(__pyx_t_10); + __Pyx_XDECREF(__pyx_t_11); + { PyObject *__pyx_type, *__pyx_value, *__pyx_tb; + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ErrFetch(&__pyx_type, &__pyx_value, &__pyx_tb); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_fixed_shapes_sorted.rcbuffer->pybuffer); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_indices.rcbuffer->pybuffer); + __Pyx_ErrRestore(__pyx_type, __pyx_value, __pyx_tb);} + __Pyx_AddTraceback("fairseq.data.data_utils_fast.batch_fixed_shapes_fast", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + goto __pyx_L2; + __pyx_L0:; + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_fixed_shapes_sorted.rcbuffer->pybuffer); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_indices.rcbuffer->pybuffer); + __pyx_L2:; + __Pyx_XDECREF(__pyx_v_sample_lens); + __Pyx_XDECREF(__pyx_v_batch); + __Pyx_XDECREF(__pyx_v_batches); + __PYX_XCLEAR_MEMVIEW(&__pyx_v_indices_view, 1); + __PYX_XCLEAR_MEMVIEW(&__pyx_v_shapes_view, 1); + __Pyx_XDECREF(__pyx_v_shape_idx); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* Python wrapper */ +static PyObject *__pyx_pw_7fairseq_4data_15data_utils_fast_5batch_fixed_shapes_fast(PyObject *__pyx_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyMethodDef __pyx_mdef_7fairseq_4data_15data_utils_fast_5batch_fixed_shapes_fast = {"batch_fixed_shapes_fast", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw_7fairseq_4data_15data_utils_fast_5batch_fixed_shapes_fast, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}; +static PyObject *__pyx_pw_7fairseq_4data_15data_utils_fast_5batch_fixed_shapes_fast(PyObject *__pyx_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + PyArrayObject *__pyx_v_indices = 0; + PyObject *__pyx_v_num_tokens_fn = 0; + PyArrayObject *__pyx_v_fixed_shapes_sorted = 0; + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[3] = {0,0,0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("batch_fixed_shapes_fast (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_indices,&__pyx_n_s_num_tokens_fn,&__pyx_n_s_fixed_shapes_sorted,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 3: values[2] = __Pyx_Arg_FASTCALL(__pyx_args, 2); + CYTHON_FALLTHROUGH; + case 2: values[1] = __Pyx_Arg_FASTCALL(__pyx_args, 1); + CYTHON_FALLTHROUGH; + case 1: values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_FASTCALL(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_indices)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 140, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + CYTHON_FALLTHROUGH; + case 1: + if (likely((values[1] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_num_tokens_fn)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[1]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 140, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("batch_fixed_shapes_fast", 1, 3, 3, 1); __PYX_ERR(0, 140, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 2: + if (likely((values[2] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_fixed_shapes_sorted)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[2]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 140, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("batch_fixed_shapes_fast", 1, 3, 3, 2); __PYX_ERR(0, 140, __pyx_L3_error) + } + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "batch_fixed_shapes_fast") < 0)) __PYX_ERR(0, 140, __pyx_L3_error) + } + } else if (unlikely(__pyx_nargs != 3)) { + goto __pyx_L5_argtuple_error; + } else { + values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + values[1] = __Pyx_Arg_FASTCALL(__pyx_args, 1); + values[2] = __Pyx_Arg_FASTCALL(__pyx_args, 2); + } + __pyx_v_indices = ((PyArrayObject *)values[0]); + __pyx_v_num_tokens_fn = values[1]; + __pyx_v_fixed_shapes_sorted = ((PyArrayObject *)values[2]); + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("batch_fixed_shapes_fast", 1, 3, 3, __pyx_nargs); __PYX_ERR(0, 140, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("fairseq.data.data_utils_fast.batch_fixed_shapes_fast", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return NULL; + __pyx_L4_argument_unpacking_done:; + if (unlikely(!__Pyx_ArgTypeTest(((PyObject *)__pyx_v_indices), __pyx_ptype_5numpy_ndarray, 1, "indices", 0))) __PYX_ERR(0, 141, __pyx_L1_error) + if (unlikely(!__Pyx_ArgTypeTest(((PyObject *)__pyx_v_fixed_shapes_sorted), __pyx_ptype_5numpy_ndarray, 1, "fixed_shapes_sorted", 0))) __PYX_ERR(0, 143, __pyx_L1_error) + __pyx_r = __pyx_pf_7fairseq_4data_15data_utils_fast_4batch_fixed_shapes_fast(__pyx_self, __pyx_v_indices, __pyx_v_num_tokens_fn, __pyx_v_fixed_shapes_sorted); + + /* function exit code */ + goto __pyx_L0; + __pyx_L1_error:; + __pyx_r = NULL; + __pyx_L0:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_7fairseq_4data_15data_utils_fast_4batch_fixed_shapes_fast(CYTHON_UNUSED PyObject *__pyx_self, PyArrayObject *__pyx_v_indices, PyObject *__pyx_v_num_tokens_fn, PyArrayObject *__pyx_v_fixed_shapes_sorted) { + __Pyx_LocalBuf_ND __pyx_pybuffernd_fixed_shapes_sorted; + __Pyx_Buffer __pyx_pybuffer_fixed_shapes_sorted; + __Pyx_LocalBuf_ND __pyx_pybuffernd_indices; + __Pyx_Buffer __pyx_pybuffer_indices; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("batch_fixed_shapes_fast", 1); + __pyx_pybuffer_indices.pybuffer.buf = NULL; + __pyx_pybuffer_indices.refcount = 0; + __pyx_pybuffernd_indices.data = NULL; + __pyx_pybuffernd_indices.rcbuffer = &__pyx_pybuffer_indices; + __pyx_pybuffer_fixed_shapes_sorted.pybuffer.buf = NULL; + __pyx_pybuffer_fixed_shapes_sorted.refcount = 0; + __pyx_pybuffernd_fixed_shapes_sorted.data = NULL; + __pyx_pybuffernd_fixed_shapes_sorted.rcbuffer = &__pyx_pybuffer_fixed_shapes_sorted; + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_indices.rcbuffer->pybuffer, (PyObject*)__pyx_v_indices, &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t, PyBUF_FORMAT| PyBUF_STRIDES, 1, 0, __pyx_stack) == -1)) __PYX_ERR(0, 140, __pyx_L1_error) + } + __pyx_pybuffernd_indices.diminfo[0].strides = __pyx_pybuffernd_indices.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_indices.diminfo[0].shape = __pyx_pybuffernd_indices.rcbuffer->pybuffer.shape[0]; + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_fixed_shapes_sorted.rcbuffer->pybuffer, (PyObject*)__pyx_v_fixed_shapes_sorted, &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t, PyBUF_FORMAT| PyBUF_STRIDES, 2, 0, __pyx_stack) == -1)) __PYX_ERR(0, 140, __pyx_L1_error) + } + __pyx_pybuffernd_fixed_shapes_sorted.diminfo[0].strides = __pyx_pybuffernd_fixed_shapes_sorted.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_fixed_shapes_sorted.diminfo[0].shape = __pyx_pybuffernd_fixed_shapes_sorted.rcbuffer->pybuffer.shape[0]; __pyx_pybuffernd_fixed_shapes_sorted.diminfo[1].strides = __pyx_pybuffernd_fixed_shapes_sorted.rcbuffer->pybuffer.strides[1]; __pyx_pybuffernd_fixed_shapes_sorted.diminfo[1].shape = __pyx_pybuffernd_fixed_shapes_sorted.rcbuffer->pybuffer.shape[1]; + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __pyx_f_7fairseq_4data_15data_utils_fast_batch_fixed_shapes_fast(__pyx_v_indices, __pyx_v_num_tokens_fn, __pyx_v_fixed_shapes_sorted, 0); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 140, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + { PyObject *__pyx_type, *__pyx_value, *__pyx_tb; + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ErrFetch(&__pyx_type, &__pyx_value, &__pyx_tb); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_fixed_shapes_sorted.rcbuffer->pybuffer); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_indices.rcbuffer->pybuffer); + __Pyx_ErrRestore(__pyx_type, __pyx_value, __pyx_tb);} + __Pyx_AddTraceback("fairseq.data.data_utils_fast.batch_fixed_shapes_fast", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + goto __pyx_L2; + __pyx_L0:; + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_fixed_shapes_sorted.rcbuffer->pybuffer); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_indices.rcbuffer->pybuffer); + __pyx_L2:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} +static struct __pyx_vtabstruct_array __pyx_vtable_array; + +static PyObject *__pyx_tp_new_array(PyTypeObject *t, PyObject *a, PyObject *k) { + struct __pyx_array_obj *p; + PyObject *o; + #if CYTHON_COMPILING_IN_LIMITED_API + allocfunc alloc_func = (allocfunc)PyType_GetSlot(t, Py_tp_alloc); + o = alloc_func(t, 0); + #else + if (likely(!__Pyx_PyType_HasFeature(t, Py_TPFLAGS_IS_ABSTRACT))) { + o = (*t->tp_alloc)(t, 0); + } else { + o = (PyObject *) PyBaseObject_Type.tp_new(t, __pyx_empty_tuple, 0); + } + if (unlikely(!o)) return 0; + #endif + p = ((struct __pyx_array_obj *)o); + p->__pyx_vtab = __pyx_vtabptr_array; + p->mode = ((PyObject*)Py_None); Py_INCREF(Py_None); + p->_format = ((PyObject*)Py_None); Py_INCREF(Py_None); + if (unlikely(__pyx_array___cinit__(o, a, k) < 0)) goto bad; + return o; + bad: + Py_DECREF(o); o = 0; + return NULL; +} + +static void __pyx_tp_dealloc_array(PyObject *o) { + struct __pyx_array_obj *p = (struct __pyx_array_obj *)o; + #if CYTHON_USE_TP_FINALIZE + if (unlikely((PY_VERSION_HEX >= 0x03080000 || __Pyx_PyType_HasFeature(Py_TYPE(o), Py_TPFLAGS_HAVE_FINALIZE)) && __Pyx_PyObject_GetSlot(o, tp_finalize, destructor)) && (!PyType_IS_GC(Py_TYPE(o)) || !__Pyx_PyObject_GC_IsFinalized(o))) { + if (__Pyx_PyObject_GetSlot(o, tp_dealloc, destructor) == __pyx_tp_dealloc_array) { + if (PyObject_CallFinalizerFromDealloc(o)) return; + } + } + #endif + { + PyObject *etype, *eval, *etb; + PyErr_Fetch(&etype, &eval, &etb); + __Pyx_SET_REFCNT(o, Py_REFCNT(o) + 1); + __pyx_array___dealloc__(o); + __Pyx_SET_REFCNT(o, Py_REFCNT(o) - 1); + PyErr_Restore(etype, eval, etb); + } + Py_CLEAR(p->mode); + Py_CLEAR(p->_format); + #if CYTHON_USE_TYPE_SLOTS || CYTHON_COMPILING_IN_PYPY + (*Py_TYPE(o)->tp_free)(o); + #else + { + freefunc tp_free = (freefunc)PyType_GetSlot(Py_TYPE(o), Py_tp_free); + if (tp_free) tp_free(o); + } + #endif +} +static PyObject *__pyx_sq_item_array(PyObject *o, Py_ssize_t i) { + PyObject *r; + PyObject *x = PyInt_FromSsize_t(i); if(!x) return 0; + r = Py_TYPE(o)->tp_as_mapping->mp_subscript(o, x); + Py_DECREF(x); + return r; +} + +static int __pyx_mp_ass_subscript_array(PyObject *o, PyObject *i, PyObject *v) { + if (v) { + return __pyx_array___setitem__(o, i, v); + } + else { + __Pyx_TypeName o_type_name; + o_type_name = __Pyx_PyType_GetName(Py_TYPE(o)); + PyErr_Format(PyExc_NotImplementedError, + "Subscript deletion not supported by " __Pyx_FMT_TYPENAME, o_type_name); + __Pyx_DECREF_TypeName(o_type_name); + return -1; + } +} + +static PyObject *__pyx_tp_getattro_array(PyObject *o, PyObject *n) { + PyObject *v = __Pyx_PyObject_GenericGetAttr(o, n); + if (!v && PyErr_ExceptionMatches(PyExc_AttributeError)) { + PyErr_Clear(); + v = __pyx_array___getattr__(o, n); + } + return v; +} + +static PyObject *__pyx_getprop___pyx_array_memview(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_5array_7memview_1__get__(o); +} + +static PyMethodDef __pyx_methods_array[] = { + {"__getattr__", (PyCFunction)__pyx_array___getattr__, METH_O|METH_COEXIST, 0}, + {"__reduce_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw___pyx_array_1__reduce_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {"__setstate_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw___pyx_array_3__setstate_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {0, 0, 0, 0} +}; + +static struct PyGetSetDef __pyx_getsets_array[] = { + {(char *)"memview", __pyx_getprop___pyx_array_memview, 0, (char *)0, 0}, + {0, 0, 0, 0, 0} +}; +#if CYTHON_USE_TYPE_SPECS +#if !CYTHON_COMPILING_IN_LIMITED_API + +static PyBufferProcs __pyx_tp_as_buffer_array = { + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getreadbuffer*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getwritebuffer*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getsegcount*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getcharbuffer*/ + #endif + __pyx_array_getbuffer, /*bf_getbuffer*/ + 0, /*bf_releasebuffer*/ +}; +#endif +static PyType_Slot __pyx_type___pyx_array_slots[] = { + {Py_tp_dealloc, (void *)__pyx_tp_dealloc_array}, + {Py_sq_length, (void *)__pyx_array___len__}, + {Py_sq_item, (void *)__pyx_sq_item_array}, + {Py_mp_length, (void *)__pyx_array___len__}, + {Py_mp_subscript, (void *)__pyx_array___getitem__}, + {Py_mp_ass_subscript, (void *)__pyx_mp_ass_subscript_array}, + {Py_tp_getattro, (void *)__pyx_tp_getattro_array}, + #if defined(Py_bf_getbuffer) + {Py_bf_getbuffer, (void *)__pyx_array_getbuffer}, + #endif + {Py_tp_methods, (void *)__pyx_methods_array}, + {Py_tp_getset, (void *)__pyx_getsets_array}, + {Py_tp_new, (void *)__pyx_tp_new_array}, + {0, 0}, +}; +static PyType_Spec __pyx_type___pyx_array_spec = { + "fairseq.data.data_utils_fast.array", + sizeof(struct __pyx_array_obj), + 0, + Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_SEQUENCE, + __pyx_type___pyx_array_slots, +}; +#else + +static PySequenceMethods __pyx_tp_as_sequence_array = { + __pyx_array___len__, /*sq_length*/ + 0, /*sq_concat*/ + 0, /*sq_repeat*/ + __pyx_sq_item_array, /*sq_item*/ + 0, /*sq_slice*/ + 0, /*sq_ass_item*/ + 0, /*sq_ass_slice*/ + 0, /*sq_contains*/ + 0, /*sq_inplace_concat*/ + 0, /*sq_inplace_repeat*/ +}; + +static PyMappingMethods __pyx_tp_as_mapping_array = { + __pyx_array___len__, /*mp_length*/ + __pyx_array___getitem__, /*mp_subscript*/ + __pyx_mp_ass_subscript_array, /*mp_ass_subscript*/ +}; + +static PyBufferProcs __pyx_tp_as_buffer_array = { + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getreadbuffer*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getwritebuffer*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getsegcount*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getcharbuffer*/ + #endif + __pyx_array_getbuffer, /*bf_getbuffer*/ + 0, /*bf_releasebuffer*/ +}; + +static PyTypeObject __pyx_type___pyx_array = { + PyVarObject_HEAD_INIT(0, 0) + "fairseq.data.data_utils_fast.""array", /*tp_name*/ + sizeof(struct __pyx_array_obj), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + __pyx_tp_dealloc_array, /*tp_dealloc*/ + #if PY_VERSION_HEX < 0x030800b4 + 0, /*tp_print*/ + #endif + #if PY_VERSION_HEX >= 0x030800b4 + 0, /*tp_vectorcall_offset*/ + #endif + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + #if PY_MAJOR_VERSION < 3 + 0, /*tp_compare*/ + #endif + #if PY_MAJOR_VERSION >= 3 + 0, /*tp_as_async*/ + #endif + 0, /*tp_repr*/ + 0, /*tp_as_number*/ + &__pyx_tp_as_sequence_array, /*tp_as_sequence*/ + &__pyx_tp_as_mapping_array, /*tp_as_mapping*/ + 0, /*tp_hash*/ + 0, /*tp_call*/ + 0, /*tp_str*/ + __pyx_tp_getattro_array, /*tp_getattro*/ + 0, /*tp_setattro*/ + &__pyx_tp_as_buffer_array, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_SEQUENCE, /*tp_flags*/ + 0, /*tp_doc*/ + 0, /*tp_traverse*/ + 0, /*tp_clear*/ + 0, /*tp_richcompare*/ + 0, /*tp_weaklistoffset*/ + 0, /*tp_iter*/ + 0, /*tp_iternext*/ + __pyx_methods_array, /*tp_methods*/ + 0, /*tp_members*/ + __pyx_getsets_array, /*tp_getset*/ + 0, /*tp_base*/ + 0, /*tp_dict*/ + 0, /*tp_descr_get*/ + 0, /*tp_descr_set*/ + #if !CYTHON_USE_TYPE_SPECS + 0, /*tp_dictoffset*/ + #endif + 0, /*tp_init*/ + 0, /*tp_alloc*/ + __pyx_tp_new_array, /*tp_new*/ + 0, /*tp_free*/ + 0, /*tp_is_gc*/ + 0, /*tp_bases*/ + 0, /*tp_mro*/ + 0, /*tp_cache*/ + 0, /*tp_subclasses*/ + 0, /*tp_weaklist*/ + 0, /*tp_del*/ + 0, /*tp_version_tag*/ + #if PY_VERSION_HEX >= 0x030400a1 + #if CYTHON_USE_TP_FINALIZE + 0, /*tp_finalize*/ + #else + NULL, /*tp_finalize*/ + #endif + #endif + #if PY_VERSION_HEX >= 0x030800b1 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07030800) + 0, /*tp_vectorcall*/ + #endif + #if __PYX_NEED_TP_PRINT_SLOT == 1 + 0, /*tp_print*/ + #endif + #if PY_VERSION_HEX >= 0x030C0000 + 0, /*tp_watched*/ + #endif + #if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX >= 0x03090000 && PY_VERSION_HEX < 0x030a0000 + 0, /*tp_pypy_flags*/ + #endif +}; +#endif + +static PyObject *__pyx_tp_new_Enum(PyTypeObject *t, CYTHON_UNUSED PyObject *a, CYTHON_UNUSED PyObject *k) { + struct __pyx_MemviewEnum_obj *p; + PyObject *o; + #if CYTHON_COMPILING_IN_LIMITED_API + allocfunc alloc_func = (allocfunc)PyType_GetSlot(t, Py_tp_alloc); + o = alloc_func(t, 0); + #else + if (likely(!__Pyx_PyType_HasFeature(t, Py_TPFLAGS_IS_ABSTRACT))) { + o = (*t->tp_alloc)(t, 0); + } else { + o = (PyObject *) PyBaseObject_Type.tp_new(t, __pyx_empty_tuple, 0); + } + if (unlikely(!o)) return 0; + #endif + p = ((struct __pyx_MemviewEnum_obj *)o); + p->name = Py_None; Py_INCREF(Py_None); + return o; +} + +static void __pyx_tp_dealloc_Enum(PyObject *o) { + struct __pyx_MemviewEnum_obj *p = (struct __pyx_MemviewEnum_obj *)o; + #if CYTHON_USE_TP_FINALIZE + if (unlikely((PY_VERSION_HEX >= 0x03080000 || __Pyx_PyType_HasFeature(Py_TYPE(o), Py_TPFLAGS_HAVE_FINALIZE)) && __Pyx_PyObject_GetSlot(o, tp_finalize, destructor)) && !__Pyx_PyObject_GC_IsFinalized(o)) { + if (__Pyx_PyObject_GetSlot(o, tp_dealloc, destructor) == __pyx_tp_dealloc_Enum) { + if (PyObject_CallFinalizerFromDealloc(o)) return; + } + } + #endif + PyObject_GC_UnTrack(o); + Py_CLEAR(p->name); + #if CYTHON_USE_TYPE_SLOTS || CYTHON_COMPILING_IN_PYPY + (*Py_TYPE(o)->tp_free)(o); + #else + { + freefunc tp_free = (freefunc)PyType_GetSlot(Py_TYPE(o), Py_tp_free); + if (tp_free) tp_free(o); + } + #endif +} + +static int __pyx_tp_traverse_Enum(PyObject *o, visitproc v, void *a) { + int e; + struct __pyx_MemviewEnum_obj *p = (struct __pyx_MemviewEnum_obj *)o; + if (p->name) { + e = (*v)(p->name, a); if (e) return e; + } + return 0; +} + +static int __pyx_tp_clear_Enum(PyObject *o) { + PyObject* tmp; + struct __pyx_MemviewEnum_obj *p = (struct __pyx_MemviewEnum_obj *)o; + tmp = ((PyObject*)p->name); + p->name = Py_None; Py_INCREF(Py_None); + Py_XDECREF(tmp); + return 0; +} + +static PyObject *__pyx_specialmethod___pyx_MemviewEnum___repr__(PyObject *self, CYTHON_UNUSED PyObject *arg) { + return __pyx_MemviewEnum___repr__(self); +} + +static PyMethodDef __pyx_methods_Enum[] = { + {"__repr__", (PyCFunction)__pyx_specialmethod___pyx_MemviewEnum___repr__, METH_NOARGS|METH_COEXIST, 0}, + {"__reduce_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw___pyx_MemviewEnum_1__reduce_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {"__setstate_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw___pyx_MemviewEnum_3__setstate_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {0, 0, 0, 0} +}; +#if CYTHON_USE_TYPE_SPECS +static PyType_Slot __pyx_type___pyx_MemviewEnum_slots[] = { + {Py_tp_dealloc, (void *)__pyx_tp_dealloc_Enum}, + {Py_tp_repr, (void *)__pyx_MemviewEnum___repr__}, + {Py_tp_traverse, (void *)__pyx_tp_traverse_Enum}, + {Py_tp_clear, (void *)__pyx_tp_clear_Enum}, + {Py_tp_methods, (void *)__pyx_methods_Enum}, + {Py_tp_init, (void *)__pyx_MemviewEnum___init__}, + {Py_tp_new, (void *)__pyx_tp_new_Enum}, + {0, 0}, +}; +static PyType_Spec __pyx_type___pyx_MemviewEnum_spec = { + "fairseq.data.data_utils_fast.Enum", + sizeof(struct __pyx_MemviewEnum_obj), + 0, + Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC, + __pyx_type___pyx_MemviewEnum_slots, +}; +#else + +static PyTypeObject __pyx_type___pyx_MemviewEnum = { + PyVarObject_HEAD_INIT(0, 0) + "fairseq.data.data_utils_fast.""Enum", /*tp_name*/ + sizeof(struct __pyx_MemviewEnum_obj), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + __pyx_tp_dealloc_Enum, /*tp_dealloc*/ + #if PY_VERSION_HEX < 0x030800b4 + 0, /*tp_print*/ + #endif + #if PY_VERSION_HEX >= 0x030800b4 + 0, /*tp_vectorcall_offset*/ + #endif + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + #if PY_MAJOR_VERSION < 3 + 0, /*tp_compare*/ + #endif + #if PY_MAJOR_VERSION >= 3 + 0, /*tp_as_async*/ + #endif + __pyx_MemviewEnum___repr__, /*tp_repr*/ + 0, /*tp_as_number*/ + 0, /*tp_as_sequence*/ + 0, /*tp_as_mapping*/ + 0, /*tp_hash*/ + 0, /*tp_call*/ + 0, /*tp_str*/ + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + 0, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC, /*tp_flags*/ + 0, /*tp_doc*/ + __pyx_tp_traverse_Enum, /*tp_traverse*/ + __pyx_tp_clear_Enum, /*tp_clear*/ + 0, /*tp_richcompare*/ + 0, /*tp_weaklistoffset*/ + 0, /*tp_iter*/ + 0, /*tp_iternext*/ + __pyx_methods_Enum, /*tp_methods*/ + 0, /*tp_members*/ + 0, /*tp_getset*/ + 0, /*tp_base*/ + 0, /*tp_dict*/ + 0, /*tp_descr_get*/ + 0, /*tp_descr_set*/ + #if !CYTHON_USE_TYPE_SPECS + 0, /*tp_dictoffset*/ + #endif + __pyx_MemviewEnum___init__, /*tp_init*/ + 0, /*tp_alloc*/ + __pyx_tp_new_Enum, /*tp_new*/ + 0, /*tp_free*/ + 0, /*tp_is_gc*/ + 0, /*tp_bases*/ + 0, /*tp_mro*/ + 0, /*tp_cache*/ + 0, /*tp_subclasses*/ + 0, /*tp_weaklist*/ + 0, /*tp_del*/ + 0, /*tp_version_tag*/ + #if PY_VERSION_HEX >= 0x030400a1 + #if CYTHON_USE_TP_FINALIZE + 0, /*tp_finalize*/ + #else + NULL, /*tp_finalize*/ + #endif + #endif + #if PY_VERSION_HEX >= 0x030800b1 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07030800) + 0, /*tp_vectorcall*/ + #endif + #if __PYX_NEED_TP_PRINT_SLOT == 1 + 0, /*tp_print*/ + #endif + #if PY_VERSION_HEX >= 0x030C0000 + 0, /*tp_watched*/ + #endif + #if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX >= 0x03090000 && PY_VERSION_HEX < 0x030a0000 + 0, /*tp_pypy_flags*/ + #endif +}; +#endif +static struct __pyx_vtabstruct_memoryview __pyx_vtable_memoryview; + +static PyObject *__pyx_tp_new_memoryview(PyTypeObject *t, PyObject *a, PyObject *k) { + struct __pyx_memoryview_obj *p; + PyObject *o; + #if CYTHON_COMPILING_IN_LIMITED_API + allocfunc alloc_func = (allocfunc)PyType_GetSlot(t, Py_tp_alloc); + o = alloc_func(t, 0); + #else + if (likely(!__Pyx_PyType_HasFeature(t, Py_TPFLAGS_IS_ABSTRACT))) { + o = (*t->tp_alloc)(t, 0); + } else { + o = (PyObject *) PyBaseObject_Type.tp_new(t, __pyx_empty_tuple, 0); + } + if (unlikely(!o)) return 0; + #endif + p = ((struct __pyx_memoryview_obj *)o); + p->__pyx_vtab = __pyx_vtabptr_memoryview; + p->obj = Py_None; Py_INCREF(Py_None); + p->_size = Py_None; Py_INCREF(Py_None); + p->_array_interface = Py_None; Py_INCREF(Py_None); + p->view.obj = NULL; + if (unlikely(__pyx_memoryview___cinit__(o, a, k) < 0)) goto bad; + return o; + bad: + Py_DECREF(o); o = 0; + return NULL; +} + +static void __pyx_tp_dealloc_memoryview(PyObject *o) { + struct __pyx_memoryview_obj *p = (struct __pyx_memoryview_obj *)o; + #if CYTHON_USE_TP_FINALIZE + if (unlikely((PY_VERSION_HEX >= 0x03080000 || __Pyx_PyType_HasFeature(Py_TYPE(o), Py_TPFLAGS_HAVE_FINALIZE)) && __Pyx_PyObject_GetSlot(o, tp_finalize, destructor)) && !__Pyx_PyObject_GC_IsFinalized(o)) { + if (__Pyx_PyObject_GetSlot(o, tp_dealloc, destructor) == __pyx_tp_dealloc_memoryview) { + if (PyObject_CallFinalizerFromDealloc(o)) return; + } + } + #endif + PyObject_GC_UnTrack(o); + { + PyObject *etype, *eval, *etb; + PyErr_Fetch(&etype, &eval, &etb); + __Pyx_SET_REFCNT(o, Py_REFCNT(o) + 1); + __pyx_memoryview___dealloc__(o); + __Pyx_SET_REFCNT(o, Py_REFCNT(o) - 1); + PyErr_Restore(etype, eval, etb); + } + Py_CLEAR(p->obj); + Py_CLEAR(p->_size); + Py_CLEAR(p->_array_interface); + #if CYTHON_USE_TYPE_SLOTS || CYTHON_COMPILING_IN_PYPY + (*Py_TYPE(o)->tp_free)(o); + #else + { + freefunc tp_free = (freefunc)PyType_GetSlot(Py_TYPE(o), Py_tp_free); + if (tp_free) tp_free(o); + } + #endif +} + +static int __pyx_tp_traverse_memoryview(PyObject *o, visitproc v, void *a) { + int e; + struct __pyx_memoryview_obj *p = (struct __pyx_memoryview_obj *)o; + if (p->obj) { + e = (*v)(p->obj, a); if (e) return e; + } + if (p->_size) { + e = (*v)(p->_size, a); if (e) return e; + } + if (p->_array_interface) { + e = (*v)(p->_array_interface, a); if (e) return e; + } + if (p->view.obj) { + e = (*v)(p->view.obj, a); if (e) return e; + } + return 0; +} + +static int __pyx_tp_clear_memoryview(PyObject *o) { + PyObject* tmp; + struct __pyx_memoryview_obj *p = (struct __pyx_memoryview_obj *)o; + tmp = ((PyObject*)p->obj); + p->obj = Py_None; Py_INCREF(Py_None); + Py_XDECREF(tmp); + tmp = ((PyObject*)p->_size); + p->_size = Py_None; Py_INCREF(Py_None); + Py_XDECREF(tmp); + tmp = ((PyObject*)p->_array_interface); + p->_array_interface = Py_None; Py_INCREF(Py_None); + Py_XDECREF(tmp); + Py_CLEAR(p->view.obj); + return 0; +} +static PyObject *__pyx_sq_item_memoryview(PyObject *o, Py_ssize_t i) { + PyObject *r; + PyObject *x = PyInt_FromSsize_t(i); if(!x) return 0; + r = Py_TYPE(o)->tp_as_mapping->mp_subscript(o, x); + Py_DECREF(x); + return r; +} + +static int __pyx_mp_ass_subscript_memoryview(PyObject *o, PyObject *i, PyObject *v) { + if (v) { + return __pyx_memoryview___setitem__(o, i, v); + } + else { + __Pyx_TypeName o_type_name; + o_type_name = __Pyx_PyType_GetName(Py_TYPE(o)); + PyErr_Format(PyExc_NotImplementedError, + "Subscript deletion not supported by " __Pyx_FMT_TYPENAME, o_type_name); + __Pyx_DECREF_TypeName(o_type_name); + return -1; + } +} + +static PyObject *__pyx_getprop___pyx_memoryview_T(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_10memoryview_1T_1__get__(o); +} + +static PyObject *__pyx_getprop___pyx_memoryview_base(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_10memoryview_4base_1__get__(o); +} + +static PyObject *__pyx_getprop___pyx_memoryview_shape(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_10memoryview_5shape_1__get__(o); +} + +static PyObject *__pyx_getprop___pyx_memoryview_strides(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_10memoryview_7strides_1__get__(o); +} + +static PyObject *__pyx_getprop___pyx_memoryview_suboffsets(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_10memoryview_10suboffsets_1__get__(o); +} + +static PyObject *__pyx_getprop___pyx_memoryview_ndim(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_10memoryview_4ndim_1__get__(o); +} + +static PyObject *__pyx_getprop___pyx_memoryview_itemsize(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_10memoryview_8itemsize_1__get__(o); +} + +static PyObject *__pyx_getprop___pyx_memoryview_nbytes(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_10memoryview_6nbytes_1__get__(o); +} + +static PyObject *__pyx_getprop___pyx_memoryview_size(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_10memoryview_4size_1__get__(o); +} + +static PyObject *__pyx_specialmethod___pyx_memoryview___repr__(PyObject *self, CYTHON_UNUSED PyObject *arg) { + return __pyx_memoryview___repr__(self); +} + +static PyMethodDef __pyx_methods_memoryview[] = { + {"__repr__", (PyCFunction)__pyx_specialmethod___pyx_memoryview___repr__, METH_NOARGS|METH_COEXIST, 0}, + {"is_c_contig", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_memoryview_is_c_contig, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {"is_f_contig", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_memoryview_is_f_contig, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {"copy", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_memoryview_copy, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {"copy_fortran", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_memoryview_copy_fortran, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {"__reduce_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw___pyx_memoryview_1__reduce_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {"__setstate_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw___pyx_memoryview_3__setstate_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {0, 0, 0, 0} +}; + +static struct PyGetSetDef __pyx_getsets_memoryview[] = { + {(char *)"T", __pyx_getprop___pyx_memoryview_T, 0, (char *)0, 0}, + {(char *)"base", __pyx_getprop___pyx_memoryview_base, 0, (char *)0, 0}, + {(char *)"shape", __pyx_getprop___pyx_memoryview_shape, 0, (char *)0, 0}, + {(char *)"strides", __pyx_getprop___pyx_memoryview_strides, 0, (char *)0, 0}, + {(char *)"suboffsets", __pyx_getprop___pyx_memoryview_suboffsets, 0, (char *)0, 0}, + {(char *)"ndim", __pyx_getprop___pyx_memoryview_ndim, 0, (char *)0, 0}, + {(char *)"itemsize", __pyx_getprop___pyx_memoryview_itemsize, 0, (char *)0, 0}, + {(char *)"nbytes", __pyx_getprop___pyx_memoryview_nbytes, 0, (char *)0, 0}, + {(char *)"size", __pyx_getprop___pyx_memoryview_size, 0, (char *)0, 0}, + {0, 0, 0, 0, 0} +}; +#if CYTHON_USE_TYPE_SPECS +#if !CYTHON_COMPILING_IN_LIMITED_API + +static PyBufferProcs __pyx_tp_as_buffer_memoryview = { + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getreadbuffer*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getwritebuffer*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getsegcount*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getcharbuffer*/ + #endif + __pyx_memoryview_getbuffer, /*bf_getbuffer*/ + 0, /*bf_releasebuffer*/ +}; +#endif +static PyType_Slot __pyx_type___pyx_memoryview_slots[] = { + {Py_tp_dealloc, (void *)__pyx_tp_dealloc_memoryview}, + {Py_tp_repr, (void *)__pyx_memoryview___repr__}, + {Py_sq_length, (void *)__pyx_memoryview___len__}, + {Py_sq_item, (void *)__pyx_sq_item_memoryview}, + {Py_mp_length, (void *)__pyx_memoryview___len__}, + {Py_mp_subscript, (void *)__pyx_memoryview___getitem__}, + {Py_mp_ass_subscript, (void *)__pyx_mp_ass_subscript_memoryview}, + {Py_tp_str, (void *)__pyx_memoryview___str__}, + #if defined(Py_bf_getbuffer) + {Py_bf_getbuffer, (void *)__pyx_memoryview_getbuffer}, + #endif + {Py_tp_traverse, (void *)__pyx_tp_traverse_memoryview}, + {Py_tp_clear, (void *)__pyx_tp_clear_memoryview}, + {Py_tp_methods, (void *)__pyx_methods_memoryview}, + {Py_tp_getset, (void *)__pyx_getsets_memoryview}, + {Py_tp_new, (void *)__pyx_tp_new_memoryview}, + {0, 0}, +}; +static PyType_Spec __pyx_type___pyx_memoryview_spec = { + "fairseq.data.data_utils_fast.memoryview", + sizeof(struct __pyx_memoryview_obj), + 0, + Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC, + __pyx_type___pyx_memoryview_slots, +}; +#else + +static PySequenceMethods __pyx_tp_as_sequence_memoryview = { + __pyx_memoryview___len__, /*sq_length*/ + 0, /*sq_concat*/ + 0, /*sq_repeat*/ + __pyx_sq_item_memoryview, /*sq_item*/ + 0, /*sq_slice*/ + 0, /*sq_ass_item*/ + 0, /*sq_ass_slice*/ + 0, /*sq_contains*/ + 0, /*sq_inplace_concat*/ + 0, /*sq_inplace_repeat*/ +}; + +static PyMappingMethods __pyx_tp_as_mapping_memoryview = { + __pyx_memoryview___len__, /*mp_length*/ + __pyx_memoryview___getitem__, /*mp_subscript*/ + __pyx_mp_ass_subscript_memoryview, /*mp_ass_subscript*/ +}; + +static PyBufferProcs __pyx_tp_as_buffer_memoryview = { + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getreadbuffer*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getwritebuffer*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getsegcount*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getcharbuffer*/ + #endif + __pyx_memoryview_getbuffer, /*bf_getbuffer*/ + 0, /*bf_releasebuffer*/ +}; + +static PyTypeObject __pyx_type___pyx_memoryview = { + PyVarObject_HEAD_INIT(0, 0) + "fairseq.data.data_utils_fast.""memoryview", /*tp_name*/ + sizeof(struct __pyx_memoryview_obj), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + __pyx_tp_dealloc_memoryview, /*tp_dealloc*/ + #if PY_VERSION_HEX < 0x030800b4 + 0, /*tp_print*/ + #endif + #if PY_VERSION_HEX >= 0x030800b4 + 0, /*tp_vectorcall_offset*/ + #endif + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + #if PY_MAJOR_VERSION < 3 + 0, /*tp_compare*/ + #endif + #if PY_MAJOR_VERSION >= 3 + 0, /*tp_as_async*/ + #endif + __pyx_memoryview___repr__, /*tp_repr*/ + 0, /*tp_as_number*/ + &__pyx_tp_as_sequence_memoryview, /*tp_as_sequence*/ + &__pyx_tp_as_mapping_memoryview, /*tp_as_mapping*/ + 0, /*tp_hash*/ + 0, /*tp_call*/ + __pyx_memoryview___str__, /*tp_str*/ + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + &__pyx_tp_as_buffer_memoryview, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC, /*tp_flags*/ + 0, /*tp_doc*/ + __pyx_tp_traverse_memoryview, /*tp_traverse*/ + __pyx_tp_clear_memoryview, /*tp_clear*/ + 0, /*tp_richcompare*/ + 0, /*tp_weaklistoffset*/ + 0, /*tp_iter*/ + 0, /*tp_iternext*/ + __pyx_methods_memoryview, /*tp_methods*/ + 0, /*tp_members*/ + __pyx_getsets_memoryview, /*tp_getset*/ + 0, /*tp_base*/ + 0, /*tp_dict*/ + 0, /*tp_descr_get*/ + 0, /*tp_descr_set*/ + #if !CYTHON_USE_TYPE_SPECS + 0, /*tp_dictoffset*/ + #endif + 0, /*tp_init*/ + 0, /*tp_alloc*/ + __pyx_tp_new_memoryview, /*tp_new*/ + 0, /*tp_free*/ + 0, /*tp_is_gc*/ + 0, /*tp_bases*/ + 0, /*tp_mro*/ + 0, /*tp_cache*/ + 0, /*tp_subclasses*/ + 0, /*tp_weaklist*/ + 0, /*tp_del*/ + 0, /*tp_version_tag*/ + #if PY_VERSION_HEX >= 0x030400a1 + #if CYTHON_USE_TP_FINALIZE + 0, /*tp_finalize*/ + #else + NULL, /*tp_finalize*/ + #endif + #endif + #if PY_VERSION_HEX >= 0x030800b1 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07030800) + 0, /*tp_vectorcall*/ + #endif + #if __PYX_NEED_TP_PRINT_SLOT == 1 + 0, /*tp_print*/ + #endif + #if PY_VERSION_HEX >= 0x030C0000 + 0, /*tp_watched*/ + #endif + #if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX >= 0x03090000 && PY_VERSION_HEX < 0x030a0000 + 0, /*tp_pypy_flags*/ + #endif +}; +#endif +static struct __pyx_vtabstruct__memoryviewslice __pyx_vtable__memoryviewslice; + +static PyObject *__pyx_tp_new__memoryviewslice(PyTypeObject *t, PyObject *a, PyObject *k) { + struct __pyx_memoryviewslice_obj *p; + PyObject *o = __pyx_tp_new_memoryview(t, a, k); + if (unlikely(!o)) return 0; + p = ((struct __pyx_memoryviewslice_obj *)o); + p->__pyx_base.__pyx_vtab = (struct __pyx_vtabstruct_memoryview*)__pyx_vtabptr__memoryviewslice; + new((void*)&(p->from_slice)) __Pyx_memviewslice(); + p->from_object = Py_None; Py_INCREF(Py_None); + p->from_slice.memview = NULL; + return o; +} + +static void __pyx_tp_dealloc__memoryviewslice(PyObject *o) { + struct __pyx_memoryviewslice_obj *p = (struct __pyx_memoryviewslice_obj *)o; + #if CYTHON_USE_TP_FINALIZE + if (unlikely((PY_VERSION_HEX >= 0x03080000 || __Pyx_PyType_HasFeature(Py_TYPE(o), Py_TPFLAGS_HAVE_FINALIZE)) && __Pyx_PyObject_GetSlot(o, tp_finalize, destructor)) && !__Pyx_PyObject_GC_IsFinalized(o)) { + if (__Pyx_PyObject_GetSlot(o, tp_dealloc, destructor) == __pyx_tp_dealloc__memoryviewslice) { + if (PyObject_CallFinalizerFromDealloc(o)) return; + } + } + #endif + PyObject_GC_UnTrack(o); + { + PyObject *etype, *eval, *etb; + PyErr_Fetch(&etype, &eval, &etb); + __Pyx_SET_REFCNT(o, Py_REFCNT(o) + 1); + __pyx_memoryviewslice___dealloc__(o); + __Pyx_SET_REFCNT(o, Py_REFCNT(o) - 1); + PyErr_Restore(etype, eval, etb); + } + __Pyx_call_destructor(p->from_slice); + Py_CLEAR(p->from_object); + PyObject_GC_Track(o); + __pyx_tp_dealloc_memoryview(o); +} + +static int __pyx_tp_traverse__memoryviewslice(PyObject *o, visitproc v, void *a) { + int e; + struct __pyx_memoryviewslice_obj *p = (struct __pyx_memoryviewslice_obj *)o; + e = __pyx_tp_traverse_memoryview(o, v, a); if (e) return e; + if (p->from_object) { + e = (*v)(p->from_object, a); if (e) return e; + } + return 0; +} + +static int __pyx_tp_clear__memoryviewslice(PyObject *o) { + PyObject* tmp; + struct __pyx_memoryviewslice_obj *p = (struct __pyx_memoryviewslice_obj *)o; + __pyx_tp_clear_memoryview(o); + tmp = ((PyObject*)p->from_object); + p->from_object = Py_None; Py_INCREF(Py_None); + Py_XDECREF(tmp); + __PYX_XCLEAR_MEMVIEW(&p->from_slice, 1); + return 0; +} + +static PyMethodDef __pyx_methods__memoryviewslice[] = { + {"__reduce_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw___pyx_memoryviewslice_1__reduce_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {"__setstate_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw___pyx_memoryviewslice_3__setstate_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {0, 0, 0, 0} +}; +#if CYTHON_USE_TYPE_SPECS +static PyType_Slot __pyx_type___pyx_memoryviewslice_slots[] = { + {Py_tp_dealloc, (void *)__pyx_tp_dealloc__memoryviewslice}, + {Py_tp_doc, (void *)PyDoc_STR("Internal class for passing memoryview slices to Python")}, + {Py_tp_traverse, (void *)__pyx_tp_traverse__memoryviewslice}, + {Py_tp_clear, (void *)__pyx_tp_clear__memoryviewslice}, + {Py_tp_methods, (void *)__pyx_methods__memoryviewslice}, + {Py_tp_new, (void *)__pyx_tp_new__memoryviewslice}, + {0, 0}, +}; +static PyType_Spec __pyx_type___pyx_memoryviewslice_spec = { + "fairseq.data.data_utils_fast._memoryviewslice", + sizeof(struct __pyx_memoryviewslice_obj), + 0, + Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC|Py_TPFLAGS_SEQUENCE, + __pyx_type___pyx_memoryviewslice_slots, +}; +#else + +static PyTypeObject __pyx_type___pyx_memoryviewslice = { + PyVarObject_HEAD_INIT(0, 0) + "fairseq.data.data_utils_fast.""_memoryviewslice", /*tp_name*/ + sizeof(struct __pyx_memoryviewslice_obj), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + __pyx_tp_dealloc__memoryviewslice, /*tp_dealloc*/ + #if PY_VERSION_HEX < 0x030800b4 + 0, /*tp_print*/ + #endif + #if PY_VERSION_HEX >= 0x030800b4 + 0, /*tp_vectorcall_offset*/ + #endif + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + #if PY_MAJOR_VERSION < 3 + 0, /*tp_compare*/ + #endif + #if PY_MAJOR_VERSION >= 3 + 0, /*tp_as_async*/ + #endif + #if CYTHON_COMPILING_IN_PYPY || 0 + __pyx_memoryview___repr__, /*tp_repr*/ + #else + 0, /*tp_repr*/ + #endif + 0, /*tp_as_number*/ + 0, /*tp_as_sequence*/ + 0, /*tp_as_mapping*/ + 0, /*tp_hash*/ + 0, /*tp_call*/ + #if CYTHON_COMPILING_IN_PYPY || 0 + __pyx_memoryview___str__, /*tp_str*/ + #else + 0, /*tp_str*/ + #endif + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + 0, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC|Py_TPFLAGS_SEQUENCE, /*tp_flags*/ + PyDoc_STR("Internal class for passing memoryview slices to Python"), /*tp_doc*/ + __pyx_tp_traverse__memoryviewslice, /*tp_traverse*/ + __pyx_tp_clear__memoryviewslice, /*tp_clear*/ + 0, /*tp_richcompare*/ + 0, /*tp_weaklistoffset*/ + 0, /*tp_iter*/ + 0, /*tp_iternext*/ + __pyx_methods__memoryviewslice, /*tp_methods*/ + 0, /*tp_members*/ + 0, /*tp_getset*/ + 0, /*tp_base*/ + 0, /*tp_dict*/ + 0, /*tp_descr_get*/ + 0, /*tp_descr_set*/ + #if !CYTHON_USE_TYPE_SPECS + 0, /*tp_dictoffset*/ + #endif + 0, /*tp_init*/ + 0, /*tp_alloc*/ + __pyx_tp_new__memoryviewslice, /*tp_new*/ + 0, /*tp_free*/ + 0, /*tp_is_gc*/ + 0, /*tp_bases*/ + 0, /*tp_mro*/ + 0, /*tp_cache*/ + 0, /*tp_subclasses*/ + 0, /*tp_weaklist*/ + 0, /*tp_del*/ + 0, /*tp_version_tag*/ + #if PY_VERSION_HEX >= 0x030400a1 + #if CYTHON_USE_TP_FINALIZE + 0, /*tp_finalize*/ + #else + NULL, /*tp_finalize*/ + #endif + #endif + #if PY_VERSION_HEX >= 0x030800b1 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07030800) + 0, /*tp_vectorcall*/ + #endif + #if __PYX_NEED_TP_PRINT_SLOT == 1 + 0, /*tp_print*/ + #endif + #if PY_VERSION_HEX >= 0x030C0000 + 0, /*tp_watched*/ + #endif + #if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX >= 0x03090000 && PY_VERSION_HEX < 0x030a0000 + 0, /*tp_pypy_flags*/ + #endif +}; +#endif + +static PyMethodDef __pyx_methods[] = { + {0, 0, 0, 0} +}; +#ifndef CYTHON_SMALL_CODE +#if defined(__clang__) + #define CYTHON_SMALL_CODE +#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 3)) + #define CYTHON_SMALL_CODE __attribute__((cold)) +#else + #define CYTHON_SMALL_CODE +#endif +#endif +/* #### Code section: pystring_table ### */ + +static int __Pyx_CreateStringTabAndInitStrings(void) { + __Pyx_StringTabEntry __pyx_string_tab[] = { + {&__pyx_kp_u_, __pyx_k_, sizeof(__pyx_k_), 0, 1, 0, 0}, + {&__pyx_n_s_ASCII, __pyx_k_ASCII, sizeof(__pyx_k_ASCII), 0, 0, 1, 1}, + {&__pyx_kp_s_All_dimensions_preceding_dimensi, __pyx_k_All_dimensions_preceding_dimensi, sizeof(__pyx_k_All_dimensions_preceding_dimensi), 0, 0, 1, 0}, + {&__pyx_n_s_AssertionError, __pyx_k_AssertionError, sizeof(__pyx_k_AssertionError), 0, 0, 1, 1}, + {&__pyx_kp_s_Buffer_view_does_not_expose_stri, __pyx_k_Buffer_view_does_not_expose_stri, sizeof(__pyx_k_Buffer_view_does_not_expose_stri), 0, 0, 1, 0}, + {&__pyx_kp_s_Can_only_create_a_buffer_that_is, __pyx_k_Can_only_create_a_buffer_that_is, sizeof(__pyx_k_Can_only_create_a_buffer_that_is), 0, 0, 1, 0}, + {&__pyx_kp_s_Cannot_assign_to_read_only_memor, __pyx_k_Cannot_assign_to_read_only_memor, sizeof(__pyx_k_Cannot_assign_to_read_only_memor), 0, 0, 1, 0}, + {&__pyx_kp_s_Cannot_create_writable_memory_vi, __pyx_k_Cannot_create_writable_memory_vi, sizeof(__pyx_k_Cannot_create_writable_memory_vi), 0, 0, 1, 0}, + {&__pyx_kp_u_Cannot_index_with_type, __pyx_k_Cannot_index_with_type, sizeof(__pyx_k_Cannot_index_with_type), 0, 1, 0, 0}, + {&__pyx_kp_s_Cannot_transpose_memoryview_with, __pyx_k_Cannot_transpose_memoryview_with, sizeof(__pyx_k_Cannot_transpose_memoryview_with), 0, 0, 1, 0}, + {&__pyx_kp_s_Dimension_d_is_not_direct, __pyx_k_Dimension_d_is_not_direct, sizeof(__pyx_k_Dimension_d_is_not_direct), 0, 0, 1, 0}, + {&__pyx_n_s_Ellipsis, __pyx_k_Ellipsis, sizeof(__pyx_k_Ellipsis), 0, 0, 1, 1}, + {&__pyx_kp_s_Empty_shape_tuple_for_cython_arr, __pyx_k_Empty_shape_tuple_for_cython_arr, sizeof(__pyx_k_Empty_shape_tuple_for_cython_arr), 0, 0, 1, 0}, + {&__pyx_n_s_ImportError, __pyx_k_ImportError, sizeof(__pyx_k_ImportError), 0, 0, 1, 1}, + {&__pyx_kp_s_Incompatible_checksums_0x_x_vs_0, __pyx_k_Incompatible_checksums_0x_x_vs_0, sizeof(__pyx_k_Incompatible_checksums_0x_x_vs_0), 0, 0, 1, 0}, + {&__pyx_n_s_IndexError, __pyx_k_IndexError, sizeof(__pyx_k_IndexError), 0, 0, 1, 1}, + {&__pyx_kp_s_Index_out_of_bounds_axis_d, __pyx_k_Index_out_of_bounds_axis_d, sizeof(__pyx_k_Index_out_of_bounds_axis_d), 0, 0, 1, 0}, + {&__pyx_kp_s_Indirect_dimensions_not_supporte, __pyx_k_Indirect_dimensions_not_supporte, sizeof(__pyx_k_Indirect_dimensions_not_supporte), 0, 0, 1, 0}, + {&__pyx_kp_u_Invalid_mode_expected_c_or_fortr, __pyx_k_Invalid_mode_expected_c_or_fortr, sizeof(__pyx_k_Invalid_mode_expected_c_or_fortr), 0, 1, 0, 0}, + {&__pyx_kp_u_Invalid_shape_in_axis, __pyx_k_Invalid_shape_in_axis, sizeof(__pyx_k_Invalid_shape_in_axis), 0, 1, 0, 0}, + {&__pyx_n_s_MemoryError, __pyx_k_MemoryError, sizeof(__pyx_k_MemoryError), 0, 0, 1, 1}, + {&__pyx_kp_s_MemoryView_of_r_at_0x_x, __pyx_k_MemoryView_of_r_at_0x_x, sizeof(__pyx_k_MemoryView_of_r_at_0x_x), 0, 0, 1, 0}, + {&__pyx_kp_s_MemoryView_of_r_object, __pyx_k_MemoryView_of_r_object, sizeof(__pyx_k_MemoryView_of_r_object), 0, 0, 1, 0}, + {&__pyx_n_b_O, __pyx_k_O, sizeof(__pyx_k_O), 0, 0, 0, 1}, + {&__pyx_kp_u_Out_of_bounds_on_buffer_access_a, __pyx_k_Out_of_bounds_on_buffer_access_a, sizeof(__pyx_k_Out_of_bounds_on_buffer_access_a), 0, 1, 0, 0}, + {&__pyx_n_s_PickleError, __pyx_k_PickleError, sizeof(__pyx_k_PickleError), 0, 0, 1, 1}, + {&__pyx_kp_u_Sentences_lengths_should_not_exc, __pyx_k_Sentences_lengths_should_not_exc, sizeof(__pyx_k_Sentences_lengths_should_not_exc), 0, 1, 0, 0}, + {&__pyx_n_s_Sequence, __pyx_k_Sequence, sizeof(__pyx_k_Sequence), 0, 0, 1, 1}, + {&__pyx_kp_s_Step_may_not_be_zero_axis_d, __pyx_k_Step_may_not_be_zero_axis_d, sizeof(__pyx_k_Step_may_not_be_zero_axis_d), 0, 0, 1, 0}, + {&__pyx_n_s_TypeError, __pyx_k_TypeError, sizeof(__pyx_k_TypeError), 0, 0, 1, 1}, + {&__pyx_kp_s_Unable_to_convert_item_to_object, __pyx_k_Unable_to_convert_item_to_object, sizeof(__pyx_k_Unable_to_convert_item_to_object), 0, 0, 1, 0}, + {&__pyx_n_s_ValueError, __pyx_k_ValueError, sizeof(__pyx_k_ValueError), 0, 0, 1, 1}, + {&__pyx_n_s_View_MemoryView, __pyx_k_View_MemoryView, sizeof(__pyx_k_View_MemoryView), 0, 0, 1, 1}, + {&__pyx_kp_u__2, __pyx_k__2, sizeof(__pyx_k__2), 0, 1, 0, 0}, + {&__pyx_n_s__28, __pyx_k__28, sizeof(__pyx_k__28), 0, 0, 1, 1}, + {&__pyx_n_s__3, __pyx_k__3, sizeof(__pyx_k__3), 0, 0, 1, 1}, + {&__pyx_kp_u__6, __pyx_k__6, sizeof(__pyx_k__6), 0, 1, 0, 0}, + {&__pyx_kp_u__7, __pyx_k__7, sizeof(__pyx_k__7), 0, 1, 0, 0}, + {&__pyx_n_s_abc, __pyx_k_abc, sizeof(__pyx_k_abc), 0, 0, 1, 1}, + {&__pyx_n_s_allocate_buffer, __pyx_k_allocate_buffer, sizeof(__pyx_k_allocate_buffer), 0, 0, 1, 1}, + {&__pyx_kp_u_and, __pyx_k_and, sizeof(__pyx_k_and), 0, 1, 0, 0}, + {&__pyx_n_s_asyncio_coroutines, __pyx_k_asyncio_coroutines, sizeof(__pyx_k_asyncio_coroutines), 0, 0, 1, 1}, + {&__pyx_n_s_base, __pyx_k_base, sizeof(__pyx_k_base), 0, 0, 1, 1}, + {&__pyx_n_s_batch_by_size_fn, __pyx_k_batch_by_size_fn, sizeof(__pyx_k_batch_by_size_fn), 0, 0, 1, 1}, + {&__pyx_n_s_batch_by_size_vec, __pyx_k_batch_by_size_vec, sizeof(__pyx_k_batch_by_size_vec), 0, 0, 1, 1}, + {&__pyx_n_s_batch_fixed_shapes_fast, __pyx_k_batch_fixed_shapes_fast, sizeof(__pyx_k_batch_fixed_shapes_fast), 0, 0, 1, 1}, + {&__pyx_n_s_bsz_mult, __pyx_k_bsz_mult, sizeof(__pyx_k_bsz_mult), 0, 0, 1, 1}, + {&__pyx_n_s_c, __pyx_k_c, sizeof(__pyx_k_c), 0, 0, 1, 1}, + {&__pyx_n_u_c, __pyx_k_c, sizeof(__pyx_k_c), 0, 1, 0, 1}, + {&__pyx_n_s_class, __pyx_k_class, sizeof(__pyx_k_class), 0, 0, 1, 1}, + {&__pyx_n_s_class_getitem, __pyx_k_class_getitem, sizeof(__pyx_k_class_getitem), 0, 0, 1, 1}, + {&__pyx_n_s_cline_in_traceback, __pyx_k_cline_in_traceback, sizeof(__pyx_k_cline_in_traceback), 0, 0, 1, 1}, + {&__pyx_n_s_collections, __pyx_k_collections, sizeof(__pyx_k_collections), 0, 0, 1, 1}, + {&__pyx_kp_s_collections_abc, __pyx_k_collections_abc, sizeof(__pyx_k_collections_abc), 0, 0, 1, 0}, + {&__pyx_kp_s_contiguous_and_direct, __pyx_k_contiguous_and_direct, sizeof(__pyx_k_contiguous_and_direct), 0, 0, 1, 0}, + {&__pyx_kp_s_contiguous_and_indirect, __pyx_k_contiguous_and_indirect, sizeof(__pyx_k_contiguous_and_indirect), 0, 0, 1, 0}, + {&__pyx_n_s_count, __pyx_k_count, sizeof(__pyx_k_count), 0, 0, 1, 1}, + {&__pyx_n_s_dict, __pyx_k_dict, sizeof(__pyx_k_dict), 0, 0, 1, 1}, + {&__pyx_kp_u_disable, __pyx_k_disable, sizeof(__pyx_k_disable), 0, 1, 0, 0}, + {&__pyx_n_s_dtype, __pyx_k_dtype, sizeof(__pyx_k_dtype), 0, 0, 1, 1}, + {&__pyx_n_s_dtype_is_object, __pyx_k_dtype_is_object, sizeof(__pyx_k_dtype_is_object), 0, 0, 1, 1}, + {&__pyx_kp_u_enable, __pyx_k_enable, sizeof(__pyx_k_enable), 0, 1, 0, 0}, + {&__pyx_n_s_encode, __pyx_k_encode, sizeof(__pyx_k_encode), 0, 0, 1, 1}, + {&__pyx_n_s_enumerate, __pyx_k_enumerate, sizeof(__pyx_k_enumerate), 0, 0, 1, 1}, + {&__pyx_n_s_error, __pyx_k_error, sizeof(__pyx_k_error), 0, 0, 1, 1}, + {&__pyx_n_s_fairseq_data_data_utils_fast, __pyx_k_fairseq_data_data_utils_fast, sizeof(__pyx_k_fairseq_data_data_utils_fast), 0, 0, 1, 1}, + {&__pyx_kp_s_fairseq_data_data_utils_fast_pyx, __pyx_k_fairseq_data_data_utils_fast_pyx, sizeof(__pyx_k_fairseq_data_data_utils_fast_pyx), 0, 0, 1, 0}, + {&__pyx_n_s_fixed_shapes_sorted, __pyx_k_fixed_shapes_sorted, sizeof(__pyx_k_fixed_shapes_sorted), 0, 0, 1, 1}, + {&__pyx_n_s_flags, __pyx_k_flags, sizeof(__pyx_k_flags), 0, 0, 1, 1}, + {&__pyx_n_s_format, __pyx_k_format, sizeof(__pyx_k_format), 0, 0, 1, 1}, + {&__pyx_n_s_fortran, __pyx_k_fortran, sizeof(__pyx_k_fortran), 0, 0, 1, 1}, + {&__pyx_n_u_fortran, __pyx_k_fortran, sizeof(__pyx_k_fortran), 0, 1, 0, 1}, + {&__pyx_kp_u_gc, __pyx_k_gc, sizeof(__pyx_k_gc), 0, 1, 0, 0}, + {&__pyx_n_s_getstate, __pyx_k_getstate, sizeof(__pyx_k_getstate), 0, 0, 1, 1}, + {&__pyx_kp_u_got, __pyx_k_got, sizeof(__pyx_k_got), 0, 1, 0, 0}, + {&__pyx_kp_u_got_differing_extents_in_dimensi, __pyx_k_got_differing_extents_in_dimensi, sizeof(__pyx_k_got_differing_extents_in_dimensi), 0, 1, 0, 0}, + {&__pyx_n_s_id, __pyx_k_id, sizeof(__pyx_k_id), 0, 0, 1, 1}, + {&__pyx_n_s_import, __pyx_k_import, sizeof(__pyx_k_import), 0, 0, 1, 1}, + {&__pyx_n_s_index, __pyx_k_index, sizeof(__pyx_k_index), 0, 0, 1, 1}, + {&__pyx_n_s_indices, __pyx_k_indices, sizeof(__pyx_k_indices), 0, 0, 1, 1}, + {&__pyx_n_s_initializing, __pyx_k_initializing, sizeof(__pyx_k_initializing), 0, 0, 1, 1}, + {&__pyx_n_s_int32, __pyx_k_int32, sizeof(__pyx_k_int32), 0, 0, 1, 1}, + {&__pyx_n_s_int64, __pyx_k_int64, sizeof(__pyx_k_int64), 0, 0, 1, 1}, + {&__pyx_n_s_is_coroutine, __pyx_k_is_coroutine, sizeof(__pyx_k_is_coroutine), 0, 0, 1, 1}, + {&__pyx_kp_u_isenabled, __pyx_k_isenabled, sizeof(__pyx_k_isenabled), 0, 1, 0, 0}, + {&__pyx_n_s_itemsize, __pyx_k_itemsize, sizeof(__pyx_k_itemsize), 0, 0, 1, 1}, + {&__pyx_kp_s_itemsize_0_for_cython_array, __pyx_k_itemsize_0_for_cython_array, sizeof(__pyx_k_itemsize_0_for_cython_array), 0, 0, 1, 0}, + {&__pyx_n_s_main, __pyx_k_main, sizeof(__pyx_k_main), 0, 0, 1, 1}, + {&__pyx_n_s_max, __pyx_k_max, sizeof(__pyx_k_max), 0, 0, 1, 1}, + {&__pyx_n_s_max_sentences, __pyx_k_max_sentences, sizeof(__pyx_k_max_sentences), 0, 0, 1, 1}, + {&__pyx_n_s_max_tokens, __pyx_k_max_tokens, sizeof(__pyx_k_max_tokens), 0, 0, 1, 1}, + {&__pyx_n_s_memview, __pyx_k_memview, sizeof(__pyx_k_memview), 0, 0, 1, 1}, + {&__pyx_n_s_mode, __pyx_k_mode, sizeof(__pyx_k_mode), 0, 0, 1, 1}, + {&__pyx_n_s_name, __pyx_k_name, sizeof(__pyx_k_name), 0, 0, 1, 1}, + {&__pyx_n_s_name_2, __pyx_k_name_2, sizeof(__pyx_k_name_2), 0, 0, 1, 1}, + {&__pyx_n_s_ndim, __pyx_k_ndim, sizeof(__pyx_k_ndim), 0, 0, 1, 1}, + {&__pyx_n_s_new, __pyx_k_new, sizeof(__pyx_k_new), 0, 0, 1, 1}, + {&__pyx_kp_s_no_default___reduce___due_to_non, __pyx_k_no_default___reduce___due_to_non, sizeof(__pyx_k_no_default___reduce___due_to_non), 0, 0, 1, 0}, + {&__pyx_n_s_np, __pyx_k_np, sizeof(__pyx_k_np), 0, 0, 1, 1}, + {&__pyx_n_s_num_tokens_fn, __pyx_k_num_tokens_fn, sizeof(__pyx_k_num_tokens_fn), 0, 0, 1, 1}, + {&__pyx_n_s_num_tokens_vec, __pyx_k_num_tokens_vec, sizeof(__pyx_k_num_tokens_vec), 0, 0, 1, 1}, + {&__pyx_n_s_numpy, __pyx_k_numpy, sizeof(__pyx_k_numpy), 0, 0, 1, 1}, + {&__pyx_kp_u_numpy_core_multiarray_failed_to, __pyx_k_numpy_core_multiarray_failed_to, sizeof(__pyx_k_numpy_core_multiarray_failed_to), 0, 1, 0, 0}, + {&__pyx_kp_u_numpy_core_umath_failed_to_impor, __pyx_k_numpy_core_umath_failed_to_impor, sizeof(__pyx_k_numpy_core_umath_failed_to_impor), 0, 1, 0, 0}, + {&__pyx_n_s_obj, __pyx_k_obj, sizeof(__pyx_k_obj), 0, 0, 1, 1}, + {&__pyx_n_s_pack, __pyx_k_pack, sizeof(__pyx_k_pack), 0, 0, 1, 1}, + {&__pyx_n_s_pickle, __pyx_k_pickle, sizeof(__pyx_k_pickle), 0, 0, 1, 1}, + {&__pyx_n_s_pyx_PickleError, __pyx_k_pyx_PickleError, sizeof(__pyx_k_pyx_PickleError), 0, 0, 1, 1}, + {&__pyx_n_s_pyx_checksum, __pyx_k_pyx_checksum, sizeof(__pyx_k_pyx_checksum), 0, 0, 1, 1}, + {&__pyx_n_s_pyx_result, __pyx_k_pyx_result, sizeof(__pyx_k_pyx_result), 0, 0, 1, 1}, + {&__pyx_n_s_pyx_state, __pyx_k_pyx_state, sizeof(__pyx_k_pyx_state), 0, 0, 1, 1}, + {&__pyx_n_s_pyx_type, __pyx_k_pyx_type, sizeof(__pyx_k_pyx_type), 0, 0, 1, 1}, + {&__pyx_n_s_pyx_unpickle_Enum, __pyx_k_pyx_unpickle_Enum, sizeof(__pyx_k_pyx_unpickle_Enum), 0, 0, 1, 1}, + {&__pyx_n_s_pyx_vtable, __pyx_k_pyx_vtable, sizeof(__pyx_k_pyx_vtable), 0, 0, 1, 1}, + {&__pyx_n_s_range, __pyx_k_range, sizeof(__pyx_k_range), 0, 0, 1, 1}, + {&__pyx_n_s_reduce, __pyx_k_reduce, sizeof(__pyx_k_reduce), 0, 0, 1, 1}, + {&__pyx_n_s_reduce_cython, __pyx_k_reduce_cython, sizeof(__pyx_k_reduce_cython), 0, 0, 1, 1}, + {&__pyx_n_s_reduce_ex, __pyx_k_reduce_ex, sizeof(__pyx_k_reduce_ex), 0, 0, 1, 1}, + {&__pyx_n_s_register, __pyx_k_register, sizeof(__pyx_k_register), 0, 0, 1, 1}, + {&__pyx_n_s_setstate, __pyx_k_setstate, sizeof(__pyx_k_setstate), 0, 0, 1, 1}, + {&__pyx_n_s_setstate_cython, __pyx_k_setstate_cython, sizeof(__pyx_k_setstate_cython), 0, 0, 1, 1}, + {&__pyx_n_s_shape, __pyx_k_shape, sizeof(__pyx_k_shape), 0, 0, 1, 1}, + {&__pyx_n_s_size, __pyx_k_size, sizeof(__pyx_k_size), 0, 0, 1, 1}, + {&__pyx_n_s_spec, __pyx_k_spec, sizeof(__pyx_k_spec), 0, 0, 1, 1}, + {&__pyx_n_s_split, __pyx_k_split, sizeof(__pyx_k_split), 0, 0, 1, 1}, + {&__pyx_n_s_start, __pyx_k_start, sizeof(__pyx_k_start), 0, 0, 1, 1}, + {&__pyx_n_s_step, __pyx_k_step, sizeof(__pyx_k_step), 0, 0, 1, 1}, + {&__pyx_n_s_stop, __pyx_k_stop, sizeof(__pyx_k_stop), 0, 0, 1, 1}, + {&__pyx_kp_s_strided_and_direct, __pyx_k_strided_and_direct, sizeof(__pyx_k_strided_and_direct), 0, 0, 1, 0}, + {&__pyx_kp_s_strided_and_direct_or_indirect, __pyx_k_strided_and_direct_or_indirect, sizeof(__pyx_k_strided_and_direct_or_indirect), 0, 0, 1, 0}, + {&__pyx_kp_s_strided_and_indirect, __pyx_k_strided_and_indirect, sizeof(__pyx_k_strided_and_indirect), 0, 0, 1, 0}, + {&__pyx_kp_s_stringsource, __pyx_k_stringsource, sizeof(__pyx_k_stringsource), 0, 0, 1, 0}, + {&__pyx_n_s_struct, __pyx_k_struct, sizeof(__pyx_k_struct), 0, 0, 1, 1}, + {&__pyx_n_s_sys, __pyx_k_sys, sizeof(__pyx_k_sys), 0, 0, 1, 1}, + {&__pyx_n_s_test, __pyx_k_test, sizeof(__pyx_k_test), 0, 0, 1, 1}, + {&__pyx_kp_s_unable_to_allocate_array_data, __pyx_k_unable_to_allocate_array_data, sizeof(__pyx_k_unable_to_allocate_array_data), 0, 0, 1, 0}, + {&__pyx_kp_s_unable_to_allocate_shape_and_str, __pyx_k_unable_to_allocate_shape_and_str, sizeof(__pyx_k_unable_to_allocate_shape_and_str), 0, 0, 1, 0}, + {&__pyx_n_s_unpack, __pyx_k_unpack, sizeof(__pyx_k_unpack), 0, 0, 1, 1}, + {&__pyx_n_s_update, __pyx_k_update, sizeof(__pyx_k_update), 0, 0, 1, 1}, + {&__pyx_n_s_version_info, __pyx_k_version_info, sizeof(__pyx_k_version_info), 0, 0, 1, 1}, + {&__pyx_n_s_zeros, __pyx_k_zeros, sizeof(__pyx_k_zeros), 0, 0, 1, 1}, + {0, 0, 0, 0, 0, 0, 0} + }; + return __Pyx_InitStrings(__pyx_string_tab); +} +/* #### Code section: cached_builtins ### */ +static CYTHON_SMALL_CODE int __Pyx_InitCachedBuiltins(void) { + __pyx_builtin_AssertionError = __Pyx_GetBuiltinName(__pyx_n_s_AssertionError); if (!__pyx_builtin_AssertionError) __PYX_ERR(0, 30, __pyx_L1_error) + __pyx_builtin_range = __Pyx_GetBuiltinName(__pyx_n_s_range); if (!__pyx_builtin_range) __PYX_ERR(0, 55, __pyx_L1_error) + __pyx_builtin___import__ = __Pyx_GetBuiltinName(__pyx_n_s_import); if (!__pyx_builtin___import__) __PYX_ERR(1, 100, __pyx_L1_error) + __pyx_builtin_ValueError = __Pyx_GetBuiltinName(__pyx_n_s_ValueError); if (!__pyx_builtin_ValueError) __PYX_ERR(1, 141, __pyx_L1_error) + __pyx_builtin_MemoryError = __Pyx_GetBuiltinName(__pyx_n_s_MemoryError); if (!__pyx_builtin_MemoryError) __PYX_ERR(1, 156, __pyx_L1_error) + __pyx_builtin_enumerate = __Pyx_GetBuiltinName(__pyx_n_s_enumerate); if (!__pyx_builtin_enumerate) __PYX_ERR(1, 159, __pyx_L1_error) + __pyx_builtin_TypeError = __Pyx_GetBuiltinName(__pyx_n_s_TypeError); if (!__pyx_builtin_TypeError) __PYX_ERR(1, 2, __pyx_L1_error) + __pyx_builtin_Ellipsis = __Pyx_GetBuiltinName(__pyx_n_s_Ellipsis); if (!__pyx_builtin_Ellipsis) __PYX_ERR(1, 408, __pyx_L1_error) + __pyx_builtin_id = __Pyx_GetBuiltinName(__pyx_n_s_id); if (!__pyx_builtin_id) __PYX_ERR(1, 618, __pyx_L1_error) + __pyx_builtin_IndexError = __Pyx_GetBuiltinName(__pyx_n_s_IndexError); if (!__pyx_builtin_IndexError) __PYX_ERR(1, 914, __pyx_L1_error) + __pyx_builtin_ImportError = __Pyx_GetBuiltinName(__pyx_n_s_ImportError); if (!__pyx_builtin_ImportError) __PYX_ERR(2, 984, __pyx_L1_error) + return 0; + __pyx_L1_error:; + return -1; +} +/* #### Code section: cached_constants ### */ + +static CYTHON_SMALL_CODE int __Pyx_InitCachedConstants(void) { + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__Pyx_InitCachedConstants", 0); + + /* "View.MemoryView":582 + * def suboffsets(self): + * if self.view.suboffsets == NULL: + * return (-1,) * self.view.ndim # <<<<<<<<<<<<<< + * + * return tuple([suboffset for suboffset in self.view.suboffsets[:self.view.ndim]]) + */ + __pyx_tuple__4 = PyTuple_New(1); if (unlikely(!__pyx_tuple__4)) __PYX_ERR(1, 582, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__4); + __Pyx_INCREF(__pyx_int_neg_1); + __Pyx_GIVEREF(__pyx_int_neg_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_tuple__4, 0, __pyx_int_neg_1)) __PYX_ERR(1, 582, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_tuple__4); + + /* "View.MemoryView":679 + * tup = index if isinstance(index, tuple) else (index,) + * + * result = [slice(None)] * ndim # <<<<<<<<<<<<<< + * have_slices = False + * seen_ellipsis = False + */ + __pyx_slice__5 = PySlice_New(Py_None, Py_None, Py_None); if (unlikely(!__pyx_slice__5)) __PYX_ERR(1, 679, __pyx_L1_error) + __Pyx_GOTREF(__pyx_slice__5); + __Pyx_GIVEREF(__pyx_slice__5); + + /* "(tree fragment)":4 + * cdef object __pyx_PickleError + * cdef object __pyx_result + * if __pyx_checksum not in (0x82a3537, 0x6ae9995, 0xb068931): # <<<<<<<<<<<<<< + * from pickle import PickleError as __pyx_PickleError + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x82a3537, 0x6ae9995, 0xb068931) = (name))" % __pyx_checksum + */ + __pyx_tuple__8 = PyTuple_Pack(3, __pyx_int_136983863, __pyx_int_112105877, __pyx_int_184977713); if (unlikely(!__pyx_tuple__8)) __PYX_ERR(1, 4, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__8); + __Pyx_GIVEREF(__pyx_tuple__8); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":984 + * __pyx_import_array() + * except Exception: + * raise ImportError("numpy.core.multiarray failed to import") # <<<<<<<<<<<<<< + * + * cdef inline int import_umath() except -1: + */ + __pyx_tuple__9 = PyTuple_Pack(1, __pyx_kp_u_numpy_core_multiarray_failed_to); if (unlikely(!__pyx_tuple__9)) __PYX_ERR(2, 984, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__9); + __Pyx_GIVEREF(__pyx_tuple__9); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":990 + * _import_umath() + * except Exception: + * raise ImportError("numpy.core.umath failed to import") # <<<<<<<<<<<<<< + * + * cdef inline int import_ufunc() except -1: + */ + __pyx_tuple__10 = PyTuple_Pack(1, __pyx_kp_u_numpy_core_umath_failed_to_impor); if (unlikely(!__pyx_tuple__10)) __PYX_ERR(2, 990, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__10); + __Pyx_GIVEREF(__pyx_tuple__10); + + /* "View.MemoryView":100 + * cdef object __pyx_collections_abc_Sequence "__pyx_collections_abc_Sequence" + * try: + * if __import__("sys").version_info >= (3, 3): # <<<<<<<<<<<<<< + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence + * else: + */ + __pyx_tuple__11 = PyTuple_Pack(1, __pyx_n_s_sys); if (unlikely(!__pyx_tuple__11)) __PYX_ERR(1, 100, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__11); + __Pyx_GIVEREF(__pyx_tuple__11); + __pyx_tuple__12 = PyTuple_Pack(2, __pyx_int_3, __pyx_int_3); if (unlikely(!__pyx_tuple__12)) __PYX_ERR(1, 100, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__12); + __Pyx_GIVEREF(__pyx_tuple__12); + + /* "View.MemoryView":101 + * try: + * if __import__("sys").version_info >= (3, 3): + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence # <<<<<<<<<<<<<< + * else: + * __pyx_collections_abc_Sequence = __import__("collections").Sequence + */ + __pyx_tuple__13 = PyTuple_Pack(1, __pyx_kp_s_collections_abc); if (unlikely(!__pyx_tuple__13)) __PYX_ERR(1, 101, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__13); + __Pyx_GIVEREF(__pyx_tuple__13); + + /* "View.MemoryView":103 + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence + * else: + * __pyx_collections_abc_Sequence = __import__("collections").Sequence # <<<<<<<<<<<<<< + * except: + * + */ + __pyx_tuple__14 = PyTuple_Pack(1, __pyx_n_s_collections); if (unlikely(!__pyx_tuple__14)) __PYX_ERR(1, 103, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__14); + __Pyx_GIVEREF(__pyx_tuple__14); + + /* "View.MemoryView":309 + * return self.name + * + * cdef generic = Enum("") # <<<<<<<<<<<<<< + * cdef strided = Enum("") # default + * cdef indirect = Enum("") + */ + __pyx_tuple__15 = PyTuple_Pack(1, __pyx_kp_s_strided_and_direct_or_indirect); if (unlikely(!__pyx_tuple__15)) __PYX_ERR(1, 309, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__15); + __Pyx_GIVEREF(__pyx_tuple__15); + + /* "View.MemoryView":310 + * + * cdef generic = Enum("") + * cdef strided = Enum("") # default # <<<<<<<<<<<<<< + * cdef indirect = Enum("") + * + */ + __pyx_tuple__16 = PyTuple_Pack(1, __pyx_kp_s_strided_and_direct); if (unlikely(!__pyx_tuple__16)) __PYX_ERR(1, 310, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__16); + __Pyx_GIVEREF(__pyx_tuple__16); + + /* "View.MemoryView":311 + * cdef generic = Enum("") + * cdef strided = Enum("") # default + * cdef indirect = Enum("") # <<<<<<<<<<<<<< + * + * + */ + __pyx_tuple__17 = PyTuple_Pack(1, __pyx_kp_s_strided_and_indirect); if (unlikely(!__pyx_tuple__17)) __PYX_ERR(1, 311, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__17); + __Pyx_GIVEREF(__pyx_tuple__17); + + /* "View.MemoryView":314 + * + * + * cdef contiguous = Enum("") # <<<<<<<<<<<<<< + * cdef indirect_contiguous = Enum("") + * + */ + __pyx_tuple__18 = PyTuple_Pack(1, __pyx_kp_s_contiguous_and_direct); if (unlikely(!__pyx_tuple__18)) __PYX_ERR(1, 314, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__18); + __Pyx_GIVEREF(__pyx_tuple__18); + + /* "View.MemoryView":315 + * + * cdef contiguous = Enum("") + * cdef indirect_contiguous = Enum("") # <<<<<<<<<<<<<< + * + * + */ + __pyx_tuple__19 = PyTuple_Pack(1, __pyx_kp_s_contiguous_and_indirect); if (unlikely(!__pyx_tuple__19)) __PYX_ERR(1, 315, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__19); + __Pyx_GIVEREF(__pyx_tuple__19); + + /* "(tree fragment)":1 + * def __pyx_unpickle_Enum(__pyx_type, long __pyx_checksum, __pyx_state): # <<<<<<<<<<<<<< + * cdef object __pyx_PickleError + * cdef object __pyx_result + */ + __pyx_tuple__20 = PyTuple_Pack(5, __pyx_n_s_pyx_type, __pyx_n_s_pyx_checksum, __pyx_n_s_pyx_state, __pyx_n_s_pyx_PickleError, __pyx_n_s_pyx_result); if (unlikely(!__pyx_tuple__20)) __PYX_ERR(1, 1, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__20); + __Pyx_GIVEREF(__pyx_tuple__20); + __pyx_codeobj__21 = (PyObject*)__Pyx_PyCode_New(3, 0, 0, 5, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__20, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_stringsource, __pyx_n_s_pyx_unpickle_Enum, 1, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__21)) __PYX_ERR(1, 1, __pyx_L1_error) + + /* "fairseq/data/data_utils_fast.pyx":20 + * @cython.boundscheck(False) + * @cython.wraparound(False) + * cpdef list batch_by_size_vec( # <<<<<<<<<<<<<< + * np.ndarray[int64_t, ndim=1] indices, + * np.ndarray[int64_t, ndim=1] num_tokens_vec, + */ + __pyx_tuple__22 = PyTuple_Pack(5, __pyx_n_s_indices, __pyx_n_s_num_tokens_vec, __pyx_n_s_max_tokens, __pyx_n_s_max_sentences, __pyx_n_s_bsz_mult); if (unlikely(!__pyx_tuple__22)) __PYX_ERR(0, 20, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__22); + __Pyx_GIVEREF(__pyx_tuple__22); + __pyx_codeobj__23 = (PyObject*)__Pyx_PyCode_New(5, 0, 0, 5, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__22, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_fairseq_data_data_utils_fast_pyx, __pyx_n_s_batch_by_size_vec, 20, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__23)) __PYX_ERR(0, 20, __pyx_L1_error) + + /* "fairseq/data/data_utils_fast.pyx":108 + * @cython.boundscheck(False) + * @cython.wraparound(False) + * cpdef list batch_by_size_fn( # <<<<<<<<<<<<<< + * np.ndarray[DTYPE_t, ndim=1] indices, + * num_tokens_fn, + */ + __pyx_tuple__24 = PyTuple_Pack(5, __pyx_n_s_indices, __pyx_n_s_num_tokens_fn, __pyx_n_s_max_tokens, __pyx_n_s_max_sentences, __pyx_n_s_bsz_mult); if (unlikely(!__pyx_tuple__24)) __PYX_ERR(0, 108, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__24); + __Pyx_GIVEREF(__pyx_tuple__24); + __pyx_codeobj__25 = (PyObject*)__Pyx_PyCode_New(5, 0, 0, 5, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__24, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_fairseq_data_data_utils_fast_pyx, __pyx_n_s_batch_by_size_fn, 108, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__25)) __PYX_ERR(0, 108, __pyx_L1_error) + + /* "fairseq/data/data_utils_fast.pyx":140 + * + * @cython.cdivision(True) + * cpdef list batch_fixed_shapes_fast( # <<<<<<<<<<<<<< + * np.ndarray[DTYPE_t, ndim=1] indices, + * num_tokens_fn, + */ + __pyx_tuple__26 = PyTuple_Pack(3, __pyx_n_s_indices, __pyx_n_s_num_tokens_fn, __pyx_n_s_fixed_shapes_sorted); if (unlikely(!__pyx_tuple__26)) __PYX_ERR(0, 140, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__26); + __Pyx_GIVEREF(__pyx_tuple__26); + __pyx_codeobj__27 = (PyObject*)__Pyx_PyCode_New(3, 0, 0, 3, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__26, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_fairseq_data_data_utils_fast_pyx, __pyx_n_s_batch_fixed_shapes_fast, 140, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__27)) __PYX_ERR(0, 140, __pyx_L1_error) + __Pyx_RefNannyFinishContext(); + return 0; + __pyx_L1_error:; + __Pyx_RefNannyFinishContext(); + return -1; +} +/* #### Code section: init_constants ### */ + +static CYTHON_SMALL_CODE int __Pyx_InitConstants(void) { + if (__Pyx_CreateStringTabAndInitStrings() < 0) __PYX_ERR(0, 1, __pyx_L1_error); + __pyx_int_0 = PyInt_FromLong(0); if (unlikely(!__pyx_int_0)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_int_1 = PyInt_FromLong(1); if (unlikely(!__pyx_int_1)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_int_3 = PyInt_FromLong(3); if (unlikely(!__pyx_int_3)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_int_112105877 = PyInt_FromLong(112105877L); if (unlikely(!__pyx_int_112105877)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_int_136983863 = PyInt_FromLong(136983863L); if (unlikely(!__pyx_int_136983863)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_int_184977713 = PyInt_FromLong(184977713L); if (unlikely(!__pyx_int_184977713)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_int_neg_1 = PyInt_FromLong(-1); if (unlikely(!__pyx_int_neg_1)) __PYX_ERR(0, 1, __pyx_L1_error) + return 0; + __pyx_L1_error:; + return -1; +} +/* #### Code section: init_globals ### */ + +static CYTHON_SMALL_CODE int __Pyx_InitGlobals(void) { + /* AssertionsEnabled.init */ + if (likely(__Pyx_init_assertions_enabled() == 0)); else + +if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 1, __pyx_L1_error) + + /* NumpyImportArray.init */ + /* + * Cython has automatically inserted a call to _import_array since + * you didn't include one when you cimported numpy. To disable this + * add the line + * numpy._import_array + */ +#ifdef NPY_FEATURE_VERSION +#ifndef NO_IMPORT_ARRAY +if (unlikely(_import_array() == -1)) { + PyErr_SetString(PyExc_ImportError, "numpy.core.multiarray failed to import " + "(auto-generated because you didn't call 'numpy.import_array()' after cimporting numpy; " + "use 'numpy._import_array' to disable if you are certain you don't need it)."); +} +#endif +#endif + +if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 1, __pyx_L1_error) + + return 0; + __pyx_L1_error:; + return -1; +} +/* #### Code section: init_module ### */ + +static CYTHON_SMALL_CODE int __Pyx_modinit_global_init_code(void); /*proto*/ +static CYTHON_SMALL_CODE int __Pyx_modinit_variable_export_code(void); /*proto*/ +static CYTHON_SMALL_CODE int __Pyx_modinit_function_export_code(void); /*proto*/ +static CYTHON_SMALL_CODE int __Pyx_modinit_type_init_code(void); /*proto*/ +static CYTHON_SMALL_CODE int __Pyx_modinit_type_import_code(void); /*proto*/ +static CYTHON_SMALL_CODE int __Pyx_modinit_variable_import_code(void); /*proto*/ +static CYTHON_SMALL_CODE int __Pyx_modinit_function_import_code(void); /*proto*/ + +static int __Pyx_modinit_global_init_code(void) { + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__Pyx_modinit_global_init_code", 0); + /*--- Global init code ---*/ + __pyx_collections_abc_Sequence = Py_None; Py_INCREF(Py_None); + generic = Py_None; Py_INCREF(Py_None); + strided = Py_None; Py_INCREF(Py_None); + indirect = Py_None; Py_INCREF(Py_None); + contiguous = Py_None; Py_INCREF(Py_None); + indirect_contiguous = Py_None; Py_INCREF(Py_None); + __Pyx_RefNannyFinishContext(); + return 0; +} + +static int __Pyx_modinit_variable_export_code(void) { + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__Pyx_modinit_variable_export_code", 0); + /*--- Variable export code ---*/ + __Pyx_RefNannyFinishContext(); + return 0; +} + +static int __Pyx_modinit_function_export_code(void) { + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__Pyx_modinit_function_export_code", 0); + /*--- Function export code ---*/ + __Pyx_RefNannyFinishContext(); + return 0; +} + +static int __Pyx_modinit_type_init_code(void) { + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__Pyx_modinit_type_init_code", 0); + /*--- Type init code ---*/ + __pyx_vtabptr_array = &__pyx_vtable_array; + __pyx_vtable_array.get_memview = (PyObject *(*)(struct __pyx_array_obj *))__pyx_array_get_memview; + #if CYTHON_USE_TYPE_SPECS + __pyx_array_type = (PyTypeObject *) __Pyx_PyType_FromModuleAndSpec(__pyx_m, &__pyx_type___pyx_array_spec, NULL); if (unlikely(!__pyx_array_type)) __PYX_ERR(1, 114, __pyx_L1_error) + #if !CYTHON_COMPILING_IN_LIMITED_API + __pyx_array_type->tp_as_buffer = &__pyx_tp_as_buffer_array; + if (!__pyx_array_type->tp_as_buffer->bf_releasebuffer && __pyx_array_type->tp_base->tp_as_buffer && __pyx_array_type->tp_base->tp_as_buffer->bf_releasebuffer) { + __pyx_array_type->tp_as_buffer->bf_releasebuffer = __pyx_array_type->tp_base->tp_as_buffer->bf_releasebuffer; + } + #elif defined(Py_bf_getbuffer) && defined(Py_bf_releasebuffer) + /* PY_VERSION_HEX >= 0x03090000 || Py_LIMITED_API >= 0x030B0000 */ + #elif defined(_MSC_VER) + #pragma message ("The buffer protocol is not supported in the Limited C-API < 3.11.") + #else + #warning "The buffer protocol is not supported in the Limited C-API < 3.11." + #endif + if (__Pyx_fix_up_extension_type_from_spec(&__pyx_type___pyx_array_spec, __pyx_array_type) < 0) __PYX_ERR(1, 114, __pyx_L1_error) + #else + __pyx_array_type = &__pyx_type___pyx_array; + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + #endif + #if !CYTHON_USE_TYPE_SPECS + if (__Pyx_PyType_Ready(__pyx_array_type) < 0) __PYX_ERR(1, 114, __pyx_L1_error) + #endif + #if PY_MAJOR_VERSION < 3 + __pyx_array_type->tp_print = 0; + #endif + if (__Pyx_SetVtable(__pyx_array_type, __pyx_vtabptr_array) < 0) __PYX_ERR(1, 114, __pyx_L1_error) + #if !CYTHON_COMPILING_IN_LIMITED_API + if (__Pyx_MergeVtables(__pyx_array_type) < 0) __PYX_ERR(1, 114, __pyx_L1_error) + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + if (__Pyx_setup_reduce((PyObject *) __pyx_array_type) < 0) __PYX_ERR(1, 114, __pyx_L1_error) + #endif + #if CYTHON_USE_TYPE_SPECS + __pyx_MemviewEnum_type = (PyTypeObject *) __Pyx_PyType_FromModuleAndSpec(__pyx_m, &__pyx_type___pyx_MemviewEnum_spec, NULL); if (unlikely(!__pyx_MemviewEnum_type)) __PYX_ERR(1, 302, __pyx_L1_error) + if (__Pyx_fix_up_extension_type_from_spec(&__pyx_type___pyx_MemviewEnum_spec, __pyx_MemviewEnum_type) < 0) __PYX_ERR(1, 302, __pyx_L1_error) + #else + __pyx_MemviewEnum_type = &__pyx_type___pyx_MemviewEnum; + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + #endif + #if !CYTHON_USE_TYPE_SPECS + if (__Pyx_PyType_Ready(__pyx_MemviewEnum_type) < 0) __PYX_ERR(1, 302, __pyx_L1_error) + #endif + #if PY_MAJOR_VERSION < 3 + __pyx_MemviewEnum_type->tp_print = 0; + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + if ((CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP) && likely(!__pyx_MemviewEnum_type->tp_dictoffset && __pyx_MemviewEnum_type->tp_getattro == PyObject_GenericGetAttr)) { + __pyx_MemviewEnum_type->tp_getattro = __Pyx_PyObject_GenericGetAttr; + } + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + if (__Pyx_setup_reduce((PyObject *) __pyx_MemviewEnum_type) < 0) __PYX_ERR(1, 302, __pyx_L1_error) + #endif + __pyx_vtabptr_memoryview = &__pyx_vtable_memoryview; + __pyx_vtable_memoryview.get_item_pointer = (char *(*)(struct __pyx_memoryview_obj *, PyObject *))__pyx_memoryview_get_item_pointer; + __pyx_vtable_memoryview.is_slice = (PyObject *(*)(struct __pyx_memoryview_obj *, PyObject *))__pyx_memoryview_is_slice; + __pyx_vtable_memoryview.setitem_slice_assignment = (PyObject *(*)(struct __pyx_memoryview_obj *, PyObject *, PyObject *))__pyx_memoryview_setitem_slice_assignment; + __pyx_vtable_memoryview.setitem_slice_assign_scalar = (PyObject *(*)(struct __pyx_memoryview_obj *, struct __pyx_memoryview_obj *, PyObject *))__pyx_memoryview_setitem_slice_assign_scalar; + __pyx_vtable_memoryview.setitem_indexed = (PyObject *(*)(struct __pyx_memoryview_obj *, PyObject *, PyObject *))__pyx_memoryview_setitem_indexed; + __pyx_vtable_memoryview.convert_item_to_object = (PyObject *(*)(struct __pyx_memoryview_obj *, char *))__pyx_memoryview_convert_item_to_object; + __pyx_vtable_memoryview.assign_item_from_object = (PyObject *(*)(struct __pyx_memoryview_obj *, char *, PyObject *))__pyx_memoryview_assign_item_from_object; + __pyx_vtable_memoryview._get_base = (PyObject *(*)(struct __pyx_memoryview_obj *))__pyx_memoryview__get_base; + #if CYTHON_USE_TYPE_SPECS + __pyx_memoryview_type = (PyTypeObject *) __Pyx_PyType_FromModuleAndSpec(__pyx_m, &__pyx_type___pyx_memoryview_spec, NULL); if (unlikely(!__pyx_memoryview_type)) __PYX_ERR(1, 337, __pyx_L1_error) + #if !CYTHON_COMPILING_IN_LIMITED_API + __pyx_memoryview_type->tp_as_buffer = &__pyx_tp_as_buffer_memoryview; + if (!__pyx_memoryview_type->tp_as_buffer->bf_releasebuffer && __pyx_memoryview_type->tp_base->tp_as_buffer && __pyx_memoryview_type->tp_base->tp_as_buffer->bf_releasebuffer) { + __pyx_memoryview_type->tp_as_buffer->bf_releasebuffer = __pyx_memoryview_type->tp_base->tp_as_buffer->bf_releasebuffer; + } + #elif defined(Py_bf_getbuffer) && defined(Py_bf_releasebuffer) + /* PY_VERSION_HEX >= 0x03090000 || Py_LIMITED_API >= 0x030B0000 */ + #elif defined(_MSC_VER) + #pragma message ("The buffer protocol is not supported in the Limited C-API < 3.11.") + #else + #warning "The buffer protocol is not supported in the Limited C-API < 3.11." + #endif + if (__Pyx_fix_up_extension_type_from_spec(&__pyx_type___pyx_memoryview_spec, __pyx_memoryview_type) < 0) __PYX_ERR(1, 337, __pyx_L1_error) + #else + __pyx_memoryview_type = &__pyx_type___pyx_memoryview; + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + #endif + #if !CYTHON_USE_TYPE_SPECS + if (__Pyx_PyType_Ready(__pyx_memoryview_type) < 0) __PYX_ERR(1, 337, __pyx_L1_error) + #endif + #if PY_MAJOR_VERSION < 3 + __pyx_memoryview_type->tp_print = 0; + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + if ((CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP) && likely(!__pyx_memoryview_type->tp_dictoffset && __pyx_memoryview_type->tp_getattro == PyObject_GenericGetAttr)) { + __pyx_memoryview_type->tp_getattro = __Pyx_PyObject_GenericGetAttr; + } + #endif + if (__Pyx_SetVtable(__pyx_memoryview_type, __pyx_vtabptr_memoryview) < 0) __PYX_ERR(1, 337, __pyx_L1_error) + #if !CYTHON_COMPILING_IN_LIMITED_API + if (__Pyx_MergeVtables(__pyx_memoryview_type) < 0) __PYX_ERR(1, 337, __pyx_L1_error) + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + if (__Pyx_setup_reduce((PyObject *) __pyx_memoryview_type) < 0) __PYX_ERR(1, 337, __pyx_L1_error) + #endif + __pyx_vtabptr__memoryviewslice = &__pyx_vtable__memoryviewslice; + __pyx_vtable__memoryviewslice.__pyx_base = *__pyx_vtabptr_memoryview; + __pyx_vtable__memoryviewslice.__pyx_base.convert_item_to_object = (PyObject *(*)(struct __pyx_memoryview_obj *, char *))__pyx_memoryviewslice_convert_item_to_object; + __pyx_vtable__memoryviewslice.__pyx_base.assign_item_from_object = (PyObject *(*)(struct __pyx_memoryview_obj *, char *, PyObject *))__pyx_memoryviewslice_assign_item_from_object; + __pyx_vtable__memoryviewslice.__pyx_base._get_base = (PyObject *(*)(struct __pyx_memoryview_obj *))__pyx_memoryviewslice__get_base; + #if CYTHON_USE_TYPE_SPECS + __pyx_t_1 = PyTuple_Pack(1, (PyObject *)__pyx_memoryview_type); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 952, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_memoryviewslice_type = (PyTypeObject *) __Pyx_PyType_FromModuleAndSpec(__pyx_m, &__pyx_type___pyx_memoryviewslice_spec, __pyx_t_1); + __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; + if (unlikely(!__pyx_memoryviewslice_type)) __PYX_ERR(1, 952, __pyx_L1_error) + if (__Pyx_fix_up_extension_type_from_spec(&__pyx_type___pyx_memoryviewslice_spec, __pyx_memoryviewslice_type) < 0) __PYX_ERR(1, 952, __pyx_L1_error) + #else + __pyx_memoryviewslice_type = &__pyx_type___pyx_memoryviewslice; + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + __pyx_memoryviewslice_type->tp_base = __pyx_memoryview_type; + #endif + #if !CYTHON_USE_TYPE_SPECS + if (__Pyx_PyType_Ready(__pyx_memoryviewslice_type) < 0) __PYX_ERR(1, 952, __pyx_L1_error) + #endif + #if PY_MAJOR_VERSION < 3 + __pyx_memoryviewslice_type->tp_print = 0; + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + if ((CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP) && likely(!__pyx_memoryviewslice_type->tp_dictoffset && __pyx_memoryviewslice_type->tp_getattro == PyObject_GenericGetAttr)) { + __pyx_memoryviewslice_type->tp_getattro = __Pyx_PyObject_GenericGetAttr; + } + #endif + if (__Pyx_SetVtable(__pyx_memoryviewslice_type, __pyx_vtabptr__memoryviewslice) < 0) __PYX_ERR(1, 952, __pyx_L1_error) + #if !CYTHON_COMPILING_IN_LIMITED_API + if (__Pyx_MergeVtables(__pyx_memoryviewslice_type) < 0) __PYX_ERR(1, 952, __pyx_L1_error) + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + if (__Pyx_setup_reduce((PyObject *) __pyx_memoryviewslice_type) < 0) __PYX_ERR(1, 952, __pyx_L1_error) + #endif + __Pyx_RefNannyFinishContext(); + return 0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_RefNannyFinishContext(); + return -1; +} + +static int __Pyx_modinit_type_import_code(void) { + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__Pyx_modinit_type_import_code", 0); + /*--- Type import code ---*/ + __pyx_t_1 = PyImport_ImportModule(__Pyx_BUILTIN_MODULE_NAME); if (unlikely(!__pyx_t_1)) __PYX_ERR(3, 9, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_ptype_7cpython_4type_type = __Pyx_ImportType_3_0_8(__pyx_t_1, __Pyx_BUILTIN_MODULE_NAME, "type", + #if defined(PYPY_VERSION_NUM) && PYPY_VERSION_NUM < 0x050B0000 + sizeof(PyTypeObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyTypeObject), + #elif CYTHON_COMPILING_IN_LIMITED_API + sizeof(PyTypeObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyTypeObject), + #else + sizeof(PyHeapTypeObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyHeapTypeObject), + #endif + __Pyx_ImportType_CheckSize_Warn_3_0_8); if (!__pyx_ptype_7cpython_4type_type) __PYX_ERR(3, 9, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_t_1 = PyImport_ImportModule("numpy"); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 202, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_ptype_5numpy_dtype = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "dtype", sizeof(PyArray_Descr), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyArray_Descr),__Pyx_ImportType_CheckSize_Ignore_3_0_8); if (!__pyx_ptype_5numpy_dtype) __PYX_ERR(2, 202, __pyx_L1_error) + __pyx_ptype_5numpy_flatiter = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "flatiter", sizeof(PyArrayIterObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyArrayIterObject),__Pyx_ImportType_CheckSize_Ignore_3_0_8); if (!__pyx_ptype_5numpy_flatiter) __PYX_ERR(2, 225, __pyx_L1_error) + __pyx_ptype_5numpy_broadcast = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "broadcast", sizeof(PyArrayMultiIterObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyArrayMultiIterObject),__Pyx_ImportType_CheckSize_Ignore_3_0_8); if (!__pyx_ptype_5numpy_broadcast) __PYX_ERR(2, 229, __pyx_L1_error) + __pyx_ptype_5numpy_ndarray = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "ndarray", sizeof(PyArrayObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyArrayObject),__Pyx_ImportType_CheckSize_Ignore_3_0_8); if (!__pyx_ptype_5numpy_ndarray) __PYX_ERR(2, 238, __pyx_L1_error) + __pyx_ptype_5numpy_generic = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "generic", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_8); if (!__pyx_ptype_5numpy_generic) __PYX_ERR(2, 809, __pyx_L1_error) + __pyx_ptype_5numpy_number = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "number", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_8); if (!__pyx_ptype_5numpy_number) __PYX_ERR(2, 811, __pyx_L1_error) + __pyx_ptype_5numpy_integer = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "integer", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_8); if (!__pyx_ptype_5numpy_integer) __PYX_ERR(2, 813, __pyx_L1_error) + __pyx_ptype_5numpy_signedinteger = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "signedinteger", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_8); if (!__pyx_ptype_5numpy_signedinteger) __PYX_ERR(2, 815, __pyx_L1_error) + __pyx_ptype_5numpy_unsignedinteger = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "unsignedinteger", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_8); if (!__pyx_ptype_5numpy_unsignedinteger) __PYX_ERR(2, 817, __pyx_L1_error) + __pyx_ptype_5numpy_inexact = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "inexact", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_8); if (!__pyx_ptype_5numpy_inexact) __PYX_ERR(2, 819, __pyx_L1_error) + __pyx_ptype_5numpy_floating = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "floating", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_8); if (!__pyx_ptype_5numpy_floating) __PYX_ERR(2, 821, __pyx_L1_error) + __pyx_ptype_5numpy_complexfloating = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "complexfloating", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_8); if (!__pyx_ptype_5numpy_complexfloating) __PYX_ERR(2, 823, __pyx_L1_error) + __pyx_ptype_5numpy_flexible = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "flexible", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_8); if (!__pyx_ptype_5numpy_flexible) __PYX_ERR(2, 825, __pyx_L1_error) + __pyx_ptype_5numpy_character = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "character", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_8); if (!__pyx_ptype_5numpy_character) __PYX_ERR(2, 827, __pyx_L1_error) + __pyx_ptype_5numpy_ufunc = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "ufunc", sizeof(PyUFuncObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyUFuncObject),__Pyx_ImportType_CheckSize_Ignore_3_0_8); if (!__pyx_ptype_5numpy_ufunc) __PYX_ERR(2, 866, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_RefNannyFinishContext(); + return 0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_RefNannyFinishContext(); + return -1; +} + +static int __Pyx_modinit_variable_import_code(void) { + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__Pyx_modinit_variable_import_code", 0); + /*--- Variable import code ---*/ + __Pyx_RefNannyFinishContext(); + return 0; +} + +static int __Pyx_modinit_function_import_code(void) { + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__Pyx_modinit_function_import_code", 0); + /*--- Function import code ---*/ + __Pyx_RefNannyFinishContext(); + return 0; +} + + +#if PY_MAJOR_VERSION >= 3 +#if CYTHON_PEP489_MULTI_PHASE_INIT +static PyObject* __pyx_pymod_create(PyObject *spec, PyModuleDef *def); /*proto*/ +static int __pyx_pymod_exec_data_utils_fast(PyObject* module); /*proto*/ +static PyModuleDef_Slot __pyx_moduledef_slots[] = { + {Py_mod_create, (void*)__pyx_pymod_create}, + {Py_mod_exec, (void*)__pyx_pymod_exec_data_utils_fast}, + {0, NULL} +}; +#endif + +#ifdef __cplusplus +namespace { + struct PyModuleDef __pyx_moduledef = + #else + static struct PyModuleDef __pyx_moduledef = + #endif + { + PyModuleDef_HEAD_INIT, + "data_utils_fast", + 0, /* m_doc */ + #if CYTHON_PEP489_MULTI_PHASE_INIT + 0, /* m_size */ + #elif CYTHON_USE_MODULE_STATE + sizeof(__pyx_mstate), /* m_size */ + #else + -1, /* m_size */ + #endif + __pyx_methods /* m_methods */, + #if CYTHON_PEP489_MULTI_PHASE_INIT + __pyx_moduledef_slots, /* m_slots */ + #else + NULL, /* m_reload */ + #endif + #if CYTHON_USE_MODULE_STATE + __pyx_m_traverse, /* m_traverse */ + __pyx_m_clear, /* m_clear */ + NULL /* m_free */ + #else + NULL, /* m_traverse */ + NULL, /* m_clear */ + NULL /* m_free */ + #endif + }; + #ifdef __cplusplus +} /* anonymous namespace */ +#endif +#endif + +#ifndef CYTHON_NO_PYINIT_EXPORT +#define __Pyx_PyMODINIT_FUNC PyMODINIT_FUNC +#elif PY_MAJOR_VERSION < 3 +#ifdef __cplusplus +#define __Pyx_PyMODINIT_FUNC extern "C" void +#else +#define __Pyx_PyMODINIT_FUNC void +#endif +#else +#ifdef __cplusplus +#define __Pyx_PyMODINIT_FUNC extern "C" PyObject * +#else +#define __Pyx_PyMODINIT_FUNC PyObject * +#endif +#endif + + +#if PY_MAJOR_VERSION < 3 +__Pyx_PyMODINIT_FUNC initdata_utils_fast(void) CYTHON_SMALL_CODE; /*proto*/ +__Pyx_PyMODINIT_FUNC initdata_utils_fast(void) +#else +__Pyx_PyMODINIT_FUNC PyInit_data_utils_fast(void) CYTHON_SMALL_CODE; /*proto*/ +__Pyx_PyMODINIT_FUNC PyInit_data_utils_fast(void) +#if CYTHON_PEP489_MULTI_PHASE_INIT +{ + return PyModuleDef_Init(&__pyx_moduledef); +} +static CYTHON_SMALL_CODE int __Pyx_check_single_interpreter(void) { + #if PY_VERSION_HEX >= 0x030700A1 + static PY_INT64_T main_interpreter_id = -1; + PY_INT64_T current_id = PyInterpreterState_GetID(PyThreadState_Get()->interp); + if (main_interpreter_id == -1) { + main_interpreter_id = current_id; + return (unlikely(current_id == -1)) ? -1 : 0; + } else if (unlikely(main_interpreter_id != current_id)) + #else + static PyInterpreterState *main_interpreter = NULL; + PyInterpreterState *current_interpreter = PyThreadState_Get()->interp; + if (!main_interpreter) { + main_interpreter = current_interpreter; + } else if (unlikely(main_interpreter != current_interpreter)) + #endif + { + PyErr_SetString( + PyExc_ImportError, + "Interpreter change detected - this module can only be loaded into one interpreter per process."); + return -1; + } + return 0; +} +#if CYTHON_COMPILING_IN_LIMITED_API +static CYTHON_SMALL_CODE int __Pyx_copy_spec_to_module(PyObject *spec, PyObject *module, const char* from_name, const char* to_name, int allow_none) +#else +static CYTHON_SMALL_CODE int __Pyx_copy_spec_to_module(PyObject *spec, PyObject *moddict, const char* from_name, const char* to_name, int allow_none) +#endif +{ + PyObject *value = PyObject_GetAttrString(spec, from_name); + int result = 0; + if (likely(value)) { + if (allow_none || value != Py_None) { +#if CYTHON_COMPILING_IN_LIMITED_API + result = PyModule_AddObject(module, to_name, value); +#else + result = PyDict_SetItemString(moddict, to_name, value); +#endif + } + Py_DECREF(value); + } else if (PyErr_ExceptionMatches(PyExc_AttributeError)) { + PyErr_Clear(); + } else { + result = -1; + } + return result; +} +static CYTHON_SMALL_CODE PyObject* __pyx_pymod_create(PyObject *spec, PyModuleDef *def) { + PyObject *module = NULL, *moddict, *modname; + CYTHON_UNUSED_VAR(def); + if (__Pyx_check_single_interpreter()) + return NULL; + if (__pyx_m) + return __Pyx_NewRef(__pyx_m); + modname = PyObject_GetAttrString(spec, "name"); + if (unlikely(!modname)) goto bad; + module = PyModule_NewObject(modname); + Py_DECREF(modname); + if (unlikely(!module)) goto bad; +#if CYTHON_COMPILING_IN_LIMITED_API + moddict = module; +#else + moddict = PyModule_GetDict(module); + if (unlikely(!moddict)) goto bad; +#endif + if (unlikely(__Pyx_copy_spec_to_module(spec, moddict, "loader", "__loader__", 1) < 0)) goto bad; + if (unlikely(__Pyx_copy_spec_to_module(spec, moddict, "origin", "__file__", 1) < 0)) goto bad; + if (unlikely(__Pyx_copy_spec_to_module(spec, moddict, "parent", "__package__", 1) < 0)) goto bad; + if (unlikely(__Pyx_copy_spec_to_module(spec, moddict, "submodule_search_locations", "__path__", 0) < 0)) goto bad; + return module; +bad: + Py_XDECREF(module); + return NULL; +} + + +static CYTHON_SMALL_CODE int __pyx_pymod_exec_data_utils_fast(PyObject *__pyx_pyinit_module) +#endif +#endif +{ + int stringtab_initialized = 0; + #if CYTHON_USE_MODULE_STATE + int pystate_addmodule_run = 0; + #endif + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + PyObject *__pyx_t_5 = NULL; + int __pyx_t_6; + PyObject *__pyx_t_7 = NULL; + static PyThread_type_lock __pyx_t_8[8]; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannyDeclarations + #if CYTHON_PEP489_MULTI_PHASE_INIT + if (__pyx_m) { + if (__pyx_m == __pyx_pyinit_module) return 0; + PyErr_SetString(PyExc_RuntimeError, "Module 'data_utils_fast' has already been imported. Re-initialisation is not supported."); + return -1; + } + #elif PY_MAJOR_VERSION >= 3 + if (__pyx_m) return __Pyx_NewRef(__pyx_m); + #endif + /*--- Module creation code ---*/ + #if CYTHON_PEP489_MULTI_PHASE_INIT + __pyx_m = __pyx_pyinit_module; + Py_INCREF(__pyx_m); + #else + #if PY_MAJOR_VERSION < 3 + __pyx_m = Py_InitModule4("data_utils_fast", __pyx_methods, 0, 0, PYTHON_API_VERSION); Py_XINCREF(__pyx_m); + if (unlikely(!__pyx_m)) __PYX_ERR(0, 1, __pyx_L1_error) + #elif CYTHON_USE_MODULE_STATE + __pyx_t_1 = PyModule_Create(&__pyx_moduledef); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1, __pyx_L1_error) + { + int add_module_result = PyState_AddModule(__pyx_t_1, &__pyx_moduledef); + __pyx_t_1 = 0; /* transfer ownership from __pyx_t_1 to "data_utils_fast" pseudovariable */ + if (unlikely((add_module_result < 0))) __PYX_ERR(0, 1, __pyx_L1_error) + pystate_addmodule_run = 1; + } + #else + __pyx_m = PyModule_Create(&__pyx_moduledef); + if (unlikely(!__pyx_m)) __PYX_ERR(0, 1, __pyx_L1_error) + #endif + #endif + CYTHON_UNUSED_VAR(__pyx_t_1); + __pyx_d = PyModule_GetDict(__pyx_m); if (unlikely(!__pyx_d)) __PYX_ERR(0, 1, __pyx_L1_error) + Py_INCREF(__pyx_d); + __pyx_b = __Pyx_PyImport_AddModuleRef(__Pyx_BUILTIN_MODULE_NAME); if (unlikely(!__pyx_b)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_cython_runtime = __Pyx_PyImport_AddModuleRef((const char *) "cython_runtime"); if (unlikely(!__pyx_cython_runtime)) __PYX_ERR(0, 1, __pyx_L1_error) + if (PyObject_SetAttrString(__pyx_m, "__builtins__", __pyx_b) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #if CYTHON_REFNANNY +__Pyx_RefNanny = __Pyx_RefNannyImportAPI("refnanny"); +if (!__Pyx_RefNanny) { + PyErr_Clear(); + __Pyx_RefNanny = __Pyx_RefNannyImportAPI("Cython.Runtime.refnanny"); + if (!__Pyx_RefNanny) + Py_FatalError("failed to import 'refnanny' module"); +} +#endif + __Pyx_RefNannySetupContext("__Pyx_PyMODINIT_FUNC PyInit_data_utils_fast(void)", 0); + if (__Pyx_check_binary_version(__PYX_LIMITED_VERSION_HEX, __Pyx_get_runtime_version(), CYTHON_COMPILING_IN_LIMITED_API) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #ifdef __Pxy_PyFrame_Initialize_Offsets + __Pxy_PyFrame_Initialize_Offsets(); + #endif + __pyx_empty_tuple = PyTuple_New(0); if (unlikely(!__pyx_empty_tuple)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_empty_bytes = PyBytes_FromStringAndSize("", 0); if (unlikely(!__pyx_empty_bytes)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_empty_unicode = PyUnicode_FromStringAndSize("", 0); if (unlikely(!__pyx_empty_unicode)) __PYX_ERR(0, 1, __pyx_L1_error) + #ifdef __Pyx_CyFunction_USED + if (__pyx_CyFunction_init(__pyx_m) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #endif + #ifdef __Pyx_FusedFunction_USED + if (__pyx_FusedFunction_init(__pyx_m) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #endif + #ifdef __Pyx_Coroutine_USED + if (__pyx_Coroutine_init(__pyx_m) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #endif + #ifdef __Pyx_Generator_USED + if (__pyx_Generator_init(__pyx_m) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #endif + #ifdef __Pyx_AsyncGen_USED + if (__pyx_AsyncGen_init(__pyx_m) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #endif + #ifdef __Pyx_StopAsyncIteration_USED + if (__pyx_StopAsyncIteration_init(__pyx_m) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #endif + /*--- Library function declarations ---*/ + /*--- Threads initialization code ---*/ + #if defined(WITH_THREAD) && PY_VERSION_HEX < 0x030700F0 && defined(__PYX_FORCE_INIT_THREADS) && __PYX_FORCE_INIT_THREADS + PyEval_InitThreads(); + #endif + /*--- Initialize various global constants etc. ---*/ + if (__Pyx_InitConstants() < 0) __PYX_ERR(0, 1, __pyx_L1_error) + stringtab_initialized = 1; + if (__Pyx_InitGlobals() < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #if PY_MAJOR_VERSION < 3 && (__PYX_DEFAULT_STRING_ENCODING_IS_ASCII || __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT) + if (__Pyx_init_sys_getdefaultencoding_params() < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #endif + if (__pyx_module_is_main_fairseq__data__data_utils_fast) { + if (PyObject_SetAttr(__pyx_m, __pyx_n_s_name_2, __pyx_n_s_main) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + } + #if PY_MAJOR_VERSION >= 3 + { + PyObject *modules = PyImport_GetModuleDict(); if (unlikely(!modules)) __PYX_ERR(0, 1, __pyx_L1_error) + if (!PyDict_GetItemString(modules, "fairseq.data.data_utils_fast")) { + if (unlikely((PyDict_SetItemString(modules, "fairseq.data.data_utils_fast", __pyx_m) < 0))) __PYX_ERR(0, 1, __pyx_L1_error) + } + } + #endif + /*--- Builtin init code ---*/ + if (__Pyx_InitCachedBuiltins() < 0) __PYX_ERR(0, 1, __pyx_L1_error) + /*--- Constants init code ---*/ + if (__Pyx_InitCachedConstants() < 0) __PYX_ERR(0, 1, __pyx_L1_error) + /*--- Global type/function init code ---*/ + (void)__Pyx_modinit_global_init_code(); + (void)__Pyx_modinit_variable_export_code(); + (void)__Pyx_modinit_function_export_code(); + if (unlikely((__Pyx_modinit_type_init_code() < 0))) __PYX_ERR(0, 1, __pyx_L1_error) + if (unlikely((__Pyx_modinit_type_import_code() < 0))) __PYX_ERR(0, 1, __pyx_L1_error) + (void)__Pyx_modinit_variable_import_code(); + (void)__Pyx_modinit_function_import_code(); + /*--- Execution code ---*/ + #if defined(__Pyx_Generator_USED) || defined(__Pyx_Coroutine_USED) + if (__Pyx_patch_abc() < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #endif + + /* "View.MemoryView":99 + * + * cdef object __pyx_collections_abc_Sequence "__pyx_collections_abc_Sequence" + * try: # <<<<<<<<<<<<<< + * if __import__("sys").version_info >= (3, 3): + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence + */ + { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ExceptionSave(&__pyx_t_1, &__pyx_t_2, &__pyx_t_3); + __Pyx_XGOTREF(__pyx_t_1); + __Pyx_XGOTREF(__pyx_t_2); + __Pyx_XGOTREF(__pyx_t_3); + /*try:*/ { + + /* "View.MemoryView":100 + * cdef object __pyx_collections_abc_Sequence "__pyx_collections_abc_Sequence" + * try: + * if __import__("sys").version_info >= (3, 3): # <<<<<<<<<<<<<< + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence + * else: + */ + __pyx_t_4 = __Pyx_PyObject_Call(__pyx_builtin___import__, __pyx_tuple__11, NULL); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 100, __pyx_L2_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_5 = __Pyx_PyObject_GetAttrStr(__pyx_t_4, __pyx_n_s_version_info); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 100, __pyx_L2_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __pyx_t_4 = PyObject_RichCompare(__pyx_t_5, __pyx_tuple__12, Py_GE); __Pyx_XGOTREF(__pyx_t_4); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 100, __pyx_L2_error) + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __pyx_t_6 = __Pyx_PyObject_IsTrue(__pyx_t_4); if (unlikely((__pyx_t_6 < 0))) __PYX_ERR(1, 100, __pyx_L2_error) + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + if (__pyx_t_6) { + + /* "View.MemoryView":101 + * try: + * if __import__("sys").version_info >= (3, 3): + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence # <<<<<<<<<<<<<< + * else: + * __pyx_collections_abc_Sequence = __import__("collections").Sequence + */ + __pyx_t_4 = __Pyx_PyObject_Call(__pyx_builtin___import__, __pyx_tuple__13, NULL); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 101, __pyx_L2_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_5 = __Pyx_PyObject_GetAttrStr(__pyx_t_4, __pyx_n_s_abc); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 101, __pyx_L2_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __pyx_t_4 = __Pyx_PyObject_GetAttrStr(__pyx_t_5, __pyx_n_s_Sequence); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 101, __pyx_L2_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_XGOTREF(__pyx_collections_abc_Sequence); + __Pyx_DECREF_SET(__pyx_collections_abc_Sequence, __pyx_t_4); + __Pyx_GIVEREF(__pyx_t_4); + __pyx_t_4 = 0; + + /* "View.MemoryView":100 + * cdef object __pyx_collections_abc_Sequence "__pyx_collections_abc_Sequence" + * try: + * if __import__("sys").version_info >= (3, 3): # <<<<<<<<<<<<<< + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence + * else: + */ + goto __pyx_L8; + } + + /* "View.MemoryView":103 + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence + * else: + * __pyx_collections_abc_Sequence = __import__("collections").Sequence # <<<<<<<<<<<<<< + * except: + * + */ + /*else*/ { + __pyx_t_4 = __Pyx_PyObject_Call(__pyx_builtin___import__, __pyx_tuple__14, NULL); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 103, __pyx_L2_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_5 = __Pyx_PyObject_GetAttrStr(__pyx_t_4, __pyx_n_s_Sequence); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 103, __pyx_L2_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_XGOTREF(__pyx_collections_abc_Sequence); + __Pyx_DECREF_SET(__pyx_collections_abc_Sequence, __pyx_t_5); + __Pyx_GIVEREF(__pyx_t_5); + __pyx_t_5 = 0; + } + __pyx_L8:; + + /* "View.MemoryView":99 + * + * cdef object __pyx_collections_abc_Sequence "__pyx_collections_abc_Sequence" + * try: # <<<<<<<<<<<<<< + * if __import__("sys").version_info >= (3, 3): + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence + */ + } + __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0; + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + goto __pyx_L7_try_end; + __pyx_L2_error:; + __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; + + /* "View.MemoryView":104 + * else: + * __pyx_collections_abc_Sequence = __import__("collections").Sequence + * except: # <<<<<<<<<<<<<< + * + * __pyx_collections_abc_Sequence = None + */ + /*except:*/ { + __Pyx_AddTraceback("View.MemoryView", __pyx_clineno, __pyx_lineno, __pyx_filename); + if (__Pyx_GetException(&__pyx_t_5, &__pyx_t_4, &__pyx_t_7) < 0) __PYX_ERR(1, 104, __pyx_L4_except_error) + __Pyx_XGOTREF(__pyx_t_5); + __Pyx_XGOTREF(__pyx_t_4); + __Pyx_XGOTREF(__pyx_t_7); + + /* "View.MemoryView":106 + * except: + * + * __pyx_collections_abc_Sequence = None # <<<<<<<<<<<<<< + * + * + */ + __Pyx_INCREF(Py_None); + __Pyx_XGOTREF(__pyx_collections_abc_Sequence); + __Pyx_DECREF_SET(__pyx_collections_abc_Sequence, Py_None); + __Pyx_GIVEREF(Py_None); + __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; + goto __pyx_L3_exception_handled; + } + + /* "View.MemoryView":99 + * + * cdef object __pyx_collections_abc_Sequence "__pyx_collections_abc_Sequence" + * try: # <<<<<<<<<<<<<< + * if __import__("sys").version_info >= (3, 3): + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence + */ + __pyx_L4_except_error:; + __Pyx_XGIVEREF(__pyx_t_1); + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_ExceptionReset(__pyx_t_1, __pyx_t_2, __pyx_t_3); + goto __pyx_L1_error; + __pyx_L3_exception_handled:; + __Pyx_XGIVEREF(__pyx_t_1); + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_ExceptionReset(__pyx_t_1, __pyx_t_2, __pyx_t_3); + __pyx_L7_try_end:; + } + + /* "View.MemoryView":241 + * + * + * try: # <<<<<<<<<<<<<< + * count = __pyx_collections_abc_Sequence.count + * index = __pyx_collections_abc_Sequence.index + */ + { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ExceptionSave(&__pyx_t_3, &__pyx_t_2, &__pyx_t_1); + __Pyx_XGOTREF(__pyx_t_3); + __Pyx_XGOTREF(__pyx_t_2); + __Pyx_XGOTREF(__pyx_t_1); + /*try:*/ { + + /* "View.MemoryView":242 + * + * try: + * count = __pyx_collections_abc_Sequence.count # <<<<<<<<<<<<<< + * index = __pyx_collections_abc_Sequence.index + * except: + */ + __pyx_t_7 = __Pyx_PyObject_GetAttrStr(__pyx_collections_abc_Sequence, __pyx_n_s_count); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 242, __pyx_L11_error) + __Pyx_GOTREF(__pyx_t_7); + if (__Pyx_SetItemOnTypeDict(__pyx_array_type, __pyx_n_s_count, __pyx_t_7) < 0) __PYX_ERR(1, 242, __pyx_L11_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + PyType_Modified(__pyx_array_type); + + /* "View.MemoryView":243 + * try: + * count = __pyx_collections_abc_Sequence.count + * index = __pyx_collections_abc_Sequence.index # <<<<<<<<<<<<<< + * except: + * pass + */ + __pyx_t_7 = __Pyx_PyObject_GetAttrStr(__pyx_collections_abc_Sequence, __pyx_n_s_index); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 243, __pyx_L11_error) + __Pyx_GOTREF(__pyx_t_7); + if (__Pyx_SetItemOnTypeDict(__pyx_array_type, __pyx_n_s_index, __pyx_t_7) < 0) __PYX_ERR(1, 243, __pyx_L11_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + PyType_Modified(__pyx_array_type); + + /* "View.MemoryView":241 + * + * + * try: # <<<<<<<<<<<<<< + * count = __pyx_collections_abc_Sequence.count + * index = __pyx_collections_abc_Sequence.index + */ + } + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0; + __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; + goto __pyx_L16_try_end; + __pyx_L11_error:; + __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "View.MemoryView":244 + * count = __pyx_collections_abc_Sequence.count + * index = __pyx_collections_abc_Sequence.index + * except: # <<<<<<<<<<<<<< + * pass + * + */ + /*except:*/ { + __Pyx_ErrRestore(0,0,0); + goto __pyx_L12_exception_handled; + } + __pyx_L12_exception_handled:; + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_1); + __Pyx_ExceptionReset(__pyx_t_3, __pyx_t_2, __pyx_t_1); + __pyx_L16_try_end:; + } + + /* "View.MemoryView":309 + * return self.name + * + * cdef generic = Enum("") # <<<<<<<<<<<<<< + * cdef strided = Enum("") # default + * cdef indirect = Enum("") + */ + __pyx_t_7 = __Pyx_PyObject_Call(((PyObject *)__pyx_MemviewEnum_type), __pyx_tuple__15, NULL); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 309, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __Pyx_XGOTREF(generic); + __Pyx_DECREF_SET(generic, __pyx_t_7); + __Pyx_GIVEREF(__pyx_t_7); + __pyx_t_7 = 0; + + /* "View.MemoryView":310 + * + * cdef generic = Enum("") + * cdef strided = Enum("") # default # <<<<<<<<<<<<<< + * cdef indirect = Enum("") + * + */ + __pyx_t_7 = __Pyx_PyObject_Call(((PyObject *)__pyx_MemviewEnum_type), __pyx_tuple__16, NULL); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 310, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __Pyx_XGOTREF(strided); + __Pyx_DECREF_SET(strided, __pyx_t_7); + __Pyx_GIVEREF(__pyx_t_7); + __pyx_t_7 = 0; + + /* "View.MemoryView":311 + * cdef generic = Enum("") + * cdef strided = Enum("") # default + * cdef indirect = Enum("") # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_7 = __Pyx_PyObject_Call(((PyObject *)__pyx_MemviewEnum_type), __pyx_tuple__17, NULL); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 311, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __Pyx_XGOTREF(indirect); + __Pyx_DECREF_SET(indirect, __pyx_t_7); + __Pyx_GIVEREF(__pyx_t_7); + __pyx_t_7 = 0; + + /* "View.MemoryView":314 + * + * + * cdef contiguous = Enum("") # <<<<<<<<<<<<<< + * cdef indirect_contiguous = Enum("") + * + */ + __pyx_t_7 = __Pyx_PyObject_Call(((PyObject *)__pyx_MemviewEnum_type), __pyx_tuple__18, NULL); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 314, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __Pyx_XGOTREF(contiguous); + __Pyx_DECREF_SET(contiguous, __pyx_t_7); + __Pyx_GIVEREF(__pyx_t_7); + __pyx_t_7 = 0; + + /* "View.MemoryView":315 + * + * cdef contiguous = Enum("") + * cdef indirect_contiguous = Enum("") # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_7 = __Pyx_PyObject_Call(((PyObject *)__pyx_MemviewEnum_type), __pyx_tuple__19, NULL); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 315, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __Pyx_XGOTREF(indirect_contiguous); + __Pyx_DECREF_SET(indirect_contiguous, __pyx_t_7); + __Pyx_GIVEREF(__pyx_t_7); + __pyx_t_7 = 0; + + /* "View.MemoryView":323 + * + * + * cdef int __pyx_memoryview_thread_locks_used = 0 # <<<<<<<<<<<<<< + * cdef PyThread_type_lock[8] __pyx_memoryview_thread_locks = [ + * PyThread_allocate_lock(), + */ + __pyx_memoryview_thread_locks_used = 0; + + /* "View.MemoryView":324 + * + * cdef int __pyx_memoryview_thread_locks_used = 0 + * cdef PyThread_type_lock[8] __pyx_memoryview_thread_locks = [ # <<<<<<<<<<<<<< + * PyThread_allocate_lock(), + * PyThread_allocate_lock(), + */ + __pyx_t_8[0] = PyThread_allocate_lock(); + __pyx_t_8[1] = PyThread_allocate_lock(); + __pyx_t_8[2] = PyThread_allocate_lock(); + __pyx_t_8[3] = PyThread_allocate_lock(); + __pyx_t_8[4] = PyThread_allocate_lock(); + __pyx_t_8[5] = PyThread_allocate_lock(); + __pyx_t_8[6] = PyThread_allocate_lock(); + __pyx_t_8[7] = PyThread_allocate_lock(); + memcpy(&(__pyx_memoryview_thread_locks[0]), __pyx_t_8, sizeof(__pyx_memoryview_thread_locks[0]) * (8)); + + /* "View.MemoryView":982 + * + * + * try: # <<<<<<<<<<<<<< + * count = __pyx_collections_abc_Sequence.count + * index = __pyx_collections_abc_Sequence.index + */ + { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ExceptionSave(&__pyx_t_1, &__pyx_t_2, &__pyx_t_3); + __Pyx_XGOTREF(__pyx_t_1); + __Pyx_XGOTREF(__pyx_t_2); + __Pyx_XGOTREF(__pyx_t_3); + /*try:*/ { + + /* "View.MemoryView":983 + * + * try: + * count = __pyx_collections_abc_Sequence.count # <<<<<<<<<<<<<< + * index = __pyx_collections_abc_Sequence.index + * except: + */ + __pyx_t_7 = __Pyx_PyObject_GetAttrStr(__pyx_collections_abc_Sequence, __pyx_n_s_count); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 983, __pyx_L17_error) + __Pyx_GOTREF(__pyx_t_7); + if (__Pyx_SetItemOnTypeDict(__pyx_memoryviewslice_type, __pyx_n_s_count, __pyx_t_7) < 0) __PYX_ERR(1, 983, __pyx_L17_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + PyType_Modified(__pyx_memoryviewslice_type); + + /* "View.MemoryView":984 + * try: + * count = __pyx_collections_abc_Sequence.count + * index = __pyx_collections_abc_Sequence.index # <<<<<<<<<<<<<< + * except: + * pass + */ + __pyx_t_7 = __Pyx_PyObject_GetAttrStr(__pyx_collections_abc_Sequence, __pyx_n_s_index); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 984, __pyx_L17_error) + __Pyx_GOTREF(__pyx_t_7); + if (__Pyx_SetItemOnTypeDict(__pyx_memoryviewslice_type, __pyx_n_s_index, __pyx_t_7) < 0) __PYX_ERR(1, 984, __pyx_L17_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + PyType_Modified(__pyx_memoryviewslice_type); + + /* "View.MemoryView":982 + * + * + * try: # <<<<<<<<<<<<<< + * count = __pyx_collections_abc_Sequence.count + * index = __pyx_collections_abc_Sequence.index + */ + } + __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0; + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + goto __pyx_L22_try_end; + __pyx_L17_error:; + __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "View.MemoryView":985 + * count = __pyx_collections_abc_Sequence.count + * index = __pyx_collections_abc_Sequence.index + * except: # <<<<<<<<<<<<<< + * pass + * + */ + /*except:*/ { + __Pyx_ErrRestore(0,0,0); + goto __pyx_L18_exception_handled; + } + __pyx_L18_exception_handled:; + __Pyx_XGIVEREF(__pyx_t_1); + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_ExceptionReset(__pyx_t_1, __pyx_t_2, __pyx_t_3); + __pyx_L22_try_end:; + } + + /* "View.MemoryView":988 + * pass + * + * try: # <<<<<<<<<<<<<< + * if __pyx_collections_abc_Sequence: + * + */ + { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ExceptionSave(&__pyx_t_3, &__pyx_t_2, &__pyx_t_1); + __Pyx_XGOTREF(__pyx_t_3); + __Pyx_XGOTREF(__pyx_t_2); + __Pyx_XGOTREF(__pyx_t_1); + /*try:*/ { + + /* "View.MemoryView":989 + * + * try: + * if __pyx_collections_abc_Sequence: # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_6 = __Pyx_PyObject_IsTrue(__pyx_collections_abc_Sequence); if (unlikely((__pyx_t_6 < 0))) __PYX_ERR(1, 989, __pyx_L23_error) + if (__pyx_t_6) { + + /* "View.MemoryView":993 + * + * + * __pyx_collections_abc_Sequence.register(_memoryviewslice) # <<<<<<<<<<<<<< + * __pyx_collections_abc_Sequence.register(array) + * except: + */ + __pyx_t_7 = __Pyx_PyObject_GetAttrStr(__pyx_collections_abc_Sequence, __pyx_n_s_register); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 993, __pyx_L23_error) + __Pyx_GOTREF(__pyx_t_7); + __pyx_t_4 = __Pyx_PyObject_CallOneArg(__pyx_t_7, ((PyObject *)__pyx_memoryviewslice_type)); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 993, __pyx_L23_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + + /* "View.MemoryView":994 + * + * __pyx_collections_abc_Sequence.register(_memoryviewslice) + * __pyx_collections_abc_Sequence.register(array) # <<<<<<<<<<<<<< + * except: + * pass # ignore failure, it's a minor issue + */ + __pyx_t_4 = __Pyx_PyObject_GetAttrStr(__pyx_collections_abc_Sequence, __pyx_n_s_register); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 994, __pyx_L23_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_7 = __Pyx_PyObject_CallOneArg(__pyx_t_4, ((PyObject *)__pyx_array_type)); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 994, __pyx_L23_error) + __Pyx_GOTREF(__pyx_t_7); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "View.MemoryView":989 + * + * try: + * if __pyx_collections_abc_Sequence: # <<<<<<<<<<<<<< + * + * + */ + } + + /* "View.MemoryView":988 + * pass + * + * try: # <<<<<<<<<<<<<< + * if __pyx_collections_abc_Sequence: + * + */ + } + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0; + __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; + goto __pyx_L28_try_end; + __pyx_L23_error:; + __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "View.MemoryView":995 + * __pyx_collections_abc_Sequence.register(_memoryviewslice) + * __pyx_collections_abc_Sequence.register(array) + * except: # <<<<<<<<<<<<<< + * pass # ignore failure, it's a minor issue + * + */ + /*except:*/ { + __Pyx_ErrRestore(0,0,0); + goto __pyx_L24_exception_handled; + } + __pyx_L24_exception_handled:; + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_1); + __Pyx_ExceptionReset(__pyx_t_3, __pyx_t_2, __pyx_t_1); + __pyx_L28_try_end:; + } + + /* "(tree fragment)":1 + * def __pyx_unpickle_Enum(__pyx_type, long __pyx_checksum, __pyx_state): # <<<<<<<<<<<<<< + * cdef object __pyx_PickleError + * cdef object __pyx_result + */ + __pyx_t_7 = PyCFunction_NewEx(&__pyx_mdef_15View_dot_MemoryView_1__pyx_unpickle_Enum, NULL, __pyx_n_s_View_MemoryView); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 1, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + if (PyDict_SetItem(__pyx_d, __pyx_n_s_pyx_unpickle_Enum, __pyx_t_7) < 0) __PYX_ERR(1, 1, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "fairseq/data/data_utils_fast.pyx":7 + * # LICENSE file in the root directory of this source tree. + * + * import numpy as np # <<<<<<<<<<<<<< + * + * cimport cython + */ + __pyx_t_7 = __Pyx_ImportDottedModule(__pyx_n_s_numpy, NULL); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 7, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + if (PyDict_SetItem(__pyx_d, __pyx_n_s_np, __pyx_t_7) < 0) __PYX_ERR(0, 7, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "fairseq/data/data_utils_fast.pyx":20 + * @cython.boundscheck(False) + * @cython.wraparound(False) + * cpdef list batch_by_size_vec( # <<<<<<<<<<<<<< + * np.ndarray[int64_t, ndim=1] indices, + * np.ndarray[int64_t, ndim=1] num_tokens_vec, + */ + __pyx_t_7 = __Pyx_CyFunction_New(&__pyx_mdef_7fairseq_4data_15data_utils_fast_1batch_by_size_vec, 0, __pyx_n_s_batch_by_size_vec, NULL, __pyx_n_s_fairseq_data_data_utils_fast, __pyx_d, ((PyObject *)__pyx_codeobj__23)); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 20, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + if (PyDict_SetItem(__pyx_d, __pyx_n_s_batch_by_size_vec, __pyx_t_7) < 0) __PYX_ERR(0, 20, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "fairseq/data/data_utils_fast.pyx":108 + * @cython.boundscheck(False) + * @cython.wraparound(False) + * cpdef list batch_by_size_fn( # <<<<<<<<<<<<<< + * np.ndarray[DTYPE_t, ndim=1] indices, + * num_tokens_fn, + */ + __pyx_t_7 = __Pyx_CyFunction_New(&__pyx_mdef_7fairseq_4data_15data_utils_fast_3batch_by_size_fn, 0, __pyx_n_s_batch_by_size_fn, NULL, __pyx_n_s_fairseq_data_data_utils_fast, __pyx_d, ((PyObject *)__pyx_codeobj__25)); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 108, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + if (PyDict_SetItem(__pyx_d, __pyx_n_s_batch_by_size_fn, __pyx_t_7) < 0) __PYX_ERR(0, 108, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "fairseq/data/data_utils_fast.pyx":140 + * + * @cython.cdivision(True) + * cpdef list batch_fixed_shapes_fast( # <<<<<<<<<<<<<< + * np.ndarray[DTYPE_t, ndim=1] indices, + * num_tokens_fn, + */ + __pyx_t_7 = __Pyx_CyFunction_New(&__pyx_mdef_7fairseq_4data_15data_utils_fast_5batch_fixed_shapes_fast, 0, __pyx_n_s_batch_fixed_shapes_fast, NULL, __pyx_n_s_fairseq_data_data_utils_fast, __pyx_d, ((PyObject *)__pyx_codeobj__27)); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 140, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + if (PyDict_SetItem(__pyx_d, __pyx_n_s_batch_fixed_shapes_fast, __pyx_t_7) < 0) __PYX_ERR(0, 140, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "fairseq/data/data_utils_fast.pyx":1 + * # cython: language_level=3 # <<<<<<<<<<<<<< + * # Copyright (c) Facebook, Inc. and its affiliates. + * # + */ + __pyx_t_7 = __Pyx_PyDict_NewPresized(0); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 1, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + if (PyDict_SetItem(__pyx_d, __pyx_n_s_test, __pyx_t_7) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + + /*--- Wrapped vars code ---*/ + + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_4); + __Pyx_XDECREF(__pyx_t_5); + __Pyx_XDECREF(__pyx_t_7); + if (__pyx_m) { + if (__pyx_d && stringtab_initialized) { + __Pyx_AddTraceback("init fairseq.data.data_utils_fast", __pyx_clineno, __pyx_lineno, __pyx_filename); + } + #if !CYTHON_USE_MODULE_STATE + Py_CLEAR(__pyx_m); + #else + Py_DECREF(__pyx_m); + if (pystate_addmodule_run) { + PyObject *tp, *value, *tb; + PyErr_Fetch(&tp, &value, &tb); + PyState_RemoveModule(&__pyx_moduledef); + PyErr_Restore(tp, value, tb); + } + #endif + } else if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_ImportError, "init fairseq.data.data_utils_fast"); + } + __pyx_L0:; + __Pyx_RefNannyFinishContext(); + #if CYTHON_PEP489_MULTI_PHASE_INIT + return (__pyx_m != NULL) ? 0 : -1; + #elif PY_MAJOR_VERSION >= 3 + return __pyx_m; + #else + return; + #endif +} +/* #### Code section: cleanup_globals ### */ +/* #### Code section: cleanup_module ### */ +/* #### Code section: main_method ### */ +/* #### Code section: utility_code_pragmas ### */ +#ifdef _MSC_VER +#pragma warning( push ) +/* Warning 4127: conditional expression is constant + * Cython uses constant conditional expressions to allow in inline functions to be optimized at + * compile-time, so this warning is not useful + */ +#pragma warning( disable : 4127 ) +#endif + + + +/* #### Code section: utility_code_def ### */ + +/* --- Runtime support code --- */ +/* Refnanny */ +#if CYTHON_REFNANNY +static __Pyx_RefNannyAPIStruct *__Pyx_RefNannyImportAPI(const char *modname) { + PyObject *m = NULL, *p = NULL; + void *r = NULL; + m = PyImport_ImportModule(modname); + if (!m) goto end; + p = PyObject_GetAttrString(m, "RefNannyAPI"); + if (!p) goto end; + r = PyLong_AsVoidPtr(p); +end: + Py_XDECREF(p); + Py_XDECREF(m); + return (__Pyx_RefNannyAPIStruct *)r; +} +#endif + +/* PyErrExceptionMatches */ +#if CYTHON_FAST_THREAD_STATE +static int __Pyx_PyErr_ExceptionMatchesTuple(PyObject *exc_type, PyObject *tuple) { + Py_ssize_t i, n; + n = PyTuple_GET_SIZE(tuple); +#if PY_MAJOR_VERSION >= 3 + for (i=0; i= 0x030C00A6 + PyObject *current_exception = tstate->current_exception; + if (unlikely(!current_exception)) return 0; + exc_type = (PyObject*) Py_TYPE(current_exception); + if (exc_type == err) return 1; +#else + exc_type = tstate->curexc_type; + if (exc_type == err) return 1; + if (unlikely(!exc_type)) return 0; +#endif + #if CYTHON_AVOID_BORROWED_REFS + Py_INCREF(exc_type); + #endif + if (unlikely(PyTuple_Check(err))) { + result = __Pyx_PyErr_ExceptionMatchesTuple(exc_type, err); + } else { + result = __Pyx_PyErr_GivenExceptionMatches(exc_type, err); + } + #if CYTHON_AVOID_BORROWED_REFS + Py_DECREF(exc_type); + #endif + return result; +} +#endif + +/* PyErrFetchRestore */ +#if CYTHON_FAST_THREAD_STATE +static CYTHON_INLINE void __Pyx_ErrRestoreInState(PyThreadState *tstate, PyObject *type, PyObject *value, PyObject *tb) { +#if PY_VERSION_HEX >= 0x030C00A6 + PyObject *tmp_value; + assert(type == NULL || (value != NULL && type == (PyObject*) Py_TYPE(value))); + if (value) { + #if CYTHON_COMPILING_IN_CPYTHON + if (unlikely(((PyBaseExceptionObject*) value)->traceback != tb)) + #endif + PyException_SetTraceback(value, tb); + } + tmp_value = tstate->current_exception; + tstate->current_exception = value; + Py_XDECREF(tmp_value); + Py_XDECREF(type); + Py_XDECREF(tb); +#else + PyObject *tmp_type, *tmp_value, *tmp_tb; + tmp_type = tstate->curexc_type; + tmp_value = tstate->curexc_value; + tmp_tb = tstate->curexc_traceback; + tstate->curexc_type = type; + tstate->curexc_value = value; + tstate->curexc_traceback = tb; + Py_XDECREF(tmp_type); + Py_XDECREF(tmp_value); + Py_XDECREF(tmp_tb); +#endif +} +static CYTHON_INLINE void __Pyx_ErrFetchInState(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb) { +#if PY_VERSION_HEX >= 0x030C00A6 + PyObject* exc_value; + exc_value = tstate->current_exception; + tstate->current_exception = 0; + *value = exc_value; + *type = NULL; + *tb = NULL; + if (exc_value) { + *type = (PyObject*) Py_TYPE(exc_value); + Py_INCREF(*type); + #if CYTHON_COMPILING_IN_CPYTHON + *tb = ((PyBaseExceptionObject*) exc_value)->traceback; + Py_XINCREF(*tb); + #else + *tb = PyException_GetTraceback(exc_value); + #endif + } +#else + *type = tstate->curexc_type; + *value = tstate->curexc_value; + *tb = tstate->curexc_traceback; + tstate->curexc_type = 0; + tstate->curexc_value = 0; + tstate->curexc_traceback = 0; +#endif +} +#endif + +/* PyObjectGetAttrStr */ +#if CYTHON_USE_TYPE_SLOTS +static CYTHON_INLINE PyObject* __Pyx_PyObject_GetAttrStr(PyObject* obj, PyObject* attr_name) { + PyTypeObject* tp = Py_TYPE(obj); + if (likely(tp->tp_getattro)) + return tp->tp_getattro(obj, attr_name); +#if PY_MAJOR_VERSION < 3 + if (likely(tp->tp_getattr)) + return tp->tp_getattr(obj, PyString_AS_STRING(attr_name)); +#endif + return PyObject_GetAttr(obj, attr_name); +} +#endif + +/* PyObjectGetAttrStrNoError */ +#if __PYX_LIMITED_VERSION_HEX < 0x030d00A1 +static void __Pyx_PyObject_GetAttrStr_ClearAttributeError(void) { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + if (likely(__Pyx_PyErr_ExceptionMatches(PyExc_AttributeError))) + __Pyx_PyErr_Clear(); +} +#endif +static CYTHON_INLINE PyObject* __Pyx_PyObject_GetAttrStrNoError(PyObject* obj, PyObject* attr_name) { + PyObject *result; +#if __PYX_LIMITED_VERSION_HEX >= 0x030d00A1 + (void) PyObject_GetOptionalAttr(obj, attr_name, &result); + return result; +#else +#if CYTHON_COMPILING_IN_CPYTHON && CYTHON_USE_TYPE_SLOTS && PY_VERSION_HEX >= 0x030700B1 + PyTypeObject* tp = Py_TYPE(obj); + if (likely(tp->tp_getattro == PyObject_GenericGetAttr)) { + return _PyObject_GenericGetAttrWithDict(obj, attr_name, NULL, 1); + } +#endif + result = __Pyx_PyObject_GetAttrStr(obj, attr_name); + if (unlikely(!result)) { + __Pyx_PyObject_GetAttrStr_ClearAttributeError(); + } + return result; +#endif +} + +/* GetBuiltinName */ +static PyObject *__Pyx_GetBuiltinName(PyObject *name) { + PyObject* result = __Pyx_PyObject_GetAttrStrNoError(__pyx_b, name); + if (unlikely(!result) && !PyErr_Occurred()) { + PyErr_Format(PyExc_NameError, +#if PY_MAJOR_VERSION >= 3 + "name '%U' is not defined", name); +#else + "name '%.200s' is not defined", PyString_AS_STRING(name)); +#endif + } + return result; +} + +/* TupleAndListFromArray */ +#if CYTHON_COMPILING_IN_CPYTHON +static CYTHON_INLINE void __Pyx_copy_object_array(PyObject *const *CYTHON_RESTRICT src, PyObject** CYTHON_RESTRICT dest, Py_ssize_t length) { + PyObject *v; + Py_ssize_t i; + for (i = 0; i < length; i++) { + v = dest[i] = src[i]; + Py_INCREF(v); + } +} +static CYTHON_INLINE PyObject * +__Pyx_PyTuple_FromArray(PyObject *const *src, Py_ssize_t n) +{ + PyObject *res; + if (n <= 0) { + Py_INCREF(__pyx_empty_tuple); + return __pyx_empty_tuple; + } + res = PyTuple_New(n); + if (unlikely(res == NULL)) return NULL; + __Pyx_copy_object_array(src, ((PyTupleObject*)res)->ob_item, n); + return res; +} +static CYTHON_INLINE PyObject * +__Pyx_PyList_FromArray(PyObject *const *src, Py_ssize_t n) +{ + PyObject *res; + if (n <= 0) { + return PyList_New(0); + } + res = PyList_New(n); + if (unlikely(res == NULL)) return NULL; + __Pyx_copy_object_array(src, ((PyListObject*)res)->ob_item, n); + return res; +} +#endif + +/* BytesEquals */ +static CYTHON_INLINE int __Pyx_PyBytes_Equals(PyObject* s1, PyObject* s2, int equals) { +#if CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API + return PyObject_RichCompareBool(s1, s2, equals); +#else + if (s1 == s2) { + return (equals == Py_EQ); + } else if (PyBytes_CheckExact(s1) & PyBytes_CheckExact(s2)) { + const char *ps1, *ps2; + Py_ssize_t length = PyBytes_GET_SIZE(s1); + if (length != PyBytes_GET_SIZE(s2)) + return (equals == Py_NE); + ps1 = PyBytes_AS_STRING(s1); + ps2 = PyBytes_AS_STRING(s2); + if (ps1[0] != ps2[0]) { + return (equals == Py_NE); + } else if (length == 1) { + return (equals == Py_EQ); + } else { + int result; +#if CYTHON_USE_UNICODE_INTERNALS && (PY_VERSION_HEX < 0x030B0000) + Py_hash_t hash1, hash2; + hash1 = ((PyBytesObject*)s1)->ob_shash; + hash2 = ((PyBytesObject*)s2)->ob_shash; + if (hash1 != hash2 && hash1 != -1 && hash2 != -1) { + return (equals == Py_NE); + } +#endif + result = memcmp(ps1, ps2, (size_t)length); + return (equals == Py_EQ) ? (result == 0) : (result != 0); + } + } else if ((s1 == Py_None) & PyBytes_CheckExact(s2)) { + return (equals == Py_NE); + } else if ((s2 == Py_None) & PyBytes_CheckExact(s1)) { + return (equals == Py_NE); + } else { + int result; + PyObject* py_result = PyObject_RichCompare(s1, s2, equals); + if (!py_result) + return -1; + result = __Pyx_PyObject_IsTrue(py_result); + Py_DECREF(py_result); + return result; + } +#endif +} + +/* UnicodeEquals */ +static CYTHON_INLINE int __Pyx_PyUnicode_Equals(PyObject* s1, PyObject* s2, int equals) { +#if CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API + return PyObject_RichCompareBool(s1, s2, equals); +#else +#if PY_MAJOR_VERSION < 3 + PyObject* owned_ref = NULL; +#endif + int s1_is_unicode, s2_is_unicode; + if (s1 == s2) { + goto return_eq; + } + s1_is_unicode = PyUnicode_CheckExact(s1); + s2_is_unicode = PyUnicode_CheckExact(s2); +#if PY_MAJOR_VERSION < 3 + if ((s1_is_unicode & (!s2_is_unicode)) && PyString_CheckExact(s2)) { + owned_ref = PyUnicode_FromObject(s2); + if (unlikely(!owned_ref)) + return -1; + s2 = owned_ref; + s2_is_unicode = 1; + } else if ((s2_is_unicode & (!s1_is_unicode)) && PyString_CheckExact(s1)) { + owned_ref = PyUnicode_FromObject(s1); + if (unlikely(!owned_ref)) + return -1; + s1 = owned_ref; + s1_is_unicode = 1; + } else if (((!s2_is_unicode) & (!s1_is_unicode))) { + return __Pyx_PyBytes_Equals(s1, s2, equals); + } +#endif + if (s1_is_unicode & s2_is_unicode) { + Py_ssize_t length; + int kind; + void *data1, *data2; + if (unlikely(__Pyx_PyUnicode_READY(s1) < 0) || unlikely(__Pyx_PyUnicode_READY(s2) < 0)) + return -1; + length = __Pyx_PyUnicode_GET_LENGTH(s1); + if (length != __Pyx_PyUnicode_GET_LENGTH(s2)) { + goto return_ne; + } +#if CYTHON_USE_UNICODE_INTERNALS + { + Py_hash_t hash1, hash2; + #if CYTHON_PEP393_ENABLED + hash1 = ((PyASCIIObject*)s1)->hash; + hash2 = ((PyASCIIObject*)s2)->hash; + #else + hash1 = ((PyUnicodeObject*)s1)->hash; + hash2 = ((PyUnicodeObject*)s2)->hash; + #endif + if (hash1 != hash2 && hash1 != -1 && hash2 != -1) { + goto return_ne; + } + } +#endif + kind = __Pyx_PyUnicode_KIND(s1); + if (kind != __Pyx_PyUnicode_KIND(s2)) { + goto return_ne; + } + data1 = __Pyx_PyUnicode_DATA(s1); + data2 = __Pyx_PyUnicode_DATA(s2); + if (__Pyx_PyUnicode_READ(kind, data1, 0) != __Pyx_PyUnicode_READ(kind, data2, 0)) { + goto return_ne; + } else if (length == 1) { + goto return_eq; + } else { + int result = memcmp(data1, data2, (size_t)(length * kind)); + #if PY_MAJOR_VERSION < 3 + Py_XDECREF(owned_ref); + #endif + return (equals == Py_EQ) ? (result == 0) : (result != 0); + } + } else if ((s1 == Py_None) & s2_is_unicode) { + goto return_ne; + } else if ((s2 == Py_None) & s1_is_unicode) { + goto return_ne; + } else { + int result; + PyObject* py_result = PyObject_RichCompare(s1, s2, equals); + #if PY_MAJOR_VERSION < 3 + Py_XDECREF(owned_ref); + #endif + if (!py_result) + return -1; + result = __Pyx_PyObject_IsTrue(py_result); + Py_DECREF(py_result); + return result; + } +return_eq: + #if PY_MAJOR_VERSION < 3 + Py_XDECREF(owned_ref); + #endif + return (equals == Py_EQ); +return_ne: + #if PY_MAJOR_VERSION < 3 + Py_XDECREF(owned_ref); + #endif + return (equals == Py_NE); +#endif +} + +/* fastcall */ +#if CYTHON_METH_FASTCALL +static CYTHON_INLINE PyObject * __Pyx_GetKwValue_FASTCALL(PyObject *kwnames, PyObject *const *kwvalues, PyObject *s) +{ + Py_ssize_t i, n = PyTuple_GET_SIZE(kwnames); + for (i = 0; i < n; i++) + { + if (s == PyTuple_GET_ITEM(kwnames, i)) return kwvalues[i]; + } + for (i = 0; i < n; i++) + { + int eq = __Pyx_PyUnicode_Equals(s, PyTuple_GET_ITEM(kwnames, i), Py_EQ); + if (unlikely(eq != 0)) { + if (unlikely(eq < 0)) return NULL; + return kwvalues[i]; + } + } + return NULL; +} +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030d0000 +CYTHON_UNUSED static PyObject *__Pyx_KwargsAsDict_FASTCALL(PyObject *kwnames, PyObject *const *kwvalues) { + Py_ssize_t i, nkwargs = PyTuple_GET_SIZE(kwnames); + PyObject *dict; + dict = PyDict_New(); + if (unlikely(!dict)) + return NULL; + for (i=0; i= 3 + "%s() got multiple values for keyword argument '%U'", func_name, kw_name); + #else + "%s() got multiple values for keyword argument '%s'", func_name, + PyString_AsString(kw_name)); + #endif +} + +/* ParseKeywords */ +static int __Pyx_ParseOptionalKeywords( + PyObject *kwds, + PyObject *const *kwvalues, + PyObject **argnames[], + PyObject *kwds2, + PyObject *values[], + Py_ssize_t num_pos_args, + const char* function_name) +{ + PyObject *key = 0, *value = 0; + Py_ssize_t pos = 0; + PyObject*** name; + PyObject*** first_kw_arg = argnames + num_pos_args; + int kwds_is_tuple = CYTHON_METH_FASTCALL && likely(PyTuple_Check(kwds)); + while (1) { + Py_XDECREF(key); key = NULL; + Py_XDECREF(value); value = NULL; + if (kwds_is_tuple) { + Py_ssize_t size; +#if CYTHON_ASSUME_SAFE_MACROS + size = PyTuple_GET_SIZE(kwds); +#else + size = PyTuple_Size(kwds); + if (size < 0) goto bad; +#endif + if (pos >= size) break; +#if CYTHON_AVOID_BORROWED_REFS + key = __Pyx_PySequence_ITEM(kwds, pos); + if (!key) goto bad; +#elif CYTHON_ASSUME_SAFE_MACROS + key = PyTuple_GET_ITEM(kwds, pos); +#else + key = PyTuple_GetItem(kwds, pos); + if (!key) goto bad; +#endif + value = kwvalues[pos]; + pos++; + } + else + { + if (!PyDict_Next(kwds, &pos, &key, &value)) break; +#if CYTHON_AVOID_BORROWED_REFS + Py_INCREF(key); +#endif + } + name = first_kw_arg; + while (*name && (**name != key)) name++; + if (*name) { + values[name-argnames] = value; +#if CYTHON_AVOID_BORROWED_REFS + Py_INCREF(value); + Py_DECREF(key); +#endif + key = NULL; + value = NULL; + continue; + } +#if !CYTHON_AVOID_BORROWED_REFS + Py_INCREF(key); +#endif + Py_INCREF(value); + name = first_kw_arg; + #if PY_MAJOR_VERSION < 3 + if (likely(PyString_Check(key))) { + while (*name) { + if ((CYTHON_COMPILING_IN_PYPY || PyString_GET_SIZE(**name) == PyString_GET_SIZE(key)) + && _PyString_Eq(**name, key)) { + values[name-argnames] = value; +#if CYTHON_AVOID_BORROWED_REFS + value = NULL; +#endif + break; + } + name++; + } + if (*name) continue; + else { + PyObject*** argname = argnames; + while (argname != first_kw_arg) { + if ((**argname == key) || ( + (CYTHON_COMPILING_IN_PYPY || PyString_GET_SIZE(**argname) == PyString_GET_SIZE(key)) + && _PyString_Eq(**argname, key))) { + goto arg_passed_twice; + } + argname++; + } + } + } else + #endif + if (likely(PyUnicode_Check(key))) { + while (*name) { + int cmp = ( + #if !CYTHON_COMPILING_IN_PYPY && PY_MAJOR_VERSION >= 3 + (__Pyx_PyUnicode_GET_LENGTH(**name) != __Pyx_PyUnicode_GET_LENGTH(key)) ? 1 : + #endif + PyUnicode_Compare(**name, key) + ); + if (cmp < 0 && unlikely(PyErr_Occurred())) goto bad; + if (cmp == 0) { + values[name-argnames] = value; +#if CYTHON_AVOID_BORROWED_REFS + value = NULL; +#endif + break; + } + name++; + } + if (*name) continue; + else { + PyObject*** argname = argnames; + while (argname != first_kw_arg) { + int cmp = (**argname == key) ? 0 : + #if !CYTHON_COMPILING_IN_PYPY && PY_MAJOR_VERSION >= 3 + (__Pyx_PyUnicode_GET_LENGTH(**argname) != __Pyx_PyUnicode_GET_LENGTH(key)) ? 1 : + #endif + PyUnicode_Compare(**argname, key); + if (cmp < 0 && unlikely(PyErr_Occurred())) goto bad; + if (cmp == 0) goto arg_passed_twice; + argname++; + } + } + } else + goto invalid_keyword_type; + if (kwds2) { + if (unlikely(PyDict_SetItem(kwds2, key, value))) goto bad; + } else { + goto invalid_keyword; + } + } + Py_XDECREF(key); + Py_XDECREF(value); + return 0; +arg_passed_twice: + __Pyx_RaiseDoubleKeywordsError(function_name, key); + goto bad; +invalid_keyword_type: + PyErr_Format(PyExc_TypeError, + "%.200s() keywords must be strings", function_name); + goto bad; +invalid_keyword: + #if PY_MAJOR_VERSION < 3 + PyErr_Format(PyExc_TypeError, + "%.200s() got an unexpected keyword argument '%.200s'", + function_name, PyString_AsString(key)); + #else + PyErr_Format(PyExc_TypeError, + "%s() got an unexpected keyword argument '%U'", + function_name, key); + #endif +bad: + Py_XDECREF(key); + Py_XDECREF(value); + return -1; +} + +/* ArgTypeTest */ +static int __Pyx__ArgTypeTest(PyObject *obj, PyTypeObject *type, const char *name, int exact) +{ + __Pyx_TypeName type_name; + __Pyx_TypeName obj_type_name; + if (unlikely(!type)) { + PyErr_SetString(PyExc_SystemError, "Missing type object"); + return 0; + } + else if (exact) { + #if PY_MAJOR_VERSION == 2 + if ((type == &PyBaseString_Type) && likely(__Pyx_PyBaseString_CheckExact(obj))) return 1; + #endif + } + else { + if (likely(__Pyx_TypeCheck(obj, type))) return 1; + } + type_name = __Pyx_PyType_GetName(type); + obj_type_name = __Pyx_PyType_GetName(Py_TYPE(obj)); + PyErr_Format(PyExc_TypeError, + "Argument '%.200s' has incorrect type (expected " __Pyx_FMT_TYPENAME + ", got " __Pyx_FMT_TYPENAME ")", name, type_name, obj_type_name); + __Pyx_DECREF_TypeName(type_name); + __Pyx_DECREF_TypeName(obj_type_name); + return 0; +} + +/* RaiseException */ +#if PY_MAJOR_VERSION < 3 +static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause) { + __Pyx_PyThreadState_declare + CYTHON_UNUSED_VAR(cause); + Py_XINCREF(type); + if (!value || value == Py_None) + value = NULL; + else + Py_INCREF(value); + if (!tb || tb == Py_None) + tb = NULL; + else { + Py_INCREF(tb); + if (!PyTraceBack_Check(tb)) { + PyErr_SetString(PyExc_TypeError, + "raise: arg 3 must be a traceback or None"); + goto raise_error; + } + } + if (PyType_Check(type)) { +#if CYTHON_COMPILING_IN_PYPY + if (!value) { + Py_INCREF(Py_None); + value = Py_None; + } +#endif + PyErr_NormalizeException(&type, &value, &tb); + } else { + if (value) { + PyErr_SetString(PyExc_TypeError, + "instance exception may not have a separate value"); + goto raise_error; + } + value = type; + type = (PyObject*) Py_TYPE(type); + Py_INCREF(type); + if (!PyType_IsSubtype((PyTypeObject *)type, (PyTypeObject *)PyExc_BaseException)) { + PyErr_SetString(PyExc_TypeError, + "raise: exception class must be a subclass of BaseException"); + goto raise_error; + } + } + __Pyx_PyThreadState_assign + __Pyx_ErrRestore(type, value, tb); + return; +raise_error: + Py_XDECREF(value); + Py_XDECREF(type); + Py_XDECREF(tb); + return; +} +#else +static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause) { + PyObject* owned_instance = NULL; + if (tb == Py_None) { + tb = 0; + } else if (tb && !PyTraceBack_Check(tb)) { + PyErr_SetString(PyExc_TypeError, + "raise: arg 3 must be a traceback or None"); + goto bad; + } + if (value == Py_None) + value = 0; + if (PyExceptionInstance_Check(type)) { + if (value) { + PyErr_SetString(PyExc_TypeError, + "instance exception may not have a separate value"); + goto bad; + } + value = type; + type = (PyObject*) Py_TYPE(value); + } else if (PyExceptionClass_Check(type)) { + PyObject *instance_class = NULL; + if (value && PyExceptionInstance_Check(value)) { + instance_class = (PyObject*) Py_TYPE(value); + if (instance_class != type) { + int is_subclass = PyObject_IsSubclass(instance_class, type); + if (!is_subclass) { + instance_class = NULL; + } else if (unlikely(is_subclass == -1)) { + goto bad; + } else { + type = instance_class; + } + } + } + if (!instance_class) { + PyObject *args; + if (!value) + args = PyTuple_New(0); + else if (PyTuple_Check(value)) { + Py_INCREF(value); + args = value; + } else + args = PyTuple_Pack(1, value); + if (!args) + goto bad; + owned_instance = PyObject_Call(type, args, NULL); + Py_DECREF(args); + if (!owned_instance) + goto bad; + value = owned_instance; + if (!PyExceptionInstance_Check(value)) { + PyErr_Format(PyExc_TypeError, + "calling %R should have returned an instance of " + "BaseException, not %R", + type, Py_TYPE(value)); + goto bad; + } + } + } else { + PyErr_SetString(PyExc_TypeError, + "raise: exception class must be a subclass of BaseException"); + goto bad; + } + if (cause) { + PyObject *fixed_cause; + if (cause == Py_None) { + fixed_cause = NULL; + } else if (PyExceptionClass_Check(cause)) { + fixed_cause = PyObject_CallObject(cause, NULL); + if (fixed_cause == NULL) + goto bad; + } else if (PyExceptionInstance_Check(cause)) { + fixed_cause = cause; + Py_INCREF(fixed_cause); + } else { + PyErr_SetString(PyExc_TypeError, + "exception causes must derive from " + "BaseException"); + goto bad; + } + PyException_SetCause(value, fixed_cause); + } + PyErr_SetObject(type, value); + if (tb) { + #if PY_VERSION_HEX >= 0x030C00A6 + PyException_SetTraceback(value, tb); + #elif CYTHON_FAST_THREAD_STATE + PyThreadState *tstate = __Pyx_PyThreadState_Current; + PyObject* tmp_tb = tstate->curexc_traceback; + if (tb != tmp_tb) { + Py_INCREF(tb); + tstate->curexc_traceback = tb; + Py_XDECREF(tmp_tb); + } +#else + PyObject *tmp_type, *tmp_value, *tmp_tb; + PyErr_Fetch(&tmp_type, &tmp_value, &tmp_tb); + Py_INCREF(tb); + PyErr_Restore(tmp_type, tmp_value, tb); + Py_XDECREF(tmp_tb); +#endif + } +bad: + Py_XDECREF(owned_instance); + return; +} +#endif + +/* PyFunctionFastCall */ +#if CYTHON_FAST_PYCALL && !CYTHON_VECTORCALL +static PyObject* __Pyx_PyFunction_FastCallNoKw(PyCodeObject *co, PyObject **args, Py_ssize_t na, + PyObject *globals) { + PyFrameObject *f; + PyThreadState *tstate = __Pyx_PyThreadState_Current; + PyObject **fastlocals; + Py_ssize_t i; + PyObject *result; + assert(globals != NULL); + /* XXX Perhaps we should create a specialized + PyFrame_New() that doesn't take locals, but does + take builtins without sanity checking them. + */ + assert(tstate != NULL); + f = PyFrame_New(tstate, co, globals, NULL); + if (f == NULL) { + return NULL; + } + fastlocals = __Pyx_PyFrame_GetLocalsplus(f); + for (i = 0; i < na; i++) { + Py_INCREF(*args); + fastlocals[i] = *args++; + } + result = PyEval_EvalFrameEx(f,0); + ++tstate->recursion_depth; + Py_DECREF(f); + --tstate->recursion_depth; + return result; +} +static PyObject *__Pyx_PyFunction_FastCallDict(PyObject *func, PyObject **args, Py_ssize_t nargs, PyObject *kwargs) { + PyCodeObject *co = (PyCodeObject *)PyFunction_GET_CODE(func); + PyObject *globals = PyFunction_GET_GLOBALS(func); + PyObject *argdefs = PyFunction_GET_DEFAULTS(func); + PyObject *closure; +#if PY_MAJOR_VERSION >= 3 + PyObject *kwdefs; +#endif + PyObject *kwtuple, **k; + PyObject **d; + Py_ssize_t nd; + Py_ssize_t nk; + PyObject *result; + assert(kwargs == NULL || PyDict_Check(kwargs)); + nk = kwargs ? PyDict_Size(kwargs) : 0; + #if PY_MAJOR_VERSION < 3 + if (unlikely(Py_EnterRecursiveCall((char*)" while calling a Python object"))) { + return NULL; + } + #else + if (unlikely(Py_EnterRecursiveCall(" while calling a Python object"))) { + return NULL; + } + #endif + if ( +#if PY_MAJOR_VERSION >= 3 + co->co_kwonlyargcount == 0 && +#endif + likely(kwargs == NULL || nk == 0) && + co->co_flags == (CO_OPTIMIZED | CO_NEWLOCALS | CO_NOFREE)) { + if (argdefs == NULL && co->co_argcount == nargs) { + result = __Pyx_PyFunction_FastCallNoKw(co, args, nargs, globals); + goto done; + } + else if (nargs == 0 && argdefs != NULL + && co->co_argcount == Py_SIZE(argdefs)) { + /* function called with no arguments, but all parameters have + a default value: use default values as arguments .*/ + args = &PyTuple_GET_ITEM(argdefs, 0); + result =__Pyx_PyFunction_FastCallNoKw(co, args, Py_SIZE(argdefs), globals); + goto done; + } + } + if (kwargs != NULL) { + Py_ssize_t pos, i; + kwtuple = PyTuple_New(2 * nk); + if (kwtuple == NULL) { + result = NULL; + goto done; + } + k = &PyTuple_GET_ITEM(kwtuple, 0); + pos = i = 0; + while (PyDict_Next(kwargs, &pos, &k[i], &k[i+1])) { + Py_INCREF(k[i]); + Py_INCREF(k[i+1]); + i += 2; + } + nk = i / 2; + } + else { + kwtuple = NULL; + k = NULL; + } + closure = PyFunction_GET_CLOSURE(func); +#if PY_MAJOR_VERSION >= 3 + kwdefs = PyFunction_GET_KW_DEFAULTS(func); +#endif + if (argdefs != NULL) { + d = &PyTuple_GET_ITEM(argdefs, 0); + nd = Py_SIZE(argdefs); + } + else { + d = NULL; + nd = 0; + } +#if PY_MAJOR_VERSION >= 3 + result = PyEval_EvalCodeEx((PyObject*)co, globals, (PyObject *)NULL, + args, (int)nargs, + k, (int)nk, + d, (int)nd, kwdefs, closure); +#else + result = PyEval_EvalCodeEx(co, globals, (PyObject *)NULL, + args, (int)nargs, + k, (int)nk, + d, (int)nd, closure); +#endif + Py_XDECREF(kwtuple); +done: + Py_LeaveRecursiveCall(); + return result; +} +#endif + +/* PyObjectCall */ +#if CYTHON_COMPILING_IN_CPYTHON +static CYTHON_INLINE PyObject* __Pyx_PyObject_Call(PyObject *func, PyObject *arg, PyObject *kw) { + PyObject *result; + ternaryfunc call = Py_TYPE(func)->tp_call; + if (unlikely(!call)) + return PyObject_Call(func, arg, kw); + #if PY_MAJOR_VERSION < 3 + if (unlikely(Py_EnterRecursiveCall((char*)" while calling a Python object"))) + return NULL; + #else + if (unlikely(Py_EnterRecursiveCall(" while calling a Python object"))) + return NULL; + #endif + result = (*call)(func, arg, kw); + Py_LeaveRecursiveCall(); + if (unlikely(!result) && unlikely(!PyErr_Occurred())) { + PyErr_SetString( + PyExc_SystemError, + "NULL result without error in PyObject_Call"); + } + return result; +} +#endif + +/* PyObjectCallMethO */ +#if CYTHON_COMPILING_IN_CPYTHON +static CYTHON_INLINE PyObject* __Pyx_PyObject_CallMethO(PyObject *func, PyObject *arg) { + PyObject *self, *result; + PyCFunction cfunc; + cfunc = __Pyx_CyOrPyCFunction_GET_FUNCTION(func); + self = __Pyx_CyOrPyCFunction_GET_SELF(func); + #if PY_MAJOR_VERSION < 3 + if (unlikely(Py_EnterRecursiveCall((char*)" while calling a Python object"))) + return NULL; + #else + if (unlikely(Py_EnterRecursiveCall(" while calling a Python object"))) + return NULL; + #endif + result = cfunc(self, arg); + Py_LeaveRecursiveCall(); + if (unlikely(!result) && unlikely(!PyErr_Occurred())) { + PyErr_SetString( + PyExc_SystemError, + "NULL result without error in PyObject_Call"); + } + return result; +} +#endif + +/* PyObjectFastCall */ +#if PY_VERSION_HEX < 0x03090000 || CYTHON_COMPILING_IN_LIMITED_API +static PyObject* __Pyx_PyObject_FastCall_fallback(PyObject *func, PyObject **args, size_t nargs, PyObject *kwargs) { + PyObject *argstuple; + PyObject *result = 0; + size_t i; + argstuple = PyTuple_New((Py_ssize_t)nargs); + if (unlikely(!argstuple)) return NULL; + for (i = 0; i < nargs; i++) { + Py_INCREF(args[i]); + if (__Pyx_PyTuple_SET_ITEM(argstuple, (Py_ssize_t)i, args[i]) < 0) goto bad; + } + result = __Pyx_PyObject_Call(func, argstuple, kwargs); + bad: + Py_DECREF(argstuple); + return result; +} +#endif +static CYTHON_INLINE PyObject* __Pyx_PyObject_FastCallDict(PyObject *func, PyObject **args, size_t _nargs, PyObject *kwargs) { + Py_ssize_t nargs = __Pyx_PyVectorcall_NARGS(_nargs); +#if CYTHON_COMPILING_IN_CPYTHON + if (nargs == 0 && kwargs == NULL) { + if (__Pyx_CyOrPyCFunction_Check(func) && likely( __Pyx_CyOrPyCFunction_GET_FLAGS(func) & METH_NOARGS)) + return __Pyx_PyObject_CallMethO(func, NULL); + } + else if (nargs == 1 && kwargs == NULL) { + if (__Pyx_CyOrPyCFunction_Check(func) && likely( __Pyx_CyOrPyCFunction_GET_FLAGS(func) & METH_O)) + return __Pyx_PyObject_CallMethO(func, args[0]); + } +#endif + #if PY_VERSION_HEX < 0x030800B1 + #if CYTHON_FAST_PYCCALL + if (PyCFunction_Check(func)) { + if (kwargs) { + return _PyCFunction_FastCallDict(func, args, nargs, kwargs); + } else { + return _PyCFunction_FastCallKeywords(func, args, nargs, NULL); + } + } + #if PY_VERSION_HEX >= 0x030700A1 + if (!kwargs && __Pyx_IS_TYPE(func, &PyMethodDescr_Type)) { + return _PyMethodDescr_FastCallKeywords(func, args, nargs, NULL); + } + #endif + #endif + #if CYTHON_FAST_PYCALL + if (PyFunction_Check(func)) { + return __Pyx_PyFunction_FastCallDict(func, args, nargs, kwargs); + } + #endif + #endif + if (kwargs == NULL) { + #if CYTHON_VECTORCALL + #if PY_VERSION_HEX < 0x03090000 + vectorcallfunc f = _PyVectorcall_Function(func); + #else + vectorcallfunc f = PyVectorcall_Function(func); + #endif + if (f) { + return f(func, args, (size_t)nargs, NULL); + } + #elif defined(__Pyx_CyFunction_USED) && CYTHON_BACKPORT_VECTORCALL + if (__Pyx_CyFunction_CheckExact(func)) { + __pyx_vectorcallfunc f = __Pyx_CyFunction_func_vectorcall(func); + if (f) return f(func, args, (size_t)nargs, NULL); + } + #endif + } + if (nargs == 0) { + return __Pyx_PyObject_Call(func, __pyx_empty_tuple, kwargs); + } + #if PY_VERSION_HEX >= 0x03090000 && !CYTHON_COMPILING_IN_LIMITED_API + return PyObject_VectorcallDict(func, args, (size_t)nargs, kwargs); + #else + return __Pyx_PyObject_FastCall_fallback(func, args, (size_t)nargs, kwargs); + #endif +} + +/* RaiseUnexpectedTypeError */ +static int +__Pyx_RaiseUnexpectedTypeError(const char *expected, PyObject *obj) +{ + __Pyx_TypeName obj_type_name = __Pyx_PyType_GetName(Py_TYPE(obj)); + PyErr_Format(PyExc_TypeError, "Expected %s, got " __Pyx_FMT_TYPENAME, + expected, obj_type_name); + __Pyx_DECREF_TypeName(obj_type_name); + return 0; +} + +/* CIntToDigits */ +static const char DIGIT_PAIRS_10[2*10*10+1] = { + "00010203040506070809" + "10111213141516171819" + "20212223242526272829" + "30313233343536373839" + "40414243444546474849" + "50515253545556575859" + "60616263646566676869" + "70717273747576777879" + "80818283848586878889" + "90919293949596979899" +}; +static const char DIGIT_PAIRS_8[2*8*8+1] = { + "0001020304050607" + "1011121314151617" + "2021222324252627" + "3031323334353637" + "4041424344454647" + "5051525354555657" + "6061626364656667" + "7071727374757677" +}; +static const char DIGITS_HEX[2*16+1] = { + "0123456789abcdef" + "0123456789ABCDEF" +}; + +/* BuildPyUnicode */ +static PyObject* __Pyx_PyUnicode_BuildFromAscii(Py_ssize_t ulength, char* chars, int clength, + int prepend_sign, char padding_char) { + PyObject *uval; + Py_ssize_t uoffset = ulength - clength; +#if CYTHON_USE_UNICODE_INTERNALS + Py_ssize_t i; +#if CYTHON_PEP393_ENABLED + void *udata; + uval = PyUnicode_New(ulength, 127); + if (unlikely(!uval)) return NULL; + udata = PyUnicode_DATA(uval); +#else + Py_UNICODE *udata; + uval = PyUnicode_FromUnicode(NULL, ulength); + if (unlikely(!uval)) return NULL; + udata = PyUnicode_AS_UNICODE(uval); +#endif + if (uoffset > 0) { + i = 0; + if (prepend_sign) { + __Pyx_PyUnicode_WRITE(PyUnicode_1BYTE_KIND, udata, 0, '-'); + i++; + } + for (; i < uoffset; i++) { + __Pyx_PyUnicode_WRITE(PyUnicode_1BYTE_KIND, udata, i, padding_char); + } + } + for (i=0; i < clength; i++) { + __Pyx_PyUnicode_WRITE(PyUnicode_1BYTE_KIND, udata, uoffset+i, chars[i]); + } +#else + { + PyObject *sign = NULL, *padding = NULL; + uval = NULL; + if (uoffset > 0) { + prepend_sign = !!prepend_sign; + if (uoffset > prepend_sign) { + padding = PyUnicode_FromOrdinal(padding_char); + if (likely(padding) && uoffset > prepend_sign + 1) { + PyObject *tmp; + PyObject *repeat = PyInt_FromSsize_t(uoffset - prepend_sign); + if (unlikely(!repeat)) goto done_or_error; + tmp = PyNumber_Multiply(padding, repeat); + Py_DECREF(repeat); + Py_DECREF(padding); + padding = tmp; + } + if (unlikely(!padding)) goto done_or_error; + } + if (prepend_sign) { + sign = PyUnicode_FromOrdinal('-'); + if (unlikely(!sign)) goto done_or_error; + } + } + uval = PyUnicode_DecodeASCII(chars, clength, NULL); + if (likely(uval) && padding) { + PyObject *tmp = PyNumber_Add(padding, uval); + Py_DECREF(uval); + uval = tmp; + } + if (likely(uval) && sign) { + PyObject *tmp = PyNumber_Add(sign, uval); + Py_DECREF(uval); + uval = tmp; + } +done_or_error: + Py_XDECREF(padding); + Py_XDECREF(sign); + } +#endif + return uval; +} + +/* CIntToPyUnicode */ +static CYTHON_INLINE PyObject* __Pyx_PyUnicode_From_int(int value, Py_ssize_t width, char padding_char, char format_char) { + char digits[sizeof(int)*3+2]; + char *dpos, *end = digits + sizeof(int)*3+2; + const char *hex_digits = DIGITS_HEX; + Py_ssize_t length, ulength; + int prepend_sign, last_one_off; + int remaining; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#endif + const int neg_one = (int) -1, const_zero = (int) 0; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic pop +#endif + const int is_unsigned = neg_one > const_zero; + if (format_char == 'X') { + hex_digits += 16; + format_char = 'x'; + } + remaining = value; + last_one_off = 0; + dpos = end; + do { + int digit_pos; + switch (format_char) { + case 'o': + digit_pos = abs((int)(remaining % (8*8))); + remaining = (int) (remaining / (8*8)); + dpos -= 2; + memcpy(dpos, DIGIT_PAIRS_8 + digit_pos * 2, 2); + last_one_off = (digit_pos < 8); + break; + case 'd': + digit_pos = abs((int)(remaining % (10*10))); + remaining = (int) (remaining / (10*10)); + dpos -= 2; + memcpy(dpos, DIGIT_PAIRS_10 + digit_pos * 2, 2); + last_one_off = (digit_pos < 10); + break; + case 'x': + *(--dpos) = hex_digits[abs((int)(remaining % 16))]; + remaining = (int) (remaining / 16); + break; + default: + assert(0); + break; + } + } while (unlikely(remaining != 0)); + assert(!last_one_off || *dpos == '0'); + dpos += last_one_off; + length = end - dpos; + ulength = length; + prepend_sign = 0; + if (!is_unsigned && value <= neg_one) { + if (padding_char == ' ' || width <= length + 1) { + *(--dpos) = '-'; + ++length; + } else { + prepend_sign = 1; + } + ++ulength; + } + if (width > ulength) { + ulength = width; + } + if (ulength == 1) { + return PyUnicode_FromOrdinal(*dpos); + } + return __Pyx_PyUnicode_BuildFromAscii(ulength, dpos, (int) length, prepend_sign, padding_char); +} + +/* CIntToPyUnicode */ +static CYTHON_INLINE PyObject* __Pyx_PyUnicode_From_Py_ssize_t(Py_ssize_t value, Py_ssize_t width, char padding_char, char format_char) { + char digits[sizeof(Py_ssize_t)*3+2]; + char *dpos, *end = digits + sizeof(Py_ssize_t)*3+2; + const char *hex_digits = DIGITS_HEX; + Py_ssize_t length, ulength; + int prepend_sign, last_one_off; + Py_ssize_t remaining; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#endif + const Py_ssize_t neg_one = (Py_ssize_t) -1, const_zero = (Py_ssize_t) 0; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic pop +#endif + const int is_unsigned = neg_one > const_zero; + if (format_char == 'X') { + hex_digits += 16; + format_char = 'x'; + } + remaining = value; + last_one_off = 0; + dpos = end; + do { + int digit_pos; + switch (format_char) { + case 'o': + digit_pos = abs((int)(remaining % (8*8))); + remaining = (Py_ssize_t) (remaining / (8*8)); + dpos -= 2; + memcpy(dpos, DIGIT_PAIRS_8 + digit_pos * 2, 2); + last_one_off = (digit_pos < 8); + break; + case 'd': + digit_pos = abs((int)(remaining % (10*10))); + remaining = (Py_ssize_t) (remaining / (10*10)); + dpos -= 2; + memcpy(dpos, DIGIT_PAIRS_10 + digit_pos * 2, 2); + last_one_off = (digit_pos < 10); + break; + case 'x': + *(--dpos) = hex_digits[abs((int)(remaining % 16))]; + remaining = (Py_ssize_t) (remaining / 16); + break; + default: + assert(0); + break; + } + } while (unlikely(remaining != 0)); + assert(!last_one_off || *dpos == '0'); + dpos += last_one_off; + length = end - dpos; + ulength = length; + prepend_sign = 0; + if (!is_unsigned && value <= neg_one) { + if (padding_char == ' ' || width <= length + 1) { + *(--dpos) = '-'; + ++length; + } else { + prepend_sign = 1; + } + ++ulength; + } + if (width > ulength) { + ulength = width; + } + if (ulength == 1) { + return PyUnicode_FromOrdinal(*dpos); + } + return __Pyx_PyUnicode_BuildFromAscii(ulength, dpos, (int) length, prepend_sign, padding_char); +} + +/* JoinPyUnicode */ +static PyObject* __Pyx_PyUnicode_Join(PyObject* value_tuple, Py_ssize_t value_count, Py_ssize_t result_ulength, + Py_UCS4 max_char) { +#if CYTHON_USE_UNICODE_INTERNALS && CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + PyObject *result_uval; + int result_ukind, kind_shift; + Py_ssize_t i, char_pos; + void *result_udata; + CYTHON_MAYBE_UNUSED_VAR(max_char); +#if CYTHON_PEP393_ENABLED + result_uval = PyUnicode_New(result_ulength, max_char); + if (unlikely(!result_uval)) return NULL; + result_ukind = (max_char <= 255) ? PyUnicode_1BYTE_KIND : (max_char <= 65535) ? PyUnicode_2BYTE_KIND : PyUnicode_4BYTE_KIND; + kind_shift = (result_ukind == PyUnicode_4BYTE_KIND) ? 2 : result_ukind - 1; + result_udata = PyUnicode_DATA(result_uval); +#else + result_uval = PyUnicode_FromUnicode(NULL, result_ulength); + if (unlikely(!result_uval)) return NULL; + result_ukind = sizeof(Py_UNICODE); + kind_shift = (result_ukind == 4) ? 2 : result_ukind - 1; + result_udata = PyUnicode_AS_UNICODE(result_uval); +#endif + assert(kind_shift == 2 || kind_shift == 1 || kind_shift == 0); + char_pos = 0; + for (i=0; i < value_count; i++) { + int ukind; + Py_ssize_t ulength; + void *udata; + PyObject *uval = PyTuple_GET_ITEM(value_tuple, i); + if (unlikely(__Pyx_PyUnicode_READY(uval))) + goto bad; + ulength = __Pyx_PyUnicode_GET_LENGTH(uval); + if (unlikely(!ulength)) + continue; + if (unlikely((PY_SSIZE_T_MAX >> kind_shift) - ulength < char_pos)) + goto overflow; + ukind = __Pyx_PyUnicode_KIND(uval); + udata = __Pyx_PyUnicode_DATA(uval); + if (!CYTHON_PEP393_ENABLED || ukind == result_ukind) { + memcpy((char *)result_udata + (char_pos << kind_shift), udata, (size_t) (ulength << kind_shift)); + } else { + #if PY_VERSION_HEX >= 0x030d0000 + if (unlikely(PyUnicode_CopyCharacters(result_uval, char_pos, uval, 0, ulength) < 0)) goto bad; + #elif CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030300F0 || defined(_PyUnicode_FastCopyCharacters) + _PyUnicode_FastCopyCharacters(result_uval, char_pos, uval, 0, ulength); + #else + Py_ssize_t j; + for (j=0; j < ulength; j++) { + Py_UCS4 uchar = __Pyx_PyUnicode_READ(ukind, udata, j); + __Pyx_PyUnicode_WRITE(result_ukind, result_udata, char_pos+j, uchar); + } + #endif + } + char_pos += ulength; + } + return result_uval; +overflow: + PyErr_SetString(PyExc_OverflowError, "join() result is too long for a Python string"); +bad: + Py_DECREF(result_uval); + return NULL; +#else + CYTHON_UNUSED_VAR(max_char); + CYTHON_UNUSED_VAR(result_ulength); + CYTHON_UNUSED_VAR(value_count); + return PyUnicode_Join(__pyx_empty_unicode, value_tuple); +#endif +} + +/* GetAttr */ +static CYTHON_INLINE PyObject *__Pyx_GetAttr(PyObject *o, PyObject *n) { +#if CYTHON_USE_TYPE_SLOTS +#if PY_MAJOR_VERSION >= 3 + if (likely(PyUnicode_Check(n))) +#else + if (likely(PyString_Check(n))) +#endif + return __Pyx_PyObject_GetAttrStr(o, n); +#endif + return PyObject_GetAttr(o, n); +} + +/* GetItemInt */ +static PyObject *__Pyx_GetItemInt_Generic(PyObject *o, PyObject* j) { + PyObject *r; + if (unlikely(!j)) return NULL; + r = PyObject_GetItem(o, j); + Py_DECREF(j); + return r; +} +static CYTHON_INLINE PyObject *__Pyx_GetItemInt_List_Fast(PyObject *o, Py_ssize_t i, + CYTHON_NCP_UNUSED int wraparound, + CYTHON_NCP_UNUSED int boundscheck) { +#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + Py_ssize_t wrapped_i = i; + if (wraparound & unlikely(i < 0)) { + wrapped_i += PyList_GET_SIZE(o); + } + if ((!boundscheck) || likely(__Pyx_is_valid_index(wrapped_i, PyList_GET_SIZE(o)))) { + PyObject *r = PyList_GET_ITEM(o, wrapped_i); + Py_INCREF(r); + return r; + } + return __Pyx_GetItemInt_Generic(o, PyInt_FromSsize_t(i)); +#else + return PySequence_GetItem(o, i); +#endif +} +static CYTHON_INLINE PyObject *__Pyx_GetItemInt_Tuple_Fast(PyObject *o, Py_ssize_t i, + CYTHON_NCP_UNUSED int wraparound, + CYTHON_NCP_UNUSED int boundscheck) { +#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + Py_ssize_t wrapped_i = i; + if (wraparound & unlikely(i < 0)) { + wrapped_i += PyTuple_GET_SIZE(o); + } + if ((!boundscheck) || likely(__Pyx_is_valid_index(wrapped_i, PyTuple_GET_SIZE(o)))) { + PyObject *r = PyTuple_GET_ITEM(o, wrapped_i); + Py_INCREF(r); + return r; + } + return __Pyx_GetItemInt_Generic(o, PyInt_FromSsize_t(i)); +#else + return PySequence_GetItem(o, i); +#endif +} +static CYTHON_INLINE PyObject *__Pyx_GetItemInt_Fast(PyObject *o, Py_ssize_t i, int is_list, + CYTHON_NCP_UNUSED int wraparound, + CYTHON_NCP_UNUSED int boundscheck) { +#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS && CYTHON_USE_TYPE_SLOTS + if (is_list || PyList_CheckExact(o)) { + Py_ssize_t n = ((!wraparound) | likely(i >= 0)) ? i : i + PyList_GET_SIZE(o); + if ((!boundscheck) || (likely(__Pyx_is_valid_index(n, PyList_GET_SIZE(o))))) { + PyObject *r = PyList_GET_ITEM(o, n); + Py_INCREF(r); + return r; + } + } + else if (PyTuple_CheckExact(o)) { + Py_ssize_t n = ((!wraparound) | likely(i >= 0)) ? i : i + PyTuple_GET_SIZE(o); + if ((!boundscheck) || likely(__Pyx_is_valid_index(n, PyTuple_GET_SIZE(o)))) { + PyObject *r = PyTuple_GET_ITEM(o, n); + Py_INCREF(r); + return r; + } + } else { + PyMappingMethods *mm = Py_TYPE(o)->tp_as_mapping; + PySequenceMethods *sm = Py_TYPE(o)->tp_as_sequence; + if (mm && mm->mp_subscript) { + PyObject *r, *key = PyInt_FromSsize_t(i); + if (unlikely(!key)) return NULL; + r = mm->mp_subscript(o, key); + Py_DECREF(key); + return r; + } + if (likely(sm && sm->sq_item)) { + if (wraparound && unlikely(i < 0) && likely(sm->sq_length)) { + Py_ssize_t l = sm->sq_length(o); + if (likely(l >= 0)) { + i += l; + } else { + if (!PyErr_ExceptionMatches(PyExc_OverflowError)) + return NULL; + PyErr_Clear(); + } + } + return sm->sq_item(o, i); + } + } +#else + if (is_list || !PyMapping_Check(o)) { + return PySequence_GetItem(o, i); + } +#endif + return __Pyx_GetItemInt_Generic(o, PyInt_FromSsize_t(i)); +} + +/* PyObjectCallOneArg */ +static CYTHON_INLINE PyObject* __Pyx_PyObject_CallOneArg(PyObject *func, PyObject *arg) { + PyObject *args[2] = {NULL, arg}; + return __Pyx_PyObject_FastCall(func, args+1, 1 | __Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET); +} + +/* ObjectGetItem */ +#if CYTHON_USE_TYPE_SLOTS +static PyObject *__Pyx_PyObject_GetIndex(PyObject *obj, PyObject *index) { + PyObject *runerr = NULL; + Py_ssize_t key_value; + key_value = __Pyx_PyIndex_AsSsize_t(index); + if (likely(key_value != -1 || !(runerr = PyErr_Occurred()))) { + return __Pyx_GetItemInt_Fast(obj, key_value, 0, 1, 1); + } + if (PyErr_GivenExceptionMatches(runerr, PyExc_OverflowError)) { + __Pyx_TypeName index_type_name = __Pyx_PyType_GetName(Py_TYPE(index)); + PyErr_Clear(); + PyErr_Format(PyExc_IndexError, + "cannot fit '" __Pyx_FMT_TYPENAME "' into an index-sized integer", index_type_name); + __Pyx_DECREF_TypeName(index_type_name); + } + return NULL; +} +static PyObject *__Pyx_PyObject_GetItem_Slow(PyObject *obj, PyObject *key) { + __Pyx_TypeName obj_type_name; + if (likely(PyType_Check(obj))) { + PyObject *meth = __Pyx_PyObject_GetAttrStrNoError(obj, __pyx_n_s_class_getitem); + if (!meth) { + PyErr_Clear(); + } else { + PyObject *result = __Pyx_PyObject_CallOneArg(meth, key); + Py_DECREF(meth); + return result; + } + } + obj_type_name = __Pyx_PyType_GetName(Py_TYPE(obj)); + PyErr_Format(PyExc_TypeError, + "'" __Pyx_FMT_TYPENAME "' object is not subscriptable", obj_type_name); + __Pyx_DECREF_TypeName(obj_type_name); + return NULL; +} +static PyObject *__Pyx_PyObject_GetItem(PyObject *obj, PyObject *key) { + PyTypeObject *tp = Py_TYPE(obj); + PyMappingMethods *mm = tp->tp_as_mapping; + PySequenceMethods *sm = tp->tp_as_sequence; + if (likely(mm && mm->mp_subscript)) { + return mm->mp_subscript(obj, key); + } + if (likely(sm && sm->sq_item)) { + return __Pyx_PyObject_GetIndex(obj, key); + } + return __Pyx_PyObject_GetItem_Slow(obj, key); +} +#endif + +/* KeywordStringCheck */ +static int __Pyx_CheckKeywordStrings( + PyObject *kw, + const char* function_name, + int kw_allowed) +{ + PyObject* key = 0; + Py_ssize_t pos = 0; +#if CYTHON_COMPILING_IN_PYPY + if (!kw_allowed && PyDict_Next(kw, &pos, &key, 0)) + goto invalid_keyword; + return 1; +#else + if (CYTHON_METH_FASTCALL && likely(PyTuple_Check(kw))) { + Py_ssize_t kwsize; +#if CYTHON_ASSUME_SAFE_MACROS + kwsize = PyTuple_GET_SIZE(kw); +#else + kwsize = PyTuple_Size(kw); + if (kwsize < 0) return 0; +#endif + if (unlikely(kwsize == 0)) + return 1; + if (!kw_allowed) { +#if CYTHON_ASSUME_SAFE_MACROS + key = PyTuple_GET_ITEM(kw, 0); +#else + key = PyTuple_GetItem(kw, pos); + if (!key) return 0; +#endif + goto invalid_keyword; + } +#if PY_VERSION_HEX < 0x03090000 + for (pos = 0; pos < kwsize; pos++) { +#if CYTHON_ASSUME_SAFE_MACROS + key = PyTuple_GET_ITEM(kw, pos); +#else + key = PyTuple_GetItem(kw, pos); + if (!key) return 0; +#endif + if (unlikely(!PyUnicode_Check(key))) + goto invalid_keyword_type; + } +#endif + return 1; + } + while (PyDict_Next(kw, &pos, &key, 0)) { + #if PY_MAJOR_VERSION < 3 + if (unlikely(!PyString_Check(key))) + #endif + if (unlikely(!PyUnicode_Check(key))) + goto invalid_keyword_type; + } + if (!kw_allowed && unlikely(key)) + goto invalid_keyword; + return 1; +invalid_keyword_type: + PyErr_Format(PyExc_TypeError, + "%.200s() keywords must be strings", function_name); + return 0; +#endif +invalid_keyword: + #if PY_MAJOR_VERSION < 3 + PyErr_Format(PyExc_TypeError, + "%.200s() got an unexpected keyword argument '%.200s'", + function_name, PyString_AsString(key)); + #else + PyErr_Format(PyExc_TypeError, + "%s() got an unexpected keyword argument '%U'", + function_name, key); + #endif + return 0; +} + +/* DivInt[Py_ssize_t] */ +static CYTHON_INLINE Py_ssize_t __Pyx_div_Py_ssize_t(Py_ssize_t a, Py_ssize_t b) { + Py_ssize_t q = a / b; + Py_ssize_t r = a - q*b; + q -= ((r != 0) & ((r ^ b) < 0)); + return q; +} + +/* GetAttr3 */ +#if __PYX_LIMITED_VERSION_HEX < 0x030d00A1 +static PyObject *__Pyx_GetAttr3Default(PyObject *d) { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + if (unlikely(!__Pyx_PyErr_ExceptionMatches(PyExc_AttributeError))) + return NULL; + __Pyx_PyErr_Clear(); + Py_INCREF(d); + return d; +} +#endif +static CYTHON_INLINE PyObject *__Pyx_GetAttr3(PyObject *o, PyObject *n, PyObject *d) { + PyObject *r; +#if __PYX_LIMITED_VERSION_HEX >= 0x030d00A1 + int res = PyObject_GetOptionalAttr(o, n, &r); + return (res != 0) ? r : __Pyx_NewRef(d); +#else + #if CYTHON_USE_TYPE_SLOTS + if (likely(PyString_Check(n))) { + r = __Pyx_PyObject_GetAttrStrNoError(o, n); + if (unlikely(!r) && likely(!PyErr_Occurred())) { + r = __Pyx_NewRef(d); + } + return r; + } + #endif + r = PyObject_GetAttr(o, n); + return (likely(r)) ? r : __Pyx_GetAttr3Default(d); +#endif +} + +/* PyDictVersioning */ +#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_TYPE_SLOTS +static CYTHON_INLINE PY_UINT64_T __Pyx_get_tp_dict_version(PyObject *obj) { + PyObject *dict = Py_TYPE(obj)->tp_dict; + return likely(dict) ? __PYX_GET_DICT_VERSION(dict) : 0; +} +static CYTHON_INLINE PY_UINT64_T __Pyx_get_object_dict_version(PyObject *obj) { + PyObject **dictptr = NULL; + Py_ssize_t offset = Py_TYPE(obj)->tp_dictoffset; + if (offset) { +#if CYTHON_COMPILING_IN_CPYTHON + dictptr = (likely(offset > 0)) ? (PyObject **) ((char *)obj + offset) : _PyObject_GetDictPtr(obj); +#else + dictptr = _PyObject_GetDictPtr(obj); +#endif + } + return (dictptr && *dictptr) ? __PYX_GET_DICT_VERSION(*dictptr) : 0; +} +static CYTHON_INLINE int __Pyx_object_dict_version_matches(PyObject* obj, PY_UINT64_T tp_dict_version, PY_UINT64_T obj_dict_version) { + PyObject *dict = Py_TYPE(obj)->tp_dict; + if (unlikely(!dict) || unlikely(tp_dict_version != __PYX_GET_DICT_VERSION(dict))) + return 0; + return obj_dict_version == __Pyx_get_object_dict_version(obj); +} +#endif + +/* GetModuleGlobalName */ +#if CYTHON_USE_DICT_VERSIONS +static PyObject *__Pyx__GetModuleGlobalName(PyObject *name, PY_UINT64_T *dict_version, PyObject **dict_cached_value) +#else +static CYTHON_INLINE PyObject *__Pyx__GetModuleGlobalName(PyObject *name) +#endif +{ + PyObject *result; +#if !CYTHON_AVOID_BORROWED_REFS +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030500A1 && PY_VERSION_HEX < 0x030d0000 + result = _PyDict_GetItem_KnownHash(__pyx_d, name, ((PyASCIIObject *) name)->hash); + __PYX_UPDATE_DICT_CACHE(__pyx_d, result, *dict_cached_value, *dict_version) + if (likely(result)) { + return __Pyx_NewRef(result); + } else if (unlikely(PyErr_Occurred())) { + return NULL; + } +#elif CYTHON_COMPILING_IN_LIMITED_API + if (unlikely(!__pyx_m)) { + return NULL; + } + result = PyObject_GetAttr(__pyx_m, name); + if (likely(result)) { + return result; + } +#else + result = PyDict_GetItem(__pyx_d, name); + __PYX_UPDATE_DICT_CACHE(__pyx_d, result, *dict_cached_value, *dict_version) + if (likely(result)) { + return __Pyx_NewRef(result); + } +#endif +#else + result = PyObject_GetItem(__pyx_d, name); + __PYX_UPDATE_DICT_CACHE(__pyx_d, result, *dict_cached_value, *dict_version) + if (likely(result)) { + return __Pyx_NewRef(result); + } + PyErr_Clear(); +#endif + return __Pyx_GetBuiltinName(name); +} + +/* RaiseTooManyValuesToUnpack */ +static CYTHON_INLINE void __Pyx_RaiseTooManyValuesError(Py_ssize_t expected) { + PyErr_Format(PyExc_ValueError, + "too many values to unpack (expected %" CYTHON_FORMAT_SSIZE_T "d)", expected); +} + +/* RaiseNeedMoreValuesToUnpack */ +static CYTHON_INLINE void __Pyx_RaiseNeedMoreValuesError(Py_ssize_t index) { + PyErr_Format(PyExc_ValueError, + "need more than %" CYTHON_FORMAT_SSIZE_T "d value%.1s to unpack", + index, (index == 1) ? "" : "s"); +} + +/* RaiseNoneIterError */ +static CYTHON_INLINE void __Pyx_RaiseNoneNotIterableError(void) { + PyErr_SetString(PyExc_TypeError, "'NoneType' object is not iterable"); +} + +/* ExtTypeTest */ +static CYTHON_INLINE int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type) { + __Pyx_TypeName obj_type_name; + __Pyx_TypeName type_name; + if (unlikely(!type)) { + PyErr_SetString(PyExc_SystemError, "Missing type object"); + return 0; + } + if (likely(__Pyx_TypeCheck(obj, type))) + return 1; + obj_type_name = __Pyx_PyType_GetName(Py_TYPE(obj)); + type_name = __Pyx_PyType_GetName(type); + PyErr_Format(PyExc_TypeError, + "Cannot convert " __Pyx_FMT_TYPENAME " to " __Pyx_FMT_TYPENAME, + obj_type_name, type_name); + __Pyx_DECREF_TypeName(obj_type_name); + __Pyx_DECREF_TypeName(type_name); + return 0; +} + +/* GetTopmostException */ +#if CYTHON_USE_EXC_INFO_STACK && CYTHON_FAST_THREAD_STATE +static _PyErr_StackItem * +__Pyx_PyErr_GetTopmostException(PyThreadState *tstate) +{ + _PyErr_StackItem *exc_info = tstate->exc_info; + while ((exc_info->exc_value == NULL || exc_info->exc_value == Py_None) && + exc_info->previous_item != NULL) + { + exc_info = exc_info->previous_item; + } + return exc_info; +} +#endif + +/* SaveResetException */ +#if CYTHON_FAST_THREAD_STATE +static CYTHON_INLINE void __Pyx__ExceptionSave(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb) { + #if CYTHON_USE_EXC_INFO_STACK && PY_VERSION_HEX >= 0x030B00a4 + _PyErr_StackItem *exc_info = __Pyx_PyErr_GetTopmostException(tstate); + PyObject *exc_value = exc_info->exc_value; + if (exc_value == NULL || exc_value == Py_None) { + *value = NULL; + *type = NULL; + *tb = NULL; + } else { + *value = exc_value; + Py_INCREF(*value); + *type = (PyObject*) Py_TYPE(exc_value); + Py_INCREF(*type); + *tb = PyException_GetTraceback(exc_value); + } + #elif CYTHON_USE_EXC_INFO_STACK + _PyErr_StackItem *exc_info = __Pyx_PyErr_GetTopmostException(tstate); + *type = exc_info->exc_type; + *value = exc_info->exc_value; + *tb = exc_info->exc_traceback; + Py_XINCREF(*type); + Py_XINCREF(*value); + Py_XINCREF(*tb); + #else + *type = tstate->exc_type; + *value = tstate->exc_value; + *tb = tstate->exc_traceback; + Py_XINCREF(*type); + Py_XINCREF(*value); + Py_XINCREF(*tb); + #endif +} +static CYTHON_INLINE void __Pyx__ExceptionReset(PyThreadState *tstate, PyObject *type, PyObject *value, PyObject *tb) { + #if CYTHON_USE_EXC_INFO_STACK && PY_VERSION_HEX >= 0x030B00a4 + _PyErr_StackItem *exc_info = tstate->exc_info; + PyObject *tmp_value = exc_info->exc_value; + exc_info->exc_value = value; + Py_XDECREF(tmp_value); + Py_XDECREF(type); + Py_XDECREF(tb); + #else + PyObject *tmp_type, *tmp_value, *tmp_tb; + #if CYTHON_USE_EXC_INFO_STACK + _PyErr_StackItem *exc_info = tstate->exc_info; + tmp_type = exc_info->exc_type; + tmp_value = exc_info->exc_value; + tmp_tb = exc_info->exc_traceback; + exc_info->exc_type = type; + exc_info->exc_value = value; + exc_info->exc_traceback = tb; + #else + tmp_type = tstate->exc_type; + tmp_value = tstate->exc_value; + tmp_tb = tstate->exc_traceback; + tstate->exc_type = type; + tstate->exc_value = value; + tstate->exc_traceback = tb; + #endif + Py_XDECREF(tmp_type); + Py_XDECREF(tmp_value); + Py_XDECREF(tmp_tb); + #endif +} +#endif + +/* GetException */ +#if CYTHON_FAST_THREAD_STATE +static int __Pyx__GetException(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb) +#else +static int __Pyx_GetException(PyObject **type, PyObject **value, PyObject **tb) +#endif +{ + PyObject *local_type = NULL, *local_value, *local_tb = NULL; +#if CYTHON_FAST_THREAD_STATE + PyObject *tmp_type, *tmp_value, *tmp_tb; + #if PY_VERSION_HEX >= 0x030C00A6 + local_value = tstate->current_exception; + tstate->current_exception = 0; + if (likely(local_value)) { + local_type = (PyObject*) Py_TYPE(local_value); + Py_INCREF(local_type); + local_tb = PyException_GetTraceback(local_value); + } + #else + local_type = tstate->curexc_type; + local_value = tstate->curexc_value; + local_tb = tstate->curexc_traceback; + tstate->curexc_type = 0; + tstate->curexc_value = 0; + tstate->curexc_traceback = 0; + #endif +#else + PyErr_Fetch(&local_type, &local_value, &local_tb); +#endif + PyErr_NormalizeException(&local_type, &local_value, &local_tb); +#if CYTHON_FAST_THREAD_STATE && PY_VERSION_HEX >= 0x030C00A6 + if (unlikely(tstate->current_exception)) +#elif CYTHON_FAST_THREAD_STATE + if (unlikely(tstate->curexc_type)) +#else + if (unlikely(PyErr_Occurred())) +#endif + goto bad; + #if PY_MAJOR_VERSION >= 3 + if (local_tb) { + if (unlikely(PyException_SetTraceback(local_value, local_tb) < 0)) + goto bad; + } + #endif + Py_XINCREF(local_tb); + Py_XINCREF(local_type); + Py_XINCREF(local_value); + *type = local_type; + *value = local_value; + *tb = local_tb; +#if CYTHON_FAST_THREAD_STATE + #if CYTHON_USE_EXC_INFO_STACK + { + _PyErr_StackItem *exc_info = tstate->exc_info; + #if PY_VERSION_HEX >= 0x030B00a4 + tmp_value = exc_info->exc_value; + exc_info->exc_value = local_value; + tmp_type = NULL; + tmp_tb = NULL; + Py_XDECREF(local_type); + Py_XDECREF(local_tb); + #else + tmp_type = exc_info->exc_type; + tmp_value = exc_info->exc_value; + tmp_tb = exc_info->exc_traceback; + exc_info->exc_type = local_type; + exc_info->exc_value = local_value; + exc_info->exc_traceback = local_tb; + #endif + } + #else + tmp_type = tstate->exc_type; + tmp_value = tstate->exc_value; + tmp_tb = tstate->exc_traceback; + tstate->exc_type = local_type; + tstate->exc_value = local_value; + tstate->exc_traceback = local_tb; + #endif + Py_XDECREF(tmp_type); + Py_XDECREF(tmp_value); + Py_XDECREF(tmp_tb); +#else + PyErr_SetExcInfo(local_type, local_value, local_tb); +#endif + return 0; +bad: + *type = 0; + *value = 0; + *tb = 0; + Py_XDECREF(local_type); + Py_XDECREF(local_value); + Py_XDECREF(local_tb); + return -1; +} + +/* SwapException */ +#if CYTHON_FAST_THREAD_STATE +static CYTHON_INLINE void __Pyx__ExceptionSwap(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb) { + PyObject *tmp_type, *tmp_value, *tmp_tb; + #if CYTHON_USE_EXC_INFO_STACK && PY_VERSION_HEX >= 0x030B00a4 + _PyErr_StackItem *exc_info = tstate->exc_info; + tmp_value = exc_info->exc_value; + exc_info->exc_value = *value; + if (tmp_value == NULL || tmp_value == Py_None) { + Py_XDECREF(tmp_value); + tmp_value = NULL; + tmp_type = NULL; + tmp_tb = NULL; + } else { + tmp_type = (PyObject*) Py_TYPE(tmp_value); + Py_INCREF(tmp_type); + #if CYTHON_COMPILING_IN_CPYTHON + tmp_tb = ((PyBaseExceptionObject*) tmp_value)->traceback; + Py_XINCREF(tmp_tb); + #else + tmp_tb = PyException_GetTraceback(tmp_value); + #endif + } + #elif CYTHON_USE_EXC_INFO_STACK + _PyErr_StackItem *exc_info = tstate->exc_info; + tmp_type = exc_info->exc_type; + tmp_value = exc_info->exc_value; + tmp_tb = exc_info->exc_traceback; + exc_info->exc_type = *type; + exc_info->exc_value = *value; + exc_info->exc_traceback = *tb; + #else + tmp_type = tstate->exc_type; + tmp_value = tstate->exc_value; + tmp_tb = tstate->exc_traceback; + tstate->exc_type = *type; + tstate->exc_value = *value; + tstate->exc_traceback = *tb; + #endif + *type = tmp_type; + *value = tmp_value; + *tb = tmp_tb; +} +#else +static CYTHON_INLINE void __Pyx_ExceptionSwap(PyObject **type, PyObject **value, PyObject **tb) { + PyObject *tmp_type, *tmp_value, *tmp_tb; + PyErr_GetExcInfo(&tmp_type, &tmp_value, &tmp_tb); + PyErr_SetExcInfo(*type, *value, *tb); + *type = tmp_type; + *value = tmp_value; + *tb = tmp_tb; +} +#endif + +/* Import */ +static PyObject *__Pyx_Import(PyObject *name, PyObject *from_list, int level) { + PyObject *module = 0; + PyObject *empty_dict = 0; + PyObject *empty_list = 0; + #if PY_MAJOR_VERSION < 3 + PyObject *py_import; + py_import = __Pyx_PyObject_GetAttrStr(__pyx_b, __pyx_n_s_import); + if (unlikely(!py_import)) + goto bad; + if (!from_list) { + empty_list = PyList_New(0); + if (unlikely(!empty_list)) + goto bad; + from_list = empty_list; + } + #endif + empty_dict = PyDict_New(); + if (unlikely(!empty_dict)) + goto bad; + { + #if PY_MAJOR_VERSION >= 3 + if (level == -1) { + if (strchr(__Pyx_MODULE_NAME, '.') != NULL) { + module = PyImport_ImportModuleLevelObject( + name, __pyx_d, empty_dict, from_list, 1); + if (unlikely(!module)) { + if (unlikely(!PyErr_ExceptionMatches(PyExc_ImportError))) + goto bad; + PyErr_Clear(); + } + } + level = 0; + } + #endif + if (!module) { + #if PY_MAJOR_VERSION < 3 + PyObject *py_level = PyInt_FromLong(level); + if (unlikely(!py_level)) + goto bad; + module = PyObject_CallFunctionObjArgs(py_import, + name, __pyx_d, empty_dict, from_list, py_level, (PyObject *)NULL); + Py_DECREF(py_level); + #else + module = PyImport_ImportModuleLevelObject( + name, __pyx_d, empty_dict, from_list, level); + #endif + } + } +bad: + Py_XDECREF(empty_dict); + Py_XDECREF(empty_list); + #if PY_MAJOR_VERSION < 3 + Py_XDECREF(py_import); + #endif + return module; +} + +/* ImportDottedModule */ +#if PY_MAJOR_VERSION >= 3 +static PyObject *__Pyx__ImportDottedModule_Error(PyObject *name, PyObject *parts_tuple, Py_ssize_t count) { + PyObject *partial_name = NULL, *slice = NULL, *sep = NULL; + if (unlikely(PyErr_Occurred())) { + PyErr_Clear(); + } + if (likely(PyTuple_GET_SIZE(parts_tuple) == count)) { + partial_name = name; + } else { + slice = PySequence_GetSlice(parts_tuple, 0, count); + if (unlikely(!slice)) + goto bad; + sep = PyUnicode_FromStringAndSize(".", 1); + if (unlikely(!sep)) + goto bad; + partial_name = PyUnicode_Join(sep, slice); + } + PyErr_Format( +#if PY_MAJOR_VERSION < 3 + PyExc_ImportError, + "No module named '%s'", PyString_AS_STRING(partial_name)); +#else +#if PY_VERSION_HEX >= 0x030600B1 + PyExc_ModuleNotFoundError, +#else + PyExc_ImportError, +#endif + "No module named '%U'", partial_name); +#endif +bad: + Py_XDECREF(sep); + Py_XDECREF(slice); + Py_XDECREF(partial_name); + return NULL; +} +#endif +#if PY_MAJOR_VERSION >= 3 +static PyObject *__Pyx__ImportDottedModule_Lookup(PyObject *name) { + PyObject *imported_module; +#if PY_VERSION_HEX < 0x030700A1 || (CYTHON_COMPILING_IN_PYPY && PYPY_VERSION_NUM < 0x07030400) + PyObject *modules = PyImport_GetModuleDict(); + if (unlikely(!modules)) + return NULL; + imported_module = __Pyx_PyDict_GetItemStr(modules, name); + Py_XINCREF(imported_module); +#else + imported_module = PyImport_GetModule(name); +#endif + return imported_module; +} +#endif +#if PY_MAJOR_VERSION >= 3 +static PyObject *__Pyx_ImportDottedModule_WalkParts(PyObject *module, PyObject *name, PyObject *parts_tuple) { + Py_ssize_t i, nparts; + nparts = PyTuple_GET_SIZE(parts_tuple); + for (i=1; i < nparts && module; i++) { + PyObject *part, *submodule; +#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + part = PyTuple_GET_ITEM(parts_tuple, i); +#else + part = PySequence_ITEM(parts_tuple, i); +#endif + submodule = __Pyx_PyObject_GetAttrStrNoError(module, part); +#if !(CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS) + Py_DECREF(part); +#endif + Py_DECREF(module); + module = submodule; + } + if (unlikely(!module)) { + return __Pyx__ImportDottedModule_Error(name, parts_tuple, i); + } + return module; +} +#endif +static PyObject *__Pyx__ImportDottedModule(PyObject *name, PyObject *parts_tuple) { +#if PY_MAJOR_VERSION < 3 + PyObject *module, *from_list, *star = __pyx_n_s__3; + CYTHON_UNUSED_VAR(parts_tuple); + from_list = PyList_New(1); + if (unlikely(!from_list)) + return NULL; + Py_INCREF(star); + PyList_SET_ITEM(from_list, 0, star); + module = __Pyx_Import(name, from_list, 0); + Py_DECREF(from_list); + return module; +#else + PyObject *imported_module; + PyObject *module = __Pyx_Import(name, NULL, 0); + if (!parts_tuple || unlikely(!module)) + return module; + imported_module = __Pyx__ImportDottedModule_Lookup(name); + if (likely(imported_module)) { + Py_DECREF(module); + return imported_module; + } + PyErr_Clear(); + return __Pyx_ImportDottedModule_WalkParts(module, name, parts_tuple); +#endif +} +static PyObject *__Pyx_ImportDottedModule(PyObject *name, PyObject *parts_tuple) { +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030400B1 + PyObject *module = __Pyx__ImportDottedModule_Lookup(name); + if (likely(module)) { + PyObject *spec = __Pyx_PyObject_GetAttrStrNoError(module, __pyx_n_s_spec); + if (likely(spec)) { + PyObject *unsafe = __Pyx_PyObject_GetAttrStrNoError(spec, __pyx_n_s_initializing); + if (likely(!unsafe || !__Pyx_PyObject_IsTrue(unsafe))) { + Py_DECREF(spec); + spec = NULL; + } + Py_XDECREF(unsafe); + } + if (likely(!spec)) { + PyErr_Clear(); + return module; + } + Py_DECREF(spec); + Py_DECREF(module); + } else if (PyErr_Occurred()) { + PyErr_Clear(); + } +#endif + return __Pyx__ImportDottedModule(name, parts_tuple); +} + +/* FastTypeChecks */ +#if CYTHON_COMPILING_IN_CPYTHON +static int __Pyx_InBases(PyTypeObject *a, PyTypeObject *b) { + while (a) { + a = __Pyx_PyType_GetSlot(a, tp_base, PyTypeObject*); + if (a == b) + return 1; + } + return b == &PyBaseObject_Type; +} +static CYTHON_INLINE int __Pyx_IsSubtype(PyTypeObject *a, PyTypeObject *b) { + PyObject *mro; + if (a == b) return 1; + mro = a->tp_mro; + if (likely(mro)) { + Py_ssize_t i, n; + n = PyTuple_GET_SIZE(mro); + for (i = 0; i < n; i++) { + if (PyTuple_GET_ITEM(mro, i) == (PyObject *)b) + return 1; + } + return 0; + } + return __Pyx_InBases(a, b); +} +static CYTHON_INLINE int __Pyx_IsAnySubtype2(PyTypeObject *cls, PyTypeObject *a, PyTypeObject *b) { + PyObject *mro; + if (cls == a || cls == b) return 1; + mro = cls->tp_mro; + if (likely(mro)) { + Py_ssize_t i, n; + n = PyTuple_GET_SIZE(mro); + for (i = 0; i < n; i++) { + PyObject *base = PyTuple_GET_ITEM(mro, i); + if (base == (PyObject *)a || base == (PyObject *)b) + return 1; + } + return 0; + } + return __Pyx_InBases(cls, a) || __Pyx_InBases(cls, b); +} +#if PY_MAJOR_VERSION == 2 +static int __Pyx_inner_PyErr_GivenExceptionMatches2(PyObject *err, PyObject* exc_type1, PyObject* exc_type2) { + PyObject *exception, *value, *tb; + int res; + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ErrFetch(&exception, &value, &tb); + res = exc_type1 ? PyObject_IsSubclass(err, exc_type1) : 0; + if (unlikely(res == -1)) { + PyErr_WriteUnraisable(err); + res = 0; + } + if (!res) { + res = PyObject_IsSubclass(err, exc_type2); + if (unlikely(res == -1)) { + PyErr_WriteUnraisable(err); + res = 0; + } + } + __Pyx_ErrRestore(exception, value, tb); + return res; +} +#else +static CYTHON_INLINE int __Pyx_inner_PyErr_GivenExceptionMatches2(PyObject *err, PyObject* exc_type1, PyObject *exc_type2) { + if (exc_type1) { + return __Pyx_IsAnySubtype2((PyTypeObject*)err, (PyTypeObject*)exc_type1, (PyTypeObject*)exc_type2); + } else { + return __Pyx_IsSubtype((PyTypeObject*)err, (PyTypeObject*)exc_type2); + } +} +#endif +static int __Pyx_PyErr_GivenExceptionMatchesTuple(PyObject *exc_type, PyObject *tuple) { + Py_ssize_t i, n; + assert(PyExceptionClass_Check(exc_type)); + n = PyTuple_GET_SIZE(tuple); +#if PY_MAJOR_VERSION >= 3 + for (i=0; itp_as_sequence && type->tp_as_sequence->sq_repeat)) { + return type->tp_as_sequence->sq_repeat(seq, mul); + } else +#endif + { + return __Pyx_PySequence_Multiply_Generic(seq, mul); + } +} + +/* SetItemInt */ +static int __Pyx_SetItemInt_Generic(PyObject *o, PyObject *j, PyObject *v) { + int r; + if (unlikely(!j)) return -1; + r = PyObject_SetItem(o, j, v); + Py_DECREF(j); + return r; +} +static CYTHON_INLINE int __Pyx_SetItemInt_Fast(PyObject *o, Py_ssize_t i, PyObject *v, int is_list, + CYTHON_NCP_UNUSED int wraparound, CYTHON_NCP_UNUSED int boundscheck) { +#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS && CYTHON_USE_TYPE_SLOTS + if (is_list || PyList_CheckExact(o)) { + Py_ssize_t n = (!wraparound) ? i : ((likely(i >= 0)) ? i : i + PyList_GET_SIZE(o)); + if ((!boundscheck) || likely(__Pyx_is_valid_index(n, PyList_GET_SIZE(o)))) { + PyObject* old = PyList_GET_ITEM(o, n); + Py_INCREF(v); + PyList_SET_ITEM(o, n, v); + Py_DECREF(old); + return 1; + } + } else { + PyMappingMethods *mm = Py_TYPE(o)->tp_as_mapping; + PySequenceMethods *sm = Py_TYPE(o)->tp_as_sequence; + if (mm && mm->mp_ass_subscript) { + int r; + PyObject *key = PyInt_FromSsize_t(i); + if (unlikely(!key)) return -1; + r = mm->mp_ass_subscript(o, key, v); + Py_DECREF(key); + return r; + } + if (likely(sm && sm->sq_ass_item)) { + if (wraparound && unlikely(i < 0) && likely(sm->sq_length)) { + Py_ssize_t l = sm->sq_length(o); + if (likely(l >= 0)) { + i += l; + } else { + if (!PyErr_ExceptionMatches(PyExc_OverflowError)) + return -1; + PyErr_Clear(); + } + } + return sm->sq_ass_item(o, i, v); + } + } +#else + if (is_list || !PyMapping_Check(o)) + { + return PySequence_SetItem(o, i, v); + } +#endif + return __Pyx_SetItemInt_Generic(o, PyInt_FromSsize_t(i), v); +} + +/* RaiseUnboundLocalError */ +static CYTHON_INLINE void __Pyx_RaiseUnboundLocalError(const char *varname) { + PyErr_Format(PyExc_UnboundLocalError, "local variable '%s' referenced before assignment", varname); +} + +/* DivInt[long] */ +static CYTHON_INLINE long __Pyx_div_long(long a, long b) { + long q = a / b; + long r = a - q*b; + q -= ((r != 0) & ((r ^ b) < 0)); + return q; +} + +/* ImportFrom */ +static PyObject* __Pyx_ImportFrom(PyObject* module, PyObject* name) { + PyObject* value = __Pyx_PyObject_GetAttrStr(module, name); + if (unlikely(!value) && PyErr_ExceptionMatches(PyExc_AttributeError)) { + const char* module_name_str = 0; + PyObject* module_name = 0; + PyObject* module_dot = 0; + PyObject* full_name = 0; + PyErr_Clear(); + module_name_str = PyModule_GetName(module); + if (unlikely(!module_name_str)) { goto modbad; } + module_name = PyUnicode_FromString(module_name_str); + if (unlikely(!module_name)) { goto modbad; } + module_dot = PyUnicode_Concat(module_name, __pyx_kp_u__2); + if (unlikely(!module_dot)) { goto modbad; } + full_name = PyUnicode_Concat(module_dot, name); + if (unlikely(!full_name)) { goto modbad; } + #if PY_VERSION_HEX < 0x030700A1 || (CYTHON_COMPILING_IN_PYPY && PYPY_VERSION_NUM < 0x07030400) + { + PyObject *modules = PyImport_GetModuleDict(); + if (unlikely(!modules)) + goto modbad; + value = PyObject_GetItem(modules, full_name); + } + #else + value = PyImport_GetModule(full_name); + #endif + modbad: + Py_XDECREF(full_name); + Py_XDECREF(module_dot); + Py_XDECREF(module_name); + } + if (unlikely(!value)) { + PyErr_Format(PyExc_ImportError, + #if PY_MAJOR_VERSION < 3 + "cannot import name %.230s", PyString_AS_STRING(name)); + #else + "cannot import name %S", name); + #endif + } + return value; +} + +/* HasAttr */ +#if __PYX_LIMITED_VERSION_HEX < 0x030d00A1 +static CYTHON_INLINE int __Pyx_HasAttr(PyObject *o, PyObject *n) { + PyObject *r; + if (unlikely(!__Pyx_PyBaseString_Check(n))) { + PyErr_SetString(PyExc_TypeError, + "hasattr(): attribute name must be string"); + return -1; + } + r = __Pyx_GetAttr(o, n); + if (!r) { + PyErr_Clear(); + return 0; + } else { + Py_DECREF(r); + return 1; + } +} +#endif + +/* IsLittleEndian */ +static CYTHON_INLINE int __Pyx_Is_Little_Endian(void) +{ + union { + uint32_t u32; + uint8_t u8[4]; + } S; + S.u32 = 0x01020304; + return S.u8[0] == 4; +} + +/* BufferFormatCheck */ +static void __Pyx_BufFmt_Init(__Pyx_BufFmt_Context* ctx, + __Pyx_BufFmt_StackElem* stack, + __Pyx_TypeInfo* type) { + stack[0].field = &ctx->root; + stack[0].parent_offset = 0; + ctx->root.type = type; + ctx->root.name = "buffer dtype"; + ctx->root.offset = 0; + ctx->head = stack; + ctx->head->field = &ctx->root; + ctx->fmt_offset = 0; + ctx->head->parent_offset = 0; + ctx->new_packmode = '@'; + ctx->enc_packmode = '@'; + ctx->new_count = 1; + ctx->enc_count = 0; + ctx->enc_type = 0; + ctx->is_complex = 0; + ctx->is_valid_array = 0; + ctx->struct_alignment = 0; + while (type->typegroup == 'S') { + ++ctx->head; + ctx->head->field = type->fields; + ctx->head->parent_offset = 0; + type = type->fields->type; + } +} +static int __Pyx_BufFmt_ParseNumber(const char** ts) { + int count; + const char* t = *ts; + if (*t < '0' || *t > '9') { + return -1; + } else { + count = *t++ - '0'; + while (*t >= '0' && *t <= '9') { + count *= 10; + count += *t++ - '0'; + } + } + *ts = t; + return count; +} +static int __Pyx_BufFmt_ExpectNumber(const char **ts) { + int number = __Pyx_BufFmt_ParseNumber(ts); + if (number == -1) + PyErr_Format(PyExc_ValueError,\ + "Does not understand character buffer dtype format string ('%c')", **ts); + return number; +} +static void __Pyx_BufFmt_RaiseUnexpectedChar(char ch) { + PyErr_Format(PyExc_ValueError, + "Unexpected format string character: '%c'", ch); +} +static const char* __Pyx_BufFmt_DescribeTypeChar(char ch, int is_complex) { + switch (ch) { + case '?': return "'bool'"; + case 'c': return "'char'"; + case 'b': return "'signed char'"; + case 'B': return "'unsigned char'"; + case 'h': return "'short'"; + case 'H': return "'unsigned short'"; + case 'i': return "'int'"; + case 'I': return "'unsigned int'"; + case 'l': return "'long'"; + case 'L': return "'unsigned long'"; + case 'q': return "'long long'"; + case 'Q': return "'unsigned long long'"; + case 'f': return (is_complex ? "'complex float'" : "'float'"); + case 'd': return (is_complex ? "'complex double'" : "'double'"); + case 'g': return (is_complex ? "'complex long double'" : "'long double'"); + case 'T': return "a struct"; + case 'O': return "Python object"; + case 'P': return "a pointer"; + case 's': case 'p': return "a string"; + case 0: return "end"; + default: return "unparsable format string"; + } +} +static size_t __Pyx_BufFmt_TypeCharToStandardSize(char ch, int is_complex) { + switch (ch) { + case '?': case 'c': case 'b': case 'B': case 's': case 'p': return 1; + case 'h': case 'H': return 2; + case 'i': case 'I': case 'l': case 'L': return 4; + case 'q': case 'Q': return 8; + case 'f': return (is_complex ? 8 : 4); + case 'd': return (is_complex ? 16 : 8); + case 'g': { + PyErr_SetString(PyExc_ValueError, "Python does not define a standard format string size for long double ('g').."); + return 0; + } + case 'O': case 'P': return sizeof(void*); + default: + __Pyx_BufFmt_RaiseUnexpectedChar(ch); + return 0; + } +} +static size_t __Pyx_BufFmt_TypeCharToNativeSize(char ch, int is_complex) { + switch (ch) { + case '?': case 'c': case 'b': case 'B': case 's': case 'p': return 1; + case 'h': case 'H': return sizeof(short); + case 'i': case 'I': return sizeof(int); + case 'l': case 'L': return sizeof(long); + #ifdef HAVE_LONG_LONG + case 'q': case 'Q': return sizeof(PY_LONG_LONG); + #endif + case 'f': return sizeof(float) * (is_complex ? 2 : 1); + case 'd': return sizeof(double) * (is_complex ? 2 : 1); + case 'g': return sizeof(long double) * (is_complex ? 2 : 1); + case 'O': case 'P': return sizeof(void*); + default: { + __Pyx_BufFmt_RaiseUnexpectedChar(ch); + return 0; + } + } +} +typedef struct { char c; short x; } __Pyx_st_short; +typedef struct { char c; int x; } __Pyx_st_int; +typedef struct { char c; long x; } __Pyx_st_long; +typedef struct { char c; float x; } __Pyx_st_float; +typedef struct { char c; double x; } __Pyx_st_double; +typedef struct { char c; long double x; } __Pyx_st_longdouble; +typedef struct { char c; void *x; } __Pyx_st_void_p; +#ifdef HAVE_LONG_LONG +typedef struct { char c; PY_LONG_LONG x; } __Pyx_st_longlong; +#endif +static size_t __Pyx_BufFmt_TypeCharToAlignment(char ch, int is_complex) { + CYTHON_UNUSED_VAR(is_complex); + switch (ch) { + case '?': case 'c': case 'b': case 'B': case 's': case 'p': return 1; + case 'h': case 'H': return sizeof(__Pyx_st_short) - sizeof(short); + case 'i': case 'I': return sizeof(__Pyx_st_int) - sizeof(int); + case 'l': case 'L': return sizeof(__Pyx_st_long) - sizeof(long); +#ifdef HAVE_LONG_LONG + case 'q': case 'Q': return sizeof(__Pyx_st_longlong) - sizeof(PY_LONG_LONG); +#endif + case 'f': return sizeof(__Pyx_st_float) - sizeof(float); + case 'd': return sizeof(__Pyx_st_double) - sizeof(double); + case 'g': return sizeof(__Pyx_st_longdouble) - sizeof(long double); + case 'P': case 'O': return sizeof(__Pyx_st_void_p) - sizeof(void*); + default: + __Pyx_BufFmt_RaiseUnexpectedChar(ch); + return 0; + } +} +/* These are for computing the padding at the end of the struct to align + on the first member of the struct. This will probably the same as above, + but we don't have any guarantees. + */ +typedef struct { short x; char c; } __Pyx_pad_short; +typedef struct { int x; char c; } __Pyx_pad_int; +typedef struct { long x; char c; } __Pyx_pad_long; +typedef struct { float x; char c; } __Pyx_pad_float; +typedef struct { double x; char c; } __Pyx_pad_double; +typedef struct { long double x; char c; } __Pyx_pad_longdouble; +typedef struct { void *x; char c; } __Pyx_pad_void_p; +#ifdef HAVE_LONG_LONG +typedef struct { PY_LONG_LONG x; char c; } __Pyx_pad_longlong; +#endif +static size_t __Pyx_BufFmt_TypeCharToPadding(char ch, int is_complex) { + CYTHON_UNUSED_VAR(is_complex); + switch (ch) { + case '?': case 'c': case 'b': case 'B': case 's': case 'p': return 1; + case 'h': case 'H': return sizeof(__Pyx_pad_short) - sizeof(short); + case 'i': case 'I': return sizeof(__Pyx_pad_int) - sizeof(int); + case 'l': case 'L': return sizeof(__Pyx_pad_long) - sizeof(long); +#ifdef HAVE_LONG_LONG + case 'q': case 'Q': return sizeof(__Pyx_pad_longlong) - sizeof(PY_LONG_LONG); +#endif + case 'f': return sizeof(__Pyx_pad_float) - sizeof(float); + case 'd': return sizeof(__Pyx_pad_double) - sizeof(double); + case 'g': return sizeof(__Pyx_pad_longdouble) - sizeof(long double); + case 'P': case 'O': return sizeof(__Pyx_pad_void_p) - sizeof(void*); + default: + __Pyx_BufFmt_RaiseUnexpectedChar(ch); + return 0; + } +} +static char __Pyx_BufFmt_TypeCharToGroup(char ch, int is_complex) { + switch (ch) { + case 'c': + return 'H'; + case 'b': case 'h': case 'i': + case 'l': case 'q': case 's': case 'p': + return 'I'; + case '?': case 'B': case 'H': case 'I': case 'L': case 'Q': + return 'U'; + case 'f': case 'd': case 'g': + return (is_complex ? 'C' : 'R'); + case 'O': + return 'O'; + case 'P': + return 'P'; + default: { + __Pyx_BufFmt_RaiseUnexpectedChar(ch); + return 0; + } + } +} +static void __Pyx_BufFmt_RaiseExpected(__Pyx_BufFmt_Context* ctx) { + if (ctx->head == NULL || ctx->head->field == &ctx->root) { + const char* expected; + const char* quote; + if (ctx->head == NULL) { + expected = "end"; + quote = ""; + } else { + expected = ctx->head->field->type->name; + quote = "'"; + } + PyErr_Format(PyExc_ValueError, + "Buffer dtype mismatch, expected %s%s%s but got %s", + quote, expected, quote, + __Pyx_BufFmt_DescribeTypeChar(ctx->enc_type, ctx->is_complex)); + } else { + __Pyx_StructField* field = ctx->head->field; + __Pyx_StructField* parent = (ctx->head - 1)->field; + PyErr_Format(PyExc_ValueError, + "Buffer dtype mismatch, expected '%s' but got %s in '%s.%s'", + field->type->name, __Pyx_BufFmt_DescribeTypeChar(ctx->enc_type, ctx->is_complex), + parent->type->name, field->name); + } +} +static int __Pyx_BufFmt_ProcessTypeChunk(__Pyx_BufFmt_Context* ctx) { + char group; + size_t size, offset, arraysize = 1; + if (ctx->enc_type == 0) return 0; + if (ctx->head->field->type->arraysize[0]) { + int i, ndim = 0; + if (ctx->enc_type == 's' || ctx->enc_type == 'p') { + ctx->is_valid_array = ctx->head->field->type->ndim == 1; + ndim = 1; + if (ctx->enc_count != ctx->head->field->type->arraysize[0]) { + PyErr_Format(PyExc_ValueError, + "Expected a dimension of size %zu, got %zu", + ctx->head->field->type->arraysize[0], ctx->enc_count); + return -1; + } + } + if (!ctx->is_valid_array) { + PyErr_Format(PyExc_ValueError, "Expected %d dimensions, got %d", + ctx->head->field->type->ndim, ndim); + return -1; + } + for (i = 0; i < ctx->head->field->type->ndim; i++) { + arraysize *= ctx->head->field->type->arraysize[i]; + } + ctx->is_valid_array = 0; + ctx->enc_count = 1; + } + group = __Pyx_BufFmt_TypeCharToGroup(ctx->enc_type, ctx->is_complex); + do { + __Pyx_StructField* field = ctx->head->field; + __Pyx_TypeInfo* type = field->type; + if (ctx->enc_packmode == '@' || ctx->enc_packmode == '^') { + size = __Pyx_BufFmt_TypeCharToNativeSize(ctx->enc_type, ctx->is_complex); + } else { + size = __Pyx_BufFmt_TypeCharToStandardSize(ctx->enc_type, ctx->is_complex); + } + if (ctx->enc_packmode == '@') { + size_t align_at = __Pyx_BufFmt_TypeCharToAlignment(ctx->enc_type, ctx->is_complex); + size_t align_mod_offset; + if (align_at == 0) return -1; + align_mod_offset = ctx->fmt_offset % align_at; + if (align_mod_offset > 0) ctx->fmt_offset += align_at - align_mod_offset; + if (ctx->struct_alignment == 0) + ctx->struct_alignment = __Pyx_BufFmt_TypeCharToPadding(ctx->enc_type, + ctx->is_complex); + } + if (type->size != size || type->typegroup != group) { + if (type->typegroup == 'C' && type->fields != NULL) { + size_t parent_offset = ctx->head->parent_offset + field->offset; + ++ctx->head; + ctx->head->field = type->fields; + ctx->head->parent_offset = parent_offset; + continue; + } + if ((type->typegroup == 'H' || group == 'H') && type->size == size) { + } else { + __Pyx_BufFmt_RaiseExpected(ctx); + return -1; + } + } + offset = ctx->head->parent_offset + field->offset; + if (ctx->fmt_offset != offset) { + PyErr_Format(PyExc_ValueError, + "Buffer dtype mismatch; next field is at offset %" CYTHON_FORMAT_SSIZE_T "d but %" CYTHON_FORMAT_SSIZE_T "d expected", + (Py_ssize_t)ctx->fmt_offset, (Py_ssize_t)offset); + return -1; + } + ctx->fmt_offset += size; + if (arraysize) + ctx->fmt_offset += (arraysize - 1) * size; + --ctx->enc_count; + while (1) { + if (field == &ctx->root) { + ctx->head = NULL; + if (ctx->enc_count != 0) { + __Pyx_BufFmt_RaiseExpected(ctx); + return -1; + } + break; + } + ctx->head->field = ++field; + if (field->type == NULL) { + --ctx->head; + field = ctx->head->field; + continue; + } else if (field->type->typegroup == 'S') { + size_t parent_offset = ctx->head->parent_offset + field->offset; + if (field->type->fields->type == NULL) continue; + field = field->type->fields; + ++ctx->head; + ctx->head->field = field; + ctx->head->parent_offset = parent_offset; + break; + } else { + break; + } + } + } while (ctx->enc_count); + ctx->enc_type = 0; + ctx->is_complex = 0; + return 0; +} +static int +__pyx_buffmt_parse_array(__Pyx_BufFmt_Context* ctx, const char** tsp) +{ + const char *ts = *tsp; + int i = 0, number, ndim; + ++ts; + if (ctx->new_count != 1) { + PyErr_SetString(PyExc_ValueError, + "Cannot handle repeated arrays in format string"); + return -1; + } + if (__Pyx_BufFmt_ProcessTypeChunk(ctx) == -1) return -1; + ndim = ctx->head->field->type->ndim; + while (*ts && *ts != ')') { + switch (*ts) { + case ' ': case '\f': case '\r': case '\n': case '\t': case '\v': continue; + default: break; + } + number = __Pyx_BufFmt_ExpectNumber(&ts); + if (number == -1) return -1; + if (i < ndim && (size_t) number != ctx->head->field->type->arraysize[i]) { + PyErr_Format(PyExc_ValueError, + "Expected a dimension of size %zu, got %d", + ctx->head->field->type->arraysize[i], number); + return -1; + } + if (*ts != ',' && *ts != ')') { + PyErr_Format(PyExc_ValueError, + "Expected a comma in format string, got '%c'", *ts); + return -1; + } + if (*ts == ',') ts++; + i++; + } + if (i != ndim) { + PyErr_Format(PyExc_ValueError, "Expected %d dimension(s), got %d", + ctx->head->field->type->ndim, i); + return -1; + } + if (!*ts) { + PyErr_SetString(PyExc_ValueError, + "Unexpected end of format string, expected ')'"); + return -1; + } + ctx->is_valid_array = 1; + ctx->new_count = 1; + *tsp = ++ts; + return 0; +} +static const char* __Pyx_BufFmt_CheckString(__Pyx_BufFmt_Context* ctx, const char* ts) { + int got_Z = 0; + while (1) { + switch(*ts) { + case 0: + if (ctx->enc_type != 0 && ctx->head == NULL) { + __Pyx_BufFmt_RaiseExpected(ctx); + return NULL; + } + if (__Pyx_BufFmt_ProcessTypeChunk(ctx) == -1) return NULL; + if (ctx->head != NULL) { + __Pyx_BufFmt_RaiseExpected(ctx); + return NULL; + } + return ts; + case ' ': + case '\r': + case '\n': + ++ts; + break; + case '<': + if (!__Pyx_Is_Little_Endian()) { + PyErr_SetString(PyExc_ValueError, "Little-endian buffer not supported on big-endian compiler"); + return NULL; + } + ctx->new_packmode = '='; + ++ts; + break; + case '>': + case '!': + if (__Pyx_Is_Little_Endian()) { + PyErr_SetString(PyExc_ValueError, "Big-endian buffer not supported on little-endian compiler"); + return NULL; + } + ctx->new_packmode = '='; + ++ts; + break; + case '=': + case '@': + case '^': + ctx->new_packmode = *ts++; + break; + case 'T': + { + const char* ts_after_sub; + size_t i, struct_count = ctx->new_count; + size_t struct_alignment = ctx->struct_alignment; + ctx->new_count = 1; + ++ts; + if (*ts != '{') { + PyErr_SetString(PyExc_ValueError, "Buffer acquisition: Expected '{' after 'T'"); + return NULL; + } + if (__Pyx_BufFmt_ProcessTypeChunk(ctx) == -1) return NULL; + ctx->enc_type = 0; + ctx->enc_count = 0; + ctx->struct_alignment = 0; + ++ts; + ts_after_sub = ts; + for (i = 0; i != struct_count; ++i) { + ts_after_sub = __Pyx_BufFmt_CheckString(ctx, ts); + if (!ts_after_sub) return NULL; + } + ts = ts_after_sub; + if (struct_alignment) ctx->struct_alignment = struct_alignment; + } + break; + case '}': + { + size_t alignment = ctx->struct_alignment; + ++ts; + if (__Pyx_BufFmt_ProcessTypeChunk(ctx) == -1) return NULL; + ctx->enc_type = 0; + if (alignment && ctx->fmt_offset % alignment) { + ctx->fmt_offset += alignment - (ctx->fmt_offset % alignment); + } + } + return ts; + case 'x': + if (__Pyx_BufFmt_ProcessTypeChunk(ctx) == -1) return NULL; + ctx->fmt_offset += ctx->new_count; + ctx->new_count = 1; + ctx->enc_count = 0; + ctx->enc_type = 0; + ctx->enc_packmode = ctx->new_packmode; + ++ts; + break; + case 'Z': + got_Z = 1; + ++ts; + if (*ts != 'f' && *ts != 'd' && *ts != 'g') { + __Pyx_BufFmt_RaiseUnexpectedChar('Z'); + return NULL; + } + CYTHON_FALLTHROUGH; + case '?': case 'c': case 'b': case 'B': case 'h': case 'H': case 'i': case 'I': + case 'l': case 'L': case 'q': case 'Q': + case 'f': case 'd': case 'g': + case 'O': case 'p': + if ((ctx->enc_type == *ts) && (got_Z == ctx->is_complex) && + (ctx->enc_packmode == ctx->new_packmode) && (!ctx->is_valid_array)) { + ctx->enc_count += ctx->new_count; + ctx->new_count = 1; + got_Z = 0; + ++ts; + break; + } + CYTHON_FALLTHROUGH; + case 's': + if (__Pyx_BufFmt_ProcessTypeChunk(ctx) == -1) return NULL; + ctx->enc_count = ctx->new_count; + ctx->enc_packmode = ctx->new_packmode; + ctx->enc_type = *ts; + ctx->is_complex = got_Z; + ++ts; + ctx->new_count = 1; + got_Z = 0; + break; + case ':': + ++ts; + while(*ts != ':') ++ts; + ++ts; + break; + case '(': + if (__pyx_buffmt_parse_array(ctx, &ts) < 0) return NULL; + break; + default: + { + int number = __Pyx_BufFmt_ExpectNumber(&ts); + if (number == -1) return NULL; + ctx->new_count = (size_t)number; + } + } + } +} + +/* BufferGetAndValidate */ + static CYTHON_INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info) { + if (unlikely(info->buf == NULL)) return; + if (info->suboffsets == __Pyx_minusones) info->suboffsets = NULL; + __Pyx_ReleaseBuffer(info); +} +static void __Pyx_ZeroBuffer(Py_buffer* buf) { + buf->buf = NULL; + buf->obj = NULL; + buf->strides = __Pyx_zeros; + buf->shape = __Pyx_zeros; + buf->suboffsets = __Pyx_minusones; +} +static int __Pyx__GetBufferAndValidate( + Py_buffer* buf, PyObject* obj, __Pyx_TypeInfo* dtype, int flags, + int nd, int cast, __Pyx_BufFmt_StackElem* stack) +{ + buf->buf = NULL; + if (unlikely(__Pyx_GetBuffer(obj, buf, flags) == -1)) { + __Pyx_ZeroBuffer(buf); + return -1; + } + if (unlikely(buf->ndim != nd)) { + PyErr_Format(PyExc_ValueError, + "Buffer has wrong number of dimensions (expected %d, got %d)", + nd, buf->ndim); + goto fail; + } + if (!cast) { + __Pyx_BufFmt_Context ctx; + __Pyx_BufFmt_Init(&ctx, stack, dtype); + if (!__Pyx_BufFmt_CheckString(&ctx, buf->format)) goto fail; + } + if (unlikely((size_t)buf->itemsize != dtype->size)) { + PyErr_Format(PyExc_ValueError, + "Item size of buffer (%" CYTHON_FORMAT_SSIZE_T "d byte%s) does not match size of '%s' (%" CYTHON_FORMAT_SSIZE_T "d byte%s)", + buf->itemsize, (buf->itemsize > 1) ? "s" : "", + dtype->name, (Py_ssize_t)dtype->size, (dtype->size > 1) ? "s" : ""); + goto fail; + } + if (buf->suboffsets == NULL) buf->suboffsets = __Pyx_minusones; + return 0; +fail:; + __Pyx_SafeReleaseBuffer(buf); + return -1; +} + +/* BufferIndexError */ + static void __Pyx_RaiseBufferIndexError(int axis) { + PyErr_Format(PyExc_IndexError, + "Out of bounds on buffer access (axis %d)", axis); +} + +/* PyIntCompare */ + static CYTHON_INLINE int __Pyx_PyInt_BoolEqObjC(PyObject *op1, PyObject *op2, long intval, long inplace) { + CYTHON_MAYBE_UNUSED_VAR(intval); + CYTHON_UNUSED_VAR(inplace); + if (op1 == op2) { + return 1; + } + #if PY_MAJOR_VERSION < 3 + if (likely(PyInt_CheckExact(op1))) { + const long b = intval; + long a = PyInt_AS_LONG(op1); + return (a == b); + } + #endif + #if CYTHON_USE_PYLONG_INTERNALS + if (likely(PyLong_CheckExact(op1))) { + int unequal; + unsigned long uintval; + Py_ssize_t size = __Pyx_PyLong_DigitCount(op1); + const digit* digits = __Pyx_PyLong_Digits(op1); + if (intval == 0) { + return (__Pyx_PyLong_IsZero(op1) == 1); + } else if (intval < 0) { + if (__Pyx_PyLong_IsNonNeg(op1)) + return 0; + intval = -intval; + } else { + if (__Pyx_PyLong_IsNeg(op1)) + return 0; + } + uintval = (unsigned long) intval; +#if PyLong_SHIFT * 4 < SIZEOF_LONG*8 + if (uintval >> (PyLong_SHIFT * 4)) { + unequal = (size != 5) || (digits[0] != (uintval & (unsigned long) PyLong_MASK)) + | (digits[1] != ((uintval >> (1 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)) | (digits[2] != ((uintval >> (2 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)) | (digits[3] != ((uintval >> (3 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)) | (digits[4] != ((uintval >> (4 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)); + } else +#endif +#if PyLong_SHIFT * 3 < SIZEOF_LONG*8 + if (uintval >> (PyLong_SHIFT * 3)) { + unequal = (size != 4) || (digits[0] != (uintval & (unsigned long) PyLong_MASK)) + | (digits[1] != ((uintval >> (1 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)) | (digits[2] != ((uintval >> (2 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)) | (digits[3] != ((uintval >> (3 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)); + } else +#endif +#if PyLong_SHIFT * 2 < SIZEOF_LONG*8 + if (uintval >> (PyLong_SHIFT * 2)) { + unequal = (size != 3) || (digits[0] != (uintval & (unsigned long) PyLong_MASK)) + | (digits[1] != ((uintval >> (1 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)) | (digits[2] != ((uintval >> (2 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)); + } else +#endif +#if PyLong_SHIFT * 1 < SIZEOF_LONG*8 + if (uintval >> (PyLong_SHIFT * 1)) { + unequal = (size != 2) || (digits[0] != (uintval & (unsigned long) PyLong_MASK)) + | (digits[1] != ((uintval >> (1 * PyLong_SHIFT)) & (unsigned long) PyLong_MASK)); + } else +#endif + unequal = (size != 1) || (((unsigned long) digits[0]) != (uintval & (unsigned long) PyLong_MASK)); + return (unequal == 0); + } + #endif + if (PyFloat_CheckExact(op1)) { + const long b = intval; +#if CYTHON_COMPILING_IN_LIMITED_API + double a = __pyx_PyFloat_AsDouble(op1); +#else + double a = PyFloat_AS_DOUBLE(op1); +#endif + return ((double)a == (double)b); + } + return __Pyx_PyObject_IsTrueAndDecref( + PyObject_RichCompare(op1, op2, Py_EQ)); +} + +/* PyObject_GenericGetAttrNoDict */ + #if CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP && PY_VERSION_HEX < 0x03070000 +static PyObject *__Pyx_RaiseGenericGetAttributeError(PyTypeObject *tp, PyObject *attr_name) { + __Pyx_TypeName type_name = __Pyx_PyType_GetName(tp); + PyErr_Format(PyExc_AttributeError, +#if PY_MAJOR_VERSION >= 3 + "'" __Pyx_FMT_TYPENAME "' object has no attribute '%U'", + type_name, attr_name); +#else + "'" __Pyx_FMT_TYPENAME "' object has no attribute '%.400s'", + type_name, PyString_AS_STRING(attr_name)); +#endif + __Pyx_DECREF_TypeName(type_name); + return NULL; +} +static CYTHON_INLINE PyObject* __Pyx_PyObject_GenericGetAttrNoDict(PyObject* obj, PyObject* attr_name) { + PyObject *descr; + PyTypeObject *tp = Py_TYPE(obj); + if (unlikely(!PyString_Check(attr_name))) { + return PyObject_GenericGetAttr(obj, attr_name); + } + assert(!tp->tp_dictoffset); + descr = _PyType_Lookup(tp, attr_name); + if (unlikely(!descr)) { + return __Pyx_RaiseGenericGetAttributeError(tp, attr_name); + } + Py_INCREF(descr); + #if PY_MAJOR_VERSION < 3 + if (likely(PyType_HasFeature(Py_TYPE(descr), Py_TPFLAGS_HAVE_CLASS))) + #endif + { + descrgetfunc f = Py_TYPE(descr)->tp_descr_get; + if (unlikely(f)) { + PyObject *res = f(descr, obj, (PyObject *)tp); + Py_DECREF(descr); + return res; + } + } + return descr; +} +#endif + +/* PyObject_GenericGetAttr */ + #if CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP && PY_VERSION_HEX < 0x03070000 +static PyObject* __Pyx_PyObject_GenericGetAttr(PyObject* obj, PyObject* attr_name) { + if (unlikely(Py_TYPE(obj)->tp_dictoffset)) { + return PyObject_GenericGetAttr(obj, attr_name); + } + return __Pyx_PyObject_GenericGetAttrNoDict(obj, attr_name); +} +#endif + +/* FixUpExtensionType */ + #if CYTHON_USE_TYPE_SPECS +static int __Pyx_fix_up_extension_type_from_spec(PyType_Spec *spec, PyTypeObject *type) { +#if PY_VERSION_HEX > 0x030900B1 || CYTHON_COMPILING_IN_LIMITED_API + CYTHON_UNUSED_VAR(spec); + CYTHON_UNUSED_VAR(type); +#else + const PyType_Slot *slot = spec->slots; + while (slot && slot->slot && slot->slot != Py_tp_members) + slot++; + if (slot && slot->slot == Py_tp_members) { + int changed = 0; +#if !(PY_VERSION_HEX <= 0x030900b1 && CYTHON_COMPILING_IN_CPYTHON) + const +#endif + PyMemberDef *memb = (PyMemberDef*) slot->pfunc; + while (memb && memb->name) { + if (memb->name[0] == '_' && memb->name[1] == '_') { +#if PY_VERSION_HEX < 0x030900b1 + if (strcmp(memb->name, "__weaklistoffset__") == 0) { + assert(memb->type == T_PYSSIZET); + assert(memb->flags == READONLY); + type->tp_weaklistoffset = memb->offset; + changed = 1; + } + else if (strcmp(memb->name, "__dictoffset__") == 0) { + assert(memb->type == T_PYSSIZET); + assert(memb->flags == READONLY); + type->tp_dictoffset = memb->offset; + changed = 1; + } +#if CYTHON_METH_FASTCALL + else if (strcmp(memb->name, "__vectorcalloffset__") == 0) { + assert(memb->type == T_PYSSIZET); + assert(memb->flags == READONLY); +#if PY_VERSION_HEX >= 0x030800b4 + type->tp_vectorcall_offset = memb->offset; +#else + type->tp_print = (printfunc) memb->offset; +#endif + changed = 1; + } +#endif +#else + if ((0)); +#endif +#if PY_VERSION_HEX <= 0x030900b1 && CYTHON_COMPILING_IN_CPYTHON + else if (strcmp(memb->name, "__module__") == 0) { + PyObject *descr; + assert(memb->type == T_OBJECT); + assert(memb->flags == 0 || memb->flags == READONLY); + descr = PyDescr_NewMember(type, memb); + if (unlikely(!descr)) + return -1; + if (unlikely(PyDict_SetItem(type->tp_dict, PyDescr_NAME(descr), descr) < 0)) { + Py_DECREF(descr); + return -1; + } + Py_DECREF(descr); + changed = 1; + } +#endif + } + memb++; + } + if (changed) + PyType_Modified(type); + } +#endif + return 0; +} +#endif + +/* PyObjectCallNoArg */ + static CYTHON_INLINE PyObject* __Pyx_PyObject_CallNoArg(PyObject *func) { + PyObject *arg[2] = {NULL, NULL}; + return __Pyx_PyObject_FastCall(func, arg + 1, 0 | __Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET); +} + +/* PyObjectGetMethod */ + static int __Pyx_PyObject_GetMethod(PyObject *obj, PyObject *name, PyObject **method) { + PyObject *attr; +#if CYTHON_UNPACK_METHODS && CYTHON_COMPILING_IN_CPYTHON && CYTHON_USE_PYTYPE_LOOKUP + __Pyx_TypeName type_name; + PyTypeObject *tp = Py_TYPE(obj); + PyObject *descr; + descrgetfunc f = NULL; + PyObject **dictptr, *dict; + int meth_found = 0; + assert (*method == NULL); + if (unlikely(tp->tp_getattro != PyObject_GenericGetAttr)) { + attr = __Pyx_PyObject_GetAttrStr(obj, name); + goto try_unpack; + } + if (unlikely(tp->tp_dict == NULL) && unlikely(PyType_Ready(tp) < 0)) { + return 0; + } + descr = _PyType_Lookup(tp, name); + if (likely(descr != NULL)) { + Py_INCREF(descr); +#if defined(Py_TPFLAGS_METHOD_DESCRIPTOR) && Py_TPFLAGS_METHOD_DESCRIPTOR + if (__Pyx_PyType_HasFeature(Py_TYPE(descr), Py_TPFLAGS_METHOD_DESCRIPTOR)) +#elif PY_MAJOR_VERSION >= 3 + #ifdef __Pyx_CyFunction_USED + if (likely(PyFunction_Check(descr) || __Pyx_IS_TYPE(descr, &PyMethodDescr_Type) || __Pyx_CyFunction_Check(descr))) + #else + if (likely(PyFunction_Check(descr) || __Pyx_IS_TYPE(descr, &PyMethodDescr_Type))) + #endif +#else + #ifdef __Pyx_CyFunction_USED + if (likely(PyFunction_Check(descr) || __Pyx_CyFunction_Check(descr))) + #else + if (likely(PyFunction_Check(descr))) + #endif +#endif + { + meth_found = 1; + } else { + f = Py_TYPE(descr)->tp_descr_get; + if (f != NULL && PyDescr_IsData(descr)) { + attr = f(descr, obj, (PyObject *)Py_TYPE(obj)); + Py_DECREF(descr); + goto try_unpack; + } + } + } + dictptr = _PyObject_GetDictPtr(obj); + if (dictptr != NULL && (dict = *dictptr) != NULL) { + Py_INCREF(dict); + attr = __Pyx_PyDict_GetItemStr(dict, name); + if (attr != NULL) { + Py_INCREF(attr); + Py_DECREF(dict); + Py_XDECREF(descr); + goto try_unpack; + } + Py_DECREF(dict); + } + if (meth_found) { + *method = descr; + return 1; + } + if (f != NULL) { + attr = f(descr, obj, (PyObject *)Py_TYPE(obj)); + Py_DECREF(descr); + goto try_unpack; + } + if (likely(descr != NULL)) { + *method = descr; + return 0; + } + type_name = __Pyx_PyType_GetName(tp); + PyErr_Format(PyExc_AttributeError, +#if PY_MAJOR_VERSION >= 3 + "'" __Pyx_FMT_TYPENAME "' object has no attribute '%U'", + type_name, name); +#else + "'" __Pyx_FMT_TYPENAME "' object has no attribute '%.400s'", + type_name, PyString_AS_STRING(name)); +#endif + __Pyx_DECREF_TypeName(type_name); + return 0; +#else + attr = __Pyx_PyObject_GetAttrStr(obj, name); + goto try_unpack; +#endif +try_unpack: +#if CYTHON_UNPACK_METHODS + if (likely(attr) && PyMethod_Check(attr) && likely(PyMethod_GET_SELF(attr) == obj)) { + PyObject *function = PyMethod_GET_FUNCTION(attr); + Py_INCREF(function); + Py_DECREF(attr); + *method = function; + return 1; + } +#endif + *method = attr; + return 0; +} + +/* PyObjectCallMethod0 */ + static PyObject* __Pyx_PyObject_CallMethod0(PyObject* obj, PyObject* method_name) { + PyObject *method = NULL, *result = NULL; + int is_method = __Pyx_PyObject_GetMethod(obj, method_name, &method); + if (likely(is_method)) { + result = __Pyx_PyObject_CallOneArg(method, obj); + Py_DECREF(method); + return result; + } + if (unlikely(!method)) goto bad; + result = __Pyx_PyObject_CallNoArg(method); + Py_DECREF(method); +bad: + return result; +} + +/* ValidateBasesTuple */ + #if CYTHON_COMPILING_IN_CPYTHON || CYTHON_COMPILING_IN_LIMITED_API || CYTHON_USE_TYPE_SPECS +static int __Pyx_validate_bases_tuple(const char *type_name, Py_ssize_t dictoffset, PyObject *bases) { + Py_ssize_t i, n; +#if CYTHON_ASSUME_SAFE_MACROS + n = PyTuple_GET_SIZE(bases); +#else + n = PyTuple_Size(bases); + if (n < 0) return -1; +#endif + for (i = 1; i < n; i++) + { +#if CYTHON_AVOID_BORROWED_REFS + PyObject *b0 = PySequence_GetItem(bases, i); + if (!b0) return -1; +#elif CYTHON_ASSUME_SAFE_MACROS + PyObject *b0 = PyTuple_GET_ITEM(bases, i); +#else + PyObject *b0 = PyTuple_GetItem(bases, i); + if (!b0) return -1; +#endif + PyTypeObject *b; +#if PY_MAJOR_VERSION < 3 + if (PyClass_Check(b0)) + { + PyErr_Format(PyExc_TypeError, "base class '%.200s' is an old-style class", + PyString_AS_STRING(((PyClassObject*)b0)->cl_name)); +#if CYTHON_AVOID_BORROWED_REFS + Py_DECREF(b0); +#endif + return -1; + } +#endif + b = (PyTypeObject*) b0; + if (!__Pyx_PyType_HasFeature(b, Py_TPFLAGS_HEAPTYPE)) + { + __Pyx_TypeName b_name = __Pyx_PyType_GetName(b); + PyErr_Format(PyExc_TypeError, + "base class '" __Pyx_FMT_TYPENAME "' is not a heap type", b_name); + __Pyx_DECREF_TypeName(b_name); +#if CYTHON_AVOID_BORROWED_REFS + Py_DECREF(b0); +#endif + return -1; + } + if (dictoffset == 0) + { + Py_ssize_t b_dictoffset = 0; +#if CYTHON_USE_TYPE_SLOTS || CYTHON_COMPILING_IN_PYPY + b_dictoffset = b->tp_dictoffset; +#else + PyObject *py_b_dictoffset = PyObject_GetAttrString((PyObject*)b, "__dictoffset__"); + if (!py_b_dictoffset) goto dictoffset_return; + b_dictoffset = PyLong_AsSsize_t(py_b_dictoffset); + Py_DECREF(py_b_dictoffset); + if (b_dictoffset == -1 && PyErr_Occurred()) goto dictoffset_return; +#endif + if (b_dictoffset) { + { + __Pyx_TypeName b_name = __Pyx_PyType_GetName(b); + PyErr_Format(PyExc_TypeError, + "extension type '%.200s' has no __dict__ slot, " + "but base type '" __Pyx_FMT_TYPENAME "' has: " + "either add 'cdef dict __dict__' to the extension type " + "or add '__slots__ = [...]' to the base type", + type_name, b_name); + __Pyx_DECREF_TypeName(b_name); + } +#if !(CYTHON_USE_TYPE_SLOTS || CYTHON_COMPILING_IN_PYPY) + dictoffset_return: +#endif +#if CYTHON_AVOID_BORROWED_REFS + Py_DECREF(b0); +#endif + return -1; + } + } +#if CYTHON_AVOID_BORROWED_REFS + Py_DECREF(b0); +#endif + } + return 0; +} +#endif + +/* PyType_Ready */ + static int __Pyx_PyType_Ready(PyTypeObject *t) { +#if CYTHON_USE_TYPE_SPECS || !(CYTHON_COMPILING_IN_CPYTHON || CYTHON_COMPILING_IN_LIMITED_API) || defined(PYSTON_MAJOR_VERSION) + (void)__Pyx_PyObject_CallMethod0; +#if CYTHON_USE_TYPE_SPECS + (void)__Pyx_validate_bases_tuple; +#endif + return PyType_Ready(t); +#else + int r; + PyObject *bases = __Pyx_PyType_GetSlot(t, tp_bases, PyObject*); + if (bases && unlikely(__Pyx_validate_bases_tuple(t->tp_name, t->tp_dictoffset, bases) == -1)) + return -1; +#if PY_VERSION_HEX >= 0x03050000 && !defined(PYSTON_MAJOR_VERSION) + { + int gc_was_enabled; + #if PY_VERSION_HEX >= 0x030A00b1 + gc_was_enabled = PyGC_Disable(); + (void)__Pyx_PyObject_CallMethod0; + #else + PyObject *ret, *py_status; + PyObject *gc = NULL; + #if PY_VERSION_HEX >= 0x030700a1 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM+0 >= 0x07030400) + gc = PyImport_GetModule(__pyx_kp_u_gc); + #endif + if (unlikely(!gc)) gc = PyImport_Import(__pyx_kp_u_gc); + if (unlikely(!gc)) return -1; + py_status = __Pyx_PyObject_CallMethod0(gc, __pyx_kp_u_isenabled); + if (unlikely(!py_status)) { + Py_DECREF(gc); + return -1; + } + gc_was_enabled = __Pyx_PyObject_IsTrue(py_status); + Py_DECREF(py_status); + if (gc_was_enabled > 0) { + ret = __Pyx_PyObject_CallMethod0(gc, __pyx_kp_u_disable); + if (unlikely(!ret)) { + Py_DECREF(gc); + return -1; + } + Py_DECREF(ret); + } else if (unlikely(gc_was_enabled == -1)) { + Py_DECREF(gc); + return -1; + } + #endif + t->tp_flags |= Py_TPFLAGS_HEAPTYPE; +#if PY_VERSION_HEX >= 0x030A0000 + t->tp_flags |= Py_TPFLAGS_IMMUTABLETYPE; +#endif +#else + (void)__Pyx_PyObject_CallMethod0; +#endif + r = PyType_Ready(t); +#if PY_VERSION_HEX >= 0x03050000 && !defined(PYSTON_MAJOR_VERSION) + t->tp_flags &= ~Py_TPFLAGS_HEAPTYPE; + #if PY_VERSION_HEX >= 0x030A00b1 + if (gc_was_enabled) + PyGC_Enable(); + #else + if (gc_was_enabled) { + PyObject *tp, *v, *tb; + PyErr_Fetch(&tp, &v, &tb); + ret = __Pyx_PyObject_CallMethod0(gc, __pyx_kp_u_enable); + if (likely(ret || r == -1)) { + Py_XDECREF(ret); + PyErr_Restore(tp, v, tb); + } else { + Py_XDECREF(tp); + Py_XDECREF(v); + Py_XDECREF(tb); + r = -1; + } + } + Py_DECREF(gc); + #endif + } +#endif + return r; +#endif +} + +/* SetVTable */ + static int __Pyx_SetVtable(PyTypeObject *type, void *vtable) { + PyObject *ob = PyCapsule_New(vtable, 0, 0); + if (unlikely(!ob)) + goto bad; +#if CYTHON_COMPILING_IN_LIMITED_API + if (unlikely(PyObject_SetAttr((PyObject *) type, __pyx_n_s_pyx_vtable, ob) < 0)) +#else + if (unlikely(PyDict_SetItem(type->tp_dict, __pyx_n_s_pyx_vtable, ob) < 0)) +#endif + goto bad; + Py_DECREF(ob); + return 0; +bad: + Py_XDECREF(ob); + return -1; +} + +/* GetVTable */ + static void* __Pyx_GetVtable(PyTypeObject *type) { + void* ptr; +#if CYTHON_COMPILING_IN_LIMITED_API + PyObject *ob = PyObject_GetAttr((PyObject *)type, __pyx_n_s_pyx_vtable); +#else + PyObject *ob = PyObject_GetItem(type->tp_dict, __pyx_n_s_pyx_vtable); +#endif + if (!ob) + goto bad; + ptr = PyCapsule_GetPointer(ob, 0); + if (!ptr && !PyErr_Occurred()) + PyErr_SetString(PyExc_RuntimeError, "invalid vtable found for imported type"); + Py_DECREF(ob); + return ptr; +bad: + Py_XDECREF(ob); + return NULL; +} + +/* MergeVTables */ + #if !CYTHON_COMPILING_IN_LIMITED_API +static int __Pyx_MergeVtables(PyTypeObject *type) { + int i; + void** base_vtables; + __Pyx_TypeName tp_base_name; + __Pyx_TypeName base_name; + void* unknown = (void*)-1; + PyObject* bases = type->tp_bases; + int base_depth = 0; + { + PyTypeObject* base = type->tp_base; + while (base) { + base_depth += 1; + base = base->tp_base; + } + } + base_vtables = (void**) malloc(sizeof(void*) * (size_t)(base_depth + 1)); + base_vtables[0] = unknown; + for (i = 1; i < PyTuple_GET_SIZE(bases); i++) { + void* base_vtable = __Pyx_GetVtable(((PyTypeObject*)PyTuple_GET_ITEM(bases, i))); + if (base_vtable != NULL) { + int j; + PyTypeObject* base = type->tp_base; + for (j = 0; j < base_depth; j++) { + if (base_vtables[j] == unknown) { + base_vtables[j] = __Pyx_GetVtable(base); + base_vtables[j + 1] = unknown; + } + if (base_vtables[j] == base_vtable) { + break; + } else if (base_vtables[j] == NULL) { + goto bad; + } + base = base->tp_base; + } + } + } + PyErr_Clear(); + free(base_vtables); + return 0; +bad: + tp_base_name = __Pyx_PyType_GetName(type->tp_base); + base_name = __Pyx_PyType_GetName((PyTypeObject*)PyTuple_GET_ITEM(bases, i)); + PyErr_Format(PyExc_TypeError, + "multiple bases have vtable conflict: '" __Pyx_FMT_TYPENAME "' and '" __Pyx_FMT_TYPENAME "'", tp_base_name, base_name); + __Pyx_DECREF_TypeName(tp_base_name); + __Pyx_DECREF_TypeName(base_name); + free(base_vtables); + return -1; +} +#endif + +/* SetupReduce */ + #if !CYTHON_COMPILING_IN_LIMITED_API +static int __Pyx_setup_reduce_is_named(PyObject* meth, PyObject* name) { + int ret; + PyObject *name_attr; + name_attr = __Pyx_PyObject_GetAttrStrNoError(meth, __pyx_n_s_name_2); + if (likely(name_attr)) { + ret = PyObject_RichCompareBool(name_attr, name, Py_EQ); + } else { + ret = -1; + } + if (unlikely(ret < 0)) { + PyErr_Clear(); + ret = 0; + } + Py_XDECREF(name_attr); + return ret; +} +static int __Pyx_setup_reduce(PyObject* type_obj) { + int ret = 0; + PyObject *object_reduce = NULL; + PyObject *object_getstate = NULL; + PyObject *object_reduce_ex = NULL; + PyObject *reduce = NULL; + PyObject *reduce_ex = NULL; + PyObject *reduce_cython = NULL; + PyObject *setstate = NULL; + PyObject *setstate_cython = NULL; + PyObject *getstate = NULL; +#if CYTHON_USE_PYTYPE_LOOKUP + getstate = _PyType_Lookup((PyTypeObject*)type_obj, __pyx_n_s_getstate); +#else + getstate = __Pyx_PyObject_GetAttrStrNoError(type_obj, __pyx_n_s_getstate); + if (!getstate && PyErr_Occurred()) { + goto __PYX_BAD; + } +#endif + if (getstate) { +#if CYTHON_USE_PYTYPE_LOOKUP + object_getstate = _PyType_Lookup(&PyBaseObject_Type, __pyx_n_s_getstate); +#else + object_getstate = __Pyx_PyObject_GetAttrStrNoError((PyObject*)&PyBaseObject_Type, __pyx_n_s_getstate); + if (!object_getstate && PyErr_Occurred()) { + goto __PYX_BAD; + } +#endif + if (object_getstate != getstate) { + goto __PYX_GOOD; + } + } +#if CYTHON_USE_PYTYPE_LOOKUP + object_reduce_ex = _PyType_Lookup(&PyBaseObject_Type, __pyx_n_s_reduce_ex); if (!object_reduce_ex) goto __PYX_BAD; +#else + object_reduce_ex = __Pyx_PyObject_GetAttrStr((PyObject*)&PyBaseObject_Type, __pyx_n_s_reduce_ex); if (!object_reduce_ex) goto __PYX_BAD; +#endif + reduce_ex = __Pyx_PyObject_GetAttrStr(type_obj, __pyx_n_s_reduce_ex); if (unlikely(!reduce_ex)) goto __PYX_BAD; + if (reduce_ex == object_reduce_ex) { +#if CYTHON_USE_PYTYPE_LOOKUP + object_reduce = _PyType_Lookup(&PyBaseObject_Type, __pyx_n_s_reduce); if (!object_reduce) goto __PYX_BAD; +#else + object_reduce = __Pyx_PyObject_GetAttrStr((PyObject*)&PyBaseObject_Type, __pyx_n_s_reduce); if (!object_reduce) goto __PYX_BAD; +#endif + reduce = __Pyx_PyObject_GetAttrStr(type_obj, __pyx_n_s_reduce); if (unlikely(!reduce)) goto __PYX_BAD; + if (reduce == object_reduce || __Pyx_setup_reduce_is_named(reduce, __pyx_n_s_reduce_cython)) { + reduce_cython = __Pyx_PyObject_GetAttrStrNoError(type_obj, __pyx_n_s_reduce_cython); + if (likely(reduce_cython)) { + ret = PyDict_SetItem(((PyTypeObject*)type_obj)->tp_dict, __pyx_n_s_reduce, reduce_cython); if (unlikely(ret < 0)) goto __PYX_BAD; + ret = PyDict_DelItem(((PyTypeObject*)type_obj)->tp_dict, __pyx_n_s_reduce_cython); if (unlikely(ret < 0)) goto __PYX_BAD; + } else if (reduce == object_reduce || PyErr_Occurred()) { + goto __PYX_BAD; + } + setstate = __Pyx_PyObject_GetAttrStrNoError(type_obj, __pyx_n_s_setstate); + if (!setstate) PyErr_Clear(); + if (!setstate || __Pyx_setup_reduce_is_named(setstate, __pyx_n_s_setstate_cython)) { + setstate_cython = __Pyx_PyObject_GetAttrStrNoError(type_obj, __pyx_n_s_setstate_cython); + if (likely(setstate_cython)) { + ret = PyDict_SetItem(((PyTypeObject*)type_obj)->tp_dict, __pyx_n_s_setstate, setstate_cython); if (unlikely(ret < 0)) goto __PYX_BAD; + ret = PyDict_DelItem(((PyTypeObject*)type_obj)->tp_dict, __pyx_n_s_setstate_cython); if (unlikely(ret < 0)) goto __PYX_BAD; + } else if (!setstate || PyErr_Occurred()) { + goto __PYX_BAD; + } + } + PyType_Modified((PyTypeObject*)type_obj); + } + } + goto __PYX_GOOD; +__PYX_BAD: + if (!PyErr_Occurred()) { + __Pyx_TypeName type_obj_name = + __Pyx_PyType_GetName((PyTypeObject*)type_obj); + PyErr_Format(PyExc_RuntimeError, + "Unable to initialize pickling for " __Pyx_FMT_TYPENAME, type_obj_name); + __Pyx_DECREF_TypeName(type_obj_name); + } + ret = -1; +__PYX_GOOD: +#if !CYTHON_USE_PYTYPE_LOOKUP + Py_XDECREF(object_reduce); + Py_XDECREF(object_reduce_ex); + Py_XDECREF(object_getstate); + Py_XDECREF(getstate); +#endif + Py_XDECREF(reduce); + Py_XDECREF(reduce_ex); + Py_XDECREF(reduce_cython); + Py_XDECREF(setstate); + Py_XDECREF(setstate_cython); + return ret; +} +#endif + +/* TypeImport */ + #ifndef __PYX_HAVE_RT_ImportType_3_0_8 +#define __PYX_HAVE_RT_ImportType_3_0_8 +static PyTypeObject *__Pyx_ImportType_3_0_8(PyObject *module, const char *module_name, const char *class_name, + size_t size, size_t alignment, enum __Pyx_ImportType_CheckSize_3_0_8 check_size) +{ + PyObject *result = 0; + char warning[200]; + Py_ssize_t basicsize; + Py_ssize_t itemsize; +#if CYTHON_COMPILING_IN_LIMITED_API + PyObject *py_basicsize; + PyObject *py_itemsize; +#endif + result = PyObject_GetAttrString(module, class_name); + if (!result) + goto bad; + if (!PyType_Check(result)) { + PyErr_Format(PyExc_TypeError, + "%.200s.%.200s is not a type object", + module_name, class_name); + goto bad; + } +#if !CYTHON_COMPILING_IN_LIMITED_API + basicsize = ((PyTypeObject *)result)->tp_basicsize; + itemsize = ((PyTypeObject *)result)->tp_itemsize; +#else + py_basicsize = PyObject_GetAttrString(result, "__basicsize__"); + if (!py_basicsize) + goto bad; + basicsize = PyLong_AsSsize_t(py_basicsize); + Py_DECREF(py_basicsize); + py_basicsize = 0; + if (basicsize == (Py_ssize_t)-1 && PyErr_Occurred()) + goto bad; + py_itemsize = PyObject_GetAttrString(result, "__itemsize__"); + if (!py_itemsize) + goto bad; + itemsize = PyLong_AsSsize_t(py_itemsize); + Py_DECREF(py_itemsize); + py_itemsize = 0; + if (itemsize == (Py_ssize_t)-1 && PyErr_Occurred()) + goto bad; +#endif + if (itemsize) { + if (size % alignment) { + alignment = size % alignment; + } + if (itemsize < (Py_ssize_t)alignment) + itemsize = (Py_ssize_t)alignment; + } + if ((size_t)(basicsize + itemsize) < size) { + PyErr_Format(PyExc_ValueError, + "%.200s.%.200s size changed, may indicate binary incompatibility. " + "Expected %zd from C header, got %zd from PyObject", + module_name, class_name, size, basicsize+itemsize); + goto bad; + } + if (check_size == __Pyx_ImportType_CheckSize_Error_3_0_8 && + ((size_t)basicsize > size || (size_t)(basicsize + itemsize) < size)) { + PyErr_Format(PyExc_ValueError, + "%.200s.%.200s size changed, may indicate binary incompatibility. " + "Expected %zd from C header, got %zd-%zd from PyObject", + module_name, class_name, size, basicsize, basicsize+itemsize); + goto bad; + } + else if (check_size == __Pyx_ImportType_CheckSize_Warn_3_0_8 && (size_t)basicsize > size) { + PyOS_snprintf(warning, sizeof(warning), + "%s.%s size changed, may indicate binary incompatibility. " + "Expected %zd from C header, got %zd from PyObject", + module_name, class_name, size, basicsize); + if (PyErr_WarnEx(NULL, warning, 0) < 0) goto bad; + } + return (PyTypeObject *)result; +bad: + Py_XDECREF(result); + return NULL; +} +#endif + +/* FetchSharedCythonModule */ + static PyObject *__Pyx_FetchSharedCythonABIModule(void) { + return __Pyx_PyImport_AddModuleRef((char*) __PYX_ABI_MODULE_NAME); +} + +/* FetchCommonType */ + static int __Pyx_VerifyCachedType(PyObject *cached_type, + const char *name, + Py_ssize_t basicsize, + Py_ssize_t expected_basicsize) { + if (!PyType_Check(cached_type)) { + PyErr_Format(PyExc_TypeError, + "Shared Cython type %.200s is not a type object", name); + return -1; + } + if (basicsize != expected_basicsize) { + PyErr_Format(PyExc_TypeError, + "Shared Cython type %.200s has the wrong size, try recompiling", + name); + return -1; + } + return 0; +} +#if !CYTHON_USE_TYPE_SPECS +static PyTypeObject* __Pyx_FetchCommonType(PyTypeObject* type) { + PyObject* abi_module; + const char* object_name; + PyTypeObject *cached_type = NULL; + abi_module = __Pyx_FetchSharedCythonABIModule(); + if (!abi_module) return NULL; + object_name = strrchr(type->tp_name, '.'); + object_name = object_name ? object_name+1 : type->tp_name; + cached_type = (PyTypeObject*) PyObject_GetAttrString(abi_module, object_name); + if (cached_type) { + if (__Pyx_VerifyCachedType( + (PyObject *)cached_type, + object_name, + cached_type->tp_basicsize, + type->tp_basicsize) < 0) { + goto bad; + } + goto done; + } + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) goto bad; + PyErr_Clear(); + if (PyType_Ready(type) < 0) goto bad; + if (PyObject_SetAttrString(abi_module, object_name, (PyObject *)type) < 0) + goto bad; + Py_INCREF(type); + cached_type = type; +done: + Py_DECREF(abi_module); + return cached_type; +bad: + Py_XDECREF(cached_type); + cached_type = NULL; + goto done; +} +#else +static PyTypeObject *__Pyx_FetchCommonTypeFromSpec(PyObject *module, PyType_Spec *spec, PyObject *bases) { + PyObject *abi_module, *cached_type = NULL; + const char* object_name = strrchr(spec->name, '.'); + object_name = object_name ? object_name+1 : spec->name; + abi_module = __Pyx_FetchSharedCythonABIModule(); + if (!abi_module) return NULL; + cached_type = PyObject_GetAttrString(abi_module, object_name); + if (cached_type) { + Py_ssize_t basicsize; +#if CYTHON_COMPILING_IN_LIMITED_API + PyObject *py_basicsize; + py_basicsize = PyObject_GetAttrString(cached_type, "__basicsize__"); + if (unlikely(!py_basicsize)) goto bad; + basicsize = PyLong_AsSsize_t(py_basicsize); + Py_DECREF(py_basicsize); + py_basicsize = 0; + if (unlikely(basicsize == (Py_ssize_t)-1) && PyErr_Occurred()) goto bad; +#else + basicsize = likely(PyType_Check(cached_type)) ? ((PyTypeObject*) cached_type)->tp_basicsize : -1; +#endif + if (__Pyx_VerifyCachedType( + cached_type, + object_name, + basicsize, + spec->basicsize) < 0) { + goto bad; + } + goto done; + } + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) goto bad; + PyErr_Clear(); + CYTHON_UNUSED_VAR(module); + cached_type = __Pyx_PyType_FromModuleAndSpec(abi_module, spec, bases); + if (unlikely(!cached_type)) goto bad; + if (unlikely(__Pyx_fix_up_extension_type_from_spec(spec, (PyTypeObject *) cached_type) < 0)) goto bad; + if (PyObject_SetAttrString(abi_module, object_name, cached_type) < 0) goto bad; +done: + Py_DECREF(abi_module); + assert(cached_type == NULL || PyType_Check(cached_type)); + return (PyTypeObject *) cached_type; +bad: + Py_XDECREF(cached_type); + cached_type = NULL; + goto done; +} +#endif + +/* PyVectorcallFastCallDict */ + #if CYTHON_METH_FASTCALL +static PyObject *__Pyx_PyVectorcall_FastCallDict_kw(PyObject *func, __pyx_vectorcallfunc vc, PyObject *const *args, size_t nargs, PyObject *kw) +{ + PyObject *res = NULL; + PyObject *kwnames; + PyObject **newargs; + PyObject **kwvalues; + Py_ssize_t i, pos; + size_t j; + PyObject *key, *value; + unsigned long keys_are_strings; + Py_ssize_t nkw = PyDict_GET_SIZE(kw); + newargs = (PyObject **)PyMem_Malloc((nargs + (size_t)nkw) * sizeof(args[0])); + if (unlikely(newargs == NULL)) { + PyErr_NoMemory(); + return NULL; + } + for (j = 0; j < nargs; j++) newargs[j] = args[j]; + kwnames = PyTuple_New(nkw); + if (unlikely(kwnames == NULL)) { + PyMem_Free(newargs); + return NULL; + } + kwvalues = newargs + nargs; + pos = i = 0; + keys_are_strings = Py_TPFLAGS_UNICODE_SUBCLASS; + while (PyDict_Next(kw, &pos, &key, &value)) { + keys_are_strings &= Py_TYPE(key)->tp_flags; + Py_INCREF(key); + Py_INCREF(value); + PyTuple_SET_ITEM(kwnames, i, key); + kwvalues[i] = value; + i++; + } + if (unlikely(!keys_are_strings)) { + PyErr_SetString(PyExc_TypeError, "keywords must be strings"); + goto cleanup; + } + res = vc(func, newargs, nargs, kwnames); +cleanup: + Py_DECREF(kwnames); + for (i = 0; i < nkw; i++) + Py_DECREF(kwvalues[i]); + PyMem_Free(newargs); + return res; +} +static CYTHON_INLINE PyObject *__Pyx_PyVectorcall_FastCallDict(PyObject *func, __pyx_vectorcallfunc vc, PyObject *const *args, size_t nargs, PyObject *kw) +{ + if (likely(kw == NULL) || PyDict_GET_SIZE(kw) == 0) { + return vc(func, args, nargs, NULL); + } + return __Pyx_PyVectorcall_FastCallDict_kw(func, vc, args, nargs, kw); +} +#endif + +/* CythonFunctionShared */ + #if CYTHON_COMPILING_IN_LIMITED_API +static CYTHON_INLINE int __Pyx__IsSameCyOrCFunction(PyObject *func, void *cfunc) { + if (__Pyx_CyFunction_Check(func)) { + return PyCFunction_GetFunction(((__pyx_CyFunctionObject*)func)->func) == (PyCFunction) cfunc; + } else if (PyCFunction_Check(func)) { + return PyCFunction_GetFunction(func) == (PyCFunction) cfunc; + } + return 0; +} +#else +static CYTHON_INLINE int __Pyx__IsSameCyOrCFunction(PyObject *func, void *cfunc) { + return __Pyx_CyOrPyCFunction_Check(func) && __Pyx_CyOrPyCFunction_GET_FUNCTION(func) == (PyCFunction) cfunc; +} +#endif +static CYTHON_INLINE void __Pyx__CyFunction_SetClassObj(__pyx_CyFunctionObject* f, PyObject* classobj) { +#if PY_VERSION_HEX < 0x030900B1 || CYTHON_COMPILING_IN_LIMITED_API + __Pyx_Py_XDECREF_SET( + __Pyx_CyFunction_GetClassObj(f), + ((classobj) ? __Pyx_NewRef(classobj) : NULL)); +#else + __Pyx_Py_XDECREF_SET( + ((PyCMethodObject *) (f))->mm_class, + (PyTypeObject*)((classobj) ? __Pyx_NewRef(classobj) : NULL)); +#endif +} +static PyObject * +__Pyx_CyFunction_get_doc(__pyx_CyFunctionObject *op, void *closure) +{ + CYTHON_UNUSED_VAR(closure); + if (unlikely(op->func_doc == NULL)) { +#if CYTHON_COMPILING_IN_LIMITED_API + op->func_doc = PyObject_GetAttrString(op->func, "__doc__"); + if (unlikely(!op->func_doc)) return NULL; +#else + if (((PyCFunctionObject*)op)->m_ml->ml_doc) { +#if PY_MAJOR_VERSION >= 3 + op->func_doc = PyUnicode_FromString(((PyCFunctionObject*)op)->m_ml->ml_doc); +#else + op->func_doc = PyString_FromString(((PyCFunctionObject*)op)->m_ml->ml_doc); +#endif + if (unlikely(op->func_doc == NULL)) + return NULL; + } else { + Py_INCREF(Py_None); + return Py_None; + } +#endif + } + Py_INCREF(op->func_doc); + return op->func_doc; +} +static int +__Pyx_CyFunction_set_doc(__pyx_CyFunctionObject *op, PyObject *value, void *context) +{ + CYTHON_UNUSED_VAR(context); + if (value == NULL) { + value = Py_None; + } + Py_INCREF(value); + __Pyx_Py_XDECREF_SET(op->func_doc, value); + return 0; +} +static PyObject * +__Pyx_CyFunction_get_name(__pyx_CyFunctionObject *op, void *context) +{ + CYTHON_UNUSED_VAR(context); + if (unlikely(op->func_name == NULL)) { +#if CYTHON_COMPILING_IN_LIMITED_API + op->func_name = PyObject_GetAttrString(op->func, "__name__"); +#elif PY_MAJOR_VERSION >= 3 + op->func_name = PyUnicode_InternFromString(((PyCFunctionObject*)op)->m_ml->ml_name); +#else + op->func_name = PyString_InternFromString(((PyCFunctionObject*)op)->m_ml->ml_name); +#endif + if (unlikely(op->func_name == NULL)) + return NULL; + } + Py_INCREF(op->func_name); + return op->func_name; +} +static int +__Pyx_CyFunction_set_name(__pyx_CyFunctionObject *op, PyObject *value, void *context) +{ + CYTHON_UNUSED_VAR(context); +#if PY_MAJOR_VERSION >= 3 + if (unlikely(value == NULL || !PyUnicode_Check(value))) +#else + if (unlikely(value == NULL || !PyString_Check(value))) +#endif + { + PyErr_SetString(PyExc_TypeError, + "__name__ must be set to a string object"); + return -1; + } + Py_INCREF(value); + __Pyx_Py_XDECREF_SET(op->func_name, value); + return 0; +} +static PyObject * +__Pyx_CyFunction_get_qualname(__pyx_CyFunctionObject *op, void *context) +{ + CYTHON_UNUSED_VAR(context); + Py_INCREF(op->func_qualname); + return op->func_qualname; +} +static int +__Pyx_CyFunction_set_qualname(__pyx_CyFunctionObject *op, PyObject *value, void *context) +{ + CYTHON_UNUSED_VAR(context); +#if PY_MAJOR_VERSION >= 3 + if (unlikely(value == NULL || !PyUnicode_Check(value))) +#else + if (unlikely(value == NULL || !PyString_Check(value))) +#endif + { + PyErr_SetString(PyExc_TypeError, + "__qualname__ must be set to a string object"); + return -1; + } + Py_INCREF(value); + __Pyx_Py_XDECREF_SET(op->func_qualname, value); + return 0; +} +static PyObject * +__Pyx_CyFunction_get_dict(__pyx_CyFunctionObject *op, void *context) +{ + CYTHON_UNUSED_VAR(context); + if (unlikely(op->func_dict == NULL)) { + op->func_dict = PyDict_New(); + if (unlikely(op->func_dict == NULL)) + return NULL; + } + Py_INCREF(op->func_dict); + return op->func_dict; +} +static int +__Pyx_CyFunction_set_dict(__pyx_CyFunctionObject *op, PyObject *value, void *context) +{ + CYTHON_UNUSED_VAR(context); + if (unlikely(value == NULL)) { + PyErr_SetString(PyExc_TypeError, + "function's dictionary may not be deleted"); + return -1; + } + if (unlikely(!PyDict_Check(value))) { + PyErr_SetString(PyExc_TypeError, + "setting function's dictionary to a non-dict"); + return -1; + } + Py_INCREF(value); + __Pyx_Py_XDECREF_SET(op->func_dict, value); + return 0; +} +static PyObject * +__Pyx_CyFunction_get_globals(__pyx_CyFunctionObject *op, void *context) +{ + CYTHON_UNUSED_VAR(context); + Py_INCREF(op->func_globals); + return op->func_globals; +} +static PyObject * +__Pyx_CyFunction_get_closure(__pyx_CyFunctionObject *op, void *context) +{ + CYTHON_UNUSED_VAR(op); + CYTHON_UNUSED_VAR(context); + Py_INCREF(Py_None); + return Py_None; +} +static PyObject * +__Pyx_CyFunction_get_code(__pyx_CyFunctionObject *op, void *context) +{ + PyObject* result = (op->func_code) ? op->func_code : Py_None; + CYTHON_UNUSED_VAR(context); + Py_INCREF(result); + return result; +} +static int +__Pyx_CyFunction_init_defaults(__pyx_CyFunctionObject *op) { + int result = 0; + PyObject *res = op->defaults_getter((PyObject *) op); + if (unlikely(!res)) + return -1; + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + op->defaults_tuple = PyTuple_GET_ITEM(res, 0); + Py_INCREF(op->defaults_tuple); + op->defaults_kwdict = PyTuple_GET_ITEM(res, 1); + Py_INCREF(op->defaults_kwdict); + #else + op->defaults_tuple = __Pyx_PySequence_ITEM(res, 0); + if (unlikely(!op->defaults_tuple)) result = -1; + else { + op->defaults_kwdict = __Pyx_PySequence_ITEM(res, 1); + if (unlikely(!op->defaults_kwdict)) result = -1; + } + #endif + Py_DECREF(res); + return result; +} +static int +__Pyx_CyFunction_set_defaults(__pyx_CyFunctionObject *op, PyObject* value, void *context) { + CYTHON_UNUSED_VAR(context); + if (!value) { + value = Py_None; + } else if (unlikely(value != Py_None && !PyTuple_Check(value))) { + PyErr_SetString(PyExc_TypeError, + "__defaults__ must be set to a tuple object"); + return -1; + } + PyErr_WarnEx(PyExc_RuntimeWarning, "changes to cyfunction.__defaults__ will not " + "currently affect the values used in function calls", 1); + Py_INCREF(value); + __Pyx_Py_XDECREF_SET(op->defaults_tuple, value); + return 0; +} +static PyObject * +__Pyx_CyFunction_get_defaults(__pyx_CyFunctionObject *op, void *context) { + PyObject* result = op->defaults_tuple; + CYTHON_UNUSED_VAR(context); + if (unlikely(!result)) { + if (op->defaults_getter) { + if (unlikely(__Pyx_CyFunction_init_defaults(op) < 0)) return NULL; + result = op->defaults_tuple; + } else { + result = Py_None; + } + } + Py_INCREF(result); + return result; +} +static int +__Pyx_CyFunction_set_kwdefaults(__pyx_CyFunctionObject *op, PyObject* value, void *context) { + CYTHON_UNUSED_VAR(context); + if (!value) { + value = Py_None; + } else if (unlikely(value != Py_None && !PyDict_Check(value))) { + PyErr_SetString(PyExc_TypeError, + "__kwdefaults__ must be set to a dict object"); + return -1; + } + PyErr_WarnEx(PyExc_RuntimeWarning, "changes to cyfunction.__kwdefaults__ will not " + "currently affect the values used in function calls", 1); + Py_INCREF(value); + __Pyx_Py_XDECREF_SET(op->defaults_kwdict, value); + return 0; +} +static PyObject * +__Pyx_CyFunction_get_kwdefaults(__pyx_CyFunctionObject *op, void *context) { + PyObject* result = op->defaults_kwdict; + CYTHON_UNUSED_VAR(context); + if (unlikely(!result)) { + if (op->defaults_getter) { + if (unlikely(__Pyx_CyFunction_init_defaults(op) < 0)) return NULL; + result = op->defaults_kwdict; + } else { + result = Py_None; + } + } + Py_INCREF(result); + return result; +} +static int +__Pyx_CyFunction_set_annotations(__pyx_CyFunctionObject *op, PyObject* value, void *context) { + CYTHON_UNUSED_VAR(context); + if (!value || value == Py_None) { + value = NULL; + } else if (unlikely(!PyDict_Check(value))) { + PyErr_SetString(PyExc_TypeError, + "__annotations__ must be set to a dict object"); + return -1; + } + Py_XINCREF(value); + __Pyx_Py_XDECREF_SET(op->func_annotations, value); + return 0; +} +static PyObject * +__Pyx_CyFunction_get_annotations(__pyx_CyFunctionObject *op, void *context) { + PyObject* result = op->func_annotations; + CYTHON_UNUSED_VAR(context); + if (unlikely(!result)) { + result = PyDict_New(); + if (unlikely(!result)) return NULL; + op->func_annotations = result; + } + Py_INCREF(result); + return result; +} +static PyObject * +__Pyx_CyFunction_get_is_coroutine(__pyx_CyFunctionObject *op, void *context) { + int is_coroutine; + CYTHON_UNUSED_VAR(context); + if (op->func_is_coroutine) { + return __Pyx_NewRef(op->func_is_coroutine); + } + is_coroutine = op->flags & __Pyx_CYFUNCTION_COROUTINE; +#if PY_VERSION_HEX >= 0x03050000 + if (is_coroutine) { + PyObject *module, *fromlist, *marker = __pyx_n_s_is_coroutine; + fromlist = PyList_New(1); + if (unlikely(!fromlist)) return NULL; + Py_INCREF(marker); +#if CYTHON_ASSUME_SAFE_MACROS + PyList_SET_ITEM(fromlist, 0, marker); +#else + if (unlikely(PyList_SetItem(fromlist, 0, marker) < 0)) { + Py_DECREF(marker); + Py_DECREF(fromlist); + return NULL; + } +#endif + module = PyImport_ImportModuleLevelObject(__pyx_n_s_asyncio_coroutines, NULL, NULL, fromlist, 0); + Py_DECREF(fromlist); + if (unlikely(!module)) goto ignore; + op->func_is_coroutine = __Pyx_PyObject_GetAttrStr(module, marker); + Py_DECREF(module); + if (likely(op->func_is_coroutine)) { + return __Pyx_NewRef(op->func_is_coroutine); + } +ignore: + PyErr_Clear(); + } +#endif + op->func_is_coroutine = __Pyx_PyBool_FromLong(is_coroutine); + return __Pyx_NewRef(op->func_is_coroutine); +} +#if CYTHON_COMPILING_IN_LIMITED_API +static PyObject * +__Pyx_CyFunction_get_module(__pyx_CyFunctionObject *op, void *context) { + CYTHON_UNUSED_VAR(context); + return PyObject_GetAttrString(op->func, "__module__"); +} +static int +__Pyx_CyFunction_set_module(__pyx_CyFunctionObject *op, PyObject* value, void *context) { + CYTHON_UNUSED_VAR(context); + return PyObject_SetAttrString(op->func, "__module__", value); +} +#endif +static PyGetSetDef __pyx_CyFunction_getsets[] = { + {(char *) "func_doc", (getter)__Pyx_CyFunction_get_doc, (setter)__Pyx_CyFunction_set_doc, 0, 0}, + {(char *) "__doc__", (getter)__Pyx_CyFunction_get_doc, (setter)__Pyx_CyFunction_set_doc, 0, 0}, + {(char *) "func_name", (getter)__Pyx_CyFunction_get_name, (setter)__Pyx_CyFunction_set_name, 0, 0}, + {(char *) "__name__", (getter)__Pyx_CyFunction_get_name, (setter)__Pyx_CyFunction_set_name, 0, 0}, + {(char *) "__qualname__", (getter)__Pyx_CyFunction_get_qualname, (setter)__Pyx_CyFunction_set_qualname, 0, 0}, + {(char *) "func_dict", (getter)__Pyx_CyFunction_get_dict, (setter)__Pyx_CyFunction_set_dict, 0, 0}, + {(char *) "__dict__", (getter)__Pyx_CyFunction_get_dict, (setter)__Pyx_CyFunction_set_dict, 0, 0}, + {(char *) "func_globals", (getter)__Pyx_CyFunction_get_globals, 0, 0, 0}, + {(char *) "__globals__", (getter)__Pyx_CyFunction_get_globals, 0, 0, 0}, + {(char *) "func_closure", (getter)__Pyx_CyFunction_get_closure, 0, 0, 0}, + {(char *) "__closure__", (getter)__Pyx_CyFunction_get_closure, 0, 0, 0}, + {(char *) "func_code", (getter)__Pyx_CyFunction_get_code, 0, 0, 0}, + {(char *) "__code__", (getter)__Pyx_CyFunction_get_code, 0, 0, 0}, + {(char *) "func_defaults", (getter)__Pyx_CyFunction_get_defaults, (setter)__Pyx_CyFunction_set_defaults, 0, 0}, + {(char *) "__defaults__", (getter)__Pyx_CyFunction_get_defaults, (setter)__Pyx_CyFunction_set_defaults, 0, 0}, + {(char *) "__kwdefaults__", (getter)__Pyx_CyFunction_get_kwdefaults, (setter)__Pyx_CyFunction_set_kwdefaults, 0, 0}, + {(char *) "__annotations__", (getter)__Pyx_CyFunction_get_annotations, (setter)__Pyx_CyFunction_set_annotations, 0, 0}, + {(char *) "_is_coroutine", (getter)__Pyx_CyFunction_get_is_coroutine, 0, 0, 0}, +#if CYTHON_COMPILING_IN_LIMITED_API + {"__module__", (getter)__Pyx_CyFunction_get_module, (setter)__Pyx_CyFunction_set_module, 0, 0}, +#endif + {0, 0, 0, 0, 0} +}; +static PyMemberDef __pyx_CyFunction_members[] = { +#if !CYTHON_COMPILING_IN_LIMITED_API + {(char *) "__module__", T_OBJECT, offsetof(PyCFunctionObject, m_module), 0, 0}, +#endif +#if CYTHON_USE_TYPE_SPECS + {(char *) "__dictoffset__", T_PYSSIZET, offsetof(__pyx_CyFunctionObject, func_dict), READONLY, 0}, +#if CYTHON_METH_FASTCALL +#if CYTHON_BACKPORT_VECTORCALL + {(char *) "__vectorcalloffset__", T_PYSSIZET, offsetof(__pyx_CyFunctionObject, func_vectorcall), READONLY, 0}, +#else +#if !CYTHON_COMPILING_IN_LIMITED_API + {(char *) "__vectorcalloffset__", T_PYSSIZET, offsetof(PyCFunctionObject, vectorcall), READONLY, 0}, +#endif +#endif +#endif +#if PY_VERSION_HEX < 0x030500A0 || CYTHON_COMPILING_IN_LIMITED_API + {(char *) "__weaklistoffset__", T_PYSSIZET, offsetof(__pyx_CyFunctionObject, func_weakreflist), READONLY, 0}, +#else + {(char *) "__weaklistoffset__", T_PYSSIZET, offsetof(PyCFunctionObject, m_weakreflist), READONLY, 0}, +#endif +#endif + {0, 0, 0, 0, 0} +}; +static PyObject * +__Pyx_CyFunction_reduce(__pyx_CyFunctionObject *m, PyObject *args) +{ + CYTHON_UNUSED_VAR(args); +#if PY_MAJOR_VERSION >= 3 + Py_INCREF(m->func_qualname); + return m->func_qualname; +#else + return PyString_FromString(((PyCFunctionObject*)m)->m_ml->ml_name); +#endif +} +static PyMethodDef __pyx_CyFunction_methods[] = { + {"__reduce__", (PyCFunction)__Pyx_CyFunction_reduce, METH_VARARGS, 0}, + {0, 0, 0, 0} +}; +#if PY_VERSION_HEX < 0x030500A0 || CYTHON_COMPILING_IN_LIMITED_API +#define __Pyx_CyFunction_weakreflist(cyfunc) ((cyfunc)->func_weakreflist) +#else +#define __Pyx_CyFunction_weakreflist(cyfunc) (((PyCFunctionObject*)cyfunc)->m_weakreflist) +#endif +static PyObject *__Pyx_CyFunction_Init(__pyx_CyFunctionObject *op, PyMethodDef *ml, int flags, PyObject* qualname, + PyObject *closure, PyObject *module, PyObject* globals, PyObject* code) { +#if !CYTHON_COMPILING_IN_LIMITED_API + PyCFunctionObject *cf = (PyCFunctionObject*) op; +#endif + if (unlikely(op == NULL)) + return NULL; +#if CYTHON_COMPILING_IN_LIMITED_API + op->func = PyCFunction_NewEx(ml, (PyObject*)op, module); + if (unlikely(!op->func)) return NULL; +#endif + op->flags = flags; + __Pyx_CyFunction_weakreflist(op) = NULL; +#if !CYTHON_COMPILING_IN_LIMITED_API + cf->m_ml = ml; + cf->m_self = (PyObject *) op; +#endif + Py_XINCREF(closure); + op->func_closure = closure; +#if !CYTHON_COMPILING_IN_LIMITED_API + Py_XINCREF(module); + cf->m_module = module; +#endif + op->func_dict = NULL; + op->func_name = NULL; + Py_INCREF(qualname); + op->func_qualname = qualname; + op->func_doc = NULL; +#if PY_VERSION_HEX < 0x030900B1 || CYTHON_COMPILING_IN_LIMITED_API + op->func_classobj = NULL; +#else + ((PyCMethodObject*)op)->mm_class = NULL; +#endif + op->func_globals = globals; + Py_INCREF(op->func_globals); + Py_XINCREF(code); + op->func_code = code; + op->defaults_pyobjects = 0; + op->defaults_size = 0; + op->defaults = NULL; + op->defaults_tuple = NULL; + op->defaults_kwdict = NULL; + op->defaults_getter = NULL; + op->func_annotations = NULL; + op->func_is_coroutine = NULL; +#if CYTHON_METH_FASTCALL + switch (ml->ml_flags & (METH_VARARGS | METH_FASTCALL | METH_NOARGS | METH_O | METH_KEYWORDS | METH_METHOD)) { + case METH_NOARGS: + __Pyx_CyFunction_func_vectorcall(op) = __Pyx_CyFunction_Vectorcall_NOARGS; + break; + case METH_O: + __Pyx_CyFunction_func_vectorcall(op) = __Pyx_CyFunction_Vectorcall_O; + break; + case METH_METHOD | METH_FASTCALL | METH_KEYWORDS: + __Pyx_CyFunction_func_vectorcall(op) = __Pyx_CyFunction_Vectorcall_FASTCALL_KEYWORDS_METHOD; + break; + case METH_FASTCALL | METH_KEYWORDS: + __Pyx_CyFunction_func_vectorcall(op) = __Pyx_CyFunction_Vectorcall_FASTCALL_KEYWORDS; + break; + case METH_VARARGS | METH_KEYWORDS: + __Pyx_CyFunction_func_vectorcall(op) = NULL; + break; + default: + PyErr_SetString(PyExc_SystemError, "Bad call flags for CyFunction"); + Py_DECREF(op); + return NULL; + } +#endif + return (PyObject *) op; +} +static int +__Pyx_CyFunction_clear(__pyx_CyFunctionObject *m) +{ + Py_CLEAR(m->func_closure); +#if CYTHON_COMPILING_IN_LIMITED_API + Py_CLEAR(m->func); +#else + Py_CLEAR(((PyCFunctionObject*)m)->m_module); +#endif + Py_CLEAR(m->func_dict); + Py_CLEAR(m->func_name); + Py_CLEAR(m->func_qualname); + Py_CLEAR(m->func_doc); + Py_CLEAR(m->func_globals); + Py_CLEAR(m->func_code); +#if !CYTHON_COMPILING_IN_LIMITED_API +#if PY_VERSION_HEX < 0x030900B1 + Py_CLEAR(__Pyx_CyFunction_GetClassObj(m)); +#else + { + PyObject *cls = (PyObject*) ((PyCMethodObject *) (m))->mm_class; + ((PyCMethodObject *) (m))->mm_class = NULL; + Py_XDECREF(cls); + } +#endif +#endif + Py_CLEAR(m->defaults_tuple); + Py_CLEAR(m->defaults_kwdict); + Py_CLEAR(m->func_annotations); + Py_CLEAR(m->func_is_coroutine); + if (m->defaults) { + PyObject **pydefaults = __Pyx_CyFunction_Defaults(PyObject *, m); + int i; + for (i = 0; i < m->defaults_pyobjects; i++) + Py_XDECREF(pydefaults[i]); + PyObject_Free(m->defaults); + m->defaults = NULL; + } + return 0; +} +static void __Pyx__CyFunction_dealloc(__pyx_CyFunctionObject *m) +{ + if (__Pyx_CyFunction_weakreflist(m) != NULL) + PyObject_ClearWeakRefs((PyObject *) m); + __Pyx_CyFunction_clear(m); + __Pyx_PyHeapTypeObject_GC_Del(m); +} +static void __Pyx_CyFunction_dealloc(__pyx_CyFunctionObject *m) +{ + PyObject_GC_UnTrack(m); + __Pyx__CyFunction_dealloc(m); +} +static int __Pyx_CyFunction_traverse(__pyx_CyFunctionObject *m, visitproc visit, void *arg) +{ + Py_VISIT(m->func_closure); +#if CYTHON_COMPILING_IN_LIMITED_API + Py_VISIT(m->func); +#else + Py_VISIT(((PyCFunctionObject*)m)->m_module); +#endif + Py_VISIT(m->func_dict); + Py_VISIT(m->func_name); + Py_VISIT(m->func_qualname); + Py_VISIT(m->func_doc); + Py_VISIT(m->func_globals); + Py_VISIT(m->func_code); +#if !CYTHON_COMPILING_IN_LIMITED_API + Py_VISIT(__Pyx_CyFunction_GetClassObj(m)); +#endif + Py_VISIT(m->defaults_tuple); + Py_VISIT(m->defaults_kwdict); + Py_VISIT(m->func_is_coroutine); + if (m->defaults) { + PyObject **pydefaults = __Pyx_CyFunction_Defaults(PyObject *, m); + int i; + for (i = 0; i < m->defaults_pyobjects; i++) + Py_VISIT(pydefaults[i]); + } + return 0; +} +static PyObject* +__Pyx_CyFunction_repr(__pyx_CyFunctionObject *op) +{ +#if PY_MAJOR_VERSION >= 3 + return PyUnicode_FromFormat("", + op->func_qualname, (void *)op); +#else + return PyString_FromFormat("", + PyString_AsString(op->func_qualname), (void *)op); +#endif +} +static PyObject * __Pyx_CyFunction_CallMethod(PyObject *func, PyObject *self, PyObject *arg, PyObject *kw) { +#if CYTHON_COMPILING_IN_LIMITED_API + PyObject *f = ((__pyx_CyFunctionObject*)func)->func; + PyObject *py_name = NULL; + PyCFunction meth; + int flags; + meth = PyCFunction_GetFunction(f); + if (unlikely(!meth)) return NULL; + flags = PyCFunction_GetFlags(f); + if (unlikely(flags < 0)) return NULL; +#else + PyCFunctionObject* f = (PyCFunctionObject*)func; + PyCFunction meth = f->m_ml->ml_meth; + int flags = f->m_ml->ml_flags; +#endif + Py_ssize_t size; + switch (flags & (METH_VARARGS | METH_KEYWORDS | METH_NOARGS | METH_O)) { + case METH_VARARGS: + if (likely(kw == NULL || PyDict_Size(kw) == 0)) + return (*meth)(self, arg); + break; + case METH_VARARGS | METH_KEYWORDS: + return (*(PyCFunctionWithKeywords)(void*)meth)(self, arg, kw); + case METH_NOARGS: + if (likely(kw == NULL || PyDict_Size(kw) == 0)) { +#if CYTHON_ASSUME_SAFE_MACROS + size = PyTuple_GET_SIZE(arg); +#else + size = PyTuple_Size(arg); + if (unlikely(size < 0)) return NULL; +#endif + if (likely(size == 0)) + return (*meth)(self, NULL); +#if CYTHON_COMPILING_IN_LIMITED_API + py_name = __Pyx_CyFunction_get_name((__pyx_CyFunctionObject*)func, NULL); + if (!py_name) return NULL; + PyErr_Format(PyExc_TypeError, + "%.200S() takes no arguments (%" CYTHON_FORMAT_SSIZE_T "d given)", + py_name, size); + Py_DECREF(py_name); +#else + PyErr_Format(PyExc_TypeError, + "%.200s() takes no arguments (%" CYTHON_FORMAT_SSIZE_T "d given)", + f->m_ml->ml_name, size); +#endif + return NULL; + } + break; + case METH_O: + if (likely(kw == NULL || PyDict_Size(kw) == 0)) { +#if CYTHON_ASSUME_SAFE_MACROS + size = PyTuple_GET_SIZE(arg); +#else + size = PyTuple_Size(arg); + if (unlikely(size < 0)) return NULL; +#endif + if (likely(size == 1)) { + PyObject *result, *arg0; + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + arg0 = PyTuple_GET_ITEM(arg, 0); + #else + arg0 = __Pyx_PySequence_ITEM(arg, 0); if (unlikely(!arg0)) return NULL; + #endif + result = (*meth)(self, arg0); + #if !(CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS) + Py_DECREF(arg0); + #endif + return result; + } +#if CYTHON_COMPILING_IN_LIMITED_API + py_name = __Pyx_CyFunction_get_name((__pyx_CyFunctionObject*)func, NULL); + if (!py_name) return NULL; + PyErr_Format(PyExc_TypeError, + "%.200S() takes exactly one argument (%" CYTHON_FORMAT_SSIZE_T "d given)", + py_name, size); + Py_DECREF(py_name); +#else + PyErr_Format(PyExc_TypeError, + "%.200s() takes exactly one argument (%" CYTHON_FORMAT_SSIZE_T "d given)", + f->m_ml->ml_name, size); +#endif + return NULL; + } + break; + default: + PyErr_SetString(PyExc_SystemError, "Bad call flags for CyFunction"); + return NULL; + } +#if CYTHON_COMPILING_IN_LIMITED_API + py_name = __Pyx_CyFunction_get_name((__pyx_CyFunctionObject*)func, NULL); + if (!py_name) return NULL; + PyErr_Format(PyExc_TypeError, "%.200S() takes no keyword arguments", + py_name); + Py_DECREF(py_name); +#else + PyErr_Format(PyExc_TypeError, "%.200s() takes no keyword arguments", + f->m_ml->ml_name); +#endif + return NULL; +} +static CYTHON_INLINE PyObject *__Pyx_CyFunction_Call(PyObject *func, PyObject *arg, PyObject *kw) { + PyObject *self, *result; +#if CYTHON_COMPILING_IN_LIMITED_API + self = PyCFunction_GetSelf(((__pyx_CyFunctionObject*)func)->func); + if (unlikely(!self) && PyErr_Occurred()) return NULL; +#else + self = ((PyCFunctionObject*)func)->m_self; +#endif + result = __Pyx_CyFunction_CallMethod(func, self, arg, kw); + return result; +} +static PyObject *__Pyx_CyFunction_CallAsMethod(PyObject *func, PyObject *args, PyObject *kw) { + PyObject *result; + __pyx_CyFunctionObject *cyfunc = (__pyx_CyFunctionObject *) func; +#if CYTHON_METH_FASTCALL + __pyx_vectorcallfunc vc = __Pyx_CyFunction_func_vectorcall(cyfunc); + if (vc) { +#if CYTHON_ASSUME_SAFE_MACROS + return __Pyx_PyVectorcall_FastCallDict(func, vc, &PyTuple_GET_ITEM(args, 0), (size_t)PyTuple_GET_SIZE(args), kw); +#else + (void) &__Pyx_PyVectorcall_FastCallDict; + return PyVectorcall_Call(func, args, kw); +#endif + } +#endif + if ((cyfunc->flags & __Pyx_CYFUNCTION_CCLASS) && !(cyfunc->flags & __Pyx_CYFUNCTION_STATICMETHOD)) { + Py_ssize_t argc; + PyObject *new_args; + PyObject *self; +#if CYTHON_ASSUME_SAFE_MACROS + argc = PyTuple_GET_SIZE(args); +#else + argc = PyTuple_Size(args); + if (unlikely(!argc) < 0) return NULL; +#endif + new_args = PyTuple_GetSlice(args, 1, argc); + if (unlikely(!new_args)) + return NULL; + self = PyTuple_GetItem(args, 0); + if (unlikely(!self)) { + Py_DECREF(new_args); +#if PY_MAJOR_VERSION > 2 + PyErr_Format(PyExc_TypeError, + "unbound method %.200S() needs an argument", + cyfunc->func_qualname); +#else + PyErr_SetString(PyExc_TypeError, + "unbound method needs an argument"); +#endif + return NULL; + } + result = __Pyx_CyFunction_CallMethod(func, self, new_args, kw); + Py_DECREF(new_args); + } else { + result = __Pyx_CyFunction_Call(func, args, kw); + } + return result; +} +#if CYTHON_METH_FASTCALL +static CYTHON_INLINE int __Pyx_CyFunction_Vectorcall_CheckArgs(__pyx_CyFunctionObject *cyfunc, Py_ssize_t nargs, PyObject *kwnames) +{ + int ret = 0; + if ((cyfunc->flags & __Pyx_CYFUNCTION_CCLASS) && !(cyfunc->flags & __Pyx_CYFUNCTION_STATICMETHOD)) { + if (unlikely(nargs < 1)) { + PyErr_Format(PyExc_TypeError, "%.200s() needs an argument", + ((PyCFunctionObject*)cyfunc)->m_ml->ml_name); + return -1; + } + ret = 1; + } + if (unlikely(kwnames) && unlikely(PyTuple_GET_SIZE(kwnames))) { + PyErr_Format(PyExc_TypeError, + "%.200s() takes no keyword arguments", ((PyCFunctionObject*)cyfunc)->m_ml->ml_name); + return -1; + } + return ret; +} +static PyObject * __Pyx_CyFunction_Vectorcall_NOARGS(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames) +{ + __pyx_CyFunctionObject *cyfunc = (__pyx_CyFunctionObject *)func; + PyMethodDef* def = ((PyCFunctionObject*)cyfunc)->m_ml; +#if CYTHON_BACKPORT_VECTORCALL + Py_ssize_t nargs = (Py_ssize_t)nargsf; +#else + Py_ssize_t nargs = PyVectorcall_NARGS(nargsf); +#endif + PyObject *self; + switch (__Pyx_CyFunction_Vectorcall_CheckArgs(cyfunc, nargs, kwnames)) { + case 1: + self = args[0]; + args += 1; + nargs -= 1; + break; + case 0: + self = ((PyCFunctionObject*)cyfunc)->m_self; + break; + default: + return NULL; + } + if (unlikely(nargs != 0)) { + PyErr_Format(PyExc_TypeError, + "%.200s() takes no arguments (%" CYTHON_FORMAT_SSIZE_T "d given)", + def->ml_name, nargs); + return NULL; + } + return def->ml_meth(self, NULL); +} +static PyObject * __Pyx_CyFunction_Vectorcall_O(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames) +{ + __pyx_CyFunctionObject *cyfunc = (__pyx_CyFunctionObject *)func; + PyMethodDef* def = ((PyCFunctionObject*)cyfunc)->m_ml; +#if CYTHON_BACKPORT_VECTORCALL + Py_ssize_t nargs = (Py_ssize_t)nargsf; +#else + Py_ssize_t nargs = PyVectorcall_NARGS(nargsf); +#endif + PyObject *self; + switch (__Pyx_CyFunction_Vectorcall_CheckArgs(cyfunc, nargs, kwnames)) { + case 1: + self = args[0]; + args += 1; + nargs -= 1; + break; + case 0: + self = ((PyCFunctionObject*)cyfunc)->m_self; + break; + default: + return NULL; + } + if (unlikely(nargs != 1)) { + PyErr_Format(PyExc_TypeError, + "%.200s() takes exactly one argument (%" CYTHON_FORMAT_SSIZE_T "d given)", + def->ml_name, nargs); + return NULL; + } + return def->ml_meth(self, args[0]); +} +static PyObject * __Pyx_CyFunction_Vectorcall_FASTCALL_KEYWORDS(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames) +{ + __pyx_CyFunctionObject *cyfunc = (__pyx_CyFunctionObject *)func; + PyMethodDef* def = ((PyCFunctionObject*)cyfunc)->m_ml; +#if CYTHON_BACKPORT_VECTORCALL + Py_ssize_t nargs = (Py_ssize_t)nargsf; +#else + Py_ssize_t nargs = PyVectorcall_NARGS(nargsf); +#endif + PyObject *self; + switch (__Pyx_CyFunction_Vectorcall_CheckArgs(cyfunc, nargs, NULL)) { + case 1: + self = args[0]; + args += 1; + nargs -= 1; + break; + case 0: + self = ((PyCFunctionObject*)cyfunc)->m_self; + break; + default: + return NULL; + } + return ((_PyCFunctionFastWithKeywords)(void(*)(void))def->ml_meth)(self, args, nargs, kwnames); +} +static PyObject * __Pyx_CyFunction_Vectorcall_FASTCALL_KEYWORDS_METHOD(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames) +{ + __pyx_CyFunctionObject *cyfunc = (__pyx_CyFunctionObject *)func; + PyMethodDef* def = ((PyCFunctionObject*)cyfunc)->m_ml; + PyTypeObject *cls = (PyTypeObject *) __Pyx_CyFunction_GetClassObj(cyfunc); +#if CYTHON_BACKPORT_VECTORCALL + Py_ssize_t nargs = (Py_ssize_t)nargsf; +#else + Py_ssize_t nargs = PyVectorcall_NARGS(nargsf); +#endif + PyObject *self; + switch (__Pyx_CyFunction_Vectorcall_CheckArgs(cyfunc, nargs, NULL)) { + case 1: + self = args[0]; + args += 1; + nargs -= 1; + break; + case 0: + self = ((PyCFunctionObject*)cyfunc)->m_self; + break; + default: + return NULL; + } + return ((__Pyx_PyCMethod)(void(*)(void))def->ml_meth)(self, cls, args, (size_t)nargs, kwnames); +} +#endif +#if CYTHON_USE_TYPE_SPECS +static PyType_Slot __pyx_CyFunctionType_slots[] = { + {Py_tp_dealloc, (void *)__Pyx_CyFunction_dealloc}, + {Py_tp_repr, (void *)__Pyx_CyFunction_repr}, + {Py_tp_call, (void *)__Pyx_CyFunction_CallAsMethod}, + {Py_tp_traverse, (void *)__Pyx_CyFunction_traverse}, + {Py_tp_clear, (void *)__Pyx_CyFunction_clear}, + {Py_tp_methods, (void *)__pyx_CyFunction_methods}, + {Py_tp_members, (void *)__pyx_CyFunction_members}, + {Py_tp_getset, (void *)__pyx_CyFunction_getsets}, + {Py_tp_descr_get, (void *)__Pyx_PyMethod_New}, + {0, 0}, +}; +static PyType_Spec __pyx_CyFunctionType_spec = { + __PYX_TYPE_MODULE_PREFIX "cython_function_or_method", + sizeof(__pyx_CyFunctionObject), + 0, +#ifdef Py_TPFLAGS_METHOD_DESCRIPTOR + Py_TPFLAGS_METHOD_DESCRIPTOR | +#endif +#if (defined(_Py_TPFLAGS_HAVE_VECTORCALL) && CYTHON_METH_FASTCALL) + _Py_TPFLAGS_HAVE_VECTORCALL | +#endif + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_BASETYPE, + __pyx_CyFunctionType_slots +}; +#else +static PyTypeObject __pyx_CyFunctionType_type = { + PyVarObject_HEAD_INIT(0, 0) + __PYX_TYPE_MODULE_PREFIX "cython_function_or_method", + sizeof(__pyx_CyFunctionObject), + 0, + (destructor) __Pyx_CyFunction_dealloc, +#if !CYTHON_METH_FASTCALL + 0, +#elif CYTHON_BACKPORT_VECTORCALL + (printfunc)offsetof(__pyx_CyFunctionObject, func_vectorcall), +#else + offsetof(PyCFunctionObject, vectorcall), +#endif + 0, + 0, +#if PY_MAJOR_VERSION < 3 + 0, +#else + 0, +#endif + (reprfunc) __Pyx_CyFunction_repr, + 0, + 0, + 0, + 0, + __Pyx_CyFunction_CallAsMethod, + 0, + 0, + 0, + 0, +#ifdef Py_TPFLAGS_METHOD_DESCRIPTOR + Py_TPFLAGS_METHOD_DESCRIPTOR | +#endif +#if defined(_Py_TPFLAGS_HAVE_VECTORCALL) && CYTHON_METH_FASTCALL + _Py_TPFLAGS_HAVE_VECTORCALL | +#endif + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_BASETYPE, + 0, + (traverseproc) __Pyx_CyFunction_traverse, + (inquiry) __Pyx_CyFunction_clear, + 0, +#if PY_VERSION_HEX < 0x030500A0 + offsetof(__pyx_CyFunctionObject, func_weakreflist), +#else + offsetof(PyCFunctionObject, m_weakreflist), +#endif + 0, + 0, + __pyx_CyFunction_methods, + __pyx_CyFunction_members, + __pyx_CyFunction_getsets, + 0, + 0, + __Pyx_PyMethod_New, + 0, + offsetof(__pyx_CyFunctionObject, func_dict), + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, +#if PY_VERSION_HEX >= 0x030400a1 + 0, +#endif +#if PY_VERSION_HEX >= 0x030800b1 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07030800) + 0, +#endif +#if __PYX_NEED_TP_PRINT_SLOT + 0, +#endif +#if PY_VERSION_HEX >= 0x030C0000 + 0, +#endif +#if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX >= 0x03090000 && PY_VERSION_HEX < 0x030a0000 + 0, +#endif +}; +#endif +static int __pyx_CyFunction_init(PyObject *module) { +#if CYTHON_USE_TYPE_SPECS + __pyx_CyFunctionType = __Pyx_FetchCommonTypeFromSpec(module, &__pyx_CyFunctionType_spec, NULL); +#else + CYTHON_UNUSED_VAR(module); + __pyx_CyFunctionType = __Pyx_FetchCommonType(&__pyx_CyFunctionType_type); +#endif + if (unlikely(__pyx_CyFunctionType == NULL)) { + return -1; + } + return 0; +} +static CYTHON_INLINE void *__Pyx_CyFunction_InitDefaults(PyObject *func, size_t size, int pyobjects) { + __pyx_CyFunctionObject *m = (__pyx_CyFunctionObject *) func; + m->defaults = PyObject_Malloc(size); + if (unlikely(!m->defaults)) + return PyErr_NoMemory(); + memset(m->defaults, 0, size); + m->defaults_pyobjects = pyobjects; + m->defaults_size = size; + return m->defaults; +} +static CYTHON_INLINE void __Pyx_CyFunction_SetDefaultsTuple(PyObject *func, PyObject *tuple) { + __pyx_CyFunctionObject *m = (__pyx_CyFunctionObject *) func; + m->defaults_tuple = tuple; + Py_INCREF(tuple); +} +static CYTHON_INLINE void __Pyx_CyFunction_SetDefaultsKwDict(PyObject *func, PyObject *dict) { + __pyx_CyFunctionObject *m = (__pyx_CyFunctionObject *) func; + m->defaults_kwdict = dict; + Py_INCREF(dict); +} +static CYTHON_INLINE void __Pyx_CyFunction_SetAnnotationsDict(PyObject *func, PyObject *dict) { + __pyx_CyFunctionObject *m = (__pyx_CyFunctionObject *) func; + m->func_annotations = dict; + Py_INCREF(dict); +} + +/* CythonFunction */ + static PyObject *__Pyx_CyFunction_New(PyMethodDef *ml, int flags, PyObject* qualname, + PyObject *closure, PyObject *module, PyObject* globals, PyObject* code) { + PyObject *op = __Pyx_CyFunction_Init( + PyObject_GC_New(__pyx_CyFunctionObject, __pyx_CyFunctionType), + ml, flags, qualname, closure, module, globals, code + ); + if (likely(op)) { + PyObject_GC_Track(op); + } + return op; +} + +/* CLineInTraceback */ + #ifndef CYTHON_CLINE_IN_TRACEBACK +static int __Pyx_CLineForTraceback(PyThreadState *tstate, int c_line) { + PyObject *use_cline; + PyObject *ptype, *pvalue, *ptraceback; +#if CYTHON_COMPILING_IN_CPYTHON + PyObject **cython_runtime_dict; +#endif + CYTHON_MAYBE_UNUSED_VAR(tstate); + if (unlikely(!__pyx_cython_runtime)) { + return c_line; + } + __Pyx_ErrFetchInState(tstate, &ptype, &pvalue, &ptraceback); +#if CYTHON_COMPILING_IN_CPYTHON + cython_runtime_dict = _PyObject_GetDictPtr(__pyx_cython_runtime); + if (likely(cython_runtime_dict)) { + __PYX_PY_DICT_LOOKUP_IF_MODIFIED( + use_cline, *cython_runtime_dict, + __Pyx_PyDict_GetItemStr(*cython_runtime_dict, __pyx_n_s_cline_in_traceback)) + } else +#endif + { + PyObject *use_cline_obj = __Pyx_PyObject_GetAttrStrNoError(__pyx_cython_runtime, __pyx_n_s_cline_in_traceback); + if (use_cline_obj) { + use_cline = PyObject_Not(use_cline_obj) ? Py_False : Py_True; + Py_DECREF(use_cline_obj); + } else { + PyErr_Clear(); + use_cline = NULL; + } + } + if (!use_cline) { + c_line = 0; + (void) PyObject_SetAttr(__pyx_cython_runtime, __pyx_n_s_cline_in_traceback, Py_False); + } + else if (use_cline == Py_False || (use_cline != Py_True && PyObject_Not(use_cline) != 0)) { + c_line = 0; + } + __Pyx_ErrRestoreInState(tstate, ptype, pvalue, ptraceback); + return c_line; +} +#endif + +/* CodeObjectCache */ + #if !CYTHON_COMPILING_IN_LIMITED_API +static int __pyx_bisect_code_objects(__Pyx_CodeObjectCacheEntry* entries, int count, int code_line) { + int start = 0, mid = 0, end = count - 1; + if (end >= 0 && code_line > entries[end].code_line) { + return count; + } + while (start < end) { + mid = start + (end - start) / 2; + if (code_line < entries[mid].code_line) { + end = mid; + } else if (code_line > entries[mid].code_line) { + start = mid + 1; + } else { + return mid; + } + } + if (code_line <= entries[mid].code_line) { + return mid; + } else { + return mid + 1; + } +} +static PyCodeObject *__pyx_find_code_object(int code_line) { + PyCodeObject* code_object; + int pos; + if (unlikely(!code_line) || unlikely(!__pyx_code_cache.entries)) { + return NULL; + } + pos = __pyx_bisect_code_objects(__pyx_code_cache.entries, __pyx_code_cache.count, code_line); + if (unlikely(pos >= __pyx_code_cache.count) || unlikely(__pyx_code_cache.entries[pos].code_line != code_line)) { + return NULL; + } + code_object = __pyx_code_cache.entries[pos].code_object; + Py_INCREF(code_object); + return code_object; +} +static void __pyx_insert_code_object(int code_line, PyCodeObject* code_object) { + int pos, i; + __Pyx_CodeObjectCacheEntry* entries = __pyx_code_cache.entries; + if (unlikely(!code_line)) { + return; + } + if (unlikely(!entries)) { + entries = (__Pyx_CodeObjectCacheEntry*)PyMem_Malloc(64*sizeof(__Pyx_CodeObjectCacheEntry)); + if (likely(entries)) { + __pyx_code_cache.entries = entries; + __pyx_code_cache.max_count = 64; + __pyx_code_cache.count = 1; + entries[0].code_line = code_line; + entries[0].code_object = code_object; + Py_INCREF(code_object); + } + return; + } + pos = __pyx_bisect_code_objects(__pyx_code_cache.entries, __pyx_code_cache.count, code_line); + if ((pos < __pyx_code_cache.count) && unlikely(__pyx_code_cache.entries[pos].code_line == code_line)) { + PyCodeObject* tmp = entries[pos].code_object; + entries[pos].code_object = code_object; + Py_DECREF(tmp); + return; + } + if (__pyx_code_cache.count == __pyx_code_cache.max_count) { + int new_max = __pyx_code_cache.max_count + 64; + entries = (__Pyx_CodeObjectCacheEntry*)PyMem_Realloc( + __pyx_code_cache.entries, ((size_t)new_max) * sizeof(__Pyx_CodeObjectCacheEntry)); + if (unlikely(!entries)) { + return; + } + __pyx_code_cache.entries = entries; + __pyx_code_cache.max_count = new_max; + } + for (i=__pyx_code_cache.count; i>pos; i--) { + entries[i] = entries[i-1]; + } + entries[pos].code_line = code_line; + entries[pos].code_object = code_object; + __pyx_code_cache.count++; + Py_INCREF(code_object); +} +#endif + +/* AddTraceback */ + #include "compile.h" +#include "frameobject.h" +#include "traceback.h" +#if PY_VERSION_HEX >= 0x030b00a6 && !CYTHON_COMPILING_IN_LIMITED_API + #ifndef Py_BUILD_CORE + #define Py_BUILD_CORE 1 + #endif + #include "internal/pycore_frame.h" +#endif +#if CYTHON_COMPILING_IN_LIMITED_API +static PyObject *__Pyx_PyCode_Replace_For_AddTraceback(PyObject *code, PyObject *scratch_dict, + PyObject *firstlineno, PyObject *name) { + PyObject *replace = NULL; + if (unlikely(PyDict_SetItemString(scratch_dict, "co_firstlineno", firstlineno))) return NULL; + if (unlikely(PyDict_SetItemString(scratch_dict, "co_name", name))) return NULL; + replace = PyObject_GetAttrString(code, "replace"); + if (likely(replace)) { + PyObject *result; + result = PyObject_Call(replace, __pyx_empty_tuple, scratch_dict); + Py_DECREF(replace); + return result; + } + PyErr_Clear(); + #if __PYX_LIMITED_VERSION_HEX < 0x030780000 + { + PyObject *compiled = NULL, *result = NULL; + if (unlikely(PyDict_SetItemString(scratch_dict, "code", code))) return NULL; + if (unlikely(PyDict_SetItemString(scratch_dict, "type", (PyObject*)(&PyType_Type)))) return NULL; + compiled = Py_CompileString( + "out = type(code)(\n" + " code.co_argcount, code.co_kwonlyargcount, code.co_nlocals, code.co_stacksize,\n" + " code.co_flags, code.co_code, code.co_consts, code.co_names,\n" + " code.co_varnames, code.co_filename, co_name, co_firstlineno,\n" + " code.co_lnotab)\n", "", Py_file_input); + if (!compiled) return NULL; + result = PyEval_EvalCode(compiled, scratch_dict, scratch_dict); + Py_DECREF(compiled); + if (!result) PyErr_Print(); + Py_DECREF(result); + result = PyDict_GetItemString(scratch_dict, "out"); + if (result) Py_INCREF(result); + return result; + } + #else + return NULL; + #endif +} +static void __Pyx_AddTraceback(const char *funcname, int c_line, + int py_line, const char *filename) { + PyObject *code_object = NULL, *py_py_line = NULL, *py_funcname = NULL, *dict = NULL; + PyObject *replace = NULL, *getframe = NULL, *frame = NULL; + PyObject *exc_type, *exc_value, *exc_traceback; + int success = 0; + if (c_line) { + (void) __pyx_cfilenm; + (void) __Pyx_CLineForTraceback(__Pyx_PyThreadState_Current, c_line); + } + PyErr_Fetch(&exc_type, &exc_value, &exc_traceback); + code_object = Py_CompileString("_getframe()", filename, Py_eval_input); + if (unlikely(!code_object)) goto bad; + py_py_line = PyLong_FromLong(py_line); + if (unlikely(!py_py_line)) goto bad; + py_funcname = PyUnicode_FromString(funcname); + if (unlikely(!py_funcname)) goto bad; + dict = PyDict_New(); + if (unlikely(!dict)) goto bad; + { + PyObject *old_code_object = code_object; + code_object = __Pyx_PyCode_Replace_For_AddTraceback(code_object, dict, py_py_line, py_funcname); + Py_DECREF(old_code_object); + } + if (unlikely(!code_object)) goto bad; + getframe = PySys_GetObject("_getframe"); + if (unlikely(!getframe)) goto bad; + if (unlikely(PyDict_SetItemString(dict, "_getframe", getframe))) goto bad; + frame = PyEval_EvalCode(code_object, dict, dict); + if (unlikely(!frame) || frame == Py_None) goto bad; + success = 1; + bad: + PyErr_Restore(exc_type, exc_value, exc_traceback); + Py_XDECREF(code_object); + Py_XDECREF(py_py_line); + Py_XDECREF(py_funcname); + Py_XDECREF(dict); + Py_XDECREF(replace); + if (success) { + PyTraceBack_Here( + (struct _frame*)frame); + } + Py_XDECREF(frame); +} +#else +static PyCodeObject* __Pyx_CreateCodeObjectForTraceback( + const char *funcname, int c_line, + int py_line, const char *filename) { + PyCodeObject *py_code = NULL; + PyObject *py_funcname = NULL; + #if PY_MAJOR_VERSION < 3 + PyObject *py_srcfile = NULL; + py_srcfile = PyString_FromString(filename); + if (!py_srcfile) goto bad; + #endif + if (c_line) { + #if PY_MAJOR_VERSION < 3 + py_funcname = PyString_FromFormat( "%s (%s:%d)", funcname, __pyx_cfilenm, c_line); + if (!py_funcname) goto bad; + #else + py_funcname = PyUnicode_FromFormat( "%s (%s:%d)", funcname, __pyx_cfilenm, c_line); + if (!py_funcname) goto bad; + funcname = PyUnicode_AsUTF8(py_funcname); + if (!funcname) goto bad; + #endif + } + else { + #if PY_MAJOR_VERSION < 3 + py_funcname = PyString_FromString(funcname); + if (!py_funcname) goto bad; + #endif + } + #if PY_MAJOR_VERSION < 3 + py_code = __Pyx_PyCode_New( + 0, + 0, + 0, + 0, + 0, + 0, + __pyx_empty_bytes, /*PyObject *code,*/ + __pyx_empty_tuple, /*PyObject *consts,*/ + __pyx_empty_tuple, /*PyObject *names,*/ + __pyx_empty_tuple, /*PyObject *varnames,*/ + __pyx_empty_tuple, /*PyObject *freevars,*/ + __pyx_empty_tuple, /*PyObject *cellvars,*/ + py_srcfile, /*PyObject *filename,*/ + py_funcname, /*PyObject *name,*/ + py_line, + __pyx_empty_bytes /*PyObject *lnotab*/ + ); + Py_DECREF(py_srcfile); + #else + py_code = PyCode_NewEmpty(filename, funcname, py_line); + #endif + Py_XDECREF(py_funcname); + return py_code; +bad: + Py_XDECREF(py_funcname); + #if PY_MAJOR_VERSION < 3 + Py_XDECREF(py_srcfile); + #endif + return NULL; +} +static void __Pyx_AddTraceback(const char *funcname, int c_line, + int py_line, const char *filename) { + PyCodeObject *py_code = 0; + PyFrameObject *py_frame = 0; + PyThreadState *tstate = __Pyx_PyThreadState_Current; + PyObject *ptype, *pvalue, *ptraceback; + if (c_line) { + c_line = __Pyx_CLineForTraceback(tstate, c_line); + } + py_code = __pyx_find_code_object(c_line ? -c_line : py_line); + if (!py_code) { + __Pyx_ErrFetchInState(tstate, &ptype, &pvalue, &ptraceback); + py_code = __Pyx_CreateCodeObjectForTraceback( + funcname, c_line, py_line, filename); + if (!py_code) { + /* If the code object creation fails, then we should clear the + fetched exception references and propagate the new exception */ + Py_XDECREF(ptype); + Py_XDECREF(pvalue); + Py_XDECREF(ptraceback); + goto bad; + } + __Pyx_ErrRestoreInState(tstate, ptype, pvalue, ptraceback); + __pyx_insert_code_object(c_line ? -c_line : py_line, py_code); + } + py_frame = PyFrame_New( + tstate, /*PyThreadState *tstate,*/ + py_code, /*PyCodeObject *code,*/ + __pyx_d, /*PyObject *globals,*/ + 0 /*PyObject *locals*/ + ); + if (!py_frame) goto bad; + __Pyx_PyFrame_SetLineNumber(py_frame, py_line); + PyTraceBack_Here(py_frame); +bad: + Py_XDECREF(py_code); + Py_XDECREF(py_frame); +} +#endif + +#if PY_MAJOR_VERSION < 3 +static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags) { + __Pyx_TypeName obj_type_name; + if (PyObject_CheckBuffer(obj)) return PyObject_GetBuffer(obj, view, flags); + if (__Pyx_TypeCheck(obj, __pyx_array_type)) return __pyx_array_getbuffer(obj, view, flags); + if (__Pyx_TypeCheck(obj, __pyx_memoryview_type)) return __pyx_memoryview_getbuffer(obj, view, flags); + obj_type_name = __Pyx_PyType_GetName(Py_TYPE(obj)); + PyErr_Format(PyExc_TypeError, + "'" __Pyx_FMT_TYPENAME "' does not have the buffer interface", + obj_type_name); + __Pyx_DECREF_TypeName(obj_type_name); + return -1; +} +static void __Pyx_ReleaseBuffer(Py_buffer *view) { + PyObject *obj = view->obj; + if (!obj) return; + if (PyObject_CheckBuffer(obj)) { + PyBuffer_Release(view); + return; + } + if ((0)) {} + view->obj = NULL; + Py_DECREF(obj); +} +#endif + + + /* MemviewSliceIsContig */ + static int +__pyx_memviewslice_is_contig(const __Pyx_memviewslice mvs, char order, int ndim) +{ + int i, index, step, start; + Py_ssize_t itemsize = mvs.memview->view.itemsize; + if (order == 'F') { + step = 1; + start = 0; + } else { + step = -1; + start = ndim - 1; + } + for (i = 0; i < ndim; i++) { + index = start + step * i; + if (mvs.suboffsets[index] >= 0 || mvs.strides[index] != itemsize) + return 0; + itemsize *= mvs.shape[index]; + } + return 1; +} + +/* OverlappingSlices */ + static void +__pyx_get_array_memory_extents(__Pyx_memviewslice *slice, + void **out_start, void **out_end, + int ndim, size_t itemsize) +{ + char *start, *end; + int i; + start = end = slice->data; + for (i = 0; i < ndim; i++) { + Py_ssize_t stride = slice->strides[i]; + Py_ssize_t extent = slice->shape[i]; + if (extent == 0) { + *out_start = *out_end = start; + return; + } else { + if (stride > 0) + end += stride * (extent - 1); + else + start += stride * (extent - 1); + } + } + *out_start = start; + *out_end = end + itemsize; +} +static int +__pyx_slices_overlap(__Pyx_memviewslice *slice1, + __Pyx_memviewslice *slice2, + int ndim, size_t itemsize) +{ + void *start1, *end1, *start2, *end2; + __pyx_get_array_memory_extents(slice1, &start1, &end1, ndim, itemsize); + __pyx_get_array_memory_extents(slice2, &start2, &end2, ndim, itemsize); + return (start1 < end2) && (start2 < end1); +} + +/* CIntFromPyVerify */ + #define __PYX_VERIFY_RETURN_INT(target_type, func_type, func_value)\ + __PYX__VERIFY_RETURN_INT(target_type, func_type, func_value, 0) +#define __PYX_VERIFY_RETURN_INT_EXC(target_type, func_type, func_value)\ + __PYX__VERIFY_RETURN_INT(target_type, func_type, func_value, 1) +#define __PYX__VERIFY_RETURN_INT(target_type, func_type, func_value, exc)\ + {\ + func_type value = func_value;\ + if (sizeof(target_type) < sizeof(func_type)) {\ + if (unlikely(value != (func_type) (target_type) value)) {\ + func_type zero = 0;\ + if (exc && unlikely(value == (func_type)-1 && PyErr_Occurred()))\ + return (target_type) -1;\ + if (is_unsigned && unlikely(value < zero))\ + goto raise_neg_overflow;\ + else\ + goto raise_overflow;\ + }\ + }\ + return (target_type) value;\ + } + +/* TypeInfoCompare */ + static int +__pyx_typeinfo_cmp(__Pyx_TypeInfo *a, __Pyx_TypeInfo *b) +{ + int i; + if (!a || !b) + return 0; + if (a == b) + return 1; + if (a->size != b->size || a->typegroup != b->typegroup || + a->is_unsigned != b->is_unsigned || a->ndim != b->ndim) { + if (a->typegroup == 'H' || b->typegroup == 'H') { + return a->size == b->size; + } else { + return 0; + } + } + if (a->ndim) { + for (i = 0; i < a->ndim; i++) + if (a->arraysize[i] != b->arraysize[i]) + return 0; + } + if (a->typegroup == 'S') { + if (a->flags != b->flags) + return 0; + if (a->fields || b->fields) { + if (!(a->fields && b->fields)) + return 0; + for (i = 0; a->fields[i].type && b->fields[i].type; i++) { + __Pyx_StructField *field_a = a->fields + i; + __Pyx_StructField *field_b = b->fields + i; + if (field_a->offset != field_b->offset || + !__pyx_typeinfo_cmp(field_a->type, field_b->type)) + return 0; + } + return !a->fields[i].type && !b->fields[i].type; + } + } + return 1; +} + +/* MemviewSliceValidateAndInit */ + static int +__pyx_check_strides(Py_buffer *buf, int dim, int ndim, int spec) +{ + if (buf->shape[dim] <= 1) + return 1; + if (buf->strides) { + if (spec & __Pyx_MEMVIEW_CONTIG) { + if (spec & (__Pyx_MEMVIEW_PTR|__Pyx_MEMVIEW_FULL)) { + if (unlikely(buf->strides[dim] != sizeof(void *))) { + PyErr_Format(PyExc_ValueError, + "Buffer is not indirectly contiguous " + "in dimension %d.", dim); + goto fail; + } + } else if (unlikely(buf->strides[dim] != buf->itemsize)) { + PyErr_SetString(PyExc_ValueError, + "Buffer and memoryview are not contiguous " + "in the same dimension."); + goto fail; + } + } + if (spec & __Pyx_MEMVIEW_FOLLOW) { + Py_ssize_t stride = buf->strides[dim]; + if (stride < 0) + stride = -stride; + if (unlikely(stride < buf->itemsize)) { + PyErr_SetString(PyExc_ValueError, + "Buffer and memoryview are not contiguous " + "in the same dimension."); + goto fail; + } + } + } else { + if (unlikely(spec & __Pyx_MEMVIEW_CONTIG && dim != ndim - 1)) { + PyErr_Format(PyExc_ValueError, + "C-contiguous buffer is not contiguous in " + "dimension %d", dim); + goto fail; + } else if (unlikely(spec & (__Pyx_MEMVIEW_PTR))) { + PyErr_Format(PyExc_ValueError, + "C-contiguous buffer is not indirect in " + "dimension %d", dim); + goto fail; + } else if (unlikely(buf->suboffsets)) { + PyErr_SetString(PyExc_ValueError, + "Buffer exposes suboffsets but no strides"); + goto fail; + } + } + return 1; +fail: + return 0; +} +static int +__pyx_check_suboffsets(Py_buffer *buf, int dim, int ndim, int spec) +{ + CYTHON_UNUSED_VAR(ndim); + if (spec & __Pyx_MEMVIEW_DIRECT) { + if (unlikely(buf->suboffsets && buf->suboffsets[dim] >= 0)) { + PyErr_Format(PyExc_ValueError, + "Buffer not compatible with direct access " + "in dimension %d.", dim); + goto fail; + } + } + if (spec & __Pyx_MEMVIEW_PTR) { + if (unlikely(!buf->suboffsets || (buf->suboffsets[dim] < 0))) { + PyErr_Format(PyExc_ValueError, + "Buffer is not indirectly accessible " + "in dimension %d.", dim); + goto fail; + } + } + return 1; +fail: + return 0; +} +static int +__pyx_verify_contig(Py_buffer *buf, int ndim, int c_or_f_flag) +{ + int i; + if (c_or_f_flag & __Pyx_IS_F_CONTIG) { + Py_ssize_t stride = 1; + for (i = 0; i < ndim; i++) { + if (unlikely(stride * buf->itemsize != buf->strides[i] && buf->shape[i] > 1)) { + PyErr_SetString(PyExc_ValueError, + "Buffer not fortran contiguous."); + goto fail; + } + stride = stride * buf->shape[i]; + } + } else if (c_or_f_flag & __Pyx_IS_C_CONTIG) { + Py_ssize_t stride = 1; + for (i = ndim - 1; i >- 1; i--) { + if (unlikely(stride * buf->itemsize != buf->strides[i] && buf->shape[i] > 1)) { + PyErr_SetString(PyExc_ValueError, + "Buffer not C contiguous."); + goto fail; + } + stride = stride * buf->shape[i]; + } + } + return 1; +fail: + return 0; +} +static int __Pyx_ValidateAndInit_memviewslice( + int *axes_specs, + int c_or_f_flag, + int buf_flags, + int ndim, + __Pyx_TypeInfo *dtype, + __Pyx_BufFmt_StackElem stack[], + __Pyx_memviewslice *memviewslice, + PyObject *original_obj) +{ + struct __pyx_memoryview_obj *memview, *new_memview; + __Pyx_RefNannyDeclarations + Py_buffer *buf; + int i, spec = 0, retval = -1; + __Pyx_BufFmt_Context ctx; + int from_memoryview = __pyx_memoryview_check(original_obj); + __Pyx_RefNannySetupContext("ValidateAndInit_memviewslice", 0); + if (from_memoryview && __pyx_typeinfo_cmp(dtype, ((struct __pyx_memoryview_obj *) + original_obj)->typeinfo)) { + memview = (struct __pyx_memoryview_obj *) original_obj; + new_memview = NULL; + } else { + memview = (struct __pyx_memoryview_obj *) __pyx_memoryview_new( + original_obj, buf_flags, 0, dtype); + new_memview = memview; + if (unlikely(!memview)) + goto fail; + } + buf = &memview->view; + if (unlikely(buf->ndim != ndim)) { + PyErr_Format(PyExc_ValueError, + "Buffer has wrong number of dimensions (expected %d, got %d)", + ndim, buf->ndim); + goto fail; + } + if (new_memview) { + __Pyx_BufFmt_Init(&ctx, stack, dtype); + if (unlikely(!__Pyx_BufFmt_CheckString(&ctx, buf->format))) goto fail; + } + if (unlikely((unsigned) buf->itemsize != dtype->size)) { + PyErr_Format(PyExc_ValueError, + "Item size of buffer (%" CYTHON_FORMAT_SSIZE_T "u byte%s) " + "does not match size of '%s' (%" CYTHON_FORMAT_SSIZE_T "u byte%s)", + buf->itemsize, + (buf->itemsize > 1) ? "s" : "", + dtype->name, + dtype->size, + (dtype->size > 1) ? "s" : ""); + goto fail; + } + if (buf->len > 0) { + for (i = 0; i < ndim; i++) { + spec = axes_specs[i]; + if (unlikely(!__pyx_check_strides(buf, i, ndim, spec))) + goto fail; + if (unlikely(!__pyx_check_suboffsets(buf, i, ndim, spec))) + goto fail; + } + if (unlikely(buf->strides && !__pyx_verify_contig(buf, ndim, c_or_f_flag))) + goto fail; + } + if (unlikely(__Pyx_init_memviewslice(memview, ndim, memviewslice, + new_memview != NULL) == -1)) { + goto fail; + } + retval = 0; + goto no_fail; +fail: + Py_XDECREF(new_memview); + retval = -1; +no_fail: + __Pyx_RefNannyFinishContext(); + return retval; +} + +/* ObjectToMemviewSlice */ + static CYTHON_INLINE __Pyx_memviewslice __Pyx_PyObject_to_MemoryviewSlice_ds_nn_int32_t(PyObject *obj, int writable_flag) { + __Pyx_memviewslice result = { 0, 0, { 0 }, { 0 }, { 0 } }; + __Pyx_BufFmt_StackElem stack[1]; + int axes_specs[] = { (__Pyx_MEMVIEW_DIRECT | __Pyx_MEMVIEW_STRIDED) }; + int retcode; + if (obj == Py_None) { + result.memview = (struct __pyx_memoryview_obj *) Py_None; + return result; + } + retcode = __Pyx_ValidateAndInit_memviewslice(axes_specs, 0, + PyBUF_RECORDS_RO | writable_flag, 1, + &__Pyx_TypeInfo_nn_int32_t, stack, + &result, obj); + if (unlikely(retcode == -1)) + goto __pyx_fail; + return result; +__pyx_fail: + result.memview = NULL; + result.data = NULL; + return result; +} + +/* ObjectToMemviewSlice */ + static CYTHON_INLINE __Pyx_memviewslice __Pyx_PyObject_to_MemoryviewSlice_ds_nn_int64_t(PyObject *obj, int writable_flag) { + __Pyx_memviewslice result = { 0, 0, { 0 }, { 0 }, { 0 } }; + __Pyx_BufFmt_StackElem stack[1]; + int axes_specs[] = { (__Pyx_MEMVIEW_DIRECT | __Pyx_MEMVIEW_STRIDED) }; + int retcode; + if (obj == Py_None) { + result.memview = (struct __pyx_memoryview_obj *) Py_None; + return result; + } + retcode = __Pyx_ValidateAndInit_memviewslice(axes_specs, 0, + PyBUF_RECORDS_RO | writable_flag, 1, + &__Pyx_TypeInfo_nn_int64_t, stack, + &result, obj); + if (unlikely(retcode == -1)) + goto __pyx_fail; + return result; +__pyx_fail: + result.memview = NULL; + result.data = NULL; + return result; +} + +/* ObjectToMemviewSlice */ + static CYTHON_INLINE __Pyx_memviewslice __Pyx_PyObject_to_MemoryviewSlice_ds_nn___pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t(PyObject *obj, int writable_flag) { + __Pyx_memviewslice result = { 0, 0, { 0 }, { 0 }, { 0 } }; + __Pyx_BufFmt_StackElem stack[1]; + int axes_specs[] = { (__Pyx_MEMVIEW_DIRECT | __Pyx_MEMVIEW_STRIDED) }; + int retcode; + if (obj == Py_None) { + result.memview = (struct __pyx_memoryview_obj *) Py_None; + return result; + } + retcode = __Pyx_ValidateAndInit_memviewslice(axes_specs, 0, + PyBUF_RECORDS_RO | writable_flag, 1, + &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t, stack, + &result, obj); + if (unlikely(retcode == -1)) + goto __pyx_fail; + return result; +__pyx_fail: + result.memview = NULL; + result.data = NULL; + return result; +} + +/* ObjectToMemviewSlice */ + static CYTHON_INLINE __Pyx_memviewslice __Pyx_PyObject_to_MemoryviewSlice_dsds_nn___pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t(PyObject *obj, int writable_flag) { + __Pyx_memviewslice result = { 0, 0, { 0 }, { 0 }, { 0 } }; + __Pyx_BufFmt_StackElem stack[1]; + int axes_specs[] = { (__Pyx_MEMVIEW_DIRECT | __Pyx_MEMVIEW_STRIDED), (__Pyx_MEMVIEW_DIRECT | __Pyx_MEMVIEW_STRIDED) }; + int retcode; + if (obj == Py_None) { + result.memview = (struct __pyx_memoryview_obj *) Py_None; + return result; + } + retcode = __Pyx_ValidateAndInit_memviewslice(axes_specs, 0, + PyBUF_RECORDS_RO | writable_flag, 2, + &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t, stack, + &result, obj); + if (unlikely(retcode == -1)) + goto __pyx_fail; + return result; +__pyx_fail: + result.memview = NULL; + result.data = NULL; + return result; +} + +/* MemviewDtypeToObject */ + static CYTHON_INLINE PyObject *__pyx_memview_get_nn___pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t(const char *itemp) { + return (PyObject *) __Pyx_PyInt_From_int64_t(*(__pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t *) itemp); +} +static CYTHON_INLINE int __pyx_memview_set_nn___pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t(const char *itemp, PyObject *obj) { + __pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t value = __Pyx_PyInt_As_int64_t(obj); + if (unlikely((value == ((int64_t)-1)) && PyErr_Occurred())) + return 0; + *(__pyx_t_7fairseq_4data_15data_utils_fast_DTYPE_t *) itemp = value; + return 1; +} + +/* Declarations */ + #if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) + #ifdef __cplusplus + static CYTHON_INLINE __pyx_t_float_complex __pyx_t_float_complex_from_parts(float x, float y) { + return ::std::complex< float >(x, y); + } + #else + static CYTHON_INLINE __pyx_t_float_complex __pyx_t_float_complex_from_parts(float x, float y) { + return x + y*(__pyx_t_float_complex)_Complex_I; + } + #endif +#else + static CYTHON_INLINE __pyx_t_float_complex __pyx_t_float_complex_from_parts(float x, float y) { + __pyx_t_float_complex z; + z.real = x; + z.imag = y; + return z; + } +#endif + +/* Arithmetic */ + #if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) +#else + static CYTHON_INLINE int __Pyx_c_eq_float(__pyx_t_float_complex a, __pyx_t_float_complex b) { + return (a.real == b.real) && (a.imag == b.imag); + } + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_sum_float(__pyx_t_float_complex a, __pyx_t_float_complex b) { + __pyx_t_float_complex z; + z.real = a.real + b.real; + z.imag = a.imag + b.imag; + return z; + } + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_diff_float(__pyx_t_float_complex a, __pyx_t_float_complex b) { + __pyx_t_float_complex z; + z.real = a.real - b.real; + z.imag = a.imag - b.imag; + return z; + } + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_prod_float(__pyx_t_float_complex a, __pyx_t_float_complex b) { + __pyx_t_float_complex z; + z.real = a.real * b.real - a.imag * b.imag; + z.imag = a.real * b.imag + a.imag * b.real; + return z; + } + #if 1 + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_quot_float(__pyx_t_float_complex a, __pyx_t_float_complex b) { + if (b.imag == 0) { + return __pyx_t_float_complex_from_parts(a.real / b.real, a.imag / b.real); + } else if (fabsf(b.real) >= fabsf(b.imag)) { + if (b.real == 0 && b.imag == 0) { + return __pyx_t_float_complex_from_parts(a.real / b.real, a.imag / b.imag); + } else { + float r = b.imag / b.real; + float s = (float)(1.0) / (b.real + b.imag * r); + return __pyx_t_float_complex_from_parts( + (a.real + a.imag * r) * s, (a.imag - a.real * r) * s); + } + } else { + float r = b.real / b.imag; + float s = (float)(1.0) / (b.imag + b.real * r); + return __pyx_t_float_complex_from_parts( + (a.real * r + a.imag) * s, (a.imag * r - a.real) * s); + } + } + #else + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_quot_float(__pyx_t_float_complex a, __pyx_t_float_complex b) { + if (b.imag == 0) { + return __pyx_t_float_complex_from_parts(a.real / b.real, a.imag / b.real); + } else { + float denom = b.real * b.real + b.imag * b.imag; + return __pyx_t_float_complex_from_parts( + (a.real * b.real + a.imag * b.imag) / denom, + (a.imag * b.real - a.real * b.imag) / denom); + } + } + #endif + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_neg_float(__pyx_t_float_complex a) { + __pyx_t_float_complex z; + z.real = -a.real; + z.imag = -a.imag; + return z; + } + static CYTHON_INLINE int __Pyx_c_is_zero_float(__pyx_t_float_complex a) { + return (a.real == 0) && (a.imag == 0); + } + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_conj_float(__pyx_t_float_complex a) { + __pyx_t_float_complex z; + z.real = a.real; + z.imag = -a.imag; + return z; + } + #if 1 + static CYTHON_INLINE float __Pyx_c_abs_float(__pyx_t_float_complex z) { + #if !defined(HAVE_HYPOT) || defined(_MSC_VER) + return sqrtf(z.real*z.real + z.imag*z.imag); + #else + return hypotf(z.real, z.imag); + #endif + } + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_pow_float(__pyx_t_float_complex a, __pyx_t_float_complex b) { + __pyx_t_float_complex z; + float r, lnr, theta, z_r, z_theta; + if (b.imag == 0 && b.real == (int)b.real) { + if (b.real < 0) { + float denom = a.real * a.real + a.imag * a.imag; + a.real = a.real / denom; + a.imag = -a.imag / denom; + b.real = -b.real; + } + switch ((int)b.real) { + case 0: + z.real = 1; + z.imag = 0; + return z; + case 1: + return a; + case 2: + return __Pyx_c_prod_float(a, a); + case 3: + z = __Pyx_c_prod_float(a, a); + return __Pyx_c_prod_float(z, a); + case 4: + z = __Pyx_c_prod_float(a, a); + return __Pyx_c_prod_float(z, z); + } + } + if (a.imag == 0) { + if (a.real == 0) { + return a; + } else if ((b.imag == 0) && (a.real >= 0)) { + z.real = powf(a.real, b.real); + z.imag = 0; + return z; + } else if (a.real > 0) { + r = a.real; + theta = 0; + } else { + r = -a.real; + theta = atan2f(0.0, -1.0); + } + } else { + r = __Pyx_c_abs_float(a); + theta = atan2f(a.imag, a.real); + } + lnr = logf(r); + z_r = expf(lnr * b.real - theta * b.imag); + z_theta = theta * b.real + lnr * b.imag; + z.real = z_r * cosf(z_theta); + z.imag = z_r * sinf(z_theta); + return z; + } + #endif +#endif + +/* Declarations */ + #if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) + #ifdef __cplusplus + static CYTHON_INLINE __pyx_t_double_complex __pyx_t_double_complex_from_parts(double x, double y) { + return ::std::complex< double >(x, y); + } + #else + static CYTHON_INLINE __pyx_t_double_complex __pyx_t_double_complex_from_parts(double x, double y) { + return x + y*(__pyx_t_double_complex)_Complex_I; + } + #endif +#else + static CYTHON_INLINE __pyx_t_double_complex __pyx_t_double_complex_from_parts(double x, double y) { + __pyx_t_double_complex z; + z.real = x; + z.imag = y; + return z; + } +#endif + +/* Arithmetic */ + #if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) +#else + static CYTHON_INLINE int __Pyx_c_eq_double(__pyx_t_double_complex a, __pyx_t_double_complex b) { + return (a.real == b.real) && (a.imag == b.imag); + } + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_sum_double(__pyx_t_double_complex a, __pyx_t_double_complex b) { + __pyx_t_double_complex z; + z.real = a.real + b.real; + z.imag = a.imag + b.imag; + return z; + } + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_diff_double(__pyx_t_double_complex a, __pyx_t_double_complex b) { + __pyx_t_double_complex z; + z.real = a.real - b.real; + z.imag = a.imag - b.imag; + return z; + } + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_prod_double(__pyx_t_double_complex a, __pyx_t_double_complex b) { + __pyx_t_double_complex z; + z.real = a.real * b.real - a.imag * b.imag; + z.imag = a.real * b.imag + a.imag * b.real; + return z; + } + #if 1 + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_quot_double(__pyx_t_double_complex a, __pyx_t_double_complex b) { + if (b.imag == 0) { + return __pyx_t_double_complex_from_parts(a.real / b.real, a.imag / b.real); + } else if (fabs(b.real) >= fabs(b.imag)) { + if (b.real == 0 && b.imag == 0) { + return __pyx_t_double_complex_from_parts(a.real / b.real, a.imag / b.imag); + } else { + double r = b.imag / b.real; + double s = (double)(1.0) / (b.real + b.imag * r); + return __pyx_t_double_complex_from_parts( + (a.real + a.imag * r) * s, (a.imag - a.real * r) * s); + } + } else { + double r = b.real / b.imag; + double s = (double)(1.0) / (b.imag + b.real * r); + return __pyx_t_double_complex_from_parts( + (a.real * r + a.imag) * s, (a.imag * r - a.real) * s); + } + } + #else + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_quot_double(__pyx_t_double_complex a, __pyx_t_double_complex b) { + if (b.imag == 0) { + return __pyx_t_double_complex_from_parts(a.real / b.real, a.imag / b.real); + } else { + double denom = b.real * b.real + b.imag * b.imag; + return __pyx_t_double_complex_from_parts( + (a.real * b.real + a.imag * b.imag) / denom, + (a.imag * b.real - a.real * b.imag) / denom); + } + } + #endif + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_neg_double(__pyx_t_double_complex a) { + __pyx_t_double_complex z; + z.real = -a.real; + z.imag = -a.imag; + return z; + } + static CYTHON_INLINE int __Pyx_c_is_zero_double(__pyx_t_double_complex a) { + return (a.real == 0) && (a.imag == 0); + } + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_conj_double(__pyx_t_double_complex a) { + __pyx_t_double_complex z; + z.real = a.real; + z.imag = -a.imag; + return z; + } + #if 1 + static CYTHON_INLINE double __Pyx_c_abs_double(__pyx_t_double_complex z) { + #if !defined(HAVE_HYPOT) || defined(_MSC_VER) + return sqrt(z.real*z.real + z.imag*z.imag); + #else + return hypot(z.real, z.imag); + #endif + } + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_pow_double(__pyx_t_double_complex a, __pyx_t_double_complex b) { + __pyx_t_double_complex z; + double r, lnr, theta, z_r, z_theta; + if (b.imag == 0 && b.real == (int)b.real) { + if (b.real < 0) { + double denom = a.real * a.real + a.imag * a.imag; + a.real = a.real / denom; + a.imag = -a.imag / denom; + b.real = -b.real; + } + switch ((int)b.real) { + case 0: + z.real = 1; + z.imag = 0; + return z; + case 1: + return a; + case 2: + return __Pyx_c_prod_double(a, a); + case 3: + z = __Pyx_c_prod_double(a, a); + return __Pyx_c_prod_double(z, a); + case 4: + z = __Pyx_c_prod_double(a, a); + return __Pyx_c_prod_double(z, z); + } + } + if (a.imag == 0) { + if (a.real == 0) { + return a; + } else if ((b.imag == 0) && (a.real >= 0)) { + z.real = pow(a.real, b.real); + z.imag = 0; + return z; + } else if (a.real > 0) { + r = a.real; + theta = 0; + } else { + r = -a.real; + theta = atan2(0.0, -1.0); + } + } else { + r = __Pyx_c_abs_double(a); + theta = atan2(a.imag, a.real); + } + lnr = log(r); + z_r = exp(lnr * b.real - theta * b.imag); + z_theta = theta * b.real + lnr * b.imag; + z.real = z_r * cos(z_theta); + z.imag = z_r * sin(z_theta); + return z; + } + #endif +#endif + +/* MemviewSliceCopyTemplate */ + static __Pyx_memviewslice +__pyx_memoryview_copy_new_contig(const __Pyx_memviewslice *from_mvs, + const char *mode, int ndim, + size_t sizeof_dtype, int contig_flag, + int dtype_is_object) +{ + __Pyx_RefNannyDeclarations + int i; + __Pyx_memviewslice new_mvs = { 0, 0, { 0 }, { 0 }, { 0 } }; + struct __pyx_memoryview_obj *from_memview = from_mvs->memview; + Py_buffer *buf = &from_memview->view; + PyObject *shape_tuple = NULL; + PyObject *temp_int = NULL; + struct __pyx_array_obj *array_obj = NULL; + struct __pyx_memoryview_obj *memview_obj = NULL; + __Pyx_RefNannySetupContext("__pyx_memoryview_copy_new_contig", 0); + for (i = 0; i < ndim; i++) { + if (unlikely(from_mvs->suboffsets[i] >= 0)) { + PyErr_Format(PyExc_ValueError, "Cannot copy memoryview slice with " + "indirect dimensions (axis %d)", i); + goto fail; + } + } + shape_tuple = PyTuple_New(ndim); + if (unlikely(!shape_tuple)) { + goto fail; + } + __Pyx_GOTREF(shape_tuple); + for(i = 0; i < ndim; i++) { + temp_int = PyInt_FromSsize_t(from_mvs->shape[i]); + if(unlikely(!temp_int)) { + goto fail; + } else { + PyTuple_SET_ITEM(shape_tuple, i, temp_int); + temp_int = NULL; + } + } + array_obj = __pyx_array_new(shape_tuple, sizeof_dtype, buf->format, (char *) mode, NULL); + if (unlikely(!array_obj)) { + goto fail; + } + __Pyx_GOTREF(array_obj); + memview_obj = (struct __pyx_memoryview_obj *) __pyx_memoryview_new( + (PyObject *) array_obj, contig_flag, + dtype_is_object, + from_mvs->memview->typeinfo); + if (unlikely(!memview_obj)) + goto fail; + if (unlikely(__Pyx_init_memviewslice(memview_obj, ndim, &new_mvs, 1) < 0)) + goto fail; + if (unlikely(__pyx_memoryview_copy_contents(*from_mvs, new_mvs, ndim, ndim, + dtype_is_object) < 0)) + goto fail; + goto no_fail; +fail: + __Pyx_XDECREF(new_mvs.memview); + new_mvs.memview = NULL; + new_mvs.data = NULL; +no_fail: + __Pyx_XDECREF(shape_tuple); + __Pyx_XDECREF(temp_int); + __Pyx_XDECREF(array_obj); + __Pyx_RefNannyFinishContext(); + return new_mvs; +} + +/* MemviewSliceInit */ + static int +__Pyx_init_memviewslice(struct __pyx_memoryview_obj *memview, + int ndim, + __Pyx_memviewslice *memviewslice, + int memview_is_new_reference) +{ + __Pyx_RefNannyDeclarations + int i, retval=-1; + Py_buffer *buf = &memview->view; + __Pyx_RefNannySetupContext("init_memviewslice", 0); + if (unlikely(memviewslice->memview || memviewslice->data)) { + PyErr_SetString(PyExc_ValueError, + "memviewslice is already initialized!"); + goto fail; + } + if (buf->strides) { + for (i = 0; i < ndim; i++) { + memviewslice->strides[i] = buf->strides[i]; + } + } else { + Py_ssize_t stride = buf->itemsize; + for (i = ndim - 1; i >= 0; i--) { + memviewslice->strides[i] = stride; + stride *= buf->shape[i]; + } + } + for (i = 0; i < ndim; i++) { + memviewslice->shape[i] = buf->shape[i]; + if (buf->suboffsets) { + memviewslice->suboffsets[i] = buf->suboffsets[i]; + } else { + memviewslice->suboffsets[i] = -1; + } + } + memviewslice->memview = memview; + memviewslice->data = (char *)buf->buf; + if (__pyx_add_acquisition_count(memview) == 0 && !memview_is_new_reference) { + Py_INCREF(memview); + } + retval = 0; + goto no_fail; +fail: + memviewslice->memview = 0; + memviewslice->data = 0; + retval = -1; +no_fail: + __Pyx_RefNannyFinishContext(); + return retval; +} +#ifndef Py_NO_RETURN +#define Py_NO_RETURN +#endif +static void __pyx_fatalerror(const char *fmt, ...) Py_NO_RETURN { + va_list vargs; + char msg[200]; +#if PY_VERSION_HEX >= 0x030A0000 || defined(HAVE_STDARG_PROTOTYPES) + va_start(vargs, fmt); +#else + va_start(vargs); +#endif + vsnprintf(msg, 200, fmt, vargs); + va_end(vargs); + Py_FatalError(msg); +} +static CYTHON_INLINE int +__pyx_add_acquisition_count_locked(__pyx_atomic_int_type *acquisition_count, + PyThread_type_lock lock) +{ + int result; + PyThread_acquire_lock(lock, 1); + result = (*acquisition_count)++; + PyThread_release_lock(lock); + return result; +} +static CYTHON_INLINE int +__pyx_sub_acquisition_count_locked(__pyx_atomic_int_type *acquisition_count, + PyThread_type_lock lock) +{ + int result; + PyThread_acquire_lock(lock, 1); + result = (*acquisition_count)--; + PyThread_release_lock(lock); + return result; +} +static CYTHON_INLINE void +__Pyx_INC_MEMVIEW(__Pyx_memviewslice *memslice, int have_gil, int lineno) +{ + __pyx_nonatomic_int_type old_acquisition_count; + struct __pyx_memoryview_obj *memview = memslice->memview; + if (unlikely(!memview || (PyObject *) memview == Py_None)) { + return; + } + old_acquisition_count = __pyx_add_acquisition_count(memview); + if (unlikely(old_acquisition_count <= 0)) { + if (likely(old_acquisition_count == 0)) { + if (have_gil) { + Py_INCREF((PyObject *) memview); + } else { + PyGILState_STATE _gilstate = PyGILState_Ensure(); + Py_INCREF((PyObject *) memview); + PyGILState_Release(_gilstate); + } + } else { + __pyx_fatalerror("Acquisition count is %d (line %d)", + old_acquisition_count+1, lineno); + } + } +} +static CYTHON_INLINE void __Pyx_XCLEAR_MEMVIEW(__Pyx_memviewslice *memslice, + int have_gil, int lineno) { + __pyx_nonatomic_int_type old_acquisition_count; + struct __pyx_memoryview_obj *memview = memslice->memview; + if (unlikely(!memview || (PyObject *) memview == Py_None)) { + memslice->memview = NULL; + return; + } + old_acquisition_count = __pyx_sub_acquisition_count(memview); + memslice->data = NULL; + if (likely(old_acquisition_count > 1)) { + memslice->memview = NULL; + } else if (likely(old_acquisition_count == 1)) { + if (have_gil) { + Py_CLEAR(memslice->memview); + } else { + PyGILState_STATE _gilstate = PyGILState_Ensure(); + Py_CLEAR(memslice->memview); + PyGILState_Release(_gilstate); + } + } else { + __pyx_fatalerror("Acquisition count is %d (line %d)", + old_acquisition_count-1, lineno); + } +} + +/* CIntFromPy */ + static CYTHON_INLINE int64_t __Pyx_PyInt_As_int64_t(PyObject *x) { +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#endif + const int64_t neg_one = (int64_t) -1, const_zero = (int64_t) 0; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic pop +#endif + const int is_unsigned = neg_one > const_zero; +#if PY_MAJOR_VERSION < 3 + if (likely(PyInt_Check(x))) { + if ((sizeof(int64_t) < sizeof(long))) { + __PYX_VERIFY_RETURN_INT(int64_t, long, PyInt_AS_LONG(x)) + } else { + long val = PyInt_AS_LONG(x); + if (is_unsigned && unlikely(val < 0)) { + goto raise_neg_overflow; + } + return (int64_t) val; + } + } else +#endif + if (likely(PyLong_Check(x))) { + if (is_unsigned) { +#if CYTHON_USE_PYLONG_INTERNALS + if (unlikely(__Pyx_PyLong_IsNeg(x))) { + goto raise_neg_overflow; + } else if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(int64_t, __Pyx_compact_upylong, __Pyx_PyLong_CompactValueUnsigned(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_DigitCount(x)) { + case 2: + if ((8 * sizeof(int64_t) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int64_t, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int64_t) >= 2 * PyLong_SHIFT)) { + return (int64_t) (((((int64_t)digits[1]) << PyLong_SHIFT) | (int64_t)digits[0])); + } + } + break; + case 3: + if ((8 * sizeof(int64_t) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int64_t, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int64_t) >= 3 * PyLong_SHIFT)) { + return (int64_t) (((((((int64_t)digits[2]) << PyLong_SHIFT) | (int64_t)digits[1]) << PyLong_SHIFT) | (int64_t)digits[0])); + } + } + break; + case 4: + if ((8 * sizeof(int64_t) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int64_t, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int64_t) >= 4 * PyLong_SHIFT)) { + return (int64_t) (((((((((int64_t)digits[3]) << PyLong_SHIFT) | (int64_t)digits[2]) << PyLong_SHIFT) | (int64_t)digits[1]) << PyLong_SHIFT) | (int64_t)digits[0])); + } + } + break; + } + } +#endif +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030C00A7 + if (unlikely(Py_SIZE(x) < 0)) { + goto raise_neg_overflow; + } +#else + { + int result = PyObject_RichCompareBool(x, Py_False, Py_LT); + if (unlikely(result < 0)) + return (int64_t) -1; + if (unlikely(result == 1)) + goto raise_neg_overflow; + } +#endif + if ((sizeof(int64_t) <= sizeof(unsigned long))) { + __PYX_VERIFY_RETURN_INT_EXC(int64_t, unsigned long, PyLong_AsUnsignedLong(x)) +#ifdef HAVE_LONG_LONG + } else if ((sizeof(int64_t) <= sizeof(unsigned PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(int64_t, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x)) +#endif + } + } else { +#if CYTHON_USE_PYLONG_INTERNALS + if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(int64_t, __Pyx_compact_pylong, __Pyx_PyLong_CompactValue(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_SignedDigitCount(x)) { + case -2: + if ((8 * sizeof(int64_t) - 1 > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int64_t, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int64_t) - 1 > 2 * PyLong_SHIFT)) { + return (int64_t) (((int64_t)-1)*(((((int64_t)digits[1]) << PyLong_SHIFT) | (int64_t)digits[0]))); + } + } + break; + case 2: + if ((8 * sizeof(int64_t) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int64_t, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int64_t) - 1 > 2 * PyLong_SHIFT)) { + return (int64_t) ((((((int64_t)digits[1]) << PyLong_SHIFT) | (int64_t)digits[0]))); + } + } + break; + case -3: + if ((8 * sizeof(int64_t) - 1 > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int64_t, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int64_t) - 1 > 3 * PyLong_SHIFT)) { + return (int64_t) (((int64_t)-1)*(((((((int64_t)digits[2]) << PyLong_SHIFT) | (int64_t)digits[1]) << PyLong_SHIFT) | (int64_t)digits[0]))); + } + } + break; + case 3: + if ((8 * sizeof(int64_t) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int64_t, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int64_t) - 1 > 3 * PyLong_SHIFT)) { + return (int64_t) ((((((((int64_t)digits[2]) << PyLong_SHIFT) | (int64_t)digits[1]) << PyLong_SHIFT) | (int64_t)digits[0]))); + } + } + break; + case -4: + if ((8 * sizeof(int64_t) - 1 > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int64_t, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int64_t) - 1 > 4 * PyLong_SHIFT)) { + return (int64_t) (((int64_t)-1)*(((((((((int64_t)digits[3]) << PyLong_SHIFT) | (int64_t)digits[2]) << PyLong_SHIFT) | (int64_t)digits[1]) << PyLong_SHIFT) | (int64_t)digits[0]))); + } + } + break; + case 4: + if ((8 * sizeof(int64_t) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int64_t, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int64_t) - 1 > 4 * PyLong_SHIFT)) { + return (int64_t) ((((((((((int64_t)digits[3]) << PyLong_SHIFT) | (int64_t)digits[2]) << PyLong_SHIFT) | (int64_t)digits[1]) << PyLong_SHIFT) | (int64_t)digits[0]))); + } + } + break; + } + } +#endif + if ((sizeof(int64_t) <= sizeof(long))) { + __PYX_VERIFY_RETURN_INT_EXC(int64_t, long, PyLong_AsLong(x)) +#ifdef HAVE_LONG_LONG + } else if ((sizeof(int64_t) <= sizeof(PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(int64_t, PY_LONG_LONG, PyLong_AsLongLong(x)) +#endif + } + } + { + int64_t val; + PyObject *v = __Pyx_PyNumber_IntOrLong(x); +#if PY_MAJOR_VERSION < 3 + if (likely(v) && !PyLong_Check(v)) { + PyObject *tmp = v; + v = PyNumber_Long(tmp); + Py_DECREF(tmp); + } +#endif + if (likely(v)) { + int ret = -1; +#if PY_VERSION_HEX < 0x030d0000 && !(CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API) || defined(_PyLong_AsByteArray) + int one = 1; int is_little = (int)*(unsigned char *)&one; + unsigned char *bytes = (unsigned char *)&val; + ret = _PyLong_AsByteArray((PyLongObject *)v, + bytes, sizeof(val), + is_little, !is_unsigned); +#else + PyObject *stepval = NULL, *mask = NULL, *shift = NULL; + int bits, remaining_bits, is_negative = 0; + long idigit; + int chunk_size = (sizeof(long) < 8) ? 30 : 62; + if (unlikely(!PyLong_CheckExact(v))) { + PyObject *tmp = v; + v = PyNumber_Long(v); + assert(PyLong_CheckExact(v)); + Py_DECREF(tmp); + if (unlikely(!v)) return (int64_t) -1; + } +#if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 + if (Py_SIZE(x) == 0) + return (int64_t) 0; + is_negative = Py_SIZE(x) < 0; +#else + { + int result = PyObject_RichCompareBool(x, Py_False, Py_LT); + if (unlikely(result < 0)) + return (int64_t) -1; + is_negative = result == 1; + } +#endif + if (is_unsigned && unlikely(is_negative)) { + goto raise_neg_overflow; + } else if (is_negative) { + stepval = PyNumber_Invert(v); + if (unlikely(!stepval)) + return (int64_t) -1; + } else { + stepval = __Pyx_NewRef(v); + } + val = (int64_t) 0; + mask = PyLong_FromLong((1L << chunk_size) - 1); if (unlikely(!mask)) goto done; + shift = PyLong_FromLong(chunk_size); if (unlikely(!shift)) goto done; + for (bits = 0; bits < (int) sizeof(int64_t) * 8 - chunk_size; bits += chunk_size) { + PyObject *tmp, *digit; + digit = PyNumber_And(stepval, mask); + if (unlikely(!digit)) goto done; + idigit = PyLong_AsLong(digit); + Py_DECREF(digit); + if (unlikely(idigit < 0)) goto done; + tmp = PyNumber_Rshift(stepval, shift); + if (unlikely(!tmp)) goto done; + Py_DECREF(stepval); stepval = tmp; + val |= ((int64_t) idigit) << bits; + #if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 + if (Py_SIZE(stepval) == 0) + goto unpacking_done; + #endif + } + idigit = PyLong_AsLong(stepval); + if (unlikely(idigit < 0)) goto done; + remaining_bits = ((int) sizeof(int64_t) * 8) - bits - (is_unsigned ? 0 : 1); + if (unlikely(idigit >= (1L << remaining_bits))) + goto raise_overflow; + val |= ((int64_t) idigit) << bits; + #if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 + unpacking_done: + #endif + if (!is_unsigned) { + if (unlikely(val & (((int64_t) 1) << (sizeof(int64_t) * 8 - 1)))) + goto raise_overflow; + if (is_negative) + val = ~val; + } + ret = 0; + done: + Py_XDECREF(shift); + Py_XDECREF(mask); + Py_XDECREF(stepval); +#endif + Py_DECREF(v); + if (likely(!ret)) + return val; + } + return (int64_t) -1; + } + } else { + int64_t val; + PyObject *tmp = __Pyx_PyNumber_IntOrLong(x); + if (!tmp) return (int64_t) -1; + val = __Pyx_PyInt_As_int64_t(tmp); + Py_DECREF(tmp); + return val; + } +raise_overflow: + PyErr_SetString(PyExc_OverflowError, + "value too large to convert to int64_t"); + return (int64_t) -1; +raise_neg_overflow: + PyErr_SetString(PyExc_OverflowError, + "can't convert negative value to int64_t"); + return (int64_t) -1; +} + +/* CIntFromPy */ + static CYTHON_INLINE int32_t __Pyx_PyInt_As_int32_t(PyObject *x) { +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#endif + const int32_t neg_one = (int32_t) -1, const_zero = (int32_t) 0; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic pop +#endif + const int is_unsigned = neg_one > const_zero; +#if PY_MAJOR_VERSION < 3 + if (likely(PyInt_Check(x))) { + if ((sizeof(int32_t) < sizeof(long))) { + __PYX_VERIFY_RETURN_INT(int32_t, long, PyInt_AS_LONG(x)) + } else { + long val = PyInt_AS_LONG(x); + if (is_unsigned && unlikely(val < 0)) { + goto raise_neg_overflow; + } + return (int32_t) val; + } + } else +#endif + if (likely(PyLong_Check(x))) { + if (is_unsigned) { +#if CYTHON_USE_PYLONG_INTERNALS + if (unlikely(__Pyx_PyLong_IsNeg(x))) { + goto raise_neg_overflow; + } else if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(int32_t, __Pyx_compact_upylong, __Pyx_PyLong_CompactValueUnsigned(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_DigitCount(x)) { + case 2: + if ((8 * sizeof(int32_t) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int32_t, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int32_t) >= 2 * PyLong_SHIFT)) { + return (int32_t) (((((int32_t)digits[1]) << PyLong_SHIFT) | (int32_t)digits[0])); + } + } + break; + case 3: + if ((8 * sizeof(int32_t) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int32_t, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int32_t) >= 3 * PyLong_SHIFT)) { + return (int32_t) (((((((int32_t)digits[2]) << PyLong_SHIFT) | (int32_t)digits[1]) << PyLong_SHIFT) | (int32_t)digits[0])); + } + } + break; + case 4: + if ((8 * sizeof(int32_t) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int32_t, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int32_t) >= 4 * PyLong_SHIFT)) { + return (int32_t) (((((((((int32_t)digits[3]) << PyLong_SHIFT) | (int32_t)digits[2]) << PyLong_SHIFT) | (int32_t)digits[1]) << PyLong_SHIFT) | (int32_t)digits[0])); + } + } + break; + } + } +#endif +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030C00A7 + if (unlikely(Py_SIZE(x) < 0)) { + goto raise_neg_overflow; + } +#else + { + int result = PyObject_RichCompareBool(x, Py_False, Py_LT); + if (unlikely(result < 0)) + return (int32_t) -1; + if (unlikely(result == 1)) + goto raise_neg_overflow; + } +#endif + if ((sizeof(int32_t) <= sizeof(unsigned long))) { + __PYX_VERIFY_RETURN_INT_EXC(int32_t, unsigned long, PyLong_AsUnsignedLong(x)) +#ifdef HAVE_LONG_LONG + } else if ((sizeof(int32_t) <= sizeof(unsigned PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(int32_t, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x)) +#endif + } + } else { +#if CYTHON_USE_PYLONG_INTERNALS + if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(int32_t, __Pyx_compact_pylong, __Pyx_PyLong_CompactValue(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_SignedDigitCount(x)) { + case -2: + if ((8 * sizeof(int32_t) - 1 > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int32_t, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int32_t) - 1 > 2 * PyLong_SHIFT)) { + return (int32_t) (((int32_t)-1)*(((((int32_t)digits[1]) << PyLong_SHIFT) | (int32_t)digits[0]))); + } + } + break; + case 2: + if ((8 * sizeof(int32_t) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int32_t, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int32_t) - 1 > 2 * PyLong_SHIFT)) { + return (int32_t) ((((((int32_t)digits[1]) << PyLong_SHIFT) | (int32_t)digits[0]))); + } + } + break; + case -3: + if ((8 * sizeof(int32_t) - 1 > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int32_t, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int32_t) - 1 > 3 * PyLong_SHIFT)) { + return (int32_t) (((int32_t)-1)*(((((((int32_t)digits[2]) << PyLong_SHIFT) | (int32_t)digits[1]) << PyLong_SHIFT) | (int32_t)digits[0]))); + } + } + break; + case 3: + if ((8 * sizeof(int32_t) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int32_t, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int32_t) - 1 > 3 * PyLong_SHIFT)) { + return (int32_t) ((((((((int32_t)digits[2]) << PyLong_SHIFT) | (int32_t)digits[1]) << PyLong_SHIFT) | (int32_t)digits[0]))); + } + } + break; + case -4: + if ((8 * sizeof(int32_t) - 1 > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int32_t, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int32_t) - 1 > 4 * PyLong_SHIFT)) { + return (int32_t) (((int32_t)-1)*(((((((((int32_t)digits[3]) << PyLong_SHIFT) | (int32_t)digits[2]) << PyLong_SHIFT) | (int32_t)digits[1]) << PyLong_SHIFT) | (int32_t)digits[0]))); + } + } + break; + case 4: + if ((8 * sizeof(int32_t) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int32_t, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int32_t) - 1 > 4 * PyLong_SHIFT)) { + return (int32_t) ((((((((((int32_t)digits[3]) << PyLong_SHIFT) | (int32_t)digits[2]) << PyLong_SHIFT) | (int32_t)digits[1]) << PyLong_SHIFT) | (int32_t)digits[0]))); + } + } + break; + } + } +#endif + if ((sizeof(int32_t) <= sizeof(long))) { + __PYX_VERIFY_RETURN_INT_EXC(int32_t, long, PyLong_AsLong(x)) +#ifdef HAVE_LONG_LONG + } else if ((sizeof(int32_t) <= sizeof(PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(int32_t, PY_LONG_LONG, PyLong_AsLongLong(x)) +#endif + } + } + { + int32_t val; + PyObject *v = __Pyx_PyNumber_IntOrLong(x); +#if PY_MAJOR_VERSION < 3 + if (likely(v) && !PyLong_Check(v)) { + PyObject *tmp = v; + v = PyNumber_Long(tmp); + Py_DECREF(tmp); + } +#endif + if (likely(v)) { + int ret = -1; +#if PY_VERSION_HEX < 0x030d0000 && !(CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API) || defined(_PyLong_AsByteArray) + int one = 1; int is_little = (int)*(unsigned char *)&one; + unsigned char *bytes = (unsigned char *)&val; + ret = _PyLong_AsByteArray((PyLongObject *)v, + bytes, sizeof(val), + is_little, !is_unsigned); +#else + PyObject *stepval = NULL, *mask = NULL, *shift = NULL; + int bits, remaining_bits, is_negative = 0; + long idigit; + int chunk_size = (sizeof(long) < 8) ? 30 : 62; + if (unlikely(!PyLong_CheckExact(v))) { + PyObject *tmp = v; + v = PyNumber_Long(v); + assert(PyLong_CheckExact(v)); + Py_DECREF(tmp); + if (unlikely(!v)) return (int32_t) -1; + } +#if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 + if (Py_SIZE(x) == 0) + return (int32_t) 0; + is_negative = Py_SIZE(x) < 0; +#else + { + int result = PyObject_RichCompareBool(x, Py_False, Py_LT); + if (unlikely(result < 0)) + return (int32_t) -1; + is_negative = result == 1; + } +#endif + if (is_unsigned && unlikely(is_negative)) { + goto raise_neg_overflow; + } else if (is_negative) { + stepval = PyNumber_Invert(v); + if (unlikely(!stepval)) + return (int32_t) -1; + } else { + stepval = __Pyx_NewRef(v); + } + val = (int32_t) 0; + mask = PyLong_FromLong((1L << chunk_size) - 1); if (unlikely(!mask)) goto done; + shift = PyLong_FromLong(chunk_size); if (unlikely(!shift)) goto done; + for (bits = 0; bits < (int) sizeof(int32_t) * 8 - chunk_size; bits += chunk_size) { + PyObject *tmp, *digit; + digit = PyNumber_And(stepval, mask); + if (unlikely(!digit)) goto done; + idigit = PyLong_AsLong(digit); + Py_DECREF(digit); + if (unlikely(idigit < 0)) goto done; + tmp = PyNumber_Rshift(stepval, shift); + if (unlikely(!tmp)) goto done; + Py_DECREF(stepval); stepval = tmp; + val |= ((int32_t) idigit) << bits; + #if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 + if (Py_SIZE(stepval) == 0) + goto unpacking_done; + #endif + } + idigit = PyLong_AsLong(stepval); + if (unlikely(idigit < 0)) goto done; + remaining_bits = ((int) sizeof(int32_t) * 8) - bits - (is_unsigned ? 0 : 1); + if (unlikely(idigit >= (1L << remaining_bits))) + goto raise_overflow; + val |= ((int32_t) idigit) << bits; + #if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 + unpacking_done: + #endif + if (!is_unsigned) { + if (unlikely(val & (((int32_t) 1) << (sizeof(int32_t) * 8 - 1)))) + goto raise_overflow; + if (is_negative) + val = ~val; + } + ret = 0; + done: + Py_XDECREF(shift); + Py_XDECREF(mask); + Py_XDECREF(stepval); +#endif + Py_DECREF(v); + if (likely(!ret)) + return val; + } + return (int32_t) -1; + } + } else { + int32_t val; + PyObject *tmp = __Pyx_PyNumber_IntOrLong(x); + if (!tmp) return (int32_t) -1; + val = __Pyx_PyInt_As_int32_t(tmp); + Py_DECREF(tmp); + return val; + } +raise_overflow: + PyErr_SetString(PyExc_OverflowError, + "value too large to convert to int32_t"); + return (int32_t) -1; +raise_neg_overflow: + PyErr_SetString(PyExc_OverflowError, + "can't convert negative value to int32_t"); + return (int32_t) -1; +} + +/* CIntToPy */ + static CYTHON_INLINE PyObject* __Pyx_PyInt_From_int64_t(int64_t value) { +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#endif + const int64_t neg_one = (int64_t) -1, const_zero = (int64_t) 0; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic pop +#endif + const int is_unsigned = neg_one > const_zero; + if (is_unsigned) { + if (sizeof(int64_t) < sizeof(long)) { + return PyInt_FromLong((long) value); + } else if (sizeof(int64_t) <= sizeof(unsigned long)) { + return PyLong_FromUnsignedLong((unsigned long) value); +#ifdef HAVE_LONG_LONG + } else if (sizeof(int64_t) <= sizeof(unsigned PY_LONG_LONG)) { + return PyLong_FromUnsignedLongLong((unsigned PY_LONG_LONG) value); +#endif + } + } else { + if (sizeof(int64_t) <= sizeof(long)) { + return PyInt_FromLong((long) value); +#ifdef HAVE_LONG_LONG + } else if (sizeof(int64_t) <= sizeof(PY_LONG_LONG)) { + return PyLong_FromLongLong((PY_LONG_LONG) value); +#endif + } + } + { + int one = 1; int little = (int)*(unsigned char *)&one; + unsigned char *bytes = (unsigned char *)&value; +#if !CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030d0000 + return _PyLong_FromByteArray(bytes, sizeof(int64_t), + little, !is_unsigned); +#else + PyObject *from_bytes, *result = NULL; + PyObject *py_bytes = NULL, *arg_tuple = NULL, *kwds = NULL, *order_str = NULL; + from_bytes = PyObject_GetAttrString((PyObject*)&PyLong_Type, "from_bytes"); + if (!from_bytes) return NULL; + py_bytes = PyBytes_FromStringAndSize((char*)bytes, sizeof(int64_t)); + if (!py_bytes) goto limited_bad; + order_str = PyUnicode_FromString(little ? "little" : "big"); + if (!order_str) goto limited_bad; + arg_tuple = PyTuple_Pack(2, py_bytes, order_str); + if (!arg_tuple) goto limited_bad; + if (!is_unsigned) { + kwds = PyDict_New(); + if (!kwds) goto limited_bad; + if (PyDict_SetItemString(kwds, "signed", __Pyx_NewRef(Py_True))) goto limited_bad; + } + result = PyObject_Call(from_bytes, arg_tuple, kwds); + limited_bad: + Py_XDECREF(kwds); + Py_XDECREF(arg_tuple); + Py_XDECREF(order_str); + Py_XDECREF(py_bytes); + Py_XDECREF(from_bytes); + return result; +#endif + } +} + +/* CIntToPy */ + static CYTHON_INLINE PyObject* __Pyx_PyInt_From_int32_t(int32_t value) { +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#endif + const int32_t neg_one = (int32_t) -1, const_zero = (int32_t) 0; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic pop +#endif + const int is_unsigned = neg_one > const_zero; + if (is_unsigned) { + if (sizeof(int32_t) < sizeof(long)) { + return PyInt_FromLong((long) value); + } else if (sizeof(int32_t) <= sizeof(unsigned long)) { + return PyLong_FromUnsignedLong((unsigned long) value); +#ifdef HAVE_LONG_LONG + } else if (sizeof(int32_t) <= sizeof(unsigned PY_LONG_LONG)) { + return PyLong_FromUnsignedLongLong((unsigned PY_LONG_LONG) value); +#endif + } + } else { + if (sizeof(int32_t) <= sizeof(long)) { + return PyInt_FromLong((long) value); +#ifdef HAVE_LONG_LONG + } else if (sizeof(int32_t) <= sizeof(PY_LONG_LONG)) { + return PyLong_FromLongLong((PY_LONG_LONG) value); +#endif + } + } + { + int one = 1; int little = (int)*(unsigned char *)&one; + unsigned char *bytes = (unsigned char *)&value; +#if !CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030d0000 + return _PyLong_FromByteArray(bytes, sizeof(int32_t), + little, !is_unsigned); +#else + PyObject *from_bytes, *result = NULL; + PyObject *py_bytes = NULL, *arg_tuple = NULL, *kwds = NULL, *order_str = NULL; + from_bytes = PyObject_GetAttrString((PyObject*)&PyLong_Type, "from_bytes"); + if (!from_bytes) return NULL; + py_bytes = PyBytes_FromStringAndSize((char*)bytes, sizeof(int32_t)); + if (!py_bytes) goto limited_bad; + order_str = PyUnicode_FromString(little ? "little" : "big"); + if (!order_str) goto limited_bad; + arg_tuple = PyTuple_Pack(2, py_bytes, order_str); + if (!arg_tuple) goto limited_bad; + if (!is_unsigned) { + kwds = PyDict_New(); + if (!kwds) goto limited_bad; + if (PyDict_SetItemString(kwds, "signed", __Pyx_NewRef(Py_True))) goto limited_bad; + } + result = PyObject_Call(from_bytes, arg_tuple, kwds); + limited_bad: + Py_XDECREF(kwds); + Py_XDECREF(arg_tuple); + Py_XDECREF(order_str); + Py_XDECREF(py_bytes); + Py_XDECREF(from_bytes); + return result; +#endif + } +} + +/* CIntFromPy */ + static CYTHON_INLINE int __Pyx_PyInt_As_int(PyObject *x) { +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#endif + const int neg_one = (int) -1, const_zero = (int) 0; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic pop +#endif + const int is_unsigned = neg_one > const_zero; +#if PY_MAJOR_VERSION < 3 + if (likely(PyInt_Check(x))) { + if ((sizeof(int) < sizeof(long))) { + __PYX_VERIFY_RETURN_INT(int, long, PyInt_AS_LONG(x)) + } else { + long val = PyInt_AS_LONG(x); + if (is_unsigned && unlikely(val < 0)) { + goto raise_neg_overflow; + } + return (int) val; + } + } else +#endif + if (likely(PyLong_Check(x))) { + if (is_unsigned) { +#if CYTHON_USE_PYLONG_INTERNALS + if (unlikely(__Pyx_PyLong_IsNeg(x))) { + goto raise_neg_overflow; + } else if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(int, __Pyx_compact_upylong, __Pyx_PyLong_CompactValueUnsigned(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_DigitCount(x)) { + case 2: + if ((8 * sizeof(int) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) >= 2 * PyLong_SHIFT)) { + return (int) (((((int)digits[1]) << PyLong_SHIFT) | (int)digits[0])); + } + } + break; + case 3: + if ((8 * sizeof(int) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) >= 3 * PyLong_SHIFT)) { + return (int) (((((((int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0])); + } + } + break; + case 4: + if ((8 * sizeof(int) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) >= 4 * PyLong_SHIFT)) { + return (int) (((((((((int)digits[3]) << PyLong_SHIFT) | (int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0])); + } + } + break; + } + } +#endif +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030C00A7 + if (unlikely(Py_SIZE(x) < 0)) { + goto raise_neg_overflow; + } +#else + { + int result = PyObject_RichCompareBool(x, Py_False, Py_LT); + if (unlikely(result < 0)) + return (int) -1; + if (unlikely(result == 1)) + goto raise_neg_overflow; + } +#endif + if ((sizeof(int) <= sizeof(unsigned long))) { + __PYX_VERIFY_RETURN_INT_EXC(int, unsigned long, PyLong_AsUnsignedLong(x)) +#ifdef HAVE_LONG_LONG + } else if ((sizeof(int) <= sizeof(unsigned PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(int, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x)) +#endif + } + } else { +#if CYTHON_USE_PYLONG_INTERNALS + if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(int, __Pyx_compact_pylong, __Pyx_PyLong_CompactValue(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_SignedDigitCount(x)) { + case -2: + if ((8 * sizeof(int) - 1 > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) - 1 > 2 * PyLong_SHIFT)) { + return (int) (((int)-1)*(((((int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); + } + } + break; + case 2: + if ((8 * sizeof(int) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) - 1 > 2 * PyLong_SHIFT)) { + return (int) ((((((int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); + } + } + break; + case -3: + if ((8 * sizeof(int) - 1 > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) - 1 > 3 * PyLong_SHIFT)) { + return (int) (((int)-1)*(((((((int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); + } + } + break; + case 3: + if ((8 * sizeof(int) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) - 1 > 3 * PyLong_SHIFT)) { + return (int) ((((((((int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); + } + } + break; + case -4: + if ((8 * sizeof(int) - 1 > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) - 1 > 4 * PyLong_SHIFT)) { + return (int) (((int)-1)*(((((((((int)digits[3]) << PyLong_SHIFT) | (int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); + } + } + break; + case 4: + if ((8 * sizeof(int) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) - 1 > 4 * PyLong_SHIFT)) { + return (int) ((((((((((int)digits[3]) << PyLong_SHIFT) | (int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); + } + } + break; + } + } +#endif + if ((sizeof(int) <= sizeof(long))) { + __PYX_VERIFY_RETURN_INT_EXC(int, long, PyLong_AsLong(x)) +#ifdef HAVE_LONG_LONG + } else if ((sizeof(int) <= sizeof(PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(int, PY_LONG_LONG, PyLong_AsLongLong(x)) +#endif + } + } + { + int val; + PyObject *v = __Pyx_PyNumber_IntOrLong(x); +#if PY_MAJOR_VERSION < 3 + if (likely(v) && !PyLong_Check(v)) { + PyObject *tmp = v; + v = PyNumber_Long(tmp); + Py_DECREF(tmp); + } +#endif + if (likely(v)) { + int ret = -1; +#if PY_VERSION_HEX < 0x030d0000 && !(CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API) || defined(_PyLong_AsByteArray) + int one = 1; int is_little = (int)*(unsigned char *)&one; + unsigned char *bytes = (unsigned char *)&val; + ret = _PyLong_AsByteArray((PyLongObject *)v, + bytes, sizeof(val), + is_little, !is_unsigned); +#else + PyObject *stepval = NULL, *mask = NULL, *shift = NULL; + int bits, remaining_bits, is_negative = 0; + long idigit; + int chunk_size = (sizeof(long) < 8) ? 30 : 62; + if (unlikely(!PyLong_CheckExact(v))) { + PyObject *tmp = v; + v = PyNumber_Long(v); + assert(PyLong_CheckExact(v)); + Py_DECREF(tmp); + if (unlikely(!v)) return (int) -1; + } +#if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 + if (Py_SIZE(x) == 0) + return (int) 0; + is_negative = Py_SIZE(x) < 0; +#else + { + int result = PyObject_RichCompareBool(x, Py_False, Py_LT); + if (unlikely(result < 0)) + return (int) -1; + is_negative = result == 1; + } +#endif + if (is_unsigned && unlikely(is_negative)) { + goto raise_neg_overflow; + } else if (is_negative) { + stepval = PyNumber_Invert(v); + if (unlikely(!stepval)) + return (int) -1; + } else { + stepval = __Pyx_NewRef(v); + } + val = (int) 0; + mask = PyLong_FromLong((1L << chunk_size) - 1); if (unlikely(!mask)) goto done; + shift = PyLong_FromLong(chunk_size); if (unlikely(!shift)) goto done; + for (bits = 0; bits < (int) sizeof(int) * 8 - chunk_size; bits += chunk_size) { + PyObject *tmp, *digit; + digit = PyNumber_And(stepval, mask); + if (unlikely(!digit)) goto done; + idigit = PyLong_AsLong(digit); + Py_DECREF(digit); + if (unlikely(idigit < 0)) goto done; + tmp = PyNumber_Rshift(stepval, shift); + if (unlikely(!tmp)) goto done; + Py_DECREF(stepval); stepval = tmp; + val |= ((int) idigit) << bits; + #if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 + if (Py_SIZE(stepval) == 0) + goto unpacking_done; + #endif + } + idigit = PyLong_AsLong(stepval); + if (unlikely(idigit < 0)) goto done; + remaining_bits = ((int) sizeof(int) * 8) - bits - (is_unsigned ? 0 : 1); + if (unlikely(idigit >= (1L << remaining_bits))) + goto raise_overflow; + val |= ((int) idigit) << bits; + #if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 + unpacking_done: + #endif + if (!is_unsigned) { + if (unlikely(val & (((int) 1) << (sizeof(int) * 8 - 1)))) + goto raise_overflow; + if (is_negative) + val = ~val; + } + ret = 0; + done: + Py_XDECREF(shift); + Py_XDECREF(mask); + Py_XDECREF(stepval); +#endif + Py_DECREF(v); + if (likely(!ret)) + return val; + } + return (int) -1; + } + } else { + int val; + PyObject *tmp = __Pyx_PyNumber_IntOrLong(x); + if (!tmp) return (int) -1; + val = __Pyx_PyInt_As_int(tmp); + Py_DECREF(tmp); + return val; + } +raise_overflow: + PyErr_SetString(PyExc_OverflowError, + "value too large to convert to int"); + return (int) -1; +raise_neg_overflow: + PyErr_SetString(PyExc_OverflowError, + "can't convert negative value to int"); + return (int) -1; +} + +/* CIntFromPy */ + static CYTHON_INLINE long __Pyx_PyInt_As_long(PyObject *x) { +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#endif + const long neg_one = (long) -1, const_zero = (long) 0; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic pop +#endif + const int is_unsigned = neg_one > const_zero; +#if PY_MAJOR_VERSION < 3 + if (likely(PyInt_Check(x))) { + if ((sizeof(long) < sizeof(long))) { + __PYX_VERIFY_RETURN_INT(long, long, PyInt_AS_LONG(x)) + } else { + long val = PyInt_AS_LONG(x); + if (is_unsigned && unlikely(val < 0)) { + goto raise_neg_overflow; + } + return (long) val; + } + } else +#endif + if (likely(PyLong_Check(x))) { + if (is_unsigned) { +#if CYTHON_USE_PYLONG_INTERNALS + if (unlikely(__Pyx_PyLong_IsNeg(x))) { + goto raise_neg_overflow; + } else if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(long, __Pyx_compact_upylong, __Pyx_PyLong_CompactValueUnsigned(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_DigitCount(x)) { + case 2: + if ((8 * sizeof(long) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) >= 2 * PyLong_SHIFT)) { + return (long) (((((long)digits[1]) << PyLong_SHIFT) | (long)digits[0])); + } + } + break; + case 3: + if ((8 * sizeof(long) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) >= 3 * PyLong_SHIFT)) { + return (long) (((((((long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0])); + } + } + break; + case 4: + if ((8 * sizeof(long) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) >= 4 * PyLong_SHIFT)) { + return (long) (((((((((long)digits[3]) << PyLong_SHIFT) | (long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0])); + } + } + break; + } + } +#endif +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030C00A7 + if (unlikely(Py_SIZE(x) < 0)) { + goto raise_neg_overflow; + } +#else + { + int result = PyObject_RichCompareBool(x, Py_False, Py_LT); + if (unlikely(result < 0)) + return (long) -1; + if (unlikely(result == 1)) + goto raise_neg_overflow; + } +#endif + if ((sizeof(long) <= sizeof(unsigned long))) { + __PYX_VERIFY_RETURN_INT_EXC(long, unsigned long, PyLong_AsUnsignedLong(x)) +#ifdef HAVE_LONG_LONG + } else if ((sizeof(long) <= sizeof(unsigned PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(long, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x)) +#endif + } + } else { +#if CYTHON_USE_PYLONG_INTERNALS + if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(long, __Pyx_compact_pylong, __Pyx_PyLong_CompactValue(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_SignedDigitCount(x)) { + case -2: + if ((8 * sizeof(long) - 1 > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) - 1 > 2 * PyLong_SHIFT)) { + return (long) (((long)-1)*(((((long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); + } + } + break; + case 2: + if ((8 * sizeof(long) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) - 1 > 2 * PyLong_SHIFT)) { + return (long) ((((((long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); + } + } + break; + case -3: + if ((8 * sizeof(long) - 1 > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) - 1 > 3 * PyLong_SHIFT)) { + return (long) (((long)-1)*(((((((long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); + } + } + break; + case 3: + if ((8 * sizeof(long) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) - 1 > 3 * PyLong_SHIFT)) { + return (long) ((((((((long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); + } + } + break; + case -4: + if ((8 * sizeof(long) - 1 > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) - 1 > 4 * PyLong_SHIFT)) { + return (long) (((long)-1)*(((((((((long)digits[3]) << PyLong_SHIFT) | (long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); + } + } + break; + case 4: + if ((8 * sizeof(long) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) - 1 > 4 * PyLong_SHIFT)) { + return (long) ((((((((((long)digits[3]) << PyLong_SHIFT) | (long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); + } + } + break; + } + } +#endif + if ((sizeof(long) <= sizeof(long))) { + __PYX_VERIFY_RETURN_INT_EXC(long, long, PyLong_AsLong(x)) +#ifdef HAVE_LONG_LONG + } else if ((sizeof(long) <= sizeof(PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(long, PY_LONG_LONG, PyLong_AsLongLong(x)) +#endif + } + } + { + long val; + PyObject *v = __Pyx_PyNumber_IntOrLong(x); +#if PY_MAJOR_VERSION < 3 + if (likely(v) && !PyLong_Check(v)) { + PyObject *tmp = v; + v = PyNumber_Long(tmp); + Py_DECREF(tmp); + } +#endif + if (likely(v)) { + int ret = -1; +#if PY_VERSION_HEX < 0x030d0000 && !(CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API) || defined(_PyLong_AsByteArray) + int one = 1; int is_little = (int)*(unsigned char *)&one; + unsigned char *bytes = (unsigned char *)&val; + ret = _PyLong_AsByteArray((PyLongObject *)v, + bytes, sizeof(val), + is_little, !is_unsigned); +#else + PyObject *stepval = NULL, *mask = NULL, *shift = NULL; + int bits, remaining_bits, is_negative = 0; + long idigit; + int chunk_size = (sizeof(long) < 8) ? 30 : 62; + if (unlikely(!PyLong_CheckExact(v))) { + PyObject *tmp = v; + v = PyNumber_Long(v); + assert(PyLong_CheckExact(v)); + Py_DECREF(tmp); + if (unlikely(!v)) return (long) -1; + } +#if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 + if (Py_SIZE(x) == 0) + return (long) 0; + is_negative = Py_SIZE(x) < 0; +#else + { + int result = PyObject_RichCompareBool(x, Py_False, Py_LT); + if (unlikely(result < 0)) + return (long) -1; + is_negative = result == 1; + } +#endif + if (is_unsigned && unlikely(is_negative)) { + goto raise_neg_overflow; + } else if (is_negative) { + stepval = PyNumber_Invert(v); + if (unlikely(!stepval)) + return (long) -1; + } else { + stepval = __Pyx_NewRef(v); + } + val = (long) 0; + mask = PyLong_FromLong((1L << chunk_size) - 1); if (unlikely(!mask)) goto done; + shift = PyLong_FromLong(chunk_size); if (unlikely(!shift)) goto done; + for (bits = 0; bits < (int) sizeof(long) * 8 - chunk_size; bits += chunk_size) { + PyObject *tmp, *digit; + digit = PyNumber_And(stepval, mask); + if (unlikely(!digit)) goto done; + idigit = PyLong_AsLong(digit); + Py_DECREF(digit); + if (unlikely(idigit < 0)) goto done; + tmp = PyNumber_Rshift(stepval, shift); + if (unlikely(!tmp)) goto done; + Py_DECREF(stepval); stepval = tmp; + val |= ((long) idigit) << bits; + #if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 + if (Py_SIZE(stepval) == 0) + goto unpacking_done; + #endif + } + idigit = PyLong_AsLong(stepval); + if (unlikely(idigit < 0)) goto done; + remaining_bits = ((int) sizeof(long) * 8) - bits - (is_unsigned ? 0 : 1); + if (unlikely(idigit >= (1L << remaining_bits))) + goto raise_overflow; + val |= ((long) idigit) << bits; + #if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 + unpacking_done: + #endif + if (!is_unsigned) { + if (unlikely(val & (((long) 1) << (sizeof(long) * 8 - 1)))) + goto raise_overflow; + if (is_negative) + val = ~val; + } + ret = 0; + done: + Py_XDECREF(shift); + Py_XDECREF(mask); + Py_XDECREF(stepval); +#endif + Py_DECREF(v); + if (likely(!ret)) + return val; + } + return (long) -1; + } + } else { + long val; + PyObject *tmp = __Pyx_PyNumber_IntOrLong(x); + if (!tmp) return (long) -1; + val = __Pyx_PyInt_As_long(tmp); + Py_DECREF(tmp); + return val; + } +raise_overflow: + PyErr_SetString(PyExc_OverflowError, + "value too large to convert to long"); + return (long) -1; +raise_neg_overflow: + PyErr_SetString(PyExc_OverflowError, + "can't convert negative value to long"); + return (long) -1; +} + +/* CIntToPy */ + static CYTHON_INLINE PyObject* __Pyx_PyInt_From_int(int value) { +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#endif + const int neg_one = (int) -1, const_zero = (int) 0; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic pop +#endif + const int is_unsigned = neg_one > const_zero; + if (is_unsigned) { + if (sizeof(int) < sizeof(long)) { + return PyInt_FromLong((long) value); + } else if (sizeof(int) <= sizeof(unsigned long)) { + return PyLong_FromUnsignedLong((unsigned long) value); +#ifdef HAVE_LONG_LONG + } else if (sizeof(int) <= sizeof(unsigned PY_LONG_LONG)) { + return PyLong_FromUnsignedLongLong((unsigned PY_LONG_LONG) value); +#endif + } + } else { + if (sizeof(int) <= sizeof(long)) { + return PyInt_FromLong((long) value); +#ifdef HAVE_LONG_LONG + } else if (sizeof(int) <= sizeof(PY_LONG_LONG)) { + return PyLong_FromLongLong((PY_LONG_LONG) value); +#endif + } + } + { + int one = 1; int little = (int)*(unsigned char *)&one; + unsigned char *bytes = (unsigned char *)&value; +#if !CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030d0000 + return _PyLong_FromByteArray(bytes, sizeof(int), + little, !is_unsigned); +#else + PyObject *from_bytes, *result = NULL; + PyObject *py_bytes = NULL, *arg_tuple = NULL, *kwds = NULL, *order_str = NULL; + from_bytes = PyObject_GetAttrString((PyObject*)&PyLong_Type, "from_bytes"); + if (!from_bytes) return NULL; + py_bytes = PyBytes_FromStringAndSize((char*)bytes, sizeof(int)); + if (!py_bytes) goto limited_bad; + order_str = PyUnicode_FromString(little ? "little" : "big"); + if (!order_str) goto limited_bad; + arg_tuple = PyTuple_Pack(2, py_bytes, order_str); + if (!arg_tuple) goto limited_bad; + if (!is_unsigned) { + kwds = PyDict_New(); + if (!kwds) goto limited_bad; + if (PyDict_SetItemString(kwds, "signed", __Pyx_NewRef(Py_True))) goto limited_bad; + } + result = PyObject_Call(from_bytes, arg_tuple, kwds); + limited_bad: + Py_XDECREF(kwds); + Py_XDECREF(arg_tuple); + Py_XDECREF(order_str); + Py_XDECREF(py_bytes); + Py_XDECREF(from_bytes); + return result; +#endif + } +} + +/* CIntToPy */ + static CYTHON_INLINE PyObject* __Pyx_PyInt_From_long(long value) { +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#endif + const long neg_one = (long) -1, const_zero = (long) 0; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic pop +#endif + const int is_unsigned = neg_one > const_zero; + if (is_unsigned) { + if (sizeof(long) < sizeof(long)) { + return PyInt_FromLong((long) value); + } else if (sizeof(long) <= sizeof(unsigned long)) { + return PyLong_FromUnsignedLong((unsigned long) value); +#ifdef HAVE_LONG_LONG + } else if (sizeof(long) <= sizeof(unsigned PY_LONG_LONG)) { + return PyLong_FromUnsignedLongLong((unsigned PY_LONG_LONG) value); +#endif + } + } else { + if (sizeof(long) <= sizeof(long)) { + return PyInt_FromLong((long) value); +#ifdef HAVE_LONG_LONG + } else if (sizeof(long) <= sizeof(PY_LONG_LONG)) { + return PyLong_FromLongLong((PY_LONG_LONG) value); +#endif + } + } + { + int one = 1; int little = (int)*(unsigned char *)&one; + unsigned char *bytes = (unsigned char *)&value; +#if !CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030d0000 + return _PyLong_FromByteArray(bytes, sizeof(long), + little, !is_unsigned); +#else + PyObject *from_bytes, *result = NULL; + PyObject *py_bytes = NULL, *arg_tuple = NULL, *kwds = NULL, *order_str = NULL; + from_bytes = PyObject_GetAttrString((PyObject*)&PyLong_Type, "from_bytes"); + if (!from_bytes) return NULL; + py_bytes = PyBytes_FromStringAndSize((char*)bytes, sizeof(long)); + if (!py_bytes) goto limited_bad; + order_str = PyUnicode_FromString(little ? "little" : "big"); + if (!order_str) goto limited_bad; + arg_tuple = PyTuple_Pack(2, py_bytes, order_str); + if (!arg_tuple) goto limited_bad; + if (!is_unsigned) { + kwds = PyDict_New(); + if (!kwds) goto limited_bad; + if (PyDict_SetItemString(kwds, "signed", __Pyx_NewRef(Py_True))) goto limited_bad; + } + result = PyObject_Call(from_bytes, arg_tuple, kwds); + limited_bad: + Py_XDECREF(kwds); + Py_XDECREF(arg_tuple); + Py_XDECREF(order_str); + Py_XDECREF(py_bytes); + Py_XDECREF(from_bytes); + return result; +#endif + } +} + +/* CIntFromPy */ + static CYTHON_INLINE char __Pyx_PyInt_As_char(PyObject *x) { +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#endif + const char neg_one = (char) -1, const_zero = (char) 0; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic pop +#endif + const int is_unsigned = neg_one > const_zero; +#if PY_MAJOR_VERSION < 3 + if (likely(PyInt_Check(x))) { + if ((sizeof(char) < sizeof(long))) { + __PYX_VERIFY_RETURN_INT(char, long, PyInt_AS_LONG(x)) + } else { + long val = PyInt_AS_LONG(x); + if (is_unsigned && unlikely(val < 0)) { + goto raise_neg_overflow; + } + return (char) val; + } + } else +#endif + if (likely(PyLong_Check(x))) { + if (is_unsigned) { +#if CYTHON_USE_PYLONG_INTERNALS + if (unlikely(__Pyx_PyLong_IsNeg(x))) { + goto raise_neg_overflow; + } else if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(char, __Pyx_compact_upylong, __Pyx_PyLong_CompactValueUnsigned(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_DigitCount(x)) { + case 2: + if ((8 * sizeof(char) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(char, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(char) >= 2 * PyLong_SHIFT)) { + return (char) (((((char)digits[1]) << PyLong_SHIFT) | (char)digits[0])); + } + } + break; + case 3: + if ((8 * sizeof(char) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(char, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(char) >= 3 * PyLong_SHIFT)) { + return (char) (((((((char)digits[2]) << PyLong_SHIFT) | (char)digits[1]) << PyLong_SHIFT) | (char)digits[0])); + } + } + break; + case 4: + if ((8 * sizeof(char) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(char, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(char) >= 4 * PyLong_SHIFT)) { + return (char) (((((((((char)digits[3]) << PyLong_SHIFT) | (char)digits[2]) << PyLong_SHIFT) | (char)digits[1]) << PyLong_SHIFT) | (char)digits[0])); + } + } + break; + } + } +#endif +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030C00A7 + if (unlikely(Py_SIZE(x) < 0)) { + goto raise_neg_overflow; + } +#else + { + int result = PyObject_RichCompareBool(x, Py_False, Py_LT); + if (unlikely(result < 0)) + return (char) -1; + if (unlikely(result == 1)) + goto raise_neg_overflow; + } +#endif + if ((sizeof(char) <= sizeof(unsigned long))) { + __PYX_VERIFY_RETURN_INT_EXC(char, unsigned long, PyLong_AsUnsignedLong(x)) +#ifdef HAVE_LONG_LONG + } else if ((sizeof(char) <= sizeof(unsigned PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(char, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x)) +#endif + } + } else { +#if CYTHON_USE_PYLONG_INTERNALS + if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(char, __Pyx_compact_pylong, __Pyx_PyLong_CompactValue(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_SignedDigitCount(x)) { + case -2: + if ((8 * sizeof(char) - 1 > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(char, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(char) - 1 > 2 * PyLong_SHIFT)) { + return (char) (((char)-1)*(((((char)digits[1]) << PyLong_SHIFT) | (char)digits[0]))); + } + } + break; + case 2: + if ((8 * sizeof(char) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(char, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(char) - 1 > 2 * PyLong_SHIFT)) { + return (char) ((((((char)digits[1]) << PyLong_SHIFT) | (char)digits[0]))); + } + } + break; + case -3: + if ((8 * sizeof(char) - 1 > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(char, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(char) - 1 > 3 * PyLong_SHIFT)) { + return (char) (((char)-1)*(((((((char)digits[2]) << PyLong_SHIFT) | (char)digits[1]) << PyLong_SHIFT) | (char)digits[0]))); + } + } + break; + case 3: + if ((8 * sizeof(char) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(char, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(char) - 1 > 3 * PyLong_SHIFT)) { + return (char) ((((((((char)digits[2]) << PyLong_SHIFT) | (char)digits[1]) << PyLong_SHIFT) | (char)digits[0]))); + } + } + break; + case -4: + if ((8 * sizeof(char) - 1 > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(char, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(char) - 1 > 4 * PyLong_SHIFT)) { + return (char) (((char)-1)*(((((((((char)digits[3]) << PyLong_SHIFT) | (char)digits[2]) << PyLong_SHIFT) | (char)digits[1]) << PyLong_SHIFT) | (char)digits[0]))); + } + } + break; + case 4: + if ((8 * sizeof(char) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(char, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(char) - 1 > 4 * PyLong_SHIFT)) { + return (char) ((((((((((char)digits[3]) << PyLong_SHIFT) | (char)digits[2]) << PyLong_SHIFT) | (char)digits[1]) << PyLong_SHIFT) | (char)digits[0]))); + } + } + break; + } + } +#endif + if ((sizeof(char) <= sizeof(long))) { + __PYX_VERIFY_RETURN_INT_EXC(char, long, PyLong_AsLong(x)) +#ifdef HAVE_LONG_LONG + } else if ((sizeof(char) <= sizeof(PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(char, PY_LONG_LONG, PyLong_AsLongLong(x)) +#endif + } + } + { + char val; + PyObject *v = __Pyx_PyNumber_IntOrLong(x); +#if PY_MAJOR_VERSION < 3 + if (likely(v) && !PyLong_Check(v)) { + PyObject *tmp = v; + v = PyNumber_Long(tmp); + Py_DECREF(tmp); + } +#endif + if (likely(v)) { + int ret = -1; +#if PY_VERSION_HEX < 0x030d0000 && !(CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API) || defined(_PyLong_AsByteArray) + int one = 1; int is_little = (int)*(unsigned char *)&one; + unsigned char *bytes = (unsigned char *)&val; + ret = _PyLong_AsByteArray((PyLongObject *)v, + bytes, sizeof(val), + is_little, !is_unsigned); +#else + PyObject *stepval = NULL, *mask = NULL, *shift = NULL; + int bits, remaining_bits, is_negative = 0; + long idigit; + int chunk_size = (sizeof(long) < 8) ? 30 : 62; + if (unlikely(!PyLong_CheckExact(v))) { + PyObject *tmp = v; + v = PyNumber_Long(v); + assert(PyLong_CheckExact(v)); + Py_DECREF(tmp); + if (unlikely(!v)) return (char) -1; + } +#if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 + if (Py_SIZE(x) == 0) + return (char) 0; + is_negative = Py_SIZE(x) < 0; +#else + { + int result = PyObject_RichCompareBool(x, Py_False, Py_LT); + if (unlikely(result < 0)) + return (char) -1; + is_negative = result == 1; + } +#endif + if (is_unsigned && unlikely(is_negative)) { + goto raise_neg_overflow; + } else if (is_negative) { + stepval = PyNumber_Invert(v); + if (unlikely(!stepval)) + return (char) -1; + } else { + stepval = __Pyx_NewRef(v); + } + val = (char) 0; + mask = PyLong_FromLong((1L << chunk_size) - 1); if (unlikely(!mask)) goto done; + shift = PyLong_FromLong(chunk_size); if (unlikely(!shift)) goto done; + for (bits = 0; bits < (int) sizeof(char) * 8 - chunk_size; bits += chunk_size) { + PyObject *tmp, *digit; + digit = PyNumber_And(stepval, mask); + if (unlikely(!digit)) goto done; + idigit = PyLong_AsLong(digit); + Py_DECREF(digit); + if (unlikely(idigit < 0)) goto done; + tmp = PyNumber_Rshift(stepval, shift); + if (unlikely(!tmp)) goto done; + Py_DECREF(stepval); stepval = tmp; + val |= ((char) idigit) << bits; + #if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 + if (Py_SIZE(stepval) == 0) + goto unpacking_done; + #endif + } + idigit = PyLong_AsLong(stepval); + if (unlikely(idigit < 0)) goto done; + remaining_bits = ((int) sizeof(char) * 8) - bits - (is_unsigned ? 0 : 1); + if (unlikely(idigit >= (1L << remaining_bits))) + goto raise_overflow; + val |= ((char) idigit) << bits; + #if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 + unpacking_done: + #endif + if (!is_unsigned) { + if (unlikely(val & (((char) 1) << (sizeof(char) * 8 - 1)))) + goto raise_overflow; + if (is_negative) + val = ~val; + } + ret = 0; + done: + Py_XDECREF(shift); + Py_XDECREF(mask); + Py_XDECREF(stepval); +#endif + Py_DECREF(v); + if (likely(!ret)) + return val; + } + return (char) -1; + } + } else { + char val; + PyObject *tmp = __Pyx_PyNumber_IntOrLong(x); + if (!tmp) return (char) -1; + val = __Pyx_PyInt_As_char(tmp); + Py_DECREF(tmp); + return val; + } +raise_overflow: + PyErr_SetString(PyExc_OverflowError, + "value too large to convert to char"); + return (char) -1; +raise_neg_overflow: + PyErr_SetString(PyExc_OverflowError, + "can't convert negative value to char"); + return (char) -1; +} + +/* FormatTypeName */ + #if CYTHON_COMPILING_IN_LIMITED_API +static __Pyx_TypeName +__Pyx_PyType_GetName(PyTypeObject* tp) +{ + PyObject *name = __Pyx_PyObject_GetAttrStr((PyObject *)tp, + __pyx_n_s_name_2); + if (unlikely(name == NULL) || unlikely(!PyUnicode_Check(name))) { + PyErr_Clear(); + Py_XDECREF(name); + name = __Pyx_NewRef(__pyx_n_s__28); + } + return name; +} +#endif + +/* CheckBinaryVersion */ + static unsigned long __Pyx_get_runtime_version(void) { +#if __PYX_LIMITED_VERSION_HEX >= 0x030B00A4 + return Py_Version & ~0xFFUL; +#else + const char* rt_version = Py_GetVersion(); + unsigned long version = 0; + unsigned long factor = 0x01000000UL; + unsigned int digit = 0; + int i = 0; + while (factor) { + while ('0' <= rt_version[i] && rt_version[i] <= '9') { + digit = digit * 10 + (unsigned int) (rt_version[i] - '0'); + ++i; + } + version += factor * digit; + if (rt_version[i] != '.') + break; + digit = 0; + factor >>= 8; + ++i; + } + return version; +#endif +} +static int __Pyx_check_binary_version(unsigned long ct_version, unsigned long rt_version, int allow_newer) { + const unsigned long MAJOR_MINOR = 0xFFFF0000UL; + if ((rt_version & MAJOR_MINOR) == (ct_version & MAJOR_MINOR)) + return 0; + if (likely(allow_newer && (rt_version & MAJOR_MINOR) > (ct_version & MAJOR_MINOR))) + return 1; + { + char message[200]; + PyOS_snprintf(message, sizeof(message), + "compile time Python version %d.%d " + "of module '%.100s' " + "%s " + "runtime version %d.%d", + (int) (ct_version >> 24), (int) ((ct_version >> 16) & 0xFF), + __Pyx_MODULE_NAME, + (allow_newer) ? "was newer than" : "does not match", + (int) (rt_version >> 24), (int) ((rt_version >> 16) & 0xFF) + ); + return PyErr_WarnEx(NULL, message, 1); + } +} + +/* InitStrings */ + #if PY_MAJOR_VERSION >= 3 +static int __Pyx_InitString(__Pyx_StringTabEntry t, PyObject **str) { + if (t.is_unicode | t.is_str) { + if (t.intern) { + *str = PyUnicode_InternFromString(t.s); + } else if (t.encoding) { + *str = PyUnicode_Decode(t.s, t.n - 1, t.encoding, NULL); + } else { + *str = PyUnicode_FromStringAndSize(t.s, t.n - 1); + } + } else { + *str = PyBytes_FromStringAndSize(t.s, t.n - 1); + } + if (!*str) + return -1; + if (PyObject_Hash(*str) == -1) + return -1; + return 0; +} +#endif +static int __Pyx_InitStrings(__Pyx_StringTabEntry *t) { + while (t->p) { + #if PY_MAJOR_VERSION >= 3 + __Pyx_InitString(*t, t->p); + #else + if (t->is_unicode) { + *t->p = PyUnicode_DecodeUTF8(t->s, t->n - 1, NULL); + } else if (t->intern) { + *t->p = PyString_InternFromString(t->s); + } else { + *t->p = PyString_FromStringAndSize(t->s, t->n - 1); + } + if (!*t->p) + return -1; + if (PyObject_Hash(*t->p) == -1) + return -1; + #endif + ++t; + } + return 0; +} + +#include +static CYTHON_INLINE Py_ssize_t __Pyx_ssize_strlen(const char *s) { + size_t len = strlen(s); + if (unlikely(len > (size_t) PY_SSIZE_T_MAX)) { + PyErr_SetString(PyExc_OverflowError, "byte string is too long"); + return -1; + } + return (Py_ssize_t) len; +} +static CYTHON_INLINE PyObject* __Pyx_PyUnicode_FromString(const char* c_str) { + Py_ssize_t len = __Pyx_ssize_strlen(c_str); + if (unlikely(len < 0)) return NULL; + return __Pyx_PyUnicode_FromStringAndSize(c_str, len); +} +static CYTHON_INLINE PyObject* __Pyx_PyByteArray_FromString(const char* c_str) { + Py_ssize_t len = __Pyx_ssize_strlen(c_str); + if (unlikely(len < 0)) return NULL; + return PyByteArray_FromStringAndSize(c_str, len); +} +static CYTHON_INLINE const char* __Pyx_PyObject_AsString(PyObject* o) { + Py_ssize_t ignore; + return __Pyx_PyObject_AsStringAndSize(o, &ignore); +} +#if __PYX_DEFAULT_STRING_ENCODING_IS_ASCII || __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT +#if !CYTHON_PEP393_ENABLED +static const char* __Pyx_PyUnicode_AsStringAndSize(PyObject* o, Py_ssize_t *length) { + char* defenc_c; + PyObject* defenc = _PyUnicode_AsDefaultEncodedString(o, NULL); + if (!defenc) return NULL; + defenc_c = PyBytes_AS_STRING(defenc); +#if __PYX_DEFAULT_STRING_ENCODING_IS_ASCII + { + char* end = defenc_c + PyBytes_GET_SIZE(defenc); + char* c; + for (c = defenc_c; c < end; c++) { + if ((unsigned char) (*c) >= 128) { + PyUnicode_AsASCIIString(o); + return NULL; + } + } + } +#endif + *length = PyBytes_GET_SIZE(defenc); + return defenc_c; +} +#else +static CYTHON_INLINE const char* __Pyx_PyUnicode_AsStringAndSize(PyObject* o, Py_ssize_t *length) { + if (unlikely(__Pyx_PyUnicode_READY(o) == -1)) return NULL; +#if __PYX_DEFAULT_STRING_ENCODING_IS_ASCII + if (likely(PyUnicode_IS_ASCII(o))) { + *length = PyUnicode_GET_LENGTH(o); + return PyUnicode_AsUTF8(o); + } else { + PyUnicode_AsASCIIString(o); + return NULL; + } +#else + return PyUnicode_AsUTF8AndSize(o, length); +#endif +} +#endif +#endif +static CYTHON_INLINE const char* __Pyx_PyObject_AsStringAndSize(PyObject* o, Py_ssize_t *length) { +#if __PYX_DEFAULT_STRING_ENCODING_IS_ASCII || __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT + if ( +#if PY_MAJOR_VERSION < 3 && __PYX_DEFAULT_STRING_ENCODING_IS_ASCII + __Pyx_sys_getdefaultencoding_not_ascii && +#endif + PyUnicode_Check(o)) { + return __Pyx_PyUnicode_AsStringAndSize(o, length); + } else +#endif +#if (!CYTHON_COMPILING_IN_PYPY && !CYTHON_COMPILING_IN_LIMITED_API) || (defined(PyByteArray_AS_STRING) && defined(PyByteArray_GET_SIZE)) + if (PyByteArray_Check(o)) { + *length = PyByteArray_GET_SIZE(o); + return PyByteArray_AS_STRING(o); + } else +#endif + { + char* result; + int r = PyBytes_AsStringAndSize(o, &result, length); + if (unlikely(r < 0)) { + return NULL; + } else { + return result; + } + } +} +static CYTHON_INLINE int __Pyx_PyObject_IsTrue(PyObject* x) { + int is_true = x == Py_True; + if (is_true | (x == Py_False) | (x == Py_None)) return is_true; + else return PyObject_IsTrue(x); +} +static CYTHON_INLINE int __Pyx_PyObject_IsTrueAndDecref(PyObject* x) { + int retval; + if (unlikely(!x)) return -1; + retval = __Pyx_PyObject_IsTrue(x); + Py_DECREF(x); + return retval; +} +static PyObject* __Pyx_PyNumber_IntOrLongWrongResultType(PyObject* result, const char* type_name) { + __Pyx_TypeName result_type_name = __Pyx_PyType_GetName(Py_TYPE(result)); +#if PY_MAJOR_VERSION >= 3 + if (PyLong_Check(result)) { + if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1, + "__int__ returned non-int (type " __Pyx_FMT_TYPENAME "). " + "The ability to return an instance of a strict subclass of int is deprecated, " + "and may be removed in a future version of Python.", + result_type_name)) { + __Pyx_DECREF_TypeName(result_type_name); + Py_DECREF(result); + return NULL; + } + __Pyx_DECREF_TypeName(result_type_name); + return result; + } +#endif + PyErr_Format(PyExc_TypeError, + "__%.4s__ returned non-%.4s (type " __Pyx_FMT_TYPENAME ")", + type_name, type_name, result_type_name); + __Pyx_DECREF_TypeName(result_type_name); + Py_DECREF(result); + return NULL; +} +static CYTHON_INLINE PyObject* __Pyx_PyNumber_IntOrLong(PyObject* x) { +#if CYTHON_USE_TYPE_SLOTS + PyNumberMethods *m; +#endif + const char *name = NULL; + PyObject *res = NULL; +#if PY_MAJOR_VERSION < 3 + if (likely(PyInt_Check(x) || PyLong_Check(x))) +#else + if (likely(PyLong_Check(x))) +#endif + return __Pyx_NewRef(x); +#if CYTHON_USE_TYPE_SLOTS + m = Py_TYPE(x)->tp_as_number; + #if PY_MAJOR_VERSION < 3 + if (m && m->nb_int) { + name = "int"; + res = m->nb_int(x); + } + else if (m && m->nb_long) { + name = "long"; + res = m->nb_long(x); + } + #else + if (likely(m && m->nb_int)) { + name = "int"; + res = m->nb_int(x); + } + #endif +#else + if (!PyBytes_CheckExact(x) && !PyUnicode_CheckExact(x)) { + res = PyNumber_Int(x); + } +#endif + if (likely(res)) { +#if PY_MAJOR_VERSION < 3 + if (unlikely(!PyInt_Check(res) && !PyLong_Check(res))) { +#else + if (unlikely(!PyLong_CheckExact(res))) { +#endif + return __Pyx_PyNumber_IntOrLongWrongResultType(res, name); + } + } + else if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_TypeError, + "an integer is required"); + } + return res; +} +static CYTHON_INLINE Py_ssize_t __Pyx_PyIndex_AsSsize_t(PyObject* b) { + Py_ssize_t ival; + PyObject *x; +#if PY_MAJOR_VERSION < 3 + if (likely(PyInt_CheckExact(b))) { + if (sizeof(Py_ssize_t) >= sizeof(long)) + return PyInt_AS_LONG(b); + else + return PyInt_AsSsize_t(b); + } +#endif + if (likely(PyLong_CheckExact(b))) { + #if CYTHON_USE_PYLONG_INTERNALS + if (likely(__Pyx_PyLong_IsCompact(b))) { + return __Pyx_PyLong_CompactValue(b); + } else { + const digit* digits = __Pyx_PyLong_Digits(b); + const Py_ssize_t size = __Pyx_PyLong_SignedDigitCount(b); + switch (size) { + case 2: + if (8 * sizeof(Py_ssize_t) > 2 * PyLong_SHIFT) { + return (Py_ssize_t) (((((size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); + } + break; + case -2: + if (8 * sizeof(Py_ssize_t) > 2 * PyLong_SHIFT) { + return -(Py_ssize_t) (((((size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); + } + break; + case 3: + if (8 * sizeof(Py_ssize_t) > 3 * PyLong_SHIFT) { + return (Py_ssize_t) (((((((size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); + } + break; + case -3: + if (8 * sizeof(Py_ssize_t) > 3 * PyLong_SHIFT) { + return -(Py_ssize_t) (((((((size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); + } + break; + case 4: + if (8 * sizeof(Py_ssize_t) > 4 * PyLong_SHIFT) { + return (Py_ssize_t) (((((((((size_t)digits[3]) << PyLong_SHIFT) | (size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); + } + break; + case -4: + if (8 * sizeof(Py_ssize_t) > 4 * PyLong_SHIFT) { + return -(Py_ssize_t) (((((((((size_t)digits[3]) << PyLong_SHIFT) | (size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); + } + break; + } + } + #endif + return PyLong_AsSsize_t(b); + } + x = PyNumber_Index(b); + if (!x) return -1; + ival = PyInt_AsSsize_t(x); + Py_DECREF(x); + return ival; +} +static CYTHON_INLINE Py_hash_t __Pyx_PyIndex_AsHash_t(PyObject* o) { + if (sizeof(Py_hash_t) == sizeof(Py_ssize_t)) { + return (Py_hash_t) __Pyx_PyIndex_AsSsize_t(o); +#if PY_MAJOR_VERSION < 3 + } else if (likely(PyInt_CheckExact(o))) { + return PyInt_AS_LONG(o); +#endif + } else { + Py_ssize_t ival; + PyObject *x; + x = PyNumber_Index(o); + if (!x) return -1; + ival = PyInt_AsLong(x); + Py_DECREF(x); + return ival; + } +} +static CYTHON_INLINE PyObject * __Pyx_PyBool_FromLong(long b) { + return b ? __Pyx_NewRef(Py_True) : __Pyx_NewRef(Py_False); +} +static CYTHON_INLINE PyObject * __Pyx_PyInt_FromSize_t(size_t ival) { + return PyInt_FromSize_t(ival); +} + + +/* #### Code section: utility_code_pragmas_end ### */ +#ifdef _MSC_VER +#pragma warning( pop ) +#endif + + + +/* #### Code section: end ### */ +#endif /* Py_PYTHON_H */ diff --git a/fairseq/data/data_utils_fast.pyx b/fairseq/data/data_utils_fast.pyx new file mode 100644 index 0000000000000000000000000000000000000000..c61f31d6b2113d4c6a03d6553335997098ba0c20 --- /dev/null +++ b/fairseq/data/data_utils_fast.pyx @@ -0,0 +1,178 @@ +# cython: language_level=3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np + +cimport cython +cimport numpy as np + +from libc.stdint cimport int32_t, int64_t +from libcpp cimport bool as bool_t + +ctypedef int64_t DTYPE_t + +@cython.cdivision(True) +@cython.boundscheck(False) +@cython.wraparound(False) +cpdef list batch_by_size_vec( + np.ndarray[int64_t, ndim=1] indices, + np.ndarray[int64_t, ndim=1] num_tokens_vec, + int64_t max_tokens, + int64_t max_sentences, + int32_t bsz_mult, +): + if indices.shape[0] == 0: + return [] + + assert max_tokens <= 0 or np.max(num_tokens_vec) <= max_tokens, ( + f"Sentences lengths should not exceed max_tokens={max_tokens}" + ) + + cdef int32_t indices_len = indices.shape[0] + cdef np.ndarray[int32_t, ndim=1] batches_ends = \ + np.zeros(indices_len, dtype=np.int32) + cdef int32_t[:] batches_ends_view = batches_ends + cdef int64_t[:] num_tokens_view = num_tokens_vec + + cdef int32_t pos = 0 + cdef int32_t new_batch_end = 0 + + cdef int64_t new_batch_max_tokens = 0 + cdef int32_t new_batch_sentences = 0 + cdef int64_t new_batch_num_tokens = 0 + + cdef bool_t overflow = False + cdef bool_t size_matches_with_bsz_mult = False + + cdef int32_t batches_count = 0 + cdef int32_t batch_start = 0 + cdef int64_t tail_max_tokens = 0 + cdef int64_t batch_max_tokens = 0 + + for pos in range(indices_len): + # At every pos we keep stats about the last complete batch [batch_start:batch_end), + # and tail [batch_end:pos]. + # 1) Every time when (batch + tail) forms a valid batch + # (according to max_tokens, max_sentences and bsz_mult) we append tail to batch. + # 2) When (batch+tail) violates max_tokens or max_sentences constraints + # we finalize running batch, and tail becomes a new batch. + # 3) There is a corner case when tail also violates constraints. + # In that situation [batch_end:pos-1] (tail without the current pos) + # gets added to the finalized batches, while [pos:pos] becomes a new tail. + # + # Important: For the sake of performance try to avoid using function calls within this loop. + + tail_max_tokens = tail_max_tokens \ + if tail_max_tokens > num_tokens_view[pos] \ + else num_tokens_view[pos] + new_batch_end = pos + 1 + new_batch_max_tokens = batch_max_tokens \ + if batch_max_tokens > tail_max_tokens \ + else tail_max_tokens + new_batch_sentences = new_batch_end - batch_start + new_batch_num_tokens = new_batch_sentences * new_batch_max_tokens + + overflow = (new_batch_sentences > max_sentences > 0 or + new_batch_num_tokens > max_tokens > 0) + size_matches_with_bsz_mult = (new_batch_sentences < bsz_mult or + new_batch_sentences % bsz_mult == 0) + + if overflow: + tail_num_tokens = tail_max_tokens * \ + (new_batch_end - batches_ends_view[batches_count]) + tail_overflow = tail_num_tokens > max_tokens > 0 + # In case of a tail overflow finalize two batches + if tail_overflow: + batches_count += 1 + batches_ends_view[batches_count] = pos + tail_max_tokens = num_tokens_view[pos] + batch_start = batches_ends_view[batches_count] + batches_count += 1 + new_batch_max_tokens = tail_max_tokens + + if overflow or size_matches_with_bsz_mult: + batches_ends_view[batches_count] = new_batch_end + batch_max_tokens = new_batch_max_tokens + tail_max_tokens = 0 + if batches_ends_view[batches_count] != indices_len: + batches_count += 1 + # Memory and time-efficient split + return np.split(indices, batches_ends[:batches_count]) + + +@cython.boundscheck(False) +@cython.wraparound(False) +cpdef list batch_by_size_fn( + np.ndarray[DTYPE_t, ndim=1] indices, + num_tokens_fn, + int64_t max_tokens, + int64_t max_sentences, + int32_t bsz_mult, +): + cdef int32_t indices_len = indices.shape[0] + cdef np.ndarray[int64_t, ndim=1] num_tokens_vec = np.zeros(indices_len, + dtype=np.int64) + cdef DTYPE_t[:] indices_view = indices + cdef DTYPE_t[:] num_tokens_vec_view = num_tokens_vec + cdef int64_t pos + for pos in range(indices_len): + num_tokens_vec[pos] = num_tokens_fn(indices_view[pos]) + return batch_by_size_vec(indices, num_tokens_vec, max_tokens, + max_sentences, bsz_mult,) + + +cdef _find_valid_shape( + DTYPE_t[:, :] shapes_view, + int64_t num_sentences, + int64_t num_tokens, +): + """Return index of first valid shape of -1 if none is found.""" + for i in range(shapes_view.shape[0]): + if num_sentences <= shapes_view[i][0] and num_tokens <= shapes_view[i][1]: + return i + return -1 + + +@cython.cdivision(True) +cpdef list batch_fixed_shapes_fast( + np.ndarray[DTYPE_t, ndim=1] indices, + num_tokens_fn, + np.ndarray[DTYPE_t, ndim=2] fixed_shapes_sorted, +): + cdef int64_t sample_len = 0 + cdef list sample_lens = [] + cdef list batch = [] + cdef list batches = [] + cdef int64_t mod_len + cdef int64_t i + cdef int64_t idx + cdef int64_t num_tokens + cdef DTYPE_t[:] indices_view = indices + cdef DTYPE_t[:, :] shapes_view = fixed_shapes_sorted + + for i in range(len(indices_view)): + idx = indices_view[i] + num_tokens = num_tokens_fn(idx) + sample_lens.append(num_tokens) + sample_len = max(sample_len, num_tokens) + + shape_idx = _find_valid_shape(shapes_view, len(batch) + 1, sample_len) + if shape_idx == -1: + batches.append(batch) + batch = [] + sample_lens = [] + sample_len = 0 + shapes_view = fixed_shapes_sorted + elif shape_idx > 0: + # small optimization for the next call to _find_valid_shape + shapes_view = shapes_view[shape_idx:] + + batch.append(idx) + + if len(batch) > 0: + batches.append(batch) + + return batches diff --git a/fairseq/data/denoising_dataset.py b/fairseq/data/denoising_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a900fc6f960c7faa41173316df011e9bc5cb23c9 --- /dev/null +++ b/fairseq/data/denoising_dataset.py @@ -0,0 +1,443 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import numpy as np +import torch + +from . import FairseqDataset, data_utils + + +def collate( + samples, + pad_idx, + eos_idx, + vocab, + left_pad_source=False, + left_pad_target=False, + input_feeding=True, + pad_to_length=None, +): + assert input_feeding + if len(samples) == 0: + return {} + + def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): + return data_utils.collate_tokens( + [s[key] for s in samples], + pad_idx, + eos_idx=None, # use eos_idx of each sample instead of vocab.eos() + left_pad=left_pad, + move_eos_to_beginning=move_eos_to_beginning, + pad_to_length=pad_to_length, + ) + + id = torch.LongTensor([s["id"] for s in samples]) + src_tokens = merge( + "source", + left_pad=left_pad_source, + pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, + ) + # sort by descending source length + src_lengths = torch.LongTensor([s["source"].numel() for s in samples]) + src_lengths, sort_order = src_lengths.sort(descending=True) + id = id.index_select(0, sort_order) + src_tokens = src_tokens.index_select(0, sort_order) + + prev_output_tokens = None + target = None + if samples[0].get("target", None) is not None: + target = merge( + "target", + left_pad=left_pad_target, + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, + ) + target = target.index_select(0, sort_order) + ntokens = sum(len(s["target"]) for s in samples) + + if input_feeding: + # we create a shifted version of targets for feeding the + # previous output token(s) into the next decoder step + prev_output_tokens = merge( + "target", + left_pad=left_pad_target, + move_eos_to_beginning=True, + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, + ) + prev_output_tokens = prev_output_tokens.index_select(0, sort_order) + else: + ntokens = sum(len(s["source"]) for s in samples) + + batch = { + "id": id, + "ntokens": ntokens, + "net_input": { + "src_tokens": src_tokens, + "src_lengths": src_lengths, + }, + "target": target, + "nsentences": samples[0]["source"].size(0), + "sort_order": sort_order, + } + if prev_output_tokens is not None: + batch["net_input"]["prev_output_tokens"] = prev_output_tokens + + return batch + + +class DenoisingDataset(FairseqDataset): + """ + A wrapper around TokenBlockDataset for BART dataset. + + Args: + dataset (TokenBlockDataset): dataset to wrap + sizes (List[int]): sentence lengths + vocab (~fairseq.data.Dictionary): vocabulary + mask_idx (int): dictionary index used for masked token + mask_whole_words: only mask whole words. This should be a byte mask + over vocab indices, indicating whether it is the beginning of a + word. We will extend any mask to encompass the whole word. + shuffle (bool, optional): shuffle the elements before batching. + Default: ``True`` + seed: Seed for random number generator for reproducibility. + """ + + def __init__( + self, + dataset, + sizes, + vocab, + mask_idx, + mask_whole_words, + shuffle, + seed, + mask, + mask_random, + insert, + rotate, + permute_sentences, + bpe, + replace_length, + mask_length, + poisson_lambda, + eos=None, + item_transform_func=None, + ): + self.dataset = dataset + + self.sizes = sizes + + self.vocab = vocab + self.shuffle = shuffle + self.seed = seed + self.mask_idx = mask_idx + self.mask_whole_word = mask_whole_words + self.mask_ratio = mask + self.random_ratio = mask_random + self.insert_ratio = insert + self.rotate_ratio = rotate + self.permute_sentence_ratio = permute_sentences + self.eos = eos if eos is not None else vocab.eos() + self.item_transform_func = item_transform_func + + if bpe != "gpt2": + self.full_stop_index = self.vocab.eos() + else: + assert bpe == "gpt2" + self.full_stop_index = self.vocab.index("13") + + self.replace_length = replace_length + if self.replace_length not in [-1, 0, 1]: + raise ValueError(f"invalid arg: replace_length={self.replace_length}") + if mask_length not in ["subword", "word", "span-poisson"]: + raise ValueError(f"invalid arg: mask-length={mask_length}") + if mask_length == "subword" and replace_length not in [0, 1]: + raise ValueError(f"if using subwords, use replace-length=1 or 0") + + self.mask_span_distribution = None + if mask_length == "span-poisson": + _lambda = poisson_lambda + + lambda_to_the_k = 1 + e_to_the_minus_lambda = math.exp(-_lambda) + k_factorial = 1 + ps = [] + for k in range(0, 128): + ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial) + lambda_to_the_k *= _lambda + k_factorial *= k + 1 + if ps[-1] < 0.0000001: + break + ps = torch.FloatTensor(ps) + self.mask_span_distribution = torch.distributions.Categorical(ps) + + self.epoch = 0 + + @property + def can_reuse_epoch_itr_across_epochs(self): + return True # only the noise changes, not item sizes + + def set_epoch(self, epoch, **unused): + self.epoch = epoch + + def __getitem__(self, index): + with data_utils.numpy_seed(self.seed, self.epoch, index): + tokens = self.dataset[index] + assert tokens[-1] == self.eos + source, target = tokens, tokens.clone() + + if self.permute_sentence_ratio > 0.0: + source = self.permute_sentences(source, self.permute_sentence_ratio) + + if self.mask_ratio > 0: + source = self.add_whole_word_mask(source, self.mask_ratio) + + if self.insert_ratio > 0: + source = self.add_insertion_noise(source, self.insert_ratio) + + if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio: + source = self.add_rolling_noise(source) + # there can additional changes to make: + if self.item_transform_func is not None: + source, target = self.item_transform_func(source, target) + + assert (source >= 0).all() + assert (source[1:-1] >= 1).all() + assert (source <= len(self.vocab)).all() + assert source[0] == self.vocab.bos() + assert source[-1] == self.eos + return { + "id": index, + "source": source, + "target": target, + } + + def __len__(self): + return len(self.dataset) + + def permute_sentences(self, source, p=1.0): + full_stops = source == self.full_stop_index + # Pretend it ends with a full stop so last span is a sentence + full_stops[-2] = 1 + + # Tokens that are full stops, where the previous token is not + sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero(as_tuple=False) + 2 + result = source.clone() + + num_sentences = sentence_ends.size(0) + num_to_permute = math.ceil((num_sentences * 2 * p) / 2.0) + substitutions = torch.randperm(num_sentences)[:num_to_permute] + ordering = torch.arange(0, num_sentences) + ordering[substitutions] = substitutions[torch.randperm(num_to_permute)] + + # Ignore at start + index = 1 + for i in ordering: + sentence = source[(sentence_ends[i - 1] if i > 0 else 1) : sentence_ends[i]] + result[index : index + sentence.size(0)] = sentence + index += sentence.size(0) + return result + + def word_starts(self, source): + if self.mask_whole_word is not None: + is_word_start = self.mask_whole_word.gather(0, source) + else: + is_word_start = torch.ones(source.size()) + is_word_start[0] = 0 + is_word_start[-1] = 0 + return is_word_start + + def add_whole_word_mask(self, source, p): + is_word_start = self.word_starts(source) + num_to_mask = int(math.ceil(is_word_start.float().sum() * p)) + num_inserts = 0 + if num_to_mask == 0: + return source + + if self.mask_span_distribution is not None: + lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,)) + + # Make sure we have enough to mask + cum_length = torch.cumsum(lengths, 0) + while cum_length[-1] < num_to_mask: + lengths = torch.cat( + [ + lengths, + self.mask_span_distribution.sample(sample_shape=(num_to_mask,)), + ], + dim=0, + ) + cum_length = torch.cumsum(lengths, 0) + + # Trim to masking budget + i = 0 + while cum_length[i] < num_to_mask: + i += 1 + lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1]) + num_to_mask = i + 1 + lengths = lengths[:num_to_mask] + + # Handle 0-length mask (inserts) separately + lengths = lengths[lengths > 0] + num_inserts = num_to_mask - lengths.size(0) + num_to_mask -= num_inserts + if num_to_mask == 0: + return self.add_insertion_noise(source, num_inserts / source.size(0)) + + assert (lengths > 0).all() + else: + lengths = torch.ones((num_to_mask,)).long() + assert is_word_start[-1] == 0 + word_starts = is_word_start.nonzero(as_tuple=False) + indices = word_starts[ + torch.randperm(word_starts.size(0))[:num_to_mask] + ].squeeze(1) + mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio + + source_length = source.size(0) + assert source_length - 1 not in indices + to_keep = torch.ones(source_length, dtype=torch.bool) + is_word_start[ + -1 + ] = 255 # acts as a long length, so spans don't go over the end of doc + if self.replace_length == 0: + to_keep[indices] = 0 + else: + # keep index, but replace it with [MASK] + source[indices] = self.mask_idx + source[indices[mask_random]] = torch.randint( + 1, len(self.vocab), size=(mask_random.sum(),) + ) + + if self.mask_span_distribution is not None: + assert len(lengths.size()) == 1 + assert lengths.size() == indices.size() + lengths -= 1 + while indices.size(0) > 0: + assert lengths.size() == indices.size() + lengths -= is_word_start[indices + 1].long() + uncompleted = lengths >= 0 + indices = indices[uncompleted] + 1 + mask_random = mask_random[uncompleted] + lengths = lengths[uncompleted] + if self.replace_length != -1: + # delete token + to_keep[indices] = 0 + else: + # keep index, but replace it with [MASK] + source[indices] = self.mask_idx + source[indices[mask_random]] = torch.randint( + 1, len(self.vocab), size=(mask_random.sum(),) + ) + else: + # A bit faster when all lengths are 1 + while indices.size(0) > 0: + uncompleted = is_word_start[indices + 1] == 0 + indices = indices[uncompleted] + 1 + mask_random = mask_random[uncompleted] + if self.replace_length != -1: + # delete token + to_keep[indices] = 0 + else: + # keep index, but replace it with [MASK] + source[indices] = self.mask_idx + source[indices[mask_random]] = torch.randint( + 1, len(self.vocab), size=(mask_random.sum(),) + ) + + assert source_length - 1 not in indices + + source = source[to_keep] + + if num_inserts > 0: + source = self.add_insertion_noise(source, num_inserts / source.size(0)) + + return source + + def add_permuted_noise(self, tokens, p): + num_words = len(tokens) + num_to_permute = math.ceil(((num_words * 2) * p) / 2.0) + substitutions = torch.randperm(num_words - 2)[:num_to_permute] + 1 + tokens[substitutions] = tokens[substitutions[torch.randperm(num_to_permute)]] + return tokens + + def add_rolling_noise(self, tokens): + offset = np.random.randint(1, max(1, tokens.size(-1) - 1) + 1) + tokens = torch.cat( + (tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]), + dim=0, + ) + return tokens + + def add_insertion_noise(self, tokens, p): + if p == 0.0: + return tokens + + num_tokens = len(tokens) + n = int(math.ceil(num_tokens * p)) + + noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1 + noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool) + noise_mask[noise_indices] = 1 + result = torch.LongTensor(n + len(tokens)).fill_(-1) + + num_random = int(math.ceil(n * self.random_ratio)) + result[noise_indices[num_random:]] = self.mask_idx + result[noise_indices[:num_random]] = torch.randint( + low=1, high=len(self.vocab), size=(num_random,) + ) + + result[~noise_mask] = tokens + + assert (result >= 0).all() + return result + + def collater(self, samples, pad_to_length=None): + """Merge a list of samples to form a mini-batch. + Args: + samples (List[dict]): samples to collate + Returns: + dict: a mini-batch of data + """ + return collate( + samples, self.vocab.pad(), self.eos, self.vocab, pad_to_length=pad_to_length + ) + + def num_tokens(self, index): + """Return the number of tokens in a sample. This value is used to + enforce ``--max-tokens`` during batching.""" + return self.sizes[index] + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + return self.sizes[index] + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + indices = np.random.permutation(len(self)) + else: + indices = np.arange(len(self)) + return indices[np.argsort(self.sizes[indices], kind="mergesort")] + + def prefetch(self, indices): + self.src.prefetch(indices) + self.tgt.prefetch(indices) + + @property + def supports_prefetch(self): + return ( + hasattr(self.src, "supports_prefetch") + and self.src.supports_prefetch + and hasattr(self.tgt, "supports_prefetch") + and self.tgt.supports_prefetch + ) diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py new file mode 100644 index 0000000000000000000000000000000000000000..7ad590a19b26158bc345a3a66903006a414e2375 --- /dev/null +++ b/fairseq/data/dictionary.py @@ -0,0 +1,403 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +from collections import Counter +from multiprocessing import Pool + +import torch +from fairseq import utils +from fairseq.data import data_utils +from fairseq.file_chunker_utils import Chunker, find_offsets +from fairseq.file_io import PathManager +from fairseq.tokenizer import tokenize_line + + +class Dictionary: + """A mapping from symbols to consecutive integers""" + + def __init__( + self, + *, # begin keyword-only arguments + bos="", + pad="", + eos="", + unk="", + extra_special_symbols=None, + add_special_symbols=True, + ): + self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos + self.symbols = [] + self.count = [] + self.indices = {} + if add_special_symbols: + self.bos_index = self.add_symbol(bos) + self.pad_index = self.add_symbol(pad) + self.eos_index = self.add_symbol(eos) + self.unk_index = self.add_symbol(unk) + if extra_special_symbols: + for s in extra_special_symbols: + self.add_symbol(s) + self.nspecial = len(self.symbols) + + def __eq__(self, other): + return self.indices == other.indices + + def __getitem__(self, idx): + if idx < len(self.symbols): + return self.symbols[idx] + return self.unk_word + + def get_count(self, idx): + return self.count[idx] + + def __len__(self): + """Returns the number of symbols in the dictionary""" + return len(self.symbols) + + def __contains__(self, sym): + return sym in self.indices + + def index(self, sym): + """Returns the index of the specified symbol""" + assert isinstance(sym, str) + if sym in self.indices: + return self.indices[sym] + return self.unk_index + + def string( + self, + tensor, + bpe_symbol=None, + escape_unk=False, + extra_symbols_to_ignore=None, + unk_string=None, + include_eos=False, + separator=" ", + ): + """Helper for converting a tensor of token indices to a string. + + Can optionally remove BPE symbols or escape words. + """ + if torch.is_tensor(tensor) and tensor.dim() == 2: + return "\n".join( + self.string( + t, + bpe_symbol, + escape_unk, + extra_symbols_to_ignore, + include_eos=include_eos, + ) + for t in tensor + ) + + extra_symbols_to_ignore = set(extra_symbols_to_ignore or []) + if not include_eos: + extra_symbols_to_ignore.add(self.eos()) + + def token_string(i): + if i == self.unk(): + if unk_string is not None: + return unk_string + else: + return self.unk_string(escape_unk) + else: + return self[i] + + if hasattr(self, "bos_index"): + extra_symbols_to_ignore.add(self.bos()) + + sent = separator.join( + token_string(i) + for i in tensor + if utils.item(i) not in extra_symbols_to_ignore + ) + + return data_utils.post_process(sent, bpe_symbol) + + def unk_string(self, escape=False): + """Return unknown string, optionally escaped as: <>""" + if escape: + return "<{}>".format(self.unk_word) + else: + return self.unk_word + + def add_symbol(self, word, n=1, overwrite=False): + """Adds a word to the dictionary""" + if word in self.indices and not overwrite: + idx = self.indices[word] + self.count[idx] = self.count[idx] + n + return idx + else: + idx = len(self.symbols) + self.indices[word] = idx + self.symbols.append(word) + self.count.append(n) + return idx + + def update(self, new_dict): + """Updates counts from new dictionary.""" + for word in new_dict.symbols: + idx2 = new_dict.indices[word] + if word in self.indices: + idx = self.indices[word] + self.count[idx] = self.count[idx] + new_dict.count[idx2] + else: + idx = len(self.symbols) + self.indices[word] = idx + self.symbols.append(word) + self.count.append(new_dict.count[idx2]) + + def finalize(self, threshold=-1, nwords=-1, padding_factor=8): + """Sort symbols by frequency in descending order, ignoring special ones. + + Args: + - threshold defines the minimum word count + - nwords defines the total number of words in the final dictionary, + including special symbols + - padding_factor can be used to pad the dictionary size to be a + multiple of 8, which is important on some hardware (e.g., Nvidia + Tensor Cores). + """ + if nwords <= 0: + nwords = len(self) + + new_indices = dict(zip(self.symbols[: self.nspecial], range(self.nspecial))) + new_symbols = self.symbols[: self.nspecial] + new_count = self.count[: self.nspecial] + + c = Counter( + dict( + sorted(zip(self.symbols[self.nspecial :], self.count[self.nspecial :])) + ) + ) + for symbol, count in c.most_common(nwords - self.nspecial): + if count >= threshold: + new_indices[symbol] = len(new_symbols) + new_symbols.append(symbol) + new_count.append(count) + else: + break + + assert len(new_symbols) == len(new_indices) + + self.count = list(new_count) + self.symbols = list(new_symbols) + self.indices = new_indices + + self.pad_to_multiple_(padding_factor) + + def pad_to_multiple_(self, padding_factor): + """Pad Dictionary size to be a multiple of *padding_factor*.""" + if padding_factor > 1: + i = 0 + while len(self) % padding_factor != 0: + symbol = "madeupword{:04d}".format(i) + self.add_symbol(symbol, n=0) + i += 1 + + def bos(self): + """Helper to get index of beginning-of-sentence symbol""" + return self.bos_index + + def pad(self): + """Helper to get index of pad symbol""" + return self.pad_index + + def eos(self): + """Helper to get index of end-of-sentence symbol""" + return self.eos_index + + def unk(self): + """Helper to get index of unk symbol""" + return self.unk_index + + @classmethod + def load(cls, f, add_special_symbols=True): + """Loads the dictionary from a text file with the format: + + ``` + + + ... + ``` + """ + d = cls(add_special_symbols=add_special_symbols) + d.add_from_file(f) + return d + + def add_from_file(self, f): + """ + Loads a pre-existing dictionary from a text file and adds its symbols + to this instance. + """ + if isinstance(f, str): + try: + with open(PathManager.get_local_path(f), "r", encoding="utf-8") as fd: + self.add_from_file(fd) + except FileNotFoundError as fnfe: + raise fnfe + except UnicodeError: + raise Exception( + "Incorrect encoding detected in {}, please " + "rebuild the dataset".format(f) + ) + return + + lines = f.readlines() + indices_start_line = self._load_meta(lines) + + for line in lines[indices_start_line:]: + try: + line, field = line.rstrip().rsplit(" ", 1) + if field == "#fairseq:overwrite": + overwrite = True + line, field = line.rsplit(" ", 1) + else: + overwrite = False + count = int(field) + word = line + if word in self and not overwrite: + raise RuntimeError( + "Duplicate word found when loading Dictionary: '{}'. " + "Duplicate words can overwrite earlier ones by adding the " + "#fairseq:overwrite flag at the end of the corresponding row " + "in the dictionary file. If using the Camembert model, please " + "download an updated copy of the model file.".format(word) + ) + self.add_symbol(word, n=count, overwrite=overwrite) + except ValueError: + raise ValueError( + f"Incorrect dictionary format, expected ' [flags]': \"{line}\"" + ) + + def _save(self, f, kv_iterator): + if isinstance(f, str): + PathManager.mkdirs(os.path.dirname(f)) + with PathManager.open(f, "w", encoding="utf-8") as fd: + return self.save(fd) + for k, v in kv_iterator: + print("{} {}".format(k, v), file=f) + + def _get_meta(self): + return [], [] + + def _load_meta(self, lines): + return 0 + + def save(self, f): + """Stores dictionary into a text file""" + ex_keys, ex_vals = self._get_meta() + self._save( + f, + zip( + ex_keys + self.symbols[self.nspecial :], + ex_vals + self.count[self.nspecial :], + ), + ) + + def dummy_sentence(self, length): + t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long() + t[-1] = self.eos() + return t + + def encode_line( + self, + line, + line_tokenizer=tokenize_line, + add_if_not_exist=True, + consumer=None, + append_eos=True, + reverse_order=False, + ) -> torch.IntTensor: + words = line_tokenizer(line) + if reverse_order: + words = list(reversed(words)) + nwords = len(words) + ids = torch.IntTensor(nwords + 1 if append_eos else nwords) + + for i, word in enumerate(words): + if add_if_not_exist: + idx = self.add_symbol(word) + else: + idx = self.index(word) + if consumer is not None: + consumer(word, idx) + ids[i] = idx + if append_eos: + ids[nwords] = self.eos_index + return ids + + @staticmethod + def _add_file_to_dictionary_single_worker( + filename, + tokenize, + eos_word, + start_offset, + end_offset, + ): + counter = Counter() + with Chunker(filename, start_offset, end_offset) as line_iterator: + for line in line_iterator: + for word in tokenize(line): + counter.update([word]) + counter.update([eos_word]) + return counter + + @staticmethod + def add_file_to_dictionary(filename, dict, tokenize, num_workers): + def merge_result(counter): + for w, c in sorted(counter.items()): + dict.add_symbol(w, c) + + local_file = PathManager.get_local_path(filename) + offsets = find_offsets(local_file, num_workers) + if num_workers > 1: + chunks = zip(offsets, offsets[1:]) + pool = Pool(processes=num_workers) + results = [] + for (start_offset, end_offset) in chunks: + results.append( + pool.apply_async( + Dictionary._add_file_to_dictionary_single_worker, + ( + local_file, + tokenize, + dict.eos_word, + start_offset, + end_offset, + ), + ) + ) + pool.close() + pool.join() + for r in results: + merge_result(r.get()) + else: + merge_result( + Dictionary._add_file_to_dictionary_single_worker( + local_file, tokenize, dict.eos_word, offsets[0], offsets[1] + ) + ) + + +class TruncatedDictionary(object): + def __init__(self, wrapped_dict, length): + self.__class__ = type( + wrapped_dict.__class__.__name__, + (self.__class__, wrapped_dict.__class__), + {}, + ) + self.__dict__ = wrapped_dict.__dict__ + self.wrapped_dict = wrapped_dict + self.length = min(len(self.wrapped_dict), length) + + def __len__(self): + return self.length + + def __getitem__(self, i): + if i < self.length: + return self.wrapped_dict[i] + return self.wrapped_dict.unk() diff --git a/fairseq/data/encoders/__init__.py b/fairseq/data/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7cbe00a10520331709441e5e77991bd2edca8c06 --- /dev/null +++ b/fairseq/data/encoders/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import importlib +import os + +from fairseq import registry + + +build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY, _ = registry.setup_registry( + "--tokenizer", + default=None, +) + + +build_bpe, register_bpe, BPE_REGISTRY, _ = registry.setup_registry( + "--bpe", + default=None, +) + + +# automatically import any Python files in the encoders/ directory +for file in sorted(os.listdir(os.path.dirname(__file__))): + if file.endswith(".py") and not file.startswith("_"): + module = file[: file.find(".py")] + importlib.import_module("fairseq.data.encoders." + module) diff --git a/fairseq/data/encoders/__pycache__/__init__.cpython-310.pyc b/fairseq/data/encoders/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f973889889b545735f7cdff39b6108eccc729840 Binary files /dev/null and b/fairseq/data/encoders/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/data/encoders/__pycache__/byte_bpe.cpython-310.pyc b/fairseq/data/encoders/__pycache__/byte_bpe.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..640d97db529d5331845bb18b8a6762b8966a4e68 Binary files /dev/null and b/fairseq/data/encoders/__pycache__/byte_bpe.cpython-310.pyc differ diff --git a/fairseq/data/encoders/__pycache__/byte_utils.cpython-310.pyc b/fairseq/data/encoders/__pycache__/byte_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d1b82fb5517ddce364d34fedd3c0f6a0ed42656 Binary files /dev/null and b/fairseq/data/encoders/__pycache__/byte_utils.cpython-310.pyc differ diff --git a/fairseq/data/encoders/__pycache__/bytes.cpython-310.pyc b/fairseq/data/encoders/__pycache__/bytes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5d8b15b21050f2b69f3052e6a1474c0cc021c89 Binary files /dev/null and b/fairseq/data/encoders/__pycache__/bytes.cpython-310.pyc differ diff --git a/fairseq/data/encoders/__pycache__/characters.cpython-310.pyc b/fairseq/data/encoders/__pycache__/characters.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58aa40e1f6d52a2b260667a1440c625564e79ea5 Binary files /dev/null and b/fairseq/data/encoders/__pycache__/characters.cpython-310.pyc differ diff --git a/fairseq/data/encoders/__pycache__/fastbpe.cpython-310.pyc b/fairseq/data/encoders/__pycache__/fastbpe.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cd11bdbd6b891018ca828844ffd1d0e98cee0a7 Binary files /dev/null and b/fairseq/data/encoders/__pycache__/fastbpe.cpython-310.pyc differ diff --git a/fairseq/data/encoders/__pycache__/gpt2_bpe.cpython-310.pyc b/fairseq/data/encoders/__pycache__/gpt2_bpe.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a17d57e10f9b90944b4e65a1ad196d9b5f87384 Binary files /dev/null and b/fairseq/data/encoders/__pycache__/gpt2_bpe.cpython-310.pyc differ diff --git a/fairseq/data/encoders/__pycache__/gpt2_bpe_utils.cpython-310.pyc b/fairseq/data/encoders/__pycache__/gpt2_bpe_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c62369276f76b631e5d2fb805f1daca5c9f105af Binary files /dev/null and b/fairseq/data/encoders/__pycache__/gpt2_bpe_utils.cpython-310.pyc differ diff --git a/fairseq/data/encoders/__pycache__/hf_bert_bpe.cpython-310.pyc b/fairseq/data/encoders/__pycache__/hf_bert_bpe.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..298426f3d6640c3a70483369b20a3a388c12ba88 Binary files /dev/null and b/fairseq/data/encoders/__pycache__/hf_bert_bpe.cpython-310.pyc differ diff --git a/fairseq/data/encoders/__pycache__/hf_byte_bpe.cpython-310.pyc b/fairseq/data/encoders/__pycache__/hf_byte_bpe.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29240d68e4be29985032505c8b8ebfc7d0984061 Binary files /dev/null and b/fairseq/data/encoders/__pycache__/hf_byte_bpe.cpython-310.pyc differ diff --git a/fairseq/data/encoders/__pycache__/moses_tokenizer.cpython-310.pyc b/fairseq/data/encoders/__pycache__/moses_tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..212ec98f91f9908d1aa7b4166ca0c16344751861 Binary files /dev/null and b/fairseq/data/encoders/__pycache__/moses_tokenizer.cpython-310.pyc differ diff --git a/fairseq/data/encoders/__pycache__/nltk_tokenizer.cpython-310.pyc b/fairseq/data/encoders/__pycache__/nltk_tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d7a7063084bf74b30583240b93de0eb54a0388c Binary files /dev/null and b/fairseq/data/encoders/__pycache__/nltk_tokenizer.cpython-310.pyc differ diff --git a/fairseq/data/encoders/__pycache__/sentencepiece_bpe.cpython-310.pyc b/fairseq/data/encoders/__pycache__/sentencepiece_bpe.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3e45b4f3ecae11b720d0bc6bc124394c706cf7e Binary files /dev/null and b/fairseq/data/encoders/__pycache__/sentencepiece_bpe.cpython-310.pyc differ diff --git a/fairseq/data/encoders/__pycache__/space_tokenizer.cpython-310.pyc b/fairseq/data/encoders/__pycache__/space_tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12a624cce29eb4724d56c0208270514a6a5d4193 Binary files /dev/null and b/fairseq/data/encoders/__pycache__/space_tokenizer.cpython-310.pyc differ diff --git a/fairseq/data/encoders/__pycache__/subword_nmt_bpe.cpython-310.pyc b/fairseq/data/encoders/__pycache__/subword_nmt_bpe.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15793b2094b6e3f21d9014f21649c459bb37500f Binary files /dev/null and b/fairseq/data/encoders/__pycache__/subword_nmt_bpe.cpython-310.pyc differ diff --git a/fairseq/data/encoders/__pycache__/utils.cpython-310.pyc b/fairseq/data/encoders/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51320a598fd9213ebffdabfb2bc6f19b02a3b89e Binary files /dev/null and b/fairseq/data/encoders/__pycache__/utils.cpython-310.pyc differ diff --git a/fairseq/data/encoders/byte_bpe.py b/fairseq/data/encoders/byte_bpe.py new file mode 100644 index 0000000000000000000000000000000000000000..31e3a0627827f19ca7f0b58da45e46d40a80c3bf --- /dev/null +++ b/fairseq/data/encoders/byte_bpe.py @@ -0,0 +1,48 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass, field + +from fairseq import file_utils +from fairseq.data.encoders import register_bpe +from fairseq.data.encoders.byte_utils import ( + SPACE, + SPACE_ESCAPE, + byte_encode, + smart_byte_decode, +) +from fairseq.dataclass import FairseqDataclass + + +@dataclass +class ByteBpeConfig(FairseqDataclass): + sentencepiece_model_path: str = field( + default="???", metadata={"help": "path to sentencepiece model"} + ) + + +@register_bpe("byte_bpe", dataclass=ByteBpeConfig) +class ByteBPE(object): + def __init__(self, cfg): + vocab = file_utils.cached_path(cfg.sentencepiece_model_path) + try: + import sentencepiece as spm + + self.sp = spm.SentencePieceProcessor() + self.sp.Load(vocab) + except ImportError: + raise ImportError( + "Please install sentencepiece with: pip install sentencepiece" + ) + + def encode(self, x: str) -> str: + byte_encoded = byte_encode(x) + return SPACE.join(self.sp.EncodeAsPieces(byte_encoded)) + + @staticmethod + def decode(x: str) -> str: + unescaped = x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE) + return smart_byte_decode(unescaped) diff --git a/fairseq/data/encoders/byte_utils.py b/fairseq/data/encoders/byte_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a305c080926c2d094b7e8ae48f5331da82025a75 --- /dev/null +++ b/fairseq/data/encoders/byte_utils.py @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import re + + +WHITESPACE_NORMALIZER = re.compile(r"\s+") +SPACE = chr(32) +SPACE_ESCAPE = chr(9601) +# excluding non-breaking space (160) here +PRINTABLE_LATIN = set( + list(range(32, 126 + 1)) + list(range(161, 172 + 1)) + list(range(174, 255 + 1)) +) +BYTE_TO_BCHAR = { + b: chr(b) if b in PRINTABLE_LATIN else chr(256 + b) for b in range(256) +} +BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()} + + +def byte_encode(x: str) -> str: + normalized = WHITESPACE_NORMALIZER.sub(SPACE, x) + return "".join([BYTE_TO_BCHAR[b] for b in normalized.encode("utf-8")]) + + +def byte_decode(x: str) -> str: + try: + return bytes([BCHAR_TO_BYTE[bc] for bc in x]).decode("utf-8") + except ValueError: + return "" + + +def smart_byte_decode(x: str) -> str: + output = byte_decode(x) + if output == "": + # DP the best recovery (max valid chars) if it's broken + n_bytes = len(x) + f = [0 for _ in range(n_bytes + 1)] + pt = [0 for _ in range(n_bytes + 1)] + for i in range(1, n_bytes + 1): + f[i], pt[i] = f[i - 1], i - 1 + for j in range(1, min(4, i) + 1): + if f[i - j] + 1 > f[i] and len(byte_decode(x[i - j : i])) > 0: + f[i], pt[i] = f[i - j] + 1, i - j + cur_pt = n_bytes + while cur_pt > 0: + if f[cur_pt] == f[pt[cur_pt]] + 1: + output = byte_decode(x[pt[cur_pt] : cur_pt]) + output + cur_pt = pt[cur_pt] + return output diff --git a/fairseq/data/encoders/bytes.py b/fairseq/data/encoders/bytes.py new file mode 100644 index 0000000000000000000000000000000000000000..f88f8f6929f5b6bdb0db470be9ebedf8fe1f752d --- /dev/null +++ b/fairseq/data/encoders/bytes.py @@ -0,0 +1,34 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from fairseq.data.encoders import register_bpe +from fairseq.data.encoders.byte_utils import ( + SPACE, + SPACE_ESCAPE, + byte_encode, + smart_byte_decode, +) + + +@register_bpe("bytes") +class Bytes(object): + def __init__(self, *unused): + pass + + @staticmethod + def add_args(parser): + pass + + @staticmethod + def encode(x: str) -> str: + encoded = byte_encode(x) + escaped = encoded.replace(SPACE, SPACE_ESCAPE) + return SPACE.join(list(escaped)) + + @staticmethod + def decode(x: str) -> str: + unescaped = x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE) + return smart_byte_decode(unescaped) diff --git a/fairseq/data/encoders/characters.py b/fairseq/data/encoders/characters.py new file mode 100644 index 0000000000000000000000000000000000000000..494ea219392716dc75d2c1e19d71cd55b9b2f4ba --- /dev/null +++ b/fairseq/data/encoders/characters.py @@ -0,0 +1,30 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from fairseq.data.encoders import register_bpe + + +SPACE = chr(32) +SPACE_ESCAPE = chr(9601) + + +@register_bpe("characters") +class Characters(object): + def __init__(self, *unused): + pass + + @staticmethod + def add_args(parser): + pass + + @staticmethod + def encode(x: str) -> str: + escaped = x.replace(SPACE, SPACE_ESCAPE) + return SPACE.join(list(escaped)) + + @staticmethod + def decode(x: str) -> str: + return x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE) diff --git a/fairseq/data/encoders/fastbpe.py b/fairseq/data/encoders/fastbpe.py new file mode 100644 index 0000000000000000000000000000000000000000..f7c21039549ea002e73d1ad7cde5735f215f11ee --- /dev/null +++ b/fairseq/data/encoders/fastbpe.py @@ -0,0 +1,36 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +from fairseq import file_utils +from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass + + +@dataclass +class fastBPEConfig(FairseqDataclass): + bpe_codes: str = field(default="???", metadata={"help": "path to fastBPE BPE"}) + + +@register_bpe("fastbpe", dataclass=fastBPEConfig) +class fastBPE(object): + def __init__(self, cfg): + if cfg.bpe_codes is None: + raise ValueError("--bpe-codes is required for --bpe=fastbpe") + codes = file_utils.cached_path(cfg.bpe_codes) + try: + import fastBPE + + self.bpe = fastBPE.fastBPE(codes) + self.bpe_symbol = "@@ " + except ImportError: + raise ImportError("Please install fastBPE with: pip install fastBPE") + + def encode(self, x: str) -> str: + return self.bpe.apply([x])[0] + + def decode(self, x: str) -> str: + return (x + " ").replace(self.bpe_symbol, "").rstrip() diff --git a/fairseq/data/encoders/gpt2_bpe.py b/fairseq/data/encoders/gpt2_bpe.py new file mode 100644 index 0000000000000000000000000000000000000000..e661426a73c7e735f7054bcb04281bf1649bb46c --- /dev/null +++ b/fairseq/data/encoders/gpt2_bpe.py @@ -0,0 +1,45 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +from fairseq import file_utils +from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass + +from .gpt2_bpe_utils import get_encoder + + +DEFAULT_ENCODER_JSON = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json" +DEFAULT_VOCAB_BPE = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe" + + +@dataclass +class GPT2BPEConfig(FairseqDataclass): + gpt2_encoder_json: str = field( + default=DEFAULT_ENCODER_JSON, metadata={"help": "path to encoder.json"} + ) + gpt2_vocab_bpe: str = field( + default=DEFAULT_VOCAB_BPE, metadata={"help": "path to vocab.bpe"} + ) + + +@register_bpe("gpt2", dataclass=GPT2BPEConfig) +class GPT2BPE(object): + def __init__(self, cfg): + encoder_json = file_utils.cached_path(cfg.gpt2_encoder_json) + vocab_bpe = file_utils.cached_path(cfg.gpt2_vocab_bpe) + self.bpe = get_encoder(encoder_json, vocab_bpe) + + def encode(self, x: str) -> str: + return " ".join(map(str, self.bpe.encode(x))) + + def decode(self, x: str) -> str: + return self.bpe.decode( + [int(tok) if tok not in {"", ""} else tok for tok in x.split()] + ) + + def is_beginning_of_word(self, x: str) -> bool: + return self.decode(x).startswith(" ") diff --git a/fairseq/data/encoders/gpt2_bpe_utils.py b/fairseq/data/encoders/gpt2_bpe_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..996d3d4a1175fb3841242c286f08faf43fc06e3e --- /dev/null +++ b/fairseq/data/encoders/gpt2_bpe_utils.py @@ -0,0 +1,140 @@ +""" +Byte pair encoding utilities from GPT-2. + +Original source: https://github.com/openai/gpt-2/blob/master/src/encoder.py +Original license: MIT +""" + +import json +from functools import lru_cache + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class Encoder: + def __init__(self, encoder, bpe_merges, errors="replace"): + self.encoder = encoder + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + + try: + import regex as re + + self.re = re + except ImportError: + raise ImportError("Please install regex with: pip install regex") + + # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = self.re.compile( + r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" + ) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + for token in self.re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder.get(token, token) for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode( + "utf-8", errors=self.errors + ) + return text + + +def get_encoder(encoder_json_path, vocab_bpe_path): + with open(encoder_json_path, "r") as f: + encoder = json.load(f) + with open(vocab_bpe_path, "r", encoding="utf-8") as f: + bpe_data = f.read() + bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]] + return Encoder( + encoder=encoder, + bpe_merges=bpe_merges, + ) diff --git a/fairseq/data/encoders/hf_bert_bpe.py b/fairseq/data/encoders/hf_bert_bpe.py new file mode 100644 index 0000000000000000000000000000000000000000..a41c059343ec7e2914b2c9d2f53f526c33f9659d --- /dev/null +++ b/fairseq/data/encoders/hf_bert_bpe.py @@ -0,0 +1,50 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +from typing import Optional + +from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass + + +@dataclass +class BertBPEConfig(FairseqDataclass): + bpe_cased: bool = field(default=False, metadata={"help": "set for cased BPE"}) + bpe_vocab_file: Optional[str] = field( + default=None, metadata={"help": "bpe vocab file"} + ) + + +@register_bpe("bert", dataclass=BertBPEConfig) +class BertBPE(object): + def __init__(self, cfg): + try: + from transformers import BertTokenizer + except ImportError: + raise ImportError( + "Please install transformers with: pip install transformers" + ) + + if cfg.bpe_vocab_file: + self.bert_tokenizer = BertTokenizer( + cfg.bpe_vocab_file, do_lower_case=not cfg.bpe_cased + ) + else: + vocab_file_name = ( + "bert-base-cased" if cfg.bpe_cased else "bert-base-uncased" + ) + self.bert_tokenizer = BertTokenizer.from_pretrained(vocab_file_name) + + def encode(self, x: str) -> str: + return " ".join(self.bert_tokenizer.tokenize(x)) + + def decode(self, x: str) -> str: + return self.bert_tokenizer.clean_up_tokenization( + self.bert_tokenizer.convert_tokens_to_string(x.split(" ")) + ) + + def is_beginning_of_word(self, x: str) -> bool: + return not x.startswith("##") diff --git a/fairseq/data/encoders/hf_byte_bpe.py b/fairseq/data/encoders/hf_byte_bpe.py new file mode 100644 index 0000000000000000000000000000000000000000..c508578d41bf6b7ce0a847e0797d71b19beb393d --- /dev/null +++ b/fairseq/data/encoders/hf_byte_bpe.py @@ -0,0 +1,50 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass +from fairseq import file_utils + + +@dataclass +class HuggingFaceByteLevelBPEConfig(FairseqDataclass): + bpe_merges: str = field(default="???", metadata={"help": "path to merges.txt"}) + bpe_vocab: str = field(default="???", metadata={"help": "path to vocab.json"}) + bpe_add_prefix_space: bool = field( + default=False, metadata={"help": "add prefix space before encoding"} + ) + + +@register_bpe("hf_byte_bpe", dataclass=HuggingFaceByteLevelBPEConfig) +class HuggingFaceByteLevelBPE(object): + def __init__(self, cfg): + try: + from tokenizers import ByteLevelBPETokenizer + except ImportError: + raise ImportError( + "Please install huggingface/tokenizers with: " "pip install tokenizers" + ) + + bpe_vocab = file_utils.cached_path(cfg.bpe_vocab) + bpe_merges = file_utils.cached_path(cfg.bpe_merges) + + self.bpe = ByteLevelBPETokenizer( + bpe_vocab, + bpe_merges, + add_prefix_space=cfg.bpe_add_prefix_space, + ) + + def encode(self, x: str) -> str: + return " ".join(map(str, self.bpe.encode(x).ids)) + + def decode(self, x: str) -> str: + return self.bpe.decode( + [int(tok) if tok not in {"", ""} else tok for tok in x.split()] + ) + + def is_beginning_of_word(self, x: str) -> bool: + return self.decode(x).startswith(" ") diff --git a/fairseq/data/encoders/moses_tokenizer.py b/fairseq/data/encoders/moses_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..e236dad167a037a8ed95f7fc8292b27b10d580b0 --- /dev/null +++ b/fairseq/data/encoders/moses_tokenizer.py @@ -0,0 +1,49 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +from fairseq.data.encoders import register_tokenizer +from fairseq.dataclass import FairseqDataclass + + +@dataclass +class MosesTokenizerConfig(FairseqDataclass): + source_lang: str = field(default="en", metadata={"help": "source language"}) + target_lang: str = field(default="en", metadata={"help": "target language"}) + moses_no_dash_splits: bool = field( + default=False, metadata={"help": "don't apply dash split rules"} + ) + moses_no_escape: bool = field( + default=False, + metadata={"help": "don't perform HTML escaping on apostrophe, quotes, etc."}, + ) + + +@register_tokenizer("moses", dataclass=MosesTokenizerConfig) +class MosesTokenizer(object): + def __init__(self, cfg: MosesTokenizerConfig): + self.cfg = cfg + + try: + from sacremoses import MosesTokenizer, MosesDetokenizer + + self.tok = MosesTokenizer(cfg.source_lang) + self.detok = MosesDetokenizer(cfg.target_lang) + except ImportError: + raise ImportError( + "Please install Moses tokenizer with: pip install sacremoses" + ) + + def encode(self, x: str) -> str: + return self.tok.tokenize( + x, + aggressive_dash_splits=(not self.cfg.moses_no_dash_splits), + return_str=True, + escape=(not self.cfg.moses_no_escape), + ) + + def decode(self, x: str) -> str: + return self.detok.detokenize(x.split()) diff --git a/fairseq/data/encoders/nltk_tokenizer.py b/fairseq/data/encoders/nltk_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0ab92377b3a23bb48384c3f7acf299612e8b0775 --- /dev/null +++ b/fairseq/data/encoders/nltk_tokenizer.py @@ -0,0 +1,24 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.data.encoders import register_tokenizer +from fairseq.dataclass import FairseqDataclass + + +@register_tokenizer("nltk", dataclass=FairseqDataclass) +class NLTKTokenizer(object): + def __init__(self, *unused): + try: + from nltk.tokenize import word_tokenize + + self.word_tokenize = word_tokenize + except ImportError: + raise ImportError("Please install nltk with: pip install nltk") + + def encode(self, x: str) -> str: + return " ".join(self.word_tokenize(x)) + + def decode(self, x: str) -> str: + return x diff --git a/fairseq/data/encoders/sentencepiece_bpe.py b/fairseq/data/encoders/sentencepiece_bpe.py new file mode 100644 index 0000000000000000000000000000000000000000..0aa6cd7681d0c3a91a6917640972d008db8faef7 --- /dev/null +++ b/fairseq/data/encoders/sentencepiece_bpe.py @@ -0,0 +1,65 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +from typing import Optional + +from fairseq import file_utils +from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass + + +@dataclass +class SentencepieceConfig(FairseqDataclass): + sentencepiece_model: str = field( + default="???", metadata={"help": "path to sentencepiece model"} + ) + sentencepiece_enable_sampling: bool = field( + default=False, metadata={"help": "enable sampling"} + ) + sentencepiece_alpha: Optional[float] = field( + default=None, + metadata={ + "help": "soothing parameter for unigram sampling, " + "and merge probability for BPE-dropout" + }, + ) + + +@register_bpe("sentencepiece", dataclass=SentencepieceConfig) +class SentencepieceBPE(object): + def __init__(self, cfg): + self.enable_sampling = cfg.sentencepiece_enable_sampling + self.alpha = cfg.sentencepiece_alpha + sentencepiece_model = file_utils.cached_path(cfg.sentencepiece_model) + try: + import sentencepiece as spm + + self.sp = spm.SentencePieceProcessor() + self.sp.Load(sentencepiece_model) + except ImportError: + raise ImportError( + "Please install sentencepiece with: pip install sentencepiece" + ) + + def encode(self, x: str) -> str: + return " ".join( + self.sp.Encode( + x, out_type=str, enable_sampling=self.enable_sampling, alpha=self.alpha + ) + ) + + def decode(self, x: str) -> str: + return x.replace(" ", "").replace("\u2581", " ").strip() + + def is_beginning_of_word(self, x: str) -> bool: + if x in ["", "", "", ""]: + # special elements are always considered beginnings + # HACK: this logic is already present in fairseq/tasks/masked_lm.py + # but these special tokens are also contained in the sentencepiece + # vocabulary which causes duplicate special tokens. This hack makes + # sure that they are all taken into account. + return True + return x.startswith("\u2581") diff --git a/fairseq/data/encoders/space_tokenizer.py b/fairseq/data/encoders/space_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..925ad41b7c1aee6738c63938c36bd3ee16dca812 --- /dev/null +++ b/fairseq/data/encoders/space_tokenizer.py @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import re + +from fairseq.data.encoders import register_tokenizer +from fairseq.dataclass import FairseqDataclass + + +@register_tokenizer("space", dataclass=FairseqDataclass) +class SpaceTokenizer(object): + def __init__(self, *unused): + self.space_tok = re.compile(r"\s+") + + def encode(self, x: str) -> str: + return self.space_tok.sub(" ", x) + + def decode(self, x: str) -> str: + return x diff --git a/fairseq/data/encoders/subword_nmt_bpe.py b/fairseq/data/encoders/subword_nmt_bpe.py new file mode 100644 index 0000000000000000000000000000000000000000..5d724d2730a5895ca55af2998c2ced471625b516 --- /dev/null +++ b/fairseq/data/encoders/subword_nmt_bpe.py @@ -0,0 +1,54 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +from fairseq import file_utils +from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass + + +@dataclass +class SubwordNMTBPEConfig(FairseqDataclass): + bpe_codes: str = field(default="???", metadata={"help": "path to subword NMT BPE"}) + bpe_separator: str = field(default="@@", metadata={"help": "BPE separator"}) + + +@register_bpe("subword_nmt", dataclass=SubwordNMTBPEConfig) +class SubwordNMTBPE(object): + def __init__(self, cfg): + if cfg.bpe_codes is None: + raise ValueError("--bpe-codes is required for --bpe=subword_nmt") + codes = file_utils.cached_path(cfg.bpe_codes) + try: + from subword_nmt import apply_bpe + + bpe_parser = apply_bpe.create_parser() + bpe_args = bpe_parser.parse_args( + [ + "--codes", + codes, + "--separator", + cfg.bpe_separator, + ] + ) + self.bpe = apply_bpe.BPE( + bpe_args.codes, + bpe_args.merges, + bpe_args.separator, + None, + bpe_args.glossaries, + ) + self.bpe_symbol = bpe_args.separator + " " + except ImportError: + raise ImportError( + "Please install subword_nmt with: pip install subword-nmt" + ) + + def encode(self, x: str) -> str: + return self.bpe.process_line(x) + + def decode(self, x: str) -> str: + return (x + " ").replace(self.bpe_symbol, "").rstrip() diff --git a/fairseq/data/encoders/utils.py b/fairseq/data/encoders/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d93eb532ef84f0e2bc708b777229ab2cb76ca14b --- /dev/null +++ b/fairseq/data/encoders/utils.py @@ -0,0 +1,30 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from fairseq.data import encoders + + +def get_whole_word_mask(args, dictionary): + bpe = encoders.build_bpe(args) + if bpe is not None: + + def is_beginning_of_word(i): + if i < dictionary.nspecial: + # special elements are always considered beginnings + return True + tok = dictionary[i] + if tok.startswith("madeupword"): + return True + try: + return bpe.is_beginning_of_word(tok) + except ValueError: + return True + + mask_whole_words = torch.ByteTensor( + list(map(is_beginning_of_word, range(len(dictionary)))) + ) + return mask_whole_words + return None diff --git a/fairseq/data/fairseq_dataset.py b/fairseq/data/fairseq_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2bde7fc57b99df2e14e2186a5f9cd98982870ddd --- /dev/null +++ b/fairseq/data/fairseq_dataset.py @@ -0,0 +1,205 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import numpy as np +import torch.utils.data +from fairseq.data import data_utils + +logger = logging.getLogger(__name__) + + +class EpochListening: + """Mixin for receiving updates whenever the epoch increments.""" + + @property + def can_reuse_epoch_itr_across_epochs(self): + """ + Whether we can reuse the :class:`fairseq.data.EpochBatchIterator` for + this dataset across epochs. + + This needs to return ``False`` if the sample sizes can change across + epochs, in which case we may need to regenerate batches at each epoch. + If your dataset relies in ``set_epoch`` then you should consider setting + this to ``False``. + """ + return True + + def set_epoch(self, epoch): + """Will receive the updated epoch number at the beginning of the epoch.""" + pass + + +class FairseqDataset(torch.utils.data.Dataset, EpochListening): + """A dataset that provides helpers for batching.""" + + def __getitem__(self, index): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + def collater(self, samples): + """Merge a list of samples to form a mini-batch. + + Args: + samples (List[dict]): samples to collate + + Returns: + dict: a mini-batch suitable for forwarding with a Model + """ + raise NotImplementedError + + def num_tokens(self, index): + """Return the number of tokens in a sample. This value is used to + enforce ``--max-tokens`` during batching.""" + raise NotImplementedError + + def num_tokens_vec(self, indices): + """Return the number of tokens for a set of positions defined by indices. + This value is used to enforce ``--max-tokens`` during batching.""" + raise NotImplementedError + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + raise NotImplementedError + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + return np.arange(len(self), dtype=np.int64) + + @property + def supports_prefetch(self): + """Whether this dataset supports prefetching.""" + return False + + def attr(self, attr: str, index: int): + return getattr(self, attr, None) + + def prefetch(self, indices): + """Prefetch the data required for this epoch.""" + raise NotImplementedError + + def get_batch_shapes(self): + """ + Return a list of valid batch shapes, for example:: + + [(8, 512), (16, 256), (32, 128)] + + The first dimension of each tuple is the batch size and can be ``None`` + to automatically infer the max batch size based on ``--max-tokens``. + The second dimension of each tuple is the max supported length as given + by :func:`fairseq.data.FairseqDataset.num_tokens`. + + This will be used by :func:`fairseq.data.FairseqDataset.batch_by_size` + to restrict batch shapes. This is useful on TPUs to avoid too many + dynamic shapes (and recompilations). + """ + return None + + def batch_by_size( + self, + indices, + max_tokens=None, + max_sentences=None, + required_batch_size_multiple=1, + ): + """ + Given an ordered set of indices, return batches according to + *max_tokens*, *max_sentences* and *required_batch_size_multiple*. + """ + from fairseq.data import data_utils + + fixed_shapes = self.get_batch_shapes() + if fixed_shapes is not None: + + def adjust_bsz(bsz, num_tokens): + if bsz is None: + assert max_tokens is not None, "Must specify --max-tokens" + bsz = max_tokens // num_tokens + if max_sentences is not None: + bsz = min(bsz, max_sentences) + elif ( + bsz >= required_batch_size_multiple + and bsz % required_batch_size_multiple != 0 + ): + bsz -= bsz % required_batch_size_multiple + return bsz + + fixed_shapes = np.array( + [ + [adjust_bsz(bsz, num_tokens), num_tokens] + for (bsz, num_tokens) in fixed_shapes + ] + ) + + try: + num_tokens_vec = self.num_tokens_vec(indices).astype("int64") + except NotImplementedError: + num_tokens_vec = None + + return data_utils.batch_by_size( + indices, + num_tokens_fn=self.num_tokens, + num_tokens_vec=num_tokens_vec, + max_tokens=max_tokens, + max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + fixed_shapes=fixed_shapes, + ) + + def filter_indices_by_size(self, indices, max_sizes): + """ + Filter a list of sample indices. Remove those that are longer than + specified in *max_sizes*. + + WARNING: don't update, override method in child classes + + Args: + indices (np.array): original array of sample indices + max_sizes (int or list[int] or tuple[int]): max sample size, + can be defined separately for src and tgt (then list or tuple) + + Returns: + np.array: filtered sample array + list: list of removed indices + """ + if isinstance(max_sizes, float) or isinstance(max_sizes, int): + if hasattr(self, "sizes") and isinstance(self.sizes, np.ndarray): + ignored = indices[self.sizes[indices] > max_sizes].tolist() + indices = indices[self.sizes[indices] <= max_sizes] + elif ( + hasattr(self, "sizes") + and isinstance(self.sizes, list) + and len(self.sizes) == 1 + ): + ignored = indices[self.sizes[0][indices] > max_sizes].tolist() + indices = indices[self.sizes[0][indices] <= max_sizes] + else: + indices, ignored = data_utils._filter_by_size_dynamic( + indices, self.size, max_sizes + ) + else: + indices, ignored = data_utils._filter_by_size_dynamic( + indices, self.size, max_sizes + ) + return indices, ignored + + @property + def supports_fetch_outside_dataloader(self): + """Whether this dataset supports fetching outside the workers of the dataloader.""" + return True + + +class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening): + """ + For datasets that need to be read sequentially, usually because the data is + being streamed or otherwise can't be manipulated on a single machine. + """ + + def __iter__(self): + raise NotImplementedError diff --git a/fairseq/data/fasta_dataset.py b/fairseq/data/fasta_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..007011974a997fd7446dd29d7eba097d7513bab0 --- /dev/null +++ b/fairseq/data/fasta_dataset.py @@ -0,0 +1,107 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import subprocess +import threading +from pathlib import Path + +import numpy as np +import torch + + +def fasta_file_path(prefix_path): + return prefix_path + ".fasta" + + +class FastaDataset(torch.utils.data.Dataset): + """ + For loading protein sequence datasets in the common FASTA data format + """ + + def __init__(self, path: str, cache_indices=False): + self.fn = fasta_file_path(path) + self.threadlocal = threading.local() + self.cache = Path(f"{path}.fasta.idx.npy") + if cache_indices: + if self.cache.exists(): + self.offsets, self.sizes = np.load(self.cache) + else: + self.offsets, self.sizes = self._build_index(path) + np.save(self.cache, np.stack([self.offsets, self.sizes])) + else: + self.offsets, self.sizes = self._build_index(path) + + def _get_file(self): + if not hasattr(self.threadlocal, "f"): + self.threadlocal.f = open(self.fn, "r") + return self.threadlocal.f + + def __getitem__(self, idx): + f = self._get_file() + f.seek(self.offsets[idx]) + desc = f.readline().strip() + line = f.readline() + seq = "" + while line != "" and line[0] != ">": + seq += line.strip() + line = f.readline() + return desc, seq + + def __len__(self): + return self.offsets.size + + def _build_index(self, path: str): + # Use grep and awk to get 100M/s on local SSD. + # Should process your enormous 100G fasta in ~10 min single core... + path = fasta_file_path(path) + bytes_offsets = subprocess.check_output( + f"cat {path} | tqdm --bytes --total $(wc -c < {path})" + "| grep --byte-offset '^>' -o | cut -d: -f1", + shell=True, + ) + fasta_lengths = subprocess.check_output( + f"cat {path} | tqdm --bytes --total $(wc -c < {path})" + "| awk '/^>/ {print \"\";next;} { printf(\"%s\",$0);}' | tail -n+2 | awk '{print length($1)}'", + shell=True, + ) + bytes_np = np.fromstring(bytes_offsets, dtype=np.int64, sep=" ") + sizes_np = np.fromstring(fasta_lengths, dtype=np.int64, sep=" ") + return bytes_np, sizes_np + + def __setstate__(self, state): + self.__dict__ = state + self.threadlocal = threading.local() + + def __getstate__(self): + d = {} + for i, v in self.__dict__.items(): + if i != "threadlocal": + d[i] = v + return d + + def __del__(self): + if hasattr(self.threadlocal, "f"): + self.threadlocal.f.close() + del self.threadlocal.f + + @staticmethod + def exists(path): + return os.path.exists(fasta_file_path(path)) + + +class EncodedFastaDataset(FastaDataset): + """ + The FastaDataset returns raw sequences - this allows us to return + indices with a dictionary instead. + """ + + def __init__(self, path, dictionary): + super().__init__(path, cache_indices=True) + self.dictionary = dictionary + + def __getitem__(self, idx): + desc, seq = super().__getitem__(idx) + return self.dictionary.encode_line(seq, line_tokenizer=list).long() diff --git a/fairseq/data/huffman/__init__.py b/fairseq/data/huffman/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9b61fafadba28f65fe78a28b2099368b83cfcf41 --- /dev/null +++ b/fairseq/data/huffman/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .huffman_coder import HuffmanCodeBuilder, HuffmanCoder +from .huffman_mmap_indexed_dataset import ( + HuffmanMMapIndex, + HuffmanMMapIndexedDataset, + HuffmanMMapIndexedDatasetBuilder, + vocab_file_path, +) + +__all__ = [ + "HuffmanCoder", + "HuffmanCodeBuilder", + "HuffmanMMapIndexedDatasetBuilder", + "HuffmanMMapIndexedDataset", + "HuffmanMMapIndex", + "vocab_file_path", +] diff --git a/fairseq/data/huffman/__pycache__/__init__.cpython-310.pyc b/fairseq/data/huffman/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23588aacfe3afc56abc2f9d0b929dec6c298229f Binary files /dev/null and b/fairseq/data/huffman/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/data/huffman/__pycache__/huffman_coder.cpython-310.pyc b/fairseq/data/huffman/__pycache__/huffman_coder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1771e6fe411839054849ae0786264101ad5bc68 Binary files /dev/null and b/fairseq/data/huffman/__pycache__/huffman_coder.cpython-310.pyc differ diff --git a/fairseq/data/huffman/__pycache__/huffman_mmap_indexed_dataset.cpython-310.pyc b/fairseq/data/huffman/__pycache__/huffman_mmap_indexed_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e27e2388d0c4cf7f667ad1938276ad29c5347899 Binary files /dev/null and b/fairseq/data/huffman/__pycache__/huffman_mmap_indexed_dataset.cpython-310.pyc differ diff --git a/fairseq/data/huffman/huffman_coder.py b/fairseq/data/huffman/huffman_coder.py new file mode 100644 index 0000000000000000000000000000000000000000..c04f84564e6a22209439c67fed3cac31f010c6e9 --- /dev/null +++ b/fairseq/data/huffman/huffman_coder.py @@ -0,0 +1,267 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import re +import typing as tp +from collections import Counter, deque +from dataclasses import dataclass + +from bitarray import bitarray, util +from fairseq.data import Dictionary + +# basically we have to write to addressable bytes for the memory mapped +# dataset loader. Sentences that get encoded to a length that is not a +# multiple of BLOCKSIZE (a byte) will be padded to fit. (see _pad in the coder) +BLOCKSIZE = 8 + + +class HuffmanCoder: + def __init__( + self, root: "HuffmanNode", bos="", pad="", eos="", unk="" + ): + self.root = root + self.table = root.code_table() + self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos + + def _pad(self, a: bitarray) -> bitarray: + """ + bitpadding, 1 then 0. + + If the array is already a multiple of blocksize, we add a full block. + """ + pad_len = BLOCKSIZE - (len(a) % BLOCKSIZE) - 1 + padding = bitarray("1" + "0" * pad_len) + return a + padding + + def _unpad(self, a: bitarray) -> bitarray: + """ + remove the bitpadding. + + There will be a set of 0s preceded by a 1 at the end of the bitarray, we remove that + """ + # count the 0 padding at the end until we find the first 1 + # we want to remove the one too + remove_cnt = util.rindex(a, 1) + return a[:remove_cnt] + + def encode(self, iter: tp.List[str]) -> bytes: + """ + encode a list of tokens a return bytes. We use bitpadding to make sure the encoded bits fit in bytes. + """ + a = bitarray() + for token in iter: + code = self.get_code(token) + if code is None: + if self.unk_word is None: + raise Exception(f"unknown token {token} cannot be encoded.") + else: + token = self.unk_word + a = a + self.get_code(token) + return self._pad(a).tobytes() + + def decode(self, bits: bytes) -> tp.Iterator["HuffmanNode"]: + """ + take bitpadded bytes and decode it to a set of leaves. You can then use each node to find the symbol/id + """ + a = bitarray() + a.frombytes(bits) + return self.root.decode(self._unpad(a)) + + def get_code(self, symbol: str) -> tp.Optional[bitarray]: + node = self.get_node(symbol) + return None if node is None else node.code + + def get_node(self, symbol: str) -> "HuffmanNode": + return self.table.get(symbol) + + @classmethod + def from_file( + cls, + filename: str, + bos="", + pad="", + eos="", + unk="", + ) -> "HuffmanCoder": + builder = HuffmanCodeBuilder.from_file(filename) + return builder.build_code(bos=bos, pad=pad, eos=eos, unk=unk) + + def to_file(self, filename, sep="\t"): + nodes = list(self.table.values()) + nodes.sort(key=lambda n: n.id) + with open(filename, "w", encoding="utf-8") as output: + for n in nodes: + output.write(f"{n.symbol}{sep}{n.count}\n") + + def __iter__(self): + for n in self.table.values(): + yield n + + def merge(self, other_coder: "HuffmanCoder") -> "HuffmanCoder": + builder = HuffmanCodeBuilder() + for n in self: + builder.increment(n.symbol, n.count) + for n in other_coder: + builder.increment(n.symbol, n.count) + return builder.build_code() + + def __eq__(self, other: "HuffmanCoder") -> bool: + return self.table == other.table + + def __len__(self) -> int: + return len(self.table) + + def __contains__(self, sym: str) -> bool: + return sym in self.table + + def to_dictionary(self) -> Dictionary: + dictionary = Dictionary(bos=self.bos, unk=self.unk, pad=self.pad, eos=self.eos) + for n in self: + dictionary.add_symbol(n.symbol, n=n.count) + dictionary.finalize() + return dictionary + + +@dataclass +class HuffmanNode: + """ + a node in a Huffman tree + """ + + id: int + count: int + symbol: tp.Optional[str] = None + left: tp.Optional["HuffmanNode"] = None + right: tp.Optional["HuffmanNode"] = None + code: tp.Optional[bitarray] = None + + def is_leaf(self) -> bool: + return self.left is None and self.right is None + + def code_table( + self, prefix: tp.Optional[bitarray] = None + ) -> tp.Dict[str, "HuffmanNode"]: + defaulted_prefix = prefix if prefix is not None else bitarray() + if self.is_leaf(): + self.code = ( + defaulted_prefix if len(defaulted_prefix) > 0 else bitarray("0") + ) # leaf could be the root if there is only one symbol + return {self.symbol: self} + + codes_right = self.right.code_table(defaulted_prefix + bitarray([0])) + codes_left = self.left.code_table(defaulted_prefix + bitarray([1])) + return {**codes_left, **codes_right} + + def decode(self, bits: bitarray) -> tp.Iterator["HuffmanNode"]: + current_node = self + for bit in bits: + if bit == 0: # go right + current_node = current_node.right + else: # go left + current_node = current_node.left + if current_node is None: + # we shouldn't be on a leaf here + raise Exception("fell off a leaf") + if current_node.is_leaf(): + yield current_node + current_node = self + if current_node != self: + raise Exception("couldn't decode all the bits") + + +class HuffmanCodeBuilder: + """ + build a dictionary with occurence count and then build the Huffman code for it. + """ + + def __init__(self): + self.symbols = Counter() + + def add_symbols(self, *syms) -> None: + self.symbols.update(syms) + + def increment(self, symbol: str, cnt: int) -> None: + self.symbols[symbol] += cnt + + @classmethod + def from_file(cls, filename): + c = cls() + with open(filename, "r", encoding="utf-8") as input: + for line in input: + split = re.split(r"[\s]+", line) + c.increment(split[0], int(split[1])) + return c + + def to_file(self, filename, sep="\t"): + with open(filename, "w", encoding="utf-8") as output: + for (tok, cnt) in self.symbols.most_common(): + output.write(f"{tok}{sep}{cnt}\n") + + def _smallest(self, q1: deque, q2: deque) -> HuffmanNode: + if len(q1) == 0: + return q2.pop() + + if len(q2) == 0: + return q1.pop() + + if q1[-1].count < q2[-1].count: + return q1.pop() + + return q2.pop() + + def __add__(self, c: "HuffmanCodeBuilder") -> "HuffmanCodeBuilder": + new_c = self.symbols + c.symbols + new_b = HuffmanCodeBuilder() + new_b.symbols = new_c + return new_b + + def build_code( + self, + bos="", + pad="", + eos="", + unk="", + ) -> HuffmanCoder: + assert len(self.symbols) > 0, "cannot build code from empty list of symbols" + + if self.symbols[bos] == 0: + self.add_symbols(bos) + if self.symbols[pad] == 0: + self.add_symbols(pad) + if self.symbols[eos] == 0: + self.add_symbols(eos) + if self.symbols[unk] == 0: + self.add_symbols(unk) + + node_id = 0 + leaves_queue = deque( + [ + HuffmanNode(symbol=symbol, count=count, id=idx) + for idx, (symbol, count) in enumerate(self.symbols.most_common()) + ] + ) # left are the most common, right are the least common + + if len(leaves_queue) == 1: + root = leaves_queue.pop() + root.id = 0 + return HuffmanCoder(root) + + nodes_queue = deque() + + while len(leaves_queue) > 0 or len(nodes_queue) != 1: + # get the lowest two nodes at the head of each queue + node1 = self._smallest(leaves_queue, nodes_queue) + node2 = self._smallest(leaves_queue, nodes_queue) + + # add new node + nodes_queue.appendleft( + HuffmanNode( + count=node1.count + node2.count, left=node1, right=node2, id=node_id + ) + ) + node_id += 1 + + # we are left with the root + return HuffmanCoder(nodes_queue.pop(), bos=bos, pad=pad, eos=eos, unk=unk) diff --git a/fairseq/data/huffman/huffman_mmap_indexed_dataset.py b/fairseq/data/huffman/huffman_mmap_indexed_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9b098f2c2be32ef65525dd773a6664d7823ada38 --- /dev/null +++ b/fairseq/data/huffman/huffman_mmap_indexed_dataset.py @@ -0,0 +1,287 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import mmap +import os +import shutil +import struct +import typing as tp +from functools import lru_cache + +import numpy as np +import torch +from fairseq.data import indexed_dataset +from fairseq.data.huffman import HuffmanCoder +from fairseq.file_io import PathManager + + +class HuffmanMMapIndex: + """ + keep an index of the offsets in the huffman binary file. + First a header, then the list of sizes (num tokens) for each instance and finally + the addresses of each instance. + """ + + _HDR_MAGIC = b"HUFFIDX\x00\x00" + _VERSION = 1 + + @classmethod + def writer(cls, path: str, data_len: int): + class _Writer: + def __enter__(self): + self._file = open(path, "wb") + + # write header (magic + version) + self._file.write(cls._HDR_MAGIC) + self._file.write(struct.pack(" None: + self._path_prefix = path_prefix + self._coder = coder + self._sizes = [] + self._ptrs = [] + self._data_len = 0 + + def open(self): + self._coder.to_file(vocab_file_path(self._path_prefix)) + self._data_file = open(indexed_dataset.data_file_path(self._path_prefix), "wb") + + def __enter__(self) -> "HuffmanMMapIndexedDatasetBuilder": + self.open() + return self + + def add_item(self, tokens: tp.List[str]) -> None: + """ + add a list of tokens to the dataset, they will compressed with the + provided coder before being written to file. + """ + encoded = self._coder.encode(tokens) + code_len = len(encoded) + last_ptr = 0 + if len(self._ptrs) > 0: + last_ptr = self._ptrs[-1] + self._sizes.append(len(tokens)) + self._ptrs.append(last_ptr + code_len) + self._data_len += code_len + self._data_file.write(encoded) + + def append(self, other_dataset_path_prefix: str) -> None: + """ + append an existing dataset. + Beware, if it wasn't built with the same coder, you are in trouble. + """ + other_index = HuffmanMMapIndex( + indexed_dataset.index_file_path(other_dataset_path_prefix) + ) + for (ptr, size) in other_index: + self._ptrs.append(ptr + self._data_len) + self._sizes.append(size) + + # Concatenate data + with open(indexed_dataset.data_file_path(other_dataset_path_prefix), "rb") as f: + shutil.copyfileobj(f, self._data_file) + + self._data_len += other_index.data_len + + def close(self): + self._data_file.close() + with HuffmanMMapIndex.writer( + indexed_dataset.index_file_path(self._path_prefix), self._data_len + ) as index: + index.write(self._sizes, self._ptrs) + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.close() diff --git a/fairseq/data/id_dataset.py b/fairseq/data/id_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3e4d7969cf2a26e852b466f165a6fadabae3b35f --- /dev/null +++ b/fairseq/data/id_dataset.py @@ -0,0 +1,19 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from . import FairseqDataset + + +class IdDataset(FairseqDataset): + def __getitem__(self, index): + return index + + def __len__(self): + return 0 + + def collater(self, samples): + return torch.tensor(samples) diff --git a/fairseq/data/indexed_dataset.py b/fairseq/data/indexed_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1947d994081167ee78e9b2a5590882e5025b2244 --- /dev/null +++ b/fairseq/data/indexed_dataset.py @@ -0,0 +1,592 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import shutil +import struct +from functools import lru_cache + +import numpy as np +import torch +from fairseq.dataclass.constants import DATASET_IMPL_CHOICES +from fairseq.data.fasta_dataset import FastaDataset +from fairseq.file_io import PathManager +from fairseq.data.huffman import HuffmanMMapIndexedDataset, HuffmanMMapIndex + +from . import FairseqDataset + +from typing import Union + + +def best_fitting_int_dtype( + max_int_to_represent, +) -> Union[np.uint16, np.uint32, np.int64]: + + if max_int_to_represent is None: + return np.uint32 # Safe guess + elif max_int_to_represent < 65500: + return np.uint16 + elif max_int_to_represent < 4294967295: + return np.uint32 + else: + return np.int64 + # we avoid np.uint64 because it doesn't save space and its type promotion behaves unexpectedly + # https://github.com/numpy/numpy/issues/5745 + + +def get_available_dataset_impl(): + return list(map(str, DATASET_IMPL_CHOICES)) + + +def infer_dataset_impl(path): + if IndexedRawTextDataset.exists(path): + return "raw" + elif IndexedDataset.exists(path): + with open(index_file_path(path), "rb") as f: + magic = f.read(8) + if magic == IndexedDataset._HDR_MAGIC: + return "cached" + elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: + return "mmap" + elif magic == HuffmanMMapIndex._HDR_MAGIC[:8]: + return "huffman" + else: + return None + elif FastaDataset.exists(path): + return "fasta" + else: + return None + + +def make_builder(out_file, impl, vocab_size=None): + if impl == "mmap": + return MMapIndexedDatasetBuilder( + out_file, dtype=best_fitting_int_dtype(vocab_size) + ) + elif impl == "fasta": + raise NotImplementedError + elif impl == "huffman": + raise ValueError( + "Use HuffmanCodeBuilder directly as it has a different interface." + ) + else: + return IndexedDatasetBuilder(out_file) + + +def make_dataset(path, impl, fix_lua_indexing=False, dictionary=None): + if impl == "raw" and IndexedRawTextDataset.exists(path): + assert dictionary is not None + return IndexedRawTextDataset(path, dictionary) + elif impl == "lazy" and IndexedDataset.exists(path): + return IndexedDataset(path, fix_lua_indexing=fix_lua_indexing) + elif impl == "cached" and IndexedDataset.exists(path): + return IndexedCachedDataset(path, fix_lua_indexing=fix_lua_indexing) + elif impl == "mmap" and MMapIndexedDataset.exists(path): + return MMapIndexedDataset(path) + elif impl == "fasta" and FastaDataset.exists(path): + from fairseq.data.fasta_dataset import EncodedFastaDataset + + return EncodedFastaDataset(path, dictionary) + elif impl == "huffman" and HuffmanMMapIndexedDataset.exists(path): + return HuffmanMMapIndexedDataset(path) + return None + + +def dataset_exists(path, impl): + if impl == "raw": + return IndexedRawTextDataset.exists(path) + elif impl == "mmap": + return MMapIndexedDataset.exists(path) + elif impl == "huffman": + return HuffmanMMapIndexedDataset.exists(path) + else: + return IndexedDataset.exists(path) + + +def read_longs(f, n): + a = np.empty(n, dtype=np.int64) + f.readinto(a) + return a + + +def write_longs(f, a): + f.write(np.array(a, dtype=np.int64)) + + +_code_to_dtype = { + 1: np.uint8, + 2: np.int8, + 3: np.int16, + 4: np.int32, + 5: np.int64, + 6: np.float64, + 7: np.double, + 8: np.uint16, + 9: np.uint32, + 10: np.uint64, +} + + +def _dtype_header_code(dtype) -> int: + for k in _code_to_dtype.keys(): + if _code_to_dtype[k] == dtype: + return k + raise ValueError(dtype) + + +def index_file_path(prefix_path): + return prefix_path + ".idx" + + +def data_file_path(prefix_path): + return prefix_path + ".bin" + + +class IndexedDataset(FairseqDataset): + """Loader for TorchNet IndexedDataset""" + + _HDR_MAGIC = b"TNTIDX\x00\x00" + + def __init__(self, path, fix_lua_indexing=False): + super().__init__() + self.path = path + self.fix_lua_indexing = fix_lua_indexing + self.data_file = None + self.read_index(path) + + def read_index(self, path): + with open(index_file_path(path), "rb") as f: + magic = f.read(8) + assert magic == self._HDR_MAGIC, ( + "Index file doesn't match expected format. " + "Make sure that --dataset-impl is configured properly." + ) + version = f.read(8) + assert struct.unpack("= self._len: + raise IndexError("index out of range") + + def __del__(self): + if self.data_file: + self.data_file.close() + + @lru_cache(maxsize=8) + def __getitem__(self, i) -> torch.Tensor: + if not self.data_file: + self.read_data(self.path) + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + item = torch.from_numpy(a).long() + if self.fix_lua_indexing: + item -= 1 # subtract 1 for 0-based indexing + return item + + def __len__(self): + return self._len + + def num_tokens(self, index): + return self.sizes[index] + + def size(self, index): + return self.sizes[index] + + @staticmethod + def exists(path): + return PathManager.exists(index_file_path(path)) and PathManager.exists( + data_file_path(path) + ) + + @property + def supports_prefetch(self): + return False # avoid prefetching to save memory + + +class IndexedCachedDataset(IndexedDataset): + def __init__(self, path, fix_lua_indexing=False): + super().__init__(path, fix_lua_indexing=fix_lua_indexing) + self.cache = None + self.cache_index = {} + + @property + def supports_prefetch(self): + return True + + def prefetch(self, indices): + if all(i in self.cache_index for i in indices): + return + if not self.data_file: + self.read_data(self.path) + indices = sorted(set(indices)) + total_size = 0 + for i in indices: + total_size += self.data_offsets[i + 1] - self.data_offsets[i] + self.cache = np.empty(total_size, dtype=self.dtype) + ptx = 0 + self.cache_index.clear() + for i in indices: + self.cache_index[i] = ptx + size = self.data_offsets[i + 1] - self.data_offsets[i] + a = self.cache[ptx : ptx + size] + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + ptx += size + if self.data_file: + # close and delete data file after prefetch so we can pickle + self.data_file.close() + self.data_file = None + + @lru_cache(maxsize=8) + def __getitem__(self, i): + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + ptx = self.cache_index[i] + np.copyto(a, self.cache[ptx : ptx + a.size]) + item = torch.from_numpy(a).long() + if self.fix_lua_indexing: + item -= 1 # subtract 1 for 0-based indexing + return item + + +class IndexedRawTextDataset(FairseqDataset): + """Takes a text file as input and binarizes it in memory at instantiation. + Original lines are also kept in memory""" + + def __init__(self, path, dictionary, append_eos=True, reverse_order=False): + self.tokens_list = [] + self.lines = [] + self.sizes = [] + self.append_eos = append_eos + self.reverse_order = reverse_order + self.read_data(path, dictionary) + self.size = len(self.tokens_list) + + def read_data(self, path, dictionary): + with open(path, "r", encoding="utf-8") as f: + for line in f: + self.lines.append(line.strip("\n")) + tokens = dictionary.encode_line( + line, + add_if_not_exist=False, + append_eos=self.append_eos, + reverse_order=self.reverse_order, + ).long() + self.tokens_list.append(tokens) + self.sizes.append(len(tokens)) + self.sizes = np.array(self.sizes) + + def check_index(self, i): + if i < 0 or i >= self.size: + raise IndexError("index out of range") + + @lru_cache(maxsize=8) + def __getitem__(self, i): + self.check_index(i) + return self.tokens_list[i] + + def get_original_text(self, i): + self.check_index(i) + return self.lines[i] + + def __del__(self): + pass + + def __len__(self): + return self.size + + def num_tokens(self, index): + return self.sizes[index] + + def size(self, index): + return self.sizes[index] + + @staticmethod + def exists(path): + return PathManager.exists(path) + + +class IndexedDatasetBuilder: + element_sizes = { + np.uint8: 1, + np.int8: 1, + np.int16: 2, + np.int32: 4, + np.int64: 8, + np.float64: 4, + np.double: 8, + } + + def __init__(self, out_file, dtype=np.int32): + self.out_file = open(out_file, "wb") + self.dtype = dtype + self.data_offsets = [0] + self.dim_offsets = [0] + self.sizes = [] + self.element_size = self.element_sizes[self.dtype] + + def add_item(self, tensor): + # +1 for Lua compatibility + bytes = self.out_file.write(np.array(tensor.numpy() + 1, dtype=self.dtype)) + self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size) + for s in tensor.size(): + self.sizes.append(s) + self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size())) + + def merge_file_(self, another_file): + index = IndexedDataset(another_file) + assert index.dtype == self.dtype + + begin = self.data_offsets[-1] + for offset in index.data_offsets[1:]: + self.data_offsets.append(begin + offset) + self.sizes.extend(index.sizes) + begin = self.dim_offsets[-1] + for dim_offset in index.dim_offsets[1:]: + self.dim_offsets.append(begin + dim_offset) + + with open(data_file_path(another_file), "rb") as f: + while True: + data = f.read(1024) + if data: + self.out_file.write(data) + else: + break + + def finalize(self, index_file): + self.out_file.close() + index = open(index_file, "wb") + index.write(b"TNTIDX\x00\x00") + index.write(struct.pack(" str: + local_index_path = PathManager.get_local_path(index_file_path(path)) + local_data_path = PathManager.get_local_path(data_file_path(path)) + + assert local_index_path.endswith(".idx") and local_data_path.endswith(".bin"), ( + "PathManager.get_local_path does not return files with expected patterns: " + f"{local_index_path} and {local_data_path}" + ) + + local_path = local_data_path[:-4] # stripping surfix ".bin" + assert local_path == local_index_path[:-4] # stripping surfix ".idx" + return local_path + + +class MMapIndexedDatasetBuilder: + def __init__(self, out_file, dtype=np.int64): + self._data_file = open(out_file, "wb") + self._dtype = dtype + self._sizes = [] + + def add_item(self, tensor): + np_array = np.array(tensor.numpy(), dtype=self._dtype) + self._data_file.write(np_array.tobytes(order="C")) + self._sizes.append(np_array.size) + + def merge_file_(self, another_file): + # Concatenate index + index = MMapIndexedDataset.Index(index_file_path(another_file)) + assert index.dtype == self._dtype + + for size in index.sizes: + self._sizes.append(size) + + # Concatenate data + with open(data_file_path(another_file), "rb") as f: + shutil.copyfileobj(f, self._data_file) + + def finalize(self, index_file): + self._data_file.close() + + with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index: + index.write(self._sizes) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py new file mode 100644 index 0000000000000000000000000000000000000000..6a5a42a9cf1cee9e0559d4c3b00024992271128c --- /dev/null +++ b/fairseq/data/iterators.py @@ -0,0 +1,879 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import logging +import math +import operator +import os +import queue +import time +from threading import Thread +from typing import Iterator, List + +import numpy as np +import torch +from fairseq.data import data_utils + + +logger = logging.getLogger(__name__) + +# Object used by _background_consumer to signal the source is exhausted +# to the main thread. +_sentinel = object() + + +class CountingIterator(object): + """Wrapper around an iterable that maintains the iteration count. + + Args: + iterable (iterable): iterable to wrap + start (int): starting iteration count. Note that this doesn't + actually advance the iterator. + total (int): override the iterator length returned by ``__len``. + This can be used to truncate *iterator*. + + Attributes: + n (int): number of elements consumed from this iterator + """ + + def __init__(self, iterable, start=None, total=None): + self._itr = iter(iterable) + self.n = start or getattr(iterable, "n", 0) + self.total = total if total is not None else self.n + len(iterable) + + def __len__(self): + return self.total + + def __iter__(self): + return self + + def __next__(self): + if not self.has_next(): + raise StopIteration + try: + x = next(self._itr) + except StopIteration: + raise IndexError( + f"Iterator expected to have length {self.total}, " + f"but exhausted at position {self.n}." + ) + self.n += 1 + return x + + def has_next(self): + """Whether the iterator has been exhausted.""" + return self.n < self.total + + def skip(self, n): + """Fast-forward the iterator by skipping n elements.""" + for _ in range(n): + next(self) + return self + + def take(self, n): + """Truncate the iterator to n elements at most.""" + self.total = min(self.total, n) + # Propagate this change to the underlying iterator + if hasattr(self._itr, "take"): + self._itr.take(max(n - self.n, 0)) + return self + + +class EpochBatchIterating(object): + def __len__(self) -> int: + raise NotImplementedError + + @property + def next_epoch_idx(self): + raise NotImplementedError + + def next_epoch_itr( + self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True + ): + """Return a new iterator over the dataset. + + Args: + shuffle (bool, optional): shuffle batches before returning the + iterator (default: True). + fix_batches_to_gpus (bool, optional): ensure that batches are always + allocated to the same shards across epochs. Requires + that :attr:`dataset` supports prefetching (default: False). + set_dataset_epoch (bool, optional): update the wrapped Dataset with + the new epoch number (default: True). + """ + raise NotImplementedError + + def end_of_epoch(self) -> bool: + """Returns whether the most recent epoch iterator has been exhausted""" + raise NotImplementedError + + @property + def iterations_in_epoch(self) -> int: + """The number of consumed batches in the current epoch.""" + raise NotImplementedError + + def state_dict(self): + """Returns a dictionary containing a whole state of the iterator.""" + raise NotImplementedError + + def load_state_dict(self, state_dict): + """Copies the state of the iterator from the given *state_dict*.""" + raise NotImplementedError + + @property + def first_batch(self): + return "DUMMY" + + +class StreamingEpochBatchIterator(EpochBatchIterating): + """A steaming-style iterator over a :class:`torch.utils.data.IterableDataset`. + + Args: + dataset (~torch.utils.data.Dataset): dataset from which to load the data + max_sentences: batch size + collate_fn (callable): merges a list of samples to form a mini-batch + num_workers (int, optional): how many subprocesses to use for data + loading. 0 means the data will be loaded in the main process + (default: 0). + epoch (int, optional): the epoch to start the iterator from + (default: 1). + buffer_size (int, optional): the number of batches to keep ready in the + queue. Helps speeding up dataloading. When buffer_size is zero, the + default torch.utils.data.DataLoader preloading is used. + timeout (int, optional): if positive, the timeout value for collecting a batch + from workers. Should always be non-negative (default: ``0``). + """ + + def __init__( + self, + dataset, + max_sentences=1, + collate_fn=None, + epoch=1, + num_workers=0, + buffer_size=0, + timeout=0, + persistent_workers=True, + ): + assert isinstance(dataset, torch.utils.data.IterableDataset) + self.dataset = dataset + self.max_sentences = max_sentences + self.collate_fn = collate_fn + self.epoch = max(epoch, 1) # we use 1-based indexing for epochs + self.num_workers = num_workers + self.persistent_workers = persistent_workers and num_workers > 0 + # This upper limit here is to prevent people from abusing this feature + # in a shared computing environment. + self.buffer_size = min(buffer_size, 20) + self.timeout = timeout + + self._current_epoch_iterator = None + + @property + def next_epoch_idx(self): + """Return the epoch index after *next_epoch_itr* is called.""" + if self._current_epoch_iterator is not None and self.end_of_epoch(): + return self.epoch + 1 + else: + return self.epoch + + def next_epoch_itr( + self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True + ): + self.epoch = self.next_epoch_idx + if set_dataset_epoch and hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(self.epoch) + self._current_epoch_iterator = self._get_iterator_for_epoch(self.epoch, shuffle) + return self._current_epoch_iterator + + def end_of_epoch(self) -> bool: + return not self._current_epoch_iterator.has_next() + + @property + def iterations_in_epoch(self) -> int: + if self._current_epoch_iterator is not None: + return self._current_epoch_iterator.n + return 0 + + def state_dict(self): + return { + "epoch": self.epoch, + } + + def load_state_dict(self, state_dict): + self.epoch = state_dict["epoch"] + + def _get_iterator_for_epoch(self, epoch, shuffle, offset=0): + if self.num_workers > 0: + os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" + + # Create data loader + worker_init_fn = getattr(self.dataset, "worker_init_fn", None) + itr = torch.utils.data.DataLoader( + self.dataset, + batch_size=self.max_sentences, + collate_fn=self.collate_fn, + num_workers=self.num_workers, + timeout=self.timeout, + worker_init_fn=worker_init_fn, + pin_memory=True, + persistent_workers=self.persistent_workers, + ) + + # Wrap with a BufferedIterator if needed + if self.buffer_size > 0: + itr = BufferedIterator(self.buffer_size, itr) + + # Wrap with CountingIterator + itr = CountingIterator(itr, start=offset) + + return itr + + +class FrozenBatchSampler: + def __init__( + self, + ordered_batches, + epoch, + fix_batches_to_gpus, + shuffle, + initial_offset, + ): + self.ordered_batches = ordered_batches + self.fix_batches_to_gpus = fix_batches_to_gpus + self.shuffle = shuffle + self.make_batches_for_epoch(epoch, initial_offset) + + def make_batches_for_epoch(self, epoch, offset=0): + self.batches = self.ordered_batches( + epoch, self.fix_batches_to_gpus, self.shuffle + ) + if offset > 0: + self.batches = self.batches[offset:] + + def __iter__(self) -> Iterator[List[int]]: + return iter(self.batches) + + def __len__(self) -> int: + return len(self.batches) + + +class EpochBatchIterator(EpochBatchIterating): + """A multi-epoch iterator over a :class:`torch.utils.data.Dataset`. + + Compared to :class:`torch.utils.data.DataLoader`, this iterator: + + - can be reused across multiple epochs with the :func:`next_epoch_itr` + method (optionally shuffled between epochs) + - can be serialized/deserialized with the :func:`state_dict` and + :func:`load_state_dict` methods + - supports sharding with the *num_shards* and *shard_id* arguments + + Args: + dataset (~torch.utils.data.Dataset): dataset from which to load the data + collate_fn (callable): merges a list of samples to form a mini-batch + batch_sampler (~torch.utils.data.Sampler or a callable): an iterator over batches of + indices, or a callable to create such an iterator (~torch.utils.data.Sampler). + A callable batch_sampler will be called for each epoch to enable per epoch dynamic + batch iterators defined by this callable batch_sampler. + seed (int, optional): seed for random number generator for + reproducibility (default: 1). + num_shards (int, optional): shard the data iterator into N + shards (default: 1). + shard_id (int, optional): which shard of the data iterator to + return (default: 0). + num_workers (int, optional): how many subprocesses to use for data + loading. 0 means the data will be loaded in the main process + (default: 0). + epoch (int, optional): the epoch to start the iterator from + (default: 1). + buffer_size (int, optional): the number of batches to keep ready in the + queue. Helps speeding up dataloading. When buffer_size is zero, the + default torch.utils.data.DataLoader preloading is used. + timeout (int, optional): if positive, the timeout value for collecting a batch + from workers. Should always be non-negative (default: ``0``). + disable_shuffling (bool, optional): force disable shuffling + (default: ``False``). + skip_remainder_batch (bool, optional): if set, discard the last batch in an epoch + for the sake of training stability, as the last batch is usually smaller than + local_batch_size * distributed_word_size (default: ``False``). + grouped_shuffling (bool, optional): enable shuffling batches in groups + of num_shards. Ensures that each GPU receives similar length sequences when + batches are sorted by length. + """ + + def __init__( + self, + dataset, + collate_fn, + batch_sampler, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=1, + buffer_size=0, + timeout=0, + disable_shuffling=False, + skip_remainder_batch=False, + grouped_shuffling=False, + reuse_dataloader=False, + persistent_workers=True, + ): + assert isinstance(dataset, torch.utils.data.Dataset) + self.dataset = dataset + self.collate_fn = collate_fn + self.batch_sampler = batch_sampler + self._frozen_batches = ( + tuple(batch_sampler) if not callable(batch_sampler) else None + ) + self.seed = seed + self.num_shards = num_shards + self.shard_id = shard_id + self.num_workers = num_workers + self.persistent_workers = persistent_workers and num_workers > 0 + # This upper limit here is to prevent people from abusing this feature + # in a shared computing environment. + self.buffer_size = min(buffer_size, 20) + self.timeout = timeout + self.disable_shuffling = disable_shuffling + self.skip_remainder_batch = skip_remainder_batch + self.grouped_shuffling = grouped_shuffling + + self.epoch = max(epoch, 1) # we use 1-based indexing for epochs + self.shuffle = not disable_shuffling + self._cur_epoch_itr = None + self._next_epoch_itr = None + self._supports_prefetch = getattr(dataset, "supports_prefetch", False) + + self.dataloader = None + self.reuse_dataloader = reuse_dataloader + + @property + def frozen_batches(self): + if self._frozen_batches is None: + self._frozen_batches = tuple(self.batch_sampler(self.dataset, self.epoch)) + return self._frozen_batches + + @property + def first_batch(self): + if len(self.frozen_batches) == 0: + raise Exception( + "The dataset is empty. This could indicate " + "that all elements in the dataset have been skipped. " + "Try increasing the max number of allowed tokens or using " + "a larger dataset." + ) + + if getattr(self.dataset, "supports_fetch_outside_dataloader", True): + return self.collate_fn([self.dataset[i] for i in self.frozen_batches[0]]) + else: + return "DUMMY" + + def __len__(self): + return int(math.ceil(len(self.frozen_batches) / float(self.num_shards))) + + @property + def n(self): + return self.iterations_in_epoch + + @property + def next_epoch_idx(self): + """Return the epoch index after *next_epoch_itr* is called.""" + if self._next_epoch_itr is not None: + return self.epoch + elif self._cur_epoch_itr is not None and self.end_of_epoch(): + return self.epoch + 1 + else: + return self.epoch + + def next_epoch_itr( + self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True + ): + """Return a new iterator over the dataset. + + Args: + shuffle (bool, optional): shuffle batches before returning the + iterator (default: True). + fix_batches_to_gpus (bool, optional): ensure that batches are always + allocated to the same shards across epochs. Requires + that :attr:`dataset` supports prefetching (default: False). + set_dataset_epoch (bool, optional): update the wrapped Dataset with + the new epoch number (default: True). + """ + if self.disable_shuffling: + shuffle = False + prev_epoch = self.epoch + self.epoch = self.next_epoch_idx + if set_dataset_epoch and hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(self.epoch) + if self._next_epoch_itr is not None: + self._cur_epoch_itr = self._next_epoch_itr + self._next_epoch_itr = None + else: + if callable(self.batch_sampler) and prev_epoch != self.epoch: + # reset _frozen_batches to refresh the next epoch + self._frozen_batches = None + self._cur_epoch_itr = self._get_iterator_for_epoch( + self.epoch, + shuffle, + fix_batches_to_gpus=fix_batches_to_gpus, + ) + self.shuffle = shuffle + return self._cur_epoch_itr + + def end_of_epoch(self) -> bool: + """Returns whether the most recent epoch iterator has been exhausted""" + return not self._cur_epoch_itr.has_next() + + @property + def iterations_in_epoch(self): + """The number of consumed batches in the current epoch.""" + if self._cur_epoch_itr is not None: + return self._cur_epoch_itr.n + elif self._next_epoch_itr is not None: + return self._next_epoch_itr.n + return 0 + + def state_dict(self): + """Returns a dictionary containing a whole state of the iterator.""" + if self.end_of_epoch(): + epoch = self.epoch + 1 + iter_in_epoch = 0 + else: + epoch = self.epoch + iter_in_epoch = self.iterations_in_epoch + return { + "version": 2, + "epoch": epoch, + "iterations_in_epoch": iter_in_epoch, + "shuffle": self.shuffle, + } + + def load_state_dict(self, state_dict): + """Copies the state of the iterator from the given *state_dict*.""" + self.epoch = state_dict["epoch"] + itr_pos = state_dict.get("iterations_in_epoch", 0) + version = state_dict.get("version", 1) + if itr_pos > 0: + # fast-forward epoch iterator + self._next_epoch_itr = self._get_iterator_for_epoch( + self.epoch, + shuffle=state_dict.get("shuffle", True), + offset=itr_pos, + ) + if self._next_epoch_itr is None: + if version == 1: + # legacy behavior: we finished the epoch, increment epoch counter + self.epoch += 1 + else: + raise RuntimeError( + "Cannot resume training due to dataloader mismatch, please " + "report this to the fairseq developers. You can relaunch " + "training with `--reset-dataloader` and it should work." + ) + else: + self._next_epoch_itr = None + + def _get_iterator_for_epoch( + self, epoch, shuffle, fix_batches_to_gpus=False, offset=0 + ): + if self.reuse_dataloader and self.dataloader is not None: + self.epoch_batch_sampler.make_batches_for_epoch(epoch, offset) + itr = self.dataloader + else: + self.epoch_batch_sampler = FrozenBatchSampler( + self.ordered_batches, + epoch, + fix_batches_to_gpus, + shuffle, + initial_offset=offset, + ) + + if offset > 0 and len(self.epoch_batch_sampler) == 0: + return None + + if self.num_workers > 0: + os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" + + # Create data loader + itr = torch.utils.data.DataLoader( + self.dataset, + collate_fn=self.collate_fn, + batch_sampler=self.epoch_batch_sampler, + num_workers=self.num_workers, + timeout=self.timeout, + pin_memory=True, + persistent_workers=self.persistent_workers, + ) + + if self.reuse_dataloader: + self.dataloader = itr + + # Wrap with a BufferedIterator if needed + if self.buffer_size > 0: + itr = BufferedIterator(self.buffer_size, itr) + + # Wrap with CountingIterator + itr = CountingIterator(itr, start=offset) + + if self.skip_remainder_batch: + # TODO: Below is a lazy implementation which discard the final batch regardless + # of whether it is a full batch or not. + + total_num_itrs = len(itr) - 1 + itr.take(total_num_itrs) + logger.info(f"skip final residual batch, total_num_itrs = {total_num_itrs}") + + return itr + + def ordered_batches(self, epoch, fix_batches_to_gpus, shuffle): + def shuffle_batches(batches, seed): + with data_utils.numpy_seed(seed): + + if self.grouped_shuffling: + grouped_batches = [ + batches[(i * self.num_shards) : ((i + 1) * self.num_shards)] + for i in range((len(batches) // self.num_shards)) + ] + np.random.shuffle(grouped_batches) + batches = list(itertools.chain(*grouped_batches)) + else: + np.random.shuffle(batches) + + return batches + + if self._supports_prefetch: + batches = self.frozen_batches + + if shuffle and not fix_batches_to_gpus: + batches = shuffle_batches(list(batches), self.seed + epoch) + + batches = list( + ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]) + ) + self.dataset.prefetch([i for s in batches for i in s]) + + if shuffle and fix_batches_to_gpus: + batches = shuffle_batches(batches, self.seed + epoch + self.shard_id) + else: + if shuffle: + batches = shuffle_batches(list(self.frozen_batches), self.seed + epoch) + else: + batches = self.frozen_batches + batches = list( + ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]) + ) + return batches + + +class GroupedIterator(CountingIterator): + """Wrapper around an iterable that returns groups (chunks) of items. + + Args: + iterable (iterable): iterable to wrap + chunk_size (int): size of each chunk + skip_remainder_batch (bool, optional): if set, discard the last grouped batch in + each training epoch, as the last grouped batch is usually smaller than + local_batch_size * distributed_word_size * chunk_size (default: ``False``). + Attributes: + n (int): number of elements consumed from this iterator + """ + + def __init__(self, iterable, chunk_size, skip_remainder_batch=False): + if skip_remainder_batch: + total_num_itrs = int(math.floor(len(iterable) / float(chunk_size))) + logger.info( + f"skip final residual batch, grouped total_num_itrs = {total_num_itrs}" + ) + else: + total_num_itrs = int(math.ceil(len(iterable) / float(chunk_size))) + logger.info(f"grouped total_num_itrs = {total_num_itrs}") + + itr = _chunk_iterator(iterable, chunk_size, skip_remainder_batch) + super().__init__( + itr, + start=int(math.ceil(getattr(iterable, "n", 0) / float(chunk_size))), + total=total_num_itrs, + ) + self.chunk_size = chunk_size + + if skip_remainder_batch: + self.take(total_num_itrs) + # TODO: [Hack] Here the grouped iterator modifies the base iterator size so that + # training can move into the next epoch once the grouped iterator is exhausted. + # Double-check this implementation in case unexpected behavior occurs. + iterable.take(total_num_itrs * chunk_size) + + +def _chunk_iterator(itr, chunk_size, skip_remainder_batch=False): + chunk = [] + for x in itr: + chunk.append(x) + if len(chunk) == chunk_size: + yield chunk + chunk = [] + if not skip_remainder_batch and len(chunk) > 0: + yield chunk + + +class ShardedIterator(CountingIterator): + """A sharded wrapper around an iterable, padded to length. + + Args: + iterable (iterable): iterable to wrap + num_shards (int): number of shards to split the iterable into + shard_id (int): which shard to iterator over + fill_value (Any, optional): padding value when the iterable doesn't + evenly divide *num_shards* (default: None). + + Attributes: + n (int): number of elements consumed from this iterator + """ + + def __init__( + self, iterable, num_shards, shard_id, fill_value=None, skip_remainder_batch=None + ): + """ + Args: + skip_remainder_batch: ignored""" + if shard_id < 0 or shard_id >= num_shards: + raise ValueError("shard_id must be between 0 and num_shards") + sharded_len = int(math.ceil(len(iterable) / float(num_shards))) + itr = map( + operator.itemgetter(1), + itertools.zip_longest( + range(sharded_len), + itertools.islice(iterable, shard_id, len(iterable), num_shards), + fillvalue=fill_value, + ), + ) + super().__init__( + itr, + start=int(math.ceil(getattr(iterable, "n", 0) / float(num_shards))), + total=sharded_len, + ) + + +class BackgroundConsumer(Thread): + def __init__(self, queue, source, max_len, cuda_device): + Thread.__init__(self) + + self._queue = queue + self._source = source + self._max_len = max_len + self.count = 0 + self.cuda_device = cuda_device + + def run(self): + # set_device to avoid creation of GPU0 context when using pin_memory + if self.cuda_device is not None: + torch.cuda.set_device(self.cuda_device) + + try: + for item in self._source: + self._queue.put(item) + + # Stop if we reached the maximum length + self.count += 1 + if self._max_len is not None and self.count >= self._max_len: + break + + # Signal the consumer we are done. + self._queue.put(_sentinel) + except Exception as e: + self._queue.put(e) + + +class BufferedIterator(object): + def __init__(self, size, iterable): + self._queue = queue.Queue(size) + self._iterable = iterable + self._consumer = None + + self.start_time = time.time() + self.warning_time = None + + self.total = len(iterable) + + def _create_consumer(self): + self._consumer = BackgroundConsumer( + self._queue, + self._iterable, + self.total, + torch.cuda.current_device() if torch.cuda.is_available() else None, + ) + self._consumer.daemon = True + self._consumer.start() + + def __iter__(self): + return self + + def __len__(self): + return self.total + + def take(self, n): + self.total = min(self.total, n) + # Propagate this change to the underlying iterator + if hasattr(self._iterable, "take"): + self._iterable.take(n) + return self + + def __next__(self): + # Create consumer if not created yet + if self._consumer is None: + self._create_consumer() + + # Notify the user if there is a data loading bottleneck + if self._queue.qsize() < min(2, max(1, self._queue.maxsize // 2)): + if time.time() - self.start_time > 5 * 60: + if ( + self.warning_time is None + or time.time() - self.warning_time > 15 * 60 + ): + logger.debug( + "Data loading buffer is empty or nearly empty. This may " + "indicate a data loading bottleneck, and increasing the " + "number of workers (--num-workers) may help." + ) + self.warning_time = time.time() + + # Get next example + item = self._queue.get(True) + if isinstance(item, Exception): + raise item + if item is _sentinel: + raise StopIteration() + return item + + +class GroupedEpochBatchIterator(EpochBatchIterator): + """Grouped version of EpochBatchIterator + It takes several samplers from different datasets. + Each epoch shuffle the dataset wise sampler individually with different + random seed. The those sub samplers are combined with into + one big samplers with deterministic permutation to mix batches from + different datasets. It will act like EpochBatchIterator but make sure + 1) data from one data set each time + 2) for different workers, they use the same order to fetch the data + so they will use data from the same dataset everytime + mult_rate is used for update_freq > 1 case where we want to make sure update_freq + mini-batches come from same source + """ + + def __init__( + self, + dataset, + collate_fn, + batch_samplers, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=0, + mult_rate=1, + buffer_size=0, + skip_remainder_batch=False, + ): + super().__init__( + dataset, + collate_fn, + batch_samplers, + seed, + num_shards, + shard_id, + num_workers, + epoch, + buffer_size, + skip_remainder_batch=skip_remainder_batch, + ) + # level 0: sub-samplers 1: batch_idx 2: batches + self._frozen_batches = tuple([tuple(sub_batch) for sub_batch in batch_samplers]) + self.step_size = mult_rate * num_shards + + self.lengths = [ + (len(x) // self.step_size) * self.step_size for x in self.frozen_batches + ] + + def __len__(self): + return sum(self.lengths) + + @property + def first_batch(self): + if len(self.frozen_batches) == 0: + raise Exception( + "The dataset is empty. This could indicate " + "that all elements in the dataset have been skipped. " + "Try increasing the max number of allowed tokens or using " + "a larger dataset." + ) + + if self.dataset.supports_fetch_outside_dataloader: + return self.collate_fn([self.dataset[i] for i in self.frozen_batches[0][0]]) + else: + return "DUMMY" + + def _get_iterator_for_epoch( + self, epoch, shuffle, fix_batches_to_gpus=False, offset=0 + ): + def shuffle_batches(batches, seed): + with data_utils.numpy_seed(seed): + np.random.shuffle(batches) + return batches + + def return_full_batches(batch_sets, seed, shuffle): + if shuffle: + batch_sets = [shuffle_batches(list(x), seed) for x in batch_sets] + + batch_sets = [ + batch_sets[i][: self.lengths[i]] for i in range(len(batch_sets)) + ] + batches = list(itertools.chain.from_iterable(batch_sets)) + + if shuffle: + with data_utils.numpy_seed(seed): + idx = np.random.permutation(len(batches) // self.step_size) + if len(idx) * self.step_size != len(batches): + raise ValueError( + "ERROR: %d %d %d %d" + % (len(idx), self.step_size, len(batches), self.shard_id), + ":".join(["%d" % x for x in self.lengths]), + ) + mini_shards = [ + batches[i * self.step_size : (i + 1) * self.step_size] + for i in idx + ] + batches = list(itertools.chain.from_iterable(mini_shards)) + + return batches + + if self._supports_prefetch: + raise NotImplementedError("To be implemented") + else: + batches = return_full_batches( + self.frozen_batches, self.seed + epoch, shuffle + ) + batches = list( + ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]) + ) + + if offset > 0 and offset >= len(batches): + return None + + if self.num_workers > 0: + os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" + + itr = torch.utils.data.DataLoader( + self.dataset, + collate_fn=self.collate_fn, + batch_sampler=batches[offset:], + num_workers=self.num_workers, + persistent_workers=self.persistent_workers, + ) + if self.buffer_size > 0: + itr = BufferedIterator(self.buffer_size, itr) + + return CountingIterator(itr, start=offset) diff --git a/fairseq/data/language_pair_dataset.py b/fairseq/data/language_pair_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fd356ddd044faebd7bb09aeb22499f6b70304216 --- /dev/null +++ b/fairseq/data/language_pair_dataset.py @@ -0,0 +1,477 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import numpy as np +import torch +from fairseq.data import FairseqDataset, data_utils + + +logger = logging.getLogger(__name__) + + +def collate( + samples, + pad_idx, + eos_idx, + left_pad_source=True, + left_pad_target=False, + input_feeding=True, + pad_to_length=None, + pad_to_multiple=1, +): + if len(samples) == 0: + return {} + + def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): + return data_utils.collate_tokens( + [s[key] for s in samples], + pad_idx, + eos_idx, + left_pad, + move_eos_to_beginning, + pad_to_length=pad_to_length, + pad_to_multiple=pad_to_multiple, + ) + + def check_alignment(alignment, src_len, tgt_len): + if alignment is None or len(alignment) == 0: + return False + if ( + alignment[:, 0].max().item() >= src_len - 1 + or alignment[:, 1].max().item() >= tgt_len - 1 + ): + logger.warning("alignment size mismatch found, skipping alignment!") + return False + return True + + def compute_alignment_weights(alignments): + """ + Given a tensor of shape [:, 2] containing the source-target indices + corresponding to the alignments, a weight vector containing the + inverse frequency of each target index is computed. + For e.g. if alignments = [[5, 7], [2, 3], [1, 3], [4, 2]], then + a tensor containing [1., 0.5, 0.5, 1] should be returned (since target + index 3 is repeated twice) + """ + align_tgt = alignments[:, 1] + _, align_tgt_i, align_tgt_c = torch.unique( + align_tgt, return_inverse=True, return_counts=True + ) + align_weights = align_tgt_c[align_tgt_i[np.arange(len(align_tgt))]] + return 1.0 / align_weights.float() + + id = torch.LongTensor([s["id"] for s in samples]) + src_tokens = merge( + "source", + left_pad=left_pad_source, + pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, + ) + # sort by descending source length + src_lengths = torch.LongTensor( + [s["source"].ne(pad_idx).long().sum() for s in samples] + ) + src_lengths, sort_order = src_lengths.sort(descending=True) + id = id.index_select(0, sort_order) + src_tokens = src_tokens.index_select(0, sort_order) + + prev_output_tokens = None + target = None + if samples[0].get("target", None) is not None: + target = merge( + "target", + left_pad=left_pad_target, + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, + ) + target = target.index_select(0, sort_order) + tgt_lengths = torch.LongTensor( + [s["target"].ne(pad_idx).long().sum() for s in samples] + ).index_select(0, sort_order) + ntokens = tgt_lengths.sum().item() + + if samples[0].get("prev_output_tokens", None) is not None: + prev_output_tokens = merge("prev_output_tokens", left_pad=left_pad_target) + elif input_feeding: + # we create a shifted version of targets for feeding the + # previous output token(s) into the next decoder step + prev_output_tokens = merge( + "target", + left_pad=left_pad_target, + move_eos_to_beginning=True, + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, + ) + else: + ntokens = src_lengths.sum().item() + + batch = { + "id": id, + "nsentences": len(samples), + "ntokens": ntokens, + "net_input": { + "src_tokens": src_tokens, + "src_lengths": src_lengths, + }, + "target": target, + } + if prev_output_tokens is not None: + batch["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select( + 0, sort_order + ) + + if samples[0].get("alignment", None) is not None: + bsz, tgt_sz = batch["target"].shape + src_sz = batch["net_input"]["src_tokens"].shape[1] + + offsets = torch.zeros((len(sort_order), 2), dtype=torch.long) + offsets[:, 1] += torch.arange(len(sort_order), dtype=torch.long) * tgt_sz + if left_pad_source: + offsets[:, 0] += src_sz - src_lengths + if left_pad_target: + offsets[:, 1] += tgt_sz - tgt_lengths + + alignments = [ + alignment + offset + for align_idx, offset, src_len, tgt_len in zip( + sort_order, offsets, src_lengths, tgt_lengths + ) + for alignment in [samples[align_idx]["alignment"].view(-1, 2)] + if check_alignment(alignment, src_len, tgt_len) + ] + + if len(alignments) > 0: + alignments = torch.cat(alignments, dim=0) + align_weights = compute_alignment_weights(alignments) + + batch["alignments"] = alignments + batch["align_weights"] = align_weights + + if samples[0].get("constraints", None) is not None: + # Collate the packed constraints across the samples, padding to + # the length of the longest sample. + lens = [sample.get("constraints").size(0) for sample in samples] + max_len = max(lens) + constraints = torch.zeros((len(samples), max(lens))).long() + for i, sample in enumerate(samples): + constraints[i, 0 : lens[i]] = samples[i].get("constraints") + batch["constraints"] = constraints.index_select(0, sort_order) + + return batch + + +class LanguagePairDataset(FairseqDataset): + """ + A pair of torch.utils.data.Datasets. + + Args: + src (torch.utils.data.Dataset): source dataset to wrap + src_sizes (List[int]): source sentence lengths + src_dict (~fairseq.data.Dictionary): source vocabulary + tgt (torch.utils.data.Dataset, optional): target dataset to wrap + tgt_sizes (List[int], optional): target sentence lengths + tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary + left_pad_source (bool, optional): pad source tensors on the left side + (default: True). + left_pad_target (bool, optional): pad target tensors on the left side + (default: False). + shuffle (bool, optional): shuffle dataset elements before batching + (default: True). + input_feeding (bool, optional): create a shifted version of the targets + to be passed into the model for teacher forcing (default: True). + remove_eos_from_source (bool, optional): if set, removes eos from end + of source if it's present (default: False). + append_eos_to_target (bool, optional): if set, appends eos to end of + target if it's absent (default: False). + align_dataset (torch.utils.data.Dataset, optional): dataset + containing alignments. + constraints (Tensor, optional): 2d tensor with a concatenated, zero- + delimited list of constraints for each sentence. + append_bos (bool, optional): if set, appends bos to the beginning of + source/target sentence. + num_buckets (int, optional): if set to a value greater than 0, then + batches will be bucketed into the given number of batch shapes. + src_lang_id (int, optional): source language ID, if set, the collated batch + will contain a field 'src_lang_id' in 'net_input' which indicates the + source language of the samples. + tgt_lang_id (int, optional): target language ID, if set, the collated batch + will contain a field 'tgt_lang_id' which indicates the target language + of the samples. + """ + + def __init__( + self, + src, + src_sizes, + src_dict, + tgt=None, + tgt_sizes=None, + tgt_dict=None, + left_pad_source=True, + left_pad_target=False, + shuffle=True, + input_feeding=True, + remove_eos_from_source=False, + append_eos_to_target=False, + align_dataset=None, + constraints=None, + append_bos=False, + eos=None, + num_buckets=0, + src_lang_id=None, + tgt_lang_id=None, + pad_to_multiple=1, + ): + if tgt_dict is not None: + assert src_dict.pad() == tgt_dict.pad() + assert src_dict.eos() == tgt_dict.eos() + assert src_dict.unk() == tgt_dict.unk() + if tgt is not None: + assert len(src) == len( + tgt + ), "Source and target must contain the same number of examples" + self.src = src + self.tgt = tgt + self.src_sizes = np.array(src_sizes) + self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None + self.sizes = ( + np.vstack((self.src_sizes, self.tgt_sizes)).T + if self.tgt_sizes is not None + else self.src_sizes + ) + self.src_dict = src_dict + self.tgt_dict = tgt_dict + self.left_pad_source = left_pad_source + self.left_pad_target = left_pad_target + self.shuffle = shuffle + self.input_feeding = input_feeding + self.remove_eos_from_source = remove_eos_from_source + self.append_eos_to_target = append_eos_to_target + self.align_dataset = align_dataset + if self.align_dataset is not None: + assert ( + self.tgt_sizes is not None + ), "Both source and target needed when alignments are provided" + self.constraints = constraints + self.append_bos = append_bos + self.eos = eos if eos is not None else src_dict.eos() + self.src_lang_id = src_lang_id + self.tgt_lang_id = tgt_lang_id + if num_buckets > 0: + from fairseq.data import BucketPadLengthDataset + + self.src = BucketPadLengthDataset( + self.src, + sizes=self.src_sizes, + num_buckets=num_buckets, + pad_idx=self.src_dict.pad(), + left_pad=self.left_pad_source, + ) + self.src_sizes = self.src.sizes + logger.info("bucketing source lengths: {}".format(list(self.src.buckets))) + if self.tgt is not None: + self.tgt = BucketPadLengthDataset( + self.tgt, + sizes=self.tgt_sizes, + num_buckets=num_buckets, + pad_idx=self.tgt_dict.pad(), + left_pad=self.left_pad_target, + ) + self.tgt_sizes = self.tgt.sizes + logger.info( + "bucketing target lengths: {}".format(list(self.tgt.buckets)) + ) + + # determine bucket sizes using self.num_tokens, which will return + # the padded lengths (thanks to BucketPadLengthDataset) + num_tokens = np.vectorize(self.num_tokens, otypes=[np.compat.long]) + self.bucketed_num_tokens = num_tokens(np.arange(len(self.src))) + self.buckets = [ + (None, num_tokens) for num_tokens in np.unique(self.bucketed_num_tokens) + ] + else: + self.buckets = None + self.pad_to_multiple = pad_to_multiple + + def get_batch_shapes(self): + return self.buckets + + def __getitem__(self, index): + tgt_item = self.tgt[index] if self.tgt is not None else None + src_item = self.src[index] + # Append EOS to end of tgt sentence if it does not have an EOS and remove + # EOS from end of src sentence if it exists. This is useful when we use + # use existing datasets for opposite directions i.e., when we want to + # use tgt_dataset as src_dataset and vice versa + if self.append_eos_to_target: + eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos() + if self.tgt and self.tgt[index][-1] != eos: + tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])]) + + if self.append_bos: + bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos() + if self.tgt and self.tgt[index][0] != bos: + tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]]) + + bos = self.src_dict.bos() + if self.src[index][0] != bos: + src_item = torch.cat([torch.LongTensor([bos]), self.src[index]]) + + if self.remove_eos_from_source: + eos = self.src_dict.eos() + if self.src[index][-1] == eos: + src_item = self.src[index][:-1] + + example = { + "id": index, + "source": src_item, + "target": tgt_item, + } + if self.align_dataset is not None: + example["alignment"] = self.align_dataset[index] + if self.constraints is not None: + example["constraints"] = self.constraints[index] + return example + + def __len__(self): + return len(self.src) + + def collater(self, samples, pad_to_length=None): + """Merge a list of samples to form a mini-batch. + + Args: + samples (List[dict]): samples to collate + pad_to_length (dict, optional): a dictionary of + {'source': source_pad_to_length, 'target': target_pad_to_length} + to indicate the max length to pad to in source and target respectively. + + Returns: + dict: a mini-batch with the following keys: + + - `id` (LongTensor): example IDs in the original input order + - `ntokens` (int): total number of tokens in the batch + - `net_input` (dict): the input to the Model, containing keys: + + - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in + the source sentence of shape `(bsz, src_len)`. Padding will + appear on the left if *left_pad_source* is ``True``. + - `src_lengths` (LongTensor): 1D Tensor of the unpadded + lengths of each source sentence of shape `(bsz)` + - `prev_output_tokens` (LongTensor): a padded 2D Tensor of + tokens in the target sentence, shifted right by one + position for teacher forcing, of shape `(bsz, tgt_len)`. + This key will not be present if *input_feeding* is + ``False``. Padding will appear on the left if + *left_pad_target* is ``True``. + - `src_lang_id` (LongTensor): a long Tensor which contains source + language IDs of each sample in the batch + + - `target` (LongTensor): a padded 2D Tensor of tokens in the + target sentence of shape `(bsz, tgt_len)`. Padding will appear + on the left if *left_pad_target* is ``True``. + - `tgt_lang_id` (LongTensor): a long Tensor which contains target language + IDs of each sample in the batch + """ + res = collate( + samples, + pad_idx=self.src_dict.pad(), + eos_idx=self.eos, + left_pad_source=self.left_pad_source, + left_pad_target=self.left_pad_target, + input_feeding=self.input_feeding, + pad_to_length=pad_to_length, + pad_to_multiple=self.pad_to_multiple, + ) + if self.src_lang_id is not None or self.tgt_lang_id is not None: + src_tokens = res["net_input"]["src_tokens"] + bsz = src_tokens.size(0) + if self.src_lang_id is not None: + res["net_input"]["src_lang_id"] = ( + torch.LongTensor([[self.src_lang_id]]).expand(bsz, 1).to(src_tokens) + ) + if self.tgt_lang_id is not None: + res["tgt_lang_id"] = ( + torch.LongTensor([[self.tgt_lang_id]]).expand(bsz, 1).to(src_tokens) + ) + return res + + def num_tokens(self, index): + """Return the number of tokens in a sample. This value is used to + enforce ``--max-tokens`` during batching.""" + return max( + self.src_sizes[index], + self.tgt_sizes[index] if self.tgt_sizes is not None else 0, + ) + + def num_tokens_vec(self, indices): + """Return the number of tokens for a set of positions defined by indices. + This value is used to enforce ``--max-tokens`` during batching.""" + sizes = self.src_sizes[indices] + if self.tgt_sizes is not None: + sizes = np.maximum(sizes, self.tgt_sizes[indices]) + return sizes + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + return ( + self.src_sizes[index], + self.tgt_sizes[index] if self.tgt_sizes is not None else 0, + ) + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + indices = np.random.permutation(len(self)).astype(np.int64) + else: + indices = np.arange(len(self), dtype=np.int64) + if self.buckets is None: + # sort by target length, then source length + if self.tgt_sizes is not None: + indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")] + return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] + else: + # sort by bucketed_num_tokens, which is: + # max(padded_src_len, padded_tgt_len) + return indices[ + np.argsort(self.bucketed_num_tokens[indices], kind="mergesort") + ] + + @property + def supports_prefetch(self): + return getattr(self.src, "supports_prefetch", False) and ( + getattr(self.tgt, "supports_prefetch", False) or self.tgt is None + ) + + def prefetch(self, indices): + self.src.prefetch(indices) + if self.tgt is not None: + self.tgt.prefetch(indices) + if self.align_dataset is not None: + self.align_dataset.prefetch(indices) + + def filter_indices_by_size(self, indices, max_sizes): + """Filter a list of sample indices. Remove those that are longer + than specified in max_sizes. + + Args: + indices (np.array): original array of sample indices + max_sizes (int or list[int] or tuple[int]): max sample size, + can be defined separately for src and tgt (then list or tuple) + + Returns: + np.array: filtered sample array + list: list of removed indices + """ + return data_utils.filter_paired_dataset_indices_by_size( + self.src_sizes, + self.tgt_sizes, + indices, + max_sizes, + ) diff --git a/fairseq/data/legacy/__init__.py b/fairseq/data/legacy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9bd5c72b5e9d7f67fb7e4ef10808d7ec08967ff4 --- /dev/null +++ b/fairseq/data/legacy/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .block_pair_dataset import BlockPairDataset +from .masked_lm_dataset import MaskedLMDataset +from .masked_lm_dictionary import BertDictionary, MaskedLMDictionary + + +__all__ = [ + "BertDictionary", + "BlockPairDataset", + "MaskedLMDataset", + "MaskedLMDictionary", +] diff --git a/fairseq/data/legacy/__pycache__/__init__.cpython-310.pyc b/fairseq/data/legacy/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bba5aae28e6a832e17acbdc6c4e0a4da86dde925 Binary files /dev/null and b/fairseq/data/legacy/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/data/legacy/__pycache__/block_pair_dataset.cpython-310.pyc b/fairseq/data/legacy/__pycache__/block_pair_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b73ec5e7f86e89bd29c399974a072dcaab4e9724 Binary files /dev/null and b/fairseq/data/legacy/__pycache__/block_pair_dataset.cpython-310.pyc differ diff --git a/fairseq/data/legacy/__pycache__/masked_lm_dataset.cpython-310.pyc b/fairseq/data/legacy/__pycache__/masked_lm_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc0c9ac6cac230ba85d6deda449ac9a4bd440133 Binary files /dev/null and b/fairseq/data/legacy/__pycache__/masked_lm_dataset.cpython-310.pyc differ diff --git a/fairseq/data/legacy/__pycache__/masked_lm_dictionary.cpython-310.pyc b/fairseq/data/legacy/__pycache__/masked_lm_dictionary.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fd75085d3318260393a3a3344adda547f8e24a3 Binary files /dev/null and b/fairseq/data/legacy/__pycache__/masked_lm_dictionary.cpython-310.pyc differ diff --git a/fairseq/data/legacy/block_pair_dataset.py b/fairseq/data/legacy/block_pair_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ba069b46052286c531b4f9706d96788732cd2ad2 --- /dev/null +++ b/fairseq/data/legacy/block_pair_dataset.py @@ -0,0 +1,311 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import numpy as np +import torch +from fairseq.data import FairseqDataset + + +class BlockPairDataset(FairseqDataset): + """Break a Dataset of tokens into sentence pair blocks for next sentence + prediction as well as masked language model. + + High-level logics are: + 1. break input tensor to tensor blocks + 2. pair the blocks with 50% next sentence and 50% random sentence + 3. return paired blocks as well as related segment labels + + Args: + dataset (~torch.utils.data.Dataset): dataset to break into blocks + sizes: array of sentence lengths + dictionary: dictionary for the task + block_size: maximum block size + break_mode: mode for breaking copurs into block pairs. currently we support + 2 modes + doc: respect document boundaries and each part of the pair should belong to on document + none: don't respect any boundary and cut tokens evenly + short_seq_prob: probability for generating shorter block pairs + doc_break_size: Size for empty line separating documents. Typically 1 if + the sentences have eos, 0 otherwise. + """ + + def __init__( + self, + dataset, + dictionary, + sizes, + block_size, + break_mode="doc", + short_seq_prob=0.1, + doc_break_size=1, + ): + super().__init__() + self.dataset = dataset + self.pad = dictionary.pad() + self.eos = dictionary.eos() + self.cls = dictionary.cls() + self.mask = dictionary.mask() + self.sep = dictionary.sep() + self.break_mode = break_mode + self.dictionary = dictionary + self.short_seq_prob = short_seq_prob + self.block_indices = [] + + assert len(dataset) == len(sizes) + + if break_mode == "doc": + cur_doc = [] + for sent_id, sz in enumerate(sizes): + assert doc_break_size == 0 or sz != 0, ( + "when doc_break_size is non-zero, we expect documents to be" + "separated by a blank line with a single eos." + ) + # empty line as document separator + if sz == doc_break_size: + if len(cur_doc) == 0: + continue + self.block_indices.append(cur_doc) + cur_doc = [] + else: + cur_doc.append(sent_id) + max_num_tokens = block_size - 3 # Account for [CLS], [SEP], [SEP] + self.sent_pairs = [] + self.sizes = [] + for doc_id, doc in enumerate(self.block_indices): + self._generate_sentence_pair(doc, doc_id, max_num_tokens, sizes) + elif break_mode is None or break_mode == "none": + # each block should have half of the block size since we are constructing block pair + sent_length = (block_size - 3) // 2 + total_len = sum(dataset.sizes) + length = math.ceil(total_len / sent_length) + + def block_at(i): + start = i * sent_length + end = min(start + sent_length, total_len) + return (start, end) + + sent_indices = np.array([block_at(i) for i in range(length)]) + sent_sizes = np.array([e - s for s, e in sent_indices]) + dataset_index = self._sent_to_dataset_index(sent_sizes) + + # pair sentences + self._pair_sentences(dataset_index) + else: + raise ValueError("Invalid break_mode: " + break_mode) + + def _pair_sentences(self, dataset_index): + """ + Give a list of evenly cut blocks/sentences, pair these sentences with 50% + consecutive sentences and 50% random sentences. + This is used for none break mode + """ + # pair sentences + for sent_id, sent in enumerate(dataset_index): + next_sent_label = ( + 1 if np.random.rand() > 0.5 and sent_id != len(dataset_index) - 1 else 0 + ) + if next_sent_label: + next_sent = dataset_index[sent_id + 1] + else: + next_sent = dataset_index[ + self._skip_sampling(len(dataset_index), [sent_id, sent_id + 1]) + ] + self.sent_pairs.append((sent, next_sent, next_sent_label)) + + # The current blocks don't include the special tokens but the + # sizes already account for this + self.sizes.append(3 + sent[3] + next_sent[3]) + + def _sent_to_dataset_index(self, sent_sizes): + """ + Build index mapping block indices to the underlying dataset indices + """ + dataset_index = [] + ds_idx, ds_remaining = -1, 0 + for to_consume in sent_sizes: + sent_size = to_consume + if ds_remaining == 0: + ds_idx += 1 + ds_remaining = sent_sizes[ds_idx] + start_ds_idx = ds_idx + start_offset = sent_sizes[ds_idx] - ds_remaining + while to_consume > ds_remaining: + to_consume -= ds_remaining + ds_idx += 1 + ds_remaining = sent_sizes[ds_idx] + ds_remaining -= to_consume + dataset_index.append( + ( + start_ds_idx, # starting index in dataset + start_offset, # starting offset within starting index + ds_idx, # ending index in dataset + sent_size, # sentence length + ) + ) + assert ds_remaining == 0 + assert ds_idx == len(self.dataset) - 1 + return dataset_index + + def _generate_sentence_pair(self, doc, doc_id, max_num_tokens, sizes): + """ + Go through a single document and genrate sentence paris from it + """ + current_chunk = [] + current_length = 0 + curr = 0 + # To provide more randomness, we decrease target seq length for parts of + # samples (10% by default). Note that max_num_tokens is the hard threshold + # for batching and will never be changed. + target_seq_length = max_num_tokens + if np.random.random() < self.short_seq_prob: + target_seq_length = np.random.randint(2, max_num_tokens) + # loop through all sentences in document + while curr < len(doc): + sent_id = doc[curr] + current_chunk.append(sent_id) + current_length = sum(sizes[current_chunk]) + # split chunk and generate pair when exceed target_seq_length or + # finish the loop + if curr == len(doc) - 1 or current_length >= target_seq_length: + # split the chunk into 2 parts + a_end = 1 + if len(current_chunk) > 2: + a_end = np.random.randint(1, len(current_chunk) - 1) + sent_a = current_chunk[:a_end] + len_a = sum(sizes[sent_a]) + # generate next sentence label, note that if there is only 1 sentence + # in current chunk, label is always 0 + next_sent_label = ( + 1 if np.random.rand() > 0.5 and len(current_chunk) != 1 else 0 + ) + if not next_sent_label: + # if next sentence label is 0, sample sent_b from a random doc + target_b_length = target_seq_length - len_a + rand_doc_id = self._skip_sampling(len(self.block_indices), [doc_id]) + random_doc = self.block_indices[rand_doc_id] + random_start = np.random.randint(0, len(random_doc)) + sent_b = [] + len_b = 0 + for j in range(random_start, len(random_doc)): + sent_b.append(random_doc[j]) + len_b = sum(sizes[sent_b]) + if len_b >= target_b_length: + break + # return the second part of the chunk since it's not used + num_unused_segments = len(current_chunk) - a_end + curr -= num_unused_segments + else: + # if next sentence label is 1, use the second part of chunk as sent_B + sent_b = current_chunk[a_end:] + len_b = sum(sizes[sent_b]) + # currently sent_a and sent_B may be longer than max_num_tokens, + # truncate them and return block idx and offsets for them + sent_a, sent_b = self._truncate_sentences( + sent_a, sent_b, max_num_tokens + ) + self.sent_pairs.append((sent_a, sent_b, next_sent_label)) + self.sizes.append(3 + sent_a[3] + sent_b[3]) + current_chunk = [] + curr += 1 + + def _skip_sampling(self, total, skip_ids): + """ + Generate a random integer which is not in skip_ids. Sample range is [0, total) + TODO: ids in skip_ids should be consecutive, we can extend it to more generic version later + """ + rand_id = np.random.randint(total - len(skip_ids)) + return rand_id if rand_id < min(skip_ids) else rand_id + len(skip_ids) + + def _truncate_sentences(self, sent_a, sent_b, max_num_tokens): + """ + Trancate a pair of sentence to limit total length under max_num_tokens + Logics: + 1. Truncate longer sentence + 2. Tokens to be truncated could be at the beginning or the end of the sentnce + Returns: + Truncated sentences represented by dataset idx + """ + len_a, len_b = sum(self.dataset.sizes[sent_a]), sum(self.dataset.sizes[sent_b]) + front_cut_a = front_cut_b = end_cut_a = end_cut_b = 0 + + while True: + total_length = ( + len_a + len_b - front_cut_a - front_cut_b - end_cut_a - end_cut_b + ) + if total_length <= max_num_tokens: + break + + if len_a - front_cut_a - end_cut_a > len_b - front_cut_b - end_cut_b: + if np.random.rand() < 0.5: + front_cut_a += 1 + else: + end_cut_a += 1 + else: + if np.random.rand() < 0.5: + front_cut_b += 1 + else: + end_cut_b += 1 + + # calculate ds indices as well as offsets and return + truncated_sent_a = self._cut_sentence(sent_a, front_cut_a, end_cut_a) + truncated_sent_b = self._cut_sentence(sent_b, front_cut_b, end_cut_b) + return truncated_sent_a, truncated_sent_b + + def _cut_sentence(self, sent, front_cut, end_cut): + """ + Cut a sentence based on the numbers of tokens to be cut from beginning and end + Represent the sentence as dataset idx and return + """ + start_ds_idx, end_ds_idx, offset = sent[0], sent[-1], 0 + target_len = sum(self.dataset.sizes[sent]) - front_cut - end_cut + while front_cut > 0: + if self.dataset.sizes[start_ds_idx] > front_cut: + offset += front_cut + break + else: + front_cut -= self.dataset.sizes[start_ds_idx] + start_ds_idx += 1 + while end_cut > 0: + if self.dataset.sizes[end_ds_idx] > end_cut: + break + else: + end_cut -= self.dataset.sizes[end_ds_idx] + end_ds_idx -= 1 + return start_ds_idx, offset, end_ds_idx, target_len + + def _fetch_block(self, start_ds_idx, offset, end_ds_idx, length): + """ + Fetch a block of tokens based on its dataset idx + """ + buffer = torch.cat( + [self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)] + ) + s, e = offset, offset + length + return buffer[s:e] + + def __getitem__(self, index): + block1, block2, next_sent_label = self.sent_pairs[index] + block1 = self._fetch_block(*block1) + block2 = self._fetch_block(*block2) + return block1, block2, next_sent_label + + def __len__(self): + return len(self.sizes) + + @property + def supports_prefetch(self): + return getattr(self.dataset, "supports_prefetch", False) + + def prefetch(self, indices): + prefetch_idx = set() + for index in indices: + for block1, block2, _ in [self.sent_pairs[index]]: + for ds_idx in range(block1[0], block1[2] + 1): + prefetch_idx.add(ds_idx) + for ds_idx in range(block2[0], block2[2] + 1): + prefetch_idx.add(ds_idx) + self.dataset.prefetch(prefetch_idx) diff --git a/fairseq/data/legacy/masked_lm_dataset.py b/fairseq/data/legacy/masked_lm_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..dd8ea2c60aff306ab3a756223a298a28d41a4991 --- /dev/null +++ b/fairseq/data/legacy/masked_lm_dataset.py @@ -0,0 +1,303 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Dict, List, Tuple + +import numpy as np +import torch +from fairseq.data import Dictionary, FairseqDataset, data_utils +from fairseq.data.concat_dataset import ConcatDataset +from fairseq.data.legacy.block_pair_dataset import BlockPairDataset +from fairseq.data.token_block_dataset import TokenBlockDataset + + +class MaskedLMDataset(FairseqDataset): + """ + A wrapper Dataset for masked language modelling. The dataset + wraps around TokenBlockDataset or BlockedPairDataset and creates a batch + where the input blocks are masked according to the specified masking + probability. Additionally the batch can also contain sentence level targets + if this is specified. + + Args: + dataset: Dataset which generates blocks of data. Only BlockPairDataset + and TokenBlockDataset are supported. + sizes: Sentence lengths + vocab: Dictionary with the vocabulary and special tokens. + pad_idx: Id of padding token in dictionary + mask_idx: Id of mask token in dictionary + classif_token_idx: Id of classification token in dictionary. This is the + token associated with the sentence embedding (Eg: CLS for BERT) + sep_token_idx: Id of separator token in dictionary + (Eg: SEP in BERT) + seed: Seed for random number generator for reproducibility. + shuffle: Shuffle the elements before batching. + has_pairs: Specifies whether the underlying dataset + generates a pair of blocks along with a sentence_target or not. + Setting it to True assumes that the underlying dataset generates a + label for the pair of sentences which is surfaced as + sentence_target. The default value assumes a single block with no + sentence target. + segment_id: An optional segment id for filling in the segment labels + when we are in the single block setting (Eg: XLM). Default is 0. + masking_ratio: specifies what percentage of the blocks should be masked. + masking_prob: specifies the probability of a given token being + replaced with the "MASK" token. + random_token_prob: specifies the probability of a given token being + replaced by a random token from the vocabulary. + """ + + def __init__( + self, + dataset: FairseqDataset, + sizes: np.ndarray, + vocab: Dictionary, + pad_idx: int, + mask_idx: int, + classif_token_idx: int, + sep_token_idx: int, + seed: int = 1, + shuffle: bool = True, + has_pairs: bool = True, + segment_id: int = 0, + masking_ratio: float = 0.15, + masking_prob: float = 0.8, + random_token_prob: float = 0.1, + ): + # Make sure the input datasets are the ones supported + assert ( + isinstance(dataset, TokenBlockDataset) + or isinstance(dataset, BlockPairDataset) + or isinstance(dataset, ConcatDataset) + ), ( + "MaskedLMDataset only wraps TokenBlockDataset or BlockPairDataset or " + "ConcatDataset" + ) + + self.dataset = dataset + self.sizes = np.array(sizes) + self.vocab = vocab + self.pad_idx = pad_idx + self.mask_idx = mask_idx + self.classif_token_idx = classif_token_idx + self.sep_token_idx = sep_token_idx + self.shuffle = shuffle + self.seed = seed + self.has_pairs = has_pairs + self.segment_id = segment_id + self.masking_ratio = masking_ratio + self.masking_prob = masking_prob + self.random_token_prob = random_token_prob + + # If we have only one block then sizes needs to be updated to include + # the classification token + if not has_pairs: + self.sizes = self.sizes + 1 + + def __getitem__(self, index: int): + # if has_pairs, then expect 2 blocks and a sentence target + if self.has_pairs: + (block_one, block_two, sentence_target) = self.dataset[index] + else: + block_one = self.dataset[index] + + return { + "id": index, + "block_one": block_one, + "block_two": block_two if self.has_pairs else None, + "sentence_target": sentence_target if self.has_pairs else None, + } + + def __len__(self): + return len(self.dataset) + + def _mask_block( + self, + sentence: np.ndarray, + mask_idx: int, + pad_idx: int, + dictionary_token_range: Tuple, + ): + """ + Mask tokens for Masked Language Model training + Samples mask_ratio tokens that will be predicted by LM. + + Note:This function may not be efficient enough since we had multiple + conversions between np and torch, we can replace them with torch + operators later. + + Args: + sentence: 1d tensor to be masked + mask_idx: index to use for masking the sentence + pad_idx: index to use for masking the target for tokens we aren't + predicting + dictionary_token_range: range of indices in dictionary which can + be used for random word replacement + (e.g. without special characters) + Return: + masked_sent: masked sentence + target: target with words which we are not predicting replaced + by pad_idx + """ + masked_sent = np.copy(sentence) + sent_length = len(sentence) + mask_num = math.ceil(sent_length * self.masking_ratio) + mask = np.random.choice(sent_length, mask_num, replace=False) + target = np.copy(sentence) + + for i in range(sent_length): + if i in mask: + rand = np.random.random() + + # replace with mask if probability is less than masking_prob + # (Eg: 0.8) + if rand < self.masking_prob: + masked_sent[i] = mask_idx + + # replace with random token if probability is less than + # masking_prob + random_token_prob (Eg: 0.9) + elif rand < (self.masking_prob + self.random_token_prob): + # sample random token from dictionary + masked_sent[i] = np.random.randint( + dictionary_token_range[0], dictionary_token_range[1] + ) + else: + target[i] = pad_idx + + return masked_sent, target + + def _collate(self, samples: List[Dict], pad_idx: int, eos_idx: int): + """ + Does the heavy lifting for creating a batch from the input list of + examples. The logic is as follows: + 1. Mask the input blocks. In case has_pair is True then we have 2 + blocks to mask. + 2. Prepend the first masked block tensor with the special token + used as sentence embedding. Eg: CLS in BERT. This happens + irrespective of the value of has_pair. + 3. If has_pair is True, then append the first masked block with the + special separator token (eg: SEP for BERT) and compute segment + label accordingly. In this case, also append the second masked + block with this special separator token and compute its segment + label. + 4. For the targets tensor, prepend and append with padding index + accordingly. + 5. Concatenate all tensors. + """ + if len(samples) == 0: + return {} + # To ensure determinism, we reset the state of the PRNG after every + # batch based on the seed and the first id of the batch. This ensures + # that across epochs we get the same mask for the same example. This + # is needed for reproducibility and is how BERT does masking + # TODO: Can we add deteminism without this constraint? + with data_utils.numpy_seed(self.seed + samples[0]["id"]): + for s in samples: + + # token range is needed for replacing with random token during + # masking + token_range = (self.vocab.nspecial, len(self.vocab)) + + # mask according to specified probabilities. + masked_blk_one, masked_tgt_one = self._mask_block( + s["block_one"], + self.mask_idx, + self.pad_idx, + token_range, + ) + + tokens = np.concatenate([[self.classif_token_idx], masked_blk_one]) + targets = np.concatenate([[self.pad_idx], masked_tgt_one]) + segments = np.ones(len(tokens)) * self.segment_id + + # if has_pairs is True then we need to add the SEP token to both + # the blocks after masking and re-compute segments based on the new + # lengths. + if self.has_pairs: + tokens_one = np.concatenate([tokens, [self.sep_token_idx]]) + targets_one = np.concatenate([targets, [self.pad_idx]]) + + masked_blk_two, masked_tgt_two = self._mask_block( + s["block_two"], self.mask_idx, self.pad_idx, token_range + ) + tokens_two = np.concatenate([masked_blk_two, [self.sep_token_idx]]) + targets_two = np.concatenate([masked_tgt_two, [self.pad_idx]]) + + # block + 1 sep + 1 special (CLS) + segments_one = np.zeros(len(tokens_one)) + # block + 1 sep + segments_two = np.ones(len(tokens_two)) + + tokens = np.concatenate([tokens_one, tokens_two]) + targets = np.concatenate([targets_one, targets_two]) + segments = np.concatenate([segments_one, segments_two]) + + s["source"] = torch.LongTensor(tokens) + s["segment_labels"] = torch.LongTensor(segments) + s["lm_target"] = torch.LongTensor(targets) + + def merge(key): + return data_utils.collate_tokens( + [s[key] for s in samples], pad_idx, eos_idx, left_pad=False + ) + + return { + "id": torch.LongTensor([s["id"] for s in samples]), + "ntokens": sum(len(s["source"]) for s in samples), + "net_input": { + "src_tokens": merge("source"), + "segment_labels": merge("segment_labels"), + }, + "lm_target": merge("lm_target"), + "sentence_target": torch.LongTensor([s["sentence_target"] for s in samples]) + if self.has_pairs + else None, + "nsentences": len(samples), + } + + def collater(self, samples: List[Dict]): + """Merge a list of samples to form a mini-batch. + + Args: + samples (List[dict]): samples to collate + + Returns: + dict: a mini-batch of data + """ + return self._collate(samples, self.vocab.pad(), self.vocab.eos()) + + def num_tokens(self, index: int): + """ + Return the number of tokens in a sample. This value is used to + enforce max-tokens during batching. + """ + return self.sizes[index] + + def size(self, index: int): + """ + Return an example's size as a float or tuple. This value is used when + filtering a dataset with max-positions. + """ + return self.sizes[index] + + def ordered_indices(self): + """ + Return an ordered list of indices. Batches will be constructed based + on this order. + """ + if self.shuffle: + return np.random.permutation(len(self)) + else: + order = [np.arange(len(self))] + order.append(self.sizes) + return np.lexsort(order) + + @property + def supports_prefetch(self): + return getattr(self.dataset, "supports_prefetch", False) + + def prefetch(self, indices): + self.dataset.prefetch(indices) diff --git a/fairseq/data/legacy/masked_lm_dictionary.py b/fairseq/data/legacy/masked_lm_dictionary.py new file mode 100644 index 0000000000000000000000000000000000000000..dee88f7a3ed72ea465ea4e8ffe7b1c01ff6f57f1 --- /dev/null +++ b/fairseq/data/legacy/masked_lm_dictionary.py @@ -0,0 +1,60 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.data import Dictionary + + +class MaskedLMDictionary(Dictionary): + """ + Dictionary for Masked Language Modelling tasks. This extends Dictionary by + adding the mask symbol. + """ + + def __init__( + self, + pad="", + eos="", + unk="", + mask="", + ): + super().__init__(pad=pad, eos=eos, unk=unk) + self.mask_word = mask + self.mask_index = self.add_symbol(mask) + self.nspecial = len(self.symbols) + + def mask(self): + """Helper to get index of mask symbol""" + return self.mask_index + + +class BertDictionary(MaskedLMDictionary): + """ + Dictionary for BERT task. This extends MaskedLMDictionary by adding support + for cls and sep symbols. + """ + + def __init__( + self, + pad="", + eos="", + unk="", + mask="", + cls="", + sep="", + ): + super().__init__(pad=pad, eos=eos, unk=unk, mask=mask) + self.cls_word = cls + self.sep_word = sep + self.cls_index = self.add_symbol(cls) + self.sep_index = self.add_symbol(sep) + self.nspecial = len(self.symbols) + + def cls(self): + """Helper to get index of cls symbol""" + return self.cls_index + + def sep(self): + """Helper to get index of sep symbol""" + return self.sep_index diff --git a/fairseq/data/list_dataset.py b/fairseq/data/list_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..12f00aa43661d6bad701c9e72653ba8779136906 --- /dev/null +++ b/fairseq/data/list_dataset.py @@ -0,0 +1,32 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import BaseWrapperDataset + + +class ListDataset(BaseWrapperDataset): + def __init__(self, dataset, sizes=None): + super().__init__(dataset) + self._sizes = sizes + + def __iter__(self): + for x in self.dataset: + yield x + + def collater(self, samples): + return samples + + @property + def sizes(self): + return self._sizes + + def num_tokens(self, index): + return self.sizes[index] + + def size(self, index): + return self.sizes[index] + + def set_epoch(self, epoch): + pass diff --git a/fairseq/data/lm_context_window_dataset.py b/fairseq/data/lm_context_window_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1a945927cf0d96719003685676a990737a3762b2 --- /dev/null +++ b/fairseq/data/lm_context_window_dataset.py @@ -0,0 +1,97 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from typing import Dict + +from fairseq.data.monolingual_dataset import MonolingualDataset + +from . import FairseqDataset + + +class LMContextWindowDataset(FairseqDataset): + """ + Wraps a MonolingualDataset and provides more context for evaluation. + + Each item in the new dataset will have a maximum size of + ``tokens_per_sample + context_window``. + + Args: + dataset: dataset to wrap + tokens_per_sample (int): the max number of tokens in each dataset item + context_window (int): the number of accumulated tokens to add to each + dataset item + pad_idx (int): padding symbol + """ + + def __init__( + self, + dataset: MonolingualDataset, + tokens_per_sample: int, + context_window: int, + pad_idx: int, + ): + assert context_window > 0 + self.dataset = dataset + self.tokens_per_sample = tokens_per_sample + self.context_window = context_window + self.pad_idx = pad_idx + self.prev_tokens = np.empty([0]) + + def __getitem__(self, index): + return self.dataset[index] + + def __len__(self): + return len(self.dataset) + + def collater(self, samples) -> Dict: + sample = self.dataset.collater(samples) + + pad = self.pad_idx + max_sample_len = self.tokens_per_sample + self.context_window + + bsz, tsz = sample["net_input"]["src_tokens"].shape + start_idxs = [0] * bsz + toks = sample["net_input"]["src_tokens"] + lengths = sample["net_input"]["src_lengths"] + tgt = sample["target"] + new_toks = np.empty([bsz, tsz + self.context_window], dtype=np.int64) + new_tgt = np.full([bsz, tsz + self.context_window], pad, dtype=np.int64) + sample_lens = toks.ne(pad).long().sum(dim=1).cpu() + for i in range(bsz): + sample_len = sample_lens[i] + extra = len(self.prev_tokens) + sample_len - max_sample_len + if extra > 0: + self.prev_tokens = self.prev_tokens[extra:] + pads = np.full(self.context_window - len(self.prev_tokens), pad) + new_toks[i] = np.concatenate([self.prev_tokens, toks[i].numpy(), pads]) + new_tgt[ + i, len(self.prev_tokens) : len(self.prev_tokens) + len(tgt[i]) + ] = tgt[i] + start_idxs[i] = len(self.prev_tokens) + lengths[i] += len(self.prev_tokens) + self.prev_tokens = new_toks[i][new_toks[i] != pad][-self.context_window :] + sample["net_input"]["src_tokens"] = torch.from_numpy(new_toks) + sample["target"] = torch.from_numpy(new_tgt) + sample["start_indices"] = start_idxs + return sample + + def num_tokens(self, index): + return self.dataset.num_tokens(index) + + def size(self, index): + return self.dataset.size(index) + + def ordered_indices(self): + # NOTE we don't shuffle the data to retain access to the previous dataset elements + return np.arange(len(self.dataset)) + + @property + def supports_prefetch(self): + return getattr(self.dataset, "supports_prefetch", False) + + def prefetch(self, indices): + return self.dataset.prefetch(indices) diff --git a/fairseq/data/lru_cache_dataset.py b/fairseq/data/lru_cache_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a7854ac1701392754ce5795cafe9c634671aebdf --- /dev/null +++ b/fairseq/data/lru_cache_dataset.py @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from functools import lru_cache + +from . import BaseWrapperDataset + + +class LRUCacheDataset(BaseWrapperDataset): + def __init__(self, dataset, token=None): + super().__init__(dataset) + + @lru_cache(maxsize=8) + def __getitem__(self, index): + return self.dataset[index] + + @lru_cache(maxsize=8) + def collater(self, samples): + return self.dataset.collater(samples) diff --git a/fairseq/data/mask_tokens_dataset.py b/fairseq/data/mask_tokens_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0ca9051c9ae9eaf4bb31a917ae115cbeb13879a7 --- /dev/null +++ b/fairseq/data/mask_tokens_dataset.py @@ -0,0 +1,226 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from functools import lru_cache + +import numpy as np +import torch +from fairseq.data import Dictionary, data_utils + +from . import BaseWrapperDataset, LRUCacheDataset + + +class MaskTokensDataset(BaseWrapperDataset): + """ + A wrapper Dataset for masked language modeling. + + Input items are masked according to the specified masking probability. + + Args: + dataset: Dataset to wrap. + sizes: Sentence lengths + vocab: Dictionary with the vocabulary and special tokens. + pad_idx: Id of pad token in vocab + mask_idx: Id of mask token in vocab + return_masked_tokens: controls whether to return the non-masked tokens + (the default) or to return a tensor with the original masked token + IDs (and *pad_idx* elsewhere). The latter is useful as targets for + masked LM training. + seed: Seed for random number generator for reproducibility. + mask_prob: probability of replacing a token with *mask_idx*. + leave_unmasked_prob: probability that a masked token is unmasked. + random_token_prob: probability of replacing a masked token with a + random token from the vocabulary. + freq_weighted_replacement: sample random replacement words based on + word frequencies in the vocab. + mask_whole_words: only mask whole words. This should be a byte mask + over vocab indices, indicating whether it is the beginning of a + word. We will extend any mask to encompass the whole word. + bpe: BPE to use for whole-word masking. + mask_multiple_length : repeat each mask index multiple times. Default + value is 1. + mask_stdev : standard deviation of masks distribution in case of + multiple masking. Default value is 0. + """ + + @classmethod + def apply_mask(cls, dataset: torch.utils.data.Dataset, *args, **kwargs): + """Return the source and target datasets for masked LM training.""" + dataset = LRUCacheDataset(dataset) + return ( + LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=False)), + LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=True)), + ) + + def __init__( + self, + dataset: torch.utils.data.Dataset, + vocab: Dictionary, + pad_idx: int, + mask_idx: int, + return_masked_tokens: bool = False, + seed: int = 1, + mask_prob: float = 0.15, + leave_unmasked_prob: float = 0.1, + random_token_prob: float = 0.1, + freq_weighted_replacement: bool = False, + mask_whole_words: torch.Tensor = None, + mask_multiple_length: int = 1, + mask_stdev: float = 0.0, + skip_masking: bool = False, + ): + assert 0.0 < mask_prob < 1.0 + assert 0.0 <= random_token_prob <= 1.0 + assert 0.0 <= leave_unmasked_prob <= 1.0 + assert random_token_prob + leave_unmasked_prob <= 1.0 + assert mask_multiple_length >= 1 + assert mask_stdev >= 0.0 + + self.dataset = dataset + self.vocab = vocab + self.pad_idx = pad_idx + self.mask_idx = mask_idx + self.return_masked_tokens = return_masked_tokens + self.seed = seed + self.mask_prob = mask_prob + self.leave_unmasked_prob = leave_unmasked_prob + self.random_token_prob = random_token_prob + self.mask_whole_words = mask_whole_words + self.mask_multiple_length = mask_multiple_length + self.mask_stdev = mask_stdev + self.skip_masking = skip_masking + + if random_token_prob > 0.0: + if freq_weighted_replacement: + weights = np.array(self.vocab.count) + else: + weights = np.ones(len(self.vocab)) + weights[: self.vocab.nspecial] = 0 + self.weights = weights / weights.sum() + + self.epoch = 0 + + @property + def can_reuse_epoch_itr_across_epochs(self): + return True # only the noise changes, not item sizes + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + def __getitem__(self, index: int): + return self.__getitem_cached__(self.seed, self.epoch, index) + + @lru_cache(maxsize=8) + def __getitem_cached__(self, seed: int, epoch: int, index: int): + seed = int(hash((seed, epoch, index)) % 1e6) + rng = np.random.default_rng(seed) + item = self.dataset[index] + sz = len(item) + + assert ( + self.mask_idx not in item + ), "Dataset contains mask_idx (={}), this is not expected!".format( + self.mask_idx, + ) + if self.skip_masking: + return torch.from_numpy(np.copy(item)) + + if self.mask_whole_words is not None: + word_begins_mask = self.mask_whole_words.gather(0, item) + word_begins_idx = word_begins_mask.nonzero().view(-1) + sz = len(word_begins_idx) + words = np.split(word_begins_mask, word_begins_idx)[1:] + assert len(words) == sz + word_lens = list(map(len, words)) + + # decide elements to mask + mask = np.full(sz, False) + num_mask = int( + # add a random number for probabilistic rounding + self.mask_prob * sz / float(self.mask_multiple_length) + + rng.random() + ) + + # multiple masking as described in the vq-wav2vec paper (https://arxiv.org/abs/1910.05453) + mask_idc = rng.choice(sz, num_mask, replace=False) + if self.mask_stdev > 0.0: + lengths = rng.normal( + self.mask_multiple_length, self.mask_stdev, size=num_mask + ) + lengths = [max(0, int(round(x))) for x in lengths] + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ], + dtype=np.int64, + ) + else: + mask_idc = np.concatenate( + [mask_idc + i for i in range(self.mask_multiple_length)] + ) + mask_idc = mask_idc[mask_idc < len(mask)] + try: + mask[mask_idc] = True + except: # something wrong + print("Assigning mask indexes {} to mask {} failed!".format(mask_idc, mask)) + raise + + # if self.return_masked_tokens: + # print(( + # f"IDX={index}; seed={seed}; epoch={epoch}; is_tgt={self.return_masked_tokens}: " + # f"{np.nonzero(mask)[0].sum()}" + # )) + if self.return_masked_tokens: + # exit early if we're just returning the masked tokens + # (i.e., the targets for masked LM training) + if self.mask_whole_words is not None: + mask = np.repeat(mask, word_lens) + new_item = np.full(len(mask), self.pad_idx) + new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8)) == 1] + return torch.from_numpy(new_item) + + # decide unmasking and random replacement + rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob + if rand_or_unmask_prob > 0.0: + rand_or_unmask = mask & (rng.random(sz) < rand_or_unmask_prob) + if self.random_token_prob == 0.0: + unmask = rand_or_unmask + rand_mask = None + elif self.leave_unmasked_prob == 0.0: + unmask = None + rand_mask = rand_or_unmask + else: + unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob + decision = rng.random(sz) < unmask_prob + unmask = rand_or_unmask & decision + rand_mask = rand_or_unmask & (~decision) + else: + unmask = rand_mask = None + + if unmask is not None: + mask = mask ^ unmask + + if self.mask_whole_words is not None: + mask = np.repeat(mask, word_lens) + + new_item = np.copy(item) + new_item[mask] = self.mask_idx + if rand_mask is not None: + num_rand = rand_mask.sum() + if num_rand > 0: + if self.mask_whole_words is not None: + rand_mask = np.repeat(rand_mask, word_lens) + num_rand = rand_mask.sum() + + new_item[rand_mask] = rng.choice( + len(self.vocab), + num_rand, + p=self.weights, + ) + + return torch.from_numpy(new_item) diff --git a/fairseq/data/monolingual_dataset.py b/fairseq/data/monolingual_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..54fd583b64a3a475324ade6eaaeccf593d747fdc --- /dev/null +++ b/fairseq/data/monolingual_dataset.py @@ -0,0 +1,253 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from . import FairseqDataset, data_utils + + +def collate(samples, pad_idx, eos_idx, fixed_pad_length=None, pad_to_bsz=None): + if len(samples) == 0: + return {} + + def merge(key, is_list=False): + if is_list: + res = [] + for i in range(len(samples[0][key])): + res.append( + data_utils.collate_tokens( + [s[key][i] for s in samples], + pad_idx, + eos_idx, + left_pad=False, + pad_to_length=fixed_pad_length, + pad_to_bsz=pad_to_bsz, + ) + ) + return res + else: + return data_utils.collate_tokens( + [s[key] for s in samples], + pad_idx, + eos_idx, + left_pad=False, + pad_to_length=fixed_pad_length, + pad_to_bsz=pad_to_bsz, + ) + + src_tokens = merge("source") + if samples[0]["target"] is not None: + is_target_list = isinstance(samples[0]["target"], list) + target = merge("target", is_target_list) + else: + target = src_tokens + + return { + "id": torch.LongTensor([s["id"] for s in samples]), + "nsentences": len(samples), + "ntokens": sum(len(s["source"]) for s in samples), + "net_input": { + "src_tokens": src_tokens, + "src_lengths": torch.LongTensor([s["source"].numel() for s in samples]), + }, + "target": target, + } + + +class MonolingualDataset(FairseqDataset): + """ + A wrapper around torch.utils.data.Dataset for monolingual data. + + Args: + dataset (torch.utils.data.Dataset): dataset to wrap + sizes (List[int]): sentence lengths + vocab (~fairseq.data.Dictionary): vocabulary + shuffle (bool, optional): shuffle the elements before batching + (default: True). + """ + + def __init__( + self, + dataset, + sizes, + src_vocab, + tgt_vocab=None, + add_eos_for_other_targets=False, + shuffle=False, + targets=None, + add_bos_token=False, + fixed_pad_length=None, + pad_to_bsz=None, + src_lang_idx=None, + tgt_lang_idx=None, + ): + self.dataset = dataset + self.sizes = np.array(sizes) + self.vocab = src_vocab + self.tgt_vocab = tgt_vocab or src_vocab + self.add_eos_for_other_targets = add_eos_for_other_targets + self.shuffle = shuffle + self.add_bos_token = add_bos_token + self.fixed_pad_length = fixed_pad_length + self.pad_to_bsz = pad_to_bsz + self.src_lang_idx = src_lang_idx + self.tgt_lang_idx = tgt_lang_idx + + assert targets is None or all( + t in {"self", "future", "past"} for t in targets + ), "targets must be none or one of 'self', 'future', 'past'" + if targets is not None and len(targets) == 0: + targets = None + self.targets = targets + + def __getitem__(self, index): + if self.targets is not None: + # *future_target* is the original sentence + # *source* is shifted right by 1 (maybe left-padded with eos) + # *past_target* is shifted right by 2 (left-padded as needed) + # + # Left-to-right language models should condition on *source* and + # predict *future_target*. + # Right-to-left language models should condition on *source* and + # predict *past_target*. + source, future_target, past_target = self.dataset[index] + source, target = self._make_source_target( + source, future_target, past_target + ) + else: + source = self.dataset[index] + target = None + source, target = self._maybe_add_bos(source, target) + return {"id": index, "source": source, "target": target} + + def __len__(self): + return len(self.dataset) + + def _make_source_target(self, source, future_target, past_target): + if self.targets is not None: + target = [] + + if ( + self.add_eos_for_other_targets + and (("self" in self.targets) or ("past" in self.targets)) + and source[-1] != self.vocab.eos() + ): + # append eos at the end of source + source = torch.cat([source, source.new([self.vocab.eos()])]) + + if "future" in self.targets: + future_target = torch.cat( + [future_target, future_target.new([self.vocab.pad()])] + ) + if "past" in self.targets: + # first token is before the start of sentence which is only used in "none" break mode when + # add_eos_for_other_targets is False + past_target = torch.cat( + [ + past_target.new([self.vocab.pad()]), + past_target[1:], + source[-2, None], + ] + ) + + for t in self.targets: + if t == "self": + target.append(source) + elif t == "future": + target.append(future_target) + elif t == "past": + target.append(past_target) + else: + raise Exception("invalid target " + t) + + if len(target) == 1: + target = target[0] + else: + target = future_target + + return source, self._filter_vocab(target) + + def _maybe_add_bos(self, source, target): + if self.add_bos_token: + source = torch.cat([source.new([self.vocab.bos()]), source]) + if target is not None: + target = torch.cat([target.new([self.tgt_vocab.bos()]), target]) + return source, target + + def num_tokens_vec(self, indices): + """Return the number of tokens for a set of positions defined by indices. + This value is used to enforce ``--max-tokens`` during batching.""" + return self.sizes[indices] + + def _filter_vocab(self, target): + if len(self.tgt_vocab) != len(self.vocab): + + def _filter(target): + mask = target.ge(len(self.tgt_vocab)) + if mask.any(): + target[mask] = self.tgt_vocab.unk() + return target + + if isinstance(target, list): + return [_filter(t) for t in target] + return _filter(target) + return target + + def collater(self, samples): + """Merge a list of samples to form a mini-batch. + + Args: + samples (List[dict]): samples to collate + + Returns: + dict: a mini-batch with the following keys: + + - `id` (LongTensor): example IDs in the original input order + - `ntokens` (int): total number of tokens in the batch + - `net_input` (dict): the input to the Model, containing keys: + + - `src_tokens` (LongTensor): a padded 2D Tensor of tokens in + the source sentence of shape `(bsz, src_len)`. Padding will + appear on the right. + + - `target` (LongTensor): a padded 2D Tensor of tokens in the + target sentence of shape `(bsz, tgt_len)`. Padding will appear + on the right. + """ + return collate( + samples, + self.vocab.pad(), + self.vocab.eos(), + self.fixed_pad_length, + self.pad_to_bsz, + ) + + def num_tokens(self, index): + """Return the number of tokens in a sample. This value is used to + enforce ``--max-tokens`` during batching.""" + return self.sizes[index] + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + return self.sizes[index] + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + order = [np.random.permutation(len(self))] + else: + order = [np.arange(len(self))] + order.append(self.sizes) + return np.lexsort(order) + + @property + def supports_prefetch(self): + return getattr(self.dataset, "supports_prefetch", False) + + def prefetch(self, indices): + self.dataset.prefetch(indices) diff --git a/fairseq/data/multi_corpus_dataset.py b/fairseq/data/multi_corpus_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6f2fe074b2280c85706979614ba7abc5ad4c7bb5 --- /dev/null +++ b/fairseq/data/multi_corpus_dataset.py @@ -0,0 +1,285 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import asyncio +import logging +import time +from collections import OrderedDict +from typing import Dict, List, Optional + +import numpy as np + +from fairseq.data import data_utils + +from . import FairseqDataset + +logger = logging.getLogger(__name__) + + +class MultiCorpusDataset(FairseqDataset): + """ + Stores multiple instances of FairseqDataset together. + Unless batch_sample=True, requires each instance + to be the same dataset, as the collate method needs to work on batches with + samples from each dataset. + + Allows specifying a distribution over the datasets to use. Note that unlike + MultiCorpusSampledDataset, this distribution allows sampling for each item, + rather than on a batch level. Note that datasets with sampling probabilty + of 0 will be skipped. + + Each time ordered_indices() is called, a new sample is generated with + the specified distribution. + + Args: + datasets: a OrderedDict of FairseqDataset instances. + distribution: a List containing the probability of getting an utterance from + corresponding dataset + seed: random seed for sampling the datsets + sort_indices: if true, will sort the ordered indices by size + batch_sample: if true, will ensure each batch is from a single dataset + """ + + def __init__( + self, + datasets: Dict[str, FairseqDataset], + distribution: List[float], + seed: int, + sort_indices: bool = False, + batch_sample: bool = False, + distributed_rank: Optional[int] = None, + ): + super().__init__() + assert isinstance(datasets, OrderedDict) + assert len(datasets) == len(distribution) + assert sum(distribution) == 1 + self.datasets = datasets + self.distribution = distribution + self.seed = seed + self.sort_indices = sort_indices + self.batch_sample = batch_sample + self.distributed_rank = distributed_rank + + # Avoid repeated conversions to list later + self.dataset_list = list(datasets.values()) + self.total_num_instances = 0 + + first_dataset = self.dataset_list[0] + + self.num_instances_per_dataset = [] + self.dataset_offsets = [] + for i, dataset in enumerate(self.dataset_list): + assert isinstance(dataset, FairseqDataset) + assert type(dataset) is type(first_dataset) + self.num_instances_per_dataset.append( + 0 if self.distribution[i] == 0 else len(dataset) + ) + self.dataset_offsets.append(self.total_num_instances) + self.total_num_instances += self.num_instances_per_dataset[i] + + def ordered_indices(self): + start = time.time() + with data_utils.numpy_seed(self.seed, self.epoch): + logger.info( + f"sampling new dataset with seed {self.seed} epoch {self.epoch}" + ) + sampled_indices = [] + num_selected_instances = 0 + + # For each dataset i, sample self.distribution[i] * self.total_num_instances + for i, key in enumerate(self.datasets): + if self.distribution[i] == 0: + # skip dataset if sampling probability is 0 + continue + + if i < len(self.datasets) - 1: + num_instances = int(self.distribution[i] * self.total_num_instances) + high = self.dataset_offsets[i + 1] + else: + num_instances = self.total_num_instances - num_selected_instances + high = self.total_num_instances + + logger.info(f"sampling {num_instances} from {key} dataset") + num_selected_instances += num_instances + + # First, add k copies of the dataset where k = num_instances // len(dataset). + # This ensures an equal distribution of the data points as much as possible. + # For the remaining entries randomly sample them + dataset_size = len(self.datasets[key]) + num_copies = num_instances // dataset_size + dataset_indices = ( + np.random.permutation(high - self.dataset_offsets[i]) + + self.dataset_offsets[i] + )[: num_instances - num_copies * dataset_size] + if num_copies > 0: + sampled_indices += list( + np.concatenate( + ( + np.repeat( + np.arange(self.dataset_offsets[i], high), num_copies + ), + dataset_indices, + ) + ) + ) + else: + sampled_indices += list(dataset_indices) + + assert ( + len(sampled_indices) == self.total_num_instances + ), f"{len(sampled_indices)} vs {self.total_num_instances}" + + np.random.shuffle(sampled_indices) + if self.sort_indices: + sampled_indices.sort(key=lambda i: self.num_tokens(i)) + + logger.info( + "multi_corpus_dataset ordered_indices took {}s".format( + time.time() - start + ) + ) + return np.array(sampled_indices, dtype=np.int64) + + def _map_index(self, index: int): + """ + If dataset A has length N and dataset B has length M + then index 1 maps to index 1 of dataset A, and index N + 1 + maps to index 1 of B. + """ + counter = 0 + for num_instances, key in zip(self.num_instances_per_dataset, self.datasets): + if index < counter + num_instances: + return index - counter, key + counter += num_instances + raise ValueError( + "Invalid index: {}, max: {}".format(index, self.total_num_instances) + ) + + def __len__(self): + """ + Length of this dataset is the sum of individual datasets + """ + return self.total_num_instances + + async def getitem(self, index): + new_index, key = self._map_index(index) + try: + if hasattr(self.datasets[key], "getitem"): + item = await self.datasets[key].getitem(new_index) + else: + item = self.datasets[key][new_index] + item["full_id"] = index + return item + except Exception as e: + e.args = (f"Error from {key} dataset", *e.args) + raise + + def __getitem__(self, index): + return asyncio.run(self.getitem(index)) + + async def getitems(self, indices): + # initialize a bunch of everstore read operations + # wait in the end to reduce overhead + # very helpful if io is latency bounded + + max_concurrency = 32 + sem = asyncio.Semaphore(max_concurrency) + + async def controlled_getitem(index): + async with sem: + return await self.getitem(index) + + coroutines = [] + for index in indices: + coroutines.append(controlled_getitem(index)) + results = await asyncio.gather(*coroutines) + return results + + def __getitems__(self, indices): + return asyncio.run(self.getitems(indices)) + + def collater(self, samples): + """ + If we are doing batch sampling, then pick the right collater to use. + + Otherwise we assume all collaters are the same. + """ + if len(samples) == 0: + return None + if "full_id" in samples[0]: + _, key = self._map_index(samples[0]["full_id"]) + try: + batch = self.datasets[key].collater(samples) + except Exception: + print(f"Collating failed for key {key}", flush=True) + raise + return batch + else: + # Subclasses may override __getitem__ to not specify full_id + return list(self.datasets.values())[0].collater(samples) + + def num_tokens(self, index: int): + index, key = self._map_index(index) + return self.datasets[key].num_tokens(index) + + def size(self, index: int): + index, key = self._map_index(index) + return self.datasets[key].size(index) + + @property + def can_reuse_epoch_itr_across_epochs(self): + return False + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + logger.info(f"setting epoch of multi_corpus_dataset to {epoch}") + self.epoch = epoch + + @property + def supports_prefetch(self): + return False + + @property + def supports_fetch_outside_dataloader(self): + return all( + self.datasets[key].supports_fetch_outside_dataloader + for key in self.datasets + ) + + def batch_by_size( + self, + indices, + max_tokens=None, + max_sentences=None, + required_batch_size_multiple=1, + ): + if not self.batch_sample: + return super().batch_by_size( + indices, max_tokens, max_sentences, required_batch_size_multiple + ) + + dataset_indices = {key: [] for key in self.datasets} + for i in indices: + _, key = self._map_index(i) + dataset_indices[key].append(i) + + batches = [] + for key in dataset_indices: + cur_batches = super().batch_by_size( + np.array(dataset_indices[key], dtype=np.int64), + max_tokens, + max_sentences, + required_batch_size_multiple, + ) + logger.info(f"Created {len(cur_batches)} batches for dataset {key}") + batches += cur_batches + + # If this dataset is used in a distributed training setup, + # then shuffle such that the order is seeded by the distributed rank + # as well + if self.distributed_rank is not None: + with data_utils.numpy_seed(self.seed, self.epoch, self.distributed_rank): + np.random.shuffle(batches) + return batches diff --git a/fairseq/data/multi_corpus_sampled_dataset.py b/fairseq/data/multi_corpus_sampled_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e2e9fdf004dd1da519a170a5e8bc225775776f72 --- /dev/null +++ b/fairseq/data/multi_corpus_sampled_dataset.py @@ -0,0 +1,152 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections import OrderedDict +from typing import Callable, Dict, List + +import numpy as np + +from . import FairseqDataset + + +def uniform_sampler(x): + # Sample from uniform distribution + return np.random.choice(x, 1).item() + + +class MultiCorpusSampledDataset(FairseqDataset): + """ + Stores multiple instances of FairseqDataset together and in every iteration + creates a batch by first sampling a dataset according to a specified + probability distribution and then getting instances from that dataset. + + Args: + datasets: an OrderedDict of FairseqDataset instances. + sampling_func: A function for sampling over list of dataset keys. + The default strategy is to sample uniformly. + """ + + def __init__( + self, + datasets: Dict[str, FairseqDataset], + sampling_func: Callable[[List], int] = None, + ): + super().__init__() + assert isinstance(datasets, OrderedDict) + self.datasets = datasets + if sampling_func is None: + sampling_func = uniform_sampler + self.sampling_func = sampling_func + + self.total_num_instances = 0 + for _, dataset in datasets.items(): + assert isinstance(dataset, FairseqDataset) + self.total_num_instances += len(dataset) + + self._ordered_indices = None + + def __len__(self): + """ + Length of this dataset is the sum of individual datasets + """ + return self.total_num_instances + + def ordered_indices(self): + """ + Ordered indices for batching. Here we call the underlying + dataset's ordered_indices() so that we get the same random ordering + as we would have from using the underlying dataset directly. + """ + if self._ordered_indices is None: + self._ordered_indices = OrderedDict( + [ + (key, dataset.ordered_indices()) + for key, dataset in self.datasets.items() + ] + ) + return np.arange(len(self)) + + def _map_index_to_dataset(self, key: int, index: int): + """ + Different underlying datasets have different lengths. In order to ensure + we are not accessing an index outside the range of the current dataset + size, we wrap around. This function should be called after we have + created an ordering for this and all underlying datasets. + """ + assert ( + self._ordered_indices is not None + ), "Must call MultiCorpusSampledDataset.ordered_indices() first" + mapped_index = index % len(self.datasets[key]) + return self._ordered_indices[key][mapped_index] + + def __getitem__(self, index: int): + """ + Get the item associated with index from each underlying dataset. + Since index is in the range of [0, TotalNumInstances], we need to + map the index to the dataset before retrieving the item. + """ + return OrderedDict( + [ + (key, dataset[self._map_index_to_dataset(key, index)]) + for key, dataset in self.datasets.items() + ] + ) + + def collater(self, samples: List[Dict]): + """ + Generate a mini-batch for this dataset. + To convert this into a regular mini-batch we use the following + logic: + 1. Select a dataset using the specified probability distribution. + 2. Call the collater function of the selected dataset. + """ + if len(samples) == 0: + return None + + selected_key = self.sampling_func(list(self.datasets.keys())) + selected_samples = [sample[selected_key] for sample in samples] + return self.datasets[selected_key].collater(selected_samples) + + def num_tokens(self, index: int): + """ + Return an example's length (number of tokens), used for batching. Here + we return the max across all examples at index across all underlying + datasets. + """ + return max( + dataset.num_tokens(self._map_index_to_dataset(key, index)) + for key, dataset in self.datasets.items() + ) + + def size(self, index: int): + """ + Return an example's size as a float or tuple. Here we return the max + across all underlying datasets. This value is used when filtering a + dataset with max-positions. + """ + return max( + dataset.size(self._map_index_to_dataset(key, index)) + for key, dataset in self.datasets.items() + ) + + @property + def supports_prefetch(self): + return all( + getattr(dataset, "supports_prefetch", False) + for dataset in self.datasets.values() + ) + + def prefetch(self, indices): + for key, dataset in self.datasets.items(): + dataset.prefetch( + [self._map_index_to_dataset(key, index) for index in indices] + ) + + @property + def supports_fetch_outside_dataloader(self): + return all( + self.datasets[key].supports_fetch_outside_dataloader + for key in self.datasets + ) diff --git a/fairseq/data/multilingual/__init__.py b/fairseq/data/multilingual/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6264236915a7269a4d920ee8213004374dd86a9a --- /dev/null +++ b/fairseq/data/multilingual/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/fairseq/data/multilingual/__pycache__/__init__.cpython-310.pyc b/fairseq/data/multilingual/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..594c569678f39b9bbdedf149a2f72ad5fda55c91 Binary files /dev/null and b/fairseq/data/multilingual/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/data/multilingual/__pycache__/multilingual_data_manager.cpython-310.pyc b/fairseq/data/multilingual/__pycache__/multilingual_data_manager.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebb9539b435045b8ff969512e1855ac37304d5fb Binary files /dev/null and b/fairseq/data/multilingual/__pycache__/multilingual_data_manager.cpython-310.pyc differ diff --git a/fairseq/data/multilingual/__pycache__/multilingual_utils.cpython-310.pyc b/fairseq/data/multilingual/__pycache__/multilingual_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab973580ef38028c9783e2bcfa3b2adfd6865090 Binary files /dev/null and b/fairseq/data/multilingual/__pycache__/multilingual_utils.cpython-310.pyc differ diff --git a/fairseq/data/multilingual/__pycache__/sampled_multi_dataset.cpython-310.pyc b/fairseq/data/multilingual/__pycache__/sampled_multi_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fe164a942393ab79f7da850ad1abae7a393b790 Binary files /dev/null and b/fairseq/data/multilingual/__pycache__/sampled_multi_dataset.cpython-310.pyc differ diff --git a/fairseq/data/multilingual/__pycache__/sampled_multi_epoch_dataset.cpython-310.pyc b/fairseq/data/multilingual/__pycache__/sampled_multi_epoch_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cca1f2c605e1481b5c0c7a1c485ec6681b2016f6 Binary files /dev/null and b/fairseq/data/multilingual/__pycache__/sampled_multi_epoch_dataset.cpython-310.pyc differ diff --git a/fairseq/data/multilingual/__pycache__/sampling_method.cpython-310.pyc b/fairseq/data/multilingual/__pycache__/sampling_method.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32850a4a4568187c83c69868ccbde24d5fb4100c Binary files /dev/null and b/fairseq/data/multilingual/__pycache__/sampling_method.cpython-310.pyc differ diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..876dfcec36e4cf9236c21e440e9657a68036a278 --- /dev/null +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -0,0 +1,1156 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import json +import logging +import math +import os +from collections import OrderedDict, defaultdict +from argparse import ArgumentError + +from fairseq import utils +from fairseq.data import ( + AppendTokenDataset, + ConcatDataset, + Dictionary, + LanguagePairDataset, + PrependTokenDataset, + SampledMultiDataset, + SampledMultiEpochDataset, + StripTokenDataset, + TransformEosLangPairDataset, + TruncateDataset, + data_utils, + indexed_dataset, +) +from fairseq.data.multilingual.multilingual_utils import ( + EncoderLangtok, + LangTokSpec, + LangTokStyle, + augment_dictionary, + get_lang_tok, +) +from fairseq.data.multilingual.sampled_multi_dataset import CollateFormat +from fairseq.file_io import PathManager +from fairseq.utils import FileContentsAction, csv_str_list, eval_str_dict + + +logger = logging.getLogger(__name__) + +SRC_DICT_NAME = "src" +TGT_DICT_NAME = "tgt" + + +def _lang_id(dic: Dictionary, lang: str): + """Return language ID index.""" + idx = dic.index(lang) + assert idx != dic.unk_index, "cannot find language ID for lang {}".format(lang) + return idx + + +def load_sampling_weights(from_file): + with open(from_file) as f: + weights = json.load(f) + return weights + + +class MultilingualDatasetManager(object): + def __init__(self, args, lang_pairs, langs, dicts, sampling_method): + super().__init__() + self.args = args + self.seed = args.seed + self.lang_pairs = lang_pairs + self.extra_lang_pairs = ( + list({p for _, v in args.extra_lang_pairs.items() for p in v.split(",")}) + if args.extra_lang_pairs + else [] + ) + self.src_langs = { + p.split("-")[0] for p in args.lang_pairs + self.extra_lang_pairs + } + self.tgt_langs = { + p.split("-")[1] for p in args.lang_pairs + self.extra_lang_pairs + } + self.langs = langs + self.dicts = dicts + self.lang_dict = self.create_lang_dictionary(self.langs) + self.sampling_method = sampling_method + self.sampling_scheduler = None + self._has_sharded_data = False + self._num_shards_dict = {} + self._training_data_sizes = defaultdict(lambda: {}) + + @classmethod + def setup_data_manager(cls, args, lang_pairs, langs, dicts, sampling_method): + return MultilingualDatasetManager( + args, lang_pairs, langs, dicts, sampling_method + ) + + @staticmethod + def add_args(parser): + parser.add_argument( + "data", + help="colon separated path to data directories list, \ + will be iterated upon during epochs in round-robin manner", + action=FileContentsAction, + ) + parser.add_argument( + "--langs", + default=None, + type=csv_str_list, + help="a list of languages comma sperated languages which can appear in lang-pairs; " + "note that the ordering determines language token IDs", + ) + parser.add_argument( + "--lang-dict", + default=None, + type=str, + help="an external file which contains a list of " + "languages which can appear in lang-pairs; " + "note that the ordering determines language token IDs; " + "--langs and --lang-dict are two exclusive options", + ) + parser.add_argument( + "--source-dict", + default=None, + type=str, + help="path to source dictionary; if specified it will override per language dictionary loading", + ) + parser.add_argument( + "--target-dict", + default=None, + type=str, + help="path to target dictionary; if specified it will override per language dictionary loading", + ) + parser.add_argument( + "--lang-tok-style", + default=LangTokStyle.multilingual.value, + type=str, + choices=[LangTokStyle.multilingual.value, LangTokStyle.mbart.value], + help="language token styles", + ) + + parser.add_argument( + "--load-alignments", + action="store_true", + help="load the binarized alignments", + ) + parser.add_argument( + "--left-pad-source", + default="True", + type=str, + metavar="BOOL", + help="pad the source on the left", + ) + parser.add_argument( + "--left-pad-target", + default="False", + type=str, + metavar="BOOL", + help="pad the target on the left", + ) + try: + parser.add_argument( + "--max-source-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the source sequence", + ) + parser.add_argument( + "--max-target-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the target sequence", + ) + except ArgumentError: + # this might have already been defined. Once we transition this to hydra it should be fine to add it here. + pass + parser.add_argument( + "--upsample-primary", + default=1, + type=int, + help="amount to upsample primary dataset", + ) + parser.add_argument( + "--truncate-source", + action="store_true", + default=False, + help="truncate source to max-source-positions", + ) + parser.add_argument( + "--encoder-langtok", + default=None, + type=str, + choices=[EncoderLangtok.src.value, EncoderLangtok.tgt.value], + metavar="SRCTGT", + help="prepend to the beginning of source sentence the source or target " + "language token. (src/tgt)", + ) + parser.add_argument( + "--decoder-langtok", + action="store_true", + help="prepend to the beginning of target sentence the target language token", + ) + parser.add_argument( + "--lang-tok-replacing-bos-eos", action="store_true", default=False + ) + parser.add_argument( + "--enable-lang-ids", + default=False, + action="store_true", + help="whether to include language IDs in samples", + ) + parser.add_argument( + "--enable-reservsed-directions-shared-datasets", + default=False, + action="store_true", + help="whether to allow datasets be used in reversed directions", + ) + + parser.add_argument( + "--extra-data", + help='a dictionary of data name to this path, \ + e.g. {"mined", path_to_mined_data, "denoised": path_to_denoised_data}', + type=lambda uf: eval_str_dict(uf, type=str), + default=None, + ) + parser.add_argument( + "--extra-lang-pairs", + help='a dictionary of data name to the language pairs they serve, \ + e.g. {"mined": comma-separated-lang-pairs, "denoised": comma-separated-lang-pairs}', + type=lambda uf: eval_str_dict(uf, type=str), + default=None, + ) + parser.add_argument( + "--fixed-dictionary", + help="Fixed dictionary to use with model path", + default=None, + type=str, + ) + parser.add_argument( + "--langtoks-specs", + help='a list of comma separated data types that a set of language tokens to be specialized for, \ + e.g. "main,dae,mined". There will be a set of language tokens added to the vocab to \ + distinguish languages in different training data types. If not specified, default language \ + tokens per languages will be added', + default=LangTokSpec.main.value, + type=csv_str_list, + ) + parser.add_argument( + "--langtoks", + help='a dictionary of how to add language tokens, \ + e.g. {"mined": (None, "tgt"), "mono_dae": ("src.dae", "tgt"), "main": \ + ("src", "tgt")}, or {"mined": ("src.mined", "tgt")}', + default=None, + type=lambda uf: eval_str_dict(uf, type=str), + ) + parser.add_argument( + "--sampling-weights-from-file", + help='a file contain a python dictionary of how to sample data sets, \ + e.g. { "main:en_XX-es_XX": 0.2, "mined:en_XX-pt_XX": 0.5, \ + "mono_dae:es_XX-es_XX: 0.3, "main:en_xx-fr_XX": 0.8 }', + default=None, + type=str, + ) + parser.add_argument( + "--sampling-weights", + help='a dictionary of how to sample data sets, \ + e.g. { "main:en_XX-es_XX": 0.2, "mined:en_XX-pt_XX": 0.5, \ + "mono_dae:es_XX-es_XX: 0.3, "main:en_xx-fr_XX": 0.8 }', + default=None, + type=lambda uf: eval_str_dict(uf, type=str), + ) + parser.add_argument( + "--virtual-epoch-size", + default=None, + type=int, + help="virtual epoch size to speed up data loading", + ) + parser.add_argument( + "--virtual-data-size", + default=None, + type=int, + help="virtual data size of the whole joint dataset to speed" + "up data loading and have specific dynamic sampling strategy interval", + ) + + @classmethod + def load_langs(cls, args, **kwargs): + if args.lang_dict and args.langs: + raise ValueError("--langs and --lang-dict can not both be specified") + if args.lang_dict is None and args.langs is None: + logger.warning( + "External language dictionary is not provided; " + "use lang-pairs to infer the set of supported languages. " + "The language ordering is not stable which might cause " + "misalignment in pretraining and finetuning." + ) + # infer from lang_pairs as it is + langs = list( + {x for lang_pair in args.lang_pairs for x in lang_pair.split("-")} + ) + langs = sorted(langs) + logger.info(f"inferred language list: {langs}") + elif args.lang_dict: + with open( + PathManager.get_local_path(args.lang_dict), "r", encoding="utf-8" + ) as f: + langs = [lang.strip() for lang in f.readlines() if lang.strip()] + logger.info( + f"loaded language list from {args.lang_dict} as they are ordered in file" + ) + elif args.langs: + langs = args.langs + logger.info( + f"parsed the language list as they are ordered in the option: {langs}" + ) + return langs + + def has_sharded_data(self, split): + return self._has_sharded_data and split == getattr( + self.args, "train_subset", None + ) + + def _shared_collater(self): + return not (self.args.extra_data and "mono_dae" in self.args.extra_data) and ( + not self.args.lang_tok_replacing_bos_eos + ) + + def estimate_global_pass_epoch(self, epoch): + if self.args.virtual_epoch_size is None or self.args.virtual_data_size is None: + return None + # one epoch more for remaining data in each shard + virtual_epochs_per_shard = math.ceil( + self.args.virtual_data_size / self.args.virtual_epoch_size + ) + # note that fairseq epoch / shard_epoch starts from 1 + shard_epoch = (epoch - 1) // virtual_epochs_per_shard + 1 + return shard_epoch + + @classmethod + def prepare(cls, load_dictionary, args, **kargs): + args.left_pad_source = utils.eval_bool(args.left_pad_source) + args.left_pad_target = utils.eval_bool(args.left_pad_target) + + if not hasattr(args, "shuffle_instance"): + args.shuffle_instance = False + if args.langtoks is None: + args.langtoks = {} + if "main" not in args.langtoks: + src_langtok_spec = args.encoder_langtok if args.encoder_langtok else None + tgt_langtok_spec = "tgt" if args.decoder_langtok else None + args.langtoks["main"] = (src_langtok_spec, tgt_langtok_spec) + + def check_langs(langs, pairs): + messages = [] + for src, tgt in pairs: + if src not in langs or tgt not in langs: + messages.append( + f"language pair {src}-{tgt} contains languages " + "that are not in the language dictionary" + ) + if len(messages) > 0: + raise ValueError(" ".join(messages) + f"; langs: {langs}") + + if args.lang_pairs is None: + raise ValueError( + "--lang-pairs is required. List all the language pairs in the training objective." + ) + if isinstance(args.lang_pairs, str): + args.lang_pairs = args.lang_pairs.split(",") + if args.source_lang is not None or args.target_lang is not None: + training = False + else: + training = True + language_list = cls.load_langs(args, **kargs) + check_langs( + language_list, + ( + [p.split("-") for p in args.lang_pairs] + if training + else [(args.source_lang, args.target_lang)] + ), + ) + + def load_dictionary_and_postproc(path): + d = load_dictionary(path) + augment_dictionary( + dictionary=d, + language_list=language_list, + lang_tok_style=args.lang_tok_style, + langtoks_specs=args.langtoks_specs, + extra_data=args.extra_data, + ) + return d + + dicts = cls.load_all_dictionaries( + args, language_list, load_dictionary_and_postproc, training + ) + return language_list, dicts, training + + @classmethod + def load_all_dictionaries(cls, args, language_list, load_dictionary, training): + dicts = OrderedDict() + if args.source_dict is not None: + dicts[SRC_DICT_NAME] = load_dictionary(args.source_dict) + if args.target_dict is not None: + dicts[TGT_DICT_NAME] = load_dictionary(args.target_dict) + + if training: + extra_lang_pairs = ( + list( + {p for _, v in args.extra_lang_pairs.items() for p in v.split(",")} + ) + if args.extra_lang_pairs + else [] + ) + src_langs_to_load_dicts = sorted( + {p.split("-")[0] for p in (args.lang_pairs + extra_lang_pairs)} + ) + tgt_langs_to_load_dicts = sorted( + {p.split("-")[1] for p in (args.lang_pairs + extra_lang_pairs)} + ) + else: + src_langs_to_load_dicts = [args.source_lang] + tgt_langs_to_load_dicts = [args.target_lang] + + paths = utils.split_paths(args.data) + assert len(paths) > 0 + + def load_dicts(langs_to_load_dicts): + for lang in langs_to_load_dicts: + dicts[lang] = load_dictionary( + os.path.join(paths[0], "dict.{}.txt".format(lang)) + ) + if len(dicts) > 0: + dict0 = next(iter(dicts.values())) + assert dicts[lang].pad() == dict0.pad() + assert dicts[lang].eos() == dict0.eos() + assert dicts[lang].unk() == dict0.unk() + logger.info("[{}] dictionary: {} types".format(lang, len(dicts[lang]))) + + if args.fixed_dictionary is not None: + fixed_dict = load_dictionary(args.fixed_dictionary) + dicts = { + lang: fixed_dict + for lang in src_langs_to_load_dicts + tgt_langs_to_load_dicts + } + else: + if args.source_dict is None: + load_dicts(src_langs_to_load_dicts) + if args.target_dict is None: + load_dicts(tgt_langs_to_load_dicts) + return dicts + + def get_source_dictionary(self, lang): + if self.args.source_dict is not None: + return self.dicts[SRC_DICT_NAME] + else: + return self.dicts[lang] + + def get_target_dictionary(self, lang): + if self.args.target_dict is not None: + return self.dicts[TGT_DICT_NAME] + else: + return self.dicts[lang] + + @classmethod + def create_lang_dictionary(cls, langs): + unk = "" + # hack to remove symbols other than unk as they are not needed by lang dict + lang_dict = Dictionary(pad=unk, eos=unk, unk=unk, bos=unk) + for lang in langs: + lang_dict.add_symbol(lang) + return lang_dict + + @classmethod + def get_langtok_index(cls, lang_tok, dic): + idx = dic.index(lang_tok) + assert ( + idx != dic.unk_index + ), "cannot find language token {} in the dictionary".format(lang_tok) + return idx + + def get_encoder_langtok(self, src_lang, tgt_lang, spec=None): + if spec is None: + return None + if spec and spec.startswith("src"): + if src_lang is None: + return None + langtok = get_lang_tok( + lang=src_lang, lang_tok_style=self.args.lang_tok_style, spec=spec + ) + else: + if tgt_lang is None: + return None + langtok = get_lang_tok( + lang=tgt_lang, lang_tok_style=self.args.lang_tok_style, spec=spec + ) + return self.get_langtok_index( + langtok, + self.get_source_dictionary(src_lang) + if src_lang + else self.get_target_dictionary(tgt_lang), + ) + + def get_decoder_langtok(self, tgt_lang, spec=None): + if spec is None: + return None + langtok = get_lang_tok( + lang=tgt_lang, lang_tok_style=self.args.lang_tok_style, spec=spec + ) + return self.get_langtok_index(langtok, self.get_target_dictionary(tgt_lang)) + + @classmethod + def load_data(cls, path, vdict, impl): + dataset = data_utils.load_indexed_dataset(path, vdict, impl) + return dataset + + @classmethod + def split_exists(cls, split, src, tgt, lang, data_path, dataset_impl): + filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) + return indexed_dataset.dataset_exists(filename, impl=dataset_impl) + + def load_lang_dataset( + self, + data_path, + split, + src, + src_dict, + tgt, + tgt_dict, + combine, + dataset_impl, + upsample_primary, + max_source_positions, + prepend_bos=False, + load_alignments=False, + truncate_source=False, + ): + + src_datasets = [] + tgt_datasets = [] + + for k in itertools.count(): + split_k = split + (str(k) if k > 0 else "") + + # infer langcode + if self.split_exists(split_k, src, tgt, src, data_path, dataset_impl): + prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt)) + elif self.split_exists(split_k, tgt, src, src, data_path, dataset_impl): + prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src)) + else: + if k > 0: + break + else: + logger.error( + f"Dataset not found: {data_path}, {split_k}, {src}, {tgt}" + ) + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, data_path) + ) + + src_dataset = self.load_data(prefix + src, src_dict, dataset_impl) + if truncate_source: + src_dataset = AppendTokenDataset( + TruncateDataset( + StripTokenDataset(src_dataset, src_dict.eos()), + max_source_positions - 1, + ), + src_dict.eos(), + ) + src_datasets.append(src_dataset) + tgt_datasets.append(self.load_data(prefix + tgt, tgt_dict, dataset_impl)) + + logger.info( + "{} {} {}-{} {} examples".format( + data_path, split_k, src, tgt, len(src_datasets[-1]) + ) + ) + + if not combine: + break + + assert len(src_datasets) == len(tgt_datasets) + + if len(src_datasets) == 1: + src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] + else: + sample_ratios = [1] * len(src_datasets) + sample_ratios[0] = upsample_primary + src_dataset = ConcatDataset(src_datasets, sample_ratios) + tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) + + if prepend_bos: + assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index") + src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) + tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) + + align_dataset = None + if load_alignments: + align_path = os.path.join( + data_path, "{}.align.{}-{}".format(split, src, tgt) + ) + if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): + align_dataset = data_utils.load_indexed_dataset( + align_path, None, dataset_impl + ) + + return src_dataset, tgt_dataset, align_dataset + + def load_langpair_dataset( + self, + data_path, + split, + src, + src_dict, + tgt, + tgt_dict, + combine, + dataset_impl, + upsample_primary, + left_pad_source, + left_pad_target, + max_source_positions, + max_target_positions, + prepend_bos=False, + load_alignments=False, + truncate_source=False, + src_dataset_transform_func=lambda dataset: dataset, + tgt_dataset_transform_func=lambda dataset: dataset, + src_lang_id=None, + tgt_lang_id=None, + langpairs_sharing_datasets=None, + ): + norm_direction = "-".join(sorted([src, tgt])) + if langpairs_sharing_datasets is not None: + src_dataset = langpairs_sharing_datasets.get( + (data_path, split, norm_direction, src), "NotInCache" + ) + tgt_dataset = langpairs_sharing_datasets.get( + (data_path, split, norm_direction, tgt), "NotInCache" + ) + align_dataset = langpairs_sharing_datasets.get( + (data_path, split, norm_direction, src, tgt), "NotInCache" + ) + + # a hack: any one is not in cache, we need to reload them + if ( + langpairs_sharing_datasets is None + or src_dataset == "NotInCache" + or tgt_dataset == "NotInCache" + or align_dataset == "NotInCache" + or split != getattr(self.args, "train_subset", None) + ): + # source and target datasets can be reused in reversed directions to save memory + # reversed directions of valid and test data will not share source and target datasets + src_dataset, tgt_dataset, align_dataset = self.load_lang_dataset( + data_path, + split, + src, + src_dict, + tgt, + tgt_dict, + combine, + dataset_impl, + upsample_primary, + max_source_positions=max_source_positions, + prepend_bos=prepend_bos, + load_alignments=load_alignments, + truncate_source=truncate_source, + ) + src_dataset = src_dataset_transform_func(src_dataset) + tgt_dataset = tgt_dataset_transform_func(tgt_dataset) + if langpairs_sharing_datasets is not None: + langpairs_sharing_datasets[ + (data_path, split, norm_direction, src) + ] = src_dataset + langpairs_sharing_datasets[ + (data_path, split, norm_direction, tgt) + ] = tgt_dataset + langpairs_sharing_datasets[ + (data_path, split, norm_direction, src, tgt) + ] = align_dataset + if align_dataset is None: + # no align data so flag the reverse direction as well in sharing + langpairs_sharing_datasets[ + (data_path, split, norm_direction, tgt, src) + ] = align_dataset + else: + logger.info( + f"Reusing source and target datasets of [{split}] {tgt}-{src} for reversed direction: " + f"[{split}] {src}-{tgt}: src length={len(src_dataset)}; tgt length={len(tgt_dataset)}" + ) + + return LanguagePairDataset( + src_dataset, + src_dataset.sizes, + src_dict, + tgt_dataset, + tgt_dataset.sizes if tgt_dataset is not None else None, + tgt_dict, + left_pad_source=left_pad_source, + left_pad_target=left_pad_target, + align_dataset=align_dataset, + src_lang_id=src_lang_id, + tgt_lang_id=tgt_lang_id, + ) + + def src_dataset_tranform_func(self, src_lang, tgt_lang, dataset, spec=None): + if self.args.lang_tok_replacing_bos_eos: + # it is handled by self.alter_dataset_langtok + # TODO: Unifiy with alter_dataset_langtok + return dataset + if spec is None: + return dataset + tok = self.get_encoder_langtok(src_lang, tgt_lang, spec) + if tok: + return PrependTokenDataset(dataset, tok) + return dataset + + def tgt_dataset_tranform_func(self, source_lang, target_lang, dataset, spec=None): + if dataset is None: + # note that target dataset can be None during inference time + return None + if self.args.lang_tok_replacing_bos_eos: + # TODO: Unifiy with alter_dataset_langtok + # It is handled by self.alter_dataset_langtok. + # The complication in self.alter_dataset_langtok + # makes a unified framework difficult. + return dataset + # if not self.args.decoder_langtok: + if not spec: + return dataset + tok = self.get_decoder_langtok(target_lang, spec) + if tok: + return PrependTokenDataset(dataset, tok) + return dataset + + def alter_dataset_langtok( + self, + lang_pair_dataset, + src_eos=None, + src_lang=None, + tgt_eos=None, + tgt_lang=None, + src_langtok_spec=None, + tgt_langtok_spec=None, + ): + if src_langtok_spec is None and tgt_langtok_spec is None: + return lang_pair_dataset + + new_src_eos = None + if ( + src_langtok_spec is not None + and src_eos is not None + and (src_lang is not None or tgt_lang is not None) + ): + new_src_eos = self.get_encoder_langtok(src_lang, tgt_lang, src_langtok_spec) + else: + src_eos = None + + new_tgt_bos = None + if tgt_langtok_spec and tgt_eos is not None and tgt_lang is not None: + new_tgt_bos = self.get_decoder_langtok(tgt_lang, tgt_langtok_spec) + else: + tgt_eos = None + + return TransformEosLangPairDataset( + lang_pair_dataset, + src_eos=src_eos, + new_src_eos=new_src_eos, + tgt_bos=tgt_eos, + new_tgt_bos=new_tgt_bos, + ) + + def load_a_dataset( + self, + split, + data_path, + src, + src_dict, + tgt, + tgt_dict, + combine, + prepend_bos=False, + langpairs_sharing_datasets=None, + data_category=None, + **extra_kwargs, + ): + dataset_impl = self.args.dataset_impl + upsample_primary = self.args.upsample_primary + left_pad_source = self.args.left_pad_source + left_pad_target = self.args.left_pad_target + max_source_positions = self.args.max_source_positions + max_target_positions = self.args.max_target_positions + load_alignments = self.args.load_alignments + truncate_source = self.args.truncate_source + src_dataset_transform_func = self.src_dataset_tranform_func + tgt_dataset_transform_func = self.tgt_dataset_tranform_func + enable_lang_ids = self.args.enable_lang_ids + lang_dictionary = self.lang_dict + src_langtok_spec, tgt_langtok_spec = extra_kwargs["langtok_spec"] + + src_langtok = self.get_encoder_langtok(src, tgt, src_langtok_spec) + tgt_langtok = self.get_decoder_langtok(tgt, tgt_langtok_spec) + logger.info( + f"{data_category}:{src}-{tgt} src_langtok: {src_langtok}; tgt_langtok: {tgt_langtok}" + ) + + langpair_ds = self.load_langpair_dataset( + data_path, + split, + src, + src_dict, + tgt, + tgt_dict, + combine, + dataset_impl, + upsample_primary, + left_pad_source, + left_pad_target, + max_source_positions, + max_target_positions, + prepend_bos, + load_alignments, + truncate_source, + src_dataset_transform_func=lambda dataset: src_dataset_transform_func( + src, tgt, dataset, src_langtok_spec + ), + tgt_dataset_transform_func=lambda dataset: tgt_dataset_transform_func( + src, tgt, dataset, tgt_langtok_spec + ), + src_lang_id=_lang_id(lang_dictionary, src) + if enable_lang_ids and lang_dictionary is not None + else None, + tgt_lang_id=_lang_id(lang_dictionary, tgt) + if enable_lang_ids and lang_dictionary is not None + else None, + langpairs_sharing_datasets=langpairs_sharing_datasets, + ) + # TODO: handle modified lang toks for mined data and dae data + if self.args.lang_tok_replacing_bos_eos: + ds = self.alter_dataset_langtok( + langpair_ds, + src_eos=self.get_source_dictionary(src).eos() + if src + else self.get_target_dictionary(tgt).eos(), + src_lang=src, + tgt_eos=self.get_target_dictionary(tgt).eos(), + tgt_lang=tgt, + src_langtok_spec=src_langtok_spec, + tgt_langtok_spec=tgt_langtok_spec, + ) + else: + ds = langpair_ds + return ds + + def load_split_langpair_datasets(self, split, data_param_list): + datasets = [] + langpairs_sharing_datasets = ( + {} if self.args.enable_reservsed_directions_shared_datasets else None + ) + for param in data_param_list: + ds = self.load_a_dataset( + split=split, + langpairs_sharing_datasets=langpairs_sharing_datasets, + **param, + ) + datasets.append(ds) + return datasets + + def get_data_paths_and_lang_pairs(self, split): + datapaths = {"main": self.args.data} + lang_pairs = {"main": self.lang_pairs} + if split == getattr(self.args, "train_subset", None): + # only training data can have extra data and extra language pairs + if self.args.extra_data: + extra_datapaths = self.args.extra_data + datapaths.update(extra_datapaths) + if self.args.extra_lang_pairs: + extra_lang_pairs = { + k: v.split(",") for k, v in self.args.extra_lang_pairs.items() + } + lang_pairs.update(extra_lang_pairs) + return datapaths, lang_pairs + + @classmethod + def get_dataset_key(cls, data_category, src, tgt): + return f"{data_category}:{src}-{tgt}" + + @classmethod + def _get_shard_num_dict(cls, split, paths): + shards = defaultdict(int) + for path in paths: + files = PathManager.ls(path) + directions = set() + for f in files: + if f.startswith(split) and f.endswith(".idx"): + # idx files of the form "{split}.{src}-{tgt}.{lang}.idx" + direction = f.split(".")[-3] + directions.add(direction) + for direction in directions: + shards[direction] += 1 + return shards + + def get_split_num_data_shards(self, split): + if split in self._num_shards_dict: + return self._num_shards_dict[split] + num_shards_dict = {} + data_paths, lang_pairs = self.get_data_paths_and_lang_pairs(split) + + for data_category, paths in data_paths.items(): + if data_category not in lang_pairs: + continue + paths = utils.split_paths(paths) + shards_dict = self._get_shard_num_dict(split, paths) + lang_dirs = [ + lang_pair.split("-") for lang_pair in lang_pairs[data_category] + ] + lang_dirs = [x if len(x) > 1 else (x[0], x[0]) for x in lang_dirs] + for src, tgt in lang_dirs: + key = self.get_dataset_key(data_category, src, tgt) + if "mono_" in data_category: + # monolingual data requires tgt only + assert src is None or src == tgt, ( + f"error: src={src}, " + f"tgt={tgt} for data_category={data_category}" + ) + num_shards_dict[key] = shards_dict[tgt] + else: + if f"{src}-{tgt}" in shards_dict: + num_shards_dict[key] = shards_dict[f"{src}-{tgt}"] + elif f"{tgt}-{src}" in shards_dict: + # follow the fairseq tradition to use reversed direction data if it is not available + num_shards_dict[key] = shards_dict[f"{tgt}-{src}"] + self._num_shards_dict[split] = num_shards_dict + logger.info(f"[{split}] num of shards: {num_shards_dict}") + return num_shards_dict + + @classmethod + def get_shard_id(cls, num_shards, epoch, shard_epoch=None): + shard = epoch if shard_epoch is None else shard_epoch + shard = (shard - 1) % num_shards + return shard + + def get_split_data_path(self, paths, epoch, shard_epoch, num_shards): + path = paths[self.get_shard_id(num_shards, epoch, shard_epoch)] + return path + + def get_split_data_param_list(self, split, epoch, shard_epoch=None): + # TODO: to extend with extra datasets and keys and loop over different shard data paths + param_list = [] + data_paths, lang_pairs = self.get_data_paths_and_lang_pairs(split) + logger.info(f"langtoks settings: {self.args.langtoks}") + split_num_shards_dict = self.get_split_num_data_shards(split) + for data_category, paths in data_paths.items(): + if data_category not in lang_pairs: + continue + paths = utils.split_paths(paths) + assert len(paths) > 0 + if len(paths) > 1: + self._has_sharded_data = True + if split != getattr(self.args, "train_subset", None): + # if not training data set, use the first shard for valid and test + paths = paths[:1] + + if data_category in self.args.langtoks: + lang_tok_spec = self.args.langtoks[data_category] + else: + # default to None + lang_tok_spec = (None, None) + + # infer langcode + lang_dirs = [ + lang_pair.split("-") for lang_pair in lang_pairs[data_category] + ] + lang_dirs = [x if len(x) > 1 else (x[0], x[0]) for x in lang_dirs] + for src, tgt in lang_dirs: + assert src is not None or data_category == "mono_dae", ( + f"error: src={src}, " f"tgt={tgt} for data_category={data_category}" + ) + # logger.info(f"preparing param for {data_category}: {src} - {tgt}") + key = self.get_dataset_key(data_category, src, tgt) + data_path = self.get_split_data_path( + paths, epoch, shard_epoch, split_num_shards_dict[key] + ) + param_list.append( + { + "key": key, + "data_path": data_path, + "split": split, + "src": src, + "src_dict": self.get_source_dictionary(src) + if src and data_category != "mono_dae" + else None, + "tgt": tgt, + "tgt_dict": self.get_target_dictionary(tgt), + "data_category": data_category, + "langtok_spec": lang_tok_spec, + } + ) + return param_list + + def get_train_dataset_sizes( + self, data_param_list, datasets, epoch, shard_epoch=None + ): + num_shards = [ + self.get_split_num_data_shards(param["split"])[param["key"]] + for param in data_param_list + ] + data_sizes = [] + for (key, d), num_shard in zip(datasets, num_shards): + my_data_sizes = self._training_data_sizes[key] + shard_ind = self.get_shard_id(num_shard, epoch, shard_epoch) + if shard_ind not in my_data_sizes: + my_data_sizes[shard_ind] = len(d) + known_size = max(my_data_sizes.values()) + data_sizes.append( + # If we don't know the data size of the shard yet, + # use the the max known data size to approximate. + # Note that we preprocess shards by a designated shard size + # and put any remaining data at the end into the last shard so + # the max shard size approximation is almost correct before loading + # the last shard; after loading the last shard, it will have the + # exact data sizes of the whole data size. + (key, sum(my_data_sizes.get(i, known_size) for i in range(num_shard))) + ) + logger.info( + f"estimated total data sizes of all shards used in sampling ratios: {data_sizes}. " + "Note that if the data a shard has not been loaded yet, use the max known data size to approximate" + ) + return [s for _, s in data_sizes] + + def get_train_sampling_ratios( + self, data_param_list, datasets, epoch=1, shard_epoch=None + ): + data_sizes = self.get_train_dataset_sizes( + data_param_list, datasets, epoch, shard_epoch + ) + sampling_func = self.sampling_method.sampling_method_selector() + sample_ratios = sampling_func(data_sizes) if sampling_func is not None else None + return sample_ratios + + def get_sampling_ratios(self, data_param_list, datasets, epoch, shard_epoch=None): + if self.args.sampling_weights_from_file: + weights = load_sampling_weights(self.args.sampling_weights_from_file) + sample_ratios = [weights[k] for k, _ in datasets] + logger.info( + "| ignoring --sampling-weights when loadding sampling weights " + f"from file {self.args.sampling_weights_from_file}" + ) + elif self.args.sampling_weights: + sample_ratios = [self.args.sampling_weights[k] for k, _ in datasets] + else: + sample_ratios = self.get_train_sampling_ratios( + data_param_list, datasets, epoch, shard_epoch + ) + + if sample_ratios is not None: + logger.info( + "| Upsample ratios: {}".format( + list(zip(map(lambda x: x["key"], data_param_list), sample_ratios)) + ) + ) + assert len(sample_ratios) == len(datasets) + return sample_ratios + + def load_split_datasets( + self, split, training, epoch=1, combine=False, shard_epoch=None, **kwargs + ): + data_param_list = self.get_split_data_param_list( + split, epoch, shard_epoch=shard_epoch + ) + langpairs_sharing_datasets = ( + {} if self.args.enable_reservsed_directions_shared_datasets else None + ) + datasets = [ + ( + param["key"], + self.load_a_dataset( + combine=combine, + langpairs_sharing_datasets=langpairs_sharing_datasets, + **param, + ), + ) + for param in data_param_list + ] + return datasets, data_param_list + + def load_into_concat_dataset(self, split, datasets, data_param_list): + if self.args.lang_tok_replacing_bos_eos: + # TODO: to investigate why TransformEosLangPairDataset doesn't work with ConcatDataset + return SampledMultiDataset( + OrderedDict(datasets), + sampling_ratios=None, + eval_key=None, + collate_format=CollateFormat.single, + virtual_size=None, + split=split, + ) + return ConcatDataset([d for _, d in datasets]) + + def load_sampled_multi_epoch_dataset( + self, split, training, epoch=0, combine=False, shard_epoch=None, **kwargs + ): + datasets, data_param_list = self.load_split_datasets( + split, training, epoch, combine, shard_epoch=shard_epoch, **kwargs + ) + if training and split == getattr(self.args, "train_subset", None): + sample_ratios = self.get_sampling_ratios(data_param_list, datasets, epoch) + return SampledMultiEpochDataset( + OrderedDict(datasets), + epoch=epoch, + shard_epoch=shard_epoch, + # valid and test datasets will be degenerate to concating datasets: + sampling_ratios=sample_ratios, + eval_key=None, + collate_format=CollateFormat.single, + virtual_size=self.args.virtual_data_size, + split=split, + virtual_epoch_size=self.args.virtual_epoch_size, + # if not using lang_tok altering, simplified to use the same collater + shared_collater=self._shared_collater(), + ) + else: + return self.load_into_concat_dataset(split, datasets, data_param_list) + + def load_sampled_multi_dataset( + self, split, training, epoch=0, combine=False, shard_epoch=None, **kwargs + ): + datasets, data_param_list = self.load_split_datasets( + split, training, epoch, combine, shard_epoch=shard_epoch, **kwargs + ) + if training and split == getattr(self.args, "train_subset", None): + sample_ratios = self.get_sampling_ratios(data_param_list, datasets, epoch) + return SampledMultiDataset( + OrderedDict(datasets), + epoch=epoch, + # valid and test datasets will be degerate to concating datasets: + sampling_ratios=sample_ratios, + eval_key=None, + collate_format=CollateFormat.single, + virtual_size=self.args.virtual_data_size, + split=split, + # if not using lang_tok altering, simplified to use the same collater + shared_collater=self._shared_collater(), + ) + else: + return self.load_into_concat_dataset(split, datasets, data_param_list) + + def load_dataset( + self, split, training, epoch=0, combine=False, shard_epoch=None, **kwargs + ): + if self.args.virtual_epoch_size is None: + return self.load_sampled_multi_dataset( + split, training, epoch, combine, shard_epoch, **kwargs + ) + else: + return self.load_sampled_multi_epoch_dataset( + split, training, epoch, combine, shard_epoch, **kwargs + ) diff --git a/fairseq/data/multilingual/multilingual_utils.py b/fairseq/data/multilingual/multilingual_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b4e0f9828cabfdbe375d05d9152b58bdbd6de7dc --- /dev/null +++ b/fairseq/data/multilingual/multilingual_utils.py @@ -0,0 +1,63 @@ +from enum import Enum +from typing import Dict, List, Optional, Sequence + +import torch +from fairseq.data import Dictionary + + +class EncoderLangtok(Enum): + """ + Prepend to the beginning of source sentence either the + source or target language token. (src/tgt). + """ + + src = "src" + tgt = "tgt" + + +class LangTokSpec(Enum): + main = "main" + mono_dae = "mono_dae" + + +class LangTokStyle(Enum): + multilingual = "multilingual" + mbart = "mbart" + + +@torch.jit.export +def get_lang_tok( + lang: str, lang_tok_style: str, spec: str = LangTokSpec.main.value +) -> str: + # TOKEN_STYLES can't be defined outside this fn since it needs to be + # TorchScriptable. + TOKEN_STYLES: Dict[str, str] = { + LangTokStyle.mbart.value: "[{}]", + LangTokStyle.multilingual.value: "__{}__", + } + + if spec.endswith("dae"): + lang = f"{lang}_dae" + elif spec.endswith("mined"): + lang = f"{lang}_mined" + style = TOKEN_STYLES[lang_tok_style] + return style.format(lang) + + +def augment_dictionary( + dictionary: Dictionary, + language_list: List[str], + lang_tok_style: str, + langtoks_specs: Sequence[str] = (LangTokSpec.main.value,), + extra_data: Optional[Dict[str, str]] = None, +) -> None: + for spec in langtoks_specs: + for language in language_list: + dictionary.add_symbol( + get_lang_tok(lang=language, lang_tok_style=lang_tok_style, spec=spec) + ) + + if lang_tok_style == LangTokStyle.mbart.value or ( + extra_data is not None and LangTokSpec.mono_dae.value in extra_data + ): + dictionary.add_symbol("") diff --git a/fairseq/data/multilingual/sampled_multi_dataset.py b/fairseq/data/multilingual/sampled_multi_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ece9a9721e453112553d7f41755133b1c937e14e --- /dev/null +++ b/fairseq/data/multilingual/sampled_multi_dataset.py @@ -0,0 +1,468 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import datetime +import hashlib +import logging +import time +from bisect import bisect_right +from collections import OrderedDict, defaultdict +from enum import Enum +from typing import List + +import numpy as np +import torch + +from fairseq.data import FairseqDataset, data_utils +from fairseq.distributed import utils as distributed_utils + + +def get_time_gap(s, e): + return ( + datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s) + ).__str__() + + +logger = logging.getLogger(__name__) + + +def default_virtual_size_func(datasets, ratios, max_scale_up=1.5): + sizes = [len(d) for d in datasets] + if ratios is None: + return sum(sizes) + largest_idx = np.argmax(sizes) + largest_r = ratios[largest_idx] + largest_s = sizes[largest_idx] + # set virtual sizes relative to the largest dataset + virtual_sizes = [(r / largest_r) * largest_s for r in ratios] + vsize = sum(virtual_sizes) + max_size = sum(sizes) * max_scale_up + return int(vsize if vsize < max_size else max_size) + + +class CollateFormat(Enum): + single = 1 + ordered_dict = 2 + + +class SampledMultiDataset(FairseqDataset): + """Samples from multiple sub-datasets according to given sampling ratios. + Args: + datasets ( + List[~torch.utils.data.Dataset] + or OrderedDict[str, ~torch.utils.data.Dataset] + ): datasets + sampling_ratios (List[float]): list of probability of each dataset to be sampled + (default: None, which corresponds to concatenating all dataset together). + seed (int): RNG seed to use (default: 2). + epoch (int): starting epoch number (default: 1). + eval_key (str, optional): a key used at evaluation time that causes + this instance to pass-through batches from *datasets[eval_key]*. + collate_format (CollateFormat): collater output format, either CollateFormat.ordered_dict or + CollateFormat.single (default: CollateFormat.single) where CollateFormat.single configures + the collater to output batches of data mixed from all sub-datasets, + and CollateFormat.ordered_dict configures the collater to output a dictionary of batches indexed by keys + of sub-datasets. + Note that not all sub-datasets will present in a single batch in both formats. + virtual_size (int, or callable): the expected virtual size of the dataset (default: default_virtual_size_func). + split (str): the split of the data, e.g. 'train', 'valid' or 'test'. + shared_collater (bool): whether or not to all sub-datasets have the same collater. + shuffle (bool): whether or not to shuffle data (default: True). + """ + + def __init__( + self, + datasets, + sampling_ratios=None, + seed=2, + epoch=1, + eval_key=None, + collate_format=CollateFormat.single, + virtual_size=default_virtual_size_func, + split="", + shared_collater=False, + shuffle=True, + ): + super().__init__() + self.shared_collater = shared_collater + self.shuffle = shuffle + + if isinstance(datasets, OrderedDict): + self.keys = list(datasets.keys()) + datasets = list(datasets.values()) + elif isinstance(datasets, List): + self.keys = list(range(len(datasets))) + else: + raise AssertionError() + self.datasets = datasets + self.split = split + + self.eval_key = eval_key + if self.eval_key is not None: + self.collate_format = CollateFormat.single + else: + self.collate_format = collate_format + + self.seed = seed + self._cur_epoch = None + + self.cumulated_sizes = None + # self.datasets[k][self._cur_indices[i]] is the data item i in this sampled dataset + # namely, data item i is sampled from the kth sub-dataset self.datasets[k] + # where self.cumulated_sizes[k-1] <= i < self.cumulated_sizes[k] + self._cur_indices = None + + self._sizes = None + self.virtual_size_per_dataset = None + # caching properties + self._reset_cached_properties() + self.setup_sampling(sampling_ratios, virtual_size) + self.set_epoch(epoch) + + def _clean_if_not_none(self, var_list): + for v in var_list: + if v is not None: + del v + + def _reset_cached_properties(self): + self._clean_if_not_none([self._sizes, self._cur_indices]) + self._sizes = None + self._cur_indices = None + + def setup_sampling(self, sample_ratios, virtual_size): + sizes = [len(d) for d in self.datasets] + if sample_ratios is None: + # default back to concating datasets + self.sample_ratios = None + self.virtual_size = sum(sizes) + else: + if not isinstance(sample_ratios, np.ndarray): + sample_ratios = np.array(sample_ratios) + self.sample_ratios = sample_ratios + virtual_size = ( + default_virtual_size_func if virtual_size is None else virtual_size + ) + self.virtual_size = ( + virtual_size(self.datasets, self.sample_ratios) + if callable(virtual_size) + else virtual_size + ) + + def adjust_sampling(self, epoch, sampling_ratios, virtual_size): + if sampling_ratios is not None: + sampling_ratios = self._sync_sample_ratios(sampling_ratios) + self.setup_sampling(sampling_ratios, virtual_size) + + def _sync_sample_ratios(self, ratios): + # in case the ratios are not precisely the same across processes + # also to ensure every procresses update the ratios in the same pace + ratios = torch.DoubleTensor(ratios) + if torch.distributed.is_initialized(): + if torch.cuda.is_available(): + distributed_utils.all_reduce( + ratios.cuda(), group=distributed_utils.get_data_parallel_group() + ) + else: + distributed_utils.all_reduce( + ratios, group=distributed_utils.get_data_parallel_group() + ) + ret = ratios.cpu() + ret = ret.numpy() + return ret + + def random_choice_in_dataset(self, rng, dataset, choice_size): + if hasattr(dataset, "random_choice_in_dataset"): + return dataset.random_choice_in_dataset(rng, choice_size) + dataset_size = len(dataset) + return rng.choice( + dataset_size, choice_size, replace=(choice_size > dataset_size) + ) + + def get_virtual_indices(self, rng, datasets, sample_ratios, virtual_size): + def get_counts(sample_ratios): + counts = np.array([virtual_size * r for r in sample_ratios], dtype=np.int64) + diff = virtual_size - counts.sum() + assert diff >= 0 + # due to round-offs, the size might not match the desired sizes + if diff > 0: + dataset_indices = rng.choice( + len(sample_ratios), size=diff, p=sample_ratios + ) + for i in dataset_indices: + counts[i] += 1 + return counts + + def get_in_dataset_indices(datasets, sizes, sample_ratios): + counts = get_counts(sample_ratios) + # uniformally sample desired counts for each dataset + # if the desired counts are large, sample with replacement: + indices = [ + self.random_choice_in_dataset(rng, d, c) + for c, d in zip(counts, datasets) + ] + return indices + + sizes = [len(d) for d in datasets] + if sample_ratios is None: + # default back to concating datasets + in_dataset_indices = [list(range(s)) for s in sizes] + virtual_sizes_per_dataset = sizes + else: + ratios = sample_ratios / sample_ratios.sum() + in_dataset_indices = get_in_dataset_indices(datasets, sizes, ratios) + virtual_sizes_per_dataset = [len(d) for d in in_dataset_indices] + virtual_sizes_per_dataset = np.array(virtual_sizes_per_dataset, np.int64) + cumulative_sizes = np.cumsum(virtual_sizes_per_dataset) + assert sum(virtual_sizes_per_dataset) == virtual_size + assert cumulative_sizes[-1] == virtual_size + if virtual_size < sum(sizes): + logger.warning( + f"virtual data size ({virtual_size}) is less than real data size ({sum(sizes)})." + " If virtual size << real data size, there could be data coverage issue." + ) + in_dataset_indices = np.hstack(in_dataset_indices) + return in_dataset_indices, cumulative_sizes, virtual_sizes_per_dataset + + def _get_dataset_and_index(self, index): + i = bisect_right(self.cumulated_sizes, index) + return i, self._cur_indices[index] + + def __getitem__(self, index): + # self.__getitem__(index) returns self.datasets[k][self._cur_indices[index]] + # where k satisfies self.cumulated_sizes[k - 1] <= k < self.cumulated_sizes[k] + ds_idx, ds_sample_idx = self._get_dataset_and_index(index) + ret = (ds_idx, self.datasets[ds_idx][ds_sample_idx]) + return ret + + def num_tokens(self, index): + return self.sizes[index].max() + + def num_tokens_vec(self, indices): + sizes_vec = self.sizes[np.array(indices)] + # max across all dimensions but first one + return np.amax(sizes_vec, axis=tuple(range(1, len(sizes_vec.shape)))) + + def size(self, index): + return self.sizes[index] + + def __len__(self): + return self.virtual_size + + def collater(self, samples, **extra_args): + """Merge a list of samples to form a mini-batch.""" + if len(samples) == 0: + return None + if self.collate_format == "ordered_dict": + collect_samples = [[] for _ in range(len(self.datasets))] + for (i, sample) in samples: + collect_samples[i].append(sample) + batch = OrderedDict( + [ + (self.keys[i], dataset.collater(collect_samples[i])) + for i, (key, dataset) in enumerate(zip(self.keys, self.datasets)) + if len(collect_samples[i]) > 0 + ] + ) + elif self.shared_collater: + batch = self.datasets[0].collater([s for _, s in samples]) + else: + samples_dict = defaultdict(list) + pad_to_length = ( + defaultdict(int) + if "pad_to_length" not in extra_args + else extra_args["pad_to_length"] + ) + for ds_idx, s in samples: + pad_to_length["source"] = max( + pad_to_length["source"], s["source"].size(0) + ) + if s["target"] is not None: + pad_to_length["target"] = max( + pad_to_length["target"], s["target"].size(0) + ) + samples_dict[ds_idx].append(s) + batches = [ + self.datasets[i].collater(samples_dict[i], pad_to_length=pad_to_length) + for i in range(len(self.datasets)) + if len(samples_dict[i]) > 0 + ] + + def straight_data(tensors): + batch = torch.cat(tensors, dim=0) + return batch + + src_lengths = straight_data( + [b["net_input"]["src_lengths"] for b in batches] + ) + src_lengths, sort_order = src_lengths.sort(descending=True) + + def straight_order(tensors): + batch = straight_data(tensors) + return batch.index_select(0, sort_order) + + batch = { + "id": straight_order([b["id"] for b in batches]), + "nsentences": sum(b["nsentences"] for b in batches), + "ntokens": sum(b["ntokens"] for b in batches), + "net_input": { + "src_tokens": straight_order( + [b["net_input"]["src_tokens"] for b in batches] + ), + "src_lengths": src_lengths, + }, + "target": straight_order([b["target"] for b in batches]) + if batches[0]["target"] is not None + else None, + } + if "prev_output_tokens" in batches[0]["net_input"]: + batch["net_input"]["prev_output_tokens"] = straight_order( + [b["net_input"]["prev_output_tokens"] for b in batches] + ) + if "src_lang_id" in batches[0]["net_input"]: + batch["net_input"]["src_lang_id"] = straight_order( + [b["net_input"]["src_lang_id"] for b in batches] + ) + if "tgt_lang_id" in batches[0]: + batch["tgt_lang_id"] = straight_order( + [b["tgt_lang_id"] for b in batches] + ) + return batch + + @property + def sizes(self): + if self._sizes is not None: + return self._sizes + start_time = time.time() + in_sub_dataset_indices = [ + self._cur_indices[ + 0 if i == 0 else self.cumulated_sizes[i - 1] : self.cumulated_sizes[i] + ] + for i in range(len(self.datasets)) + ] + sub_dataset_sizes = [ + d.sizes[indices] + for d, indices in zip(self.datasets, in_sub_dataset_indices) + ] + self._sizes = np.vstack(sub_dataset_sizes) + logger.info(f"sizes() calling time: {get_time_gap(start_time, time.time())}") + return self._sizes + + def ordered_indices(self): + if self.shuffle: + indices = np.random.permutation(len(self)) + else: + indices = np.arange(len(self)) + + sizes = self.sizes + tgt_sizes = sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None + src_sizes = ( + sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes + ) + + # sort by target length, then source length + if tgt_sizes is not None: + indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")] + sort_indices = indices[np.argsort(src_sizes[indices], kind="mergesort")] + return sort_indices + + def prefetch(self, indices): + prefetch_indices = [[] for _ in range(len(self.datasets))] + for i in indices: + ds_idx, ds_sample_idx = self._get_dataset_and_index(i) + prefetch_indices[ds_idx].append(ds_sample_idx) + for i in range(len(prefetch_indices)): + self.datasets[i].prefetch(prefetch_indices[i]) + + @property + def can_reuse_epoch_itr_across_epochs(self): + return False + + def set_epoch(self, epoch): + super().set_epoch(epoch) + if epoch == self._cur_epoch: + # re-enter so return + return + for d in self.datasets: + if hasattr(d, "set_epoch"): + d.set_epoch(epoch) + self._cur_epoch = epoch + self._establish_virtual_datasets() + + def _establish_virtual_datasets(self): + if self.sample_ratios is None and self._cur_indices is not None: + # not a samping dataset, no need to resample if indices are already established + return + self._reset_cached_properties() + + start_time = time.time() + # Generate a weighted sample of indices as a function of the + # random seed and the current epoch. + rng = np.random.RandomState( + [ + int( + hashlib.sha1( + str(self.__class__.__name__).encode("utf-8") + ).hexdigest(), + 16, + ) + % (2**32), + self.seed % (2**32), # global seed + self._cur_epoch, # epoch index, + ] + ) + self._clean_if_not_none( + [self.cumulated_sizes, self.virtual_size_per_dataset, self._sizes] + ) + self._sizes = None + + indices, cumulated_sizes, virtual_size_per_dataset = self.get_virtual_indices( + rng, self.datasets, self.sample_ratios, self.virtual_size + ) + self._cur_indices = indices + self.cumulated_sizes = cumulated_sizes + self.virtual_size_per_dataset = virtual_size_per_dataset + + raw_sizes = [len(d) for d in self.datasets] + sampled_sizes = self.virtual_size_per_dataset + logger.info( + f"[{self.split}] Raw sizes: {str(dict(zip(self.keys, raw_sizes)))}; " + f"raw total size: {sum(raw_sizes)}" + ) + logger.info( + f"[{self.split}] Resampled sizes: {str(dict(zip(self.keys, sampled_sizes)))}; " + f"resampled total size: {sum(sampled_sizes)}" + ) + if self.sample_ratios is not None: + logger.info( + f"[{self.split}] Upsampling ratios: {str(dict(zip(self.keys, self.sample_ratios)))}" + ) + else: + logger.info(f"[{self.split}] A concat dataset") + logger.info( + f"[{self.split}] virtual dataset established time: {get_time_gap(start_time, time.time())}" + ) + + def filter_indices_by_size(self, indices, max_sizes): + """Filter a list of sample indices. Remove those that are longer + than specified in max_sizes. + + Args: + indices (np.array): original array of sample indices + max_sizes (int or list[int] or tuple[int]): max sample size, + can be defined separately for src and tgt (then list or tuple) + + Returns: + np.array: filtered sample array + list: list of removed indices + """ + sizes = self.sizes + tgt_sizes = sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None + src_sizes = ( + sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes + ) + + return data_utils.filter_paired_dataset_indices_by_size( + src_sizes, tgt_sizes, indices, max_sizes + ) diff --git a/fairseq/data/multilingual/sampled_multi_epoch_dataset.py b/fairseq/data/multilingual/sampled_multi_epoch_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..bb187a8dc28c7647fe93cd4ba3d26f5a892ca7fd --- /dev/null +++ b/fairseq/data/multilingual/sampled_multi_epoch_dataset.py @@ -0,0 +1,199 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import hashlib +import logging +import math + +import numpy as np + +from fairseq.data import SampledMultiDataset + +from .sampled_multi_dataset import CollateFormat, default_virtual_size_func + +logger = logging.getLogger(__name__) + + +class SampledMultiEpochDataset(SampledMultiDataset): + """Samples from multiple sub-datasets according to sampling ratios + using virtual epoch sizes to speed up dataloading. + Args: + datasets ( + List[~torch.utils.data.Dataset] + or OrderedDict[str, ~torch.utils.data.Dataset] + ): datasets + sampling_ratios (List[float]): list of probability of each dataset to be sampled + (default: None, which corresponds to concating all dataset together). + seed (int): RNG seed to use (default: 2). + epoch (int): starting epoch number (default: 1). + eval_key (str, optional): a key used at evaluation time that causes + this instance to pass-through batches from *datasets[eval_key]*. + collate_format (CollateFormat): collater output format, either CollateFormat.ordered_dict or + CollateFormat.single (default: CollateFormat.single) where CollateFormat.single configures + the collater to output batches of data mixed from all sub-datasets, + and CollateFormat.ordered_dict configures the collater to output a dictionary of batches indexed by keys + of sub-datasets. + Note that not all sub-datasets will present in a single batch in both formats. + virtual_size (int, or callable): the expected virtual size of the dataset (default: default_virtual_size_func). + split (str): the split of the data, e.g. 'train', 'valid' or 'test'. + virtual_epoch_size (int): virtual epoch size, the dataset will go through the data by + this virtual epoch size one by one to speed up data loading, e.g. indicing and filtering + can be performed whenever a virtual epoch is loaded without waiting for the whole dataset to be loaded. + shared_collater (bool): whether or not to all sub-datasets have the same collater. + shard_epoch (int): the real epoch number for shard selection. + shuffle (bool): whether or not to shuffle data (default: True). + """ + + def __init__( + self, + datasets, + sampling_ratios=None, + seed=2, + epoch=1, + eval_key=None, + collate_format=CollateFormat.single, + virtual_size=default_virtual_size_func, + split="", + virtual_epoch_size=None, + shared_collater=False, + shard_epoch=1, + shuffle=True, + ): + self.virtual_epoch_size = virtual_epoch_size + self._current_epoch_start_index = None + self._random_global_indices = None + self.shard_epoch = shard_epoch if shard_epoch is not None else 1 + self.load_next_shard = None + self._epoch_sizes = None + super().__init__( + datasets=datasets, + sampling_ratios=sampling_ratios, + seed=seed, + epoch=epoch, + eval_key=eval_key, + collate_format=collate_format, + virtual_size=virtual_size, + split=split, + shared_collater=shared_collater, + shuffle=shuffle, + ) + + def _setup(self, epoch): + self.virtual_epoch_size = ( + self.virtual_epoch_size + if self.virtual_epoch_size is not None + else self.virtual_size + ) + if self.virtual_epoch_size > self.virtual_size: + logger.warning( + f"virtual epoch size {self.virtual_epoch_size} " + f"is greater than virtual dataset size {self.virtual_size}" + ) + self.virtual_epoch_size = self.virtual_size + self.num_virtual_epochs = math.ceil(self.virtual_size / self.virtual_epoch_size) + self._current_epoch_start_index = self._get_epoch_start_index(epoch) + logger.info( + f"virtual epoch size {self.virtual_epoch_size}; virtual dataset size {self.virtual_size}" + ) + + def _map_epoch_index_to_global(self, index): + index = self._current_epoch_start_index + index + # add randomness + return self._random_global_indices[index] + + @property + def sizes(self): + if self._epoch_sizes is not None: + return self._epoch_sizes + _sizes = super().sizes + indices = self._random_global_indices[ + self._current_epoch_start_index : self._current_epoch_start_index + + len(self) + ] + self._epoch_sizes = _sizes[indices] + # del super()._sizes to save memory + del self._sizes + self._sizes = None + return self._epoch_sizes + + def _get_dataset_and_index(self, index): + i = self._map_epoch_index_to_global(index) + return super()._get_dataset_and_index(i) + + def __len__(self): + return ( + self.virtual_epoch_size + if self._current_epoch_start_index + self.virtual_epoch_size + < self.virtual_size + else self.virtual_size - self._current_epoch_start_index + ) + + def set_epoch(self, epoch): + if self._current_epoch_start_index is None: + # initializing epoch idnices of a virtual dataset + self._setup(epoch) + self._next_virtual_epoch(epoch) + else: + # working on already intialized epoch indices + if epoch == self._cur_epoch: + # re-enter so return + return + self._next_virtual_epoch(epoch) + + def _get_epoch_start_index(self, epoch): + assert epoch >= 1 # fairseq is using 1-based epoch everywhere + return ((epoch - 1) % self.num_virtual_epochs) * self.virtual_epoch_size + + def _next_global_indices(self, epoch): + rng = np.random.RandomState( + [ + int( + hashlib.sha1( + str(self.__class__.__name__).encode("utf-8") + ).hexdigest(), + 16, + ) + % (2**32), + self.seed % (2**32), # global seed + epoch, # epoch index, + ] + ) + del self._random_global_indices + self._random_global_indices = rng.choice( + self.virtual_size, self.virtual_size, replace=False + ) + if self.load_next_shard is None: + self.load_next_shard = False + else: + # increase shard epoch for next loading + self.shard_epoch += 1 + self.load_next_shard = True + logger.info( + "to load next epoch/shard in next load_dataset: " + f"epoch={epoch}/shard_epoch={self.shard_epoch}" + ) + + def _next_virtual_epoch(self, epoch): + index = self._get_epoch_start_index(epoch) + if index == 0 or self._random_global_indices is None: + # need to start from the beginning, + # so call super().set_epoch(epoch) to establish the global virtual indices + logger.info( + "establishing a new set of global virtual indices for " + f"epoch={epoch}/shard_epoch={self.shard_epoch}" + ) + super().set_epoch(epoch) + self._next_global_indices(epoch) + else: + self._cur_epoch = epoch + + # reset cache sizes and ordered_indices for the epoch after moving to a new epoch + self._clean_if_not_none( + [ + self._epoch_sizes, + ] + ) + self._epoch_sizes = None + self._current_epoch_start_index = index diff --git a/fairseq/data/multilingual/sampling_method.py b/fairseq/data/multilingual/sampling_method.py new file mode 100644 index 0000000000000000000000000000000000000000..140c68f01d60e902ef88f11f30f8813dc15fc681 --- /dev/null +++ b/fairseq/data/multilingual/sampling_method.py @@ -0,0 +1,78 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import List + + +logger = logging.getLogger(__name__) + + +def uniform(dataset_sizes: List[int]): + return [1.0] * len(dataset_sizes) + + +def temperature_sampling(dataset_sizes, temp): + total_size = sum(dataset_sizes) + return [(size / total_size) ** (1.0 / temp) for size in dataset_sizes] + + +def make_temperature_sampling(temp=1.0): + def sampling_func(dataset_sizes): + return temperature_sampling(dataset_sizes, temp) + + return sampling_func + + +def make_ratio_sampling(ratios): + def sampling_func(dataset_sizes): + return ratios + + return sampling_func + + +class SamplingMethod: + @staticmethod + def add_arguments(parser): + parser.add_argument( + "--sampling-method", + choices=[ + "uniform", + "temperature", + "concat", + "RoundRobin", + ], + type=str, + default="concat", + help="The method to sample data per language pairs", + ) + parser.add_argument( + "--sampling-temperature", + default=1.5, + type=float, + help="only work with --sampling-method temperature", + ) + + @staticmethod + def build_sampler(args, task): + return SamplingMethod(args, task) + + def __init__(self, args, task): + self.args = args + self.task = task + + def is_adaptive(self): + return False + + def sampling_method_selector(self): + args = self.args + logger.info(f"selected sampler: {args.sampling_method}") + if args.sampling_method == "uniform": + return uniform + elif args.sampling_method == "temperature" or self.is_adaptive(): + return make_temperature_sampling(float(args.sampling_temperature)) + else: + # default to concating all data set together + return None diff --git a/fairseq/data/nested_dictionary_dataset.py b/fairseq/data/nested_dictionary_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..52e74abddacc923c5e29b0a0c41d7efc85482d3b --- /dev/null +++ b/fairseq/data/nested_dictionary_dataset.py @@ -0,0 +1,125 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections import OrderedDict + +import torch +from torch.utils.data.dataloader import default_collate + +from . import FairseqDataset + + +def _flatten(dico, prefix=None): + """Flatten a nested dictionary.""" + new_dico = OrderedDict() + if isinstance(dico, dict): + prefix = prefix + "." if prefix is not None else "" + for k, v in dico.items(): + if v is None: + continue + new_dico.update(_flatten(v, prefix + k)) + elif isinstance(dico, list): + for i, v in enumerate(dico): + new_dico.update(_flatten(v, prefix + ".[" + str(i) + "]")) + else: + new_dico = OrderedDict({prefix: dico}) + return new_dico + + +def _unflatten(dico): + """Unflatten a flattened dictionary into a nested dictionary.""" + new_dico = OrderedDict() + for full_k, v in dico.items(): + full_k = full_k.split(".") + node = new_dico + for k in full_k[:-1]: + if k.startswith("[") and k.endswith("]"): + k = int(k[1:-1]) + if k not in node: + node[k] = OrderedDict() + node = node[k] + node[full_k[-1]] = v + return new_dico + + +class NestedDictionaryDataset(FairseqDataset): + def __init__(self, defn, sizes=None): + super().__init__() + self.defn = _flatten(defn) + self.sizes = [sizes] if not isinstance(sizes, (list, tuple)) else sizes + + first = None + for v in self.defn.values(): + if not isinstance( + v, + ( + FairseqDataset, + torch.utils.data.Dataset, + ), + ): + raise ValueError("Expected Dataset but found: {}".format(v.__class__)) + first = first or v + if len(v) > 0: + assert len(v) == len(first), "dataset lengths must match" + + self._len = len(first) + + def __getitem__(self, index): + return OrderedDict((k, ds[index]) for k, ds in self.defn.items()) + + def __len__(self): + return self._len + + def collater(self, samples): + """Merge a list of samples to form a mini-batch. + + Args: + samples (List[dict]): samples to collate + + Returns: + dict: a mini-batch suitable for forwarding with a Model + """ + if len(samples) == 0: + return {} + sample = OrderedDict() + for k, ds in self.defn.items(): + try: + sample[k] = ds.collater([s[k] for s in samples]) + except NotImplementedError: + sample[k] = default_collate([s[k] for s in samples]) + return _unflatten(sample) + + def num_tokens(self, index): + """Return the number of tokens in a sample. This value is used to + enforce ``--max-tokens`` during batching.""" + return max(s[index] for s in self.sizes) + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + if len(self.sizes) == 1: + return self.sizes[0][index] + else: + return (s[index] for s in self.sizes) + + @property + def supports_prefetch(self): + """Whether this dataset supports prefetching.""" + return any(ds.supports_prefetch for ds in self.defn.values()) + + def prefetch(self, indices): + """Prefetch the data required for this epoch.""" + for ds in self.defn.values(): + if getattr(ds, "supports_prefetch", False): + ds.prefetch(indices) + + @property + def can_reuse_epoch_itr_across_epochs(self): + return all(ds.can_reuse_epoch_itr_across_epochs for ds in self.defn.values()) + + def set_epoch(self, epoch): + super().set_epoch(epoch) + for ds in self.defn.values(): + ds.set_epoch(epoch) diff --git a/fairseq/data/noising.py b/fairseq/data/noising.py new file mode 100644 index 0000000000000000000000000000000000000000..e92e83c2cd2e2950d387f93ae8a80acbc12f909f --- /dev/null +++ b/fairseq/data/noising.py @@ -0,0 +1,334 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from fairseq.data import data_utils + + +class WordNoising(object): + """Generate a noisy version of a sentence, without changing words themselves.""" + + def __init__(self, dictionary, bpe_cont_marker="@@", bpe_end_marker=None): + self.dictionary = dictionary + self.bpe_end = None + if bpe_cont_marker: + self.bpe_end = np.array( + [ + not self.dictionary[i].endswith(bpe_cont_marker) + for i in range(len(self.dictionary)) + ] + ) + elif bpe_end_marker: + self.bpe_end = np.array( + [ + self.dictionary[i].endswith(bpe_end_marker) + for i in range(len(self.dictionary)) + ] + ) + + self.get_word_idx = ( + self._get_bpe_word_idx if self.bpe_end is not None else self._get_token_idx + ) + + def noising(self, x, lengths, noising_prob=0.0): + raise NotImplementedError() + + def _get_bpe_word_idx(self, x): + """ + Given a list of BPE tokens, for every index in the tokens list, + return the index of the word grouping that it belongs to. + For example, for input x corresponding to ["how", "are", "y@@", "ou"], + return [[0], [1], [2], [2]]. + """ + # x: (T x B) + bpe_end = self.bpe_end[x] + + if x.size(0) == 1 and x.size(1) == 1: + # Special case when we only have one word in x. If x = [[N]], + # bpe_end is a scalar (bool) instead of a 2-dim array of bools, + # which makes the sum operation below fail. + return np.array([[0]]) + + # do a reduce front sum to generate word ids + word_idx = bpe_end[::-1].cumsum(0)[::-1] + word_idx = word_idx.max(0)[None, :] - word_idx + return word_idx + + def _get_token_idx(self, x): + """ + This is to extend noising functions to be able to apply to non-bpe + tokens, e.g. word or characters. + """ + x = torch.t(x) + word_idx = np.array([range(len(x_i)) for x_i in x]) + return np.transpose(word_idx) + + +class WordDropout(WordNoising): + """Randomly drop input words. If not passing blank_idx (default is None), + then dropped words will be removed. Otherwise, it will be replaced by the + blank_idx.""" + + def __init__( + self, + dictionary, + default_dropout_prob=0.1, + bpe_cont_marker="@@", + bpe_end_marker=None, + ): + super().__init__(dictionary, bpe_cont_marker, bpe_end_marker) + self.default_dropout_prob = default_dropout_prob + + def noising(self, x, lengths, dropout_prob=None, blank_idx=None): + if dropout_prob is None: + dropout_prob = self.default_dropout_prob + # x: (T x B), lengths: B + if dropout_prob == 0: + return x, lengths + + assert 0 < dropout_prob < 1 + + # be sure to drop entire words + word_idx = self.get_word_idx(x) + sentences = [] + modified_lengths = [] + for i in range(lengths.size(0)): + # Since dropout probabilities need to apply over non-pad tokens, + # it is not trivial to generate the keep mask without consider + # input lengths; otherwise, this could be done outside the loop + + # We want to drop whole words based on word_idx grouping + num_words = max(word_idx[:, i]) + 1 + + # ith example: [x0, x1, ..., eos, pad, ..., pad] + # We should only generate keep probs for non-EOS tokens. Thus if the + # input sentence ends in EOS, the last word idx is not included in + # the dropout mask generation and we append True to always keep EOS. + # Otherwise, just generate the dropout mask for all word idx + # positions. + has_eos = x[lengths[i] - 1, i] == self.dictionary.eos() + if has_eos: # has eos? + keep = np.random.rand(num_words - 1) >= dropout_prob + keep = np.append(keep, [True]) # keep EOS symbol + else: + keep = np.random.rand(num_words) >= dropout_prob + + words = x[: lengths[i], i].tolist() + + # TODO: speed up the following loop + # drop words from the input according to keep + new_s = [ + w if keep[word_idx[j, i]] else blank_idx for j, w in enumerate(words) + ] + new_s = [w for w in new_s if w is not None] + # we need to have at least one word in the sentence (more than the + # start / end sentence symbols) + if len(new_s) <= 1: + # insert at beginning in case the only token left is EOS + # EOS should be at end of list. + new_s.insert(0, words[np.random.randint(0, len(words))]) + assert len(new_s) >= 1 and ( + not has_eos # Either don't have EOS at end or last token is EOS + or (len(new_s) >= 2 and new_s[-1] == self.dictionary.eos()) + ), "New sentence is invalid." + sentences.append(new_s) + modified_lengths.append(len(new_s)) + # re-construct input + modified_lengths = torch.LongTensor(modified_lengths) + modified_x = torch.LongTensor( + modified_lengths.max(), modified_lengths.size(0) + ).fill_(self.dictionary.pad()) + for i in range(modified_lengths.size(0)): + modified_x[: modified_lengths[i], i].copy_(torch.LongTensor(sentences[i])) + + return modified_x, modified_lengths + + +class WordShuffle(WordNoising): + """Shuffle words by no more than k positions.""" + + def __init__( + self, + dictionary, + default_max_shuffle_distance=3, + bpe_cont_marker="@@", + bpe_end_marker=None, + ): + super().__init__(dictionary, bpe_cont_marker, bpe_end_marker) + self.default_max_shuffle_distance = 3 + + def noising(self, x, lengths, max_shuffle_distance=None): + if max_shuffle_distance is None: + max_shuffle_distance = self.default_max_shuffle_distance + # x: (T x B), lengths: B + if max_shuffle_distance == 0: + return x, lengths + + # max_shuffle_distance < 1 will return the same sequence + assert max_shuffle_distance > 1 + + # define noise word scores + noise = np.random.uniform( + 0, + max_shuffle_distance, + size=(x.size(0), x.size(1)), + ) + noise[0] = -1 # do not move start sentence symbol + # be sure to shuffle entire words + word_idx = self.get_word_idx(x) + x2 = x.clone() + for i in range(lengths.size(0)): + length_no_eos = lengths[i] + if x[lengths[i] - 1, i] == self.dictionary.eos(): + length_no_eos = lengths[i] - 1 + # generate a random permutation + scores = word_idx[:length_no_eos, i] + noise[word_idx[:length_no_eos, i], i] + # ensure no reordering inside a word + scores += 1e-6 * np.arange(length_no_eos.item()) + permutation = scores.argsort() + # shuffle words + x2[:length_no_eos, i].copy_( + x2[:length_no_eos, i][torch.from_numpy(permutation)] + ) + return x2, lengths + + +class UnsupervisedMTNoising(WordNoising): + """ + Implements the default configuration for noising in UnsupervisedMT + (github.com/facebookresearch/UnsupervisedMT) + """ + + def __init__( + self, + dictionary, + max_word_shuffle_distance, + word_dropout_prob, + word_blanking_prob, + bpe_cont_marker="@@", + bpe_end_marker=None, + ): + super().__init__(dictionary) + self.max_word_shuffle_distance = max_word_shuffle_distance + self.word_dropout_prob = word_dropout_prob + self.word_blanking_prob = word_blanking_prob + + self.word_dropout = WordDropout( + dictionary=dictionary, + bpe_cont_marker=bpe_cont_marker, + bpe_end_marker=bpe_end_marker, + ) + self.word_shuffle = WordShuffle( + dictionary=dictionary, + bpe_cont_marker=bpe_cont_marker, + bpe_end_marker=bpe_end_marker, + ) + + def noising(self, x, lengths): + # 1. Word Shuffle + noisy_src_tokens, noisy_src_lengths = self.word_shuffle.noising( + x=x, + lengths=lengths, + max_shuffle_distance=self.max_word_shuffle_distance, + ) + # 2. Word Dropout + noisy_src_tokens, noisy_src_lengths = self.word_dropout.noising( + x=noisy_src_tokens, + lengths=noisy_src_lengths, + dropout_prob=self.word_dropout_prob, + ) + # 3. Word Blanking + noisy_src_tokens, noisy_src_lengths = self.word_dropout.noising( + x=noisy_src_tokens, + lengths=noisy_src_lengths, + dropout_prob=self.word_blanking_prob, + blank_idx=self.dictionary.unk(), + ) + + return noisy_src_tokens + + +class NoisingDataset(torch.utils.data.Dataset): + def __init__( + self, + src_dataset, + src_dict, + seed, + noiser=None, + noising_class=UnsupervisedMTNoising, + **kwargs + ): + """ + Wrap a :class:`~torch.utils.data.Dataset` and apply noise to the + samples based on the supplied noising configuration. + + Args: + src_dataset (~torch.utils.data.Dataset): dataset to wrap. + to build self.src_dataset -- + a LanguagePairDataset with src dataset as the source dataset and + None as the target dataset. Should NOT have padding so that + src_lengths are accurately calculated by language_pair_dataset + collate function. + We use language_pair_dataset here to encapsulate the tgt_dataset + so we can re-use the LanguagePairDataset collater to format the + batches in the structure that SequenceGenerator expects. + src_dict (~fairseq.data.Dictionary): source dictionary + seed (int): seed to use when generating random noise + noiser (WordNoising): a pre-initialized :class:`WordNoising` + instance. If this is None, a new instance will be created using + *noising_class* and *kwargs*. + noising_class (class, optional): class to use to initialize a + default :class:`WordNoising` instance. + kwargs (dict, optional): arguments to initialize the default + :class:`WordNoising` instance given by *noiser*. + """ + self.src_dataset = src_dataset + self.src_dict = src_dict + self.seed = seed + self.noiser = ( + noiser + if noiser is not None + else noising_class( + dictionary=src_dict, + **kwargs, + ) + ) + self.sizes = src_dataset.sizes + + def __getitem__(self, index): + """ + Returns a single noisy sample. Multiple samples are fed to the collater + create a noising dataset batch. + """ + src_tokens = self.src_dataset[index] + src_lengths = torch.LongTensor([len(src_tokens)]) + src_tokens = src_tokens.unsqueeze(0) + + # Transpose src tokens to fit expected shape of x in noising function + # (batch size, sequence length) -> (sequence length, batch size) + src_tokens_t = torch.t(src_tokens) + + with data_utils.numpy_seed(self.seed + index): + noisy_src_tokens = self.noiser.noising(src_tokens_t, src_lengths) + + # Transpose back to expected src_tokens format + # (sequence length, 1) -> (1, sequence length) + noisy_src_tokens = torch.t(noisy_src_tokens) + return noisy_src_tokens[0] + + def __len__(self): + """ + The length of the noising dataset is the length of src. + """ + return len(self.src_dataset) + + @property + def supports_prefetch(self): + return self.src_dataset.supports_prefetch + + def prefetch(self, indices): + if self.src_dataset.supports_prefetch: + self.src_dataset.prefetch(indices) diff --git a/fairseq/data/num_samples_dataset.py b/fairseq/data/num_samples_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..99a17495c701d8a05e0268f98bf453905e11d078 --- /dev/null +++ b/fairseq/data/num_samples_dataset.py @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import FairseqDataset + + +class NumSamplesDataset(FairseqDataset): + def __getitem__(self, index): + return 1 + + def __len__(self): + return 0 + + def collater(self, samples): + return sum(samples) diff --git a/fairseq/data/numel_dataset.py b/fairseq/data/numel_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ac86dfd2f1d89055de909656d61d6aca85523f00 --- /dev/null +++ b/fairseq/data/numel_dataset.py @@ -0,0 +1,31 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from . import BaseWrapperDataset + + +class NumelDataset(BaseWrapperDataset): + def __init__(self, dataset, reduce=False): + super().__init__(dataset) + self.reduce = reduce + + def __getitem__(self, index): + item = self.dataset[index] + if torch.is_tensor(item): + return torch.numel(item) + else: + return np.size(item) + + def __len__(self): + return len(self.dataset) + + def collater(self, samples): + if self.reduce: + return sum(samples) + else: + return torch.tensor(samples) diff --git a/fairseq/data/offset_tokens_dataset.py b/fairseq/data/offset_tokens_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6fabbdcdaa1a8f70d8d8c07db4cd53754503c194 --- /dev/null +++ b/fairseq/data/offset_tokens_dataset.py @@ -0,0 +1,15 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import BaseWrapperDataset + + +class OffsetTokensDataset(BaseWrapperDataset): + def __init__(self, dataset, offset): + super().__init__(dataset) + self.offset = offset + + def __getitem__(self, idx): + return self.dataset[idx] + self.offset diff --git a/fairseq/data/pad_dataset.py b/fairseq/data/pad_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b512d370f94d6c4009b1cd42aa3d49279003e59c --- /dev/null +++ b/fairseq/data/pad_dataset.py @@ -0,0 +1,31 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.data import data_utils + +from . import BaseWrapperDataset + + +class PadDataset(BaseWrapperDataset): + def __init__(self, dataset, pad_idx, left_pad, pad_length=None): + super().__init__(dataset) + self.pad_idx = pad_idx + self.left_pad = left_pad + self.pad_length = pad_length + + def collater(self, samples): + return data_utils.collate_tokens( + samples, self.pad_idx, left_pad=self.left_pad, pad_to_length=self.pad_length + ) + + +class LeftPadDataset(PadDataset): + def __init__(self, dataset, pad_idx): + super().__init__(dataset, pad_idx, left_pad=True) + + +class RightPadDataset(PadDataset): + def __init__(self, dataset, pad_idx): + super().__init__(dataset, pad_idx, left_pad=False) diff --git a/fairseq/data/padding_mask_dataset.py b/fairseq/data/padding_mask_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d7f7b88dbbb5cc101073652be2211b962b25418a --- /dev/null +++ b/fairseq/data/padding_mask_dataset.py @@ -0,0 +1,38 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from fairseq.data import data_utils +from . import BaseWrapperDataset + + +class PaddingMaskDataset(BaseWrapperDataset): + def __init__(self, dataset, left_pad, pad_length=None): + super().__init__(dataset) + self.left_pad = left_pad + self.pad_length = pad_length + + def __getitem__(self, index): + item = self.dataset[index] + return torch.zeros_like(item).bool() + + def __len__(self): + return len(self.dataset) + + def collater(self, samples): + return data_utils.collate_tokens( + samples, True, left_pad=self.left_pad, pad_to_length=self.pad_length + ) + + +class LeftPaddingMaskDataset(PaddingMaskDataset): + def __init__(self, dataset): + super().__init__(dataset, left_pad=True) + + +class RightPaddingMaskDataset(PaddingMaskDataset): + def __init__(self, dataset): + super().__init__(dataset, left_pad=False) diff --git a/fairseq/data/plasma_utils.py b/fairseq/data/plasma_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..459fb8acd789e7b03c70201cb5cb2a9e7dc4f325 --- /dev/null +++ b/fairseq/data/plasma_utils.py @@ -0,0 +1,197 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import hashlib +import json +import subprocess +import tempfile +from typing import Hashable + +try: + import pyarrow.plasma as plasma + + PYARROW_AVAILABLE = True +except ImportError: + plasma = None + PYARROW_AVAILABLE = False + + +class PlasmaArray: + """ + Wrapper around numpy arrays that automatically moves the data to shared + memory upon serialization. This is particularly helpful when passing numpy + arrays through multiprocessing, so that data is not unnecessarily + duplicated or pickled. + """ + + def __init__(self, array): + super().__init__() + self.array = array + self.disable = array.nbytes < 134217728 # disable for arrays <128MB + self.object_id = None + self.path = None + + # variables with underscores shouldn't be pickled + self._client = None + self._server = None + self._server_tmp = None + self._plasma = None + + @property + def plasma(self): + if self._plasma is None and not self.disable: + self._plasma = plasma + return self._plasma + + def start_server(self): + if self.plasma is None or self._server is not None: + return + assert self.object_id is None + assert self.path is None + self._server_tmp = tempfile.NamedTemporaryFile() + self.path = self._server_tmp.name + self._server = subprocess.Popen( + ["plasma_store", "-m", str(int(1.05 * self.array.nbytes)), "-s", self.path] + ) + + @property + def client(self): + if self._client is None: + assert self.path is not None + self._client = self.plasma.connect(self.path, num_retries=200) + return self._client + + def __getstate__(self): + """Called on pickle load""" + if self.plasma is None: + return self.__dict__ + if self.object_id is None: + self.start_server() + self.object_id = self.client.put(self.array) + state = self.__dict__.copy() + del state["array"] + state["_client"] = None + state["_server"] = None + state["_server_tmp"] = None + state["_plasma"] = None + return state + + def __setstate__(self, state): + """Called on pickle save""" + self.__dict__.update(state) + if self.plasma is None: + return + self.array = self.client.get(self.object_id) + + def __del__(self): + if self._server is not None: + self._server.kill() + self._server = None + self._server_tmp.close() + self._server_tmp = None + + +DEFAULT_PLASMA_PATH = "/tmp/plasma" + + +class PlasmaView: + """Interface to write and read from shared memory. Whereas PlasmaArray writes to plasma on serialization, + PlasmaView writes to shared memory on instantiation.""" + + def __init__(self, array, split_path: str, hash_data: Hashable, plasma_path=None): + """ + Args: + array: numpy array to store. This can be read with ``PlasmaView().array`` + split_path: the path whence the data was read, used for hashing + hash_data: other metadata about the array that can be used to create a unique key. + as of writing, the 3 callers in ``TokenBlockDataset`` use:: + + hash_data = ((block_size, document_sep_len, str(break_mode), len(dataset)), 0|1|2) + + + """ + assert PYARROW_AVAILABLE + assert split_path is not None + if plasma_path is None: + plasma_path = DEFAULT_PLASMA_PATH + + self.path = plasma_path + self.split_path = split_path + self._client = None # Initialize lazily for pickle. plasma clients should not be deep copied or serialized. + self._n = None + + self.object_id = self.get_object_id(self.split_path, hash_data) + try: + self.client.put(array, object_id=self.object_id) + except plasma.PlasmaObjectExists: + pass + + @property + def client(self): + if self._client is None: + self._client = plasma.connect(self.path, num_retries=200) + return self._client + + @property + def array(self): + """Fetch a read only view of an np.array, stored in plasma.""" + ret = self.client.get(self.object_id) + return ret + + @staticmethod + def get_object_id(split_path: str, hash_data: Hashable): + """Returns plasma.ObjectID from hashing split_path and object_num.""" + hash = hashlib.blake2b(bytes(split_path, "utf-8"), digest_size=20) + harg = json.dumps(hash_data).encode("utf-8") + hash.update(harg) + return plasma.ObjectID(hash.digest()) + + def __getstate__(self): + """Called on pickle save""" + self.disconnect() + state = self.__dict__.copy() + assert state["_client"] is None + assert "object_id" in state + return state + + def __setstate__(self, state): + """Called on pickle load""" + self.__dict__.update(state) + + def __del__(self): + self.disconnect() + + def disconnect(self): + if self._client is not None: + self._client.disconnect() + self._client = None + + def __len__(self): + """Save reads by caching len""" + if self._n is None: + self._n = len(self.array) + return self._n + + +GB100 = (1024**3) * 100 + + +class PlasmaStore: + def __init__(self, path=DEFAULT_PLASMA_PATH, nbytes: int = GB100): + + self.server = self.start(path, nbytes) + + def __del__(self): + self.server.kill() + + @staticmethod + def start(path=DEFAULT_PLASMA_PATH, nbytes: int = GB100) -> subprocess.Popen: + if not PYARROW_AVAILABLE: + raise ImportError("please run pip install pyarrow to use --use_plasma_view") + # best practice is to allocate more space than we need. The limitation seems to be the size of /dev/shm + _server = subprocess.Popen(["plasma_store", "-m", str(nbytes), "-s", path]) + plasma.connect(path, num_retries=200) # If we can't connect we fail immediately + return _server diff --git a/fairseq/data/prepend_dataset.py b/fairseq/data/prepend_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ad74784d2d7920e4a6225282d95543ce16ea50d9 --- /dev/null +++ b/fairseq/data/prepend_dataset.py @@ -0,0 +1,28 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from . import BaseWrapperDataset + + +class PrependDataset(BaseWrapperDataset): + def __init__(self, dataset, prepend_getter, ensure_first_token_is=None): + super().__init__(dataset) + self.prepend_getter = prepend_getter + self.ensure_first_token = ensure_first_token_is + + def __getitem__(self, idx): + item = self.dataset[idx] + is_tuple = isinstance(item, tuple) + src = item[0] if is_tuple else item + + assert self.ensure_first_token is None or src[0] == self.ensure_first_token + prepend_idx = self.prepend_getter(self.dataset, idx) + assert isinstance(prepend_idx, int) + src[0] = prepend_idx + item = tuple((src,) + item[1:]) if is_tuple else src + return item diff --git a/fairseq/data/prepend_token_dataset.py b/fairseq/data/prepend_token_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fd1331f4c44c1595eb9bb78baa0cf5cf3bcce9ad --- /dev/null +++ b/fairseq/data/prepend_token_dataset.py @@ -0,0 +1,41 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from . import BaseWrapperDataset + + +class PrependTokenDataset(BaseWrapperDataset): + def __init__(self, dataset, token=None): + super().__init__(dataset) + self.token = token + if token is not None: + self._sizes = np.array(dataset.sizes) + 1 + else: + self._sizes = dataset.sizes + + def __getitem__(self, idx): + item = self.dataset[idx] + if self.token is not None: + item = torch.cat([item.new([self.token]), item]) + return item + + @property + def sizes(self): + return self._sizes + + def num_tokens(self, index): + n = self.dataset.num_tokens(index) + if self.token is not None: + n += 1 + return n + + def size(self, index): + n = self.dataset.size(index) + if self.token is not None: + n += 1 + return n diff --git a/fairseq/data/raw_label_dataset.py b/fairseq/data/raw_label_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d054904f419bd64855d33a2a770b43f671c7c8d8 --- /dev/null +++ b/fairseq/data/raw_label_dataset.py @@ -0,0 +1,23 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from . import FairseqDataset + + +class RawLabelDataset(FairseqDataset): + def __init__(self, labels): + super().__init__() + self.labels = labels + + def __getitem__(self, index): + return self.labels[index] + + def __len__(self): + return len(self.labels) + + def collater(self, samples): + return torch.tensor(samples) diff --git a/fairseq/data/replace_dataset.py b/fairseq/data/replace_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5aac2ba96bee0a8bb65f4c9e56fa0b17248ee1d9 --- /dev/null +++ b/fairseq/data/replace_dataset.py @@ -0,0 +1,36 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import BaseWrapperDataset + + +class ReplaceDataset(BaseWrapperDataset): + """Replaces tokens found in the dataset by a specified replacement token + + Args: + dataset (~torch.utils.data.Dataset): dataset to replace tokens in + replace_map(Dictionary[int,int]): map of token to replace -> replacement token + offsets (List[int]): do not replace tokens before (from left if pos, right if neg) this offset. should be + as many as the number of objects returned by the underlying dataset __getitem__ method. + """ + + def __init__(self, dataset, replace_map, offsets): + super().__init__(dataset) + assert len(replace_map) > 0 + self.replace_map = replace_map + self.offsets = offsets + + def __getitem__(self, index): + item = self.dataset[index] + is_tuple = isinstance(item, tuple) + srcs = item if is_tuple else [item] + + for offset, src in zip(self.offsets, srcs): + for k, v in self.replace_map.items(): + src_off = src[offset:] if offset >= 0 else src[:offset] + src_off.masked_fill_(src_off == k, v) + + item = srcs if is_tuple else srcs[0] + return item diff --git a/fairseq/data/resampling_dataset.py b/fairseq/data/resampling_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2d77ed79d7b917f44602eae609df7abbd15ff0fd --- /dev/null +++ b/fairseq/data/resampling_dataset.py @@ -0,0 +1,139 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import numpy as np + +from fairseq.data import BaseWrapperDataset, plasma_utils + +logger = logging.getLogger(__name__) + + +class ResamplingDataset(BaseWrapperDataset): + """Randomly samples from a given dataset at each epoch. + + Sampling is done with or without replacement, depending on the "replace" + parameter. + + Optionally, the epoch size can be rescaled. This is potentially desirable + to increase per-epoch coverage of the base dataset (since sampling with + replacement means that many items in the dataset will be left out). In the + case of sampling without replacement, size_ratio should be strictly less + than 1. + + Args: + dataset (~torch.utils.data.Dataset): dataset on which to sample. + weights (List[float]): list of probability weights + (default: None, which corresponds to uniform sampling). + replace (bool): sampling mode; True for "with replacement", or False + for "without replacement" (default: True) + size_ratio (float): the ratio to subsample to; must be positive + (default: 1.0). + batch_by_size (bool): whether or not to batch by sequence length + (default: True). + seed (int): RNG seed to use (default: 0). + epoch (int): starting epoch number (default: 1). + """ + + def __init__( + self, + dataset, + weights=None, + replace=True, + size_ratio=1.0, + batch_by_size=True, + seed=0, + epoch=1, + ): + super().__init__(dataset) + + if weights is None: + self.weights = None + + else: + assert len(weights) == len(dataset) + weights_arr = np.array(weights, dtype=np.float64) + weights_arr /= weights_arr.sum() + self.weights = plasma_utils.PlasmaArray(weights_arr) + + self.replace = replace + + assert size_ratio > 0.0 + if not self.replace: + assert size_ratio < 1.0 + self.size_ratio = float(size_ratio) + self.actual_size = np.ceil(len(dataset) * self.size_ratio).astype(int) + + self.batch_by_size = batch_by_size + self.seed = seed + + self._cur_epoch = None + self._cur_indices = None + + self.set_epoch(epoch) + + def __getitem__(self, index): + return self.dataset[self._cur_indices.array[index]] + + def __len__(self): + return self.actual_size + + @property + def sizes(self): + if isinstance(self.dataset.sizes, list): + return [s[self._cur_indices.array] for s in self.dataset.sizes] + return self.dataset.sizes[self._cur_indices.array] + + def num_tokens(self, index): + return self.dataset.num_tokens(self._cur_indices.array[index]) + + def size(self, index): + return self.dataset.size(self._cur_indices.array[index]) + + def ordered_indices(self): + if self.batch_by_size: + order = [ + np.arange(len(self)), + self.sizes, + ] # No need to handle `self.shuffle == True` + return np.lexsort(order) + else: + return np.arange(len(self)) + + def prefetch(self, indices): + self.dataset.prefetch(self._cur_indices.array[indices]) + + @property + def can_reuse_epoch_itr_across_epochs(self): + return False + + def set_epoch(self, epoch): + logger.debug("ResamplingDataset.set_epoch: {}".format(epoch)) + super().set_epoch(epoch) + + if epoch == self._cur_epoch: + return + + self._cur_epoch = epoch + + # Generate a weighted sample of indices as a function of the + # random seed and the current epoch. + + rng = np.random.RandomState( + [ + 42, # magic number + self.seed % (2**32), # global seed + self._cur_epoch, # epoch index + ] + ) + self._cur_indices = plasma_utils.PlasmaArray( + rng.choice( + len(self.dataset), + self.actual_size, + replace=self.replace, + p=(None if self.weights is None else self.weights.array), + ) + ) diff --git a/fairseq/data/roll_dataset.py b/fairseq/data/roll_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a2915eeb3e8fb4dfb4b2bb33e0464ad0783d854c --- /dev/null +++ b/fairseq/data/roll_dataset.py @@ -0,0 +1,18 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from . import BaseWrapperDataset + + +class RollDataset(BaseWrapperDataset): + def __init__(self, dataset, shifts): + super().__init__(dataset) + self.shifts = shifts + + def __getitem__(self, index): + item = self.dataset[index] + return torch.roll(item, self.shifts) diff --git a/fairseq/data/round_robin_zip_datasets.py b/fairseq/data/round_robin_zip_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..2cb7447ea955a7c3ae7372f09ee426c08acd430e --- /dev/null +++ b/fairseq/data/round_robin_zip_datasets.py @@ -0,0 +1,160 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from collections import OrderedDict +from typing import Dict, Sequence + +import numpy as np + +from . import FairseqDataset, LanguagePairDataset + +logger = logging.getLogger(__name__) + + +class RoundRobinZipDatasets(FairseqDataset): + """Zip multiple :class:`~fairseq.data.FairseqDataset` instances together. + + Shorter datasets are repeated in a round-robin fashion to match the length + of the longest one. + + Args: + datasets (Dict[~fairseq.data.FairseqDataset]): a dictionary of + :class:`~fairseq.data.FairseqDataset` instances. + eval_key (str, optional): a key used at evaluation time that causes + this instance to pass-through batches from *datasets[eval_key]*. + """ + + def __init__(self, datasets, eval_key=None): + super().__init__() + if isinstance(datasets, dict): + datasets = OrderedDict(datasets) + assert isinstance(datasets, OrderedDict) + assert datasets, "Can't make a RoundRobinZipDatasets out of nothing" + for dataset in datasets.values(): + assert isinstance(dataset, FairseqDataset) + + self.datasets = datasets + self.eval_key = eval_key + + self.longest_dataset_key = max(datasets, key=lambda k: len(datasets[k])) + self.longest_dataset = datasets[self.longest_dataset_key] + self._ordered_indices: Dict[str, Sequence[int]] = None + + def _map_index(self, key, index): + assert ( + self._ordered_indices is not None + ), "Must call RoundRobinZipDatasets.ordered_indices() first" + o = self._ordered_indices[key] + return o[index % len(o)] + + def __getitem__(self, index): + if self.eval_key is None: + return OrderedDict( + [ + (key, dataset[self._map_index(key, index)]) + for key, dataset in self.datasets.items() + ] + ) + else: + # at evaluation time it's useful to pass-through batches from a single key + return self.datasets[self.eval_key][self._map_index(self.eval_key, index)] + + def __len__(self): + if self._ordered_indices is not None: + return len(self._ordered_indices[self.longest_dataset_key]) + return len(self.longest_dataset) + + def collater(self, samples): + """Merge a list of samples to form a mini-batch.""" + if len(samples) == 0: + return None + if self.eval_key is None: + return OrderedDict( + [ + (key, dataset.collater([sample[key] for sample in samples])) + for key, dataset in self.datasets.items() + ] + ) + else: + # at evaluation time it's useful to pass-through batches from a single key + return self.datasets[self.eval_key].collater(samples) + + def num_tokens(self, index): + """Return an example's length (number of tokens), used for batching.""" + # TODO make it configurable whether to use max() or sum() here + return max( + dataset.num_tokens(self._map_index(key, index)) + for key, dataset in self.datasets.items() + ) + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + return { + key: dataset.size(self._map_index(key, index)) + for key, dataset in self.datasets.items() + } + + def ordered_indices(self): + """Ordered indices for batching.""" + if self._ordered_indices is None: + # Call the underlying dataset's ordered_indices() here, so that we + # get the same random ordering as we would have from using the + # underlying sub-datasets directly. + self._ordered_indices = OrderedDict( + [ + (key, dataset.ordered_indices()) + for key, dataset in self.datasets.items() + ] + ) + return np.arange(len(self)) + + def filter_indices_by_size(self, indices, max_positions=None): + """ + Filter each sub-dataset independently, then update the round robin to work + on the filtered sub-datasets. + """ + + def _deep_until_language_pair(dataset): + if isinstance(dataset, LanguagePairDataset): + return dataset + if hasattr(dataset, "tgt_dataset"): + return _deep_until_language_pair(dataset.tgt_dataset) + if hasattr(dataset, "dataset"): + return _deep_until_language_pair(dataset.dataset) + raise Exception(f"Don't know how to unwrap this dataset: {dataset}") + + if not isinstance(max_positions, dict): + max_positions = {k: max_positions for k in self.datasets.keys()} + ignored_some = False + for key, dataset in self.datasets.items(): + dataset = _deep_until_language_pair(dataset) + self._ordered_indices[key], ignored = dataset.filter_indices_by_size( + self._ordered_indices[key], max_positions[key] + ) + if len(ignored) > 0: + ignored_some = True + logger.warning( + f"{len(ignored)} samples from {key} have invalid sizes and will be skipped, " + f"max_positions={max_positions[key]}, first few sample ids={ignored[:10]}" + ) + # Since we are modifying in place the _ordered_indices, + # it's not possible anymore to return valid ignored indices. + # Hopefully the extra debug information print above should be enough to debug. + # Ideally we would receive ignore_invalid_inputs so that we could have + # a proper error message. + return (np.arange(len(self)), [0] if ignored_some else []) + + @property + def supports_prefetch(self): + return all( + getattr(dataset, "supports_prefetch", False) + for dataset in self.datasets.values() + ) + + def prefetch(self, indices): + for key, dataset in self.datasets.items(): + dataset.prefetch([self._map_index(key, index) for index in indices]) diff --git a/fairseq/data/shorten_dataset.py b/fairseq/data/shorten_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6ebb5d88feb3f29d1512a0873df304915d051209 --- /dev/null +++ b/fairseq/data/shorten_dataset.py @@ -0,0 +1,78 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +from fairseq.data import data_utils + +from . import BaseWrapperDataset + + +class TruncateDataset(BaseWrapperDataset): + """Truncate a sequence by returning the first truncation_length tokens""" + + def __init__(self, dataset, truncation_length): + super().__init__(dataset) + assert truncation_length is not None + self.truncation_length = truncation_length + self.dataset = dataset + + def __getitem__(self, index): + item = self.dataset[index] + item_len = item.size(0) + if item_len > self.truncation_length: + item = item[: self.truncation_length] + return item + + @property + def sizes(self): + return np.minimum(self.dataset.sizes, self.truncation_length) + + def __len__(self): + return len(self.dataset) + + +class RandomCropDataset(TruncateDataset): + """Truncate a sequence by returning a random crop of truncation_length tokens""" + + def __init__(self, dataset, truncation_length, seed=1): + super().__init__(dataset, truncation_length) + self.seed = seed + self.epoch = 0 + + @property + def can_reuse_epoch_itr_across_epochs(self): + return True # only the crop changes, not item sizes + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + def __getitem__(self, index): + with data_utils.numpy_seed(self.seed, self.epoch, index): + item = self.dataset[index] + item_len = item.size(0) + excess = item_len - self.truncation_length + if excess > 0: + start_idx = np.random.randint(0, excess) + item = item[start_idx : start_idx + self.truncation_length] + return item + + +def maybe_shorten_dataset( + dataset, + split, + shorten_data_split_list, + shorten_method, + tokens_per_sample, + seed, +): + truncate_split = ( + split in shorten_data_split_list.split(",") or len(shorten_data_split_list) == 0 + ) + if shorten_method == "truncate" and truncate_split: + dataset = TruncateDataset(dataset, tokens_per_sample) + elif shorten_method == "random_crop" and truncate_split: + dataset = RandomCropDataset(dataset, tokens_per_sample, seed) + return dataset diff --git a/fairseq/data/sort_dataset.py b/fairseq/data/sort_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b3890e7279e1f26db2e48ec0a91c639e9299d60f --- /dev/null +++ b/fairseq/data/sort_dataset.py @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np + +from . import BaseWrapperDataset + + +class SortDataset(BaseWrapperDataset): + def __init__(self, dataset, sort_order): + super().__init__(dataset) + if not isinstance(sort_order, (list, tuple)): + sort_order = [sort_order] + self.sort_order = sort_order + + assert all(len(so) == len(dataset) for so in sort_order) + + def ordered_indices(self): + return np.lexsort(self.sort_order) diff --git a/fairseq/data/span_mask_tokens_dataset.py b/fairseq/data/span_mask_tokens_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..72189bd3786ca4bbebb3889ced61a6875661e2d5 --- /dev/null +++ b/fairseq/data/span_mask_tokens_dataset.py @@ -0,0 +1,293 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from . import Dictionary, FairseqDataset, data_utils + + +def collate( + samples, + pad_idx, + eos_idx, + vocab, + left_pad_source=False, + left_pad_target=False, + input_feeding=True, + pad_to_length=None, +): + assert input_feeding + if len(samples) == 0: + return {} + + def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): + return data_utils.collate_tokens( + [s[key] for s in samples], + pad_idx, + eos_idx=None, # use eos_idx of each sample instead of vocab.eos() + left_pad=left_pad, + move_eos_to_beginning=move_eos_to_beginning, + pad_to_length=pad_to_length, + ) + + id = torch.LongTensor([s["id"] for s in samples]) + src_tokens = merge( + "source", + left_pad=left_pad_source, + pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, + ) + # sort by descending source length + src_lengths = torch.LongTensor([s["source"].numel() for s in samples]) + src_lengths, sort_order = src_lengths.sort(descending=True) + id = id.index_select(0, sort_order) + src_tokens = src_tokens.index_select(0, sort_order) + + prev_output_tokens = None + target = None + if samples[0].get("target", None) is not None: + target = merge( + "target", + left_pad=left_pad_target, + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, + ) + target = target.index_select(0, sort_order) + ntokens = sum(len(s["target"]) for s in samples) + + if input_feeding: + # we create a shifted version of targets for feeding the + # previous output token(s) into the next decoder step + prev_output_tokens = merge( + "target", + left_pad=left_pad_target, + move_eos_to_beginning=True, + pad_to_length=pad_to_length["target"] + if pad_to_length is not None + else None, + ) + prev_output_tokens = prev_output_tokens.index_select(0, sort_order) + else: + ntokens = sum(len(s["source"]) for s in samples) + + batch = { + "id": id, + "ntokens": ntokens, + "net_input": { + "src_tokens": src_tokens, + "src_lengths": src_lengths, + }, + "target": target, + "target_lengths": torch.LongTensor([len(t) for t in target]), + "nsentences": samples[0]["source"].size(0), + "sort_order": sort_order, + } + if prev_output_tokens is not None: + batch["net_input"]["prev_output_tokens"] = prev_output_tokens + + return batch + + +class SpanMaskedTokensDataset(FairseqDataset): + """ + A wrapper around TokenBlockDataset for T5 dataset. + + Args: + dataset (~torch.utils.data.Dataset): dataset to wrap + vocab (~fairseq.data.Dictionary): vocabulary + noise_density (float): fraction of the tokens to select as noise. + mean_noise_span_length (float): mean noise span length. + shuffle (bool, optional): shuffle the elements before batching. + Default: ``True`` + seed: Seed for random number generator for reproducibility. + """ + + def __init__( + self, + dataset: torch.utils.data.Dataset, + vocab: Dictionary, + noise_density: float, + mean_noise_span_length: float, + shuffle: bool, + seed: int = 1, + ): + self.dataset = dataset + self.vocab = vocab + self.seed = seed + self.noise_density = noise_density + self.mean_noise_span_length = mean_noise_span_length + self.shuffle = shuffle + self.epoch = 0 + + @property + def can_reuse_epoch_itr_across_epochs(self): + return True # only the noise changes, not item sizes + + def set_epoch(self, epoch, **unused): + self.epoch = epoch + + def __getitem__(self, index): + with data_utils.numpy_seed(self.seed, self.epoch, index): + item = self.dataset[index] + assert item[-1] == self.vocab.eos() + + noise_mask = self.random_spans_noise_mask(len(item)) + + source_sentinel_ids = self.create_sentinel_ids(noise_mask.astype(np.int8)) + source = self.filter_input_ids(item, source_sentinel_ids) + + target_sentinel_ids = self.create_sentinel_ids( + (~noise_mask).astype(np.int8) + ) + target = self.filter_input_ids(item, target_sentinel_ids) + + return { + "id": index, + "source": torch.from_numpy(source), + "target": torch.from_numpy(target), + } + + def random_spans_noise_mask(self, length): + + """ + This function is copy of `random_spans_helper `__ . + Noise mask consisting of random spans of noise tokens. + The number of noise tokens and the number of noise spans and non-noise spans + are determined deterministically as follows: + num_noise_tokens = round(length * noise_density) + num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length) + Spans alternate between non-noise and noise, beginning with non-noise. + Subject to the above restrictions, all masks are equally likely. + Args: + length: an int32 scalar (length of the incoming token sequence) + Returns: + a boolean tensor with shape [length] + """ + + orig_length = length + + num_noise_tokens = int(np.round(length * self.noise_density)) + # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens. + num_noise_tokens = min(max(num_noise_tokens, 1), length - 1) + num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length)) + + # avoid degeneracy by ensuring positive number of noise spans + num_noise_spans = max(num_noise_spans, 1) + num_nonnoise_tokens = length - num_noise_tokens + + # pick the lengths of the noise spans and the non-noise spans + def _random_segmentation(num_items, num_segments): + """ + Partition a sequence of items randomly into non-empty segments. + Args: + num_items: an integer scalar > 0 + num_segments: an integer scalar in [1, num_items] + Returns: + a Tensor with shape [num_segments] containing positive integers that add up to num_items + """ + mask_indices = np.arange(num_items - 1) < (num_segments - 1) + np.random.shuffle(mask_indices) + first_in_segment = np.pad(mask_indices, [[1, 0]]) + segment_id = np.cumsum(first_in_segment) + # count length of subsegments assuming that list is sorted + _, segment_length = np.unique(segment_id, return_counts=True) + return segment_length + + noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans) + nonnoise_span_lengths = _random_segmentation( + num_nonnoise_tokens, num_noise_spans + ) + + interleaved_span_lengths = np.reshape( + np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), + [num_noise_spans * 2], + ) + span_starts = np.cumsum(interleaved_span_lengths)[:-1] + span_start_indicator = np.zeros((length,), dtype=np.int8) + span_start_indicator[span_starts] = True + span_num = np.cumsum(span_start_indicator) + is_noise = np.equal(span_num % 2, 1) + + return is_noise[:orig_length] + + def create_sentinel_ids(self, mask_indices): + """ + Sentinel ids creation given the indices that should be masked. + The start indices of each mask are replaced by the sentinel ids in increasing + order. Consecutive mask indices to be deleted are replaced with `-1`. + """ + start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices + + sentinel_ids = np.where( + start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices + ) + # making sure all sentinel tokens are unique over the example + sentinel_ids = np.where(sentinel_ids != 0, len(self.vocab) - sentinel_ids, 0) + sentinel_ids -= mask_indices - start_indices + return sentinel_ids + + @staticmethod + def filter_input_ids(input_ids, sentinel_ids): + """ + Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting. + This will reduce the sequence length from `expanded_inputs_length` to `input_length`. + """ + input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids) + + # input_ids tokens and sentinel tokens are >= 0, tokens < 0 are + # masked tokens coming after sentinel tokens and should be removed + return input_ids_full[input_ids_full >= 0] + + def __len__(self): + return len(self.dataset) + + def collater(self, samples, pad_to_length=None): + """ + Merge a list of samples to form a mini-batch. + Args: + samples (List[dict]): samples to collate + Returns: + dict: a mini-batch of data + """ + return collate( + samples, + self.vocab.pad(), + self.vocab.eos(), + self.vocab, + pad_to_length=pad_to_length, + ) + + def num_tokens(self, index): + """Return the number of tokens in a sample. This value is used to + enforce ``--max-tokens`` during batching.""" + return self.dataset.sizes[index] + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + return self.dataset.sizes[index] + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + indices = np.random.permutation(len(self)) + else: + indices = np.arange(len(self)) + return indices[np.argsort(self.dataset.sizes[indices], kind="mergesort")] + + def prefetch(self, indices): + self.src.prefetch(indices) + self.tgt.prefetch(indices) + + @property + def supports_prefetch(self): + return ( + hasattr(self.src, "supports_prefetch") + and self.src.supports_prefetch + and hasattr(self.tgt, "supports_prefetch") + and self.tgt.supports_prefetch + ) diff --git a/fairseq/data/speech_dlm_dataset.py b/fairseq/data/speech_dlm_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..06c4808f0aaacd3191eadfdb1e03d49add2c3827 --- /dev/null +++ b/fairseq/data/speech_dlm_dataset.py @@ -0,0 +1,307 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections import OrderedDict + +import numpy as np +import torch + +from fairseq.data import FairseqDataset, MonolingualDataset, data_utils + + +class SpeechDLMDataset(FairseqDataset): + """The dataset used to train the SpeechDLM model as described in the paper: + https://arxiv.org/pdf/2203.16502.pdf + + The input datasets is expected to be a dict over channel names with the values + being instances of :class:`~fairseq.data.MonolingualDataset`. + + Each element of SpeechDLMDataset is a dictionary with the following keys: + - `id` (int) : index of the item + - `source` (OrderedDict[str, Tensor of shape (seq_len,)]) : dictionary over + channels with the values containing the input unit tokens + - `target_next` (OrderedDict[str, Tensor of shape (seq_len,)]) : dictionary + over channels with the values containing the next unit tokens (input + tokens shifted by 1). + Its value is None if 'next' not in self.targets + - `target_edge` (OrderedDict[str, Tensor of shape (dedup_seq_len,)]) : dictionary + over channels with the values containing the edge unit tokens (input tokens + deduplicated). + Its value is None if 'edge' not in self.targets + - `target_duration` (OrderedDict[str, Tensor of shape (dedup_seq_len,)]) : + dictionary over channels with the values being the durations of the edge units. + Its value is None if 'duration' not in targets. + - `target_edge_indices` (OrderedDict[str, Tensor of shape (dedup_seq_len,)]) : + dictionary over channels with the values being the indices of the edge units + in the source sequence. + Its value is None if neither 'edge' or 'duration in targets. + + Args: + datasets (Dict[str, ~fairseq.data.MonolingualDataset]): a dictionary of + :class:`~fairseq.data.MonolingualDataset` instances. + targets (List[str]): list of the target types that the SpeechDLM model + should predict. Can be one of "next", "edge", "duration". + shuffle (bool, optional): shuffle the elements before batching + (default: True). + """ + + def __init__( + self, datasets, targets=None, max_target_durations=None, shuffle=False + ): + super().__init__() + if isinstance(datasets, dict): + datasets = OrderedDict(datasets) + assert isinstance( + datasets, OrderedDict + ), "datasets is expected to be an instance of Dictionary or OrderedDict" + assert datasets, "datasets is None" + for dataset in datasets.values(): + assert isinstance( + dataset, MonolingualDataset + ), "Each value of datasets is expected to be an instance of MonolingualDataset" + + self.datasets = datasets + self.targets = targets + if max_target_durations is not None and max_target_durations > 0: + self.max_target_durations = max_target_durations + else: + self.max_target_durations = float("inf") + self.sizes = next(iter(datasets.values())).sizes + self.vocab = next(iter(datasets.values())).vocab + self.length = len(next(iter(datasets.values()))) + self.shuffle = shuffle + + for channel, dataset in datasets.items(): + assert ( + len(dataset) == self.length + ), "[{}] length mismatch ({} vs {})".format( + channel, len(dataset), self.length + ) + assert (dataset.sizes == self.sizes).all(), "[{}] sizes mismatch".format( + channel + ) + + assert ( + dataset.vocab.pad() == self.vocab.pad() + ), "pad token is expected to be the same" + assert ( + dataset.vocab.eos() == self.vocab.eos() + ), "eos token is expected to be the same" + assert ( + dataset.vocab.bos() == self.vocab.bos() + ), "bos token is expected to be the same" + assert ( + dataset.vocab.unk() == self.vocab.unk() + ), "unk token is expected to be the same" + + def __getitem__(self, index): + source = OrderedDict( + [ + (key, dataset[index]["source"]) + for (key, dataset) in self.datasets.items() + ] + ) + + item = { + "id": index, + "source": source, + "target_next": None, + "target_edge": None, + "target_duration": None, + "target_edge_indices": None, + } + + if self.targets is not None: + for channel in self.datasets: + target = self._get_target(index, channel) + for t in target: + if item[f"target_{t}"] is None: + item[f"target_{t}"] = OrderedDict() + item[f"target_{t}"][channel] = target[t] + + return item + + def __len__(self): + return self.length + + def _get_target(self, index, channel): + """Get target in one of ['next', 'edge', 'duration'] + - 'next' is the future unit + - 'edge' is the edge unit + - 'duration' is the duration of the edge unit + """ + if self.targets is not None: + target = {} + pad_idx = self.vocab.pad() + max_dur = self.max_target_durations + future_target = self.datasets[channel][index]["target"] + if "edge" in self.targets or "duration" in self.targets: + edge_units, edge_unit_counts = torch.unique_consecutive( + future_target, return_counts=True + ) + padding_end = edge_units[-1] == pad_idx + if padding_end: + edge_units = edge_units[:-1] + edge_unit_counts = edge_unit_counts[:-1] + edge_indices = torch.cumsum(edge_unit_counts, 0) + edge_indices = torch.cat([torch.tensor([0]), edge_indices[:-1]]) + target["edge_indices"] = edge_indices + + for t in self.targets: + if t == "next": + target[t] = future_target + elif t == "edge": + target[t] = edge_units + elif t == "duration": + # count the remaining duration of the last edge indices in the next sentence + if not padding_end and index < len(self.datasets[channel]) - 1: + i = 0 + next_sentence_target = self.datasets[channel][index + 1][ + "target" + ] + while ( + next_sentence_target[i] == edge_units[-1] + and edge_unit_counts[-1] + i < max_dur + ): + i += 1 + edge_unit_counts[-1] += i + + # cut off to the maximal threshold + if max_dur: + edge_unit_counts[edge_unit_counts > max_dur] = max_dur + + target[t] = edge_unit_counts + else: + raise Exception("invalid target " + t) + + return target + + def collater(self, samples): + """Merge a list of samples to form a mini-batch. + + Args: + samples (List[dict]): samples to collate + + Returns: + dict: a mini-batch with the following keys: + + - `id` (LongTensor): example IDs in the original input order + - `ntokens` (int): total number of tokens in the batch + - `net_input` (dict): the input to the Model, containing keys: + + - `src_tokens` (OrderedDict[str, LongTensor]): dictionary + over channel with the values being padded 2D Tensor of + samples `source` of shape `(bsz, src_len)`. + Padding will appear on the right. + - `src_lengths` (LongTensor): lengths of source sentences + in the mini-batch + + - `target` (dict): the target of the Model, containing keys: + + - `next` (OrderedDict[str, LongTensor]): dictionary + over channel with the values being padded 2D Tensor of + batch samples' `target_next` of shape `(bsz, tgt_len)`. + Padding will appear on the right. + - `edge` (OrderedDict[str, LongTensor]): dictionary + over channel with the values being the concatenated + 1D Tensor of batch samples' `target_edge` of shape + `(sum of dedup_tgt_len,)` + - `duration` (OrderedDict[str, LongTensor]): dictionary + over channel with the values being the concatenated + 1D Tensor of batch samples' `target_duration` of shape + `(sum of dedup_tgt_len,)` + - `edge_indices` (OrderedDict[str, LongTensor]): dictionary + over channel with the values being the concatenated + 1D Tensor of batch samples' `target_edge_indices` of + shape `(sum of dedup_tgt_len,)`. + The indices are added to multiplies of batch size + such that they are the actual indices in the flatten + `src_tokens` Tensor + """ + if len(samples) == 0: + return {} + + pad_idx = self.vocab.pad() + eos_idx = self.vocab.eos() + + def merge(key, max_size=None): + if samples[0][key] is None: + return None + res = OrderedDict() + for channel in samples[0][key]: + if key in ["source", "target_next"]: + # fill batch of shape: (batch_size, max_size) + res[channel] = data_utils.collate_tokens( + [s[key][channel] for s in samples], + pad_idx, + eos_idx, + left_pad=False, + ) + elif key in ["target_edge", "target_duration"]: + # concatenate the edge units/duration + res[channel] = torch.cat([s[key][channel] for s in samples]) + elif key == "target_edge_indices": + # increase the edge indices to the indices in the flatten batch + res[channel] = torch.cat( + [s[key][channel] + i * max_size for i, s in enumerate(samples)] + ) + + return res + + src_tokens = merge("source") + tgt_next = merge("target_next") + tgt_edge = merge("target_edge") + tgt_duration = merge("target_duration") + tgt_edge_indices = merge( + "target_edge_indices", max_size=next(iter(src_tokens.values())).size(-1) + ) + return { + "id": torch.LongTensor([s["id"] for s in samples]), + "nsentences": len(samples), + "ntokens": sum(len(item) for s in samples for item in s["source"].values()), + "net_input": { + "src_tokens": src_tokens, + "src_lengths": torch.LongTensor( + [next(iter(s["source"].values())).numel() for s in samples] + ), + }, + "target": { + "next": tgt_next, + "edge": tgt_edge, + "duration": tgt_duration, + "edge_indices": tgt_edge_indices, + }, + } + + def num_tokens(self, index): + """Return the number of tokens in a sample. This value is used to + enforce ``--max-tokens`` during batching.""" + return self.sizes[index] + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + return self.sizes[index] + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + order = [np.random.permutation(len(self))] + else: + order = [np.arange(len(self))] + order.append(self.sizes) + return np.lexsort(order) + + @property + def supports_prefetch(self): + return all( + getattr(dataset, "supports_prefetch", False) + for dataset in self.datasets.values() + ) + + def prefetch(self, indices): + for key, dataset in self.datasets.items(): + dataset.prefetch(indices) diff --git a/fairseq/data/strip_token_dataset.py b/fairseq/data/strip_token_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cae39ba4d2f8106398eccd7eb0cf5c2194ec0db5 --- /dev/null +++ b/fairseq/data/strip_token_dataset.py @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import BaseWrapperDataset + + +class StripTokenDataset(BaseWrapperDataset): + def __init__(self, dataset, id_to_strip): + super().__init__(dataset) + self.id_to_strip = id_to_strip + + def __getitem__(self, index): + item = self.dataset[index] + while len(item) > 0 and item[-1] == self.id_to_strip: + item = item[:-1] + while len(item) > 0 and item[0] == self.id_to_strip: + item = item[1:] + return item diff --git a/fairseq/data/subsample_dataset.py b/fairseq/data/subsample_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5c7e2ac864613a10b1886bca78cbc53f5bfd64 --- /dev/null +++ b/fairseq/data/subsample_dataset.py @@ -0,0 +1,75 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import logging + +import numpy as np +from fairseq.data.data_utils import numpy_seed + +from . import BaseWrapperDataset + + +logger = logging.getLogger(__name__) + + +class SubsampleDataset(BaseWrapperDataset): + """Subsamples a given dataset by a specified ratio. Subsampling is done on the number of examples + + Args: + dataset (~torch.utils.data.Dataset): dataset to subsample + size_ratio(float): the ratio to subsample to. must be between 0 and 1 (exclusive) + """ + + def __init__(self, dataset, size_ratio, shuffle=False, seed=None): + super().__init__(dataset) + assert size_ratio < 1 + self.actual_size = np.ceil(len(dataset) * size_ratio).astype(int) + with numpy_seed(seed) if seed is not None else contextlib.ExitStack(): + self.indices = np.random.choice( + list(range(len(self.dataset))), self.actual_size, replace=False + ) + self.shuffle = shuffle + logger.info( + "subsampled dataset from {} to {} (ratio={})".format( + len(self.dataset), self.actual_size, size_ratio + ) + ) + + def __getitem__(self, index): + return self.dataset[self.indices[index]] + + def __len__(self): + return self.actual_size + + def collater(self, samples): + return self.dataset.collater(samples) + + @property + def sizes(self): + return self.dataset.sizes[self.indices] + + @property + def name(self): + return self.dataset.name + + def num_tokens(self, index): + return self.dataset.num_tokens(self.indices[index]) + + def size(self, index): + return self.dataset.size(self.indices[index]) + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + order = [np.random.permutation(len(self))] + else: + order = [np.arange(len(self))] + order.append(self.sizes) + return np.lexsort(order) + + def prefetch(self, indices): + self.dataset.prefetch(self.indices[indices]) diff --git a/fairseq/data/text_compressor.py b/fairseq/data/text_compressor.py new file mode 100644 index 0000000000000000000000000000000000000000..d699f2ea296f33cdc37ca152ab225d09cb04b5ea --- /dev/null +++ b/fairseq/data/text_compressor.py @@ -0,0 +1,58 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum + + +class TextCompressionLevel(Enum): + none = 0 + low = 1 + high = 2 + + +class TextCompressor(object): + def __init__( + self, level: TextCompressionLevel, max_input_byte_length: int = 2**16 + ): + self.level = level + self.max_input_length = max_input_byte_length + + def compress(self, text: str) -> bytes: + if self.level == TextCompressionLevel.low: + import zlib + + # zlib: built-in, fast + return zlib.compress(text.encode(), level=0) + elif self.level == TextCompressionLevel.high: + try: + import unishox2 + + # unishox2: optimized for short text but slower + except ImportError: + raise ImportError( + "Please install unishox2 for the text compression feature: " + "pip install unishox2-py3" + ) + assert len(text.encode()) <= self.max_input_length + return unishox2.compress(text)[0] + else: + return text.encode() + + def decompress(self, compressed: bytes) -> str: + if self.level == TextCompressionLevel.low: + import zlib + + return zlib.decompress(compressed).decode() + elif self.level == TextCompressionLevel.high: + try: + import unishox2 + except ImportError: + raise ImportError( + "Please install unishox2 for the text compression feature: " + "pip install unishox2-py3" + ) + return unishox2.decompress(compressed, self.max_input_length) + else: + return compressed.decode() diff --git a/fairseq/data/token_block_dataset.py b/fairseq/data/token_block_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a414e7ef64193b4c9e285e357350c09663dd2d8f --- /dev/null +++ b/fairseq/data/token_block_dataset.py @@ -0,0 +1,206 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from fairseq.data import FairseqDataset, plasma_utils +from fairseq.data.indexed_dataset import best_fitting_int_dtype +from typing import Tuple + + +class TokenBlockDataset(FairseqDataset): + """Break a Dataset of tokens into blocks. + + Args: + dataset (~torch.utils.data.Dataset): dataset to break into blocks + sizes (List[int]): sentence lengths (required for 'complete' and 'eos') + block_size (int): maximum block size (ignored in 'eos' break mode) + break_mode (str, optional): Mode used for breaking tokens. Values can + be one of: + - 'none': break tokens into equally sized blocks (up to block_size) + - 'complete': break tokens into blocks (up to block_size) such that + blocks contains complete sentences, although block_size may be + exceeded if some sentences exceed block_size + - 'complete_doc': similar to 'complete' mode, but do not + cross document boundaries + - 'eos': each block contains one sentence (block_size is ignored) + include_targets (bool, optional): return next tokens as targets + (default: False). + document_sep_len (int, optional): document separator size (required for + 'complete_doc' break mode). Typically 1 if the sentences have eos + and 0 otherwise. + """ + + def __init__( + self, + dataset, + sizes, + block_size, + pad, + eos, + break_mode=None, + include_targets=False, + document_sep_len=1, + use_plasma_view=False, + split_path=None, + plasma_path=None, + ): + + super().__init__() + self.dataset = dataset + self.pad = pad + self.eos = eos + self.include_targets = include_targets + + assert len(dataset) > 0 + + assert len(dataset) == len(sizes) + _sizes, block_to_dataset_index, slice_indices = self._build_slice_indices( + sizes, break_mode, document_sep_len, block_size + ) + if use_plasma_view: + plasma_id = (block_size, document_sep_len, str(break_mode), len(dataset)) + self._slice_indices = plasma_utils.PlasmaView( + slice_indices, split_path, (plasma_id, 0), plasma_path=plasma_path + ) + self._sizes = plasma_utils.PlasmaView( + _sizes, split_path, (plasma_id, 1), plasma_path=plasma_path + ) + self._block_to_dataset_index = plasma_utils.PlasmaView( + block_to_dataset_index, + split_path, + (plasma_id, 2), + plasma_path=plasma_path, + ) + else: + self._slice_indices = plasma_utils.PlasmaArray(slice_indices) + self._sizes = plasma_utils.PlasmaArray(_sizes) + self._block_to_dataset_index = plasma_utils.PlasmaArray( + block_to_dataset_index + ) + + @staticmethod + def _build_slice_indices( + sizes, break_mode, document_sep_len, block_size + ) -> Tuple[np.ndarray]: + """Use token_block_utils_fast to build arrays for indexing into self.dataset""" + try: + from fairseq.data.token_block_utils_fast import ( + _get_slice_indices_fast, + _get_block_to_dataset_index_fast, + ) + except ImportError: + raise ImportError( + "Please build Cython components with: `pip install --editable .` " + "or `python setup.py build_ext --inplace`" + ) + + if isinstance(sizes, list): + sizes = np.array(sizes, dtype=np.int64) + else: + if torch.is_tensor(sizes): + sizes = sizes.numpy() + sizes = sizes.astype(np.int64) + + break_mode = break_mode if break_mode is not None else "none" + + # For "eos" break-mode, block_size is not required parameters. + if break_mode == "eos" and block_size is None: + block_size = 0 + + slice_indices = _get_slice_indices_fast( + sizes, str(break_mode), block_size, document_sep_len + ) + _sizes = slice_indices[:, 1] - slice_indices[:, 0] + + # build index mapping block indices to the underlying dataset indices + if break_mode == "eos": + # much faster version for eos break mode + block_to_dataset_index = np.stack( + [ + np.arange(len(sizes)), # starting index in dataset + np.zeros( + len(sizes), dtype=np.compat.long + ), # starting offset within starting index + np.arange(len(sizes)), # ending index in dataset + ], + 1, + ) + else: + block_to_dataset_index = _get_block_to_dataset_index_fast( + sizes, + slice_indices, + ) + size_dtype = np.uint16 if block_size < 65535 else np.uint32 + num_tokens = slice_indices[-1].max() + slice_indices_dtype = best_fitting_int_dtype(num_tokens) + slice_indices = slice_indices.astype(slice_indices_dtype) + _sizes = _sizes.astype(size_dtype) + block_to_dataset_index = block_to_dataset_index.astype(slice_indices_dtype) + return _sizes, block_to_dataset_index, slice_indices + + @property + def slice_indices(self): + return self._slice_indices.array + + @property + def sizes(self): + return self._sizes.array + + @property + def block_to_dataset_index(self): + return self._block_to_dataset_index.array + + def attr(self, attr: str, index: int): + start_ds_idx, _, _ = self.block_to_dataset_index[index] + return self.dataset.attr(attr, start_ds_idx) + + def __getitem__(self, index): + start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index] + + buffer = torch.cat( + [self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)] + ) + slice_s, slice_e = self.slice_indices[index] + length = slice_e - slice_s + s, e = start_offset, start_offset + length + item = buffer[s:e] + + if self.include_targets: + # *target* is the original sentence (=item) + # *source* is shifted right by 1 (maybe left-padded with eos) + # *past_target* is shifted right by 2 (left-padded as needed) + if s == 0: + source = torch.cat([item.new([self.eos]), buffer[0 : e - 1]]) + past_target = torch.cat( + [item.new([self.pad, self.eos]), buffer[0 : e - 2]] + ) + else: + source = buffer[s - 1 : e - 1] + if s == 1: + past_target = torch.cat([item.new([self.eos]), buffer[0 : e - 2]]) + else: + past_target = buffer[s - 2 : e - 2] + + return source, item, past_target + + return item + + def __len__(self): + return len(self.slice_indices) + + @property + def supports_prefetch(self): + return getattr(self.dataset, "supports_prefetch", False) + + def prefetch(self, indices): + self.dataset.prefetch( + { + ds_idx + for index in indices + for start_ds_idx, _, end_ds_idx in [self.block_to_dataset_index[index]] + for ds_idx in range(start_ds_idx, end_ds_idx + 1) + } + ) diff --git a/fairseq/data/token_block_utils_fast.cpp b/fairseq/data/token_block_utils_fast.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2c01da68b1ff2342e7fe634a2f3aa15252791507 --- /dev/null +++ b/fairseq/data/token_block_utils_fast.cpp @@ -0,0 +1,33818 @@ +/* Generated by Cython 3.0.8 */ + +/* BEGIN: Cython Metadata +{ + "distutils": { + "depends": [ + "/tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/core/include/numpy/arrayobject.h", + "/tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/core/include/numpy/arrayscalars.h", + "/tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/core/include/numpy/ndarrayobject.h", + "/tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/core/include/numpy/ndarraytypes.h", + "/tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/core/include/numpy/ufuncobject.h" + ], + "extra_compile_args": [ + "-std=c++11", + "-O3", + "-DTORCH_API_INCLUDE_EXTENSION_H", + "-DPYBIND11_COMPILER_TYPE=\"_gcc\"", + "-DPYBIND11_STDLIB=\"_libstdcpp\"", + "-DPYBIND11_BUILD_ABI=\"_cxxabi1011\"", + "-DTORCH_EXTENSION_NAME=token_block_utils_fast", + "-D_GLIBCXX_USE_CXX11_ABI=0" + ], + "include_dirs": [ + "/tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/core/include", + "/tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/core/include" + ], + "language": "c++", + "name": "fairseq.data.token_block_utils_fast", + "sources": [ + "fairseq/data/token_block_utils_fast.pyx" + ] + }, + "module_name": "fairseq.data.token_block_utils_fast" +} +END: Cython Metadata */ + +#ifndef PY_SSIZE_T_CLEAN +#define PY_SSIZE_T_CLEAN +#endif /* PY_SSIZE_T_CLEAN */ +#if defined(CYTHON_LIMITED_API) && 0 + #ifndef Py_LIMITED_API + #if CYTHON_LIMITED_API+0 > 0x03030000 + #define Py_LIMITED_API CYTHON_LIMITED_API + #else + #define Py_LIMITED_API 0x03030000 + #endif + #endif +#endif + +#include "Python.h" +#ifndef Py_PYTHON_H + #error Python headers needed to compile C extensions, please install development version of Python. +#elif PY_VERSION_HEX < 0x02070000 || (0x03000000 <= PY_VERSION_HEX && PY_VERSION_HEX < 0x03030000) + #error Cython requires Python 2.7+ or Python 3.3+. +#else +#if defined(CYTHON_LIMITED_API) && CYTHON_LIMITED_API +#define __PYX_EXTRA_ABI_MODULE_NAME "limited" +#else +#define __PYX_EXTRA_ABI_MODULE_NAME "" +#endif +#define CYTHON_ABI "3_0_8" __PYX_EXTRA_ABI_MODULE_NAME +#define __PYX_ABI_MODULE_NAME "_cython_" CYTHON_ABI +#define __PYX_TYPE_MODULE_PREFIX __PYX_ABI_MODULE_NAME "." +#define CYTHON_HEX_VERSION 0x030008F0 +#define CYTHON_FUTURE_DIVISION 1 +#include +#ifndef offsetof + #define offsetof(type, member) ( (size_t) & ((type*)0) -> member ) +#endif +#if !defined(_WIN32) && !defined(WIN32) && !defined(MS_WINDOWS) + #ifndef __stdcall + #define __stdcall + #endif + #ifndef __cdecl + #define __cdecl + #endif + #ifndef __fastcall + #define __fastcall + #endif +#endif +#ifndef DL_IMPORT + #define DL_IMPORT(t) t +#endif +#ifndef DL_EXPORT + #define DL_EXPORT(t) t +#endif +#define __PYX_COMMA , +#ifndef HAVE_LONG_LONG + #define HAVE_LONG_LONG +#endif +#ifndef PY_LONG_LONG + #define PY_LONG_LONG LONG_LONG +#endif +#ifndef Py_HUGE_VAL + #define Py_HUGE_VAL HUGE_VAL +#endif +#define __PYX_LIMITED_VERSION_HEX PY_VERSION_HEX +#if defined(GRAALVM_PYTHON) + /* For very preliminary testing purposes. Most variables are set the same as PyPy. + The existence of this section does not imply that anything works or is even tested */ + #define CYTHON_COMPILING_IN_PYPY 0 + #define CYTHON_COMPILING_IN_CPYTHON 0 + #define CYTHON_COMPILING_IN_LIMITED_API 0 + #define CYTHON_COMPILING_IN_GRAAL 1 + #define CYTHON_COMPILING_IN_NOGIL 0 + #undef CYTHON_USE_TYPE_SLOTS + #define CYTHON_USE_TYPE_SLOTS 0 + #undef CYTHON_USE_TYPE_SPECS + #define CYTHON_USE_TYPE_SPECS 0 + #undef CYTHON_USE_PYTYPE_LOOKUP + #define CYTHON_USE_PYTYPE_LOOKUP 0 + #if PY_VERSION_HEX < 0x03050000 + #undef CYTHON_USE_ASYNC_SLOTS + #define CYTHON_USE_ASYNC_SLOTS 0 + #elif !defined(CYTHON_USE_ASYNC_SLOTS) + #define CYTHON_USE_ASYNC_SLOTS 1 + #endif + #undef CYTHON_USE_PYLIST_INTERNALS + #define CYTHON_USE_PYLIST_INTERNALS 0 + #undef CYTHON_USE_UNICODE_INTERNALS + #define CYTHON_USE_UNICODE_INTERNALS 0 + #undef CYTHON_USE_UNICODE_WRITER + #define CYTHON_USE_UNICODE_WRITER 0 + #undef CYTHON_USE_PYLONG_INTERNALS + #define CYTHON_USE_PYLONG_INTERNALS 0 + #undef CYTHON_AVOID_BORROWED_REFS + #define CYTHON_AVOID_BORROWED_REFS 1 + #undef CYTHON_ASSUME_SAFE_MACROS + #define CYTHON_ASSUME_SAFE_MACROS 0 + #undef CYTHON_UNPACK_METHODS + #define CYTHON_UNPACK_METHODS 0 + #undef CYTHON_FAST_THREAD_STATE + #define CYTHON_FAST_THREAD_STATE 0 + #undef CYTHON_FAST_GIL + #define CYTHON_FAST_GIL 0 + #undef CYTHON_METH_FASTCALL + #define CYTHON_METH_FASTCALL 0 + #undef CYTHON_FAST_PYCALL + #define CYTHON_FAST_PYCALL 0 + #ifndef CYTHON_PEP487_INIT_SUBCLASS + #define CYTHON_PEP487_INIT_SUBCLASS (PY_MAJOR_VERSION >= 3) + #endif + #undef CYTHON_PEP489_MULTI_PHASE_INIT + #define CYTHON_PEP489_MULTI_PHASE_INIT 1 + #undef CYTHON_USE_MODULE_STATE + #define CYTHON_USE_MODULE_STATE 0 + #undef CYTHON_USE_TP_FINALIZE + #define CYTHON_USE_TP_FINALIZE 0 + #undef CYTHON_USE_DICT_VERSIONS + #define CYTHON_USE_DICT_VERSIONS 0 + #undef CYTHON_USE_EXC_INFO_STACK + #define CYTHON_USE_EXC_INFO_STACK 0 + #ifndef CYTHON_UPDATE_DESCRIPTOR_DOC + #define CYTHON_UPDATE_DESCRIPTOR_DOC 0 + #endif +#elif defined(PYPY_VERSION) + #define CYTHON_COMPILING_IN_PYPY 1 + #define CYTHON_COMPILING_IN_CPYTHON 0 + #define CYTHON_COMPILING_IN_LIMITED_API 0 + #define CYTHON_COMPILING_IN_GRAAL 0 + #define CYTHON_COMPILING_IN_NOGIL 0 + #undef CYTHON_USE_TYPE_SLOTS + #define CYTHON_USE_TYPE_SLOTS 0 + #ifndef CYTHON_USE_TYPE_SPECS + #define CYTHON_USE_TYPE_SPECS 0 + #endif + #undef CYTHON_USE_PYTYPE_LOOKUP + #define CYTHON_USE_PYTYPE_LOOKUP 0 + #if PY_VERSION_HEX < 0x03050000 + #undef CYTHON_USE_ASYNC_SLOTS + #define CYTHON_USE_ASYNC_SLOTS 0 + #elif !defined(CYTHON_USE_ASYNC_SLOTS) + #define CYTHON_USE_ASYNC_SLOTS 1 + #endif + #undef CYTHON_USE_PYLIST_INTERNALS + #define CYTHON_USE_PYLIST_INTERNALS 0 + #undef CYTHON_USE_UNICODE_INTERNALS + #define CYTHON_USE_UNICODE_INTERNALS 0 + #undef CYTHON_USE_UNICODE_WRITER + #define CYTHON_USE_UNICODE_WRITER 0 + #undef CYTHON_USE_PYLONG_INTERNALS + #define CYTHON_USE_PYLONG_INTERNALS 0 + #undef CYTHON_AVOID_BORROWED_REFS + #define CYTHON_AVOID_BORROWED_REFS 1 + #undef CYTHON_ASSUME_SAFE_MACROS + #define CYTHON_ASSUME_SAFE_MACROS 0 + #undef CYTHON_UNPACK_METHODS + #define CYTHON_UNPACK_METHODS 0 + #undef CYTHON_FAST_THREAD_STATE + #define CYTHON_FAST_THREAD_STATE 0 + #undef CYTHON_FAST_GIL + #define CYTHON_FAST_GIL 0 + #undef CYTHON_METH_FASTCALL + #define CYTHON_METH_FASTCALL 0 + #undef CYTHON_FAST_PYCALL + #define CYTHON_FAST_PYCALL 0 + #ifndef CYTHON_PEP487_INIT_SUBCLASS + #define CYTHON_PEP487_INIT_SUBCLASS (PY_MAJOR_VERSION >= 3) + #endif + #if PY_VERSION_HEX < 0x03090000 + #undef CYTHON_PEP489_MULTI_PHASE_INIT + #define CYTHON_PEP489_MULTI_PHASE_INIT 0 + #elif !defined(CYTHON_PEP489_MULTI_PHASE_INIT) + #define CYTHON_PEP489_MULTI_PHASE_INIT 1 + #endif + #undef CYTHON_USE_MODULE_STATE + #define CYTHON_USE_MODULE_STATE 0 + #undef CYTHON_USE_TP_FINALIZE + #define CYTHON_USE_TP_FINALIZE (PY_VERSION_HEX >= 0x030400a1 && PYPY_VERSION_NUM >= 0x07030C00) + #undef CYTHON_USE_DICT_VERSIONS + #define CYTHON_USE_DICT_VERSIONS 0 + #undef CYTHON_USE_EXC_INFO_STACK + #define CYTHON_USE_EXC_INFO_STACK 0 + #ifndef CYTHON_UPDATE_DESCRIPTOR_DOC + #define CYTHON_UPDATE_DESCRIPTOR_DOC 0 + #endif +#elif defined(CYTHON_LIMITED_API) + #ifdef Py_LIMITED_API + #undef __PYX_LIMITED_VERSION_HEX + #define __PYX_LIMITED_VERSION_HEX Py_LIMITED_API + #endif + #define CYTHON_COMPILING_IN_PYPY 0 + #define CYTHON_COMPILING_IN_CPYTHON 0 + #define CYTHON_COMPILING_IN_LIMITED_API 1 + #define CYTHON_COMPILING_IN_GRAAL 0 + #define CYTHON_COMPILING_IN_NOGIL 0 + #undef CYTHON_CLINE_IN_TRACEBACK + #define CYTHON_CLINE_IN_TRACEBACK 0 + #undef CYTHON_USE_TYPE_SLOTS + #define CYTHON_USE_TYPE_SLOTS 0 + #undef CYTHON_USE_TYPE_SPECS + #define CYTHON_USE_TYPE_SPECS 1 + #undef CYTHON_USE_PYTYPE_LOOKUP + #define CYTHON_USE_PYTYPE_LOOKUP 0 + #undef CYTHON_USE_ASYNC_SLOTS + #define CYTHON_USE_ASYNC_SLOTS 0 + #undef CYTHON_USE_PYLIST_INTERNALS + #define CYTHON_USE_PYLIST_INTERNALS 0 + #undef CYTHON_USE_UNICODE_INTERNALS + #define CYTHON_USE_UNICODE_INTERNALS 0 + #ifndef CYTHON_USE_UNICODE_WRITER + #define CYTHON_USE_UNICODE_WRITER 0 + #endif + #undef CYTHON_USE_PYLONG_INTERNALS + #define CYTHON_USE_PYLONG_INTERNALS 0 + #ifndef CYTHON_AVOID_BORROWED_REFS + #define CYTHON_AVOID_BORROWED_REFS 0 + #endif + #undef CYTHON_ASSUME_SAFE_MACROS + #define CYTHON_ASSUME_SAFE_MACROS 0 + #undef CYTHON_UNPACK_METHODS + #define CYTHON_UNPACK_METHODS 0 + #undef CYTHON_FAST_THREAD_STATE + #define CYTHON_FAST_THREAD_STATE 0 + #undef CYTHON_FAST_GIL + #define CYTHON_FAST_GIL 0 + #undef CYTHON_METH_FASTCALL + #define CYTHON_METH_FASTCALL 0 + #undef CYTHON_FAST_PYCALL + #define CYTHON_FAST_PYCALL 0 + #ifndef CYTHON_PEP487_INIT_SUBCLASS + #define CYTHON_PEP487_INIT_SUBCLASS 1 + #endif + #undef CYTHON_PEP489_MULTI_PHASE_INIT + #define CYTHON_PEP489_MULTI_PHASE_INIT 0 + #undef CYTHON_USE_MODULE_STATE + #define CYTHON_USE_MODULE_STATE 1 + #ifndef CYTHON_USE_TP_FINALIZE + #define CYTHON_USE_TP_FINALIZE 0 + #endif + #undef CYTHON_USE_DICT_VERSIONS + #define CYTHON_USE_DICT_VERSIONS 0 + #undef CYTHON_USE_EXC_INFO_STACK + #define CYTHON_USE_EXC_INFO_STACK 0 + #ifndef CYTHON_UPDATE_DESCRIPTOR_DOC + #define CYTHON_UPDATE_DESCRIPTOR_DOC 0 + #endif +#elif defined(Py_GIL_DISABLED) || defined(Py_NOGIL) + #define CYTHON_COMPILING_IN_PYPY 0 + #define CYTHON_COMPILING_IN_CPYTHON 0 + #define CYTHON_COMPILING_IN_LIMITED_API 0 + #define CYTHON_COMPILING_IN_GRAAL 0 + #define CYTHON_COMPILING_IN_NOGIL 1 + #ifndef CYTHON_USE_TYPE_SLOTS + #define CYTHON_USE_TYPE_SLOTS 1 + #endif + #undef CYTHON_USE_PYTYPE_LOOKUP + #define CYTHON_USE_PYTYPE_LOOKUP 0 + #ifndef CYTHON_USE_ASYNC_SLOTS + #define CYTHON_USE_ASYNC_SLOTS 1 + #endif + #undef CYTHON_USE_PYLIST_INTERNALS + #define CYTHON_USE_PYLIST_INTERNALS 0 + #ifndef CYTHON_USE_UNICODE_INTERNALS + #define CYTHON_USE_UNICODE_INTERNALS 1 + #endif + #undef CYTHON_USE_UNICODE_WRITER + #define CYTHON_USE_UNICODE_WRITER 0 + #undef CYTHON_USE_PYLONG_INTERNALS + #define CYTHON_USE_PYLONG_INTERNALS 0 + #ifndef CYTHON_AVOID_BORROWED_REFS + #define CYTHON_AVOID_BORROWED_REFS 0 + #endif + #ifndef CYTHON_ASSUME_SAFE_MACROS + #define CYTHON_ASSUME_SAFE_MACROS 1 + #endif + #ifndef CYTHON_UNPACK_METHODS + #define CYTHON_UNPACK_METHODS 1 + #endif + #undef CYTHON_FAST_THREAD_STATE + #define CYTHON_FAST_THREAD_STATE 0 + #undef CYTHON_FAST_PYCALL + #define CYTHON_FAST_PYCALL 0 + #ifndef CYTHON_PEP489_MULTI_PHASE_INIT + #define CYTHON_PEP489_MULTI_PHASE_INIT 1 + #endif + #ifndef CYTHON_USE_TP_FINALIZE + #define CYTHON_USE_TP_FINALIZE 1 + #endif + #undef CYTHON_USE_DICT_VERSIONS + #define CYTHON_USE_DICT_VERSIONS 0 + #undef CYTHON_USE_EXC_INFO_STACK + #define CYTHON_USE_EXC_INFO_STACK 0 +#else + #define CYTHON_COMPILING_IN_PYPY 0 + #define CYTHON_COMPILING_IN_CPYTHON 1 + #define CYTHON_COMPILING_IN_LIMITED_API 0 + #define CYTHON_COMPILING_IN_GRAAL 0 + #define CYTHON_COMPILING_IN_NOGIL 0 + #ifndef CYTHON_USE_TYPE_SLOTS + #define CYTHON_USE_TYPE_SLOTS 1 + #endif + #ifndef CYTHON_USE_TYPE_SPECS + #define CYTHON_USE_TYPE_SPECS 0 + #endif + #ifndef CYTHON_USE_PYTYPE_LOOKUP + #define CYTHON_USE_PYTYPE_LOOKUP 1 + #endif + #if PY_MAJOR_VERSION < 3 + #undef CYTHON_USE_ASYNC_SLOTS + #define CYTHON_USE_ASYNC_SLOTS 0 + #elif !defined(CYTHON_USE_ASYNC_SLOTS) + #define CYTHON_USE_ASYNC_SLOTS 1 + #endif + #ifndef CYTHON_USE_PYLONG_INTERNALS + #define CYTHON_USE_PYLONG_INTERNALS 1 + #endif + #ifndef CYTHON_USE_PYLIST_INTERNALS + #define CYTHON_USE_PYLIST_INTERNALS 1 + #endif + #ifndef CYTHON_USE_UNICODE_INTERNALS + #define CYTHON_USE_UNICODE_INTERNALS 1 + #endif + #if PY_VERSION_HEX < 0x030300F0 || PY_VERSION_HEX >= 0x030B00A2 + #undef CYTHON_USE_UNICODE_WRITER + #define CYTHON_USE_UNICODE_WRITER 0 + #elif !defined(CYTHON_USE_UNICODE_WRITER) + #define CYTHON_USE_UNICODE_WRITER 1 + #endif + #ifndef CYTHON_AVOID_BORROWED_REFS + #define CYTHON_AVOID_BORROWED_REFS 0 + #endif + #ifndef CYTHON_ASSUME_SAFE_MACROS + #define CYTHON_ASSUME_SAFE_MACROS 1 + #endif + #ifndef CYTHON_UNPACK_METHODS + #define CYTHON_UNPACK_METHODS 1 + #endif + #ifndef CYTHON_FAST_THREAD_STATE + #define CYTHON_FAST_THREAD_STATE 1 + #endif + #ifndef CYTHON_FAST_GIL + #define CYTHON_FAST_GIL (PY_MAJOR_VERSION < 3 || PY_VERSION_HEX >= 0x03060000 && PY_VERSION_HEX < 0x030C00A6) + #endif + #ifndef CYTHON_METH_FASTCALL + #define CYTHON_METH_FASTCALL (PY_VERSION_HEX >= 0x030700A1) + #endif + #ifndef CYTHON_FAST_PYCALL + #define CYTHON_FAST_PYCALL 1 + #endif + #ifndef CYTHON_PEP487_INIT_SUBCLASS + #define CYTHON_PEP487_INIT_SUBCLASS 1 + #endif + #if PY_VERSION_HEX < 0x03050000 + #undef CYTHON_PEP489_MULTI_PHASE_INIT + #define CYTHON_PEP489_MULTI_PHASE_INIT 0 + #elif !defined(CYTHON_PEP489_MULTI_PHASE_INIT) + #define CYTHON_PEP489_MULTI_PHASE_INIT 1 + #endif + #ifndef CYTHON_USE_MODULE_STATE + #define CYTHON_USE_MODULE_STATE 0 + #endif + #if PY_VERSION_HEX < 0x030400a1 + #undef CYTHON_USE_TP_FINALIZE + #define CYTHON_USE_TP_FINALIZE 0 + #elif !defined(CYTHON_USE_TP_FINALIZE) + #define CYTHON_USE_TP_FINALIZE 1 + #endif + #if PY_VERSION_HEX < 0x030600B1 + #undef CYTHON_USE_DICT_VERSIONS + #define CYTHON_USE_DICT_VERSIONS 0 + #elif !defined(CYTHON_USE_DICT_VERSIONS) + #define CYTHON_USE_DICT_VERSIONS (PY_VERSION_HEX < 0x030C00A5) + #endif + #if PY_VERSION_HEX < 0x030700A3 + #undef CYTHON_USE_EXC_INFO_STACK + #define CYTHON_USE_EXC_INFO_STACK 0 + #elif !defined(CYTHON_USE_EXC_INFO_STACK) + #define CYTHON_USE_EXC_INFO_STACK 1 + #endif + #ifndef CYTHON_UPDATE_DESCRIPTOR_DOC + #define CYTHON_UPDATE_DESCRIPTOR_DOC 1 + #endif +#endif +#if !defined(CYTHON_FAST_PYCCALL) +#define CYTHON_FAST_PYCCALL (CYTHON_FAST_PYCALL && PY_VERSION_HEX >= 0x030600B1) +#endif +#if !defined(CYTHON_VECTORCALL) +#define CYTHON_VECTORCALL (CYTHON_FAST_PYCCALL && PY_VERSION_HEX >= 0x030800B1) +#endif +#define CYTHON_BACKPORT_VECTORCALL (CYTHON_METH_FASTCALL && PY_VERSION_HEX < 0x030800B1) +#if CYTHON_USE_PYLONG_INTERNALS + #if PY_MAJOR_VERSION < 3 + #include "longintrepr.h" + #endif + #undef SHIFT + #undef BASE + #undef MASK + #ifdef SIZEOF_VOID_P + enum { __pyx_check_sizeof_voidp = 1 / (int)(SIZEOF_VOID_P == sizeof(void*)) }; + #endif +#endif +#ifndef __has_attribute + #define __has_attribute(x) 0 +#endif +#ifndef __has_cpp_attribute + #define __has_cpp_attribute(x) 0 +#endif +#ifndef CYTHON_RESTRICT + #if defined(__GNUC__) + #define CYTHON_RESTRICT __restrict__ + #elif defined(_MSC_VER) && _MSC_VER >= 1400 + #define CYTHON_RESTRICT __restrict + #elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L + #define CYTHON_RESTRICT restrict + #else + #define CYTHON_RESTRICT + #endif +#endif +#ifndef CYTHON_UNUSED + #if defined(__cplusplus) + /* for clang __has_cpp_attribute(maybe_unused) is true even before C++17 + * but leads to warnings with -pedantic, since it is a C++17 feature */ + #if ((defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) || __cplusplus >= 201703L) + #if __has_cpp_attribute(maybe_unused) + #define CYTHON_UNUSED [[maybe_unused]] + #endif + #endif + #endif +#endif +#ifndef CYTHON_UNUSED +# if defined(__GNUC__) +# if !(defined(__cplusplus)) || (__GNUC__ > 3 || (__GNUC__ == 3 && __GNUC_MINOR__ >= 4)) +# define CYTHON_UNUSED __attribute__ ((__unused__)) +# else +# define CYTHON_UNUSED +# endif +# elif defined(__ICC) || (defined(__INTEL_COMPILER) && !defined(_MSC_VER)) +# define CYTHON_UNUSED __attribute__ ((__unused__)) +# else +# define CYTHON_UNUSED +# endif +#endif +#ifndef CYTHON_UNUSED_VAR +# if defined(__cplusplus) + template void CYTHON_UNUSED_VAR( const T& ) { } +# else +# define CYTHON_UNUSED_VAR(x) (void)(x) +# endif +#endif +#ifndef CYTHON_MAYBE_UNUSED_VAR + #define CYTHON_MAYBE_UNUSED_VAR(x) CYTHON_UNUSED_VAR(x) +#endif +#ifndef CYTHON_NCP_UNUSED +# if CYTHON_COMPILING_IN_CPYTHON +# define CYTHON_NCP_UNUSED +# else +# define CYTHON_NCP_UNUSED CYTHON_UNUSED +# endif +#endif +#ifndef CYTHON_USE_CPP_STD_MOVE + #if defined(__cplusplus) && (\ + __cplusplus >= 201103L || (defined(_MSC_VER) && _MSC_VER >= 1600)) + #define CYTHON_USE_CPP_STD_MOVE 1 + #else + #define CYTHON_USE_CPP_STD_MOVE 0 + #endif +#endif +#define __Pyx_void_to_None(void_result) ((void)(void_result), Py_INCREF(Py_None), Py_None) +#ifdef _MSC_VER + #ifndef _MSC_STDINT_H_ + #if _MSC_VER < 1300 + typedef unsigned char uint8_t; + typedef unsigned short uint16_t; + typedef unsigned int uint32_t; + #else + typedef unsigned __int8 uint8_t; + typedef unsigned __int16 uint16_t; + typedef unsigned __int32 uint32_t; + #endif + #endif + #if _MSC_VER < 1300 + #ifdef _WIN64 + typedef unsigned long long __pyx_uintptr_t; + #else + typedef unsigned int __pyx_uintptr_t; + #endif + #else + #ifdef _WIN64 + typedef unsigned __int64 __pyx_uintptr_t; + #else + typedef unsigned __int32 __pyx_uintptr_t; + #endif + #endif +#else + #include + typedef uintptr_t __pyx_uintptr_t; +#endif +#ifndef CYTHON_FALLTHROUGH + #if defined(__cplusplus) + /* for clang __has_cpp_attribute(fallthrough) is true even before C++17 + * but leads to warnings with -pedantic, since it is a C++17 feature */ + #if ((defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) || __cplusplus >= 201703L) + #if __has_cpp_attribute(fallthrough) + #define CYTHON_FALLTHROUGH [[fallthrough]] + #endif + #endif + #ifndef CYTHON_FALLTHROUGH + #if __has_cpp_attribute(clang::fallthrough) + #define CYTHON_FALLTHROUGH [[clang::fallthrough]] + #elif __has_cpp_attribute(gnu::fallthrough) + #define CYTHON_FALLTHROUGH [[gnu::fallthrough]] + #endif + #endif + #endif + #ifndef CYTHON_FALLTHROUGH + #if __has_attribute(fallthrough) + #define CYTHON_FALLTHROUGH __attribute__((fallthrough)) + #else + #define CYTHON_FALLTHROUGH + #endif + #endif + #if defined(__clang__) && defined(__apple_build_version__) + #if __apple_build_version__ < 7000000 + #undef CYTHON_FALLTHROUGH + #define CYTHON_FALLTHROUGH + #endif + #endif +#endif +#ifdef __cplusplus + template + struct __PYX_IS_UNSIGNED_IMPL {static const bool value = T(0) < T(-1);}; + #define __PYX_IS_UNSIGNED(type) (__PYX_IS_UNSIGNED_IMPL::value) +#else + #define __PYX_IS_UNSIGNED(type) (((type)-1) > 0) +#endif +#if CYTHON_COMPILING_IN_PYPY == 1 + #define __PYX_NEED_TP_PRINT_SLOT (PY_VERSION_HEX >= 0x030800b4 && PY_VERSION_HEX < 0x030A0000) +#else + #define __PYX_NEED_TP_PRINT_SLOT (PY_VERSION_HEX >= 0x030800b4 && PY_VERSION_HEX < 0x03090000) +#endif +#define __PYX_REINTERPRET_FUNCION(func_pointer, other_pointer) ((func_pointer)(void(*)(void))(other_pointer)) + +#ifndef __cplusplus + #error "Cython files generated with the C++ option must be compiled with a C++ compiler." +#endif +#ifndef CYTHON_INLINE + #if defined(__clang__) + #define CYTHON_INLINE __inline__ __attribute__ ((__unused__)) + #else + #define CYTHON_INLINE inline + #endif +#endif +template +void __Pyx_call_destructor(T& x) { + x.~T(); +} +template +class __Pyx_FakeReference { + public: + __Pyx_FakeReference() : ptr(NULL) { } + __Pyx_FakeReference(const T& ref) : ptr(const_cast(&ref)) { } + T *operator->() { return ptr; } + T *operator&() { return ptr; } + operator T&() { return *ptr; } + template bool operator ==(const U& other) const { return *ptr == other; } + template bool operator !=(const U& other) const { return *ptr != other; } + template bool operator==(const __Pyx_FakeReference& other) const { return *ptr == *other.ptr; } + template bool operator!=(const __Pyx_FakeReference& other) const { return *ptr != *other.ptr; } + private: + T *ptr; +}; + +#define __PYX_BUILD_PY_SSIZE_T "n" +#define CYTHON_FORMAT_SSIZE_T "z" +#if PY_MAJOR_VERSION < 3 + #define __Pyx_BUILTIN_MODULE_NAME "__builtin__" + #define __Pyx_DefaultClassType PyClass_Type + #define __Pyx_PyCode_New(a, p, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos)\ + PyCode_New(a+k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos) +#else + #define __Pyx_BUILTIN_MODULE_NAME "builtins" + #define __Pyx_DefaultClassType PyType_Type +#if CYTHON_COMPILING_IN_LIMITED_API + static CYTHON_INLINE PyObject* __Pyx_PyCode_New(int a, int p, int k, int l, int s, int f, + PyObject *code, PyObject *c, PyObject* n, PyObject *v, + PyObject *fv, PyObject *cell, PyObject* fn, + PyObject *name, int fline, PyObject *lnos) { + PyObject *exception_table = NULL; + PyObject *types_module=NULL, *code_type=NULL, *result=NULL; + #if __PYX_LIMITED_VERSION_HEX < 0x030B0000 + PyObject *version_info; + PyObject *py_minor_version = NULL; + #endif + long minor_version = 0; + PyObject *type, *value, *traceback; + PyErr_Fetch(&type, &value, &traceback); + #if __PYX_LIMITED_VERSION_HEX >= 0x030B0000 + minor_version = 11; + #else + if (!(version_info = PySys_GetObject("version_info"))) goto end; + if (!(py_minor_version = PySequence_GetItem(version_info, 1))) goto end; + minor_version = PyLong_AsLong(py_minor_version); + Py_DECREF(py_minor_version); + if (minor_version == -1 && PyErr_Occurred()) goto end; + #endif + if (!(types_module = PyImport_ImportModule("types"))) goto end; + if (!(code_type = PyObject_GetAttrString(types_module, "CodeType"))) goto end; + if (minor_version <= 7) { + (void)p; + result = PyObject_CallFunction(code_type, "iiiiiOOOOOOiOO", a, k, l, s, f, code, + c, n, v, fn, name, fline, lnos, fv, cell); + } else if (minor_version <= 10) { + result = PyObject_CallFunction(code_type, "iiiiiiOOOOOOiOO", a,p, k, l, s, f, code, + c, n, v, fn, name, fline, lnos, fv, cell); + } else { + if (!(exception_table = PyBytes_FromStringAndSize(NULL, 0))) goto end; + result = PyObject_CallFunction(code_type, "iiiiiiOOOOOOOiOO", a,p, k, l, s, f, code, + c, n, v, fn, name, name, fline, lnos, exception_table, fv, cell); + } + end: + Py_XDECREF(code_type); + Py_XDECREF(exception_table); + Py_XDECREF(types_module); + if (type) { + PyErr_Restore(type, value, traceback); + } + return result; + } + #ifndef CO_OPTIMIZED + #define CO_OPTIMIZED 0x0001 + #endif + #ifndef CO_NEWLOCALS + #define CO_NEWLOCALS 0x0002 + #endif + #ifndef CO_VARARGS + #define CO_VARARGS 0x0004 + #endif + #ifndef CO_VARKEYWORDS + #define CO_VARKEYWORDS 0x0008 + #endif + #ifndef CO_ASYNC_GENERATOR + #define CO_ASYNC_GENERATOR 0x0200 + #endif + #ifndef CO_GENERATOR + #define CO_GENERATOR 0x0020 + #endif + #ifndef CO_COROUTINE + #define CO_COROUTINE 0x0080 + #endif +#elif PY_VERSION_HEX >= 0x030B0000 + static CYTHON_INLINE PyCodeObject* __Pyx_PyCode_New(int a, int p, int k, int l, int s, int f, + PyObject *code, PyObject *c, PyObject* n, PyObject *v, + PyObject *fv, PyObject *cell, PyObject* fn, + PyObject *name, int fline, PyObject *lnos) { + PyCodeObject *result; + PyObject *empty_bytes = PyBytes_FromStringAndSize("", 0); + if (!empty_bytes) return NULL; + result = + #if PY_VERSION_HEX >= 0x030C0000 + PyUnstable_Code_NewWithPosOnlyArgs + #else + PyCode_NewWithPosOnlyArgs + #endif + (a, p, k, l, s, f, code, c, n, v, fv, cell, fn, name, name, fline, lnos, empty_bytes); + Py_DECREF(empty_bytes); + return result; + } +#elif PY_VERSION_HEX >= 0x030800B2 && !CYTHON_COMPILING_IN_PYPY + #define __Pyx_PyCode_New(a, p, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos)\ + PyCode_NewWithPosOnlyArgs(a, p, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos) +#else + #define __Pyx_PyCode_New(a, p, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos)\ + PyCode_New(a, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos) +#endif +#endif +#if PY_VERSION_HEX >= 0x030900A4 || defined(Py_IS_TYPE) + #define __Pyx_IS_TYPE(ob, type) Py_IS_TYPE(ob, type) +#else + #define __Pyx_IS_TYPE(ob, type) (((const PyObject*)ob)->ob_type == (type)) +#endif +#if PY_VERSION_HEX >= 0x030A00B1 || defined(Py_Is) + #define __Pyx_Py_Is(x, y) Py_Is(x, y) +#else + #define __Pyx_Py_Is(x, y) ((x) == (y)) +#endif +#if PY_VERSION_HEX >= 0x030A00B1 || defined(Py_IsNone) + #define __Pyx_Py_IsNone(ob) Py_IsNone(ob) +#else + #define __Pyx_Py_IsNone(ob) __Pyx_Py_Is((ob), Py_None) +#endif +#if PY_VERSION_HEX >= 0x030A00B1 || defined(Py_IsTrue) + #define __Pyx_Py_IsTrue(ob) Py_IsTrue(ob) +#else + #define __Pyx_Py_IsTrue(ob) __Pyx_Py_Is((ob), Py_True) +#endif +#if PY_VERSION_HEX >= 0x030A00B1 || defined(Py_IsFalse) + #define __Pyx_Py_IsFalse(ob) Py_IsFalse(ob) +#else + #define __Pyx_Py_IsFalse(ob) __Pyx_Py_Is((ob), Py_False) +#endif +#define __Pyx_NoneAsNull(obj) (__Pyx_Py_IsNone(obj) ? NULL : (obj)) +#if PY_VERSION_HEX >= 0x030900F0 && !CYTHON_COMPILING_IN_PYPY + #define __Pyx_PyObject_GC_IsFinalized(o) PyObject_GC_IsFinalized(o) +#else + #define __Pyx_PyObject_GC_IsFinalized(o) _PyGC_FINALIZED(o) +#endif +#ifndef CO_COROUTINE + #define CO_COROUTINE 0x80 +#endif +#ifndef CO_ASYNC_GENERATOR + #define CO_ASYNC_GENERATOR 0x200 +#endif +#ifndef Py_TPFLAGS_CHECKTYPES + #define Py_TPFLAGS_CHECKTYPES 0 +#endif +#ifndef Py_TPFLAGS_HAVE_INDEX + #define Py_TPFLAGS_HAVE_INDEX 0 +#endif +#ifndef Py_TPFLAGS_HAVE_NEWBUFFER + #define Py_TPFLAGS_HAVE_NEWBUFFER 0 +#endif +#ifndef Py_TPFLAGS_HAVE_FINALIZE + #define Py_TPFLAGS_HAVE_FINALIZE 0 +#endif +#ifndef Py_TPFLAGS_SEQUENCE + #define Py_TPFLAGS_SEQUENCE 0 +#endif +#ifndef Py_TPFLAGS_MAPPING + #define Py_TPFLAGS_MAPPING 0 +#endif +#ifndef METH_STACKLESS + #define METH_STACKLESS 0 +#endif +#if PY_VERSION_HEX <= 0x030700A3 || !defined(METH_FASTCALL) + #ifndef METH_FASTCALL + #define METH_FASTCALL 0x80 + #endif + typedef PyObject *(*__Pyx_PyCFunctionFast) (PyObject *self, PyObject *const *args, Py_ssize_t nargs); + typedef PyObject *(*__Pyx_PyCFunctionFastWithKeywords) (PyObject *self, PyObject *const *args, + Py_ssize_t nargs, PyObject *kwnames); +#else + #define __Pyx_PyCFunctionFast _PyCFunctionFast + #define __Pyx_PyCFunctionFastWithKeywords _PyCFunctionFastWithKeywords +#endif +#if CYTHON_METH_FASTCALL + #define __Pyx_METH_FASTCALL METH_FASTCALL + #define __Pyx_PyCFunction_FastCall __Pyx_PyCFunctionFast + #define __Pyx_PyCFunction_FastCallWithKeywords __Pyx_PyCFunctionFastWithKeywords +#else + #define __Pyx_METH_FASTCALL METH_VARARGS + #define __Pyx_PyCFunction_FastCall PyCFunction + #define __Pyx_PyCFunction_FastCallWithKeywords PyCFunctionWithKeywords +#endif +#if CYTHON_VECTORCALL + #define __pyx_vectorcallfunc vectorcallfunc + #define __Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET PY_VECTORCALL_ARGUMENTS_OFFSET + #define __Pyx_PyVectorcall_NARGS(n) PyVectorcall_NARGS((size_t)(n)) +#elif CYTHON_BACKPORT_VECTORCALL + typedef PyObject *(*__pyx_vectorcallfunc)(PyObject *callable, PyObject *const *args, + size_t nargsf, PyObject *kwnames); + #define __Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET ((size_t)1 << (8 * sizeof(size_t) - 1)) + #define __Pyx_PyVectorcall_NARGS(n) ((Py_ssize_t)(((size_t)(n)) & ~__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET)) +#else + #define __Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET 0 + #define __Pyx_PyVectorcall_NARGS(n) ((Py_ssize_t)(n)) +#endif +#if PY_MAJOR_VERSION >= 0x030900B1 +#define __Pyx_PyCFunction_CheckExact(func) PyCFunction_CheckExact(func) +#else +#define __Pyx_PyCFunction_CheckExact(func) PyCFunction_Check(func) +#endif +#define __Pyx_CyOrPyCFunction_Check(func) PyCFunction_Check(func) +#if CYTHON_COMPILING_IN_CPYTHON +#define __Pyx_CyOrPyCFunction_GET_FUNCTION(func) (((PyCFunctionObject*)(func))->m_ml->ml_meth) +#elif !CYTHON_COMPILING_IN_LIMITED_API +#define __Pyx_CyOrPyCFunction_GET_FUNCTION(func) PyCFunction_GET_FUNCTION(func) +#endif +#if CYTHON_COMPILING_IN_CPYTHON +#define __Pyx_CyOrPyCFunction_GET_FLAGS(func) (((PyCFunctionObject*)(func))->m_ml->ml_flags) +static CYTHON_INLINE PyObject* __Pyx_CyOrPyCFunction_GET_SELF(PyObject *func) { + return (__Pyx_CyOrPyCFunction_GET_FLAGS(func) & METH_STATIC) ? NULL : ((PyCFunctionObject*)func)->m_self; +} +#endif +static CYTHON_INLINE int __Pyx__IsSameCFunction(PyObject *func, void *cfunc) { +#if CYTHON_COMPILING_IN_LIMITED_API + return PyCFunction_Check(func) && PyCFunction_GetFunction(func) == (PyCFunction) cfunc; +#else + return PyCFunction_Check(func) && PyCFunction_GET_FUNCTION(func) == (PyCFunction) cfunc; +#endif +} +#define __Pyx_IsSameCFunction(func, cfunc) __Pyx__IsSameCFunction(func, cfunc) +#if __PYX_LIMITED_VERSION_HEX < 0x030900B1 + #define __Pyx_PyType_FromModuleAndSpec(m, s, b) ((void)m, PyType_FromSpecWithBases(s, b)) + typedef PyObject *(*__Pyx_PyCMethod)(PyObject *, PyTypeObject *, PyObject *const *, size_t, PyObject *); +#else + #define __Pyx_PyType_FromModuleAndSpec(m, s, b) PyType_FromModuleAndSpec(m, s, b) + #define __Pyx_PyCMethod PyCMethod +#endif +#ifndef METH_METHOD + #define METH_METHOD 0x200 +#endif +#if CYTHON_COMPILING_IN_PYPY && !defined(PyObject_Malloc) + #define PyObject_Malloc(s) PyMem_Malloc(s) + #define PyObject_Free(p) PyMem_Free(p) + #define PyObject_Realloc(p) PyMem_Realloc(p) +#endif +#if CYTHON_COMPILING_IN_LIMITED_API + #define __Pyx_PyCode_HasFreeVars(co) (PyCode_GetNumFree(co) > 0) + #define __Pyx_PyFrame_SetLineNumber(frame, lineno) +#else + #define __Pyx_PyCode_HasFreeVars(co) (PyCode_GetNumFree(co) > 0) + #define __Pyx_PyFrame_SetLineNumber(frame, lineno) (frame)->f_lineno = (lineno) +#endif +#if CYTHON_COMPILING_IN_LIMITED_API + #define __Pyx_PyThreadState_Current PyThreadState_Get() +#elif !CYTHON_FAST_THREAD_STATE + #define __Pyx_PyThreadState_Current PyThreadState_GET() +#elif PY_VERSION_HEX >= 0x030d00A1 + #define __Pyx_PyThreadState_Current PyThreadState_GetUnchecked() +#elif PY_VERSION_HEX >= 0x03060000 + #define __Pyx_PyThreadState_Current _PyThreadState_UncheckedGet() +#elif PY_VERSION_HEX >= 0x03000000 + #define __Pyx_PyThreadState_Current PyThreadState_GET() +#else + #define __Pyx_PyThreadState_Current _PyThreadState_Current +#endif +#if CYTHON_COMPILING_IN_LIMITED_API +static CYTHON_INLINE void *__Pyx_PyModule_GetState(PyObject *op) +{ + void *result; + result = PyModule_GetState(op); + if (!result) + Py_FatalError("Couldn't find the module state"); + return result; +} +#endif +#define __Pyx_PyObject_GetSlot(obj, name, func_ctype) __Pyx_PyType_GetSlot(Py_TYPE(obj), name, func_ctype) +#if CYTHON_COMPILING_IN_LIMITED_API + #define __Pyx_PyType_GetSlot(type, name, func_ctype) ((func_ctype) PyType_GetSlot((type), Py_##name)) +#else + #define __Pyx_PyType_GetSlot(type, name, func_ctype) ((type)->name) +#endif +#if PY_VERSION_HEX < 0x030700A2 && !defined(PyThread_tss_create) && !defined(Py_tss_NEEDS_INIT) +#include "pythread.h" +#define Py_tss_NEEDS_INIT 0 +typedef int Py_tss_t; +static CYTHON_INLINE int PyThread_tss_create(Py_tss_t *key) { + *key = PyThread_create_key(); + return 0; +} +static CYTHON_INLINE Py_tss_t * PyThread_tss_alloc(void) { + Py_tss_t *key = (Py_tss_t *)PyObject_Malloc(sizeof(Py_tss_t)); + *key = Py_tss_NEEDS_INIT; + return key; +} +static CYTHON_INLINE void PyThread_tss_free(Py_tss_t *key) { + PyObject_Free(key); +} +static CYTHON_INLINE int PyThread_tss_is_created(Py_tss_t *key) { + return *key != Py_tss_NEEDS_INIT; +} +static CYTHON_INLINE void PyThread_tss_delete(Py_tss_t *key) { + PyThread_delete_key(*key); + *key = Py_tss_NEEDS_INIT; +} +static CYTHON_INLINE int PyThread_tss_set(Py_tss_t *key, void *value) { + return PyThread_set_key_value(*key, value); +} +static CYTHON_INLINE void * PyThread_tss_get(Py_tss_t *key) { + return PyThread_get_key_value(*key); +} +#endif +#if PY_MAJOR_VERSION < 3 + #if CYTHON_COMPILING_IN_PYPY + #if PYPY_VERSION_NUM < 0x07030600 + #if defined(__cplusplus) && __cplusplus >= 201402L + [[deprecated("`with nogil:` inside a nogil function will not release the GIL in PyPy2 < 7.3.6")]] + #elif defined(__GNUC__) || defined(__clang__) + __attribute__ ((__deprecated__("`with nogil:` inside a nogil function will not release the GIL in PyPy2 < 7.3.6"))) + #elif defined(_MSC_VER) + __declspec(deprecated("`with nogil:` inside a nogil function will not release the GIL in PyPy2 < 7.3.6")) + #endif + static CYTHON_INLINE int PyGILState_Check(void) { + return 0; + } + #else // PYPY_VERSION_NUM < 0x07030600 + #endif // PYPY_VERSION_NUM < 0x07030600 + #else + static CYTHON_INLINE int PyGILState_Check(void) { + PyThreadState * tstate = _PyThreadState_Current; + return tstate && (tstate == PyGILState_GetThisThreadState()); + } + #endif +#endif +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030d0000 || defined(_PyDict_NewPresized) +#define __Pyx_PyDict_NewPresized(n) ((n <= 8) ? PyDict_New() : _PyDict_NewPresized(n)) +#else +#define __Pyx_PyDict_NewPresized(n) PyDict_New() +#endif +#if PY_MAJOR_VERSION >= 3 || CYTHON_FUTURE_DIVISION + #define __Pyx_PyNumber_Divide(x,y) PyNumber_TrueDivide(x,y) + #define __Pyx_PyNumber_InPlaceDivide(x,y) PyNumber_InPlaceTrueDivide(x,y) +#else + #define __Pyx_PyNumber_Divide(x,y) PyNumber_Divide(x,y) + #define __Pyx_PyNumber_InPlaceDivide(x,y) PyNumber_InPlaceDivide(x,y) +#endif +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX > 0x030600B4 && PY_VERSION_HEX < 0x030d0000 && CYTHON_USE_UNICODE_INTERNALS +#define __Pyx_PyDict_GetItemStrWithError(dict, name) _PyDict_GetItem_KnownHash(dict, name, ((PyASCIIObject *) name)->hash) +static CYTHON_INLINE PyObject * __Pyx_PyDict_GetItemStr(PyObject *dict, PyObject *name) { + PyObject *res = __Pyx_PyDict_GetItemStrWithError(dict, name); + if (res == NULL) PyErr_Clear(); + return res; +} +#elif PY_MAJOR_VERSION >= 3 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07020000) +#define __Pyx_PyDict_GetItemStrWithError PyDict_GetItemWithError +#define __Pyx_PyDict_GetItemStr PyDict_GetItem +#else +static CYTHON_INLINE PyObject * __Pyx_PyDict_GetItemStrWithError(PyObject *dict, PyObject *name) { +#if CYTHON_COMPILING_IN_PYPY + return PyDict_GetItem(dict, name); +#else + PyDictEntry *ep; + PyDictObject *mp = (PyDictObject*) dict; + long hash = ((PyStringObject *) name)->ob_shash; + assert(hash != -1); + ep = (mp->ma_lookup)(mp, name, hash); + if (ep == NULL) { + return NULL; + } + return ep->me_value; +#endif +} +#define __Pyx_PyDict_GetItemStr PyDict_GetItem +#endif +#if CYTHON_USE_TYPE_SLOTS + #define __Pyx_PyType_GetFlags(tp) (((PyTypeObject *)tp)->tp_flags) + #define __Pyx_PyType_HasFeature(type, feature) ((__Pyx_PyType_GetFlags(type) & (feature)) != 0) + #define __Pyx_PyObject_GetIterNextFunc(obj) (Py_TYPE(obj)->tp_iternext) +#else + #define __Pyx_PyType_GetFlags(tp) (PyType_GetFlags((PyTypeObject *)tp)) + #define __Pyx_PyType_HasFeature(type, feature) PyType_HasFeature(type, feature) + #define __Pyx_PyObject_GetIterNextFunc(obj) PyIter_Next +#endif +#if CYTHON_COMPILING_IN_LIMITED_API + #define __Pyx_SetItemOnTypeDict(tp, k, v) PyObject_GenericSetAttr((PyObject*)tp, k, v) +#else + #define __Pyx_SetItemOnTypeDict(tp, k, v) PyDict_SetItem(tp->tp_dict, k, v) +#endif +#if CYTHON_USE_TYPE_SPECS && PY_VERSION_HEX >= 0x03080000 +#define __Pyx_PyHeapTypeObject_GC_Del(obj) {\ + PyTypeObject *type = Py_TYPE((PyObject*)obj);\ + assert(__Pyx_PyType_HasFeature(type, Py_TPFLAGS_HEAPTYPE));\ + PyObject_GC_Del(obj);\ + Py_DECREF(type);\ +} +#else +#define __Pyx_PyHeapTypeObject_GC_Del(obj) PyObject_GC_Del(obj) +#endif +#if CYTHON_COMPILING_IN_LIMITED_API + #define CYTHON_PEP393_ENABLED 1 + #define __Pyx_PyUnicode_READY(op) (0) + #define __Pyx_PyUnicode_GET_LENGTH(u) PyUnicode_GetLength(u) + #define __Pyx_PyUnicode_READ_CHAR(u, i) PyUnicode_ReadChar(u, i) + #define __Pyx_PyUnicode_MAX_CHAR_VALUE(u) ((void)u, 1114111U) + #define __Pyx_PyUnicode_KIND(u) ((void)u, (0)) + #define __Pyx_PyUnicode_DATA(u) ((void*)u) + #define __Pyx_PyUnicode_READ(k, d, i) ((void)k, PyUnicode_ReadChar((PyObject*)(d), i)) + #define __Pyx_PyUnicode_IS_TRUE(u) (0 != PyUnicode_GetLength(u)) +#elif PY_VERSION_HEX > 0x03030000 && defined(PyUnicode_KIND) + #define CYTHON_PEP393_ENABLED 1 + #if PY_VERSION_HEX >= 0x030C0000 + #define __Pyx_PyUnicode_READY(op) (0) + #else + #define __Pyx_PyUnicode_READY(op) (likely(PyUnicode_IS_READY(op)) ?\ + 0 : _PyUnicode_Ready((PyObject *)(op))) + #endif + #define __Pyx_PyUnicode_GET_LENGTH(u) PyUnicode_GET_LENGTH(u) + #define __Pyx_PyUnicode_READ_CHAR(u, i) PyUnicode_READ_CHAR(u, i) + #define __Pyx_PyUnicode_MAX_CHAR_VALUE(u) PyUnicode_MAX_CHAR_VALUE(u) + #define __Pyx_PyUnicode_KIND(u) ((int)PyUnicode_KIND(u)) + #define __Pyx_PyUnicode_DATA(u) PyUnicode_DATA(u) + #define __Pyx_PyUnicode_READ(k, d, i) PyUnicode_READ(k, d, i) + #define __Pyx_PyUnicode_WRITE(k, d, i, ch) PyUnicode_WRITE(k, d, i, (Py_UCS4) ch) + #if PY_VERSION_HEX >= 0x030C0000 + #define __Pyx_PyUnicode_IS_TRUE(u) (0 != PyUnicode_GET_LENGTH(u)) + #else + #if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x03090000 + #define __Pyx_PyUnicode_IS_TRUE(u) (0 != (likely(PyUnicode_IS_READY(u)) ? PyUnicode_GET_LENGTH(u) : ((PyCompactUnicodeObject *)(u))->wstr_length)) + #else + #define __Pyx_PyUnicode_IS_TRUE(u) (0 != (likely(PyUnicode_IS_READY(u)) ? PyUnicode_GET_LENGTH(u) : PyUnicode_GET_SIZE(u))) + #endif + #endif +#else + #define CYTHON_PEP393_ENABLED 0 + #define PyUnicode_1BYTE_KIND 1 + #define PyUnicode_2BYTE_KIND 2 + #define PyUnicode_4BYTE_KIND 4 + #define __Pyx_PyUnicode_READY(op) (0) + #define __Pyx_PyUnicode_GET_LENGTH(u) PyUnicode_GET_SIZE(u) + #define __Pyx_PyUnicode_READ_CHAR(u, i) ((Py_UCS4)(PyUnicode_AS_UNICODE(u)[i])) + #define __Pyx_PyUnicode_MAX_CHAR_VALUE(u) ((sizeof(Py_UNICODE) == 2) ? 65535U : 1114111U) + #define __Pyx_PyUnicode_KIND(u) ((int)sizeof(Py_UNICODE)) + #define __Pyx_PyUnicode_DATA(u) ((void*)PyUnicode_AS_UNICODE(u)) + #define __Pyx_PyUnicode_READ(k, d, i) ((void)(k), (Py_UCS4)(((Py_UNICODE*)d)[i])) + #define __Pyx_PyUnicode_WRITE(k, d, i, ch) (((void)(k)), ((Py_UNICODE*)d)[i] = (Py_UNICODE) ch) + #define __Pyx_PyUnicode_IS_TRUE(u) (0 != PyUnicode_GET_SIZE(u)) +#endif +#if CYTHON_COMPILING_IN_PYPY + #define __Pyx_PyUnicode_Concat(a, b) PyNumber_Add(a, b) + #define __Pyx_PyUnicode_ConcatSafe(a, b) PyNumber_Add(a, b) +#else + #define __Pyx_PyUnicode_Concat(a, b) PyUnicode_Concat(a, b) + #define __Pyx_PyUnicode_ConcatSafe(a, b) ((unlikely((a) == Py_None) || unlikely((b) == Py_None)) ?\ + PyNumber_Add(a, b) : __Pyx_PyUnicode_Concat(a, b)) +#endif +#if CYTHON_COMPILING_IN_PYPY + #if !defined(PyUnicode_DecodeUnicodeEscape) + #define PyUnicode_DecodeUnicodeEscape(s, size, errors) PyUnicode_Decode(s, size, "unicode_escape", errors) + #endif + #if !defined(PyUnicode_Contains) || (PY_MAJOR_VERSION == 2 && PYPY_VERSION_NUM < 0x07030500) + #undef PyUnicode_Contains + #define PyUnicode_Contains(u, s) PySequence_Contains(u, s) + #endif + #if !defined(PyByteArray_Check) + #define PyByteArray_Check(obj) PyObject_TypeCheck(obj, &PyByteArray_Type) + #endif + #if !defined(PyObject_Format) + #define PyObject_Format(obj, fmt) PyObject_CallMethod(obj, "__format__", "O", fmt) + #endif +#endif +#define __Pyx_PyString_FormatSafe(a, b) ((unlikely((a) == Py_None || (PyString_Check(b) && !PyString_CheckExact(b)))) ? PyNumber_Remainder(a, b) : __Pyx_PyString_Format(a, b)) +#define __Pyx_PyUnicode_FormatSafe(a, b) ((unlikely((a) == Py_None || (PyUnicode_Check(b) && !PyUnicode_CheckExact(b)))) ? PyNumber_Remainder(a, b) : PyUnicode_Format(a, b)) +#if PY_MAJOR_VERSION >= 3 + #define __Pyx_PyString_Format(a, b) PyUnicode_Format(a, b) +#else + #define __Pyx_PyString_Format(a, b) PyString_Format(a, b) +#endif +#if PY_MAJOR_VERSION < 3 && !defined(PyObject_ASCII) + #define PyObject_ASCII(o) PyObject_Repr(o) +#endif +#if PY_MAJOR_VERSION >= 3 + #define PyBaseString_Type PyUnicode_Type + #define PyStringObject PyUnicodeObject + #define PyString_Type PyUnicode_Type + #define PyString_Check PyUnicode_Check + #define PyString_CheckExact PyUnicode_CheckExact +#ifndef PyObject_Unicode + #define PyObject_Unicode PyObject_Str +#endif +#endif +#if PY_MAJOR_VERSION >= 3 + #define __Pyx_PyBaseString_Check(obj) PyUnicode_Check(obj) + #define __Pyx_PyBaseString_CheckExact(obj) PyUnicode_CheckExact(obj) +#else + #define __Pyx_PyBaseString_Check(obj) (PyString_Check(obj) || PyUnicode_Check(obj)) + #define __Pyx_PyBaseString_CheckExact(obj) (PyString_CheckExact(obj) || PyUnicode_CheckExact(obj)) +#endif +#if CYTHON_COMPILING_IN_CPYTHON + #define __Pyx_PySequence_ListKeepNew(obj)\ + (likely(PyList_CheckExact(obj) && Py_REFCNT(obj) == 1) ? __Pyx_NewRef(obj) : PySequence_List(obj)) +#else + #define __Pyx_PySequence_ListKeepNew(obj) PySequence_List(obj) +#endif +#ifndef PySet_CheckExact + #define PySet_CheckExact(obj) __Pyx_IS_TYPE(obj, &PySet_Type) +#endif +#if PY_VERSION_HEX >= 0x030900A4 + #define __Pyx_SET_REFCNT(obj, refcnt) Py_SET_REFCNT(obj, refcnt) + #define __Pyx_SET_SIZE(obj, size) Py_SET_SIZE(obj, size) +#else + #define __Pyx_SET_REFCNT(obj, refcnt) Py_REFCNT(obj) = (refcnt) + #define __Pyx_SET_SIZE(obj, size) Py_SIZE(obj) = (size) +#endif +#if CYTHON_ASSUME_SAFE_MACROS + #define __Pyx_PySequence_ITEM(o, i) PySequence_ITEM(o, i) + #define __Pyx_PySequence_SIZE(seq) Py_SIZE(seq) + #define __Pyx_PyTuple_SET_ITEM(o, i, v) (PyTuple_SET_ITEM(o, i, v), (0)) + #define __Pyx_PyList_SET_ITEM(o, i, v) (PyList_SET_ITEM(o, i, v), (0)) + #define __Pyx_PyTuple_GET_SIZE(o) PyTuple_GET_SIZE(o) + #define __Pyx_PyList_GET_SIZE(o) PyList_GET_SIZE(o) + #define __Pyx_PySet_GET_SIZE(o) PySet_GET_SIZE(o) + #define __Pyx_PyBytes_GET_SIZE(o) PyBytes_GET_SIZE(o) + #define __Pyx_PyByteArray_GET_SIZE(o) PyByteArray_GET_SIZE(o) +#else + #define __Pyx_PySequence_ITEM(o, i) PySequence_GetItem(o, i) + #define __Pyx_PySequence_SIZE(seq) PySequence_Size(seq) + #define __Pyx_PyTuple_SET_ITEM(o, i, v) PyTuple_SetItem(o, i, v) + #define __Pyx_PyList_SET_ITEM(o, i, v) PyList_SetItem(o, i, v) + #define __Pyx_PyTuple_GET_SIZE(o) PyTuple_Size(o) + #define __Pyx_PyList_GET_SIZE(o) PyList_Size(o) + #define __Pyx_PySet_GET_SIZE(o) PySet_Size(o) + #define __Pyx_PyBytes_GET_SIZE(o) PyBytes_Size(o) + #define __Pyx_PyByteArray_GET_SIZE(o) PyByteArray_Size(o) +#endif +#if PY_VERSION_HEX >= 0x030d00A1 + #define __Pyx_PyImport_AddModuleRef(name) PyImport_AddModuleRef(name) +#else + static CYTHON_INLINE PyObject *__Pyx_PyImport_AddModuleRef(const char *name) { + PyObject *module = PyImport_AddModule(name); + Py_XINCREF(module); + return module; + } +#endif +#if PY_MAJOR_VERSION >= 3 + #define PyIntObject PyLongObject + #define PyInt_Type PyLong_Type + #define PyInt_Check(op) PyLong_Check(op) + #define PyInt_CheckExact(op) PyLong_CheckExact(op) + #define __Pyx_Py3Int_Check(op) PyLong_Check(op) + #define __Pyx_Py3Int_CheckExact(op) PyLong_CheckExact(op) + #define PyInt_FromString PyLong_FromString + #define PyInt_FromUnicode PyLong_FromUnicode + #define PyInt_FromLong PyLong_FromLong + #define PyInt_FromSize_t PyLong_FromSize_t + #define PyInt_FromSsize_t PyLong_FromSsize_t + #define PyInt_AsLong PyLong_AsLong + #define PyInt_AS_LONG PyLong_AS_LONG + #define PyInt_AsSsize_t PyLong_AsSsize_t + #define PyInt_AsUnsignedLongMask PyLong_AsUnsignedLongMask + #define PyInt_AsUnsignedLongLongMask PyLong_AsUnsignedLongLongMask + #define PyNumber_Int PyNumber_Long +#else + #define __Pyx_Py3Int_Check(op) (PyLong_Check(op) || PyInt_Check(op)) + #define __Pyx_Py3Int_CheckExact(op) (PyLong_CheckExact(op) || PyInt_CheckExact(op)) +#endif +#if PY_MAJOR_VERSION >= 3 + #define PyBoolObject PyLongObject +#endif +#if PY_MAJOR_VERSION >= 3 && CYTHON_COMPILING_IN_PYPY + #ifndef PyUnicode_InternFromString + #define PyUnicode_InternFromString(s) PyUnicode_FromString(s) + #endif +#endif +#if PY_VERSION_HEX < 0x030200A4 + typedef long Py_hash_t; + #define __Pyx_PyInt_FromHash_t PyInt_FromLong + #define __Pyx_PyInt_AsHash_t __Pyx_PyIndex_AsHash_t +#else + #define __Pyx_PyInt_FromHash_t PyInt_FromSsize_t + #define __Pyx_PyInt_AsHash_t __Pyx_PyIndex_AsSsize_t +#endif +#if CYTHON_USE_ASYNC_SLOTS + #if PY_VERSION_HEX >= 0x030500B1 + #define __Pyx_PyAsyncMethodsStruct PyAsyncMethods + #define __Pyx_PyType_AsAsync(obj) (Py_TYPE(obj)->tp_as_async) + #else + #define __Pyx_PyType_AsAsync(obj) ((__Pyx_PyAsyncMethodsStruct*) (Py_TYPE(obj)->tp_reserved)) + #endif +#else + #define __Pyx_PyType_AsAsync(obj) NULL +#endif +#ifndef __Pyx_PyAsyncMethodsStruct + typedef struct { + unaryfunc am_await; + unaryfunc am_aiter; + unaryfunc am_anext; + } __Pyx_PyAsyncMethodsStruct; +#endif + +#if defined(_WIN32) || defined(WIN32) || defined(MS_WINDOWS) + #if !defined(_USE_MATH_DEFINES) + #define _USE_MATH_DEFINES + #endif +#endif +#include +#ifdef NAN +#define __PYX_NAN() ((float) NAN) +#else +static CYTHON_INLINE float __PYX_NAN() { + float value; + memset(&value, 0xFF, sizeof(value)); + return value; +} +#endif +#if defined(__CYGWIN__) && defined(_LDBL_EQ_DBL) +#define __Pyx_truncl trunc +#else +#define __Pyx_truncl truncl +#endif + +#define __PYX_MARK_ERR_POS(f_index, lineno) \ + { __pyx_filename = __pyx_f[f_index]; (void)__pyx_filename; __pyx_lineno = lineno; (void)__pyx_lineno; __pyx_clineno = __LINE__; (void)__pyx_clineno; } +#define __PYX_ERR(f_index, lineno, Ln_error) \ + { __PYX_MARK_ERR_POS(f_index, lineno) goto Ln_error; } + +#ifdef CYTHON_EXTERN_C + #undef __PYX_EXTERN_C + #define __PYX_EXTERN_C CYTHON_EXTERN_C +#elif defined(__PYX_EXTERN_C) + #ifdef _MSC_VER + #pragma message ("Please do not define the '__PYX_EXTERN_C' macro externally. Use 'CYTHON_EXTERN_C' instead.") + #else + #warning Please do not define the '__PYX_EXTERN_C' macro externally. Use 'CYTHON_EXTERN_C' instead. + #endif +#else + #define __PYX_EXTERN_C extern "C++" +#endif + +#define __PYX_HAVE__fairseq__data__token_block_utils_fast +#define __PYX_HAVE_API__fairseq__data__token_block_utils_fast +/* Early includes */ +#include +#include +#include + + /* Using NumPy API declarations from "numpy/__init__.cython-30.pxd" */ + +#include "numpy/arrayobject.h" +#include "numpy/ndarrayobject.h" +#include "numpy/ndarraytypes.h" +#include "numpy/arrayscalars.h" +#include "numpy/ufuncobject.h" +#include +#include "pythread.h" +#include +#ifdef _OPENMP +#include +#endif /* _OPENMP */ + +#if defined(PYREX_WITHOUT_ASSERTIONS) && !defined(CYTHON_WITHOUT_ASSERTIONS) +#define CYTHON_WITHOUT_ASSERTIONS +#endif + +typedef struct {PyObject **p; const char *s; const Py_ssize_t n; const char* encoding; + const char is_unicode; const char is_str; const char intern; } __Pyx_StringTabEntry; + +#define __PYX_DEFAULT_STRING_ENCODING_IS_ASCII 0 +#define __PYX_DEFAULT_STRING_ENCODING_IS_UTF8 0 +#define __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT (PY_MAJOR_VERSION >= 3 && __PYX_DEFAULT_STRING_ENCODING_IS_UTF8) +#define __PYX_DEFAULT_STRING_ENCODING "" +#define __Pyx_PyObject_FromString __Pyx_PyBytes_FromString +#define __Pyx_PyObject_FromStringAndSize __Pyx_PyBytes_FromStringAndSize +#define __Pyx_uchar_cast(c) ((unsigned char)c) +#define __Pyx_long_cast(x) ((long)x) +#define __Pyx_fits_Py_ssize_t(v, type, is_signed) (\ + (sizeof(type) < sizeof(Py_ssize_t)) ||\ + (sizeof(type) > sizeof(Py_ssize_t) &&\ + likely(v < (type)PY_SSIZE_T_MAX ||\ + v == (type)PY_SSIZE_T_MAX) &&\ + (!is_signed || likely(v > (type)PY_SSIZE_T_MIN ||\ + v == (type)PY_SSIZE_T_MIN))) ||\ + (sizeof(type) == sizeof(Py_ssize_t) &&\ + (is_signed || likely(v < (type)PY_SSIZE_T_MAX ||\ + v == (type)PY_SSIZE_T_MAX))) ) +static CYTHON_INLINE int __Pyx_is_valid_index(Py_ssize_t i, Py_ssize_t limit) { + return (size_t) i < (size_t) limit; +} +#if defined (__cplusplus) && __cplusplus >= 201103L + #include + #define __Pyx_sst_abs(value) std::abs(value) +#elif SIZEOF_INT >= SIZEOF_SIZE_T + #define __Pyx_sst_abs(value) abs(value) +#elif SIZEOF_LONG >= SIZEOF_SIZE_T + #define __Pyx_sst_abs(value) labs(value) +#elif defined (_MSC_VER) + #define __Pyx_sst_abs(value) ((Py_ssize_t)_abs64(value)) +#elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L + #define __Pyx_sst_abs(value) llabs(value) +#elif defined (__GNUC__) + #define __Pyx_sst_abs(value) __builtin_llabs(value) +#else + #define __Pyx_sst_abs(value) ((value<0) ? -value : value) +#endif +static CYTHON_INLINE Py_ssize_t __Pyx_ssize_strlen(const char *s); +static CYTHON_INLINE const char* __Pyx_PyObject_AsString(PyObject*); +static CYTHON_INLINE const char* __Pyx_PyObject_AsStringAndSize(PyObject*, Py_ssize_t* length); +static CYTHON_INLINE PyObject* __Pyx_PyByteArray_FromString(const char*); +#define __Pyx_PyByteArray_FromStringAndSize(s, l) PyByteArray_FromStringAndSize((const char*)s, l) +#define __Pyx_PyBytes_FromString PyBytes_FromString +#define __Pyx_PyBytes_FromStringAndSize PyBytes_FromStringAndSize +static CYTHON_INLINE PyObject* __Pyx_PyUnicode_FromString(const char*); +#if PY_MAJOR_VERSION < 3 + #define __Pyx_PyStr_FromString __Pyx_PyBytes_FromString + #define __Pyx_PyStr_FromStringAndSize __Pyx_PyBytes_FromStringAndSize +#else + #define __Pyx_PyStr_FromString __Pyx_PyUnicode_FromString + #define __Pyx_PyStr_FromStringAndSize __Pyx_PyUnicode_FromStringAndSize +#endif +#define __Pyx_PyBytes_AsWritableString(s) ((char*) PyBytes_AS_STRING(s)) +#define __Pyx_PyBytes_AsWritableSString(s) ((signed char*) PyBytes_AS_STRING(s)) +#define __Pyx_PyBytes_AsWritableUString(s) ((unsigned char*) PyBytes_AS_STRING(s)) +#define __Pyx_PyBytes_AsString(s) ((const char*) PyBytes_AS_STRING(s)) +#define __Pyx_PyBytes_AsSString(s) ((const signed char*) PyBytes_AS_STRING(s)) +#define __Pyx_PyBytes_AsUString(s) ((const unsigned char*) PyBytes_AS_STRING(s)) +#define __Pyx_PyObject_AsWritableString(s) ((char*)(__pyx_uintptr_t) __Pyx_PyObject_AsString(s)) +#define __Pyx_PyObject_AsWritableSString(s) ((signed char*)(__pyx_uintptr_t) __Pyx_PyObject_AsString(s)) +#define __Pyx_PyObject_AsWritableUString(s) ((unsigned char*)(__pyx_uintptr_t) __Pyx_PyObject_AsString(s)) +#define __Pyx_PyObject_AsSString(s) ((const signed char*) __Pyx_PyObject_AsString(s)) +#define __Pyx_PyObject_AsUString(s) ((const unsigned char*) __Pyx_PyObject_AsString(s)) +#define __Pyx_PyObject_FromCString(s) __Pyx_PyObject_FromString((const char*)s) +#define __Pyx_PyBytes_FromCString(s) __Pyx_PyBytes_FromString((const char*)s) +#define __Pyx_PyByteArray_FromCString(s) __Pyx_PyByteArray_FromString((const char*)s) +#define __Pyx_PyStr_FromCString(s) __Pyx_PyStr_FromString((const char*)s) +#define __Pyx_PyUnicode_FromCString(s) __Pyx_PyUnicode_FromString((const char*)s) +#if CYTHON_COMPILING_IN_LIMITED_API +static CYTHON_INLINE size_t __Pyx_Py_UNICODE_strlen(const wchar_t *u) +{ + const wchar_t *u_end = u; + while (*u_end++) ; + return (size_t)(u_end - u - 1); +} +#else +static CYTHON_INLINE size_t __Pyx_Py_UNICODE_strlen(const Py_UNICODE *u) +{ + const Py_UNICODE *u_end = u; + while (*u_end++) ; + return (size_t)(u_end - u - 1); +} +#endif +#define __Pyx_PyUnicode_FromOrdinal(o) PyUnicode_FromOrdinal((int)o) +#define __Pyx_PyUnicode_FromUnicode(u) PyUnicode_FromUnicode(u, __Pyx_Py_UNICODE_strlen(u)) +#define __Pyx_PyUnicode_FromUnicodeAndLength PyUnicode_FromUnicode +#define __Pyx_PyUnicode_AsUnicode PyUnicode_AsUnicode +#define __Pyx_NewRef(obj) (Py_INCREF(obj), obj) +#define __Pyx_Owned_Py_None(b) __Pyx_NewRef(Py_None) +static CYTHON_INLINE PyObject * __Pyx_PyBool_FromLong(long b); +static CYTHON_INLINE int __Pyx_PyObject_IsTrue(PyObject*); +static CYTHON_INLINE int __Pyx_PyObject_IsTrueAndDecref(PyObject*); +static CYTHON_INLINE PyObject* __Pyx_PyNumber_IntOrLong(PyObject* x); +#define __Pyx_PySequence_Tuple(obj)\ + (likely(PyTuple_CheckExact(obj)) ? __Pyx_NewRef(obj) : PySequence_Tuple(obj)) +static CYTHON_INLINE Py_ssize_t __Pyx_PyIndex_AsSsize_t(PyObject*); +static CYTHON_INLINE PyObject * __Pyx_PyInt_FromSize_t(size_t); +static CYTHON_INLINE Py_hash_t __Pyx_PyIndex_AsHash_t(PyObject*); +#if CYTHON_ASSUME_SAFE_MACROS +#define __pyx_PyFloat_AsDouble(x) (PyFloat_CheckExact(x) ? PyFloat_AS_DOUBLE(x) : PyFloat_AsDouble(x)) +#else +#define __pyx_PyFloat_AsDouble(x) PyFloat_AsDouble(x) +#endif +#define __pyx_PyFloat_AsFloat(x) ((float) __pyx_PyFloat_AsDouble(x)) +#if PY_MAJOR_VERSION >= 3 +#define __Pyx_PyNumber_Int(x) (PyLong_CheckExact(x) ? __Pyx_NewRef(x) : PyNumber_Long(x)) +#else +#define __Pyx_PyNumber_Int(x) (PyInt_CheckExact(x) ? __Pyx_NewRef(x) : PyNumber_Int(x)) +#endif +#if CYTHON_USE_PYLONG_INTERNALS + #if PY_VERSION_HEX >= 0x030C00A7 + #ifndef _PyLong_SIGN_MASK + #define _PyLong_SIGN_MASK 3 + #endif + #ifndef _PyLong_NON_SIZE_BITS + #define _PyLong_NON_SIZE_BITS 3 + #endif + #define __Pyx_PyLong_Sign(x) (((PyLongObject*)x)->long_value.lv_tag & _PyLong_SIGN_MASK) + #define __Pyx_PyLong_IsNeg(x) ((__Pyx_PyLong_Sign(x) & 2) != 0) + #define __Pyx_PyLong_IsNonNeg(x) (!__Pyx_PyLong_IsNeg(x)) + #define __Pyx_PyLong_IsZero(x) (__Pyx_PyLong_Sign(x) & 1) + #define __Pyx_PyLong_IsPos(x) (__Pyx_PyLong_Sign(x) == 0) + #define __Pyx_PyLong_CompactValueUnsigned(x) (__Pyx_PyLong_Digits(x)[0]) + #define __Pyx_PyLong_DigitCount(x) ((Py_ssize_t) (((PyLongObject*)x)->long_value.lv_tag >> _PyLong_NON_SIZE_BITS)) + #define __Pyx_PyLong_SignedDigitCount(x)\ + ((1 - (Py_ssize_t) __Pyx_PyLong_Sign(x)) * __Pyx_PyLong_DigitCount(x)) + #if defined(PyUnstable_Long_IsCompact) && defined(PyUnstable_Long_CompactValue) + #define __Pyx_PyLong_IsCompact(x) PyUnstable_Long_IsCompact((PyLongObject*) x) + #define __Pyx_PyLong_CompactValue(x) PyUnstable_Long_CompactValue((PyLongObject*) x) + #else + #define __Pyx_PyLong_IsCompact(x) (((PyLongObject*)x)->long_value.lv_tag < (2 << _PyLong_NON_SIZE_BITS)) + #define __Pyx_PyLong_CompactValue(x) ((1 - (Py_ssize_t) __Pyx_PyLong_Sign(x)) * (Py_ssize_t) __Pyx_PyLong_Digits(x)[0]) + #endif + typedef Py_ssize_t __Pyx_compact_pylong; + typedef size_t __Pyx_compact_upylong; + #else + #define __Pyx_PyLong_IsNeg(x) (Py_SIZE(x) < 0) + #define __Pyx_PyLong_IsNonNeg(x) (Py_SIZE(x) >= 0) + #define __Pyx_PyLong_IsZero(x) (Py_SIZE(x) == 0) + #define __Pyx_PyLong_IsPos(x) (Py_SIZE(x) > 0) + #define __Pyx_PyLong_CompactValueUnsigned(x) ((Py_SIZE(x) == 0) ? 0 : __Pyx_PyLong_Digits(x)[0]) + #define __Pyx_PyLong_DigitCount(x) __Pyx_sst_abs(Py_SIZE(x)) + #define __Pyx_PyLong_SignedDigitCount(x) Py_SIZE(x) + #define __Pyx_PyLong_IsCompact(x) (Py_SIZE(x) == 0 || Py_SIZE(x) == 1 || Py_SIZE(x) == -1) + #define __Pyx_PyLong_CompactValue(x)\ + ((Py_SIZE(x) == 0) ? (sdigit) 0 : ((Py_SIZE(x) < 0) ? -(sdigit)__Pyx_PyLong_Digits(x)[0] : (sdigit)__Pyx_PyLong_Digits(x)[0])) + typedef sdigit __Pyx_compact_pylong; + typedef digit __Pyx_compact_upylong; + #endif + #if PY_VERSION_HEX >= 0x030C00A5 + #define __Pyx_PyLong_Digits(x) (((PyLongObject*)x)->long_value.ob_digit) + #else + #define __Pyx_PyLong_Digits(x) (((PyLongObject*)x)->ob_digit) + #endif +#endif +#if PY_MAJOR_VERSION < 3 && __PYX_DEFAULT_STRING_ENCODING_IS_ASCII +#include +static int __Pyx_sys_getdefaultencoding_not_ascii; +static int __Pyx_init_sys_getdefaultencoding_params(void) { + PyObject* sys; + PyObject* default_encoding = NULL; + PyObject* ascii_chars_u = NULL; + PyObject* ascii_chars_b = NULL; + const char* default_encoding_c; + sys = PyImport_ImportModule("sys"); + if (!sys) goto bad; + default_encoding = PyObject_CallMethod(sys, (char*) "getdefaultencoding", NULL); + Py_DECREF(sys); + if (!default_encoding) goto bad; + default_encoding_c = PyBytes_AsString(default_encoding); + if (!default_encoding_c) goto bad; + if (strcmp(default_encoding_c, "ascii") == 0) { + __Pyx_sys_getdefaultencoding_not_ascii = 0; + } else { + char ascii_chars[128]; + int c; + for (c = 0; c < 128; c++) { + ascii_chars[c] = (char) c; + } + __Pyx_sys_getdefaultencoding_not_ascii = 1; + ascii_chars_u = PyUnicode_DecodeASCII(ascii_chars, 128, NULL); + if (!ascii_chars_u) goto bad; + ascii_chars_b = PyUnicode_AsEncodedString(ascii_chars_u, default_encoding_c, NULL); + if (!ascii_chars_b || !PyBytes_Check(ascii_chars_b) || memcmp(ascii_chars, PyBytes_AS_STRING(ascii_chars_b), 128) != 0) { + PyErr_Format( + PyExc_ValueError, + "This module compiled with c_string_encoding=ascii, but default encoding '%.200s' is not a superset of ascii.", + default_encoding_c); + goto bad; + } + Py_DECREF(ascii_chars_u); + Py_DECREF(ascii_chars_b); + } + Py_DECREF(default_encoding); + return 0; +bad: + Py_XDECREF(default_encoding); + Py_XDECREF(ascii_chars_u); + Py_XDECREF(ascii_chars_b); + return -1; +} +#endif +#if __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT && PY_MAJOR_VERSION >= 3 +#define __Pyx_PyUnicode_FromStringAndSize(c_str, size) PyUnicode_DecodeUTF8(c_str, size, NULL) +#else +#define __Pyx_PyUnicode_FromStringAndSize(c_str, size) PyUnicode_Decode(c_str, size, __PYX_DEFAULT_STRING_ENCODING, NULL) +#if __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT +#include +static char* __PYX_DEFAULT_STRING_ENCODING; +static int __Pyx_init_sys_getdefaultencoding_params(void) { + PyObject* sys; + PyObject* default_encoding = NULL; + char* default_encoding_c; + sys = PyImport_ImportModule("sys"); + if (!sys) goto bad; + default_encoding = PyObject_CallMethod(sys, (char*) (const char*) "getdefaultencoding", NULL); + Py_DECREF(sys); + if (!default_encoding) goto bad; + default_encoding_c = PyBytes_AsString(default_encoding); + if (!default_encoding_c) goto bad; + __PYX_DEFAULT_STRING_ENCODING = (char*) malloc(strlen(default_encoding_c) + 1); + if (!__PYX_DEFAULT_STRING_ENCODING) goto bad; + strcpy(__PYX_DEFAULT_STRING_ENCODING, default_encoding_c); + Py_DECREF(default_encoding); + return 0; +bad: + Py_XDECREF(default_encoding); + return -1; +} +#endif +#endif + + +/* Test for GCC > 2.95 */ +#if defined(__GNUC__) && (__GNUC__ > 2 || (__GNUC__ == 2 && (__GNUC_MINOR__ > 95))) + #define likely(x) __builtin_expect(!!(x), 1) + #define unlikely(x) __builtin_expect(!!(x), 0) +#else /* !__GNUC__ or GCC < 2.95 */ + #define likely(x) (x) + #define unlikely(x) (x) +#endif /* __GNUC__ */ +static CYTHON_INLINE void __Pyx_pretend_to_initialize(void* ptr) { (void)ptr; } + +#if !CYTHON_USE_MODULE_STATE +static PyObject *__pyx_m = NULL; +#endif +static int __pyx_lineno; +static int __pyx_clineno = 0; +static const char * __pyx_cfilenm = __FILE__; +static const char *__pyx_filename; + +/* Header.proto */ +#if !defined(CYTHON_CCOMPLEX) + #if defined(__cplusplus) + #define CYTHON_CCOMPLEX 1 + #elif (defined(_Complex_I) && !defined(_MSC_VER)) || ((defined (__STDC_VERSION__) && __STDC_VERSION__ >= 201112L) && !defined(__STDC_NO_COMPLEX__) && !defined(_MSC_VER)) + #define CYTHON_CCOMPLEX 1 + #else + #define CYTHON_CCOMPLEX 0 + #endif +#endif +#if CYTHON_CCOMPLEX + #ifdef __cplusplus + #include + #else + #include + #endif +#endif +#if CYTHON_CCOMPLEX && !defined(__cplusplus) && defined(__sun__) && defined(__GNUC__) + #undef _Complex_I + #define _Complex_I 1.0fj +#endif + +/* #### Code section: filename_table ### */ + +static const char *__pyx_f[] = { + "fairseq/data/token_block_utils_fast.pyx", + "", + "__init__.cython-30.pxd", + "type.pxd", +}; +/* #### Code section: utility_code_proto_before_types ### */ +/* ForceInitThreads.proto */ +#ifndef __PYX_FORCE_INIT_THREADS + #define __PYX_FORCE_INIT_THREADS 0 +#endif + +/* NoFastGil.proto */ +#define __Pyx_PyGILState_Ensure PyGILState_Ensure +#define __Pyx_PyGILState_Release PyGILState_Release +#define __Pyx_FastGIL_Remember() +#define __Pyx_FastGIL_Forget() +#define __Pyx_FastGilFuncInit() + +/* BufferFormatStructs.proto */ +struct __Pyx_StructField_; +#define __PYX_BUF_FLAGS_PACKED_STRUCT (1 << 0) +typedef struct { + const char* name; + struct __Pyx_StructField_* fields; + size_t size; + size_t arraysize[8]; + int ndim; + char typegroup; + char is_unsigned; + int flags; +} __Pyx_TypeInfo; +typedef struct __Pyx_StructField_ { + __Pyx_TypeInfo* type; + const char* name; + size_t offset; +} __Pyx_StructField; +typedef struct { + __Pyx_StructField* field; + size_t parent_offset; +} __Pyx_BufFmt_StackElem; +typedef struct { + __Pyx_StructField root; + __Pyx_BufFmt_StackElem* head; + size_t fmt_offset; + size_t new_count, enc_count; + size_t struct_alignment; + int is_complex; + char enc_type; + char new_packmode; + char enc_packmode; + char is_valid_array; +} __Pyx_BufFmt_Context; + +/* Atomics.proto */ +#include +#ifndef CYTHON_ATOMICS + #define CYTHON_ATOMICS 1 +#endif +#define __PYX_CYTHON_ATOMICS_ENABLED() CYTHON_ATOMICS +#define __pyx_atomic_int_type int +#define __pyx_nonatomic_int_type int +#if CYTHON_ATOMICS && (defined(__STDC_VERSION__) &&\ + (__STDC_VERSION__ >= 201112L) &&\ + !defined(__STDC_NO_ATOMICS__)) + #include +#elif CYTHON_ATOMICS && (defined(__cplusplus) && (\ + (__cplusplus >= 201103L) ||\ + (defined(_MSC_VER) && _MSC_VER >= 1700))) + #include +#endif +#if CYTHON_ATOMICS && (defined(__STDC_VERSION__) &&\ + (__STDC_VERSION__ >= 201112L) &&\ + !defined(__STDC_NO_ATOMICS__) &&\ + ATOMIC_INT_LOCK_FREE == 2) + #undef __pyx_atomic_int_type + #define __pyx_atomic_int_type atomic_int + #define __pyx_atomic_incr_aligned(value) atomic_fetch_add_explicit(value, 1, memory_order_relaxed) + #define __pyx_atomic_decr_aligned(value) atomic_fetch_sub_explicit(value, 1, memory_order_acq_rel) + #if defined(__PYX_DEBUG_ATOMICS) && defined(_MSC_VER) + #pragma message ("Using standard C atomics") + #elif defined(__PYX_DEBUG_ATOMICS) + #warning "Using standard C atomics" + #endif +#elif CYTHON_ATOMICS && (defined(__cplusplus) && (\ + (__cplusplus >= 201103L) ||\ +\ + (defined(_MSC_VER) && _MSC_VER >= 1700)) &&\ + ATOMIC_INT_LOCK_FREE == 2) + #undef __pyx_atomic_int_type + #define __pyx_atomic_int_type std::atomic_int + #define __pyx_atomic_incr_aligned(value) std::atomic_fetch_add_explicit(value, 1, std::memory_order_relaxed) + #define __pyx_atomic_decr_aligned(value) std::atomic_fetch_sub_explicit(value, 1, std::memory_order_acq_rel) + #if defined(__PYX_DEBUG_ATOMICS) && defined(_MSC_VER) + #pragma message ("Using standard C++ atomics") + #elif defined(__PYX_DEBUG_ATOMICS) + #warning "Using standard C++ atomics" + #endif +#elif CYTHON_ATOMICS && (__GNUC__ >= 5 || (__GNUC__ == 4 &&\ + (__GNUC_MINOR__ > 1 ||\ + (__GNUC_MINOR__ == 1 && __GNUC_PATCHLEVEL__ >= 2)))) + #define __pyx_atomic_incr_aligned(value) __sync_fetch_and_add(value, 1) + #define __pyx_atomic_decr_aligned(value) __sync_fetch_and_sub(value, 1) + #ifdef __PYX_DEBUG_ATOMICS + #warning "Using GNU atomics" + #endif +#elif CYTHON_ATOMICS && defined(_MSC_VER) + #include + #undef __pyx_atomic_int_type + #define __pyx_atomic_int_type long + #undef __pyx_nonatomic_int_type + #define __pyx_nonatomic_int_type long + #pragma intrinsic (_InterlockedExchangeAdd) + #define __pyx_atomic_incr_aligned(value) _InterlockedExchangeAdd(value, 1) + #define __pyx_atomic_decr_aligned(value) _InterlockedExchangeAdd(value, -1) + #ifdef __PYX_DEBUG_ATOMICS + #pragma message ("Using MSVC atomics") + #endif +#else + #undef CYTHON_ATOMICS + #define CYTHON_ATOMICS 0 + #ifdef __PYX_DEBUG_ATOMICS + #warning "Not using atomics" + #endif +#endif +#if CYTHON_ATOMICS + #define __pyx_add_acquisition_count(memview)\ + __pyx_atomic_incr_aligned(__pyx_get_slice_count_pointer(memview)) + #define __pyx_sub_acquisition_count(memview)\ + __pyx_atomic_decr_aligned(__pyx_get_slice_count_pointer(memview)) +#else + #define __pyx_add_acquisition_count(memview)\ + __pyx_add_acquisition_count_locked(__pyx_get_slice_count_pointer(memview), memview->lock) + #define __pyx_sub_acquisition_count(memview)\ + __pyx_sub_acquisition_count_locked(__pyx_get_slice_count_pointer(memview), memview->lock) +#endif + +/* MemviewSliceStruct.proto */ +struct __pyx_memoryview_obj; +typedef struct { + struct __pyx_memoryview_obj *memview; + char *data; + Py_ssize_t shape[8]; + Py_ssize_t strides[8]; + Py_ssize_t suboffsets[8]; +} __Pyx_memviewslice; +#define __Pyx_MemoryView_Len(m) (m.shape[0]) + +/* #### Code section: numeric_typedefs ### */ + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":730 + * # in Cython to enable them only on the right systems. + * + * ctypedef npy_int8 int8_t # <<<<<<<<<<<<<< + * ctypedef npy_int16 int16_t + * ctypedef npy_int32 int32_t + */ +typedef npy_int8 __pyx_t_5numpy_int8_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":731 + * + * ctypedef npy_int8 int8_t + * ctypedef npy_int16 int16_t # <<<<<<<<<<<<<< + * ctypedef npy_int32 int32_t + * ctypedef npy_int64 int64_t + */ +typedef npy_int16 __pyx_t_5numpy_int16_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":732 + * ctypedef npy_int8 int8_t + * ctypedef npy_int16 int16_t + * ctypedef npy_int32 int32_t # <<<<<<<<<<<<<< + * ctypedef npy_int64 int64_t + * #ctypedef npy_int96 int96_t + */ +typedef npy_int32 __pyx_t_5numpy_int32_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":733 + * ctypedef npy_int16 int16_t + * ctypedef npy_int32 int32_t + * ctypedef npy_int64 int64_t # <<<<<<<<<<<<<< + * #ctypedef npy_int96 int96_t + * #ctypedef npy_int128 int128_t + */ +typedef npy_int64 __pyx_t_5numpy_int64_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":737 + * #ctypedef npy_int128 int128_t + * + * ctypedef npy_uint8 uint8_t # <<<<<<<<<<<<<< + * ctypedef npy_uint16 uint16_t + * ctypedef npy_uint32 uint32_t + */ +typedef npy_uint8 __pyx_t_5numpy_uint8_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":738 + * + * ctypedef npy_uint8 uint8_t + * ctypedef npy_uint16 uint16_t # <<<<<<<<<<<<<< + * ctypedef npy_uint32 uint32_t + * ctypedef npy_uint64 uint64_t + */ +typedef npy_uint16 __pyx_t_5numpy_uint16_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":739 + * ctypedef npy_uint8 uint8_t + * ctypedef npy_uint16 uint16_t + * ctypedef npy_uint32 uint32_t # <<<<<<<<<<<<<< + * ctypedef npy_uint64 uint64_t + * #ctypedef npy_uint96 uint96_t + */ +typedef npy_uint32 __pyx_t_5numpy_uint32_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":740 + * ctypedef npy_uint16 uint16_t + * ctypedef npy_uint32 uint32_t + * ctypedef npy_uint64 uint64_t # <<<<<<<<<<<<<< + * #ctypedef npy_uint96 uint96_t + * #ctypedef npy_uint128 uint128_t + */ +typedef npy_uint64 __pyx_t_5numpy_uint64_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":744 + * #ctypedef npy_uint128 uint128_t + * + * ctypedef npy_float32 float32_t # <<<<<<<<<<<<<< + * ctypedef npy_float64 float64_t + * #ctypedef npy_float80 float80_t + */ +typedef npy_float32 __pyx_t_5numpy_float32_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":745 + * + * ctypedef npy_float32 float32_t + * ctypedef npy_float64 float64_t # <<<<<<<<<<<<<< + * #ctypedef npy_float80 float80_t + * #ctypedef npy_float128 float128_t + */ +typedef npy_float64 __pyx_t_5numpy_float64_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":754 + * # The int types are mapped a bit surprising -- + * # numpy.int corresponds to 'l' and numpy.long to 'q' + * ctypedef npy_long int_t # <<<<<<<<<<<<<< + * ctypedef npy_longlong longlong_t + * + */ +typedef npy_long __pyx_t_5numpy_int_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":755 + * # numpy.int corresponds to 'l' and numpy.long to 'q' + * ctypedef npy_long int_t + * ctypedef npy_longlong longlong_t # <<<<<<<<<<<<<< + * + * ctypedef npy_ulong uint_t + */ +typedef npy_longlong __pyx_t_5numpy_longlong_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":757 + * ctypedef npy_longlong longlong_t + * + * ctypedef npy_ulong uint_t # <<<<<<<<<<<<<< + * ctypedef npy_ulonglong ulonglong_t + * + */ +typedef npy_ulong __pyx_t_5numpy_uint_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":758 + * + * ctypedef npy_ulong uint_t + * ctypedef npy_ulonglong ulonglong_t # <<<<<<<<<<<<<< + * + * ctypedef npy_intp intp_t + */ +typedef npy_ulonglong __pyx_t_5numpy_ulonglong_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":760 + * ctypedef npy_ulonglong ulonglong_t + * + * ctypedef npy_intp intp_t # <<<<<<<<<<<<<< + * ctypedef npy_uintp uintp_t + * + */ +typedef npy_intp __pyx_t_5numpy_intp_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":761 + * + * ctypedef npy_intp intp_t + * ctypedef npy_uintp uintp_t # <<<<<<<<<<<<<< + * + * ctypedef npy_double float_t + */ +typedef npy_uintp __pyx_t_5numpy_uintp_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":763 + * ctypedef npy_uintp uintp_t + * + * ctypedef npy_double float_t # <<<<<<<<<<<<<< + * ctypedef npy_double double_t + * ctypedef npy_longdouble longdouble_t + */ +typedef npy_double __pyx_t_5numpy_float_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":764 + * + * ctypedef npy_double float_t + * ctypedef npy_double double_t # <<<<<<<<<<<<<< + * ctypedef npy_longdouble longdouble_t + * + */ +typedef npy_double __pyx_t_5numpy_double_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":765 + * ctypedef npy_double float_t + * ctypedef npy_double double_t + * ctypedef npy_longdouble longdouble_t # <<<<<<<<<<<<<< + * + * ctypedef npy_cfloat cfloat_t + */ +typedef npy_longdouble __pyx_t_5numpy_longdouble_t; + +/* "fairseq/data/token_block_utils_fast.pyx":18 + * + * DTYPE = np.int64 + * ctypedef int64_t DTYPE_t # <<<<<<<<<<<<<< + * + * + */ +typedef int64_t __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t; +/* #### Code section: complex_type_declarations ### */ +/* Declarations.proto */ +#if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) + #ifdef __cplusplus + typedef ::std::complex< float > __pyx_t_float_complex; + #else + typedef float _Complex __pyx_t_float_complex; + #endif +#else + typedef struct { float real, imag; } __pyx_t_float_complex; +#endif +static CYTHON_INLINE __pyx_t_float_complex __pyx_t_float_complex_from_parts(float, float); + +/* Declarations.proto */ +#if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) + #ifdef __cplusplus + typedef ::std::complex< double > __pyx_t_double_complex; + #else + typedef double _Complex __pyx_t_double_complex; + #endif +#else + typedef struct { double real, imag; } __pyx_t_double_complex; +#endif +static CYTHON_INLINE __pyx_t_double_complex __pyx_t_double_complex_from_parts(double, double); + +/* #### Code section: type_declarations ### */ + +/*--- Type declarations ---*/ +struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher; +struct __pyx_array_obj; +struct __pyx_MemviewEnum_obj; +struct __pyx_memoryview_obj; +struct __pyx_memoryviewslice_obj; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":767 + * ctypedef npy_longdouble longdouble_t + * + * ctypedef npy_cfloat cfloat_t # <<<<<<<<<<<<<< + * ctypedef npy_cdouble cdouble_t + * ctypedef npy_clongdouble clongdouble_t + */ +typedef npy_cfloat __pyx_t_5numpy_cfloat_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":768 + * + * ctypedef npy_cfloat cfloat_t + * ctypedef npy_cdouble cdouble_t # <<<<<<<<<<<<<< + * ctypedef npy_clongdouble clongdouble_t + * + */ +typedef npy_cdouble __pyx_t_5numpy_cdouble_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":769 + * ctypedef npy_cfloat cfloat_t + * ctypedef npy_cdouble cdouble_t + * ctypedef npy_clongdouble clongdouble_t # <<<<<<<<<<<<<< + * + * ctypedef npy_cdouble complex_t + */ +typedef npy_clongdouble __pyx_t_5numpy_clongdouble_t; + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":771 + * ctypedef npy_clongdouble clongdouble_t + * + * ctypedef npy_cdouble complex_t # <<<<<<<<<<<<<< + * + * cdef inline object PyArray_MultiIterNew1(a): + */ +typedef npy_cdouble __pyx_t_5numpy_complex_t; + +/* "fairseq/data/token_block_utils_fast.pyx":141 + * + * + * cdef class DatasetSearcher(object): # <<<<<<<<<<<<<< + * """Helper for mapping "flat" indices to indices and offsets in an + * underlying dataset.""" + */ +struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher { + PyObject_HEAD + struct __pyx_vtabstruct_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *__pyx_vtab; + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t current_i; + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t current_offset; + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t current_index; + __Pyx_memviewslice sizes; +}; + + +/* "View.MemoryView":114 + * @cython.collection_type("sequence") + * @cname("__pyx_array") + * cdef class array: # <<<<<<<<<<<<<< + * + * cdef: + */ +struct __pyx_array_obj { + PyObject_HEAD + struct __pyx_vtabstruct_array *__pyx_vtab; + char *data; + Py_ssize_t len; + char *format; + int ndim; + Py_ssize_t *_shape; + Py_ssize_t *_strides; + Py_ssize_t itemsize; + PyObject *mode; + PyObject *_format; + void (*callback_free_data)(void *); + int free_data; + int dtype_is_object; +}; + + +/* "View.MemoryView":302 + * + * @cname('__pyx_MemviewEnum') + * cdef class Enum(object): # <<<<<<<<<<<<<< + * cdef object name + * def __init__(self, name): + */ +struct __pyx_MemviewEnum_obj { + PyObject_HEAD + PyObject *name; +}; + + +/* "View.MemoryView":337 + * + * @cname('__pyx_memoryview') + * cdef class memoryview: # <<<<<<<<<<<<<< + * + * cdef object obj + */ +struct __pyx_memoryview_obj { + PyObject_HEAD + struct __pyx_vtabstruct_memoryview *__pyx_vtab; + PyObject *obj; + PyObject *_size; + PyObject *_array_interface; + PyThread_type_lock lock; + __pyx_atomic_int_type acquisition_count; + Py_buffer view; + int flags; + int dtype_is_object; + __Pyx_TypeInfo *typeinfo; +}; + + +/* "View.MemoryView":952 + * @cython.collection_type("sequence") + * @cname('__pyx_memoryviewslice') + * cdef class _memoryviewslice(memoryview): # <<<<<<<<<<<<<< + * "Internal class for passing memoryview slices to Python" + * + */ +struct __pyx_memoryviewslice_obj { + struct __pyx_memoryview_obj __pyx_base; + __Pyx_memviewslice from_slice; + PyObject *from_object; + PyObject *(*to_object_func)(char *); + int (*to_dtype_func)(char *, PyObject *); +}; + + + +/* "fairseq/data/token_block_utils_fast.pyx":141 + * + * + * cdef class DatasetSearcher(object): # <<<<<<<<<<<<<< + * """Helper for mapping "flat" indices to indices and offsets in an + * underlying dataset.""" + */ + +struct __pyx_vtabstruct_7fairseq_4data_22token_block_utils_fast_DatasetSearcher { + PyObject *(*reset)(struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *); + int (*step)(struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *, __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t); + PyObject *(*seek)(struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *, __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t); +}; +static struct __pyx_vtabstruct_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *__pyx_vtabptr_7fairseq_4data_22token_block_utils_fast_DatasetSearcher; + + +/* "View.MemoryView":114 + * @cython.collection_type("sequence") + * @cname("__pyx_array") + * cdef class array: # <<<<<<<<<<<<<< + * + * cdef: + */ + +struct __pyx_vtabstruct_array { + PyObject *(*get_memview)(struct __pyx_array_obj *); +}; +static struct __pyx_vtabstruct_array *__pyx_vtabptr_array; + + +/* "View.MemoryView":337 + * + * @cname('__pyx_memoryview') + * cdef class memoryview: # <<<<<<<<<<<<<< + * + * cdef object obj + */ + +struct __pyx_vtabstruct_memoryview { + char *(*get_item_pointer)(struct __pyx_memoryview_obj *, PyObject *); + PyObject *(*is_slice)(struct __pyx_memoryview_obj *, PyObject *); + PyObject *(*setitem_slice_assignment)(struct __pyx_memoryview_obj *, PyObject *, PyObject *); + PyObject *(*setitem_slice_assign_scalar)(struct __pyx_memoryview_obj *, struct __pyx_memoryview_obj *, PyObject *); + PyObject *(*setitem_indexed)(struct __pyx_memoryview_obj *, PyObject *, PyObject *); + PyObject *(*convert_item_to_object)(struct __pyx_memoryview_obj *, char *); + PyObject *(*assign_item_from_object)(struct __pyx_memoryview_obj *, char *, PyObject *); + PyObject *(*_get_base)(struct __pyx_memoryview_obj *); +}; +static struct __pyx_vtabstruct_memoryview *__pyx_vtabptr_memoryview; + + +/* "View.MemoryView":952 + * @cython.collection_type("sequence") + * @cname('__pyx_memoryviewslice') + * cdef class _memoryviewslice(memoryview): # <<<<<<<<<<<<<< + * "Internal class for passing memoryview slices to Python" + * + */ + +struct __pyx_vtabstruct__memoryviewslice { + struct __pyx_vtabstruct_memoryview __pyx_base; +}; +static struct __pyx_vtabstruct__memoryviewslice *__pyx_vtabptr__memoryviewslice; +/* #### Code section: utility_code_proto ### */ + +/* --- Runtime support code (head) --- */ +/* Refnanny.proto */ +#ifndef CYTHON_REFNANNY + #define CYTHON_REFNANNY 0 +#endif +#if CYTHON_REFNANNY + typedef struct { + void (*INCREF)(void*, PyObject*, Py_ssize_t); + void (*DECREF)(void*, PyObject*, Py_ssize_t); + void (*GOTREF)(void*, PyObject*, Py_ssize_t); + void (*GIVEREF)(void*, PyObject*, Py_ssize_t); + void* (*SetupContext)(const char*, Py_ssize_t, const char*); + void (*FinishContext)(void**); + } __Pyx_RefNannyAPIStruct; + static __Pyx_RefNannyAPIStruct *__Pyx_RefNanny = NULL; + static __Pyx_RefNannyAPIStruct *__Pyx_RefNannyImportAPI(const char *modname); + #define __Pyx_RefNannyDeclarations void *__pyx_refnanny = NULL; +#ifdef WITH_THREAD + #define __Pyx_RefNannySetupContext(name, acquire_gil)\ + if (acquire_gil) {\ + PyGILState_STATE __pyx_gilstate_save = PyGILState_Ensure();\ + __pyx_refnanny = __Pyx_RefNanny->SetupContext((name), (__LINE__), (__FILE__));\ + PyGILState_Release(__pyx_gilstate_save);\ + } else {\ + __pyx_refnanny = __Pyx_RefNanny->SetupContext((name), (__LINE__), (__FILE__));\ + } + #define __Pyx_RefNannyFinishContextNogil() {\ + PyGILState_STATE __pyx_gilstate_save = PyGILState_Ensure();\ + __Pyx_RefNannyFinishContext();\ + PyGILState_Release(__pyx_gilstate_save);\ + } +#else + #define __Pyx_RefNannySetupContext(name, acquire_gil)\ + __pyx_refnanny = __Pyx_RefNanny->SetupContext((name), (__LINE__), (__FILE__)) + #define __Pyx_RefNannyFinishContextNogil() __Pyx_RefNannyFinishContext() +#endif + #define __Pyx_RefNannyFinishContextNogil() {\ + PyGILState_STATE __pyx_gilstate_save = PyGILState_Ensure();\ + __Pyx_RefNannyFinishContext();\ + PyGILState_Release(__pyx_gilstate_save);\ + } + #define __Pyx_RefNannyFinishContext()\ + __Pyx_RefNanny->FinishContext(&__pyx_refnanny) + #define __Pyx_INCREF(r) __Pyx_RefNanny->INCREF(__pyx_refnanny, (PyObject *)(r), (__LINE__)) + #define __Pyx_DECREF(r) __Pyx_RefNanny->DECREF(__pyx_refnanny, (PyObject *)(r), (__LINE__)) + #define __Pyx_GOTREF(r) __Pyx_RefNanny->GOTREF(__pyx_refnanny, (PyObject *)(r), (__LINE__)) + #define __Pyx_GIVEREF(r) __Pyx_RefNanny->GIVEREF(__pyx_refnanny, (PyObject *)(r), (__LINE__)) + #define __Pyx_XINCREF(r) do { if((r) == NULL); else {__Pyx_INCREF(r); }} while(0) + #define __Pyx_XDECREF(r) do { if((r) == NULL); else {__Pyx_DECREF(r); }} while(0) + #define __Pyx_XGOTREF(r) do { if((r) == NULL); else {__Pyx_GOTREF(r); }} while(0) + #define __Pyx_XGIVEREF(r) do { if((r) == NULL); else {__Pyx_GIVEREF(r);}} while(0) +#else + #define __Pyx_RefNannyDeclarations + #define __Pyx_RefNannySetupContext(name, acquire_gil) + #define __Pyx_RefNannyFinishContextNogil() + #define __Pyx_RefNannyFinishContext() + #define __Pyx_INCREF(r) Py_INCREF(r) + #define __Pyx_DECREF(r) Py_DECREF(r) + #define __Pyx_GOTREF(r) + #define __Pyx_GIVEREF(r) + #define __Pyx_XINCREF(r) Py_XINCREF(r) + #define __Pyx_XDECREF(r) Py_XDECREF(r) + #define __Pyx_XGOTREF(r) + #define __Pyx_XGIVEREF(r) +#endif +#define __Pyx_Py_XDECREF_SET(r, v) do {\ + PyObject *tmp = (PyObject *) r;\ + r = v; Py_XDECREF(tmp);\ + } while (0) +#define __Pyx_XDECREF_SET(r, v) do {\ + PyObject *tmp = (PyObject *) r;\ + r = v; __Pyx_XDECREF(tmp);\ + } while (0) +#define __Pyx_DECREF_SET(r, v) do {\ + PyObject *tmp = (PyObject *) r;\ + r = v; __Pyx_DECREF(tmp);\ + } while (0) +#define __Pyx_CLEAR(r) do { PyObject* tmp = ((PyObject*)(r)); r = NULL; __Pyx_DECREF(tmp);} while(0) +#define __Pyx_XCLEAR(r) do { if((r) != NULL) {PyObject* tmp = ((PyObject*)(r)); r = NULL; __Pyx_DECREF(tmp);}} while(0) + +/* PyErrExceptionMatches.proto */ +#if CYTHON_FAST_THREAD_STATE +#define __Pyx_PyErr_ExceptionMatches(err) __Pyx_PyErr_ExceptionMatchesInState(__pyx_tstate, err) +static CYTHON_INLINE int __Pyx_PyErr_ExceptionMatchesInState(PyThreadState* tstate, PyObject* err); +#else +#define __Pyx_PyErr_ExceptionMatches(err) PyErr_ExceptionMatches(err) +#endif + +/* PyThreadStateGet.proto */ +#if CYTHON_FAST_THREAD_STATE +#define __Pyx_PyThreadState_declare PyThreadState *__pyx_tstate; +#define __Pyx_PyThreadState_assign __pyx_tstate = __Pyx_PyThreadState_Current; +#if PY_VERSION_HEX >= 0x030C00A6 +#define __Pyx_PyErr_Occurred() (__pyx_tstate->current_exception != NULL) +#define __Pyx_PyErr_CurrentExceptionType() (__pyx_tstate->current_exception ? (PyObject*) Py_TYPE(__pyx_tstate->current_exception) : (PyObject*) NULL) +#else +#define __Pyx_PyErr_Occurred() (__pyx_tstate->curexc_type != NULL) +#define __Pyx_PyErr_CurrentExceptionType() (__pyx_tstate->curexc_type) +#endif +#else +#define __Pyx_PyThreadState_declare +#define __Pyx_PyThreadState_assign +#define __Pyx_PyErr_Occurred() (PyErr_Occurred() != NULL) +#define __Pyx_PyErr_CurrentExceptionType() PyErr_Occurred() +#endif + +/* PyErrFetchRestore.proto */ +#if CYTHON_FAST_THREAD_STATE +#define __Pyx_PyErr_Clear() __Pyx_ErrRestore(NULL, NULL, NULL) +#define __Pyx_ErrRestoreWithState(type, value, tb) __Pyx_ErrRestoreInState(PyThreadState_GET(), type, value, tb) +#define __Pyx_ErrFetchWithState(type, value, tb) __Pyx_ErrFetchInState(PyThreadState_GET(), type, value, tb) +#define __Pyx_ErrRestore(type, value, tb) __Pyx_ErrRestoreInState(__pyx_tstate, type, value, tb) +#define __Pyx_ErrFetch(type, value, tb) __Pyx_ErrFetchInState(__pyx_tstate, type, value, tb) +static CYTHON_INLINE void __Pyx_ErrRestoreInState(PyThreadState *tstate, PyObject *type, PyObject *value, PyObject *tb); +static CYTHON_INLINE void __Pyx_ErrFetchInState(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb); +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030C00A6 +#define __Pyx_PyErr_SetNone(exc) (Py_INCREF(exc), __Pyx_ErrRestore((exc), NULL, NULL)) +#else +#define __Pyx_PyErr_SetNone(exc) PyErr_SetNone(exc) +#endif +#else +#define __Pyx_PyErr_Clear() PyErr_Clear() +#define __Pyx_PyErr_SetNone(exc) PyErr_SetNone(exc) +#define __Pyx_ErrRestoreWithState(type, value, tb) PyErr_Restore(type, value, tb) +#define __Pyx_ErrFetchWithState(type, value, tb) PyErr_Fetch(type, value, tb) +#define __Pyx_ErrRestoreInState(tstate, type, value, tb) PyErr_Restore(type, value, tb) +#define __Pyx_ErrFetchInState(tstate, type, value, tb) PyErr_Fetch(type, value, tb) +#define __Pyx_ErrRestore(type, value, tb) PyErr_Restore(type, value, tb) +#define __Pyx_ErrFetch(type, value, tb) PyErr_Fetch(type, value, tb) +#endif + +/* PyObjectGetAttrStr.proto */ +#if CYTHON_USE_TYPE_SLOTS +static CYTHON_INLINE PyObject* __Pyx_PyObject_GetAttrStr(PyObject* obj, PyObject* attr_name); +#else +#define __Pyx_PyObject_GetAttrStr(o,n) PyObject_GetAttr(o,n) +#endif + +/* PyObjectGetAttrStrNoError.proto */ +static CYTHON_INLINE PyObject* __Pyx_PyObject_GetAttrStrNoError(PyObject* obj, PyObject* attr_name); + +/* GetBuiltinName.proto */ +static PyObject *__Pyx_GetBuiltinName(PyObject *name); + +/* TupleAndListFromArray.proto */ +#if CYTHON_COMPILING_IN_CPYTHON +static CYTHON_INLINE PyObject* __Pyx_PyList_FromArray(PyObject *const *src, Py_ssize_t n); +static CYTHON_INLINE PyObject* __Pyx_PyTuple_FromArray(PyObject *const *src, Py_ssize_t n); +#endif + +/* IncludeStringH.proto */ +#include + +/* BytesEquals.proto */ +static CYTHON_INLINE int __Pyx_PyBytes_Equals(PyObject* s1, PyObject* s2, int equals); + +/* UnicodeEquals.proto */ +static CYTHON_INLINE int __Pyx_PyUnicode_Equals(PyObject* s1, PyObject* s2, int equals); + +/* fastcall.proto */ +#if CYTHON_AVOID_BORROWED_REFS + #define __Pyx_Arg_VARARGS(args, i) PySequence_GetItem(args, i) +#elif CYTHON_ASSUME_SAFE_MACROS + #define __Pyx_Arg_VARARGS(args, i) PyTuple_GET_ITEM(args, i) +#else + #define __Pyx_Arg_VARARGS(args, i) PyTuple_GetItem(args, i) +#endif +#if CYTHON_AVOID_BORROWED_REFS + #define __Pyx_Arg_NewRef_VARARGS(arg) __Pyx_NewRef(arg) + #define __Pyx_Arg_XDECREF_VARARGS(arg) Py_XDECREF(arg) +#else + #define __Pyx_Arg_NewRef_VARARGS(arg) arg + #define __Pyx_Arg_XDECREF_VARARGS(arg) +#endif +#define __Pyx_NumKwargs_VARARGS(kwds) PyDict_Size(kwds) +#define __Pyx_KwValues_VARARGS(args, nargs) NULL +#define __Pyx_GetKwValue_VARARGS(kw, kwvalues, s) __Pyx_PyDict_GetItemStrWithError(kw, s) +#define __Pyx_KwargsAsDict_VARARGS(kw, kwvalues) PyDict_Copy(kw) +#if CYTHON_METH_FASTCALL + #define __Pyx_Arg_FASTCALL(args, i) args[i] + #define __Pyx_NumKwargs_FASTCALL(kwds) PyTuple_GET_SIZE(kwds) + #define __Pyx_KwValues_FASTCALL(args, nargs) ((args) + (nargs)) + static CYTHON_INLINE PyObject * __Pyx_GetKwValue_FASTCALL(PyObject *kwnames, PyObject *const *kwvalues, PyObject *s); +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030d0000 + CYTHON_UNUSED static PyObject *__Pyx_KwargsAsDict_FASTCALL(PyObject *kwnames, PyObject *const *kwvalues); + #else + #define __Pyx_KwargsAsDict_FASTCALL(kw, kwvalues) _PyStack_AsDict(kwvalues, kw) + #endif + #define __Pyx_Arg_NewRef_FASTCALL(arg) arg /* no-op, __Pyx_Arg_FASTCALL is direct and this needs + to have the same reference counting */ + #define __Pyx_Arg_XDECREF_FASTCALL(arg) +#else + #define __Pyx_Arg_FASTCALL __Pyx_Arg_VARARGS + #define __Pyx_NumKwargs_FASTCALL __Pyx_NumKwargs_VARARGS + #define __Pyx_KwValues_FASTCALL __Pyx_KwValues_VARARGS + #define __Pyx_GetKwValue_FASTCALL __Pyx_GetKwValue_VARARGS + #define __Pyx_KwargsAsDict_FASTCALL __Pyx_KwargsAsDict_VARARGS + #define __Pyx_Arg_NewRef_FASTCALL(arg) __Pyx_Arg_NewRef_VARARGS(arg) + #define __Pyx_Arg_XDECREF_FASTCALL(arg) __Pyx_Arg_XDECREF_VARARGS(arg) +#endif +#if CYTHON_COMPILING_IN_CPYTHON && CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS +#define __Pyx_ArgsSlice_VARARGS(args, start, stop) __Pyx_PyTuple_FromArray(&__Pyx_Arg_VARARGS(args, start), stop - start) +#define __Pyx_ArgsSlice_FASTCALL(args, start, stop) __Pyx_PyTuple_FromArray(&__Pyx_Arg_FASTCALL(args, start), stop - start) +#else +#define __Pyx_ArgsSlice_VARARGS(args, start, stop) PyTuple_GetSlice(args, start, stop) +#define __Pyx_ArgsSlice_FASTCALL(args, start, stop) PyTuple_GetSlice(args, start, stop) +#endif + +/* RaiseArgTupleInvalid.proto */ +static void __Pyx_RaiseArgtupleInvalid(const char* func_name, int exact, + Py_ssize_t num_min, Py_ssize_t num_max, Py_ssize_t num_found); + +/* RaiseDoubleKeywords.proto */ +static void __Pyx_RaiseDoubleKeywordsError(const char* func_name, PyObject* kw_name); + +/* ParseKeywords.proto */ +static int __Pyx_ParseOptionalKeywords(PyObject *kwds, PyObject *const *kwvalues, + PyObject **argnames[], + PyObject *kwds2, PyObject *values[], Py_ssize_t num_pos_args, + const char* function_name); + +/* ArgTypeTest.proto */ +#define __Pyx_ArgTypeTest(obj, type, none_allowed, name, exact)\ + ((likely(__Pyx_IS_TYPE(obj, type) | (none_allowed && (obj == Py_None)))) ? 1 :\ + __Pyx__ArgTypeTest(obj, type, name, exact)) +static int __Pyx__ArgTypeTest(PyObject *obj, PyTypeObject *type, const char *name, int exact); + +/* RaiseException.proto */ +static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause); + +/* PyFunctionFastCall.proto */ +#if CYTHON_FAST_PYCALL +#if !CYTHON_VECTORCALL +#define __Pyx_PyFunction_FastCall(func, args, nargs)\ + __Pyx_PyFunction_FastCallDict((func), (args), (nargs), NULL) +static PyObject *__Pyx_PyFunction_FastCallDict(PyObject *func, PyObject **args, Py_ssize_t nargs, PyObject *kwargs); +#endif +#define __Pyx_BUILD_ASSERT_EXPR(cond)\ + (sizeof(char [1 - 2*!(cond)]) - 1) +#ifndef Py_MEMBER_SIZE +#define Py_MEMBER_SIZE(type, member) sizeof(((type *)0)->member) +#endif +#if !CYTHON_VECTORCALL +#if PY_VERSION_HEX >= 0x03080000 + #include "frameobject.h" +#if PY_VERSION_HEX >= 0x030b00a6 && !CYTHON_COMPILING_IN_LIMITED_API + #ifndef Py_BUILD_CORE + #define Py_BUILD_CORE 1 + #endif + #include "internal/pycore_frame.h" +#endif + #define __Pxy_PyFrame_Initialize_Offsets() + #define __Pyx_PyFrame_GetLocalsplus(frame) ((frame)->f_localsplus) +#else + static size_t __pyx_pyframe_localsplus_offset = 0; + #include "frameobject.h" + #define __Pxy_PyFrame_Initialize_Offsets()\ + ((void)__Pyx_BUILD_ASSERT_EXPR(sizeof(PyFrameObject) == offsetof(PyFrameObject, f_localsplus) + Py_MEMBER_SIZE(PyFrameObject, f_localsplus)),\ + (void)(__pyx_pyframe_localsplus_offset = ((size_t)PyFrame_Type.tp_basicsize) - Py_MEMBER_SIZE(PyFrameObject, f_localsplus))) + #define __Pyx_PyFrame_GetLocalsplus(frame)\ + (assert(__pyx_pyframe_localsplus_offset), (PyObject **)(((char *)(frame)) + __pyx_pyframe_localsplus_offset)) +#endif +#endif +#endif + +/* PyObjectCall.proto */ +#if CYTHON_COMPILING_IN_CPYTHON +static CYTHON_INLINE PyObject* __Pyx_PyObject_Call(PyObject *func, PyObject *arg, PyObject *kw); +#else +#define __Pyx_PyObject_Call(func, arg, kw) PyObject_Call(func, arg, kw) +#endif + +/* PyObjectCallMethO.proto */ +#if CYTHON_COMPILING_IN_CPYTHON +static CYTHON_INLINE PyObject* __Pyx_PyObject_CallMethO(PyObject *func, PyObject *arg); +#endif + +/* PyObjectFastCall.proto */ +#define __Pyx_PyObject_FastCall(func, args, nargs) __Pyx_PyObject_FastCallDict(func, args, (size_t)(nargs), NULL) +static CYTHON_INLINE PyObject* __Pyx_PyObject_FastCallDict(PyObject *func, PyObject **args, size_t nargs, PyObject *kwargs); + +/* RaiseUnexpectedTypeError.proto */ +static int __Pyx_RaiseUnexpectedTypeError(const char *expected, PyObject *obj); + +/* GCCDiagnostics.proto */ +#if !defined(__INTEL_COMPILER) && defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)) +#define __Pyx_HAS_GCC_DIAGNOSTIC +#endif + +/* BuildPyUnicode.proto */ +static PyObject* __Pyx_PyUnicode_BuildFromAscii(Py_ssize_t ulength, char* chars, int clength, + int prepend_sign, char padding_char); + +/* CIntToPyUnicode.proto */ +static CYTHON_INLINE PyObject* __Pyx_PyUnicode_From_int(int value, Py_ssize_t width, char padding_char, char format_char); + +/* CIntToPyUnicode.proto */ +static CYTHON_INLINE PyObject* __Pyx_PyUnicode_From_Py_ssize_t(Py_ssize_t value, Py_ssize_t width, char padding_char, char format_char); + +/* JoinPyUnicode.proto */ +static PyObject* __Pyx_PyUnicode_Join(PyObject* value_tuple, Py_ssize_t value_count, Py_ssize_t result_ulength, + Py_UCS4 max_char); + +/* StrEquals.proto */ +#if PY_MAJOR_VERSION >= 3 +#define __Pyx_PyString_Equals __Pyx_PyUnicode_Equals +#else +#define __Pyx_PyString_Equals __Pyx_PyBytes_Equals +#endif + +/* PyObjectFormatSimple.proto */ +#if CYTHON_COMPILING_IN_PYPY + #define __Pyx_PyObject_FormatSimple(s, f) (\ + likely(PyUnicode_CheckExact(s)) ? (Py_INCREF(s), s) :\ + PyObject_Format(s, f)) +#elif PY_MAJOR_VERSION < 3 + #define __Pyx_PyObject_FormatSimple(s, f) (\ + likely(PyUnicode_CheckExact(s)) ? (Py_INCREF(s), s) :\ + likely(PyString_CheckExact(s)) ? PyUnicode_FromEncodedObject(s, NULL, "strict") :\ + PyObject_Format(s, f)) +#elif CYTHON_USE_TYPE_SLOTS + #define __Pyx_PyObject_FormatSimple(s, f) (\ + likely(PyUnicode_CheckExact(s)) ? (Py_INCREF(s), s) :\ + likely(PyLong_CheckExact(s)) ? PyLong_Type.tp_repr(s) :\ + likely(PyFloat_CheckExact(s)) ? PyFloat_Type.tp_repr(s) :\ + PyObject_Format(s, f)) +#else + #define __Pyx_PyObject_FormatSimple(s, f) (\ + likely(PyUnicode_CheckExact(s)) ? (Py_INCREF(s), s) :\ + PyObject_Format(s, f)) +#endif + +CYTHON_UNUSED static int __pyx_array_getbuffer(PyObject *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags); /*proto*/ +static PyObject *__pyx_array_get_memview(struct __pyx_array_obj *); /*proto*/ +/* GetAttr.proto */ +static CYTHON_INLINE PyObject *__Pyx_GetAttr(PyObject *, PyObject *); + +/* GetItemInt.proto */ +#define __Pyx_GetItemInt(o, i, type, is_signed, to_py_func, is_list, wraparound, boundscheck)\ + (__Pyx_fits_Py_ssize_t(i, type, is_signed) ?\ + __Pyx_GetItemInt_Fast(o, (Py_ssize_t)i, is_list, wraparound, boundscheck) :\ + (is_list ? (PyErr_SetString(PyExc_IndexError, "list index out of range"), (PyObject*)NULL) :\ + __Pyx_GetItemInt_Generic(o, to_py_func(i)))) +#define __Pyx_GetItemInt_List(o, i, type, is_signed, to_py_func, is_list, wraparound, boundscheck)\ + (__Pyx_fits_Py_ssize_t(i, type, is_signed) ?\ + __Pyx_GetItemInt_List_Fast(o, (Py_ssize_t)i, wraparound, boundscheck) :\ + (PyErr_SetString(PyExc_IndexError, "list index out of range"), (PyObject*)NULL)) +static CYTHON_INLINE PyObject *__Pyx_GetItemInt_List_Fast(PyObject *o, Py_ssize_t i, + int wraparound, int boundscheck); +#define __Pyx_GetItemInt_Tuple(o, i, type, is_signed, to_py_func, is_list, wraparound, boundscheck)\ + (__Pyx_fits_Py_ssize_t(i, type, is_signed) ?\ + __Pyx_GetItemInt_Tuple_Fast(o, (Py_ssize_t)i, wraparound, boundscheck) :\ + (PyErr_SetString(PyExc_IndexError, "tuple index out of range"), (PyObject*)NULL)) +static CYTHON_INLINE PyObject *__Pyx_GetItemInt_Tuple_Fast(PyObject *o, Py_ssize_t i, + int wraparound, int boundscheck); +static PyObject *__Pyx_GetItemInt_Generic(PyObject *o, PyObject* j); +static CYTHON_INLINE PyObject *__Pyx_GetItemInt_Fast(PyObject *o, Py_ssize_t i, + int is_list, int wraparound, int boundscheck); + +/* PyObjectCallOneArg.proto */ +static CYTHON_INLINE PyObject* __Pyx_PyObject_CallOneArg(PyObject *func, PyObject *arg); + +/* ObjectGetItem.proto */ +#if CYTHON_USE_TYPE_SLOTS +static CYTHON_INLINE PyObject *__Pyx_PyObject_GetItem(PyObject *obj, PyObject *key); +#else +#define __Pyx_PyObject_GetItem(obj, key) PyObject_GetItem(obj, key) +#endif + +/* KeywordStringCheck.proto */ +static int __Pyx_CheckKeywordStrings(PyObject *kw, const char* function_name, int kw_allowed); + +/* DivInt[Py_ssize_t].proto */ +static CYTHON_INLINE Py_ssize_t __Pyx_div_Py_ssize_t(Py_ssize_t, Py_ssize_t); + +/* UnaryNegOverflows.proto */ +#define __Pyx_UNARY_NEG_WOULD_OVERFLOW(x)\ + (((x) < 0) & ((unsigned long)(x) == 0-(unsigned long)(x))) + +/* GetAttr3.proto */ +static CYTHON_INLINE PyObject *__Pyx_GetAttr3(PyObject *, PyObject *, PyObject *); + +/* PyDictVersioning.proto */ +#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_TYPE_SLOTS +#define __PYX_DICT_VERSION_INIT ((PY_UINT64_T) -1) +#define __PYX_GET_DICT_VERSION(dict) (((PyDictObject*)(dict))->ma_version_tag) +#define __PYX_UPDATE_DICT_CACHE(dict, value, cache_var, version_var)\ + (version_var) = __PYX_GET_DICT_VERSION(dict);\ + (cache_var) = (value); +#define __PYX_PY_DICT_LOOKUP_IF_MODIFIED(VAR, DICT, LOOKUP) {\ + static PY_UINT64_T __pyx_dict_version = 0;\ + static PyObject *__pyx_dict_cached_value = NULL;\ + if (likely(__PYX_GET_DICT_VERSION(DICT) == __pyx_dict_version)) {\ + (VAR) = __pyx_dict_cached_value;\ + } else {\ + (VAR) = __pyx_dict_cached_value = (LOOKUP);\ + __pyx_dict_version = __PYX_GET_DICT_VERSION(DICT);\ + }\ +} +static CYTHON_INLINE PY_UINT64_T __Pyx_get_tp_dict_version(PyObject *obj); +static CYTHON_INLINE PY_UINT64_T __Pyx_get_object_dict_version(PyObject *obj); +static CYTHON_INLINE int __Pyx_object_dict_version_matches(PyObject* obj, PY_UINT64_T tp_dict_version, PY_UINT64_T obj_dict_version); +#else +#define __PYX_GET_DICT_VERSION(dict) (0) +#define __PYX_UPDATE_DICT_CACHE(dict, value, cache_var, version_var) +#define __PYX_PY_DICT_LOOKUP_IF_MODIFIED(VAR, DICT, LOOKUP) (VAR) = (LOOKUP); +#endif + +/* GetModuleGlobalName.proto */ +#if CYTHON_USE_DICT_VERSIONS +#define __Pyx_GetModuleGlobalName(var, name) do {\ + static PY_UINT64_T __pyx_dict_version = 0;\ + static PyObject *__pyx_dict_cached_value = NULL;\ + (var) = (likely(__pyx_dict_version == __PYX_GET_DICT_VERSION(__pyx_d))) ?\ + (likely(__pyx_dict_cached_value) ? __Pyx_NewRef(__pyx_dict_cached_value) : __Pyx_GetBuiltinName(name)) :\ + __Pyx__GetModuleGlobalName(name, &__pyx_dict_version, &__pyx_dict_cached_value);\ +} while(0) +#define __Pyx_GetModuleGlobalNameUncached(var, name) do {\ + PY_UINT64_T __pyx_dict_version;\ + PyObject *__pyx_dict_cached_value;\ + (var) = __Pyx__GetModuleGlobalName(name, &__pyx_dict_version, &__pyx_dict_cached_value);\ +} while(0) +static PyObject *__Pyx__GetModuleGlobalName(PyObject *name, PY_UINT64_T *dict_version, PyObject **dict_cached_value); +#else +#define __Pyx_GetModuleGlobalName(var, name) (var) = __Pyx__GetModuleGlobalName(name) +#define __Pyx_GetModuleGlobalNameUncached(var, name) (var) = __Pyx__GetModuleGlobalName(name) +static CYTHON_INLINE PyObject *__Pyx__GetModuleGlobalName(PyObject *name); +#endif + +/* AssertionsEnabled.proto */ +#if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX < 0x02070600 && !defined(Py_OptimizeFlag) + #define __Pyx_init_assertions_enabled() (0) + #define __pyx_assertions_enabled() (1) +#elif CYTHON_COMPILING_IN_LIMITED_API || (CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030C0000) + static int __pyx_assertions_enabled_flag; + #define __pyx_assertions_enabled() (__pyx_assertions_enabled_flag) + static int __Pyx_init_assertions_enabled(void) { + PyObject *builtins, *debug, *debug_str; + int flag; + builtins = PyEval_GetBuiltins(); + if (!builtins) goto bad; + debug_str = PyUnicode_FromStringAndSize("__debug__", 9); + if (!debug_str) goto bad; + debug = PyObject_GetItem(builtins, debug_str); + Py_DECREF(debug_str); + if (!debug) goto bad; + flag = PyObject_IsTrue(debug); + Py_DECREF(debug); + if (flag == -1) goto bad; + __pyx_assertions_enabled_flag = flag; + return 0; + bad: + __pyx_assertions_enabled_flag = 1; + return -1; + } +#else + #define __Pyx_init_assertions_enabled() (0) + #define __pyx_assertions_enabled() (!Py_OptimizeFlag) +#endif + +/* RaiseTooManyValuesToUnpack.proto */ +static CYTHON_INLINE void __Pyx_RaiseTooManyValuesError(Py_ssize_t expected); + +/* RaiseNeedMoreValuesToUnpack.proto */ +static CYTHON_INLINE void __Pyx_RaiseNeedMoreValuesError(Py_ssize_t index); + +/* RaiseNoneIterError.proto */ +static CYTHON_INLINE void __Pyx_RaiseNoneNotIterableError(void); + +/* ExtTypeTest.proto */ +static CYTHON_INLINE int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type); + +/* GetTopmostException.proto */ +#if CYTHON_USE_EXC_INFO_STACK && CYTHON_FAST_THREAD_STATE +static _PyErr_StackItem * __Pyx_PyErr_GetTopmostException(PyThreadState *tstate); +#endif + +/* SaveResetException.proto */ +#if CYTHON_FAST_THREAD_STATE +#define __Pyx_ExceptionSave(type, value, tb) __Pyx__ExceptionSave(__pyx_tstate, type, value, tb) +static CYTHON_INLINE void __Pyx__ExceptionSave(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb); +#define __Pyx_ExceptionReset(type, value, tb) __Pyx__ExceptionReset(__pyx_tstate, type, value, tb) +static CYTHON_INLINE void __Pyx__ExceptionReset(PyThreadState *tstate, PyObject *type, PyObject *value, PyObject *tb); +#else +#define __Pyx_ExceptionSave(type, value, tb) PyErr_GetExcInfo(type, value, tb) +#define __Pyx_ExceptionReset(type, value, tb) PyErr_SetExcInfo(type, value, tb) +#endif + +/* GetException.proto */ +#if CYTHON_FAST_THREAD_STATE +#define __Pyx_GetException(type, value, tb) __Pyx__GetException(__pyx_tstate, type, value, tb) +static int __Pyx__GetException(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb); +#else +static int __Pyx_GetException(PyObject **type, PyObject **value, PyObject **tb); +#endif + +/* SwapException.proto */ +#if CYTHON_FAST_THREAD_STATE +#define __Pyx_ExceptionSwap(type, value, tb) __Pyx__ExceptionSwap(__pyx_tstate, type, value, tb) +static CYTHON_INLINE void __Pyx__ExceptionSwap(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb); +#else +static CYTHON_INLINE void __Pyx_ExceptionSwap(PyObject **type, PyObject **value, PyObject **tb); +#endif + +/* Import.proto */ +static PyObject *__Pyx_Import(PyObject *name, PyObject *from_list, int level); + +/* ImportDottedModule.proto */ +static PyObject *__Pyx_ImportDottedModule(PyObject *name, PyObject *parts_tuple); +#if PY_MAJOR_VERSION >= 3 +static PyObject *__Pyx_ImportDottedModule_WalkParts(PyObject *module, PyObject *name, PyObject *parts_tuple); +#endif + +/* FastTypeChecks.proto */ +#if CYTHON_COMPILING_IN_CPYTHON +#define __Pyx_TypeCheck(obj, type) __Pyx_IsSubtype(Py_TYPE(obj), (PyTypeObject *)type) +#define __Pyx_TypeCheck2(obj, type1, type2) __Pyx_IsAnySubtype2(Py_TYPE(obj), (PyTypeObject *)type1, (PyTypeObject *)type2) +static CYTHON_INLINE int __Pyx_IsSubtype(PyTypeObject *a, PyTypeObject *b); +static CYTHON_INLINE int __Pyx_IsAnySubtype2(PyTypeObject *cls, PyTypeObject *a, PyTypeObject *b); +static CYTHON_INLINE int __Pyx_PyErr_GivenExceptionMatches(PyObject *err, PyObject *type); +static CYTHON_INLINE int __Pyx_PyErr_GivenExceptionMatches2(PyObject *err, PyObject *type1, PyObject *type2); +#else +#define __Pyx_TypeCheck(obj, type) PyObject_TypeCheck(obj, (PyTypeObject *)type) +#define __Pyx_TypeCheck2(obj, type1, type2) (PyObject_TypeCheck(obj, (PyTypeObject *)type1) || PyObject_TypeCheck(obj, (PyTypeObject *)type2)) +#define __Pyx_PyErr_GivenExceptionMatches(err, type) PyErr_GivenExceptionMatches(err, type) +#define __Pyx_PyErr_GivenExceptionMatches2(err, type1, type2) (PyErr_GivenExceptionMatches(err, type1) || PyErr_GivenExceptionMatches(err, type2)) +#endif +#define __Pyx_PyErr_ExceptionMatches2(err1, err2) __Pyx_PyErr_GivenExceptionMatches2(__Pyx_PyErr_CurrentExceptionType(), err1, err2) +#define __Pyx_PyException_Check(obj) __Pyx_TypeCheck(obj, PyExc_Exception) + +CYTHON_UNUSED static int __pyx_memoryview_getbuffer(PyObject *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags); /*proto*/ +/* ListCompAppend.proto */ +#if CYTHON_USE_PYLIST_INTERNALS && CYTHON_ASSUME_SAFE_MACROS +static CYTHON_INLINE int __Pyx_ListComp_Append(PyObject* list, PyObject* x) { + PyListObject* L = (PyListObject*) list; + Py_ssize_t len = Py_SIZE(list); + if (likely(L->allocated > len)) { + Py_INCREF(x); + #if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030d0000 + L->ob_item[len] = x; + #else + PyList_SET_ITEM(list, len, x); + #endif + __Pyx_SET_SIZE(list, len + 1); + return 0; + } + return PyList_Append(list, x); +} +#else +#define __Pyx_ListComp_Append(L,x) PyList_Append(L,x) +#endif + +/* PySequenceMultiply.proto */ +#define __Pyx_PySequence_Multiply_Left(mul, seq) __Pyx_PySequence_Multiply(seq, mul) +static CYTHON_INLINE PyObject* __Pyx_PySequence_Multiply(PyObject *seq, Py_ssize_t mul); + +/* SetItemInt.proto */ +#define __Pyx_SetItemInt(o, i, v, type, is_signed, to_py_func, is_list, wraparound, boundscheck)\ + (__Pyx_fits_Py_ssize_t(i, type, is_signed) ?\ + __Pyx_SetItemInt_Fast(o, (Py_ssize_t)i, v, is_list, wraparound, boundscheck) :\ + (is_list ? (PyErr_SetString(PyExc_IndexError, "list assignment index out of range"), -1) :\ + __Pyx_SetItemInt_Generic(o, to_py_func(i), v))) +static int __Pyx_SetItemInt_Generic(PyObject *o, PyObject *j, PyObject *v); +static CYTHON_INLINE int __Pyx_SetItemInt_Fast(PyObject *o, Py_ssize_t i, PyObject *v, + int is_list, int wraparound, int boundscheck); + +/* RaiseUnboundLocalError.proto */ +static CYTHON_INLINE void __Pyx_RaiseUnboundLocalError(const char *varname); + +/* DivInt[long].proto */ +static CYTHON_INLINE long __Pyx_div_long(long, long); + +/* PySequenceContains.proto */ +static CYTHON_INLINE int __Pyx_PySequence_ContainsTF(PyObject* item, PyObject* seq, int eq) { + int result = PySequence_Contains(seq, item); + return unlikely(result < 0) ? result : (result == (eq == Py_EQ)); +} + +/* ImportFrom.proto */ +static PyObject* __Pyx_ImportFrom(PyObject* module, PyObject* name); + +/* HasAttr.proto */ +#if __PYX_LIMITED_VERSION_HEX >= 0x030d00A1 +#define __Pyx_HasAttr(o, n) PyObject_HasAttrWithError(o, n) +#else +static CYTHON_INLINE int __Pyx_HasAttr(PyObject *, PyObject *); +#endif + +/* IsLittleEndian.proto */ +static CYTHON_INLINE int __Pyx_Is_Little_Endian(void); + +/* BufferFormatCheck.proto */ +static const char* __Pyx_BufFmt_CheckString(__Pyx_BufFmt_Context* ctx, const char* ts); +static void __Pyx_BufFmt_Init(__Pyx_BufFmt_Context* ctx, + __Pyx_BufFmt_StackElem* stack, + __Pyx_TypeInfo* type); + +/* BufferGetAndValidate.proto */ +#define __Pyx_GetBufferAndValidate(buf, obj, dtype, flags, nd, cast, stack)\ + ((obj == Py_None || obj == NULL) ?\ + (__Pyx_ZeroBuffer(buf), 0) :\ + __Pyx__GetBufferAndValidate(buf, obj, dtype, flags, nd, cast, stack)) +static int __Pyx__GetBufferAndValidate(Py_buffer* buf, PyObject* obj, + __Pyx_TypeInfo* dtype, int flags, int nd, int cast, __Pyx_BufFmt_StackElem* stack); +static void __Pyx_ZeroBuffer(Py_buffer* buf); +static CYTHON_INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info); +static Py_ssize_t __Pyx_minusones[] = { -1, -1, -1, -1, -1, -1, -1, -1 }; +static Py_ssize_t __Pyx_zeros[] = { 0, 0, 0, 0, 0, 0, 0, 0 }; + +/* BufferFallbackError.proto */ +static void __Pyx_RaiseBufferFallbackError(void); + +/* ListAppend.proto */ +#if CYTHON_USE_PYLIST_INTERNALS && CYTHON_ASSUME_SAFE_MACROS +static CYTHON_INLINE int __Pyx_PyList_Append(PyObject* list, PyObject* x) { + PyListObject* L = (PyListObject*) list; + Py_ssize_t len = Py_SIZE(list); + if (likely(L->allocated > len) & likely(len > (L->allocated >> 1))) { + Py_INCREF(x); + #if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030d0000 + L->ob_item[len] = x; + #else + PyList_SET_ITEM(list, len, x); + #endif + __Pyx_SET_SIZE(list, len + 1); + return 0; + } + return PyList_Append(list, x); +} +#else +#define __Pyx_PyList_Append(L,x) PyList_Append(L,x) +#endif + +/* PyIntBinop.proto */ +#if !CYTHON_COMPILING_IN_PYPY +static PyObject* __Pyx_PyInt_SubtractObjC(PyObject *op1, PyObject *op2, long intval, int inplace, int zerodivision_check); +#else +#define __Pyx_PyInt_SubtractObjC(op1, op2, intval, inplace, zerodivision_check)\ + (inplace ? PyNumber_InPlaceSubtract(op1, op2) : PyNumber_Subtract(op1, op2)) +#endif + +/* SliceObject.proto */ +static CYTHON_INLINE PyObject* __Pyx_PyObject_GetSlice( + PyObject* obj, Py_ssize_t cstart, Py_ssize_t cstop, + PyObject** py_start, PyObject** py_stop, PyObject** py_slice, + int has_cstart, int has_cstop, int wraparound); + +/* PyObject_GenericGetAttrNoDict.proto */ +#if CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP && PY_VERSION_HEX < 0x03070000 +static CYTHON_INLINE PyObject* __Pyx_PyObject_GenericGetAttrNoDict(PyObject* obj, PyObject* attr_name); +#else +#define __Pyx_PyObject_GenericGetAttrNoDict PyObject_GenericGetAttr +#endif + +/* PyObject_GenericGetAttr.proto */ +#if CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP && PY_VERSION_HEX < 0x03070000 +static PyObject* __Pyx_PyObject_GenericGetAttr(PyObject* obj, PyObject* attr_name); +#else +#define __Pyx_PyObject_GenericGetAttr PyObject_GenericGetAttr +#endif + +/* IncludeStructmemberH.proto */ +#include + +/* FixUpExtensionType.proto */ +#if CYTHON_USE_TYPE_SPECS +static int __Pyx_fix_up_extension_type_from_spec(PyType_Spec *spec, PyTypeObject *type); +#endif + +/* PyObjectCallNoArg.proto */ +static CYTHON_INLINE PyObject* __Pyx_PyObject_CallNoArg(PyObject *func); + +/* PyObjectGetMethod.proto */ +static int __Pyx_PyObject_GetMethod(PyObject *obj, PyObject *name, PyObject **method); + +/* PyObjectCallMethod0.proto */ +static PyObject* __Pyx_PyObject_CallMethod0(PyObject* obj, PyObject* method_name); + +/* ValidateBasesTuple.proto */ +#if CYTHON_COMPILING_IN_CPYTHON || CYTHON_COMPILING_IN_LIMITED_API || CYTHON_USE_TYPE_SPECS +static int __Pyx_validate_bases_tuple(const char *type_name, Py_ssize_t dictoffset, PyObject *bases); +#endif + +/* PyType_Ready.proto */ +CYTHON_UNUSED static int __Pyx_PyType_Ready(PyTypeObject *t); + +/* SetVTable.proto */ +static int __Pyx_SetVtable(PyTypeObject* typeptr , void* vtable); + +/* GetVTable.proto */ +static void* __Pyx_GetVtable(PyTypeObject *type); + +/* MergeVTables.proto */ +#if !CYTHON_COMPILING_IN_LIMITED_API +static int __Pyx_MergeVtables(PyTypeObject *type); +#endif + +/* SetupReduce.proto */ +#if !CYTHON_COMPILING_IN_LIMITED_API +static int __Pyx_setup_reduce(PyObject* type_obj); +#endif + +/* TypeImport.proto */ +#ifndef __PYX_HAVE_RT_ImportType_proto_3_0_8 +#define __PYX_HAVE_RT_ImportType_proto_3_0_8 +#if defined (__STDC_VERSION__) && __STDC_VERSION__ >= 201112L +#include +#endif +#if (defined (__STDC_VERSION__) && __STDC_VERSION__ >= 201112L) || __cplusplus >= 201103L +#define __PYX_GET_STRUCT_ALIGNMENT_3_0_8(s) alignof(s) +#else +#define __PYX_GET_STRUCT_ALIGNMENT_3_0_8(s) sizeof(void*) +#endif +enum __Pyx_ImportType_CheckSize_3_0_8 { + __Pyx_ImportType_CheckSize_Error_3_0_8 = 0, + __Pyx_ImportType_CheckSize_Warn_3_0_8 = 1, + __Pyx_ImportType_CheckSize_Ignore_3_0_8 = 2 +}; +static PyTypeObject *__Pyx_ImportType_3_0_8(PyObject* module, const char *module_name, const char *class_name, size_t size, size_t alignment, enum __Pyx_ImportType_CheckSize_3_0_8 check_size); +#endif + +/* FetchSharedCythonModule.proto */ +static PyObject *__Pyx_FetchSharedCythonABIModule(void); + +/* FetchCommonType.proto */ +#if !CYTHON_USE_TYPE_SPECS +static PyTypeObject* __Pyx_FetchCommonType(PyTypeObject* type); +#else +static PyTypeObject* __Pyx_FetchCommonTypeFromSpec(PyObject *module, PyType_Spec *spec, PyObject *bases); +#endif + +/* PyMethodNew.proto */ +#if CYTHON_COMPILING_IN_LIMITED_API +static PyObject *__Pyx_PyMethod_New(PyObject *func, PyObject *self, PyObject *typ) { + PyObject *typesModule=NULL, *methodType=NULL, *result=NULL; + CYTHON_UNUSED_VAR(typ); + if (!self) + return __Pyx_NewRef(func); + typesModule = PyImport_ImportModule("types"); + if (!typesModule) return NULL; + methodType = PyObject_GetAttrString(typesModule, "MethodType"); + Py_DECREF(typesModule); + if (!methodType) return NULL; + result = PyObject_CallFunctionObjArgs(methodType, func, self, NULL); + Py_DECREF(methodType); + return result; +} +#elif PY_MAJOR_VERSION >= 3 +static PyObject *__Pyx_PyMethod_New(PyObject *func, PyObject *self, PyObject *typ) { + CYTHON_UNUSED_VAR(typ); + if (!self) + return __Pyx_NewRef(func); + return PyMethod_New(func, self); +} +#else + #define __Pyx_PyMethod_New PyMethod_New +#endif + +/* PyVectorcallFastCallDict.proto */ +#if CYTHON_METH_FASTCALL +static CYTHON_INLINE PyObject *__Pyx_PyVectorcall_FastCallDict(PyObject *func, __pyx_vectorcallfunc vc, PyObject *const *args, size_t nargs, PyObject *kw); +#endif + +/* CythonFunctionShared.proto */ +#define __Pyx_CyFunction_USED +#define __Pyx_CYFUNCTION_STATICMETHOD 0x01 +#define __Pyx_CYFUNCTION_CLASSMETHOD 0x02 +#define __Pyx_CYFUNCTION_CCLASS 0x04 +#define __Pyx_CYFUNCTION_COROUTINE 0x08 +#define __Pyx_CyFunction_GetClosure(f)\ + (((__pyx_CyFunctionObject *) (f))->func_closure) +#if PY_VERSION_HEX < 0x030900B1 || CYTHON_COMPILING_IN_LIMITED_API + #define __Pyx_CyFunction_GetClassObj(f)\ + (((__pyx_CyFunctionObject *) (f))->func_classobj) +#else + #define __Pyx_CyFunction_GetClassObj(f)\ + ((PyObject*) ((PyCMethodObject *) (f))->mm_class) +#endif +#define __Pyx_CyFunction_SetClassObj(f, classobj)\ + __Pyx__CyFunction_SetClassObj((__pyx_CyFunctionObject *) (f), (classobj)) +#define __Pyx_CyFunction_Defaults(type, f)\ + ((type *)(((__pyx_CyFunctionObject *) (f))->defaults)) +#define __Pyx_CyFunction_SetDefaultsGetter(f, g)\ + ((__pyx_CyFunctionObject *) (f))->defaults_getter = (g) +typedef struct { +#if CYTHON_COMPILING_IN_LIMITED_API + PyObject_HEAD + PyObject *func; +#elif PY_VERSION_HEX < 0x030900B1 + PyCFunctionObject func; +#else + PyCMethodObject func; +#endif +#if CYTHON_BACKPORT_VECTORCALL + __pyx_vectorcallfunc func_vectorcall; +#endif +#if PY_VERSION_HEX < 0x030500A0 || CYTHON_COMPILING_IN_LIMITED_API + PyObject *func_weakreflist; +#endif + PyObject *func_dict; + PyObject *func_name; + PyObject *func_qualname; + PyObject *func_doc; + PyObject *func_globals; + PyObject *func_code; + PyObject *func_closure; +#if PY_VERSION_HEX < 0x030900B1 || CYTHON_COMPILING_IN_LIMITED_API + PyObject *func_classobj; +#endif + void *defaults; + int defaults_pyobjects; + size_t defaults_size; + int flags; + PyObject *defaults_tuple; + PyObject *defaults_kwdict; + PyObject *(*defaults_getter)(PyObject *); + PyObject *func_annotations; + PyObject *func_is_coroutine; +} __pyx_CyFunctionObject; +#undef __Pyx_CyOrPyCFunction_Check +#define __Pyx_CyFunction_Check(obj) __Pyx_TypeCheck(obj, __pyx_CyFunctionType) +#define __Pyx_CyOrPyCFunction_Check(obj) __Pyx_TypeCheck2(obj, __pyx_CyFunctionType, &PyCFunction_Type) +#define __Pyx_CyFunction_CheckExact(obj) __Pyx_IS_TYPE(obj, __pyx_CyFunctionType) +static CYTHON_INLINE int __Pyx__IsSameCyOrCFunction(PyObject *func, void *cfunc); +#undef __Pyx_IsSameCFunction +#define __Pyx_IsSameCFunction(func, cfunc) __Pyx__IsSameCyOrCFunction(func, cfunc) +static PyObject *__Pyx_CyFunction_Init(__pyx_CyFunctionObject* op, PyMethodDef *ml, + int flags, PyObject* qualname, + PyObject *closure, + PyObject *module, PyObject *globals, + PyObject* code); +static CYTHON_INLINE void __Pyx__CyFunction_SetClassObj(__pyx_CyFunctionObject* f, PyObject* classobj); +static CYTHON_INLINE void *__Pyx_CyFunction_InitDefaults(PyObject *m, + size_t size, + int pyobjects); +static CYTHON_INLINE void __Pyx_CyFunction_SetDefaultsTuple(PyObject *m, + PyObject *tuple); +static CYTHON_INLINE void __Pyx_CyFunction_SetDefaultsKwDict(PyObject *m, + PyObject *dict); +static CYTHON_INLINE void __Pyx_CyFunction_SetAnnotationsDict(PyObject *m, + PyObject *dict); +static int __pyx_CyFunction_init(PyObject *module); +#if CYTHON_METH_FASTCALL +static PyObject * __Pyx_CyFunction_Vectorcall_NOARGS(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames); +static PyObject * __Pyx_CyFunction_Vectorcall_O(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames); +static PyObject * __Pyx_CyFunction_Vectorcall_FASTCALL_KEYWORDS(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames); +static PyObject * __Pyx_CyFunction_Vectorcall_FASTCALL_KEYWORDS_METHOD(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames); +#if CYTHON_BACKPORT_VECTORCALL +#define __Pyx_CyFunction_func_vectorcall(f) (((__pyx_CyFunctionObject*)f)->func_vectorcall) +#else +#define __Pyx_CyFunction_func_vectorcall(f) (((PyCFunctionObject*)f)->vectorcall) +#endif +#endif + +/* CythonFunction.proto */ +static PyObject *__Pyx_CyFunction_New(PyMethodDef *ml, + int flags, PyObject* qualname, + PyObject *closure, + PyObject *module, PyObject *globals, + PyObject* code); + +/* CLineInTraceback.proto */ +#ifdef CYTHON_CLINE_IN_TRACEBACK +#define __Pyx_CLineForTraceback(tstate, c_line) (((CYTHON_CLINE_IN_TRACEBACK)) ? c_line : 0) +#else +static int __Pyx_CLineForTraceback(PyThreadState *tstate, int c_line); +#endif + +/* CodeObjectCache.proto */ +#if !CYTHON_COMPILING_IN_LIMITED_API +typedef struct { + PyCodeObject* code_object; + int code_line; +} __Pyx_CodeObjectCacheEntry; +struct __Pyx_CodeObjectCache { + int count; + int max_count; + __Pyx_CodeObjectCacheEntry* entries; +}; +static struct __Pyx_CodeObjectCache __pyx_code_cache = {0,0,NULL}; +static int __pyx_bisect_code_objects(__Pyx_CodeObjectCacheEntry* entries, int count, int code_line); +static PyCodeObject *__pyx_find_code_object(int code_line); +static void __pyx_insert_code_object(int code_line, PyCodeObject* code_object); +#endif + +/* AddTraceback.proto */ +static void __Pyx_AddTraceback(const char *funcname, int c_line, + int py_line, const char *filename); + +#if PY_MAJOR_VERSION < 3 + static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags); + static void __Pyx_ReleaseBuffer(Py_buffer *view); +#else + #define __Pyx_GetBuffer PyObject_GetBuffer + #define __Pyx_ReleaseBuffer PyBuffer_Release +#endif + + +/* BufferStructDeclare.proto */ +typedef struct { + Py_ssize_t shape, strides, suboffsets; +} __Pyx_Buf_DimInfo; +typedef struct { + size_t refcount; + Py_buffer pybuffer; +} __Pyx_Buffer; +typedef struct { + __Pyx_Buffer *rcbuffer; + char *data; + __Pyx_Buf_DimInfo diminfo[8]; +} __Pyx_LocalBuf_ND; + +/* MemviewSliceIsContig.proto */ +static int __pyx_memviewslice_is_contig(const __Pyx_memviewslice mvs, char order, int ndim); + +/* OverlappingSlices.proto */ +static int __pyx_slices_overlap(__Pyx_memviewslice *slice1, + __Pyx_memviewslice *slice2, + int ndim, size_t itemsize); + +/* MemviewDtypeToObject.proto */ +static CYTHON_INLINE PyObject *__pyx_memview_get_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t(const char *itemp); +static CYTHON_INLINE int __pyx_memview_set_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t(const char *itemp, PyObject *obj); + +/* TypeInfoCompare.proto */ +static int __pyx_typeinfo_cmp(__Pyx_TypeInfo *a, __Pyx_TypeInfo *b); + +/* MemviewSliceValidateAndInit.proto */ +static int __Pyx_ValidateAndInit_memviewslice( + int *axes_specs, + int c_or_f_flag, + int buf_flags, + int ndim, + __Pyx_TypeInfo *dtype, + __Pyx_BufFmt_StackElem stack[], + __Pyx_memviewslice *memviewslice, + PyObject *original_obj); + +/* ObjectToMemviewSlice.proto */ +static CYTHON_INLINE __Pyx_memviewslice __Pyx_PyObject_to_MemoryviewSlice_ds_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t(PyObject *, int writable_flag); + +/* ObjectToMemviewSlice.proto */ +static CYTHON_INLINE __Pyx_memviewslice __Pyx_PyObject_to_MemoryviewSlice_dsds_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t(PyObject *, int writable_flag); + +/* RealImag.proto */ +#if CYTHON_CCOMPLEX + #ifdef __cplusplus + #define __Pyx_CREAL(z) ((z).real()) + #define __Pyx_CIMAG(z) ((z).imag()) + #else + #define __Pyx_CREAL(z) (__real__(z)) + #define __Pyx_CIMAG(z) (__imag__(z)) + #endif +#else + #define __Pyx_CREAL(z) ((z).real) + #define __Pyx_CIMAG(z) ((z).imag) +#endif +#if defined(__cplusplus) && CYTHON_CCOMPLEX\ + && (defined(_WIN32) || defined(__clang__) || (defined(__GNUC__) && (__GNUC__ >= 5 || __GNUC__ == 4 && __GNUC_MINOR__ >= 4 )) || __cplusplus >= 201103) + #define __Pyx_SET_CREAL(z,x) ((z).real(x)) + #define __Pyx_SET_CIMAG(z,y) ((z).imag(y)) +#else + #define __Pyx_SET_CREAL(z,x) __Pyx_CREAL(z) = (x) + #define __Pyx_SET_CIMAG(z,y) __Pyx_CIMAG(z) = (y) +#endif + +/* Arithmetic.proto */ +#if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) + #define __Pyx_c_eq_float(a, b) ((a)==(b)) + #define __Pyx_c_sum_float(a, b) ((a)+(b)) + #define __Pyx_c_diff_float(a, b) ((a)-(b)) + #define __Pyx_c_prod_float(a, b) ((a)*(b)) + #define __Pyx_c_quot_float(a, b) ((a)/(b)) + #define __Pyx_c_neg_float(a) (-(a)) + #ifdef __cplusplus + #define __Pyx_c_is_zero_float(z) ((z)==(float)0) + #define __Pyx_c_conj_float(z) (::std::conj(z)) + #if 1 + #define __Pyx_c_abs_float(z) (::std::abs(z)) + #define __Pyx_c_pow_float(a, b) (::std::pow(a, b)) + #endif + #else + #define __Pyx_c_is_zero_float(z) ((z)==0) + #define __Pyx_c_conj_float(z) (conjf(z)) + #if 1 + #define __Pyx_c_abs_float(z) (cabsf(z)) + #define __Pyx_c_pow_float(a, b) (cpowf(a, b)) + #endif + #endif +#else + static CYTHON_INLINE int __Pyx_c_eq_float(__pyx_t_float_complex, __pyx_t_float_complex); + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_sum_float(__pyx_t_float_complex, __pyx_t_float_complex); + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_diff_float(__pyx_t_float_complex, __pyx_t_float_complex); + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_prod_float(__pyx_t_float_complex, __pyx_t_float_complex); + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_quot_float(__pyx_t_float_complex, __pyx_t_float_complex); + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_neg_float(__pyx_t_float_complex); + static CYTHON_INLINE int __Pyx_c_is_zero_float(__pyx_t_float_complex); + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_conj_float(__pyx_t_float_complex); + #if 1 + static CYTHON_INLINE float __Pyx_c_abs_float(__pyx_t_float_complex); + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_pow_float(__pyx_t_float_complex, __pyx_t_float_complex); + #endif +#endif + +/* Arithmetic.proto */ +#if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) + #define __Pyx_c_eq_double(a, b) ((a)==(b)) + #define __Pyx_c_sum_double(a, b) ((a)+(b)) + #define __Pyx_c_diff_double(a, b) ((a)-(b)) + #define __Pyx_c_prod_double(a, b) ((a)*(b)) + #define __Pyx_c_quot_double(a, b) ((a)/(b)) + #define __Pyx_c_neg_double(a) (-(a)) + #ifdef __cplusplus + #define __Pyx_c_is_zero_double(z) ((z)==(double)0) + #define __Pyx_c_conj_double(z) (::std::conj(z)) + #if 1 + #define __Pyx_c_abs_double(z) (::std::abs(z)) + #define __Pyx_c_pow_double(a, b) (::std::pow(a, b)) + #endif + #else + #define __Pyx_c_is_zero_double(z) ((z)==0) + #define __Pyx_c_conj_double(z) (conj(z)) + #if 1 + #define __Pyx_c_abs_double(z) (cabs(z)) + #define __Pyx_c_pow_double(a, b) (cpow(a, b)) + #endif + #endif +#else + static CYTHON_INLINE int __Pyx_c_eq_double(__pyx_t_double_complex, __pyx_t_double_complex); + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_sum_double(__pyx_t_double_complex, __pyx_t_double_complex); + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_diff_double(__pyx_t_double_complex, __pyx_t_double_complex); + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_prod_double(__pyx_t_double_complex, __pyx_t_double_complex); + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_quot_double(__pyx_t_double_complex, __pyx_t_double_complex); + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_neg_double(__pyx_t_double_complex); + static CYTHON_INLINE int __Pyx_c_is_zero_double(__pyx_t_double_complex); + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_conj_double(__pyx_t_double_complex); + #if 1 + static CYTHON_INLINE double __Pyx_c_abs_double(__pyx_t_double_complex); + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_pow_double(__pyx_t_double_complex, __pyx_t_double_complex); + #endif +#endif + +/* MemviewSliceCopyTemplate.proto */ +static __Pyx_memviewslice +__pyx_memoryview_copy_new_contig(const __Pyx_memviewslice *from_mvs, + const char *mode, int ndim, + size_t sizeof_dtype, int contig_flag, + int dtype_is_object); + +/* MemviewSliceInit.proto */ +#define __Pyx_BUF_MAX_NDIMS %(BUF_MAX_NDIMS)d +#define __Pyx_MEMVIEW_DIRECT 1 +#define __Pyx_MEMVIEW_PTR 2 +#define __Pyx_MEMVIEW_FULL 4 +#define __Pyx_MEMVIEW_CONTIG 8 +#define __Pyx_MEMVIEW_STRIDED 16 +#define __Pyx_MEMVIEW_FOLLOW 32 +#define __Pyx_IS_C_CONTIG 1 +#define __Pyx_IS_F_CONTIG 2 +static int __Pyx_init_memviewslice( + struct __pyx_memoryview_obj *memview, + int ndim, + __Pyx_memviewslice *memviewslice, + int memview_is_new_reference); +static CYTHON_INLINE int __pyx_add_acquisition_count_locked( + __pyx_atomic_int_type *acquisition_count, PyThread_type_lock lock); +static CYTHON_INLINE int __pyx_sub_acquisition_count_locked( + __pyx_atomic_int_type *acquisition_count, PyThread_type_lock lock); +#define __pyx_get_slice_count_pointer(memview) (&memview->acquisition_count) +#define __PYX_INC_MEMVIEW(slice, have_gil) __Pyx_INC_MEMVIEW(slice, have_gil, __LINE__) +#define __PYX_XCLEAR_MEMVIEW(slice, have_gil) __Pyx_XCLEAR_MEMVIEW(slice, have_gil, __LINE__) +static CYTHON_INLINE void __Pyx_INC_MEMVIEW(__Pyx_memviewslice *, int, int); +static CYTHON_INLINE void __Pyx_XCLEAR_MEMVIEW(__Pyx_memviewslice *, int, int); + +/* CIntToPy.proto */ +static CYTHON_INLINE PyObject* __Pyx_PyInt_From_int64_t(int64_t value); + +/* CIntFromPy.proto */ +static CYTHON_INLINE int64_t __Pyx_PyInt_As_int64_t(PyObject *); + +/* CIntFromPy.proto */ +static CYTHON_INLINE int __Pyx_PyInt_As_int(PyObject *); + +/* CIntFromPy.proto */ +static CYTHON_INLINE long __Pyx_PyInt_As_long(PyObject *); + +/* CIntToPy.proto */ +static CYTHON_INLINE PyObject* __Pyx_PyInt_From_long(long value); + +/* None.proto */ +#include + +/* CIntToPy.proto */ +static CYTHON_INLINE PyObject* __Pyx_PyInt_From_int(int value); + +/* CIntFromPy.proto */ +static CYTHON_INLINE char __Pyx_PyInt_As_char(PyObject *); + +/* FormatTypeName.proto */ +#if CYTHON_COMPILING_IN_LIMITED_API +typedef PyObject *__Pyx_TypeName; +#define __Pyx_FMT_TYPENAME "%U" +static __Pyx_TypeName __Pyx_PyType_GetName(PyTypeObject* tp); +#define __Pyx_DECREF_TypeName(obj) Py_XDECREF(obj) +#else +typedef const char *__Pyx_TypeName; +#define __Pyx_FMT_TYPENAME "%.200s" +#define __Pyx_PyType_GetName(tp) ((tp)->tp_name) +#define __Pyx_DECREF_TypeName(obj) +#endif + +/* CheckBinaryVersion.proto */ +static unsigned long __Pyx_get_runtime_version(void); +static int __Pyx_check_binary_version(unsigned long ct_version, unsigned long rt_version, int allow_newer); + +/* InitStrings.proto */ +static int __Pyx_InitStrings(__Pyx_StringTabEntry *t); + +/* #### Code section: module_declarations ### */ +static PyObject *__pyx_array_get_memview(struct __pyx_array_obj *__pyx_v_self); /* proto*/ +static char *__pyx_memoryview_get_item_pointer(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index); /* proto*/ +static PyObject *__pyx_memoryview_is_slice(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_obj); /* proto*/ +static PyObject *__pyx_memoryview_setitem_slice_assignment(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_dst, PyObject *__pyx_v_src); /* proto*/ +static PyObject *__pyx_memoryview_setitem_slice_assign_scalar(struct __pyx_memoryview_obj *__pyx_v_self, struct __pyx_memoryview_obj *__pyx_v_dst, PyObject *__pyx_v_value); /* proto*/ +static PyObject *__pyx_memoryview_setitem_indexed(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index, PyObject *__pyx_v_value); /* proto*/ +static PyObject *__pyx_memoryview_convert_item_to_object(struct __pyx_memoryview_obj *__pyx_v_self, char *__pyx_v_itemp); /* proto*/ +static PyObject *__pyx_memoryview_assign_item_from_object(struct __pyx_memoryview_obj *__pyx_v_self, char *__pyx_v_itemp, PyObject *__pyx_v_value); /* proto*/ +static PyObject *__pyx_memoryview__get_base(struct __pyx_memoryview_obj *__pyx_v_self); /* proto*/ +static PyObject *__pyx_memoryviewslice_convert_item_to_object(struct __pyx_memoryviewslice_obj *__pyx_v_self, char *__pyx_v_itemp); /* proto*/ +static PyObject *__pyx_memoryviewslice_assign_item_from_object(struct __pyx_memoryviewslice_obj *__pyx_v_self, char *__pyx_v_itemp, PyObject *__pyx_v_value); /* proto*/ +static PyObject *__pyx_memoryviewslice__get_base(struct __pyx_memoryviewslice_obj *__pyx_v_self); /* proto*/ +static CYTHON_INLINE PyObject *__pyx_f_5numpy_7ndarray_4base_base(PyArrayObject *__pyx_v_self); /* proto*/ +static CYTHON_INLINE PyArray_Descr *__pyx_f_5numpy_7ndarray_5descr_descr(PyArrayObject *__pyx_v_self); /* proto*/ +static CYTHON_INLINE int __pyx_f_5numpy_7ndarray_4ndim_ndim(PyArrayObject *__pyx_v_self); /* proto*/ +static CYTHON_INLINE npy_intp *__pyx_f_5numpy_7ndarray_5shape_shape(PyArrayObject *__pyx_v_self); /* proto*/ +static CYTHON_INLINE npy_intp *__pyx_f_5numpy_7ndarray_7strides_strides(PyArrayObject *__pyx_v_self); /* proto*/ +static CYTHON_INLINE npy_intp __pyx_f_5numpy_7ndarray_4size_size(PyArrayObject *__pyx_v_self); /* proto*/ +static CYTHON_INLINE char *__pyx_f_5numpy_7ndarray_4data_data(PyArrayObject *__pyx_v_self); /* proto*/ +static PyObject *__pyx_f_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_reset(struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *__pyx_v_self); /* proto*/ +static int __pyx_f_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_step(struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *__pyx_v_self, __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_v_i); /* proto*/ +static PyObject *__pyx_f_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_seek(struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *__pyx_v_self, __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_v_i); /* proto*/ + +/* Module declarations from "libc.math" */ + +/* Module declarations from "cython.view" */ + +/* Module declarations from "cython.dataclasses" */ + +/* Module declarations from "cython" */ + +/* Module declarations from "libc.string" */ + +/* Module declarations from "libc.stdio" */ + +/* Module declarations from "__builtin__" */ + +/* Module declarations from "cpython.type" */ + +/* Module declarations from "cpython" */ + +/* Module declarations from "cpython.object" */ + +/* Module declarations from "cpython.ref" */ + +/* Module declarations from "numpy" */ + +/* Module declarations from "numpy" */ + +/* Module declarations from "libc.stdint" */ + +/* Module declarations from "fairseq.data.token_block_utils_fast" */ +static PyObject *__pyx_collections_abc_Sequence = 0; +static PyObject *generic = 0; +static PyObject *strided = 0; +static PyObject *indirect = 0; +static PyObject *contiguous = 0; +static PyObject *indirect_contiguous = 0; +static int __pyx_memoryview_thread_locks_used; +static PyThread_type_lock __pyx_memoryview_thread_locks[8]; +static PyArrayObject *__pyx_f_7fairseq_4data_22token_block_utils_fast__get_slice_indices_none_mode(PyArrayObject *, int); /*proto*/ +static PyArrayObject *__pyx_f_7fairseq_4data_22token_block_utils_fast__fast_convert_to_np_array(PyObject *); /*proto*/ +static PyArrayObject *__pyx_f_7fairseq_4data_22token_block_utils_fast__get_slice_indices_fast(PyArrayObject *, PyObject *, int, int, int __pyx_skip_dispatch); /*proto*/ +static PyArrayObject *__pyx_f_7fairseq_4data_22token_block_utils_fast__get_block_to_dataset_index_fast(PyArrayObject *, PyArrayObject *, int __pyx_skip_dispatch); /*proto*/ +static PyObject *__pyx_f_7fairseq_4data_22token_block_utils_fast___pyx_unpickle_DatasetSearcher__set_state(struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *, PyObject *); /*proto*/ +static int __pyx_array_allocate_buffer(struct __pyx_array_obj *); /*proto*/ +static struct __pyx_array_obj *__pyx_array_new(PyObject *, Py_ssize_t, char *, char *, char *); /*proto*/ +static PyObject *__pyx_memoryview_new(PyObject *, int, int, __Pyx_TypeInfo *); /*proto*/ +static CYTHON_INLINE int __pyx_memoryview_check(PyObject *); /*proto*/ +static PyObject *_unellipsify(PyObject *, int); /*proto*/ +static int assert_direct_dimensions(Py_ssize_t *, int); /*proto*/ +static struct __pyx_memoryview_obj *__pyx_memview_slice(struct __pyx_memoryview_obj *, PyObject *); /*proto*/ +static int __pyx_memoryview_slice_memviewslice(__Pyx_memviewslice *, Py_ssize_t, Py_ssize_t, Py_ssize_t, int, int, int *, Py_ssize_t, Py_ssize_t, Py_ssize_t, int, int, int, int); /*proto*/ +static char *__pyx_pybuffer_index(Py_buffer *, char *, Py_ssize_t, Py_ssize_t); /*proto*/ +static int __pyx_memslice_transpose(__Pyx_memviewslice *); /*proto*/ +static PyObject *__pyx_memoryview_fromslice(__Pyx_memviewslice, int, PyObject *(*)(char *), int (*)(char *, PyObject *), int); /*proto*/ +static __Pyx_memviewslice *__pyx_memoryview_get_slice_from_memoryview(struct __pyx_memoryview_obj *, __Pyx_memviewslice *); /*proto*/ +static void __pyx_memoryview_slice_copy(struct __pyx_memoryview_obj *, __Pyx_memviewslice *); /*proto*/ +static PyObject *__pyx_memoryview_copy_object(struct __pyx_memoryview_obj *); /*proto*/ +static PyObject *__pyx_memoryview_copy_object_from_slice(struct __pyx_memoryview_obj *, __Pyx_memviewslice *); /*proto*/ +static Py_ssize_t abs_py_ssize_t(Py_ssize_t); /*proto*/ +static char __pyx_get_best_slice_order(__Pyx_memviewslice *, int); /*proto*/ +static void _copy_strided_to_strided(char *, Py_ssize_t *, char *, Py_ssize_t *, Py_ssize_t *, Py_ssize_t *, int, size_t); /*proto*/ +static void copy_strided_to_strided(__Pyx_memviewslice *, __Pyx_memviewslice *, int, size_t); /*proto*/ +static Py_ssize_t __pyx_memoryview_slice_get_size(__Pyx_memviewslice *, int); /*proto*/ +static Py_ssize_t __pyx_fill_contig_strides_array(Py_ssize_t *, Py_ssize_t *, Py_ssize_t, int, char); /*proto*/ +static void *__pyx_memoryview_copy_data_to_temp(__Pyx_memviewslice *, __Pyx_memviewslice *, char, int); /*proto*/ +static int __pyx_memoryview_err_extents(int, Py_ssize_t, Py_ssize_t); /*proto*/ +static int __pyx_memoryview_err_dim(PyObject *, PyObject *, int); /*proto*/ +static int __pyx_memoryview_err(PyObject *, PyObject *); /*proto*/ +static int __pyx_memoryview_err_no_memory(void); /*proto*/ +static int __pyx_memoryview_copy_contents(__Pyx_memviewslice, __Pyx_memviewslice, int, int, int); /*proto*/ +static void __pyx_memoryview_broadcast_leading(__Pyx_memviewslice *, int, int); /*proto*/ +static void __pyx_memoryview_refcount_copying(__Pyx_memviewslice *, int, int, int); /*proto*/ +static void __pyx_memoryview_refcount_objects_in_slice_with_gil(char *, Py_ssize_t *, Py_ssize_t *, int, int); /*proto*/ +static void __pyx_memoryview_refcount_objects_in_slice(char *, Py_ssize_t *, Py_ssize_t *, int, int); /*proto*/ +static void __pyx_memoryview_slice_assign_scalar(__Pyx_memviewslice *, int, size_t, void *, int); /*proto*/ +static void __pyx_memoryview__slice_assign_scalar(char *, Py_ssize_t *, Py_ssize_t *, int, size_t, void *); /*proto*/ +static PyObject *__pyx_unpickle_Enum__set_state(struct __pyx_MemviewEnum_obj *, PyObject *); /*proto*/ +/* #### Code section: typeinfo ### */ +static __Pyx_TypeInfo __Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t = { "DTYPE_t", NULL, sizeof(__pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t), { 0 }, 0, __PYX_IS_UNSIGNED(__pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t) ? 'U' : 'I', __PYX_IS_UNSIGNED(__pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t), 0 }; +/* #### Code section: before_global_var ### */ +#define __Pyx_MODULE_NAME "fairseq.data.token_block_utils_fast" +extern int __pyx_module_is_main_fairseq__data__token_block_utils_fast; +int __pyx_module_is_main_fairseq__data__token_block_utils_fast = 0; + +/* Implementation of "fairseq.data.token_block_utils_fast" */ +/* #### Code section: global_var ### */ +static PyObject *__pyx_builtin_range; +static PyObject *__pyx_builtin_ValueError; +static PyObject *__pyx_builtin_AssertionError; +static PyObject *__pyx_builtin___import__; +static PyObject *__pyx_builtin_MemoryError; +static PyObject *__pyx_builtin_enumerate; +static PyObject *__pyx_builtin_TypeError; +static PyObject *__pyx_builtin_Ellipsis; +static PyObject *__pyx_builtin_id; +static PyObject *__pyx_builtin_IndexError; +static PyObject *__pyx_builtin_ImportError; +/* #### Code section: string_decls ### */ +static const char __pyx_k_[] = ": "; +static const char __pyx_k_O[] = "O"; +static const char __pyx_k_c[] = "c"; +static const char __pyx_k__2[] = "."; +static const char __pyx_k__3[] = "*"; +static const char __pyx_k__6[] = "'"; +static const char __pyx_k__7[] = ")"; +static const char __pyx_k_gc[] = "gc"; +static const char __pyx_k_id[] = "id"; +static const char __pyx_k_np[] = "np"; +static const char __pyx_k__35[] = "?"; +static const char __pyx_k_abc[] = "abc"; +static const char __pyx_k_and[] = " and "; +static const char __pyx_k_eos[] = "eos"; +static const char __pyx_k_got[] = " (got "; +static const char __pyx_k_new[] = "__new__"; +static const char __pyx_k_obj[] = "obj"; +static const char __pyx_k_sum[] = "sum"; +static const char __pyx_k_sys[] = "sys"; +static const char __pyx_k_axis[] = "axis"; +static const char __pyx_k_base[] = "base"; +static const char __pyx_k_dict[] = "__dict__"; +static const char __pyx_k_main[] = "__main__"; +static const char __pyx_k_mode[] = "mode"; +static const char __pyx_k_name[] = "name"; +static const char __pyx_k_ndim[] = "ndim"; +static const char __pyx_k_none[] = "none"; +static const char __pyx_k_pack[] = "pack"; +static const char __pyx_k_self[] = "self"; +static const char __pyx_k_size[] = "size"; +static const char __pyx_k_spec[] = "__spec__"; +static const char __pyx_k_step[] = "step"; +static const char __pyx_k_stop[] = "stop"; +static const char __pyx_k_test[] = "__test__"; +static const char __pyx_k_ASCII[] = "ASCII"; +static const char __pyx_k_DTYPE[] = "DTYPE"; +static const char __pyx_k_chain[] = "chain"; +static const char __pyx_k_class[] = "__class__"; +static const char __pyx_k_count[] = "count"; +static const char __pyx_k_dtype[] = "dtype"; +static const char __pyx_k_error[] = "error"; +static const char __pyx_k_flags[] = "flags"; +static const char __pyx_k_index[] = "index"; +static const char __pyx_k_int64[] = "int64"; +static const char __pyx_k_numpy[] = "numpy"; +static const char __pyx_k_range[] = "range"; +static const char __pyx_k_shape[] = "shape"; +static const char __pyx_k_sizes[] = "sizes"; +static const char __pyx_k_start[] = "start"; +static const char __pyx_k_state[] = "state"; +static const char __pyx_k_torch[] = "torch"; +static const char __pyx_k_zeros[] = "zeros"; +static const char __pyx_k_cumsum[] = "cumsum"; +static const char __pyx_k_dict_2[] = "_dict"; +static const char __pyx_k_enable[] = "enable"; +static const char __pyx_k_encode[] = "encode"; +static const char __pyx_k_format[] = "format"; +static const char __pyx_k_import[] = "__import__"; +static const char __pyx_k_name_2[] = "__name__"; +static const char __pyx_k_pickle[] = "pickle"; +static const char __pyx_k_reduce[] = "__reduce__"; +static const char __pyx_k_struct[] = "struct"; +static const char __pyx_k_unpack[] = "unpack"; +static const char __pyx_k_update[] = "update"; +static const char __pyx_k_disable[] = "disable"; +static const char __pyx_k_fortran[] = "fortran"; +static const char __pyx_k_memview[] = "memview"; +static const char __pyx_k_reshape[] = "reshape"; +static const char __pyx_k_Ellipsis[] = "Ellipsis"; +static const char __pyx_k_Sequence[] = "Sequence"; +static const char __pyx_k_complete[] = "complete"; +static const char __pyx_k_fromiter[] = "fromiter"; +static const char __pyx_k_getstate[] = "__getstate__"; +static const char __pyx_k_itemsize[] = "itemsize"; +static const char __pyx_k_pyx_type[] = "__pyx_type"; +static const char __pyx_k_register[] = "register"; +static const char __pyx_k_setstate[] = "__setstate__"; +static const char __pyx_k_TypeError[] = "TypeError"; +static const char __pyx_k_enumerate[] = "enumerate"; +static const char __pyx_k_isenabled[] = "isenabled"; +static const char __pyx_k_itertools[] = "itertools"; +static const char __pyx_k_pyx_state[] = "__pyx_state"; +static const char __pyx_k_reduce_ex[] = "__reduce_ex__"; +static const char __pyx_k_IndexError[] = "IndexError"; +static const char __pyx_k_ValueError[] = "ValueError"; +static const char __pyx_k_block_size[] = "block_size"; +static const char __pyx_k_break_mode[] = "break_mode"; +static const char __pyx_k_pyx_result[] = "__pyx_result"; +static const char __pyx_k_pyx_vtable[] = "__pyx_vtable__"; +static const char __pyx_k_ImportError[] = "ImportError"; +static const char __pyx_k_MemoryError[] = "MemoryError"; +static const char __pyx_k_PickleError[] = "PickleError"; +static const char __pyx_k_collections[] = "collections"; +static const char __pyx_k_complete_doc[] = "complete_doc"; +static const char __pyx_k_initializing[] = "_initializing"; +static const char __pyx_k_is_coroutine[] = "_is_coroutine"; +static const char __pyx_k_pyx_checksum[] = "__pyx_checksum"; +static const char __pyx_k_stringsource[] = ""; +static const char __pyx_k_use_setstate[] = "use_setstate"; +static const char __pyx_k_version_info[] = "version_info"; +static const char __pyx_k_class_getitem[] = "__class_getitem__"; +static const char __pyx_k_from_iterable[] = "from_iterable"; +static const char __pyx_k_reduce_cython[] = "__reduce_cython__"; +static const char __pyx_k_slice_indices[] = "slice_indices"; +static const char __pyx_k_AssertionError[] = "AssertionError"; +static const char __pyx_k_DatasetSearcher[] = "DatasetSearcher"; +static const char __pyx_k_View_MemoryView[] = "View.MemoryView"; +static const char __pyx_k_allocate_buffer[] = "allocate_buffer"; +static const char __pyx_k_collections_abc[] = "collections.abc"; +static const char __pyx_k_dtype_is_object[] = "dtype_is_object"; +static const char __pyx_k_pyx_PickleError[] = "__pyx_PickleError"; +static const char __pyx_k_setstate_cython[] = "__setstate_cython__"; +static const char __pyx_k_document_sep_len[] = "document_sep_len"; +static const char __pyx_k_pyx_unpickle_Enum[] = "__pyx_unpickle_Enum"; +static const char __pyx_k_Invalid_break_mode[] = "Invalid break_mode: "; +static const char __pyx_k_asyncio_coroutines[] = "asyncio.coroutines"; +static const char __pyx_k_cline_in_traceback[] = "cline_in_traceback"; +static const char __pyx_k_strided_and_direct[] = ""; +static const char __pyx_k_strided_and_indirect[] = ""; +static const char __pyx_k_Invalid_shape_in_axis[] = "Invalid shape in axis "; +static const char __pyx_k_contiguous_and_direct[] = ""; +static const char __pyx_k_Cannot_index_with_type[] = "Cannot index with type '"; +static const char __pyx_k_MemoryView_of_r_object[] = ""; +static const char __pyx_k_get_slice_indices_fast[] = "_get_slice_indices_fast"; +static const char __pyx_k_MemoryView_of_r_at_0x_x[] = ""; +static const char __pyx_k_contiguous_and_indirect[] = ""; +static const char __pyx_k_Dimension_d_is_not_direct[] = "Dimension %d is not direct"; +static const char __pyx_k_Index_out_of_bounds_axis_d[] = "Index out of bounds (axis %d)"; +static const char __pyx_k_Step_may_not_be_zero_axis_d[] = "Step may not be zero (axis %d)"; +static const char __pyx_k_itemsize_0_for_cython_array[] = "itemsize <= 0 for cython.array"; +static const char __pyx_k_pyx_unpickle_DatasetSearcher[] = "__pyx_unpickle_DatasetSearcher"; +static const char __pyx_k_unable_to_allocate_array_data[] = "unable to allocate array data."; +static const char __pyx_k_strided_and_direct_or_indirect[] = ""; +static const char __pyx_k_DatasetSearcher___reduce_cython[] = "DatasetSearcher.__reduce_cython__"; +static const char __pyx_k_get_block_to_dataset_index_fast[] = "_get_block_to_dataset_index_fast"; +static const char __pyx_k_numpy_core_multiarray_failed_to[] = "numpy.core.multiarray failed to import"; +static const char __pyx_k_All_dimensions_preceding_dimensi[] = "All dimensions preceding dimension %d must be indexed and not sliced"; +static const char __pyx_k_Buffer_view_does_not_expose_stri[] = "Buffer view does not expose strides"; +static const char __pyx_k_Can_only_create_a_buffer_that_is[] = "Can only create a buffer that is contiguous in memory."; +static const char __pyx_k_Cannot_assign_to_read_only_memor[] = "Cannot assign to read-only memoryview"; +static const char __pyx_k_Cannot_create_writable_memory_vi[] = "Cannot create writable memory view from read-only memoryview"; +static const char __pyx_k_Cannot_transpose_memoryview_with[] = "Cannot transpose memoryview with indirect dimensions"; +static const char __pyx_k_DatasetSearcher___setstate_cytho[] = "DatasetSearcher.__setstate_cython__"; +static const char __pyx_k_Empty_shape_tuple_for_cython_arr[] = "Empty shape tuple for cython.array"; +static const char __pyx_k_Incompatible_checksums_0x_x_vs_0[] = "Incompatible checksums (0x%x vs (0x82a3537, 0x6ae9995, 0xb068931) = (name))"; +static const char __pyx_k_Indirect_dimensions_not_supporte[] = "Indirect dimensions not supported"; +static const char __pyx_k_Invalid_mode_expected_c_or_fortr[] = "Invalid mode, expected 'c' or 'fortran', got "; +static const char __pyx_k_Out_of_bounds_on_buffer_access_a[] = "Out of bounds on buffer access (axis "; +static const char __pyx_k_Unable_to_convert_item_to_object[] = "Unable to convert item to object"; +static const char __pyx_k_fairseq_data_token_block_utils_f[] = "fairseq/data/token_block_utils_fast.pyx"; +static const char __pyx_k_got_differing_extents_in_dimensi[] = "got differing extents in dimension "; +static const char __pyx_k_no_default___reduce___due_to_non[] = "no default __reduce__ due to non-trivial __cinit__"; +static const char __pyx_k_numpy_core_umath_failed_to_impor[] = "numpy.core.umath failed to import"; +static const char __pyx_k_unable_to_allocate_shape_and_str[] = "unable to allocate shape and strides."; +static const char __pyx_k_Incompatible_checksums_0x_x_vs_0_2[] = "Incompatible checksums (0x%x vs (0x8c67b45, 0x2e2dd22, 0x6632805) = (current_i, current_index, current_offset, sizes))"; +static const char __pyx_k_fairseq_data_token_block_utils_f_2[] = "fairseq.data.token_block_utils_fast"; +/* #### Code section: decls ### */ +static int __pyx_array___pyx_pf_15View_dot_MemoryView_5array___cinit__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_shape, Py_ssize_t __pyx_v_itemsize, PyObject *__pyx_v_format, PyObject *__pyx_v_mode, int __pyx_v_allocate_buffer); /* proto */ +static int __pyx_array___pyx_pf_15View_dot_MemoryView_5array_2__getbuffer__(struct __pyx_array_obj *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags); /* proto */ +static void __pyx_array___pyx_pf_15View_dot_MemoryView_5array_4__dealloc__(struct __pyx_array_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_5array_7memview___get__(struct __pyx_array_obj *__pyx_v_self); /* proto */ +static Py_ssize_t __pyx_array___pyx_pf_15View_dot_MemoryView_5array_6__len__(struct __pyx_array_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_array___pyx_pf_15View_dot_MemoryView_5array_8__getattr__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_attr); /* proto */ +static PyObject *__pyx_array___pyx_pf_15View_dot_MemoryView_5array_10__getitem__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_item); /* proto */ +static int __pyx_array___pyx_pf_15View_dot_MemoryView_5array_12__setitem__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_item, PyObject *__pyx_v_value); /* proto */ +static PyObject *__pyx_pf___pyx_array___reduce_cython__(CYTHON_UNUSED struct __pyx_array_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf___pyx_array_2__setstate_cython__(CYTHON_UNUSED struct __pyx_array_obj *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state); /* proto */ +static int __pyx_MemviewEnum___pyx_pf_15View_dot_MemoryView_4Enum___init__(struct __pyx_MemviewEnum_obj *__pyx_v_self, PyObject *__pyx_v_name); /* proto */ +static PyObject *__pyx_MemviewEnum___pyx_pf_15View_dot_MemoryView_4Enum_2__repr__(struct __pyx_MemviewEnum_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf___pyx_MemviewEnum___reduce_cython__(struct __pyx_MemviewEnum_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf___pyx_MemviewEnum_2__setstate_cython__(struct __pyx_MemviewEnum_obj *__pyx_v_self, PyObject *__pyx_v___pyx_state); /* proto */ +static int __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview___cinit__(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_obj, int __pyx_v_flags, int __pyx_v_dtype_is_object); /* proto */ +static void __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_2__dealloc__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_4__getitem__(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index); /* proto */ +static int __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_6__setitem__(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index, PyObject *__pyx_v_value); /* proto */ +static int __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_8__getbuffer__(struct __pyx_memoryview_obj *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_1T___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_4base___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_5shape___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_7strides___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_10suboffsets___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_4ndim___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_8itemsize___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_6nbytes___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_4size___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static Py_ssize_t __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_10__len__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_12__repr__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_14__str__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_16is_c_contig(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_18is_f_contig(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_20copy(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_22copy_fortran(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf___pyx_memoryview___reduce_cython__(CYTHON_UNUSED struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf___pyx_memoryview_2__setstate_cython__(CYTHON_UNUSED struct __pyx_memoryview_obj *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state); /* proto */ +static void __pyx_memoryviewslice___pyx_pf_15View_dot_MemoryView_16_memoryviewslice___dealloc__(struct __pyx_memoryviewslice_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf___pyx_memoryviewslice___reduce_cython__(CYTHON_UNUSED struct __pyx_memoryviewslice_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf___pyx_memoryviewslice_2__setstate_cython__(CYTHON_UNUSED struct __pyx_memoryviewslice_obj *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView___pyx_unpickle_Enum(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v___pyx_type, long __pyx_v___pyx_checksum, PyObject *__pyx_v___pyx_state); /* proto */ +static PyObject *__pyx_pf_7fairseq_4data_22token_block_utils_fast__get_slice_indices_fast(CYTHON_UNUSED PyObject *__pyx_self, PyArrayObject *__pyx_v_sizes, PyObject *__pyx_v_break_mode, int __pyx_v_block_size, int __pyx_v_document_sep_len); /* proto */ +static PyObject *__pyx_pf_7fairseq_4data_22token_block_utils_fast_2_get_block_to_dataset_index_fast(CYTHON_UNUSED PyObject *__pyx_self, PyArrayObject *__pyx_v_sizes, PyArrayObject *__pyx_v_slice_indices); /* proto */ +static int __pyx_pf_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher___init__(struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *__pyx_v_self, __Pyx_memviewslice __pyx_v_sizes); /* proto */ +static PyObject *__pyx_pf_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_2__reduce_cython__(struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_4__setstate_cython__(struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *__pyx_v_self, PyObject *__pyx_v___pyx_state); /* proto */ +static PyObject *__pyx_pf_7fairseq_4data_22token_block_utils_fast_4__pyx_unpickle_DatasetSearcher(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v___pyx_type, long __pyx_v___pyx_checksum, PyObject *__pyx_v___pyx_state); /* proto */ +static PyObject *__pyx_tp_new_7fairseq_4data_22token_block_utils_fast_DatasetSearcher(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/ +static PyObject *__pyx_tp_new_array(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/ +static PyObject *__pyx_tp_new_Enum(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/ +static PyObject *__pyx_tp_new_memoryview(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/ +static PyObject *__pyx_tp_new__memoryviewslice(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/ +/* #### Code section: late_includes ### */ +/* #### Code section: module_state ### */ +typedef struct { + PyObject *__pyx_d; + PyObject *__pyx_b; + PyObject *__pyx_cython_runtime; + PyObject *__pyx_empty_tuple; + PyObject *__pyx_empty_bytes; + PyObject *__pyx_empty_unicode; + #ifdef __Pyx_CyFunction_USED + PyTypeObject *__pyx_CyFunctionType; + #endif + #ifdef __Pyx_FusedFunction_USED + PyTypeObject *__pyx_FusedFunctionType; + #endif + #ifdef __Pyx_Generator_USED + PyTypeObject *__pyx_GeneratorType; + #endif + #ifdef __Pyx_IterableCoroutine_USED + PyTypeObject *__pyx_IterableCoroutineType; + #endif + #ifdef __Pyx_Coroutine_USED + PyTypeObject *__pyx_CoroutineAwaitType; + #endif + #ifdef __Pyx_Coroutine_USED + PyTypeObject *__pyx_CoroutineType; + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + PyTypeObject *__pyx_ptype_7cpython_4type_type; + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + PyTypeObject *__pyx_ptype_5numpy_dtype; + PyTypeObject *__pyx_ptype_5numpy_flatiter; + PyTypeObject *__pyx_ptype_5numpy_broadcast; + PyTypeObject *__pyx_ptype_5numpy_ndarray; + PyTypeObject *__pyx_ptype_5numpy_generic; + PyTypeObject *__pyx_ptype_5numpy_number; + PyTypeObject *__pyx_ptype_5numpy_integer; + PyTypeObject *__pyx_ptype_5numpy_signedinteger; + PyTypeObject *__pyx_ptype_5numpy_unsignedinteger; + PyTypeObject *__pyx_ptype_5numpy_inexact; + PyTypeObject *__pyx_ptype_5numpy_floating; + PyTypeObject *__pyx_ptype_5numpy_complexfloating; + PyTypeObject *__pyx_ptype_5numpy_flexible; + PyTypeObject *__pyx_ptype_5numpy_character; + PyTypeObject *__pyx_ptype_5numpy_ufunc; + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + PyObject *__pyx_type_7fairseq_4data_22token_block_utils_fast_DatasetSearcher; + PyObject *__pyx_type___pyx_array; + PyObject *__pyx_type___pyx_MemviewEnum; + PyObject *__pyx_type___pyx_memoryview; + PyObject *__pyx_type___pyx_memoryviewslice; + #endif + PyTypeObject *__pyx_ptype_7fairseq_4data_22token_block_utils_fast_DatasetSearcher; + PyTypeObject *__pyx_array_type; + PyTypeObject *__pyx_MemviewEnum_type; + PyTypeObject *__pyx_memoryview_type; + PyTypeObject *__pyx_memoryviewslice_type; + PyObject *__pyx_kp_u_; + PyObject *__pyx_n_s_ASCII; + PyObject *__pyx_kp_s_All_dimensions_preceding_dimensi; + PyObject *__pyx_n_s_AssertionError; + PyObject *__pyx_kp_s_Buffer_view_does_not_expose_stri; + PyObject *__pyx_kp_s_Can_only_create_a_buffer_that_is; + PyObject *__pyx_kp_s_Cannot_assign_to_read_only_memor; + PyObject *__pyx_kp_s_Cannot_create_writable_memory_vi; + PyObject *__pyx_kp_u_Cannot_index_with_type; + PyObject *__pyx_kp_s_Cannot_transpose_memoryview_with; + PyObject *__pyx_n_s_DTYPE; + PyObject *__pyx_n_s_DatasetSearcher; + PyObject *__pyx_n_s_DatasetSearcher___reduce_cython; + PyObject *__pyx_n_s_DatasetSearcher___setstate_cytho; + PyObject *__pyx_kp_s_Dimension_d_is_not_direct; + PyObject *__pyx_n_s_Ellipsis; + PyObject *__pyx_kp_s_Empty_shape_tuple_for_cython_arr; + PyObject *__pyx_n_s_ImportError; + PyObject *__pyx_kp_s_Incompatible_checksums_0x_x_vs_0; + PyObject *__pyx_kp_s_Incompatible_checksums_0x_x_vs_0_2; + PyObject *__pyx_n_s_IndexError; + PyObject *__pyx_kp_s_Index_out_of_bounds_axis_d; + PyObject *__pyx_kp_s_Indirect_dimensions_not_supporte; + PyObject *__pyx_kp_u_Invalid_break_mode; + PyObject *__pyx_kp_u_Invalid_mode_expected_c_or_fortr; + PyObject *__pyx_kp_u_Invalid_shape_in_axis; + PyObject *__pyx_n_s_MemoryError; + PyObject *__pyx_kp_s_MemoryView_of_r_at_0x_x; + PyObject *__pyx_kp_s_MemoryView_of_r_object; + PyObject *__pyx_n_b_O; + PyObject *__pyx_kp_u_Out_of_bounds_on_buffer_access_a; + PyObject *__pyx_n_s_PickleError; + PyObject *__pyx_n_s_Sequence; + PyObject *__pyx_kp_s_Step_may_not_be_zero_axis_d; + PyObject *__pyx_n_s_TypeError; + PyObject *__pyx_kp_s_Unable_to_convert_item_to_object; + PyObject *__pyx_n_s_ValueError; + PyObject *__pyx_n_s_View_MemoryView; + PyObject *__pyx_kp_u__2; + PyObject *__pyx_n_s__3; + PyObject *__pyx_n_s__35; + PyObject *__pyx_kp_u__6; + PyObject *__pyx_kp_u__7; + PyObject *__pyx_n_s_abc; + PyObject *__pyx_n_s_allocate_buffer; + PyObject *__pyx_kp_u_and; + PyObject *__pyx_n_s_asyncio_coroutines; + PyObject *__pyx_n_s_axis; + PyObject *__pyx_n_s_base; + PyObject *__pyx_n_s_block_size; + PyObject *__pyx_n_s_break_mode; + PyObject *__pyx_n_s_c; + PyObject *__pyx_n_u_c; + PyObject *__pyx_n_s_chain; + PyObject *__pyx_n_s_class; + PyObject *__pyx_n_s_class_getitem; + PyObject *__pyx_n_s_cline_in_traceback; + PyObject *__pyx_n_s_collections; + PyObject *__pyx_kp_s_collections_abc; + PyObject *__pyx_n_u_complete; + PyObject *__pyx_n_u_complete_doc; + PyObject *__pyx_kp_s_contiguous_and_direct; + PyObject *__pyx_kp_s_contiguous_and_indirect; + PyObject *__pyx_n_s_count; + PyObject *__pyx_n_s_cumsum; + PyObject *__pyx_n_s_dict; + PyObject *__pyx_n_s_dict_2; + PyObject *__pyx_kp_u_disable; + PyObject *__pyx_n_s_document_sep_len; + PyObject *__pyx_n_s_dtype; + PyObject *__pyx_n_s_dtype_is_object; + PyObject *__pyx_kp_u_enable; + PyObject *__pyx_n_s_encode; + PyObject *__pyx_n_s_enumerate; + PyObject *__pyx_n_u_eos; + PyObject *__pyx_n_s_error; + PyObject *__pyx_kp_s_fairseq_data_token_block_utils_f; + PyObject *__pyx_n_s_fairseq_data_token_block_utils_f_2; + PyObject *__pyx_n_s_flags; + PyObject *__pyx_n_s_format; + PyObject *__pyx_n_s_fortran; + PyObject *__pyx_n_u_fortran; + PyObject *__pyx_n_s_from_iterable; + PyObject *__pyx_n_s_fromiter; + PyObject *__pyx_kp_u_gc; + PyObject *__pyx_n_s_get_block_to_dataset_index_fast; + PyObject *__pyx_n_s_get_slice_indices_fast; + PyObject *__pyx_n_s_getstate; + PyObject *__pyx_kp_u_got; + PyObject *__pyx_kp_u_got_differing_extents_in_dimensi; + PyObject *__pyx_n_s_id; + PyObject *__pyx_n_s_import; + PyObject *__pyx_n_s_index; + PyObject *__pyx_n_s_initializing; + PyObject *__pyx_n_s_int64; + PyObject *__pyx_n_s_is_coroutine; + PyObject *__pyx_kp_u_isenabled; + PyObject *__pyx_n_s_itemsize; + PyObject *__pyx_kp_s_itemsize_0_for_cython_array; + PyObject *__pyx_n_s_itertools; + PyObject *__pyx_n_s_main; + PyObject *__pyx_n_s_memview; + PyObject *__pyx_n_s_mode; + PyObject *__pyx_n_s_name; + PyObject *__pyx_n_s_name_2; + PyObject *__pyx_n_s_ndim; + PyObject *__pyx_n_s_new; + PyObject *__pyx_kp_s_no_default___reduce___due_to_non; + PyObject *__pyx_n_u_none; + PyObject *__pyx_n_s_np; + PyObject *__pyx_n_s_numpy; + PyObject *__pyx_kp_u_numpy_core_multiarray_failed_to; + PyObject *__pyx_kp_u_numpy_core_umath_failed_to_impor; + PyObject *__pyx_n_s_obj; + PyObject *__pyx_n_s_pack; + PyObject *__pyx_n_s_pickle; + PyObject *__pyx_n_s_pyx_PickleError; + PyObject *__pyx_n_s_pyx_checksum; + PyObject *__pyx_n_s_pyx_result; + PyObject *__pyx_n_s_pyx_state; + PyObject *__pyx_n_s_pyx_type; + PyObject *__pyx_n_s_pyx_unpickle_DatasetSearcher; + PyObject *__pyx_n_s_pyx_unpickle_Enum; + PyObject *__pyx_n_s_pyx_vtable; + PyObject *__pyx_n_s_range; + PyObject *__pyx_n_s_reduce; + PyObject *__pyx_n_s_reduce_cython; + PyObject *__pyx_n_s_reduce_ex; + PyObject *__pyx_n_s_register; + PyObject *__pyx_n_s_reshape; + PyObject *__pyx_n_s_self; + PyObject *__pyx_n_s_setstate; + PyObject *__pyx_n_s_setstate_cython; + PyObject *__pyx_n_s_shape; + PyObject *__pyx_n_s_size; + PyObject *__pyx_n_s_sizes; + PyObject *__pyx_n_s_slice_indices; + PyObject *__pyx_n_s_spec; + PyObject *__pyx_n_s_start; + PyObject *__pyx_n_s_state; + PyObject *__pyx_n_s_step; + PyObject *__pyx_n_s_stop; + PyObject *__pyx_kp_s_strided_and_direct; + PyObject *__pyx_kp_s_strided_and_direct_or_indirect; + PyObject *__pyx_kp_s_strided_and_indirect; + PyObject *__pyx_kp_s_stringsource; + PyObject *__pyx_n_s_struct; + PyObject *__pyx_n_s_sum; + PyObject *__pyx_n_s_sys; + PyObject *__pyx_n_s_test; + PyObject *__pyx_n_s_torch; + PyObject *__pyx_kp_s_unable_to_allocate_array_data; + PyObject *__pyx_kp_s_unable_to_allocate_shape_and_str; + PyObject *__pyx_n_s_unpack; + PyObject *__pyx_n_s_update; + PyObject *__pyx_n_s_use_setstate; + PyObject *__pyx_n_s_version_info; + PyObject *__pyx_n_s_zeros; + PyObject *__pyx_int_0; + PyObject *__pyx_int_1; + PyObject *__pyx_int_2; + PyObject *__pyx_int_3; + PyObject *__pyx_int_48422178; + PyObject *__pyx_int_107161605; + PyObject *__pyx_int_112105877; + PyObject *__pyx_int_136983863; + PyObject *__pyx_int_147225413; + PyObject *__pyx_int_184977713; + PyObject *__pyx_int_neg_1; + PyObject *__pyx_slice__5; + PyObject *__pyx_tuple__4; + PyObject *__pyx_tuple__8; + PyObject *__pyx_tuple__9; + PyObject *__pyx_slice__11; + PyObject *__pyx_tuple__10; + PyObject *__pyx_tuple__12; + PyObject *__pyx_tuple__13; + PyObject *__pyx_tuple__14; + PyObject *__pyx_tuple__15; + PyObject *__pyx_tuple__16; + PyObject *__pyx_tuple__17; + PyObject *__pyx_tuple__18; + PyObject *__pyx_tuple__19; + PyObject *__pyx_tuple__20; + PyObject *__pyx_tuple__21; + PyObject *__pyx_tuple__22; + PyObject *__pyx_tuple__23; + PyObject *__pyx_tuple__24; + PyObject *__pyx_tuple__26; + PyObject *__pyx_tuple__28; + PyObject *__pyx_tuple__30; + PyObject *__pyx_tuple__32; + PyObject *__pyx_codeobj__25; + PyObject *__pyx_codeobj__27; + PyObject *__pyx_codeobj__29; + PyObject *__pyx_codeobj__31; + PyObject *__pyx_codeobj__33; + PyObject *__pyx_codeobj__34; +} __pyx_mstate; + +#if CYTHON_USE_MODULE_STATE +#ifdef __cplusplus +namespace { + extern struct PyModuleDef __pyx_moduledef; +} /* anonymous namespace */ +#else +static struct PyModuleDef __pyx_moduledef; +#endif + +#define __pyx_mstate(o) ((__pyx_mstate *)__Pyx_PyModule_GetState(o)) + +#define __pyx_mstate_global (__pyx_mstate(PyState_FindModule(&__pyx_moduledef))) + +#define __pyx_m (PyState_FindModule(&__pyx_moduledef)) +#else +static __pyx_mstate __pyx_mstate_global_static = +#ifdef __cplusplus + {}; +#else + {0}; +#endif +static __pyx_mstate *__pyx_mstate_global = &__pyx_mstate_global_static; +#endif +/* #### Code section: module_state_clear ### */ +#if CYTHON_USE_MODULE_STATE +static int __pyx_m_clear(PyObject *m) { + __pyx_mstate *clear_module_state = __pyx_mstate(m); + if (!clear_module_state) return 0; + Py_CLEAR(clear_module_state->__pyx_d); + Py_CLEAR(clear_module_state->__pyx_b); + Py_CLEAR(clear_module_state->__pyx_cython_runtime); + Py_CLEAR(clear_module_state->__pyx_empty_tuple); + Py_CLEAR(clear_module_state->__pyx_empty_bytes); + Py_CLEAR(clear_module_state->__pyx_empty_unicode); + #ifdef __Pyx_CyFunction_USED + Py_CLEAR(clear_module_state->__pyx_CyFunctionType); + #endif + #ifdef __Pyx_FusedFunction_USED + Py_CLEAR(clear_module_state->__pyx_FusedFunctionType); + #endif + Py_CLEAR(clear_module_state->__pyx_ptype_7cpython_4type_type); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_dtype); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_flatiter); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_broadcast); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_ndarray); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_generic); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_number); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_integer); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_signedinteger); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_unsignedinteger); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_inexact); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_floating); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_complexfloating); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_flexible); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_character); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_ufunc); + Py_CLEAR(clear_module_state->__pyx_ptype_7fairseq_4data_22token_block_utils_fast_DatasetSearcher); + Py_CLEAR(clear_module_state->__pyx_type_7fairseq_4data_22token_block_utils_fast_DatasetSearcher); + Py_CLEAR(clear_module_state->__pyx_array_type); + Py_CLEAR(clear_module_state->__pyx_type___pyx_array); + Py_CLEAR(clear_module_state->__pyx_MemviewEnum_type); + Py_CLEAR(clear_module_state->__pyx_type___pyx_MemviewEnum); + Py_CLEAR(clear_module_state->__pyx_memoryview_type); + Py_CLEAR(clear_module_state->__pyx_type___pyx_memoryview); + Py_CLEAR(clear_module_state->__pyx_memoryviewslice_type); + Py_CLEAR(clear_module_state->__pyx_type___pyx_memoryviewslice); + Py_CLEAR(clear_module_state->__pyx_kp_u_); + Py_CLEAR(clear_module_state->__pyx_n_s_ASCII); + Py_CLEAR(clear_module_state->__pyx_kp_s_All_dimensions_preceding_dimensi); + Py_CLEAR(clear_module_state->__pyx_n_s_AssertionError); + Py_CLEAR(clear_module_state->__pyx_kp_s_Buffer_view_does_not_expose_stri); + Py_CLEAR(clear_module_state->__pyx_kp_s_Can_only_create_a_buffer_that_is); + Py_CLEAR(clear_module_state->__pyx_kp_s_Cannot_assign_to_read_only_memor); + Py_CLEAR(clear_module_state->__pyx_kp_s_Cannot_create_writable_memory_vi); + Py_CLEAR(clear_module_state->__pyx_kp_u_Cannot_index_with_type); + Py_CLEAR(clear_module_state->__pyx_kp_s_Cannot_transpose_memoryview_with); + Py_CLEAR(clear_module_state->__pyx_n_s_DTYPE); + Py_CLEAR(clear_module_state->__pyx_n_s_DatasetSearcher); + Py_CLEAR(clear_module_state->__pyx_n_s_DatasetSearcher___reduce_cython); + Py_CLEAR(clear_module_state->__pyx_n_s_DatasetSearcher___setstate_cytho); + Py_CLEAR(clear_module_state->__pyx_kp_s_Dimension_d_is_not_direct); + Py_CLEAR(clear_module_state->__pyx_n_s_Ellipsis); + Py_CLEAR(clear_module_state->__pyx_kp_s_Empty_shape_tuple_for_cython_arr); + Py_CLEAR(clear_module_state->__pyx_n_s_ImportError); + Py_CLEAR(clear_module_state->__pyx_kp_s_Incompatible_checksums_0x_x_vs_0); + Py_CLEAR(clear_module_state->__pyx_kp_s_Incompatible_checksums_0x_x_vs_0_2); + Py_CLEAR(clear_module_state->__pyx_n_s_IndexError); + Py_CLEAR(clear_module_state->__pyx_kp_s_Index_out_of_bounds_axis_d); + Py_CLEAR(clear_module_state->__pyx_kp_s_Indirect_dimensions_not_supporte); + Py_CLEAR(clear_module_state->__pyx_kp_u_Invalid_break_mode); + Py_CLEAR(clear_module_state->__pyx_kp_u_Invalid_mode_expected_c_or_fortr); + Py_CLEAR(clear_module_state->__pyx_kp_u_Invalid_shape_in_axis); + Py_CLEAR(clear_module_state->__pyx_n_s_MemoryError); + Py_CLEAR(clear_module_state->__pyx_kp_s_MemoryView_of_r_at_0x_x); + Py_CLEAR(clear_module_state->__pyx_kp_s_MemoryView_of_r_object); + Py_CLEAR(clear_module_state->__pyx_n_b_O); + Py_CLEAR(clear_module_state->__pyx_kp_u_Out_of_bounds_on_buffer_access_a); + Py_CLEAR(clear_module_state->__pyx_n_s_PickleError); + Py_CLEAR(clear_module_state->__pyx_n_s_Sequence); + Py_CLEAR(clear_module_state->__pyx_kp_s_Step_may_not_be_zero_axis_d); + Py_CLEAR(clear_module_state->__pyx_n_s_TypeError); + Py_CLEAR(clear_module_state->__pyx_kp_s_Unable_to_convert_item_to_object); + Py_CLEAR(clear_module_state->__pyx_n_s_ValueError); + Py_CLEAR(clear_module_state->__pyx_n_s_View_MemoryView); + Py_CLEAR(clear_module_state->__pyx_kp_u__2); + Py_CLEAR(clear_module_state->__pyx_n_s__3); + Py_CLEAR(clear_module_state->__pyx_n_s__35); + Py_CLEAR(clear_module_state->__pyx_kp_u__6); + Py_CLEAR(clear_module_state->__pyx_kp_u__7); + Py_CLEAR(clear_module_state->__pyx_n_s_abc); + Py_CLEAR(clear_module_state->__pyx_n_s_allocate_buffer); + Py_CLEAR(clear_module_state->__pyx_kp_u_and); + Py_CLEAR(clear_module_state->__pyx_n_s_asyncio_coroutines); + Py_CLEAR(clear_module_state->__pyx_n_s_axis); + Py_CLEAR(clear_module_state->__pyx_n_s_base); + Py_CLEAR(clear_module_state->__pyx_n_s_block_size); + Py_CLEAR(clear_module_state->__pyx_n_s_break_mode); + Py_CLEAR(clear_module_state->__pyx_n_s_c); + Py_CLEAR(clear_module_state->__pyx_n_u_c); + Py_CLEAR(clear_module_state->__pyx_n_s_chain); + Py_CLEAR(clear_module_state->__pyx_n_s_class); + Py_CLEAR(clear_module_state->__pyx_n_s_class_getitem); + Py_CLEAR(clear_module_state->__pyx_n_s_cline_in_traceback); + Py_CLEAR(clear_module_state->__pyx_n_s_collections); + Py_CLEAR(clear_module_state->__pyx_kp_s_collections_abc); + Py_CLEAR(clear_module_state->__pyx_n_u_complete); + Py_CLEAR(clear_module_state->__pyx_n_u_complete_doc); + Py_CLEAR(clear_module_state->__pyx_kp_s_contiguous_and_direct); + Py_CLEAR(clear_module_state->__pyx_kp_s_contiguous_and_indirect); + Py_CLEAR(clear_module_state->__pyx_n_s_count); + Py_CLEAR(clear_module_state->__pyx_n_s_cumsum); + Py_CLEAR(clear_module_state->__pyx_n_s_dict); + Py_CLEAR(clear_module_state->__pyx_n_s_dict_2); + Py_CLEAR(clear_module_state->__pyx_kp_u_disable); + Py_CLEAR(clear_module_state->__pyx_n_s_document_sep_len); + Py_CLEAR(clear_module_state->__pyx_n_s_dtype); + Py_CLEAR(clear_module_state->__pyx_n_s_dtype_is_object); + Py_CLEAR(clear_module_state->__pyx_kp_u_enable); + Py_CLEAR(clear_module_state->__pyx_n_s_encode); + Py_CLEAR(clear_module_state->__pyx_n_s_enumerate); + Py_CLEAR(clear_module_state->__pyx_n_u_eos); + Py_CLEAR(clear_module_state->__pyx_n_s_error); + Py_CLEAR(clear_module_state->__pyx_kp_s_fairseq_data_token_block_utils_f); + Py_CLEAR(clear_module_state->__pyx_n_s_fairseq_data_token_block_utils_f_2); + Py_CLEAR(clear_module_state->__pyx_n_s_flags); + Py_CLEAR(clear_module_state->__pyx_n_s_format); + Py_CLEAR(clear_module_state->__pyx_n_s_fortran); + Py_CLEAR(clear_module_state->__pyx_n_u_fortran); + Py_CLEAR(clear_module_state->__pyx_n_s_from_iterable); + Py_CLEAR(clear_module_state->__pyx_n_s_fromiter); + Py_CLEAR(clear_module_state->__pyx_kp_u_gc); + Py_CLEAR(clear_module_state->__pyx_n_s_get_block_to_dataset_index_fast); + Py_CLEAR(clear_module_state->__pyx_n_s_get_slice_indices_fast); + Py_CLEAR(clear_module_state->__pyx_n_s_getstate); + Py_CLEAR(clear_module_state->__pyx_kp_u_got); + Py_CLEAR(clear_module_state->__pyx_kp_u_got_differing_extents_in_dimensi); + Py_CLEAR(clear_module_state->__pyx_n_s_id); + Py_CLEAR(clear_module_state->__pyx_n_s_import); + Py_CLEAR(clear_module_state->__pyx_n_s_index); + Py_CLEAR(clear_module_state->__pyx_n_s_initializing); + Py_CLEAR(clear_module_state->__pyx_n_s_int64); + Py_CLEAR(clear_module_state->__pyx_n_s_is_coroutine); + Py_CLEAR(clear_module_state->__pyx_kp_u_isenabled); + Py_CLEAR(clear_module_state->__pyx_n_s_itemsize); + Py_CLEAR(clear_module_state->__pyx_kp_s_itemsize_0_for_cython_array); + Py_CLEAR(clear_module_state->__pyx_n_s_itertools); + Py_CLEAR(clear_module_state->__pyx_n_s_main); + Py_CLEAR(clear_module_state->__pyx_n_s_memview); + Py_CLEAR(clear_module_state->__pyx_n_s_mode); + Py_CLEAR(clear_module_state->__pyx_n_s_name); + Py_CLEAR(clear_module_state->__pyx_n_s_name_2); + Py_CLEAR(clear_module_state->__pyx_n_s_ndim); + Py_CLEAR(clear_module_state->__pyx_n_s_new); + Py_CLEAR(clear_module_state->__pyx_kp_s_no_default___reduce___due_to_non); + Py_CLEAR(clear_module_state->__pyx_n_u_none); + Py_CLEAR(clear_module_state->__pyx_n_s_np); + Py_CLEAR(clear_module_state->__pyx_n_s_numpy); + Py_CLEAR(clear_module_state->__pyx_kp_u_numpy_core_multiarray_failed_to); + Py_CLEAR(clear_module_state->__pyx_kp_u_numpy_core_umath_failed_to_impor); + Py_CLEAR(clear_module_state->__pyx_n_s_obj); + Py_CLEAR(clear_module_state->__pyx_n_s_pack); + Py_CLEAR(clear_module_state->__pyx_n_s_pickle); + Py_CLEAR(clear_module_state->__pyx_n_s_pyx_PickleError); + Py_CLEAR(clear_module_state->__pyx_n_s_pyx_checksum); + Py_CLEAR(clear_module_state->__pyx_n_s_pyx_result); + Py_CLEAR(clear_module_state->__pyx_n_s_pyx_state); + Py_CLEAR(clear_module_state->__pyx_n_s_pyx_type); + Py_CLEAR(clear_module_state->__pyx_n_s_pyx_unpickle_DatasetSearcher); + Py_CLEAR(clear_module_state->__pyx_n_s_pyx_unpickle_Enum); + Py_CLEAR(clear_module_state->__pyx_n_s_pyx_vtable); + Py_CLEAR(clear_module_state->__pyx_n_s_range); + Py_CLEAR(clear_module_state->__pyx_n_s_reduce); + Py_CLEAR(clear_module_state->__pyx_n_s_reduce_cython); + Py_CLEAR(clear_module_state->__pyx_n_s_reduce_ex); + Py_CLEAR(clear_module_state->__pyx_n_s_register); + Py_CLEAR(clear_module_state->__pyx_n_s_reshape); + Py_CLEAR(clear_module_state->__pyx_n_s_self); + Py_CLEAR(clear_module_state->__pyx_n_s_setstate); + Py_CLEAR(clear_module_state->__pyx_n_s_setstate_cython); + Py_CLEAR(clear_module_state->__pyx_n_s_shape); + Py_CLEAR(clear_module_state->__pyx_n_s_size); + Py_CLEAR(clear_module_state->__pyx_n_s_sizes); + Py_CLEAR(clear_module_state->__pyx_n_s_slice_indices); + Py_CLEAR(clear_module_state->__pyx_n_s_spec); + Py_CLEAR(clear_module_state->__pyx_n_s_start); + Py_CLEAR(clear_module_state->__pyx_n_s_state); + Py_CLEAR(clear_module_state->__pyx_n_s_step); + Py_CLEAR(clear_module_state->__pyx_n_s_stop); + Py_CLEAR(clear_module_state->__pyx_kp_s_strided_and_direct); + Py_CLEAR(clear_module_state->__pyx_kp_s_strided_and_direct_or_indirect); + Py_CLEAR(clear_module_state->__pyx_kp_s_strided_and_indirect); + Py_CLEAR(clear_module_state->__pyx_kp_s_stringsource); + Py_CLEAR(clear_module_state->__pyx_n_s_struct); + Py_CLEAR(clear_module_state->__pyx_n_s_sum); + Py_CLEAR(clear_module_state->__pyx_n_s_sys); + Py_CLEAR(clear_module_state->__pyx_n_s_test); + Py_CLEAR(clear_module_state->__pyx_n_s_torch); + Py_CLEAR(clear_module_state->__pyx_kp_s_unable_to_allocate_array_data); + Py_CLEAR(clear_module_state->__pyx_kp_s_unable_to_allocate_shape_and_str); + Py_CLEAR(clear_module_state->__pyx_n_s_unpack); + Py_CLEAR(clear_module_state->__pyx_n_s_update); + Py_CLEAR(clear_module_state->__pyx_n_s_use_setstate); + Py_CLEAR(clear_module_state->__pyx_n_s_version_info); + Py_CLEAR(clear_module_state->__pyx_n_s_zeros); + Py_CLEAR(clear_module_state->__pyx_int_0); + Py_CLEAR(clear_module_state->__pyx_int_1); + Py_CLEAR(clear_module_state->__pyx_int_2); + Py_CLEAR(clear_module_state->__pyx_int_3); + Py_CLEAR(clear_module_state->__pyx_int_48422178); + Py_CLEAR(clear_module_state->__pyx_int_107161605); + Py_CLEAR(clear_module_state->__pyx_int_112105877); + Py_CLEAR(clear_module_state->__pyx_int_136983863); + Py_CLEAR(clear_module_state->__pyx_int_147225413); + Py_CLEAR(clear_module_state->__pyx_int_184977713); + Py_CLEAR(clear_module_state->__pyx_int_neg_1); + Py_CLEAR(clear_module_state->__pyx_slice__5); + Py_CLEAR(clear_module_state->__pyx_tuple__4); + Py_CLEAR(clear_module_state->__pyx_tuple__8); + Py_CLEAR(clear_module_state->__pyx_tuple__9); + Py_CLEAR(clear_module_state->__pyx_slice__11); + Py_CLEAR(clear_module_state->__pyx_tuple__10); + Py_CLEAR(clear_module_state->__pyx_tuple__12); + Py_CLEAR(clear_module_state->__pyx_tuple__13); + Py_CLEAR(clear_module_state->__pyx_tuple__14); + Py_CLEAR(clear_module_state->__pyx_tuple__15); + Py_CLEAR(clear_module_state->__pyx_tuple__16); + Py_CLEAR(clear_module_state->__pyx_tuple__17); + Py_CLEAR(clear_module_state->__pyx_tuple__18); + Py_CLEAR(clear_module_state->__pyx_tuple__19); + Py_CLEAR(clear_module_state->__pyx_tuple__20); + Py_CLEAR(clear_module_state->__pyx_tuple__21); + Py_CLEAR(clear_module_state->__pyx_tuple__22); + Py_CLEAR(clear_module_state->__pyx_tuple__23); + Py_CLEAR(clear_module_state->__pyx_tuple__24); + Py_CLEAR(clear_module_state->__pyx_tuple__26); + Py_CLEAR(clear_module_state->__pyx_tuple__28); + Py_CLEAR(clear_module_state->__pyx_tuple__30); + Py_CLEAR(clear_module_state->__pyx_tuple__32); + Py_CLEAR(clear_module_state->__pyx_codeobj__25); + Py_CLEAR(clear_module_state->__pyx_codeobj__27); + Py_CLEAR(clear_module_state->__pyx_codeobj__29); + Py_CLEAR(clear_module_state->__pyx_codeobj__31); + Py_CLEAR(clear_module_state->__pyx_codeobj__33); + Py_CLEAR(clear_module_state->__pyx_codeobj__34); + return 0; +} +#endif +/* #### Code section: module_state_traverse ### */ +#if CYTHON_USE_MODULE_STATE +static int __pyx_m_traverse(PyObject *m, visitproc visit, void *arg) { + __pyx_mstate *traverse_module_state = __pyx_mstate(m); + if (!traverse_module_state) return 0; + Py_VISIT(traverse_module_state->__pyx_d); + Py_VISIT(traverse_module_state->__pyx_b); + Py_VISIT(traverse_module_state->__pyx_cython_runtime); + Py_VISIT(traverse_module_state->__pyx_empty_tuple); + Py_VISIT(traverse_module_state->__pyx_empty_bytes); + Py_VISIT(traverse_module_state->__pyx_empty_unicode); + #ifdef __Pyx_CyFunction_USED + Py_VISIT(traverse_module_state->__pyx_CyFunctionType); + #endif + #ifdef __Pyx_FusedFunction_USED + Py_VISIT(traverse_module_state->__pyx_FusedFunctionType); + #endif + Py_VISIT(traverse_module_state->__pyx_ptype_7cpython_4type_type); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_dtype); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_flatiter); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_broadcast); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_ndarray); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_generic); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_number); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_integer); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_signedinteger); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_unsignedinteger); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_inexact); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_floating); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_complexfloating); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_flexible); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_character); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_ufunc); + Py_VISIT(traverse_module_state->__pyx_ptype_7fairseq_4data_22token_block_utils_fast_DatasetSearcher); + Py_VISIT(traverse_module_state->__pyx_type_7fairseq_4data_22token_block_utils_fast_DatasetSearcher); + Py_VISIT(traverse_module_state->__pyx_array_type); + Py_VISIT(traverse_module_state->__pyx_type___pyx_array); + Py_VISIT(traverse_module_state->__pyx_MemviewEnum_type); + Py_VISIT(traverse_module_state->__pyx_type___pyx_MemviewEnum); + Py_VISIT(traverse_module_state->__pyx_memoryview_type); + Py_VISIT(traverse_module_state->__pyx_type___pyx_memoryview); + Py_VISIT(traverse_module_state->__pyx_memoryviewslice_type); + Py_VISIT(traverse_module_state->__pyx_type___pyx_memoryviewslice); + Py_VISIT(traverse_module_state->__pyx_kp_u_); + Py_VISIT(traverse_module_state->__pyx_n_s_ASCII); + Py_VISIT(traverse_module_state->__pyx_kp_s_All_dimensions_preceding_dimensi); + Py_VISIT(traverse_module_state->__pyx_n_s_AssertionError); + Py_VISIT(traverse_module_state->__pyx_kp_s_Buffer_view_does_not_expose_stri); + Py_VISIT(traverse_module_state->__pyx_kp_s_Can_only_create_a_buffer_that_is); + Py_VISIT(traverse_module_state->__pyx_kp_s_Cannot_assign_to_read_only_memor); + Py_VISIT(traverse_module_state->__pyx_kp_s_Cannot_create_writable_memory_vi); + Py_VISIT(traverse_module_state->__pyx_kp_u_Cannot_index_with_type); + Py_VISIT(traverse_module_state->__pyx_kp_s_Cannot_transpose_memoryview_with); + Py_VISIT(traverse_module_state->__pyx_n_s_DTYPE); + Py_VISIT(traverse_module_state->__pyx_n_s_DatasetSearcher); + Py_VISIT(traverse_module_state->__pyx_n_s_DatasetSearcher___reduce_cython); + Py_VISIT(traverse_module_state->__pyx_n_s_DatasetSearcher___setstate_cytho); + Py_VISIT(traverse_module_state->__pyx_kp_s_Dimension_d_is_not_direct); + Py_VISIT(traverse_module_state->__pyx_n_s_Ellipsis); + Py_VISIT(traverse_module_state->__pyx_kp_s_Empty_shape_tuple_for_cython_arr); + Py_VISIT(traverse_module_state->__pyx_n_s_ImportError); + Py_VISIT(traverse_module_state->__pyx_kp_s_Incompatible_checksums_0x_x_vs_0); + Py_VISIT(traverse_module_state->__pyx_kp_s_Incompatible_checksums_0x_x_vs_0_2); + Py_VISIT(traverse_module_state->__pyx_n_s_IndexError); + Py_VISIT(traverse_module_state->__pyx_kp_s_Index_out_of_bounds_axis_d); + Py_VISIT(traverse_module_state->__pyx_kp_s_Indirect_dimensions_not_supporte); + Py_VISIT(traverse_module_state->__pyx_kp_u_Invalid_break_mode); + Py_VISIT(traverse_module_state->__pyx_kp_u_Invalid_mode_expected_c_or_fortr); + Py_VISIT(traverse_module_state->__pyx_kp_u_Invalid_shape_in_axis); + Py_VISIT(traverse_module_state->__pyx_n_s_MemoryError); + Py_VISIT(traverse_module_state->__pyx_kp_s_MemoryView_of_r_at_0x_x); + Py_VISIT(traverse_module_state->__pyx_kp_s_MemoryView_of_r_object); + Py_VISIT(traverse_module_state->__pyx_n_b_O); + Py_VISIT(traverse_module_state->__pyx_kp_u_Out_of_bounds_on_buffer_access_a); + Py_VISIT(traverse_module_state->__pyx_n_s_PickleError); + Py_VISIT(traverse_module_state->__pyx_n_s_Sequence); + Py_VISIT(traverse_module_state->__pyx_kp_s_Step_may_not_be_zero_axis_d); + Py_VISIT(traverse_module_state->__pyx_n_s_TypeError); + Py_VISIT(traverse_module_state->__pyx_kp_s_Unable_to_convert_item_to_object); + Py_VISIT(traverse_module_state->__pyx_n_s_ValueError); + Py_VISIT(traverse_module_state->__pyx_n_s_View_MemoryView); + Py_VISIT(traverse_module_state->__pyx_kp_u__2); + Py_VISIT(traverse_module_state->__pyx_n_s__3); + Py_VISIT(traverse_module_state->__pyx_n_s__35); + Py_VISIT(traverse_module_state->__pyx_kp_u__6); + Py_VISIT(traverse_module_state->__pyx_kp_u__7); + Py_VISIT(traverse_module_state->__pyx_n_s_abc); + Py_VISIT(traverse_module_state->__pyx_n_s_allocate_buffer); + Py_VISIT(traverse_module_state->__pyx_kp_u_and); + Py_VISIT(traverse_module_state->__pyx_n_s_asyncio_coroutines); + Py_VISIT(traverse_module_state->__pyx_n_s_axis); + Py_VISIT(traverse_module_state->__pyx_n_s_base); + Py_VISIT(traverse_module_state->__pyx_n_s_block_size); + Py_VISIT(traverse_module_state->__pyx_n_s_break_mode); + Py_VISIT(traverse_module_state->__pyx_n_s_c); + Py_VISIT(traverse_module_state->__pyx_n_u_c); + Py_VISIT(traverse_module_state->__pyx_n_s_chain); + Py_VISIT(traverse_module_state->__pyx_n_s_class); + Py_VISIT(traverse_module_state->__pyx_n_s_class_getitem); + Py_VISIT(traverse_module_state->__pyx_n_s_cline_in_traceback); + Py_VISIT(traverse_module_state->__pyx_n_s_collections); + Py_VISIT(traverse_module_state->__pyx_kp_s_collections_abc); + Py_VISIT(traverse_module_state->__pyx_n_u_complete); + Py_VISIT(traverse_module_state->__pyx_n_u_complete_doc); + Py_VISIT(traverse_module_state->__pyx_kp_s_contiguous_and_direct); + Py_VISIT(traverse_module_state->__pyx_kp_s_contiguous_and_indirect); + Py_VISIT(traverse_module_state->__pyx_n_s_count); + Py_VISIT(traverse_module_state->__pyx_n_s_cumsum); + Py_VISIT(traverse_module_state->__pyx_n_s_dict); + Py_VISIT(traverse_module_state->__pyx_n_s_dict_2); + Py_VISIT(traverse_module_state->__pyx_kp_u_disable); + Py_VISIT(traverse_module_state->__pyx_n_s_document_sep_len); + Py_VISIT(traverse_module_state->__pyx_n_s_dtype); + Py_VISIT(traverse_module_state->__pyx_n_s_dtype_is_object); + Py_VISIT(traverse_module_state->__pyx_kp_u_enable); + Py_VISIT(traverse_module_state->__pyx_n_s_encode); + Py_VISIT(traverse_module_state->__pyx_n_s_enumerate); + Py_VISIT(traverse_module_state->__pyx_n_u_eos); + Py_VISIT(traverse_module_state->__pyx_n_s_error); + Py_VISIT(traverse_module_state->__pyx_kp_s_fairseq_data_token_block_utils_f); + Py_VISIT(traverse_module_state->__pyx_n_s_fairseq_data_token_block_utils_f_2); + Py_VISIT(traverse_module_state->__pyx_n_s_flags); + Py_VISIT(traverse_module_state->__pyx_n_s_format); + Py_VISIT(traverse_module_state->__pyx_n_s_fortran); + Py_VISIT(traverse_module_state->__pyx_n_u_fortran); + Py_VISIT(traverse_module_state->__pyx_n_s_from_iterable); + Py_VISIT(traverse_module_state->__pyx_n_s_fromiter); + Py_VISIT(traverse_module_state->__pyx_kp_u_gc); + Py_VISIT(traverse_module_state->__pyx_n_s_get_block_to_dataset_index_fast); + Py_VISIT(traverse_module_state->__pyx_n_s_get_slice_indices_fast); + Py_VISIT(traverse_module_state->__pyx_n_s_getstate); + Py_VISIT(traverse_module_state->__pyx_kp_u_got); + Py_VISIT(traverse_module_state->__pyx_kp_u_got_differing_extents_in_dimensi); + Py_VISIT(traverse_module_state->__pyx_n_s_id); + Py_VISIT(traverse_module_state->__pyx_n_s_import); + Py_VISIT(traverse_module_state->__pyx_n_s_index); + Py_VISIT(traverse_module_state->__pyx_n_s_initializing); + Py_VISIT(traverse_module_state->__pyx_n_s_int64); + Py_VISIT(traverse_module_state->__pyx_n_s_is_coroutine); + Py_VISIT(traverse_module_state->__pyx_kp_u_isenabled); + Py_VISIT(traverse_module_state->__pyx_n_s_itemsize); + Py_VISIT(traverse_module_state->__pyx_kp_s_itemsize_0_for_cython_array); + Py_VISIT(traverse_module_state->__pyx_n_s_itertools); + Py_VISIT(traverse_module_state->__pyx_n_s_main); + Py_VISIT(traverse_module_state->__pyx_n_s_memview); + Py_VISIT(traverse_module_state->__pyx_n_s_mode); + Py_VISIT(traverse_module_state->__pyx_n_s_name); + Py_VISIT(traverse_module_state->__pyx_n_s_name_2); + Py_VISIT(traverse_module_state->__pyx_n_s_ndim); + Py_VISIT(traverse_module_state->__pyx_n_s_new); + Py_VISIT(traverse_module_state->__pyx_kp_s_no_default___reduce___due_to_non); + Py_VISIT(traverse_module_state->__pyx_n_u_none); + Py_VISIT(traverse_module_state->__pyx_n_s_np); + Py_VISIT(traverse_module_state->__pyx_n_s_numpy); + Py_VISIT(traverse_module_state->__pyx_kp_u_numpy_core_multiarray_failed_to); + Py_VISIT(traverse_module_state->__pyx_kp_u_numpy_core_umath_failed_to_impor); + Py_VISIT(traverse_module_state->__pyx_n_s_obj); + Py_VISIT(traverse_module_state->__pyx_n_s_pack); + Py_VISIT(traverse_module_state->__pyx_n_s_pickle); + Py_VISIT(traverse_module_state->__pyx_n_s_pyx_PickleError); + Py_VISIT(traverse_module_state->__pyx_n_s_pyx_checksum); + Py_VISIT(traverse_module_state->__pyx_n_s_pyx_result); + Py_VISIT(traverse_module_state->__pyx_n_s_pyx_state); + Py_VISIT(traverse_module_state->__pyx_n_s_pyx_type); + Py_VISIT(traverse_module_state->__pyx_n_s_pyx_unpickle_DatasetSearcher); + Py_VISIT(traverse_module_state->__pyx_n_s_pyx_unpickle_Enum); + Py_VISIT(traverse_module_state->__pyx_n_s_pyx_vtable); + Py_VISIT(traverse_module_state->__pyx_n_s_range); + Py_VISIT(traverse_module_state->__pyx_n_s_reduce); + Py_VISIT(traverse_module_state->__pyx_n_s_reduce_cython); + Py_VISIT(traverse_module_state->__pyx_n_s_reduce_ex); + Py_VISIT(traverse_module_state->__pyx_n_s_register); + Py_VISIT(traverse_module_state->__pyx_n_s_reshape); + Py_VISIT(traverse_module_state->__pyx_n_s_self); + Py_VISIT(traverse_module_state->__pyx_n_s_setstate); + Py_VISIT(traverse_module_state->__pyx_n_s_setstate_cython); + Py_VISIT(traverse_module_state->__pyx_n_s_shape); + Py_VISIT(traverse_module_state->__pyx_n_s_size); + Py_VISIT(traverse_module_state->__pyx_n_s_sizes); + Py_VISIT(traverse_module_state->__pyx_n_s_slice_indices); + Py_VISIT(traverse_module_state->__pyx_n_s_spec); + Py_VISIT(traverse_module_state->__pyx_n_s_start); + Py_VISIT(traverse_module_state->__pyx_n_s_state); + Py_VISIT(traverse_module_state->__pyx_n_s_step); + Py_VISIT(traverse_module_state->__pyx_n_s_stop); + Py_VISIT(traverse_module_state->__pyx_kp_s_strided_and_direct); + Py_VISIT(traverse_module_state->__pyx_kp_s_strided_and_direct_or_indirect); + Py_VISIT(traverse_module_state->__pyx_kp_s_strided_and_indirect); + Py_VISIT(traverse_module_state->__pyx_kp_s_stringsource); + Py_VISIT(traverse_module_state->__pyx_n_s_struct); + Py_VISIT(traverse_module_state->__pyx_n_s_sum); + Py_VISIT(traverse_module_state->__pyx_n_s_sys); + Py_VISIT(traverse_module_state->__pyx_n_s_test); + Py_VISIT(traverse_module_state->__pyx_n_s_torch); + Py_VISIT(traverse_module_state->__pyx_kp_s_unable_to_allocate_array_data); + Py_VISIT(traverse_module_state->__pyx_kp_s_unable_to_allocate_shape_and_str); + Py_VISIT(traverse_module_state->__pyx_n_s_unpack); + Py_VISIT(traverse_module_state->__pyx_n_s_update); + Py_VISIT(traverse_module_state->__pyx_n_s_use_setstate); + Py_VISIT(traverse_module_state->__pyx_n_s_version_info); + Py_VISIT(traverse_module_state->__pyx_n_s_zeros); + Py_VISIT(traverse_module_state->__pyx_int_0); + Py_VISIT(traverse_module_state->__pyx_int_1); + Py_VISIT(traverse_module_state->__pyx_int_2); + Py_VISIT(traverse_module_state->__pyx_int_3); + Py_VISIT(traverse_module_state->__pyx_int_48422178); + Py_VISIT(traverse_module_state->__pyx_int_107161605); + Py_VISIT(traverse_module_state->__pyx_int_112105877); + Py_VISIT(traverse_module_state->__pyx_int_136983863); + Py_VISIT(traverse_module_state->__pyx_int_147225413); + Py_VISIT(traverse_module_state->__pyx_int_184977713); + Py_VISIT(traverse_module_state->__pyx_int_neg_1); + Py_VISIT(traverse_module_state->__pyx_slice__5); + Py_VISIT(traverse_module_state->__pyx_tuple__4); + Py_VISIT(traverse_module_state->__pyx_tuple__8); + Py_VISIT(traverse_module_state->__pyx_tuple__9); + Py_VISIT(traverse_module_state->__pyx_slice__11); + Py_VISIT(traverse_module_state->__pyx_tuple__10); + Py_VISIT(traverse_module_state->__pyx_tuple__12); + Py_VISIT(traverse_module_state->__pyx_tuple__13); + Py_VISIT(traverse_module_state->__pyx_tuple__14); + Py_VISIT(traverse_module_state->__pyx_tuple__15); + Py_VISIT(traverse_module_state->__pyx_tuple__16); + Py_VISIT(traverse_module_state->__pyx_tuple__17); + Py_VISIT(traverse_module_state->__pyx_tuple__18); + Py_VISIT(traverse_module_state->__pyx_tuple__19); + Py_VISIT(traverse_module_state->__pyx_tuple__20); + Py_VISIT(traverse_module_state->__pyx_tuple__21); + Py_VISIT(traverse_module_state->__pyx_tuple__22); + Py_VISIT(traverse_module_state->__pyx_tuple__23); + Py_VISIT(traverse_module_state->__pyx_tuple__24); + Py_VISIT(traverse_module_state->__pyx_tuple__26); + Py_VISIT(traverse_module_state->__pyx_tuple__28); + Py_VISIT(traverse_module_state->__pyx_tuple__30); + Py_VISIT(traverse_module_state->__pyx_tuple__32); + Py_VISIT(traverse_module_state->__pyx_codeobj__25); + Py_VISIT(traverse_module_state->__pyx_codeobj__27); + Py_VISIT(traverse_module_state->__pyx_codeobj__29); + Py_VISIT(traverse_module_state->__pyx_codeobj__31); + Py_VISIT(traverse_module_state->__pyx_codeobj__33); + Py_VISIT(traverse_module_state->__pyx_codeobj__34); + return 0; +} +#endif +/* #### Code section: module_state_defines ### */ +#define __pyx_d __pyx_mstate_global->__pyx_d +#define __pyx_b __pyx_mstate_global->__pyx_b +#define __pyx_cython_runtime __pyx_mstate_global->__pyx_cython_runtime +#define __pyx_empty_tuple __pyx_mstate_global->__pyx_empty_tuple +#define __pyx_empty_bytes __pyx_mstate_global->__pyx_empty_bytes +#define __pyx_empty_unicode __pyx_mstate_global->__pyx_empty_unicode +#ifdef __Pyx_CyFunction_USED +#define __pyx_CyFunctionType __pyx_mstate_global->__pyx_CyFunctionType +#endif +#ifdef __Pyx_FusedFunction_USED +#define __pyx_FusedFunctionType __pyx_mstate_global->__pyx_FusedFunctionType +#endif +#ifdef __Pyx_Generator_USED +#define __pyx_GeneratorType __pyx_mstate_global->__pyx_GeneratorType +#endif +#ifdef __Pyx_IterableCoroutine_USED +#define __pyx_IterableCoroutineType __pyx_mstate_global->__pyx_IterableCoroutineType +#endif +#ifdef __Pyx_Coroutine_USED +#define __pyx_CoroutineAwaitType __pyx_mstate_global->__pyx_CoroutineAwaitType +#endif +#ifdef __Pyx_Coroutine_USED +#define __pyx_CoroutineType __pyx_mstate_global->__pyx_CoroutineType +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#define __pyx_ptype_7cpython_4type_type __pyx_mstate_global->__pyx_ptype_7cpython_4type_type +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#define __pyx_ptype_5numpy_dtype __pyx_mstate_global->__pyx_ptype_5numpy_dtype +#define __pyx_ptype_5numpy_flatiter __pyx_mstate_global->__pyx_ptype_5numpy_flatiter +#define __pyx_ptype_5numpy_broadcast __pyx_mstate_global->__pyx_ptype_5numpy_broadcast +#define __pyx_ptype_5numpy_ndarray __pyx_mstate_global->__pyx_ptype_5numpy_ndarray +#define __pyx_ptype_5numpy_generic __pyx_mstate_global->__pyx_ptype_5numpy_generic +#define __pyx_ptype_5numpy_number __pyx_mstate_global->__pyx_ptype_5numpy_number +#define __pyx_ptype_5numpy_integer __pyx_mstate_global->__pyx_ptype_5numpy_integer +#define __pyx_ptype_5numpy_signedinteger __pyx_mstate_global->__pyx_ptype_5numpy_signedinteger +#define __pyx_ptype_5numpy_unsignedinteger __pyx_mstate_global->__pyx_ptype_5numpy_unsignedinteger +#define __pyx_ptype_5numpy_inexact __pyx_mstate_global->__pyx_ptype_5numpy_inexact +#define __pyx_ptype_5numpy_floating __pyx_mstate_global->__pyx_ptype_5numpy_floating +#define __pyx_ptype_5numpy_complexfloating __pyx_mstate_global->__pyx_ptype_5numpy_complexfloating +#define __pyx_ptype_5numpy_flexible __pyx_mstate_global->__pyx_ptype_5numpy_flexible +#define __pyx_ptype_5numpy_character __pyx_mstate_global->__pyx_ptype_5numpy_character +#define __pyx_ptype_5numpy_ufunc __pyx_mstate_global->__pyx_ptype_5numpy_ufunc +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#define __pyx_type_7fairseq_4data_22token_block_utils_fast_DatasetSearcher __pyx_mstate_global->__pyx_type_7fairseq_4data_22token_block_utils_fast_DatasetSearcher +#define __pyx_type___pyx_array __pyx_mstate_global->__pyx_type___pyx_array +#define __pyx_type___pyx_MemviewEnum __pyx_mstate_global->__pyx_type___pyx_MemviewEnum +#define __pyx_type___pyx_memoryview __pyx_mstate_global->__pyx_type___pyx_memoryview +#define __pyx_type___pyx_memoryviewslice __pyx_mstate_global->__pyx_type___pyx_memoryviewslice +#endif +#define __pyx_ptype_7fairseq_4data_22token_block_utils_fast_DatasetSearcher __pyx_mstate_global->__pyx_ptype_7fairseq_4data_22token_block_utils_fast_DatasetSearcher +#define __pyx_array_type __pyx_mstate_global->__pyx_array_type +#define __pyx_MemviewEnum_type __pyx_mstate_global->__pyx_MemviewEnum_type +#define __pyx_memoryview_type __pyx_mstate_global->__pyx_memoryview_type +#define __pyx_memoryviewslice_type __pyx_mstate_global->__pyx_memoryviewslice_type +#define __pyx_kp_u_ __pyx_mstate_global->__pyx_kp_u_ +#define __pyx_n_s_ASCII __pyx_mstate_global->__pyx_n_s_ASCII +#define __pyx_kp_s_All_dimensions_preceding_dimensi __pyx_mstate_global->__pyx_kp_s_All_dimensions_preceding_dimensi +#define __pyx_n_s_AssertionError __pyx_mstate_global->__pyx_n_s_AssertionError +#define __pyx_kp_s_Buffer_view_does_not_expose_stri __pyx_mstate_global->__pyx_kp_s_Buffer_view_does_not_expose_stri +#define __pyx_kp_s_Can_only_create_a_buffer_that_is __pyx_mstate_global->__pyx_kp_s_Can_only_create_a_buffer_that_is +#define __pyx_kp_s_Cannot_assign_to_read_only_memor __pyx_mstate_global->__pyx_kp_s_Cannot_assign_to_read_only_memor +#define __pyx_kp_s_Cannot_create_writable_memory_vi __pyx_mstate_global->__pyx_kp_s_Cannot_create_writable_memory_vi +#define __pyx_kp_u_Cannot_index_with_type __pyx_mstate_global->__pyx_kp_u_Cannot_index_with_type +#define __pyx_kp_s_Cannot_transpose_memoryview_with __pyx_mstate_global->__pyx_kp_s_Cannot_transpose_memoryview_with +#define __pyx_n_s_DTYPE __pyx_mstate_global->__pyx_n_s_DTYPE +#define __pyx_n_s_DatasetSearcher __pyx_mstate_global->__pyx_n_s_DatasetSearcher +#define __pyx_n_s_DatasetSearcher___reduce_cython __pyx_mstate_global->__pyx_n_s_DatasetSearcher___reduce_cython +#define __pyx_n_s_DatasetSearcher___setstate_cytho __pyx_mstate_global->__pyx_n_s_DatasetSearcher___setstate_cytho +#define __pyx_kp_s_Dimension_d_is_not_direct __pyx_mstate_global->__pyx_kp_s_Dimension_d_is_not_direct +#define __pyx_n_s_Ellipsis __pyx_mstate_global->__pyx_n_s_Ellipsis +#define __pyx_kp_s_Empty_shape_tuple_for_cython_arr __pyx_mstate_global->__pyx_kp_s_Empty_shape_tuple_for_cython_arr +#define __pyx_n_s_ImportError __pyx_mstate_global->__pyx_n_s_ImportError +#define __pyx_kp_s_Incompatible_checksums_0x_x_vs_0 __pyx_mstate_global->__pyx_kp_s_Incompatible_checksums_0x_x_vs_0 +#define __pyx_kp_s_Incompatible_checksums_0x_x_vs_0_2 __pyx_mstate_global->__pyx_kp_s_Incompatible_checksums_0x_x_vs_0_2 +#define __pyx_n_s_IndexError __pyx_mstate_global->__pyx_n_s_IndexError +#define __pyx_kp_s_Index_out_of_bounds_axis_d __pyx_mstate_global->__pyx_kp_s_Index_out_of_bounds_axis_d +#define __pyx_kp_s_Indirect_dimensions_not_supporte __pyx_mstate_global->__pyx_kp_s_Indirect_dimensions_not_supporte +#define __pyx_kp_u_Invalid_break_mode __pyx_mstate_global->__pyx_kp_u_Invalid_break_mode +#define __pyx_kp_u_Invalid_mode_expected_c_or_fortr __pyx_mstate_global->__pyx_kp_u_Invalid_mode_expected_c_or_fortr +#define __pyx_kp_u_Invalid_shape_in_axis __pyx_mstate_global->__pyx_kp_u_Invalid_shape_in_axis +#define __pyx_n_s_MemoryError __pyx_mstate_global->__pyx_n_s_MemoryError +#define __pyx_kp_s_MemoryView_of_r_at_0x_x __pyx_mstate_global->__pyx_kp_s_MemoryView_of_r_at_0x_x +#define __pyx_kp_s_MemoryView_of_r_object __pyx_mstate_global->__pyx_kp_s_MemoryView_of_r_object +#define __pyx_n_b_O __pyx_mstate_global->__pyx_n_b_O +#define __pyx_kp_u_Out_of_bounds_on_buffer_access_a __pyx_mstate_global->__pyx_kp_u_Out_of_bounds_on_buffer_access_a +#define __pyx_n_s_PickleError __pyx_mstate_global->__pyx_n_s_PickleError +#define __pyx_n_s_Sequence __pyx_mstate_global->__pyx_n_s_Sequence +#define __pyx_kp_s_Step_may_not_be_zero_axis_d __pyx_mstate_global->__pyx_kp_s_Step_may_not_be_zero_axis_d +#define __pyx_n_s_TypeError __pyx_mstate_global->__pyx_n_s_TypeError +#define __pyx_kp_s_Unable_to_convert_item_to_object __pyx_mstate_global->__pyx_kp_s_Unable_to_convert_item_to_object +#define __pyx_n_s_ValueError __pyx_mstate_global->__pyx_n_s_ValueError +#define __pyx_n_s_View_MemoryView __pyx_mstate_global->__pyx_n_s_View_MemoryView +#define __pyx_kp_u__2 __pyx_mstate_global->__pyx_kp_u__2 +#define __pyx_n_s__3 __pyx_mstate_global->__pyx_n_s__3 +#define __pyx_n_s__35 __pyx_mstate_global->__pyx_n_s__35 +#define __pyx_kp_u__6 __pyx_mstate_global->__pyx_kp_u__6 +#define __pyx_kp_u__7 __pyx_mstate_global->__pyx_kp_u__7 +#define __pyx_n_s_abc __pyx_mstate_global->__pyx_n_s_abc +#define __pyx_n_s_allocate_buffer __pyx_mstate_global->__pyx_n_s_allocate_buffer +#define __pyx_kp_u_and __pyx_mstate_global->__pyx_kp_u_and +#define __pyx_n_s_asyncio_coroutines __pyx_mstate_global->__pyx_n_s_asyncio_coroutines +#define __pyx_n_s_axis __pyx_mstate_global->__pyx_n_s_axis +#define __pyx_n_s_base __pyx_mstate_global->__pyx_n_s_base +#define __pyx_n_s_block_size __pyx_mstate_global->__pyx_n_s_block_size +#define __pyx_n_s_break_mode __pyx_mstate_global->__pyx_n_s_break_mode +#define __pyx_n_s_c __pyx_mstate_global->__pyx_n_s_c +#define __pyx_n_u_c __pyx_mstate_global->__pyx_n_u_c +#define __pyx_n_s_chain __pyx_mstate_global->__pyx_n_s_chain +#define __pyx_n_s_class __pyx_mstate_global->__pyx_n_s_class +#define __pyx_n_s_class_getitem __pyx_mstate_global->__pyx_n_s_class_getitem +#define __pyx_n_s_cline_in_traceback __pyx_mstate_global->__pyx_n_s_cline_in_traceback +#define __pyx_n_s_collections __pyx_mstate_global->__pyx_n_s_collections +#define __pyx_kp_s_collections_abc __pyx_mstate_global->__pyx_kp_s_collections_abc +#define __pyx_n_u_complete __pyx_mstate_global->__pyx_n_u_complete +#define __pyx_n_u_complete_doc __pyx_mstate_global->__pyx_n_u_complete_doc +#define __pyx_kp_s_contiguous_and_direct __pyx_mstate_global->__pyx_kp_s_contiguous_and_direct +#define __pyx_kp_s_contiguous_and_indirect __pyx_mstate_global->__pyx_kp_s_contiguous_and_indirect +#define __pyx_n_s_count __pyx_mstate_global->__pyx_n_s_count +#define __pyx_n_s_cumsum __pyx_mstate_global->__pyx_n_s_cumsum +#define __pyx_n_s_dict __pyx_mstate_global->__pyx_n_s_dict +#define __pyx_n_s_dict_2 __pyx_mstate_global->__pyx_n_s_dict_2 +#define __pyx_kp_u_disable __pyx_mstate_global->__pyx_kp_u_disable +#define __pyx_n_s_document_sep_len __pyx_mstate_global->__pyx_n_s_document_sep_len +#define __pyx_n_s_dtype __pyx_mstate_global->__pyx_n_s_dtype +#define __pyx_n_s_dtype_is_object __pyx_mstate_global->__pyx_n_s_dtype_is_object +#define __pyx_kp_u_enable __pyx_mstate_global->__pyx_kp_u_enable +#define __pyx_n_s_encode __pyx_mstate_global->__pyx_n_s_encode +#define __pyx_n_s_enumerate __pyx_mstate_global->__pyx_n_s_enumerate +#define __pyx_n_u_eos __pyx_mstate_global->__pyx_n_u_eos +#define __pyx_n_s_error __pyx_mstate_global->__pyx_n_s_error +#define __pyx_kp_s_fairseq_data_token_block_utils_f __pyx_mstate_global->__pyx_kp_s_fairseq_data_token_block_utils_f +#define __pyx_n_s_fairseq_data_token_block_utils_f_2 __pyx_mstate_global->__pyx_n_s_fairseq_data_token_block_utils_f_2 +#define __pyx_n_s_flags __pyx_mstate_global->__pyx_n_s_flags +#define __pyx_n_s_format __pyx_mstate_global->__pyx_n_s_format +#define __pyx_n_s_fortran __pyx_mstate_global->__pyx_n_s_fortran +#define __pyx_n_u_fortran __pyx_mstate_global->__pyx_n_u_fortran +#define __pyx_n_s_from_iterable __pyx_mstate_global->__pyx_n_s_from_iterable +#define __pyx_n_s_fromiter __pyx_mstate_global->__pyx_n_s_fromiter +#define __pyx_kp_u_gc __pyx_mstate_global->__pyx_kp_u_gc +#define __pyx_n_s_get_block_to_dataset_index_fast __pyx_mstate_global->__pyx_n_s_get_block_to_dataset_index_fast +#define __pyx_n_s_get_slice_indices_fast __pyx_mstate_global->__pyx_n_s_get_slice_indices_fast +#define __pyx_n_s_getstate __pyx_mstate_global->__pyx_n_s_getstate +#define __pyx_kp_u_got __pyx_mstate_global->__pyx_kp_u_got +#define __pyx_kp_u_got_differing_extents_in_dimensi __pyx_mstate_global->__pyx_kp_u_got_differing_extents_in_dimensi +#define __pyx_n_s_id __pyx_mstate_global->__pyx_n_s_id +#define __pyx_n_s_import __pyx_mstate_global->__pyx_n_s_import +#define __pyx_n_s_index __pyx_mstate_global->__pyx_n_s_index +#define __pyx_n_s_initializing __pyx_mstate_global->__pyx_n_s_initializing +#define __pyx_n_s_int64 __pyx_mstate_global->__pyx_n_s_int64 +#define __pyx_n_s_is_coroutine __pyx_mstate_global->__pyx_n_s_is_coroutine +#define __pyx_kp_u_isenabled __pyx_mstate_global->__pyx_kp_u_isenabled +#define __pyx_n_s_itemsize __pyx_mstate_global->__pyx_n_s_itemsize +#define __pyx_kp_s_itemsize_0_for_cython_array __pyx_mstate_global->__pyx_kp_s_itemsize_0_for_cython_array +#define __pyx_n_s_itertools __pyx_mstate_global->__pyx_n_s_itertools +#define __pyx_n_s_main __pyx_mstate_global->__pyx_n_s_main +#define __pyx_n_s_memview __pyx_mstate_global->__pyx_n_s_memview +#define __pyx_n_s_mode __pyx_mstate_global->__pyx_n_s_mode +#define __pyx_n_s_name __pyx_mstate_global->__pyx_n_s_name +#define __pyx_n_s_name_2 __pyx_mstate_global->__pyx_n_s_name_2 +#define __pyx_n_s_ndim __pyx_mstate_global->__pyx_n_s_ndim +#define __pyx_n_s_new __pyx_mstate_global->__pyx_n_s_new +#define __pyx_kp_s_no_default___reduce___due_to_non __pyx_mstate_global->__pyx_kp_s_no_default___reduce___due_to_non +#define __pyx_n_u_none __pyx_mstate_global->__pyx_n_u_none +#define __pyx_n_s_np __pyx_mstate_global->__pyx_n_s_np +#define __pyx_n_s_numpy __pyx_mstate_global->__pyx_n_s_numpy +#define __pyx_kp_u_numpy_core_multiarray_failed_to __pyx_mstate_global->__pyx_kp_u_numpy_core_multiarray_failed_to +#define __pyx_kp_u_numpy_core_umath_failed_to_impor __pyx_mstate_global->__pyx_kp_u_numpy_core_umath_failed_to_impor +#define __pyx_n_s_obj __pyx_mstate_global->__pyx_n_s_obj +#define __pyx_n_s_pack __pyx_mstate_global->__pyx_n_s_pack +#define __pyx_n_s_pickle __pyx_mstate_global->__pyx_n_s_pickle +#define __pyx_n_s_pyx_PickleError __pyx_mstate_global->__pyx_n_s_pyx_PickleError +#define __pyx_n_s_pyx_checksum __pyx_mstate_global->__pyx_n_s_pyx_checksum +#define __pyx_n_s_pyx_result __pyx_mstate_global->__pyx_n_s_pyx_result +#define __pyx_n_s_pyx_state __pyx_mstate_global->__pyx_n_s_pyx_state +#define __pyx_n_s_pyx_type __pyx_mstate_global->__pyx_n_s_pyx_type +#define __pyx_n_s_pyx_unpickle_DatasetSearcher __pyx_mstate_global->__pyx_n_s_pyx_unpickle_DatasetSearcher +#define __pyx_n_s_pyx_unpickle_Enum __pyx_mstate_global->__pyx_n_s_pyx_unpickle_Enum +#define __pyx_n_s_pyx_vtable __pyx_mstate_global->__pyx_n_s_pyx_vtable +#define __pyx_n_s_range __pyx_mstate_global->__pyx_n_s_range +#define __pyx_n_s_reduce __pyx_mstate_global->__pyx_n_s_reduce +#define __pyx_n_s_reduce_cython __pyx_mstate_global->__pyx_n_s_reduce_cython +#define __pyx_n_s_reduce_ex __pyx_mstate_global->__pyx_n_s_reduce_ex +#define __pyx_n_s_register __pyx_mstate_global->__pyx_n_s_register +#define __pyx_n_s_reshape __pyx_mstate_global->__pyx_n_s_reshape +#define __pyx_n_s_self __pyx_mstate_global->__pyx_n_s_self +#define __pyx_n_s_setstate __pyx_mstate_global->__pyx_n_s_setstate +#define __pyx_n_s_setstate_cython __pyx_mstate_global->__pyx_n_s_setstate_cython +#define __pyx_n_s_shape __pyx_mstate_global->__pyx_n_s_shape +#define __pyx_n_s_size __pyx_mstate_global->__pyx_n_s_size +#define __pyx_n_s_sizes __pyx_mstate_global->__pyx_n_s_sizes +#define __pyx_n_s_slice_indices __pyx_mstate_global->__pyx_n_s_slice_indices +#define __pyx_n_s_spec __pyx_mstate_global->__pyx_n_s_spec +#define __pyx_n_s_start __pyx_mstate_global->__pyx_n_s_start +#define __pyx_n_s_state __pyx_mstate_global->__pyx_n_s_state +#define __pyx_n_s_step __pyx_mstate_global->__pyx_n_s_step +#define __pyx_n_s_stop __pyx_mstate_global->__pyx_n_s_stop +#define __pyx_kp_s_strided_and_direct __pyx_mstate_global->__pyx_kp_s_strided_and_direct +#define __pyx_kp_s_strided_and_direct_or_indirect __pyx_mstate_global->__pyx_kp_s_strided_and_direct_or_indirect +#define __pyx_kp_s_strided_and_indirect __pyx_mstate_global->__pyx_kp_s_strided_and_indirect +#define __pyx_kp_s_stringsource __pyx_mstate_global->__pyx_kp_s_stringsource +#define __pyx_n_s_struct __pyx_mstate_global->__pyx_n_s_struct +#define __pyx_n_s_sum __pyx_mstate_global->__pyx_n_s_sum +#define __pyx_n_s_sys __pyx_mstate_global->__pyx_n_s_sys +#define __pyx_n_s_test __pyx_mstate_global->__pyx_n_s_test +#define __pyx_n_s_torch __pyx_mstate_global->__pyx_n_s_torch +#define __pyx_kp_s_unable_to_allocate_array_data __pyx_mstate_global->__pyx_kp_s_unable_to_allocate_array_data +#define __pyx_kp_s_unable_to_allocate_shape_and_str __pyx_mstate_global->__pyx_kp_s_unable_to_allocate_shape_and_str +#define __pyx_n_s_unpack __pyx_mstate_global->__pyx_n_s_unpack +#define __pyx_n_s_update __pyx_mstate_global->__pyx_n_s_update +#define __pyx_n_s_use_setstate __pyx_mstate_global->__pyx_n_s_use_setstate +#define __pyx_n_s_version_info __pyx_mstate_global->__pyx_n_s_version_info +#define __pyx_n_s_zeros __pyx_mstate_global->__pyx_n_s_zeros +#define __pyx_int_0 __pyx_mstate_global->__pyx_int_0 +#define __pyx_int_1 __pyx_mstate_global->__pyx_int_1 +#define __pyx_int_2 __pyx_mstate_global->__pyx_int_2 +#define __pyx_int_3 __pyx_mstate_global->__pyx_int_3 +#define __pyx_int_48422178 __pyx_mstate_global->__pyx_int_48422178 +#define __pyx_int_107161605 __pyx_mstate_global->__pyx_int_107161605 +#define __pyx_int_112105877 __pyx_mstate_global->__pyx_int_112105877 +#define __pyx_int_136983863 __pyx_mstate_global->__pyx_int_136983863 +#define __pyx_int_147225413 __pyx_mstate_global->__pyx_int_147225413 +#define __pyx_int_184977713 __pyx_mstate_global->__pyx_int_184977713 +#define __pyx_int_neg_1 __pyx_mstate_global->__pyx_int_neg_1 +#define __pyx_slice__5 __pyx_mstate_global->__pyx_slice__5 +#define __pyx_tuple__4 __pyx_mstate_global->__pyx_tuple__4 +#define __pyx_tuple__8 __pyx_mstate_global->__pyx_tuple__8 +#define __pyx_tuple__9 __pyx_mstate_global->__pyx_tuple__9 +#define __pyx_slice__11 __pyx_mstate_global->__pyx_slice__11 +#define __pyx_tuple__10 __pyx_mstate_global->__pyx_tuple__10 +#define __pyx_tuple__12 __pyx_mstate_global->__pyx_tuple__12 +#define __pyx_tuple__13 __pyx_mstate_global->__pyx_tuple__13 +#define __pyx_tuple__14 __pyx_mstate_global->__pyx_tuple__14 +#define __pyx_tuple__15 __pyx_mstate_global->__pyx_tuple__15 +#define __pyx_tuple__16 __pyx_mstate_global->__pyx_tuple__16 +#define __pyx_tuple__17 __pyx_mstate_global->__pyx_tuple__17 +#define __pyx_tuple__18 __pyx_mstate_global->__pyx_tuple__18 +#define __pyx_tuple__19 __pyx_mstate_global->__pyx_tuple__19 +#define __pyx_tuple__20 __pyx_mstate_global->__pyx_tuple__20 +#define __pyx_tuple__21 __pyx_mstate_global->__pyx_tuple__21 +#define __pyx_tuple__22 __pyx_mstate_global->__pyx_tuple__22 +#define __pyx_tuple__23 __pyx_mstate_global->__pyx_tuple__23 +#define __pyx_tuple__24 __pyx_mstate_global->__pyx_tuple__24 +#define __pyx_tuple__26 __pyx_mstate_global->__pyx_tuple__26 +#define __pyx_tuple__28 __pyx_mstate_global->__pyx_tuple__28 +#define __pyx_tuple__30 __pyx_mstate_global->__pyx_tuple__30 +#define __pyx_tuple__32 __pyx_mstate_global->__pyx_tuple__32 +#define __pyx_codeobj__25 __pyx_mstate_global->__pyx_codeobj__25 +#define __pyx_codeobj__27 __pyx_mstate_global->__pyx_codeobj__27 +#define __pyx_codeobj__29 __pyx_mstate_global->__pyx_codeobj__29 +#define __pyx_codeobj__31 __pyx_mstate_global->__pyx_codeobj__31 +#define __pyx_codeobj__33 __pyx_mstate_global->__pyx_codeobj__33 +#define __pyx_codeobj__34 __pyx_mstate_global->__pyx_codeobj__34 +/* #### Code section: module_code ### */ + +/* "View.MemoryView":131 + * cdef bint dtype_is_object + * + * def __cinit__(array self, tuple shape, Py_ssize_t itemsize, format not None, # <<<<<<<<<<<<<< + * mode="c", bint allocate_buffer=True): + * + */ + +/* Python wrapper */ +static int __pyx_array___cinit__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ +static int __pyx_array___cinit__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { + PyObject *__pyx_v_shape = 0; + Py_ssize_t __pyx_v_itemsize; + PyObject *__pyx_v_format = 0; + PyObject *__pyx_v_mode = 0; + int __pyx_v_allocate_buffer; + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[5] = {0,0,0,0,0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + int __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__cinit__ (wrapper)", 0); + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return -1; + #endif + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_shape,&__pyx_n_s_itemsize,&__pyx_n_s_format,&__pyx_n_s_mode,&__pyx_n_s_allocate_buffer,0}; + values[3] = __Pyx_Arg_NewRef_VARARGS(((PyObject *)__pyx_n_s_c)); + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 5: values[4] = __Pyx_Arg_VARARGS(__pyx_args, 4); + CYTHON_FALLTHROUGH; + case 4: values[3] = __Pyx_Arg_VARARGS(__pyx_args, 3); + CYTHON_FALLTHROUGH; + case 3: values[2] = __Pyx_Arg_VARARGS(__pyx_args, 2); + CYTHON_FALLTHROUGH; + case 2: values[1] = __Pyx_Arg_VARARGS(__pyx_args, 1); + CYTHON_FALLTHROUGH; + case 1: values[0] = __Pyx_Arg_VARARGS(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_VARARGS(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_shape)) != 0)) { + (void)__Pyx_Arg_NewRef_VARARGS(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 131, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + CYTHON_FALLTHROUGH; + case 1: + if (likely((values[1] = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_itemsize)) != 0)) { + (void)__Pyx_Arg_NewRef_VARARGS(values[1]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 131, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("__cinit__", 0, 3, 5, 1); __PYX_ERR(1, 131, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 2: + if (likely((values[2] = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_format)) != 0)) { + (void)__Pyx_Arg_NewRef_VARARGS(values[2]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 131, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("__cinit__", 0, 3, 5, 2); __PYX_ERR(1, 131, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 3: + if (kw_args > 0) { + PyObject* value = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_mode); + if (value) { values[3] = __Pyx_Arg_NewRef_VARARGS(value); kw_args--; } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 131, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 4: + if (kw_args > 0) { + PyObject* value = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_allocate_buffer); + if (value) { values[4] = __Pyx_Arg_NewRef_VARARGS(value); kw_args--; } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 131, __pyx_L3_error) + } + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "__cinit__") < 0)) __PYX_ERR(1, 131, __pyx_L3_error) + } + } else { + switch (__pyx_nargs) { + case 5: values[4] = __Pyx_Arg_VARARGS(__pyx_args, 4); + CYTHON_FALLTHROUGH; + case 4: values[3] = __Pyx_Arg_VARARGS(__pyx_args, 3); + CYTHON_FALLTHROUGH; + case 3: values[2] = __Pyx_Arg_VARARGS(__pyx_args, 2); + values[1] = __Pyx_Arg_VARARGS(__pyx_args, 1); + values[0] = __Pyx_Arg_VARARGS(__pyx_args, 0); + break; + default: goto __pyx_L5_argtuple_error; + } + } + __pyx_v_shape = ((PyObject*)values[0]); + __pyx_v_itemsize = __Pyx_PyIndex_AsSsize_t(values[1]); if (unlikely((__pyx_v_itemsize == (Py_ssize_t)-1) && PyErr_Occurred())) __PYX_ERR(1, 131, __pyx_L3_error) + __pyx_v_format = values[2]; + __pyx_v_mode = values[3]; + if (values[4]) { + __pyx_v_allocate_buffer = __Pyx_PyObject_IsTrue(values[4]); if (unlikely((__pyx_v_allocate_buffer == (int)-1) && PyErr_Occurred())) __PYX_ERR(1, 132, __pyx_L3_error) + } else { + + /* "View.MemoryView":132 + * + * def __cinit__(array self, tuple shape, Py_ssize_t itemsize, format not None, + * mode="c", bint allocate_buffer=True): # <<<<<<<<<<<<<< + * + * cdef int idx + */ + __pyx_v_allocate_buffer = ((int)1); + } + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("__cinit__", 0, 3, 5, __pyx_nargs); __PYX_ERR(1, 131, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_VARARGS(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("View.MemoryView.array.__cinit__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return -1; + __pyx_L4_argument_unpacking_done:; + if (unlikely(!__Pyx_ArgTypeTest(((PyObject *)__pyx_v_shape), (&PyTuple_Type), 1, "shape", 1))) __PYX_ERR(1, 131, __pyx_L1_error) + if (unlikely(((PyObject *)__pyx_v_format) == Py_None)) { + PyErr_Format(PyExc_TypeError, "Argument '%.200s' must not be None", "format"); __PYX_ERR(1, 131, __pyx_L1_error) + } + __pyx_r = __pyx_array___pyx_pf_15View_dot_MemoryView_5array___cinit__(((struct __pyx_array_obj *)__pyx_v_self), __pyx_v_shape, __pyx_v_itemsize, __pyx_v_format, __pyx_v_mode, __pyx_v_allocate_buffer); + + /* "View.MemoryView":131 + * cdef bint dtype_is_object + * + * def __cinit__(array self, tuple shape, Py_ssize_t itemsize, format not None, # <<<<<<<<<<<<<< + * mode="c", bint allocate_buffer=True): + * + */ + + /* function exit code */ + goto __pyx_L0; + __pyx_L1_error:; + __pyx_r = -1; + __pyx_L0:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_VARARGS(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static int __pyx_array___pyx_pf_15View_dot_MemoryView_5array___cinit__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_shape, Py_ssize_t __pyx_v_itemsize, PyObject *__pyx_v_format, PyObject *__pyx_v_mode, int __pyx_v_allocate_buffer) { + int __pyx_v_idx; + Py_ssize_t __pyx_v_dim; + char __pyx_v_order; + int __pyx_r; + __Pyx_RefNannyDeclarations + Py_ssize_t __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + PyObject *__pyx_t_4 = NULL; + PyObject *__pyx_t_5 = NULL; + PyObject *__pyx_t_6 = NULL; + int __pyx_t_7; + char *__pyx_t_8; + Py_ssize_t __pyx_t_9; + Py_UCS4 __pyx_t_10; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__cinit__", 0); + __Pyx_INCREF(__pyx_v_format); + + /* "View.MemoryView":137 + * cdef Py_ssize_t dim + * + * self.ndim = len(shape) # <<<<<<<<<<<<<< + * self.itemsize = itemsize + * + */ + if (unlikely(__pyx_v_shape == Py_None)) { + PyErr_SetString(PyExc_TypeError, "object of type 'NoneType' has no len()"); + __PYX_ERR(1, 137, __pyx_L1_error) + } + __pyx_t_1 = __Pyx_PyTuple_GET_SIZE(__pyx_v_shape); if (unlikely(__pyx_t_1 == ((Py_ssize_t)-1))) __PYX_ERR(1, 137, __pyx_L1_error) + __pyx_v_self->ndim = ((int)__pyx_t_1); + + /* "View.MemoryView":138 + * + * self.ndim = len(shape) + * self.itemsize = itemsize # <<<<<<<<<<<<<< + * + * if not self.ndim: + */ + __pyx_v_self->itemsize = __pyx_v_itemsize; + + /* "View.MemoryView":140 + * self.itemsize = itemsize + * + * if not self.ndim: # <<<<<<<<<<<<<< + * raise ValueError, "Empty shape tuple for cython.array" + * + */ + __pyx_t_2 = (!(__pyx_v_self->ndim != 0)); + if (unlikely(__pyx_t_2)) { + + /* "View.MemoryView":141 + * + * if not self.ndim: + * raise ValueError, "Empty shape tuple for cython.array" # <<<<<<<<<<<<<< + * + * if itemsize <= 0: + */ + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_kp_s_Empty_shape_tuple_for_cython_arr, 0, 0); + __PYX_ERR(1, 141, __pyx_L1_error) + + /* "View.MemoryView":140 + * self.itemsize = itemsize + * + * if not self.ndim: # <<<<<<<<<<<<<< + * raise ValueError, "Empty shape tuple for cython.array" + * + */ + } + + /* "View.MemoryView":143 + * raise ValueError, "Empty shape tuple for cython.array" + * + * if itemsize <= 0: # <<<<<<<<<<<<<< + * raise ValueError, "itemsize <= 0 for cython.array" + * + */ + __pyx_t_2 = (__pyx_v_itemsize <= 0); + if (unlikely(__pyx_t_2)) { + + /* "View.MemoryView":144 + * + * if itemsize <= 0: + * raise ValueError, "itemsize <= 0 for cython.array" # <<<<<<<<<<<<<< + * + * if not isinstance(format, bytes): + */ + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_kp_s_itemsize_0_for_cython_array, 0, 0); + __PYX_ERR(1, 144, __pyx_L1_error) + + /* "View.MemoryView":143 + * raise ValueError, "Empty shape tuple for cython.array" + * + * if itemsize <= 0: # <<<<<<<<<<<<<< + * raise ValueError, "itemsize <= 0 for cython.array" + * + */ + } + + /* "View.MemoryView":146 + * raise ValueError, "itemsize <= 0 for cython.array" + * + * if not isinstance(format, bytes): # <<<<<<<<<<<<<< + * format = format.encode('ASCII') + * self._format = format # keep a reference to the byte string + */ + __pyx_t_2 = PyBytes_Check(__pyx_v_format); + __pyx_t_3 = (!__pyx_t_2); + if (__pyx_t_3) { + + /* "View.MemoryView":147 + * + * if not isinstance(format, bytes): + * format = format.encode('ASCII') # <<<<<<<<<<<<<< + * self._format = format # keep a reference to the byte string + * self.format = self._format + */ + __pyx_t_5 = __Pyx_PyObject_GetAttrStr(__pyx_v_format, __pyx_n_s_encode); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 147, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_6 = NULL; + __pyx_t_7 = 0; + #if CYTHON_UNPACK_METHODS + if (likely(PyMethod_Check(__pyx_t_5))) { + __pyx_t_6 = PyMethod_GET_SELF(__pyx_t_5); + if (likely(__pyx_t_6)) { + PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_5); + __Pyx_INCREF(__pyx_t_6); + __Pyx_INCREF(function); + __Pyx_DECREF_SET(__pyx_t_5, function); + __pyx_t_7 = 1; + } + } + #endif + { + PyObject *__pyx_callargs[2] = {__pyx_t_6, __pyx_n_s_ASCII}; + __pyx_t_4 = __Pyx_PyObject_FastCall(__pyx_t_5, __pyx_callargs+1-__pyx_t_7, 1+__pyx_t_7); + __Pyx_XDECREF(__pyx_t_6); __pyx_t_6 = 0; + if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 147, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + } + __Pyx_DECREF_SET(__pyx_v_format, __pyx_t_4); + __pyx_t_4 = 0; + + /* "View.MemoryView":146 + * raise ValueError, "itemsize <= 0 for cython.array" + * + * if not isinstance(format, bytes): # <<<<<<<<<<<<<< + * format = format.encode('ASCII') + * self._format = format # keep a reference to the byte string + */ + } + + /* "View.MemoryView":148 + * if not isinstance(format, bytes): + * format = format.encode('ASCII') + * self._format = format # keep a reference to the byte string # <<<<<<<<<<<<<< + * self.format = self._format + * + */ + if (!(likely(PyBytes_CheckExact(__pyx_v_format))||((__pyx_v_format) == Py_None) || __Pyx_RaiseUnexpectedTypeError("bytes", __pyx_v_format))) __PYX_ERR(1, 148, __pyx_L1_error) + __pyx_t_4 = __pyx_v_format; + __Pyx_INCREF(__pyx_t_4); + __Pyx_GIVEREF(__pyx_t_4); + __Pyx_GOTREF(__pyx_v_self->_format); + __Pyx_DECREF(__pyx_v_self->_format); + __pyx_v_self->_format = ((PyObject*)__pyx_t_4); + __pyx_t_4 = 0; + + /* "View.MemoryView":149 + * format = format.encode('ASCII') + * self._format = format # keep a reference to the byte string + * self.format = self._format # <<<<<<<<<<<<<< + * + * + */ + if (unlikely(__pyx_v_self->_format == Py_None)) { + PyErr_SetString(PyExc_TypeError, "expected bytes, NoneType found"); + __PYX_ERR(1, 149, __pyx_L1_error) + } + __pyx_t_8 = __Pyx_PyBytes_AsWritableString(__pyx_v_self->_format); if (unlikely((!__pyx_t_8) && PyErr_Occurred())) __PYX_ERR(1, 149, __pyx_L1_error) + __pyx_v_self->format = __pyx_t_8; + + /* "View.MemoryView":152 + * + * + * self._shape = PyObject_Malloc(sizeof(Py_ssize_t)*self.ndim*2) # <<<<<<<<<<<<<< + * self._strides = self._shape + self.ndim + * + */ + __pyx_v_self->_shape = ((Py_ssize_t *)PyObject_Malloc((((sizeof(Py_ssize_t)) * __pyx_v_self->ndim) * 2))); + + /* "View.MemoryView":153 + * + * self._shape = PyObject_Malloc(sizeof(Py_ssize_t)*self.ndim*2) + * self._strides = self._shape + self.ndim # <<<<<<<<<<<<<< + * + * if not self._shape: + */ + __pyx_v_self->_strides = (__pyx_v_self->_shape + __pyx_v_self->ndim); + + /* "View.MemoryView":155 + * self._strides = self._shape + self.ndim + * + * if not self._shape: # <<<<<<<<<<<<<< + * raise MemoryError, "unable to allocate shape and strides." + * + */ + __pyx_t_3 = (!(__pyx_v_self->_shape != 0)); + if (unlikely(__pyx_t_3)) { + + /* "View.MemoryView":156 + * + * if not self._shape: + * raise MemoryError, "unable to allocate shape and strides." # <<<<<<<<<<<<<< + * + * + */ + __Pyx_Raise(__pyx_builtin_MemoryError, __pyx_kp_s_unable_to_allocate_shape_and_str, 0, 0); + __PYX_ERR(1, 156, __pyx_L1_error) + + /* "View.MemoryView":155 + * self._strides = self._shape + self.ndim + * + * if not self._shape: # <<<<<<<<<<<<<< + * raise MemoryError, "unable to allocate shape and strides." + * + */ + } + + /* "View.MemoryView":159 + * + * + * for idx, dim in enumerate(shape): # <<<<<<<<<<<<<< + * if dim <= 0: + * raise ValueError, f"Invalid shape in axis {idx}: {dim}." + */ + __pyx_t_7 = 0; + __pyx_t_4 = __pyx_v_shape; __Pyx_INCREF(__pyx_t_4); + __pyx_t_1 = 0; + for (;;) { + { + Py_ssize_t __pyx_temp = __Pyx_PyTuple_GET_SIZE(__pyx_t_4); + #if !CYTHON_ASSUME_SAFE_MACROS + if (unlikely((__pyx_temp < 0))) __PYX_ERR(1, 159, __pyx_L1_error) + #endif + if (__pyx_t_1 >= __pyx_temp) break; + } + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + __pyx_t_5 = PyTuple_GET_ITEM(__pyx_t_4, __pyx_t_1); __Pyx_INCREF(__pyx_t_5); __pyx_t_1++; if (unlikely((0 < 0))) __PYX_ERR(1, 159, __pyx_L1_error) + #else + __pyx_t_5 = __Pyx_PySequence_ITEM(__pyx_t_4, __pyx_t_1); __pyx_t_1++; if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 159, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + #endif + __pyx_t_9 = __Pyx_PyIndex_AsSsize_t(__pyx_t_5); if (unlikely((__pyx_t_9 == (Py_ssize_t)-1) && PyErr_Occurred())) __PYX_ERR(1, 159, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __pyx_v_dim = __pyx_t_9; + __pyx_v_idx = __pyx_t_7; + __pyx_t_7 = (__pyx_t_7 + 1); + + /* "View.MemoryView":160 + * + * for idx, dim in enumerate(shape): + * if dim <= 0: # <<<<<<<<<<<<<< + * raise ValueError, f"Invalid shape in axis {idx}: {dim}." + * self._shape[idx] = dim + */ + __pyx_t_3 = (__pyx_v_dim <= 0); + if (unlikely(__pyx_t_3)) { + + /* "View.MemoryView":161 + * for idx, dim in enumerate(shape): + * if dim <= 0: + * raise ValueError, f"Invalid shape in axis {idx}: {dim}." # <<<<<<<<<<<<<< + * self._shape[idx] = dim + * + */ + __pyx_t_5 = PyTuple_New(5); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 161, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_9 = 0; + __pyx_t_10 = 127; + __Pyx_INCREF(__pyx_kp_u_Invalid_shape_in_axis); + __pyx_t_9 += 22; + __Pyx_GIVEREF(__pyx_kp_u_Invalid_shape_in_axis); + PyTuple_SET_ITEM(__pyx_t_5, 0, __pyx_kp_u_Invalid_shape_in_axis); + __pyx_t_6 = __Pyx_PyUnicode_From_int(__pyx_v_idx, 0, ' ', 'd'); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 161, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __pyx_t_9 += __Pyx_PyUnicode_GET_LENGTH(__pyx_t_6); + __Pyx_GIVEREF(__pyx_t_6); + PyTuple_SET_ITEM(__pyx_t_5, 1, __pyx_t_6); + __pyx_t_6 = 0; + __Pyx_INCREF(__pyx_kp_u_); + __pyx_t_9 += 2; + __Pyx_GIVEREF(__pyx_kp_u_); + PyTuple_SET_ITEM(__pyx_t_5, 2, __pyx_kp_u_); + __pyx_t_6 = __Pyx_PyUnicode_From_Py_ssize_t(__pyx_v_dim, 0, ' ', 'd'); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 161, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __pyx_t_9 += __Pyx_PyUnicode_GET_LENGTH(__pyx_t_6); + __Pyx_GIVEREF(__pyx_t_6); + PyTuple_SET_ITEM(__pyx_t_5, 3, __pyx_t_6); + __pyx_t_6 = 0; + __Pyx_INCREF(__pyx_kp_u__2); + __pyx_t_9 += 1; + __Pyx_GIVEREF(__pyx_kp_u__2); + PyTuple_SET_ITEM(__pyx_t_5, 4, __pyx_kp_u__2); + __pyx_t_6 = __Pyx_PyUnicode_Join(__pyx_t_5, 5, __pyx_t_9, __pyx_t_10); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 161, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_t_6, 0, 0); + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + __PYX_ERR(1, 161, __pyx_L1_error) + + /* "View.MemoryView":160 + * + * for idx, dim in enumerate(shape): + * if dim <= 0: # <<<<<<<<<<<<<< + * raise ValueError, f"Invalid shape in axis {idx}: {dim}." + * self._shape[idx] = dim + */ + } + + /* "View.MemoryView":162 + * if dim <= 0: + * raise ValueError, f"Invalid shape in axis {idx}: {dim}." + * self._shape[idx] = dim # <<<<<<<<<<<<<< + * + * cdef char order + */ + (__pyx_v_self->_shape[__pyx_v_idx]) = __pyx_v_dim; + + /* "View.MemoryView":159 + * + * + * for idx, dim in enumerate(shape): # <<<<<<<<<<<<<< + * if dim <= 0: + * raise ValueError, f"Invalid shape in axis {idx}: {dim}." + */ + } + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + + /* "View.MemoryView":165 + * + * cdef char order + * if mode == 'c': # <<<<<<<<<<<<<< + * order = b'C' + * self.mode = u'c' + */ + __pyx_t_3 = (__Pyx_PyString_Equals(__pyx_v_mode, __pyx_n_s_c, Py_EQ)); if (unlikely((__pyx_t_3 < 0))) __PYX_ERR(1, 165, __pyx_L1_error) + if (__pyx_t_3) { + + /* "View.MemoryView":166 + * cdef char order + * if mode == 'c': + * order = b'C' # <<<<<<<<<<<<<< + * self.mode = u'c' + * elif mode == 'fortran': + */ + __pyx_v_order = 'C'; + + /* "View.MemoryView":167 + * if mode == 'c': + * order = b'C' + * self.mode = u'c' # <<<<<<<<<<<<<< + * elif mode == 'fortran': + * order = b'F' + */ + __Pyx_INCREF(__pyx_n_u_c); + __Pyx_GIVEREF(__pyx_n_u_c); + __Pyx_GOTREF(__pyx_v_self->mode); + __Pyx_DECREF(__pyx_v_self->mode); + __pyx_v_self->mode = __pyx_n_u_c; + + /* "View.MemoryView":165 + * + * cdef char order + * if mode == 'c': # <<<<<<<<<<<<<< + * order = b'C' + * self.mode = u'c' + */ + goto __pyx_L11; + } + + /* "View.MemoryView":168 + * order = b'C' + * self.mode = u'c' + * elif mode == 'fortran': # <<<<<<<<<<<<<< + * order = b'F' + * self.mode = u'fortran' + */ + __pyx_t_3 = (__Pyx_PyString_Equals(__pyx_v_mode, __pyx_n_s_fortran, Py_EQ)); if (unlikely((__pyx_t_3 < 0))) __PYX_ERR(1, 168, __pyx_L1_error) + if (likely(__pyx_t_3)) { + + /* "View.MemoryView":169 + * self.mode = u'c' + * elif mode == 'fortran': + * order = b'F' # <<<<<<<<<<<<<< + * self.mode = u'fortran' + * else: + */ + __pyx_v_order = 'F'; + + /* "View.MemoryView":170 + * elif mode == 'fortran': + * order = b'F' + * self.mode = u'fortran' # <<<<<<<<<<<<<< + * else: + * raise ValueError, f"Invalid mode, expected 'c' or 'fortran', got {mode}" + */ + __Pyx_INCREF(__pyx_n_u_fortran); + __Pyx_GIVEREF(__pyx_n_u_fortran); + __Pyx_GOTREF(__pyx_v_self->mode); + __Pyx_DECREF(__pyx_v_self->mode); + __pyx_v_self->mode = __pyx_n_u_fortran; + + /* "View.MemoryView":168 + * order = b'C' + * self.mode = u'c' + * elif mode == 'fortran': # <<<<<<<<<<<<<< + * order = b'F' + * self.mode = u'fortran' + */ + goto __pyx_L11; + } + + /* "View.MemoryView":172 + * self.mode = u'fortran' + * else: + * raise ValueError, f"Invalid mode, expected 'c' or 'fortran', got {mode}" # <<<<<<<<<<<<<< + * + * self.len = fill_contig_strides_array(self._shape, self._strides, itemsize, self.ndim, order) + */ + /*else*/ { + __pyx_t_4 = __Pyx_PyObject_FormatSimple(__pyx_v_mode, __pyx_empty_unicode); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 172, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_6 = __Pyx_PyUnicode_Concat(__pyx_kp_u_Invalid_mode_expected_c_or_fortr, __pyx_t_4); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 172, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_t_6, 0, 0); + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + __PYX_ERR(1, 172, __pyx_L1_error) + } + __pyx_L11:; + + /* "View.MemoryView":174 + * raise ValueError, f"Invalid mode, expected 'c' or 'fortran', got {mode}" + * + * self.len = fill_contig_strides_array(self._shape, self._strides, itemsize, self.ndim, order) # <<<<<<<<<<<<<< + * + * self.free_data = allocate_buffer + */ + __pyx_v_self->len = __pyx_fill_contig_strides_array(__pyx_v_self->_shape, __pyx_v_self->_strides, __pyx_v_itemsize, __pyx_v_self->ndim, __pyx_v_order); + + /* "View.MemoryView":176 + * self.len = fill_contig_strides_array(self._shape, self._strides, itemsize, self.ndim, order) + * + * self.free_data = allocate_buffer # <<<<<<<<<<<<<< + * self.dtype_is_object = format == b'O' + * + */ + __pyx_v_self->free_data = __pyx_v_allocate_buffer; + + /* "View.MemoryView":177 + * + * self.free_data = allocate_buffer + * self.dtype_is_object = format == b'O' # <<<<<<<<<<<<<< + * + * if allocate_buffer: + */ + __pyx_t_6 = PyObject_RichCompare(__pyx_v_format, __pyx_n_b_O, Py_EQ); __Pyx_XGOTREF(__pyx_t_6); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 177, __pyx_L1_error) + __pyx_t_3 = __Pyx_PyObject_IsTrue(__pyx_t_6); if (unlikely((__pyx_t_3 == (int)-1) && PyErr_Occurred())) __PYX_ERR(1, 177, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + __pyx_v_self->dtype_is_object = __pyx_t_3; + + /* "View.MemoryView":179 + * self.dtype_is_object = format == b'O' + * + * if allocate_buffer: # <<<<<<<<<<<<<< + * _allocate_buffer(self) + * + */ + if (__pyx_v_allocate_buffer) { + + /* "View.MemoryView":180 + * + * if allocate_buffer: + * _allocate_buffer(self) # <<<<<<<<<<<<<< + * + * @cname('getbuffer') + */ + __pyx_t_7 = __pyx_array_allocate_buffer(__pyx_v_self); if (unlikely(__pyx_t_7 == ((int)-1))) __PYX_ERR(1, 180, __pyx_L1_error) + + /* "View.MemoryView":179 + * self.dtype_is_object = format == b'O' + * + * if allocate_buffer: # <<<<<<<<<<<<<< + * _allocate_buffer(self) + * + */ + } + + /* "View.MemoryView":131 + * cdef bint dtype_is_object + * + * def __cinit__(array self, tuple shape, Py_ssize_t itemsize, format not None, # <<<<<<<<<<<<<< + * mode="c", bint allocate_buffer=True): + * + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_4); + __Pyx_XDECREF(__pyx_t_5); + __Pyx_XDECREF(__pyx_t_6); + __Pyx_AddTraceback("View.MemoryView.array.__cinit__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_format); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":182 + * _allocate_buffer(self) + * + * @cname('getbuffer') # <<<<<<<<<<<<<< + * def __getbuffer__(self, Py_buffer *info, int flags): + * cdef int bufmode = -1 + */ + +/* Python wrapper */ +CYTHON_UNUSED static int __pyx_array_getbuffer(PyObject *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags); /*proto*/ +CYTHON_UNUSED static int __pyx_array_getbuffer(PyObject *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + int __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__getbuffer__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_array___pyx_pf_15View_dot_MemoryView_5array_2__getbuffer__(((struct __pyx_array_obj *)__pyx_v_self), ((Py_buffer *)__pyx_v_info), ((int)__pyx_v_flags)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static int __pyx_array___pyx_pf_15View_dot_MemoryView_5array_2__getbuffer__(struct __pyx_array_obj *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags) { + int __pyx_v_bufmode; + int __pyx_r; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + char *__pyx_t_2; + Py_ssize_t __pyx_t_3; + int __pyx_t_4; + Py_ssize_t *__pyx_t_5; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + if (unlikely(__pyx_v_info == NULL)) { + PyErr_SetString(PyExc_BufferError, "PyObject_GetBuffer: view==NULL argument is obsolete"); + return -1; + } + __Pyx_RefNannySetupContext("__getbuffer__", 0); + __pyx_v_info->obj = Py_None; __Pyx_INCREF(Py_None); + __Pyx_GIVEREF(__pyx_v_info->obj); + + /* "View.MemoryView":184 + * @cname('getbuffer') + * def __getbuffer__(self, Py_buffer *info, int flags): + * cdef int bufmode = -1 # <<<<<<<<<<<<<< + * if flags & (PyBUF_C_CONTIGUOUS | PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS): + * if self.mode == u"c": + */ + __pyx_v_bufmode = -1; + + /* "View.MemoryView":185 + * def __getbuffer__(self, Py_buffer *info, int flags): + * cdef int bufmode = -1 + * if flags & (PyBUF_C_CONTIGUOUS | PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS): # <<<<<<<<<<<<<< + * if self.mode == u"c": + * bufmode = PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + */ + __pyx_t_1 = ((__pyx_v_flags & ((PyBUF_C_CONTIGUOUS | PyBUF_F_CONTIGUOUS) | PyBUF_ANY_CONTIGUOUS)) != 0); + if (__pyx_t_1) { + + /* "View.MemoryView":186 + * cdef int bufmode = -1 + * if flags & (PyBUF_C_CONTIGUOUS | PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS): + * if self.mode == u"c": # <<<<<<<<<<<<<< + * bufmode = PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * elif self.mode == u"fortran": + */ + __pyx_t_1 = (__Pyx_PyUnicode_Equals(__pyx_v_self->mode, __pyx_n_u_c, Py_EQ)); if (unlikely((__pyx_t_1 < 0))) __PYX_ERR(1, 186, __pyx_L1_error) + if (__pyx_t_1) { + + /* "View.MemoryView":187 + * if flags & (PyBUF_C_CONTIGUOUS | PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS): + * if self.mode == u"c": + * bufmode = PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS # <<<<<<<<<<<<<< + * elif self.mode == u"fortran": + * bufmode = PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + */ + __pyx_v_bufmode = (PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS); + + /* "View.MemoryView":186 + * cdef int bufmode = -1 + * if flags & (PyBUF_C_CONTIGUOUS | PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS): + * if self.mode == u"c": # <<<<<<<<<<<<<< + * bufmode = PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * elif self.mode == u"fortran": + */ + goto __pyx_L4; + } + + /* "View.MemoryView":188 + * if self.mode == u"c": + * bufmode = PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * elif self.mode == u"fortran": # <<<<<<<<<<<<<< + * bufmode = PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * if not (flags & bufmode): + */ + __pyx_t_1 = (__Pyx_PyUnicode_Equals(__pyx_v_self->mode, __pyx_n_u_fortran, Py_EQ)); if (unlikely((__pyx_t_1 < 0))) __PYX_ERR(1, 188, __pyx_L1_error) + if (__pyx_t_1) { + + /* "View.MemoryView":189 + * bufmode = PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * elif self.mode == u"fortran": + * bufmode = PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS # <<<<<<<<<<<<<< + * if not (flags & bufmode): + * raise ValueError, "Can only create a buffer that is contiguous in memory." + */ + __pyx_v_bufmode = (PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS); + + /* "View.MemoryView":188 + * if self.mode == u"c": + * bufmode = PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * elif self.mode == u"fortran": # <<<<<<<<<<<<<< + * bufmode = PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * if not (flags & bufmode): + */ + } + __pyx_L4:; + + /* "View.MemoryView":190 + * elif self.mode == u"fortran": + * bufmode = PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * if not (flags & bufmode): # <<<<<<<<<<<<<< + * raise ValueError, "Can only create a buffer that is contiguous in memory." + * info.buf = self.data + */ + __pyx_t_1 = (!((__pyx_v_flags & __pyx_v_bufmode) != 0)); + if (unlikely(__pyx_t_1)) { + + /* "View.MemoryView":191 + * bufmode = PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * if not (flags & bufmode): + * raise ValueError, "Can only create a buffer that is contiguous in memory." # <<<<<<<<<<<<<< + * info.buf = self.data + * info.len = self.len + */ + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_kp_s_Can_only_create_a_buffer_that_is, 0, 0); + __PYX_ERR(1, 191, __pyx_L1_error) + + /* "View.MemoryView":190 + * elif self.mode == u"fortran": + * bufmode = PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * if not (flags & bufmode): # <<<<<<<<<<<<<< + * raise ValueError, "Can only create a buffer that is contiguous in memory." + * info.buf = self.data + */ + } + + /* "View.MemoryView":185 + * def __getbuffer__(self, Py_buffer *info, int flags): + * cdef int bufmode = -1 + * if flags & (PyBUF_C_CONTIGUOUS | PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS): # <<<<<<<<<<<<<< + * if self.mode == u"c": + * bufmode = PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + */ + } + + /* "View.MemoryView":192 + * if not (flags & bufmode): + * raise ValueError, "Can only create a buffer that is contiguous in memory." + * info.buf = self.data # <<<<<<<<<<<<<< + * info.len = self.len + * + */ + __pyx_t_2 = __pyx_v_self->data; + __pyx_v_info->buf = __pyx_t_2; + + /* "View.MemoryView":193 + * raise ValueError, "Can only create a buffer that is contiguous in memory." + * info.buf = self.data + * info.len = self.len # <<<<<<<<<<<<<< + * + * if flags & PyBUF_STRIDES: + */ + __pyx_t_3 = __pyx_v_self->len; + __pyx_v_info->len = __pyx_t_3; + + /* "View.MemoryView":195 + * info.len = self.len + * + * if flags & PyBUF_STRIDES: # <<<<<<<<<<<<<< + * info.ndim = self.ndim + * info.shape = self._shape + */ + __pyx_t_1 = ((__pyx_v_flags & PyBUF_STRIDES) != 0); + if (__pyx_t_1) { + + /* "View.MemoryView":196 + * + * if flags & PyBUF_STRIDES: + * info.ndim = self.ndim # <<<<<<<<<<<<<< + * info.shape = self._shape + * info.strides = self._strides + */ + __pyx_t_4 = __pyx_v_self->ndim; + __pyx_v_info->ndim = __pyx_t_4; + + /* "View.MemoryView":197 + * if flags & PyBUF_STRIDES: + * info.ndim = self.ndim + * info.shape = self._shape # <<<<<<<<<<<<<< + * info.strides = self._strides + * else: + */ + __pyx_t_5 = __pyx_v_self->_shape; + __pyx_v_info->shape = __pyx_t_5; + + /* "View.MemoryView":198 + * info.ndim = self.ndim + * info.shape = self._shape + * info.strides = self._strides # <<<<<<<<<<<<<< + * else: + * info.ndim = 1 + */ + __pyx_t_5 = __pyx_v_self->_strides; + __pyx_v_info->strides = __pyx_t_5; + + /* "View.MemoryView":195 + * info.len = self.len + * + * if flags & PyBUF_STRIDES: # <<<<<<<<<<<<<< + * info.ndim = self.ndim + * info.shape = self._shape + */ + goto __pyx_L6; + } + + /* "View.MemoryView":200 + * info.strides = self._strides + * else: + * info.ndim = 1 # <<<<<<<<<<<<<< + * info.shape = &self.len if flags & PyBUF_ND else NULL + * info.strides = NULL + */ + /*else*/ { + __pyx_v_info->ndim = 1; + + /* "View.MemoryView":201 + * else: + * info.ndim = 1 + * info.shape = &self.len if flags & PyBUF_ND else NULL # <<<<<<<<<<<<<< + * info.strides = NULL + * + */ + __pyx_t_1 = ((__pyx_v_flags & PyBUF_ND) != 0); + if (__pyx_t_1) { + __pyx_t_5 = (&__pyx_v_self->len); + } else { + __pyx_t_5 = NULL; + } + __pyx_v_info->shape = __pyx_t_5; + + /* "View.MemoryView":202 + * info.ndim = 1 + * info.shape = &self.len if flags & PyBUF_ND else NULL + * info.strides = NULL # <<<<<<<<<<<<<< + * + * info.suboffsets = NULL + */ + __pyx_v_info->strides = NULL; + } + __pyx_L6:; + + /* "View.MemoryView":204 + * info.strides = NULL + * + * info.suboffsets = NULL # <<<<<<<<<<<<<< + * info.itemsize = self.itemsize + * info.readonly = 0 + */ + __pyx_v_info->suboffsets = NULL; + + /* "View.MemoryView":205 + * + * info.suboffsets = NULL + * info.itemsize = self.itemsize # <<<<<<<<<<<<<< + * info.readonly = 0 + * info.format = self.format if flags & PyBUF_FORMAT else NULL + */ + __pyx_t_3 = __pyx_v_self->itemsize; + __pyx_v_info->itemsize = __pyx_t_3; + + /* "View.MemoryView":206 + * info.suboffsets = NULL + * info.itemsize = self.itemsize + * info.readonly = 0 # <<<<<<<<<<<<<< + * info.format = self.format if flags & PyBUF_FORMAT else NULL + * info.obj = self + */ + __pyx_v_info->readonly = 0; + + /* "View.MemoryView":207 + * info.itemsize = self.itemsize + * info.readonly = 0 + * info.format = self.format if flags & PyBUF_FORMAT else NULL # <<<<<<<<<<<<<< + * info.obj = self + * + */ + __pyx_t_1 = ((__pyx_v_flags & PyBUF_FORMAT) != 0); + if (__pyx_t_1) { + __pyx_t_2 = __pyx_v_self->format; + } else { + __pyx_t_2 = NULL; + } + __pyx_v_info->format = __pyx_t_2; + + /* "View.MemoryView":208 + * info.readonly = 0 + * info.format = self.format if flags & PyBUF_FORMAT else NULL + * info.obj = self # <<<<<<<<<<<<<< + * + * def __dealloc__(array self): + */ + __Pyx_INCREF((PyObject *)__pyx_v_self); + __Pyx_GIVEREF((PyObject *)__pyx_v_self); + __Pyx_GOTREF(__pyx_v_info->obj); + __Pyx_DECREF(__pyx_v_info->obj); + __pyx_v_info->obj = ((PyObject *)__pyx_v_self); + + /* "View.MemoryView":182 + * _allocate_buffer(self) + * + * @cname('getbuffer') # <<<<<<<<<<<<<< + * def __getbuffer__(self, Py_buffer *info, int flags): + * cdef int bufmode = -1 + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView.array.__getbuffer__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + if (__pyx_v_info->obj != NULL) { + __Pyx_GOTREF(__pyx_v_info->obj); + __Pyx_DECREF(__pyx_v_info->obj); __pyx_v_info->obj = 0; + } + goto __pyx_L2; + __pyx_L0:; + if (__pyx_v_info->obj == Py_None) { + __Pyx_GOTREF(__pyx_v_info->obj); + __Pyx_DECREF(__pyx_v_info->obj); __pyx_v_info->obj = 0; + } + __pyx_L2:; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":210 + * info.obj = self + * + * def __dealloc__(array self): # <<<<<<<<<<<<<< + * if self.callback_free_data != NULL: + * self.callback_free_data(self.data) + */ + +/* Python wrapper */ +static void __pyx_array___dealloc__(PyObject *__pyx_v_self); /*proto*/ +static void __pyx_array___dealloc__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__dealloc__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_array___pyx_pf_15View_dot_MemoryView_5array_4__dealloc__(((struct __pyx_array_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); +} + +static void __pyx_array___pyx_pf_15View_dot_MemoryView_5array_4__dealloc__(struct __pyx_array_obj *__pyx_v_self) { + int __pyx_t_1; + int __pyx_t_2; + + /* "View.MemoryView":211 + * + * def __dealloc__(array self): + * if self.callback_free_data != NULL: # <<<<<<<<<<<<<< + * self.callback_free_data(self.data) + * elif self.free_data and self.data is not NULL: + */ + __pyx_t_1 = (__pyx_v_self->callback_free_data != NULL); + if (__pyx_t_1) { + + /* "View.MemoryView":212 + * def __dealloc__(array self): + * if self.callback_free_data != NULL: + * self.callback_free_data(self.data) # <<<<<<<<<<<<<< + * elif self.free_data and self.data is not NULL: + * if self.dtype_is_object: + */ + __pyx_v_self->callback_free_data(__pyx_v_self->data); + + /* "View.MemoryView":211 + * + * def __dealloc__(array self): + * if self.callback_free_data != NULL: # <<<<<<<<<<<<<< + * self.callback_free_data(self.data) + * elif self.free_data and self.data is not NULL: + */ + goto __pyx_L3; + } + + /* "View.MemoryView":213 + * if self.callback_free_data != NULL: + * self.callback_free_data(self.data) + * elif self.free_data and self.data is not NULL: # <<<<<<<<<<<<<< + * if self.dtype_is_object: + * refcount_objects_in_slice(self.data, self._shape, self._strides, self.ndim, inc=False) + */ + if (__pyx_v_self->free_data) { + } else { + __pyx_t_1 = __pyx_v_self->free_data; + goto __pyx_L4_bool_binop_done; + } + __pyx_t_2 = (__pyx_v_self->data != NULL); + __pyx_t_1 = __pyx_t_2; + __pyx_L4_bool_binop_done:; + if (__pyx_t_1) { + + /* "View.MemoryView":214 + * self.callback_free_data(self.data) + * elif self.free_data and self.data is not NULL: + * if self.dtype_is_object: # <<<<<<<<<<<<<< + * refcount_objects_in_slice(self.data, self._shape, self._strides, self.ndim, inc=False) + * free(self.data) + */ + if (__pyx_v_self->dtype_is_object) { + + /* "View.MemoryView":215 + * elif self.free_data and self.data is not NULL: + * if self.dtype_is_object: + * refcount_objects_in_slice(self.data, self._shape, self._strides, self.ndim, inc=False) # <<<<<<<<<<<<<< + * free(self.data) + * PyObject_Free(self._shape) + */ + __pyx_memoryview_refcount_objects_in_slice(__pyx_v_self->data, __pyx_v_self->_shape, __pyx_v_self->_strides, __pyx_v_self->ndim, 0); + + /* "View.MemoryView":214 + * self.callback_free_data(self.data) + * elif self.free_data and self.data is not NULL: + * if self.dtype_is_object: # <<<<<<<<<<<<<< + * refcount_objects_in_slice(self.data, self._shape, self._strides, self.ndim, inc=False) + * free(self.data) + */ + } + + /* "View.MemoryView":216 + * if self.dtype_is_object: + * refcount_objects_in_slice(self.data, self._shape, self._strides, self.ndim, inc=False) + * free(self.data) # <<<<<<<<<<<<<< + * PyObject_Free(self._shape) + * + */ + free(__pyx_v_self->data); + + /* "View.MemoryView":213 + * if self.callback_free_data != NULL: + * self.callback_free_data(self.data) + * elif self.free_data and self.data is not NULL: # <<<<<<<<<<<<<< + * if self.dtype_is_object: + * refcount_objects_in_slice(self.data, self._shape, self._strides, self.ndim, inc=False) + */ + } + __pyx_L3:; + + /* "View.MemoryView":217 + * refcount_objects_in_slice(self.data, self._shape, self._strides, self.ndim, inc=False) + * free(self.data) + * PyObject_Free(self._shape) # <<<<<<<<<<<<<< + * + * @property + */ + PyObject_Free(__pyx_v_self->_shape); + + /* "View.MemoryView":210 + * info.obj = self + * + * def __dealloc__(array self): # <<<<<<<<<<<<<< + * if self.callback_free_data != NULL: + * self.callback_free_data(self.data) + */ + + /* function exit code */ +} + +/* "View.MemoryView":219 + * PyObject_Free(self._shape) + * + * @property # <<<<<<<<<<<<<< + * def memview(self): + * return self.get_memview() + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_5array_7memview_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_5array_7memview_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_5array_7memview___get__(((struct __pyx_array_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_5array_7memview___get__(struct __pyx_array_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":221 + * @property + * def memview(self): + * return self.get_memview() # <<<<<<<<<<<<<< + * + * @cname('get_memview') + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = ((struct __pyx_vtabstruct_array *)__pyx_v_self->__pyx_vtab)->get_memview(__pyx_v_self); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 221, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "View.MemoryView":219 + * PyObject_Free(self._shape) + * + * @property # <<<<<<<<<<<<<< + * def memview(self): + * return self.get_memview() + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("View.MemoryView.array.memview.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":224 + * + * @cname('get_memview') + * cdef get_memview(self): # <<<<<<<<<<<<<< + * flags = PyBUF_ANY_CONTIGUOUS|PyBUF_FORMAT|PyBUF_WRITABLE + * return memoryview(self, flags, self.dtype_is_object) + */ + +static PyObject *__pyx_array_get_memview(struct __pyx_array_obj *__pyx_v_self) { + int __pyx_v_flags; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("get_memview", 1); + + /* "View.MemoryView":225 + * @cname('get_memview') + * cdef get_memview(self): + * flags = PyBUF_ANY_CONTIGUOUS|PyBUF_FORMAT|PyBUF_WRITABLE # <<<<<<<<<<<<<< + * return memoryview(self, flags, self.dtype_is_object) + * + */ + __pyx_v_flags = ((PyBUF_ANY_CONTIGUOUS | PyBUF_FORMAT) | PyBUF_WRITABLE); + + /* "View.MemoryView":226 + * cdef get_memview(self): + * flags = PyBUF_ANY_CONTIGUOUS|PyBUF_FORMAT|PyBUF_WRITABLE + * return memoryview(self, flags, self.dtype_is_object) # <<<<<<<<<<<<<< + * + * def __len__(self): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __Pyx_PyInt_From_int(__pyx_v_flags); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 226, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = __Pyx_PyBool_FromLong(__pyx_v_self->dtype_is_object); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 226, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_3 = PyTuple_New(3); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 226, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_INCREF((PyObject *)__pyx_v_self); + __Pyx_GIVEREF((PyObject *)__pyx_v_self); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, ((PyObject *)__pyx_v_self))) __PYX_ERR(1, 226, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_t_1)) __PYX_ERR(1, 226, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_2); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 2, __pyx_t_2)) __PYX_ERR(1, 226, __pyx_L1_error); + __pyx_t_1 = 0; + __pyx_t_2 = 0; + __pyx_t_2 = __Pyx_PyObject_Call(((PyObject *)__pyx_memoryview_type), __pyx_t_3, NULL); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 226, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":224 + * + * @cname('get_memview') + * cdef get_memview(self): # <<<<<<<<<<<<<< + * flags = PyBUF_ANY_CONTIGUOUS|PyBUF_FORMAT|PyBUF_WRITABLE + * return memoryview(self, flags, self.dtype_is_object) + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_AddTraceback("View.MemoryView.array.get_memview", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":228 + * return memoryview(self, flags, self.dtype_is_object) + * + * def __len__(self): # <<<<<<<<<<<<<< + * return self._shape[0] + * + */ + +/* Python wrapper */ +static Py_ssize_t __pyx_array___len__(PyObject *__pyx_v_self); /*proto*/ +static Py_ssize_t __pyx_array___len__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + Py_ssize_t __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__len__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_array___pyx_pf_15View_dot_MemoryView_5array_6__len__(((struct __pyx_array_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static Py_ssize_t __pyx_array___pyx_pf_15View_dot_MemoryView_5array_6__len__(struct __pyx_array_obj *__pyx_v_self) { + Py_ssize_t __pyx_r; + + /* "View.MemoryView":229 + * + * def __len__(self): + * return self._shape[0] # <<<<<<<<<<<<<< + * + * def __getattr__(self, attr): + */ + __pyx_r = (__pyx_v_self->_shape[0]); + goto __pyx_L0; + + /* "View.MemoryView":228 + * return memoryview(self, flags, self.dtype_is_object) + * + * def __len__(self): # <<<<<<<<<<<<<< + * return self._shape[0] + * + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":231 + * return self._shape[0] + * + * def __getattr__(self, attr): # <<<<<<<<<<<<<< + * return getattr(self.memview, attr) + * + */ + +/* Python wrapper */ +static PyObject *__pyx_array___getattr__(PyObject *__pyx_v_self, PyObject *__pyx_v_attr); /*proto*/ +static PyObject *__pyx_array___getattr__(PyObject *__pyx_v_self, PyObject *__pyx_v_attr) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__getattr__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_array___pyx_pf_15View_dot_MemoryView_5array_8__getattr__(((struct __pyx_array_obj *)__pyx_v_self), ((PyObject *)__pyx_v_attr)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_array___pyx_pf_15View_dot_MemoryView_5array_8__getattr__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_attr) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__getattr__", 1); + + /* "View.MemoryView":232 + * + * def __getattr__(self, attr): + * return getattr(self.memview, attr) # <<<<<<<<<<<<<< + * + * def __getitem__(self, item): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self), __pyx_n_s_memview); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 232, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = __Pyx_GetAttr(__pyx_t_1, __pyx_v_attr); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 232, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":231 + * return self._shape[0] + * + * def __getattr__(self, attr): # <<<<<<<<<<<<<< + * return getattr(self.memview, attr) + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.array.__getattr__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":234 + * return getattr(self.memview, attr) + * + * def __getitem__(self, item): # <<<<<<<<<<<<<< + * return self.memview[item] + * + */ + +/* Python wrapper */ +static PyObject *__pyx_array___getitem__(PyObject *__pyx_v_self, PyObject *__pyx_v_item); /*proto*/ +static PyObject *__pyx_array___getitem__(PyObject *__pyx_v_self, PyObject *__pyx_v_item) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__getitem__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_array___pyx_pf_15View_dot_MemoryView_5array_10__getitem__(((struct __pyx_array_obj *)__pyx_v_self), ((PyObject *)__pyx_v_item)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_array___pyx_pf_15View_dot_MemoryView_5array_10__getitem__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_item) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__getitem__", 1); + + /* "View.MemoryView":235 + * + * def __getitem__(self, item): + * return self.memview[item] # <<<<<<<<<<<<<< + * + * def __setitem__(self, item, value): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self), __pyx_n_s_memview); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 235, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = __Pyx_PyObject_GetItem(__pyx_t_1, __pyx_v_item); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 235, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":234 + * return getattr(self.memview, attr) + * + * def __getitem__(self, item): # <<<<<<<<<<<<<< + * return self.memview[item] + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.array.__getitem__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":237 + * return self.memview[item] + * + * def __setitem__(self, item, value): # <<<<<<<<<<<<<< + * self.memview[item] = value + * + */ + +/* Python wrapper */ +static int __pyx_array___setitem__(PyObject *__pyx_v_self, PyObject *__pyx_v_item, PyObject *__pyx_v_value); /*proto*/ +static int __pyx_array___setitem__(PyObject *__pyx_v_self, PyObject *__pyx_v_item, PyObject *__pyx_v_value) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + int __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__setitem__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_array___pyx_pf_15View_dot_MemoryView_5array_12__setitem__(((struct __pyx_array_obj *)__pyx_v_self), ((PyObject *)__pyx_v_item), ((PyObject *)__pyx_v_value)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static int __pyx_array___pyx_pf_15View_dot_MemoryView_5array_12__setitem__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_item, PyObject *__pyx_v_value) { + int __pyx_r; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__setitem__", 1); + + /* "View.MemoryView":238 + * + * def __setitem__(self, item, value): + * self.memview[item] = value # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self), __pyx_n_s_memview); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 238, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + if (unlikely((PyObject_SetItem(__pyx_t_1, __pyx_v_item, __pyx_v_value) < 0))) __PYX_ERR(1, 238, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + + /* "View.MemoryView":237 + * return self.memview[item] + * + * def __setitem__(self, item, value): # <<<<<<<<<<<<<< + * self.memview[item] = value + * + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("View.MemoryView.array.__setitem__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): + */ + +/* Python wrapper */ +static PyObject *__pyx_pw___pyx_array_1__reduce_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_pw___pyx_array_1__reduce_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__reduce_cython__ (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + if (unlikely(__pyx_nargs > 0)) { + __Pyx_RaiseArgtupleInvalid("__reduce_cython__", 1, 0, 0, __pyx_nargs); return NULL;} + if (unlikely(__pyx_kwds) && __Pyx_NumKwargs_FASTCALL(__pyx_kwds) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "__reduce_cython__", 0))) return NULL; + __pyx_r = __pyx_pf___pyx_array___reduce_cython__(((struct __pyx_array_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf___pyx_array___reduce_cython__(CYTHON_UNUSED struct __pyx_array_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__reduce_cython__", 1); + + /* "(tree fragment)":2 + * def __reduce_cython__(self): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" # <<<<<<<<<<<<<< + * def __setstate_cython__(self, __pyx_state): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + */ + __Pyx_Raise(__pyx_builtin_TypeError, __pyx_kp_s_no_default___reduce___due_to_non, 0, 0); + __PYX_ERR(1, 2, __pyx_L1_error) + + /* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView.array.__reduce_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":3 + * def __reduce_cython__(self): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + */ + +/* Python wrapper */ +static PyObject *__pyx_pw___pyx_array_3__setstate_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_pw___pyx_array_3__setstate_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + CYTHON_UNUSED PyObject *__pyx_v___pyx_state = 0; + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[1] = {0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__setstate_cython__ (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_pyx_state,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 1: values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_FASTCALL(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_pyx_state)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 3, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "__setstate_cython__") < 0)) __PYX_ERR(1, 3, __pyx_L3_error) + } + } else if (unlikely(__pyx_nargs != 1)) { + goto __pyx_L5_argtuple_error; + } else { + values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + } + __pyx_v___pyx_state = values[0]; + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("__setstate_cython__", 1, 1, 1, __pyx_nargs); __PYX_ERR(1, 3, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("View.MemoryView.array.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return NULL; + __pyx_L4_argument_unpacking_done:; + __pyx_r = __pyx_pf___pyx_array_2__setstate_cython__(((struct __pyx_array_obj *)__pyx_v_self), __pyx_v___pyx_state); + + /* function exit code */ + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf___pyx_array_2__setstate_cython__(CYTHON_UNUSED struct __pyx_array_obj *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__setstate_cython__", 1); + + /* "(tree fragment)":4 + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" # <<<<<<<<<<<<<< + */ + __Pyx_Raise(__pyx_builtin_TypeError, __pyx_kp_s_no_default___reduce___due_to_non, 0, 0); + __PYX_ERR(1, 4, __pyx_L1_error) + + /* "(tree fragment)":3 + * def __reduce_cython__(self): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView.array.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":248 + * + * @cname("__pyx_array_allocate_buffer") + * cdef int _allocate_buffer(array self) except -1: # <<<<<<<<<<<<<< + * + * + */ + +static int __pyx_array_allocate_buffer(struct __pyx_array_obj *__pyx_v_self) { + Py_ssize_t __pyx_v_i; + PyObject **__pyx_v_p; + int __pyx_r; + int __pyx_t_1; + Py_ssize_t __pyx_t_2; + Py_ssize_t __pyx_t_3; + Py_ssize_t __pyx_t_4; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + + /* "View.MemoryView":254 + * cdef PyObject **p + * + * self.free_data = True # <<<<<<<<<<<<<< + * self.data = malloc(self.len) + * if not self.data: + */ + __pyx_v_self->free_data = 1; + + /* "View.MemoryView":255 + * + * self.free_data = True + * self.data = malloc(self.len) # <<<<<<<<<<<<<< + * if not self.data: + * raise MemoryError, "unable to allocate array data." + */ + __pyx_v_self->data = ((char *)malloc(__pyx_v_self->len)); + + /* "View.MemoryView":256 + * self.free_data = True + * self.data = malloc(self.len) + * if not self.data: # <<<<<<<<<<<<<< + * raise MemoryError, "unable to allocate array data." + * + */ + __pyx_t_1 = (!(__pyx_v_self->data != 0)); + if (unlikely(__pyx_t_1)) { + + /* "View.MemoryView":257 + * self.data = malloc(self.len) + * if not self.data: + * raise MemoryError, "unable to allocate array data." # <<<<<<<<<<<<<< + * + * if self.dtype_is_object: + */ + __Pyx_Raise(__pyx_builtin_MemoryError, __pyx_kp_s_unable_to_allocate_array_data, 0, 0); + __PYX_ERR(1, 257, __pyx_L1_error) + + /* "View.MemoryView":256 + * self.free_data = True + * self.data = malloc(self.len) + * if not self.data: # <<<<<<<<<<<<<< + * raise MemoryError, "unable to allocate array data." + * + */ + } + + /* "View.MemoryView":259 + * raise MemoryError, "unable to allocate array data." + * + * if self.dtype_is_object: # <<<<<<<<<<<<<< + * p = self.data + * for i in range(self.len // self.itemsize): + */ + if (__pyx_v_self->dtype_is_object) { + + /* "View.MemoryView":260 + * + * if self.dtype_is_object: + * p = self.data # <<<<<<<<<<<<<< + * for i in range(self.len // self.itemsize): + * p[i] = Py_None + */ + __pyx_v_p = ((PyObject **)__pyx_v_self->data); + + /* "View.MemoryView":261 + * if self.dtype_is_object: + * p = self.data + * for i in range(self.len // self.itemsize): # <<<<<<<<<<<<<< + * p[i] = Py_None + * Py_INCREF(Py_None) + */ + if (unlikely(__pyx_v_self->itemsize == 0)) { + PyErr_SetString(PyExc_ZeroDivisionError, "integer division or modulo by zero"); + __PYX_ERR(1, 261, __pyx_L1_error) + } + else if (sizeof(Py_ssize_t) == sizeof(long) && (!(((Py_ssize_t)-1) > 0)) && unlikely(__pyx_v_self->itemsize == (Py_ssize_t)-1) && unlikely(__Pyx_UNARY_NEG_WOULD_OVERFLOW(__pyx_v_self->len))) { + PyErr_SetString(PyExc_OverflowError, "value too large to perform division"); + __PYX_ERR(1, 261, __pyx_L1_error) + } + __pyx_t_2 = __Pyx_div_Py_ssize_t(__pyx_v_self->len, __pyx_v_self->itemsize); + __pyx_t_3 = __pyx_t_2; + for (__pyx_t_4 = 0; __pyx_t_4 < __pyx_t_3; __pyx_t_4+=1) { + __pyx_v_i = __pyx_t_4; + + /* "View.MemoryView":262 + * p = self.data + * for i in range(self.len // self.itemsize): + * p[i] = Py_None # <<<<<<<<<<<<<< + * Py_INCREF(Py_None) + * return 0 + */ + (__pyx_v_p[__pyx_v_i]) = Py_None; + + /* "View.MemoryView":263 + * for i in range(self.len // self.itemsize): + * p[i] = Py_None + * Py_INCREF(Py_None) # <<<<<<<<<<<<<< + * return 0 + * + */ + Py_INCREF(Py_None); + } + + /* "View.MemoryView":259 + * raise MemoryError, "unable to allocate array data." + * + * if self.dtype_is_object: # <<<<<<<<<<<<<< + * p = self.data + * for i in range(self.len // self.itemsize): + */ + } + + /* "View.MemoryView":264 + * p[i] = Py_None + * Py_INCREF(Py_None) + * return 0 # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = 0; + goto __pyx_L0; + + /* "View.MemoryView":248 + * + * @cname("__pyx_array_allocate_buffer") + * cdef int _allocate_buffer(array self) except -1: # <<<<<<<<<<<<<< + * + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView._allocate_buffer", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":268 + * + * @cname("__pyx_array_new") + * cdef array array_cwrapper(tuple shape, Py_ssize_t itemsize, char *format, char *c_mode, char *buf): # <<<<<<<<<<<<<< + * cdef array result + * cdef str mode = "fortran" if c_mode[0] == b'f' else "c" # this often comes from a constant C string. + */ + +static struct __pyx_array_obj *__pyx_array_new(PyObject *__pyx_v_shape, Py_ssize_t __pyx_v_itemsize, char *__pyx_v_format, char *__pyx_v_c_mode, char *__pyx_v_buf) { + struct __pyx_array_obj *__pyx_v_result = 0; + PyObject *__pyx_v_mode = 0; + struct __pyx_array_obj *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("array_cwrapper", 1); + + /* "View.MemoryView":270 + * cdef array array_cwrapper(tuple shape, Py_ssize_t itemsize, char *format, char *c_mode, char *buf): + * cdef array result + * cdef str mode = "fortran" if c_mode[0] == b'f' else "c" # this often comes from a constant C string. # <<<<<<<<<<<<<< + * + * if buf is NULL: + */ + __pyx_t_2 = ((__pyx_v_c_mode[0]) == 'f'); + if (__pyx_t_2) { + __Pyx_INCREF(__pyx_n_s_fortran); + __pyx_t_1 = __pyx_n_s_fortran; + } else { + __Pyx_INCREF(__pyx_n_s_c); + __pyx_t_1 = __pyx_n_s_c; + } + __pyx_v_mode = ((PyObject*)__pyx_t_1); + __pyx_t_1 = 0; + + /* "View.MemoryView":272 + * cdef str mode = "fortran" if c_mode[0] == b'f' else "c" # this often comes from a constant C string. + * + * if buf is NULL: # <<<<<<<<<<<<<< + * result = array.__new__(array, shape, itemsize, format, mode) + * else: + */ + __pyx_t_2 = (__pyx_v_buf == NULL); + if (__pyx_t_2) { + + /* "View.MemoryView":273 + * + * if buf is NULL: + * result = array.__new__(array, shape, itemsize, format, mode) # <<<<<<<<<<<<<< + * else: + * result = array.__new__(array, shape, itemsize, format, mode, allocate_buffer=False) + */ + __pyx_t_1 = PyInt_FromSsize_t(__pyx_v_itemsize); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 273, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_3 = __Pyx_PyBytes_FromString(__pyx_v_format); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 273, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_4 = PyTuple_New(4); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 273, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_INCREF(__pyx_v_shape); + __Pyx_GIVEREF(__pyx_v_shape); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 0, __pyx_v_shape)) __PYX_ERR(1, 273, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 1, __pyx_t_1)) __PYX_ERR(1, 273, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_3); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 2, __pyx_t_3)) __PYX_ERR(1, 273, __pyx_L1_error); + __Pyx_INCREF(__pyx_v_mode); + __Pyx_GIVEREF(__pyx_v_mode); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 3, __pyx_v_mode)) __PYX_ERR(1, 273, __pyx_L1_error); + __pyx_t_1 = 0; + __pyx_t_3 = 0; + __pyx_t_3 = ((PyObject *)__pyx_tp_new_array(((PyTypeObject *)__pyx_array_type), __pyx_t_4, NULL)); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 273, __pyx_L1_error) + __Pyx_GOTREF((PyObject *)__pyx_t_3); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __pyx_v_result = ((struct __pyx_array_obj *)__pyx_t_3); + __pyx_t_3 = 0; + + /* "View.MemoryView":272 + * cdef str mode = "fortran" if c_mode[0] == b'f' else "c" # this often comes from a constant C string. + * + * if buf is NULL: # <<<<<<<<<<<<<< + * result = array.__new__(array, shape, itemsize, format, mode) + * else: + */ + goto __pyx_L3; + } + + /* "View.MemoryView":275 + * result = array.__new__(array, shape, itemsize, format, mode) + * else: + * result = array.__new__(array, shape, itemsize, format, mode, allocate_buffer=False) # <<<<<<<<<<<<<< + * result.data = buf + * + */ + /*else*/ { + __pyx_t_3 = PyInt_FromSsize_t(__pyx_v_itemsize); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 275, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_4 = __Pyx_PyBytes_FromString(__pyx_v_format); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 275, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_1 = PyTuple_New(4); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 275, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_INCREF(__pyx_v_shape); + __Pyx_GIVEREF(__pyx_v_shape); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 0, __pyx_v_shape)) __PYX_ERR(1, 275, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_3); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 1, __pyx_t_3)) __PYX_ERR(1, 275, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_4); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 2, __pyx_t_4)) __PYX_ERR(1, 275, __pyx_L1_error); + __Pyx_INCREF(__pyx_v_mode); + __Pyx_GIVEREF(__pyx_v_mode); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 3, __pyx_v_mode)) __PYX_ERR(1, 275, __pyx_L1_error); + __pyx_t_3 = 0; + __pyx_t_4 = 0; + __pyx_t_4 = __Pyx_PyDict_NewPresized(1); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 275, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + if (PyDict_SetItem(__pyx_t_4, __pyx_n_s_allocate_buffer, Py_False) < 0) __PYX_ERR(1, 275, __pyx_L1_error) + __pyx_t_3 = ((PyObject *)__pyx_tp_new_array(((PyTypeObject *)__pyx_array_type), __pyx_t_1, __pyx_t_4)); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 275, __pyx_L1_error) + __Pyx_GOTREF((PyObject *)__pyx_t_3); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __pyx_v_result = ((struct __pyx_array_obj *)__pyx_t_3); + __pyx_t_3 = 0; + + /* "View.MemoryView":276 + * else: + * result = array.__new__(array, shape, itemsize, format, mode, allocate_buffer=False) + * result.data = buf # <<<<<<<<<<<<<< + * + * return result + */ + __pyx_v_result->data = __pyx_v_buf; + } + __pyx_L3:; + + /* "View.MemoryView":278 + * result.data = buf + * + * return result # <<<<<<<<<<<<<< + * + * + */ + __Pyx_XDECREF((PyObject *)__pyx_r); + __Pyx_INCREF((PyObject *)__pyx_v_result); + __pyx_r = __pyx_v_result; + goto __pyx_L0; + + /* "View.MemoryView":268 + * + * @cname("__pyx_array_new") + * cdef array array_cwrapper(tuple shape, Py_ssize_t itemsize, char *format, char *c_mode, char *buf): # <<<<<<<<<<<<<< + * cdef array result + * cdef str mode = "fortran" if c_mode[0] == b'f' else "c" # this often comes from a constant C string. + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_XDECREF(__pyx_t_4); + __Pyx_AddTraceback("View.MemoryView.array_cwrapper", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XDECREF((PyObject *)__pyx_v_result); + __Pyx_XDECREF(__pyx_v_mode); + __Pyx_XGIVEREF((PyObject *)__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":304 + * cdef class Enum(object): + * cdef object name + * def __init__(self, name): # <<<<<<<<<<<<<< + * self.name = name + * def __repr__(self): + */ + +/* Python wrapper */ +static int __pyx_MemviewEnum___init__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ +static int __pyx_MemviewEnum___init__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { + PyObject *__pyx_v_name = 0; + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[1] = {0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + int __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__init__ (wrapper)", 0); + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return -1; + #endif + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_name,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 1: values[0] = __Pyx_Arg_VARARGS(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_VARARGS(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_name)) != 0)) { + (void)__Pyx_Arg_NewRef_VARARGS(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 304, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "__init__") < 0)) __PYX_ERR(1, 304, __pyx_L3_error) + } + } else if (unlikely(__pyx_nargs != 1)) { + goto __pyx_L5_argtuple_error; + } else { + values[0] = __Pyx_Arg_VARARGS(__pyx_args, 0); + } + __pyx_v_name = values[0]; + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("__init__", 1, 1, 1, __pyx_nargs); __PYX_ERR(1, 304, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_VARARGS(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("View.MemoryView.Enum.__init__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return -1; + __pyx_L4_argument_unpacking_done:; + __pyx_r = __pyx_MemviewEnum___pyx_pf_15View_dot_MemoryView_4Enum___init__(((struct __pyx_MemviewEnum_obj *)__pyx_v_self), __pyx_v_name); + + /* function exit code */ + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_VARARGS(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static int __pyx_MemviewEnum___pyx_pf_15View_dot_MemoryView_4Enum___init__(struct __pyx_MemviewEnum_obj *__pyx_v_self, PyObject *__pyx_v_name) { + int __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__init__", 1); + + /* "View.MemoryView":305 + * cdef object name + * def __init__(self, name): + * self.name = name # <<<<<<<<<<<<<< + * def __repr__(self): + * return self.name + */ + __Pyx_INCREF(__pyx_v_name); + __Pyx_GIVEREF(__pyx_v_name); + __Pyx_GOTREF(__pyx_v_self->name); + __Pyx_DECREF(__pyx_v_self->name); + __pyx_v_self->name = __pyx_v_name; + + /* "View.MemoryView":304 + * cdef class Enum(object): + * cdef object name + * def __init__(self, name): # <<<<<<<<<<<<<< + * self.name = name + * def __repr__(self): + */ + + /* function exit code */ + __pyx_r = 0; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":306 + * def __init__(self, name): + * self.name = name + * def __repr__(self): # <<<<<<<<<<<<<< + * return self.name + * + */ + +/* Python wrapper */ +static PyObject *__pyx_MemviewEnum___repr__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_MemviewEnum___repr__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__repr__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_MemviewEnum___pyx_pf_15View_dot_MemoryView_4Enum_2__repr__(((struct __pyx_MemviewEnum_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_MemviewEnum___pyx_pf_15View_dot_MemoryView_4Enum_2__repr__(struct __pyx_MemviewEnum_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__repr__", 1); + + /* "View.MemoryView":307 + * self.name = name + * def __repr__(self): + * return self.name # <<<<<<<<<<<<<< + * + * cdef generic = Enum("") + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(__pyx_v_self->name); + __pyx_r = __pyx_v_self->name; + goto __pyx_L0; + + /* "View.MemoryView":306 + * def __init__(self, name): + * self.name = name + * def __repr__(self): # <<<<<<<<<<<<<< + * return self.name + * + */ + + /* function exit code */ + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * cdef tuple state + * cdef object _dict + */ + +/* Python wrapper */ +static PyObject *__pyx_pw___pyx_MemviewEnum_1__reduce_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_pw___pyx_MemviewEnum_1__reduce_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__reduce_cython__ (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + if (unlikely(__pyx_nargs > 0)) { + __Pyx_RaiseArgtupleInvalid("__reduce_cython__", 1, 0, 0, __pyx_nargs); return NULL;} + if (unlikely(__pyx_kwds) && __Pyx_NumKwargs_FASTCALL(__pyx_kwds) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "__reduce_cython__", 0))) return NULL; + __pyx_r = __pyx_pf___pyx_MemviewEnum___reduce_cython__(((struct __pyx_MemviewEnum_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf___pyx_MemviewEnum___reduce_cython__(struct __pyx_MemviewEnum_obj *__pyx_v_self) { + PyObject *__pyx_v_state = 0; + PyObject *__pyx_v__dict = 0; + int __pyx_v_use_setstate; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__reduce_cython__", 1); + + /* "(tree fragment)":5 + * cdef object _dict + * cdef bint use_setstate + * state = (self.name,) # <<<<<<<<<<<<<< + * _dict = getattr(self, '__dict__', None) + * if _dict is not None: + */ + __pyx_t_1 = PyTuple_New(1); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 5, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_INCREF(__pyx_v_self->name); + __Pyx_GIVEREF(__pyx_v_self->name); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 0, __pyx_v_self->name)) __PYX_ERR(1, 5, __pyx_L1_error); + __pyx_v_state = ((PyObject*)__pyx_t_1); + __pyx_t_1 = 0; + + /* "(tree fragment)":6 + * cdef bint use_setstate + * state = (self.name,) + * _dict = getattr(self, '__dict__', None) # <<<<<<<<<<<<<< + * if _dict is not None: + * state += (_dict,) + */ + __pyx_t_1 = __Pyx_GetAttr3(((PyObject *)__pyx_v_self), __pyx_n_s_dict, Py_None); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 6, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_v__dict = __pyx_t_1; + __pyx_t_1 = 0; + + /* "(tree fragment)":7 + * state = (self.name,) + * _dict = getattr(self, '__dict__', None) + * if _dict is not None: # <<<<<<<<<<<<<< + * state += (_dict,) + * use_setstate = True + */ + __pyx_t_2 = (__pyx_v__dict != Py_None); + if (__pyx_t_2) { + + /* "(tree fragment)":8 + * _dict = getattr(self, '__dict__', None) + * if _dict is not None: + * state += (_dict,) # <<<<<<<<<<<<<< + * use_setstate = True + * else: + */ + __pyx_t_1 = PyTuple_New(1); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 8, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_INCREF(__pyx_v__dict); + __Pyx_GIVEREF(__pyx_v__dict); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 0, __pyx_v__dict)) __PYX_ERR(1, 8, __pyx_L1_error); + __pyx_t_3 = PyNumber_InPlaceAdd(__pyx_v_state, __pyx_t_1); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 8, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_DECREF_SET(__pyx_v_state, ((PyObject*)__pyx_t_3)); + __pyx_t_3 = 0; + + /* "(tree fragment)":9 + * if _dict is not None: + * state += (_dict,) + * use_setstate = True # <<<<<<<<<<<<<< + * else: + * use_setstate = self.name is not None + */ + __pyx_v_use_setstate = 1; + + /* "(tree fragment)":7 + * state = (self.name,) + * _dict = getattr(self, '__dict__', None) + * if _dict is not None: # <<<<<<<<<<<<<< + * state += (_dict,) + * use_setstate = True + */ + goto __pyx_L3; + } + + /* "(tree fragment)":11 + * use_setstate = True + * else: + * use_setstate = self.name is not None # <<<<<<<<<<<<<< + * if use_setstate: + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, None), state + */ + /*else*/ { + __pyx_t_2 = (__pyx_v_self->name != Py_None); + __pyx_v_use_setstate = __pyx_t_2; + } + __pyx_L3:; + + /* "(tree fragment)":12 + * else: + * use_setstate = self.name is not None + * if use_setstate: # <<<<<<<<<<<<<< + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, None), state + * else: + */ + if (__pyx_v_use_setstate) { + + /* "(tree fragment)":13 + * use_setstate = self.name is not None + * if use_setstate: + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, None), state # <<<<<<<<<<<<<< + * else: + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, state) + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_GetModuleGlobalName(__pyx_t_3, __pyx_n_s_pyx_unpickle_Enum); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 13, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_1 = PyTuple_New(3); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 13, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_INCREF(((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); + __Pyx_GIVEREF(((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 0, ((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self))))) __PYX_ERR(1, 13, __pyx_L1_error); + __Pyx_INCREF(__pyx_int_136983863); + __Pyx_GIVEREF(__pyx_int_136983863); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 1, __pyx_int_136983863)) __PYX_ERR(1, 13, __pyx_L1_error); + __Pyx_INCREF(Py_None); + __Pyx_GIVEREF(Py_None); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 2, Py_None)) __PYX_ERR(1, 13, __pyx_L1_error); + __pyx_t_4 = PyTuple_New(3); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 13, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_GIVEREF(__pyx_t_3); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 0, __pyx_t_3)) __PYX_ERR(1, 13, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 1, __pyx_t_1)) __PYX_ERR(1, 13, __pyx_L1_error); + __Pyx_INCREF(__pyx_v_state); + __Pyx_GIVEREF(__pyx_v_state); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 2, __pyx_v_state)) __PYX_ERR(1, 13, __pyx_L1_error); + __pyx_t_3 = 0; + __pyx_t_1 = 0; + __pyx_r = __pyx_t_4; + __pyx_t_4 = 0; + goto __pyx_L0; + + /* "(tree fragment)":12 + * else: + * use_setstate = self.name is not None + * if use_setstate: # <<<<<<<<<<<<<< + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, None), state + * else: + */ + } + + /* "(tree fragment)":15 + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, None), state + * else: + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, state) # <<<<<<<<<<<<<< + * def __setstate_cython__(self, __pyx_state): + * __pyx_unpickle_Enum__set_state(self, __pyx_state) + */ + /*else*/ { + __Pyx_XDECREF(__pyx_r); + __Pyx_GetModuleGlobalName(__pyx_t_4, __pyx_n_s_pyx_unpickle_Enum); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 15, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_1 = PyTuple_New(3); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 15, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_INCREF(((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); + __Pyx_GIVEREF(((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 0, ((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self))))) __PYX_ERR(1, 15, __pyx_L1_error); + __Pyx_INCREF(__pyx_int_136983863); + __Pyx_GIVEREF(__pyx_int_136983863); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 1, __pyx_int_136983863)) __PYX_ERR(1, 15, __pyx_L1_error); + __Pyx_INCREF(__pyx_v_state); + __Pyx_GIVEREF(__pyx_v_state); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 2, __pyx_v_state)) __PYX_ERR(1, 15, __pyx_L1_error); + __pyx_t_3 = PyTuple_New(2); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 15, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_GIVEREF(__pyx_t_4); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_t_4)) __PYX_ERR(1, 15, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_t_1)) __PYX_ERR(1, 15, __pyx_L1_error); + __pyx_t_4 = 0; + __pyx_t_1 = 0; + __pyx_r = __pyx_t_3; + __pyx_t_3 = 0; + goto __pyx_L0; + } + + /* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * cdef tuple state + * cdef object _dict + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_XDECREF(__pyx_t_4); + __Pyx_AddTraceback("View.MemoryView.Enum.__reduce_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_state); + __Pyx_XDECREF(__pyx_v__dict); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":16 + * else: + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, state) + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * __pyx_unpickle_Enum__set_state(self, __pyx_state) + */ + +/* Python wrapper */ +static PyObject *__pyx_pw___pyx_MemviewEnum_3__setstate_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_pw___pyx_MemviewEnum_3__setstate_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + PyObject *__pyx_v___pyx_state = 0; + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[1] = {0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__setstate_cython__ (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_pyx_state,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 1: values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_FASTCALL(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_pyx_state)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 16, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "__setstate_cython__") < 0)) __PYX_ERR(1, 16, __pyx_L3_error) + } + } else if (unlikely(__pyx_nargs != 1)) { + goto __pyx_L5_argtuple_error; + } else { + values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + } + __pyx_v___pyx_state = values[0]; + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("__setstate_cython__", 1, 1, 1, __pyx_nargs); __PYX_ERR(1, 16, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("View.MemoryView.Enum.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return NULL; + __pyx_L4_argument_unpacking_done:; + __pyx_r = __pyx_pf___pyx_MemviewEnum_2__setstate_cython__(((struct __pyx_MemviewEnum_obj *)__pyx_v_self), __pyx_v___pyx_state); + + /* function exit code */ + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf___pyx_MemviewEnum_2__setstate_cython__(struct __pyx_MemviewEnum_obj *__pyx_v_self, PyObject *__pyx_v___pyx_state) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__setstate_cython__", 1); + + /* "(tree fragment)":17 + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, state) + * def __setstate_cython__(self, __pyx_state): + * __pyx_unpickle_Enum__set_state(self, __pyx_state) # <<<<<<<<<<<<<< + */ + if (!(likely(PyTuple_CheckExact(__pyx_v___pyx_state))||((__pyx_v___pyx_state) == Py_None) || __Pyx_RaiseUnexpectedTypeError("tuple", __pyx_v___pyx_state))) __PYX_ERR(1, 17, __pyx_L1_error) + __pyx_t_1 = __pyx_unpickle_Enum__set_state(__pyx_v_self, ((PyObject*)__pyx_v___pyx_state)); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 17, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + + /* "(tree fragment)":16 + * else: + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, state) + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * __pyx_unpickle_Enum__set_state(self, __pyx_state) + */ + + /* function exit code */ + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("View.MemoryView.Enum.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":349 + * cdef __Pyx_TypeInfo *typeinfo + * + * def __cinit__(memoryview self, object obj, int flags, bint dtype_is_object=False): # <<<<<<<<<<<<<< + * self.obj = obj + * self.flags = flags + */ + +/* Python wrapper */ +static int __pyx_memoryview___cinit__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ +static int __pyx_memoryview___cinit__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { + PyObject *__pyx_v_obj = 0; + int __pyx_v_flags; + int __pyx_v_dtype_is_object; + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[3] = {0,0,0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + int __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__cinit__ (wrapper)", 0); + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return -1; + #endif + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_obj,&__pyx_n_s_flags,&__pyx_n_s_dtype_is_object,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 3: values[2] = __Pyx_Arg_VARARGS(__pyx_args, 2); + CYTHON_FALLTHROUGH; + case 2: values[1] = __Pyx_Arg_VARARGS(__pyx_args, 1); + CYTHON_FALLTHROUGH; + case 1: values[0] = __Pyx_Arg_VARARGS(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_VARARGS(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_obj)) != 0)) { + (void)__Pyx_Arg_NewRef_VARARGS(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 349, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + CYTHON_FALLTHROUGH; + case 1: + if (likely((values[1] = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_flags)) != 0)) { + (void)__Pyx_Arg_NewRef_VARARGS(values[1]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 349, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("__cinit__", 0, 2, 3, 1); __PYX_ERR(1, 349, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 2: + if (kw_args > 0) { + PyObject* value = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_dtype_is_object); + if (value) { values[2] = __Pyx_Arg_NewRef_VARARGS(value); kw_args--; } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 349, __pyx_L3_error) + } + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "__cinit__") < 0)) __PYX_ERR(1, 349, __pyx_L3_error) + } + } else { + switch (__pyx_nargs) { + case 3: values[2] = __Pyx_Arg_VARARGS(__pyx_args, 2); + CYTHON_FALLTHROUGH; + case 2: values[1] = __Pyx_Arg_VARARGS(__pyx_args, 1); + values[0] = __Pyx_Arg_VARARGS(__pyx_args, 0); + break; + default: goto __pyx_L5_argtuple_error; + } + } + __pyx_v_obj = values[0]; + __pyx_v_flags = __Pyx_PyInt_As_int(values[1]); if (unlikely((__pyx_v_flags == (int)-1) && PyErr_Occurred())) __PYX_ERR(1, 349, __pyx_L3_error) + if (values[2]) { + __pyx_v_dtype_is_object = __Pyx_PyObject_IsTrue(values[2]); if (unlikely((__pyx_v_dtype_is_object == (int)-1) && PyErr_Occurred())) __PYX_ERR(1, 349, __pyx_L3_error) + } else { + __pyx_v_dtype_is_object = ((int)0); + } + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("__cinit__", 0, 2, 3, __pyx_nargs); __PYX_ERR(1, 349, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_VARARGS(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("View.MemoryView.memoryview.__cinit__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return -1; + __pyx_L4_argument_unpacking_done:; + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview___cinit__(((struct __pyx_memoryview_obj *)__pyx_v_self), __pyx_v_obj, __pyx_v_flags, __pyx_v_dtype_is_object); + + /* function exit code */ + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_VARARGS(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static int __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview___cinit__(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_obj, int __pyx_v_flags, int __pyx_v_dtype_is_object) { + int __pyx_r; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + Py_intptr_t __pyx_t_4; + size_t __pyx_t_5; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__cinit__", 1); + + /* "View.MemoryView":350 + * + * def __cinit__(memoryview self, object obj, int flags, bint dtype_is_object=False): + * self.obj = obj # <<<<<<<<<<<<<< + * self.flags = flags + * if type(self) is memoryview or obj is not None: + */ + __Pyx_INCREF(__pyx_v_obj); + __Pyx_GIVEREF(__pyx_v_obj); + __Pyx_GOTREF(__pyx_v_self->obj); + __Pyx_DECREF(__pyx_v_self->obj); + __pyx_v_self->obj = __pyx_v_obj; + + /* "View.MemoryView":351 + * def __cinit__(memoryview self, object obj, int flags, bint dtype_is_object=False): + * self.obj = obj + * self.flags = flags # <<<<<<<<<<<<<< + * if type(self) is memoryview or obj is not None: + * __Pyx_GetBuffer(obj, &self.view, flags) + */ + __pyx_v_self->flags = __pyx_v_flags; + + /* "View.MemoryView":352 + * self.obj = obj + * self.flags = flags + * if type(self) is memoryview or obj is not None: # <<<<<<<<<<<<<< + * __Pyx_GetBuffer(obj, &self.view, flags) + * if self.view.obj == NULL: + */ + __pyx_t_2 = (((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self))) == ((PyObject *)__pyx_memoryview_type)); + if (!__pyx_t_2) { + } else { + __pyx_t_1 = __pyx_t_2; + goto __pyx_L4_bool_binop_done; + } + __pyx_t_2 = (__pyx_v_obj != Py_None); + __pyx_t_1 = __pyx_t_2; + __pyx_L4_bool_binop_done:; + if (__pyx_t_1) { + + /* "View.MemoryView":353 + * self.flags = flags + * if type(self) is memoryview or obj is not None: + * __Pyx_GetBuffer(obj, &self.view, flags) # <<<<<<<<<<<<<< + * if self.view.obj == NULL: + * (<__pyx_buffer *> &self.view).obj = Py_None + */ + __pyx_t_3 = __Pyx_GetBuffer(__pyx_v_obj, (&__pyx_v_self->view), __pyx_v_flags); if (unlikely(__pyx_t_3 == ((int)-1))) __PYX_ERR(1, 353, __pyx_L1_error) + + /* "View.MemoryView":354 + * if type(self) is memoryview or obj is not None: + * __Pyx_GetBuffer(obj, &self.view, flags) + * if self.view.obj == NULL: # <<<<<<<<<<<<<< + * (<__pyx_buffer *> &self.view).obj = Py_None + * Py_INCREF(Py_None) + */ + __pyx_t_1 = (((PyObject *)__pyx_v_self->view.obj) == NULL); + if (__pyx_t_1) { + + /* "View.MemoryView":355 + * __Pyx_GetBuffer(obj, &self.view, flags) + * if self.view.obj == NULL: + * (<__pyx_buffer *> &self.view).obj = Py_None # <<<<<<<<<<<<<< + * Py_INCREF(Py_None) + * + */ + ((Py_buffer *)(&__pyx_v_self->view))->obj = Py_None; + + /* "View.MemoryView":356 + * if self.view.obj == NULL: + * (<__pyx_buffer *> &self.view).obj = Py_None + * Py_INCREF(Py_None) # <<<<<<<<<<<<<< + * + * if not __PYX_CYTHON_ATOMICS_ENABLED(): + */ + Py_INCREF(Py_None); + + /* "View.MemoryView":354 + * if type(self) is memoryview or obj is not None: + * __Pyx_GetBuffer(obj, &self.view, flags) + * if self.view.obj == NULL: # <<<<<<<<<<<<<< + * (<__pyx_buffer *> &self.view).obj = Py_None + * Py_INCREF(Py_None) + */ + } + + /* "View.MemoryView":352 + * self.obj = obj + * self.flags = flags + * if type(self) is memoryview or obj is not None: # <<<<<<<<<<<<<< + * __Pyx_GetBuffer(obj, &self.view, flags) + * if self.view.obj == NULL: + */ + } + + /* "View.MemoryView":358 + * Py_INCREF(Py_None) + * + * if not __PYX_CYTHON_ATOMICS_ENABLED(): # <<<<<<<<<<<<<< + * global __pyx_memoryview_thread_locks_used + * if __pyx_memoryview_thread_locks_used < 8: + */ + __pyx_t_1 = (!__PYX_CYTHON_ATOMICS_ENABLED()); + if (__pyx_t_1) { + + /* "View.MemoryView":360 + * if not __PYX_CYTHON_ATOMICS_ENABLED(): + * global __pyx_memoryview_thread_locks_used + * if __pyx_memoryview_thread_locks_used < 8: # <<<<<<<<<<<<<< + * self.lock = __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] + * __pyx_memoryview_thread_locks_used += 1 + */ + __pyx_t_1 = (__pyx_memoryview_thread_locks_used < 8); + if (__pyx_t_1) { + + /* "View.MemoryView":361 + * global __pyx_memoryview_thread_locks_used + * if __pyx_memoryview_thread_locks_used < 8: + * self.lock = __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] # <<<<<<<<<<<<<< + * __pyx_memoryview_thread_locks_used += 1 + * if self.lock is NULL: + */ + __pyx_v_self->lock = (__pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used]); + + /* "View.MemoryView":362 + * if __pyx_memoryview_thread_locks_used < 8: + * self.lock = __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] + * __pyx_memoryview_thread_locks_used += 1 # <<<<<<<<<<<<<< + * if self.lock is NULL: + * self.lock = PyThread_allocate_lock() + */ + __pyx_memoryview_thread_locks_used = (__pyx_memoryview_thread_locks_used + 1); + + /* "View.MemoryView":360 + * if not __PYX_CYTHON_ATOMICS_ENABLED(): + * global __pyx_memoryview_thread_locks_used + * if __pyx_memoryview_thread_locks_used < 8: # <<<<<<<<<<<<<< + * self.lock = __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] + * __pyx_memoryview_thread_locks_used += 1 + */ + } + + /* "View.MemoryView":363 + * self.lock = __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] + * __pyx_memoryview_thread_locks_used += 1 + * if self.lock is NULL: # <<<<<<<<<<<<<< + * self.lock = PyThread_allocate_lock() + * if self.lock is NULL: + */ + __pyx_t_1 = (__pyx_v_self->lock == NULL); + if (__pyx_t_1) { + + /* "View.MemoryView":364 + * __pyx_memoryview_thread_locks_used += 1 + * if self.lock is NULL: + * self.lock = PyThread_allocate_lock() # <<<<<<<<<<<<<< + * if self.lock is NULL: + * raise MemoryError + */ + __pyx_v_self->lock = PyThread_allocate_lock(); + + /* "View.MemoryView":365 + * if self.lock is NULL: + * self.lock = PyThread_allocate_lock() + * if self.lock is NULL: # <<<<<<<<<<<<<< + * raise MemoryError + * + */ + __pyx_t_1 = (__pyx_v_self->lock == NULL); + if (unlikely(__pyx_t_1)) { + + /* "View.MemoryView":366 + * self.lock = PyThread_allocate_lock() + * if self.lock is NULL: + * raise MemoryError # <<<<<<<<<<<<<< + * + * if flags & PyBUF_FORMAT: + */ + PyErr_NoMemory(); __PYX_ERR(1, 366, __pyx_L1_error) + + /* "View.MemoryView":365 + * if self.lock is NULL: + * self.lock = PyThread_allocate_lock() + * if self.lock is NULL: # <<<<<<<<<<<<<< + * raise MemoryError + * + */ + } + + /* "View.MemoryView":363 + * self.lock = __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] + * __pyx_memoryview_thread_locks_used += 1 + * if self.lock is NULL: # <<<<<<<<<<<<<< + * self.lock = PyThread_allocate_lock() + * if self.lock is NULL: + */ + } + + /* "View.MemoryView":358 + * Py_INCREF(Py_None) + * + * if not __PYX_CYTHON_ATOMICS_ENABLED(): # <<<<<<<<<<<<<< + * global __pyx_memoryview_thread_locks_used + * if __pyx_memoryview_thread_locks_used < 8: + */ + } + + /* "View.MemoryView":368 + * raise MemoryError + * + * if flags & PyBUF_FORMAT: # <<<<<<<<<<<<<< + * self.dtype_is_object = (self.view.format[0] == b'O' and self.view.format[1] == b'\0') + * else: + */ + __pyx_t_1 = ((__pyx_v_flags & PyBUF_FORMAT) != 0); + if (__pyx_t_1) { + + /* "View.MemoryView":369 + * + * if flags & PyBUF_FORMAT: + * self.dtype_is_object = (self.view.format[0] == b'O' and self.view.format[1] == b'\0') # <<<<<<<<<<<<<< + * else: + * self.dtype_is_object = dtype_is_object + */ + __pyx_t_2 = ((__pyx_v_self->view.format[0]) == 'O'); + if (__pyx_t_2) { + } else { + __pyx_t_1 = __pyx_t_2; + goto __pyx_L12_bool_binop_done; + } + __pyx_t_2 = ((__pyx_v_self->view.format[1]) == '\x00'); + __pyx_t_1 = __pyx_t_2; + __pyx_L12_bool_binop_done:; + __pyx_v_self->dtype_is_object = __pyx_t_1; + + /* "View.MemoryView":368 + * raise MemoryError + * + * if flags & PyBUF_FORMAT: # <<<<<<<<<<<<<< + * self.dtype_is_object = (self.view.format[0] == b'O' and self.view.format[1] == b'\0') + * else: + */ + goto __pyx_L11; + } + + /* "View.MemoryView":371 + * self.dtype_is_object = (self.view.format[0] == b'O' and self.view.format[1] == b'\0') + * else: + * self.dtype_is_object = dtype_is_object # <<<<<<<<<<<<<< + * + * assert (&self.acquisition_count) % sizeof(__pyx_atomic_int_type) == 0 + */ + /*else*/ { + __pyx_v_self->dtype_is_object = __pyx_v_dtype_is_object; + } + __pyx_L11:; + + /* "View.MemoryView":373 + * self.dtype_is_object = dtype_is_object + * + * assert (&self.acquisition_count) % sizeof(__pyx_atomic_int_type) == 0 # <<<<<<<<<<<<<< + * self.typeinfo = NULL + * + */ + #ifndef CYTHON_WITHOUT_ASSERTIONS + if (unlikely(__pyx_assertions_enabled())) { + __pyx_t_4 = ((Py_intptr_t)((void *)(&__pyx_v_self->acquisition_count))); + __pyx_t_5 = (sizeof(__pyx_atomic_int_type)); + if (unlikely(__pyx_t_5 == 0)) { + PyErr_SetString(PyExc_ZeroDivisionError, "integer division or modulo by zero"); + __PYX_ERR(1, 373, __pyx_L1_error) + } + __pyx_t_1 = ((__pyx_t_4 % __pyx_t_5) == 0); + if (unlikely(!__pyx_t_1)) { + __Pyx_Raise(__pyx_builtin_AssertionError, 0, 0, 0); + __PYX_ERR(1, 373, __pyx_L1_error) + } + } + #else + if ((1)); else __PYX_ERR(1, 373, __pyx_L1_error) + #endif + + /* "View.MemoryView":374 + * + * assert (&self.acquisition_count) % sizeof(__pyx_atomic_int_type) == 0 + * self.typeinfo = NULL # <<<<<<<<<<<<<< + * + * def __dealloc__(memoryview self): + */ + __pyx_v_self->typeinfo = NULL; + + /* "View.MemoryView":349 + * cdef __Pyx_TypeInfo *typeinfo + * + * def __cinit__(memoryview self, object obj, int flags, bint dtype_is_object=False): # <<<<<<<<<<<<<< + * self.obj = obj + * self.flags = flags + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView.memoryview.__cinit__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":376 + * self.typeinfo = NULL + * + * def __dealloc__(memoryview self): # <<<<<<<<<<<<<< + * if self.obj is not None: + * __Pyx_ReleaseBuffer(&self.view) + */ + +/* Python wrapper */ +static void __pyx_memoryview___dealloc__(PyObject *__pyx_v_self); /*proto*/ +static void __pyx_memoryview___dealloc__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__dealloc__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_2__dealloc__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); +} + +static void __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_2__dealloc__(struct __pyx_memoryview_obj *__pyx_v_self) { + int __pyx_v_i; + int __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + int __pyx_t_4; + PyThread_type_lock __pyx_t_5; + PyThread_type_lock __pyx_t_6; + + /* "View.MemoryView":377 + * + * def __dealloc__(memoryview self): + * if self.obj is not None: # <<<<<<<<<<<<<< + * __Pyx_ReleaseBuffer(&self.view) + * elif (<__pyx_buffer *> &self.view).obj == Py_None: + */ + __pyx_t_1 = (__pyx_v_self->obj != Py_None); + if (__pyx_t_1) { + + /* "View.MemoryView":378 + * def __dealloc__(memoryview self): + * if self.obj is not None: + * __Pyx_ReleaseBuffer(&self.view) # <<<<<<<<<<<<<< + * elif (<__pyx_buffer *> &self.view).obj == Py_None: + * + */ + __Pyx_ReleaseBuffer((&__pyx_v_self->view)); + + /* "View.MemoryView":377 + * + * def __dealloc__(memoryview self): + * if self.obj is not None: # <<<<<<<<<<<<<< + * __Pyx_ReleaseBuffer(&self.view) + * elif (<__pyx_buffer *> &self.view).obj == Py_None: + */ + goto __pyx_L3; + } + + /* "View.MemoryView":379 + * if self.obj is not None: + * __Pyx_ReleaseBuffer(&self.view) + * elif (<__pyx_buffer *> &self.view).obj == Py_None: # <<<<<<<<<<<<<< + * + * (<__pyx_buffer *> &self.view).obj = NULL + */ + __pyx_t_1 = (((Py_buffer *)(&__pyx_v_self->view))->obj == Py_None); + if (__pyx_t_1) { + + /* "View.MemoryView":381 + * elif (<__pyx_buffer *> &self.view).obj == Py_None: + * + * (<__pyx_buffer *> &self.view).obj = NULL # <<<<<<<<<<<<<< + * Py_DECREF(Py_None) + * + */ + ((Py_buffer *)(&__pyx_v_self->view))->obj = NULL; + + /* "View.MemoryView":382 + * + * (<__pyx_buffer *> &self.view).obj = NULL + * Py_DECREF(Py_None) # <<<<<<<<<<<<<< + * + * cdef int i + */ + Py_DECREF(Py_None); + + /* "View.MemoryView":379 + * if self.obj is not None: + * __Pyx_ReleaseBuffer(&self.view) + * elif (<__pyx_buffer *> &self.view).obj == Py_None: # <<<<<<<<<<<<<< + * + * (<__pyx_buffer *> &self.view).obj = NULL + */ + } + __pyx_L3:; + + /* "View.MemoryView":386 + * cdef int i + * global __pyx_memoryview_thread_locks_used + * if self.lock != NULL: # <<<<<<<<<<<<<< + * for i in range(__pyx_memoryview_thread_locks_used): + * if __pyx_memoryview_thread_locks[i] is self.lock: + */ + __pyx_t_1 = (__pyx_v_self->lock != NULL); + if (__pyx_t_1) { + + /* "View.MemoryView":387 + * global __pyx_memoryview_thread_locks_used + * if self.lock != NULL: + * for i in range(__pyx_memoryview_thread_locks_used): # <<<<<<<<<<<<<< + * if __pyx_memoryview_thread_locks[i] is self.lock: + * __pyx_memoryview_thread_locks_used -= 1 + */ + __pyx_t_2 = __pyx_memoryview_thread_locks_used; + __pyx_t_3 = __pyx_t_2; + for (__pyx_t_4 = 0; __pyx_t_4 < __pyx_t_3; __pyx_t_4+=1) { + __pyx_v_i = __pyx_t_4; + + /* "View.MemoryView":388 + * if self.lock != NULL: + * for i in range(__pyx_memoryview_thread_locks_used): + * if __pyx_memoryview_thread_locks[i] is self.lock: # <<<<<<<<<<<<<< + * __pyx_memoryview_thread_locks_used -= 1 + * if i != __pyx_memoryview_thread_locks_used: + */ + __pyx_t_1 = ((__pyx_memoryview_thread_locks[__pyx_v_i]) == __pyx_v_self->lock); + if (__pyx_t_1) { + + /* "View.MemoryView":389 + * for i in range(__pyx_memoryview_thread_locks_used): + * if __pyx_memoryview_thread_locks[i] is self.lock: + * __pyx_memoryview_thread_locks_used -= 1 # <<<<<<<<<<<<<< + * if i != __pyx_memoryview_thread_locks_used: + * __pyx_memoryview_thread_locks[i], __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] = ( + */ + __pyx_memoryview_thread_locks_used = (__pyx_memoryview_thread_locks_used - 1); + + /* "View.MemoryView":390 + * if __pyx_memoryview_thread_locks[i] is self.lock: + * __pyx_memoryview_thread_locks_used -= 1 + * if i != __pyx_memoryview_thread_locks_used: # <<<<<<<<<<<<<< + * __pyx_memoryview_thread_locks[i], __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] = ( + * __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used], __pyx_memoryview_thread_locks[i]) + */ + __pyx_t_1 = (__pyx_v_i != __pyx_memoryview_thread_locks_used); + if (__pyx_t_1) { + + /* "View.MemoryView":392 + * if i != __pyx_memoryview_thread_locks_used: + * __pyx_memoryview_thread_locks[i], __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] = ( + * __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used], __pyx_memoryview_thread_locks[i]) # <<<<<<<<<<<<<< + * break + * else: + */ + __pyx_t_5 = (__pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used]); + __pyx_t_6 = (__pyx_memoryview_thread_locks[__pyx_v_i]); + + /* "View.MemoryView":391 + * __pyx_memoryview_thread_locks_used -= 1 + * if i != __pyx_memoryview_thread_locks_used: + * __pyx_memoryview_thread_locks[i], __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] = ( # <<<<<<<<<<<<<< + * __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used], __pyx_memoryview_thread_locks[i]) + * break + */ + (__pyx_memoryview_thread_locks[__pyx_v_i]) = __pyx_t_5; + (__pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used]) = __pyx_t_6; + + /* "View.MemoryView":390 + * if __pyx_memoryview_thread_locks[i] is self.lock: + * __pyx_memoryview_thread_locks_used -= 1 + * if i != __pyx_memoryview_thread_locks_used: # <<<<<<<<<<<<<< + * __pyx_memoryview_thread_locks[i], __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] = ( + * __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used], __pyx_memoryview_thread_locks[i]) + */ + } + + /* "View.MemoryView":393 + * __pyx_memoryview_thread_locks[i], __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] = ( + * __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used], __pyx_memoryview_thread_locks[i]) + * break # <<<<<<<<<<<<<< + * else: + * PyThread_free_lock(self.lock) + */ + goto __pyx_L6_break; + + /* "View.MemoryView":388 + * if self.lock != NULL: + * for i in range(__pyx_memoryview_thread_locks_used): + * if __pyx_memoryview_thread_locks[i] is self.lock: # <<<<<<<<<<<<<< + * __pyx_memoryview_thread_locks_used -= 1 + * if i != __pyx_memoryview_thread_locks_used: + */ + } + } + /*else*/ { + + /* "View.MemoryView":395 + * break + * else: + * PyThread_free_lock(self.lock) # <<<<<<<<<<<<<< + * + * cdef char *get_item_pointer(memoryview self, object index) except NULL: + */ + PyThread_free_lock(__pyx_v_self->lock); + } + __pyx_L6_break:; + + /* "View.MemoryView":386 + * cdef int i + * global __pyx_memoryview_thread_locks_used + * if self.lock != NULL: # <<<<<<<<<<<<<< + * for i in range(__pyx_memoryview_thread_locks_used): + * if __pyx_memoryview_thread_locks[i] is self.lock: + */ + } + + /* "View.MemoryView":376 + * self.typeinfo = NULL + * + * def __dealloc__(memoryview self): # <<<<<<<<<<<<<< + * if self.obj is not None: + * __Pyx_ReleaseBuffer(&self.view) + */ + + /* function exit code */ +} + +/* "View.MemoryView":397 + * PyThread_free_lock(self.lock) + * + * cdef char *get_item_pointer(memoryview self, object index) except NULL: # <<<<<<<<<<<<<< + * cdef Py_ssize_t dim + * cdef char *itemp = self.view.buf + */ + +static char *__pyx_memoryview_get_item_pointer(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index) { + Py_ssize_t __pyx_v_dim; + char *__pyx_v_itemp; + PyObject *__pyx_v_idx = NULL; + char *__pyx_r; + __Pyx_RefNannyDeclarations + Py_ssize_t __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + Py_ssize_t __pyx_t_3; + PyObject *(*__pyx_t_4)(PyObject *); + PyObject *__pyx_t_5 = NULL; + Py_ssize_t __pyx_t_6; + char *__pyx_t_7; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("get_item_pointer", 1); + + /* "View.MemoryView":399 + * cdef char *get_item_pointer(memoryview self, object index) except NULL: + * cdef Py_ssize_t dim + * cdef char *itemp = self.view.buf # <<<<<<<<<<<<<< + * + * for dim, idx in enumerate(index): + */ + __pyx_v_itemp = ((char *)__pyx_v_self->view.buf); + + /* "View.MemoryView":401 + * cdef char *itemp = self.view.buf + * + * for dim, idx in enumerate(index): # <<<<<<<<<<<<<< + * itemp = pybuffer_index(&self.view, itemp, idx, dim) + * + */ + __pyx_t_1 = 0; + if (likely(PyList_CheckExact(__pyx_v_index)) || PyTuple_CheckExact(__pyx_v_index)) { + __pyx_t_2 = __pyx_v_index; __Pyx_INCREF(__pyx_t_2); + __pyx_t_3 = 0; + __pyx_t_4 = NULL; + } else { + __pyx_t_3 = -1; __pyx_t_2 = PyObject_GetIter(__pyx_v_index); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 401, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_4 = __Pyx_PyObject_GetIterNextFunc(__pyx_t_2); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 401, __pyx_L1_error) + } + for (;;) { + if (likely(!__pyx_t_4)) { + if (likely(PyList_CheckExact(__pyx_t_2))) { + { + Py_ssize_t __pyx_temp = __Pyx_PyList_GET_SIZE(__pyx_t_2); + #if !CYTHON_ASSUME_SAFE_MACROS + if (unlikely((__pyx_temp < 0))) __PYX_ERR(1, 401, __pyx_L1_error) + #endif + if (__pyx_t_3 >= __pyx_temp) break; + } + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + __pyx_t_5 = PyList_GET_ITEM(__pyx_t_2, __pyx_t_3); __Pyx_INCREF(__pyx_t_5); __pyx_t_3++; if (unlikely((0 < 0))) __PYX_ERR(1, 401, __pyx_L1_error) + #else + __pyx_t_5 = __Pyx_PySequence_ITEM(__pyx_t_2, __pyx_t_3); __pyx_t_3++; if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 401, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + #endif + } else { + { + Py_ssize_t __pyx_temp = __Pyx_PyTuple_GET_SIZE(__pyx_t_2); + #if !CYTHON_ASSUME_SAFE_MACROS + if (unlikely((__pyx_temp < 0))) __PYX_ERR(1, 401, __pyx_L1_error) + #endif + if (__pyx_t_3 >= __pyx_temp) break; + } + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + __pyx_t_5 = PyTuple_GET_ITEM(__pyx_t_2, __pyx_t_3); __Pyx_INCREF(__pyx_t_5); __pyx_t_3++; if (unlikely((0 < 0))) __PYX_ERR(1, 401, __pyx_L1_error) + #else + __pyx_t_5 = __Pyx_PySequence_ITEM(__pyx_t_2, __pyx_t_3); __pyx_t_3++; if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 401, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + #endif + } + } else { + __pyx_t_5 = __pyx_t_4(__pyx_t_2); + if (unlikely(!__pyx_t_5)) { + PyObject* exc_type = PyErr_Occurred(); + if (exc_type) { + if (likely(__Pyx_PyErr_GivenExceptionMatches(exc_type, PyExc_StopIteration))) PyErr_Clear(); + else __PYX_ERR(1, 401, __pyx_L1_error) + } + break; + } + __Pyx_GOTREF(__pyx_t_5); + } + __Pyx_XDECREF_SET(__pyx_v_idx, __pyx_t_5); + __pyx_t_5 = 0; + __pyx_v_dim = __pyx_t_1; + __pyx_t_1 = (__pyx_t_1 + 1); + + /* "View.MemoryView":402 + * + * for dim, idx in enumerate(index): + * itemp = pybuffer_index(&self.view, itemp, idx, dim) # <<<<<<<<<<<<<< + * + * return itemp + */ + __pyx_t_6 = __Pyx_PyIndex_AsSsize_t(__pyx_v_idx); if (unlikely((__pyx_t_6 == (Py_ssize_t)-1) && PyErr_Occurred())) __PYX_ERR(1, 402, __pyx_L1_error) + __pyx_t_7 = __pyx_pybuffer_index((&__pyx_v_self->view), __pyx_v_itemp, __pyx_t_6, __pyx_v_dim); if (unlikely(__pyx_t_7 == ((char *)NULL))) __PYX_ERR(1, 402, __pyx_L1_error) + __pyx_v_itemp = __pyx_t_7; + + /* "View.MemoryView":401 + * cdef char *itemp = self.view.buf + * + * for dim, idx in enumerate(index): # <<<<<<<<<<<<<< + * itemp = pybuffer_index(&self.view, itemp, idx, dim) + * + */ + } + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + + /* "View.MemoryView":404 + * itemp = pybuffer_index(&self.view, itemp, idx, dim) + * + * return itemp # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = __pyx_v_itemp; + goto __pyx_L0; + + /* "View.MemoryView":397 + * PyThread_free_lock(self.lock) + * + * cdef char *get_item_pointer(memoryview self, object index) except NULL: # <<<<<<<<<<<<<< + * cdef Py_ssize_t dim + * cdef char *itemp = self.view.buf + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_5); + __Pyx_AddTraceback("View.MemoryView.memoryview.get_item_pointer", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_idx); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":407 + * + * + * def __getitem__(memoryview self, object index): # <<<<<<<<<<<<<< + * if index is Ellipsis: + * return self + */ + +/* Python wrapper */ +static PyObject *__pyx_memoryview___getitem__(PyObject *__pyx_v_self, PyObject *__pyx_v_index); /*proto*/ +static PyObject *__pyx_memoryview___getitem__(PyObject *__pyx_v_self, PyObject *__pyx_v_index) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__getitem__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_4__getitem__(((struct __pyx_memoryview_obj *)__pyx_v_self), ((PyObject *)__pyx_v_index)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_4__getitem__(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index) { + PyObject *__pyx_v_have_slices = NULL; + PyObject *__pyx_v_indices = NULL; + char *__pyx_v_itemp; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + char *__pyx_t_5; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__getitem__", 1); + + /* "View.MemoryView":408 + * + * def __getitem__(memoryview self, object index): + * if index is Ellipsis: # <<<<<<<<<<<<<< + * return self + * + */ + __pyx_t_1 = (__pyx_v_index == __pyx_builtin_Ellipsis); + if (__pyx_t_1) { + + /* "View.MemoryView":409 + * def __getitem__(memoryview self, object index): + * if index is Ellipsis: + * return self # <<<<<<<<<<<<<< + * + * have_slices, indices = _unellipsify(index, self.view.ndim) + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF((PyObject *)__pyx_v_self); + __pyx_r = ((PyObject *)__pyx_v_self); + goto __pyx_L0; + + /* "View.MemoryView":408 + * + * def __getitem__(memoryview self, object index): + * if index is Ellipsis: # <<<<<<<<<<<<<< + * return self + * + */ + } + + /* "View.MemoryView":411 + * return self + * + * have_slices, indices = _unellipsify(index, self.view.ndim) # <<<<<<<<<<<<<< + * + * cdef char *itemp + */ + __pyx_t_2 = _unellipsify(__pyx_v_index, __pyx_v_self->view.ndim); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 411, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + if (likely(__pyx_t_2 != Py_None)) { + PyObject* sequence = __pyx_t_2; + Py_ssize_t size = __Pyx_PySequence_SIZE(sequence); + if (unlikely(size != 2)) { + if (size > 2) __Pyx_RaiseTooManyValuesError(2); + else if (size >= 0) __Pyx_RaiseNeedMoreValuesError(size); + __PYX_ERR(1, 411, __pyx_L1_error) + } + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + __pyx_t_3 = PyTuple_GET_ITEM(sequence, 0); + __pyx_t_4 = PyTuple_GET_ITEM(sequence, 1); + __Pyx_INCREF(__pyx_t_3); + __Pyx_INCREF(__pyx_t_4); + #else + __pyx_t_3 = PySequence_ITEM(sequence, 0); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 411, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_4 = PySequence_ITEM(sequence, 1); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 411, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + #endif + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + } else { + __Pyx_RaiseNoneNotIterableError(); __PYX_ERR(1, 411, __pyx_L1_error) + } + __pyx_v_have_slices = __pyx_t_3; + __pyx_t_3 = 0; + __pyx_v_indices = __pyx_t_4; + __pyx_t_4 = 0; + + /* "View.MemoryView":414 + * + * cdef char *itemp + * if have_slices: # <<<<<<<<<<<<<< + * return memview_slice(self, indices) + * else: + */ + __pyx_t_1 = __Pyx_PyObject_IsTrue(__pyx_v_have_slices); if (unlikely((__pyx_t_1 < 0))) __PYX_ERR(1, 414, __pyx_L1_error) + if (__pyx_t_1) { + + /* "View.MemoryView":415 + * cdef char *itemp + * if have_slices: + * return memview_slice(self, indices) # <<<<<<<<<<<<<< + * else: + * itemp = self.get_item_pointer(indices) + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = ((PyObject *)__pyx_memview_slice(__pyx_v_self, __pyx_v_indices)); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 415, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":414 + * + * cdef char *itemp + * if have_slices: # <<<<<<<<<<<<<< + * return memview_slice(self, indices) + * else: + */ + } + + /* "View.MemoryView":417 + * return memview_slice(self, indices) + * else: + * itemp = self.get_item_pointer(indices) # <<<<<<<<<<<<<< + * return self.convert_item_to_object(itemp) + * + */ + /*else*/ { + __pyx_t_5 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->get_item_pointer(__pyx_v_self, __pyx_v_indices); if (unlikely(__pyx_t_5 == ((char *)NULL))) __PYX_ERR(1, 417, __pyx_L1_error) + __pyx_v_itemp = __pyx_t_5; + + /* "View.MemoryView":418 + * else: + * itemp = self.get_item_pointer(indices) + * return self.convert_item_to_object(itemp) # <<<<<<<<<<<<<< + * + * def __setitem__(memoryview self, object index, object value): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->convert_item_to_object(__pyx_v_self, __pyx_v_itemp); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 418, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + } + + /* "View.MemoryView":407 + * + * + * def __getitem__(memoryview self, object index): # <<<<<<<<<<<<<< + * if index is Ellipsis: + * return self + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_XDECREF(__pyx_t_4); + __Pyx_AddTraceback("View.MemoryView.memoryview.__getitem__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_have_slices); + __Pyx_XDECREF(__pyx_v_indices); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":420 + * return self.convert_item_to_object(itemp) + * + * def __setitem__(memoryview self, object index, object value): # <<<<<<<<<<<<<< + * if self.view.readonly: + * raise TypeError, "Cannot assign to read-only memoryview" + */ + +/* Python wrapper */ +static int __pyx_memoryview___setitem__(PyObject *__pyx_v_self, PyObject *__pyx_v_index, PyObject *__pyx_v_value); /*proto*/ +static int __pyx_memoryview___setitem__(PyObject *__pyx_v_self, PyObject *__pyx_v_index, PyObject *__pyx_v_value) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + int __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__setitem__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_6__setitem__(((struct __pyx_memoryview_obj *)__pyx_v_self), ((PyObject *)__pyx_v_index), ((PyObject *)__pyx_v_value)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static int __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_6__setitem__(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index, PyObject *__pyx_v_value) { + PyObject *__pyx_v_have_slices = NULL; + PyObject *__pyx_v_obj = NULL; + int __pyx_r; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + int __pyx_t_4; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__setitem__", 0); + __Pyx_INCREF(__pyx_v_index); + + /* "View.MemoryView":421 + * + * def __setitem__(memoryview self, object index, object value): + * if self.view.readonly: # <<<<<<<<<<<<<< + * raise TypeError, "Cannot assign to read-only memoryview" + * + */ + if (unlikely(__pyx_v_self->view.readonly)) { + + /* "View.MemoryView":422 + * def __setitem__(memoryview self, object index, object value): + * if self.view.readonly: + * raise TypeError, "Cannot assign to read-only memoryview" # <<<<<<<<<<<<<< + * + * have_slices, index = _unellipsify(index, self.view.ndim) + */ + __Pyx_Raise(__pyx_builtin_TypeError, __pyx_kp_s_Cannot_assign_to_read_only_memor, 0, 0); + __PYX_ERR(1, 422, __pyx_L1_error) + + /* "View.MemoryView":421 + * + * def __setitem__(memoryview self, object index, object value): + * if self.view.readonly: # <<<<<<<<<<<<<< + * raise TypeError, "Cannot assign to read-only memoryview" + * + */ + } + + /* "View.MemoryView":424 + * raise TypeError, "Cannot assign to read-only memoryview" + * + * have_slices, index = _unellipsify(index, self.view.ndim) # <<<<<<<<<<<<<< + * + * if have_slices: + */ + __pyx_t_1 = _unellipsify(__pyx_v_index, __pyx_v_self->view.ndim); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 424, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + if (likely(__pyx_t_1 != Py_None)) { + PyObject* sequence = __pyx_t_1; + Py_ssize_t size = __Pyx_PySequence_SIZE(sequence); + if (unlikely(size != 2)) { + if (size > 2) __Pyx_RaiseTooManyValuesError(2); + else if (size >= 0) __Pyx_RaiseNeedMoreValuesError(size); + __PYX_ERR(1, 424, __pyx_L1_error) + } + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + __pyx_t_2 = PyTuple_GET_ITEM(sequence, 0); + __pyx_t_3 = PyTuple_GET_ITEM(sequence, 1); + __Pyx_INCREF(__pyx_t_2); + __Pyx_INCREF(__pyx_t_3); + #else + __pyx_t_2 = PySequence_ITEM(sequence, 0); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 424, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_3 = PySequence_ITEM(sequence, 1); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 424, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + #endif + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + } else { + __Pyx_RaiseNoneNotIterableError(); __PYX_ERR(1, 424, __pyx_L1_error) + } + __pyx_v_have_slices = __pyx_t_2; + __pyx_t_2 = 0; + __Pyx_DECREF_SET(__pyx_v_index, __pyx_t_3); + __pyx_t_3 = 0; + + /* "View.MemoryView":426 + * have_slices, index = _unellipsify(index, self.view.ndim) + * + * if have_slices: # <<<<<<<<<<<<<< + * obj = self.is_slice(value) + * if obj: + */ + __pyx_t_4 = __Pyx_PyObject_IsTrue(__pyx_v_have_slices); if (unlikely((__pyx_t_4 < 0))) __PYX_ERR(1, 426, __pyx_L1_error) + if (__pyx_t_4) { + + /* "View.MemoryView":427 + * + * if have_slices: + * obj = self.is_slice(value) # <<<<<<<<<<<<<< + * if obj: + * self.setitem_slice_assignment(self[index], obj) + */ + __pyx_t_1 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->is_slice(__pyx_v_self, __pyx_v_value); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 427, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_v_obj = __pyx_t_1; + __pyx_t_1 = 0; + + /* "View.MemoryView":428 + * if have_slices: + * obj = self.is_slice(value) + * if obj: # <<<<<<<<<<<<<< + * self.setitem_slice_assignment(self[index], obj) + * else: + */ + __pyx_t_4 = __Pyx_PyObject_IsTrue(__pyx_v_obj); if (unlikely((__pyx_t_4 < 0))) __PYX_ERR(1, 428, __pyx_L1_error) + if (__pyx_t_4) { + + /* "View.MemoryView":429 + * obj = self.is_slice(value) + * if obj: + * self.setitem_slice_assignment(self[index], obj) # <<<<<<<<<<<<<< + * else: + * self.setitem_slice_assign_scalar(self[index], value) + */ + __pyx_t_1 = __Pyx_PyObject_GetItem(((PyObject *)__pyx_v_self), __pyx_v_index); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 429, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_3 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->setitem_slice_assignment(__pyx_v_self, __pyx_t_1, __pyx_v_obj); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 429, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + + /* "View.MemoryView":428 + * if have_slices: + * obj = self.is_slice(value) + * if obj: # <<<<<<<<<<<<<< + * self.setitem_slice_assignment(self[index], obj) + * else: + */ + goto __pyx_L5; + } + + /* "View.MemoryView":431 + * self.setitem_slice_assignment(self[index], obj) + * else: + * self.setitem_slice_assign_scalar(self[index], value) # <<<<<<<<<<<<<< + * else: + * self.setitem_indexed(index, value) + */ + /*else*/ { + __pyx_t_3 = __Pyx_PyObject_GetItem(((PyObject *)__pyx_v_self), __pyx_v_index); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 431, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + if (!(likely(((__pyx_t_3) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_3, __pyx_memoryview_type))))) __PYX_ERR(1, 431, __pyx_L1_error) + __pyx_t_1 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->setitem_slice_assign_scalar(__pyx_v_self, ((struct __pyx_memoryview_obj *)__pyx_t_3), __pyx_v_value); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 431, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + } + __pyx_L5:; + + /* "View.MemoryView":426 + * have_slices, index = _unellipsify(index, self.view.ndim) + * + * if have_slices: # <<<<<<<<<<<<<< + * obj = self.is_slice(value) + * if obj: + */ + goto __pyx_L4; + } + + /* "View.MemoryView":433 + * self.setitem_slice_assign_scalar(self[index], value) + * else: + * self.setitem_indexed(index, value) # <<<<<<<<<<<<<< + * + * cdef is_slice(self, obj): + */ + /*else*/ { + __pyx_t_1 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->setitem_indexed(__pyx_v_self, __pyx_v_index, __pyx_v_value); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 433, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + } + __pyx_L4:; + + /* "View.MemoryView":420 + * return self.convert_item_to_object(itemp) + * + * def __setitem__(memoryview self, object index, object value): # <<<<<<<<<<<<<< + * if self.view.readonly: + * raise TypeError, "Cannot assign to read-only memoryview" + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_AddTraceback("View.MemoryView.memoryview.__setitem__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_have_slices); + __Pyx_XDECREF(__pyx_v_obj); + __Pyx_XDECREF(__pyx_v_index); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":435 + * self.setitem_indexed(index, value) + * + * cdef is_slice(self, obj): # <<<<<<<<<<<<<< + * if not isinstance(obj, memoryview): + * try: + */ + +static PyObject *__pyx_memoryview_is_slice(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_obj) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + PyObject *__pyx_t_5 = NULL; + PyObject *__pyx_t_6 = NULL; + PyObject *__pyx_t_7 = NULL; + PyObject *__pyx_t_8 = NULL; + int __pyx_t_9; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("is_slice", 0); + __Pyx_INCREF(__pyx_v_obj); + + /* "View.MemoryView":436 + * + * cdef is_slice(self, obj): + * if not isinstance(obj, memoryview): # <<<<<<<<<<<<<< + * try: + * obj = memoryview(obj, self.flags & ~PyBUF_WRITABLE | PyBUF_ANY_CONTIGUOUS, + */ + __pyx_t_1 = __Pyx_TypeCheck(__pyx_v_obj, __pyx_memoryview_type); + __pyx_t_2 = (!__pyx_t_1); + if (__pyx_t_2) { + + /* "View.MemoryView":437 + * cdef is_slice(self, obj): + * if not isinstance(obj, memoryview): + * try: # <<<<<<<<<<<<<< + * obj = memoryview(obj, self.flags & ~PyBUF_WRITABLE | PyBUF_ANY_CONTIGUOUS, + * self.dtype_is_object) + */ + { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ExceptionSave(&__pyx_t_3, &__pyx_t_4, &__pyx_t_5); + __Pyx_XGOTREF(__pyx_t_3); + __Pyx_XGOTREF(__pyx_t_4); + __Pyx_XGOTREF(__pyx_t_5); + /*try:*/ { + + /* "View.MemoryView":438 + * if not isinstance(obj, memoryview): + * try: + * obj = memoryview(obj, self.flags & ~PyBUF_WRITABLE | PyBUF_ANY_CONTIGUOUS, # <<<<<<<<<<<<<< + * self.dtype_is_object) + * except TypeError: + */ + __pyx_t_6 = __Pyx_PyInt_From_int(((__pyx_v_self->flags & (~PyBUF_WRITABLE)) | PyBUF_ANY_CONTIGUOUS)); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 438, __pyx_L4_error) + __Pyx_GOTREF(__pyx_t_6); + + /* "View.MemoryView":439 + * try: + * obj = memoryview(obj, self.flags & ~PyBUF_WRITABLE | PyBUF_ANY_CONTIGUOUS, + * self.dtype_is_object) # <<<<<<<<<<<<<< + * except TypeError: + * return None + */ + __pyx_t_7 = __Pyx_PyBool_FromLong(__pyx_v_self->dtype_is_object); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 439, __pyx_L4_error) + __Pyx_GOTREF(__pyx_t_7); + + /* "View.MemoryView":438 + * if not isinstance(obj, memoryview): + * try: + * obj = memoryview(obj, self.flags & ~PyBUF_WRITABLE | PyBUF_ANY_CONTIGUOUS, # <<<<<<<<<<<<<< + * self.dtype_is_object) + * except TypeError: + */ + __pyx_t_8 = PyTuple_New(3); if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 438, __pyx_L4_error) + __Pyx_GOTREF(__pyx_t_8); + __Pyx_INCREF(__pyx_v_obj); + __Pyx_GIVEREF(__pyx_v_obj); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_8, 0, __pyx_v_obj)) __PYX_ERR(1, 438, __pyx_L4_error); + __Pyx_GIVEREF(__pyx_t_6); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_8, 1, __pyx_t_6)) __PYX_ERR(1, 438, __pyx_L4_error); + __Pyx_GIVEREF(__pyx_t_7); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_8, 2, __pyx_t_7)) __PYX_ERR(1, 438, __pyx_L4_error); + __pyx_t_6 = 0; + __pyx_t_7 = 0; + __pyx_t_7 = __Pyx_PyObject_Call(((PyObject *)__pyx_memoryview_type), __pyx_t_8, NULL); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 438, __pyx_L4_error) + __Pyx_GOTREF(__pyx_t_7); + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + __Pyx_DECREF_SET(__pyx_v_obj, __pyx_t_7); + __pyx_t_7 = 0; + + /* "View.MemoryView":437 + * cdef is_slice(self, obj): + * if not isinstance(obj, memoryview): + * try: # <<<<<<<<<<<<<< + * obj = memoryview(obj, self.flags & ~PyBUF_WRITABLE | PyBUF_ANY_CONTIGUOUS, + * self.dtype_is_object) + */ + } + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; + goto __pyx_L9_try_end; + __pyx_L4_error:; + __Pyx_XDECREF(__pyx_t_6); __pyx_t_6 = 0; + __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; + __Pyx_XDECREF(__pyx_t_8); __pyx_t_8 = 0; + + /* "View.MemoryView":440 + * obj = memoryview(obj, self.flags & ~PyBUF_WRITABLE | PyBUF_ANY_CONTIGUOUS, + * self.dtype_is_object) + * except TypeError: # <<<<<<<<<<<<<< + * return None + * + */ + __pyx_t_9 = __Pyx_PyErr_ExceptionMatches(__pyx_builtin_TypeError); + if (__pyx_t_9) { + __Pyx_AddTraceback("View.MemoryView.memoryview.is_slice", __pyx_clineno, __pyx_lineno, __pyx_filename); + if (__Pyx_GetException(&__pyx_t_7, &__pyx_t_8, &__pyx_t_6) < 0) __PYX_ERR(1, 440, __pyx_L6_except_error) + __Pyx_XGOTREF(__pyx_t_7); + __Pyx_XGOTREF(__pyx_t_8); + __Pyx_XGOTREF(__pyx_t_6); + + /* "View.MemoryView":441 + * self.dtype_is_object) + * except TypeError: + * return None # <<<<<<<<<<<<<< + * + * return obj + */ + __Pyx_XDECREF(__pyx_r); + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + goto __pyx_L7_except_return; + } + goto __pyx_L6_except_error; + + /* "View.MemoryView":437 + * cdef is_slice(self, obj): + * if not isinstance(obj, memoryview): + * try: # <<<<<<<<<<<<<< + * obj = memoryview(obj, self.flags & ~PyBUF_WRITABLE | PyBUF_ANY_CONTIGUOUS, + * self.dtype_is_object) + */ + __pyx_L6_except_error:; + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_XGIVEREF(__pyx_t_4); + __Pyx_XGIVEREF(__pyx_t_5); + __Pyx_ExceptionReset(__pyx_t_3, __pyx_t_4, __pyx_t_5); + goto __pyx_L1_error; + __pyx_L7_except_return:; + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_XGIVEREF(__pyx_t_4); + __Pyx_XGIVEREF(__pyx_t_5); + __Pyx_ExceptionReset(__pyx_t_3, __pyx_t_4, __pyx_t_5); + goto __pyx_L0; + __pyx_L9_try_end:; + } + + /* "View.MemoryView":436 + * + * cdef is_slice(self, obj): + * if not isinstance(obj, memoryview): # <<<<<<<<<<<<<< + * try: + * obj = memoryview(obj, self.flags & ~PyBUF_WRITABLE | PyBUF_ANY_CONTIGUOUS, + */ + } + + /* "View.MemoryView":443 + * return None + * + * return obj # <<<<<<<<<<<<<< + * + * cdef setitem_slice_assignment(self, dst, src): + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(__pyx_v_obj); + __pyx_r = __pyx_v_obj; + goto __pyx_L0; + + /* "View.MemoryView":435 + * self.setitem_indexed(index, value) + * + * cdef is_slice(self, obj): # <<<<<<<<<<<<<< + * if not isinstance(obj, memoryview): + * try: + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_6); + __Pyx_XDECREF(__pyx_t_7); + __Pyx_XDECREF(__pyx_t_8); + __Pyx_AddTraceback("View.MemoryView.memoryview.is_slice", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_obj); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":445 + * return obj + * + * cdef setitem_slice_assignment(self, dst, src): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice dst_slice + * cdef __Pyx_memviewslice src_slice + */ + +static PyObject *__pyx_memoryview_setitem_slice_assignment(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_dst, PyObject *__pyx_v_src) { + __Pyx_memviewslice __pyx_v_dst_slice; + __Pyx_memviewslice __pyx_v_src_slice; + __Pyx_memviewslice __pyx_v_msrc; + __Pyx_memviewslice __pyx_v_mdst; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_memviewslice *__pyx_t_1; + PyObject *__pyx_t_2 = NULL; + int __pyx_t_3; + int __pyx_t_4; + int __pyx_t_5; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("setitem_slice_assignment", 1); + + /* "View.MemoryView":448 + * cdef __Pyx_memviewslice dst_slice + * cdef __Pyx_memviewslice src_slice + * cdef __Pyx_memviewslice msrc = get_slice_from_memview(src, &src_slice)[0] # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice mdst = get_slice_from_memview(dst, &dst_slice)[0] + * + */ + if (!(likely(((__pyx_v_src) == Py_None) || likely(__Pyx_TypeTest(__pyx_v_src, __pyx_memoryview_type))))) __PYX_ERR(1, 448, __pyx_L1_error) + __pyx_t_1 = __pyx_memoryview_get_slice_from_memoryview(((struct __pyx_memoryview_obj *)__pyx_v_src), (&__pyx_v_src_slice)); if (unlikely(__pyx_t_1 == ((__Pyx_memviewslice *)NULL))) __PYX_ERR(1, 448, __pyx_L1_error) + __pyx_v_msrc = (__pyx_t_1[0]); + + /* "View.MemoryView":449 + * cdef __Pyx_memviewslice src_slice + * cdef __Pyx_memviewslice msrc = get_slice_from_memview(src, &src_slice)[0] + * cdef __Pyx_memviewslice mdst = get_slice_from_memview(dst, &dst_slice)[0] # <<<<<<<<<<<<<< + * + * memoryview_copy_contents(msrc, mdst, src.ndim, dst.ndim, self.dtype_is_object) + */ + if (!(likely(((__pyx_v_dst) == Py_None) || likely(__Pyx_TypeTest(__pyx_v_dst, __pyx_memoryview_type))))) __PYX_ERR(1, 449, __pyx_L1_error) + __pyx_t_1 = __pyx_memoryview_get_slice_from_memoryview(((struct __pyx_memoryview_obj *)__pyx_v_dst), (&__pyx_v_dst_slice)); if (unlikely(__pyx_t_1 == ((__Pyx_memviewslice *)NULL))) __PYX_ERR(1, 449, __pyx_L1_error) + __pyx_v_mdst = (__pyx_t_1[0]); + + /* "View.MemoryView":451 + * cdef __Pyx_memviewslice mdst = get_slice_from_memview(dst, &dst_slice)[0] + * + * memoryview_copy_contents(msrc, mdst, src.ndim, dst.ndim, self.dtype_is_object) # <<<<<<<<<<<<<< + * + * cdef setitem_slice_assign_scalar(self, memoryview dst, value): + */ + __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_v_src, __pyx_n_s_ndim); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 451, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_3 = __Pyx_PyInt_As_int(__pyx_t_2); if (unlikely((__pyx_t_3 == (int)-1) && PyErr_Occurred())) __PYX_ERR(1, 451, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_v_dst, __pyx_n_s_ndim); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 451, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_4 = __Pyx_PyInt_As_int(__pyx_t_2); if (unlikely((__pyx_t_4 == (int)-1) && PyErr_Occurred())) __PYX_ERR(1, 451, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_t_5 = __pyx_memoryview_copy_contents(__pyx_v_msrc, __pyx_v_mdst, __pyx_t_3, __pyx_t_4, __pyx_v_self->dtype_is_object); if (unlikely(__pyx_t_5 == ((int)-1))) __PYX_ERR(1, 451, __pyx_L1_error) + + /* "View.MemoryView":445 + * return obj + * + * cdef setitem_slice_assignment(self, dst, src): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice dst_slice + * cdef __Pyx_memviewslice src_slice + */ + + /* function exit code */ + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.memoryview.setitem_slice_assignment", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":453 + * memoryview_copy_contents(msrc, mdst, src.ndim, dst.ndim, self.dtype_is_object) + * + * cdef setitem_slice_assign_scalar(self, memoryview dst, value): # <<<<<<<<<<<<<< + * cdef int array[128] + * cdef void *tmp = NULL + */ + +static PyObject *__pyx_memoryview_setitem_slice_assign_scalar(struct __pyx_memoryview_obj *__pyx_v_self, struct __pyx_memoryview_obj *__pyx_v_dst, PyObject *__pyx_v_value) { + int __pyx_v_array[0x80]; + void *__pyx_v_tmp; + void *__pyx_v_item; + __Pyx_memviewslice *__pyx_v_dst_slice; + __Pyx_memviewslice __pyx_v_tmp_slice; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_memviewslice *__pyx_t_1; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + int __pyx_t_4; + int __pyx_t_5; + char const *__pyx_t_6; + PyObject *__pyx_t_7 = NULL; + PyObject *__pyx_t_8 = NULL; + PyObject *__pyx_t_9 = NULL; + PyObject *__pyx_t_10 = NULL; + PyObject *__pyx_t_11 = NULL; + PyObject *__pyx_t_12 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("setitem_slice_assign_scalar", 1); + + /* "View.MemoryView":455 + * cdef setitem_slice_assign_scalar(self, memoryview dst, value): + * cdef int array[128] + * cdef void *tmp = NULL # <<<<<<<<<<<<<< + * cdef void *item + * + */ + __pyx_v_tmp = NULL; + + /* "View.MemoryView":460 + * cdef __Pyx_memviewslice *dst_slice + * cdef __Pyx_memviewslice tmp_slice + * dst_slice = get_slice_from_memview(dst, &tmp_slice) # <<<<<<<<<<<<<< + * + * if self.view.itemsize > sizeof(array): + */ + __pyx_t_1 = __pyx_memoryview_get_slice_from_memoryview(__pyx_v_dst, (&__pyx_v_tmp_slice)); if (unlikely(__pyx_t_1 == ((__Pyx_memviewslice *)NULL))) __PYX_ERR(1, 460, __pyx_L1_error) + __pyx_v_dst_slice = __pyx_t_1; + + /* "View.MemoryView":462 + * dst_slice = get_slice_from_memview(dst, &tmp_slice) + * + * if self.view.itemsize > sizeof(array): # <<<<<<<<<<<<<< + * tmp = PyMem_Malloc(self.view.itemsize) + * if tmp == NULL: + */ + __pyx_t_2 = (((size_t)__pyx_v_self->view.itemsize) > (sizeof(__pyx_v_array))); + if (__pyx_t_2) { + + /* "View.MemoryView":463 + * + * if self.view.itemsize > sizeof(array): + * tmp = PyMem_Malloc(self.view.itemsize) # <<<<<<<<<<<<<< + * if tmp == NULL: + * raise MemoryError + */ + __pyx_v_tmp = PyMem_Malloc(__pyx_v_self->view.itemsize); + + /* "View.MemoryView":464 + * if self.view.itemsize > sizeof(array): + * tmp = PyMem_Malloc(self.view.itemsize) + * if tmp == NULL: # <<<<<<<<<<<<<< + * raise MemoryError + * item = tmp + */ + __pyx_t_2 = (__pyx_v_tmp == NULL); + if (unlikely(__pyx_t_2)) { + + /* "View.MemoryView":465 + * tmp = PyMem_Malloc(self.view.itemsize) + * if tmp == NULL: + * raise MemoryError # <<<<<<<<<<<<<< + * item = tmp + * else: + */ + PyErr_NoMemory(); __PYX_ERR(1, 465, __pyx_L1_error) + + /* "View.MemoryView":464 + * if self.view.itemsize > sizeof(array): + * tmp = PyMem_Malloc(self.view.itemsize) + * if tmp == NULL: # <<<<<<<<<<<<<< + * raise MemoryError + * item = tmp + */ + } + + /* "View.MemoryView":466 + * if tmp == NULL: + * raise MemoryError + * item = tmp # <<<<<<<<<<<<<< + * else: + * item = array + */ + __pyx_v_item = __pyx_v_tmp; + + /* "View.MemoryView":462 + * dst_slice = get_slice_from_memview(dst, &tmp_slice) + * + * if self.view.itemsize > sizeof(array): # <<<<<<<<<<<<<< + * tmp = PyMem_Malloc(self.view.itemsize) + * if tmp == NULL: + */ + goto __pyx_L3; + } + + /* "View.MemoryView":468 + * item = tmp + * else: + * item = array # <<<<<<<<<<<<<< + * + * try: + */ + /*else*/ { + __pyx_v_item = ((void *)__pyx_v_array); + } + __pyx_L3:; + + /* "View.MemoryView":470 + * item = array + * + * try: # <<<<<<<<<<<<<< + * if self.dtype_is_object: + * ( item)[0] = value + */ + /*try:*/ { + + /* "View.MemoryView":471 + * + * try: + * if self.dtype_is_object: # <<<<<<<<<<<<<< + * ( item)[0] = value + * else: + */ + if (__pyx_v_self->dtype_is_object) { + + /* "View.MemoryView":472 + * try: + * if self.dtype_is_object: + * ( item)[0] = value # <<<<<<<<<<<<<< + * else: + * self.assign_item_from_object( item, value) + */ + (((PyObject **)__pyx_v_item)[0]) = ((PyObject *)__pyx_v_value); + + /* "View.MemoryView":471 + * + * try: + * if self.dtype_is_object: # <<<<<<<<<<<<<< + * ( item)[0] = value + * else: + */ + goto __pyx_L8; + } + + /* "View.MemoryView":474 + * ( item)[0] = value + * else: + * self.assign_item_from_object( item, value) # <<<<<<<<<<<<<< + * + * + */ + /*else*/ { + __pyx_t_3 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->assign_item_from_object(__pyx_v_self, ((char *)__pyx_v_item), __pyx_v_value); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 474, __pyx_L6_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + } + __pyx_L8:; + + /* "View.MemoryView":478 + * + * + * if self.view.suboffsets != NULL: # <<<<<<<<<<<<<< + * assert_direct_dimensions(self.view.suboffsets, self.view.ndim) + * slice_assign_scalar(dst_slice, dst.view.ndim, self.view.itemsize, + */ + __pyx_t_2 = (__pyx_v_self->view.suboffsets != NULL); + if (__pyx_t_2) { + + /* "View.MemoryView":479 + * + * if self.view.suboffsets != NULL: + * assert_direct_dimensions(self.view.suboffsets, self.view.ndim) # <<<<<<<<<<<<<< + * slice_assign_scalar(dst_slice, dst.view.ndim, self.view.itemsize, + * item, self.dtype_is_object) + */ + __pyx_t_4 = assert_direct_dimensions(__pyx_v_self->view.suboffsets, __pyx_v_self->view.ndim); if (unlikely(__pyx_t_4 == ((int)-1))) __PYX_ERR(1, 479, __pyx_L6_error) + + /* "View.MemoryView":478 + * + * + * if self.view.suboffsets != NULL: # <<<<<<<<<<<<<< + * assert_direct_dimensions(self.view.suboffsets, self.view.ndim) + * slice_assign_scalar(dst_slice, dst.view.ndim, self.view.itemsize, + */ + } + + /* "View.MemoryView":480 + * if self.view.suboffsets != NULL: + * assert_direct_dimensions(self.view.suboffsets, self.view.ndim) + * slice_assign_scalar(dst_slice, dst.view.ndim, self.view.itemsize, # <<<<<<<<<<<<<< + * item, self.dtype_is_object) + * finally: + */ + __pyx_memoryview_slice_assign_scalar(__pyx_v_dst_slice, __pyx_v_dst->view.ndim, __pyx_v_self->view.itemsize, __pyx_v_item, __pyx_v_self->dtype_is_object); + } + + /* "View.MemoryView":483 + * item, self.dtype_is_object) + * finally: + * PyMem_Free(tmp) # <<<<<<<<<<<<<< + * + * cdef setitem_indexed(self, index, value): + */ + /*finally:*/ { + /*normal exit:*/{ + PyMem_Free(__pyx_v_tmp); + goto __pyx_L7; + } + __pyx_L6_error:; + /*exception exit:*/{ + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __pyx_t_7 = 0; __pyx_t_8 = 0; __pyx_t_9 = 0; __pyx_t_10 = 0; __pyx_t_11 = 0; __pyx_t_12 = 0; + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + if (PY_MAJOR_VERSION >= 3) __Pyx_ExceptionSwap(&__pyx_t_10, &__pyx_t_11, &__pyx_t_12); + if ((PY_MAJOR_VERSION < 3) || unlikely(__Pyx_GetException(&__pyx_t_7, &__pyx_t_8, &__pyx_t_9) < 0)) __Pyx_ErrFetch(&__pyx_t_7, &__pyx_t_8, &__pyx_t_9); + __Pyx_XGOTREF(__pyx_t_7); + __Pyx_XGOTREF(__pyx_t_8); + __Pyx_XGOTREF(__pyx_t_9); + __Pyx_XGOTREF(__pyx_t_10); + __Pyx_XGOTREF(__pyx_t_11); + __Pyx_XGOTREF(__pyx_t_12); + __pyx_t_4 = __pyx_lineno; __pyx_t_5 = __pyx_clineno; __pyx_t_6 = __pyx_filename; + { + PyMem_Free(__pyx_v_tmp); + } + if (PY_MAJOR_VERSION >= 3) { + __Pyx_XGIVEREF(__pyx_t_10); + __Pyx_XGIVEREF(__pyx_t_11); + __Pyx_XGIVEREF(__pyx_t_12); + __Pyx_ExceptionReset(__pyx_t_10, __pyx_t_11, __pyx_t_12); + } + __Pyx_XGIVEREF(__pyx_t_7); + __Pyx_XGIVEREF(__pyx_t_8); + __Pyx_XGIVEREF(__pyx_t_9); + __Pyx_ErrRestore(__pyx_t_7, __pyx_t_8, __pyx_t_9); + __pyx_t_7 = 0; __pyx_t_8 = 0; __pyx_t_9 = 0; __pyx_t_10 = 0; __pyx_t_11 = 0; __pyx_t_12 = 0; + __pyx_lineno = __pyx_t_4; __pyx_clineno = __pyx_t_5; __pyx_filename = __pyx_t_6; + goto __pyx_L1_error; + } + __pyx_L7:; + } + + /* "View.MemoryView":453 + * memoryview_copy_contents(msrc, mdst, src.ndim, dst.ndim, self.dtype_is_object) + * + * cdef setitem_slice_assign_scalar(self, memoryview dst, value): # <<<<<<<<<<<<<< + * cdef int array[128] + * cdef void *tmp = NULL + */ + + /* function exit code */ + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_3); + __Pyx_AddTraceback("View.MemoryView.memoryview.setitem_slice_assign_scalar", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":485 + * PyMem_Free(tmp) + * + * cdef setitem_indexed(self, index, value): # <<<<<<<<<<<<<< + * cdef char *itemp = self.get_item_pointer(index) + * self.assign_item_from_object(itemp, value) + */ + +static PyObject *__pyx_memoryview_setitem_indexed(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index, PyObject *__pyx_v_value) { + char *__pyx_v_itemp; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + char *__pyx_t_1; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("setitem_indexed", 1); + + /* "View.MemoryView":486 + * + * cdef setitem_indexed(self, index, value): + * cdef char *itemp = self.get_item_pointer(index) # <<<<<<<<<<<<<< + * self.assign_item_from_object(itemp, value) + * + */ + __pyx_t_1 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->get_item_pointer(__pyx_v_self, __pyx_v_index); if (unlikely(__pyx_t_1 == ((char *)NULL))) __PYX_ERR(1, 486, __pyx_L1_error) + __pyx_v_itemp = __pyx_t_1; + + /* "View.MemoryView":487 + * cdef setitem_indexed(self, index, value): + * cdef char *itemp = self.get_item_pointer(index) + * self.assign_item_from_object(itemp, value) # <<<<<<<<<<<<<< + * + * cdef convert_item_to_object(self, char *itemp): + */ + __pyx_t_2 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->assign_item_from_object(__pyx_v_self, __pyx_v_itemp, __pyx_v_value); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 487, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + + /* "View.MemoryView":485 + * PyMem_Free(tmp) + * + * cdef setitem_indexed(self, index, value): # <<<<<<<<<<<<<< + * cdef char *itemp = self.get_item_pointer(index) + * self.assign_item_from_object(itemp, value) + */ + + /* function exit code */ + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.memoryview.setitem_indexed", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":489 + * self.assign_item_from_object(itemp, value) + * + * cdef convert_item_to_object(self, char *itemp): # <<<<<<<<<<<<<< + * """Only used if instantiated manually by the user, or if Cython doesn't + * know how to convert the type""" + */ + +static PyObject *__pyx_memoryview_convert_item_to_object(struct __pyx_memoryview_obj *__pyx_v_self, char *__pyx_v_itemp) { + PyObject *__pyx_v_struct = NULL; + PyObject *__pyx_v_bytesitem = 0; + PyObject *__pyx_v_result = NULL; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + PyObject *__pyx_t_5 = NULL; + PyObject *__pyx_t_6 = NULL; + PyObject *__pyx_t_7 = NULL; + int __pyx_t_8; + Py_ssize_t __pyx_t_9; + int __pyx_t_10; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("convert_item_to_object", 1); + + /* "View.MemoryView":492 + * """Only used if instantiated manually by the user, or if Cython doesn't + * know how to convert the type""" + * import struct # <<<<<<<<<<<<<< + * cdef bytes bytesitem + * + */ + __pyx_t_1 = __Pyx_ImportDottedModule(__pyx_n_s_struct, NULL); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 492, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_v_struct = __pyx_t_1; + __pyx_t_1 = 0; + + /* "View.MemoryView":495 + * cdef bytes bytesitem + * + * bytesitem = itemp[:self.view.itemsize] # <<<<<<<<<<<<<< + * try: + * result = struct.unpack(self.view.format, bytesitem) + */ + __pyx_t_1 = __Pyx_PyBytes_FromStringAndSize(__pyx_v_itemp + 0, __pyx_v_self->view.itemsize - 0); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 495, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_v_bytesitem = ((PyObject*)__pyx_t_1); + __pyx_t_1 = 0; + + /* "View.MemoryView":496 + * + * bytesitem = itemp[:self.view.itemsize] + * try: # <<<<<<<<<<<<<< + * result = struct.unpack(self.view.format, bytesitem) + * except struct.error: + */ + { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ExceptionSave(&__pyx_t_2, &__pyx_t_3, &__pyx_t_4); + __Pyx_XGOTREF(__pyx_t_2); + __Pyx_XGOTREF(__pyx_t_3); + __Pyx_XGOTREF(__pyx_t_4); + /*try:*/ { + + /* "View.MemoryView":497 + * bytesitem = itemp[:self.view.itemsize] + * try: + * result = struct.unpack(self.view.format, bytesitem) # <<<<<<<<<<<<<< + * except struct.error: + * raise ValueError, "Unable to convert item to object" + */ + __pyx_t_5 = __Pyx_PyObject_GetAttrStr(__pyx_v_struct, __pyx_n_s_unpack); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 497, __pyx_L3_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_6 = __Pyx_PyBytes_FromString(__pyx_v_self->view.format); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 497, __pyx_L3_error) + __Pyx_GOTREF(__pyx_t_6); + __pyx_t_7 = NULL; + __pyx_t_8 = 0; + #if CYTHON_UNPACK_METHODS + if (likely(PyMethod_Check(__pyx_t_5))) { + __pyx_t_7 = PyMethod_GET_SELF(__pyx_t_5); + if (likely(__pyx_t_7)) { + PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_5); + __Pyx_INCREF(__pyx_t_7); + __Pyx_INCREF(function); + __Pyx_DECREF_SET(__pyx_t_5, function); + __pyx_t_8 = 1; + } + } + #endif + { + PyObject *__pyx_callargs[3] = {__pyx_t_7, __pyx_t_6, __pyx_v_bytesitem}; + __pyx_t_1 = __Pyx_PyObject_FastCall(__pyx_t_5, __pyx_callargs+1-__pyx_t_8, 2+__pyx_t_8); + __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 497, __pyx_L3_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + } + __pyx_v_result = __pyx_t_1; + __pyx_t_1 = 0; + + /* "View.MemoryView":496 + * + * bytesitem = itemp[:self.view.itemsize] + * try: # <<<<<<<<<<<<<< + * result = struct.unpack(self.view.format, bytesitem) + * except struct.error: + */ + } + + /* "View.MemoryView":501 + * raise ValueError, "Unable to convert item to object" + * else: + * if len(self.view.format) == 1: # <<<<<<<<<<<<<< + * return result[0] + * return result + */ + /*else:*/ { + __pyx_t_9 = __Pyx_ssize_strlen(__pyx_v_self->view.format); if (unlikely(__pyx_t_9 == ((Py_ssize_t)-1))) __PYX_ERR(1, 501, __pyx_L5_except_error) + __pyx_t_10 = (__pyx_t_9 == 1); + if (__pyx_t_10) { + + /* "View.MemoryView":502 + * else: + * if len(self.view.format) == 1: + * return result[0] # <<<<<<<<<<<<<< + * return result + * + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __Pyx_GetItemInt(__pyx_v_result, 0, long, 1, __Pyx_PyInt_From_long, 0, 0, 1); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 502, __pyx_L5_except_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L6_except_return; + + /* "View.MemoryView":501 + * raise ValueError, "Unable to convert item to object" + * else: + * if len(self.view.format) == 1: # <<<<<<<<<<<<<< + * return result[0] + * return result + */ + } + + /* "View.MemoryView":503 + * if len(self.view.format) == 1: + * return result[0] + * return result # <<<<<<<<<<<<<< + * + * cdef assign_item_from_object(self, char *itemp, object value): + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(__pyx_v_result); + __pyx_r = __pyx_v_result; + goto __pyx_L6_except_return; + } + __pyx_L3_error:; + __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_XDECREF(__pyx_t_6); __pyx_t_6 = 0; + __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "View.MemoryView":498 + * try: + * result = struct.unpack(self.view.format, bytesitem) + * except struct.error: # <<<<<<<<<<<<<< + * raise ValueError, "Unable to convert item to object" + * else: + */ + __Pyx_ErrFetch(&__pyx_t_1, &__pyx_t_5, &__pyx_t_6); + __pyx_t_7 = __Pyx_PyObject_GetAttrStr(__pyx_v_struct, __pyx_n_s_error); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 498, __pyx_L5_except_error) + __Pyx_GOTREF(__pyx_t_7); + __pyx_t_8 = __Pyx_PyErr_GivenExceptionMatches(__pyx_t_1, __pyx_t_7); + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + __Pyx_ErrRestore(__pyx_t_1, __pyx_t_5, __pyx_t_6); + __pyx_t_1 = 0; __pyx_t_5 = 0; __pyx_t_6 = 0; + if (__pyx_t_8) { + __Pyx_AddTraceback("View.MemoryView.memoryview.convert_item_to_object", __pyx_clineno, __pyx_lineno, __pyx_filename); + if (__Pyx_GetException(&__pyx_t_6, &__pyx_t_5, &__pyx_t_1) < 0) __PYX_ERR(1, 498, __pyx_L5_except_error) + __Pyx_XGOTREF(__pyx_t_6); + __Pyx_XGOTREF(__pyx_t_5); + __Pyx_XGOTREF(__pyx_t_1); + + /* "View.MemoryView":499 + * result = struct.unpack(self.view.format, bytesitem) + * except struct.error: + * raise ValueError, "Unable to convert item to object" # <<<<<<<<<<<<<< + * else: + * if len(self.view.format) == 1: + */ + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_kp_s_Unable_to_convert_item_to_object, 0, 0); + __PYX_ERR(1, 499, __pyx_L5_except_error) + } + goto __pyx_L5_except_error; + + /* "View.MemoryView":496 + * + * bytesitem = itemp[:self.view.itemsize] + * try: # <<<<<<<<<<<<<< + * result = struct.unpack(self.view.format, bytesitem) + * except struct.error: + */ + __pyx_L5_except_error:; + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_XGIVEREF(__pyx_t_4); + __Pyx_ExceptionReset(__pyx_t_2, __pyx_t_3, __pyx_t_4); + goto __pyx_L1_error; + __pyx_L6_except_return:; + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_XGIVEREF(__pyx_t_4); + __Pyx_ExceptionReset(__pyx_t_2, __pyx_t_3, __pyx_t_4); + goto __pyx_L0; + } + + /* "View.MemoryView":489 + * self.assign_item_from_object(itemp, value) + * + * cdef convert_item_to_object(self, char *itemp): # <<<<<<<<<<<<<< + * """Only used if instantiated manually by the user, or if Cython doesn't + * know how to convert the type""" + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_5); + __Pyx_XDECREF(__pyx_t_6); + __Pyx_XDECREF(__pyx_t_7); + __Pyx_AddTraceback("View.MemoryView.memoryview.convert_item_to_object", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_struct); + __Pyx_XDECREF(__pyx_v_bytesitem); + __Pyx_XDECREF(__pyx_v_result); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":505 + * return result + * + * cdef assign_item_from_object(self, char *itemp, object value): # <<<<<<<<<<<<<< + * """Only used if instantiated manually by the user, or if Cython doesn't + * know how to convert the type""" + */ + +static PyObject *__pyx_memoryview_assign_item_from_object(struct __pyx_memoryview_obj *__pyx_v_self, char *__pyx_v_itemp, PyObject *__pyx_v_value) { + PyObject *__pyx_v_struct = NULL; + char __pyx_v_c; + PyObject *__pyx_v_bytesvalue = 0; + Py_ssize_t __pyx_v_i; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + PyObject *__pyx_t_5 = NULL; + int __pyx_t_6; + Py_ssize_t __pyx_t_7; + PyObject *__pyx_t_8 = NULL; + char *__pyx_t_9; + char *__pyx_t_10; + char *__pyx_t_11; + char *__pyx_t_12; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("assign_item_from_object", 1); + + /* "View.MemoryView":508 + * """Only used if instantiated manually by the user, or if Cython doesn't + * know how to convert the type""" + * import struct # <<<<<<<<<<<<<< + * cdef char c + * cdef bytes bytesvalue + */ + __pyx_t_1 = __Pyx_ImportDottedModule(__pyx_n_s_struct, NULL); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 508, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_v_struct = __pyx_t_1; + __pyx_t_1 = 0; + + /* "View.MemoryView":513 + * cdef Py_ssize_t i + * + * if isinstance(value, tuple): # <<<<<<<<<<<<<< + * bytesvalue = struct.pack(self.view.format, *value) + * else: + */ + __pyx_t_2 = PyTuple_Check(__pyx_v_value); + if (__pyx_t_2) { + + /* "View.MemoryView":514 + * + * if isinstance(value, tuple): + * bytesvalue = struct.pack(self.view.format, *value) # <<<<<<<<<<<<<< + * else: + * bytesvalue = struct.pack(self.view.format, value) + */ + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(__pyx_v_struct, __pyx_n_s_pack); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 514, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_3 = __Pyx_PyBytes_FromString(__pyx_v_self->view.format); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 514, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_4 = PyTuple_New(1); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 514, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_GIVEREF(__pyx_t_3); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 0, __pyx_t_3)) __PYX_ERR(1, 514, __pyx_L1_error); + __pyx_t_3 = 0; + __pyx_t_3 = __Pyx_PySequence_Tuple(__pyx_v_value); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 514, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_5 = PyNumber_Add(__pyx_t_4, __pyx_t_3); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 514, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __pyx_t_3 = __Pyx_PyObject_Call(__pyx_t_1, __pyx_t_5, NULL); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 514, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + if (!(likely(PyBytes_CheckExact(__pyx_t_3))||((__pyx_t_3) == Py_None) || __Pyx_RaiseUnexpectedTypeError("bytes", __pyx_t_3))) __PYX_ERR(1, 514, __pyx_L1_error) + __pyx_v_bytesvalue = ((PyObject*)__pyx_t_3); + __pyx_t_3 = 0; + + /* "View.MemoryView":513 + * cdef Py_ssize_t i + * + * if isinstance(value, tuple): # <<<<<<<<<<<<<< + * bytesvalue = struct.pack(self.view.format, *value) + * else: + */ + goto __pyx_L3; + } + + /* "View.MemoryView":516 + * bytesvalue = struct.pack(self.view.format, *value) + * else: + * bytesvalue = struct.pack(self.view.format, value) # <<<<<<<<<<<<<< + * + * for i, c in enumerate(bytesvalue): + */ + /*else*/ { + __pyx_t_5 = __Pyx_PyObject_GetAttrStr(__pyx_v_struct, __pyx_n_s_pack); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 516, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_1 = __Pyx_PyBytes_FromString(__pyx_v_self->view.format); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 516, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_4 = NULL; + __pyx_t_6 = 0; + #if CYTHON_UNPACK_METHODS + if (likely(PyMethod_Check(__pyx_t_5))) { + __pyx_t_4 = PyMethod_GET_SELF(__pyx_t_5); + if (likely(__pyx_t_4)) { + PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_5); + __Pyx_INCREF(__pyx_t_4); + __Pyx_INCREF(function); + __Pyx_DECREF_SET(__pyx_t_5, function); + __pyx_t_6 = 1; + } + } + #endif + { + PyObject *__pyx_callargs[3] = {__pyx_t_4, __pyx_t_1, __pyx_v_value}; + __pyx_t_3 = __Pyx_PyObject_FastCall(__pyx_t_5, __pyx_callargs+1-__pyx_t_6, 2+__pyx_t_6); + __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 516, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + } + if (!(likely(PyBytes_CheckExact(__pyx_t_3))||((__pyx_t_3) == Py_None) || __Pyx_RaiseUnexpectedTypeError("bytes", __pyx_t_3))) __PYX_ERR(1, 516, __pyx_L1_error) + __pyx_v_bytesvalue = ((PyObject*)__pyx_t_3); + __pyx_t_3 = 0; + } + __pyx_L3:; + + /* "View.MemoryView":518 + * bytesvalue = struct.pack(self.view.format, value) + * + * for i, c in enumerate(bytesvalue): # <<<<<<<<<<<<<< + * itemp[i] = c + * + */ + __pyx_t_7 = 0; + if (unlikely(__pyx_v_bytesvalue == Py_None)) { + PyErr_SetString(PyExc_TypeError, "'NoneType' is not iterable"); + __PYX_ERR(1, 518, __pyx_L1_error) + } + __Pyx_INCREF(__pyx_v_bytesvalue); + __pyx_t_8 = __pyx_v_bytesvalue; + __pyx_t_10 = PyBytes_AS_STRING(__pyx_t_8); + __pyx_t_11 = (__pyx_t_10 + PyBytes_GET_SIZE(__pyx_t_8)); + for (__pyx_t_12 = __pyx_t_10; __pyx_t_12 < __pyx_t_11; __pyx_t_12++) { + __pyx_t_9 = __pyx_t_12; + __pyx_v_c = (__pyx_t_9[0]); + + /* "View.MemoryView":519 + * + * for i, c in enumerate(bytesvalue): + * itemp[i] = c # <<<<<<<<<<<<<< + * + * @cname('getbuffer') + */ + __pyx_v_i = __pyx_t_7; + + /* "View.MemoryView":518 + * bytesvalue = struct.pack(self.view.format, value) + * + * for i, c in enumerate(bytesvalue): # <<<<<<<<<<<<<< + * itemp[i] = c + * + */ + __pyx_t_7 = (__pyx_t_7 + 1); + + /* "View.MemoryView":519 + * + * for i, c in enumerate(bytesvalue): + * itemp[i] = c # <<<<<<<<<<<<<< + * + * @cname('getbuffer') + */ + (__pyx_v_itemp[__pyx_v_i]) = __pyx_v_c; + } + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + + /* "View.MemoryView":505 + * return result + * + * cdef assign_item_from_object(self, char *itemp, object value): # <<<<<<<<<<<<<< + * """Only used if instantiated manually by the user, or if Cython doesn't + * know how to convert the type""" + */ + + /* function exit code */ + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_XDECREF(__pyx_t_4); + __Pyx_XDECREF(__pyx_t_5); + __Pyx_XDECREF(__pyx_t_8); + __Pyx_AddTraceback("View.MemoryView.memoryview.assign_item_from_object", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_struct); + __Pyx_XDECREF(__pyx_v_bytesvalue); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":521 + * itemp[i] = c + * + * @cname('getbuffer') # <<<<<<<<<<<<<< + * def __getbuffer__(self, Py_buffer *info, int flags): + * if flags & PyBUF_WRITABLE and self.view.readonly: + */ + +/* Python wrapper */ +CYTHON_UNUSED static int __pyx_memoryview_getbuffer(PyObject *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags); /*proto*/ +CYTHON_UNUSED static int __pyx_memoryview_getbuffer(PyObject *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + int __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__getbuffer__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_8__getbuffer__(((struct __pyx_memoryview_obj *)__pyx_v_self), ((Py_buffer *)__pyx_v_info), ((int)__pyx_v_flags)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static int __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_8__getbuffer__(struct __pyx_memoryview_obj *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags) { + int __pyx_r; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + int __pyx_t_2; + Py_ssize_t *__pyx_t_3; + char *__pyx_t_4; + void *__pyx_t_5; + int __pyx_t_6; + Py_ssize_t __pyx_t_7; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + if (unlikely(__pyx_v_info == NULL)) { + PyErr_SetString(PyExc_BufferError, "PyObject_GetBuffer: view==NULL argument is obsolete"); + return -1; + } + __Pyx_RefNannySetupContext("__getbuffer__", 0); + __pyx_v_info->obj = Py_None; __Pyx_INCREF(Py_None); + __Pyx_GIVEREF(__pyx_v_info->obj); + + /* "View.MemoryView":523 + * @cname('getbuffer') + * def __getbuffer__(self, Py_buffer *info, int flags): + * if flags & PyBUF_WRITABLE and self.view.readonly: # <<<<<<<<<<<<<< + * raise ValueError, "Cannot create writable memory view from read-only memoryview" + * + */ + __pyx_t_2 = ((__pyx_v_flags & PyBUF_WRITABLE) != 0); + if (__pyx_t_2) { + } else { + __pyx_t_1 = __pyx_t_2; + goto __pyx_L4_bool_binop_done; + } + __pyx_t_1 = __pyx_v_self->view.readonly; + __pyx_L4_bool_binop_done:; + if (unlikely(__pyx_t_1)) { + + /* "View.MemoryView":524 + * def __getbuffer__(self, Py_buffer *info, int flags): + * if flags & PyBUF_WRITABLE and self.view.readonly: + * raise ValueError, "Cannot create writable memory view from read-only memoryview" # <<<<<<<<<<<<<< + * + * if flags & PyBUF_ND: + */ + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_kp_s_Cannot_create_writable_memory_vi, 0, 0); + __PYX_ERR(1, 524, __pyx_L1_error) + + /* "View.MemoryView":523 + * @cname('getbuffer') + * def __getbuffer__(self, Py_buffer *info, int flags): + * if flags & PyBUF_WRITABLE and self.view.readonly: # <<<<<<<<<<<<<< + * raise ValueError, "Cannot create writable memory view from read-only memoryview" + * + */ + } + + /* "View.MemoryView":526 + * raise ValueError, "Cannot create writable memory view from read-only memoryview" + * + * if flags & PyBUF_ND: # <<<<<<<<<<<<<< + * info.shape = self.view.shape + * else: + */ + __pyx_t_1 = ((__pyx_v_flags & PyBUF_ND) != 0); + if (__pyx_t_1) { + + /* "View.MemoryView":527 + * + * if flags & PyBUF_ND: + * info.shape = self.view.shape # <<<<<<<<<<<<<< + * else: + * info.shape = NULL + */ + __pyx_t_3 = __pyx_v_self->view.shape; + __pyx_v_info->shape = __pyx_t_3; + + /* "View.MemoryView":526 + * raise ValueError, "Cannot create writable memory view from read-only memoryview" + * + * if flags & PyBUF_ND: # <<<<<<<<<<<<<< + * info.shape = self.view.shape + * else: + */ + goto __pyx_L6; + } + + /* "View.MemoryView":529 + * info.shape = self.view.shape + * else: + * info.shape = NULL # <<<<<<<<<<<<<< + * + * if flags & PyBUF_STRIDES: + */ + /*else*/ { + __pyx_v_info->shape = NULL; + } + __pyx_L6:; + + /* "View.MemoryView":531 + * info.shape = NULL + * + * if flags & PyBUF_STRIDES: # <<<<<<<<<<<<<< + * info.strides = self.view.strides + * else: + */ + __pyx_t_1 = ((__pyx_v_flags & PyBUF_STRIDES) != 0); + if (__pyx_t_1) { + + /* "View.MemoryView":532 + * + * if flags & PyBUF_STRIDES: + * info.strides = self.view.strides # <<<<<<<<<<<<<< + * else: + * info.strides = NULL + */ + __pyx_t_3 = __pyx_v_self->view.strides; + __pyx_v_info->strides = __pyx_t_3; + + /* "View.MemoryView":531 + * info.shape = NULL + * + * if flags & PyBUF_STRIDES: # <<<<<<<<<<<<<< + * info.strides = self.view.strides + * else: + */ + goto __pyx_L7; + } + + /* "View.MemoryView":534 + * info.strides = self.view.strides + * else: + * info.strides = NULL # <<<<<<<<<<<<<< + * + * if flags & PyBUF_INDIRECT: + */ + /*else*/ { + __pyx_v_info->strides = NULL; + } + __pyx_L7:; + + /* "View.MemoryView":536 + * info.strides = NULL + * + * if flags & PyBUF_INDIRECT: # <<<<<<<<<<<<<< + * info.suboffsets = self.view.suboffsets + * else: + */ + __pyx_t_1 = ((__pyx_v_flags & PyBUF_INDIRECT) != 0); + if (__pyx_t_1) { + + /* "View.MemoryView":537 + * + * if flags & PyBUF_INDIRECT: + * info.suboffsets = self.view.suboffsets # <<<<<<<<<<<<<< + * else: + * info.suboffsets = NULL + */ + __pyx_t_3 = __pyx_v_self->view.suboffsets; + __pyx_v_info->suboffsets = __pyx_t_3; + + /* "View.MemoryView":536 + * info.strides = NULL + * + * if flags & PyBUF_INDIRECT: # <<<<<<<<<<<<<< + * info.suboffsets = self.view.suboffsets + * else: + */ + goto __pyx_L8; + } + + /* "View.MemoryView":539 + * info.suboffsets = self.view.suboffsets + * else: + * info.suboffsets = NULL # <<<<<<<<<<<<<< + * + * if flags & PyBUF_FORMAT: + */ + /*else*/ { + __pyx_v_info->suboffsets = NULL; + } + __pyx_L8:; + + /* "View.MemoryView":541 + * info.suboffsets = NULL + * + * if flags & PyBUF_FORMAT: # <<<<<<<<<<<<<< + * info.format = self.view.format + * else: + */ + __pyx_t_1 = ((__pyx_v_flags & PyBUF_FORMAT) != 0); + if (__pyx_t_1) { + + /* "View.MemoryView":542 + * + * if flags & PyBUF_FORMAT: + * info.format = self.view.format # <<<<<<<<<<<<<< + * else: + * info.format = NULL + */ + __pyx_t_4 = __pyx_v_self->view.format; + __pyx_v_info->format = __pyx_t_4; + + /* "View.MemoryView":541 + * info.suboffsets = NULL + * + * if flags & PyBUF_FORMAT: # <<<<<<<<<<<<<< + * info.format = self.view.format + * else: + */ + goto __pyx_L9; + } + + /* "View.MemoryView":544 + * info.format = self.view.format + * else: + * info.format = NULL # <<<<<<<<<<<<<< + * + * info.buf = self.view.buf + */ + /*else*/ { + __pyx_v_info->format = NULL; + } + __pyx_L9:; + + /* "View.MemoryView":546 + * info.format = NULL + * + * info.buf = self.view.buf # <<<<<<<<<<<<<< + * info.ndim = self.view.ndim + * info.itemsize = self.view.itemsize + */ + __pyx_t_5 = __pyx_v_self->view.buf; + __pyx_v_info->buf = __pyx_t_5; + + /* "View.MemoryView":547 + * + * info.buf = self.view.buf + * info.ndim = self.view.ndim # <<<<<<<<<<<<<< + * info.itemsize = self.view.itemsize + * info.len = self.view.len + */ + __pyx_t_6 = __pyx_v_self->view.ndim; + __pyx_v_info->ndim = __pyx_t_6; + + /* "View.MemoryView":548 + * info.buf = self.view.buf + * info.ndim = self.view.ndim + * info.itemsize = self.view.itemsize # <<<<<<<<<<<<<< + * info.len = self.view.len + * info.readonly = self.view.readonly + */ + __pyx_t_7 = __pyx_v_self->view.itemsize; + __pyx_v_info->itemsize = __pyx_t_7; + + /* "View.MemoryView":549 + * info.ndim = self.view.ndim + * info.itemsize = self.view.itemsize + * info.len = self.view.len # <<<<<<<<<<<<<< + * info.readonly = self.view.readonly + * info.obj = self + */ + __pyx_t_7 = __pyx_v_self->view.len; + __pyx_v_info->len = __pyx_t_7; + + /* "View.MemoryView":550 + * info.itemsize = self.view.itemsize + * info.len = self.view.len + * info.readonly = self.view.readonly # <<<<<<<<<<<<<< + * info.obj = self + * + */ + __pyx_t_1 = __pyx_v_self->view.readonly; + __pyx_v_info->readonly = __pyx_t_1; + + /* "View.MemoryView":551 + * info.len = self.view.len + * info.readonly = self.view.readonly + * info.obj = self # <<<<<<<<<<<<<< + * + * + */ + __Pyx_INCREF((PyObject *)__pyx_v_self); + __Pyx_GIVEREF((PyObject *)__pyx_v_self); + __Pyx_GOTREF(__pyx_v_info->obj); + __Pyx_DECREF(__pyx_v_info->obj); + __pyx_v_info->obj = ((PyObject *)__pyx_v_self); + + /* "View.MemoryView":521 + * itemp[i] = c + * + * @cname('getbuffer') # <<<<<<<<<<<<<< + * def __getbuffer__(self, Py_buffer *info, int flags): + * if flags & PyBUF_WRITABLE and self.view.readonly: + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView.memoryview.__getbuffer__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + if (__pyx_v_info->obj != NULL) { + __Pyx_GOTREF(__pyx_v_info->obj); + __Pyx_DECREF(__pyx_v_info->obj); __pyx_v_info->obj = 0; + } + goto __pyx_L2; + __pyx_L0:; + if (__pyx_v_info->obj == Py_None) { + __Pyx_GOTREF(__pyx_v_info->obj); + __Pyx_DECREF(__pyx_v_info->obj); __pyx_v_info->obj = 0; + } + __pyx_L2:; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":554 + * + * + * @property # <<<<<<<<<<<<<< + * def T(self): + * cdef _memoryviewslice result = memoryview_copy(self) + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_1T_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_1T_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_10memoryview_1T___get__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_1T___get__(struct __pyx_memoryview_obj *__pyx_v_self) { + struct __pyx_memoryviewslice_obj *__pyx_v_result = 0; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_t_2; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":556 + * @property + * def T(self): + * cdef _memoryviewslice result = memoryview_copy(self) # <<<<<<<<<<<<<< + * transpose_memslice(&result.from_slice) + * return result + */ + __pyx_t_1 = __pyx_memoryview_copy_object(__pyx_v_self); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 556, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + if (!(likely(((__pyx_t_1) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_1, __pyx_memoryviewslice_type))))) __PYX_ERR(1, 556, __pyx_L1_error) + __pyx_v_result = ((struct __pyx_memoryviewslice_obj *)__pyx_t_1); + __pyx_t_1 = 0; + + /* "View.MemoryView":557 + * def T(self): + * cdef _memoryviewslice result = memoryview_copy(self) + * transpose_memslice(&result.from_slice) # <<<<<<<<<<<<<< + * return result + * + */ + __pyx_t_2 = __pyx_memslice_transpose((&__pyx_v_result->from_slice)); if (unlikely(__pyx_t_2 == ((int)-1))) __PYX_ERR(1, 557, __pyx_L1_error) + + /* "View.MemoryView":558 + * cdef _memoryviewslice result = memoryview_copy(self) + * transpose_memslice(&result.from_slice) + * return result # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF((PyObject *)__pyx_v_result); + __pyx_r = ((PyObject *)__pyx_v_result); + goto __pyx_L0; + + /* "View.MemoryView":554 + * + * + * @property # <<<<<<<<<<<<<< + * def T(self): + * cdef _memoryviewslice result = memoryview_copy(self) + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("View.MemoryView.memoryview.T.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XDECREF((PyObject *)__pyx_v_result); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":560 + * return result + * + * @property # <<<<<<<<<<<<<< + * def base(self): + * return self._get_base() + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_4base_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_4base_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_10memoryview_4base___get__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_4base___get__(struct __pyx_memoryview_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":562 + * @property + * def base(self): + * return self._get_base() # <<<<<<<<<<<<<< + * + * cdef _get_base(self): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->_get_base(__pyx_v_self); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 562, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "View.MemoryView":560 + * return result + * + * @property # <<<<<<<<<<<<<< + * def base(self): + * return self._get_base() + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("View.MemoryView.memoryview.base.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":564 + * return self._get_base() + * + * cdef _get_base(self): # <<<<<<<<<<<<<< + * return self.obj + * + */ + +static PyObject *__pyx_memoryview__get_base(struct __pyx_memoryview_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("_get_base", 1); + + /* "View.MemoryView":565 + * + * cdef _get_base(self): + * return self.obj # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(__pyx_v_self->obj); + __pyx_r = __pyx_v_self->obj; + goto __pyx_L0; + + /* "View.MemoryView":564 + * return self._get_base() + * + * cdef _get_base(self): # <<<<<<<<<<<<<< + * return self.obj + * + */ + + /* function exit code */ + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":567 + * return self.obj + * + * @property # <<<<<<<<<<<<<< + * def shape(self): + * return tuple([length for length in self.view.shape[:self.view.ndim]]) + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_5shape_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_5shape_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_10memoryview_5shape___get__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_5shape___get__(struct __pyx_memoryview_obj *__pyx_v_self) { + Py_ssize_t __pyx_7genexpr__pyx_v_length; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + Py_ssize_t *__pyx_t_2; + Py_ssize_t *__pyx_t_3; + Py_ssize_t *__pyx_t_4; + PyObject *__pyx_t_5 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":569 + * @property + * def shape(self): + * return tuple([length for length in self.view.shape[:self.view.ndim]]) # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF(__pyx_r); + { /* enter inner scope */ + __pyx_t_1 = PyList_New(0); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 569, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_3 = (__pyx_v_self->view.shape + __pyx_v_self->view.ndim); + for (__pyx_t_4 = __pyx_v_self->view.shape; __pyx_t_4 < __pyx_t_3; __pyx_t_4++) { + __pyx_t_2 = __pyx_t_4; + __pyx_7genexpr__pyx_v_length = (__pyx_t_2[0]); + __pyx_t_5 = PyInt_FromSsize_t(__pyx_7genexpr__pyx_v_length); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 569, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + if (unlikely(__Pyx_ListComp_Append(__pyx_t_1, (PyObject*)__pyx_t_5))) __PYX_ERR(1, 569, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + } + } /* exit inner scope */ + __pyx_t_5 = PyList_AsTuple(((PyObject*)__pyx_t_1)); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 569, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_r = __pyx_t_5; + __pyx_t_5 = 0; + goto __pyx_L0; + + /* "View.MemoryView":567 + * return self.obj + * + * @property # <<<<<<<<<<<<<< + * def shape(self): + * return tuple([length for length in self.view.shape[:self.view.ndim]]) + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_5); + __Pyx_AddTraceback("View.MemoryView.memoryview.shape.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":571 + * return tuple([length for length in self.view.shape[:self.view.ndim]]) + * + * @property # <<<<<<<<<<<<<< + * def strides(self): + * if self.view.strides == NULL: + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_7strides_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_7strides_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_10memoryview_7strides___get__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_7strides___get__(struct __pyx_memoryview_obj *__pyx_v_self) { + Py_ssize_t __pyx_8genexpr1__pyx_v_stride; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + Py_ssize_t *__pyx_t_3; + Py_ssize_t *__pyx_t_4; + Py_ssize_t *__pyx_t_5; + PyObject *__pyx_t_6 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":573 + * @property + * def strides(self): + * if self.view.strides == NULL: # <<<<<<<<<<<<<< + * + * raise ValueError, "Buffer view does not expose strides" + */ + __pyx_t_1 = (__pyx_v_self->view.strides == NULL); + if (unlikely(__pyx_t_1)) { + + /* "View.MemoryView":575 + * if self.view.strides == NULL: + * + * raise ValueError, "Buffer view does not expose strides" # <<<<<<<<<<<<<< + * + * return tuple([stride for stride in self.view.strides[:self.view.ndim]]) + */ + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_kp_s_Buffer_view_does_not_expose_stri, 0, 0); + __PYX_ERR(1, 575, __pyx_L1_error) + + /* "View.MemoryView":573 + * @property + * def strides(self): + * if self.view.strides == NULL: # <<<<<<<<<<<<<< + * + * raise ValueError, "Buffer view does not expose strides" + */ + } + + /* "View.MemoryView":577 + * raise ValueError, "Buffer view does not expose strides" + * + * return tuple([stride for stride in self.view.strides[:self.view.ndim]]) # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF(__pyx_r); + { /* enter inner scope */ + __pyx_t_2 = PyList_New(0); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 577, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_4 = (__pyx_v_self->view.strides + __pyx_v_self->view.ndim); + for (__pyx_t_5 = __pyx_v_self->view.strides; __pyx_t_5 < __pyx_t_4; __pyx_t_5++) { + __pyx_t_3 = __pyx_t_5; + __pyx_8genexpr1__pyx_v_stride = (__pyx_t_3[0]); + __pyx_t_6 = PyInt_FromSsize_t(__pyx_8genexpr1__pyx_v_stride); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 577, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + if (unlikely(__Pyx_ListComp_Append(__pyx_t_2, (PyObject*)__pyx_t_6))) __PYX_ERR(1, 577, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + } + } /* exit inner scope */ + __pyx_t_6 = PyList_AsTuple(((PyObject*)__pyx_t_2)); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 577, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_r = __pyx_t_6; + __pyx_t_6 = 0; + goto __pyx_L0; + + /* "View.MemoryView":571 + * return tuple([length for length in self.view.shape[:self.view.ndim]]) + * + * @property # <<<<<<<<<<<<<< + * def strides(self): + * if self.view.strides == NULL: + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_6); + __Pyx_AddTraceback("View.MemoryView.memoryview.strides.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":579 + * return tuple([stride for stride in self.view.strides[:self.view.ndim]]) + * + * @property # <<<<<<<<<<<<<< + * def suboffsets(self): + * if self.view.suboffsets == NULL: + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_10suboffsets_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_10suboffsets_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_10memoryview_10suboffsets___get__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_10suboffsets___get__(struct __pyx_memoryview_obj *__pyx_v_self) { + Py_ssize_t __pyx_8genexpr2__pyx_v_suboffset; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + Py_ssize_t *__pyx_t_3; + Py_ssize_t *__pyx_t_4; + Py_ssize_t *__pyx_t_5; + PyObject *__pyx_t_6 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":581 + * @property + * def suboffsets(self): + * if self.view.suboffsets == NULL: # <<<<<<<<<<<<<< + * return (-1,) * self.view.ndim + * + */ + __pyx_t_1 = (__pyx_v_self->view.suboffsets == NULL); + if (__pyx_t_1) { + + /* "View.MemoryView":582 + * def suboffsets(self): + * if self.view.suboffsets == NULL: + * return (-1,) * self.view.ndim # <<<<<<<<<<<<<< + * + * return tuple([suboffset for suboffset in self.view.suboffsets[:self.view.ndim]]) + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = __Pyx_PySequence_Multiply(__pyx_tuple__4, __pyx_v_self->view.ndim); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 582, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":581 + * @property + * def suboffsets(self): + * if self.view.suboffsets == NULL: # <<<<<<<<<<<<<< + * return (-1,) * self.view.ndim + * + */ + } + + /* "View.MemoryView":584 + * return (-1,) * self.view.ndim + * + * return tuple([suboffset for suboffset in self.view.suboffsets[:self.view.ndim]]) # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF(__pyx_r); + { /* enter inner scope */ + __pyx_t_2 = PyList_New(0); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 584, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_4 = (__pyx_v_self->view.suboffsets + __pyx_v_self->view.ndim); + for (__pyx_t_5 = __pyx_v_self->view.suboffsets; __pyx_t_5 < __pyx_t_4; __pyx_t_5++) { + __pyx_t_3 = __pyx_t_5; + __pyx_8genexpr2__pyx_v_suboffset = (__pyx_t_3[0]); + __pyx_t_6 = PyInt_FromSsize_t(__pyx_8genexpr2__pyx_v_suboffset); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 584, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + if (unlikely(__Pyx_ListComp_Append(__pyx_t_2, (PyObject*)__pyx_t_6))) __PYX_ERR(1, 584, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + } + } /* exit inner scope */ + __pyx_t_6 = PyList_AsTuple(((PyObject*)__pyx_t_2)); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 584, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_r = __pyx_t_6; + __pyx_t_6 = 0; + goto __pyx_L0; + + /* "View.MemoryView":579 + * return tuple([stride for stride in self.view.strides[:self.view.ndim]]) + * + * @property # <<<<<<<<<<<<<< + * def suboffsets(self): + * if self.view.suboffsets == NULL: + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_6); + __Pyx_AddTraceback("View.MemoryView.memoryview.suboffsets.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":586 + * return tuple([suboffset for suboffset in self.view.suboffsets[:self.view.ndim]]) + * + * @property # <<<<<<<<<<<<<< + * def ndim(self): + * return self.view.ndim + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_4ndim_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_4ndim_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_10memoryview_4ndim___get__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_4ndim___get__(struct __pyx_memoryview_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":588 + * @property + * def ndim(self): + * return self.view.ndim # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __Pyx_PyInt_From_int(__pyx_v_self->view.ndim); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 588, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "View.MemoryView":586 + * return tuple([suboffset for suboffset in self.view.suboffsets[:self.view.ndim]]) + * + * @property # <<<<<<<<<<<<<< + * def ndim(self): + * return self.view.ndim + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("View.MemoryView.memoryview.ndim.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":590 + * return self.view.ndim + * + * @property # <<<<<<<<<<<<<< + * def itemsize(self): + * return self.view.itemsize + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_8itemsize_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_8itemsize_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_10memoryview_8itemsize___get__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_8itemsize___get__(struct __pyx_memoryview_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":592 + * @property + * def itemsize(self): + * return self.view.itemsize # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = PyInt_FromSsize_t(__pyx_v_self->view.itemsize); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 592, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "View.MemoryView":590 + * return self.view.ndim + * + * @property # <<<<<<<<<<<<<< + * def itemsize(self): + * return self.view.itemsize + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("View.MemoryView.memoryview.itemsize.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":594 + * return self.view.itemsize + * + * @property # <<<<<<<<<<<<<< + * def nbytes(self): + * return self.size * self.view.itemsize + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_6nbytes_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_6nbytes_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_10memoryview_6nbytes___get__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_6nbytes___get__(struct __pyx_memoryview_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":596 + * @property + * def nbytes(self): + * return self.size * self.view.itemsize # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self), __pyx_n_s_size); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 596, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = PyInt_FromSsize_t(__pyx_v_self->view.itemsize); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 596, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_3 = PyNumber_Multiply(__pyx_t_1, __pyx_t_2); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 596, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_r = __pyx_t_3; + __pyx_t_3 = 0; + goto __pyx_L0; + + /* "View.MemoryView":594 + * return self.view.itemsize + * + * @property # <<<<<<<<<<<<<< + * def nbytes(self): + * return self.size * self.view.itemsize + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_AddTraceback("View.MemoryView.memoryview.nbytes.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":598 + * return self.size * self.view.itemsize + * + * @property # <<<<<<<<<<<<<< + * def size(self): + * if self._size is None: + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_4size_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_4size_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_10memoryview_4size___get__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_4size___get__(struct __pyx_memoryview_obj *__pyx_v_self) { + PyObject *__pyx_v_result = NULL; + PyObject *__pyx_v_length = NULL; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + Py_ssize_t *__pyx_t_2; + Py_ssize_t *__pyx_t_3; + Py_ssize_t *__pyx_t_4; + PyObject *__pyx_t_5 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":600 + * @property + * def size(self): + * if self._size is None: # <<<<<<<<<<<<<< + * result = 1 + * + */ + __pyx_t_1 = (__pyx_v_self->_size == Py_None); + if (__pyx_t_1) { + + /* "View.MemoryView":601 + * def size(self): + * if self._size is None: + * result = 1 # <<<<<<<<<<<<<< + * + * for length in self.view.shape[:self.view.ndim]: + */ + __Pyx_INCREF(__pyx_int_1); + __pyx_v_result = __pyx_int_1; + + /* "View.MemoryView":603 + * result = 1 + * + * for length in self.view.shape[:self.view.ndim]: # <<<<<<<<<<<<<< + * result *= length + * + */ + __pyx_t_3 = (__pyx_v_self->view.shape + __pyx_v_self->view.ndim); + for (__pyx_t_4 = __pyx_v_self->view.shape; __pyx_t_4 < __pyx_t_3; __pyx_t_4++) { + __pyx_t_2 = __pyx_t_4; + __pyx_t_5 = PyInt_FromSsize_t((__pyx_t_2[0])); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 603, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_XDECREF_SET(__pyx_v_length, __pyx_t_5); + __pyx_t_5 = 0; + + /* "View.MemoryView":604 + * + * for length in self.view.shape[:self.view.ndim]: + * result *= length # <<<<<<<<<<<<<< + * + * self._size = result + */ + __pyx_t_5 = PyNumber_InPlaceMultiply(__pyx_v_result, __pyx_v_length); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 604, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_DECREF_SET(__pyx_v_result, __pyx_t_5); + __pyx_t_5 = 0; + } + + /* "View.MemoryView":606 + * result *= length + * + * self._size = result # <<<<<<<<<<<<<< + * + * return self._size + */ + __Pyx_INCREF(__pyx_v_result); + __Pyx_GIVEREF(__pyx_v_result); + __Pyx_GOTREF(__pyx_v_self->_size); + __Pyx_DECREF(__pyx_v_self->_size); + __pyx_v_self->_size = __pyx_v_result; + + /* "View.MemoryView":600 + * @property + * def size(self): + * if self._size is None: # <<<<<<<<<<<<<< + * result = 1 + * + */ + } + + /* "View.MemoryView":608 + * self._size = result + * + * return self._size # <<<<<<<<<<<<<< + * + * def __len__(self): + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(__pyx_v_self->_size); + __pyx_r = __pyx_v_self->_size; + goto __pyx_L0; + + /* "View.MemoryView":598 + * return self.size * self.view.itemsize + * + * @property # <<<<<<<<<<<<<< + * def size(self): + * if self._size is None: + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_5); + __Pyx_AddTraceback("View.MemoryView.memoryview.size.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_result); + __Pyx_XDECREF(__pyx_v_length); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":610 + * return self._size + * + * def __len__(self): # <<<<<<<<<<<<<< + * if self.view.ndim >= 1: + * return self.view.shape[0] + */ + +/* Python wrapper */ +static Py_ssize_t __pyx_memoryview___len__(PyObject *__pyx_v_self); /*proto*/ +static Py_ssize_t __pyx_memoryview___len__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + Py_ssize_t __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__len__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_10__len__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static Py_ssize_t __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_10__len__(struct __pyx_memoryview_obj *__pyx_v_self) { + Py_ssize_t __pyx_r; + int __pyx_t_1; + + /* "View.MemoryView":611 + * + * def __len__(self): + * if self.view.ndim >= 1: # <<<<<<<<<<<<<< + * return self.view.shape[0] + * + */ + __pyx_t_1 = (__pyx_v_self->view.ndim >= 1); + if (__pyx_t_1) { + + /* "View.MemoryView":612 + * def __len__(self): + * if self.view.ndim >= 1: + * return self.view.shape[0] # <<<<<<<<<<<<<< + * + * return 0 + */ + __pyx_r = (__pyx_v_self->view.shape[0]); + goto __pyx_L0; + + /* "View.MemoryView":611 + * + * def __len__(self): + * if self.view.ndim >= 1: # <<<<<<<<<<<<<< + * return self.view.shape[0] + * + */ + } + + /* "View.MemoryView":614 + * return self.view.shape[0] + * + * return 0 # <<<<<<<<<<<<<< + * + * def __repr__(self): + */ + __pyx_r = 0; + goto __pyx_L0; + + /* "View.MemoryView":610 + * return self._size + * + * def __len__(self): # <<<<<<<<<<<<<< + * if self.view.ndim >= 1: + * return self.view.shape[0] + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":616 + * return 0 + * + * def __repr__(self): # <<<<<<<<<<<<<< + * return "" % (self.base.__class__.__name__, + * id(self)) + */ + +/* Python wrapper */ +static PyObject *__pyx_memoryview___repr__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_memoryview___repr__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__repr__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_12__repr__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_12__repr__(struct __pyx_memoryview_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__repr__", 1); + + /* "View.MemoryView":617 + * + * def __repr__(self): + * return "" % (self.base.__class__.__name__, # <<<<<<<<<<<<<< + * id(self)) + * + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self), __pyx_n_s_base); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 617, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_t_1, __pyx_n_s_class); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 617, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(__pyx_t_2, __pyx_n_s_name_2); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 617, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + + /* "View.MemoryView":618 + * def __repr__(self): + * return "" % (self.base.__class__.__name__, + * id(self)) # <<<<<<<<<<<<<< + * + * def __str__(self): + */ + __pyx_t_2 = __Pyx_PyObject_CallOneArg(__pyx_builtin_id, ((PyObject *)__pyx_v_self)); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 618, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + + /* "View.MemoryView":617 + * + * def __repr__(self): + * return "" % (self.base.__class__.__name__, # <<<<<<<<<<<<<< + * id(self)) + * + */ + __pyx_t_3 = PyTuple_New(2); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 617, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_GIVEREF(__pyx_t_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_t_1)) __PYX_ERR(1, 617, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_2); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_t_2)) __PYX_ERR(1, 617, __pyx_L1_error); + __pyx_t_1 = 0; + __pyx_t_2 = 0; + __pyx_t_2 = __Pyx_PyString_Format(__pyx_kp_s_MemoryView_of_r_at_0x_x, __pyx_t_3); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 617, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":616 + * return 0 + * + * def __repr__(self): # <<<<<<<<<<<<<< + * return "" % (self.base.__class__.__name__, + * id(self)) + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_AddTraceback("View.MemoryView.memoryview.__repr__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":620 + * id(self)) + * + * def __str__(self): # <<<<<<<<<<<<<< + * return "" % (self.base.__class__.__name__,) + * + */ + +/* Python wrapper */ +static PyObject *__pyx_memoryview___str__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_memoryview___str__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__str__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_14__str__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_14__str__(struct __pyx_memoryview_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__str__", 1); + + /* "View.MemoryView":621 + * + * def __str__(self): + * return "" % (self.base.__class__.__name__,) # <<<<<<<<<<<<<< + * + * + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self), __pyx_n_s_base); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 621, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_t_1, __pyx_n_s_class); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 621, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(__pyx_t_2, __pyx_n_s_name_2); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 621, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_t_2 = PyTuple_New(1); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 621, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_GIVEREF(__pyx_t_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_2, 0, __pyx_t_1)) __PYX_ERR(1, 621, __pyx_L1_error); + __pyx_t_1 = 0; + __pyx_t_1 = __Pyx_PyString_Format(__pyx_kp_s_MemoryView_of_r_object, __pyx_t_2); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 621, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "View.MemoryView":620 + * id(self)) + * + * def __str__(self): # <<<<<<<<<<<<<< + * return "" % (self.base.__class__.__name__,) + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.memoryview.__str__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":624 + * + * + * def is_c_contig(self): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice *mslice + * cdef __Pyx_memviewslice tmp + */ + +/* Python wrapper */ +static PyObject *__pyx_memoryview_is_c_contig(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_memoryview_is_c_contig(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("is_c_contig (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + if (unlikely(__pyx_nargs > 0)) { + __Pyx_RaiseArgtupleInvalid("is_c_contig", 1, 0, 0, __pyx_nargs); return NULL;} + if (unlikely(__pyx_kwds) && __Pyx_NumKwargs_FASTCALL(__pyx_kwds) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "is_c_contig", 0))) return NULL; + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_16is_c_contig(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_16is_c_contig(struct __pyx_memoryview_obj *__pyx_v_self) { + __Pyx_memviewslice *__pyx_v_mslice; + __Pyx_memviewslice __pyx_v_tmp; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_memviewslice *__pyx_t_1; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("is_c_contig", 1); + + /* "View.MemoryView":627 + * cdef __Pyx_memviewslice *mslice + * cdef __Pyx_memviewslice tmp + * mslice = get_slice_from_memview(self, &tmp) # <<<<<<<<<<<<<< + * return slice_is_contig(mslice[0], 'C', self.view.ndim) + * + */ + __pyx_t_1 = __pyx_memoryview_get_slice_from_memoryview(__pyx_v_self, (&__pyx_v_tmp)); if (unlikely(__pyx_t_1 == ((__Pyx_memviewslice *)NULL))) __PYX_ERR(1, 627, __pyx_L1_error) + __pyx_v_mslice = __pyx_t_1; + + /* "View.MemoryView":628 + * cdef __Pyx_memviewslice tmp + * mslice = get_slice_from_memview(self, &tmp) + * return slice_is_contig(mslice[0], 'C', self.view.ndim) # <<<<<<<<<<<<<< + * + * def is_f_contig(self): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = __Pyx_PyBool_FromLong(__pyx_memviewslice_is_contig((__pyx_v_mslice[0]), 'C', __pyx_v_self->view.ndim)); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 628, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":624 + * + * + * def is_c_contig(self): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice *mslice + * cdef __Pyx_memviewslice tmp + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.memoryview.is_c_contig", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":630 + * return slice_is_contig(mslice[0], 'C', self.view.ndim) + * + * def is_f_contig(self): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice *mslice + * cdef __Pyx_memviewslice tmp + */ + +/* Python wrapper */ +static PyObject *__pyx_memoryview_is_f_contig(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_memoryview_is_f_contig(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("is_f_contig (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + if (unlikely(__pyx_nargs > 0)) { + __Pyx_RaiseArgtupleInvalid("is_f_contig", 1, 0, 0, __pyx_nargs); return NULL;} + if (unlikely(__pyx_kwds) && __Pyx_NumKwargs_FASTCALL(__pyx_kwds) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "is_f_contig", 0))) return NULL; + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_18is_f_contig(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_18is_f_contig(struct __pyx_memoryview_obj *__pyx_v_self) { + __Pyx_memviewslice *__pyx_v_mslice; + __Pyx_memviewslice __pyx_v_tmp; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_memviewslice *__pyx_t_1; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("is_f_contig", 1); + + /* "View.MemoryView":633 + * cdef __Pyx_memviewslice *mslice + * cdef __Pyx_memviewslice tmp + * mslice = get_slice_from_memview(self, &tmp) # <<<<<<<<<<<<<< + * return slice_is_contig(mslice[0], 'F', self.view.ndim) + * + */ + __pyx_t_1 = __pyx_memoryview_get_slice_from_memoryview(__pyx_v_self, (&__pyx_v_tmp)); if (unlikely(__pyx_t_1 == ((__Pyx_memviewslice *)NULL))) __PYX_ERR(1, 633, __pyx_L1_error) + __pyx_v_mslice = __pyx_t_1; + + /* "View.MemoryView":634 + * cdef __Pyx_memviewslice tmp + * mslice = get_slice_from_memview(self, &tmp) + * return slice_is_contig(mslice[0], 'F', self.view.ndim) # <<<<<<<<<<<<<< + * + * def copy(self): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = __Pyx_PyBool_FromLong(__pyx_memviewslice_is_contig((__pyx_v_mslice[0]), 'F', __pyx_v_self->view.ndim)); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 634, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":630 + * return slice_is_contig(mslice[0], 'C', self.view.ndim) + * + * def is_f_contig(self): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice *mslice + * cdef __Pyx_memviewslice tmp + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.memoryview.is_f_contig", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":636 + * return slice_is_contig(mslice[0], 'F', self.view.ndim) + * + * def copy(self): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice mslice + * cdef int flags = self.flags & ~PyBUF_F_CONTIGUOUS + */ + +/* Python wrapper */ +static PyObject *__pyx_memoryview_copy(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_memoryview_copy(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("copy (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + if (unlikely(__pyx_nargs > 0)) { + __Pyx_RaiseArgtupleInvalid("copy", 1, 0, 0, __pyx_nargs); return NULL;} + if (unlikely(__pyx_kwds) && __Pyx_NumKwargs_FASTCALL(__pyx_kwds) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "copy", 0))) return NULL; + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_20copy(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_20copy(struct __pyx_memoryview_obj *__pyx_v_self) { + __Pyx_memviewslice __pyx_v_mslice; + int __pyx_v_flags; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_memviewslice __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("copy", 1); + + /* "View.MemoryView":638 + * def copy(self): + * cdef __Pyx_memviewslice mslice + * cdef int flags = self.flags & ~PyBUF_F_CONTIGUOUS # <<<<<<<<<<<<<< + * + * slice_copy(self, &mslice) + */ + __pyx_v_flags = (__pyx_v_self->flags & (~PyBUF_F_CONTIGUOUS)); + + /* "View.MemoryView":640 + * cdef int flags = self.flags & ~PyBUF_F_CONTIGUOUS + * + * slice_copy(self, &mslice) # <<<<<<<<<<<<<< + * mslice = slice_copy_contig(&mslice, "c", self.view.ndim, + * self.view.itemsize, + */ + __pyx_memoryview_slice_copy(__pyx_v_self, (&__pyx_v_mslice)); + + /* "View.MemoryView":641 + * + * slice_copy(self, &mslice) + * mslice = slice_copy_contig(&mslice, "c", self.view.ndim, # <<<<<<<<<<<<<< + * self.view.itemsize, + * flags|PyBUF_C_CONTIGUOUS, + */ + __pyx_t_1 = __pyx_memoryview_copy_new_contig((&__pyx_v_mslice), ((char *)"c"), __pyx_v_self->view.ndim, __pyx_v_self->view.itemsize, (__pyx_v_flags | PyBUF_C_CONTIGUOUS), __pyx_v_self->dtype_is_object); if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 641, __pyx_L1_error) + __pyx_v_mslice = __pyx_t_1; + + /* "View.MemoryView":646 + * self.dtype_is_object) + * + * return memoryview_copy_from_slice(self, &mslice) # <<<<<<<<<<<<<< + * + * def copy_fortran(self): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = __pyx_memoryview_copy_object_from_slice(__pyx_v_self, (&__pyx_v_mslice)); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 646, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":636 + * return slice_is_contig(mslice[0], 'F', self.view.ndim) + * + * def copy(self): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice mslice + * cdef int flags = self.flags & ~PyBUF_F_CONTIGUOUS + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.memoryview.copy", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":648 + * return memoryview_copy_from_slice(self, &mslice) + * + * def copy_fortran(self): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice src, dst + * cdef int flags = self.flags & ~PyBUF_C_CONTIGUOUS + */ + +/* Python wrapper */ +static PyObject *__pyx_memoryview_copy_fortran(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_memoryview_copy_fortran(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("copy_fortran (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + if (unlikely(__pyx_nargs > 0)) { + __Pyx_RaiseArgtupleInvalid("copy_fortran", 1, 0, 0, __pyx_nargs); return NULL;} + if (unlikely(__pyx_kwds) && __Pyx_NumKwargs_FASTCALL(__pyx_kwds) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "copy_fortran", 0))) return NULL; + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_22copy_fortran(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_22copy_fortran(struct __pyx_memoryview_obj *__pyx_v_self) { + __Pyx_memviewslice __pyx_v_src; + __Pyx_memviewslice __pyx_v_dst; + int __pyx_v_flags; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_memviewslice __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("copy_fortran", 1); + + /* "View.MemoryView":650 + * def copy_fortran(self): + * cdef __Pyx_memviewslice src, dst + * cdef int flags = self.flags & ~PyBUF_C_CONTIGUOUS # <<<<<<<<<<<<<< + * + * slice_copy(self, &src) + */ + __pyx_v_flags = (__pyx_v_self->flags & (~PyBUF_C_CONTIGUOUS)); + + /* "View.MemoryView":652 + * cdef int flags = self.flags & ~PyBUF_C_CONTIGUOUS + * + * slice_copy(self, &src) # <<<<<<<<<<<<<< + * dst = slice_copy_contig(&src, "fortran", self.view.ndim, + * self.view.itemsize, + */ + __pyx_memoryview_slice_copy(__pyx_v_self, (&__pyx_v_src)); + + /* "View.MemoryView":653 + * + * slice_copy(self, &src) + * dst = slice_copy_contig(&src, "fortran", self.view.ndim, # <<<<<<<<<<<<<< + * self.view.itemsize, + * flags|PyBUF_F_CONTIGUOUS, + */ + __pyx_t_1 = __pyx_memoryview_copy_new_contig((&__pyx_v_src), ((char *)"fortran"), __pyx_v_self->view.ndim, __pyx_v_self->view.itemsize, (__pyx_v_flags | PyBUF_F_CONTIGUOUS), __pyx_v_self->dtype_is_object); if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 653, __pyx_L1_error) + __pyx_v_dst = __pyx_t_1; + + /* "View.MemoryView":658 + * self.dtype_is_object) + * + * return memoryview_copy_from_slice(self, &dst) # <<<<<<<<<<<<<< + * + * + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = __pyx_memoryview_copy_object_from_slice(__pyx_v_self, (&__pyx_v_dst)); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 658, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":648 + * return memoryview_copy_from_slice(self, &mslice) + * + * def copy_fortran(self): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice src, dst + * cdef int flags = self.flags & ~PyBUF_C_CONTIGUOUS + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.memoryview.copy_fortran", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): + */ + +/* Python wrapper */ +static PyObject *__pyx_pw___pyx_memoryview_1__reduce_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_pw___pyx_memoryview_1__reduce_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__reduce_cython__ (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + if (unlikely(__pyx_nargs > 0)) { + __Pyx_RaiseArgtupleInvalid("__reduce_cython__", 1, 0, 0, __pyx_nargs); return NULL;} + if (unlikely(__pyx_kwds) && __Pyx_NumKwargs_FASTCALL(__pyx_kwds) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "__reduce_cython__", 0))) return NULL; + __pyx_r = __pyx_pf___pyx_memoryview___reduce_cython__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf___pyx_memoryview___reduce_cython__(CYTHON_UNUSED struct __pyx_memoryview_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__reduce_cython__", 1); + + /* "(tree fragment)":2 + * def __reduce_cython__(self): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" # <<<<<<<<<<<<<< + * def __setstate_cython__(self, __pyx_state): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + */ + __Pyx_Raise(__pyx_builtin_TypeError, __pyx_kp_s_no_default___reduce___due_to_non, 0, 0); + __PYX_ERR(1, 2, __pyx_L1_error) + + /* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView.memoryview.__reduce_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":3 + * def __reduce_cython__(self): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + */ + +/* Python wrapper */ +static PyObject *__pyx_pw___pyx_memoryview_3__setstate_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_pw___pyx_memoryview_3__setstate_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + CYTHON_UNUSED PyObject *__pyx_v___pyx_state = 0; + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[1] = {0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__setstate_cython__ (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_pyx_state,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 1: values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_FASTCALL(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_pyx_state)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 3, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "__setstate_cython__") < 0)) __PYX_ERR(1, 3, __pyx_L3_error) + } + } else if (unlikely(__pyx_nargs != 1)) { + goto __pyx_L5_argtuple_error; + } else { + values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + } + __pyx_v___pyx_state = values[0]; + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("__setstate_cython__", 1, 1, 1, __pyx_nargs); __PYX_ERR(1, 3, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("View.MemoryView.memoryview.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return NULL; + __pyx_L4_argument_unpacking_done:; + __pyx_r = __pyx_pf___pyx_memoryview_2__setstate_cython__(((struct __pyx_memoryview_obj *)__pyx_v_self), __pyx_v___pyx_state); + + /* function exit code */ + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf___pyx_memoryview_2__setstate_cython__(CYTHON_UNUSED struct __pyx_memoryview_obj *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__setstate_cython__", 1); + + /* "(tree fragment)":4 + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" # <<<<<<<<<<<<<< + */ + __Pyx_Raise(__pyx_builtin_TypeError, __pyx_kp_s_no_default___reduce___due_to_non, 0, 0); + __PYX_ERR(1, 4, __pyx_L1_error) + + /* "(tree fragment)":3 + * def __reduce_cython__(self): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView.memoryview.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":662 + * + * @cname('__pyx_memoryview_new') + * cdef memoryview_cwrapper(object o, int flags, bint dtype_is_object, __Pyx_TypeInfo *typeinfo): # <<<<<<<<<<<<<< + * cdef memoryview result = memoryview(o, flags, dtype_is_object) + * result.typeinfo = typeinfo + */ + +static PyObject *__pyx_memoryview_new(PyObject *__pyx_v_o, int __pyx_v_flags, int __pyx_v_dtype_is_object, __Pyx_TypeInfo *__pyx_v_typeinfo) { + struct __pyx_memoryview_obj *__pyx_v_result = 0; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("memoryview_cwrapper", 1); + + /* "View.MemoryView":663 + * @cname('__pyx_memoryview_new') + * cdef memoryview_cwrapper(object o, int flags, bint dtype_is_object, __Pyx_TypeInfo *typeinfo): + * cdef memoryview result = memoryview(o, flags, dtype_is_object) # <<<<<<<<<<<<<< + * result.typeinfo = typeinfo + * return result + */ + __pyx_t_1 = __Pyx_PyInt_From_int(__pyx_v_flags); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 663, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = __Pyx_PyBool_FromLong(__pyx_v_dtype_is_object); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 663, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_3 = PyTuple_New(3); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 663, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_INCREF(__pyx_v_o); + __Pyx_GIVEREF(__pyx_v_o); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_v_o)) __PYX_ERR(1, 663, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_t_1)) __PYX_ERR(1, 663, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_2); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 2, __pyx_t_2)) __PYX_ERR(1, 663, __pyx_L1_error); + __pyx_t_1 = 0; + __pyx_t_2 = 0; + __pyx_t_2 = __Pyx_PyObject_Call(((PyObject *)__pyx_memoryview_type), __pyx_t_3, NULL); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 663, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __pyx_v_result = ((struct __pyx_memoryview_obj *)__pyx_t_2); + __pyx_t_2 = 0; + + /* "View.MemoryView":664 + * cdef memoryview_cwrapper(object o, int flags, bint dtype_is_object, __Pyx_TypeInfo *typeinfo): + * cdef memoryview result = memoryview(o, flags, dtype_is_object) + * result.typeinfo = typeinfo # <<<<<<<<<<<<<< + * return result + * + */ + __pyx_v_result->typeinfo = __pyx_v_typeinfo; + + /* "View.MemoryView":665 + * cdef memoryview result = memoryview(o, flags, dtype_is_object) + * result.typeinfo = typeinfo + * return result # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_check') + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF((PyObject *)__pyx_v_result); + __pyx_r = ((PyObject *)__pyx_v_result); + goto __pyx_L0; + + /* "View.MemoryView":662 + * + * @cname('__pyx_memoryview_new') + * cdef memoryview_cwrapper(object o, int flags, bint dtype_is_object, __Pyx_TypeInfo *typeinfo): # <<<<<<<<<<<<<< + * cdef memoryview result = memoryview(o, flags, dtype_is_object) + * result.typeinfo = typeinfo + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_AddTraceback("View.MemoryView.memoryview_cwrapper", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XDECREF((PyObject *)__pyx_v_result); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":668 + * + * @cname('__pyx_memoryview_check') + * cdef inline bint memoryview_check(object o) noexcept: # <<<<<<<<<<<<<< + * return isinstance(o, memoryview) + * + */ + +static CYTHON_INLINE int __pyx_memoryview_check(PyObject *__pyx_v_o) { + int __pyx_r; + int __pyx_t_1; + + /* "View.MemoryView":669 + * @cname('__pyx_memoryview_check') + * cdef inline bint memoryview_check(object o) noexcept: + * return isinstance(o, memoryview) # <<<<<<<<<<<<<< + * + * cdef tuple _unellipsify(object index, int ndim): + */ + __pyx_t_1 = __Pyx_TypeCheck(__pyx_v_o, __pyx_memoryview_type); + __pyx_r = __pyx_t_1; + goto __pyx_L0; + + /* "View.MemoryView":668 + * + * @cname('__pyx_memoryview_check') + * cdef inline bint memoryview_check(object o) noexcept: # <<<<<<<<<<<<<< + * return isinstance(o, memoryview) + * + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":671 + * return isinstance(o, memoryview) + * + * cdef tuple _unellipsify(object index, int ndim): # <<<<<<<<<<<<<< + * """ + * Replace all ellipses with full slices and fill incomplete indices with + */ + +static PyObject *_unellipsify(PyObject *__pyx_v_index, int __pyx_v_ndim) { + Py_ssize_t __pyx_v_idx; + PyObject *__pyx_v_tup = NULL; + PyObject *__pyx_v_result = NULL; + int __pyx_v_have_slices; + int __pyx_v_seen_ellipsis; + PyObject *__pyx_v_item = NULL; + Py_ssize_t __pyx_v_nslices; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + Py_ssize_t __pyx_t_4; + Py_ssize_t __pyx_t_5; + Py_UCS4 __pyx_t_6; + PyObject *__pyx_t_7 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("_unellipsify", 1); + + /* "View.MemoryView":677 + * """ + * cdef Py_ssize_t idx + * tup = index if isinstance(index, tuple) else (index,) # <<<<<<<<<<<<<< + * + * result = [slice(None)] * ndim + */ + __pyx_t_2 = PyTuple_Check(__pyx_v_index); + if (__pyx_t_2) { + __Pyx_INCREF(((PyObject*)__pyx_v_index)); + __pyx_t_1 = __pyx_v_index; + } else { + __pyx_t_3 = PyTuple_New(1); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 677, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_INCREF(__pyx_v_index); + __Pyx_GIVEREF(__pyx_v_index); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_v_index)) __PYX_ERR(1, 677, __pyx_L1_error); + __pyx_t_1 = __pyx_t_3; + __pyx_t_3 = 0; + } + __pyx_v_tup = ((PyObject*)__pyx_t_1); + __pyx_t_1 = 0; + + /* "View.MemoryView":679 + * tup = index if isinstance(index, tuple) else (index,) + * + * result = [slice(None)] * ndim # <<<<<<<<<<<<<< + * have_slices = False + * seen_ellipsis = False + */ + __pyx_t_1 = PyList_New(1 * ((__pyx_v_ndim<0) ? 0:__pyx_v_ndim)); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 679, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + { Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < __pyx_v_ndim; __pyx_temp++) { + __Pyx_INCREF(__pyx_slice__5); + __Pyx_GIVEREF(__pyx_slice__5); + if (__Pyx_PyList_SET_ITEM(__pyx_t_1, __pyx_temp, __pyx_slice__5)) __PYX_ERR(1, 679, __pyx_L1_error); + } + } + __pyx_v_result = ((PyObject*)__pyx_t_1); + __pyx_t_1 = 0; + + /* "View.MemoryView":680 + * + * result = [slice(None)] * ndim + * have_slices = False # <<<<<<<<<<<<<< + * seen_ellipsis = False + * idx = 0 + */ + __pyx_v_have_slices = 0; + + /* "View.MemoryView":681 + * result = [slice(None)] * ndim + * have_slices = False + * seen_ellipsis = False # <<<<<<<<<<<<<< + * idx = 0 + * for item in tup: + */ + __pyx_v_seen_ellipsis = 0; + + /* "View.MemoryView":682 + * have_slices = False + * seen_ellipsis = False + * idx = 0 # <<<<<<<<<<<<<< + * for item in tup: + * if item is Ellipsis: + */ + __pyx_v_idx = 0; + + /* "View.MemoryView":683 + * seen_ellipsis = False + * idx = 0 + * for item in tup: # <<<<<<<<<<<<<< + * if item is Ellipsis: + * if not seen_ellipsis: + */ + if (unlikely(__pyx_v_tup == Py_None)) { + PyErr_SetString(PyExc_TypeError, "'NoneType' object is not iterable"); + __PYX_ERR(1, 683, __pyx_L1_error) + } + __pyx_t_1 = __pyx_v_tup; __Pyx_INCREF(__pyx_t_1); + __pyx_t_4 = 0; + for (;;) { + { + Py_ssize_t __pyx_temp = __Pyx_PyTuple_GET_SIZE(__pyx_t_1); + #if !CYTHON_ASSUME_SAFE_MACROS + if (unlikely((__pyx_temp < 0))) __PYX_ERR(1, 683, __pyx_L1_error) + #endif + if (__pyx_t_4 >= __pyx_temp) break; + } + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + __pyx_t_3 = PyTuple_GET_ITEM(__pyx_t_1, __pyx_t_4); __Pyx_INCREF(__pyx_t_3); __pyx_t_4++; if (unlikely((0 < 0))) __PYX_ERR(1, 683, __pyx_L1_error) + #else + __pyx_t_3 = __Pyx_PySequence_ITEM(__pyx_t_1, __pyx_t_4); __pyx_t_4++; if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 683, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + #endif + __Pyx_XDECREF_SET(__pyx_v_item, __pyx_t_3); + __pyx_t_3 = 0; + + /* "View.MemoryView":684 + * idx = 0 + * for item in tup: + * if item is Ellipsis: # <<<<<<<<<<<<<< + * if not seen_ellipsis: + * idx += ndim - len(tup) + */ + __pyx_t_2 = (__pyx_v_item == __pyx_builtin_Ellipsis); + if (__pyx_t_2) { + + /* "View.MemoryView":685 + * for item in tup: + * if item is Ellipsis: + * if not seen_ellipsis: # <<<<<<<<<<<<<< + * idx += ndim - len(tup) + * seen_ellipsis = True + */ + __pyx_t_2 = (!__pyx_v_seen_ellipsis); + if (__pyx_t_2) { + + /* "View.MemoryView":686 + * if item is Ellipsis: + * if not seen_ellipsis: + * idx += ndim - len(tup) # <<<<<<<<<<<<<< + * seen_ellipsis = True + * have_slices = True + */ + if (unlikely(__pyx_v_tup == Py_None)) { + PyErr_SetString(PyExc_TypeError, "object of type 'NoneType' has no len()"); + __PYX_ERR(1, 686, __pyx_L1_error) + } + __pyx_t_5 = __Pyx_PyTuple_GET_SIZE(__pyx_v_tup); if (unlikely(__pyx_t_5 == ((Py_ssize_t)-1))) __PYX_ERR(1, 686, __pyx_L1_error) + __pyx_v_idx = (__pyx_v_idx + (__pyx_v_ndim - __pyx_t_5)); + + /* "View.MemoryView":687 + * if not seen_ellipsis: + * idx += ndim - len(tup) + * seen_ellipsis = True # <<<<<<<<<<<<<< + * have_slices = True + * else: + */ + __pyx_v_seen_ellipsis = 1; + + /* "View.MemoryView":685 + * for item in tup: + * if item is Ellipsis: + * if not seen_ellipsis: # <<<<<<<<<<<<<< + * idx += ndim - len(tup) + * seen_ellipsis = True + */ + } + + /* "View.MemoryView":688 + * idx += ndim - len(tup) + * seen_ellipsis = True + * have_slices = True # <<<<<<<<<<<<<< + * else: + * if isinstance(item, slice): + */ + __pyx_v_have_slices = 1; + + /* "View.MemoryView":684 + * idx = 0 + * for item in tup: + * if item is Ellipsis: # <<<<<<<<<<<<<< + * if not seen_ellipsis: + * idx += ndim - len(tup) + */ + goto __pyx_L5; + } + + /* "View.MemoryView":690 + * have_slices = True + * else: + * if isinstance(item, slice): # <<<<<<<<<<<<<< + * have_slices = True + * elif not PyIndex_Check(item): + */ + /*else*/ { + __pyx_t_2 = PySlice_Check(__pyx_v_item); + if (__pyx_t_2) { + + /* "View.MemoryView":691 + * else: + * if isinstance(item, slice): + * have_slices = True # <<<<<<<<<<<<<< + * elif not PyIndex_Check(item): + * raise TypeError, f"Cannot index with type '{type(item)}'" + */ + __pyx_v_have_slices = 1; + + /* "View.MemoryView":690 + * have_slices = True + * else: + * if isinstance(item, slice): # <<<<<<<<<<<<<< + * have_slices = True + * elif not PyIndex_Check(item): + */ + goto __pyx_L7; + } + + /* "View.MemoryView":692 + * if isinstance(item, slice): + * have_slices = True + * elif not PyIndex_Check(item): # <<<<<<<<<<<<<< + * raise TypeError, f"Cannot index with type '{type(item)}'" + * result[idx] = item + */ + __pyx_t_2 = (!(PyIndex_Check(__pyx_v_item) != 0)); + if (unlikely(__pyx_t_2)) { + + /* "View.MemoryView":693 + * have_slices = True + * elif not PyIndex_Check(item): + * raise TypeError, f"Cannot index with type '{type(item)}'" # <<<<<<<<<<<<<< + * result[idx] = item + * idx += 1 + */ + __pyx_t_3 = PyTuple_New(3); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 693, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_5 = 0; + __pyx_t_6 = 127; + __Pyx_INCREF(__pyx_kp_u_Cannot_index_with_type); + __pyx_t_5 += 24; + __Pyx_GIVEREF(__pyx_kp_u_Cannot_index_with_type); + PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_kp_u_Cannot_index_with_type); + __pyx_t_7 = __Pyx_PyObject_FormatSimple(((PyObject *)Py_TYPE(__pyx_v_item)), __pyx_empty_unicode); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 693, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __pyx_t_6 = (__Pyx_PyUnicode_MAX_CHAR_VALUE(__pyx_t_7) > __pyx_t_6) ? __Pyx_PyUnicode_MAX_CHAR_VALUE(__pyx_t_7) : __pyx_t_6; + __pyx_t_5 += __Pyx_PyUnicode_GET_LENGTH(__pyx_t_7); + __Pyx_GIVEREF(__pyx_t_7); + PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_t_7); + __pyx_t_7 = 0; + __Pyx_INCREF(__pyx_kp_u__6); + __pyx_t_5 += 1; + __Pyx_GIVEREF(__pyx_kp_u__6); + PyTuple_SET_ITEM(__pyx_t_3, 2, __pyx_kp_u__6); + __pyx_t_7 = __Pyx_PyUnicode_Join(__pyx_t_3, 3, __pyx_t_5, __pyx_t_6); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 693, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __Pyx_Raise(__pyx_builtin_TypeError, __pyx_t_7, 0, 0); + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + __PYX_ERR(1, 693, __pyx_L1_error) + + /* "View.MemoryView":692 + * if isinstance(item, slice): + * have_slices = True + * elif not PyIndex_Check(item): # <<<<<<<<<<<<<< + * raise TypeError, f"Cannot index with type '{type(item)}'" + * result[idx] = item + */ + } + __pyx_L7:; + + /* "View.MemoryView":694 + * elif not PyIndex_Check(item): + * raise TypeError, f"Cannot index with type '{type(item)}'" + * result[idx] = item # <<<<<<<<<<<<<< + * idx += 1 + * + */ + if (unlikely((__Pyx_SetItemInt(__pyx_v_result, __pyx_v_idx, __pyx_v_item, Py_ssize_t, 1, PyInt_FromSsize_t, 1, 1, 1) < 0))) __PYX_ERR(1, 694, __pyx_L1_error) + } + __pyx_L5:; + + /* "View.MemoryView":695 + * raise TypeError, f"Cannot index with type '{type(item)}'" + * result[idx] = item + * idx += 1 # <<<<<<<<<<<<<< + * + * nslices = ndim - idx + */ + __pyx_v_idx = (__pyx_v_idx + 1); + + /* "View.MemoryView":683 + * seen_ellipsis = False + * idx = 0 + * for item in tup: # <<<<<<<<<<<<<< + * if item is Ellipsis: + * if not seen_ellipsis: + */ + } + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + + /* "View.MemoryView":697 + * idx += 1 + * + * nslices = ndim - idx # <<<<<<<<<<<<<< + * return have_slices or nslices, tuple(result) + * + */ + __pyx_v_nslices = (__pyx_v_ndim - __pyx_v_idx); + + /* "View.MemoryView":698 + * + * nslices = ndim - idx + * return have_slices or nslices, tuple(result) # <<<<<<<<<<<<<< + * + * cdef int assert_direct_dimensions(Py_ssize_t *suboffsets, int ndim) except -1: + */ + __Pyx_XDECREF(__pyx_r); + if (!__pyx_v_have_slices) { + } else { + __pyx_t_7 = __Pyx_PyBool_FromLong(__pyx_v_have_slices); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 698, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __pyx_t_1 = __pyx_t_7; + __pyx_t_7 = 0; + goto __pyx_L9_bool_binop_done; + } + __pyx_t_7 = PyInt_FromSsize_t(__pyx_v_nslices); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 698, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __pyx_t_1 = __pyx_t_7; + __pyx_t_7 = 0; + __pyx_L9_bool_binop_done:; + __pyx_t_7 = PyList_AsTuple(__pyx_v_result); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 698, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __pyx_t_3 = PyTuple_New(2); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 698, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_GIVEREF(__pyx_t_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_t_1)) __PYX_ERR(1, 698, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_7); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_t_7)) __PYX_ERR(1, 698, __pyx_L1_error); + __pyx_t_1 = 0; + __pyx_t_7 = 0; + __pyx_r = ((PyObject*)__pyx_t_3); + __pyx_t_3 = 0; + goto __pyx_L0; + + /* "View.MemoryView":671 + * return isinstance(o, memoryview) + * + * cdef tuple _unellipsify(object index, int ndim): # <<<<<<<<<<<<<< + * """ + * Replace all ellipses with full slices and fill incomplete indices with + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_XDECREF(__pyx_t_7); + __Pyx_AddTraceback("View.MemoryView._unellipsify", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_tup); + __Pyx_XDECREF(__pyx_v_result); + __Pyx_XDECREF(__pyx_v_item); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":700 + * return have_slices or nslices, tuple(result) + * + * cdef int assert_direct_dimensions(Py_ssize_t *suboffsets, int ndim) except -1: # <<<<<<<<<<<<<< + * for suboffset in suboffsets[:ndim]: + * if suboffset >= 0: + */ + +static int assert_direct_dimensions(Py_ssize_t *__pyx_v_suboffsets, int __pyx_v_ndim) { + Py_ssize_t __pyx_v_suboffset; + int __pyx_r; + Py_ssize_t *__pyx_t_1; + Py_ssize_t *__pyx_t_2; + Py_ssize_t *__pyx_t_3; + int __pyx_t_4; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + + /* "View.MemoryView":701 + * + * cdef int assert_direct_dimensions(Py_ssize_t *suboffsets, int ndim) except -1: + * for suboffset in suboffsets[:ndim]: # <<<<<<<<<<<<<< + * if suboffset >= 0: + * raise ValueError, "Indirect dimensions not supported" + */ + __pyx_t_2 = (__pyx_v_suboffsets + __pyx_v_ndim); + for (__pyx_t_3 = __pyx_v_suboffsets; __pyx_t_3 < __pyx_t_2; __pyx_t_3++) { + __pyx_t_1 = __pyx_t_3; + __pyx_v_suboffset = (__pyx_t_1[0]); + + /* "View.MemoryView":702 + * cdef int assert_direct_dimensions(Py_ssize_t *suboffsets, int ndim) except -1: + * for suboffset in suboffsets[:ndim]: + * if suboffset >= 0: # <<<<<<<<<<<<<< + * raise ValueError, "Indirect dimensions not supported" + * return 0 # return type just used as an error flag + */ + __pyx_t_4 = (__pyx_v_suboffset >= 0); + if (unlikely(__pyx_t_4)) { + + /* "View.MemoryView":703 + * for suboffset in suboffsets[:ndim]: + * if suboffset >= 0: + * raise ValueError, "Indirect dimensions not supported" # <<<<<<<<<<<<<< + * return 0 # return type just used as an error flag + * + */ + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_kp_s_Indirect_dimensions_not_supporte, 0, 0); + __PYX_ERR(1, 703, __pyx_L1_error) + + /* "View.MemoryView":702 + * cdef int assert_direct_dimensions(Py_ssize_t *suboffsets, int ndim) except -1: + * for suboffset in suboffsets[:ndim]: + * if suboffset >= 0: # <<<<<<<<<<<<<< + * raise ValueError, "Indirect dimensions not supported" + * return 0 # return type just used as an error flag + */ + } + } + + /* "View.MemoryView":704 + * if suboffset >= 0: + * raise ValueError, "Indirect dimensions not supported" + * return 0 # return type just used as an error flag # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = 0; + goto __pyx_L0; + + /* "View.MemoryView":700 + * return have_slices or nslices, tuple(result) + * + * cdef int assert_direct_dimensions(Py_ssize_t *suboffsets, int ndim) except -1: # <<<<<<<<<<<<<< + * for suboffset in suboffsets[:ndim]: + * if suboffset >= 0: + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView.assert_direct_dimensions", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":711 + * + * @cname('__pyx_memview_slice') + * cdef memoryview memview_slice(memoryview memview, object indices): # <<<<<<<<<<<<<< + * cdef int new_ndim = 0, suboffset_dim = -1, dim + * cdef bint negative_step + */ + +static struct __pyx_memoryview_obj *__pyx_memview_slice(struct __pyx_memoryview_obj *__pyx_v_memview, PyObject *__pyx_v_indices) { + int __pyx_v_new_ndim; + int __pyx_v_suboffset_dim; + int __pyx_v_dim; + __Pyx_memviewslice __pyx_v_src; + __Pyx_memviewslice __pyx_v_dst; + __Pyx_memviewslice *__pyx_v_p_src; + struct __pyx_memoryviewslice_obj *__pyx_v_memviewsliceobj = 0; + __Pyx_memviewslice *__pyx_v_p_dst; + int *__pyx_v_p_suboffset_dim; + Py_ssize_t __pyx_v_start; + Py_ssize_t __pyx_v_stop; + Py_ssize_t __pyx_v_step; + Py_ssize_t __pyx_v_cindex; + int __pyx_v_have_start; + int __pyx_v_have_stop; + int __pyx_v_have_step; + PyObject *__pyx_v_index = NULL; + struct __pyx_memoryview_obj *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + struct __pyx_memoryview_obj *__pyx_t_3; + char *__pyx_t_4; + int __pyx_t_5; + Py_ssize_t __pyx_t_6; + PyObject *(*__pyx_t_7)(PyObject *); + PyObject *__pyx_t_8 = NULL; + Py_ssize_t __pyx_t_9; + int __pyx_t_10; + Py_ssize_t __pyx_t_11; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("memview_slice", 1); + + /* "View.MemoryView":712 + * @cname('__pyx_memview_slice') + * cdef memoryview memview_slice(memoryview memview, object indices): + * cdef int new_ndim = 0, suboffset_dim = -1, dim # <<<<<<<<<<<<<< + * cdef bint negative_step + * cdef __Pyx_memviewslice src, dst + */ + __pyx_v_new_ndim = 0; + __pyx_v_suboffset_dim = -1; + + /* "View.MemoryView":719 + * + * + * memset(&dst, 0, sizeof(dst)) # <<<<<<<<<<<<<< + * + * cdef _memoryviewslice memviewsliceobj + */ + (void)(memset((&__pyx_v_dst), 0, (sizeof(__pyx_v_dst)))); + + /* "View.MemoryView":723 + * cdef _memoryviewslice memviewsliceobj + * + * assert memview.view.ndim > 0 # <<<<<<<<<<<<<< + * + * if isinstance(memview, _memoryviewslice): + */ + #ifndef CYTHON_WITHOUT_ASSERTIONS + if (unlikely(__pyx_assertions_enabled())) { + __pyx_t_1 = (__pyx_v_memview->view.ndim > 0); + if (unlikely(!__pyx_t_1)) { + __Pyx_Raise(__pyx_builtin_AssertionError, 0, 0, 0); + __PYX_ERR(1, 723, __pyx_L1_error) + } + } + #else + if ((1)); else __PYX_ERR(1, 723, __pyx_L1_error) + #endif + + /* "View.MemoryView":725 + * assert memview.view.ndim > 0 + * + * if isinstance(memview, _memoryviewslice): # <<<<<<<<<<<<<< + * memviewsliceobj = memview + * p_src = &memviewsliceobj.from_slice + */ + __pyx_t_1 = __Pyx_TypeCheck(((PyObject *)__pyx_v_memview), __pyx_memoryviewslice_type); + if (__pyx_t_1) { + + /* "View.MemoryView":726 + * + * if isinstance(memview, _memoryviewslice): + * memviewsliceobj = memview # <<<<<<<<<<<<<< + * p_src = &memviewsliceobj.from_slice + * else: + */ + if (!(likely(((((PyObject *)__pyx_v_memview)) == Py_None) || likely(__Pyx_TypeTest(((PyObject *)__pyx_v_memview), __pyx_memoryviewslice_type))))) __PYX_ERR(1, 726, __pyx_L1_error) + __pyx_t_2 = ((PyObject *)__pyx_v_memview); + __Pyx_INCREF(__pyx_t_2); + __pyx_v_memviewsliceobj = ((struct __pyx_memoryviewslice_obj *)__pyx_t_2); + __pyx_t_2 = 0; + + /* "View.MemoryView":727 + * if isinstance(memview, _memoryviewslice): + * memviewsliceobj = memview + * p_src = &memviewsliceobj.from_slice # <<<<<<<<<<<<<< + * else: + * slice_copy(memview, &src) + */ + __pyx_v_p_src = (&__pyx_v_memviewsliceobj->from_slice); + + /* "View.MemoryView":725 + * assert memview.view.ndim > 0 + * + * if isinstance(memview, _memoryviewslice): # <<<<<<<<<<<<<< + * memviewsliceobj = memview + * p_src = &memviewsliceobj.from_slice + */ + goto __pyx_L3; + } + + /* "View.MemoryView":729 + * p_src = &memviewsliceobj.from_slice + * else: + * slice_copy(memview, &src) # <<<<<<<<<<<<<< + * p_src = &src + * + */ + /*else*/ { + __pyx_memoryview_slice_copy(__pyx_v_memview, (&__pyx_v_src)); + + /* "View.MemoryView":730 + * else: + * slice_copy(memview, &src) + * p_src = &src # <<<<<<<<<<<<<< + * + * + */ + __pyx_v_p_src = (&__pyx_v_src); + } + __pyx_L3:; + + /* "View.MemoryView":736 + * + * + * dst.memview = p_src.memview # <<<<<<<<<<<<<< + * dst.data = p_src.data + * + */ + __pyx_t_3 = __pyx_v_p_src->memview; + __pyx_v_dst.memview = __pyx_t_3; + + /* "View.MemoryView":737 + * + * dst.memview = p_src.memview + * dst.data = p_src.data # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_4 = __pyx_v_p_src->data; + __pyx_v_dst.data = __pyx_t_4; + + /* "View.MemoryView":742 + * + * + * cdef __Pyx_memviewslice *p_dst = &dst # <<<<<<<<<<<<<< + * cdef int *p_suboffset_dim = &suboffset_dim + * cdef Py_ssize_t start, stop, step, cindex + */ + __pyx_v_p_dst = (&__pyx_v_dst); + + /* "View.MemoryView":743 + * + * cdef __Pyx_memviewslice *p_dst = &dst + * cdef int *p_suboffset_dim = &suboffset_dim # <<<<<<<<<<<<<< + * cdef Py_ssize_t start, stop, step, cindex + * cdef bint have_start, have_stop, have_step + */ + __pyx_v_p_suboffset_dim = (&__pyx_v_suboffset_dim); + + /* "View.MemoryView":747 + * cdef bint have_start, have_stop, have_step + * + * for dim, index in enumerate(indices): # <<<<<<<<<<<<<< + * if PyIndex_Check(index): + * cindex = index + */ + __pyx_t_5 = 0; + if (likely(PyList_CheckExact(__pyx_v_indices)) || PyTuple_CheckExact(__pyx_v_indices)) { + __pyx_t_2 = __pyx_v_indices; __Pyx_INCREF(__pyx_t_2); + __pyx_t_6 = 0; + __pyx_t_7 = NULL; + } else { + __pyx_t_6 = -1; __pyx_t_2 = PyObject_GetIter(__pyx_v_indices); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 747, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_7 = __Pyx_PyObject_GetIterNextFunc(__pyx_t_2); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 747, __pyx_L1_error) + } + for (;;) { + if (likely(!__pyx_t_7)) { + if (likely(PyList_CheckExact(__pyx_t_2))) { + { + Py_ssize_t __pyx_temp = __Pyx_PyList_GET_SIZE(__pyx_t_2); + #if !CYTHON_ASSUME_SAFE_MACROS + if (unlikely((__pyx_temp < 0))) __PYX_ERR(1, 747, __pyx_L1_error) + #endif + if (__pyx_t_6 >= __pyx_temp) break; + } + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + __pyx_t_8 = PyList_GET_ITEM(__pyx_t_2, __pyx_t_6); __Pyx_INCREF(__pyx_t_8); __pyx_t_6++; if (unlikely((0 < 0))) __PYX_ERR(1, 747, __pyx_L1_error) + #else + __pyx_t_8 = __Pyx_PySequence_ITEM(__pyx_t_2, __pyx_t_6); __pyx_t_6++; if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 747, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_8); + #endif + } else { + { + Py_ssize_t __pyx_temp = __Pyx_PyTuple_GET_SIZE(__pyx_t_2); + #if !CYTHON_ASSUME_SAFE_MACROS + if (unlikely((__pyx_temp < 0))) __PYX_ERR(1, 747, __pyx_L1_error) + #endif + if (__pyx_t_6 >= __pyx_temp) break; + } + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + __pyx_t_8 = PyTuple_GET_ITEM(__pyx_t_2, __pyx_t_6); __Pyx_INCREF(__pyx_t_8); __pyx_t_6++; if (unlikely((0 < 0))) __PYX_ERR(1, 747, __pyx_L1_error) + #else + __pyx_t_8 = __Pyx_PySequence_ITEM(__pyx_t_2, __pyx_t_6); __pyx_t_6++; if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 747, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_8); + #endif + } + } else { + __pyx_t_8 = __pyx_t_7(__pyx_t_2); + if (unlikely(!__pyx_t_8)) { + PyObject* exc_type = PyErr_Occurred(); + if (exc_type) { + if (likely(__Pyx_PyErr_GivenExceptionMatches(exc_type, PyExc_StopIteration))) PyErr_Clear(); + else __PYX_ERR(1, 747, __pyx_L1_error) + } + break; + } + __Pyx_GOTREF(__pyx_t_8); + } + __Pyx_XDECREF_SET(__pyx_v_index, __pyx_t_8); + __pyx_t_8 = 0; + __pyx_v_dim = __pyx_t_5; + __pyx_t_5 = (__pyx_t_5 + 1); + + /* "View.MemoryView":748 + * + * for dim, index in enumerate(indices): + * if PyIndex_Check(index): # <<<<<<<<<<<<<< + * cindex = index + * slice_memviewslice( + */ + __pyx_t_1 = (PyIndex_Check(__pyx_v_index) != 0); + if (__pyx_t_1) { + + /* "View.MemoryView":749 + * for dim, index in enumerate(indices): + * if PyIndex_Check(index): + * cindex = index # <<<<<<<<<<<<<< + * slice_memviewslice( + * p_dst, p_src.shape[dim], p_src.strides[dim], p_src.suboffsets[dim], + */ + __pyx_t_9 = __Pyx_PyIndex_AsSsize_t(__pyx_v_index); if (unlikely((__pyx_t_9 == (Py_ssize_t)-1) && PyErr_Occurred())) __PYX_ERR(1, 749, __pyx_L1_error) + __pyx_v_cindex = __pyx_t_9; + + /* "View.MemoryView":750 + * if PyIndex_Check(index): + * cindex = index + * slice_memviewslice( # <<<<<<<<<<<<<< + * p_dst, p_src.shape[dim], p_src.strides[dim], p_src.suboffsets[dim], + * dim, new_ndim, p_suboffset_dim, + */ + __pyx_t_10 = __pyx_memoryview_slice_memviewslice(__pyx_v_p_dst, (__pyx_v_p_src->shape[__pyx_v_dim]), (__pyx_v_p_src->strides[__pyx_v_dim]), (__pyx_v_p_src->suboffsets[__pyx_v_dim]), __pyx_v_dim, __pyx_v_new_ndim, __pyx_v_p_suboffset_dim, __pyx_v_cindex, 0, 0, 0, 0, 0, 0); if (unlikely(__pyx_t_10 == ((int)-1))) __PYX_ERR(1, 750, __pyx_L1_error) + + /* "View.MemoryView":748 + * + * for dim, index in enumerate(indices): + * if PyIndex_Check(index): # <<<<<<<<<<<<<< + * cindex = index + * slice_memviewslice( + */ + goto __pyx_L6; + } + + /* "View.MemoryView":756 + * 0, 0, 0, # have_{start,stop,step} + * False) + * elif index is None: # <<<<<<<<<<<<<< + * p_dst.shape[new_ndim] = 1 + * p_dst.strides[new_ndim] = 0 + */ + __pyx_t_1 = (__pyx_v_index == Py_None); + if (__pyx_t_1) { + + /* "View.MemoryView":757 + * False) + * elif index is None: + * p_dst.shape[new_ndim] = 1 # <<<<<<<<<<<<<< + * p_dst.strides[new_ndim] = 0 + * p_dst.suboffsets[new_ndim] = -1 + */ + (__pyx_v_p_dst->shape[__pyx_v_new_ndim]) = 1; + + /* "View.MemoryView":758 + * elif index is None: + * p_dst.shape[new_ndim] = 1 + * p_dst.strides[new_ndim] = 0 # <<<<<<<<<<<<<< + * p_dst.suboffsets[new_ndim] = -1 + * new_ndim += 1 + */ + (__pyx_v_p_dst->strides[__pyx_v_new_ndim]) = 0; + + /* "View.MemoryView":759 + * p_dst.shape[new_ndim] = 1 + * p_dst.strides[new_ndim] = 0 + * p_dst.suboffsets[new_ndim] = -1 # <<<<<<<<<<<<<< + * new_ndim += 1 + * else: + */ + (__pyx_v_p_dst->suboffsets[__pyx_v_new_ndim]) = -1L; + + /* "View.MemoryView":760 + * p_dst.strides[new_ndim] = 0 + * p_dst.suboffsets[new_ndim] = -1 + * new_ndim += 1 # <<<<<<<<<<<<<< + * else: + * start = index.start or 0 + */ + __pyx_v_new_ndim = (__pyx_v_new_ndim + 1); + + /* "View.MemoryView":756 + * 0, 0, 0, # have_{start,stop,step} + * False) + * elif index is None: # <<<<<<<<<<<<<< + * p_dst.shape[new_ndim] = 1 + * p_dst.strides[new_ndim] = 0 + */ + goto __pyx_L6; + } + + /* "View.MemoryView":762 + * new_ndim += 1 + * else: + * start = index.start or 0 # <<<<<<<<<<<<<< + * stop = index.stop or 0 + * step = index.step or 0 + */ + /*else*/ { + __pyx_t_8 = __Pyx_PyObject_GetAttrStr(__pyx_v_index, __pyx_n_s_start); if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 762, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_8); + __pyx_t_1 = __Pyx_PyObject_IsTrue(__pyx_t_8); if (unlikely((__pyx_t_1 < 0))) __PYX_ERR(1, 762, __pyx_L1_error) + if (!__pyx_t_1) { + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + } else { + __pyx_t_11 = __Pyx_PyIndex_AsSsize_t(__pyx_t_8); if (unlikely((__pyx_t_11 == (Py_ssize_t)-1) && PyErr_Occurred())) __PYX_ERR(1, 762, __pyx_L1_error) + __pyx_t_9 = __pyx_t_11; + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + goto __pyx_L7_bool_binop_done; + } + __pyx_t_9 = 0; + __pyx_L7_bool_binop_done:; + __pyx_v_start = __pyx_t_9; + + /* "View.MemoryView":763 + * else: + * start = index.start or 0 + * stop = index.stop or 0 # <<<<<<<<<<<<<< + * step = index.step or 0 + * + */ + __pyx_t_8 = __Pyx_PyObject_GetAttrStr(__pyx_v_index, __pyx_n_s_stop); if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 763, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_8); + __pyx_t_1 = __Pyx_PyObject_IsTrue(__pyx_t_8); if (unlikely((__pyx_t_1 < 0))) __PYX_ERR(1, 763, __pyx_L1_error) + if (!__pyx_t_1) { + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + } else { + __pyx_t_11 = __Pyx_PyIndex_AsSsize_t(__pyx_t_8); if (unlikely((__pyx_t_11 == (Py_ssize_t)-1) && PyErr_Occurred())) __PYX_ERR(1, 763, __pyx_L1_error) + __pyx_t_9 = __pyx_t_11; + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + goto __pyx_L9_bool_binop_done; + } + __pyx_t_9 = 0; + __pyx_L9_bool_binop_done:; + __pyx_v_stop = __pyx_t_9; + + /* "View.MemoryView":764 + * start = index.start or 0 + * stop = index.stop or 0 + * step = index.step or 0 # <<<<<<<<<<<<<< + * + * have_start = index.start is not None + */ + __pyx_t_8 = __Pyx_PyObject_GetAttrStr(__pyx_v_index, __pyx_n_s_step); if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 764, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_8); + __pyx_t_1 = __Pyx_PyObject_IsTrue(__pyx_t_8); if (unlikely((__pyx_t_1 < 0))) __PYX_ERR(1, 764, __pyx_L1_error) + if (!__pyx_t_1) { + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + } else { + __pyx_t_11 = __Pyx_PyIndex_AsSsize_t(__pyx_t_8); if (unlikely((__pyx_t_11 == (Py_ssize_t)-1) && PyErr_Occurred())) __PYX_ERR(1, 764, __pyx_L1_error) + __pyx_t_9 = __pyx_t_11; + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + goto __pyx_L11_bool_binop_done; + } + __pyx_t_9 = 0; + __pyx_L11_bool_binop_done:; + __pyx_v_step = __pyx_t_9; + + /* "View.MemoryView":766 + * step = index.step or 0 + * + * have_start = index.start is not None # <<<<<<<<<<<<<< + * have_stop = index.stop is not None + * have_step = index.step is not None + */ + __pyx_t_8 = __Pyx_PyObject_GetAttrStr(__pyx_v_index, __pyx_n_s_start); if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 766, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_8); + __pyx_t_1 = (__pyx_t_8 != Py_None); + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + __pyx_v_have_start = __pyx_t_1; + + /* "View.MemoryView":767 + * + * have_start = index.start is not None + * have_stop = index.stop is not None # <<<<<<<<<<<<<< + * have_step = index.step is not None + * + */ + __pyx_t_8 = __Pyx_PyObject_GetAttrStr(__pyx_v_index, __pyx_n_s_stop); if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 767, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_8); + __pyx_t_1 = (__pyx_t_8 != Py_None); + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + __pyx_v_have_stop = __pyx_t_1; + + /* "View.MemoryView":768 + * have_start = index.start is not None + * have_stop = index.stop is not None + * have_step = index.step is not None # <<<<<<<<<<<<<< + * + * slice_memviewslice( + */ + __pyx_t_8 = __Pyx_PyObject_GetAttrStr(__pyx_v_index, __pyx_n_s_step); if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 768, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_8); + __pyx_t_1 = (__pyx_t_8 != Py_None); + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + __pyx_v_have_step = __pyx_t_1; + + /* "View.MemoryView":770 + * have_step = index.step is not None + * + * slice_memviewslice( # <<<<<<<<<<<<<< + * p_dst, p_src.shape[dim], p_src.strides[dim], p_src.suboffsets[dim], + * dim, new_ndim, p_suboffset_dim, + */ + __pyx_t_10 = __pyx_memoryview_slice_memviewslice(__pyx_v_p_dst, (__pyx_v_p_src->shape[__pyx_v_dim]), (__pyx_v_p_src->strides[__pyx_v_dim]), (__pyx_v_p_src->suboffsets[__pyx_v_dim]), __pyx_v_dim, __pyx_v_new_ndim, __pyx_v_p_suboffset_dim, __pyx_v_start, __pyx_v_stop, __pyx_v_step, __pyx_v_have_start, __pyx_v_have_stop, __pyx_v_have_step, 1); if (unlikely(__pyx_t_10 == ((int)-1))) __PYX_ERR(1, 770, __pyx_L1_error) + + /* "View.MemoryView":776 + * have_start, have_stop, have_step, + * True) + * new_ndim += 1 # <<<<<<<<<<<<<< + * + * if isinstance(memview, _memoryviewslice): + */ + __pyx_v_new_ndim = (__pyx_v_new_ndim + 1); + } + __pyx_L6:; + + /* "View.MemoryView":747 + * cdef bint have_start, have_stop, have_step + * + * for dim, index in enumerate(indices): # <<<<<<<<<<<<<< + * if PyIndex_Check(index): + * cindex = index + */ + } + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + + /* "View.MemoryView":778 + * new_ndim += 1 + * + * if isinstance(memview, _memoryviewslice): # <<<<<<<<<<<<<< + * return memoryview_fromslice(dst, new_ndim, + * memviewsliceobj.to_object_func, + */ + __pyx_t_1 = __Pyx_TypeCheck(((PyObject *)__pyx_v_memview), __pyx_memoryviewslice_type); + if (__pyx_t_1) { + + /* "View.MemoryView":779 + * + * if isinstance(memview, _memoryviewslice): + * return memoryview_fromslice(dst, new_ndim, # <<<<<<<<<<<<<< + * memviewsliceobj.to_object_func, + * memviewsliceobj.to_dtype_func, + */ + __Pyx_XDECREF((PyObject *)__pyx_r); + + /* "View.MemoryView":780 + * if isinstance(memview, _memoryviewslice): + * return memoryview_fromslice(dst, new_ndim, + * memviewsliceobj.to_object_func, # <<<<<<<<<<<<<< + * memviewsliceobj.to_dtype_func, + * memview.dtype_is_object) + */ + if (unlikely(!__pyx_v_memviewsliceobj)) { __Pyx_RaiseUnboundLocalError("memviewsliceobj"); __PYX_ERR(1, 780, __pyx_L1_error) } + + /* "View.MemoryView":781 + * return memoryview_fromslice(dst, new_ndim, + * memviewsliceobj.to_object_func, + * memviewsliceobj.to_dtype_func, # <<<<<<<<<<<<<< + * memview.dtype_is_object) + * else: + */ + if (unlikely(!__pyx_v_memviewsliceobj)) { __Pyx_RaiseUnboundLocalError("memviewsliceobj"); __PYX_ERR(1, 781, __pyx_L1_error) } + + /* "View.MemoryView":779 + * + * if isinstance(memview, _memoryviewslice): + * return memoryview_fromslice(dst, new_ndim, # <<<<<<<<<<<<<< + * memviewsliceobj.to_object_func, + * memviewsliceobj.to_dtype_func, + */ + __pyx_t_2 = __pyx_memoryview_fromslice(__pyx_v_dst, __pyx_v_new_ndim, __pyx_v_memviewsliceobj->to_object_func, __pyx_v_memviewsliceobj->to_dtype_func, __pyx_v_memview->dtype_is_object); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 779, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + if (!(likely(((__pyx_t_2) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_2, __pyx_memoryview_type))))) __PYX_ERR(1, 779, __pyx_L1_error) + __pyx_r = ((struct __pyx_memoryview_obj *)__pyx_t_2); + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":778 + * new_ndim += 1 + * + * if isinstance(memview, _memoryviewslice): # <<<<<<<<<<<<<< + * return memoryview_fromslice(dst, new_ndim, + * memviewsliceobj.to_object_func, + */ + } + + /* "View.MemoryView":784 + * memview.dtype_is_object) + * else: + * return memoryview_fromslice(dst, new_ndim, NULL, NULL, # <<<<<<<<<<<<<< + * memview.dtype_is_object) + * + */ + /*else*/ { + __Pyx_XDECREF((PyObject *)__pyx_r); + + /* "View.MemoryView":785 + * else: + * return memoryview_fromslice(dst, new_ndim, NULL, NULL, + * memview.dtype_is_object) # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_2 = __pyx_memoryview_fromslice(__pyx_v_dst, __pyx_v_new_ndim, NULL, NULL, __pyx_v_memview->dtype_is_object); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 784, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + + /* "View.MemoryView":784 + * memview.dtype_is_object) + * else: + * return memoryview_fromslice(dst, new_ndim, NULL, NULL, # <<<<<<<<<<<<<< + * memview.dtype_is_object) + * + */ + if (!(likely(((__pyx_t_2) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_2, __pyx_memoryview_type))))) __PYX_ERR(1, 784, __pyx_L1_error) + __pyx_r = ((struct __pyx_memoryview_obj *)__pyx_t_2); + __pyx_t_2 = 0; + goto __pyx_L0; + } + + /* "View.MemoryView":711 + * + * @cname('__pyx_memview_slice') + * cdef memoryview memview_slice(memoryview memview, object indices): # <<<<<<<<<<<<<< + * cdef int new_ndim = 0, suboffset_dim = -1, dim + * cdef bint negative_step + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_8); + __Pyx_AddTraceback("View.MemoryView.memview_slice", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XDECREF((PyObject *)__pyx_v_memviewsliceobj); + __Pyx_XDECREF(__pyx_v_index); + __Pyx_XGIVEREF((PyObject *)__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":793 + * + * @cname('__pyx_memoryview_slice_memviewslice') + * cdef int slice_memviewslice( # <<<<<<<<<<<<<< + * __Pyx_memviewslice *dst, + * Py_ssize_t shape, Py_ssize_t stride, Py_ssize_t suboffset, + */ + +static int __pyx_memoryview_slice_memviewslice(__Pyx_memviewslice *__pyx_v_dst, Py_ssize_t __pyx_v_shape, Py_ssize_t __pyx_v_stride, Py_ssize_t __pyx_v_suboffset, int __pyx_v_dim, int __pyx_v_new_ndim, int *__pyx_v_suboffset_dim, Py_ssize_t __pyx_v_start, Py_ssize_t __pyx_v_stop, Py_ssize_t __pyx_v_step, int __pyx_v_have_start, int __pyx_v_have_stop, int __pyx_v_have_step, int __pyx_v_is_slice) { + Py_ssize_t __pyx_v_new_shape; + int __pyx_v_negative_step; + int __pyx_r; + int __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save; + #endif + + /* "View.MemoryView":813 + * cdef bint negative_step + * + * if not is_slice: # <<<<<<<<<<<<<< + * + * if start < 0: + */ + __pyx_t_1 = (!__pyx_v_is_slice); + if (__pyx_t_1) { + + /* "View.MemoryView":815 + * if not is_slice: + * + * if start < 0: # <<<<<<<<<<<<<< + * start += shape + * if not 0 <= start < shape: + */ + __pyx_t_1 = (__pyx_v_start < 0); + if (__pyx_t_1) { + + /* "View.MemoryView":816 + * + * if start < 0: + * start += shape # <<<<<<<<<<<<<< + * if not 0 <= start < shape: + * _err_dim(PyExc_IndexError, "Index out of bounds (axis %d)", dim) + */ + __pyx_v_start = (__pyx_v_start + __pyx_v_shape); + + /* "View.MemoryView":815 + * if not is_slice: + * + * if start < 0: # <<<<<<<<<<<<<< + * start += shape + * if not 0 <= start < shape: + */ + } + + /* "View.MemoryView":817 + * if start < 0: + * start += shape + * if not 0 <= start < shape: # <<<<<<<<<<<<<< + * _err_dim(PyExc_IndexError, "Index out of bounds (axis %d)", dim) + * else: + */ + __pyx_t_1 = (0 <= __pyx_v_start); + if (__pyx_t_1) { + __pyx_t_1 = (__pyx_v_start < __pyx_v_shape); + } + __pyx_t_2 = (!__pyx_t_1); + if (__pyx_t_2) { + + /* "View.MemoryView":818 + * start += shape + * if not 0 <= start < shape: + * _err_dim(PyExc_IndexError, "Index out of bounds (axis %d)", dim) # <<<<<<<<<<<<<< + * else: + * + */ + __pyx_t_3 = __pyx_memoryview_err_dim(PyExc_IndexError, __pyx_kp_s_Index_out_of_bounds_axis_d, __pyx_v_dim); if (unlikely(__pyx_t_3 == ((int)-1))) __PYX_ERR(1, 818, __pyx_L1_error) + + /* "View.MemoryView":817 + * if start < 0: + * start += shape + * if not 0 <= start < shape: # <<<<<<<<<<<<<< + * _err_dim(PyExc_IndexError, "Index out of bounds (axis %d)", dim) + * else: + */ + } + + /* "View.MemoryView":813 + * cdef bint negative_step + * + * if not is_slice: # <<<<<<<<<<<<<< + * + * if start < 0: + */ + goto __pyx_L3; + } + + /* "View.MemoryView":821 + * else: + * + * if have_step: # <<<<<<<<<<<<<< + * negative_step = step < 0 + * if step == 0: + */ + /*else*/ { + __pyx_t_2 = (__pyx_v_have_step != 0); + if (__pyx_t_2) { + + /* "View.MemoryView":822 + * + * if have_step: + * negative_step = step < 0 # <<<<<<<<<<<<<< + * if step == 0: + * _err_dim(PyExc_ValueError, "Step may not be zero (axis %d)", dim) + */ + __pyx_v_negative_step = (__pyx_v_step < 0); + + /* "View.MemoryView":823 + * if have_step: + * negative_step = step < 0 + * if step == 0: # <<<<<<<<<<<<<< + * _err_dim(PyExc_ValueError, "Step may not be zero (axis %d)", dim) + * else: + */ + __pyx_t_2 = (__pyx_v_step == 0); + if (__pyx_t_2) { + + /* "View.MemoryView":824 + * negative_step = step < 0 + * if step == 0: + * _err_dim(PyExc_ValueError, "Step may not be zero (axis %d)", dim) # <<<<<<<<<<<<<< + * else: + * negative_step = False + */ + __pyx_t_3 = __pyx_memoryview_err_dim(PyExc_ValueError, __pyx_kp_s_Step_may_not_be_zero_axis_d, __pyx_v_dim); if (unlikely(__pyx_t_3 == ((int)-1))) __PYX_ERR(1, 824, __pyx_L1_error) + + /* "View.MemoryView":823 + * if have_step: + * negative_step = step < 0 + * if step == 0: # <<<<<<<<<<<<<< + * _err_dim(PyExc_ValueError, "Step may not be zero (axis %d)", dim) + * else: + */ + } + + /* "View.MemoryView":821 + * else: + * + * if have_step: # <<<<<<<<<<<<<< + * negative_step = step < 0 + * if step == 0: + */ + goto __pyx_L6; + } + + /* "View.MemoryView":826 + * _err_dim(PyExc_ValueError, "Step may not be zero (axis %d)", dim) + * else: + * negative_step = False # <<<<<<<<<<<<<< + * step = 1 + * + */ + /*else*/ { + __pyx_v_negative_step = 0; + + /* "View.MemoryView":827 + * else: + * negative_step = False + * step = 1 # <<<<<<<<<<<<<< + * + * + */ + __pyx_v_step = 1; + } + __pyx_L6:; + + /* "View.MemoryView":830 + * + * + * if have_start: # <<<<<<<<<<<<<< + * if start < 0: + * start += shape + */ + __pyx_t_2 = (__pyx_v_have_start != 0); + if (__pyx_t_2) { + + /* "View.MemoryView":831 + * + * if have_start: + * if start < 0: # <<<<<<<<<<<<<< + * start += shape + * if start < 0: + */ + __pyx_t_2 = (__pyx_v_start < 0); + if (__pyx_t_2) { + + /* "View.MemoryView":832 + * if have_start: + * if start < 0: + * start += shape # <<<<<<<<<<<<<< + * if start < 0: + * start = 0 + */ + __pyx_v_start = (__pyx_v_start + __pyx_v_shape); + + /* "View.MemoryView":833 + * if start < 0: + * start += shape + * if start < 0: # <<<<<<<<<<<<<< + * start = 0 + * elif start >= shape: + */ + __pyx_t_2 = (__pyx_v_start < 0); + if (__pyx_t_2) { + + /* "View.MemoryView":834 + * start += shape + * if start < 0: + * start = 0 # <<<<<<<<<<<<<< + * elif start >= shape: + * if negative_step: + */ + __pyx_v_start = 0; + + /* "View.MemoryView":833 + * if start < 0: + * start += shape + * if start < 0: # <<<<<<<<<<<<<< + * start = 0 + * elif start >= shape: + */ + } + + /* "View.MemoryView":831 + * + * if have_start: + * if start < 0: # <<<<<<<<<<<<<< + * start += shape + * if start < 0: + */ + goto __pyx_L9; + } + + /* "View.MemoryView":835 + * if start < 0: + * start = 0 + * elif start >= shape: # <<<<<<<<<<<<<< + * if negative_step: + * start = shape - 1 + */ + __pyx_t_2 = (__pyx_v_start >= __pyx_v_shape); + if (__pyx_t_2) { + + /* "View.MemoryView":836 + * start = 0 + * elif start >= shape: + * if negative_step: # <<<<<<<<<<<<<< + * start = shape - 1 + * else: + */ + if (__pyx_v_negative_step) { + + /* "View.MemoryView":837 + * elif start >= shape: + * if negative_step: + * start = shape - 1 # <<<<<<<<<<<<<< + * else: + * start = shape + */ + __pyx_v_start = (__pyx_v_shape - 1); + + /* "View.MemoryView":836 + * start = 0 + * elif start >= shape: + * if negative_step: # <<<<<<<<<<<<<< + * start = shape - 1 + * else: + */ + goto __pyx_L11; + } + + /* "View.MemoryView":839 + * start = shape - 1 + * else: + * start = shape # <<<<<<<<<<<<<< + * else: + * if negative_step: + */ + /*else*/ { + __pyx_v_start = __pyx_v_shape; + } + __pyx_L11:; + + /* "View.MemoryView":835 + * if start < 0: + * start = 0 + * elif start >= shape: # <<<<<<<<<<<<<< + * if negative_step: + * start = shape - 1 + */ + } + __pyx_L9:; + + /* "View.MemoryView":830 + * + * + * if have_start: # <<<<<<<<<<<<<< + * if start < 0: + * start += shape + */ + goto __pyx_L8; + } + + /* "View.MemoryView":841 + * start = shape + * else: + * if negative_step: # <<<<<<<<<<<<<< + * start = shape - 1 + * else: + */ + /*else*/ { + if (__pyx_v_negative_step) { + + /* "View.MemoryView":842 + * else: + * if negative_step: + * start = shape - 1 # <<<<<<<<<<<<<< + * else: + * start = 0 + */ + __pyx_v_start = (__pyx_v_shape - 1); + + /* "View.MemoryView":841 + * start = shape + * else: + * if negative_step: # <<<<<<<<<<<<<< + * start = shape - 1 + * else: + */ + goto __pyx_L12; + } + + /* "View.MemoryView":844 + * start = shape - 1 + * else: + * start = 0 # <<<<<<<<<<<<<< + * + * if have_stop: + */ + /*else*/ { + __pyx_v_start = 0; + } + __pyx_L12:; + } + __pyx_L8:; + + /* "View.MemoryView":846 + * start = 0 + * + * if have_stop: # <<<<<<<<<<<<<< + * if stop < 0: + * stop += shape + */ + __pyx_t_2 = (__pyx_v_have_stop != 0); + if (__pyx_t_2) { + + /* "View.MemoryView":847 + * + * if have_stop: + * if stop < 0: # <<<<<<<<<<<<<< + * stop += shape + * if stop < 0: + */ + __pyx_t_2 = (__pyx_v_stop < 0); + if (__pyx_t_2) { + + /* "View.MemoryView":848 + * if have_stop: + * if stop < 0: + * stop += shape # <<<<<<<<<<<<<< + * if stop < 0: + * stop = 0 + */ + __pyx_v_stop = (__pyx_v_stop + __pyx_v_shape); + + /* "View.MemoryView":849 + * if stop < 0: + * stop += shape + * if stop < 0: # <<<<<<<<<<<<<< + * stop = 0 + * elif stop > shape: + */ + __pyx_t_2 = (__pyx_v_stop < 0); + if (__pyx_t_2) { + + /* "View.MemoryView":850 + * stop += shape + * if stop < 0: + * stop = 0 # <<<<<<<<<<<<<< + * elif stop > shape: + * stop = shape + */ + __pyx_v_stop = 0; + + /* "View.MemoryView":849 + * if stop < 0: + * stop += shape + * if stop < 0: # <<<<<<<<<<<<<< + * stop = 0 + * elif stop > shape: + */ + } + + /* "View.MemoryView":847 + * + * if have_stop: + * if stop < 0: # <<<<<<<<<<<<<< + * stop += shape + * if stop < 0: + */ + goto __pyx_L14; + } + + /* "View.MemoryView":851 + * if stop < 0: + * stop = 0 + * elif stop > shape: # <<<<<<<<<<<<<< + * stop = shape + * else: + */ + __pyx_t_2 = (__pyx_v_stop > __pyx_v_shape); + if (__pyx_t_2) { + + /* "View.MemoryView":852 + * stop = 0 + * elif stop > shape: + * stop = shape # <<<<<<<<<<<<<< + * else: + * if negative_step: + */ + __pyx_v_stop = __pyx_v_shape; + + /* "View.MemoryView":851 + * if stop < 0: + * stop = 0 + * elif stop > shape: # <<<<<<<<<<<<<< + * stop = shape + * else: + */ + } + __pyx_L14:; + + /* "View.MemoryView":846 + * start = 0 + * + * if have_stop: # <<<<<<<<<<<<<< + * if stop < 0: + * stop += shape + */ + goto __pyx_L13; + } + + /* "View.MemoryView":854 + * stop = shape + * else: + * if negative_step: # <<<<<<<<<<<<<< + * stop = -1 + * else: + */ + /*else*/ { + if (__pyx_v_negative_step) { + + /* "View.MemoryView":855 + * else: + * if negative_step: + * stop = -1 # <<<<<<<<<<<<<< + * else: + * stop = shape + */ + __pyx_v_stop = -1L; + + /* "View.MemoryView":854 + * stop = shape + * else: + * if negative_step: # <<<<<<<<<<<<<< + * stop = -1 + * else: + */ + goto __pyx_L16; + } + + /* "View.MemoryView":857 + * stop = -1 + * else: + * stop = shape # <<<<<<<<<<<<<< + * + * + */ + /*else*/ { + __pyx_v_stop = __pyx_v_shape; + } + __pyx_L16:; + } + __pyx_L13:; + + /* "View.MemoryView":861 + * + * with cython.cdivision(True): + * new_shape = (stop - start) // step # <<<<<<<<<<<<<< + * + * if (stop - start) - step * new_shape: + */ + __pyx_v_new_shape = ((__pyx_v_stop - __pyx_v_start) / __pyx_v_step); + + /* "View.MemoryView":863 + * new_shape = (stop - start) // step + * + * if (stop - start) - step * new_shape: # <<<<<<<<<<<<<< + * new_shape += 1 + * + */ + __pyx_t_2 = (((__pyx_v_stop - __pyx_v_start) - (__pyx_v_step * __pyx_v_new_shape)) != 0); + if (__pyx_t_2) { + + /* "View.MemoryView":864 + * + * if (stop - start) - step * new_shape: + * new_shape += 1 # <<<<<<<<<<<<<< + * + * if new_shape < 0: + */ + __pyx_v_new_shape = (__pyx_v_new_shape + 1); + + /* "View.MemoryView":863 + * new_shape = (stop - start) // step + * + * if (stop - start) - step * new_shape: # <<<<<<<<<<<<<< + * new_shape += 1 + * + */ + } + + /* "View.MemoryView":866 + * new_shape += 1 + * + * if new_shape < 0: # <<<<<<<<<<<<<< + * new_shape = 0 + * + */ + __pyx_t_2 = (__pyx_v_new_shape < 0); + if (__pyx_t_2) { + + /* "View.MemoryView":867 + * + * if new_shape < 0: + * new_shape = 0 # <<<<<<<<<<<<<< + * + * + */ + __pyx_v_new_shape = 0; + + /* "View.MemoryView":866 + * new_shape += 1 + * + * if new_shape < 0: # <<<<<<<<<<<<<< + * new_shape = 0 + * + */ + } + + /* "View.MemoryView":870 + * + * + * dst.strides[new_ndim] = stride * step # <<<<<<<<<<<<<< + * dst.shape[new_ndim] = new_shape + * dst.suboffsets[new_ndim] = suboffset + */ + (__pyx_v_dst->strides[__pyx_v_new_ndim]) = (__pyx_v_stride * __pyx_v_step); + + /* "View.MemoryView":871 + * + * dst.strides[new_ndim] = stride * step + * dst.shape[new_ndim] = new_shape # <<<<<<<<<<<<<< + * dst.suboffsets[new_ndim] = suboffset + * + */ + (__pyx_v_dst->shape[__pyx_v_new_ndim]) = __pyx_v_new_shape; + + /* "View.MemoryView":872 + * dst.strides[new_ndim] = stride * step + * dst.shape[new_ndim] = new_shape + * dst.suboffsets[new_ndim] = suboffset # <<<<<<<<<<<<<< + * + * + */ + (__pyx_v_dst->suboffsets[__pyx_v_new_ndim]) = __pyx_v_suboffset; + } + __pyx_L3:; + + /* "View.MemoryView":875 + * + * + * if suboffset_dim[0] < 0: # <<<<<<<<<<<<<< + * dst.data += start * stride + * else: + */ + __pyx_t_2 = ((__pyx_v_suboffset_dim[0]) < 0); + if (__pyx_t_2) { + + /* "View.MemoryView":876 + * + * if suboffset_dim[0] < 0: + * dst.data += start * stride # <<<<<<<<<<<<<< + * else: + * dst.suboffsets[suboffset_dim[0]] += start * stride + */ + __pyx_v_dst->data = (__pyx_v_dst->data + (__pyx_v_start * __pyx_v_stride)); + + /* "View.MemoryView":875 + * + * + * if suboffset_dim[0] < 0: # <<<<<<<<<<<<<< + * dst.data += start * stride + * else: + */ + goto __pyx_L19; + } + + /* "View.MemoryView":878 + * dst.data += start * stride + * else: + * dst.suboffsets[suboffset_dim[0]] += start * stride # <<<<<<<<<<<<<< + * + * if suboffset >= 0: + */ + /*else*/ { + __pyx_t_3 = (__pyx_v_suboffset_dim[0]); + (__pyx_v_dst->suboffsets[__pyx_t_3]) = ((__pyx_v_dst->suboffsets[__pyx_t_3]) + (__pyx_v_start * __pyx_v_stride)); + } + __pyx_L19:; + + /* "View.MemoryView":880 + * dst.suboffsets[suboffset_dim[0]] += start * stride + * + * if suboffset >= 0: # <<<<<<<<<<<<<< + * if not is_slice: + * if new_ndim == 0: + */ + __pyx_t_2 = (__pyx_v_suboffset >= 0); + if (__pyx_t_2) { + + /* "View.MemoryView":881 + * + * if suboffset >= 0: + * if not is_slice: # <<<<<<<<<<<<<< + * if new_ndim == 0: + * dst.data = ( dst.data)[0] + suboffset + */ + __pyx_t_2 = (!__pyx_v_is_slice); + if (__pyx_t_2) { + + /* "View.MemoryView":882 + * if suboffset >= 0: + * if not is_slice: + * if new_ndim == 0: # <<<<<<<<<<<<<< + * dst.data = ( dst.data)[0] + suboffset + * else: + */ + __pyx_t_2 = (__pyx_v_new_ndim == 0); + if (__pyx_t_2) { + + /* "View.MemoryView":883 + * if not is_slice: + * if new_ndim == 0: + * dst.data = ( dst.data)[0] + suboffset # <<<<<<<<<<<<<< + * else: + * _err_dim(PyExc_IndexError, "All dimensions preceding dimension %d " + */ + __pyx_v_dst->data = ((((char **)__pyx_v_dst->data)[0]) + __pyx_v_suboffset); + + /* "View.MemoryView":882 + * if suboffset >= 0: + * if not is_slice: + * if new_ndim == 0: # <<<<<<<<<<<<<< + * dst.data = ( dst.data)[0] + suboffset + * else: + */ + goto __pyx_L22; + } + + /* "View.MemoryView":885 + * dst.data = ( dst.data)[0] + suboffset + * else: + * _err_dim(PyExc_IndexError, "All dimensions preceding dimension %d " # <<<<<<<<<<<<<< + * "must be indexed and not sliced", dim) + * else: + */ + /*else*/ { + + /* "View.MemoryView":886 + * else: + * _err_dim(PyExc_IndexError, "All dimensions preceding dimension %d " + * "must be indexed and not sliced", dim) # <<<<<<<<<<<<<< + * else: + * suboffset_dim[0] = new_ndim + */ + __pyx_t_3 = __pyx_memoryview_err_dim(PyExc_IndexError, __pyx_kp_s_All_dimensions_preceding_dimensi, __pyx_v_dim); if (unlikely(__pyx_t_3 == ((int)-1))) __PYX_ERR(1, 885, __pyx_L1_error) + } + __pyx_L22:; + + /* "View.MemoryView":881 + * + * if suboffset >= 0: + * if not is_slice: # <<<<<<<<<<<<<< + * if new_ndim == 0: + * dst.data = ( dst.data)[0] + suboffset + */ + goto __pyx_L21; + } + + /* "View.MemoryView":888 + * "must be indexed and not sliced", dim) + * else: + * suboffset_dim[0] = new_ndim # <<<<<<<<<<<<<< + * + * return 0 + */ + /*else*/ { + (__pyx_v_suboffset_dim[0]) = __pyx_v_new_ndim; + } + __pyx_L21:; + + /* "View.MemoryView":880 + * dst.suboffsets[suboffset_dim[0]] += start * stride + * + * if suboffset >= 0: # <<<<<<<<<<<<<< + * if not is_slice: + * if new_ndim == 0: + */ + } + + /* "View.MemoryView":890 + * suboffset_dim[0] = new_ndim + * + * return 0 # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = 0; + goto __pyx_L0; + + /* "View.MemoryView":793 + * + * @cname('__pyx_memoryview_slice_memviewslice') + * cdef int slice_memviewslice( # <<<<<<<<<<<<<< + * __Pyx_memviewslice *dst, + * Py_ssize_t shape, Py_ssize_t stride, Py_ssize_t suboffset, + */ + + /* function exit code */ + __pyx_L1_error:; + #ifdef WITH_THREAD + __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + __Pyx_AddTraceback("View.MemoryView.slice_memviewslice", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":896 + * + * @cname('__pyx_pybuffer_index') + * cdef char *pybuffer_index(Py_buffer *view, char *bufp, Py_ssize_t index, # <<<<<<<<<<<<<< + * Py_ssize_t dim) except NULL: + * cdef Py_ssize_t shape, stride, suboffset = -1 + */ + +static char *__pyx_pybuffer_index(Py_buffer *__pyx_v_view, char *__pyx_v_bufp, Py_ssize_t __pyx_v_index, Py_ssize_t __pyx_v_dim) { + Py_ssize_t __pyx_v_shape; + Py_ssize_t __pyx_v_stride; + Py_ssize_t __pyx_v_suboffset; + Py_ssize_t __pyx_v_itemsize; + char *__pyx_v_resultp; + char *__pyx_r; + __Pyx_RefNannyDeclarations + Py_ssize_t __pyx_t_1; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + Py_UCS4 __pyx_t_4; + PyObject *__pyx_t_5 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("pybuffer_index", 1); + + /* "View.MemoryView":898 + * cdef char *pybuffer_index(Py_buffer *view, char *bufp, Py_ssize_t index, + * Py_ssize_t dim) except NULL: + * cdef Py_ssize_t shape, stride, suboffset = -1 # <<<<<<<<<<<<<< + * cdef Py_ssize_t itemsize = view.itemsize + * cdef char *resultp + */ + __pyx_v_suboffset = -1L; + + /* "View.MemoryView":899 + * Py_ssize_t dim) except NULL: + * cdef Py_ssize_t shape, stride, suboffset = -1 + * cdef Py_ssize_t itemsize = view.itemsize # <<<<<<<<<<<<<< + * cdef char *resultp + * + */ + __pyx_t_1 = __pyx_v_view->itemsize; + __pyx_v_itemsize = __pyx_t_1; + + /* "View.MemoryView":902 + * cdef char *resultp + * + * if view.ndim == 0: # <<<<<<<<<<<<<< + * shape = view.len // itemsize + * stride = itemsize + */ + __pyx_t_2 = (__pyx_v_view->ndim == 0); + if (__pyx_t_2) { + + /* "View.MemoryView":903 + * + * if view.ndim == 0: + * shape = view.len // itemsize # <<<<<<<<<<<<<< + * stride = itemsize + * else: + */ + if (unlikely(__pyx_v_itemsize == 0)) { + PyErr_SetString(PyExc_ZeroDivisionError, "integer division or modulo by zero"); + __PYX_ERR(1, 903, __pyx_L1_error) + } + else if (sizeof(Py_ssize_t) == sizeof(long) && (!(((Py_ssize_t)-1) > 0)) && unlikely(__pyx_v_itemsize == (Py_ssize_t)-1) && unlikely(__Pyx_UNARY_NEG_WOULD_OVERFLOW(__pyx_v_view->len))) { + PyErr_SetString(PyExc_OverflowError, "value too large to perform division"); + __PYX_ERR(1, 903, __pyx_L1_error) + } + __pyx_v_shape = __Pyx_div_Py_ssize_t(__pyx_v_view->len, __pyx_v_itemsize); + + /* "View.MemoryView":904 + * if view.ndim == 0: + * shape = view.len // itemsize + * stride = itemsize # <<<<<<<<<<<<<< + * else: + * shape = view.shape[dim] + */ + __pyx_v_stride = __pyx_v_itemsize; + + /* "View.MemoryView":902 + * cdef char *resultp + * + * if view.ndim == 0: # <<<<<<<<<<<<<< + * shape = view.len // itemsize + * stride = itemsize + */ + goto __pyx_L3; + } + + /* "View.MemoryView":906 + * stride = itemsize + * else: + * shape = view.shape[dim] # <<<<<<<<<<<<<< + * stride = view.strides[dim] + * if view.suboffsets != NULL: + */ + /*else*/ { + __pyx_v_shape = (__pyx_v_view->shape[__pyx_v_dim]); + + /* "View.MemoryView":907 + * else: + * shape = view.shape[dim] + * stride = view.strides[dim] # <<<<<<<<<<<<<< + * if view.suboffsets != NULL: + * suboffset = view.suboffsets[dim] + */ + __pyx_v_stride = (__pyx_v_view->strides[__pyx_v_dim]); + + /* "View.MemoryView":908 + * shape = view.shape[dim] + * stride = view.strides[dim] + * if view.suboffsets != NULL: # <<<<<<<<<<<<<< + * suboffset = view.suboffsets[dim] + * + */ + __pyx_t_2 = (__pyx_v_view->suboffsets != NULL); + if (__pyx_t_2) { + + /* "View.MemoryView":909 + * stride = view.strides[dim] + * if view.suboffsets != NULL: + * suboffset = view.suboffsets[dim] # <<<<<<<<<<<<<< + * + * if index < 0: + */ + __pyx_v_suboffset = (__pyx_v_view->suboffsets[__pyx_v_dim]); + + /* "View.MemoryView":908 + * shape = view.shape[dim] + * stride = view.strides[dim] + * if view.suboffsets != NULL: # <<<<<<<<<<<<<< + * suboffset = view.suboffsets[dim] + * + */ + } + } + __pyx_L3:; + + /* "View.MemoryView":911 + * suboffset = view.suboffsets[dim] + * + * if index < 0: # <<<<<<<<<<<<<< + * index += view.shape[dim] + * if index < 0: + */ + __pyx_t_2 = (__pyx_v_index < 0); + if (__pyx_t_2) { + + /* "View.MemoryView":912 + * + * if index < 0: + * index += view.shape[dim] # <<<<<<<<<<<<<< + * if index < 0: + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" + */ + __pyx_v_index = (__pyx_v_index + (__pyx_v_view->shape[__pyx_v_dim])); + + /* "View.MemoryView":913 + * if index < 0: + * index += view.shape[dim] + * if index < 0: # <<<<<<<<<<<<<< + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" + * + */ + __pyx_t_2 = (__pyx_v_index < 0); + if (unlikely(__pyx_t_2)) { + + /* "View.MemoryView":914 + * index += view.shape[dim] + * if index < 0: + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" # <<<<<<<<<<<<<< + * + * if index >= shape: + */ + __pyx_t_3 = PyTuple_New(3); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 914, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_1 = 0; + __pyx_t_4 = 127; + __Pyx_INCREF(__pyx_kp_u_Out_of_bounds_on_buffer_access_a); + __pyx_t_1 += 37; + __Pyx_GIVEREF(__pyx_kp_u_Out_of_bounds_on_buffer_access_a); + PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_kp_u_Out_of_bounds_on_buffer_access_a); + __pyx_t_5 = __Pyx_PyUnicode_From_Py_ssize_t(__pyx_v_dim, 0, ' ', 'd'); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 914, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_1 += __Pyx_PyUnicode_GET_LENGTH(__pyx_t_5); + __Pyx_GIVEREF(__pyx_t_5); + PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_t_5); + __pyx_t_5 = 0; + __Pyx_INCREF(__pyx_kp_u__7); + __pyx_t_1 += 1; + __Pyx_GIVEREF(__pyx_kp_u__7); + PyTuple_SET_ITEM(__pyx_t_3, 2, __pyx_kp_u__7); + __pyx_t_5 = __Pyx_PyUnicode_Join(__pyx_t_3, 3, __pyx_t_1, __pyx_t_4); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 914, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __Pyx_Raise(__pyx_builtin_IndexError, __pyx_t_5, 0, 0); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __PYX_ERR(1, 914, __pyx_L1_error) + + /* "View.MemoryView":913 + * if index < 0: + * index += view.shape[dim] + * if index < 0: # <<<<<<<<<<<<<< + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" + * + */ + } + + /* "View.MemoryView":911 + * suboffset = view.suboffsets[dim] + * + * if index < 0: # <<<<<<<<<<<<<< + * index += view.shape[dim] + * if index < 0: + */ + } + + /* "View.MemoryView":916 + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" + * + * if index >= shape: # <<<<<<<<<<<<<< + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" + * + */ + __pyx_t_2 = (__pyx_v_index >= __pyx_v_shape); + if (unlikely(__pyx_t_2)) { + + /* "View.MemoryView":917 + * + * if index >= shape: + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" # <<<<<<<<<<<<<< + * + * resultp = bufp + index * stride + */ + __pyx_t_5 = PyTuple_New(3); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 917, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_1 = 0; + __pyx_t_4 = 127; + __Pyx_INCREF(__pyx_kp_u_Out_of_bounds_on_buffer_access_a); + __pyx_t_1 += 37; + __Pyx_GIVEREF(__pyx_kp_u_Out_of_bounds_on_buffer_access_a); + PyTuple_SET_ITEM(__pyx_t_5, 0, __pyx_kp_u_Out_of_bounds_on_buffer_access_a); + __pyx_t_3 = __Pyx_PyUnicode_From_Py_ssize_t(__pyx_v_dim, 0, ' ', 'd'); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 917, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_1 += __Pyx_PyUnicode_GET_LENGTH(__pyx_t_3); + __Pyx_GIVEREF(__pyx_t_3); + PyTuple_SET_ITEM(__pyx_t_5, 1, __pyx_t_3); + __pyx_t_3 = 0; + __Pyx_INCREF(__pyx_kp_u__7); + __pyx_t_1 += 1; + __Pyx_GIVEREF(__pyx_kp_u__7); + PyTuple_SET_ITEM(__pyx_t_5, 2, __pyx_kp_u__7); + __pyx_t_3 = __Pyx_PyUnicode_Join(__pyx_t_5, 3, __pyx_t_1, __pyx_t_4); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 917, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_Raise(__pyx_builtin_IndexError, __pyx_t_3, 0, 0); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __PYX_ERR(1, 917, __pyx_L1_error) + + /* "View.MemoryView":916 + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" + * + * if index >= shape: # <<<<<<<<<<<<<< + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" + * + */ + } + + /* "View.MemoryView":919 + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" + * + * resultp = bufp + index * stride # <<<<<<<<<<<<<< + * if suboffset >= 0: + * resultp = ( resultp)[0] + suboffset + */ + __pyx_v_resultp = (__pyx_v_bufp + (__pyx_v_index * __pyx_v_stride)); + + /* "View.MemoryView":920 + * + * resultp = bufp + index * stride + * if suboffset >= 0: # <<<<<<<<<<<<<< + * resultp = ( resultp)[0] + suboffset + * + */ + __pyx_t_2 = (__pyx_v_suboffset >= 0); + if (__pyx_t_2) { + + /* "View.MemoryView":921 + * resultp = bufp + index * stride + * if suboffset >= 0: + * resultp = ( resultp)[0] + suboffset # <<<<<<<<<<<<<< + * + * return resultp + */ + __pyx_v_resultp = ((((char **)__pyx_v_resultp)[0]) + __pyx_v_suboffset); + + /* "View.MemoryView":920 + * + * resultp = bufp + index * stride + * if suboffset >= 0: # <<<<<<<<<<<<<< + * resultp = ( resultp)[0] + suboffset + * + */ + } + + /* "View.MemoryView":923 + * resultp = ( resultp)[0] + suboffset + * + * return resultp # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = __pyx_v_resultp; + goto __pyx_L0; + + /* "View.MemoryView":896 + * + * @cname('__pyx_pybuffer_index') + * cdef char *pybuffer_index(Py_buffer *view, char *bufp, Py_ssize_t index, # <<<<<<<<<<<<<< + * Py_ssize_t dim) except NULL: + * cdef Py_ssize_t shape, stride, suboffset = -1 + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_3); + __Pyx_XDECREF(__pyx_t_5); + __Pyx_AddTraceback("View.MemoryView.pybuffer_index", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":929 + * + * @cname('__pyx_memslice_transpose') + * cdef int transpose_memslice(__Pyx_memviewslice *memslice) except -1 nogil: # <<<<<<<<<<<<<< + * cdef int ndim = memslice.memview.view.ndim + * + */ + +static int __pyx_memslice_transpose(__Pyx_memviewslice *__pyx_v_memslice) { + int __pyx_v_ndim; + Py_ssize_t *__pyx_v_shape; + Py_ssize_t *__pyx_v_strides; + int __pyx_v_i; + int __pyx_v_j; + int __pyx_r; + int __pyx_t_1; + Py_ssize_t *__pyx_t_2; + long __pyx_t_3; + long __pyx_t_4; + Py_ssize_t __pyx_t_5; + Py_ssize_t __pyx_t_6; + int __pyx_t_7; + int __pyx_t_8; + int __pyx_t_9; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save; + #endif + + /* "View.MemoryView":930 + * @cname('__pyx_memslice_transpose') + * cdef int transpose_memslice(__Pyx_memviewslice *memslice) except -1 nogil: + * cdef int ndim = memslice.memview.view.ndim # <<<<<<<<<<<<<< + * + * cdef Py_ssize_t *shape = memslice.shape + */ + __pyx_t_1 = __pyx_v_memslice->memview->view.ndim; + __pyx_v_ndim = __pyx_t_1; + + /* "View.MemoryView":932 + * cdef int ndim = memslice.memview.view.ndim + * + * cdef Py_ssize_t *shape = memslice.shape # <<<<<<<<<<<<<< + * cdef Py_ssize_t *strides = memslice.strides + * + */ + __pyx_t_2 = __pyx_v_memslice->shape; + __pyx_v_shape = __pyx_t_2; + + /* "View.MemoryView":933 + * + * cdef Py_ssize_t *shape = memslice.shape + * cdef Py_ssize_t *strides = memslice.strides # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_2 = __pyx_v_memslice->strides; + __pyx_v_strides = __pyx_t_2; + + /* "View.MemoryView":937 + * + * cdef int i, j + * for i in range(ndim // 2): # <<<<<<<<<<<<<< + * j = ndim - 1 - i + * strides[i], strides[j] = strides[j], strides[i] + */ + __pyx_t_3 = __Pyx_div_long(__pyx_v_ndim, 2); + __pyx_t_4 = __pyx_t_3; + for (__pyx_t_1 = 0; __pyx_t_1 < __pyx_t_4; __pyx_t_1+=1) { + __pyx_v_i = __pyx_t_1; + + /* "View.MemoryView":938 + * cdef int i, j + * for i in range(ndim // 2): + * j = ndim - 1 - i # <<<<<<<<<<<<<< + * strides[i], strides[j] = strides[j], strides[i] + * shape[i], shape[j] = shape[j], shape[i] + */ + __pyx_v_j = ((__pyx_v_ndim - 1) - __pyx_v_i); + + /* "View.MemoryView":939 + * for i in range(ndim // 2): + * j = ndim - 1 - i + * strides[i], strides[j] = strides[j], strides[i] # <<<<<<<<<<<<<< + * shape[i], shape[j] = shape[j], shape[i] + * + */ + __pyx_t_5 = (__pyx_v_strides[__pyx_v_j]); + __pyx_t_6 = (__pyx_v_strides[__pyx_v_i]); + (__pyx_v_strides[__pyx_v_i]) = __pyx_t_5; + (__pyx_v_strides[__pyx_v_j]) = __pyx_t_6; + + /* "View.MemoryView":940 + * j = ndim - 1 - i + * strides[i], strides[j] = strides[j], strides[i] + * shape[i], shape[j] = shape[j], shape[i] # <<<<<<<<<<<<<< + * + * if memslice.suboffsets[i] >= 0 or memslice.suboffsets[j] >= 0: + */ + __pyx_t_6 = (__pyx_v_shape[__pyx_v_j]); + __pyx_t_5 = (__pyx_v_shape[__pyx_v_i]); + (__pyx_v_shape[__pyx_v_i]) = __pyx_t_6; + (__pyx_v_shape[__pyx_v_j]) = __pyx_t_5; + + /* "View.MemoryView":942 + * shape[i], shape[j] = shape[j], shape[i] + * + * if memslice.suboffsets[i] >= 0 or memslice.suboffsets[j] >= 0: # <<<<<<<<<<<<<< + * _err(PyExc_ValueError, "Cannot transpose memoryview with indirect dimensions") + * + */ + __pyx_t_8 = ((__pyx_v_memslice->suboffsets[__pyx_v_i]) >= 0); + if (!__pyx_t_8) { + } else { + __pyx_t_7 = __pyx_t_8; + goto __pyx_L6_bool_binop_done; + } + __pyx_t_8 = ((__pyx_v_memslice->suboffsets[__pyx_v_j]) >= 0); + __pyx_t_7 = __pyx_t_8; + __pyx_L6_bool_binop_done:; + if (__pyx_t_7) { + + /* "View.MemoryView":943 + * + * if memslice.suboffsets[i] >= 0 or memslice.suboffsets[j] >= 0: + * _err(PyExc_ValueError, "Cannot transpose memoryview with indirect dimensions") # <<<<<<<<<<<<<< + * + * return 0 + */ + __pyx_t_9 = __pyx_memoryview_err(PyExc_ValueError, __pyx_kp_s_Cannot_transpose_memoryview_with); if (unlikely(__pyx_t_9 == ((int)-1))) __PYX_ERR(1, 943, __pyx_L1_error) + + /* "View.MemoryView":942 + * shape[i], shape[j] = shape[j], shape[i] + * + * if memslice.suboffsets[i] >= 0 or memslice.suboffsets[j] >= 0: # <<<<<<<<<<<<<< + * _err(PyExc_ValueError, "Cannot transpose memoryview with indirect dimensions") + * + */ + } + } + + /* "View.MemoryView":945 + * _err(PyExc_ValueError, "Cannot transpose memoryview with indirect dimensions") + * + * return 0 # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = 0; + goto __pyx_L0; + + /* "View.MemoryView":929 + * + * @cname('__pyx_memslice_transpose') + * cdef int transpose_memslice(__Pyx_memviewslice *memslice) except -1 nogil: # <<<<<<<<<<<<<< + * cdef int ndim = memslice.memview.view.ndim + * + */ + + /* function exit code */ + __pyx_L1_error:; + #ifdef WITH_THREAD + __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + __Pyx_AddTraceback("View.MemoryView.transpose_memslice", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":963 + * cdef int (*to_dtype_func)(char *, object) except 0 + * + * def __dealloc__(self): # <<<<<<<<<<<<<< + * __PYX_XCLEAR_MEMVIEW(&self.from_slice, 1) + * + */ + +/* Python wrapper */ +static void __pyx_memoryviewslice___dealloc__(PyObject *__pyx_v_self); /*proto*/ +static void __pyx_memoryviewslice___dealloc__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__dealloc__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_memoryviewslice___pyx_pf_15View_dot_MemoryView_16_memoryviewslice___dealloc__(((struct __pyx_memoryviewslice_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); +} + +static void __pyx_memoryviewslice___pyx_pf_15View_dot_MemoryView_16_memoryviewslice___dealloc__(struct __pyx_memoryviewslice_obj *__pyx_v_self) { + + /* "View.MemoryView":964 + * + * def __dealloc__(self): + * __PYX_XCLEAR_MEMVIEW(&self.from_slice, 1) # <<<<<<<<<<<<<< + * + * cdef convert_item_to_object(self, char *itemp): + */ + __PYX_XCLEAR_MEMVIEW((&__pyx_v_self->from_slice), 1); + + /* "View.MemoryView":963 + * cdef int (*to_dtype_func)(char *, object) except 0 + * + * def __dealloc__(self): # <<<<<<<<<<<<<< + * __PYX_XCLEAR_MEMVIEW(&self.from_slice, 1) + * + */ + + /* function exit code */ +} + +/* "View.MemoryView":966 + * __PYX_XCLEAR_MEMVIEW(&self.from_slice, 1) + * + * cdef convert_item_to_object(self, char *itemp): # <<<<<<<<<<<<<< + * if self.to_object_func != NULL: + * return self.to_object_func(itemp) + */ + +static PyObject *__pyx_memoryviewslice_convert_item_to_object(struct __pyx_memoryviewslice_obj *__pyx_v_self, char *__pyx_v_itemp) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("convert_item_to_object", 1); + + /* "View.MemoryView":967 + * + * cdef convert_item_to_object(self, char *itemp): + * if self.to_object_func != NULL: # <<<<<<<<<<<<<< + * return self.to_object_func(itemp) + * else: + */ + __pyx_t_1 = (__pyx_v_self->to_object_func != NULL); + if (__pyx_t_1) { + + /* "View.MemoryView":968 + * cdef convert_item_to_object(self, char *itemp): + * if self.to_object_func != NULL: + * return self.to_object_func(itemp) # <<<<<<<<<<<<<< + * else: + * return memoryview.convert_item_to_object(self, itemp) + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = __pyx_v_self->to_object_func(__pyx_v_itemp); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 968, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":967 + * + * cdef convert_item_to_object(self, char *itemp): + * if self.to_object_func != NULL: # <<<<<<<<<<<<<< + * return self.to_object_func(itemp) + * else: + */ + } + + /* "View.MemoryView":970 + * return self.to_object_func(itemp) + * else: + * return memoryview.convert_item_to_object(self, itemp) # <<<<<<<<<<<<<< + * + * cdef assign_item_from_object(self, char *itemp, object value): + */ + /*else*/ { + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = __pyx_memoryview_convert_item_to_object(((struct __pyx_memoryview_obj *)__pyx_v_self), __pyx_v_itemp); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 970, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + } + + /* "View.MemoryView":966 + * __PYX_XCLEAR_MEMVIEW(&self.from_slice, 1) + * + * cdef convert_item_to_object(self, char *itemp): # <<<<<<<<<<<<<< + * if self.to_object_func != NULL: + * return self.to_object_func(itemp) + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView._memoryviewslice.convert_item_to_object", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":972 + * return memoryview.convert_item_to_object(self, itemp) + * + * cdef assign_item_from_object(self, char *itemp, object value): # <<<<<<<<<<<<<< + * if self.to_dtype_func != NULL: + * self.to_dtype_func(itemp, value) + */ + +static PyObject *__pyx_memoryviewslice_assign_item_from_object(struct __pyx_memoryviewslice_obj *__pyx_v_self, char *__pyx_v_itemp, PyObject *__pyx_v_value) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("assign_item_from_object", 1); + + /* "View.MemoryView":973 + * + * cdef assign_item_from_object(self, char *itemp, object value): + * if self.to_dtype_func != NULL: # <<<<<<<<<<<<<< + * self.to_dtype_func(itemp, value) + * else: + */ + __pyx_t_1 = (__pyx_v_self->to_dtype_func != NULL); + if (__pyx_t_1) { + + /* "View.MemoryView":974 + * cdef assign_item_from_object(self, char *itemp, object value): + * if self.to_dtype_func != NULL: + * self.to_dtype_func(itemp, value) # <<<<<<<<<<<<<< + * else: + * memoryview.assign_item_from_object(self, itemp, value) + */ + __pyx_t_2 = __pyx_v_self->to_dtype_func(__pyx_v_itemp, __pyx_v_value); if (unlikely(__pyx_t_2 == ((int)0))) __PYX_ERR(1, 974, __pyx_L1_error) + + /* "View.MemoryView":973 + * + * cdef assign_item_from_object(self, char *itemp, object value): + * if self.to_dtype_func != NULL: # <<<<<<<<<<<<<< + * self.to_dtype_func(itemp, value) + * else: + */ + goto __pyx_L3; + } + + /* "View.MemoryView":976 + * self.to_dtype_func(itemp, value) + * else: + * memoryview.assign_item_from_object(self, itemp, value) # <<<<<<<<<<<<<< + * + * cdef _get_base(self): + */ + /*else*/ { + __pyx_t_3 = __pyx_memoryview_assign_item_from_object(((struct __pyx_memoryview_obj *)__pyx_v_self), __pyx_v_itemp, __pyx_v_value); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 976, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + } + __pyx_L3:; + + /* "View.MemoryView":972 + * return memoryview.convert_item_to_object(self, itemp) + * + * cdef assign_item_from_object(self, char *itemp, object value): # <<<<<<<<<<<<<< + * if self.to_dtype_func != NULL: + * self.to_dtype_func(itemp, value) + */ + + /* function exit code */ + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_3); + __Pyx_AddTraceback("View.MemoryView._memoryviewslice.assign_item_from_object", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":978 + * memoryview.assign_item_from_object(self, itemp, value) + * + * cdef _get_base(self): # <<<<<<<<<<<<<< + * return self.from_object + * + */ + +static PyObject *__pyx_memoryviewslice__get_base(struct __pyx_memoryviewslice_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("_get_base", 1); + + /* "View.MemoryView":979 + * + * cdef _get_base(self): + * return self.from_object # <<<<<<<<<<<<<< + * + * + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(__pyx_v_self->from_object); + __pyx_r = __pyx_v_self->from_object; + goto __pyx_L0; + + /* "View.MemoryView":978 + * memoryview.assign_item_from_object(self, itemp, value) + * + * cdef _get_base(self): # <<<<<<<<<<<<<< + * return self.from_object + * + */ + + /* function exit code */ + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): + */ + +/* Python wrapper */ +static PyObject *__pyx_pw___pyx_memoryviewslice_1__reduce_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_pw___pyx_memoryviewslice_1__reduce_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__reduce_cython__ (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + if (unlikely(__pyx_nargs > 0)) { + __Pyx_RaiseArgtupleInvalid("__reduce_cython__", 1, 0, 0, __pyx_nargs); return NULL;} + if (unlikely(__pyx_kwds) && __Pyx_NumKwargs_FASTCALL(__pyx_kwds) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "__reduce_cython__", 0))) return NULL; + __pyx_r = __pyx_pf___pyx_memoryviewslice___reduce_cython__(((struct __pyx_memoryviewslice_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf___pyx_memoryviewslice___reduce_cython__(CYTHON_UNUSED struct __pyx_memoryviewslice_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__reduce_cython__", 1); + + /* "(tree fragment)":2 + * def __reduce_cython__(self): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" # <<<<<<<<<<<<<< + * def __setstate_cython__(self, __pyx_state): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + */ + __Pyx_Raise(__pyx_builtin_TypeError, __pyx_kp_s_no_default___reduce___due_to_non, 0, 0); + __PYX_ERR(1, 2, __pyx_L1_error) + + /* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView._memoryviewslice.__reduce_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":3 + * def __reduce_cython__(self): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + */ + +/* Python wrapper */ +static PyObject *__pyx_pw___pyx_memoryviewslice_3__setstate_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_pw___pyx_memoryviewslice_3__setstate_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + CYTHON_UNUSED PyObject *__pyx_v___pyx_state = 0; + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[1] = {0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__setstate_cython__ (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_pyx_state,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 1: values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_FASTCALL(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_pyx_state)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 3, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "__setstate_cython__") < 0)) __PYX_ERR(1, 3, __pyx_L3_error) + } + } else if (unlikely(__pyx_nargs != 1)) { + goto __pyx_L5_argtuple_error; + } else { + values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + } + __pyx_v___pyx_state = values[0]; + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("__setstate_cython__", 1, 1, 1, __pyx_nargs); __PYX_ERR(1, 3, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("View.MemoryView._memoryviewslice.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return NULL; + __pyx_L4_argument_unpacking_done:; + __pyx_r = __pyx_pf___pyx_memoryviewslice_2__setstate_cython__(((struct __pyx_memoryviewslice_obj *)__pyx_v_self), __pyx_v___pyx_state); + + /* function exit code */ + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf___pyx_memoryviewslice_2__setstate_cython__(CYTHON_UNUSED struct __pyx_memoryviewslice_obj *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__setstate_cython__", 1); + + /* "(tree fragment)":4 + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" # <<<<<<<<<<<<<< + */ + __Pyx_Raise(__pyx_builtin_TypeError, __pyx_kp_s_no_default___reduce___due_to_non, 0, 0); + __PYX_ERR(1, 4, __pyx_L1_error) + + /* "(tree fragment)":3 + * def __reduce_cython__(self): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView._memoryviewslice.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":999 + * + * @cname('__pyx_memoryview_fromslice') + * cdef memoryview_fromslice(__Pyx_memviewslice memviewslice, # <<<<<<<<<<<<<< + * int ndim, + * object (*to_object_func)(char *), + */ + +static PyObject *__pyx_memoryview_fromslice(__Pyx_memviewslice __pyx_v_memviewslice, int __pyx_v_ndim, PyObject *(*__pyx_v_to_object_func)(char *), int (*__pyx_v_to_dtype_func)(char *, PyObject *), int __pyx_v_dtype_is_object) { + struct __pyx_memoryviewslice_obj *__pyx_v_result = 0; + Py_ssize_t __pyx_v_suboffset; + PyObject *__pyx_v_length = NULL; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + __Pyx_TypeInfo *__pyx_t_4; + Py_buffer __pyx_t_5; + Py_ssize_t *__pyx_t_6; + Py_ssize_t *__pyx_t_7; + Py_ssize_t *__pyx_t_8; + Py_ssize_t __pyx_t_9; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("memoryview_fromslice", 1); + + /* "View.MemoryView":1007 + * cdef _memoryviewslice result + * + * if memviewslice.memview == Py_None: # <<<<<<<<<<<<<< + * return None + * + */ + __pyx_t_1 = (((PyObject *)__pyx_v_memviewslice.memview) == Py_None); + if (__pyx_t_1) { + + /* "View.MemoryView":1008 + * + * if memviewslice.memview == Py_None: + * return None # <<<<<<<<<<<<<< + * + * + */ + __Pyx_XDECREF(__pyx_r); + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + + /* "View.MemoryView":1007 + * cdef _memoryviewslice result + * + * if memviewslice.memview == Py_None: # <<<<<<<<<<<<<< + * return None + * + */ + } + + /* "View.MemoryView":1013 + * + * + * result = _memoryviewslice.__new__(_memoryviewslice, None, 0, dtype_is_object) # <<<<<<<<<<<<<< + * + * result.from_slice = memviewslice + */ + __pyx_t_2 = __Pyx_PyBool_FromLong(__pyx_v_dtype_is_object); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 1013, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_3 = PyTuple_New(3); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 1013, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_INCREF(Py_None); + __Pyx_GIVEREF(Py_None); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, Py_None)) __PYX_ERR(1, 1013, __pyx_L1_error); + __Pyx_INCREF(__pyx_int_0); + __Pyx_GIVEREF(__pyx_int_0); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_int_0)) __PYX_ERR(1, 1013, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_2); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 2, __pyx_t_2)) __PYX_ERR(1, 1013, __pyx_L1_error); + __pyx_t_2 = 0; + __pyx_t_2 = ((PyObject *)__pyx_tp_new__memoryviewslice(((PyTypeObject *)__pyx_memoryviewslice_type), __pyx_t_3, NULL)); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 1013, __pyx_L1_error) + __Pyx_GOTREF((PyObject *)__pyx_t_2); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __pyx_v_result = ((struct __pyx_memoryviewslice_obj *)__pyx_t_2); + __pyx_t_2 = 0; + + /* "View.MemoryView":1015 + * result = _memoryviewslice.__new__(_memoryviewslice, None, 0, dtype_is_object) + * + * result.from_slice = memviewslice # <<<<<<<<<<<<<< + * __PYX_INC_MEMVIEW(&memviewslice, 1) + * + */ + __pyx_v_result->from_slice = __pyx_v_memviewslice; + + /* "View.MemoryView":1016 + * + * result.from_slice = memviewslice + * __PYX_INC_MEMVIEW(&memviewslice, 1) # <<<<<<<<<<<<<< + * + * result.from_object = ( memviewslice.memview)._get_base() + */ + __PYX_INC_MEMVIEW((&__pyx_v_memviewslice), 1); + + /* "View.MemoryView":1018 + * __PYX_INC_MEMVIEW(&memviewslice, 1) + * + * result.from_object = ( memviewslice.memview)._get_base() # <<<<<<<<<<<<<< + * result.typeinfo = memviewslice.memview.typeinfo + * + */ + __pyx_t_2 = ((struct __pyx_vtabstruct_memoryview *)((struct __pyx_memoryview_obj *)__pyx_v_memviewslice.memview)->__pyx_vtab)->_get_base(((struct __pyx_memoryview_obj *)__pyx_v_memviewslice.memview)); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 1018, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_GIVEREF(__pyx_t_2); + __Pyx_GOTREF(__pyx_v_result->from_object); + __Pyx_DECREF(__pyx_v_result->from_object); + __pyx_v_result->from_object = __pyx_t_2; + __pyx_t_2 = 0; + + /* "View.MemoryView":1019 + * + * result.from_object = ( memviewslice.memview)._get_base() + * result.typeinfo = memviewslice.memview.typeinfo # <<<<<<<<<<<<<< + * + * result.view = memviewslice.memview.view + */ + __pyx_t_4 = __pyx_v_memviewslice.memview->typeinfo; + __pyx_v_result->__pyx_base.typeinfo = __pyx_t_4; + + /* "View.MemoryView":1021 + * result.typeinfo = memviewslice.memview.typeinfo + * + * result.view = memviewslice.memview.view # <<<<<<<<<<<<<< + * result.view.buf = memviewslice.data + * result.view.ndim = ndim + */ + __pyx_t_5 = __pyx_v_memviewslice.memview->view; + __pyx_v_result->__pyx_base.view = __pyx_t_5; + + /* "View.MemoryView":1022 + * + * result.view = memviewslice.memview.view + * result.view.buf = memviewslice.data # <<<<<<<<<<<<<< + * result.view.ndim = ndim + * (<__pyx_buffer *> &result.view).obj = Py_None + */ + __pyx_v_result->__pyx_base.view.buf = ((void *)__pyx_v_memviewslice.data); + + /* "View.MemoryView":1023 + * result.view = memviewslice.memview.view + * result.view.buf = memviewslice.data + * result.view.ndim = ndim # <<<<<<<<<<<<<< + * (<__pyx_buffer *> &result.view).obj = Py_None + * Py_INCREF(Py_None) + */ + __pyx_v_result->__pyx_base.view.ndim = __pyx_v_ndim; + + /* "View.MemoryView":1024 + * result.view.buf = memviewslice.data + * result.view.ndim = ndim + * (<__pyx_buffer *> &result.view).obj = Py_None # <<<<<<<<<<<<<< + * Py_INCREF(Py_None) + * + */ + ((Py_buffer *)(&__pyx_v_result->__pyx_base.view))->obj = Py_None; + + /* "View.MemoryView":1025 + * result.view.ndim = ndim + * (<__pyx_buffer *> &result.view).obj = Py_None + * Py_INCREF(Py_None) # <<<<<<<<<<<<<< + * + * if (memviewslice.memview).flags & PyBUF_WRITABLE: + */ + Py_INCREF(Py_None); + + /* "View.MemoryView":1027 + * Py_INCREF(Py_None) + * + * if (memviewslice.memview).flags & PyBUF_WRITABLE: # <<<<<<<<<<<<<< + * result.flags = PyBUF_RECORDS + * else: + */ + __pyx_t_1 = ((((struct __pyx_memoryview_obj *)__pyx_v_memviewslice.memview)->flags & PyBUF_WRITABLE) != 0); + if (__pyx_t_1) { + + /* "View.MemoryView":1028 + * + * if (memviewslice.memview).flags & PyBUF_WRITABLE: + * result.flags = PyBUF_RECORDS # <<<<<<<<<<<<<< + * else: + * result.flags = PyBUF_RECORDS_RO + */ + __pyx_v_result->__pyx_base.flags = PyBUF_RECORDS; + + /* "View.MemoryView":1027 + * Py_INCREF(Py_None) + * + * if (memviewslice.memview).flags & PyBUF_WRITABLE: # <<<<<<<<<<<<<< + * result.flags = PyBUF_RECORDS + * else: + */ + goto __pyx_L4; + } + + /* "View.MemoryView":1030 + * result.flags = PyBUF_RECORDS + * else: + * result.flags = PyBUF_RECORDS_RO # <<<<<<<<<<<<<< + * + * result.view.shape = result.from_slice.shape + */ + /*else*/ { + __pyx_v_result->__pyx_base.flags = PyBUF_RECORDS_RO; + } + __pyx_L4:; + + /* "View.MemoryView":1032 + * result.flags = PyBUF_RECORDS_RO + * + * result.view.shape = result.from_slice.shape # <<<<<<<<<<<<<< + * result.view.strides = result.from_slice.strides + * + */ + __pyx_v_result->__pyx_base.view.shape = ((Py_ssize_t *)__pyx_v_result->from_slice.shape); + + /* "View.MemoryView":1033 + * + * result.view.shape = result.from_slice.shape + * result.view.strides = result.from_slice.strides # <<<<<<<<<<<<<< + * + * + */ + __pyx_v_result->__pyx_base.view.strides = ((Py_ssize_t *)__pyx_v_result->from_slice.strides); + + /* "View.MemoryView":1036 + * + * + * result.view.suboffsets = NULL # <<<<<<<<<<<<<< + * for suboffset in result.from_slice.suboffsets[:ndim]: + * if suboffset >= 0: + */ + __pyx_v_result->__pyx_base.view.suboffsets = NULL; + + /* "View.MemoryView":1037 + * + * result.view.suboffsets = NULL + * for suboffset in result.from_slice.suboffsets[:ndim]: # <<<<<<<<<<<<<< + * if suboffset >= 0: + * result.view.suboffsets = result.from_slice.suboffsets + */ + __pyx_t_7 = (__pyx_v_result->from_slice.suboffsets + __pyx_v_ndim); + for (__pyx_t_8 = __pyx_v_result->from_slice.suboffsets; __pyx_t_8 < __pyx_t_7; __pyx_t_8++) { + __pyx_t_6 = __pyx_t_8; + __pyx_v_suboffset = (__pyx_t_6[0]); + + /* "View.MemoryView":1038 + * result.view.suboffsets = NULL + * for suboffset in result.from_slice.suboffsets[:ndim]: + * if suboffset >= 0: # <<<<<<<<<<<<<< + * result.view.suboffsets = result.from_slice.suboffsets + * break + */ + __pyx_t_1 = (__pyx_v_suboffset >= 0); + if (__pyx_t_1) { + + /* "View.MemoryView":1039 + * for suboffset in result.from_slice.suboffsets[:ndim]: + * if suboffset >= 0: + * result.view.suboffsets = result.from_slice.suboffsets # <<<<<<<<<<<<<< + * break + * + */ + __pyx_v_result->__pyx_base.view.suboffsets = ((Py_ssize_t *)__pyx_v_result->from_slice.suboffsets); + + /* "View.MemoryView":1040 + * if suboffset >= 0: + * result.view.suboffsets = result.from_slice.suboffsets + * break # <<<<<<<<<<<<<< + * + * result.view.len = result.view.itemsize + */ + goto __pyx_L6_break; + + /* "View.MemoryView":1038 + * result.view.suboffsets = NULL + * for suboffset in result.from_slice.suboffsets[:ndim]: + * if suboffset >= 0: # <<<<<<<<<<<<<< + * result.view.suboffsets = result.from_slice.suboffsets + * break + */ + } + } + __pyx_L6_break:; + + /* "View.MemoryView":1042 + * break + * + * result.view.len = result.view.itemsize # <<<<<<<<<<<<<< + * for length in result.view.shape[:ndim]: + * result.view.len *= length + */ + __pyx_t_9 = __pyx_v_result->__pyx_base.view.itemsize; + __pyx_v_result->__pyx_base.view.len = __pyx_t_9; + + /* "View.MemoryView":1043 + * + * result.view.len = result.view.itemsize + * for length in result.view.shape[:ndim]: # <<<<<<<<<<<<<< + * result.view.len *= length + * + */ + __pyx_t_7 = (__pyx_v_result->__pyx_base.view.shape + __pyx_v_ndim); + for (__pyx_t_8 = __pyx_v_result->__pyx_base.view.shape; __pyx_t_8 < __pyx_t_7; __pyx_t_8++) { + __pyx_t_6 = __pyx_t_8; + __pyx_t_2 = PyInt_FromSsize_t((__pyx_t_6[0])); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 1043, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_XDECREF_SET(__pyx_v_length, __pyx_t_2); + __pyx_t_2 = 0; + + /* "View.MemoryView":1044 + * result.view.len = result.view.itemsize + * for length in result.view.shape[:ndim]: + * result.view.len *= length # <<<<<<<<<<<<<< + * + * result.to_object_func = to_object_func + */ + __pyx_t_2 = PyInt_FromSsize_t(__pyx_v_result->__pyx_base.view.len); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 1044, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_3 = PyNumber_InPlaceMultiply(__pyx_t_2, __pyx_v_length); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 1044, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_t_9 = __Pyx_PyIndex_AsSsize_t(__pyx_t_3); if (unlikely((__pyx_t_9 == (Py_ssize_t)-1) && PyErr_Occurred())) __PYX_ERR(1, 1044, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __pyx_v_result->__pyx_base.view.len = __pyx_t_9; + } + + /* "View.MemoryView":1046 + * result.view.len *= length + * + * result.to_object_func = to_object_func # <<<<<<<<<<<<<< + * result.to_dtype_func = to_dtype_func + * + */ + __pyx_v_result->to_object_func = __pyx_v_to_object_func; + + /* "View.MemoryView":1047 + * + * result.to_object_func = to_object_func + * result.to_dtype_func = to_dtype_func # <<<<<<<<<<<<<< + * + * return result + */ + __pyx_v_result->to_dtype_func = __pyx_v_to_dtype_func; + + /* "View.MemoryView":1049 + * result.to_dtype_func = to_dtype_func + * + * return result # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_get_slice_from_memoryview') + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF((PyObject *)__pyx_v_result); + __pyx_r = ((PyObject *)__pyx_v_result); + goto __pyx_L0; + + /* "View.MemoryView":999 + * + * @cname('__pyx_memoryview_fromslice') + * cdef memoryview_fromslice(__Pyx_memviewslice memviewslice, # <<<<<<<<<<<<<< + * int ndim, + * object (*to_object_func)(char *), + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_AddTraceback("View.MemoryView.memoryview_fromslice", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XDECREF((PyObject *)__pyx_v_result); + __Pyx_XDECREF(__pyx_v_length); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":1052 + * + * @cname('__pyx_memoryview_get_slice_from_memoryview') + * cdef __Pyx_memviewslice *get_slice_from_memview(memoryview memview, # <<<<<<<<<<<<<< + * __Pyx_memviewslice *mslice) except NULL: + * cdef _memoryviewslice obj + */ + +static __Pyx_memviewslice *__pyx_memoryview_get_slice_from_memoryview(struct __pyx_memoryview_obj *__pyx_v_memview, __Pyx_memviewslice *__pyx_v_mslice) { + struct __pyx_memoryviewslice_obj *__pyx_v_obj = 0; + __Pyx_memviewslice *__pyx_r; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("get_slice_from_memview", 1); + + /* "View.MemoryView":1055 + * __Pyx_memviewslice *mslice) except NULL: + * cdef _memoryviewslice obj + * if isinstance(memview, _memoryviewslice): # <<<<<<<<<<<<<< + * obj = memview + * return &obj.from_slice + */ + __pyx_t_1 = __Pyx_TypeCheck(((PyObject *)__pyx_v_memview), __pyx_memoryviewslice_type); + if (__pyx_t_1) { + + /* "View.MemoryView":1056 + * cdef _memoryviewslice obj + * if isinstance(memview, _memoryviewslice): + * obj = memview # <<<<<<<<<<<<<< + * return &obj.from_slice + * else: + */ + if (!(likely(((((PyObject *)__pyx_v_memview)) == Py_None) || likely(__Pyx_TypeTest(((PyObject *)__pyx_v_memview), __pyx_memoryviewslice_type))))) __PYX_ERR(1, 1056, __pyx_L1_error) + __pyx_t_2 = ((PyObject *)__pyx_v_memview); + __Pyx_INCREF(__pyx_t_2); + __pyx_v_obj = ((struct __pyx_memoryviewslice_obj *)__pyx_t_2); + __pyx_t_2 = 0; + + /* "View.MemoryView":1057 + * if isinstance(memview, _memoryviewslice): + * obj = memview + * return &obj.from_slice # <<<<<<<<<<<<<< + * else: + * slice_copy(memview, mslice) + */ + __pyx_r = (&__pyx_v_obj->from_slice); + goto __pyx_L0; + + /* "View.MemoryView":1055 + * __Pyx_memviewslice *mslice) except NULL: + * cdef _memoryviewslice obj + * if isinstance(memview, _memoryviewslice): # <<<<<<<<<<<<<< + * obj = memview + * return &obj.from_slice + */ + } + + /* "View.MemoryView":1059 + * return &obj.from_slice + * else: + * slice_copy(memview, mslice) # <<<<<<<<<<<<<< + * return mslice + * + */ + /*else*/ { + __pyx_memoryview_slice_copy(__pyx_v_memview, __pyx_v_mslice); + + /* "View.MemoryView":1060 + * else: + * slice_copy(memview, mslice) + * return mslice # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_slice_copy') + */ + __pyx_r = __pyx_v_mslice; + goto __pyx_L0; + } + + /* "View.MemoryView":1052 + * + * @cname('__pyx_memoryview_get_slice_from_memoryview') + * cdef __Pyx_memviewslice *get_slice_from_memview(memoryview memview, # <<<<<<<<<<<<<< + * __Pyx_memviewslice *mslice) except NULL: + * cdef _memoryviewslice obj + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.get_slice_from_memview", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XDECREF((PyObject *)__pyx_v_obj); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":1063 + * + * @cname('__pyx_memoryview_slice_copy') + * cdef void slice_copy(memoryview memview, __Pyx_memviewslice *dst) noexcept: # <<<<<<<<<<<<<< + * cdef int dim + * cdef (Py_ssize_t*) shape, strides, suboffsets + */ + +static void __pyx_memoryview_slice_copy(struct __pyx_memoryview_obj *__pyx_v_memview, __Pyx_memviewslice *__pyx_v_dst) { + int __pyx_v_dim; + Py_ssize_t *__pyx_v_shape; + Py_ssize_t *__pyx_v_strides; + Py_ssize_t *__pyx_v_suboffsets; + Py_ssize_t *__pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + int __pyx_t_4; + Py_ssize_t __pyx_t_5; + int __pyx_t_6; + + /* "View.MemoryView":1067 + * cdef (Py_ssize_t*) shape, strides, suboffsets + * + * shape = memview.view.shape # <<<<<<<<<<<<<< + * strides = memview.view.strides + * suboffsets = memview.view.suboffsets + */ + __pyx_t_1 = __pyx_v_memview->view.shape; + __pyx_v_shape = __pyx_t_1; + + /* "View.MemoryView":1068 + * + * shape = memview.view.shape + * strides = memview.view.strides # <<<<<<<<<<<<<< + * suboffsets = memview.view.suboffsets + * + */ + __pyx_t_1 = __pyx_v_memview->view.strides; + __pyx_v_strides = __pyx_t_1; + + /* "View.MemoryView":1069 + * shape = memview.view.shape + * strides = memview.view.strides + * suboffsets = memview.view.suboffsets # <<<<<<<<<<<<<< + * + * dst.memview = <__pyx_memoryview *> memview + */ + __pyx_t_1 = __pyx_v_memview->view.suboffsets; + __pyx_v_suboffsets = __pyx_t_1; + + /* "View.MemoryView":1071 + * suboffsets = memview.view.suboffsets + * + * dst.memview = <__pyx_memoryview *> memview # <<<<<<<<<<<<<< + * dst.data = memview.view.buf + * + */ + __pyx_v_dst->memview = ((struct __pyx_memoryview_obj *)__pyx_v_memview); + + /* "View.MemoryView":1072 + * + * dst.memview = <__pyx_memoryview *> memview + * dst.data = memview.view.buf # <<<<<<<<<<<<<< + * + * for dim in range(memview.view.ndim): + */ + __pyx_v_dst->data = ((char *)__pyx_v_memview->view.buf); + + /* "View.MemoryView":1074 + * dst.data = memview.view.buf + * + * for dim in range(memview.view.ndim): # <<<<<<<<<<<<<< + * dst.shape[dim] = shape[dim] + * dst.strides[dim] = strides[dim] + */ + __pyx_t_2 = __pyx_v_memview->view.ndim; + __pyx_t_3 = __pyx_t_2; + for (__pyx_t_4 = 0; __pyx_t_4 < __pyx_t_3; __pyx_t_4+=1) { + __pyx_v_dim = __pyx_t_4; + + /* "View.MemoryView":1075 + * + * for dim in range(memview.view.ndim): + * dst.shape[dim] = shape[dim] # <<<<<<<<<<<<<< + * dst.strides[dim] = strides[dim] + * dst.suboffsets[dim] = suboffsets[dim] if suboffsets else -1 + */ + (__pyx_v_dst->shape[__pyx_v_dim]) = (__pyx_v_shape[__pyx_v_dim]); + + /* "View.MemoryView":1076 + * for dim in range(memview.view.ndim): + * dst.shape[dim] = shape[dim] + * dst.strides[dim] = strides[dim] # <<<<<<<<<<<<<< + * dst.suboffsets[dim] = suboffsets[dim] if suboffsets else -1 + * + */ + (__pyx_v_dst->strides[__pyx_v_dim]) = (__pyx_v_strides[__pyx_v_dim]); + + /* "View.MemoryView":1077 + * dst.shape[dim] = shape[dim] + * dst.strides[dim] = strides[dim] + * dst.suboffsets[dim] = suboffsets[dim] if suboffsets else -1 # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_copy_object') + */ + __pyx_t_6 = (__pyx_v_suboffsets != 0); + if (__pyx_t_6) { + __pyx_t_5 = (__pyx_v_suboffsets[__pyx_v_dim]); + } else { + __pyx_t_5 = -1L; + } + (__pyx_v_dst->suboffsets[__pyx_v_dim]) = __pyx_t_5; + } + + /* "View.MemoryView":1063 + * + * @cname('__pyx_memoryview_slice_copy') + * cdef void slice_copy(memoryview memview, __Pyx_memviewslice *dst) noexcept: # <<<<<<<<<<<<<< + * cdef int dim + * cdef (Py_ssize_t*) shape, strides, suboffsets + */ + + /* function exit code */ +} + +/* "View.MemoryView":1080 + * + * @cname('__pyx_memoryview_copy_object') + * cdef memoryview_copy(memoryview memview): # <<<<<<<<<<<<<< + * "Create a new memoryview object" + * cdef __Pyx_memviewslice memviewslice + */ + +static PyObject *__pyx_memoryview_copy_object(struct __pyx_memoryview_obj *__pyx_v_memview) { + __Pyx_memviewslice __pyx_v_memviewslice; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("memoryview_copy", 1); + + /* "View.MemoryView":1083 + * "Create a new memoryview object" + * cdef __Pyx_memviewslice memviewslice + * slice_copy(memview, &memviewslice) # <<<<<<<<<<<<<< + * return memoryview_copy_from_slice(memview, &memviewslice) + * + */ + __pyx_memoryview_slice_copy(__pyx_v_memview, (&__pyx_v_memviewslice)); + + /* "View.MemoryView":1084 + * cdef __Pyx_memviewslice memviewslice + * slice_copy(memview, &memviewslice) + * return memoryview_copy_from_slice(memview, &memviewslice) # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_copy_object_from_slice') + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __pyx_memoryview_copy_object_from_slice(__pyx_v_memview, (&__pyx_v_memviewslice)); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 1084, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "View.MemoryView":1080 + * + * @cname('__pyx_memoryview_copy_object') + * cdef memoryview_copy(memoryview memview): # <<<<<<<<<<<<<< + * "Create a new memoryview object" + * cdef __Pyx_memviewslice memviewslice + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("View.MemoryView.memoryview_copy", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":1087 + * + * @cname('__pyx_memoryview_copy_object_from_slice') + * cdef memoryview_copy_from_slice(memoryview memview, __Pyx_memviewslice *memviewslice): # <<<<<<<<<<<<<< + * """ + * Create a new memoryview object from a given memoryview object and slice. + */ + +static PyObject *__pyx_memoryview_copy_object_from_slice(struct __pyx_memoryview_obj *__pyx_v_memview, __Pyx_memviewslice *__pyx_v_memviewslice) { + PyObject *(*__pyx_v_to_object_func)(char *); + int (*__pyx_v_to_dtype_func)(char *, PyObject *); + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + PyObject *(*__pyx_t_2)(char *); + int (*__pyx_t_3)(char *, PyObject *); + PyObject *__pyx_t_4 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("memoryview_copy_from_slice", 1); + + /* "View.MemoryView":1094 + * cdef int (*to_dtype_func)(char *, object) except 0 + * + * if isinstance(memview, _memoryviewslice): # <<<<<<<<<<<<<< + * to_object_func = (<_memoryviewslice> memview).to_object_func + * to_dtype_func = (<_memoryviewslice> memview).to_dtype_func + */ + __pyx_t_1 = __Pyx_TypeCheck(((PyObject *)__pyx_v_memview), __pyx_memoryviewslice_type); + if (__pyx_t_1) { + + /* "View.MemoryView":1095 + * + * if isinstance(memview, _memoryviewslice): + * to_object_func = (<_memoryviewslice> memview).to_object_func # <<<<<<<<<<<<<< + * to_dtype_func = (<_memoryviewslice> memview).to_dtype_func + * else: + */ + __pyx_t_2 = ((struct __pyx_memoryviewslice_obj *)__pyx_v_memview)->to_object_func; + __pyx_v_to_object_func = __pyx_t_2; + + /* "View.MemoryView":1096 + * if isinstance(memview, _memoryviewslice): + * to_object_func = (<_memoryviewslice> memview).to_object_func + * to_dtype_func = (<_memoryviewslice> memview).to_dtype_func # <<<<<<<<<<<<<< + * else: + * to_object_func = NULL + */ + __pyx_t_3 = ((struct __pyx_memoryviewslice_obj *)__pyx_v_memview)->to_dtype_func; + __pyx_v_to_dtype_func = __pyx_t_3; + + /* "View.MemoryView":1094 + * cdef int (*to_dtype_func)(char *, object) except 0 + * + * if isinstance(memview, _memoryviewslice): # <<<<<<<<<<<<<< + * to_object_func = (<_memoryviewslice> memview).to_object_func + * to_dtype_func = (<_memoryviewslice> memview).to_dtype_func + */ + goto __pyx_L3; + } + + /* "View.MemoryView":1098 + * to_dtype_func = (<_memoryviewslice> memview).to_dtype_func + * else: + * to_object_func = NULL # <<<<<<<<<<<<<< + * to_dtype_func = NULL + * + */ + /*else*/ { + __pyx_v_to_object_func = NULL; + + /* "View.MemoryView":1099 + * else: + * to_object_func = NULL + * to_dtype_func = NULL # <<<<<<<<<<<<<< + * + * return memoryview_fromslice(memviewslice[0], memview.view.ndim, + */ + __pyx_v_to_dtype_func = NULL; + } + __pyx_L3:; + + /* "View.MemoryView":1101 + * to_dtype_func = NULL + * + * return memoryview_fromslice(memviewslice[0], memview.view.ndim, # <<<<<<<<<<<<<< + * to_object_func, to_dtype_func, + * memview.dtype_is_object) + */ + __Pyx_XDECREF(__pyx_r); + + /* "View.MemoryView":1103 + * return memoryview_fromslice(memviewslice[0], memview.view.ndim, + * to_object_func, to_dtype_func, + * memview.dtype_is_object) # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_4 = __pyx_memoryview_fromslice((__pyx_v_memviewslice[0]), __pyx_v_memview->view.ndim, __pyx_v_to_object_func, __pyx_v_to_dtype_func, __pyx_v_memview->dtype_is_object); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 1101, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_r = __pyx_t_4; + __pyx_t_4 = 0; + goto __pyx_L0; + + /* "View.MemoryView":1087 + * + * @cname('__pyx_memoryview_copy_object_from_slice') + * cdef memoryview_copy_from_slice(memoryview memview, __Pyx_memviewslice *memviewslice): # <<<<<<<<<<<<<< + * """ + * Create a new memoryview object from a given memoryview object and slice. + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_4); + __Pyx_AddTraceback("View.MemoryView.memoryview_copy_from_slice", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":1109 + * + * + * cdef Py_ssize_t abs_py_ssize_t(Py_ssize_t arg) noexcept nogil: # <<<<<<<<<<<<<< + * return -arg if arg < 0 else arg + * + */ + +static Py_ssize_t abs_py_ssize_t(Py_ssize_t __pyx_v_arg) { + Py_ssize_t __pyx_r; + Py_ssize_t __pyx_t_1; + int __pyx_t_2; + + /* "View.MemoryView":1110 + * + * cdef Py_ssize_t abs_py_ssize_t(Py_ssize_t arg) noexcept nogil: + * return -arg if arg < 0 else arg # <<<<<<<<<<<<<< + * + * @cname('__pyx_get_best_slice_order') + */ + __pyx_t_2 = (__pyx_v_arg < 0); + if (__pyx_t_2) { + __pyx_t_1 = (-__pyx_v_arg); + } else { + __pyx_t_1 = __pyx_v_arg; + } + __pyx_r = __pyx_t_1; + goto __pyx_L0; + + /* "View.MemoryView":1109 + * + * + * cdef Py_ssize_t abs_py_ssize_t(Py_ssize_t arg) noexcept nogil: # <<<<<<<<<<<<<< + * return -arg if arg < 0 else arg + * + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":1113 + * + * @cname('__pyx_get_best_slice_order') + * cdef char get_best_order(__Pyx_memviewslice *mslice, int ndim) noexcept nogil: # <<<<<<<<<<<<<< + * """ + * Figure out the best memory access order for a given slice. + */ + +static char __pyx_get_best_slice_order(__Pyx_memviewslice *__pyx_v_mslice, int __pyx_v_ndim) { + int __pyx_v_i; + Py_ssize_t __pyx_v_c_stride; + Py_ssize_t __pyx_v_f_stride; + char __pyx_r; + int __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + int __pyx_t_4; + + /* "View.MemoryView":1118 + * """ + * cdef int i + * cdef Py_ssize_t c_stride = 0 # <<<<<<<<<<<<<< + * cdef Py_ssize_t f_stride = 0 + * + */ + __pyx_v_c_stride = 0; + + /* "View.MemoryView":1119 + * cdef int i + * cdef Py_ssize_t c_stride = 0 + * cdef Py_ssize_t f_stride = 0 # <<<<<<<<<<<<<< + * + * for i in range(ndim - 1, -1, -1): + */ + __pyx_v_f_stride = 0; + + /* "View.MemoryView":1121 + * cdef Py_ssize_t f_stride = 0 + * + * for i in range(ndim - 1, -1, -1): # <<<<<<<<<<<<<< + * if mslice.shape[i] > 1: + * c_stride = mslice.strides[i] + */ + for (__pyx_t_1 = (__pyx_v_ndim - 1); __pyx_t_1 > -1; __pyx_t_1-=1) { + __pyx_v_i = __pyx_t_1; + + /* "View.MemoryView":1122 + * + * for i in range(ndim - 1, -1, -1): + * if mslice.shape[i] > 1: # <<<<<<<<<<<<<< + * c_stride = mslice.strides[i] + * break + */ + __pyx_t_2 = ((__pyx_v_mslice->shape[__pyx_v_i]) > 1); + if (__pyx_t_2) { + + /* "View.MemoryView":1123 + * for i in range(ndim - 1, -1, -1): + * if mslice.shape[i] > 1: + * c_stride = mslice.strides[i] # <<<<<<<<<<<<<< + * break + * + */ + __pyx_v_c_stride = (__pyx_v_mslice->strides[__pyx_v_i]); + + /* "View.MemoryView":1124 + * if mslice.shape[i] > 1: + * c_stride = mslice.strides[i] + * break # <<<<<<<<<<<<<< + * + * for i in range(ndim): + */ + goto __pyx_L4_break; + + /* "View.MemoryView":1122 + * + * for i in range(ndim - 1, -1, -1): + * if mslice.shape[i] > 1: # <<<<<<<<<<<<<< + * c_stride = mslice.strides[i] + * break + */ + } + } + __pyx_L4_break:; + + /* "View.MemoryView":1126 + * break + * + * for i in range(ndim): # <<<<<<<<<<<<<< + * if mslice.shape[i] > 1: + * f_stride = mslice.strides[i] + */ + __pyx_t_1 = __pyx_v_ndim; + __pyx_t_3 = __pyx_t_1; + for (__pyx_t_4 = 0; __pyx_t_4 < __pyx_t_3; __pyx_t_4+=1) { + __pyx_v_i = __pyx_t_4; + + /* "View.MemoryView":1127 + * + * for i in range(ndim): + * if mslice.shape[i] > 1: # <<<<<<<<<<<<<< + * f_stride = mslice.strides[i] + * break + */ + __pyx_t_2 = ((__pyx_v_mslice->shape[__pyx_v_i]) > 1); + if (__pyx_t_2) { + + /* "View.MemoryView":1128 + * for i in range(ndim): + * if mslice.shape[i] > 1: + * f_stride = mslice.strides[i] # <<<<<<<<<<<<<< + * break + * + */ + __pyx_v_f_stride = (__pyx_v_mslice->strides[__pyx_v_i]); + + /* "View.MemoryView":1129 + * if mslice.shape[i] > 1: + * f_stride = mslice.strides[i] + * break # <<<<<<<<<<<<<< + * + * if abs_py_ssize_t(c_stride) <= abs_py_ssize_t(f_stride): + */ + goto __pyx_L7_break; + + /* "View.MemoryView":1127 + * + * for i in range(ndim): + * if mslice.shape[i] > 1: # <<<<<<<<<<<<<< + * f_stride = mslice.strides[i] + * break + */ + } + } + __pyx_L7_break:; + + /* "View.MemoryView":1131 + * break + * + * if abs_py_ssize_t(c_stride) <= abs_py_ssize_t(f_stride): # <<<<<<<<<<<<<< + * return 'C' + * else: + */ + __pyx_t_2 = (abs_py_ssize_t(__pyx_v_c_stride) <= abs_py_ssize_t(__pyx_v_f_stride)); + if (__pyx_t_2) { + + /* "View.MemoryView":1132 + * + * if abs_py_ssize_t(c_stride) <= abs_py_ssize_t(f_stride): + * return 'C' # <<<<<<<<<<<<<< + * else: + * return 'F' + */ + __pyx_r = 'C'; + goto __pyx_L0; + + /* "View.MemoryView":1131 + * break + * + * if abs_py_ssize_t(c_stride) <= abs_py_ssize_t(f_stride): # <<<<<<<<<<<<<< + * return 'C' + * else: + */ + } + + /* "View.MemoryView":1134 + * return 'C' + * else: + * return 'F' # <<<<<<<<<<<<<< + * + * @cython.cdivision(True) + */ + /*else*/ { + __pyx_r = 'F'; + goto __pyx_L0; + } + + /* "View.MemoryView":1113 + * + * @cname('__pyx_get_best_slice_order') + * cdef char get_best_order(__Pyx_memviewslice *mslice, int ndim) noexcept nogil: # <<<<<<<<<<<<<< + * """ + * Figure out the best memory access order for a given slice. + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":1137 + * + * @cython.cdivision(True) + * cdef void _copy_strided_to_strided(char *src_data, Py_ssize_t *src_strides, # <<<<<<<<<<<<<< + * char *dst_data, Py_ssize_t *dst_strides, + * Py_ssize_t *src_shape, Py_ssize_t *dst_shape, + */ + +static void _copy_strided_to_strided(char *__pyx_v_src_data, Py_ssize_t *__pyx_v_src_strides, char *__pyx_v_dst_data, Py_ssize_t *__pyx_v_dst_strides, Py_ssize_t *__pyx_v_src_shape, Py_ssize_t *__pyx_v_dst_shape, int __pyx_v_ndim, size_t __pyx_v_itemsize) { + CYTHON_UNUSED Py_ssize_t __pyx_v_i; + CYTHON_UNUSED Py_ssize_t __pyx_v_src_extent; + Py_ssize_t __pyx_v_dst_extent; + Py_ssize_t __pyx_v_src_stride; + Py_ssize_t __pyx_v_dst_stride; + int __pyx_t_1; + int __pyx_t_2; + Py_ssize_t __pyx_t_3; + Py_ssize_t __pyx_t_4; + Py_ssize_t __pyx_t_5; + + /* "View.MemoryView":1144 + * + * cdef Py_ssize_t i + * cdef Py_ssize_t src_extent = src_shape[0] # <<<<<<<<<<<<<< + * cdef Py_ssize_t dst_extent = dst_shape[0] + * cdef Py_ssize_t src_stride = src_strides[0] + */ + __pyx_v_src_extent = (__pyx_v_src_shape[0]); + + /* "View.MemoryView":1145 + * cdef Py_ssize_t i + * cdef Py_ssize_t src_extent = src_shape[0] + * cdef Py_ssize_t dst_extent = dst_shape[0] # <<<<<<<<<<<<<< + * cdef Py_ssize_t src_stride = src_strides[0] + * cdef Py_ssize_t dst_stride = dst_strides[0] + */ + __pyx_v_dst_extent = (__pyx_v_dst_shape[0]); + + /* "View.MemoryView":1146 + * cdef Py_ssize_t src_extent = src_shape[0] + * cdef Py_ssize_t dst_extent = dst_shape[0] + * cdef Py_ssize_t src_stride = src_strides[0] # <<<<<<<<<<<<<< + * cdef Py_ssize_t dst_stride = dst_strides[0] + * + */ + __pyx_v_src_stride = (__pyx_v_src_strides[0]); + + /* "View.MemoryView":1147 + * cdef Py_ssize_t dst_extent = dst_shape[0] + * cdef Py_ssize_t src_stride = src_strides[0] + * cdef Py_ssize_t dst_stride = dst_strides[0] # <<<<<<<<<<<<<< + * + * if ndim == 1: + */ + __pyx_v_dst_stride = (__pyx_v_dst_strides[0]); + + /* "View.MemoryView":1149 + * cdef Py_ssize_t dst_stride = dst_strides[0] + * + * if ndim == 1: # <<<<<<<<<<<<<< + * if (src_stride > 0 and dst_stride > 0 and + * src_stride == itemsize == dst_stride): + */ + __pyx_t_1 = (__pyx_v_ndim == 1); + if (__pyx_t_1) { + + /* "View.MemoryView":1150 + * + * if ndim == 1: + * if (src_stride > 0 and dst_stride > 0 and # <<<<<<<<<<<<<< + * src_stride == itemsize == dst_stride): + * memcpy(dst_data, src_data, itemsize * dst_extent) + */ + __pyx_t_2 = (__pyx_v_src_stride > 0); + if (__pyx_t_2) { + } else { + __pyx_t_1 = __pyx_t_2; + goto __pyx_L5_bool_binop_done; + } + __pyx_t_2 = (__pyx_v_dst_stride > 0); + if (__pyx_t_2) { + } else { + __pyx_t_1 = __pyx_t_2; + goto __pyx_L5_bool_binop_done; + } + + /* "View.MemoryView":1151 + * if ndim == 1: + * if (src_stride > 0 and dst_stride > 0 and + * src_stride == itemsize == dst_stride): # <<<<<<<<<<<<<< + * memcpy(dst_data, src_data, itemsize * dst_extent) + * else: + */ + __pyx_t_2 = (((size_t)__pyx_v_src_stride) == __pyx_v_itemsize); + if (__pyx_t_2) { + __pyx_t_2 = (__pyx_v_itemsize == ((size_t)__pyx_v_dst_stride)); + } + __pyx_t_1 = __pyx_t_2; + __pyx_L5_bool_binop_done:; + + /* "View.MemoryView":1150 + * + * if ndim == 1: + * if (src_stride > 0 and dst_stride > 0 and # <<<<<<<<<<<<<< + * src_stride == itemsize == dst_stride): + * memcpy(dst_data, src_data, itemsize * dst_extent) + */ + if (__pyx_t_1) { + + /* "View.MemoryView":1152 + * if (src_stride > 0 and dst_stride > 0 and + * src_stride == itemsize == dst_stride): + * memcpy(dst_data, src_data, itemsize * dst_extent) # <<<<<<<<<<<<<< + * else: + * for i in range(dst_extent): + */ + (void)(memcpy(__pyx_v_dst_data, __pyx_v_src_data, (__pyx_v_itemsize * __pyx_v_dst_extent))); + + /* "View.MemoryView":1150 + * + * if ndim == 1: + * if (src_stride > 0 and dst_stride > 0 and # <<<<<<<<<<<<<< + * src_stride == itemsize == dst_stride): + * memcpy(dst_data, src_data, itemsize * dst_extent) + */ + goto __pyx_L4; + } + + /* "View.MemoryView":1154 + * memcpy(dst_data, src_data, itemsize * dst_extent) + * else: + * for i in range(dst_extent): # <<<<<<<<<<<<<< + * memcpy(dst_data, src_data, itemsize) + * src_data += src_stride + */ + /*else*/ { + __pyx_t_3 = __pyx_v_dst_extent; + __pyx_t_4 = __pyx_t_3; + for (__pyx_t_5 = 0; __pyx_t_5 < __pyx_t_4; __pyx_t_5+=1) { + __pyx_v_i = __pyx_t_5; + + /* "View.MemoryView":1155 + * else: + * for i in range(dst_extent): + * memcpy(dst_data, src_data, itemsize) # <<<<<<<<<<<<<< + * src_data += src_stride + * dst_data += dst_stride + */ + (void)(memcpy(__pyx_v_dst_data, __pyx_v_src_data, __pyx_v_itemsize)); + + /* "View.MemoryView":1156 + * for i in range(dst_extent): + * memcpy(dst_data, src_data, itemsize) + * src_data += src_stride # <<<<<<<<<<<<<< + * dst_data += dst_stride + * else: + */ + __pyx_v_src_data = (__pyx_v_src_data + __pyx_v_src_stride); + + /* "View.MemoryView":1157 + * memcpy(dst_data, src_data, itemsize) + * src_data += src_stride + * dst_data += dst_stride # <<<<<<<<<<<<<< + * else: + * for i in range(dst_extent): + */ + __pyx_v_dst_data = (__pyx_v_dst_data + __pyx_v_dst_stride); + } + } + __pyx_L4:; + + /* "View.MemoryView":1149 + * cdef Py_ssize_t dst_stride = dst_strides[0] + * + * if ndim == 1: # <<<<<<<<<<<<<< + * if (src_stride > 0 and dst_stride > 0 and + * src_stride == itemsize == dst_stride): + */ + goto __pyx_L3; + } + + /* "View.MemoryView":1159 + * dst_data += dst_stride + * else: + * for i in range(dst_extent): # <<<<<<<<<<<<<< + * _copy_strided_to_strided(src_data, src_strides + 1, + * dst_data, dst_strides + 1, + */ + /*else*/ { + __pyx_t_3 = __pyx_v_dst_extent; + __pyx_t_4 = __pyx_t_3; + for (__pyx_t_5 = 0; __pyx_t_5 < __pyx_t_4; __pyx_t_5+=1) { + __pyx_v_i = __pyx_t_5; + + /* "View.MemoryView":1160 + * else: + * for i in range(dst_extent): + * _copy_strided_to_strided(src_data, src_strides + 1, # <<<<<<<<<<<<<< + * dst_data, dst_strides + 1, + * src_shape + 1, dst_shape + 1, + */ + _copy_strided_to_strided(__pyx_v_src_data, (__pyx_v_src_strides + 1), __pyx_v_dst_data, (__pyx_v_dst_strides + 1), (__pyx_v_src_shape + 1), (__pyx_v_dst_shape + 1), (__pyx_v_ndim - 1), __pyx_v_itemsize); + + /* "View.MemoryView":1164 + * src_shape + 1, dst_shape + 1, + * ndim - 1, itemsize) + * src_data += src_stride # <<<<<<<<<<<<<< + * dst_data += dst_stride + * + */ + __pyx_v_src_data = (__pyx_v_src_data + __pyx_v_src_stride); + + /* "View.MemoryView":1165 + * ndim - 1, itemsize) + * src_data += src_stride + * dst_data += dst_stride # <<<<<<<<<<<<<< + * + * cdef void copy_strided_to_strided(__Pyx_memviewslice *src, + */ + __pyx_v_dst_data = (__pyx_v_dst_data + __pyx_v_dst_stride); + } + } + __pyx_L3:; + + /* "View.MemoryView":1137 + * + * @cython.cdivision(True) + * cdef void _copy_strided_to_strided(char *src_data, Py_ssize_t *src_strides, # <<<<<<<<<<<<<< + * char *dst_data, Py_ssize_t *dst_strides, + * Py_ssize_t *src_shape, Py_ssize_t *dst_shape, + */ + + /* function exit code */ +} + +/* "View.MemoryView":1167 + * dst_data += dst_stride + * + * cdef void copy_strided_to_strided(__Pyx_memviewslice *src, # <<<<<<<<<<<<<< + * __Pyx_memviewslice *dst, + * int ndim, size_t itemsize) noexcept nogil: + */ + +static void copy_strided_to_strided(__Pyx_memviewslice *__pyx_v_src, __Pyx_memviewslice *__pyx_v_dst, int __pyx_v_ndim, size_t __pyx_v_itemsize) { + + /* "View.MemoryView":1170 + * __Pyx_memviewslice *dst, + * int ndim, size_t itemsize) noexcept nogil: + * _copy_strided_to_strided(src.data, src.strides, dst.data, dst.strides, # <<<<<<<<<<<<<< + * src.shape, dst.shape, ndim, itemsize) + * + */ + _copy_strided_to_strided(__pyx_v_src->data, __pyx_v_src->strides, __pyx_v_dst->data, __pyx_v_dst->strides, __pyx_v_src->shape, __pyx_v_dst->shape, __pyx_v_ndim, __pyx_v_itemsize); + + /* "View.MemoryView":1167 + * dst_data += dst_stride + * + * cdef void copy_strided_to_strided(__Pyx_memviewslice *src, # <<<<<<<<<<<<<< + * __Pyx_memviewslice *dst, + * int ndim, size_t itemsize) noexcept nogil: + */ + + /* function exit code */ +} + +/* "View.MemoryView":1174 + * + * @cname('__pyx_memoryview_slice_get_size') + * cdef Py_ssize_t slice_get_size(__Pyx_memviewslice *src, int ndim) noexcept nogil: # <<<<<<<<<<<<<< + * "Return the size of the memory occupied by the slice in number of bytes" + * cdef Py_ssize_t shape, size = src.memview.view.itemsize + */ + +static Py_ssize_t __pyx_memoryview_slice_get_size(__Pyx_memviewslice *__pyx_v_src, int __pyx_v_ndim) { + Py_ssize_t __pyx_v_shape; + Py_ssize_t __pyx_v_size; + Py_ssize_t __pyx_r; + Py_ssize_t __pyx_t_1; + Py_ssize_t *__pyx_t_2; + Py_ssize_t *__pyx_t_3; + Py_ssize_t *__pyx_t_4; + + /* "View.MemoryView":1176 + * cdef Py_ssize_t slice_get_size(__Pyx_memviewslice *src, int ndim) noexcept nogil: + * "Return the size of the memory occupied by the slice in number of bytes" + * cdef Py_ssize_t shape, size = src.memview.view.itemsize # <<<<<<<<<<<<<< + * + * for shape in src.shape[:ndim]: + */ + __pyx_t_1 = __pyx_v_src->memview->view.itemsize; + __pyx_v_size = __pyx_t_1; + + /* "View.MemoryView":1178 + * cdef Py_ssize_t shape, size = src.memview.view.itemsize + * + * for shape in src.shape[:ndim]: # <<<<<<<<<<<<<< + * size *= shape + * + */ + __pyx_t_3 = (__pyx_v_src->shape + __pyx_v_ndim); + for (__pyx_t_4 = __pyx_v_src->shape; __pyx_t_4 < __pyx_t_3; __pyx_t_4++) { + __pyx_t_2 = __pyx_t_4; + __pyx_v_shape = (__pyx_t_2[0]); + + /* "View.MemoryView":1179 + * + * for shape in src.shape[:ndim]: + * size *= shape # <<<<<<<<<<<<<< + * + * return size + */ + __pyx_v_size = (__pyx_v_size * __pyx_v_shape); + } + + /* "View.MemoryView":1181 + * size *= shape + * + * return size # <<<<<<<<<<<<<< + * + * @cname('__pyx_fill_contig_strides_array') + */ + __pyx_r = __pyx_v_size; + goto __pyx_L0; + + /* "View.MemoryView":1174 + * + * @cname('__pyx_memoryview_slice_get_size') + * cdef Py_ssize_t slice_get_size(__Pyx_memviewslice *src, int ndim) noexcept nogil: # <<<<<<<<<<<<<< + * "Return the size of the memory occupied by the slice in number of bytes" + * cdef Py_ssize_t shape, size = src.memview.view.itemsize + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":1184 + * + * @cname('__pyx_fill_contig_strides_array') + * cdef Py_ssize_t fill_contig_strides_array( # <<<<<<<<<<<<<< + * Py_ssize_t *shape, Py_ssize_t *strides, Py_ssize_t stride, + * int ndim, char order) noexcept nogil: + */ + +static Py_ssize_t __pyx_fill_contig_strides_array(Py_ssize_t *__pyx_v_shape, Py_ssize_t *__pyx_v_strides, Py_ssize_t __pyx_v_stride, int __pyx_v_ndim, char __pyx_v_order) { + int __pyx_v_idx; + Py_ssize_t __pyx_r; + int __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + int __pyx_t_4; + + /* "View.MemoryView":1193 + * cdef int idx + * + * if order == 'F': # <<<<<<<<<<<<<< + * for idx in range(ndim): + * strides[idx] = stride + */ + __pyx_t_1 = (__pyx_v_order == 'F'); + if (__pyx_t_1) { + + /* "View.MemoryView":1194 + * + * if order == 'F': + * for idx in range(ndim): # <<<<<<<<<<<<<< + * strides[idx] = stride + * stride *= shape[idx] + */ + __pyx_t_2 = __pyx_v_ndim; + __pyx_t_3 = __pyx_t_2; + for (__pyx_t_4 = 0; __pyx_t_4 < __pyx_t_3; __pyx_t_4+=1) { + __pyx_v_idx = __pyx_t_4; + + /* "View.MemoryView":1195 + * if order == 'F': + * for idx in range(ndim): + * strides[idx] = stride # <<<<<<<<<<<<<< + * stride *= shape[idx] + * else: + */ + (__pyx_v_strides[__pyx_v_idx]) = __pyx_v_stride; + + /* "View.MemoryView":1196 + * for idx in range(ndim): + * strides[idx] = stride + * stride *= shape[idx] # <<<<<<<<<<<<<< + * else: + * for idx in range(ndim - 1, -1, -1): + */ + __pyx_v_stride = (__pyx_v_stride * (__pyx_v_shape[__pyx_v_idx])); + } + + /* "View.MemoryView":1193 + * cdef int idx + * + * if order == 'F': # <<<<<<<<<<<<<< + * for idx in range(ndim): + * strides[idx] = stride + */ + goto __pyx_L3; + } + + /* "View.MemoryView":1198 + * stride *= shape[idx] + * else: + * for idx in range(ndim - 1, -1, -1): # <<<<<<<<<<<<<< + * strides[idx] = stride + * stride *= shape[idx] + */ + /*else*/ { + for (__pyx_t_2 = (__pyx_v_ndim - 1); __pyx_t_2 > -1; __pyx_t_2-=1) { + __pyx_v_idx = __pyx_t_2; + + /* "View.MemoryView":1199 + * else: + * for idx in range(ndim - 1, -1, -1): + * strides[idx] = stride # <<<<<<<<<<<<<< + * stride *= shape[idx] + * + */ + (__pyx_v_strides[__pyx_v_idx]) = __pyx_v_stride; + + /* "View.MemoryView":1200 + * for idx in range(ndim - 1, -1, -1): + * strides[idx] = stride + * stride *= shape[idx] # <<<<<<<<<<<<<< + * + * return stride + */ + __pyx_v_stride = (__pyx_v_stride * (__pyx_v_shape[__pyx_v_idx])); + } + } + __pyx_L3:; + + /* "View.MemoryView":1202 + * stride *= shape[idx] + * + * return stride # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_copy_data_to_temp') + */ + __pyx_r = __pyx_v_stride; + goto __pyx_L0; + + /* "View.MemoryView":1184 + * + * @cname('__pyx_fill_contig_strides_array') + * cdef Py_ssize_t fill_contig_strides_array( # <<<<<<<<<<<<<< + * Py_ssize_t *shape, Py_ssize_t *strides, Py_ssize_t stride, + * int ndim, char order) noexcept nogil: + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":1205 + * + * @cname('__pyx_memoryview_copy_data_to_temp') + * cdef void *copy_data_to_temp(__Pyx_memviewslice *src, # <<<<<<<<<<<<<< + * __Pyx_memviewslice *tmpslice, + * char order, + */ + +static void *__pyx_memoryview_copy_data_to_temp(__Pyx_memviewslice *__pyx_v_src, __Pyx_memviewslice *__pyx_v_tmpslice, char __pyx_v_order, int __pyx_v_ndim) { + int __pyx_v_i; + void *__pyx_v_result; + size_t __pyx_v_itemsize; + size_t __pyx_v_size; + void *__pyx_r; + Py_ssize_t __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + struct __pyx_memoryview_obj *__pyx_t_4; + int __pyx_t_5; + int __pyx_t_6; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save; + #endif + + /* "View.MemoryView":1216 + * cdef void *result + * + * cdef size_t itemsize = src.memview.view.itemsize # <<<<<<<<<<<<<< + * cdef size_t size = slice_get_size(src, ndim) + * + */ + __pyx_t_1 = __pyx_v_src->memview->view.itemsize; + __pyx_v_itemsize = __pyx_t_1; + + /* "View.MemoryView":1217 + * + * cdef size_t itemsize = src.memview.view.itemsize + * cdef size_t size = slice_get_size(src, ndim) # <<<<<<<<<<<<<< + * + * result = malloc(size) + */ + __pyx_v_size = __pyx_memoryview_slice_get_size(__pyx_v_src, __pyx_v_ndim); + + /* "View.MemoryView":1219 + * cdef size_t size = slice_get_size(src, ndim) + * + * result = malloc(size) # <<<<<<<<<<<<<< + * if not result: + * _err_no_memory() + */ + __pyx_v_result = malloc(__pyx_v_size); + + /* "View.MemoryView":1220 + * + * result = malloc(size) + * if not result: # <<<<<<<<<<<<<< + * _err_no_memory() + * + */ + __pyx_t_2 = (!(__pyx_v_result != 0)); + if (__pyx_t_2) { + + /* "View.MemoryView":1221 + * result = malloc(size) + * if not result: + * _err_no_memory() # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_3 = __pyx_memoryview_err_no_memory(); if (unlikely(__pyx_t_3 == ((int)-1))) __PYX_ERR(1, 1221, __pyx_L1_error) + + /* "View.MemoryView":1220 + * + * result = malloc(size) + * if not result: # <<<<<<<<<<<<<< + * _err_no_memory() + * + */ + } + + /* "View.MemoryView":1224 + * + * + * tmpslice.data = result # <<<<<<<<<<<<<< + * tmpslice.memview = src.memview + * for i in range(ndim): + */ + __pyx_v_tmpslice->data = ((char *)__pyx_v_result); + + /* "View.MemoryView":1225 + * + * tmpslice.data = result + * tmpslice.memview = src.memview # <<<<<<<<<<<<<< + * for i in range(ndim): + * tmpslice.shape[i] = src.shape[i] + */ + __pyx_t_4 = __pyx_v_src->memview; + __pyx_v_tmpslice->memview = __pyx_t_4; + + /* "View.MemoryView":1226 + * tmpslice.data = result + * tmpslice.memview = src.memview + * for i in range(ndim): # <<<<<<<<<<<<<< + * tmpslice.shape[i] = src.shape[i] + * tmpslice.suboffsets[i] = -1 + */ + __pyx_t_3 = __pyx_v_ndim; + __pyx_t_5 = __pyx_t_3; + for (__pyx_t_6 = 0; __pyx_t_6 < __pyx_t_5; __pyx_t_6+=1) { + __pyx_v_i = __pyx_t_6; + + /* "View.MemoryView":1227 + * tmpslice.memview = src.memview + * for i in range(ndim): + * tmpslice.shape[i] = src.shape[i] # <<<<<<<<<<<<<< + * tmpslice.suboffsets[i] = -1 + * + */ + (__pyx_v_tmpslice->shape[__pyx_v_i]) = (__pyx_v_src->shape[__pyx_v_i]); + + /* "View.MemoryView":1228 + * for i in range(ndim): + * tmpslice.shape[i] = src.shape[i] + * tmpslice.suboffsets[i] = -1 # <<<<<<<<<<<<<< + * + * fill_contig_strides_array(&tmpslice.shape[0], &tmpslice.strides[0], itemsize, ndim, order) + */ + (__pyx_v_tmpslice->suboffsets[__pyx_v_i]) = -1L; + } + + /* "View.MemoryView":1230 + * tmpslice.suboffsets[i] = -1 + * + * fill_contig_strides_array(&tmpslice.shape[0], &tmpslice.strides[0], itemsize, ndim, order) # <<<<<<<<<<<<<< + * + * + */ + (void)(__pyx_fill_contig_strides_array((&(__pyx_v_tmpslice->shape[0])), (&(__pyx_v_tmpslice->strides[0])), __pyx_v_itemsize, __pyx_v_ndim, __pyx_v_order)); + + /* "View.MemoryView":1233 + * + * + * for i in range(ndim): # <<<<<<<<<<<<<< + * if tmpslice.shape[i] == 1: + * tmpslice.strides[i] = 0 + */ + __pyx_t_3 = __pyx_v_ndim; + __pyx_t_5 = __pyx_t_3; + for (__pyx_t_6 = 0; __pyx_t_6 < __pyx_t_5; __pyx_t_6+=1) { + __pyx_v_i = __pyx_t_6; + + /* "View.MemoryView":1234 + * + * for i in range(ndim): + * if tmpslice.shape[i] == 1: # <<<<<<<<<<<<<< + * tmpslice.strides[i] = 0 + * + */ + __pyx_t_2 = ((__pyx_v_tmpslice->shape[__pyx_v_i]) == 1); + if (__pyx_t_2) { + + /* "View.MemoryView":1235 + * for i in range(ndim): + * if tmpslice.shape[i] == 1: + * tmpslice.strides[i] = 0 # <<<<<<<<<<<<<< + * + * if slice_is_contig(src[0], order, ndim): + */ + (__pyx_v_tmpslice->strides[__pyx_v_i]) = 0; + + /* "View.MemoryView":1234 + * + * for i in range(ndim): + * if tmpslice.shape[i] == 1: # <<<<<<<<<<<<<< + * tmpslice.strides[i] = 0 + * + */ + } + } + + /* "View.MemoryView":1237 + * tmpslice.strides[i] = 0 + * + * if slice_is_contig(src[0], order, ndim): # <<<<<<<<<<<<<< + * memcpy(result, src.data, size) + * else: + */ + __pyx_t_2 = __pyx_memviewslice_is_contig((__pyx_v_src[0]), __pyx_v_order, __pyx_v_ndim); + if (__pyx_t_2) { + + /* "View.MemoryView":1238 + * + * if slice_is_contig(src[0], order, ndim): + * memcpy(result, src.data, size) # <<<<<<<<<<<<<< + * else: + * copy_strided_to_strided(src, tmpslice, ndim, itemsize) + */ + (void)(memcpy(__pyx_v_result, __pyx_v_src->data, __pyx_v_size)); + + /* "View.MemoryView":1237 + * tmpslice.strides[i] = 0 + * + * if slice_is_contig(src[0], order, ndim): # <<<<<<<<<<<<<< + * memcpy(result, src.data, size) + * else: + */ + goto __pyx_L9; + } + + /* "View.MemoryView":1240 + * memcpy(result, src.data, size) + * else: + * copy_strided_to_strided(src, tmpslice, ndim, itemsize) # <<<<<<<<<<<<<< + * + * return result + */ + /*else*/ { + copy_strided_to_strided(__pyx_v_src, __pyx_v_tmpslice, __pyx_v_ndim, __pyx_v_itemsize); + } + __pyx_L9:; + + /* "View.MemoryView":1242 + * copy_strided_to_strided(src, tmpslice, ndim, itemsize) + * + * return result # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = __pyx_v_result; + goto __pyx_L0; + + /* "View.MemoryView":1205 + * + * @cname('__pyx_memoryview_copy_data_to_temp') + * cdef void *copy_data_to_temp(__Pyx_memviewslice *src, # <<<<<<<<<<<<<< + * __Pyx_memviewslice *tmpslice, + * char order, + */ + + /* function exit code */ + __pyx_L1_error:; + #ifdef WITH_THREAD + __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + __Pyx_AddTraceback("View.MemoryView.copy_data_to_temp", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":1247 + * + * @cname('__pyx_memoryview_err_extents') + * cdef int _err_extents(int i, Py_ssize_t extent1, # <<<<<<<<<<<<<< + * Py_ssize_t extent2) except -1 with gil: + * raise ValueError, f"got differing extents in dimension {i} (got {extent1} and {extent2})" + */ + +static int __pyx_memoryview_err_extents(int __pyx_v_i, Py_ssize_t __pyx_v_extent1, Py_ssize_t __pyx_v_extent2) { + int __pyx_r; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + Py_ssize_t __pyx_t_2; + Py_UCS4 __pyx_t_3; + PyObject *__pyx_t_4 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + __Pyx_RefNannySetupContext("_err_extents", 0); + + /* "View.MemoryView":1249 + * cdef int _err_extents(int i, Py_ssize_t extent1, + * Py_ssize_t extent2) except -1 with gil: + * raise ValueError, f"got differing extents in dimension {i} (got {extent1} and {extent2})" # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_err_dim') + */ + __pyx_t_1 = PyTuple_New(7); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 1249, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = 0; + __pyx_t_3 = 127; + __Pyx_INCREF(__pyx_kp_u_got_differing_extents_in_dimensi); + __pyx_t_2 += 35; + __Pyx_GIVEREF(__pyx_kp_u_got_differing_extents_in_dimensi); + PyTuple_SET_ITEM(__pyx_t_1, 0, __pyx_kp_u_got_differing_extents_in_dimensi); + __pyx_t_4 = __Pyx_PyUnicode_From_int(__pyx_v_i, 0, ' ', 'd'); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 1249, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_2 += __Pyx_PyUnicode_GET_LENGTH(__pyx_t_4); + __Pyx_GIVEREF(__pyx_t_4); + PyTuple_SET_ITEM(__pyx_t_1, 1, __pyx_t_4); + __pyx_t_4 = 0; + __Pyx_INCREF(__pyx_kp_u_got); + __pyx_t_2 += 6; + __Pyx_GIVEREF(__pyx_kp_u_got); + PyTuple_SET_ITEM(__pyx_t_1, 2, __pyx_kp_u_got); + __pyx_t_4 = __Pyx_PyUnicode_From_Py_ssize_t(__pyx_v_extent1, 0, ' ', 'd'); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 1249, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_2 += __Pyx_PyUnicode_GET_LENGTH(__pyx_t_4); + __Pyx_GIVEREF(__pyx_t_4); + PyTuple_SET_ITEM(__pyx_t_1, 3, __pyx_t_4); + __pyx_t_4 = 0; + __Pyx_INCREF(__pyx_kp_u_and); + __pyx_t_2 += 5; + __Pyx_GIVEREF(__pyx_kp_u_and); + PyTuple_SET_ITEM(__pyx_t_1, 4, __pyx_kp_u_and); + __pyx_t_4 = __Pyx_PyUnicode_From_Py_ssize_t(__pyx_v_extent2, 0, ' ', 'd'); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 1249, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_2 += __Pyx_PyUnicode_GET_LENGTH(__pyx_t_4); + __Pyx_GIVEREF(__pyx_t_4); + PyTuple_SET_ITEM(__pyx_t_1, 5, __pyx_t_4); + __pyx_t_4 = 0; + __Pyx_INCREF(__pyx_kp_u__7); + __pyx_t_2 += 1; + __Pyx_GIVEREF(__pyx_kp_u__7); + PyTuple_SET_ITEM(__pyx_t_1, 6, __pyx_kp_u__7); + __pyx_t_4 = __Pyx_PyUnicode_Join(__pyx_t_1, 7, __pyx_t_2, __pyx_t_3); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 1249, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_t_4, 0, 0); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __PYX_ERR(1, 1249, __pyx_L1_error) + + /* "View.MemoryView":1247 + * + * @cname('__pyx_memoryview_err_extents') + * cdef int _err_extents(int i, Py_ssize_t extent1, # <<<<<<<<<<<<<< + * Py_ssize_t extent2) except -1 with gil: + * raise ValueError, f"got differing extents in dimension {i} (got {extent1} and {extent2})" + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_4); + __Pyx_AddTraceback("View.MemoryView._err_extents", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __Pyx_RefNannyFinishContext(); + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + return __pyx_r; +} + +/* "View.MemoryView":1252 + * + * @cname('__pyx_memoryview_err_dim') + * cdef int _err_dim(PyObject *error, str msg, int dim) except -1 with gil: # <<<<<<<<<<<<<< + * raise error, msg % dim + * + */ + +static int __pyx_memoryview_err_dim(PyObject *__pyx_v_error, PyObject *__pyx_v_msg, int __pyx_v_dim) { + int __pyx_r; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + __Pyx_RefNannySetupContext("_err_dim", 0); + __Pyx_INCREF(__pyx_v_msg); + + /* "View.MemoryView":1253 + * @cname('__pyx_memoryview_err_dim') + * cdef int _err_dim(PyObject *error, str msg, int dim) except -1 with gil: + * raise error, msg % dim # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_err') + */ + __pyx_t_1 = __Pyx_PyInt_From_int(__pyx_v_dim); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 1253, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = __Pyx_PyString_FormatSafe(__pyx_v_msg, __pyx_t_1); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 1253, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_Raise(((PyObject *)__pyx_v_error), __pyx_t_2, 0, 0); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __PYX_ERR(1, 1253, __pyx_L1_error) + + /* "View.MemoryView":1252 + * + * @cname('__pyx_memoryview_err_dim') + * cdef int _err_dim(PyObject *error, str msg, int dim) except -1 with gil: # <<<<<<<<<<<<<< + * raise error, msg % dim + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView._err_dim", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __Pyx_XDECREF(__pyx_v_msg); + __Pyx_RefNannyFinishContext(); + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + return __pyx_r; +} + +/* "View.MemoryView":1256 + * + * @cname('__pyx_memoryview_err') + * cdef int _err(PyObject *error, str msg) except -1 with gil: # <<<<<<<<<<<<<< + * raise error, msg + * + */ + +static int __pyx_memoryview_err(PyObject *__pyx_v_error, PyObject *__pyx_v_msg) { + int __pyx_r; + __Pyx_RefNannyDeclarations + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + __Pyx_RefNannySetupContext("_err", 0); + __Pyx_INCREF(__pyx_v_msg); + + /* "View.MemoryView":1257 + * @cname('__pyx_memoryview_err') + * cdef int _err(PyObject *error, str msg) except -1 with gil: + * raise error, msg # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_err_no_memory') + */ + __Pyx_Raise(((PyObject *)__pyx_v_error), __pyx_v_msg, 0, 0); + __PYX_ERR(1, 1257, __pyx_L1_error) + + /* "View.MemoryView":1256 + * + * @cname('__pyx_memoryview_err') + * cdef int _err(PyObject *error, str msg) except -1 with gil: # <<<<<<<<<<<<<< + * raise error, msg + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView._err", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __Pyx_XDECREF(__pyx_v_msg); + __Pyx_RefNannyFinishContext(); + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + return __pyx_r; +} + +/* "View.MemoryView":1260 + * + * @cname('__pyx_memoryview_err_no_memory') + * cdef int _err_no_memory() except -1 with gil: # <<<<<<<<<<<<<< + * raise MemoryError + * + */ + +static int __pyx_memoryview_err_no_memory(void) { + int __pyx_r; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + + /* "View.MemoryView":1261 + * @cname('__pyx_memoryview_err_no_memory') + * cdef int _err_no_memory() except -1 with gil: + * raise MemoryError # <<<<<<<<<<<<<< + * + * + */ + PyErr_NoMemory(); __PYX_ERR(1, 1261, __pyx_L1_error) + + /* "View.MemoryView":1260 + * + * @cname('__pyx_memoryview_err_no_memory') + * cdef int _err_no_memory() except -1 with gil: # <<<<<<<<<<<<<< + * raise MemoryError + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView._err_no_memory", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + return __pyx_r; +} + +/* "View.MemoryView":1265 + * + * @cname('__pyx_memoryview_copy_contents') + * cdef int memoryview_copy_contents(__Pyx_memviewslice src, # <<<<<<<<<<<<<< + * __Pyx_memviewslice dst, + * int src_ndim, int dst_ndim, + */ + +static int __pyx_memoryview_copy_contents(__Pyx_memviewslice __pyx_v_src, __Pyx_memviewslice __pyx_v_dst, int __pyx_v_src_ndim, int __pyx_v_dst_ndim, int __pyx_v_dtype_is_object) { + void *__pyx_v_tmpdata; + size_t __pyx_v_itemsize; + int __pyx_v_i; + char __pyx_v_order; + int __pyx_v_broadcasting; + int __pyx_v_direct_copy; + __Pyx_memviewslice __pyx_v_tmp; + int __pyx_v_ndim; + int __pyx_r; + Py_ssize_t __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + int __pyx_t_4; + int __pyx_t_5; + int __pyx_t_6; + void *__pyx_t_7; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save; + #endif + + /* "View.MemoryView":1273 + * Check for overlapping memory and verify the shapes. + * """ + * cdef void *tmpdata = NULL # <<<<<<<<<<<<<< + * cdef size_t itemsize = src.memview.view.itemsize + * cdef int i + */ + __pyx_v_tmpdata = NULL; + + /* "View.MemoryView":1274 + * """ + * cdef void *tmpdata = NULL + * cdef size_t itemsize = src.memview.view.itemsize # <<<<<<<<<<<<<< + * cdef int i + * cdef char order = get_best_order(&src, src_ndim) + */ + __pyx_t_1 = __pyx_v_src.memview->view.itemsize; + __pyx_v_itemsize = __pyx_t_1; + + /* "View.MemoryView":1276 + * cdef size_t itemsize = src.memview.view.itemsize + * cdef int i + * cdef char order = get_best_order(&src, src_ndim) # <<<<<<<<<<<<<< + * cdef bint broadcasting = False + * cdef bint direct_copy = False + */ + __pyx_v_order = __pyx_get_best_slice_order((&__pyx_v_src), __pyx_v_src_ndim); + + /* "View.MemoryView":1277 + * cdef int i + * cdef char order = get_best_order(&src, src_ndim) + * cdef bint broadcasting = False # <<<<<<<<<<<<<< + * cdef bint direct_copy = False + * cdef __Pyx_memviewslice tmp + */ + __pyx_v_broadcasting = 0; + + /* "View.MemoryView":1278 + * cdef char order = get_best_order(&src, src_ndim) + * cdef bint broadcasting = False + * cdef bint direct_copy = False # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice tmp + * + */ + __pyx_v_direct_copy = 0; + + /* "View.MemoryView":1281 + * cdef __Pyx_memviewslice tmp + * + * if src_ndim < dst_ndim: # <<<<<<<<<<<<<< + * broadcast_leading(&src, src_ndim, dst_ndim) + * elif dst_ndim < src_ndim: + */ + __pyx_t_2 = (__pyx_v_src_ndim < __pyx_v_dst_ndim); + if (__pyx_t_2) { + + /* "View.MemoryView":1282 + * + * if src_ndim < dst_ndim: + * broadcast_leading(&src, src_ndim, dst_ndim) # <<<<<<<<<<<<<< + * elif dst_ndim < src_ndim: + * broadcast_leading(&dst, dst_ndim, src_ndim) + */ + __pyx_memoryview_broadcast_leading((&__pyx_v_src), __pyx_v_src_ndim, __pyx_v_dst_ndim); + + /* "View.MemoryView":1281 + * cdef __Pyx_memviewslice tmp + * + * if src_ndim < dst_ndim: # <<<<<<<<<<<<<< + * broadcast_leading(&src, src_ndim, dst_ndim) + * elif dst_ndim < src_ndim: + */ + goto __pyx_L3; + } + + /* "View.MemoryView":1283 + * if src_ndim < dst_ndim: + * broadcast_leading(&src, src_ndim, dst_ndim) + * elif dst_ndim < src_ndim: # <<<<<<<<<<<<<< + * broadcast_leading(&dst, dst_ndim, src_ndim) + * + */ + __pyx_t_2 = (__pyx_v_dst_ndim < __pyx_v_src_ndim); + if (__pyx_t_2) { + + /* "View.MemoryView":1284 + * broadcast_leading(&src, src_ndim, dst_ndim) + * elif dst_ndim < src_ndim: + * broadcast_leading(&dst, dst_ndim, src_ndim) # <<<<<<<<<<<<<< + * + * cdef int ndim = max(src_ndim, dst_ndim) + */ + __pyx_memoryview_broadcast_leading((&__pyx_v_dst), __pyx_v_dst_ndim, __pyx_v_src_ndim); + + /* "View.MemoryView":1283 + * if src_ndim < dst_ndim: + * broadcast_leading(&src, src_ndim, dst_ndim) + * elif dst_ndim < src_ndim: # <<<<<<<<<<<<<< + * broadcast_leading(&dst, dst_ndim, src_ndim) + * + */ + } + __pyx_L3:; + + /* "View.MemoryView":1286 + * broadcast_leading(&dst, dst_ndim, src_ndim) + * + * cdef int ndim = max(src_ndim, dst_ndim) # <<<<<<<<<<<<<< + * + * for i in range(ndim): + */ + __pyx_t_3 = __pyx_v_dst_ndim; + __pyx_t_4 = __pyx_v_src_ndim; + __pyx_t_2 = (__pyx_t_3 > __pyx_t_4); + if (__pyx_t_2) { + __pyx_t_5 = __pyx_t_3; + } else { + __pyx_t_5 = __pyx_t_4; + } + __pyx_v_ndim = __pyx_t_5; + + /* "View.MemoryView":1288 + * cdef int ndim = max(src_ndim, dst_ndim) + * + * for i in range(ndim): # <<<<<<<<<<<<<< + * if src.shape[i] != dst.shape[i]: + * if src.shape[i] == 1: + */ + __pyx_t_5 = __pyx_v_ndim; + __pyx_t_3 = __pyx_t_5; + for (__pyx_t_4 = 0; __pyx_t_4 < __pyx_t_3; __pyx_t_4+=1) { + __pyx_v_i = __pyx_t_4; + + /* "View.MemoryView":1289 + * + * for i in range(ndim): + * if src.shape[i] != dst.shape[i]: # <<<<<<<<<<<<<< + * if src.shape[i] == 1: + * broadcasting = True + */ + __pyx_t_2 = ((__pyx_v_src.shape[__pyx_v_i]) != (__pyx_v_dst.shape[__pyx_v_i])); + if (__pyx_t_2) { + + /* "View.MemoryView":1290 + * for i in range(ndim): + * if src.shape[i] != dst.shape[i]: + * if src.shape[i] == 1: # <<<<<<<<<<<<<< + * broadcasting = True + * src.strides[i] = 0 + */ + __pyx_t_2 = ((__pyx_v_src.shape[__pyx_v_i]) == 1); + if (__pyx_t_2) { + + /* "View.MemoryView":1291 + * if src.shape[i] != dst.shape[i]: + * if src.shape[i] == 1: + * broadcasting = True # <<<<<<<<<<<<<< + * src.strides[i] = 0 + * else: + */ + __pyx_v_broadcasting = 1; + + /* "View.MemoryView":1292 + * if src.shape[i] == 1: + * broadcasting = True + * src.strides[i] = 0 # <<<<<<<<<<<<<< + * else: + * _err_extents(i, dst.shape[i], src.shape[i]) + */ + (__pyx_v_src.strides[__pyx_v_i]) = 0; + + /* "View.MemoryView":1290 + * for i in range(ndim): + * if src.shape[i] != dst.shape[i]: + * if src.shape[i] == 1: # <<<<<<<<<<<<<< + * broadcasting = True + * src.strides[i] = 0 + */ + goto __pyx_L7; + } + + /* "View.MemoryView":1294 + * src.strides[i] = 0 + * else: + * _err_extents(i, dst.shape[i], src.shape[i]) # <<<<<<<<<<<<<< + * + * if src.suboffsets[i] >= 0: + */ + /*else*/ { + __pyx_t_6 = __pyx_memoryview_err_extents(__pyx_v_i, (__pyx_v_dst.shape[__pyx_v_i]), (__pyx_v_src.shape[__pyx_v_i])); if (unlikely(__pyx_t_6 == ((int)-1))) __PYX_ERR(1, 1294, __pyx_L1_error) + } + __pyx_L7:; + + /* "View.MemoryView":1289 + * + * for i in range(ndim): + * if src.shape[i] != dst.shape[i]: # <<<<<<<<<<<<<< + * if src.shape[i] == 1: + * broadcasting = True + */ + } + + /* "View.MemoryView":1296 + * _err_extents(i, dst.shape[i], src.shape[i]) + * + * if src.suboffsets[i] >= 0: # <<<<<<<<<<<<<< + * _err_dim(PyExc_ValueError, "Dimension %d is not direct", i) + * + */ + __pyx_t_2 = ((__pyx_v_src.suboffsets[__pyx_v_i]) >= 0); + if (__pyx_t_2) { + + /* "View.MemoryView":1297 + * + * if src.suboffsets[i] >= 0: + * _err_dim(PyExc_ValueError, "Dimension %d is not direct", i) # <<<<<<<<<<<<<< + * + * if slices_overlap(&src, &dst, ndim, itemsize): + */ + __pyx_t_6 = __pyx_memoryview_err_dim(PyExc_ValueError, __pyx_kp_s_Dimension_d_is_not_direct, __pyx_v_i); if (unlikely(__pyx_t_6 == ((int)-1))) __PYX_ERR(1, 1297, __pyx_L1_error) + + /* "View.MemoryView":1296 + * _err_extents(i, dst.shape[i], src.shape[i]) + * + * if src.suboffsets[i] >= 0: # <<<<<<<<<<<<<< + * _err_dim(PyExc_ValueError, "Dimension %d is not direct", i) + * + */ + } + } + + /* "View.MemoryView":1299 + * _err_dim(PyExc_ValueError, "Dimension %d is not direct", i) + * + * if slices_overlap(&src, &dst, ndim, itemsize): # <<<<<<<<<<<<<< + * + * if not slice_is_contig(src, order, ndim): + */ + __pyx_t_2 = __pyx_slices_overlap((&__pyx_v_src), (&__pyx_v_dst), __pyx_v_ndim, __pyx_v_itemsize); + if (__pyx_t_2) { + + /* "View.MemoryView":1301 + * if slices_overlap(&src, &dst, ndim, itemsize): + * + * if not slice_is_contig(src, order, ndim): # <<<<<<<<<<<<<< + * order = get_best_order(&dst, ndim) + * + */ + __pyx_t_2 = (!__pyx_memviewslice_is_contig(__pyx_v_src, __pyx_v_order, __pyx_v_ndim)); + if (__pyx_t_2) { + + /* "View.MemoryView":1302 + * + * if not slice_is_contig(src, order, ndim): + * order = get_best_order(&dst, ndim) # <<<<<<<<<<<<<< + * + * tmpdata = copy_data_to_temp(&src, &tmp, order, ndim) + */ + __pyx_v_order = __pyx_get_best_slice_order((&__pyx_v_dst), __pyx_v_ndim); + + /* "View.MemoryView":1301 + * if slices_overlap(&src, &dst, ndim, itemsize): + * + * if not slice_is_contig(src, order, ndim): # <<<<<<<<<<<<<< + * order = get_best_order(&dst, ndim) + * + */ + } + + /* "View.MemoryView":1304 + * order = get_best_order(&dst, ndim) + * + * tmpdata = copy_data_to_temp(&src, &tmp, order, ndim) # <<<<<<<<<<<<<< + * src = tmp + * + */ + __pyx_t_7 = __pyx_memoryview_copy_data_to_temp((&__pyx_v_src), (&__pyx_v_tmp), __pyx_v_order, __pyx_v_ndim); if (unlikely(__pyx_t_7 == ((void *)NULL))) __PYX_ERR(1, 1304, __pyx_L1_error) + __pyx_v_tmpdata = __pyx_t_7; + + /* "View.MemoryView":1305 + * + * tmpdata = copy_data_to_temp(&src, &tmp, order, ndim) + * src = tmp # <<<<<<<<<<<<<< + * + * if not broadcasting: + */ + __pyx_v_src = __pyx_v_tmp; + + /* "View.MemoryView":1299 + * _err_dim(PyExc_ValueError, "Dimension %d is not direct", i) + * + * if slices_overlap(&src, &dst, ndim, itemsize): # <<<<<<<<<<<<<< + * + * if not slice_is_contig(src, order, ndim): + */ + } + + /* "View.MemoryView":1307 + * src = tmp + * + * if not broadcasting: # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_2 = (!__pyx_v_broadcasting); + if (__pyx_t_2) { + + /* "View.MemoryView":1310 + * + * + * if slice_is_contig(src, 'C', ndim): # <<<<<<<<<<<<<< + * direct_copy = slice_is_contig(dst, 'C', ndim) + * elif slice_is_contig(src, 'F', ndim): + */ + __pyx_t_2 = __pyx_memviewslice_is_contig(__pyx_v_src, 'C', __pyx_v_ndim); + if (__pyx_t_2) { + + /* "View.MemoryView":1311 + * + * if slice_is_contig(src, 'C', ndim): + * direct_copy = slice_is_contig(dst, 'C', ndim) # <<<<<<<<<<<<<< + * elif slice_is_contig(src, 'F', ndim): + * direct_copy = slice_is_contig(dst, 'F', ndim) + */ + __pyx_v_direct_copy = __pyx_memviewslice_is_contig(__pyx_v_dst, 'C', __pyx_v_ndim); + + /* "View.MemoryView":1310 + * + * + * if slice_is_contig(src, 'C', ndim): # <<<<<<<<<<<<<< + * direct_copy = slice_is_contig(dst, 'C', ndim) + * elif slice_is_contig(src, 'F', ndim): + */ + goto __pyx_L12; + } + + /* "View.MemoryView":1312 + * if slice_is_contig(src, 'C', ndim): + * direct_copy = slice_is_contig(dst, 'C', ndim) + * elif slice_is_contig(src, 'F', ndim): # <<<<<<<<<<<<<< + * direct_copy = slice_is_contig(dst, 'F', ndim) + * + */ + __pyx_t_2 = __pyx_memviewslice_is_contig(__pyx_v_src, 'F', __pyx_v_ndim); + if (__pyx_t_2) { + + /* "View.MemoryView":1313 + * direct_copy = slice_is_contig(dst, 'C', ndim) + * elif slice_is_contig(src, 'F', ndim): + * direct_copy = slice_is_contig(dst, 'F', ndim) # <<<<<<<<<<<<<< + * + * if direct_copy: + */ + __pyx_v_direct_copy = __pyx_memviewslice_is_contig(__pyx_v_dst, 'F', __pyx_v_ndim); + + /* "View.MemoryView":1312 + * if slice_is_contig(src, 'C', ndim): + * direct_copy = slice_is_contig(dst, 'C', ndim) + * elif slice_is_contig(src, 'F', ndim): # <<<<<<<<<<<<<< + * direct_copy = slice_is_contig(dst, 'F', ndim) + * + */ + } + __pyx_L12:; + + /* "View.MemoryView":1315 + * direct_copy = slice_is_contig(dst, 'F', ndim) + * + * if direct_copy: # <<<<<<<<<<<<<< + * + * refcount_copying(&dst, dtype_is_object, ndim, inc=False) + */ + if (__pyx_v_direct_copy) { + + /* "View.MemoryView":1317 + * if direct_copy: + * + * refcount_copying(&dst, dtype_is_object, ndim, inc=False) # <<<<<<<<<<<<<< + * memcpy(dst.data, src.data, slice_get_size(&src, ndim)) + * refcount_copying(&dst, dtype_is_object, ndim, inc=True) + */ + __pyx_memoryview_refcount_copying((&__pyx_v_dst), __pyx_v_dtype_is_object, __pyx_v_ndim, 0); + + /* "View.MemoryView":1318 + * + * refcount_copying(&dst, dtype_is_object, ndim, inc=False) + * memcpy(dst.data, src.data, slice_get_size(&src, ndim)) # <<<<<<<<<<<<<< + * refcount_copying(&dst, dtype_is_object, ndim, inc=True) + * free(tmpdata) + */ + (void)(memcpy(__pyx_v_dst.data, __pyx_v_src.data, __pyx_memoryview_slice_get_size((&__pyx_v_src), __pyx_v_ndim))); + + /* "View.MemoryView":1319 + * refcount_copying(&dst, dtype_is_object, ndim, inc=False) + * memcpy(dst.data, src.data, slice_get_size(&src, ndim)) + * refcount_copying(&dst, dtype_is_object, ndim, inc=True) # <<<<<<<<<<<<<< + * free(tmpdata) + * return 0 + */ + __pyx_memoryview_refcount_copying((&__pyx_v_dst), __pyx_v_dtype_is_object, __pyx_v_ndim, 1); + + /* "View.MemoryView":1320 + * memcpy(dst.data, src.data, slice_get_size(&src, ndim)) + * refcount_copying(&dst, dtype_is_object, ndim, inc=True) + * free(tmpdata) # <<<<<<<<<<<<<< + * return 0 + * + */ + free(__pyx_v_tmpdata); + + /* "View.MemoryView":1321 + * refcount_copying(&dst, dtype_is_object, ndim, inc=True) + * free(tmpdata) + * return 0 # <<<<<<<<<<<<<< + * + * if order == 'F' == get_best_order(&dst, ndim): + */ + __pyx_r = 0; + goto __pyx_L0; + + /* "View.MemoryView":1315 + * direct_copy = slice_is_contig(dst, 'F', ndim) + * + * if direct_copy: # <<<<<<<<<<<<<< + * + * refcount_copying(&dst, dtype_is_object, ndim, inc=False) + */ + } + + /* "View.MemoryView":1307 + * src = tmp + * + * if not broadcasting: # <<<<<<<<<<<<<< + * + * + */ + } + + /* "View.MemoryView":1323 + * return 0 + * + * if order == 'F' == get_best_order(&dst, ndim): # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_2 = (__pyx_v_order == 'F'); + if (__pyx_t_2) { + __pyx_t_2 = ('F' == __pyx_get_best_slice_order((&__pyx_v_dst), __pyx_v_ndim)); + } + if (__pyx_t_2) { + + /* "View.MemoryView":1326 + * + * + * transpose_memslice(&src) # <<<<<<<<<<<<<< + * transpose_memslice(&dst) + * + */ + __pyx_t_5 = __pyx_memslice_transpose((&__pyx_v_src)); if (unlikely(__pyx_t_5 == ((int)-1))) __PYX_ERR(1, 1326, __pyx_L1_error) + + /* "View.MemoryView":1327 + * + * transpose_memslice(&src) + * transpose_memslice(&dst) # <<<<<<<<<<<<<< + * + * refcount_copying(&dst, dtype_is_object, ndim, inc=False) + */ + __pyx_t_5 = __pyx_memslice_transpose((&__pyx_v_dst)); if (unlikely(__pyx_t_5 == ((int)-1))) __PYX_ERR(1, 1327, __pyx_L1_error) + + /* "View.MemoryView":1323 + * return 0 + * + * if order == 'F' == get_best_order(&dst, ndim): # <<<<<<<<<<<<<< + * + * + */ + } + + /* "View.MemoryView":1329 + * transpose_memslice(&dst) + * + * refcount_copying(&dst, dtype_is_object, ndim, inc=False) # <<<<<<<<<<<<<< + * copy_strided_to_strided(&src, &dst, ndim, itemsize) + * refcount_copying(&dst, dtype_is_object, ndim, inc=True) + */ + __pyx_memoryview_refcount_copying((&__pyx_v_dst), __pyx_v_dtype_is_object, __pyx_v_ndim, 0); + + /* "View.MemoryView":1330 + * + * refcount_copying(&dst, dtype_is_object, ndim, inc=False) + * copy_strided_to_strided(&src, &dst, ndim, itemsize) # <<<<<<<<<<<<<< + * refcount_copying(&dst, dtype_is_object, ndim, inc=True) + * + */ + copy_strided_to_strided((&__pyx_v_src), (&__pyx_v_dst), __pyx_v_ndim, __pyx_v_itemsize); + + /* "View.MemoryView":1331 + * refcount_copying(&dst, dtype_is_object, ndim, inc=False) + * copy_strided_to_strided(&src, &dst, ndim, itemsize) + * refcount_copying(&dst, dtype_is_object, ndim, inc=True) # <<<<<<<<<<<<<< + * + * free(tmpdata) + */ + __pyx_memoryview_refcount_copying((&__pyx_v_dst), __pyx_v_dtype_is_object, __pyx_v_ndim, 1); + + /* "View.MemoryView":1333 + * refcount_copying(&dst, dtype_is_object, ndim, inc=True) + * + * free(tmpdata) # <<<<<<<<<<<<<< + * return 0 + * + */ + free(__pyx_v_tmpdata); + + /* "View.MemoryView":1334 + * + * free(tmpdata) + * return 0 # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_broadcast_leading') + */ + __pyx_r = 0; + goto __pyx_L0; + + /* "View.MemoryView":1265 + * + * @cname('__pyx_memoryview_copy_contents') + * cdef int memoryview_copy_contents(__Pyx_memviewslice src, # <<<<<<<<<<<<<< + * __Pyx_memviewslice dst, + * int src_ndim, int dst_ndim, + */ + + /* function exit code */ + __pyx_L1_error:; + #ifdef WITH_THREAD + __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + __Pyx_AddTraceback("View.MemoryView.memoryview_copy_contents", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":1337 + * + * @cname('__pyx_memoryview_broadcast_leading') + * cdef void broadcast_leading(__Pyx_memviewslice *mslice, # <<<<<<<<<<<<<< + * int ndim, + * int ndim_other) noexcept nogil: + */ + +static void __pyx_memoryview_broadcast_leading(__Pyx_memviewslice *__pyx_v_mslice, int __pyx_v_ndim, int __pyx_v_ndim_other) { + int __pyx_v_i; + int __pyx_v_offset; + int __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + + /* "View.MemoryView":1341 + * int ndim_other) noexcept nogil: + * cdef int i + * cdef int offset = ndim_other - ndim # <<<<<<<<<<<<<< + * + * for i in range(ndim - 1, -1, -1): + */ + __pyx_v_offset = (__pyx_v_ndim_other - __pyx_v_ndim); + + /* "View.MemoryView":1343 + * cdef int offset = ndim_other - ndim + * + * for i in range(ndim - 1, -1, -1): # <<<<<<<<<<<<<< + * mslice.shape[i + offset] = mslice.shape[i] + * mslice.strides[i + offset] = mslice.strides[i] + */ + for (__pyx_t_1 = (__pyx_v_ndim - 1); __pyx_t_1 > -1; __pyx_t_1-=1) { + __pyx_v_i = __pyx_t_1; + + /* "View.MemoryView":1344 + * + * for i in range(ndim - 1, -1, -1): + * mslice.shape[i + offset] = mslice.shape[i] # <<<<<<<<<<<<<< + * mslice.strides[i + offset] = mslice.strides[i] + * mslice.suboffsets[i + offset] = mslice.suboffsets[i] + */ + (__pyx_v_mslice->shape[(__pyx_v_i + __pyx_v_offset)]) = (__pyx_v_mslice->shape[__pyx_v_i]); + + /* "View.MemoryView":1345 + * for i in range(ndim - 1, -1, -1): + * mslice.shape[i + offset] = mslice.shape[i] + * mslice.strides[i + offset] = mslice.strides[i] # <<<<<<<<<<<<<< + * mslice.suboffsets[i + offset] = mslice.suboffsets[i] + * + */ + (__pyx_v_mslice->strides[(__pyx_v_i + __pyx_v_offset)]) = (__pyx_v_mslice->strides[__pyx_v_i]); + + /* "View.MemoryView":1346 + * mslice.shape[i + offset] = mslice.shape[i] + * mslice.strides[i + offset] = mslice.strides[i] + * mslice.suboffsets[i + offset] = mslice.suboffsets[i] # <<<<<<<<<<<<<< + * + * for i in range(offset): + */ + (__pyx_v_mslice->suboffsets[(__pyx_v_i + __pyx_v_offset)]) = (__pyx_v_mslice->suboffsets[__pyx_v_i]); + } + + /* "View.MemoryView":1348 + * mslice.suboffsets[i + offset] = mslice.suboffsets[i] + * + * for i in range(offset): # <<<<<<<<<<<<<< + * mslice.shape[i] = 1 + * mslice.strides[i] = mslice.strides[0] + */ + __pyx_t_1 = __pyx_v_offset; + __pyx_t_2 = __pyx_t_1; + for (__pyx_t_3 = 0; __pyx_t_3 < __pyx_t_2; __pyx_t_3+=1) { + __pyx_v_i = __pyx_t_3; + + /* "View.MemoryView":1349 + * + * for i in range(offset): + * mslice.shape[i] = 1 # <<<<<<<<<<<<<< + * mslice.strides[i] = mslice.strides[0] + * mslice.suboffsets[i] = -1 + */ + (__pyx_v_mslice->shape[__pyx_v_i]) = 1; + + /* "View.MemoryView":1350 + * for i in range(offset): + * mslice.shape[i] = 1 + * mslice.strides[i] = mslice.strides[0] # <<<<<<<<<<<<<< + * mslice.suboffsets[i] = -1 + * + */ + (__pyx_v_mslice->strides[__pyx_v_i]) = (__pyx_v_mslice->strides[0]); + + /* "View.MemoryView":1351 + * mslice.shape[i] = 1 + * mslice.strides[i] = mslice.strides[0] + * mslice.suboffsets[i] = -1 # <<<<<<<<<<<<<< + * + * + */ + (__pyx_v_mslice->suboffsets[__pyx_v_i]) = -1L; + } + + /* "View.MemoryView":1337 + * + * @cname('__pyx_memoryview_broadcast_leading') + * cdef void broadcast_leading(__Pyx_memviewslice *mslice, # <<<<<<<<<<<<<< + * int ndim, + * int ndim_other) noexcept nogil: + */ + + /* function exit code */ +} + +/* "View.MemoryView":1359 + * + * @cname('__pyx_memoryview_refcount_copying') + * cdef void refcount_copying(__Pyx_memviewslice *dst, bint dtype_is_object, int ndim, bint inc) noexcept nogil: # <<<<<<<<<<<<<< + * + * if dtype_is_object: + */ + +static void __pyx_memoryview_refcount_copying(__Pyx_memviewslice *__pyx_v_dst, int __pyx_v_dtype_is_object, int __pyx_v_ndim, int __pyx_v_inc) { + + /* "View.MemoryView":1361 + * cdef void refcount_copying(__Pyx_memviewslice *dst, bint dtype_is_object, int ndim, bint inc) noexcept nogil: + * + * if dtype_is_object: # <<<<<<<<<<<<<< + * refcount_objects_in_slice_with_gil(dst.data, dst.shape, dst.strides, ndim, inc) + * + */ + if (__pyx_v_dtype_is_object) { + + /* "View.MemoryView":1362 + * + * if dtype_is_object: + * refcount_objects_in_slice_with_gil(dst.data, dst.shape, dst.strides, ndim, inc) # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_refcount_objects_in_slice_with_gil') + */ + __pyx_memoryview_refcount_objects_in_slice_with_gil(__pyx_v_dst->data, __pyx_v_dst->shape, __pyx_v_dst->strides, __pyx_v_ndim, __pyx_v_inc); + + /* "View.MemoryView":1361 + * cdef void refcount_copying(__Pyx_memviewslice *dst, bint dtype_is_object, int ndim, bint inc) noexcept nogil: + * + * if dtype_is_object: # <<<<<<<<<<<<<< + * refcount_objects_in_slice_with_gil(dst.data, dst.shape, dst.strides, ndim, inc) + * + */ + } + + /* "View.MemoryView":1359 + * + * @cname('__pyx_memoryview_refcount_copying') + * cdef void refcount_copying(__Pyx_memviewslice *dst, bint dtype_is_object, int ndim, bint inc) noexcept nogil: # <<<<<<<<<<<<<< + * + * if dtype_is_object: + */ + + /* function exit code */ +} + +/* "View.MemoryView":1365 + * + * @cname('__pyx_memoryview_refcount_objects_in_slice_with_gil') + * cdef void refcount_objects_in_slice_with_gil(char *data, Py_ssize_t *shape, # <<<<<<<<<<<<<< + * Py_ssize_t *strides, int ndim, + * bint inc) noexcept with gil: + */ + +static void __pyx_memoryview_refcount_objects_in_slice_with_gil(char *__pyx_v_data, Py_ssize_t *__pyx_v_shape, Py_ssize_t *__pyx_v_strides, int __pyx_v_ndim, int __pyx_v_inc) { + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + + /* "View.MemoryView":1368 + * Py_ssize_t *strides, int ndim, + * bint inc) noexcept with gil: + * refcount_objects_in_slice(data, shape, strides, ndim, inc) # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_refcount_objects_in_slice') + */ + __pyx_memoryview_refcount_objects_in_slice(__pyx_v_data, __pyx_v_shape, __pyx_v_strides, __pyx_v_ndim, __pyx_v_inc); + + /* "View.MemoryView":1365 + * + * @cname('__pyx_memoryview_refcount_objects_in_slice_with_gil') + * cdef void refcount_objects_in_slice_with_gil(char *data, Py_ssize_t *shape, # <<<<<<<<<<<<<< + * Py_ssize_t *strides, int ndim, + * bint inc) noexcept with gil: + */ + + /* function exit code */ + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif +} + +/* "View.MemoryView":1371 + * + * @cname('__pyx_memoryview_refcount_objects_in_slice') + * cdef void refcount_objects_in_slice(char *data, Py_ssize_t *shape, # <<<<<<<<<<<<<< + * Py_ssize_t *strides, int ndim, bint inc) noexcept: + * cdef Py_ssize_t i + */ + +static void __pyx_memoryview_refcount_objects_in_slice(char *__pyx_v_data, Py_ssize_t *__pyx_v_shape, Py_ssize_t *__pyx_v_strides, int __pyx_v_ndim, int __pyx_v_inc) { + CYTHON_UNUSED Py_ssize_t __pyx_v_i; + Py_ssize_t __pyx_v_stride; + Py_ssize_t __pyx_t_1; + Py_ssize_t __pyx_t_2; + Py_ssize_t __pyx_t_3; + int __pyx_t_4; + + /* "View.MemoryView":1374 + * Py_ssize_t *strides, int ndim, bint inc) noexcept: + * cdef Py_ssize_t i + * cdef Py_ssize_t stride = strides[0] # <<<<<<<<<<<<<< + * + * for i in range(shape[0]): + */ + __pyx_v_stride = (__pyx_v_strides[0]); + + /* "View.MemoryView":1376 + * cdef Py_ssize_t stride = strides[0] + * + * for i in range(shape[0]): # <<<<<<<<<<<<<< + * if ndim == 1: + * if inc: + */ + __pyx_t_1 = (__pyx_v_shape[0]); + __pyx_t_2 = __pyx_t_1; + for (__pyx_t_3 = 0; __pyx_t_3 < __pyx_t_2; __pyx_t_3+=1) { + __pyx_v_i = __pyx_t_3; + + /* "View.MemoryView":1377 + * + * for i in range(shape[0]): + * if ndim == 1: # <<<<<<<<<<<<<< + * if inc: + * Py_INCREF(( data)[0]) + */ + __pyx_t_4 = (__pyx_v_ndim == 1); + if (__pyx_t_4) { + + /* "View.MemoryView":1378 + * for i in range(shape[0]): + * if ndim == 1: + * if inc: # <<<<<<<<<<<<<< + * Py_INCREF(( data)[0]) + * else: + */ + if (__pyx_v_inc) { + + /* "View.MemoryView":1379 + * if ndim == 1: + * if inc: + * Py_INCREF(( data)[0]) # <<<<<<<<<<<<<< + * else: + * Py_DECREF(( data)[0]) + */ + Py_INCREF((((PyObject **)__pyx_v_data)[0])); + + /* "View.MemoryView":1378 + * for i in range(shape[0]): + * if ndim == 1: + * if inc: # <<<<<<<<<<<<<< + * Py_INCREF(( data)[0]) + * else: + */ + goto __pyx_L6; + } + + /* "View.MemoryView":1381 + * Py_INCREF(( data)[0]) + * else: + * Py_DECREF(( data)[0]) # <<<<<<<<<<<<<< + * else: + * refcount_objects_in_slice(data, shape + 1, strides + 1, ndim - 1, inc) + */ + /*else*/ { + Py_DECREF((((PyObject **)__pyx_v_data)[0])); + } + __pyx_L6:; + + /* "View.MemoryView":1377 + * + * for i in range(shape[0]): + * if ndim == 1: # <<<<<<<<<<<<<< + * if inc: + * Py_INCREF(( data)[0]) + */ + goto __pyx_L5; + } + + /* "View.MemoryView":1383 + * Py_DECREF(( data)[0]) + * else: + * refcount_objects_in_slice(data, shape + 1, strides + 1, ndim - 1, inc) # <<<<<<<<<<<<<< + * + * data += stride + */ + /*else*/ { + __pyx_memoryview_refcount_objects_in_slice(__pyx_v_data, (__pyx_v_shape + 1), (__pyx_v_strides + 1), (__pyx_v_ndim - 1), __pyx_v_inc); + } + __pyx_L5:; + + /* "View.MemoryView":1385 + * refcount_objects_in_slice(data, shape + 1, strides + 1, ndim - 1, inc) + * + * data += stride # <<<<<<<<<<<<<< + * + * + */ + __pyx_v_data = (__pyx_v_data + __pyx_v_stride); + } + + /* "View.MemoryView":1371 + * + * @cname('__pyx_memoryview_refcount_objects_in_slice') + * cdef void refcount_objects_in_slice(char *data, Py_ssize_t *shape, # <<<<<<<<<<<<<< + * Py_ssize_t *strides, int ndim, bint inc) noexcept: + * cdef Py_ssize_t i + */ + + /* function exit code */ +} + +/* "View.MemoryView":1391 + * + * @cname('__pyx_memoryview_slice_assign_scalar') + * cdef void slice_assign_scalar(__Pyx_memviewslice *dst, int ndim, # <<<<<<<<<<<<<< + * size_t itemsize, void *item, + * bint dtype_is_object) noexcept nogil: + */ + +static void __pyx_memoryview_slice_assign_scalar(__Pyx_memviewslice *__pyx_v_dst, int __pyx_v_ndim, size_t __pyx_v_itemsize, void *__pyx_v_item, int __pyx_v_dtype_is_object) { + + /* "View.MemoryView":1394 + * size_t itemsize, void *item, + * bint dtype_is_object) noexcept nogil: + * refcount_copying(dst, dtype_is_object, ndim, inc=False) # <<<<<<<<<<<<<< + * _slice_assign_scalar(dst.data, dst.shape, dst.strides, ndim, itemsize, item) + * refcount_copying(dst, dtype_is_object, ndim, inc=True) + */ + __pyx_memoryview_refcount_copying(__pyx_v_dst, __pyx_v_dtype_is_object, __pyx_v_ndim, 0); + + /* "View.MemoryView":1395 + * bint dtype_is_object) noexcept nogil: + * refcount_copying(dst, dtype_is_object, ndim, inc=False) + * _slice_assign_scalar(dst.data, dst.shape, dst.strides, ndim, itemsize, item) # <<<<<<<<<<<<<< + * refcount_copying(dst, dtype_is_object, ndim, inc=True) + * + */ + __pyx_memoryview__slice_assign_scalar(__pyx_v_dst->data, __pyx_v_dst->shape, __pyx_v_dst->strides, __pyx_v_ndim, __pyx_v_itemsize, __pyx_v_item); + + /* "View.MemoryView":1396 + * refcount_copying(dst, dtype_is_object, ndim, inc=False) + * _slice_assign_scalar(dst.data, dst.shape, dst.strides, ndim, itemsize, item) + * refcount_copying(dst, dtype_is_object, ndim, inc=True) # <<<<<<<<<<<<<< + * + * + */ + __pyx_memoryview_refcount_copying(__pyx_v_dst, __pyx_v_dtype_is_object, __pyx_v_ndim, 1); + + /* "View.MemoryView":1391 + * + * @cname('__pyx_memoryview_slice_assign_scalar') + * cdef void slice_assign_scalar(__Pyx_memviewslice *dst, int ndim, # <<<<<<<<<<<<<< + * size_t itemsize, void *item, + * bint dtype_is_object) noexcept nogil: + */ + + /* function exit code */ +} + +/* "View.MemoryView":1400 + * + * @cname('__pyx_memoryview__slice_assign_scalar') + * cdef void _slice_assign_scalar(char *data, Py_ssize_t *shape, # <<<<<<<<<<<<<< + * Py_ssize_t *strides, int ndim, + * size_t itemsize, void *item) noexcept nogil: + */ + +static void __pyx_memoryview__slice_assign_scalar(char *__pyx_v_data, Py_ssize_t *__pyx_v_shape, Py_ssize_t *__pyx_v_strides, int __pyx_v_ndim, size_t __pyx_v_itemsize, void *__pyx_v_item) { + CYTHON_UNUSED Py_ssize_t __pyx_v_i; + Py_ssize_t __pyx_v_stride; + Py_ssize_t __pyx_v_extent; + int __pyx_t_1; + Py_ssize_t __pyx_t_2; + Py_ssize_t __pyx_t_3; + Py_ssize_t __pyx_t_4; + + /* "View.MemoryView":1404 + * size_t itemsize, void *item) noexcept nogil: + * cdef Py_ssize_t i + * cdef Py_ssize_t stride = strides[0] # <<<<<<<<<<<<<< + * cdef Py_ssize_t extent = shape[0] + * + */ + __pyx_v_stride = (__pyx_v_strides[0]); + + /* "View.MemoryView":1405 + * cdef Py_ssize_t i + * cdef Py_ssize_t stride = strides[0] + * cdef Py_ssize_t extent = shape[0] # <<<<<<<<<<<<<< + * + * if ndim == 1: + */ + __pyx_v_extent = (__pyx_v_shape[0]); + + /* "View.MemoryView":1407 + * cdef Py_ssize_t extent = shape[0] + * + * if ndim == 1: # <<<<<<<<<<<<<< + * for i in range(extent): + * memcpy(data, item, itemsize) + */ + __pyx_t_1 = (__pyx_v_ndim == 1); + if (__pyx_t_1) { + + /* "View.MemoryView":1408 + * + * if ndim == 1: + * for i in range(extent): # <<<<<<<<<<<<<< + * memcpy(data, item, itemsize) + * data += stride + */ + __pyx_t_2 = __pyx_v_extent; + __pyx_t_3 = __pyx_t_2; + for (__pyx_t_4 = 0; __pyx_t_4 < __pyx_t_3; __pyx_t_4+=1) { + __pyx_v_i = __pyx_t_4; + + /* "View.MemoryView":1409 + * if ndim == 1: + * for i in range(extent): + * memcpy(data, item, itemsize) # <<<<<<<<<<<<<< + * data += stride + * else: + */ + (void)(memcpy(__pyx_v_data, __pyx_v_item, __pyx_v_itemsize)); + + /* "View.MemoryView":1410 + * for i in range(extent): + * memcpy(data, item, itemsize) + * data += stride # <<<<<<<<<<<<<< + * else: + * for i in range(extent): + */ + __pyx_v_data = (__pyx_v_data + __pyx_v_stride); + } + + /* "View.MemoryView":1407 + * cdef Py_ssize_t extent = shape[0] + * + * if ndim == 1: # <<<<<<<<<<<<<< + * for i in range(extent): + * memcpy(data, item, itemsize) + */ + goto __pyx_L3; + } + + /* "View.MemoryView":1412 + * data += stride + * else: + * for i in range(extent): # <<<<<<<<<<<<<< + * _slice_assign_scalar(data, shape + 1, strides + 1, ndim - 1, itemsize, item) + * data += stride + */ + /*else*/ { + __pyx_t_2 = __pyx_v_extent; + __pyx_t_3 = __pyx_t_2; + for (__pyx_t_4 = 0; __pyx_t_4 < __pyx_t_3; __pyx_t_4+=1) { + __pyx_v_i = __pyx_t_4; + + /* "View.MemoryView":1413 + * else: + * for i in range(extent): + * _slice_assign_scalar(data, shape + 1, strides + 1, ndim - 1, itemsize, item) # <<<<<<<<<<<<<< + * data += stride + * + */ + __pyx_memoryview__slice_assign_scalar(__pyx_v_data, (__pyx_v_shape + 1), (__pyx_v_strides + 1), (__pyx_v_ndim - 1), __pyx_v_itemsize, __pyx_v_item); + + /* "View.MemoryView":1414 + * for i in range(extent): + * _slice_assign_scalar(data, shape + 1, strides + 1, ndim - 1, itemsize, item) + * data += stride # <<<<<<<<<<<<<< + * + * + */ + __pyx_v_data = (__pyx_v_data + __pyx_v_stride); + } + } + __pyx_L3:; + + /* "View.MemoryView":1400 + * + * @cname('__pyx_memoryview__slice_assign_scalar') + * cdef void _slice_assign_scalar(char *data, Py_ssize_t *shape, # <<<<<<<<<<<<<< + * Py_ssize_t *strides, int ndim, + * size_t itemsize, void *item) noexcept nogil: + */ + + /* function exit code */ +} + +/* "(tree fragment)":1 + * def __pyx_unpickle_Enum(__pyx_type, long __pyx_checksum, __pyx_state): # <<<<<<<<<<<<<< + * cdef object __pyx_PickleError + * cdef object __pyx_result + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_1__pyx_unpickle_Enum(PyObject *__pyx_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyMethodDef __pyx_mdef_15View_dot_MemoryView_1__pyx_unpickle_Enum = {"__pyx_unpickle_Enum", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw_15View_dot_MemoryView_1__pyx_unpickle_Enum, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}; +static PyObject *__pyx_pw_15View_dot_MemoryView_1__pyx_unpickle_Enum(PyObject *__pyx_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + PyObject *__pyx_v___pyx_type = 0; + long __pyx_v___pyx_checksum; + PyObject *__pyx_v___pyx_state = 0; + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[3] = {0,0,0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__pyx_unpickle_Enum (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_pyx_type,&__pyx_n_s_pyx_checksum,&__pyx_n_s_pyx_state,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 3: values[2] = __Pyx_Arg_FASTCALL(__pyx_args, 2); + CYTHON_FALLTHROUGH; + case 2: values[1] = __Pyx_Arg_FASTCALL(__pyx_args, 1); + CYTHON_FALLTHROUGH; + case 1: values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_FASTCALL(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_pyx_type)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 1, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + CYTHON_FALLTHROUGH; + case 1: + if (likely((values[1] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_pyx_checksum)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[1]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 1, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("__pyx_unpickle_Enum", 1, 3, 3, 1); __PYX_ERR(1, 1, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 2: + if (likely((values[2] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_pyx_state)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[2]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 1, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("__pyx_unpickle_Enum", 1, 3, 3, 2); __PYX_ERR(1, 1, __pyx_L3_error) + } + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "__pyx_unpickle_Enum") < 0)) __PYX_ERR(1, 1, __pyx_L3_error) + } + } else if (unlikely(__pyx_nargs != 3)) { + goto __pyx_L5_argtuple_error; + } else { + values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + values[1] = __Pyx_Arg_FASTCALL(__pyx_args, 1); + values[2] = __Pyx_Arg_FASTCALL(__pyx_args, 2); + } + __pyx_v___pyx_type = values[0]; + __pyx_v___pyx_checksum = __Pyx_PyInt_As_long(values[1]); if (unlikely((__pyx_v___pyx_checksum == (long)-1) && PyErr_Occurred())) __PYX_ERR(1, 1, __pyx_L3_error) + __pyx_v___pyx_state = values[2]; + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("__pyx_unpickle_Enum", 1, 3, 3, __pyx_nargs); __PYX_ERR(1, 1, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("View.MemoryView.__pyx_unpickle_Enum", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return NULL; + __pyx_L4_argument_unpacking_done:; + __pyx_r = __pyx_pf_15View_dot_MemoryView___pyx_unpickle_Enum(__pyx_self, __pyx_v___pyx_type, __pyx_v___pyx_checksum, __pyx_v___pyx_state); + + /* function exit code */ + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView___pyx_unpickle_Enum(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v___pyx_type, long __pyx_v___pyx_checksum, PyObject *__pyx_v___pyx_state) { + PyObject *__pyx_v___pyx_PickleError = 0; + PyObject *__pyx_v___pyx_result = 0; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + int __pyx_t_5; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__pyx_unpickle_Enum", 1); + + /* "(tree fragment)":4 + * cdef object __pyx_PickleError + * cdef object __pyx_result + * if __pyx_checksum not in (0x82a3537, 0x6ae9995, 0xb068931): # <<<<<<<<<<<<<< + * from pickle import PickleError as __pyx_PickleError + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x82a3537, 0x6ae9995, 0xb068931) = (name))" % __pyx_checksum + */ + __pyx_t_1 = __Pyx_PyInt_From_long(__pyx_v___pyx_checksum); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 4, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = (__Pyx_PySequence_ContainsTF(__pyx_t_1, __pyx_tuple__8, Py_NE)); if (unlikely((__pyx_t_2 < 0))) __PYX_ERR(1, 4, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + if (__pyx_t_2) { + + /* "(tree fragment)":5 + * cdef object __pyx_result + * if __pyx_checksum not in (0x82a3537, 0x6ae9995, 0xb068931): + * from pickle import PickleError as __pyx_PickleError # <<<<<<<<<<<<<< + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x82a3537, 0x6ae9995, 0xb068931) = (name))" % __pyx_checksum + * __pyx_result = Enum.__new__(__pyx_type) + */ + __pyx_t_1 = PyList_New(1); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 5, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_INCREF(__pyx_n_s_PickleError); + __Pyx_GIVEREF(__pyx_n_s_PickleError); + if (__Pyx_PyList_SET_ITEM(__pyx_t_1, 0, __pyx_n_s_PickleError)) __PYX_ERR(1, 5, __pyx_L1_error); + __pyx_t_3 = __Pyx_Import(__pyx_n_s_pickle, __pyx_t_1, 0); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 5, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_t_1 = __Pyx_ImportFrom(__pyx_t_3, __pyx_n_s_PickleError); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 5, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_INCREF(__pyx_t_1); + __pyx_v___pyx_PickleError = __pyx_t_1; + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + + /* "(tree fragment)":6 + * if __pyx_checksum not in (0x82a3537, 0x6ae9995, 0xb068931): + * from pickle import PickleError as __pyx_PickleError + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x82a3537, 0x6ae9995, 0xb068931) = (name))" % __pyx_checksum # <<<<<<<<<<<<<< + * __pyx_result = Enum.__new__(__pyx_type) + * if __pyx_state is not None: + */ + __pyx_t_3 = __Pyx_PyInt_From_long(__pyx_v___pyx_checksum); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 6, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_1 = __Pyx_PyString_Format(__pyx_kp_s_Incompatible_checksums_0x_x_vs_0, __pyx_t_3); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 6, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __Pyx_Raise(__pyx_v___pyx_PickleError, __pyx_t_1, 0, 0); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __PYX_ERR(1, 6, __pyx_L1_error) + + /* "(tree fragment)":4 + * cdef object __pyx_PickleError + * cdef object __pyx_result + * if __pyx_checksum not in (0x82a3537, 0x6ae9995, 0xb068931): # <<<<<<<<<<<<<< + * from pickle import PickleError as __pyx_PickleError + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x82a3537, 0x6ae9995, 0xb068931) = (name))" % __pyx_checksum + */ + } + + /* "(tree fragment)":7 + * from pickle import PickleError as __pyx_PickleError + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x82a3537, 0x6ae9995, 0xb068931) = (name))" % __pyx_checksum + * __pyx_result = Enum.__new__(__pyx_type) # <<<<<<<<<<<<<< + * if __pyx_state is not None: + * __pyx_unpickle_Enum__set_state( __pyx_result, __pyx_state) + */ + __pyx_t_3 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_MemviewEnum_type), __pyx_n_s_new); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 7, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_4 = NULL; + __pyx_t_5 = 0; + #if CYTHON_UNPACK_METHODS + if (likely(PyMethod_Check(__pyx_t_3))) { + __pyx_t_4 = PyMethod_GET_SELF(__pyx_t_3); + if (likely(__pyx_t_4)) { + PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_3); + __Pyx_INCREF(__pyx_t_4); + __Pyx_INCREF(function); + __Pyx_DECREF_SET(__pyx_t_3, function); + __pyx_t_5 = 1; + } + } + #endif + { + PyObject *__pyx_callargs[2] = {__pyx_t_4, __pyx_v___pyx_type}; + __pyx_t_1 = __Pyx_PyObject_FastCall(__pyx_t_3, __pyx_callargs+1-__pyx_t_5, 1+__pyx_t_5); + __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; + if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 7, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + } + __pyx_v___pyx_result = __pyx_t_1; + __pyx_t_1 = 0; + + /* "(tree fragment)":8 + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x82a3537, 0x6ae9995, 0xb068931) = (name))" % __pyx_checksum + * __pyx_result = Enum.__new__(__pyx_type) + * if __pyx_state is not None: # <<<<<<<<<<<<<< + * __pyx_unpickle_Enum__set_state( __pyx_result, __pyx_state) + * return __pyx_result + */ + __pyx_t_2 = (__pyx_v___pyx_state != Py_None); + if (__pyx_t_2) { + + /* "(tree fragment)":9 + * __pyx_result = Enum.__new__(__pyx_type) + * if __pyx_state is not None: + * __pyx_unpickle_Enum__set_state( __pyx_result, __pyx_state) # <<<<<<<<<<<<<< + * return __pyx_result + * cdef __pyx_unpickle_Enum__set_state(Enum __pyx_result, tuple __pyx_state): + */ + if (!(likely(PyTuple_CheckExact(__pyx_v___pyx_state))||((__pyx_v___pyx_state) == Py_None) || __Pyx_RaiseUnexpectedTypeError("tuple", __pyx_v___pyx_state))) __PYX_ERR(1, 9, __pyx_L1_error) + __pyx_t_1 = __pyx_unpickle_Enum__set_state(((struct __pyx_MemviewEnum_obj *)__pyx_v___pyx_result), ((PyObject*)__pyx_v___pyx_state)); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 9, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + + /* "(tree fragment)":8 + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x82a3537, 0x6ae9995, 0xb068931) = (name))" % __pyx_checksum + * __pyx_result = Enum.__new__(__pyx_type) + * if __pyx_state is not None: # <<<<<<<<<<<<<< + * __pyx_unpickle_Enum__set_state( __pyx_result, __pyx_state) + * return __pyx_result + */ + } + + /* "(tree fragment)":10 + * if __pyx_state is not None: + * __pyx_unpickle_Enum__set_state( __pyx_result, __pyx_state) + * return __pyx_result # <<<<<<<<<<<<<< + * cdef __pyx_unpickle_Enum__set_state(Enum __pyx_result, tuple __pyx_state): + * __pyx_result.name = __pyx_state[0] + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(__pyx_v___pyx_result); + __pyx_r = __pyx_v___pyx_result; + goto __pyx_L0; + + /* "(tree fragment)":1 + * def __pyx_unpickle_Enum(__pyx_type, long __pyx_checksum, __pyx_state): # <<<<<<<<<<<<<< + * cdef object __pyx_PickleError + * cdef object __pyx_result + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_XDECREF(__pyx_t_4); + __Pyx_AddTraceback("View.MemoryView.__pyx_unpickle_Enum", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v___pyx_PickleError); + __Pyx_XDECREF(__pyx_v___pyx_result); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":11 + * __pyx_unpickle_Enum__set_state( __pyx_result, __pyx_state) + * return __pyx_result + * cdef __pyx_unpickle_Enum__set_state(Enum __pyx_result, tuple __pyx_state): # <<<<<<<<<<<<<< + * __pyx_result.name = __pyx_state[0] + * if len(__pyx_state) > 1 and hasattr(__pyx_result, '__dict__'): + */ + +static PyObject *__pyx_unpickle_Enum__set_state(struct __pyx_MemviewEnum_obj *__pyx_v___pyx_result, PyObject *__pyx_v___pyx_state) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_t_2; + Py_ssize_t __pyx_t_3; + int __pyx_t_4; + PyObject *__pyx_t_5 = NULL; + PyObject *__pyx_t_6 = NULL; + PyObject *__pyx_t_7 = NULL; + int __pyx_t_8; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__pyx_unpickle_Enum__set_state", 1); + + /* "(tree fragment)":12 + * return __pyx_result + * cdef __pyx_unpickle_Enum__set_state(Enum __pyx_result, tuple __pyx_state): + * __pyx_result.name = __pyx_state[0] # <<<<<<<<<<<<<< + * if len(__pyx_state) > 1 and hasattr(__pyx_result, '__dict__'): + * __pyx_result.__dict__.update(__pyx_state[1]) + */ + if (unlikely(__pyx_v___pyx_state == Py_None)) { + PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); + __PYX_ERR(1, 12, __pyx_L1_error) + } + __pyx_t_1 = __Pyx_GetItemInt_Tuple(__pyx_v___pyx_state, 0, long, 1, __Pyx_PyInt_From_long, 0, 0, 1); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 12, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_GIVEREF(__pyx_t_1); + __Pyx_GOTREF(__pyx_v___pyx_result->name); + __Pyx_DECREF(__pyx_v___pyx_result->name); + __pyx_v___pyx_result->name = __pyx_t_1; + __pyx_t_1 = 0; + + /* "(tree fragment)":13 + * cdef __pyx_unpickle_Enum__set_state(Enum __pyx_result, tuple __pyx_state): + * __pyx_result.name = __pyx_state[0] + * if len(__pyx_state) > 1 and hasattr(__pyx_result, '__dict__'): # <<<<<<<<<<<<<< + * __pyx_result.__dict__.update(__pyx_state[1]) + */ + if (unlikely(__pyx_v___pyx_state == Py_None)) { + PyErr_SetString(PyExc_TypeError, "object of type 'NoneType' has no len()"); + __PYX_ERR(1, 13, __pyx_L1_error) + } + __pyx_t_3 = __Pyx_PyTuple_GET_SIZE(__pyx_v___pyx_state); if (unlikely(__pyx_t_3 == ((Py_ssize_t)-1))) __PYX_ERR(1, 13, __pyx_L1_error) + __pyx_t_4 = (__pyx_t_3 > 1); + if (__pyx_t_4) { + } else { + __pyx_t_2 = __pyx_t_4; + goto __pyx_L4_bool_binop_done; + } + __pyx_t_4 = __Pyx_HasAttr(((PyObject *)__pyx_v___pyx_result), __pyx_n_s_dict); if (unlikely(__pyx_t_4 == ((int)-1))) __PYX_ERR(1, 13, __pyx_L1_error) + __pyx_t_2 = __pyx_t_4; + __pyx_L4_bool_binop_done:; + if (__pyx_t_2) { + + /* "(tree fragment)":14 + * __pyx_result.name = __pyx_state[0] + * if len(__pyx_state) > 1 and hasattr(__pyx_result, '__dict__'): + * __pyx_result.__dict__.update(__pyx_state[1]) # <<<<<<<<<<<<<< + */ + __pyx_t_5 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v___pyx_result), __pyx_n_s_dict); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 14, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_6 = __Pyx_PyObject_GetAttrStr(__pyx_t_5, __pyx_n_s_update); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 14, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + if (unlikely(__pyx_v___pyx_state == Py_None)) { + PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); + __PYX_ERR(1, 14, __pyx_L1_error) + } + __pyx_t_5 = __Pyx_GetItemInt_Tuple(__pyx_v___pyx_state, 1, long, 1, __Pyx_PyInt_From_long, 0, 0, 1); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 14, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_7 = NULL; + __pyx_t_8 = 0; + #if CYTHON_UNPACK_METHODS + if (likely(PyMethod_Check(__pyx_t_6))) { + __pyx_t_7 = PyMethod_GET_SELF(__pyx_t_6); + if (likely(__pyx_t_7)) { + PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_6); + __Pyx_INCREF(__pyx_t_7); + __Pyx_INCREF(function); + __Pyx_DECREF_SET(__pyx_t_6, function); + __pyx_t_8 = 1; + } + } + #endif + { + PyObject *__pyx_callargs[2] = {__pyx_t_7, __pyx_t_5}; + __pyx_t_1 = __Pyx_PyObject_FastCall(__pyx_t_6, __pyx_callargs+1-__pyx_t_8, 1+__pyx_t_8); + __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 14, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + } + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + + /* "(tree fragment)":13 + * cdef __pyx_unpickle_Enum__set_state(Enum __pyx_result, tuple __pyx_state): + * __pyx_result.name = __pyx_state[0] + * if len(__pyx_state) > 1 and hasattr(__pyx_result, '__dict__'): # <<<<<<<<<<<<<< + * __pyx_result.__dict__.update(__pyx_state[1]) + */ + } + + /* "(tree fragment)":11 + * __pyx_unpickle_Enum__set_state( __pyx_result, __pyx_state) + * return __pyx_result + * cdef __pyx_unpickle_Enum__set_state(Enum __pyx_result, tuple __pyx_state): # <<<<<<<<<<<<<< + * __pyx_result.name = __pyx_state[0] + * if len(__pyx_state) > 1 and hasattr(__pyx_result, '__dict__'): + */ + + /* function exit code */ + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_5); + __Pyx_XDECREF(__pyx_t_6); + __Pyx_XDECREF(__pyx_t_7); + __Pyx_AddTraceback("View.MemoryView.__pyx_unpickle_Enum__set_state", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":245 + * + * @property + * cdef inline PyObject* base(self) nogil: # <<<<<<<<<<<<<< + * """Returns a borrowed reference to the object owning the data/memory. + * """ + */ + +static CYTHON_INLINE PyObject *__pyx_f_5numpy_7ndarray_4base_base(PyArrayObject *__pyx_v_self) { + PyObject *__pyx_r; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":248 + * """Returns a borrowed reference to the object owning the data/memory. + * """ + * return PyArray_BASE(self) # <<<<<<<<<<<<<< + * + * @property + */ + __pyx_r = PyArray_BASE(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":245 + * + * @property + * cdef inline PyObject* base(self) nogil: # <<<<<<<<<<<<<< + * """Returns a borrowed reference to the object owning the data/memory. + * """ + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":251 + * + * @property + * cdef inline dtype descr(self): # <<<<<<<<<<<<<< + * """Returns an owned reference to the dtype of the array. + * """ + */ + +static CYTHON_INLINE PyArray_Descr *__pyx_f_5numpy_7ndarray_5descr_descr(PyArrayObject *__pyx_v_self) { + PyArray_Descr *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyArray_Descr *__pyx_t_1; + __Pyx_RefNannySetupContext("descr", 1); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":254 + * """Returns an owned reference to the dtype of the array. + * """ + * return PyArray_DESCR(self) # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF((PyObject *)__pyx_r); + __pyx_t_1 = PyArray_DESCR(__pyx_v_self); + __Pyx_INCREF((PyObject *)((PyArray_Descr *)__pyx_t_1)); + __pyx_r = ((PyArray_Descr *)__pyx_t_1); + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":251 + * + * @property + * cdef inline dtype descr(self): # <<<<<<<<<<<<<< + * """Returns an owned reference to the dtype of the array. + * """ + */ + + /* function exit code */ + __pyx_L0:; + __Pyx_XGIVEREF((PyObject *)__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":257 + * + * @property + * cdef inline int ndim(self) nogil: # <<<<<<<<<<<<<< + * """Returns the number of dimensions in the array. + * """ + */ + +static CYTHON_INLINE int __pyx_f_5numpy_7ndarray_4ndim_ndim(PyArrayObject *__pyx_v_self) { + int __pyx_r; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":260 + * """Returns the number of dimensions in the array. + * """ + * return PyArray_NDIM(self) # <<<<<<<<<<<<<< + * + * @property + */ + __pyx_r = PyArray_NDIM(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":257 + * + * @property + * cdef inline int ndim(self) nogil: # <<<<<<<<<<<<<< + * """Returns the number of dimensions in the array. + * """ + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":263 + * + * @property + * cdef inline npy_intp *shape(self) nogil: # <<<<<<<<<<<<<< + * """Returns a pointer to the dimensions/shape of the array. + * The number of elements matches the number of dimensions of the array (ndim). + */ + +static CYTHON_INLINE npy_intp *__pyx_f_5numpy_7ndarray_5shape_shape(PyArrayObject *__pyx_v_self) { + npy_intp *__pyx_r; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":268 + * Can return NULL for 0-dimensional arrays. + * """ + * return PyArray_DIMS(self) # <<<<<<<<<<<<<< + * + * @property + */ + __pyx_r = PyArray_DIMS(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":263 + * + * @property + * cdef inline npy_intp *shape(self) nogil: # <<<<<<<<<<<<<< + * """Returns a pointer to the dimensions/shape of the array. + * The number of elements matches the number of dimensions of the array (ndim). + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":271 + * + * @property + * cdef inline npy_intp *strides(self) nogil: # <<<<<<<<<<<<<< + * """Returns a pointer to the strides of the array. + * The number of elements matches the number of dimensions of the array (ndim). + */ + +static CYTHON_INLINE npy_intp *__pyx_f_5numpy_7ndarray_7strides_strides(PyArrayObject *__pyx_v_self) { + npy_intp *__pyx_r; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":275 + * The number of elements matches the number of dimensions of the array (ndim). + * """ + * return PyArray_STRIDES(self) # <<<<<<<<<<<<<< + * + * @property + */ + __pyx_r = PyArray_STRIDES(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":271 + * + * @property + * cdef inline npy_intp *strides(self) nogil: # <<<<<<<<<<<<<< + * """Returns a pointer to the strides of the array. + * The number of elements matches the number of dimensions of the array (ndim). + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":278 + * + * @property + * cdef inline npy_intp size(self) nogil: # <<<<<<<<<<<<<< + * """Returns the total size (in number of elements) of the array. + * """ + */ + +static CYTHON_INLINE npy_intp __pyx_f_5numpy_7ndarray_4size_size(PyArrayObject *__pyx_v_self) { + npy_intp __pyx_r; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":281 + * """Returns the total size (in number of elements) of the array. + * """ + * return PyArray_SIZE(self) # <<<<<<<<<<<<<< + * + * @property + */ + __pyx_r = PyArray_SIZE(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":278 + * + * @property + * cdef inline npy_intp size(self) nogil: # <<<<<<<<<<<<<< + * """Returns the total size (in number of elements) of the array. + * """ + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":284 + * + * @property + * cdef inline char* data(self) nogil: # <<<<<<<<<<<<<< + * """The pointer to the data buffer as a char*. + * This is provided for legacy reasons to avoid direct struct field access. + */ + +static CYTHON_INLINE char *__pyx_f_5numpy_7ndarray_4data_data(PyArrayObject *__pyx_v_self) { + char *__pyx_r; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":290 + * of `PyArray_DATA()` instead, which returns a 'void*'. + * """ + * return PyArray_BYTES(self) # <<<<<<<<<<<<<< + * + * ctypedef unsigned char npy_bool + */ + __pyx_r = PyArray_BYTES(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":284 + * + * @property + * cdef inline char* data(self) nogil: # <<<<<<<<<<<<<< + * """The pointer to the data buffer as a char*. + * This is provided for legacy reasons to avoid direct struct field access. + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":773 + * ctypedef npy_cdouble complex_t + * + * cdef inline object PyArray_MultiIterNew1(a): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(1, a) + * + */ + +static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyArray_MultiIterNew1(PyObject *__pyx_v_a) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("PyArray_MultiIterNew1", 1); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":774 + * + * cdef inline object PyArray_MultiIterNew1(a): + * return PyArray_MultiIterNew(1, a) # <<<<<<<<<<<<<< + * + * cdef inline object PyArray_MultiIterNew2(a, b): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = PyArray_MultiIterNew(1, ((void *)__pyx_v_a)); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 774, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":773 + * ctypedef npy_cdouble complex_t + * + * cdef inline object PyArray_MultiIterNew1(a): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(1, a) + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("numpy.PyArray_MultiIterNew1", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":776 + * return PyArray_MultiIterNew(1, a) + * + * cdef inline object PyArray_MultiIterNew2(a, b): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(2, a, b) + * + */ + +static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyArray_MultiIterNew2(PyObject *__pyx_v_a, PyObject *__pyx_v_b) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("PyArray_MultiIterNew2", 1); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":777 + * + * cdef inline object PyArray_MultiIterNew2(a, b): + * return PyArray_MultiIterNew(2, a, b) # <<<<<<<<<<<<<< + * + * cdef inline object PyArray_MultiIterNew3(a, b, c): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = PyArray_MultiIterNew(2, ((void *)__pyx_v_a), ((void *)__pyx_v_b)); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 777, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":776 + * return PyArray_MultiIterNew(1, a) + * + * cdef inline object PyArray_MultiIterNew2(a, b): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(2, a, b) + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("numpy.PyArray_MultiIterNew2", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":779 + * return PyArray_MultiIterNew(2, a, b) + * + * cdef inline object PyArray_MultiIterNew3(a, b, c): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(3, a, b, c) + * + */ + +static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyArray_MultiIterNew3(PyObject *__pyx_v_a, PyObject *__pyx_v_b, PyObject *__pyx_v_c) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("PyArray_MultiIterNew3", 1); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":780 + * + * cdef inline object PyArray_MultiIterNew3(a, b, c): + * return PyArray_MultiIterNew(3, a, b, c) # <<<<<<<<<<<<<< + * + * cdef inline object PyArray_MultiIterNew4(a, b, c, d): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = PyArray_MultiIterNew(3, ((void *)__pyx_v_a), ((void *)__pyx_v_b), ((void *)__pyx_v_c)); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 780, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":779 + * return PyArray_MultiIterNew(2, a, b) + * + * cdef inline object PyArray_MultiIterNew3(a, b, c): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(3, a, b, c) + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("numpy.PyArray_MultiIterNew3", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":782 + * return PyArray_MultiIterNew(3, a, b, c) + * + * cdef inline object PyArray_MultiIterNew4(a, b, c, d): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(4, a, b, c, d) + * + */ + +static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyArray_MultiIterNew4(PyObject *__pyx_v_a, PyObject *__pyx_v_b, PyObject *__pyx_v_c, PyObject *__pyx_v_d) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("PyArray_MultiIterNew4", 1); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":783 + * + * cdef inline object PyArray_MultiIterNew4(a, b, c, d): + * return PyArray_MultiIterNew(4, a, b, c, d) # <<<<<<<<<<<<<< + * + * cdef inline object PyArray_MultiIterNew5(a, b, c, d, e): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = PyArray_MultiIterNew(4, ((void *)__pyx_v_a), ((void *)__pyx_v_b), ((void *)__pyx_v_c), ((void *)__pyx_v_d)); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 783, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":782 + * return PyArray_MultiIterNew(3, a, b, c) + * + * cdef inline object PyArray_MultiIterNew4(a, b, c, d): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(4, a, b, c, d) + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("numpy.PyArray_MultiIterNew4", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":785 + * return PyArray_MultiIterNew(4, a, b, c, d) + * + * cdef inline object PyArray_MultiIterNew5(a, b, c, d, e): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(5, a, b, c, d, e) + * + */ + +static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyArray_MultiIterNew5(PyObject *__pyx_v_a, PyObject *__pyx_v_b, PyObject *__pyx_v_c, PyObject *__pyx_v_d, PyObject *__pyx_v_e) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("PyArray_MultiIterNew5", 1); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":786 + * + * cdef inline object PyArray_MultiIterNew5(a, b, c, d, e): + * return PyArray_MultiIterNew(5, a, b, c, d, e) # <<<<<<<<<<<<<< + * + * cdef inline tuple PyDataType_SHAPE(dtype d): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = PyArray_MultiIterNew(5, ((void *)__pyx_v_a), ((void *)__pyx_v_b), ((void *)__pyx_v_c), ((void *)__pyx_v_d), ((void *)__pyx_v_e)); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 786, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":785 + * return PyArray_MultiIterNew(4, a, b, c, d) + * + * cdef inline object PyArray_MultiIterNew5(a, b, c, d, e): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(5, a, b, c, d, e) + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("numpy.PyArray_MultiIterNew5", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":788 + * return PyArray_MultiIterNew(5, a, b, c, d, e) + * + * cdef inline tuple PyDataType_SHAPE(dtype d): # <<<<<<<<<<<<<< + * if PyDataType_HASSUBARRAY(d): + * return d.subarray.shape + */ + +static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyDataType_SHAPE(PyArray_Descr *__pyx_v_d) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + __Pyx_RefNannySetupContext("PyDataType_SHAPE", 1); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":789 + * + * cdef inline tuple PyDataType_SHAPE(dtype d): + * if PyDataType_HASSUBARRAY(d): # <<<<<<<<<<<<<< + * return d.subarray.shape + * else: + */ + __pyx_t_1 = PyDataType_HASSUBARRAY(__pyx_v_d); + if (__pyx_t_1) { + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":790 + * cdef inline tuple PyDataType_SHAPE(dtype d): + * if PyDataType_HASSUBARRAY(d): + * return d.subarray.shape # <<<<<<<<<<<<<< + * else: + * return () + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(((PyObject*)__pyx_v_d->subarray->shape)); + __pyx_r = ((PyObject*)__pyx_v_d->subarray->shape); + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":789 + * + * cdef inline tuple PyDataType_SHAPE(dtype d): + * if PyDataType_HASSUBARRAY(d): # <<<<<<<<<<<<<< + * return d.subarray.shape + * else: + */ + } + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":792 + * return d.subarray.shape + * else: + * return () # <<<<<<<<<<<<<< + * + * + */ + /*else*/ { + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(__pyx_empty_tuple); + __pyx_r = __pyx_empty_tuple; + goto __pyx_L0; + } + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":788 + * return PyArray_MultiIterNew(5, a, b, c, d, e) + * + * cdef inline tuple PyDataType_SHAPE(dtype d): # <<<<<<<<<<<<<< + * if PyDataType_HASSUBARRAY(d): + * return d.subarray.shape + */ + + /* function exit code */ + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":968 + * int _import_umath() except -1 + * + * cdef inline void set_array_base(ndarray arr, object base): # <<<<<<<<<<<<<< + * Py_INCREF(base) # important to do this before stealing the reference below! + * PyArray_SetBaseObject(arr, base) + */ + +static CYTHON_INLINE void __pyx_f_5numpy_set_array_base(PyArrayObject *__pyx_v_arr, PyObject *__pyx_v_base) { + int __pyx_t_1; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":969 + * + * cdef inline void set_array_base(ndarray arr, object base): + * Py_INCREF(base) # important to do this before stealing the reference below! # <<<<<<<<<<<<<< + * PyArray_SetBaseObject(arr, base) + * + */ + Py_INCREF(__pyx_v_base); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":970 + * cdef inline void set_array_base(ndarray arr, object base): + * Py_INCREF(base) # important to do this before stealing the reference below! + * PyArray_SetBaseObject(arr, base) # <<<<<<<<<<<<<< + * + * cdef inline object get_array_base(ndarray arr): + */ + __pyx_t_1 = PyArray_SetBaseObject(__pyx_v_arr, __pyx_v_base); if (unlikely(__pyx_t_1 == ((int)-1))) __PYX_ERR(2, 970, __pyx_L1_error) + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":968 + * int _import_umath() except -1 + * + * cdef inline void set_array_base(ndarray arr, object base): # <<<<<<<<<<<<<< + * Py_INCREF(base) # important to do this before stealing the reference below! + * PyArray_SetBaseObject(arr, base) + */ + + /* function exit code */ + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_AddTraceback("numpy.set_array_base", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_L0:; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":972 + * PyArray_SetBaseObject(arr, base) + * + * cdef inline object get_array_base(ndarray arr): # <<<<<<<<<<<<<< + * base = PyArray_BASE(arr) + * if base is NULL: + */ + +static CYTHON_INLINE PyObject *__pyx_f_5numpy_get_array_base(PyArrayObject *__pyx_v_arr) { + PyObject *__pyx_v_base; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + __Pyx_RefNannySetupContext("get_array_base", 1); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":973 + * + * cdef inline object get_array_base(ndarray arr): + * base = PyArray_BASE(arr) # <<<<<<<<<<<<<< + * if base is NULL: + * return None + */ + __pyx_v_base = PyArray_BASE(__pyx_v_arr); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":974 + * cdef inline object get_array_base(ndarray arr): + * base = PyArray_BASE(arr) + * if base is NULL: # <<<<<<<<<<<<<< + * return None + * return base + */ + __pyx_t_1 = (__pyx_v_base == NULL); + if (__pyx_t_1) { + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":975 + * base = PyArray_BASE(arr) + * if base is NULL: + * return None # <<<<<<<<<<<<<< + * return base + * + */ + __Pyx_XDECREF(__pyx_r); + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":974 + * cdef inline object get_array_base(ndarray arr): + * base = PyArray_BASE(arr) + * if base is NULL: # <<<<<<<<<<<<<< + * return None + * return base + */ + } + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":976 + * if base is NULL: + * return None + * return base # <<<<<<<<<<<<<< + * + * # Versions of the import_* functions which are more suitable for + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(((PyObject *)__pyx_v_base)); + __pyx_r = ((PyObject *)__pyx_v_base); + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":972 + * PyArray_SetBaseObject(arr, base) + * + * cdef inline object get_array_base(ndarray arr): # <<<<<<<<<<<<<< + * base = PyArray_BASE(arr) + * if base is NULL: + */ + + /* function exit code */ + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":980 + * # Versions of the import_* functions which are more suitable for + * # Cython code. + * cdef inline int import_array() except -1: # <<<<<<<<<<<<<< + * try: + * __pyx_import_array() + */ + +static CYTHON_INLINE int __pyx_f_5numpy_import_array(void) { + int __pyx_r; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + int __pyx_t_4; + PyObject *__pyx_t_5 = NULL; + PyObject *__pyx_t_6 = NULL; + PyObject *__pyx_t_7 = NULL; + PyObject *__pyx_t_8 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("import_array", 1); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":981 + * # Cython code. + * cdef inline int import_array() except -1: + * try: # <<<<<<<<<<<<<< + * __pyx_import_array() + * except Exception: + */ + { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ExceptionSave(&__pyx_t_1, &__pyx_t_2, &__pyx_t_3); + __Pyx_XGOTREF(__pyx_t_1); + __Pyx_XGOTREF(__pyx_t_2); + __Pyx_XGOTREF(__pyx_t_3); + /*try:*/ { + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":982 + * cdef inline int import_array() except -1: + * try: + * __pyx_import_array() # <<<<<<<<<<<<<< + * except Exception: + * raise ImportError("numpy.core.multiarray failed to import") + */ + __pyx_t_4 = _import_array(); if (unlikely(__pyx_t_4 == ((int)-1))) __PYX_ERR(2, 982, __pyx_L3_error) + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":981 + * # Cython code. + * cdef inline int import_array() except -1: + * try: # <<<<<<<<<<<<<< + * __pyx_import_array() + * except Exception: + */ + } + __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0; + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + goto __pyx_L8_try_end; + __pyx_L3_error:; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":983 + * try: + * __pyx_import_array() + * except Exception: # <<<<<<<<<<<<<< + * raise ImportError("numpy.core.multiarray failed to import") + * + */ + __pyx_t_4 = __Pyx_PyErr_ExceptionMatches(((PyObject *)(&((PyTypeObject*)PyExc_Exception)[0]))); + if (__pyx_t_4) { + __Pyx_AddTraceback("numpy.import_array", __pyx_clineno, __pyx_lineno, __pyx_filename); + if (__Pyx_GetException(&__pyx_t_5, &__pyx_t_6, &__pyx_t_7) < 0) __PYX_ERR(2, 983, __pyx_L5_except_error) + __Pyx_XGOTREF(__pyx_t_5); + __Pyx_XGOTREF(__pyx_t_6); + __Pyx_XGOTREF(__pyx_t_7); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":984 + * __pyx_import_array() + * except Exception: + * raise ImportError("numpy.core.multiarray failed to import") # <<<<<<<<<<<<<< + * + * cdef inline int import_umath() except -1: + */ + __pyx_t_8 = __Pyx_PyObject_Call(__pyx_builtin_ImportError, __pyx_tuple__9, NULL); if (unlikely(!__pyx_t_8)) __PYX_ERR(2, 984, __pyx_L5_except_error) + __Pyx_GOTREF(__pyx_t_8); + __Pyx_Raise(__pyx_t_8, 0, 0, 0); + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + __PYX_ERR(2, 984, __pyx_L5_except_error) + } + goto __pyx_L5_except_error; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":981 + * # Cython code. + * cdef inline int import_array() except -1: + * try: # <<<<<<<<<<<<<< + * __pyx_import_array() + * except Exception: + */ + __pyx_L5_except_error:; + __Pyx_XGIVEREF(__pyx_t_1); + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_ExceptionReset(__pyx_t_1, __pyx_t_2, __pyx_t_3); + goto __pyx_L1_error; + __pyx_L8_try_end:; + } + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":980 + * # Versions of the import_* functions which are more suitable for + * # Cython code. + * cdef inline int import_array() except -1: # <<<<<<<<<<<<<< + * try: + * __pyx_import_array() + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_5); + __Pyx_XDECREF(__pyx_t_6); + __Pyx_XDECREF(__pyx_t_7); + __Pyx_XDECREF(__pyx_t_8); + __Pyx_AddTraceback("numpy.import_array", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":986 + * raise ImportError("numpy.core.multiarray failed to import") + * + * cdef inline int import_umath() except -1: # <<<<<<<<<<<<<< + * try: + * _import_umath() + */ + +static CYTHON_INLINE int __pyx_f_5numpy_import_umath(void) { + int __pyx_r; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + int __pyx_t_4; + PyObject *__pyx_t_5 = NULL; + PyObject *__pyx_t_6 = NULL; + PyObject *__pyx_t_7 = NULL; + PyObject *__pyx_t_8 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("import_umath", 1); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":987 + * + * cdef inline int import_umath() except -1: + * try: # <<<<<<<<<<<<<< + * _import_umath() + * except Exception: + */ + { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ExceptionSave(&__pyx_t_1, &__pyx_t_2, &__pyx_t_3); + __Pyx_XGOTREF(__pyx_t_1); + __Pyx_XGOTREF(__pyx_t_2); + __Pyx_XGOTREF(__pyx_t_3); + /*try:*/ { + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":988 + * cdef inline int import_umath() except -1: + * try: + * _import_umath() # <<<<<<<<<<<<<< + * except Exception: + * raise ImportError("numpy.core.umath failed to import") + */ + __pyx_t_4 = _import_umath(); if (unlikely(__pyx_t_4 == ((int)-1))) __PYX_ERR(2, 988, __pyx_L3_error) + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":987 + * + * cdef inline int import_umath() except -1: + * try: # <<<<<<<<<<<<<< + * _import_umath() + * except Exception: + */ + } + __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0; + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + goto __pyx_L8_try_end; + __pyx_L3_error:; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":989 + * try: + * _import_umath() + * except Exception: # <<<<<<<<<<<<<< + * raise ImportError("numpy.core.umath failed to import") + * + */ + __pyx_t_4 = __Pyx_PyErr_ExceptionMatches(((PyObject *)(&((PyTypeObject*)PyExc_Exception)[0]))); + if (__pyx_t_4) { + __Pyx_AddTraceback("numpy.import_umath", __pyx_clineno, __pyx_lineno, __pyx_filename); + if (__Pyx_GetException(&__pyx_t_5, &__pyx_t_6, &__pyx_t_7) < 0) __PYX_ERR(2, 989, __pyx_L5_except_error) + __Pyx_XGOTREF(__pyx_t_5); + __Pyx_XGOTREF(__pyx_t_6); + __Pyx_XGOTREF(__pyx_t_7); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":990 + * _import_umath() + * except Exception: + * raise ImportError("numpy.core.umath failed to import") # <<<<<<<<<<<<<< + * + * cdef inline int import_ufunc() except -1: + */ + __pyx_t_8 = __Pyx_PyObject_Call(__pyx_builtin_ImportError, __pyx_tuple__10, NULL); if (unlikely(!__pyx_t_8)) __PYX_ERR(2, 990, __pyx_L5_except_error) + __Pyx_GOTREF(__pyx_t_8); + __Pyx_Raise(__pyx_t_8, 0, 0, 0); + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + __PYX_ERR(2, 990, __pyx_L5_except_error) + } + goto __pyx_L5_except_error; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":987 + * + * cdef inline int import_umath() except -1: + * try: # <<<<<<<<<<<<<< + * _import_umath() + * except Exception: + */ + __pyx_L5_except_error:; + __Pyx_XGIVEREF(__pyx_t_1); + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_ExceptionReset(__pyx_t_1, __pyx_t_2, __pyx_t_3); + goto __pyx_L1_error; + __pyx_L8_try_end:; + } + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":986 + * raise ImportError("numpy.core.multiarray failed to import") + * + * cdef inline int import_umath() except -1: # <<<<<<<<<<<<<< + * try: + * _import_umath() + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_5); + __Pyx_XDECREF(__pyx_t_6); + __Pyx_XDECREF(__pyx_t_7); + __Pyx_XDECREF(__pyx_t_8); + __Pyx_AddTraceback("numpy.import_umath", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":992 + * raise ImportError("numpy.core.umath failed to import") + * + * cdef inline int import_ufunc() except -1: # <<<<<<<<<<<<<< + * try: + * _import_umath() + */ + +static CYTHON_INLINE int __pyx_f_5numpy_import_ufunc(void) { + int __pyx_r; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + int __pyx_t_4; + PyObject *__pyx_t_5 = NULL; + PyObject *__pyx_t_6 = NULL; + PyObject *__pyx_t_7 = NULL; + PyObject *__pyx_t_8 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("import_ufunc", 1); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":993 + * + * cdef inline int import_ufunc() except -1: + * try: # <<<<<<<<<<<<<< + * _import_umath() + * except Exception: + */ + { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ExceptionSave(&__pyx_t_1, &__pyx_t_2, &__pyx_t_3); + __Pyx_XGOTREF(__pyx_t_1); + __Pyx_XGOTREF(__pyx_t_2); + __Pyx_XGOTREF(__pyx_t_3); + /*try:*/ { + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":994 + * cdef inline int import_ufunc() except -1: + * try: + * _import_umath() # <<<<<<<<<<<<<< + * except Exception: + * raise ImportError("numpy.core.umath failed to import") + */ + __pyx_t_4 = _import_umath(); if (unlikely(__pyx_t_4 == ((int)-1))) __PYX_ERR(2, 994, __pyx_L3_error) + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":993 + * + * cdef inline int import_ufunc() except -1: + * try: # <<<<<<<<<<<<<< + * _import_umath() + * except Exception: + */ + } + __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0; + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + goto __pyx_L8_try_end; + __pyx_L3_error:; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":995 + * try: + * _import_umath() + * except Exception: # <<<<<<<<<<<<<< + * raise ImportError("numpy.core.umath failed to import") + * + */ + __pyx_t_4 = __Pyx_PyErr_ExceptionMatches(((PyObject *)(&((PyTypeObject*)PyExc_Exception)[0]))); + if (__pyx_t_4) { + __Pyx_AddTraceback("numpy.import_ufunc", __pyx_clineno, __pyx_lineno, __pyx_filename); + if (__Pyx_GetException(&__pyx_t_5, &__pyx_t_6, &__pyx_t_7) < 0) __PYX_ERR(2, 995, __pyx_L5_except_error) + __Pyx_XGOTREF(__pyx_t_5); + __Pyx_XGOTREF(__pyx_t_6); + __Pyx_XGOTREF(__pyx_t_7); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":996 + * _import_umath() + * except Exception: + * raise ImportError("numpy.core.umath failed to import") # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_8 = __Pyx_PyObject_Call(__pyx_builtin_ImportError, __pyx_tuple__10, NULL); if (unlikely(!__pyx_t_8)) __PYX_ERR(2, 996, __pyx_L5_except_error) + __Pyx_GOTREF(__pyx_t_8); + __Pyx_Raise(__pyx_t_8, 0, 0, 0); + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + __PYX_ERR(2, 996, __pyx_L5_except_error) + } + goto __pyx_L5_except_error; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":993 + * + * cdef inline int import_ufunc() except -1: + * try: # <<<<<<<<<<<<<< + * _import_umath() + * except Exception: + */ + __pyx_L5_except_error:; + __Pyx_XGIVEREF(__pyx_t_1); + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_ExceptionReset(__pyx_t_1, __pyx_t_2, __pyx_t_3); + goto __pyx_L1_error; + __pyx_L8_try_end:; + } + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":992 + * raise ImportError("numpy.core.umath failed to import") + * + * cdef inline int import_ufunc() except -1: # <<<<<<<<<<<<<< + * try: + * _import_umath() + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_5); + __Pyx_XDECREF(__pyx_t_6); + __Pyx_XDECREF(__pyx_t_7); + __Pyx_XDECREF(__pyx_t_8); + __Pyx_AddTraceback("numpy.import_ufunc", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":999 + * + * + * cdef inline bint is_timedelta64_object(object obj): # <<<<<<<<<<<<<< + * """ + * Cython equivalent of `isinstance(obj, np.timedelta64)` + */ + +static CYTHON_INLINE int __pyx_f_5numpy_is_timedelta64_object(PyObject *__pyx_v_obj) { + int __pyx_r; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1011 + * bool + * """ + * return PyObject_TypeCheck(obj, &PyTimedeltaArrType_Type) # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = PyObject_TypeCheck(__pyx_v_obj, (&PyTimedeltaArrType_Type)); + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":999 + * + * + * cdef inline bint is_timedelta64_object(object obj): # <<<<<<<<<<<<<< + * """ + * Cython equivalent of `isinstance(obj, np.timedelta64)` + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1014 + * + * + * cdef inline bint is_datetime64_object(object obj): # <<<<<<<<<<<<<< + * """ + * Cython equivalent of `isinstance(obj, np.datetime64)` + */ + +static CYTHON_INLINE int __pyx_f_5numpy_is_datetime64_object(PyObject *__pyx_v_obj) { + int __pyx_r; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1026 + * bool + * """ + * return PyObject_TypeCheck(obj, &PyDatetimeArrType_Type) # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = PyObject_TypeCheck(__pyx_v_obj, (&PyDatetimeArrType_Type)); + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1014 + * + * + * cdef inline bint is_datetime64_object(object obj): # <<<<<<<<<<<<<< + * """ + * Cython equivalent of `isinstance(obj, np.datetime64)` + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1029 + * + * + * cdef inline npy_datetime get_datetime64_value(object obj) nogil: # <<<<<<<<<<<<<< + * """ + * returns the int64 value underlying scalar numpy datetime64 object + */ + +static CYTHON_INLINE npy_datetime __pyx_f_5numpy_get_datetime64_value(PyObject *__pyx_v_obj) { + npy_datetime __pyx_r; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1036 + * also needed. That can be found using `get_datetime64_unit`. + * """ + * return (obj).obval # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = ((PyDatetimeScalarObject *)__pyx_v_obj)->obval; + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1029 + * + * + * cdef inline npy_datetime get_datetime64_value(object obj) nogil: # <<<<<<<<<<<<<< + * """ + * returns the int64 value underlying scalar numpy datetime64 object + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1039 + * + * + * cdef inline npy_timedelta get_timedelta64_value(object obj) nogil: # <<<<<<<<<<<<<< + * """ + * returns the int64 value underlying scalar numpy timedelta64 object + */ + +static CYTHON_INLINE npy_timedelta __pyx_f_5numpy_get_timedelta64_value(PyObject *__pyx_v_obj) { + npy_timedelta __pyx_r; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1043 + * returns the int64 value underlying scalar numpy timedelta64 object + * """ + * return (obj).obval # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = ((PyTimedeltaScalarObject *)__pyx_v_obj)->obval; + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1039 + * + * + * cdef inline npy_timedelta get_timedelta64_value(object obj) nogil: # <<<<<<<<<<<<<< + * """ + * returns the int64 value underlying scalar numpy timedelta64 object + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1046 + * + * + * cdef inline NPY_DATETIMEUNIT get_datetime64_unit(object obj) nogil: # <<<<<<<<<<<<<< + * """ + * returns the unit part of the dtype for a numpy datetime64 object. + */ + +static CYTHON_INLINE NPY_DATETIMEUNIT __pyx_f_5numpy_get_datetime64_unit(PyObject *__pyx_v_obj) { + NPY_DATETIMEUNIT __pyx_r; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1050 + * returns the unit part of the dtype for a numpy datetime64 object. + * """ + * return (obj).obmeta.base # <<<<<<<<<<<<<< + */ + __pyx_r = ((NPY_DATETIMEUNIT)((PyDatetimeScalarObject *)__pyx_v_obj)->obmeta.base); + goto __pyx_L0; + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":1046 + * + * + * cdef inline NPY_DATETIMEUNIT get_datetime64_unit(object obj) nogil: # <<<<<<<<<<<<<< + * """ + * returns the unit part of the dtype for a numpy datetime64 object. + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "fairseq/data/token_block_utils_fast.pyx":24 + * @cython.wraparound(False) + * @cython.nonecheck(False) + * cdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_none_mode(np.ndarray[DTYPE_t, ndim=1] sizes, int block_size): # <<<<<<<<<<<<<< + * cdef DTYPE_t total_size = sizes.sum() + * cdef DTYPE_t length = ceil(total_size / block_size) + */ + +static PyArrayObject *__pyx_f_7fairseq_4data_22token_block_utils_fast__get_slice_indices_none_mode(PyArrayObject *__pyx_v_sizes, int __pyx_v_block_size) { + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_v_total_size; + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_v_length; + PyArrayObject *__pyx_v_slice_indices = 0; + __Pyx_memviewslice __pyx_v_slice_indices_view = { 0, 0, { 0 }, { 0 }, { 0 } }; + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_v_i; + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_v_start; + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_v_end; + __Pyx_LocalBuf_ND __pyx_pybuffernd_sizes; + __Pyx_Buffer __pyx_pybuffer_sizes; + __Pyx_LocalBuf_ND __pyx_pybuffernd_slice_indices; + __Pyx_Buffer __pyx_pybuffer_slice_indices; + PyArrayObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + int __pyx_t_4; + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_t_5; + PyObject *__pyx_t_6 = NULL; + PyArrayObject *__pyx_t_7 = NULL; + __Pyx_memviewslice __pyx_t_8 = { 0, 0, { 0 }, { 0 }, { 0 } }; + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_t_9; + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_t_10; + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_t_11; + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_t_12; + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_t_13; + int __pyx_t_14; + Py_ssize_t __pyx_t_15; + Py_ssize_t __pyx_t_16; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("_get_slice_indices_none_mode", 1); + __pyx_pybuffer_slice_indices.pybuffer.buf = NULL; + __pyx_pybuffer_slice_indices.refcount = 0; + __pyx_pybuffernd_slice_indices.data = NULL; + __pyx_pybuffernd_slice_indices.rcbuffer = &__pyx_pybuffer_slice_indices; + __pyx_pybuffer_sizes.pybuffer.buf = NULL; + __pyx_pybuffer_sizes.refcount = 0; + __pyx_pybuffernd_sizes.data = NULL; + __pyx_pybuffernd_sizes.rcbuffer = &__pyx_pybuffer_sizes; + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_sizes.rcbuffer->pybuffer, (PyObject*)__pyx_v_sizes, &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t, PyBUF_FORMAT| PyBUF_STRIDES, 1, 0, __pyx_stack) == -1)) __PYX_ERR(0, 24, __pyx_L1_error) + } + __pyx_pybuffernd_sizes.diminfo[0].strides = __pyx_pybuffernd_sizes.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_sizes.diminfo[0].shape = __pyx_pybuffernd_sizes.rcbuffer->pybuffer.shape[0]; + + /* "fairseq/data/token_block_utils_fast.pyx":25 + * @cython.nonecheck(False) + * cdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_none_mode(np.ndarray[DTYPE_t, ndim=1] sizes, int block_size): + * cdef DTYPE_t total_size = sizes.sum() # <<<<<<<<<<<<<< + * cdef DTYPE_t length = ceil(total_size / block_size) + * cdef np.ndarray[DTYPE_t, ndim=2] slice_indices = np.zeros([length, 2], dtype=DTYPE) + */ + __pyx_t_2 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_sizes), __pyx_n_s_sum); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 25, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_3 = NULL; + __pyx_t_4 = 0; + #if CYTHON_UNPACK_METHODS + if (likely(PyMethod_Check(__pyx_t_2))) { + __pyx_t_3 = PyMethod_GET_SELF(__pyx_t_2); + if (likely(__pyx_t_3)) { + PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_2); + __Pyx_INCREF(__pyx_t_3); + __Pyx_INCREF(function); + __Pyx_DECREF_SET(__pyx_t_2, function); + __pyx_t_4 = 1; + } + } + #endif + { + PyObject *__pyx_callargs[2] = {__pyx_t_3, NULL}; + __pyx_t_1 = __Pyx_PyObject_FastCall(__pyx_t_2, __pyx_callargs+1-__pyx_t_4, 0+__pyx_t_4); + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 25, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + } + __pyx_t_5 = __Pyx_PyInt_As_int64_t(__pyx_t_1); if (unlikely((__pyx_t_5 == ((int64_t)-1)) && PyErr_Occurred())) __PYX_ERR(0, 25, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_v_total_size = __pyx_t_5; + + /* "fairseq/data/token_block_utils_fast.pyx":26 + * cdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_none_mode(np.ndarray[DTYPE_t, ndim=1] sizes, int block_size): + * cdef DTYPE_t total_size = sizes.sum() + * cdef DTYPE_t length = ceil(total_size / block_size) # <<<<<<<<<<<<<< + * cdef np.ndarray[DTYPE_t, ndim=2] slice_indices = np.zeros([length, 2], dtype=DTYPE) + * cdef DTYPE_t[:, :] slice_indices_view = slice_indices + */ + if (unlikely(((double)__pyx_v_block_size) == 0)) { + PyErr_SetString(PyExc_ZeroDivisionError, "float division"); + __PYX_ERR(0, 26, __pyx_L1_error) + } + __pyx_v_length = ((__pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t)ceil((((double)__pyx_v_total_size) / ((double)__pyx_v_block_size)))); + + /* "fairseq/data/token_block_utils_fast.pyx":27 + * cdef DTYPE_t total_size = sizes.sum() + * cdef DTYPE_t length = ceil(total_size / block_size) + * cdef np.ndarray[DTYPE_t, ndim=2] slice_indices = np.zeros([length, 2], dtype=DTYPE) # <<<<<<<<<<<<<< + * cdef DTYPE_t[:, :] slice_indices_view = slice_indices + * cdef DTYPE_t i + */ + __Pyx_GetModuleGlobalName(__pyx_t_1, __pyx_n_s_np); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 27, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_t_1, __pyx_n_s_zeros); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 27, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_t_1 = __Pyx_PyInt_From_int64_t(__pyx_v_length); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 27, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_3 = PyList_New(2); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 27, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_GIVEREF(__pyx_t_1); + if (__Pyx_PyList_SET_ITEM(__pyx_t_3, 0, __pyx_t_1)) __PYX_ERR(0, 27, __pyx_L1_error); + __Pyx_INCREF(__pyx_int_2); + __Pyx_GIVEREF(__pyx_int_2); + if (__Pyx_PyList_SET_ITEM(__pyx_t_3, 1, __pyx_int_2)) __PYX_ERR(0, 27, __pyx_L1_error); + __pyx_t_1 = 0; + __pyx_t_1 = PyTuple_New(1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 27, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_GIVEREF(__pyx_t_3); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 0, __pyx_t_3)) __PYX_ERR(0, 27, __pyx_L1_error); + __pyx_t_3 = 0; + __pyx_t_3 = __Pyx_PyDict_NewPresized(1); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 27, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_GetModuleGlobalName(__pyx_t_6, __pyx_n_s_DTYPE); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 27, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + if (PyDict_SetItem(__pyx_t_3, __pyx_n_s_dtype, __pyx_t_6) < 0) __PYX_ERR(0, 27, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + __pyx_t_6 = __Pyx_PyObject_Call(__pyx_t_2, __pyx_t_1, __pyx_t_3); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 27, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + if (!(likely(((__pyx_t_6) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_6, __pyx_ptype_5numpy_ndarray))))) __PYX_ERR(0, 27, __pyx_L1_error) + __pyx_t_7 = ((PyArrayObject *)__pyx_t_6); + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_slice_indices.rcbuffer->pybuffer, (PyObject*)__pyx_t_7, &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t, PyBUF_FORMAT| PyBUF_STRIDES, 2, 0, __pyx_stack) == -1)) { + __pyx_v_slice_indices = ((PyArrayObject *)Py_None); __Pyx_INCREF(Py_None); __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.buf = NULL; + __PYX_ERR(0, 27, __pyx_L1_error) + } else {__pyx_pybuffernd_slice_indices.diminfo[0].strides = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_slice_indices.diminfo[0].shape = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.shape[0]; __pyx_pybuffernd_slice_indices.diminfo[1].strides = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.strides[1]; __pyx_pybuffernd_slice_indices.diminfo[1].shape = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.shape[1]; + } + } + __pyx_t_7 = 0; + __pyx_v_slice_indices = ((PyArrayObject *)__pyx_t_6); + __pyx_t_6 = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":28 + * cdef DTYPE_t length = ceil(total_size / block_size) + * cdef np.ndarray[DTYPE_t, ndim=2] slice_indices = np.zeros([length, 2], dtype=DTYPE) + * cdef DTYPE_t[:, :] slice_indices_view = slice_indices # <<<<<<<<<<<<<< + * cdef DTYPE_t i + * cdef DTYPE_t start + */ + __pyx_t_8 = __Pyx_PyObject_to_MemoryviewSlice_dsds_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t(((PyObject *)__pyx_v_slice_indices), PyBUF_WRITABLE); if (unlikely(!__pyx_t_8.memview)) __PYX_ERR(0, 28, __pyx_L1_error) + __pyx_v_slice_indices_view = __pyx_t_8; + __pyx_t_8.memview = NULL; + __pyx_t_8.data = NULL; + + /* "fairseq/data/token_block_utils_fast.pyx":32 + * cdef DTYPE_t start + * cdef DTYPE_t end + * for i in range(length): # <<<<<<<<<<<<<< + * start = i * block_size + * end = min(start + block_size, total_size) + */ + __pyx_t_5 = __pyx_v_length; + __pyx_t_9 = __pyx_t_5; + for (__pyx_t_10 = 0; __pyx_t_10 < __pyx_t_9; __pyx_t_10+=1) { + __pyx_v_i = __pyx_t_10; + + /* "fairseq/data/token_block_utils_fast.pyx":33 + * cdef DTYPE_t end + * for i in range(length): + * start = i * block_size # <<<<<<<<<<<<<< + * end = min(start + block_size, total_size) + * slice_indices_view[i][0] = start + */ + __pyx_v_start = (__pyx_v_i * __pyx_v_block_size); + + /* "fairseq/data/token_block_utils_fast.pyx":34 + * for i in range(length): + * start = i * block_size + * end = min(start + block_size, total_size) # <<<<<<<<<<<<<< + * slice_indices_view[i][0] = start + * slice_indices_view[i][1] = end + */ + __pyx_t_11 = __pyx_v_total_size; + __pyx_t_12 = (__pyx_v_start + __pyx_v_block_size); + __pyx_t_14 = (__pyx_t_11 < __pyx_t_12); + if (__pyx_t_14) { + __pyx_t_13 = __pyx_t_11; + } else { + __pyx_t_13 = __pyx_t_12; + } + __pyx_v_end = __pyx_t_13; + + /* "fairseq/data/token_block_utils_fast.pyx":35 + * start = i * block_size + * end = min(start + block_size, total_size) + * slice_indices_view[i][0] = start # <<<<<<<<<<<<<< + * slice_indices_view[i][1] = end + * return slice_indices + */ + __pyx_t_15 = __pyx_v_i; + __pyx_t_16 = 0; + *((__pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_slice_indices_view.data + __pyx_t_15 * __pyx_v_slice_indices_view.strides[0]) ) + __pyx_t_16 * __pyx_v_slice_indices_view.strides[1]) )) = __pyx_v_start; + + /* "fairseq/data/token_block_utils_fast.pyx":36 + * end = min(start + block_size, total_size) + * slice_indices_view[i][0] = start + * slice_indices_view[i][1] = end # <<<<<<<<<<<<<< + * return slice_indices + * + */ + __pyx_t_16 = __pyx_v_i; + __pyx_t_15 = 1; + *((__pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_slice_indices_view.data + __pyx_t_16 * __pyx_v_slice_indices_view.strides[0]) ) + __pyx_t_15 * __pyx_v_slice_indices_view.strides[1]) )) = __pyx_v_end; + } + + /* "fairseq/data/token_block_utils_fast.pyx":37 + * slice_indices_view[i][0] = start + * slice_indices_view[i][1] = end + * return slice_indices # <<<<<<<<<<<<<< + * + * + */ + __Pyx_XDECREF((PyObject *)__pyx_r); + __Pyx_INCREF((PyObject *)__pyx_v_slice_indices); + __pyx_r = ((PyArrayObject *)__pyx_v_slice_indices); + goto __pyx_L0; + + /* "fairseq/data/token_block_utils_fast.pyx":24 + * @cython.wraparound(False) + * @cython.nonecheck(False) + * cdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_none_mode(np.ndarray[DTYPE_t, ndim=1] sizes, int block_size): # <<<<<<<<<<<<<< + * cdef DTYPE_t total_size = sizes.sum() + * cdef DTYPE_t length = ceil(total_size / block_size) + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_XDECREF(__pyx_t_6); + __PYX_XCLEAR_MEMVIEW(&__pyx_t_8, 1); + { PyObject *__pyx_type, *__pyx_value, *__pyx_tb; + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ErrFetch(&__pyx_type, &__pyx_value, &__pyx_tb); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_sizes.rcbuffer->pybuffer); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_slice_indices.rcbuffer->pybuffer); + __Pyx_ErrRestore(__pyx_type, __pyx_value, __pyx_tb);} + __Pyx_AddTraceback("fairseq.data.token_block_utils_fast._get_slice_indices_none_mode", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + goto __pyx_L2; + __pyx_L0:; + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_sizes.rcbuffer->pybuffer); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_slice_indices.rcbuffer->pybuffer); + __pyx_L2:; + __Pyx_XDECREF((PyObject *)__pyx_v_slice_indices); + __PYX_XCLEAR_MEMVIEW(&__pyx_v_slice_indices_view, 1); + __Pyx_XGIVEREF((PyObject *)__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "fairseq/data/token_block_utils_fast.pyx":40 + * + * + * cdef np.ndarray[DTYPE_t, ndim=2] _fast_convert_to_np_array(list list_of_list): # <<<<<<<<<<<<<< + * """ + * Faster function to convert DTYPE_t list of list. + */ + +static PyArrayObject *__pyx_f_7fairseq_4data_22token_block_utils_fast__fast_convert_to_np_array(PyObject *__pyx_v_list_of_list) { + PyArrayObject *__pyx_v_flat = 0; + __Pyx_LocalBuf_ND __pyx_pybuffernd_flat; + __Pyx_Buffer __pyx_pybuffer_flat; + PyArrayObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + PyObject *__pyx_t_5 = NULL; + int __pyx_t_6; + PyArrayObject *__pyx_t_7 = NULL; + Py_ssize_t __pyx_t_8; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("_fast_convert_to_np_array", 1); + __pyx_pybuffer_flat.pybuffer.buf = NULL; + __pyx_pybuffer_flat.refcount = 0; + __pyx_pybuffernd_flat.data = NULL; + __pyx_pybuffernd_flat.rcbuffer = &__pyx_pybuffer_flat; + + /* "fairseq/data/token_block_utils_fast.pyx":45 + * Only fast when there are huge number of rows and low number of columns. + * """ + * cdef np.ndarray[DTYPE_t, ndim=1] flat = np.fromiter(chain.from_iterable(list_of_list), DTYPE, -1) # <<<<<<<<<<<<<< + * return flat.reshape((len(list_of_list), -1)) + * + */ + __Pyx_GetModuleGlobalName(__pyx_t_2, __pyx_n_s_np); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 45, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_3 = __Pyx_PyObject_GetAttrStr(__pyx_t_2, __pyx_n_s_fromiter); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 45, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __Pyx_GetModuleGlobalName(__pyx_t_4, __pyx_n_s_chain); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 45, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_5 = __Pyx_PyObject_GetAttrStr(__pyx_t_4, __pyx_n_s_from_iterable); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 45, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __pyx_t_4 = NULL; + __pyx_t_6 = 0; + #if CYTHON_UNPACK_METHODS + if (unlikely(PyMethod_Check(__pyx_t_5))) { + __pyx_t_4 = PyMethod_GET_SELF(__pyx_t_5); + if (likely(__pyx_t_4)) { + PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_5); + __Pyx_INCREF(__pyx_t_4); + __Pyx_INCREF(function); + __Pyx_DECREF_SET(__pyx_t_5, function); + __pyx_t_6 = 1; + } + } + #endif + { + PyObject *__pyx_callargs[2] = {__pyx_t_4, __pyx_v_list_of_list}; + __pyx_t_2 = __Pyx_PyObject_FastCall(__pyx_t_5, __pyx_callargs+1-__pyx_t_6, 1+__pyx_t_6); + __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; + if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 45, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + } + __Pyx_GetModuleGlobalName(__pyx_t_5, __pyx_n_s_DTYPE); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 45, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_4 = NULL; + __pyx_t_6 = 0; + #if CYTHON_UNPACK_METHODS + if (unlikely(PyMethod_Check(__pyx_t_3))) { + __pyx_t_4 = PyMethod_GET_SELF(__pyx_t_3); + if (likely(__pyx_t_4)) { + PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_3); + __Pyx_INCREF(__pyx_t_4); + __Pyx_INCREF(function); + __Pyx_DECREF_SET(__pyx_t_3, function); + __pyx_t_6 = 1; + } + } + #endif + { + PyObject *__pyx_callargs[4] = {__pyx_t_4, __pyx_t_2, __pyx_t_5, __pyx_int_neg_1}; + __pyx_t_1 = __Pyx_PyObject_FastCall(__pyx_t_3, __pyx_callargs+1-__pyx_t_6, 3+__pyx_t_6); + __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 45, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + } + if (!(likely(((__pyx_t_1) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_1, __pyx_ptype_5numpy_ndarray))))) __PYX_ERR(0, 45, __pyx_L1_error) + __pyx_t_7 = ((PyArrayObject *)__pyx_t_1); + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_flat.rcbuffer->pybuffer, (PyObject*)__pyx_t_7, &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t, PyBUF_FORMAT| PyBUF_STRIDES, 1, 0, __pyx_stack) == -1)) { + __pyx_v_flat = ((PyArrayObject *)Py_None); __Pyx_INCREF(Py_None); __pyx_pybuffernd_flat.rcbuffer->pybuffer.buf = NULL; + __PYX_ERR(0, 45, __pyx_L1_error) + } else {__pyx_pybuffernd_flat.diminfo[0].strides = __pyx_pybuffernd_flat.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_flat.diminfo[0].shape = __pyx_pybuffernd_flat.rcbuffer->pybuffer.shape[0]; + } + } + __pyx_t_7 = 0; + __pyx_v_flat = ((PyArrayObject *)__pyx_t_1); + __pyx_t_1 = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":46 + * """ + * cdef np.ndarray[DTYPE_t, ndim=1] flat = np.fromiter(chain.from_iterable(list_of_list), DTYPE, -1) + * return flat.reshape((len(list_of_list), -1)) # <<<<<<<<<<<<<< + * + * + */ + __Pyx_XDECREF((PyObject *)__pyx_r); + __pyx_t_3 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_flat), __pyx_n_s_reshape); if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 46, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + if (unlikely(__pyx_v_list_of_list == Py_None)) { + PyErr_SetString(PyExc_TypeError, "object of type 'NoneType' has no len()"); + __PYX_ERR(0, 46, __pyx_L1_error) + } + __pyx_t_8 = __Pyx_PyList_GET_SIZE(__pyx_v_list_of_list); if (unlikely(__pyx_t_8 == ((Py_ssize_t)-1))) __PYX_ERR(0, 46, __pyx_L1_error) + __pyx_t_5 = PyInt_FromSsize_t(__pyx_t_8); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 46, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_2 = PyTuple_New(2); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 46, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_GIVEREF(__pyx_t_5); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_2, 0, __pyx_t_5)) __PYX_ERR(0, 46, __pyx_L1_error); + __Pyx_INCREF(__pyx_int_neg_1); + __Pyx_GIVEREF(__pyx_int_neg_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_2, 1, __pyx_int_neg_1)) __PYX_ERR(0, 46, __pyx_L1_error); + __pyx_t_5 = 0; + __pyx_t_5 = NULL; + __pyx_t_6 = 0; + #if CYTHON_UNPACK_METHODS + if (likely(PyMethod_Check(__pyx_t_3))) { + __pyx_t_5 = PyMethod_GET_SELF(__pyx_t_3); + if (likely(__pyx_t_5)) { + PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_3); + __Pyx_INCREF(__pyx_t_5); + __Pyx_INCREF(function); + __Pyx_DECREF_SET(__pyx_t_3, function); + __pyx_t_6 = 1; + } + } + #endif + { + PyObject *__pyx_callargs[2] = {__pyx_t_5, __pyx_t_2}; + __pyx_t_1 = __Pyx_PyObject_FastCall(__pyx_t_3, __pyx_callargs+1-__pyx_t_6, 1+__pyx_t_6); + __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 46, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + } + if (!(likely(((__pyx_t_1) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_1, __pyx_ptype_5numpy_ndarray))))) __PYX_ERR(0, 46, __pyx_L1_error) + __pyx_r = ((PyArrayObject *)__pyx_t_1); + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "fairseq/data/token_block_utils_fast.pyx":40 + * + * + * cdef np.ndarray[DTYPE_t, ndim=2] _fast_convert_to_np_array(list list_of_list): # <<<<<<<<<<<<<< + * """ + * Faster function to convert DTYPE_t list of list. + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_XDECREF(__pyx_t_4); + __Pyx_XDECREF(__pyx_t_5); + { PyObject *__pyx_type, *__pyx_value, *__pyx_tb; + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ErrFetch(&__pyx_type, &__pyx_value, &__pyx_tb); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_flat.rcbuffer->pybuffer); + __Pyx_ErrRestore(__pyx_type, __pyx_value, __pyx_tb);} + __Pyx_AddTraceback("fairseq.data.token_block_utils_fast._fast_convert_to_np_array", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + goto __pyx_L2; + __pyx_L0:; + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_flat.rcbuffer->pybuffer); + __pyx_L2:; + __Pyx_XDECREF((PyObject *)__pyx_v_flat); + __Pyx_XGIVEREF((PyObject *)__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "fairseq/data/token_block_utils_fast.pyx":52 + * @cython.wraparound(False) + * @cython.nonecheck(False) + * cpdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_fast(np.ndarray[DTYPE_t, ndim=1] sizes, str break_mode, int block_size, int document_sep_len): # <<<<<<<<<<<<<< + * cdef DTYPE_t tok_idx = 0 + * cdef DTYPE_t sz_idx = 0 + */ + +static PyObject *__pyx_pw_7fairseq_4data_22token_block_utils_fast_1_get_slice_indices_fast(PyObject *__pyx_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyArrayObject *__pyx_f_7fairseq_4data_22token_block_utils_fast__get_slice_indices_fast(PyArrayObject *__pyx_v_sizes, PyObject *__pyx_v_break_mode, int __pyx_v_block_size, int __pyx_v_document_sep_len, CYTHON_UNUSED int __pyx_skip_dispatch) { + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_v_tok_idx; + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_v_sz_idx; + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_v_curr_size; + CYTHON_UNUSED __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_v_i; + __Pyx_memviewslice __pyx_v_sizes_view = { 0, 0, { 0 }, { 0 }, { 0 } }; + PyArrayObject *__pyx_v_slice_indices = 0; + PyObject *__pyx_v_slice_indices_list = 0; + PyObject *__pyx_v_cumsum = NULL; + __Pyx_LocalBuf_ND __pyx_pybuffernd_sizes; + __Pyx_Buffer __pyx_pybuffer_sizes; + __Pyx_LocalBuf_ND __pyx_pybuffernd_slice_indices; + __Pyx_Buffer __pyx_pybuffer_slice_indices; + PyArrayObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_memviewslice __pyx_t_1 = { 0, 0, { 0 }, { 0 }, { 0 } }; + PyObject *__pyx_t_2 = NULL; + int __pyx_t_3; + int __pyx_t_4; + int __pyx_t_5; + PyObject *__pyx_t_6 = NULL; + PyObject *__pyx_t_7 = NULL; + PyObject *__pyx_t_8 = NULL; + Py_ssize_t __pyx_t_9; + Py_ssize_t __pyx_t_10; + PyObject *__pyx_t_11 = NULL; + PyObject *__pyx_t_12 = NULL; + int __pyx_t_13; + PyObject *__pyx_t_14 = NULL; + PyArrayObject *__pyx_t_15 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("_get_slice_indices_fast", 1); + __pyx_pybuffer_slice_indices.pybuffer.buf = NULL; + __pyx_pybuffer_slice_indices.refcount = 0; + __pyx_pybuffernd_slice_indices.data = NULL; + __pyx_pybuffernd_slice_indices.rcbuffer = &__pyx_pybuffer_slice_indices; + __pyx_pybuffer_sizes.pybuffer.buf = NULL; + __pyx_pybuffer_sizes.refcount = 0; + __pyx_pybuffernd_sizes.data = NULL; + __pyx_pybuffernd_sizes.rcbuffer = &__pyx_pybuffer_sizes; + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_sizes.rcbuffer->pybuffer, (PyObject*)__pyx_v_sizes, &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t, PyBUF_FORMAT| PyBUF_STRIDES, 1, 0, __pyx_stack) == -1)) __PYX_ERR(0, 52, __pyx_L1_error) + } + __pyx_pybuffernd_sizes.diminfo[0].strides = __pyx_pybuffernd_sizes.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_sizes.diminfo[0].shape = __pyx_pybuffernd_sizes.rcbuffer->pybuffer.shape[0]; + + /* "fairseq/data/token_block_utils_fast.pyx":53 + * @cython.nonecheck(False) + * cpdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_fast(np.ndarray[DTYPE_t, ndim=1] sizes, str break_mode, int block_size, int document_sep_len): + * cdef DTYPE_t tok_idx = 0 # <<<<<<<<<<<<<< + * cdef DTYPE_t sz_idx = 0 + * cdef DTYPE_t curr_size = 0 + */ + __pyx_v_tok_idx = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":54 + * cpdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_fast(np.ndarray[DTYPE_t, ndim=1] sizes, str break_mode, int block_size, int document_sep_len): + * cdef DTYPE_t tok_idx = 0 + * cdef DTYPE_t sz_idx = 0 # <<<<<<<<<<<<<< + * cdef DTYPE_t curr_size = 0 + * cdef DTYPE_t i = 0 + */ + __pyx_v_sz_idx = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":55 + * cdef DTYPE_t tok_idx = 0 + * cdef DTYPE_t sz_idx = 0 + * cdef DTYPE_t curr_size = 0 # <<<<<<<<<<<<<< + * cdef DTYPE_t i = 0 + * cdef DTYPE_t length + */ + __pyx_v_curr_size = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":56 + * cdef DTYPE_t sz_idx = 0 + * cdef DTYPE_t curr_size = 0 + * cdef DTYPE_t i = 0 # <<<<<<<<<<<<<< + * cdef DTYPE_t length + * cdef DTYPE_t total_size + */ + __pyx_v_i = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":59 + * cdef DTYPE_t length + * cdef DTYPE_t total_size + * cdef DTYPE_t[:] sizes_view = sizes # <<<<<<<<<<<<<< + * cdef np.ndarray[DTYPE_t, ndim=2] slice_indices + * cdef list slice_indices_list = [] + */ + __pyx_t_1 = __Pyx_PyObject_to_MemoryviewSlice_ds_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t(((PyObject *)__pyx_v_sizes), PyBUF_WRITABLE); if (unlikely(!__pyx_t_1.memview)) __PYX_ERR(0, 59, __pyx_L1_error) + __pyx_v_sizes_view = __pyx_t_1; + __pyx_t_1.memview = NULL; + __pyx_t_1.data = NULL; + + /* "fairseq/data/token_block_utils_fast.pyx":61 + * cdef DTYPE_t[:] sizes_view = sizes + * cdef np.ndarray[DTYPE_t, ndim=2] slice_indices + * cdef list slice_indices_list = [] # <<<<<<<<<<<<<< + * + * if break_mode is None or break_mode == 'none': + */ + __pyx_t_2 = PyList_New(0); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 61, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_v_slice_indices_list = ((PyObject*)__pyx_t_2); + __pyx_t_2 = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":63 + * cdef list slice_indices_list = [] + * + * if break_mode is None or break_mode == 'none': # <<<<<<<<<<<<<< + * slice_indices = _get_slice_indices_none_mode(sizes, block_size) + * elif break_mode == 'complete': + */ + __pyx_t_4 = (__pyx_v_break_mode == ((PyObject*)Py_None)); + if (!__pyx_t_4) { + } else { + __pyx_t_3 = __pyx_t_4; + goto __pyx_L4_bool_binop_done; + } + __pyx_t_4 = (__Pyx_PyUnicode_Equals(__pyx_v_break_mode, __pyx_n_u_none, Py_EQ)); if (unlikely((__pyx_t_4 < 0))) __PYX_ERR(0, 63, __pyx_L1_error) + __pyx_t_3 = __pyx_t_4; + __pyx_L4_bool_binop_done:; + if (__pyx_t_3) { + + /* "fairseq/data/token_block_utils_fast.pyx":64 + * + * if break_mode is None or break_mode == 'none': + * slice_indices = _get_slice_indices_none_mode(sizes, block_size) # <<<<<<<<<<<<<< + * elif break_mode == 'complete': + * while sz_idx < len(sizes_view): + */ + __pyx_t_2 = ((PyObject *)__pyx_f_7fairseq_4data_22token_block_utils_fast__get_slice_indices_none_mode(((PyArrayObject *)__pyx_v_sizes), __pyx_v_block_size)); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 64, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_slice_indices.rcbuffer->pybuffer); + __pyx_t_5 = __Pyx_GetBufferAndValidate(&__pyx_pybuffernd_slice_indices.rcbuffer->pybuffer, (PyObject*)((PyArrayObject *)__pyx_t_2), &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t, PyBUF_FORMAT| PyBUF_STRIDES, 2, 0, __pyx_stack); + if (unlikely(__pyx_t_5 < 0)) { + PyErr_Fetch(&__pyx_t_6, &__pyx_t_7, &__pyx_t_8); + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_slice_indices.rcbuffer->pybuffer, (PyObject*)__pyx_v_slice_indices, &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t, PyBUF_FORMAT| PyBUF_STRIDES, 2, 0, __pyx_stack) == -1)) { + Py_XDECREF(__pyx_t_6); Py_XDECREF(__pyx_t_7); Py_XDECREF(__pyx_t_8); + __Pyx_RaiseBufferFallbackError(); + } else { + PyErr_Restore(__pyx_t_6, __pyx_t_7, __pyx_t_8); + } + __pyx_t_6 = __pyx_t_7 = __pyx_t_8 = 0; + } + __pyx_pybuffernd_slice_indices.diminfo[0].strides = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_slice_indices.diminfo[0].shape = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.shape[0]; __pyx_pybuffernd_slice_indices.diminfo[1].strides = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.strides[1]; __pyx_pybuffernd_slice_indices.diminfo[1].shape = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.shape[1]; + if (unlikely((__pyx_t_5 < 0))) __PYX_ERR(0, 64, __pyx_L1_error) + } + __pyx_v_slice_indices = ((PyArrayObject *)__pyx_t_2); + __pyx_t_2 = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":63 + * cdef list slice_indices_list = [] + * + * if break_mode is None or break_mode == 'none': # <<<<<<<<<<<<<< + * slice_indices = _get_slice_indices_none_mode(sizes, block_size) + * elif break_mode == 'complete': + */ + goto __pyx_L3; + } + + /* "fairseq/data/token_block_utils_fast.pyx":65 + * if break_mode is None or break_mode == 'none': + * slice_indices = _get_slice_indices_none_mode(sizes, block_size) + * elif break_mode == 'complete': # <<<<<<<<<<<<<< + * while sz_idx < len(sizes_view): + * if curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0: + */ + __pyx_t_3 = (__Pyx_PyUnicode_Equals(__pyx_v_break_mode, __pyx_n_u_complete, Py_EQ)); if (unlikely((__pyx_t_3 < 0))) __PYX_ERR(0, 65, __pyx_L1_error) + if (__pyx_t_3) { + + /* "fairseq/data/token_block_utils_fast.pyx":66 + * slice_indices = _get_slice_indices_none_mode(sizes, block_size) + * elif break_mode == 'complete': + * while sz_idx < len(sizes_view): # <<<<<<<<<<<<<< + * if curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0: + * curr_size += sizes_view[sz_idx] + */ + while (1) { + __pyx_t_9 = __Pyx_MemoryView_Len(__pyx_v_sizes_view); + __pyx_t_3 = (__pyx_v_sz_idx < __pyx_t_9); + if (!__pyx_t_3) break; + + /* "fairseq/data/token_block_utils_fast.pyx":67 + * elif break_mode == 'complete': + * while sz_idx < len(sizes_view): + * if curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0: # <<<<<<<<<<<<<< + * curr_size += sizes_view[sz_idx] + * sz_idx += 1 + */ + __pyx_t_10 = __pyx_v_sz_idx; + __pyx_t_4 = ((__pyx_v_curr_size + (*((__pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t *) ( /* dim=0 */ (__pyx_v_sizes_view.data + __pyx_t_10 * __pyx_v_sizes_view.strides[0]) )))) <= __pyx_v_block_size); + if (!__pyx_t_4) { + } else { + __pyx_t_3 = __pyx_t_4; + goto __pyx_L9_bool_binop_done; + } + __pyx_t_4 = (__pyx_v_curr_size == 0); + __pyx_t_3 = __pyx_t_4; + __pyx_L9_bool_binop_done:; + if (__pyx_t_3) { + + /* "fairseq/data/token_block_utils_fast.pyx":68 + * while sz_idx < len(sizes_view): + * if curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0: + * curr_size += sizes_view[sz_idx] # <<<<<<<<<<<<<< + * sz_idx += 1 + * else: + */ + __pyx_t_10 = __pyx_v_sz_idx; + __pyx_v_curr_size = (__pyx_v_curr_size + (*((__pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t *) ( /* dim=0 */ (__pyx_v_sizes_view.data + __pyx_t_10 * __pyx_v_sizes_view.strides[0]) )))); + + /* "fairseq/data/token_block_utils_fast.pyx":69 + * if curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0: + * curr_size += sizes_view[sz_idx] + * sz_idx += 1 # <<<<<<<<<<<<<< + * else: + * slice_indices_list.append((tok_idx, tok_idx + curr_size)) + */ + __pyx_v_sz_idx = (__pyx_v_sz_idx + 1); + + /* "fairseq/data/token_block_utils_fast.pyx":67 + * elif break_mode == 'complete': + * while sz_idx < len(sizes_view): + * if curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0: # <<<<<<<<<<<<<< + * curr_size += sizes_view[sz_idx] + * sz_idx += 1 + */ + goto __pyx_L8; + } + + /* "fairseq/data/token_block_utils_fast.pyx":71 + * sz_idx += 1 + * else: + * slice_indices_list.append((tok_idx, tok_idx + curr_size)) # <<<<<<<<<<<<<< + * tok_idx += curr_size + * curr_size = 0 + */ + /*else*/ { + __pyx_t_2 = __Pyx_PyInt_From_int64_t(__pyx_v_tok_idx); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 71, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_11 = __Pyx_PyInt_From_int64_t((__pyx_v_tok_idx + __pyx_v_curr_size)); if (unlikely(!__pyx_t_11)) __PYX_ERR(0, 71, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_11); + __pyx_t_12 = PyTuple_New(2); if (unlikely(!__pyx_t_12)) __PYX_ERR(0, 71, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_12); + __Pyx_GIVEREF(__pyx_t_2); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_12, 0, __pyx_t_2)) __PYX_ERR(0, 71, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_11); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_12, 1, __pyx_t_11)) __PYX_ERR(0, 71, __pyx_L1_error); + __pyx_t_2 = 0; + __pyx_t_11 = 0; + __pyx_t_13 = __Pyx_PyList_Append(__pyx_v_slice_indices_list, __pyx_t_12); if (unlikely(__pyx_t_13 == ((int)-1))) __PYX_ERR(0, 71, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_12); __pyx_t_12 = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":72 + * else: + * slice_indices_list.append((tok_idx, tok_idx + curr_size)) + * tok_idx += curr_size # <<<<<<<<<<<<<< + * curr_size = 0 + * if curr_size > 0: + */ + __pyx_v_tok_idx = (__pyx_v_tok_idx + __pyx_v_curr_size); + + /* "fairseq/data/token_block_utils_fast.pyx":73 + * slice_indices_list.append((tok_idx, tok_idx + curr_size)) + * tok_idx += curr_size + * curr_size = 0 # <<<<<<<<<<<<<< + * if curr_size > 0: + * slice_indices_list.append((tok_idx, tok_idx + curr_size)) + */ + __pyx_v_curr_size = 0; + } + __pyx_L8:; + } + + /* "fairseq/data/token_block_utils_fast.pyx":74 + * tok_idx += curr_size + * curr_size = 0 + * if curr_size > 0: # <<<<<<<<<<<<<< + * slice_indices_list.append((tok_idx, tok_idx + curr_size)) + * slice_indices = _fast_convert_to_np_array(slice_indices_list) + */ + __pyx_t_3 = (__pyx_v_curr_size > 0); + if (__pyx_t_3) { + + /* "fairseq/data/token_block_utils_fast.pyx":75 + * curr_size = 0 + * if curr_size > 0: + * slice_indices_list.append((tok_idx, tok_idx + curr_size)) # <<<<<<<<<<<<<< + * slice_indices = _fast_convert_to_np_array(slice_indices_list) + * elif break_mode == 'complete_doc': + */ + __pyx_t_12 = __Pyx_PyInt_From_int64_t(__pyx_v_tok_idx); if (unlikely(!__pyx_t_12)) __PYX_ERR(0, 75, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_12); + __pyx_t_11 = __Pyx_PyInt_From_int64_t((__pyx_v_tok_idx + __pyx_v_curr_size)); if (unlikely(!__pyx_t_11)) __PYX_ERR(0, 75, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_11); + __pyx_t_2 = PyTuple_New(2); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 75, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_GIVEREF(__pyx_t_12); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_2, 0, __pyx_t_12)) __PYX_ERR(0, 75, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_11); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_2, 1, __pyx_t_11)) __PYX_ERR(0, 75, __pyx_L1_error); + __pyx_t_12 = 0; + __pyx_t_11 = 0; + __pyx_t_13 = __Pyx_PyList_Append(__pyx_v_slice_indices_list, __pyx_t_2); if (unlikely(__pyx_t_13 == ((int)-1))) __PYX_ERR(0, 75, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":74 + * tok_idx += curr_size + * curr_size = 0 + * if curr_size > 0: # <<<<<<<<<<<<<< + * slice_indices_list.append((tok_idx, tok_idx + curr_size)) + * slice_indices = _fast_convert_to_np_array(slice_indices_list) + */ + } + + /* "fairseq/data/token_block_utils_fast.pyx":76 + * if curr_size > 0: + * slice_indices_list.append((tok_idx, tok_idx + curr_size)) + * slice_indices = _fast_convert_to_np_array(slice_indices_list) # <<<<<<<<<<<<<< + * elif break_mode == 'complete_doc': + * while sz_idx < len(sizes_view): + */ + __pyx_t_2 = ((PyObject *)__pyx_f_7fairseq_4data_22token_block_utils_fast__fast_convert_to_np_array(__pyx_v_slice_indices_list)); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 76, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_slice_indices.rcbuffer->pybuffer); + __pyx_t_5 = __Pyx_GetBufferAndValidate(&__pyx_pybuffernd_slice_indices.rcbuffer->pybuffer, (PyObject*)((PyArrayObject *)__pyx_t_2), &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t, PyBUF_FORMAT| PyBUF_STRIDES, 2, 0, __pyx_stack); + if (unlikely(__pyx_t_5 < 0)) { + PyErr_Fetch(&__pyx_t_8, &__pyx_t_7, &__pyx_t_6); + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_slice_indices.rcbuffer->pybuffer, (PyObject*)__pyx_v_slice_indices, &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t, PyBUF_FORMAT| PyBUF_STRIDES, 2, 0, __pyx_stack) == -1)) { + Py_XDECREF(__pyx_t_8); Py_XDECREF(__pyx_t_7); Py_XDECREF(__pyx_t_6); + __Pyx_RaiseBufferFallbackError(); + } else { + PyErr_Restore(__pyx_t_8, __pyx_t_7, __pyx_t_6); + } + __pyx_t_8 = __pyx_t_7 = __pyx_t_6 = 0; + } + __pyx_pybuffernd_slice_indices.diminfo[0].strides = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_slice_indices.diminfo[0].shape = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.shape[0]; __pyx_pybuffernd_slice_indices.diminfo[1].strides = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.strides[1]; __pyx_pybuffernd_slice_indices.diminfo[1].shape = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.shape[1]; + if (unlikely((__pyx_t_5 < 0))) __PYX_ERR(0, 76, __pyx_L1_error) + } + __pyx_v_slice_indices = ((PyArrayObject *)__pyx_t_2); + __pyx_t_2 = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":65 + * if break_mode is None or break_mode == 'none': + * slice_indices = _get_slice_indices_none_mode(sizes, block_size) + * elif break_mode == 'complete': # <<<<<<<<<<<<<< + * while sz_idx < len(sizes_view): + * if curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0: + */ + goto __pyx_L3; + } + + /* "fairseq/data/token_block_utils_fast.pyx":77 + * slice_indices_list.append((tok_idx, tok_idx + curr_size)) + * slice_indices = _fast_convert_to_np_array(slice_indices_list) + * elif break_mode == 'complete_doc': # <<<<<<<<<<<<<< + * while sz_idx < len(sizes_view): + * if ( + */ + __pyx_t_3 = (__Pyx_PyUnicode_Equals(__pyx_v_break_mode, __pyx_n_u_complete_doc, Py_EQ)); if (unlikely((__pyx_t_3 < 0))) __PYX_ERR(0, 77, __pyx_L1_error) + if (__pyx_t_3) { + + /* "fairseq/data/token_block_utils_fast.pyx":78 + * slice_indices = _fast_convert_to_np_array(slice_indices_list) + * elif break_mode == 'complete_doc': + * while sz_idx < len(sizes_view): # <<<<<<<<<<<<<< + * if ( + * (curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0) + */ + while (1) { + __pyx_t_9 = __Pyx_MemoryView_Len(__pyx_v_sizes_view); + __pyx_t_3 = (__pyx_v_sz_idx < __pyx_t_9); + if (!__pyx_t_3) break; + + /* "fairseq/data/token_block_utils_fast.pyx":80 + * while sz_idx < len(sizes_view): + * if ( + * (curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0) # <<<<<<<<<<<<<< + * # an empty sentence indicates end-of-document: + * and sizes_view[sz_idx] != document_sep_len + */ + __pyx_t_10 = __pyx_v_sz_idx; + __pyx_t_4 = ((__pyx_v_curr_size + (*((__pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t *) ( /* dim=0 */ (__pyx_v_sizes_view.data + __pyx_t_10 * __pyx_v_sizes_view.strides[0]) )))) <= __pyx_v_block_size); + if (!__pyx_t_4) { + } else { + goto __pyx_L16_next_and; + } + __pyx_t_4 = (__pyx_v_curr_size == 0); + if (__pyx_t_4) { + } else { + __pyx_t_3 = __pyx_t_4; + goto __pyx_L15_bool_binop_done; + } + __pyx_L16_next_and:; + + /* "fairseq/data/token_block_utils_fast.pyx":82 + * (curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0) + * # an empty sentence indicates end-of-document: + * and sizes_view[sz_idx] != document_sep_len # <<<<<<<<<<<<<< + * ): + * curr_size += sizes_view[sz_idx] + */ + __pyx_t_10 = __pyx_v_sz_idx; + __pyx_t_4 = ((*((__pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t *) ( /* dim=0 */ (__pyx_v_sizes_view.data + __pyx_t_10 * __pyx_v_sizes_view.strides[0]) ))) != __pyx_v_document_sep_len); + __pyx_t_3 = __pyx_t_4; + __pyx_L15_bool_binop_done:; + + /* "fairseq/data/token_block_utils_fast.pyx":79 + * elif break_mode == 'complete_doc': + * while sz_idx < len(sizes_view): + * if ( # <<<<<<<<<<<<<< + * (curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0) + * # an empty sentence indicates end-of-document: + */ + if (__pyx_t_3) { + + /* "fairseq/data/token_block_utils_fast.pyx":84 + * and sizes_view[sz_idx] != document_sep_len + * ): + * curr_size += sizes_view[sz_idx] # <<<<<<<<<<<<<< + * sz_idx += 1 + * else: + */ + __pyx_t_10 = __pyx_v_sz_idx; + __pyx_v_curr_size = (__pyx_v_curr_size + (*((__pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t *) ( /* dim=0 */ (__pyx_v_sizes_view.data + __pyx_t_10 * __pyx_v_sizes_view.strides[0]) )))); + + /* "fairseq/data/token_block_utils_fast.pyx":85 + * ): + * curr_size += sizes_view[sz_idx] + * sz_idx += 1 # <<<<<<<<<<<<<< + * else: + * # Only keep non-empty documents. + */ + __pyx_v_sz_idx = (__pyx_v_sz_idx + 1); + + /* "fairseq/data/token_block_utils_fast.pyx":79 + * elif break_mode == 'complete_doc': + * while sz_idx < len(sizes_view): + * if ( # <<<<<<<<<<<<<< + * (curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0) + * # an empty sentence indicates end-of-document: + */ + goto __pyx_L14; + } + + /* "fairseq/data/token_block_utils_fast.pyx":88 + * else: + * # Only keep non-empty documents. + * if curr_size > 1: # <<<<<<<<<<<<<< + * slice_indices_list.append((tok_idx, tok_idx + curr_size)) + * tok_idx += curr_size + */ + /*else*/ { + __pyx_t_3 = (__pyx_v_curr_size > 1); + if (__pyx_t_3) { + + /* "fairseq/data/token_block_utils_fast.pyx":89 + * # Only keep non-empty documents. + * if curr_size > 1: + * slice_indices_list.append((tok_idx, tok_idx + curr_size)) # <<<<<<<<<<<<<< + * tok_idx += curr_size + * curr_size = 0 + */ + __pyx_t_2 = __Pyx_PyInt_From_int64_t(__pyx_v_tok_idx); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 89, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_11 = __Pyx_PyInt_From_int64_t((__pyx_v_tok_idx + __pyx_v_curr_size)); if (unlikely(!__pyx_t_11)) __PYX_ERR(0, 89, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_11); + __pyx_t_12 = PyTuple_New(2); if (unlikely(!__pyx_t_12)) __PYX_ERR(0, 89, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_12); + __Pyx_GIVEREF(__pyx_t_2); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_12, 0, __pyx_t_2)) __PYX_ERR(0, 89, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_11); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_12, 1, __pyx_t_11)) __PYX_ERR(0, 89, __pyx_L1_error); + __pyx_t_2 = 0; + __pyx_t_11 = 0; + __pyx_t_13 = __Pyx_PyList_Append(__pyx_v_slice_indices_list, __pyx_t_12); if (unlikely(__pyx_t_13 == ((int)-1))) __PYX_ERR(0, 89, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_12); __pyx_t_12 = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":88 + * else: + * # Only keep non-empty documents. + * if curr_size > 1: # <<<<<<<<<<<<<< + * slice_indices_list.append((tok_idx, tok_idx + curr_size)) + * tok_idx += curr_size + */ + } + + /* "fairseq/data/token_block_utils_fast.pyx":90 + * if curr_size > 1: + * slice_indices_list.append((tok_idx, tok_idx + curr_size)) + * tok_idx += curr_size # <<<<<<<<<<<<<< + * curr_size = 0 + * if sizes_view[sz_idx] == document_sep_len: + */ + __pyx_v_tok_idx = (__pyx_v_tok_idx + __pyx_v_curr_size); + + /* "fairseq/data/token_block_utils_fast.pyx":91 + * slice_indices_list.append((tok_idx, tok_idx + curr_size)) + * tok_idx += curr_size + * curr_size = 0 # <<<<<<<<<<<<<< + * if sizes_view[sz_idx] == document_sep_len: + * tok_idx += sizes_view[sz_idx] + */ + __pyx_v_curr_size = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":92 + * tok_idx += curr_size + * curr_size = 0 + * if sizes_view[sz_idx] == document_sep_len: # <<<<<<<<<<<<<< + * tok_idx += sizes_view[sz_idx] + * sz_idx += 1 + */ + __pyx_t_10 = __pyx_v_sz_idx; + __pyx_t_3 = ((*((__pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t *) ( /* dim=0 */ (__pyx_v_sizes_view.data + __pyx_t_10 * __pyx_v_sizes_view.strides[0]) ))) == __pyx_v_document_sep_len); + if (__pyx_t_3) { + + /* "fairseq/data/token_block_utils_fast.pyx":93 + * curr_size = 0 + * if sizes_view[sz_idx] == document_sep_len: + * tok_idx += sizes_view[sz_idx] # <<<<<<<<<<<<<< + * sz_idx += 1 + * if curr_size > 1: + */ + __pyx_t_10 = __pyx_v_sz_idx; + __pyx_v_tok_idx = (__pyx_v_tok_idx + (*((__pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t *) ( /* dim=0 */ (__pyx_v_sizes_view.data + __pyx_t_10 * __pyx_v_sizes_view.strides[0]) )))); + + /* "fairseq/data/token_block_utils_fast.pyx":94 + * if sizes_view[sz_idx] == document_sep_len: + * tok_idx += sizes_view[sz_idx] + * sz_idx += 1 # <<<<<<<<<<<<<< + * if curr_size > 1: + * slice_indices_list.append((tok_idx, tok_idx + curr_size)) + */ + __pyx_v_sz_idx = (__pyx_v_sz_idx + 1); + + /* "fairseq/data/token_block_utils_fast.pyx":92 + * tok_idx += curr_size + * curr_size = 0 + * if sizes_view[sz_idx] == document_sep_len: # <<<<<<<<<<<<<< + * tok_idx += sizes_view[sz_idx] + * sz_idx += 1 + */ + } + } + __pyx_L14:; + } + + /* "fairseq/data/token_block_utils_fast.pyx":95 + * tok_idx += sizes_view[sz_idx] + * sz_idx += 1 + * if curr_size > 1: # <<<<<<<<<<<<<< + * slice_indices_list.append((tok_idx, tok_idx + curr_size)) + * slice_indices = _fast_convert_to_np_array(slice_indices_list) + */ + __pyx_t_3 = (__pyx_v_curr_size > 1); + if (__pyx_t_3) { + + /* "fairseq/data/token_block_utils_fast.pyx":96 + * sz_idx += 1 + * if curr_size > 1: + * slice_indices_list.append((tok_idx, tok_idx + curr_size)) # <<<<<<<<<<<<<< + * slice_indices = _fast_convert_to_np_array(slice_indices_list) + * elif break_mode == 'eos': + */ + __pyx_t_12 = __Pyx_PyInt_From_int64_t(__pyx_v_tok_idx); if (unlikely(!__pyx_t_12)) __PYX_ERR(0, 96, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_12); + __pyx_t_11 = __Pyx_PyInt_From_int64_t((__pyx_v_tok_idx + __pyx_v_curr_size)); if (unlikely(!__pyx_t_11)) __PYX_ERR(0, 96, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_11); + __pyx_t_2 = PyTuple_New(2); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 96, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_GIVEREF(__pyx_t_12); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_2, 0, __pyx_t_12)) __PYX_ERR(0, 96, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_11); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_2, 1, __pyx_t_11)) __PYX_ERR(0, 96, __pyx_L1_error); + __pyx_t_12 = 0; + __pyx_t_11 = 0; + __pyx_t_13 = __Pyx_PyList_Append(__pyx_v_slice_indices_list, __pyx_t_2); if (unlikely(__pyx_t_13 == ((int)-1))) __PYX_ERR(0, 96, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":95 + * tok_idx += sizes_view[sz_idx] + * sz_idx += 1 + * if curr_size > 1: # <<<<<<<<<<<<<< + * slice_indices_list.append((tok_idx, tok_idx + curr_size)) + * slice_indices = _fast_convert_to_np_array(slice_indices_list) + */ + } + + /* "fairseq/data/token_block_utils_fast.pyx":97 + * if curr_size > 1: + * slice_indices_list.append((tok_idx, tok_idx + curr_size)) + * slice_indices = _fast_convert_to_np_array(slice_indices_list) # <<<<<<<<<<<<<< + * elif break_mode == 'eos': + * slice_indices = np.zeros((len(sizes), 2), dtype=DTYPE) + */ + __pyx_t_2 = ((PyObject *)__pyx_f_7fairseq_4data_22token_block_utils_fast__fast_convert_to_np_array(__pyx_v_slice_indices_list)); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 97, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_slice_indices.rcbuffer->pybuffer); + __pyx_t_5 = __Pyx_GetBufferAndValidate(&__pyx_pybuffernd_slice_indices.rcbuffer->pybuffer, (PyObject*)((PyArrayObject *)__pyx_t_2), &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t, PyBUF_FORMAT| PyBUF_STRIDES, 2, 0, __pyx_stack); + if (unlikely(__pyx_t_5 < 0)) { + PyErr_Fetch(&__pyx_t_6, &__pyx_t_7, &__pyx_t_8); + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_slice_indices.rcbuffer->pybuffer, (PyObject*)__pyx_v_slice_indices, &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t, PyBUF_FORMAT| PyBUF_STRIDES, 2, 0, __pyx_stack) == -1)) { + Py_XDECREF(__pyx_t_6); Py_XDECREF(__pyx_t_7); Py_XDECREF(__pyx_t_8); + __Pyx_RaiseBufferFallbackError(); + } else { + PyErr_Restore(__pyx_t_6, __pyx_t_7, __pyx_t_8); + } + __pyx_t_6 = __pyx_t_7 = __pyx_t_8 = 0; + } + __pyx_pybuffernd_slice_indices.diminfo[0].strides = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_slice_indices.diminfo[0].shape = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.shape[0]; __pyx_pybuffernd_slice_indices.diminfo[1].strides = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.strides[1]; __pyx_pybuffernd_slice_indices.diminfo[1].shape = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.shape[1]; + if (unlikely((__pyx_t_5 < 0))) __PYX_ERR(0, 97, __pyx_L1_error) + } + __pyx_v_slice_indices = ((PyArrayObject *)__pyx_t_2); + __pyx_t_2 = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":77 + * slice_indices_list.append((tok_idx, tok_idx + curr_size)) + * slice_indices = _fast_convert_to_np_array(slice_indices_list) + * elif break_mode == 'complete_doc': # <<<<<<<<<<<<<< + * while sz_idx < len(sizes_view): + * if ( + */ + goto __pyx_L3; + } + + /* "fairseq/data/token_block_utils_fast.pyx":98 + * slice_indices_list.append((tok_idx, tok_idx + curr_size)) + * slice_indices = _fast_convert_to_np_array(slice_indices_list) + * elif break_mode == 'eos': # <<<<<<<<<<<<<< + * slice_indices = np.zeros((len(sizes), 2), dtype=DTYPE) + * cumsum = sizes.cumsum(axis=0) + */ + __pyx_t_3 = (__Pyx_PyUnicode_Equals(__pyx_v_break_mode, __pyx_n_u_eos, Py_EQ)); if (unlikely((__pyx_t_3 < 0))) __PYX_ERR(0, 98, __pyx_L1_error) + if (likely(__pyx_t_3)) { + + /* "fairseq/data/token_block_utils_fast.pyx":99 + * slice_indices = _fast_convert_to_np_array(slice_indices_list) + * elif break_mode == 'eos': + * slice_indices = np.zeros((len(sizes), 2), dtype=DTYPE) # <<<<<<<<<<<<<< + * cumsum = sizes.cumsum(axis=0) + * slice_indices[1:, 0] = cumsum[:cumsum.shape[0] - 1] + */ + __Pyx_GetModuleGlobalName(__pyx_t_2, __pyx_n_s_np); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 99, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_11 = __Pyx_PyObject_GetAttrStr(__pyx_t_2, __pyx_n_s_zeros); if (unlikely(!__pyx_t_11)) __PYX_ERR(0, 99, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_11); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_t_9 = PyObject_Length(((PyObject *)__pyx_v_sizes)); if (unlikely(__pyx_t_9 == ((Py_ssize_t)-1))) __PYX_ERR(0, 99, __pyx_L1_error) + __pyx_t_2 = PyInt_FromSsize_t(__pyx_t_9); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 99, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_12 = PyTuple_New(2); if (unlikely(!__pyx_t_12)) __PYX_ERR(0, 99, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_12); + __Pyx_GIVEREF(__pyx_t_2); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_12, 0, __pyx_t_2)) __PYX_ERR(0, 99, __pyx_L1_error); + __Pyx_INCREF(__pyx_int_2); + __Pyx_GIVEREF(__pyx_int_2); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_12, 1, __pyx_int_2)) __PYX_ERR(0, 99, __pyx_L1_error); + __pyx_t_2 = 0; + __pyx_t_2 = PyTuple_New(1); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 99, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_GIVEREF(__pyx_t_12); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_2, 0, __pyx_t_12)) __PYX_ERR(0, 99, __pyx_L1_error); + __pyx_t_12 = 0; + __pyx_t_12 = __Pyx_PyDict_NewPresized(1); if (unlikely(!__pyx_t_12)) __PYX_ERR(0, 99, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_12); + __Pyx_GetModuleGlobalName(__pyx_t_14, __pyx_n_s_DTYPE); if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 99, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_14); + if (PyDict_SetItem(__pyx_t_12, __pyx_n_s_dtype, __pyx_t_14) < 0) __PYX_ERR(0, 99, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_14); __pyx_t_14 = 0; + __pyx_t_14 = __Pyx_PyObject_Call(__pyx_t_11, __pyx_t_2, __pyx_t_12); if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 99, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_14); + __Pyx_DECREF(__pyx_t_11); __pyx_t_11 = 0; + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __Pyx_DECREF(__pyx_t_12); __pyx_t_12 = 0; + if (!(likely(((__pyx_t_14) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_14, __pyx_ptype_5numpy_ndarray))))) __PYX_ERR(0, 99, __pyx_L1_error) + __pyx_t_15 = ((PyArrayObject *)__pyx_t_14); + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_slice_indices.rcbuffer->pybuffer); + __pyx_t_5 = __Pyx_GetBufferAndValidate(&__pyx_pybuffernd_slice_indices.rcbuffer->pybuffer, (PyObject*)__pyx_t_15, &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t, PyBUF_FORMAT| PyBUF_STRIDES, 2, 0, __pyx_stack); + if (unlikely(__pyx_t_5 < 0)) { + PyErr_Fetch(&__pyx_t_8, &__pyx_t_7, &__pyx_t_6); + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_slice_indices.rcbuffer->pybuffer, (PyObject*)__pyx_v_slice_indices, &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t, PyBUF_FORMAT| PyBUF_STRIDES, 2, 0, __pyx_stack) == -1)) { + Py_XDECREF(__pyx_t_8); Py_XDECREF(__pyx_t_7); Py_XDECREF(__pyx_t_6); + __Pyx_RaiseBufferFallbackError(); + } else { + PyErr_Restore(__pyx_t_8, __pyx_t_7, __pyx_t_6); + } + __pyx_t_8 = __pyx_t_7 = __pyx_t_6 = 0; + } + __pyx_pybuffernd_slice_indices.diminfo[0].strides = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_slice_indices.diminfo[0].shape = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.shape[0]; __pyx_pybuffernd_slice_indices.diminfo[1].strides = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.strides[1]; __pyx_pybuffernd_slice_indices.diminfo[1].shape = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.shape[1]; + if (unlikely((__pyx_t_5 < 0))) __PYX_ERR(0, 99, __pyx_L1_error) + } + __pyx_t_15 = 0; + __pyx_v_slice_indices = ((PyArrayObject *)__pyx_t_14); + __pyx_t_14 = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":100 + * elif break_mode == 'eos': + * slice_indices = np.zeros((len(sizes), 2), dtype=DTYPE) + * cumsum = sizes.cumsum(axis=0) # <<<<<<<<<<<<<< + * slice_indices[1:, 0] = cumsum[:cumsum.shape[0] - 1] + * slice_indices[:, 1] = cumsum + */ + __pyx_t_14 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_sizes), __pyx_n_s_cumsum); if (unlikely(!__pyx_t_14)) __PYX_ERR(0, 100, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_14); + __pyx_t_12 = __Pyx_PyDict_NewPresized(1); if (unlikely(!__pyx_t_12)) __PYX_ERR(0, 100, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_12); + if (PyDict_SetItem(__pyx_t_12, __pyx_n_s_axis, __pyx_int_0) < 0) __PYX_ERR(0, 100, __pyx_L1_error) + __pyx_t_2 = __Pyx_PyObject_Call(__pyx_t_14, __pyx_empty_tuple, __pyx_t_12); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 100, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_14); __pyx_t_14 = 0; + __Pyx_DECREF(__pyx_t_12); __pyx_t_12 = 0; + __pyx_v_cumsum = __pyx_t_2; + __pyx_t_2 = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":101 + * slice_indices = np.zeros((len(sizes), 2), dtype=DTYPE) + * cumsum = sizes.cumsum(axis=0) + * slice_indices[1:, 0] = cumsum[:cumsum.shape[0] - 1] # <<<<<<<<<<<<<< + * slice_indices[:, 1] = cumsum + * else: + */ + __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_v_cumsum, __pyx_n_s_shape); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 101, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_12 = __Pyx_GetItemInt(__pyx_t_2, 0, long, 1, __Pyx_PyInt_From_long, 0, 0, 0); if (unlikely(!__pyx_t_12)) __PYX_ERR(0, 101, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_12); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_t_2 = __Pyx_PyInt_SubtractObjC(__pyx_t_12, __pyx_int_1, 1, 0, 0); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 101, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_12); __pyx_t_12 = 0; + __pyx_t_12 = __Pyx_PyObject_GetSlice(__pyx_v_cumsum, 0, 0, NULL, &__pyx_t_2, NULL, 0, 0, 0); if (unlikely(!__pyx_t_12)) __PYX_ERR(0, 101, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_12); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + if (unlikely((PyObject_SetItem(((PyObject *)__pyx_v_slice_indices), __pyx_tuple__12, __pyx_t_12) < 0))) __PYX_ERR(0, 101, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_12); __pyx_t_12 = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":102 + * cumsum = sizes.cumsum(axis=0) + * slice_indices[1:, 0] = cumsum[:cumsum.shape[0] - 1] + * slice_indices[:, 1] = cumsum # <<<<<<<<<<<<<< + * else: + * raise ValueError('Invalid break_mode: ' + break_mode) + */ + if (unlikely((PyObject_SetItem(((PyObject *)__pyx_v_slice_indices), __pyx_tuple__13, __pyx_v_cumsum) < 0))) __PYX_ERR(0, 102, __pyx_L1_error) + + /* "fairseq/data/token_block_utils_fast.pyx":98 + * slice_indices_list.append((tok_idx, tok_idx + curr_size)) + * slice_indices = _fast_convert_to_np_array(slice_indices_list) + * elif break_mode == 'eos': # <<<<<<<<<<<<<< + * slice_indices = np.zeros((len(sizes), 2), dtype=DTYPE) + * cumsum = sizes.cumsum(axis=0) + */ + goto __pyx_L3; + } + + /* "fairseq/data/token_block_utils_fast.pyx":104 + * slice_indices[:, 1] = cumsum + * else: + * raise ValueError('Invalid break_mode: ' + break_mode) # <<<<<<<<<<<<<< + * return slice_indices + * + */ + /*else*/ { + __pyx_t_12 = __Pyx_PyUnicode_ConcatSafe(__pyx_kp_u_Invalid_break_mode, __pyx_v_break_mode); if (unlikely(!__pyx_t_12)) __PYX_ERR(0, 104, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_12); + __pyx_t_2 = __Pyx_PyObject_CallOneArg(__pyx_builtin_ValueError, __pyx_t_12); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 104, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_12); __pyx_t_12 = 0; + __Pyx_Raise(__pyx_t_2, 0, 0, 0); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __PYX_ERR(0, 104, __pyx_L1_error) + } + __pyx_L3:; + + /* "fairseq/data/token_block_utils_fast.pyx":105 + * else: + * raise ValueError('Invalid break_mode: ' + break_mode) + * return slice_indices # <<<<<<<<<<<<<< + * + * + */ + __Pyx_XDECREF((PyObject *)__pyx_r); + __Pyx_INCREF((PyObject *)__pyx_v_slice_indices); + __pyx_r = ((PyArrayObject *)__pyx_v_slice_indices); + goto __pyx_L0; + + /* "fairseq/data/token_block_utils_fast.pyx":52 + * @cython.wraparound(False) + * @cython.nonecheck(False) + * cpdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_fast(np.ndarray[DTYPE_t, ndim=1] sizes, str break_mode, int block_size, int document_sep_len): # <<<<<<<<<<<<<< + * cdef DTYPE_t tok_idx = 0 + * cdef DTYPE_t sz_idx = 0 + */ + + /* function exit code */ + __pyx_L1_error:; + __PYX_XCLEAR_MEMVIEW(&__pyx_t_1, 1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_11); + __Pyx_XDECREF(__pyx_t_12); + __Pyx_XDECREF(__pyx_t_14); + { PyObject *__pyx_type, *__pyx_value, *__pyx_tb; + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ErrFetch(&__pyx_type, &__pyx_value, &__pyx_tb); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_sizes.rcbuffer->pybuffer); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_slice_indices.rcbuffer->pybuffer); + __Pyx_ErrRestore(__pyx_type, __pyx_value, __pyx_tb);} + __Pyx_AddTraceback("fairseq.data.token_block_utils_fast._get_slice_indices_fast", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + goto __pyx_L2; + __pyx_L0:; + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_sizes.rcbuffer->pybuffer); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_slice_indices.rcbuffer->pybuffer); + __pyx_L2:; + __PYX_XCLEAR_MEMVIEW(&__pyx_v_sizes_view, 1); + __Pyx_XDECREF((PyObject *)__pyx_v_slice_indices); + __Pyx_XDECREF(__pyx_v_slice_indices_list); + __Pyx_XDECREF(__pyx_v_cumsum); + __Pyx_XGIVEREF((PyObject *)__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* Python wrapper */ +static PyObject *__pyx_pw_7fairseq_4data_22token_block_utils_fast_1_get_slice_indices_fast(PyObject *__pyx_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyMethodDef __pyx_mdef_7fairseq_4data_22token_block_utils_fast_1_get_slice_indices_fast = {"_get_slice_indices_fast", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw_7fairseq_4data_22token_block_utils_fast_1_get_slice_indices_fast, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}; +static PyObject *__pyx_pw_7fairseq_4data_22token_block_utils_fast_1_get_slice_indices_fast(PyObject *__pyx_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + PyArrayObject *__pyx_v_sizes = 0; + PyObject *__pyx_v_break_mode = 0; + int __pyx_v_block_size; + int __pyx_v_document_sep_len; + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[4] = {0,0,0,0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("_get_slice_indices_fast (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_sizes,&__pyx_n_s_break_mode,&__pyx_n_s_block_size,&__pyx_n_s_document_sep_len,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 4: values[3] = __Pyx_Arg_FASTCALL(__pyx_args, 3); + CYTHON_FALLTHROUGH; + case 3: values[2] = __Pyx_Arg_FASTCALL(__pyx_args, 2); + CYTHON_FALLTHROUGH; + case 2: values[1] = __Pyx_Arg_FASTCALL(__pyx_args, 1); + CYTHON_FALLTHROUGH; + case 1: values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_FASTCALL(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_sizes)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 52, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + CYTHON_FALLTHROUGH; + case 1: + if (likely((values[1] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_break_mode)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[1]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 52, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("_get_slice_indices_fast", 1, 4, 4, 1); __PYX_ERR(0, 52, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 2: + if (likely((values[2] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_block_size)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[2]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 52, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("_get_slice_indices_fast", 1, 4, 4, 2); __PYX_ERR(0, 52, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 3: + if (likely((values[3] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_document_sep_len)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[3]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 52, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("_get_slice_indices_fast", 1, 4, 4, 3); __PYX_ERR(0, 52, __pyx_L3_error) + } + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "_get_slice_indices_fast") < 0)) __PYX_ERR(0, 52, __pyx_L3_error) + } + } else if (unlikely(__pyx_nargs != 4)) { + goto __pyx_L5_argtuple_error; + } else { + values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + values[1] = __Pyx_Arg_FASTCALL(__pyx_args, 1); + values[2] = __Pyx_Arg_FASTCALL(__pyx_args, 2); + values[3] = __Pyx_Arg_FASTCALL(__pyx_args, 3); + } + __pyx_v_sizes = ((PyArrayObject *)values[0]); + __pyx_v_break_mode = ((PyObject*)values[1]); + __pyx_v_block_size = __Pyx_PyInt_As_int(values[2]); if (unlikely((__pyx_v_block_size == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 52, __pyx_L3_error) + __pyx_v_document_sep_len = __Pyx_PyInt_As_int(values[3]); if (unlikely((__pyx_v_document_sep_len == (int)-1) && PyErr_Occurred())) __PYX_ERR(0, 52, __pyx_L3_error) + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("_get_slice_indices_fast", 1, 4, 4, __pyx_nargs); __PYX_ERR(0, 52, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("fairseq.data.token_block_utils_fast._get_slice_indices_fast", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return NULL; + __pyx_L4_argument_unpacking_done:; + if (unlikely(!__Pyx_ArgTypeTest(((PyObject *)__pyx_v_sizes), __pyx_ptype_5numpy_ndarray, 1, "sizes", 0))) __PYX_ERR(0, 52, __pyx_L1_error) + if (unlikely(!__Pyx_ArgTypeTest(((PyObject *)__pyx_v_break_mode), (&PyUnicode_Type), 1, "break_mode", 1))) __PYX_ERR(0, 52, __pyx_L1_error) + __pyx_r = __pyx_pf_7fairseq_4data_22token_block_utils_fast__get_slice_indices_fast(__pyx_self, __pyx_v_sizes, __pyx_v_break_mode, __pyx_v_block_size, __pyx_v_document_sep_len); + + /* function exit code */ + goto __pyx_L0; + __pyx_L1_error:; + __pyx_r = NULL; + __pyx_L0:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_7fairseq_4data_22token_block_utils_fast__get_slice_indices_fast(CYTHON_UNUSED PyObject *__pyx_self, PyArrayObject *__pyx_v_sizes, PyObject *__pyx_v_break_mode, int __pyx_v_block_size, int __pyx_v_document_sep_len) { + __Pyx_LocalBuf_ND __pyx_pybuffernd_sizes; + __Pyx_Buffer __pyx_pybuffer_sizes; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("_get_slice_indices_fast", 1); + __pyx_pybuffer_sizes.pybuffer.buf = NULL; + __pyx_pybuffer_sizes.refcount = 0; + __pyx_pybuffernd_sizes.data = NULL; + __pyx_pybuffernd_sizes.rcbuffer = &__pyx_pybuffer_sizes; + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_sizes.rcbuffer->pybuffer, (PyObject*)__pyx_v_sizes, &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t, PyBUF_FORMAT| PyBUF_STRIDES, 1, 0, __pyx_stack) == -1)) __PYX_ERR(0, 52, __pyx_L1_error) + } + __pyx_pybuffernd_sizes.diminfo[0].strides = __pyx_pybuffernd_sizes.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_sizes.diminfo[0].shape = __pyx_pybuffernd_sizes.rcbuffer->pybuffer.shape[0]; + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = ((PyObject *)__pyx_f_7fairseq_4data_22token_block_utils_fast__get_slice_indices_fast(__pyx_v_sizes, __pyx_v_break_mode, __pyx_v_block_size, __pyx_v_document_sep_len, 0)); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 52, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + { PyObject *__pyx_type, *__pyx_value, *__pyx_tb; + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ErrFetch(&__pyx_type, &__pyx_value, &__pyx_tb); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_sizes.rcbuffer->pybuffer); + __Pyx_ErrRestore(__pyx_type, __pyx_value, __pyx_tb);} + __Pyx_AddTraceback("fairseq.data.token_block_utils_fast._get_slice_indices_fast", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + goto __pyx_L2; + __pyx_L0:; + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_sizes.rcbuffer->pybuffer); + __pyx_L2:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "fairseq/data/token_block_utils_fast.pyx":111 + * @cython.wraparound(False) + * @cython.nonecheck(False) + * cpdef np.ndarray[DTYPE_t, ndim=2] _get_block_to_dataset_index_fast(np.ndarray[DTYPE_t, ndim=1] sizes, np.ndarray[DTYPE_t, ndim=2] slice_indices): # <<<<<<<<<<<<<< + * cdef DTYPE_t start_ds_idx + * cdef DTYPE_t start_offset + */ + +static PyObject *__pyx_pw_7fairseq_4data_22token_block_utils_fast_3_get_block_to_dataset_index_fast(PyObject *__pyx_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyArrayObject *__pyx_f_7fairseq_4data_22token_block_utils_fast__get_block_to_dataset_index_fast(PyArrayObject *__pyx_v_sizes, PyArrayObject *__pyx_v_slice_indices, CYTHON_UNUSED int __pyx_skip_dispatch) { + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_v_start_ds_idx; + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_v_start_offset; + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_v_end_ds_idx; + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_v_i; + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_v_s; + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_v_e; + struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *__pyx_v_ds = 0; + PyArrayObject *__pyx_v_block_to_dataset_index = 0; + __Pyx_memviewslice __pyx_v_block_to_dataset_index_view = { 0, 0, { 0 }, { 0 }, { 0 } }; + __Pyx_memviewslice __pyx_v_slice_indices_view = { 0, 0, { 0 }, { 0 }, { 0 } }; + Py_ssize_t __pyx_v_x_max; + __Pyx_LocalBuf_ND __pyx_pybuffernd_block_to_dataset_index; + __Pyx_Buffer __pyx_pybuffer_block_to_dataset_index; + __Pyx_LocalBuf_ND __pyx_pybuffernd_sizes; + __Pyx_Buffer __pyx_pybuffer_sizes; + __Pyx_LocalBuf_ND __pyx_pybuffernd_slice_indices; + __Pyx_Buffer __pyx_pybuffer_slice_indices; + PyArrayObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + Py_ssize_t __pyx_t_3; + PyObject *__pyx_t_4 = NULL; + PyObject *__pyx_t_5 = NULL; + PyArrayObject *__pyx_t_6 = NULL; + __Pyx_memviewslice __pyx_t_7 = { 0, 0, { 0 }, { 0 }, { 0 } }; + npy_intp *__pyx_t_8; + Py_ssize_t __pyx_t_9; + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_t_10; + Py_ssize_t __pyx_t_11; + Py_ssize_t __pyx_t_12; + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_t_13; + int __pyx_t_14; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("_get_block_to_dataset_index_fast", 1); + __pyx_pybuffer_block_to_dataset_index.pybuffer.buf = NULL; + __pyx_pybuffer_block_to_dataset_index.refcount = 0; + __pyx_pybuffernd_block_to_dataset_index.data = NULL; + __pyx_pybuffernd_block_to_dataset_index.rcbuffer = &__pyx_pybuffer_block_to_dataset_index; + __pyx_pybuffer_sizes.pybuffer.buf = NULL; + __pyx_pybuffer_sizes.refcount = 0; + __pyx_pybuffernd_sizes.data = NULL; + __pyx_pybuffernd_sizes.rcbuffer = &__pyx_pybuffer_sizes; + __pyx_pybuffer_slice_indices.pybuffer.buf = NULL; + __pyx_pybuffer_slice_indices.refcount = 0; + __pyx_pybuffernd_slice_indices.data = NULL; + __pyx_pybuffernd_slice_indices.rcbuffer = &__pyx_pybuffer_slice_indices; + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_sizes.rcbuffer->pybuffer, (PyObject*)__pyx_v_sizes, &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t, PyBUF_FORMAT| PyBUF_STRIDES, 1, 0, __pyx_stack) == -1)) __PYX_ERR(0, 111, __pyx_L1_error) + } + __pyx_pybuffernd_sizes.diminfo[0].strides = __pyx_pybuffernd_sizes.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_sizes.diminfo[0].shape = __pyx_pybuffernd_sizes.rcbuffer->pybuffer.shape[0]; + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_slice_indices.rcbuffer->pybuffer, (PyObject*)__pyx_v_slice_indices, &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t, PyBUF_FORMAT| PyBUF_STRIDES, 2, 0, __pyx_stack) == -1)) __PYX_ERR(0, 111, __pyx_L1_error) + } + __pyx_pybuffernd_slice_indices.diminfo[0].strides = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_slice_indices.diminfo[0].shape = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.shape[0]; __pyx_pybuffernd_slice_indices.diminfo[1].strides = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.strides[1]; __pyx_pybuffernd_slice_indices.diminfo[1].shape = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.shape[1]; + + /* "fairseq/data/token_block_utils_fast.pyx":118 + * cdef DTYPE_t s + * cdef DTYPE_t e + * cdef DatasetSearcher ds = DatasetSearcher(sizes) # <<<<<<<<<<<<<< + * cdef np.ndarray[DTYPE_t, ndim=2] block_to_dataset_index = np.zeros([len(slice_indices), 3], dtype=DTYPE) + * cdef DTYPE_t[:, :] block_to_dataset_index_view = block_to_dataset_index + */ + __pyx_t_1 = __Pyx_PyObject_CallOneArg(((PyObject *)__pyx_ptype_7fairseq_4data_22token_block_utils_fast_DatasetSearcher), ((PyObject *)__pyx_v_sizes)); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 118, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_v_ds = ((struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *)__pyx_t_1); + __pyx_t_1 = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":119 + * cdef DTYPE_t e + * cdef DatasetSearcher ds = DatasetSearcher(sizes) + * cdef np.ndarray[DTYPE_t, ndim=2] block_to_dataset_index = np.zeros([len(slice_indices), 3], dtype=DTYPE) # <<<<<<<<<<<<<< + * cdef DTYPE_t[:, :] block_to_dataset_index_view = block_to_dataset_index + * cdef DTYPE_t[:, :] slice_indices_view = slice_indices + */ + __Pyx_GetModuleGlobalName(__pyx_t_1, __pyx_n_s_np); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 119, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_t_1, __pyx_n_s_zeros); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 119, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_t_3 = PyObject_Length(((PyObject *)__pyx_v_slice_indices)); if (unlikely(__pyx_t_3 == ((Py_ssize_t)-1))) __PYX_ERR(0, 119, __pyx_L1_error) + __pyx_t_1 = PyInt_FromSsize_t(__pyx_t_3); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 119, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_4 = PyList_New(2); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 119, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_GIVEREF(__pyx_t_1); + if (__Pyx_PyList_SET_ITEM(__pyx_t_4, 0, __pyx_t_1)) __PYX_ERR(0, 119, __pyx_L1_error); + __Pyx_INCREF(__pyx_int_3); + __Pyx_GIVEREF(__pyx_int_3); + if (__Pyx_PyList_SET_ITEM(__pyx_t_4, 1, __pyx_int_3)) __PYX_ERR(0, 119, __pyx_L1_error); + __pyx_t_1 = 0; + __pyx_t_1 = PyTuple_New(1); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 119, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_GIVEREF(__pyx_t_4); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 0, __pyx_t_4)) __PYX_ERR(0, 119, __pyx_L1_error); + __pyx_t_4 = 0; + __pyx_t_4 = __Pyx_PyDict_NewPresized(1); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 119, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_GetModuleGlobalName(__pyx_t_5, __pyx_n_s_DTYPE); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 119, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + if (PyDict_SetItem(__pyx_t_4, __pyx_n_s_dtype, __pyx_t_5) < 0) __PYX_ERR(0, 119, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __pyx_t_5 = __Pyx_PyObject_Call(__pyx_t_2, __pyx_t_1, __pyx_t_4); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 119, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + if (!(likely(((__pyx_t_5) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_5, __pyx_ptype_5numpy_ndarray))))) __PYX_ERR(0, 119, __pyx_L1_error) + __pyx_t_6 = ((PyArrayObject *)__pyx_t_5); + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_block_to_dataset_index.rcbuffer->pybuffer, (PyObject*)__pyx_t_6, &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t, PyBUF_FORMAT| PyBUF_STRIDES, 2, 0, __pyx_stack) == -1)) { + __pyx_v_block_to_dataset_index = ((PyArrayObject *)Py_None); __Pyx_INCREF(Py_None); __pyx_pybuffernd_block_to_dataset_index.rcbuffer->pybuffer.buf = NULL; + __PYX_ERR(0, 119, __pyx_L1_error) + } else {__pyx_pybuffernd_block_to_dataset_index.diminfo[0].strides = __pyx_pybuffernd_block_to_dataset_index.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_block_to_dataset_index.diminfo[0].shape = __pyx_pybuffernd_block_to_dataset_index.rcbuffer->pybuffer.shape[0]; __pyx_pybuffernd_block_to_dataset_index.diminfo[1].strides = __pyx_pybuffernd_block_to_dataset_index.rcbuffer->pybuffer.strides[1]; __pyx_pybuffernd_block_to_dataset_index.diminfo[1].shape = __pyx_pybuffernd_block_to_dataset_index.rcbuffer->pybuffer.shape[1]; + } + } + __pyx_t_6 = 0; + __pyx_v_block_to_dataset_index = ((PyArrayObject *)__pyx_t_5); + __pyx_t_5 = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":120 + * cdef DatasetSearcher ds = DatasetSearcher(sizes) + * cdef np.ndarray[DTYPE_t, ndim=2] block_to_dataset_index = np.zeros([len(slice_indices), 3], dtype=DTYPE) + * cdef DTYPE_t[:, :] block_to_dataset_index_view = block_to_dataset_index # <<<<<<<<<<<<<< + * cdef DTYPE_t[:, :] slice_indices_view = slice_indices + * cdef Py_ssize_t x_max = slice_indices.shape[0] + */ + __pyx_t_7 = __Pyx_PyObject_to_MemoryviewSlice_dsds_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t(((PyObject *)__pyx_v_block_to_dataset_index), PyBUF_WRITABLE); if (unlikely(!__pyx_t_7.memview)) __PYX_ERR(0, 120, __pyx_L1_error) + __pyx_v_block_to_dataset_index_view = __pyx_t_7; + __pyx_t_7.memview = NULL; + __pyx_t_7.data = NULL; + + /* "fairseq/data/token_block_utils_fast.pyx":121 + * cdef np.ndarray[DTYPE_t, ndim=2] block_to_dataset_index = np.zeros([len(slice_indices), 3], dtype=DTYPE) + * cdef DTYPE_t[:, :] block_to_dataset_index_view = block_to_dataset_index + * cdef DTYPE_t[:, :] slice_indices_view = slice_indices # <<<<<<<<<<<<<< + * cdef Py_ssize_t x_max = slice_indices.shape[0] + * + */ + __pyx_t_7 = __Pyx_PyObject_to_MemoryviewSlice_dsds_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t(((PyObject *)__pyx_v_slice_indices), PyBUF_WRITABLE); if (unlikely(!__pyx_t_7.memview)) __PYX_ERR(0, 121, __pyx_L1_error) + __pyx_v_slice_indices_view = __pyx_t_7; + __pyx_t_7.memview = NULL; + __pyx_t_7.data = NULL; + + /* "fairseq/data/token_block_utils_fast.pyx":122 + * cdef DTYPE_t[:, :] block_to_dataset_index_view = block_to_dataset_index + * cdef DTYPE_t[:, :] slice_indices_view = slice_indices + * cdef Py_ssize_t x_max = slice_indices.shape[0] # <<<<<<<<<<<<<< + * + * for i in range(x_max): + */ + __pyx_t_8 = __pyx_f_5numpy_7ndarray_5shape_shape(((PyArrayObject *)__pyx_v_slice_indices)); if (unlikely(__pyx_t_8 == ((npy_intp *)NULL) && PyErr_Occurred())) __PYX_ERR(0, 122, __pyx_L1_error) + __pyx_v_x_max = (__pyx_t_8[0]); + + /* "fairseq/data/token_block_utils_fast.pyx":124 + * cdef Py_ssize_t x_max = slice_indices.shape[0] + * + * for i in range(x_max): # <<<<<<<<<<<<<< + * s = slice_indices_view[i][0] + * e = slice_indices_view[i][1] + */ + __pyx_t_3 = __pyx_v_x_max; + __pyx_t_9 = __pyx_t_3; + for (__pyx_t_10 = 0; __pyx_t_10 < __pyx_t_9; __pyx_t_10+=1) { + __pyx_v_i = __pyx_t_10; + + /* "fairseq/data/token_block_utils_fast.pyx":125 + * + * for i in range(x_max): + * s = slice_indices_view[i][0] # <<<<<<<<<<<<<< + * e = slice_indices_view[i][1] + * ds.seek(s) + */ + __pyx_t_11 = __pyx_v_i; + __pyx_t_12 = 0; + __pyx_v_s = (*((__pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_slice_indices_view.data + __pyx_t_11 * __pyx_v_slice_indices_view.strides[0]) ) + __pyx_t_12 * __pyx_v_slice_indices_view.strides[1]) ))); + + /* "fairseq/data/token_block_utils_fast.pyx":126 + * for i in range(x_max): + * s = slice_indices_view[i][0] + * e = slice_indices_view[i][1] # <<<<<<<<<<<<<< + * ds.seek(s) + * start_ds_idx = ds.current_index + */ + __pyx_t_12 = __pyx_v_i; + __pyx_t_11 = 1; + __pyx_v_e = (*((__pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_slice_indices_view.data + __pyx_t_12 * __pyx_v_slice_indices_view.strides[0]) ) + __pyx_t_11 * __pyx_v_slice_indices_view.strides[1]) ))); + + /* "fairseq/data/token_block_utils_fast.pyx":127 + * s = slice_indices_view[i][0] + * e = slice_indices_view[i][1] + * ds.seek(s) # <<<<<<<<<<<<<< + * start_ds_idx = ds.current_index + * start_offset = ds.current_offset + */ + __pyx_t_5 = ((struct __pyx_vtabstruct_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *)__pyx_v_ds->__pyx_vtab)->seek(__pyx_v_ds, __pyx_v_s); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 127, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":128 + * e = slice_indices_view[i][1] + * ds.seek(s) + * start_ds_idx = ds.current_index # <<<<<<<<<<<<<< + * start_offset = ds.current_offset + * if e <= s: + */ + __pyx_t_13 = __pyx_v_ds->current_index; + __pyx_v_start_ds_idx = __pyx_t_13; + + /* "fairseq/data/token_block_utils_fast.pyx":129 + * ds.seek(s) + * start_ds_idx = ds.current_index + * start_offset = ds.current_offset # <<<<<<<<<<<<<< + * if e <= s: + * end_ds_idx = start_ds_idx + */ + __pyx_t_13 = __pyx_v_ds->current_offset; + __pyx_v_start_offset = __pyx_t_13; + + /* "fairseq/data/token_block_utils_fast.pyx":130 + * start_ds_idx = ds.current_index + * start_offset = ds.current_offset + * if e <= s: # <<<<<<<<<<<<<< + * end_ds_idx = start_ds_idx + * else: + */ + __pyx_t_14 = (__pyx_v_e <= __pyx_v_s); + if (__pyx_t_14) { + + /* "fairseq/data/token_block_utils_fast.pyx":131 + * start_offset = ds.current_offset + * if e <= s: + * end_ds_idx = start_ds_idx # <<<<<<<<<<<<<< + * else: + * ds.seek(e - 1) + */ + __pyx_v_end_ds_idx = __pyx_v_start_ds_idx; + + /* "fairseq/data/token_block_utils_fast.pyx":130 + * start_ds_idx = ds.current_index + * start_offset = ds.current_offset + * if e <= s: # <<<<<<<<<<<<<< + * end_ds_idx = start_ds_idx + * else: + */ + goto __pyx_L5; + } + + /* "fairseq/data/token_block_utils_fast.pyx":133 + * end_ds_idx = start_ds_idx + * else: + * ds.seek(e - 1) # <<<<<<<<<<<<<< + * end_ds_idx = ds.current_index + * block_to_dataset_index_view[i][0] = start_ds_idx # starting index in dataset + */ + /*else*/ { + __pyx_t_5 = ((struct __pyx_vtabstruct_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *)__pyx_v_ds->__pyx_vtab)->seek(__pyx_v_ds, (__pyx_v_e - 1)); if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 133, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":134 + * else: + * ds.seek(e - 1) + * end_ds_idx = ds.current_index # <<<<<<<<<<<<<< + * block_to_dataset_index_view[i][0] = start_ds_idx # starting index in dataset + * block_to_dataset_index_view[i][1] = start_offset # starting offset within starting index + */ + __pyx_t_13 = __pyx_v_ds->current_index; + __pyx_v_end_ds_idx = __pyx_t_13; + } + __pyx_L5:; + + /* "fairseq/data/token_block_utils_fast.pyx":135 + * ds.seek(e - 1) + * end_ds_idx = ds.current_index + * block_to_dataset_index_view[i][0] = start_ds_idx # starting index in dataset # <<<<<<<<<<<<<< + * block_to_dataset_index_view[i][1] = start_offset # starting offset within starting index + * block_to_dataset_index_view[i][2] = end_ds_idx # ending index in dataset + */ + __pyx_t_11 = __pyx_v_i; + __pyx_t_12 = 0; + *((__pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_block_to_dataset_index_view.data + __pyx_t_11 * __pyx_v_block_to_dataset_index_view.strides[0]) ) + __pyx_t_12 * __pyx_v_block_to_dataset_index_view.strides[1]) )) = __pyx_v_start_ds_idx; + + /* "fairseq/data/token_block_utils_fast.pyx":136 + * end_ds_idx = ds.current_index + * block_to_dataset_index_view[i][0] = start_ds_idx # starting index in dataset + * block_to_dataset_index_view[i][1] = start_offset # starting offset within starting index # <<<<<<<<<<<<<< + * block_to_dataset_index_view[i][2] = end_ds_idx # ending index in dataset + * return block_to_dataset_index + */ + __pyx_t_12 = __pyx_v_i; + __pyx_t_11 = 1; + *((__pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_block_to_dataset_index_view.data + __pyx_t_12 * __pyx_v_block_to_dataset_index_view.strides[0]) ) + __pyx_t_11 * __pyx_v_block_to_dataset_index_view.strides[1]) )) = __pyx_v_start_offset; + + /* "fairseq/data/token_block_utils_fast.pyx":137 + * block_to_dataset_index_view[i][0] = start_ds_idx # starting index in dataset + * block_to_dataset_index_view[i][1] = start_offset # starting offset within starting index + * block_to_dataset_index_view[i][2] = end_ds_idx # ending index in dataset # <<<<<<<<<<<<<< + * return block_to_dataset_index + * + */ + __pyx_t_11 = __pyx_v_i; + __pyx_t_12 = 2; + *((__pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_block_to_dataset_index_view.data + __pyx_t_11 * __pyx_v_block_to_dataset_index_view.strides[0]) ) + __pyx_t_12 * __pyx_v_block_to_dataset_index_view.strides[1]) )) = __pyx_v_end_ds_idx; + } + + /* "fairseq/data/token_block_utils_fast.pyx":138 + * block_to_dataset_index_view[i][1] = start_offset # starting offset within starting index + * block_to_dataset_index_view[i][2] = end_ds_idx # ending index in dataset + * return block_to_dataset_index # <<<<<<<<<<<<<< + * + * + */ + __Pyx_XDECREF((PyObject *)__pyx_r); + __Pyx_INCREF((PyObject *)__pyx_v_block_to_dataset_index); + __pyx_r = ((PyArrayObject *)__pyx_v_block_to_dataset_index); + goto __pyx_L0; + + /* "fairseq/data/token_block_utils_fast.pyx":111 + * @cython.wraparound(False) + * @cython.nonecheck(False) + * cpdef np.ndarray[DTYPE_t, ndim=2] _get_block_to_dataset_index_fast(np.ndarray[DTYPE_t, ndim=1] sizes, np.ndarray[DTYPE_t, ndim=2] slice_indices): # <<<<<<<<<<<<<< + * cdef DTYPE_t start_ds_idx + * cdef DTYPE_t start_offset + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_4); + __Pyx_XDECREF(__pyx_t_5); + __PYX_XCLEAR_MEMVIEW(&__pyx_t_7, 1); + { PyObject *__pyx_type, *__pyx_value, *__pyx_tb; + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ErrFetch(&__pyx_type, &__pyx_value, &__pyx_tb); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_block_to_dataset_index.rcbuffer->pybuffer); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_sizes.rcbuffer->pybuffer); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_slice_indices.rcbuffer->pybuffer); + __Pyx_ErrRestore(__pyx_type, __pyx_value, __pyx_tb);} + __Pyx_AddTraceback("fairseq.data.token_block_utils_fast._get_block_to_dataset_index_fast", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + goto __pyx_L2; + __pyx_L0:; + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_block_to_dataset_index.rcbuffer->pybuffer); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_sizes.rcbuffer->pybuffer); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_slice_indices.rcbuffer->pybuffer); + __pyx_L2:; + __Pyx_XDECREF((PyObject *)__pyx_v_ds); + __Pyx_XDECREF((PyObject *)__pyx_v_block_to_dataset_index); + __PYX_XCLEAR_MEMVIEW(&__pyx_v_block_to_dataset_index_view, 1); + __PYX_XCLEAR_MEMVIEW(&__pyx_v_slice_indices_view, 1); + __Pyx_XGIVEREF((PyObject *)__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* Python wrapper */ +static PyObject *__pyx_pw_7fairseq_4data_22token_block_utils_fast_3_get_block_to_dataset_index_fast(PyObject *__pyx_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyMethodDef __pyx_mdef_7fairseq_4data_22token_block_utils_fast_3_get_block_to_dataset_index_fast = {"_get_block_to_dataset_index_fast", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw_7fairseq_4data_22token_block_utils_fast_3_get_block_to_dataset_index_fast, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}; +static PyObject *__pyx_pw_7fairseq_4data_22token_block_utils_fast_3_get_block_to_dataset_index_fast(PyObject *__pyx_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + PyArrayObject *__pyx_v_sizes = 0; + PyArrayObject *__pyx_v_slice_indices = 0; + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[2] = {0,0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("_get_block_to_dataset_index_fast (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_sizes,&__pyx_n_s_slice_indices,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 2: values[1] = __Pyx_Arg_FASTCALL(__pyx_args, 1); + CYTHON_FALLTHROUGH; + case 1: values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_FASTCALL(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_sizes)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 111, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + CYTHON_FALLTHROUGH; + case 1: + if (likely((values[1] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_slice_indices)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[1]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 111, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("_get_block_to_dataset_index_fast", 1, 2, 2, 1); __PYX_ERR(0, 111, __pyx_L3_error) + } + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "_get_block_to_dataset_index_fast") < 0)) __PYX_ERR(0, 111, __pyx_L3_error) + } + } else if (unlikely(__pyx_nargs != 2)) { + goto __pyx_L5_argtuple_error; + } else { + values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + values[1] = __Pyx_Arg_FASTCALL(__pyx_args, 1); + } + __pyx_v_sizes = ((PyArrayObject *)values[0]); + __pyx_v_slice_indices = ((PyArrayObject *)values[1]); + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("_get_block_to_dataset_index_fast", 1, 2, 2, __pyx_nargs); __PYX_ERR(0, 111, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("fairseq.data.token_block_utils_fast._get_block_to_dataset_index_fast", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return NULL; + __pyx_L4_argument_unpacking_done:; + if (unlikely(!__Pyx_ArgTypeTest(((PyObject *)__pyx_v_sizes), __pyx_ptype_5numpy_ndarray, 1, "sizes", 0))) __PYX_ERR(0, 111, __pyx_L1_error) + if (unlikely(!__Pyx_ArgTypeTest(((PyObject *)__pyx_v_slice_indices), __pyx_ptype_5numpy_ndarray, 1, "slice_indices", 0))) __PYX_ERR(0, 111, __pyx_L1_error) + __pyx_r = __pyx_pf_7fairseq_4data_22token_block_utils_fast_2_get_block_to_dataset_index_fast(__pyx_self, __pyx_v_sizes, __pyx_v_slice_indices); + + /* function exit code */ + goto __pyx_L0; + __pyx_L1_error:; + __pyx_r = NULL; + __pyx_L0:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_7fairseq_4data_22token_block_utils_fast_2_get_block_to_dataset_index_fast(CYTHON_UNUSED PyObject *__pyx_self, PyArrayObject *__pyx_v_sizes, PyArrayObject *__pyx_v_slice_indices) { + __Pyx_LocalBuf_ND __pyx_pybuffernd_sizes; + __Pyx_Buffer __pyx_pybuffer_sizes; + __Pyx_LocalBuf_ND __pyx_pybuffernd_slice_indices; + __Pyx_Buffer __pyx_pybuffer_slice_indices; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("_get_block_to_dataset_index_fast", 1); + __pyx_pybuffer_sizes.pybuffer.buf = NULL; + __pyx_pybuffer_sizes.refcount = 0; + __pyx_pybuffernd_sizes.data = NULL; + __pyx_pybuffernd_sizes.rcbuffer = &__pyx_pybuffer_sizes; + __pyx_pybuffer_slice_indices.pybuffer.buf = NULL; + __pyx_pybuffer_slice_indices.refcount = 0; + __pyx_pybuffernd_slice_indices.data = NULL; + __pyx_pybuffernd_slice_indices.rcbuffer = &__pyx_pybuffer_slice_indices; + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_sizes.rcbuffer->pybuffer, (PyObject*)__pyx_v_sizes, &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t, PyBUF_FORMAT| PyBUF_STRIDES, 1, 0, __pyx_stack) == -1)) __PYX_ERR(0, 111, __pyx_L1_error) + } + __pyx_pybuffernd_sizes.diminfo[0].strides = __pyx_pybuffernd_sizes.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_sizes.diminfo[0].shape = __pyx_pybuffernd_sizes.rcbuffer->pybuffer.shape[0]; + { + __Pyx_BufFmt_StackElem __pyx_stack[1]; + if (unlikely(__Pyx_GetBufferAndValidate(&__pyx_pybuffernd_slice_indices.rcbuffer->pybuffer, (PyObject*)__pyx_v_slice_indices, &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t, PyBUF_FORMAT| PyBUF_STRIDES, 2, 0, __pyx_stack) == -1)) __PYX_ERR(0, 111, __pyx_L1_error) + } + __pyx_pybuffernd_slice_indices.diminfo[0].strides = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.strides[0]; __pyx_pybuffernd_slice_indices.diminfo[0].shape = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.shape[0]; __pyx_pybuffernd_slice_indices.diminfo[1].strides = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.strides[1]; __pyx_pybuffernd_slice_indices.diminfo[1].shape = __pyx_pybuffernd_slice_indices.rcbuffer->pybuffer.shape[1]; + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = ((PyObject *)__pyx_f_7fairseq_4data_22token_block_utils_fast__get_block_to_dataset_index_fast(__pyx_v_sizes, __pyx_v_slice_indices, 0)); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 111, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + { PyObject *__pyx_type, *__pyx_value, *__pyx_tb; + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ErrFetch(&__pyx_type, &__pyx_value, &__pyx_tb); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_sizes.rcbuffer->pybuffer); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_slice_indices.rcbuffer->pybuffer); + __Pyx_ErrRestore(__pyx_type, __pyx_value, __pyx_tb);} + __Pyx_AddTraceback("fairseq.data.token_block_utils_fast._get_block_to_dataset_index_fast", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + goto __pyx_L2; + __pyx_L0:; + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_sizes.rcbuffer->pybuffer); + __Pyx_SafeReleaseBuffer(&__pyx_pybuffernd_slice_indices.rcbuffer->pybuffer); + __pyx_L2:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "fairseq/data/token_block_utils_fast.pyx":149 + * cdef DTYPE_t[:] sizes + * + * def __init__(self, DTYPE_t[:] sizes): # <<<<<<<<<<<<<< + * self.sizes = sizes + * self.reset() + */ + +/* Python wrapper */ +static int __pyx_pw_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_1__init__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ +static int __pyx_pw_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_1__init__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { + __Pyx_memviewslice __pyx_v_sizes = { 0, 0, { 0 }, { 0 }, { 0 } }; + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[1] = {0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + int __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__init__ (wrapper)", 0); + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return -1; + #endif + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_sizes,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 1: values[0] = __Pyx_Arg_VARARGS(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_VARARGS(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_sizes)) != 0)) { + (void)__Pyx_Arg_NewRef_VARARGS(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 149, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "__init__") < 0)) __PYX_ERR(0, 149, __pyx_L3_error) + } + } else if (unlikely(__pyx_nargs != 1)) { + goto __pyx_L5_argtuple_error; + } else { + values[0] = __Pyx_Arg_VARARGS(__pyx_args, 0); + } + __pyx_v_sizes = __Pyx_PyObject_to_MemoryviewSlice_ds_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t(values[0], PyBUF_WRITABLE); if (unlikely(!__pyx_v_sizes.memview)) __PYX_ERR(0, 149, __pyx_L3_error) + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("__init__", 1, 1, 1, __pyx_nargs); __PYX_ERR(0, 149, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_VARARGS(values[__pyx_temp]); + } + } + __PYX_XCLEAR_MEMVIEW(&__pyx_v_sizes, 1); + __Pyx_AddTraceback("fairseq.data.token_block_utils_fast.DatasetSearcher.__init__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return -1; + __pyx_L4_argument_unpacking_done:; + __pyx_r = __pyx_pf_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher___init__(((struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *)__pyx_v_self), __pyx_v_sizes); + + /* function exit code */ + __PYX_XCLEAR_MEMVIEW(&__pyx_v_sizes, 1); + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_VARARGS(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static int __pyx_pf_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher___init__(struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *__pyx_v_self, __Pyx_memviewslice __pyx_v_sizes) { + int __pyx_r; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__init__", 1); + + /* "fairseq/data/token_block_utils_fast.pyx":150 + * + * def __init__(self, DTYPE_t[:] sizes): + * self.sizes = sizes # <<<<<<<<<<<<<< + * self.reset() + * + */ + __PYX_XCLEAR_MEMVIEW(&__pyx_v_self->sizes, 0); + __PYX_INC_MEMVIEW(&__pyx_v_sizes, 1); + __pyx_v_self->sizes = __pyx_v_sizes; + + /* "fairseq/data/token_block_utils_fast.pyx":151 + * def __init__(self, DTYPE_t[:] sizes): + * self.sizes = sizes + * self.reset() # <<<<<<<<<<<<<< + * + * cdef reset(self): + */ + __pyx_t_1 = ((struct __pyx_vtabstruct_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *)__pyx_v_self->__pyx_vtab)->reset(__pyx_v_self); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 151, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":149 + * cdef DTYPE_t[:] sizes + * + * def __init__(self, DTYPE_t[:] sizes): # <<<<<<<<<<<<<< + * self.sizes = sizes + * self.reset() + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("fairseq.data.token_block_utils_fast.DatasetSearcher.__init__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "fairseq/data/token_block_utils_fast.pyx":153 + * self.reset() + * + * cdef reset(self): # <<<<<<<<<<<<<< + * self.current_offset = 0 # offset within current index in underlying dataset + * self.current_i = 0 # "flat" index + */ + +static PyObject *__pyx_f_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_reset(struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("reset", 1); + + /* "fairseq/data/token_block_utils_fast.pyx":154 + * + * cdef reset(self): + * self.current_offset = 0 # offset within current index in underlying dataset # <<<<<<<<<<<<<< + * self.current_i = 0 # "flat" index + * self.current_index = 0 # index in underlying dataset + */ + __pyx_v_self->current_offset = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":155 + * cdef reset(self): + * self.current_offset = 0 # offset within current index in underlying dataset + * self.current_i = 0 # "flat" index # <<<<<<<<<<<<<< + * self.current_index = 0 # index in underlying dataset + * + */ + __pyx_v_self->current_i = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":156 + * self.current_offset = 0 # offset within current index in underlying dataset + * self.current_i = 0 # "flat" index + * self.current_index = 0 # index in underlying dataset # <<<<<<<<<<<<<< + * + * @cython.boundscheck(False) + */ + __pyx_v_self->current_index = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":153 + * self.reset() + * + * cdef reset(self): # <<<<<<<<<<<<<< + * self.current_offset = 0 # offset within current index in underlying dataset + * self.current_i = 0 # "flat" index + */ + + /* function exit code */ + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "fairseq/data/token_block_utils_fast.pyx":161 + * @cython.wraparound(False) + * @cython.nonecheck(False) + * cdef int step(self, DTYPE_t i): # <<<<<<<<<<<<<< + * cdef DTYPE_t to_consume + * cdef DTYPE_t remaining + */ + +static int __pyx_f_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_step(struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *__pyx_v_self, __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_v_i) { + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_v_to_consume; + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_v_remaining; + int __pyx_r; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + Py_ssize_t __pyx_t_3; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("step", 1); + + /* "fairseq/data/token_block_utils_fast.pyx":164 + * cdef DTYPE_t to_consume + * cdef DTYPE_t remaining + * if i < self.current_i: # <<<<<<<<<<<<<< + * self.reset() + * if i > self.current_i: + */ + __pyx_t_1 = (__pyx_v_i < __pyx_v_self->current_i); + if (__pyx_t_1) { + + /* "fairseq/data/token_block_utils_fast.pyx":165 + * cdef DTYPE_t remaining + * if i < self.current_i: + * self.reset() # <<<<<<<<<<<<<< + * if i > self.current_i: + * to_consume = i - self.current_i + */ + __pyx_t_2 = ((struct __pyx_vtabstruct_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *)__pyx_v_self->__pyx_vtab)->reset(__pyx_v_self); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 165, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":164 + * cdef DTYPE_t to_consume + * cdef DTYPE_t remaining + * if i < self.current_i: # <<<<<<<<<<<<<< + * self.reset() + * if i > self.current_i: + */ + } + + /* "fairseq/data/token_block_utils_fast.pyx":166 + * if i < self.current_i: + * self.reset() + * if i > self.current_i: # <<<<<<<<<<<<<< + * to_consume = i - self.current_i + * remaining = self.sizes[self.current_index] - self.current_offset + */ + __pyx_t_1 = (__pyx_v_i > __pyx_v_self->current_i); + if (__pyx_t_1) { + + /* "fairseq/data/token_block_utils_fast.pyx":167 + * self.reset() + * if i > self.current_i: + * to_consume = i - self.current_i # <<<<<<<<<<<<<< + * remaining = self.sizes[self.current_index] - self.current_offset + * if remaining > to_consume: + */ + __pyx_v_to_consume = (__pyx_v_i - __pyx_v_self->current_i); + + /* "fairseq/data/token_block_utils_fast.pyx":168 + * if i > self.current_i: + * to_consume = i - self.current_i + * remaining = self.sizes[self.current_index] - self.current_offset # <<<<<<<<<<<<<< + * if remaining > to_consume: + * self.current_offset += to_consume + */ + if (unlikely(!__pyx_v_self->sizes.memview)) {PyErr_SetString(PyExc_AttributeError,"Memoryview is not initialized");__PYX_ERR(0, 168, __pyx_L1_error)} + __pyx_t_3 = __pyx_v_self->current_index; + __pyx_v_remaining = ((*((__pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t *) ( /* dim=0 */ (__pyx_v_self->sizes.data + __pyx_t_3 * __pyx_v_self->sizes.strides[0]) ))) - __pyx_v_self->current_offset); + + /* "fairseq/data/token_block_utils_fast.pyx":169 + * to_consume = i - self.current_i + * remaining = self.sizes[self.current_index] - self.current_offset + * if remaining > to_consume: # <<<<<<<<<<<<<< + * self.current_offset += to_consume + * self.current_i += to_consume + */ + __pyx_t_1 = (__pyx_v_remaining > __pyx_v_to_consume); + if (__pyx_t_1) { + + /* "fairseq/data/token_block_utils_fast.pyx":170 + * remaining = self.sizes[self.current_index] - self.current_offset + * if remaining > to_consume: + * self.current_offset += to_consume # <<<<<<<<<<<<<< + * self.current_i += to_consume + * else: + */ + __pyx_v_self->current_offset = (__pyx_v_self->current_offset + __pyx_v_to_consume); + + /* "fairseq/data/token_block_utils_fast.pyx":171 + * if remaining > to_consume: + * self.current_offset += to_consume + * self.current_i += to_consume # <<<<<<<<<<<<<< + * else: + * assert remaining >= 0 + */ + __pyx_v_self->current_i = (__pyx_v_self->current_i + __pyx_v_to_consume); + + /* "fairseq/data/token_block_utils_fast.pyx":169 + * to_consume = i - self.current_i + * remaining = self.sizes[self.current_index] - self.current_offset + * if remaining > to_consume: # <<<<<<<<<<<<<< + * self.current_offset += to_consume + * self.current_i += to_consume + */ + goto __pyx_L5; + } + + /* "fairseq/data/token_block_utils_fast.pyx":173 + * self.current_i += to_consume + * else: + * assert remaining >= 0 # <<<<<<<<<<<<<< + * self.current_i += remaining + * self.current_index += 1 + */ + /*else*/ { + #ifndef CYTHON_WITHOUT_ASSERTIONS + if (unlikely(__pyx_assertions_enabled())) { + __pyx_t_1 = (__pyx_v_remaining >= 0); + if (unlikely(!__pyx_t_1)) { + __Pyx_Raise(__pyx_builtin_AssertionError, 0, 0, 0); + __PYX_ERR(0, 173, __pyx_L1_error) + } + } + #else + if ((1)); else __PYX_ERR(0, 173, __pyx_L1_error) + #endif + + /* "fairseq/data/token_block_utils_fast.pyx":174 + * else: + * assert remaining >= 0 + * self.current_i += remaining # <<<<<<<<<<<<<< + * self.current_index += 1 + * self.current_offset = 0 + */ + __pyx_v_self->current_i = (__pyx_v_self->current_i + __pyx_v_remaining); + + /* "fairseq/data/token_block_utils_fast.pyx":175 + * assert remaining >= 0 + * self.current_i += remaining + * self.current_index += 1 # <<<<<<<<<<<<<< + * self.current_offset = 0 + * return 1 + */ + __pyx_v_self->current_index = (__pyx_v_self->current_index + 1); + + /* "fairseq/data/token_block_utils_fast.pyx":176 + * self.current_i += remaining + * self.current_index += 1 + * self.current_offset = 0 # <<<<<<<<<<<<<< + * return 1 + * return 0 + */ + __pyx_v_self->current_offset = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":177 + * self.current_index += 1 + * self.current_offset = 0 + * return 1 # <<<<<<<<<<<<<< + * return 0 + * + */ + __pyx_r = 1; + goto __pyx_L0; + } + __pyx_L5:; + + /* "fairseq/data/token_block_utils_fast.pyx":166 + * if i < self.current_i: + * self.reset() + * if i > self.current_i: # <<<<<<<<<<<<<< + * to_consume = i - self.current_i + * remaining = self.sizes[self.current_index] - self.current_offset + */ + } + + /* "fairseq/data/token_block_utils_fast.pyx":178 + * self.current_offset = 0 + * return 1 + * return 0 # <<<<<<<<<<<<<< + * + * @cython.boundscheck(False) + */ + __pyx_r = 0; + goto __pyx_L0; + + /* "fairseq/data/token_block_utils_fast.pyx":161 + * @cython.wraparound(False) + * @cython.nonecheck(False) + * cdef int step(self, DTYPE_t i): # <<<<<<<<<<<<<< + * cdef DTYPE_t to_consume + * cdef DTYPE_t remaining + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("fairseq.data.token_block_utils_fast.DatasetSearcher.step", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "fairseq/data/token_block_utils_fast.pyx":183 + * @cython.wraparound(False) + * @cython.nonecheck(False) + * cdef seek(self, DTYPE_t i): # <<<<<<<<<<<<<< + * cdef int not_done = 1 + * while not_done == 1: + */ + +static PyObject *__pyx_f_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_seek(struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *__pyx_v_self, __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_v_i) { + int __pyx_v_not_done; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + int __pyx_t_2; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("seek", 1); + + /* "fairseq/data/token_block_utils_fast.pyx":184 + * @cython.nonecheck(False) + * cdef seek(self, DTYPE_t i): + * cdef int not_done = 1 # <<<<<<<<<<<<<< + * while not_done == 1: + * not_done = self.step(i) + */ + __pyx_v_not_done = 1; + + /* "fairseq/data/token_block_utils_fast.pyx":185 + * cdef seek(self, DTYPE_t i): + * cdef int not_done = 1 + * while not_done == 1: # <<<<<<<<<<<<<< + * not_done = self.step(i) + * assert self.current_i == i + */ + while (1) { + __pyx_t_1 = (__pyx_v_not_done == 1); + if (!__pyx_t_1) break; + + /* "fairseq/data/token_block_utils_fast.pyx":186 + * cdef int not_done = 1 + * while not_done == 1: + * not_done = self.step(i) # <<<<<<<<<<<<<< + * assert self.current_i == i + */ + __pyx_t_2 = ((struct __pyx_vtabstruct_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *)__pyx_v_self->__pyx_vtab)->step(__pyx_v_self, __pyx_v_i); if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 186, __pyx_L1_error) + __pyx_v_not_done = __pyx_t_2; + } + + /* "fairseq/data/token_block_utils_fast.pyx":187 + * while not_done == 1: + * not_done = self.step(i) + * assert self.current_i == i # <<<<<<<<<<<<<< + */ + #ifndef CYTHON_WITHOUT_ASSERTIONS + if (unlikely(__pyx_assertions_enabled())) { + __pyx_t_1 = (__pyx_v_self->current_i == __pyx_v_i); + if (unlikely(!__pyx_t_1)) { + __Pyx_Raise(__pyx_builtin_AssertionError, 0, 0, 0); + __PYX_ERR(0, 187, __pyx_L1_error) + } + } + #else + if ((1)); else __PYX_ERR(0, 187, __pyx_L1_error) + #endif + + /* "fairseq/data/token_block_utils_fast.pyx":183 + * @cython.wraparound(False) + * @cython.nonecheck(False) + * cdef seek(self, DTYPE_t i): # <<<<<<<<<<<<<< + * cdef int not_done = 1 + * while not_done == 1: + */ + + /* function exit code */ + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_AddTraceback("fairseq.data.token_block_utils_fast.DatasetSearcher.seek", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * cdef tuple state + * cdef object _dict + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_3__reduce_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyMethodDef __pyx_mdef_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_3__reduce_cython__ = {"__reduce_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_3__reduce_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}; +static PyObject *__pyx_pw_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_3__reduce_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__reduce_cython__ (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + if (unlikely(__pyx_nargs > 0)) { + __Pyx_RaiseArgtupleInvalid("__reduce_cython__", 1, 0, 0, __pyx_nargs); return NULL;} + if (unlikely(__pyx_kwds) && __Pyx_NumKwargs_FASTCALL(__pyx_kwds) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "__reduce_cython__", 0))) return NULL; + __pyx_r = __pyx_pf_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_2__reduce_cython__(((struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_2__reduce_cython__(struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *__pyx_v_self) { + PyObject *__pyx_v_state = 0; + PyObject *__pyx_v__dict = 0; + int __pyx_v_use_setstate; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + PyObject *__pyx_t_5 = NULL; + int __pyx_t_6; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__reduce_cython__", 1); + + /* "(tree fragment)":5 + * cdef object _dict + * cdef bint use_setstate + * state = (self.current_i, self.current_index, self.current_offset, self.sizes) # <<<<<<<<<<<<<< + * _dict = getattr(self, '__dict__', None) + * if _dict is not None: + */ + __pyx_t_1 = __Pyx_PyInt_From_int64_t(__pyx_v_self->current_i); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 5, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = __Pyx_PyInt_From_int64_t(__pyx_v_self->current_index); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 5, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_3 = __Pyx_PyInt_From_int64_t(__pyx_v_self->current_offset); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 5, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + if (unlikely(!__pyx_v_self->sizes.memview)) {PyErr_SetString(PyExc_AttributeError,"Memoryview is not initialized");__PYX_ERR(1, 5, __pyx_L1_error)} + __pyx_t_4 = __pyx_memoryview_fromslice(__pyx_v_self->sizes, 1, (PyObject *(*)(char *)) __pyx_memview_get_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t, (int (*)(char *, PyObject *)) __pyx_memview_set_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t, 0);; if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 5, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_5 = PyTuple_New(4); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 5, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_GIVEREF(__pyx_t_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_5, 0, __pyx_t_1)) __PYX_ERR(1, 5, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_2); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_5, 1, __pyx_t_2)) __PYX_ERR(1, 5, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_3); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_5, 2, __pyx_t_3)) __PYX_ERR(1, 5, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_4); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_5, 3, __pyx_t_4)) __PYX_ERR(1, 5, __pyx_L1_error); + __pyx_t_1 = 0; + __pyx_t_2 = 0; + __pyx_t_3 = 0; + __pyx_t_4 = 0; + __pyx_v_state = ((PyObject*)__pyx_t_5); + __pyx_t_5 = 0; + + /* "(tree fragment)":6 + * cdef bint use_setstate + * state = (self.current_i, self.current_index, self.current_offset, self.sizes) + * _dict = getattr(self, '__dict__', None) # <<<<<<<<<<<<<< + * if _dict is not None: + * state += (_dict,) + */ + __pyx_t_5 = __Pyx_GetAttr3(((PyObject *)__pyx_v_self), __pyx_n_s_dict, Py_None); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 6, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_v__dict = __pyx_t_5; + __pyx_t_5 = 0; + + /* "(tree fragment)":7 + * state = (self.current_i, self.current_index, self.current_offset, self.sizes) + * _dict = getattr(self, '__dict__', None) + * if _dict is not None: # <<<<<<<<<<<<<< + * state += (_dict,) + * use_setstate = True + */ + __pyx_t_6 = (__pyx_v__dict != Py_None); + if (__pyx_t_6) { + + /* "(tree fragment)":8 + * _dict = getattr(self, '__dict__', None) + * if _dict is not None: + * state += (_dict,) # <<<<<<<<<<<<<< + * use_setstate = True + * else: + */ + __pyx_t_5 = PyTuple_New(1); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 8, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_INCREF(__pyx_v__dict); + __Pyx_GIVEREF(__pyx_v__dict); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_5, 0, __pyx_v__dict)) __PYX_ERR(1, 8, __pyx_L1_error); + __pyx_t_4 = PyNumber_InPlaceAdd(__pyx_v_state, __pyx_t_5); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 8, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_DECREF_SET(__pyx_v_state, ((PyObject*)__pyx_t_4)); + __pyx_t_4 = 0; + + /* "(tree fragment)":9 + * if _dict is not None: + * state += (_dict,) + * use_setstate = True # <<<<<<<<<<<<<< + * else: + * use_setstate = False + */ + __pyx_v_use_setstate = 1; + + /* "(tree fragment)":7 + * state = (self.current_i, self.current_index, self.current_offset, self.sizes) + * _dict = getattr(self, '__dict__', None) + * if _dict is not None: # <<<<<<<<<<<<<< + * state += (_dict,) + * use_setstate = True + */ + goto __pyx_L3; + } + + /* "(tree fragment)":11 + * use_setstate = True + * else: + * use_setstate = False # <<<<<<<<<<<<<< + * if use_setstate: + * return __pyx_unpickle_DatasetSearcher, (type(self), 0x8c67b45, None), state + */ + /*else*/ { + __pyx_v_use_setstate = 0; + } + __pyx_L3:; + + /* "(tree fragment)":12 + * else: + * use_setstate = False + * if use_setstate: # <<<<<<<<<<<<<< + * return __pyx_unpickle_DatasetSearcher, (type(self), 0x8c67b45, None), state + * else: + */ + if (__pyx_v_use_setstate) { + + /* "(tree fragment)":13 + * use_setstate = False + * if use_setstate: + * return __pyx_unpickle_DatasetSearcher, (type(self), 0x8c67b45, None), state # <<<<<<<<<<<<<< + * else: + * return __pyx_unpickle_DatasetSearcher, (type(self), 0x8c67b45, state) + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_GetModuleGlobalName(__pyx_t_4, __pyx_n_s_pyx_unpickle_DatasetSearcher); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 13, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_5 = PyTuple_New(3); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 13, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_INCREF(((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); + __Pyx_GIVEREF(((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_5, 0, ((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self))))) __PYX_ERR(1, 13, __pyx_L1_error); + __Pyx_INCREF(__pyx_int_147225413); + __Pyx_GIVEREF(__pyx_int_147225413); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_5, 1, __pyx_int_147225413)) __PYX_ERR(1, 13, __pyx_L1_error); + __Pyx_INCREF(Py_None); + __Pyx_GIVEREF(Py_None); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_5, 2, Py_None)) __PYX_ERR(1, 13, __pyx_L1_error); + __pyx_t_3 = PyTuple_New(3); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 13, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_GIVEREF(__pyx_t_4); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_t_4)) __PYX_ERR(1, 13, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_5); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_t_5)) __PYX_ERR(1, 13, __pyx_L1_error); + __Pyx_INCREF(__pyx_v_state); + __Pyx_GIVEREF(__pyx_v_state); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 2, __pyx_v_state)) __PYX_ERR(1, 13, __pyx_L1_error); + __pyx_t_4 = 0; + __pyx_t_5 = 0; + __pyx_r = __pyx_t_3; + __pyx_t_3 = 0; + goto __pyx_L0; + + /* "(tree fragment)":12 + * else: + * use_setstate = False + * if use_setstate: # <<<<<<<<<<<<<< + * return __pyx_unpickle_DatasetSearcher, (type(self), 0x8c67b45, None), state + * else: + */ + } + + /* "(tree fragment)":15 + * return __pyx_unpickle_DatasetSearcher, (type(self), 0x8c67b45, None), state + * else: + * return __pyx_unpickle_DatasetSearcher, (type(self), 0x8c67b45, state) # <<<<<<<<<<<<<< + * def __setstate_cython__(self, __pyx_state): + * __pyx_unpickle_DatasetSearcher__set_state(self, __pyx_state) + */ + /*else*/ { + __Pyx_XDECREF(__pyx_r); + __Pyx_GetModuleGlobalName(__pyx_t_3, __pyx_n_s_pyx_unpickle_DatasetSearcher); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 15, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_5 = PyTuple_New(3); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 15, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_INCREF(((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); + __Pyx_GIVEREF(((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_5, 0, ((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self))))) __PYX_ERR(1, 15, __pyx_L1_error); + __Pyx_INCREF(__pyx_int_147225413); + __Pyx_GIVEREF(__pyx_int_147225413); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_5, 1, __pyx_int_147225413)) __PYX_ERR(1, 15, __pyx_L1_error); + __Pyx_INCREF(__pyx_v_state); + __Pyx_GIVEREF(__pyx_v_state); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_5, 2, __pyx_v_state)) __PYX_ERR(1, 15, __pyx_L1_error); + __pyx_t_4 = PyTuple_New(2); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 15, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_GIVEREF(__pyx_t_3); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 0, __pyx_t_3)) __PYX_ERR(1, 15, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_5); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 1, __pyx_t_5)) __PYX_ERR(1, 15, __pyx_L1_error); + __pyx_t_3 = 0; + __pyx_t_5 = 0; + __pyx_r = __pyx_t_4; + __pyx_t_4 = 0; + goto __pyx_L0; + } + + /* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * cdef tuple state + * cdef object _dict + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_XDECREF(__pyx_t_4); + __Pyx_XDECREF(__pyx_t_5); + __Pyx_AddTraceback("fairseq.data.token_block_utils_fast.DatasetSearcher.__reduce_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_state); + __Pyx_XDECREF(__pyx_v__dict); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":16 + * else: + * return __pyx_unpickle_DatasetSearcher, (type(self), 0x8c67b45, state) + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * __pyx_unpickle_DatasetSearcher__set_state(self, __pyx_state) + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_5__setstate_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyMethodDef __pyx_mdef_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_5__setstate_cython__ = {"__setstate_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_5__setstate_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}; +static PyObject *__pyx_pw_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_5__setstate_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + PyObject *__pyx_v___pyx_state = 0; + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[1] = {0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__setstate_cython__ (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_pyx_state,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 1: values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_FASTCALL(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_pyx_state)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 16, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "__setstate_cython__") < 0)) __PYX_ERR(1, 16, __pyx_L3_error) + } + } else if (unlikely(__pyx_nargs != 1)) { + goto __pyx_L5_argtuple_error; + } else { + values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + } + __pyx_v___pyx_state = values[0]; + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("__setstate_cython__", 1, 1, 1, __pyx_nargs); __PYX_ERR(1, 16, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("fairseq.data.token_block_utils_fast.DatasetSearcher.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return NULL; + __pyx_L4_argument_unpacking_done:; + __pyx_r = __pyx_pf_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_4__setstate_cython__(((struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *)__pyx_v_self), __pyx_v___pyx_state); + + /* function exit code */ + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_4__setstate_cython__(struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *__pyx_v_self, PyObject *__pyx_v___pyx_state) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__setstate_cython__", 1); + + /* "(tree fragment)":17 + * return __pyx_unpickle_DatasetSearcher, (type(self), 0x8c67b45, state) + * def __setstate_cython__(self, __pyx_state): + * __pyx_unpickle_DatasetSearcher__set_state(self, __pyx_state) # <<<<<<<<<<<<<< + */ + if (!(likely(PyTuple_CheckExact(__pyx_v___pyx_state))||((__pyx_v___pyx_state) == Py_None) || __Pyx_RaiseUnexpectedTypeError("tuple", __pyx_v___pyx_state))) __PYX_ERR(1, 17, __pyx_L1_error) + __pyx_t_1 = __pyx_f_7fairseq_4data_22token_block_utils_fast___pyx_unpickle_DatasetSearcher__set_state(__pyx_v_self, ((PyObject*)__pyx_v___pyx_state)); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 17, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + + /* "(tree fragment)":16 + * else: + * return __pyx_unpickle_DatasetSearcher, (type(self), 0x8c67b45, state) + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * __pyx_unpickle_DatasetSearcher__set_state(self, __pyx_state) + */ + + /* function exit code */ + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("fairseq.data.token_block_utils_fast.DatasetSearcher.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":1 + * def __pyx_unpickle_DatasetSearcher(__pyx_type, long __pyx_checksum, __pyx_state): # <<<<<<<<<<<<<< + * cdef object __pyx_PickleError + * cdef object __pyx_result + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_7fairseq_4data_22token_block_utils_fast_5__pyx_unpickle_DatasetSearcher(PyObject *__pyx_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyMethodDef __pyx_mdef_7fairseq_4data_22token_block_utils_fast_5__pyx_unpickle_DatasetSearcher = {"__pyx_unpickle_DatasetSearcher", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw_7fairseq_4data_22token_block_utils_fast_5__pyx_unpickle_DatasetSearcher, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}; +static PyObject *__pyx_pw_7fairseq_4data_22token_block_utils_fast_5__pyx_unpickle_DatasetSearcher(PyObject *__pyx_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + PyObject *__pyx_v___pyx_type = 0; + long __pyx_v___pyx_checksum; + PyObject *__pyx_v___pyx_state = 0; + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[3] = {0,0,0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__pyx_unpickle_DatasetSearcher (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_pyx_type,&__pyx_n_s_pyx_checksum,&__pyx_n_s_pyx_state,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 3: values[2] = __Pyx_Arg_FASTCALL(__pyx_args, 2); + CYTHON_FALLTHROUGH; + case 2: values[1] = __Pyx_Arg_FASTCALL(__pyx_args, 1); + CYTHON_FALLTHROUGH; + case 1: values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_FASTCALL(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_pyx_type)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 1, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + CYTHON_FALLTHROUGH; + case 1: + if (likely((values[1] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_pyx_checksum)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[1]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 1, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("__pyx_unpickle_DatasetSearcher", 1, 3, 3, 1); __PYX_ERR(1, 1, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 2: + if (likely((values[2] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_pyx_state)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[2]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 1, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("__pyx_unpickle_DatasetSearcher", 1, 3, 3, 2); __PYX_ERR(1, 1, __pyx_L3_error) + } + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "__pyx_unpickle_DatasetSearcher") < 0)) __PYX_ERR(1, 1, __pyx_L3_error) + } + } else if (unlikely(__pyx_nargs != 3)) { + goto __pyx_L5_argtuple_error; + } else { + values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + values[1] = __Pyx_Arg_FASTCALL(__pyx_args, 1); + values[2] = __Pyx_Arg_FASTCALL(__pyx_args, 2); + } + __pyx_v___pyx_type = values[0]; + __pyx_v___pyx_checksum = __Pyx_PyInt_As_long(values[1]); if (unlikely((__pyx_v___pyx_checksum == (long)-1) && PyErr_Occurred())) __PYX_ERR(1, 1, __pyx_L3_error) + __pyx_v___pyx_state = values[2]; + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("__pyx_unpickle_DatasetSearcher", 1, 3, 3, __pyx_nargs); __PYX_ERR(1, 1, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("fairseq.data.token_block_utils_fast.__pyx_unpickle_DatasetSearcher", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return NULL; + __pyx_L4_argument_unpacking_done:; + __pyx_r = __pyx_pf_7fairseq_4data_22token_block_utils_fast_4__pyx_unpickle_DatasetSearcher(__pyx_self, __pyx_v___pyx_type, __pyx_v___pyx_checksum, __pyx_v___pyx_state); + + /* function exit code */ + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_7fairseq_4data_22token_block_utils_fast_4__pyx_unpickle_DatasetSearcher(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v___pyx_type, long __pyx_v___pyx_checksum, PyObject *__pyx_v___pyx_state) { + PyObject *__pyx_v___pyx_PickleError = 0; + PyObject *__pyx_v___pyx_result = 0; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + int __pyx_t_5; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__pyx_unpickle_DatasetSearcher", 1); + + /* "(tree fragment)":4 + * cdef object __pyx_PickleError + * cdef object __pyx_result + * if __pyx_checksum not in (0x8c67b45, 0x2e2dd22, 0x6632805): # <<<<<<<<<<<<<< + * from pickle import PickleError as __pyx_PickleError + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x8c67b45, 0x2e2dd22, 0x6632805) = (current_i, current_index, current_offset, sizes))" % __pyx_checksum + */ + __pyx_t_1 = __Pyx_PyInt_From_long(__pyx_v___pyx_checksum); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 4, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = (__Pyx_PySequence_ContainsTF(__pyx_t_1, __pyx_tuple__14, Py_NE)); if (unlikely((__pyx_t_2 < 0))) __PYX_ERR(1, 4, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + if (__pyx_t_2) { + + /* "(tree fragment)":5 + * cdef object __pyx_result + * if __pyx_checksum not in (0x8c67b45, 0x2e2dd22, 0x6632805): + * from pickle import PickleError as __pyx_PickleError # <<<<<<<<<<<<<< + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x8c67b45, 0x2e2dd22, 0x6632805) = (current_i, current_index, current_offset, sizes))" % __pyx_checksum + * __pyx_result = DatasetSearcher.__new__(__pyx_type) + */ + __pyx_t_1 = PyList_New(1); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 5, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_INCREF(__pyx_n_s_PickleError); + __Pyx_GIVEREF(__pyx_n_s_PickleError); + if (__Pyx_PyList_SET_ITEM(__pyx_t_1, 0, __pyx_n_s_PickleError)) __PYX_ERR(1, 5, __pyx_L1_error); + __pyx_t_3 = __Pyx_Import(__pyx_n_s_pickle, __pyx_t_1, 0); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 5, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_t_1 = __Pyx_ImportFrom(__pyx_t_3, __pyx_n_s_PickleError); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 5, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_INCREF(__pyx_t_1); + __pyx_v___pyx_PickleError = __pyx_t_1; + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + + /* "(tree fragment)":6 + * if __pyx_checksum not in (0x8c67b45, 0x2e2dd22, 0x6632805): + * from pickle import PickleError as __pyx_PickleError + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x8c67b45, 0x2e2dd22, 0x6632805) = (current_i, current_index, current_offset, sizes))" % __pyx_checksum # <<<<<<<<<<<<<< + * __pyx_result = DatasetSearcher.__new__(__pyx_type) + * if __pyx_state is not None: + */ + __pyx_t_3 = __Pyx_PyInt_From_long(__pyx_v___pyx_checksum); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 6, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_1 = __Pyx_PyString_Format(__pyx_kp_s_Incompatible_checksums_0x_x_vs_0_2, __pyx_t_3); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 6, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __Pyx_Raise(__pyx_v___pyx_PickleError, __pyx_t_1, 0, 0); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __PYX_ERR(1, 6, __pyx_L1_error) + + /* "(tree fragment)":4 + * cdef object __pyx_PickleError + * cdef object __pyx_result + * if __pyx_checksum not in (0x8c67b45, 0x2e2dd22, 0x6632805): # <<<<<<<<<<<<<< + * from pickle import PickleError as __pyx_PickleError + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x8c67b45, 0x2e2dd22, 0x6632805) = (current_i, current_index, current_offset, sizes))" % __pyx_checksum + */ + } + + /* "(tree fragment)":7 + * from pickle import PickleError as __pyx_PickleError + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x8c67b45, 0x2e2dd22, 0x6632805) = (current_i, current_index, current_offset, sizes))" % __pyx_checksum + * __pyx_result = DatasetSearcher.__new__(__pyx_type) # <<<<<<<<<<<<<< + * if __pyx_state is not None: + * __pyx_unpickle_DatasetSearcher__set_state( __pyx_result, __pyx_state) + */ + __pyx_t_3 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_ptype_7fairseq_4data_22token_block_utils_fast_DatasetSearcher), __pyx_n_s_new); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 7, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_4 = NULL; + __pyx_t_5 = 0; + #if CYTHON_UNPACK_METHODS + if (likely(PyMethod_Check(__pyx_t_3))) { + __pyx_t_4 = PyMethod_GET_SELF(__pyx_t_3); + if (likely(__pyx_t_4)) { + PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_3); + __Pyx_INCREF(__pyx_t_4); + __Pyx_INCREF(function); + __Pyx_DECREF_SET(__pyx_t_3, function); + __pyx_t_5 = 1; + } + } + #endif + { + PyObject *__pyx_callargs[2] = {__pyx_t_4, __pyx_v___pyx_type}; + __pyx_t_1 = __Pyx_PyObject_FastCall(__pyx_t_3, __pyx_callargs+1-__pyx_t_5, 1+__pyx_t_5); + __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; + if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 7, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + } + __pyx_v___pyx_result = __pyx_t_1; + __pyx_t_1 = 0; + + /* "(tree fragment)":8 + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x8c67b45, 0x2e2dd22, 0x6632805) = (current_i, current_index, current_offset, sizes))" % __pyx_checksum + * __pyx_result = DatasetSearcher.__new__(__pyx_type) + * if __pyx_state is not None: # <<<<<<<<<<<<<< + * __pyx_unpickle_DatasetSearcher__set_state( __pyx_result, __pyx_state) + * return __pyx_result + */ + __pyx_t_2 = (__pyx_v___pyx_state != Py_None); + if (__pyx_t_2) { + + /* "(tree fragment)":9 + * __pyx_result = DatasetSearcher.__new__(__pyx_type) + * if __pyx_state is not None: + * __pyx_unpickle_DatasetSearcher__set_state( __pyx_result, __pyx_state) # <<<<<<<<<<<<<< + * return __pyx_result + * cdef __pyx_unpickle_DatasetSearcher__set_state(DatasetSearcher __pyx_result, tuple __pyx_state): + */ + if (!(likely(PyTuple_CheckExact(__pyx_v___pyx_state))||((__pyx_v___pyx_state) == Py_None) || __Pyx_RaiseUnexpectedTypeError("tuple", __pyx_v___pyx_state))) __PYX_ERR(1, 9, __pyx_L1_error) + __pyx_t_1 = __pyx_f_7fairseq_4data_22token_block_utils_fast___pyx_unpickle_DatasetSearcher__set_state(((struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *)__pyx_v___pyx_result), ((PyObject*)__pyx_v___pyx_state)); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 9, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + + /* "(tree fragment)":8 + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x8c67b45, 0x2e2dd22, 0x6632805) = (current_i, current_index, current_offset, sizes))" % __pyx_checksum + * __pyx_result = DatasetSearcher.__new__(__pyx_type) + * if __pyx_state is not None: # <<<<<<<<<<<<<< + * __pyx_unpickle_DatasetSearcher__set_state( __pyx_result, __pyx_state) + * return __pyx_result + */ + } + + /* "(tree fragment)":10 + * if __pyx_state is not None: + * __pyx_unpickle_DatasetSearcher__set_state( __pyx_result, __pyx_state) + * return __pyx_result # <<<<<<<<<<<<<< + * cdef __pyx_unpickle_DatasetSearcher__set_state(DatasetSearcher __pyx_result, tuple __pyx_state): + * __pyx_result.current_i = __pyx_state[0]; __pyx_result.current_index = __pyx_state[1]; __pyx_result.current_offset = __pyx_state[2]; __pyx_result.sizes = __pyx_state[3] + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(__pyx_v___pyx_result); + __pyx_r = __pyx_v___pyx_result; + goto __pyx_L0; + + /* "(tree fragment)":1 + * def __pyx_unpickle_DatasetSearcher(__pyx_type, long __pyx_checksum, __pyx_state): # <<<<<<<<<<<<<< + * cdef object __pyx_PickleError + * cdef object __pyx_result + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_XDECREF(__pyx_t_4); + __Pyx_AddTraceback("fairseq.data.token_block_utils_fast.__pyx_unpickle_DatasetSearcher", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v___pyx_PickleError); + __Pyx_XDECREF(__pyx_v___pyx_result); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":11 + * __pyx_unpickle_DatasetSearcher__set_state( __pyx_result, __pyx_state) + * return __pyx_result + * cdef __pyx_unpickle_DatasetSearcher__set_state(DatasetSearcher __pyx_result, tuple __pyx_state): # <<<<<<<<<<<<<< + * __pyx_result.current_i = __pyx_state[0]; __pyx_result.current_index = __pyx_state[1]; __pyx_result.current_offset = __pyx_state[2]; __pyx_result.sizes = __pyx_state[3] + * if len(__pyx_state) > 4 and hasattr(__pyx_result, '__dict__'): + */ + +static PyObject *__pyx_f_7fairseq_4data_22token_block_utils_fast___pyx_unpickle_DatasetSearcher__set_state(struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *__pyx_v___pyx_result, PyObject *__pyx_v___pyx_state) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t __pyx_t_2; + __Pyx_memviewslice __pyx_t_3 = { 0, 0, { 0 }, { 0 }, { 0 } }; + int __pyx_t_4; + Py_ssize_t __pyx_t_5; + int __pyx_t_6; + PyObject *__pyx_t_7 = NULL; + PyObject *__pyx_t_8 = NULL; + PyObject *__pyx_t_9 = NULL; + int __pyx_t_10; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__pyx_unpickle_DatasetSearcher__set_state", 1); + + /* "(tree fragment)":12 + * return __pyx_result + * cdef __pyx_unpickle_DatasetSearcher__set_state(DatasetSearcher __pyx_result, tuple __pyx_state): + * __pyx_result.current_i = __pyx_state[0]; __pyx_result.current_index = __pyx_state[1]; __pyx_result.current_offset = __pyx_state[2]; __pyx_result.sizes = __pyx_state[3] # <<<<<<<<<<<<<< + * if len(__pyx_state) > 4 and hasattr(__pyx_result, '__dict__'): + * __pyx_result.__dict__.update(__pyx_state[4]) + */ + if (unlikely(__pyx_v___pyx_state == Py_None)) { + PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); + __PYX_ERR(1, 12, __pyx_L1_error) + } + __pyx_t_1 = __Pyx_GetItemInt_Tuple(__pyx_v___pyx_state, 0, long, 1, __Pyx_PyInt_From_long, 0, 0, 1); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 12, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = __Pyx_PyInt_As_int64_t(__pyx_t_1); if (unlikely((__pyx_t_2 == ((int64_t)-1)) && PyErr_Occurred())) __PYX_ERR(1, 12, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_v___pyx_result->current_i = __pyx_t_2; + if (unlikely(__pyx_v___pyx_state == Py_None)) { + PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); + __PYX_ERR(1, 12, __pyx_L1_error) + } + __pyx_t_1 = __Pyx_GetItemInt_Tuple(__pyx_v___pyx_state, 1, long, 1, __Pyx_PyInt_From_long, 0, 0, 1); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 12, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = __Pyx_PyInt_As_int64_t(__pyx_t_1); if (unlikely((__pyx_t_2 == ((int64_t)-1)) && PyErr_Occurred())) __PYX_ERR(1, 12, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_v___pyx_result->current_index = __pyx_t_2; + if (unlikely(__pyx_v___pyx_state == Py_None)) { + PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); + __PYX_ERR(1, 12, __pyx_L1_error) + } + __pyx_t_1 = __Pyx_GetItemInt_Tuple(__pyx_v___pyx_state, 2, long, 1, __Pyx_PyInt_From_long, 0, 0, 1); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 12, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = __Pyx_PyInt_As_int64_t(__pyx_t_1); if (unlikely((__pyx_t_2 == ((int64_t)-1)) && PyErr_Occurred())) __PYX_ERR(1, 12, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_v___pyx_result->current_offset = __pyx_t_2; + if (unlikely(__pyx_v___pyx_state == Py_None)) { + PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); + __PYX_ERR(1, 12, __pyx_L1_error) + } + __pyx_t_1 = __Pyx_GetItemInt_Tuple(__pyx_v___pyx_state, 3, long, 1, __Pyx_PyInt_From_long, 0, 0, 1); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 12, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_3 = __Pyx_PyObject_to_MemoryviewSlice_ds_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t(__pyx_t_1, PyBUF_WRITABLE); if (unlikely(!__pyx_t_3.memview)) __PYX_ERR(1, 12, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __PYX_XCLEAR_MEMVIEW(&__pyx_v___pyx_result->sizes, 0); + __pyx_v___pyx_result->sizes = __pyx_t_3; + __pyx_t_3.memview = NULL; + __pyx_t_3.data = NULL; + + /* "(tree fragment)":13 + * cdef __pyx_unpickle_DatasetSearcher__set_state(DatasetSearcher __pyx_result, tuple __pyx_state): + * __pyx_result.current_i = __pyx_state[0]; __pyx_result.current_index = __pyx_state[1]; __pyx_result.current_offset = __pyx_state[2]; __pyx_result.sizes = __pyx_state[3] + * if len(__pyx_state) > 4 and hasattr(__pyx_result, '__dict__'): # <<<<<<<<<<<<<< + * __pyx_result.__dict__.update(__pyx_state[4]) + */ + if (unlikely(__pyx_v___pyx_state == Py_None)) { + PyErr_SetString(PyExc_TypeError, "object of type 'NoneType' has no len()"); + __PYX_ERR(1, 13, __pyx_L1_error) + } + __pyx_t_5 = __Pyx_PyTuple_GET_SIZE(__pyx_v___pyx_state); if (unlikely(__pyx_t_5 == ((Py_ssize_t)-1))) __PYX_ERR(1, 13, __pyx_L1_error) + __pyx_t_6 = (__pyx_t_5 > 4); + if (__pyx_t_6) { + } else { + __pyx_t_4 = __pyx_t_6; + goto __pyx_L4_bool_binop_done; + } + __pyx_t_6 = __Pyx_HasAttr(((PyObject *)__pyx_v___pyx_result), __pyx_n_s_dict); if (unlikely(__pyx_t_6 == ((int)-1))) __PYX_ERR(1, 13, __pyx_L1_error) + __pyx_t_4 = __pyx_t_6; + __pyx_L4_bool_binop_done:; + if (__pyx_t_4) { + + /* "(tree fragment)":14 + * __pyx_result.current_i = __pyx_state[0]; __pyx_result.current_index = __pyx_state[1]; __pyx_result.current_offset = __pyx_state[2]; __pyx_result.sizes = __pyx_state[3] + * if len(__pyx_state) > 4 and hasattr(__pyx_result, '__dict__'): + * __pyx_result.__dict__.update(__pyx_state[4]) # <<<<<<<<<<<<<< + */ + __pyx_t_7 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v___pyx_result), __pyx_n_s_dict); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 14, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __pyx_t_8 = __Pyx_PyObject_GetAttrStr(__pyx_t_7, __pyx_n_s_update); if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 14, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_8); + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + if (unlikely(__pyx_v___pyx_state == Py_None)) { + PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); + __PYX_ERR(1, 14, __pyx_L1_error) + } + __pyx_t_7 = __Pyx_GetItemInt_Tuple(__pyx_v___pyx_state, 4, long, 1, __Pyx_PyInt_From_long, 0, 0, 1); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 14, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __pyx_t_9 = NULL; + __pyx_t_10 = 0; + #if CYTHON_UNPACK_METHODS + if (likely(PyMethod_Check(__pyx_t_8))) { + __pyx_t_9 = PyMethod_GET_SELF(__pyx_t_8); + if (likely(__pyx_t_9)) { + PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_8); + __Pyx_INCREF(__pyx_t_9); + __Pyx_INCREF(function); + __Pyx_DECREF_SET(__pyx_t_8, function); + __pyx_t_10 = 1; + } + } + #endif + { + PyObject *__pyx_callargs[2] = {__pyx_t_9, __pyx_t_7}; + __pyx_t_1 = __Pyx_PyObject_FastCall(__pyx_t_8, __pyx_callargs+1-__pyx_t_10, 1+__pyx_t_10); + __Pyx_XDECREF(__pyx_t_9); __pyx_t_9 = 0; + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 14, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + } + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + + /* "(tree fragment)":13 + * cdef __pyx_unpickle_DatasetSearcher__set_state(DatasetSearcher __pyx_result, tuple __pyx_state): + * __pyx_result.current_i = __pyx_state[0]; __pyx_result.current_index = __pyx_state[1]; __pyx_result.current_offset = __pyx_state[2]; __pyx_result.sizes = __pyx_state[3] + * if len(__pyx_state) > 4 and hasattr(__pyx_result, '__dict__'): # <<<<<<<<<<<<<< + * __pyx_result.__dict__.update(__pyx_state[4]) + */ + } + + /* "(tree fragment)":11 + * __pyx_unpickle_DatasetSearcher__set_state( __pyx_result, __pyx_state) + * return __pyx_result + * cdef __pyx_unpickle_DatasetSearcher__set_state(DatasetSearcher __pyx_result, tuple __pyx_state): # <<<<<<<<<<<<<< + * __pyx_result.current_i = __pyx_state[0]; __pyx_result.current_index = __pyx_state[1]; __pyx_result.current_offset = __pyx_state[2]; __pyx_result.sizes = __pyx_state[3] + * if len(__pyx_state) > 4 and hasattr(__pyx_result, '__dict__'): + */ + + /* function exit code */ + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __PYX_XCLEAR_MEMVIEW(&__pyx_t_3, 1); + __Pyx_XDECREF(__pyx_t_7); + __Pyx_XDECREF(__pyx_t_8); + __Pyx_XDECREF(__pyx_t_9); + __Pyx_AddTraceback("fairseq.data.token_block_utils_fast.__pyx_unpickle_DatasetSearcher__set_state", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} +static struct __pyx_vtabstruct_7fairseq_4data_22token_block_utils_fast_DatasetSearcher __pyx_vtable_7fairseq_4data_22token_block_utils_fast_DatasetSearcher; + +static PyObject *__pyx_tp_new_7fairseq_4data_22token_block_utils_fast_DatasetSearcher(PyTypeObject *t, CYTHON_UNUSED PyObject *a, CYTHON_UNUSED PyObject *k) { + struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *p; + PyObject *o; + #if CYTHON_COMPILING_IN_LIMITED_API + allocfunc alloc_func = (allocfunc)PyType_GetSlot(t, Py_tp_alloc); + o = alloc_func(t, 0); + #else + if (likely(!__Pyx_PyType_HasFeature(t, Py_TPFLAGS_IS_ABSTRACT))) { + o = (*t->tp_alloc)(t, 0); + } else { + o = (PyObject *) PyBaseObject_Type.tp_new(t, __pyx_empty_tuple, 0); + } + if (unlikely(!o)) return 0; + #endif + p = ((struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *)o); + p->__pyx_vtab = __pyx_vtabptr_7fairseq_4data_22token_block_utils_fast_DatasetSearcher; + p->sizes.data = NULL; + p->sizes.memview = NULL; + return o; +} + +static void __pyx_tp_dealloc_7fairseq_4data_22token_block_utils_fast_DatasetSearcher(PyObject *o) { + struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *p = (struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *)o; + #if CYTHON_USE_TP_FINALIZE + if (unlikely((PY_VERSION_HEX >= 0x03080000 || __Pyx_PyType_HasFeature(Py_TYPE(o), Py_TPFLAGS_HAVE_FINALIZE)) && __Pyx_PyObject_GetSlot(o, tp_finalize, destructor)) && (!PyType_IS_GC(Py_TYPE(o)) || !__Pyx_PyObject_GC_IsFinalized(o))) { + if (__Pyx_PyObject_GetSlot(o, tp_dealloc, destructor) == __pyx_tp_dealloc_7fairseq_4data_22token_block_utils_fast_DatasetSearcher) { + if (PyObject_CallFinalizerFromDealloc(o)) return; + } + } + #endif + __PYX_XCLEAR_MEMVIEW(&p->sizes, 1); + p->sizes.memview = NULL; p->sizes.data = NULL; + #if CYTHON_USE_TYPE_SLOTS || CYTHON_COMPILING_IN_PYPY + (*Py_TYPE(o)->tp_free)(o); + #else + { + freefunc tp_free = (freefunc)PyType_GetSlot(Py_TYPE(o), Py_tp_free); + if (tp_free) tp_free(o); + } + #endif +} + +static PyMethodDef __pyx_methods_7fairseq_4data_22token_block_utils_fast_DatasetSearcher[] = { + {"__reduce_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_3__reduce_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {"__setstate_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_5__setstate_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {0, 0, 0, 0} +}; +#if CYTHON_USE_TYPE_SPECS +static PyType_Slot __pyx_type_7fairseq_4data_22token_block_utils_fast_DatasetSearcher_slots[] = { + {Py_tp_dealloc, (void *)__pyx_tp_dealloc_7fairseq_4data_22token_block_utils_fast_DatasetSearcher}, + {Py_tp_doc, (void *)PyDoc_STR("Helper for mapping \"flat\" indices to indices and offsets in an\n underlying dataset.")}, + {Py_tp_methods, (void *)__pyx_methods_7fairseq_4data_22token_block_utils_fast_DatasetSearcher}, + {Py_tp_init, (void *)__pyx_pw_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_1__init__}, + {Py_tp_new, (void *)__pyx_tp_new_7fairseq_4data_22token_block_utils_fast_DatasetSearcher}, + {0, 0}, +}; +static PyType_Spec __pyx_type_7fairseq_4data_22token_block_utils_fast_DatasetSearcher_spec = { + "fairseq.data.token_block_utils_fast.DatasetSearcher", + sizeof(struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher), + 0, + Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE, + __pyx_type_7fairseq_4data_22token_block_utils_fast_DatasetSearcher_slots, +}; +#else + +static PyTypeObject __pyx_type_7fairseq_4data_22token_block_utils_fast_DatasetSearcher = { + PyVarObject_HEAD_INIT(0, 0) + "fairseq.data.token_block_utils_fast.""DatasetSearcher", /*tp_name*/ + sizeof(struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + __pyx_tp_dealloc_7fairseq_4data_22token_block_utils_fast_DatasetSearcher, /*tp_dealloc*/ + #if PY_VERSION_HEX < 0x030800b4 + 0, /*tp_print*/ + #endif + #if PY_VERSION_HEX >= 0x030800b4 + 0, /*tp_vectorcall_offset*/ + #endif + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + #if PY_MAJOR_VERSION < 3 + 0, /*tp_compare*/ + #endif + #if PY_MAJOR_VERSION >= 3 + 0, /*tp_as_async*/ + #endif + 0, /*tp_repr*/ + 0, /*tp_as_number*/ + 0, /*tp_as_sequence*/ + 0, /*tp_as_mapping*/ + 0, /*tp_hash*/ + 0, /*tp_call*/ + 0, /*tp_str*/ + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + 0, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE, /*tp_flags*/ + PyDoc_STR("Helper for mapping \"flat\" indices to indices and offsets in an\n underlying dataset."), /*tp_doc*/ + 0, /*tp_traverse*/ + 0, /*tp_clear*/ + 0, /*tp_richcompare*/ + 0, /*tp_weaklistoffset*/ + 0, /*tp_iter*/ + 0, /*tp_iternext*/ + __pyx_methods_7fairseq_4data_22token_block_utils_fast_DatasetSearcher, /*tp_methods*/ + 0, /*tp_members*/ + 0, /*tp_getset*/ + 0, /*tp_base*/ + 0, /*tp_dict*/ + 0, /*tp_descr_get*/ + 0, /*tp_descr_set*/ + #if !CYTHON_USE_TYPE_SPECS + 0, /*tp_dictoffset*/ + #endif + __pyx_pw_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_1__init__, /*tp_init*/ + 0, /*tp_alloc*/ + __pyx_tp_new_7fairseq_4data_22token_block_utils_fast_DatasetSearcher, /*tp_new*/ + 0, /*tp_free*/ + 0, /*tp_is_gc*/ + 0, /*tp_bases*/ + 0, /*tp_mro*/ + 0, /*tp_cache*/ + 0, /*tp_subclasses*/ + 0, /*tp_weaklist*/ + 0, /*tp_del*/ + 0, /*tp_version_tag*/ + #if PY_VERSION_HEX >= 0x030400a1 + #if CYTHON_USE_TP_FINALIZE + 0, /*tp_finalize*/ + #else + NULL, /*tp_finalize*/ + #endif + #endif + #if PY_VERSION_HEX >= 0x030800b1 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07030800) + 0, /*tp_vectorcall*/ + #endif + #if __PYX_NEED_TP_PRINT_SLOT == 1 + 0, /*tp_print*/ + #endif + #if PY_VERSION_HEX >= 0x030C0000 + 0, /*tp_watched*/ + #endif + #if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX >= 0x03090000 && PY_VERSION_HEX < 0x030a0000 + 0, /*tp_pypy_flags*/ + #endif +}; +#endif +static struct __pyx_vtabstruct_array __pyx_vtable_array; + +static PyObject *__pyx_tp_new_array(PyTypeObject *t, PyObject *a, PyObject *k) { + struct __pyx_array_obj *p; + PyObject *o; + #if CYTHON_COMPILING_IN_LIMITED_API + allocfunc alloc_func = (allocfunc)PyType_GetSlot(t, Py_tp_alloc); + o = alloc_func(t, 0); + #else + if (likely(!__Pyx_PyType_HasFeature(t, Py_TPFLAGS_IS_ABSTRACT))) { + o = (*t->tp_alloc)(t, 0); + } else { + o = (PyObject *) PyBaseObject_Type.tp_new(t, __pyx_empty_tuple, 0); + } + if (unlikely(!o)) return 0; + #endif + p = ((struct __pyx_array_obj *)o); + p->__pyx_vtab = __pyx_vtabptr_array; + p->mode = ((PyObject*)Py_None); Py_INCREF(Py_None); + p->_format = ((PyObject*)Py_None); Py_INCREF(Py_None); + if (unlikely(__pyx_array___cinit__(o, a, k) < 0)) goto bad; + return o; + bad: + Py_DECREF(o); o = 0; + return NULL; +} + +static void __pyx_tp_dealloc_array(PyObject *o) { + struct __pyx_array_obj *p = (struct __pyx_array_obj *)o; + #if CYTHON_USE_TP_FINALIZE + if (unlikely((PY_VERSION_HEX >= 0x03080000 || __Pyx_PyType_HasFeature(Py_TYPE(o), Py_TPFLAGS_HAVE_FINALIZE)) && __Pyx_PyObject_GetSlot(o, tp_finalize, destructor)) && (!PyType_IS_GC(Py_TYPE(o)) || !__Pyx_PyObject_GC_IsFinalized(o))) { + if (__Pyx_PyObject_GetSlot(o, tp_dealloc, destructor) == __pyx_tp_dealloc_array) { + if (PyObject_CallFinalizerFromDealloc(o)) return; + } + } + #endif + { + PyObject *etype, *eval, *etb; + PyErr_Fetch(&etype, &eval, &etb); + __Pyx_SET_REFCNT(o, Py_REFCNT(o) + 1); + __pyx_array___dealloc__(o); + __Pyx_SET_REFCNT(o, Py_REFCNT(o) - 1); + PyErr_Restore(etype, eval, etb); + } + Py_CLEAR(p->mode); + Py_CLEAR(p->_format); + #if CYTHON_USE_TYPE_SLOTS || CYTHON_COMPILING_IN_PYPY + (*Py_TYPE(o)->tp_free)(o); + #else + { + freefunc tp_free = (freefunc)PyType_GetSlot(Py_TYPE(o), Py_tp_free); + if (tp_free) tp_free(o); + } + #endif +} +static PyObject *__pyx_sq_item_array(PyObject *o, Py_ssize_t i) { + PyObject *r; + PyObject *x = PyInt_FromSsize_t(i); if(!x) return 0; + r = Py_TYPE(o)->tp_as_mapping->mp_subscript(o, x); + Py_DECREF(x); + return r; +} + +static int __pyx_mp_ass_subscript_array(PyObject *o, PyObject *i, PyObject *v) { + if (v) { + return __pyx_array___setitem__(o, i, v); + } + else { + __Pyx_TypeName o_type_name; + o_type_name = __Pyx_PyType_GetName(Py_TYPE(o)); + PyErr_Format(PyExc_NotImplementedError, + "Subscript deletion not supported by " __Pyx_FMT_TYPENAME, o_type_name); + __Pyx_DECREF_TypeName(o_type_name); + return -1; + } +} + +static PyObject *__pyx_tp_getattro_array(PyObject *o, PyObject *n) { + PyObject *v = __Pyx_PyObject_GenericGetAttr(o, n); + if (!v && PyErr_ExceptionMatches(PyExc_AttributeError)) { + PyErr_Clear(); + v = __pyx_array___getattr__(o, n); + } + return v; +} + +static PyObject *__pyx_getprop___pyx_array_memview(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_5array_7memview_1__get__(o); +} + +static PyMethodDef __pyx_methods_array[] = { + {"__getattr__", (PyCFunction)__pyx_array___getattr__, METH_O|METH_COEXIST, 0}, + {"__reduce_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw___pyx_array_1__reduce_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {"__setstate_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw___pyx_array_3__setstate_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {0, 0, 0, 0} +}; + +static struct PyGetSetDef __pyx_getsets_array[] = { + {(char *)"memview", __pyx_getprop___pyx_array_memview, 0, (char *)0, 0}, + {0, 0, 0, 0, 0} +}; +#if CYTHON_USE_TYPE_SPECS +#if !CYTHON_COMPILING_IN_LIMITED_API + +static PyBufferProcs __pyx_tp_as_buffer_array = { + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getreadbuffer*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getwritebuffer*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getsegcount*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getcharbuffer*/ + #endif + __pyx_array_getbuffer, /*bf_getbuffer*/ + 0, /*bf_releasebuffer*/ +}; +#endif +static PyType_Slot __pyx_type___pyx_array_slots[] = { + {Py_tp_dealloc, (void *)__pyx_tp_dealloc_array}, + {Py_sq_length, (void *)__pyx_array___len__}, + {Py_sq_item, (void *)__pyx_sq_item_array}, + {Py_mp_length, (void *)__pyx_array___len__}, + {Py_mp_subscript, (void *)__pyx_array___getitem__}, + {Py_mp_ass_subscript, (void *)__pyx_mp_ass_subscript_array}, + {Py_tp_getattro, (void *)__pyx_tp_getattro_array}, + #if defined(Py_bf_getbuffer) + {Py_bf_getbuffer, (void *)__pyx_array_getbuffer}, + #endif + {Py_tp_methods, (void *)__pyx_methods_array}, + {Py_tp_getset, (void *)__pyx_getsets_array}, + {Py_tp_new, (void *)__pyx_tp_new_array}, + {0, 0}, +}; +static PyType_Spec __pyx_type___pyx_array_spec = { + "fairseq.data.token_block_utils_fast.array", + sizeof(struct __pyx_array_obj), + 0, + Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_SEQUENCE, + __pyx_type___pyx_array_slots, +}; +#else + +static PySequenceMethods __pyx_tp_as_sequence_array = { + __pyx_array___len__, /*sq_length*/ + 0, /*sq_concat*/ + 0, /*sq_repeat*/ + __pyx_sq_item_array, /*sq_item*/ + 0, /*sq_slice*/ + 0, /*sq_ass_item*/ + 0, /*sq_ass_slice*/ + 0, /*sq_contains*/ + 0, /*sq_inplace_concat*/ + 0, /*sq_inplace_repeat*/ +}; + +static PyMappingMethods __pyx_tp_as_mapping_array = { + __pyx_array___len__, /*mp_length*/ + __pyx_array___getitem__, /*mp_subscript*/ + __pyx_mp_ass_subscript_array, /*mp_ass_subscript*/ +}; + +static PyBufferProcs __pyx_tp_as_buffer_array = { + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getreadbuffer*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getwritebuffer*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getsegcount*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getcharbuffer*/ + #endif + __pyx_array_getbuffer, /*bf_getbuffer*/ + 0, /*bf_releasebuffer*/ +}; + +static PyTypeObject __pyx_type___pyx_array = { + PyVarObject_HEAD_INIT(0, 0) + "fairseq.data.token_block_utils_fast.""array", /*tp_name*/ + sizeof(struct __pyx_array_obj), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + __pyx_tp_dealloc_array, /*tp_dealloc*/ + #if PY_VERSION_HEX < 0x030800b4 + 0, /*tp_print*/ + #endif + #if PY_VERSION_HEX >= 0x030800b4 + 0, /*tp_vectorcall_offset*/ + #endif + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + #if PY_MAJOR_VERSION < 3 + 0, /*tp_compare*/ + #endif + #if PY_MAJOR_VERSION >= 3 + 0, /*tp_as_async*/ + #endif + 0, /*tp_repr*/ + 0, /*tp_as_number*/ + &__pyx_tp_as_sequence_array, /*tp_as_sequence*/ + &__pyx_tp_as_mapping_array, /*tp_as_mapping*/ + 0, /*tp_hash*/ + 0, /*tp_call*/ + 0, /*tp_str*/ + __pyx_tp_getattro_array, /*tp_getattro*/ + 0, /*tp_setattro*/ + &__pyx_tp_as_buffer_array, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_SEQUENCE, /*tp_flags*/ + 0, /*tp_doc*/ + 0, /*tp_traverse*/ + 0, /*tp_clear*/ + 0, /*tp_richcompare*/ + 0, /*tp_weaklistoffset*/ + 0, /*tp_iter*/ + 0, /*tp_iternext*/ + __pyx_methods_array, /*tp_methods*/ + 0, /*tp_members*/ + __pyx_getsets_array, /*tp_getset*/ + 0, /*tp_base*/ + 0, /*tp_dict*/ + 0, /*tp_descr_get*/ + 0, /*tp_descr_set*/ + #if !CYTHON_USE_TYPE_SPECS + 0, /*tp_dictoffset*/ + #endif + 0, /*tp_init*/ + 0, /*tp_alloc*/ + __pyx_tp_new_array, /*tp_new*/ + 0, /*tp_free*/ + 0, /*tp_is_gc*/ + 0, /*tp_bases*/ + 0, /*tp_mro*/ + 0, /*tp_cache*/ + 0, /*tp_subclasses*/ + 0, /*tp_weaklist*/ + 0, /*tp_del*/ + 0, /*tp_version_tag*/ + #if PY_VERSION_HEX >= 0x030400a1 + #if CYTHON_USE_TP_FINALIZE + 0, /*tp_finalize*/ + #else + NULL, /*tp_finalize*/ + #endif + #endif + #if PY_VERSION_HEX >= 0x030800b1 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07030800) + 0, /*tp_vectorcall*/ + #endif + #if __PYX_NEED_TP_PRINT_SLOT == 1 + 0, /*tp_print*/ + #endif + #if PY_VERSION_HEX >= 0x030C0000 + 0, /*tp_watched*/ + #endif + #if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX >= 0x03090000 && PY_VERSION_HEX < 0x030a0000 + 0, /*tp_pypy_flags*/ + #endif +}; +#endif + +static PyObject *__pyx_tp_new_Enum(PyTypeObject *t, CYTHON_UNUSED PyObject *a, CYTHON_UNUSED PyObject *k) { + struct __pyx_MemviewEnum_obj *p; + PyObject *o; + #if CYTHON_COMPILING_IN_LIMITED_API + allocfunc alloc_func = (allocfunc)PyType_GetSlot(t, Py_tp_alloc); + o = alloc_func(t, 0); + #else + if (likely(!__Pyx_PyType_HasFeature(t, Py_TPFLAGS_IS_ABSTRACT))) { + o = (*t->tp_alloc)(t, 0); + } else { + o = (PyObject *) PyBaseObject_Type.tp_new(t, __pyx_empty_tuple, 0); + } + if (unlikely(!o)) return 0; + #endif + p = ((struct __pyx_MemviewEnum_obj *)o); + p->name = Py_None; Py_INCREF(Py_None); + return o; +} + +static void __pyx_tp_dealloc_Enum(PyObject *o) { + struct __pyx_MemviewEnum_obj *p = (struct __pyx_MemviewEnum_obj *)o; + #if CYTHON_USE_TP_FINALIZE + if (unlikely((PY_VERSION_HEX >= 0x03080000 || __Pyx_PyType_HasFeature(Py_TYPE(o), Py_TPFLAGS_HAVE_FINALIZE)) && __Pyx_PyObject_GetSlot(o, tp_finalize, destructor)) && !__Pyx_PyObject_GC_IsFinalized(o)) { + if (__Pyx_PyObject_GetSlot(o, tp_dealloc, destructor) == __pyx_tp_dealloc_Enum) { + if (PyObject_CallFinalizerFromDealloc(o)) return; + } + } + #endif + PyObject_GC_UnTrack(o); + Py_CLEAR(p->name); + #if CYTHON_USE_TYPE_SLOTS || CYTHON_COMPILING_IN_PYPY + (*Py_TYPE(o)->tp_free)(o); + #else + { + freefunc tp_free = (freefunc)PyType_GetSlot(Py_TYPE(o), Py_tp_free); + if (tp_free) tp_free(o); + } + #endif +} + +static int __pyx_tp_traverse_Enum(PyObject *o, visitproc v, void *a) { + int e; + struct __pyx_MemviewEnum_obj *p = (struct __pyx_MemviewEnum_obj *)o; + if (p->name) { + e = (*v)(p->name, a); if (e) return e; + } + return 0; +} + +static int __pyx_tp_clear_Enum(PyObject *o) { + PyObject* tmp; + struct __pyx_MemviewEnum_obj *p = (struct __pyx_MemviewEnum_obj *)o; + tmp = ((PyObject*)p->name); + p->name = Py_None; Py_INCREF(Py_None); + Py_XDECREF(tmp); + return 0; +} + +static PyObject *__pyx_specialmethod___pyx_MemviewEnum___repr__(PyObject *self, CYTHON_UNUSED PyObject *arg) { + return __pyx_MemviewEnum___repr__(self); +} + +static PyMethodDef __pyx_methods_Enum[] = { + {"__repr__", (PyCFunction)__pyx_specialmethod___pyx_MemviewEnum___repr__, METH_NOARGS|METH_COEXIST, 0}, + {"__reduce_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw___pyx_MemviewEnum_1__reduce_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {"__setstate_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw___pyx_MemviewEnum_3__setstate_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {0, 0, 0, 0} +}; +#if CYTHON_USE_TYPE_SPECS +static PyType_Slot __pyx_type___pyx_MemviewEnum_slots[] = { + {Py_tp_dealloc, (void *)__pyx_tp_dealloc_Enum}, + {Py_tp_repr, (void *)__pyx_MemviewEnum___repr__}, + {Py_tp_traverse, (void *)__pyx_tp_traverse_Enum}, + {Py_tp_clear, (void *)__pyx_tp_clear_Enum}, + {Py_tp_methods, (void *)__pyx_methods_Enum}, + {Py_tp_init, (void *)__pyx_MemviewEnum___init__}, + {Py_tp_new, (void *)__pyx_tp_new_Enum}, + {0, 0}, +}; +static PyType_Spec __pyx_type___pyx_MemviewEnum_spec = { + "fairseq.data.token_block_utils_fast.Enum", + sizeof(struct __pyx_MemviewEnum_obj), + 0, + Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC, + __pyx_type___pyx_MemviewEnum_slots, +}; +#else + +static PyTypeObject __pyx_type___pyx_MemviewEnum = { + PyVarObject_HEAD_INIT(0, 0) + "fairseq.data.token_block_utils_fast.""Enum", /*tp_name*/ + sizeof(struct __pyx_MemviewEnum_obj), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + __pyx_tp_dealloc_Enum, /*tp_dealloc*/ + #if PY_VERSION_HEX < 0x030800b4 + 0, /*tp_print*/ + #endif + #if PY_VERSION_HEX >= 0x030800b4 + 0, /*tp_vectorcall_offset*/ + #endif + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + #if PY_MAJOR_VERSION < 3 + 0, /*tp_compare*/ + #endif + #if PY_MAJOR_VERSION >= 3 + 0, /*tp_as_async*/ + #endif + __pyx_MemviewEnum___repr__, /*tp_repr*/ + 0, /*tp_as_number*/ + 0, /*tp_as_sequence*/ + 0, /*tp_as_mapping*/ + 0, /*tp_hash*/ + 0, /*tp_call*/ + 0, /*tp_str*/ + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + 0, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC, /*tp_flags*/ + 0, /*tp_doc*/ + __pyx_tp_traverse_Enum, /*tp_traverse*/ + __pyx_tp_clear_Enum, /*tp_clear*/ + 0, /*tp_richcompare*/ + 0, /*tp_weaklistoffset*/ + 0, /*tp_iter*/ + 0, /*tp_iternext*/ + __pyx_methods_Enum, /*tp_methods*/ + 0, /*tp_members*/ + 0, /*tp_getset*/ + 0, /*tp_base*/ + 0, /*tp_dict*/ + 0, /*tp_descr_get*/ + 0, /*tp_descr_set*/ + #if !CYTHON_USE_TYPE_SPECS + 0, /*tp_dictoffset*/ + #endif + __pyx_MemviewEnum___init__, /*tp_init*/ + 0, /*tp_alloc*/ + __pyx_tp_new_Enum, /*tp_new*/ + 0, /*tp_free*/ + 0, /*tp_is_gc*/ + 0, /*tp_bases*/ + 0, /*tp_mro*/ + 0, /*tp_cache*/ + 0, /*tp_subclasses*/ + 0, /*tp_weaklist*/ + 0, /*tp_del*/ + 0, /*tp_version_tag*/ + #if PY_VERSION_HEX >= 0x030400a1 + #if CYTHON_USE_TP_FINALIZE + 0, /*tp_finalize*/ + #else + NULL, /*tp_finalize*/ + #endif + #endif + #if PY_VERSION_HEX >= 0x030800b1 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07030800) + 0, /*tp_vectorcall*/ + #endif + #if __PYX_NEED_TP_PRINT_SLOT == 1 + 0, /*tp_print*/ + #endif + #if PY_VERSION_HEX >= 0x030C0000 + 0, /*tp_watched*/ + #endif + #if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX >= 0x03090000 && PY_VERSION_HEX < 0x030a0000 + 0, /*tp_pypy_flags*/ + #endif +}; +#endif +static struct __pyx_vtabstruct_memoryview __pyx_vtable_memoryview; + +static PyObject *__pyx_tp_new_memoryview(PyTypeObject *t, PyObject *a, PyObject *k) { + struct __pyx_memoryview_obj *p; + PyObject *o; + #if CYTHON_COMPILING_IN_LIMITED_API + allocfunc alloc_func = (allocfunc)PyType_GetSlot(t, Py_tp_alloc); + o = alloc_func(t, 0); + #else + if (likely(!__Pyx_PyType_HasFeature(t, Py_TPFLAGS_IS_ABSTRACT))) { + o = (*t->tp_alloc)(t, 0); + } else { + o = (PyObject *) PyBaseObject_Type.tp_new(t, __pyx_empty_tuple, 0); + } + if (unlikely(!o)) return 0; + #endif + p = ((struct __pyx_memoryview_obj *)o); + p->__pyx_vtab = __pyx_vtabptr_memoryview; + p->obj = Py_None; Py_INCREF(Py_None); + p->_size = Py_None; Py_INCREF(Py_None); + p->_array_interface = Py_None; Py_INCREF(Py_None); + p->view.obj = NULL; + if (unlikely(__pyx_memoryview___cinit__(o, a, k) < 0)) goto bad; + return o; + bad: + Py_DECREF(o); o = 0; + return NULL; +} + +static void __pyx_tp_dealloc_memoryview(PyObject *o) { + struct __pyx_memoryview_obj *p = (struct __pyx_memoryview_obj *)o; + #if CYTHON_USE_TP_FINALIZE + if (unlikely((PY_VERSION_HEX >= 0x03080000 || __Pyx_PyType_HasFeature(Py_TYPE(o), Py_TPFLAGS_HAVE_FINALIZE)) && __Pyx_PyObject_GetSlot(o, tp_finalize, destructor)) && !__Pyx_PyObject_GC_IsFinalized(o)) { + if (__Pyx_PyObject_GetSlot(o, tp_dealloc, destructor) == __pyx_tp_dealloc_memoryview) { + if (PyObject_CallFinalizerFromDealloc(o)) return; + } + } + #endif + PyObject_GC_UnTrack(o); + { + PyObject *etype, *eval, *etb; + PyErr_Fetch(&etype, &eval, &etb); + __Pyx_SET_REFCNT(o, Py_REFCNT(o) + 1); + __pyx_memoryview___dealloc__(o); + __Pyx_SET_REFCNT(o, Py_REFCNT(o) - 1); + PyErr_Restore(etype, eval, etb); + } + Py_CLEAR(p->obj); + Py_CLEAR(p->_size); + Py_CLEAR(p->_array_interface); + #if CYTHON_USE_TYPE_SLOTS || CYTHON_COMPILING_IN_PYPY + (*Py_TYPE(o)->tp_free)(o); + #else + { + freefunc tp_free = (freefunc)PyType_GetSlot(Py_TYPE(o), Py_tp_free); + if (tp_free) tp_free(o); + } + #endif +} + +static int __pyx_tp_traverse_memoryview(PyObject *o, visitproc v, void *a) { + int e; + struct __pyx_memoryview_obj *p = (struct __pyx_memoryview_obj *)o; + if (p->obj) { + e = (*v)(p->obj, a); if (e) return e; + } + if (p->_size) { + e = (*v)(p->_size, a); if (e) return e; + } + if (p->_array_interface) { + e = (*v)(p->_array_interface, a); if (e) return e; + } + if (p->view.obj) { + e = (*v)(p->view.obj, a); if (e) return e; + } + return 0; +} + +static int __pyx_tp_clear_memoryview(PyObject *o) { + PyObject* tmp; + struct __pyx_memoryview_obj *p = (struct __pyx_memoryview_obj *)o; + tmp = ((PyObject*)p->obj); + p->obj = Py_None; Py_INCREF(Py_None); + Py_XDECREF(tmp); + tmp = ((PyObject*)p->_size); + p->_size = Py_None; Py_INCREF(Py_None); + Py_XDECREF(tmp); + tmp = ((PyObject*)p->_array_interface); + p->_array_interface = Py_None; Py_INCREF(Py_None); + Py_XDECREF(tmp); + Py_CLEAR(p->view.obj); + return 0; +} +static PyObject *__pyx_sq_item_memoryview(PyObject *o, Py_ssize_t i) { + PyObject *r; + PyObject *x = PyInt_FromSsize_t(i); if(!x) return 0; + r = Py_TYPE(o)->tp_as_mapping->mp_subscript(o, x); + Py_DECREF(x); + return r; +} + +static int __pyx_mp_ass_subscript_memoryview(PyObject *o, PyObject *i, PyObject *v) { + if (v) { + return __pyx_memoryview___setitem__(o, i, v); + } + else { + __Pyx_TypeName o_type_name; + o_type_name = __Pyx_PyType_GetName(Py_TYPE(o)); + PyErr_Format(PyExc_NotImplementedError, + "Subscript deletion not supported by " __Pyx_FMT_TYPENAME, o_type_name); + __Pyx_DECREF_TypeName(o_type_name); + return -1; + } +} + +static PyObject *__pyx_getprop___pyx_memoryview_T(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_10memoryview_1T_1__get__(o); +} + +static PyObject *__pyx_getprop___pyx_memoryview_base(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_10memoryview_4base_1__get__(o); +} + +static PyObject *__pyx_getprop___pyx_memoryview_shape(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_10memoryview_5shape_1__get__(o); +} + +static PyObject *__pyx_getprop___pyx_memoryview_strides(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_10memoryview_7strides_1__get__(o); +} + +static PyObject *__pyx_getprop___pyx_memoryview_suboffsets(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_10memoryview_10suboffsets_1__get__(o); +} + +static PyObject *__pyx_getprop___pyx_memoryview_ndim(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_10memoryview_4ndim_1__get__(o); +} + +static PyObject *__pyx_getprop___pyx_memoryview_itemsize(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_10memoryview_8itemsize_1__get__(o); +} + +static PyObject *__pyx_getprop___pyx_memoryview_nbytes(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_10memoryview_6nbytes_1__get__(o); +} + +static PyObject *__pyx_getprop___pyx_memoryview_size(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_10memoryview_4size_1__get__(o); +} + +static PyObject *__pyx_specialmethod___pyx_memoryview___repr__(PyObject *self, CYTHON_UNUSED PyObject *arg) { + return __pyx_memoryview___repr__(self); +} + +static PyMethodDef __pyx_methods_memoryview[] = { + {"__repr__", (PyCFunction)__pyx_specialmethod___pyx_memoryview___repr__, METH_NOARGS|METH_COEXIST, 0}, + {"is_c_contig", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_memoryview_is_c_contig, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {"is_f_contig", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_memoryview_is_f_contig, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {"copy", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_memoryview_copy, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {"copy_fortran", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_memoryview_copy_fortran, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {"__reduce_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw___pyx_memoryview_1__reduce_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {"__setstate_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw___pyx_memoryview_3__setstate_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {0, 0, 0, 0} +}; + +static struct PyGetSetDef __pyx_getsets_memoryview[] = { + {(char *)"T", __pyx_getprop___pyx_memoryview_T, 0, (char *)0, 0}, + {(char *)"base", __pyx_getprop___pyx_memoryview_base, 0, (char *)0, 0}, + {(char *)"shape", __pyx_getprop___pyx_memoryview_shape, 0, (char *)0, 0}, + {(char *)"strides", __pyx_getprop___pyx_memoryview_strides, 0, (char *)0, 0}, + {(char *)"suboffsets", __pyx_getprop___pyx_memoryview_suboffsets, 0, (char *)0, 0}, + {(char *)"ndim", __pyx_getprop___pyx_memoryview_ndim, 0, (char *)0, 0}, + {(char *)"itemsize", __pyx_getprop___pyx_memoryview_itemsize, 0, (char *)0, 0}, + {(char *)"nbytes", __pyx_getprop___pyx_memoryview_nbytes, 0, (char *)0, 0}, + {(char *)"size", __pyx_getprop___pyx_memoryview_size, 0, (char *)0, 0}, + {0, 0, 0, 0, 0} +}; +#if CYTHON_USE_TYPE_SPECS +#if !CYTHON_COMPILING_IN_LIMITED_API + +static PyBufferProcs __pyx_tp_as_buffer_memoryview = { + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getreadbuffer*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getwritebuffer*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getsegcount*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getcharbuffer*/ + #endif + __pyx_memoryview_getbuffer, /*bf_getbuffer*/ + 0, /*bf_releasebuffer*/ +}; +#endif +static PyType_Slot __pyx_type___pyx_memoryview_slots[] = { + {Py_tp_dealloc, (void *)__pyx_tp_dealloc_memoryview}, + {Py_tp_repr, (void *)__pyx_memoryview___repr__}, + {Py_sq_length, (void *)__pyx_memoryview___len__}, + {Py_sq_item, (void *)__pyx_sq_item_memoryview}, + {Py_mp_length, (void *)__pyx_memoryview___len__}, + {Py_mp_subscript, (void *)__pyx_memoryview___getitem__}, + {Py_mp_ass_subscript, (void *)__pyx_mp_ass_subscript_memoryview}, + {Py_tp_str, (void *)__pyx_memoryview___str__}, + #if defined(Py_bf_getbuffer) + {Py_bf_getbuffer, (void *)__pyx_memoryview_getbuffer}, + #endif + {Py_tp_traverse, (void *)__pyx_tp_traverse_memoryview}, + {Py_tp_clear, (void *)__pyx_tp_clear_memoryview}, + {Py_tp_methods, (void *)__pyx_methods_memoryview}, + {Py_tp_getset, (void *)__pyx_getsets_memoryview}, + {Py_tp_new, (void *)__pyx_tp_new_memoryview}, + {0, 0}, +}; +static PyType_Spec __pyx_type___pyx_memoryview_spec = { + "fairseq.data.token_block_utils_fast.memoryview", + sizeof(struct __pyx_memoryview_obj), + 0, + Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC, + __pyx_type___pyx_memoryview_slots, +}; +#else + +static PySequenceMethods __pyx_tp_as_sequence_memoryview = { + __pyx_memoryview___len__, /*sq_length*/ + 0, /*sq_concat*/ + 0, /*sq_repeat*/ + __pyx_sq_item_memoryview, /*sq_item*/ + 0, /*sq_slice*/ + 0, /*sq_ass_item*/ + 0, /*sq_ass_slice*/ + 0, /*sq_contains*/ + 0, /*sq_inplace_concat*/ + 0, /*sq_inplace_repeat*/ +}; + +static PyMappingMethods __pyx_tp_as_mapping_memoryview = { + __pyx_memoryview___len__, /*mp_length*/ + __pyx_memoryview___getitem__, /*mp_subscript*/ + __pyx_mp_ass_subscript_memoryview, /*mp_ass_subscript*/ +}; + +static PyBufferProcs __pyx_tp_as_buffer_memoryview = { + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getreadbuffer*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getwritebuffer*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getsegcount*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getcharbuffer*/ + #endif + __pyx_memoryview_getbuffer, /*bf_getbuffer*/ + 0, /*bf_releasebuffer*/ +}; + +static PyTypeObject __pyx_type___pyx_memoryview = { + PyVarObject_HEAD_INIT(0, 0) + "fairseq.data.token_block_utils_fast.""memoryview", /*tp_name*/ + sizeof(struct __pyx_memoryview_obj), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + __pyx_tp_dealloc_memoryview, /*tp_dealloc*/ + #if PY_VERSION_HEX < 0x030800b4 + 0, /*tp_print*/ + #endif + #if PY_VERSION_HEX >= 0x030800b4 + 0, /*tp_vectorcall_offset*/ + #endif + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + #if PY_MAJOR_VERSION < 3 + 0, /*tp_compare*/ + #endif + #if PY_MAJOR_VERSION >= 3 + 0, /*tp_as_async*/ + #endif + __pyx_memoryview___repr__, /*tp_repr*/ + 0, /*tp_as_number*/ + &__pyx_tp_as_sequence_memoryview, /*tp_as_sequence*/ + &__pyx_tp_as_mapping_memoryview, /*tp_as_mapping*/ + 0, /*tp_hash*/ + 0, /*tp_call*/ + __pyx_memoryview___str__, /*tp_str*/ + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + &__pyx_tp_as_buffer_memoryview, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC, /*tp_flags*/ + 0, /*tp_doc*/ + __pyx_tp_traverse_memoryview, /*tp_traverse*/ + __pyx_tp_clear_memoryview, /*tp_clear*/ + 0, /*tp_richcompare*/ + 0, /*tp_weaklistoffset*/ + 0, /*tp_iter*/ + 0, /*tp_iternext*/ + __pyx_methods_memoryview, /*tp_methods*/ + 0, /*tp_members*/ + __pyx_getsets_memoryview, /*tp_getset*/ + 0, /*tp_base*/ + 0, /*tp_dict*/ + 0, /*tp_descr_get*/ + 0, /*tp_descr_set*/ + #if !CYTHON_USE_TYPE_SPECS + 0, /*tp_dictoffset*/ + #endif + 0, /*tp_init*/ + 0, /*tp_alloc*/ + __pyx_tp_new_memoryview, /*tp_new*/ + 0, /*tp_free*/ + 0, /*tp_is_gc*/ + 0, /*tp_bases*/ + 0, /*tp_mro*/ + 0, /*tp_cache*/ + 0, /*tp_subclasses*/ + 0, /*tp_weaklist*/ + 0, /*tp_del*/ + 0, /*tp_version_tag*/ + #if PY_VERSION_HEX >= 0x030400a1 + #if CYTHON_USE_TP_FINALIZE + 0, /*tp_finalize*/ + #else + NULL, /*tp_finalize*/ + #endif + #endif + #if PY_VERSION_HEX >= 0x030800b1 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07030800) + 0, /*tp_vectorcall*/ + #endif + #if __PYX_NEED_TP_PRINT_SLOT == 1 + 0, /*tp_print*/ + #endif + #if PY_VERSION_HEX >= 0x030C0000 + 0, /*tp_watched*/ + #endif + #if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX >= 0x03090000 && PY_VERSION_HEX < 0x030a0000 + 0, /*tp_pypy_flags*/ + #endif +}; +#endif +static struct __pyx_vtabstruct__memoryviewslice __pyx_vtable__memoryviewslice; + +static PyObject *__pyx_tp_new__memoryviewslice(PyTypeObject *t, PyObject *a, PyObject *k) { + struct __pyx_memoryviewslice_obj *p; + PyObject *o = __pyx_tp_new_memoryview(t, a, k); + if (unlikely(!o)) return 0; + p = ((struct __pyx_memoryviewslice_obj *)o); + p->__pyx_base.__pyx_vtab = (struct __pyx_vtabstruct_memoryview*)__pyx_vtabptr__memoryviewslice; + new((void*)&(p->from_slice)) __Pyx_memviewslice(); + p->from_object = Py_None; Py_INCREF(Py_None); + p->from_slice.memview = NULL; + return o; +} + +static void __pyx_tp_dealloc__memoryviewslice(PyObject *o) { + struct __pyx_memoryviewslice_obj *p = (struct __pyx_memoryviewslice_obj *)o; + #if CYTHON_USE_TP_FINALIZE + if (unlikely((PY_VERSION_HEX >= 0x03080000 || __Pyx_PyType_HasFeature(Py_TYPE(o), Py_TPFLAGS_HAVE_FINALIZE)) && __Pyx_PyObject_GetSlot(o, tp_finalize, destructor)) && !__Pyx_PyObject_GC_IsFinalized(o)) { + if (__Pyx_PyObject_GetSlot(o, tp_dealloc, destructor) == __pyx_tp_dealloc__memoryviewslice) { + if (PyObject_CallFinalizerFromDealloc(o)) return; + } + } + #endif + PyObject_GC_UnTrack(o); + { + PyObject *etype, *eval, *etb; + PyErr_Fetch(&etype, &eval, &etb); + __Pyx_SET_REFCNT(o, Py_REFCNT(o) + 1); + __pyx_memoryviewslice___dealloc__(o); + __Pyx_SET_REFCNT(o, Py_REFCNT(o) - 1); + PyErr_Restore(etype, eval, etb); + } + __Pyx_call_destructor(p->from_slice); + Py_CLEAR(p->from_object); + PyObject_GC_Track(o); + __pyx_tp_dealloc_memoryview(o); +} + +static int __pyx_tp_traverse__memoryviewslice(PyObject *o, visitproc v, void *a) { + int e; + struct __pyx_memoryviewslice_obj *p = (struct __pyx_memoryviewslice_obj *)o; + e = __pyx_tp_traverse_memoryview(o, v, a); if (e) return e; + if (p->from_object) { + e = (*v)(p->from_object, a); if (e) return e; + } + return 0; +} + +static int __pyx_tp_clear__memoryviewslice(PyObject *o) { + PyObject* tmp; + struct __pyx_memoryviewslice_obj *p = (struct __pyx_memoryviewslice_obj *)o; + __pyx_tp_clear_memoryview(o); + tmp = ((PyObject*)p->from_object); + p->from_object = Py_None; Py_INCREF(Py_None); + Py_XDECREF(tmp); + __PYX_XCLEAR_MEMVIEW(&p->from_slice, 1); + return 0; +} + +static PyMethodDef __pyx_methods__memoryviewslice[] = { + {"__reduce_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw___pyx_memoryviewslice_1__reduce_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {"__setstate_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw___pyx_memoryviewslice_3__setstate_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {0, 0, 0, 0} +}; +#if CYTHON_USE_TYPE_SPECS +static PyType_Slot __pyx_type___pyx_memoryviewslice_slots[] = { + {Py_tp_dealloc, (void *)__pyx_tp_dealloc__memoryviewslice}, + {Py_tp_doc, (void *)PyDoc_STR("Internal class for passing memoryview slices to Python")}, + {Py_tp_traverse, (void *)__pyx_tp_traverse__memoryviewslice}, + {Py_tp_clear, (void *)__pyx_tp_clear__memoryviewslice}, + {Py_tp_methods, (void *)__pyx_methods__memoryviewslice}, + {Py_tp_new, (void *)__pyx_tp_new__memoryviewslice}, + {0, 0}, +}; +static PyType_Spec __pyx_type___pyx_memoryviewslice_spec = { + "fairseq.data.token_block_utils_fast._memoryviewslice", + sizeof(struct __pyx_memoryviewslice_obj), + 0, + Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC|Py_TPFLAGS_SEQUENCE, + __pyx_type___pyx_memoryviewslice_slots, +}; +#else + +static PyTypeObject __pyx_type___pyx_memoryviewslice = { + PyVarObject_HEAD_INIT(0, 0) + "fairseq.data.token_block_utils_fast.""_memoryviewslice", /*tp_name*/ + sizeof(struct __pyx_memoryviewslice_obj), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + __pyx_tp_dealloc__memoryviewslice, /*tp_dealloc*/ + #if PY_VERSION_HEX < 0x030800b4 + 0, /*tp_print*/ + #endif + #if PY_VERSION_HEX >= 0x030800b4 + 0, /*tp_vectorcall_offset*/ + #endif + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + #if PY_MAJOR_VERSION < 3 + 0, /*tp_compare*/ + #endif + #if PY_MAJOR_VERSION >= 3 + 0, /*tp_as_async*/ + #endif + #if CYTHON_COMPILING_IN_PYPY || 0 + __pyx_memoryview___repr__, /*tp_repr*/ + #else + 0, /*tp_repr*/ + #endif + 0, /*tp_as_number*/ + 0, /*tp_as_sequence*/ + 0, /*tp_as_mapping*/ + 0, /*tp_hash*/ + 0, /*tp_call*/ + #if CYTHON_COMPILING_IN_PYPY || 0 + __pyx_memoryview___str__, /*tp_str*/ + #else + 0, /*tp_str*/ + #endif + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + 0, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC|Py_TPFLAGS_SEQUENCE, /*tp_flags*/ + PyDoc_STR("Internal class for passing memoryview slices to Python"), /*tp_doc*/ + __pyx_tp_traverse__memoryviewslice, /*tp_traverse*/ + __pyx_tp_clear__memoryviewslice, /*tp_clear*/ + 0, /*tp_richcompare*/ + 0, /*tp_weaklistoffset*/ + 0, /*tp_iter*/ + 0, /*tp_iternext*/ + __pyx_methods__memoryviewslice, /*tp_methods*/ + 0, /*tp_members*/ + 0, /*tp_getset*/ + 0, /*tp_base*/ + 0, /*tp_dict*/ + 0, /*tp_descr_get*/ + 0, /*tp_descr_set*/ + #if !CYTHON_USE_TYPE_SPECS + 0, /*tp_dictoffset*/ + #endif + 0, /*tp_init*/ + 0, /*tp_alloc*/ + __pyx_tp_new__memoryviewslice, /*tp_new*/ + 0, /*tp_free*/ + 0, /*tp_is_gc*/ + 0, /*tp_bases*/ + 0, /*tp_mro*/ + 0, /*tp_cache*/ + 0, /*tp_subclasses*/ + 0, /*tp_weaklist*/ + 0, /*tp_del*/ + 0, /*tp_version_tag*/ + #if PY_VERSION_HEX >= 0x030400a1 + #if CYTHON_USE_TP_FINALIZE + 0, /*tp_finalize*/ + #else + NULL, /*tp_finalize*/ + #endif + #endif + #if PY_VERSION_HEX >= 0x030800b1 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07030800) + 0, /*tp_vectorcall*/ + #endif + #if __PYX_NEED_TP_PRINT_SLOT == 1 + 0, /*tp_print*/ + #endif + #if PY_VERSION_HEX >= 0x030C0000 + 0, /*tp_watched*/ + #endif + #if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX >= 0x03090000 && PY_VERSION_HEX < 0x030a0000 + 0, /*tp_pypy_flags*/ + #endif +}; +#endif + +static PyMethodDef __pyx_methods[] = { + {0, 0, 0, 0} +}; +#ifndef CYTHON_SMALL_CODE +#if defined(__clang__) + #define CYTHON_SMALL_CODE +#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 3)) + #define CYTHON_SMALL_CODE __attribute__((cold)) +#else + #define CYTHON_SMALL_CODE +#endif +#endif +/* #### Code section: pystring_table ### */ + +static int __Pyx_CreateStringTabAndInitStrings(void) { + __Pyx_StringTabEntry __pyx_string_tab[] = { + {&__pyx_kp_u_, __pyx_k_, sizeof(__pyx_k_), 0, 1, 0, 0}, + {&__pyx_n_s_ASCII, __pyx_k_ASCII, sizeof(__pyx_k_ASCII), 0, 0, 1, 1}, + {&__pyx_kp_s_All_dimensions_preceding_dimensi, __pyx_k_All_dimensions_preceding_dimensi, sizeof(__pyx_k_All_dimensions_preceding_dimensi), 0, 0, 1, 0}, + {&__pyx_n_s_AssertionError, __pyx_k_AssertionError, sizeof(__pyx_k_AssertionError), 0, 0, 1, 1}, + {&__pyx_kp_s_Buffer_view_does_not_expose_stri, __pyx_k_Buffer_view_does_not_expose_stri, sizeof(__pyx_k_Buffer_view_does_not_expose_stri), 0, 0, 1, 0}, + {&__pyx_kp_s_Can_only_create_a_buffer_that_is, __pyx_k_Can_only_create_a_buffer_that_is, sizeof(__pyx_k_Can_only_create_a_buffer_that_is), 0, 0, 1, 0}, + {&__pyx_kp_s_Cannot_assign_to_read_only_memor, __pyx_k_Cannot_assign_to_read_only_memor, sizeof(__pyx_k_Cannot_assign_to_read_only_memor), 0, 0, 1, 0}, + {&__pyx_kp_s_Cannot_create_writable_memory_vi, __pyx_k_Cannot_create_writable_memory_vi, sizeof(__pyx_k_Cannot_create_writable_memory_vi), 0, 0, 1, 0}, + {&__pyx_kp_u_Cannot_index_with_type, __pyx_k_Cannot_index_with_type, sizeof(__pyx_k_Cannot_index_with_type), 0, 1, 0, 0}, + {&__pyx_kp_s_Cannot_transpose_memoryview_with, __pyx_k_Cannot_transpose_memoryview_with, sizeof(__pyx_k_Cannot_transpose_memoryview_with), 0, 0, 1, 0}, + {&__pyx_n_s_DTYPE, __pyx_k_DTYPE, sizeof(__pyx_k_DTYPE), 0, 0, 1, 1}, + {&__pyx_n_s_DatasetSearcher, __pyx_k_DatasetSearcher, sizeof(__pyx_k_DatasetSearcher), 0, 0, 1, 1}, + {&__pyx_n_s_DatasetSearcher___reduce_cython, __pyx_k_DatasetSearcher___reduce_cython, sizeof(__pyx_k_DatasetSearcher___reduce_cython), 0, 0, 1, 1}, + {&__pyx_n_s_DatasetSearcher___setstate_cytho, __pyx_k_DatasetSearcher___setstate_cytho, sizeof(__pyx_k_DatasetSearcher___setstate_cytho), 0, 0, 1, 1}, + {&__pyx_kp_s_Dimension_d_is_not_direct, __pyx_k_Dimension_d_is_not_direct, sizeof(__pyx_k_Dimension_d_is_not_direct), 0, 0, 1, 0}, + {&__pyx_n_s_Ellipsis, __pyx_k_Ellipsis, sizeof(__pyx_k_Ellipsis), 0, 0, 1, 1}, + {&__pyx_kp_s_Empty_shape_tuple_for_cython_arr, __pyx_k_Empty_shape_tuple_for_cython_arr, sizeof(__pyx_k_Empty_shape_tuple_for_cython_arr), 0, 0, 1, 0}, + {&__pyx_n_s_ImportError, __pyx_k_ImportError, sizeof(__pyx_k_ImportError), 0, 0, 1, 1}, + {&__pyx_kp_s_Incompatible_checksums_0x_x_vs_0, __pyx_k_Incompatible_checksums_0x_x_vs_0, sizeof(__pyx_k_Incompatible_checksums_0x_x_vs_0), 0, 0, 1, 0}, + {&__pyx_kp_s_Incompatible_checksums_0x_x_vs_0_2, __pyx_k_Incompatible_checksums_0x_x_vs_0_2, sizeof(__pyx_k_Incompatible_checksums_0x_x_vs_0_2), 0, 0, 1, 0}, + {&__pyx_n_s_IndexError, __pyx_k_IndexError, sizeof(__pyx_k_IndexError), 0, 0, 1, 1}, + {&__pyx_kp_s_Index_out_of_bounds_axis_d, __pyx_k_Index_out_of_bounds_axis_d, sizeof(__pyx_k_Index_out_of_bounds_axis_d), 0, 0, 1, 0}, + {&__pyx_kp_s_Indirect_dimensions_not_supporte, __pyx_k_Indirect_dimensions_not_supporte, sizeof(__pyx_k_Indirect_dimensions_not_supporte), 0, 0, 1, 0}, + {&__pyx_kp_u_Invalid_break_mode, __pyx_k_Invalid_break_mode, sizeof(__pyx_k_Invalid_break_mode), 0, 1, 0, 0}, + {&__pyx_kp_u_Invalid_mode_expected_c_or_fortr, __pyx_k_Invalid_mode_expected_c_or_fortr, sizeof(__pyx_k_Invalid_mode_expected_c_or_fortr), 0, 1, 0, 0}, + {&__pyx_kp_u_Invalid_shape_in_axis, __pyx_k_Invalid_shape_in_axis, sizeof(__pyx_k_Invalid_shape_in_axis), 0, 1, 0, 0}, + {&__pyx_n_s_MemoryError, __pyx_k_MemoryError, sizeof(__pyx_k_MemoryError), 0, 0, 1, 1}, + {&__pyx_kp_s_MemoryView_of_r_at_0x_x, __pyx_k_MemoryView_of_r_at_0x_x, sizeof(__pyx_k_MemoryView_of_r_at_0x_x), 0, 0, 1, 0}, + {&__pyx_kp_s_MemoryView_of_r_object, __pyx_k_MemoryView_of_r_object, sizeof(__pyx_k_MemoryView_of_r_object), 0, 0, 1, 0}, + {&__pyx_n_b_O, __pyx_k_O, sizeof(__pyx_k_O), 0, 0, 0, 1}, + {&__pyx_kp_u_Out_of_bounds_on_buffer_access_a, __pyx_k_Out_of_bounds_on_buffer_access_a, sizeof(__pyx_k_Out_of_bounds_on_buffer_access_a), 0, 1, 0, 0}, + {&__pyx_n_s_PickleError, __pyx_k_PickleError, sizeof(__pyx_k_PickleError), 0, 0, 1, 1}, + {&__pyx_n_s_Sequence, __pyx_k_Sequence, sizeof(__pyx_k_Sequence), 0, 0, 1, 1}, + {&__pyx_kp_s_Step_may_not_be_zero_axis_d, __pyx_k_Step_may_not_be_zero_axis_d, sizeof(__pyx_k_Step_may_not_be_zero_axis_d), 0, 0, 1, 0}, + {&__pyx_n_s_TypeError, __pyx_k_TypeError, sizeof(__pyx_k_TypeError), 0, 0, 1, 1}, + {&__pyx_kp_s_Unable_to_convert_item_to_object, __pyx_k_Unable_to_convert_item_to_object, sizeof(__pyx_k_Unable_to_convert_item_to_object), 0, 0, 1, 0}, + {&__pyx_n_s_ValueError, __pyx_k_ValueError, sizeof(__pyx_k_ValueError), 0, 0, 1, 1}, + {&__pyx_n_s_View_MemoryView, __pyx_k_View_MemoryView, sizeof(__pyx_k_View_MemoryView), 0, 0, 1, 1}, + {&__pyx_kp_u__2, __pyx_k__2, sizeof(__pyx_k__2), 0, 1, 0, 0}, + {&__pyx_n_s__3, __pyx_k__3, sizeof(__pyx_k__3), 0, 0, 1, 1}, + {&__pyx_n_s__35, __pyx_k__35, sizeof(__pyx_k__35), 0, 0, 1, 1}, + {&__pyx_kp_u__6, __pyx_k__6, sizeof(__pyx_k__6), 0, 1, 0, 0}, + {&__pyx_kp_u__7, __pyx_k__7, sizeof(__pyx_k__7), 0, 1, 0, 0}, + {&__pyx_n_s_abc, __pyx_k_abc, sizeof(__pyx_k_abc), 0, 0, 1, 1}, + {&__pyx_n_s_allocate_buffer, __pyx_k_allocate_buffer, sizeof(__pyx_k_allocate_buffer), 0, 0, 1, 1}, + {&__pyx_kp_u_and, __pyx_k_and, sizeof(__pyx_k_and), 0, 1, 0, 0}, + {&__pyx_n_s_asyncio_coroutines, __pyx_k_asyncio_coroutines, sizeof(__pyx_k_asyncio_coroutines), 0, 0, 1, 1}, + {&__pyx_n_s_axis, __pyx_k_axis, sizeof(__pyx_k_axis), 0, 0, 1, 1}, + {&__pyx_n_s_base, __pyx_k_base, sizeof(__pyx_k_base), 0, 0, 1, 1}, + {&__pyx_n_s_block_size, __pyx_k_block_size, sizeof(__pyx_k_block_size), 0, 0, 1, 1}, + {&__pyx_n_s_break_mode, __pyx_k_break_mode, sizeof(__pyx_k_break_mode), 0, 0, 1, 1}, + {&__pyx_n_s_c, __pyx_k_c, sizeof(__pyx_k_c), 0, 0, 1, 1}, + {&__pyx_n_u_c, __pyx_k_c, sizeof(__pyx_k_c), 0, 1, 0, 1}, + {&__pyx_n_s_chain, __pyx_k_chain, sizeof(__pyx_k_chain), 0, 0, 1, 1}, + {&__pyx_n_s_class, __pyx_k_class, sizeof(__pyx_k_class), 0, 0, 1, 1}, + {&__pyx_n_s_class_getitem, __pyx_k_class_getitem, sizeof(__pyx_k_class_getitem), 0, 0, 1, 1}, + {&__pyx_n_s_cline_in_traceback, __pyx_k_cline_in_traceback, sizeof(__pyx_k_cline_in_traceback), 0, 0, 1, 1}, + {&__pyx_n_s_collections, __pyx_k_collections, sizeof(__pyx_k_collections), 0, 0, 1, 1}, + {&__pyx_kp_s_collections_abc, __pyx_k_collections_abc, sizeof(__pyx_k_collections_abc), 0, 0, 1, 0}, + {&__pyx_n_u_complete, __pyx_k_complete, sizeof(__pyx_k_complete), 0, 1, 0, 1}, + {&__pyx_n_u_complete_doc, __pyx_k_complete_doc, sizeof(__pyx_k_complete_doc), 0, 1, 0, 1}, + {&__pyx_kp_s_contiguous_and_direct, __pyx_k_contiguous_and_direct, sizeof(__pyx_k_contiguous_and_direct), 0, 0, 1, 0}, + {&__pyx_kp_s_contiguous_and_indirect, __pyx_k_contiguous_and_indirect, sizeof(__pyx_k_contiguous_and_indirect), 0, 0, 1, 0}, + {&__pyx_n_s_count, __pyx_k_count, sizeof(__pyx_k_count), 0, 0, 1, 1}, + {&__pyx_n_s_cumsum, __pyx_k_cumsum, sizeof(__pyx_k_cumsum), 0, 0, 1, 1}, + {&__pyx_n_s_dict, __pyx_k_dict, sizeof(__pyx_k_dict), 0, 0, 1, 1}, + {&__pyx_n_s_dict_2, __pyx_k_dict_2, sizeof(__pyx_k_dict_2), 0, 0, 1, 1}, + {&__pyx_kp_u_disable, __pyx_k_disable, sizeof(__pyx_k_disable), 0, 1, 0, 0}, + {&__pyx_n_s_document_sep_len, __pyx_k_document_sep_len, sizeof(__pyx_k_document_sep_len), 0, 0, 1, 1}, + {&__pyx_n_s_dtype, __pyx_k_dtype, sizeof(__pyx_k_dtype), 0, 0, 1, 1}, + {&__pyx_n_s_dtype_is_object, __pyx_k_dtype_is_object, sizeof(__pyx_k_dtype_is_object), 0, 0, 1, 1}, + {&__pyx_kp_u_enable, __pyx_k_enable, sizeof(__pyx_k_enable), 0, 1, 0, 0}, + {&__pyx_n_s_encode, __pyx_k_encode, sizeof(__pyx_k_encode), 0, 0, 1, 1}, + {&__pyx_n_s_enumerate, __pyx_k_enumerate, sizeof(__pyx_k_enumerate), 0, 0, 1, 1}, + {&__pyx_n_u_eos, __pyx_k_eos, sizeof(__pyx_k_eos), 0, 1, 0, 1}, + {&__pyx_n_s_error, __pyx_k_error, sizeof(__pyx_k_error), 0, 0, 1, 1}, + {&__pyx_kp_s_fairseq_data_token_block_utils_f, __pyx_k_fairseq_data_token_block_utils_f, sizeof(__pyx_k_fairseq_data_token_block_utils_f), 0, 0, 1, 0}, + {&__pyx_n_s_fairseq_data_token_block_utils_f_2, __pyx_k_fairseq_data_token_block_utils_f_2, sizeof(__pyx_k_fairseq_data_token_block_utils_f_2), 0, 0, 1, 1}, + {&__pyx_n_s_flags, __pyx_k_flags, sizeof(__pyx_k_flags), 0, 0, 1, 1}, + {&__pyx_n_s_format, __pyx_k_format, sizeof(__pyx_k_format), 0, 0, 1, 1}, + {&__pyx_n_s_fortran, __pyx_k_fortran, sizeof(__pyx_k_fortran), 0, 0, 1, 1}, + {&__pyx_n_u_fortran, __pyx_k_fortran, sizeof(__pyx_k_fortran), 0, 1, 0, 1}, + {&__pyx_n_s_from_iterable, __pyx_k_from_iterable, sizeof(__pyx_k_from_iterable), 0, 0, 1, 1}, + {&__pyx_n_s_fromiter, __pyx_k_fromiter, sizeof(__pyx_k_fromiter), 0, 0, 1, 1}, + {&__pyx_kp_u_gc, __pyx_k_gc, sizeof(__pyx_k_gc), 0, 1, 0, 0}, + {&__pyx_n_s_get_block_to_dataset_index_fast, __pyx_k_get_block_to_dataset_index_fast, sizeof(__pyx_k_get_block_to_dataset_index_fast), 0, 0, 1, 1}, + {&__pyx_n_s_get_slice_indices_fast, __pyx_k_get_slice_indices_fast, sizeof(__pyx_k_get_slice_indices_fast), 0, 0, 1, 1}, + {&__pyx_n_s_getstate, __pyx_k_getstate, sizeof(__pyx_k_getstate), 0, 0, 1, 1}, + {&__pyx_kp_u_got, __pyx_k_got, sizeof(__pyx_k_got), 0, 1, 0, 0}, + {&__pyx_kp_u_got_differing_extents_in_dimensi, __pyx_k_got_differing_extents_in_dimensi, sizeof(__pyx_k_got_differing_extents_in_dimensi), 0, 1, 0, 0}, + {&__pyx_n_s_id, __pyx_k_id, sizeof(__pyx_k_id), 0, 0, 1, 1}, + {&__pyx_n_s_import, __pyx_k_import, sizeof(__pyx_k_import), 0, 0, 1, 1}, + {&__pyx_n_s_index, __pyx_k_index, sizeof(__pyx_k_index), 0, 0, 1, 1}, + {&__pyx_n_s_initializing, __pyx_k_initializing, sizeof(__pyx_k_initializing), 0, 0, 1, 1}, + {&__pyx_n_s_int64, __pyx_k_int64, sizeof(__pyx_k_int64), 0, 0, 1, 1}, + {&__pyx_n_s_is_coroutine, __pyx_k_is_coroutine, sizeof(__pyx_k_is_coroutine), 0, 0, 1, 1}, + {&__pyx_kp_u_isenabled, __pyx_k_isenabled, sizeof(__pyx_k_isenabled), 0, 1, 0, 0}, + {&__pyx_n_s_itemsize, __pyx_k_itemsize, sizeof(__pyx_k_itemsize), 0, 0, 1, 1}, + {&__pyx_kp_s_itemsize_0_for_cython_array, __pyx_k_itemsize_0_for_cython_array, sizeof(__pyx_k_itemsize_0_for_cython_array), 0, 0, 1, 0}, + {&__pyx_n_s_itertools, __pyx_k_itertools, sizeof(__pyx_k_itertools), 0, 0, 1, 1}, + {&__pyx_n_s_main, __pyx_k_main, sizeof(__pyx_k_main), 0, 0, 1, 1}, + {&__pyx_n_s_memview, __pyx_k_memview, sizeof(__pyx_k_memview), 0, 0, 1, 1}, + {&__pyx_n_s_mode, __pyx_k_mode, sizeof(__pyx_k_mode), 0, 0, 1, 1}, + {&__pyx_n_s_name, __pyx_k_name, sizeof(__pyx_k_name), 0, 0, 1, 1}, + {&__pyx_n_s_name_2, __pyx_k_name_2, sizeof(__pyx_k_name_2), 0, 0, 1, 1}, + {&__pyx_n_s_ndim, __pyx_k_ndim, sizeof(__pyx_k_ndim), 0, 0, 1, 1}, + {&__pyx_n_s_new, __pyx_k_new, sizeof(__pyx_k_new), 0, 0, 1, 1}, + {&__pyx_kp_s_no_default___reduce___due_to_non, __pyx_k_no_default___reduce___due_to_non, sizeof(__pyx_k_no_default___reduce___due_to_non), 0, 0, 1, 0}, + {&__pyx_n_u_none, __pyx_k_none, sizeof(__pyx_k_none), 0, 1, 0, 1}, + {&__pyx_n_s_np, __pyx_k_np, sizeof(__pyx_k_np), 0, 0, 1, 1}, + {&__pyx_n_s_numpy, __pyx_k_numpy, sizeof(__pyx_k_numpy), 0, 0, 1, 1}, + {&__pyx_kp_u_numpy_core_multiarray_failed_to, __pyx_k_numpy_core_multiarray_failed_to, sizeof(__pyx_k_numpy_core_multiarray_failed_to), 0, 1, 0, 0}, + {&__pyx_kp_u_numpy_core_umath_failed_to_impor, __pyx_k_numpy_core_umath_failed_to_impor, sizeof(__pyx_k_numpy_core_umath_failed_to_impor), 0, 1, 0, 0}, + {&__pyx_n_s_obj, __pyx_k_obj, sizeof(__pyx_k_obj), 0, 0, 1, 1}, + {&__pyx_n_s_pack, __pyx_k_pack, sizeof(__pyx_k_pack), 0, 0, 1, 1}, + {&__pyx_n_s_pickle, __pyx_k_pickle, sizeof(__pyx_k_pickle), 0, 0, 1, 1}, + {&__pyx_n_s_pyx_PickleError, __pyx_k_pyx_PickleError, sizeof(__pyx_k_pyx_PickleError), 0, 0, 1, 1}, + {&__pyx_n_s_pyx_checksum, __pyx_k_pyx_checksum, sizeof(__pyx_k_pyx_checksum), 0, 0, 1, 1}, + {&__pyx_n_s_pyx_result, __pyx_k_pyx_result, sizeof(__pyx_k_pyx_result), 0, 0, 1, 1}, + {&__pyx_n_s_pyx_state, __pyx_k_pyx_state, sizeof(__pyx_k_pyx_state), 0, 0, 1, 1}, + {&__pyx_n_s_pyx_type, __pyx_k_pyx_type, sizeof(__pyx_k_pyx_type), 0, 0, 1, 1}, + {&__pyx_n_s_pyx_unpickle_DatasetSearcher, __pyx_k_pyx_unpickle_DatasetSearcher, sizeof(__pyx_k_pyx_unpickle_DatasetSearcher), 0, 0, 1, 1}, + {&__pyx_n_s_pyx_unpickle_Enum, __pyx_k_pyx_unpickle_Enum, sizeof(__pyx_k_pyx_unpickle_Enum), 0, 0, 1, 1}, + {&__pyx_n_s_pyx_vtable, __pyx_k_pyx_vtable, sizeof(__pyx_k_pyx_vtable), 0, 0, 1, 1}, + {&__pyx_n_s_range, __pyx_k_range, sizeof(__pyx_k_range), 0, 0, 1, 1}, + {&__pyx_n_s_reduce, __pyx_k_reduce, sizeof(__pyx_k_reduce), 0, 0, 1, 1}, + {&__pyx_n_s_reduce_cython, __pyx_k_reduce_cython, sizeof(__pyx_k_reduce_cython), 0, 0, 1, 1}, + {&__pyx_n_s_reduce_ex, __pyx_k_reduce_ex, sizeof(__pyx_k_reduce_ex), 0, 0, 1, 1}, + {&__pyx_n_s_register, __pyx_k_register, sizeof(__pyx_k_register), 0, 0, 1, 1}, + {&__pyx_n_s_reshape, __pyx_k_reshape, sizeof(__pyx_k_reshape), 0, 0, 1, 1}, + {&__pyx_n_s_self, __pyx_k_self, sizeof(__pyx_k_self), 0, 0, 1, 1}, + {&__pyx_n_s_setstate, __pyx_k_setstate, sizeof(__pyx_k_setstate), 0, 0, 1, 1}, + {&__pyx_n_s_setstate_cython, __pyx_k_setstate_cython, sizeof(__pyx_k_setstate_cython), 0, 0, 1, 1}, + {&__pyx_n_s_shape, __pyx_k_shape, sizeof(__pyx_k_shape), 0, 0, 1, 1}, + {&__pyx_n_s_size, __pyx_k_size, sizeof(__pyx_k_size), 0, 0, 1, 1}, + {&__pyx_n_s_sizes, __pyx_k_sizes, sizeof(__pyx_k_sizes), 0, 0, 1, 1}, + {&__pyx_n_s_slice_indices, __pyx_k_slice_indices, sizeof(__pyx_k_slice_indices), 0, 0, 1, 1}, + {&__pyx_n_s_spec, __pyx_k_spec, sizeof(__pyx_k_spec), 0, 0, 1, 1}, + {&__pyx_n_s_start, __pyx_k_start, sizeof(__pyx_k_start), 0, 0, 1, 1}, + {&__pyx_n_s_state, __pyx_k_state, sizeof(__pyx_k_state), 0, 0, 1, 1}, + {&__pyx_n_s_step, __pyx_k_step, sizeof(__pyx_k_step), 0, 0, 1, 1}, + {&__pyx_n_s_stop, __pyx_k_stop, sizeof(__pyx_k_stop), 0, 0, 1, 1}, + {&__pyx_kp_s_strided_and_direct, __pyx_k_strided_and_direct, sizeof(__pyx_k_strided_and_direct), 0, 0, 1, 0}, + {&__pyx_kp_s_strided_and_direct_or_indirect, __pyx_k_strided_and_direct_or_indirect, sizeof(__pyx_k_strided_and_direct_or_indirect), 0, 0, 1, 0}, + {&__pyx_kp_s_strided_and_indirect, __pyx_k_strided_and_indirect, sizeof(__pyx_k_strided_and_indirect), 0, 0, 1, 0}, + {&__pyx_kp_s_stringsource, __pyx_k_stringsource, sizeof(__pyx_k_stringsource), 0, 0, 1, 0}, + {&__pyx_n_s_struct, __pyx_k_struct, sizeof(__pyx_k_struct), 0, 0, 1, 1}, + {&__pyx_n_s_sum, __pyx_k_sum, sizeof(__pyx_k_sum), 0, 0, 1, 1}, + {&__pyx_n_s_sys, __pyx_k_sys, sizeof(__pyx_k_sys), 0, 0, 1, 1}, + {&__pyx_n_s_test, __pyx_k_test, sizeof(__pyx_k_test), 0, 0, 1, 1}, + {&__pyx_n_s_torch, __pyx_k_torch, sizeof(__pyx_k_torch), 0, 0, 1, 1}, + {&__pyx_kp_s_unable_to_allocate_array_data, __pyx_k_unable_to_allocate_array_data, sizeof(__pyx_k_unable_to_allocate_array_data), 0, 0, 1, 0}, + {&__pyx_kp_s_unable_to_allocate_shape_and_str, __pyx_k_unable_to_allocate_shape_and_str, sizeof(__pyx_k_unable_to_allocate_shape_and_str), 0, 0, 1, 0}, + {&__pyx_n_s_unpack, __pyx_k_unpack, sizeof(__pyx_k_unpack), 0, 0, 1, 1}, + {&__pyx_n_s_update, __pyx_k_update, sizeof(__pyx_k_update), 0, 0, 1, 1}, + {&__pyx_n_s_use_setstate, __pyx_k_use_setstate, sizeof(__pyx_k_use_setstate), 0, 0, 1, 1}, + {&__pyx_n_s_version_info, __pyx_k_version_info, sizeof(__pyx_k_version_info), 0, 0, 1, 1}, + {&__pyx_n_s_zeros, __pyx_k_zeros, sizeof(__pyx_k_zeros), 0, 0, 1, 1}, + {0, 0, 0, 0, 0, 0, 0} + }; + return __Pyx_InitStrings(__pyx_string_tab); +} +/* #### Code section: cached_builtins ### */ +static CYTHON_SMALL_CODE int __Pyx_InitCachedBuiltins(void) { + __pyx_builtin_range = __Pyx_GetBuiltinName(__pyx_n_s_range); if (!__pyx_builtin_range) __PYX_ERR(0, 32, __pyx_L1_error) + __pyx_builtin_ValueError = __Pyx_GetBuiltinName(__pyx_n_s_ValueError); if (!__pyx_builtin_ValueError) __PYX_ERR(0, 104, __pyx_L1_error) + __pyx_builtin_AssertionError = __Pyx_GetBuiltinName(__pyx_n_s_AssertionError); if (!__pyx_builtin_AssertionError) __PYX_ERR(0, 173, __pyx_L1_error) + __pyx_builtin___import__ = __Pyx_GetBuiltinName(__pyx_n_s_import); if (!__pyx_builtin___import__) __PYX_ERR(1, 100, __pyx_L1_error) + __pyx_builtin_MemoryError = __Pyx_GetBuiltinName(__pyx_n_s_MemoryError); if (!__pyx_builtin_MemoryError) __PYX_ERR(1, 156, __pyx_L1_error) + __pyx_builtin_enumerate = __Pyx_GetBuiltinName(__pyx_n_s_enumerate); if (!__pyx_builtin_enumerate) __PYX_ERR(1, 159, __pyx_L1_error) + __pyx_builtin_TypeError = __Pyx_GetBuiltinName(__pyx_n_s_TypeError); if (!__pyx_builtin_TypeError) __PYX_ERR(1, 2, __pyx_L1_error) + __pyx_builtin_Ellipsis = __Pyx_GetBuiltinName(__pyx_n_s_Ellipsis); if (!__pyx_builtin_Ellipsis) __PYX_ERR(1, 408, __pyx_L1_error) + __pyx_builtin_id = __Pyx_GetBuiltinName(__pyx_n_s_id); if (!__pyx_builtin_id) __PYX_ERR(1, 618, __pyx_L1_error) + __pyx_builtin_IndexError = __Pyx_GetBuiltinName(__pyx_n_s_IndexError); if (!__pyx_builtin_IndexError) __PYX_ERR(1, 914, __pyx_L1_error) + __pyx_builtin_ImportError = __Pyx_GetBuiltinName(__pyx_n_s_ImportError); if (!__pyx_builtin_ImportError) __PYX_ERR(2, 984, __pyx_L1_error) + return 0; + __pyx_L1_error:; + return -1; +} +/* #### Code section: cached_constants ### */ + +static CYTHON_SMALL_CODE int __Pyx_InitCachedConstants(void) { + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__Pyx_InitCachedConstants", 0); + + /* "View.MemoryView":582 + * def suboffsets(self): + * if self.view.suboffsets == NULL: + * return (-1,) * self.view.ndim # <<<<<<<<<<<<<< + * + * return tuple([suboffset for suboffset in self.view.suboffsets[:self.view.ndim]]) + */ + __pyx_tuple__4 = PyTuple_New(1); if (unlikely(!__pyx_tuple__4)) __PYX_ERR(1, 582, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__4); + __Pyx_INCREF(__pyx_int_neg_1); + __Pyx_GIVEREF(__pyx_int_neg_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_tuple__4, 0, __pyx_int_neg_1)) __PYX_ERR(1, 582, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_tuple__4); + + /* "View.MemoryView":679 + * tup = index if isinstance(index, tuple) else (index,) + * + * result = [slice(None)] * ndim # <<<<<<<<<<<<<< + * have_slices = False + * seen_ellipsis = False + */ + __pyx_slice__5 = PySlice_New(Py_None, Py_None, Py_None); if (unlikely(!__pyx_slice__5)) __PYX_ERR(1, 679, __pyx_L1_error) + __Pyx_GOTREF(__pyx_slice__5); + __Pyx_GIVEREF(__pyx_slice__5); + + /* "(tree fragment)":4 + * cdef object __pyx_PickleError + * cdef object __pyx_result + * if __pyx_checksum not in (0x82a3537, 0x6ae9995, 0xb068931): # <<<<<<<<<<<<<< + * from pickle import PickleError as __pyx_PickleError + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x82a3537, 0x6ae9995, 0xb068931) = (name))" % __pyx_checksum + */ + __pyx_tuple__8 = PyTuple_Pack(3, __pyx_int_136983863, __pyx_int_112105877, __pyx_int_184977713); if (unlikely(!__pyx_tuple__8)) __PYX_ERR(1, 4, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__8); + __Pyx_GIVEREF(__pyx_tuple__8); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":984 + * __pyx_import_array() + * except Exception: + * raise ImportError("numpy.core.multiarray failed to import") # <<<<<<<<<<<<<< + * + * cdef inline int import_umath() except -1: + */ + __pyx_tuple__9 = PyTuple_Pack(1, __pyx_kp_u_numpy_core_multiarray_failed_to); if (unlikely(!__pyx_tuple__9)) __PYX_ERR(2, 984, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__9); + __Pyx_GIVEREF(__pyx_tuple__9); + + /* "../../../../../../../../../../tmp/pip-build-env-cp2qouzo/overlay/lib/python3.10/site-packages/numpy/__init__.cython-30.pxd":990 + * _import_umath() + * except Exception: + * raise ImportError("numpy.core.umath failed to import") # <<<<<<<<<<<<<< + * + * cdef inline int import_ufunc() except -1: + */ + __pyx_tuple__10 = PyTuple_Pack(1, __pyx_kp_u_numpy_core_umath_failed_to_impor); if (unlikely(!__pyx_tuple__10)) __PYX_ERR(2, 990, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__10); + __Pyx_GIVEREF(__pyx_tuple__10); + + /* "fairseq/data/token_block_utils_fast.pyx":101 + * slice_indices = np.zeros((len(sizes), 2), dtype=DTYPE) + * cumsum = sizes.cumsum(axis=0) + * slice_indices[1:, 0] = cumsum[:cumsum.shape[0] - 1] # <<<<<<<<<<<<<< + * slice_indices[:, 1] = cumsum + * else: + */ + __pyx_slice__11 = PySlice_New(__pyx_int_1, Py_None, Py_None); if (unlikely(!__pyx_slice__11)) __PYX_ERR(0, 101, __pyx_L1_error) + __Pyx_GOTREF(__pyx_slice__11); + __Pyx_GIVEREF(__pyx_slice__11); + __pyx_tuple__12 = PyTuple_Pack(2, __pyx_slice__11, __pyx_int_0); if (unlikely(!__pyx_tuple__12)) __PYX_ERR(0, 101, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__12); + __Pyx_GIVEREF(__pyx_tuple__12); + + /* "fairseq/data/token_block_utils_fast.pyx":102 + * cumsum = sizes.cumsum(axis=0) + * slice_indices[1:, 0] = cumsum[:cumsum.shape[0] - 1] + * slice_indices[:, 1] = cumsum # <<<<<<<<<<<<<< + * else: + * raise ValueError('Invalid break_mode: ' + break_mode) + */ + __pyx_tuple__13 = PyTuple_Pack(2, __pyx_slice__5, __pyx_int_1); if (unlikely(!__pyx_tuple__13)) __PYX_ERR(0, 102, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__13); + __Pyx_GIVEREF(__pyx_tuple__13); + + /* "(tree fragment)":4 + * cdef object __pyx_PickleError + * cdef object __pyx_result + * if __pyx_checksum not in (0x8c67b45, 0x2e2dd22, 0x6632805): # <<<<<<<<<<<<<< + * from pickle import PickleError as __pyx_PickleError + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x8c67b45, 0x2e2dd22, 0x6632805) = (current_i, current_index, current_offset, sizes))" % __pyx_checksum + */ + __pyx_tuple__14 = PyTuple_Pack(3, __pyx_int_147225413, __pyx_int_48422178, __pyx_int_107161605); if (unlikely(!__pyx_tuple__14)) __PYX_ERR(1, 4, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__14); + __Pyx_GIVEREF(__pyx_tuple__14); + + /* "View.MemoryView":100 + * cdef object __pyx_collections_abc_Sequence "__pyx_collections_abc_Sequence" + * try: + * if __import__("sys").version_info >= (3, 3): # <<<<<<<<<<<<<< + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence + * else: + */ + __pyx_tuple__15 = PyTuple_Pack(1, __pyx_n_s_sys); if (unlikely(!__pyx_tuple__15)) __PYX_ERR(1, 100, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__15); + __Pyx_GIVEREF(__pyx_tuple__15); + __pyx_tuple__16 = PyTuple_Pack(2, __pyx_int_3, __pyx_int_3); if (unlikely(!__pyx_tuple__16)) __PYX_ERR(1, 100, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__16); + __Pyx_GIVEREF(__pyx_tuple__16); + + /* "View.MemoryView":101 + * try: + * if __import__("sys").version_info >= (3, 3): + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence # <<<<<<<<<<<<<< + * else: + * __pyx_collections_abc_Sequence = __import__("collections").Sequence + */ + __pyx_tuple__17 = PyTuple_Pack(1, __pyx_kp_s_collections_abc); if (unlikely(!__pyx_tuple__17)) __PYX_ERR(1, 101, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__17); + __Pyx_GIVEREF(__pyx_tuple__17); + + /* "View.MemoryView":103 + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence + * else: + * __pyx_collections_abc_Sequence = __import__("collections").Sequence # <<<<<<<<<<<<<< + * except: + * + */ + __pyx_tuple__18 = PyTuple_Pack(1, __pyx_n_s_collections); if (unlikely(!__pyx_tuple__18)) __PYX_ERR(1, 103, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__18); + __Pyx_GIVEREF(__pyx_tuple__18); + + /* "View.MemoryView":309 + * return self.name + * + * cdef generic = Enum("") # <<<<<<<<<<<<<< + * cdef strided = Enum("") # default + * cdef indirect = Enum("") + */ + __pyx_tuple__19 = PyTuple_Pack(1, __pyx_kp_s_strided_and_direct_or_indirect); if (unlikely(!__pyx_tuple__19)) __PYX_ERR(1, 309, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__19); + __Pyx_GIVEREF(__pyx_tuple__19); + + /* "View.MemoryView":310 + * + * cdef generic = Enum("") + * cdef strided = Enum("") # default # <<<<<<<<<<<<<< + * cdef indirect = Enum("") + * + */ + __pyx_tuple__20 = PyTuple_Pack(1, __pyx_kp_s_strided_and_direct); if (unlikely(!__pyx_tuple__20)) __PYX_ERR(1, 310, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__20); + __Pyx_GIVEREF(__pyx_tuple__20); + + /* "View.MemoryView":311 + * cdef generic = Enum("") + * cdef strided = Enum("") # default + * cdef indirect = Enum("") # <<<<<<<<<<<<<< + * + * + */ + __pyx_tuple__21 = PyTuple_Pack(1, __pyx_kp_s_strided_and_indirect); if (unlikely(!__pyx_tuple__21)) __PYX_ERR(1, 311, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__21); + __Pyx_GIVEREF(__pyx_tuple__21); + + /* "View.MemoryView":314 + * + * + * cdef contiguous = Enum("") # <<<<<<<<<<<<<< + * cdef indirect_contiguous = Enum("") + * + */ + __pyx_tuple__22 = PyTuple_Pack(1, __pyx_kp_s_contiguous_and_direct); if (unlikely(!__pyx_tuple__22)) __PYX_ERR(1, 314, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__22); + __Pyx_GIVEREF(__pyx_tuple__22); + + /* "View.MemoryView":315 + * + * cdef contiguous = Enum("") + * cdef indirect_contiguous = Enum("") # <<<<<<<<<<<<<< + * + * + */ + __pyx_tuple__23 = PyTuple_Pack(1, __pyx_kp_s_contiguous_and_indirect); if (unlikely(!__pyx_tuple__23)) __PYX_ERR(1, 315, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__23); + __Pyx_GIVEREF(__pyx_tuple__23); + + /* "(tree fragment)":1 + * def __pyx_unpickle_Enum(__pyx_type, long __pyx_checksum, __pyx_state): # <<<<<<<<<<<<<< + * cdef object __pyx_PickleError + * cdef object __pyx_result + */ + __pyx_tuple__24 = PyTuple_Pack(5, __pyx_n_s_pyx_type, __pyx_n_s_pyx_checksum, __pyx_n_s_pyx_state, __pyx_n_s_pyx_PickleError, __pyx_n_s_pyx_result); if (unlikely(!__pyx_tuple__24)) __PYX_ERR(1, 1, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__24); + __Pyx_GIVEREF(__pyx_tuple__24); + __pyx_codeobj__25 = (PyObject*)__Pyx_PyCode_New(3, 0, 0, 5, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__24, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_stringsource, __pyx_n_s_pyx_unpickle_Enum, 1, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__25)) __PYX_ERR(1, 1, __pyx_L1_error) + + /* "fairseq/data/token_block_utils_fast.pyx":52 + * @cython.wraparound(False) + * @cython.nonecheck(False) + * cpdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_fast(np.ndarray[DTYPE_t, ndim=1] sizes, str break_mode, int block_size, int document_sep_len): # <<<<<<<<<<<<<< + * cdef DTYPE_t tok_idx = 0 + * cdef DTYPE_t sz_idx = 0 + */ + __pyx_tuple__26 = PyTuple_Pack(4, __pyx_n_s_sizes, __pyx_n_s_break_mode, __pyx_n_s_block_size, __pyx_n_s_document_sep_len); if (unlikely(!__pyx_tuple__26)) __PYX_ERR(0, 52, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__26); + __Pyx_GIVEREF(__pyx_tuple__26); + __pyx_codeobj__27 = (PyObject*)__Pyx_PyCode_New(4, 0, 0, 4, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__26, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_fairseq_data_token_block_utils_f, __pyx_n_s_get_slice_indices_fast, 52, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__27)) __PYX_ERR(0, 52, __pyx_L1_error) + + /* "fairseq/data/token_block_utils_fast.pyx":111 + * @cython.wraparound(False) + * @cython.nonecheck(False) + * cpdef np.ndarray[DTYPE_t, ndim=2] _get_block_to_dataset_index_fast(np.ndarray[DTYPE_t, ndim=1] sizes, np.ndarray[DTYPE_t, ndim=2] slice_indices): # <<<<<<<<<<<<<< + * cdef DTYPE_t start_ds_idx + * cdef DTYPE_t start_offset + */ + __pyx_tuple__28 = PyTuple_Pack(2, __pyx_n_s_sizes, __pyx_n_s_slice_indices); if (unlikely(!__pyx_tuple__28)) __PYX_ERR(0, 111, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__28); + __Pyx_GIVEREF(__pyx_tuple__28); + __pyx_codeobj__29 = (PyObject*)__Pyx_PyCode_New(2, 0, 0, 2, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__28, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_fairseq_data_token_block_utils_f, __pyx_n_s_get_block_to_dataset_index_fast, 111, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__29)) __PYX_ERR(0, 111, __pyx_L1_error) + + /* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * cdef tuple state + * cdef object _dict + */ + __pyx_tuple__30 = PyTuple_Pack(4, __pyx_n_s_self, __pyx_n_s_state, __pyx_n_s_dict_2, __pyx_n_s_use_setstate); if (unlikely(!__pyx_tuple__30)) __PYX_ERR(1, 1, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__30); + __Pyx_GIVEREF(__pyx_tuple__30); + __pyx_codeobj__31 = (PyObject*)__Pyx_PyCode_New(1, 0, 0, 4, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__30, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_stringsource, __pyx_n_s_reduce_cython, 1, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__31)) __PYX_ERR(1, 1, __pyx_L1_error) + + /* "(tree fragment)":16 + * else: + * return __pyx_unpickle_DatasetSearcher, (type(self), 0x8c67b45, state) + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * __pyx_unpickle_DatasetSearcher__set_state(self, __pyx_state) + */ + __pyx_tuple__32 = PyTuple_Pack(2, __pyx_n_s_self, __pyx_n_s_pyx_state); if (unlikely(!__pyx_tuple__32)) __PYX_ERR(1, 16, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__32); + __Pyx_GIVEREF(__pyx_tuple__32); + __pyx_codeobj__33 = (PyObject*)__Pyx_PyCode_New(2, 0, 0, 2, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__32, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_stringsource, __pyx_n_s_setstate_cython, 16, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__33)) __PYX_ERR(1, 16, __pyx_L1_error) + + /* "(tree fragment)":1 + * def __pyx_unpickle_DatasetSearcher(__pyx_type, long __pyx_checksum, __pyx_state): # <<<<<<<<<<<<<< + * cdef object __pyx_PickleError + * cdef object __pyx_result + */ + __pyx_codeobj__34 = (PyObject*)__Pyx_PyCode_New(3, 0, 0, 5, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__24, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_stringsource, __pyx_n_s_pyx_unpickle_DatasetSearcher, 1, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__34)) __PYX_ERR(1, 1, __pyx_L1_error) + __Pyx_RefNannyFinishContext(); + return 0; + __pyx_L1_error:; + __Pyx_RefNannyFinishContext(); + return -1; +} +/* #### Code section: init_constants ### */ + +static CYTHON_SMALL_CODE int __Pyx_InitConstants(void) { + if (__Pyx_CreateStringTabAndInitStrings() < 0) __PYX_ERR(0, 1, __pyx_L1_error); + __pyx_int_0 = PyInt_FromLong(0); if (unlikely(!__pyx_int_0)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_int_1 = PyInt_FromLong(1); if (unlikely(!__pyx_int_1)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_int_2 = PyInt_FromLong(2); if (unlikely(!__pyx_int_2)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_int_3 = PyInt_FromLong(3); if (unlikely(!__pyx_int_3)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_int_48422178 = PyInt_FromLong(48422178L); if (unlikely(!__pyx_int_48422178)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_int_107161605 = PyInt_FromLong(107161605L); if (unlikely(!__pyx_int_107161605)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_int_112105877 = PyInt_FromLong(112105877L); if (unlikely(!__pyx_int_112105877)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_int_136983863 = PyInt_FromLong(136983863L); if (unlikely(!__pyx_int_136983863)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_int_147225413 = PyInt_FromLong(147225413L); if (unlikely(!__pyx_int_147225413)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_int_184977713 = PyInt_FromLong(184977713L); if (unlikely(!__pyx_int_184977713)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_int_neg_1 = PyInt_FromLong(-1); if (unlikely(!__pyx_int_neg_1)) __PYX_ERR(0, 1, __pyx_L1_error) + return 0; + __pyx_L1_error:; + return -1; +} +/* #### Code section: init_globals ### */ + +static CYTHON_SMALL_CODE int __Pyx_InitGlobals(void) { + /* AssertionsEnabled.init */ + if (likely(__Pyx_init_assertions_enabled() == 0)); else + +if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 1, __pyx_L1_error) + + /* NumpyImportArray.init */ + /* + * Cython has automatically inserted a call to _import_array since + * you didn't include one when you cimported numpy. To disable this + * add the line + * numpy._import_array + */ +#ifdef NPY_FEATURE_VERSION +#ifndef NO_IMPORT_ARRAY +if (unlikely(_import_array() == -1)) { + PyErr_SetString(PyExc_ImportError, "numpy.core.multiarray failed to import " + "(auto-generated because you didn't call 'numpy.import_array()' after cimporting numpy; " + "use 'numpy._import_array' to disable if you are certain you don't need it)."); +} +#endif +#endif + +if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 1, __pyx_L1_error) + + return 0; + __pyx_L1_error:; + return -1; +} +/* #### Code section: init_module ### */ + +static CYTHON_SMALL_CODE int __Pyx_modinit_global_init_code(void); /*proto*/ +static CYTHON_SMALL_CODE int __Pyx_modinit_variable_export_code(void); /*proto*/ +static CYTHON_SMALL_CODE int __Pyx_modinit_function_export_code(void); /*proto*/ +static CYTHON_SMALL_CODE int __Pyx_modinit_type_init_code(void); /*proto*/ +static CYTHON_SMALL_CODE int __Pyx_modinit_type_import_code(void); /*proto*/ +static CYTHON_SMALL_CODE int __Pyx_modinit_variable_import_code(void); /*proto*/ +static CYTHON_SMALL_CODE int __Pyx_modinit_function_import_code(void); /*proto*/ + +static int __Pyx_modinit_global_init_code(void) { + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__Pyx_modinit_global_init_code", 0); + /*--- Global init code ---*/ + __pyx_collections_abc_Sequence = Py_None; Py_INCREF(Py_None); + generic = Py_None; Py_INCREF(Py_None); + strided = Py_None; Py_INCREF(Py_None); + indirect = Py_None; Py_INCREF(Py_None); + contiguous = Py_None; Py_INCREF(Py_None); + indirect_contiguous = Py_None; Py_INCREF(Py_None); + __Pyx_RefNannyFinishContext(); + return 0; +} + +static int __Pyx_modinit_variable_export_code(void) { + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__Pyx_modinit_variable_export_code", 0); + /*--- Variable export code ---*/ + __Pyx_RefNannyFinishContext(); + return 0; +} + +static int __Pyx_modinit_function_export_code(void) { + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__Pyx_modinit_function_export_code", 0); + /*--- Function export code ---*/ + __Pyx_RefNannyFinishContext(); + return 0; +} + +static int __Pyx_modinit_type_init_code(void) { + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__Pyx_modinit_type_init_code", 0); + /*--- Type init code ---*/ + __pyx_vtabptr_7fairseq_4data_22token_block_utils_fast_DatasetSearcher = &__pyx_vtable_7fairseq_4data_22token_block_utils_fast_DatasetSearcher; + __pyx_vtable_7fairseq_4data_22token_block_utils_fast_DatasetSearcher.reset = (PyObject *(*)(struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *))__pyx_f_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_reset; + __pyx_vtable_7fairseq_4data_22token_block_utils_fast_DatasetSearcher.step = (int (*)(struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *, __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t))__pyx_f_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_step; + __pyx_vtable_7fairseq_4data_22token_block_utils_fast_DatasetSearcher.seek = (PyObject *(*)(struct __pyx_obj_7fairseq_4data_22token_block_utils_fast_DatasetSearcher *, __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t))__pyx_f_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_seek; + #if CYTHON_USE_TYPE_SPECS + __pyx_ptype_7fairseq_4data_22token_block_utils_fast_DatasetSearcher = (PyTypeObject *) __Pyx_PyType_FromModuleAndSpec(__pyx_m, &__pyx_type_7fairseq_4data_22token_block_utils_fast_DatasetSearcher_spec, NULL); if (unlikely(!__pyx_ptype_7fairseq_4data_22token_block_utils_fast_DatasetSearcher)) __PYX_ERR(0, 141, __pyx_L1_error) + if (__Pyx_fix_up_extension_type_from_spec(&__pyx_type_7fairseq_4data_22token_block_utils_fast_DatasetSearcher_spec, __pyx_ptype_7fairseq_4data_22token_block_utils_fast_DatasetSearcher) < 0) __PYX_ERR(0, 141, __pyx_L1_error) + #else + __pyx_ptype_7fairseq_4data_22token_block_utils_fast_DatasetSearcher = &__pyx_type_7fairseq_4data_22token_block_utils_fast_DatasetSearcher; + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + #endif + #if !CYTHON_USE_TYPE_SPECS + if (__Pyx_PyType_Ready(__pyx_ptype_7fairseq_4data_22token_block_utils_fast_DatasetSearcher) < 0) __PYX_ERR(0, 141, __pyx_L1_error) + #endif + #if PY_MAJOR_VERSION < 3 + __pyx_ptype_7fairseq_4data_22token_block_utils_fast_DatasetSearcher->tp_print = 0; + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + if ((CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP) && likely(!__pyx_ptype_7fairseq_4data_22token_block_utils_fast_DatasetSearcher->tp_dictoffset && __pyx_ptype_7fairseq_4data_22token_block_utils_fast_DatasetSearcher->tp_getattro == PyObject_GenericGetAttr)) { + __pyx_ptype_7fairseq_4data_22token_block_utils_fast_DatasetSearcher->tp_getattro = __Pyx_PyObject_GenericGetAttr; + } + #endif + if (__Pyx_SetVtable(__pyx_ptype_7fairseq_4data_22token_block_utils_fast_DatasetSearcher, __pyx_vtabptr_7fairseq_4data_22token_block_utils_fast_DatasetSearcher) < 0) __PYX_ERR(0, 141, __pyx_L1_error) + #if !CYTHON_COMPILING_IN_LIMITED_API + if (__Pyx_MergeVtables(__pyx_ptype_7fairseq_4data_22token_block_utils_fast_DatasetSearcher) < 0) __PYX_ERR(0, 141, __pyx_L1_error) + #endif + if (PyObject_SetAttr(__pyx_m, __pyx_n_s_DatasetSearcher, (PyObject *) __pyx_ptype_7fairseq_4data_22token_block_utils_fast_DatasetSearcher) < 0) __PYX_ERR(0, 141, __pyx_L1_error) + #if !CYTHON_COMPILING_IN_LIMITED_API + if (__Pyx_setup_reduce((PyObject *) __pyx_ptype_7fairseq_4data_22token_block_utils_fast_DatasetSearcher) < 0) __PYX_ERR(0, 141, __pyx_L1_error) + #endif + __pyx_vtabptr_array = &__pyx_vtable_array; + __pyx_vtable_array.get_memview = (PyObject *(*)(struct __pyx_array_obj *))__pyx_array_get_memview; + #if CYTHON_USE_TYPE_SPECS + __pyx_array_type = (PyTypeObject *) __Pyx_PyType_FromModuleAndSpec(__pyx_m, &__pyx_type___pyx_array_spec, NULL); if (unlikely(!__pyx_array_type)) __PYX_ERR(1, 114, __pyx_L1_error) + #if !CYTHON_COMPILING_IN_LIMITED_API + __pyx_array_type->tp_as_buffer = &__pyx_tp_as_buffer_array; + if (!__pyx_array_type->tp_as_buffer->bf_releasebuffer && __pyx_array_type->tp_base->tp_as_buffer && __pyx_array_type->tp_base->tp_as_buffer->bf_releasebuffer) { + __pyx_array_type->tp_as_buffer->bf_releasebuffer = __pyx_array_type->tp_base->tp_as_buffer->bf_releasebuffer; + } + #elif defined(Py_bf_getbuffer) && defined(Py_bf_releasebuffer) + /* PY_VERSION_HEX >= 0x03090000 || Py_LIMITED_API >= 0x030B0000 */ + #elif defined(_MSC_VER) + #pragma message ("The buffer protocol is not supported in the Limited C-API < 3.11.") + #else + #warning "The buffer protocol is not supported in the Limited C-API < 3.11." + #endif + if (__Pyx_fix_up_extension_type_from_spec(&__pyx_type___pyx_array_spec, __pyx_array_type) < 0) __PYX_ERR(1, 114, __pyx_L1_error) + #else + __pyx_array_type = &__pyx_type___pyx_array; + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + #endif + #if !CYTHON_USE_TYPE_SPECS + if (__Pyx_PyType_Ready(__pyx_array_type) < 0) __PYX_ERR(1, 114, __pyx_L1_error) + #endif + #if PY_MAJOR_VERSION < 3 + __pyx_array_type->tp_print = 0; + #endif + if (__Pyx_SetVtable(__pyx_array_type, __pyx_vtabptr_array) < 0) __PYX_ERR(1, 114, __pyx_L1_error) + #if !CYTHON_COMPILING_IN_LIMITED_API + if (__Pyx_MergeVtables(__pyx_array_type) < 0) __PYX_ERR(1, 114, __pyx_L1_error) + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + if (__Pyx_setup_reduce((PyObject *) __pyx_array_type) < 0) __PYX_ERR(1, 114, __pyx_L1_error) + #endif + #if CYTHON_USE_TYPE_SPECS + __pyx_MemviewEnum_type = (PyTypeObject *) __Pyx_PyType_FromModuleAndSpec(__pyx_m, &__pyx_type___pyx_MemviewEnum_spec, NULL); if (unlikely(!__pyx_MemviewEnum_type)) __PYX_ERR(1, 302, __pyx_L1_error) + if (__Pyx_fix_up_extension_type_from_spec(&__pyx_type___pyx_MemviewEnum_spec, __pyx_MemviewEnum_type) < 0) __PYX_ERR(1, 302, __pyx_L1_error) + #else + __pyx_MemviewEnum_type = &__pyx_type___pyx_MemviewEnum; + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + #endif + #if !CYTHON_USE_TYPE_SPECS + if (__Pyx_PyType_Ready(__pyx_MemviewEnum_type) < 0) __PYX_ERR(1, 302, __pyx_L1_error) + #endif + #if PY_MAJOR_VERSION < 3 + __pyx_MemviewEnum_type->tp_print = 0; + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + if ((CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP) && likely(!__pyx_MemviewEnum_type->tp_dictoffset && __pyx_MemviewEnum_type->tp_getattro == PyObject_GenericGetAttr)) { + __pyx_MemviewEnum_type->tp_getattro = __Pyx_PyObject_GenericGetAttr; + } + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + if (__Pyx_setup_reduce((PyObject *) __pyx_MemviewEnum_type) < 0) __PYX_ERR(1, 302, __pyx_L1_error) + #endif + __pyx_vtabptr_memoryview = &__pyx_vtable_memoryview; + __pyx_vtable_memoryview.get_item_pointer = (char *(*)(struct __pyx_memoryview_obj *, PyObject *))__pyx_memoryview_get_item_pointer; + __pyx_vtable_memoryview.is_slice = (PyObject *(*)(struct __pyx_memoryview_obj *, PyObject *))__pyx_memoryview_is_slice; + __pyx_vtable_memoryview.setitem_slice_assignment = (PyObject *(*)(struct __pyx_memoryview_obj *, PyObject *, PyObject *))__pyx_memoryview_setitem_slice_assignment; + __pyx_vtable_memoryview.setitem_slice_assign_scalar = (PyObject *(*)(struct __pyx_memoryview_obj *, struct __pyx_memoryview_obj *, PyObject *))__pyx_memoryview_setitem_slice_assign_scalar; + __pyx_vtable_memoryview.setitem_indexed = (PyObject *(*)(struct __pyx_memoryview_obj *, PyObject *, PyObject *))__pyx_memoryview_setitem_indexed; + __pyx_vtable_memoryview.convert_item_to_object = (PyObject *(*)(struct __pyx_memoryview_obj *, char *))__pyx_memoryview_convert_item_to_object; + __pyx_vtable_memoryview.assign_item_from_object = (PyObject *(*)(struct __pyx_memoryview_obj *, char *, PyObject *))__pyx_memoryview_assign_item_from_object; + __pyx_vtable_memoryview._get_base = (PyObject *(*)(struct __pyx_memoryview_obj *))__pyx_memoryview__get_base; + #if CYTHON_USE_TYPE_SPECS + __pyx_memoryview_type = (PyTypeObject *) __Pyx_PyType_FromModuleAndSpec(__pyx_m, &__pyx_type___pyx_memoryview_spec, NULL); if (unlikely(!__pyx_memoryview_type)) __PYX_ERR(1, 337, __pyx_L1_error) + #if !CYTHON_COMPILING_IN_LIMITED_API + __pyx_memoryview_type->tp_as_buffer = &__pyx_tp_as_buffer_memoryview; + if (!__pyx_memoryview_type->tp_as_buffer->bf_releasebuffer && __pyx_memoryview_type->tp_base->tp_as_buffer && __pyx_memoryview_type->tp_base->tp_as_buffer->bf_releasebuffer) { + __pyx_memoryview_type->tp_as_buffer->bf_releasebuffer = __pyx_memoryview_type->tp_base->tp_as_buffer->bf_releasebuffer; + } + #elif defined(Py_bf_getbuffer) && defined(Py_bf_releasebuffer) + /* PY_VERSION_HEX >= 0x03090000 || Py_LIMITED_API >= 0x030B0000 */ + #elif defined(_MSC_VER) + #pragma message ("The buffer protocol is not supported in the Limited C-API < 3.11.") + #else + #warning "The buffer protocol is not supported in the Limited C-API < 3.11." + #endif + if (__Pyx_fix_up_extension_type_from_spec(&__pyx_type___pyx_memoryview_spec, __pyx_memoryview_type) < 0) __PYX_ERR(1, 337, __pyx_L1_error) + #else + __pyx_memoryview_type = &__pyx_type___pyx_memoryview; + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + #endif + #if !CYTHON_USE_TYPE_SPECS + if (__Pyx_PyType_Ready(__pyx_memoryview_type) < 0) __PYX_ERR(1, 337, __pyx_L1_error) + #endif + #if PY_MAJOR_VERSION < 3 + __pyx_memoryview_type->tp_print = 0; + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + if ((CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP) && likely(!__pyx_memoryview_type->tp_dictoffset && __pyx_memoryview_type->tp_getattro == PyObject_GenericGetAttr)) { + __pyx_memoryview_type->tp_getattro = __Pyx_PyObject_GenericGetAttr; + } + #endif + if (__Pyx_SetVtable(__pyx_memoryview_type, __pyx_vtabptr_memoryview) < 0) __PYX_ERR(1, 337, __pyx_L1_error) + #if !CYTHON_COMPILING_IN_LIMITED_API + if (__Pyx_MergeVtables(__pyx_memoryview_type) < 0) __PYX_ERR(1, 337, __pyx_L1_error) + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + if (__Pyx_setup_reduce((PyObject *) __pyx_memoryview_type) < 0) __PYX_ERR(1, 337, __pyx_L1_error) + #endif + __pyx_vtabptr__memoryviewslice = &__pyx_vtable__memoryviewslice; + __pyx_vtable__memoryviewslice.__pyx_base = *__pyx_vtabptr_memoryview; + __pyx_vtable__memoryviewslice.__pyx_base.convert_item_to_object = (PyObject *(*)(struct __pyx_memoryview_obj *, char *))__pyx_memoryviewslice_convert_item_to_object; + __pyx_vtable__memoryviewslice.__pyx_base.assign_item_from_object = (PyObject *(*)(struct __pyx_memoryview_obj *, char *, PyObject *))__pyx_memoryviewslice_assign_item_from_object; + __pyx_vtable__memoryviewslice.__pyx_base._get_base = (PyObject *(*)(struct __pyx_memoryview_obj *))__pyx_memoryviewslice__get_base; + #if CYTHON_USE_TYPE_SPECS + __pyx_t_1 = PyTuple_Pack(1, (PyObject *)__pyx_memoryview_type); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 952, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_memoryviewslice_type = (PyTypeObject *) __Pyx_PyType_FromModuleAndSpec(__pyx_m, &__pyx_type___pyx_memoryviewslice_spec, __pyx_t_1); + __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; + if (unlikely(!__pyx_memoryviewslice_type)) __PYX_ERR(1, 952, __pyx_L1_error) + if (__Pyx_fix_up_extension_type_from_spec(&__pyx_type___pyx_memoryviewslice_spec, __pyx_memoryviewslice_type) < 0) __PYX_ERR(1, 952, __pyx_L1_error) + #else + __pyx_memoryviewslice_type = &__pyx_type___pyx_memoryviewslice; + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + __pyx_memoryviewslice_type->tp_base = __pyx_memoryview_type; + #endif + #if !CYTHON_USE_TYPE_SPECS + if (__Pyx_PyType_Ready(__pyx_memoryviewslice_type) < 0) __PYX_ERR(1, 952, __pyx_L1_error) + #endif + #if PY_MAJOR_VERSION < 3 + __pyx_memoryviewslice_type->tp_print = 0; + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + if ((CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP) && likely(!__pyx_memoryviewslice_type->tp_dictoffset && __pyx_memoryviewslice_type->tp_getattro == PyObject_GenericGetAttr)) { + __pyx_memoryviewslice_type->tp_getattro = __Pyx_PyObject_GenericGetAttr; + } + #endif + if (__Pyx_SetVtable(__pyx_memoryviewslice_type, __pyx_vtabptr__memoryviewslice) < 0) __PYX_ERR(1, 952, __pyx_L1_error) + #if !CYTHON_COMPILING_IN_LIMITED_API + if (__Pyx_MergeVtables(__pyx_memoryviewslice_type) < 0) __PYX_ERR(1, 952, __pyx_L1_error) + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + if (__Pyx_setup_reduce((PyObject *) __pyx_memoryviewslice_type) < 0) __PYX_ERR(1, 952, __pyx_L1_error) + #endif + __Pyx_RefNannyFinishContext(); + return 0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_RefNannyFinishContext(); + return -1; +} + +static int __Pyx_modinit_type_import_code(void) { + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__Pyx_modinit_type_import_code", 0); + /*--- Type import code ---*/ + __pyx_t_1 = PyImport_ImportModule(__Pyx_BUILTIN_MODULE_NAME); if (unlikely(!__pyx_t_1)) __PYX_ERR(3, 9, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_ptype_7cpython_4type_type = __Pyx_ImportType_3_0_8(__pyx_t_1, __Pyx_BUILTIN_MODULE_NAME, "type", + #if defined(PYPY_VERSION_NUM) && PYPY_VERSION_NUM < 0x050B0000 + sizeof(PyTypeObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyTypeObject), + #elif CYTHON_COMPILING_IN_LIMITED_API + sizeof(PyTypeObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyTypeObject), + #else + sizeof(PyHeapTypeObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyHeapTypeObject), + #endif + __Pyx_ImportType_CheckSize_Warn_3_0_8); if (!__pyx_ptype_7cpython_4type_type) __PYX_ERR(3, 9, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_t_1 = PyImport_ImportModule("numpy"); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 202, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_ptype_5numpy_dtype = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "dtype", sizeof(PyArray_Descr), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyArray_Descr),__Pyx_ImportType_CheckSize_Ignore_3_0_8); if (!__pyx_ptype_5numpy_dtype) __PYX_ERR(2, 202, __pyx_L1_error) + __pyx_ptype_5numpy_flatiter = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "flatiter", sizeof(PyArrayIterObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyArrayIterObject),__Pyx_ImportType_CheckSize_Ignore_3_0_8); if (!__pyx_ptype_5numpy_flatiter) __PYX_ERR(2, 225, __pyx_L1_error) + __pyx_ptype_5numpy_broadcast = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "broadcast", sizeof(PyArrayMultiIterObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyArrayMultiIterObject),__Pyx_ImportType_CheckSize_Ignore_3_0_8); if (!__pyx_ptype_5numpy_broadcast) __PYX_ERR(2, 229, __pyx_L1_error) + __pyx_ptype_5numpy_ndarray = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "ndarray", sizeof(PyArrayObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyArrayObject),__Pyx_ImportType_CheckSize_Ignore_3_0_8); if (!__pyx_ptype_5numpy_ndarray) __PYX_ERR(2, 238, __pyx_L1_error) + __pyx_ptype_5numpy_generic = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "generic", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_8); if (!__pyx_ptype_5numpy_generic) __PYX_ERR(2, 809, __pyx_L1_error) + __pyx_ptype_5numpy_number = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "number", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_8); if (!__pyx_ptype_5numpy_number) __PYX_ERR(2, 811, __pyx_L1_error) + __pyx_ptype_5numpy_integer = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "integer", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_8); if (!__pyx_ptype_5numpy_integer) __PYX_ERR(2, 813, __pyx_L1_error) + __pyx_ptype_5numpy_signedinteger = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "signedinteger", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_8); if (!__pyx_ptype_5numpy_signedinteger) __PYX_ERR(2, 815, __pyx_L1_error) + __pyx_ptype_5numpy_unsignedinteger = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "unsignedinteger", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_8); if (!__pyx_ptype_5numpy_unsignedinteger) __PYX_ERR(2, 817, __pyx_L1_error) + __pyx_ptype_5numpy_inexact = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "inexact", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_8); if (!__pyx_ptype_5numpy_inexact) __PYX_ERR(2, 819, __pyx_L1_error) + __pyx_ptype_5numpy_floating = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "floating", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_8); if (!__pyx_ptype_5numpy_floating) __PYX_ERR(2, 821, __pyx_L1_error) + __pyx_ptype_5numpy_complexfloating = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "complexfloating", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_8); if (!__pyx_ptype_5numpy_complexfloating) __PYX_ERR(2, 823, __pyx_L1_error) + __pyx_ptype_5numpy_flexible = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "flexible", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_8); if (!__pyx_ptype_5numpy_flexible) __PYX_ERR(2, 825, __pyx_L1_error) + __pyx_ptype_5numpy_character = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "character", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_8); if (!__pyx_ptype_5numpy_character) __PYX_ERR(2, 827, __pyx_L1_error) + __pyx_ptype_5numpy_ufunc = __Pyx_ImportType_3_0_8(__pyx_t_1, "numpy", "ufunc", sizeof(PyUFuncObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_8(PyUFuncObject),__Pyx_ImportType_CheckSize_Ignore_3_0_8); if (!__pyx_ptype_5numpy_ufunc) __PYX_ERR(2, 866, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_RefNannyFinishContext(); + return 0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_RefNannyFinishContext(); + return -1; +} + +static int __Pyx_modinit_variable_import_code(void) { + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__Pyx_modinit_variable_import_code", 0); + /*--- Variable import code ---*/ + __Pyx_RefNannyFinishContext(); + return 0; +} + +static int __Pyx_modinit_function_import_code(void) { + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__Pyx_modinit_function_import_code", 0); + /*--- Function import code ---*/ + __Pyx_RefNannyFinishContext(); + return 0; +} + + +#if PY_MAJOR_VERSION >= 3 +#if CYTHON_PEP489_MULTI_PHASE_INIT +static PyObject* __pyx_pymod_create(PyObject *spec, PyModuleDef *def); /*proto*/ +static int __pyx_pymod_exec_token_block_utils_fast(PyObject* module); /*proto*/ +static PyModuleDef_Slot __pyx_moduledef_slots[] = { + {Py_mod_create, (void*)__pyx_pymod_create}, + {Py_mod_exec, (void*)__pyx_pymod_exec_token_block_utils_fast}, + {0, NULL} +}; +#endif + +#ifdef __cplusplus +namespace { + struct PyModuleDef __pyx_moduledef = + #else + static struct PyModuleDef __pyx_moduledef = + #endif + { + PyModuleDef_HEAD_INIT, + "token_block_utils_fast", + 0, /* m_doc */ + #if CYTHON_PEP489_MULTI_PHASE_INIT + 0, /* m_size */ + #elif CYTHON_USE_MODULE_STATE + sizeof(__pyx_mstate), /* m_size */ + #else + -1, /* m_size */ + #endif + __pyx_methods /* m_methods */, + #if CYTHON_PEP489_MULTI_PHASE_INIT + __pyx_moduledef_slots, /* m_slots */ + #else + NULL, /* m_reload */ + #endif + #if CYTHON_USE_MODULE_STATE + __pyx_m_traverse, /* m_traverse */ + __pyx_m_clear, /* m_clear */ + NULL /* m_free */ + #else + NULL, /* m_traverse */ + NULL, /* m_clear */ + NULL /* m_free */ + #endif + }; + #ifdef __cplusplus +} /* anonymous namespace */ +#endif +#endif + +#ifndef CYTHON_NO_PYINIT_EXPORT +#define __Pyx_PyMODINIT_FUNC PyMODINIT_FUNC +#elif PY_MAJOR_VERSION < 3 +#ifdef __cplusplus +#define __Pyx_PyMODINIT_FUNC extern "C" void +#else +#define __Pyx_PyMODINIT_FUNC void +#endif +#else +#ifdef __cplusplus +#define __Pyx_PyMODINIT_FUNC extern "C" PyObject * +#else +#define __Pyx_PyMODINIT_FUNC PyObject * +#endif +#endif + + +#if PY_MAJOR_VERSION < 3 +__Pyx_PyMODINIT_FUNC inittoken_block_utils_fast(void) CYTHON_SMALL_CODE; /*proto*/ +__Pyx_PyMODINIT_FUNC inittoken_block_utils_fast(void) +#else +__Pyx_PyMODINIT_FUNC PyInit_token_block_utils_fast(void) CYTHON_SMALL_CODE; /*proto*/ +__Pyx_PyMODINIT_FUNC PyInit_token_block_utils_fast(void) +#if CYTHON_PEP489_MULTI_PHASE_INIT +{ + return PyModuleDef_Init(&__pyx_moduledef); +} +static CYTHON_SMALL_CODE int __Pyx_check_single_interpreter(void) { + #if PY_VERSION_HEX >= 0x030700A1 + static PY_INT64_T main_interpreter_id = -1; + PY_INT64_T current_id = PyInterpreterState_GetID(PyThreadState_Get()->interp); + if (main_interpreter_id == -1) { + main_interpreter_id = current_id; + return (unlikely(current_id == -1)) ? -1 : 0; + } else if (unlikely(main_interpreter_id != current_id)) + #else + static PyInterpreterState *main_interpreter = NULL; + PyInterpreterState *current_interpreter = PyThreadState_Get()->interp; + if (!main_interpreter) { + main_interpreter = current_interpreter; + } else if (unlikely(main_interpreter != current_interpreter)) + #endif + { + PyErr_SetString( + PyExc_ImportError, + "Interpreter change detected - this module can only be loaded into one interpreter per process."); + return -1; + } + return 0; +} +#if CYTHON_COMPILING_IN_LIMITED_API +static CYTHON_SMALL_CODE int __Pyx_copy_spec_to_module(PyObject *spec, PyObject *module, const char* from_name, const char* to_name, int allow_none) +#else +static CYTHON_SMALL_CODE int __Pyx_copy_spec_to_module(PyObject *spec, PyObject *moddict, const char* from_name, const char* to_name, int allow_none) +#endif +{ + PyObject *value = PyObject_GetAttrString(spec, from_name); + int result = 0; + if (likely(value)) { + if (allow_none || value != Py_None) { +#if CYTHON_COMPILING_IN_LIMITED_API + result = PyModule_AddObject(module, to_name, value); +#else + result = PyDict_SetItemString(moddict, to_name, value); +#endif + } + Py_DECREF(value); + } else if (PyErr_ExceptionMatches(PyExc_AttributeError)) { + PyErr_Clear(); + } else { + result = -1; + } + return result; +} +static CYTHON_SMALL_CODE PyObject* __pyx_pymod_create(PyObject *spec, PyModuleDef *def) { + PyObject *module = NULL, *moddict, *modname; + CYTHON_UNUSED_VAR(def); + if (__Pyx_check_single_interpreter()) + return NULL; + if (__pyx_m) + return __Pyx_NewRef(__pyx_m); + modname = PyObject_GetAttrString(spec, "name"); + if (unlikely(!modname)) goto bad; + module = PyModule_NewObject(modname); + Py_DECREF(modname); + if (unlikely(!module)) goto bad; +#if CYTHON_COMPILING_IN_LIMITED_API + moddict = module; +#else + moddict = PyModule_GetDict(module); + if (unlikely(!moddict)) goto bad; +#endif + if (unlikely(__Pyx_copy_spec_to_module(spec, moddict, "loader", "__loader__", 1) < 0)) goto bad; + if (unlikely(__Pyx_copy_spec_to_module(spec, moddict, "origin", "__file__", 1) < 0)) goto bad; + if (unlikely(__Pyx_copy_spec_to_module(spec, moddict, "parent", "__package__", 1) < 0)) goto bad; + if (unlikely(__Pyx_copy_spec_to_module(spec, moddict, "submodule_search_locations", "__path__", 0) < 0)) goto bad; + return module; +bad: + Py_XDECREF(module); + return NULL; +} + + +static CYTHON_SMALL_CODE int __pyx_pymod_exec_token_block_utils_fast(PyObject *__pyx_pyinit_module) +#endif +#endif +{ + int stringtab_initialized = 0; + #if CYTHON_USE_MODULE_STATE + int pystate_addmodule_run = 0; + #endif + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + PyObject *__pyx_t_5 = NULL; + int __pyx_t_6; + PyObject *__pyx_t_7 = NULL; + static PyThread_type_lock __pyx_t_8[8]; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannyDeclarations + #if CYTHON_PEP489_MULTI_PHASE_INIT + if (__pyx_m) { + if (__pyx_m == __pyx_pyinit_module) return 0; + PyErr_SetString(PyExc_RuntimeError, "Module 'token_block_utils_fast' has already been imported. Re-initialisation is not supported."); + return -1; + } + #elif PY_MAJOR_VERSION >= 3 + if (__pyx_m) return __Pyx_NewRef(__pyx_m); + #endif + /*--- Module creation code ---*/ + #if CYTHON_PEP489_MULTI_PHASE_INIT + __pyx_m = __pyx_pyinit_module; + Py_INCREF(__pyx_m); + #else + #if PY_MAJOR_VERSION < 3 + __pyx_m = Py_InitModule4("token_block_utils_fast", __pyx_methods, 0, 0, PYTHON_API_VERSION); Py_XINCREF(__pyx_m); + if (unlikely(!__pyx_m)) __PYX_ERR(0, 1, __pyx_L1_error) + #elif CYTHON_USE_MODULE_STATE + __pyx_t_1 = PyModule_Create(&__pyx_moduledef); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1, __pyx_L1_error) + { + int add_module_result = PyState_AddModule(__pyx_t_1, &__pyx_moduledef); + __pyx_t_1 = 0; /* transfer ownership from __pyx_t_1 to "token_block_utils_fast" pseudovariable */ + if (unlikely((add_module_result < 0))) __PYX_ERR(0, 1, __pyx_L1_error) + pystate_addmodule_run = 1; + } + #else + __pyx_m = PyModule_Create(&__pyx_moduledef); + if (unlikely(!__pyx_m)) __PYX_ERR(0, 1, __pyx_L1_error) + #endif + #endif + CYTHON_UNUSED_VAR(__pyx_t_1); + __pyx_d = PyModule_GetDict(__pyx_m); if (unlikely(!__pyx_d)) __PYX_ERR(0, 1, __pyx_L1_error) + Py_INCREF(__pyx_d); + __pyx_b = __Pyx_PyImport_AddModuleRef(__Pyx_BUILTIN_MODULE_NAME); if (unlikely(!__pyx_b)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_cython_runtime = __Pyx_PyImport_AddModuleRef((const char *) "cython_runtime"); if (unlikely(!__pyx_cython_runtime)) __PYX_ERR(0, 1, __pyx_L1_error) + if (PyObject_SetAttrString(__pyx_m, "__builtins__", __pyx_b) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #if CYTHON_REFNANNY +__Pyx_RefNanny = __Pyx_RefNannyImportAPI("refnanny"); +if (!__Pyx_RefNanny) { + PyErr_Clear(); + __Pyx_RefNanny = __Pyx_RefNannyImportAPI("Cython.Runtime.refnanny"); + if (!__Pyx_RefNanny) + Py_FatalError("failed to import 'refnanny' module"); +} +#endif + __Pyx_RefNannySetupContext("__Pyx_PyMODINIT_FUNC PyInit_token_block_utils_fast(void)", 0); + if (__Pyx_check_binary_version(__PYX_LIMITED_VERSION_HEX, __Pyx_get_runtime_version(), CYTHON_COMPILING_IN_LIMITED_API) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #ifdef __Pxy_PyFrame_Initialize_Offsets + __Pxy_PyFrame_Initialize_Offsets(); + #endif + __pyx_empty_tuple = PyTuple_New(0); if (unlikely(!__pyx_empty_tuple)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_empty_bytes = PyBytes_FromStringAndSize("", 0); if (unlikely(!__pyx_empty_bytes)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_empty_unicode = PyUnicode_FromStringAndSize("", 0); if (unlikely(!__pyx_empty_unicode)) __PYX_ERR(0, 1, __pyx_L1_error) + #ifdef __Pyx_CyFunction_USED + if (__pyx_CyFunction_init(__pyx_m) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #endif + #ifdef __Pyx_FusedFunction_USED + if (__pyx_FusedFunction_init(__pyx_m) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #endif + #ifdef __Pyx_Coroutine_USED + if (__pyx_Coroutine_init(__pyx_m) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #endif + #ifdef __Pyx_Generator_USED + if (__pyx_Generator_init(__pyx_m) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #endif + #ifdef __Pyx_AsyncGen_USED + if (__pyx_AsyncGen_init(__pyx_m) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #endif + #ifdef __Pyx_StopAsyncIteration_USED + if (__pyx_StopAsyncIteration_init(__pyx_m) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #endif + /*--- Library function declarations ---*/ + /*--- Threads initialization code ---*/ + #if defined(WITH_THREAD) && PY_VERSION_HEX < 0x030700F0 && defined(__PYX_FORCE_INIT_THREADS) && __PYX_FORCE_INIT_THREADS + PyEval_InitThreads(); + #endif + /*--- Initialize various global constants etc. ---*/ + if (__Pyx_InitConstants() < 0) __PYX_ERR(0, 1, __pyx_L1_error) + stringtab_initialized = 1; + if (__Pyx_InitGlobals() < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #if PY_MAJOR_VERSION < 3 && (__PYX_DEFAULT_STRING_ENCODING_IS_ASCII || __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT) + if (__Pyx_init_sys_getdefaultencoding_params() < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #endif + if (__pyx_module_is_main_fairseq__data__token_block_utils_fast) { + if (PyObject_SetAttr(__pyx_m, __pyx_n_s_name_2, __pyx_n_s_main) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + } + #if PY_MAJOR_VERSION >= 3 + { + PyObject *modules = PyImport_GetModuleDict(); if (unlikely(!modules)) __PYX_ERR(0, 1, __pyx_L1_error) + if (!PyDict_GetItemString(modules, "fairseq.data.token_block_utils_fast")) { + if (unlikely((PyDict_SetItemString(modules, "fairseq.data.token_block_utils_fast", __pyx_m) < 0))) __PYX_ERR(0, 1, __pyx_L1_error) + } + } + #endif + /*--- Builtin init code ---*/ + if (__Pyx_InitCachedBuiltins() < 0) __PYX_ERR(0, 1, __pyx_L1_error) + /*--- Constants init code ---*/ + if (__Pyx_InitCachedConstants() < 0) __PYX_ERR(0, 1, __pyx_L1_error) + /*--- Global type/function init code ---*/ + (void)__Pyx_modinit_global_init_code(); + (void)__Pyx_modinit_variable_export_code(); + (void)__Pyx_modinit_function_export_code(); + if (unlikely((__Pyx_modinit_type_init_code() < 0))) __PYX_ERR(0, 1, __pyx_L1_error) + if (unlikely((__Pyx_modinit_type_import_code() < 0))) __PYX_ERR(0, 1, __pyx_L1_error) + (void)__Pyx_modinit_variable_import_code(); + (void)__Pyx_modinit_function_import_code(); + /*--- Execution code ---*/ + #if defined(__Pyx_Generator_USED) || defined(__Pyx_Coroutine_USED) + if (__Pyx_patch_abc() < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #endif + + /* "View.MemoryView":99 + * + * cdef object __pyx_collections_abc_Sequence "__pyx_collections_abc_Sequence" + * try: # <<<<<<<<<<<<<< + * if __import__("sys").version_info >= (3, 3): + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence + */ + { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ExceptionSave(&__pyx_t_1, &__pyx_t_2, &__pyx_t_3); + __Pyx_XGOTREF(__pyx_t_1); + __Pyx_XGOTREF(__pyx_t_2); + __Pyx_XGOTREF(__pyx_t_3); + /*try:*/ { + + /* "View.MemoryView":100 + * cdef object __pyx_collections_abc_Sequence "__pyx_collections_abc_Sequence" + * try: + * if __import__("sys").version_info >= (3, 3): # <<<<<<<<<<<<<< + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence + * else: + */ + __pyx_t_4 = __Pyx_PyObject_Call(__pyx_builtin___import__, __pyx_tuple__15, NULL); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 100, __pyx_L2_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_5 = __Pyx_PyObject_GetAttrStr(__pyx_t_4, __pyx_n_s_version_info); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 100, __pyx_L2_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __pyx_t_4 = PyObject_RichCompare(__pyx_t_5, __pyx_tuple__16, Py_GE); __Pyx_XGOTREF(__pyx_t_4); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 100, __pyx_L2_error) + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __pyx_t_6 = __Pyx_PyObject_IsTrue(__pyx_t_4); if (unlikely((__pyx_t_6 < 0))) __PYX_ERR(1, 100, __pyx_L2_error) + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + if (__pyx_t_6) { + + /* "View.MemoryView":101 + * try: + * if __import__("sys").version_info >= (3, 3): + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence # <<<<<<<<<<<<<< + * else: + * __pyx_collections_abc_Sequence = __import__("collections").Sequence + */ + __pyx_t_4 = __Pyx_PyObject_Call(__pyx_builtin___import__, __pyx_tuple__17, NULL); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 101, __pyx_L2_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_5 = __Pyx_PyObject_GetAttrStr(__pyx_t_4, __pyx_n_s_abc); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 101, __pyx_L2_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __pyx_t_4 = __Pyx_PyObject_GetAttrStr(__pyx_t_5, __pyx_n_s_Sequence); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 101, __pyx_L2_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_XGOTREF(__pyx_collections_abc_Sequence); + __Pyx_DECREF_SET(__pyx_collections_abc_Sequence, __pyx_t_4); + __Pyx_GIVEREF(__pyx_t_4); + __pyx_t_4 = 0; + + /* "View.MemoryView":100 + * cdef object __pyx_collections_abc_Sequence "__pyx_collections_abc_Sequence" + * try: + * if __import__("sys").version_info >= (3, 3): # <<<<<<<<<<<<<< + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence + * else: + */ + goto __pyx_L8; + } + + /* "View.MemoryView":103 + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence + * else: + * __pyx_collections_abc_Sequence = __import__("collections").Sequence # <<<<<<<<<<<<<< + * except: + * + */ + /*else*/ { + __pyx_t_4 = __Pyx_PyObject_Call(__pyx_builtin___import__, __pyx_tuple__18, NULL); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 103, __pyx_L2_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_5 = __Pyx_PyObject_GetAttrStr(__pyx_t_4, __pyx_n_s_Sequence); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 103, __pyx_L2_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_XGOTREF(__pyx_collections_abc_Sequence); + __Pyx_DECREF_SET(__pyx_collections_abc_Sequence, __pyx_t_5); + __Pyx_GIVEREF(__pyx_t_5); + __pyx_t_5 = 0; + } + __pyx_L8:; + + /* "View.MemoryView":99 + * + * cdef object __pyx_collections_abc_Sequence "__pyx_collections_abc_Sequence" + * try: # <<<<<<<<<<<<<< + * if __import__("sys").version_info >= (3, 3): + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence + */ + } + __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0; + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + goto __pyx_L7_try_end; + __pyx_L2_error:; + __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; + + /* "View.MemoryView":104 + * else: + * __pyx_collections_abc_Sequence = __import__("collections").Sequence + * except: # <<<<<<<<<<<<<< + * + * __pyx_collections_abc_Sequence = None + */ + /*except:*/ { + __Pyx_AddTraceback("View.MemoryView", __pyx_clineno, __pyx_lineno, __pyx_filename); + if (__Pyx_GetException(&__pyx_t_5, &__pyx_t_4, &__pyx_t_7) < 0) __PYX_ERR(1, 104, __pyx_L4_except_error) + __Pyx_XGOTREF(__pyx_t_5); + __Pyx_XGOTREF(__pyx_t_4); + __Pyx_XGOTREF(__pyx_t_7); + + /* "View.MemoryView":106 + * except: + * + * __pyx_collections_abc_Sequence = None # <<<<<<<<<<<<<< + * + * + */ + __Pyx_INCREF(Py_None); + __Pyx_XGOTREF(__pyx_collections_abc_Sequence); + __Pyx_DECREF_SET(__pyx_collections_abc_Sequence, Py_None); + __Pyx_GIVEREF(Py_None); + __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; + goto __pyx_L3_exception_handled; + } + + /* "View.MemoryView":99 + * + * cdef object __pyx_collections_abc_Sequence "__pyx_collections_abc_Sequence" + * try: # <<<<<<<<<<<<<< + * if __import__("sys").version_info >= (3, 3): + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence + */ + __pyx_L4_except_error:; + __Pyx_XGIVEREF(__pyx_t_1); + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_ExceptionReset(__pyx_t_1, __pyx_t_2, __pyx_t_3); + goto __pyx_L1_error; + __pyx_L3_exception_handled:; + __Pyx_XGIVEREF(__pyx_t_1); + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_ExceptionReset(__pyx_t_1, __pyx_t_2, __pyx_t_3); + __pyx_L7_try_end:; + } + + /* "View.MemoryView":241 + * + * + * try: # <<<<<<<<<<<<<< + * count = __pyx_collections_abc_Sequence.count + * index = __pyx_collections_abc_Sequence.index + */ + { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ExceptionSave(&__pyx_t_3, &__pyx_t_2, &__pyx_t_1); + __Pyx_XGOTREF(__pyx_t_3); + __Pyx_XGOTREF(__pyx_t_2); + __Pyx_XGOTREF(__pyx_t_1); + /*try:*/ { + + /* "View.MemoryView":242 + * + * try: + * count = __pyx_collections_abc_Sequence.count # <<<<<<<<<<<<<< + * index = __pyx_collections_abc_Sequence.index + * except: + */ + __pyx_t_7 = __Pyx_PyObject_GetAttrStr(__pyx_collections_abc_Sequence, __pyx_n_s_count); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 242, __pyx_L11_error) + __Pyx_GOTREF(__pyx_t_7); + if (__Pyx_SetItemOnTypeDict(__pyx_array_type, __pyx_n_s_count, __pyx_t_7) < 0) __PYX_ERR(1, 242, __pyx_L11_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + PyType_Modified(__pyx_array_type); + + /* "View.MemoryView":243 + * try: + * count = __pyx_collections_abc_Sequence.count + * index = __pyx_collections_abc_Sequence.index # <<<<<<<<<<<<<< + * except: + * pass + */ + __pyx_t_7 = __Pyx_PyObject_GetAttrStr(__pyx_collections_abc_Sequence, __pyx_n_s_index); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 243, __pyx_L11_error) + __Pyx_GOTREF(__pyx_t_7); + if (__Pyx_SetItemOnTypeDict(__pyx_array_type, __pyx_n_s_index, __pyx_t_7) < 0) __PYX_ERR(1, 243, __pyx_L11_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + PyType_Modified(__pyx_array_type); + + /* "View.MemoryView":241 + * + * + * try: # <<<<<<<<<<<<<< + * count = __pyx_collections_abc_Sequence.count + * index = __pyx_collections_abc_Sequence.index + */ + } + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0; + __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; + goto __pyx_L16_try_end; + __pyx_L11_error:; + __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "View.MemoryView":244 + * count = __pyx_collections_abc_Sequence.count + * index = __pyx_collections_abc_Sequence.index + * except: # <<<<<<<<<<<<<< + * pass + * + */ + /*except:*/ { + __Pyx_ErrRestore(0,0,0); + goto __pyx_L12_exception_handled; + } + __pyx_L12_exception_handled:; + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_1); + __Pyx_ExceptionReset(__pyx_t_3, __pyx_t_2, __pyx_t_1); + __pyx_L16_try_end:; + } + + /* "View.MemoryView":309 + * return self.name + * + * cdef generic = Enum("") # <<<<<<<<<<<<<< + * cdef strided = Enum("") # default + * cdef indirect = Enum("") + */ + __pyx_t_7 = __Pyx_PyObject_Call(((PyObject *)__pyx_MemviewEnum_type), __pyx_tuple__19, NULL); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 309, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __Pyx_XGOTREF(generic); + __Pyx_DECREF_SET(generic, __pyx_t_7); + __Pyx_GIVEREF(__pyx_t_7); + __pyx_t_7 = 0; + + /* "View.MemoryView":310 + * + * cdef generic = Enum("") + * cdef strided = Enum("") # default # <<<<<<<<<<<<<< + * cdef indirect = Enum("") + * + */ + __pyx_t_7 = __Pyx_PyObject_Call(((PyObject *)__pyx_MemviewEnum_type), __pyx_tuple__20, NULL); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 310, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __Pyx_XGOTREF(strided); + __Pyx_DECREF_SET(strided, __pyx_t_7); + __Pyx_GIVEREF(__pyx_t_7); + __pyx_t_7 = 0; + + /* "View.MemoryView":311 + * cdef generic = Enum("") + * cdef strided = Enum("") # default + * cdef indirect = Enum("") # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_7 = __Pyx_PyObject_Call(((PyObject *)__pyx_MemviewEnum_type), __pyx_tuple__21, NULL); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 311, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __Pyx_XGOTREF(indirect); + __Pyx_DECREF_SET(indirect, __pyx_t_7); + __Pyx_GIVEREF(__pyx_t_7); + __pyx_t_7 = 0; + + /* "View.MemoryView":314 + * + * + * cdef contiguous = Enum("") # <<<<<<<<<<<<<< + * cdef indirect_contiguous = Enum("") + * + */ + __pyx_t_7 = __Pyx_PyObject_Call(((PyObject *)__pyx_MemviewEnum_type), __pyx_tuple__22, NULL); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 314, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __Pyx_XGOTREF(contiguous); + __Pyx_DECREF_SET(contiguous, __pyx_t_7); + __Pyx_GIVEREF(__pyx_t_7); + __pyx_t_7 = 0; + + /* "View.MemoryView":315 + * + * cdef contiguous = Enum("") + * cdef indirect_contiguous = Enum("") # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_7 = __Pyx_PyObject_Call(((PyObject *)__pyx_MemviewEnum_type), __pyx_tuple__23, NULL); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 315, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __Pyx_XGOTREF(indirect_contiguous); + __Pyx_DECREF_SET(indirect_contiguous, __pyx_t_7); + __Pyx_GIVEREF(__pyx_t_7); + __pyx_t_7 = 0; + + /* "View.MemoryView":323 + * + * + * cdef int __pyx_memoryview_thread_locks_used = 0 # <<<<<<<<<<<<<< + * cdef PyThread_type_lock[8] __pyx_memoryview_thread_locks = [ + * PyThread_allocate_lock(), + */ + __pyx_memoryview_thread_locks_used = 0; + + /* "View.MemoryView":324 + * + * cdef int __pyx_memoryview_thread_locks_used = 0 + * cdef PyThread_type_lock[8] __pyx_memoryview_thread_locks = [ # <<<<<<<<<<<<<< + * PyThread_allocate_lock(), + * PyThread_allocate_lock(), + */ + __pyx_t_8[0] = PyThread_allocate_lock(); + __pyx_t_8[1] = PyThread_allocate_lock(); + __pyx_t_8[2] = PyThread_allocate_lock(); + __pyx_t_8[3] = PyThread_allocate_lock(); + __pyx_t_8[4] = PyThread_allocate_lock(); + __pyx_t_8[5] = PyThread_allocate_lock(); + __pyx_t_8[6] = PyThread_allocate_lock(); + __pyx_t_8[7] = PyThread_allocate_lock(); + memcpy(&(__pyx_memoryview_thread_locks[0]), __pyx_t_8, sizeof(__pyx_memoryview_thread_locks[0]) * (8)); + + /* "View.MemoryView":982 + * + * + * try: # <<<<<<<<<<<<<< + * count = __pyx_collections_abc_Sequence.count + * index = __pyx_collections_abc_Sequence.index + */ + { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ExceptionSave(&__pyx_t_1, &__pyx_t_2, &__pyx_t_3); + __Pyx_XGOTREF(__pyx_t_1); + __Pyx_XGOTREF(__pyx_t_2); + __Pyx_XGOTREF(__pyx_t_3); + /*try:*/ { + + /* "View.MemoryView":983 + * + * try: + * count = __pyx_collections_abc_Sequence.count # <<<<<<<<<<<<<< + * index = __pyx_collections_abc_Sequence.index + * except: + */ + __pyx_t_7 = __Pyx_PyObject_GetAttrStr(__pyx_collections_abc_Sequence, __pyx_n_s_count); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 983, __pyx_L17_error) + __Pyx_GOTREF(__pyx_t_7); + if (__Pyx_SetItemOnTypeDict(__pyx_memoryviewslice_type, __pyx_n_s_count, __pyx_t_7) < 0) __PYX_ERR(1, 983, __pyx_L17_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + PyType_Modified(__pyx_memoryviewslice_type); + + /* "View.MemoryView":984 + * try: + * count = __pyx_collections_abc_Sequence.count + * index = __pyx_collections_abc_Sequence.index # <<<<<<<<<<<<<< + * except: + * pass + */ + __pyx_t_7 = __Pyx_PyObject_GetAttrStr(__pyx_collections_abc_Sequence, __pyx_n_s_index); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 984, __pyx_L17_error) + __Pyx_GOTREF(__pyx_t_7); + if (__Pyx_SetItemOnTypeDict(__pyx_memoryviewslice_type, __pyx_n_s_index, __pyx_t_7) < 0) __PYX_ERR(1, 984, __pyx_L17_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + PyType_Modified(__pyx_memoryviewslice_type); + + /* "View.MemoryView":982 + * + * + * try: # <<<<<<<<<<<<<< + * count = __pyx_collections_abc_Sequence.count + * index = __pyx_collections_abc_Sequence.index + */ + } + __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0; + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + goto __pyx_L22_try_end; + __pyx_L17_error:; + __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "View.MemoryView":985 + * count = __pyx_collections_abc_Sequence.count + * index = __pyx_collections_abc_Sequence.index + * except: # <<<<<<<<<<<<<< + * pass + * + */ + /*except:*/ { + __Pyx_ErrRestore(0,0,0); + goto __pyx_L18_exception_handled; + } + __pyx_L18_exception_handled:; + __Pyx_XGIVEREF(__pyx_t_1); + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_ExceptionReset(__pyx_t_1, __pyx_t_2, __pyx_t_3); + __pyx_L22_try_end:; + } + + /* "View.MemoryView":988 + * pass + * + * try: # <<<<<<<<<<<<<< + * if __pyx_collections_abc_Sequence: + * + */ + { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ExceptionSave(&__pyx_t_3, &__pyx_t_2, &__pyx_t_1); + __Pyx_XGOTREF(__pyx_t_3); + __Pyx_XGOTREF(__pyx_t_2); + __Pyx_XGOTREF(__pyx_t_1); + /*try:*/ { + + /* "View.MemoryView":989 + * + * try: + * if __pyx_collections_abc_Sequence: # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_6 = __Pyx_PyObject_IsTrue(__pyx_collections_abc_Sequence); if (unlikely((__pyx_t_6 < 0))) __PYX_ERR(1, 989, __pyx_L23_error) + if (__pyx_t_6) { + + /* "View.MemoryView":993 + * + * + * __pyx_collections_abc_Sequence.register(_memoryviewslice) # <<<<<<<<<<<<<< + * __pyx_collections_abc_Sequence.register(array) + * except: + */ + __pyx_t_7 = __Pyx_PyObject_GetAttrStr(__pyx_collections_abc_Sequence, __pyx_n_s_register); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 993, __pyx_L23_error) + __Pyx_GOTREF(__pyx_t_7); + __pyx_t_4 = __Pyx_PyObject_CallOneArg(__pyx_t_7, ((PyObject *)__pyx_memoryviewslice_type)); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 993, __pyx_L23_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + + /* "View.MemoryView":994 + * + * __pyx_collections_abc_Sequence.register(_memoryviewslice) + * __pyx_collections_abc_Sequence.register(array) # <<<<<<<<<<<<<< + * except: + * pass # ignore failure, it's a minor issue + */ + __pyx_t_4 = __Pyx_PyObject_GetAttrStr(__pyx_collections_abc_Sequence, __pyx_n_s_register); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 994, __pyx_L23_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_7 = __Pyx_PyObject_CallOneArg(__pyx_t_4, ((PyObject *)__pyx_array_type)); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 994, __pyx_L23_error) + __Pyx_GOTREF(__pyx_t_7); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "View.MemoryView":989 + * + * try: + * if __pyx_collections_abc_Sequence: # <<<<<<<<<<<<<< + * + * + */ + } + + /* "View.MemoryView":988 + * pass + * + * try: # <<<<<<<<<<<<<< + * if __pyx_collections_abc_Sequence: + * + */ + } + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0; + __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; + goto __pyx_L28_try_end; + __pyx_L23_error:; + __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "View.MemoryView":995 + * __pyx_collections_abc_Sequence.register(_memoryviewslice) + * __pyx_collections_abc_Sequence.register(array) + * except: # <<<<<<<<<<<<<< + * pass # ignore failure, it's a minor issue + * + */ + /*except:*/ { + __Pyx_ErrRestore(0,0,0); + goto __pyx_L24_exception_handled; + } + __pyx_L24_exception_handled:; + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_1); + __Pyx_ExceptionReset(__pyx_t_3, __pyx_t_2, __pyx_t_1); + __pyx_L28_try_end:; + } + + /* "(tree fragment)":1 + * def __pyx_unpickle_Enum(__pyx_type, long __pyx_checksum, __pyx_state): # <<<<<<<<<<<<<< + * cdef object __pyx_PickleError + * cdef object __pyx_result + */ + __pyx_t_7 = PyCFunction_NewEx(&__pyx_mdef_15View_dot_MemoryView_1__pyx_unpickle_Enum, NULL, __pyx_n_s_View_MemoryView); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 1, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + if (PyDict_SetItem(__pyx_d, __pyx_n_s_pyx_unpickle_Enum, __pyx_t_7) < 0) __PYX_ERR(1, 1, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":7 + * # LICENSE file in the root directory of this source tree. + * + * import numpy as np # <<<<<<<<<<<<<< + * import torch + * from itertools import chain + */ + __pyx_t_7 = __Pyx_ImportDottedModule(__pyx_n_s_numpy, NULL); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 7, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + if (PyDict_SetItem(__pyx_d, __pyx_n_s_np, __pyx_t_7) < 0) __PYX_ERR(0, 7, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":8 + * + * import numpy as np + * import torch # <<<<<<<<<<<<<< + * from itertools import chain + * from libc.math cimport ceil + */ + __pyx_t_7 = __Pyx_ImportDottedModule(__pyx_n_s_torch, NULL); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 8, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + if (PyDict_SetItem(__pyx_d, __pyx_n_s_torch, __pyx_t_7) < 0) __PYX_ERR(0, 8, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":9 + * import numpy as np + * import torch + * from itertools import chain # <<<<<<<<<<<<<< + * from libc.math cimport ceil + * + */ + __pyx_t_7 = PyList_New(1); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 9, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __Pyx_INCREF(__pyx_n_s_chain); + __Pyx_GIVEREF(__pyx_n_s_chain); + if (__Pyx_PyList_SET_ITEM(__pyx_t_7, 0, __pyx_n_s_chain)) __PYX_ERR(0, 9, __pyx_L1_error); + __pyx_t_4 = __Pyx_Import(__pyx_n_s_itertools, __pyx_t_7, 0); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 9, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + __pyx_t_7 = __Pyx_ImportFrom(__pyx_t_4, __pyx_n_s_chain); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 9, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + if (PyDict_SetItem(__pyx_d, __pyx_n_s_chain, __pyx_t_7) < 0) __PYX_ERR(0, 9, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":17 + * from libc.stdint cimport int32_t, int64_t + * + * DTYPE = np.int64 # <<<<<<<<<<<<<< + * ctypedef int64_t DTYPE_t + * + */ + __Pyx_GetModuleGlobalName(__pyx_t_4, __pyx_n_s_np); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 17, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_7 = __Pyx_PyObject_GetAttrStr(__pyx_t_4, __pyx_n_s_int64); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 17, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + if (PyDict_SetItem(__pyx_d, __pyx_n_s_DTYPE, __pyx_t_7) < 0) __PYX_ERR(0, 17, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":52 + * @cython.wraparound(False) + * @cython.nonecheck(False) + * cpdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_fast(np.ndarray[DTYPE_t, ndim=1] sizes, str break_mode, int block_size, int document_sep_len): # <<<<<<<<<<<<<< + * cdef DTYPE_t tok_idx = 0 + * cdef DTYPE_t sz_idx = 0 + */ + __pyx_t_7 = __Pyx_CyFunction_New(&__pyx_mdef_7fairseq_4data_22token_block_utils_fast_1_get_slice_indices_fast, 0, __pyx_n_s_get_slice_indices_fast, NULL, __pyx_n_s_fairseq_data_token_block_utils_f_2, __pyx_d, ((PyObject *)__pyx_codeobj__27)); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 52, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + if (PyDict_SetItem(__pyx_d, __pyx_n_s_get_slice_indices_fast, __pyx_t_7) < 0) __PYX_ERR(0, 52, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":111 + * @cython.wraparound(False) + * @cython.nonecheck(False) + * cpdef np.ndarray[DTYPE_t, ndim=2] _get_block_to_dataset_index_fast(np.ndarray[DTYPE_t, ndim=1] sizes, np.ndarray[DTYPE_t, ndim=2] slice_indices): # <<<<<<<<<<<<<< + * cdef DTYPE_t start_ds_idx + * cdef DTYPE_t start_offset + */ + __pyx_t_7 = __Pyx_CyFunction_New(&__pyx_mdef_7fairseq_4data_22token_block_utils_fast_3_get_block_to_dataset_index_fast, 0, __pyx_n_s_get_block_to_dataset_index_fast, NULL, __pyx_n_s_fairseq_data_token_block_utils_f_2, __pyx_d, ((PyObject *)__pyx_codeobj__29)); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 111, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + if (PyDict_SetItem(__pyx_d, __pyx_n_s_get_block_to_dataset_index_fast, __pyx_t_7) < 0) __PYX_ERR(0, 111, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * cdef tuple state + * cdef object _dict + */ + __pyx_t_7 = __Pyx_CyFunction_New(&__pyx_mdef_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_3__reduce_cython__, __Pyx_CYFUNCTION_CCLASS, __pyx_n_s_DatasetSearcher___reduce_cython, NULL, __pyx_n_s_fairseq_data_token_block_utils_f_2, __pyx_d, ((PyObject *)__pyx_codeobj__31)); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 1, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + if (__Pyx_SetItemOnTypeDict((PyObject *)__pyx_ptype_7fairseq_4data_22token_block_utils_fast_DatasetSearcher, __pyx_n_s_reduce_cython, __pyx_t_7) < 0) __PYX_ERR(1, 1, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + PyType_Modified(__pyx_ptype_7fairseq_4data_22token_block_utils_fast_DatasetSearcher); + + /* "(tree fragment)":16 + * else: + * return __pyx_unpickle_DatasetSearcher, (type(self), 0x8c67b45, state) + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * __pyx_unpickle_DatasetSearcher__set_state(self, __pyx_state) + */ + __pyx_t_7 = __Pyx_CyFunction_New(&__pyx_mdef_7fairseq_4data_22token_block_utils_fast_15DatasetSearcher_5__setstate_cython__, __Pyx_CYFUNCTION_CCLASS, __pyx_n_s_DatasetSearcher___setstate_cytho, NULL, __pyx_n_s_fairseq_data_token_block_utils_f_2, __pyx_d, ((PyObject *)__pyx_codeobj__33)); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 16, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + if (__Pyx_SetItemOnTypeDict((PyObject *)__pyx_ptype_7fairseq_4data_22token_block_utils_fast_DatasetSearcher, __pyx_n_s_setstate_cython, __pyx_t_7) < 0) __PYX_ERR(1, 16, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + PyType_Modified(__pyx_ptype_7fairseq_4data_22token_block_utils_fast_DatasetSearcher); + + /* "(tree fragment)":1 + * def __pyx_unpickle_DatasetSearcher(__pyx_type, long __pyx_checksum, __pyx_state): # <<<<<<<<<<<<<< + * cdef object __pyx_PickleError + * cdef object __pyx_result + */ + __pyx_t_7 = __Pyx_CyFunction_New(&__pyx_mdef_7fairseq_4data_22token_block_utils_fast_5__pyx_unpickle_DatasetSearcher, 0, __pyx_n_s_pyx_unpickle_DatasetSearcher, NULL, __pyx_n_s_fairseq_data_token_block_utils_f_2, __pyx_d, ((PyObject *)__pyx_codeobj__34)); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 1, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + if (PyDict_SetItem(__pyx_d, __pyx_n_s_pyx_unpickle_DatasetSearcher, __pyx_t_7) < 0) __PYX_ERR(1, 1, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "fairseq/data/token_block_utils_fast.pyx":1 + * # cython: language_level=3 # <<<<<<<<<<<<<< + * # Copyright (c) Facebook, Inc. and its affiliates. + * # + */ + __pyx_t_7 = __Pyx_PyDict_NewPresized(0); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 1, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + if (PyDict_SetItem(__pyx_d, __pyx_n_s_test, __pyx_t_7) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + + /*--- Wrapped vars code ---*/ + + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_4); + __Pyx_XDECREF(__pyx_t_5); + __Pyx_XDECREF(__pyx_t_7); + if (__pyx_m) { + if (__pyx_d && stringtab_initialized) { + __Pyx_AddTraceback("init fairseq.data.token_block_utils_fast", __pyx_clineno, __pyx_lineno, __pyx_filename); + } + #if !CYTHON_USE_MODULE_STATE + Py_CLEAR(__pyx_m); + #else + Py_DECREF(__pyx_m); + if (pystate_addmodule_run) { + PyObject *tp, *value, *tb; + PyErr_Fetch(&tp, &value, &tb); + PyState_RemoveModule(&__pyx_moduledef); + PyErr_Restore(tp, value, tb); + } + #endif + } else if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_ImportError, "init fairseq.data.token_block_utils_fast"); + } + __pyx_L0:; + __Pyx_RefNannyFinishContext(); + #if CYTHON_PEP489_MULTI_PHASE_INIT + return (__pyx_m != NULL) ? 0 : -1; + #elif PY_MAJOR_VERSION >= 3 + return __pyx_m; + #else + return; + #endif +} +/* #### Code section: cleanup_globals ### */ +/* #### Code section: cleanup_module ### */ +/* #### Code section: main_method ### */ +/* #### Code section: utility_code_pragmas ### */ +#ifdef _MSC_VER +#pragma warning( push ) +/* Warning 4127: conditional expression is constant + * Cython uses constant conditional expressions to allow in inline functions to be optimized at + * compile-time, so this warning is not useful + */ +#pragma warning( disable : 4127 ) +#endif + + + +/* #### Code section: utility_code_def ### */ + +/* --- Runtime support code --- */ +/* Refnanny */ +#if CYTHON_REFNANNY +static __Pyx_RefNannyAPIStruct *__Pyx_RefNannyImportAPI(const char *modname) { + PyObject *m = NULL, *p = NULL; + void *r = NULL; + m = PyImport_ImportModule(modname); + if (!m) goto end; + p = PyObject_GetAttrString(m, "RefNannyAPI"); + if (!p) goto end; + r = PyLong_AsVoidPtr(p); +end: + Py_XDECREF(p); + Py_XDECREF(m); + return (__Pyx_RefNannyAPIStruct *)r; +} +#endif + +/* PyErrExceptionMatches */ +#if CYTHON_FAST_THREAD_STATE +static int __Pyx_PyErr_ExceptionMatchesTuple(PyObject *exc_type, PyObject *tuple) { + Py_ssize_t i, n; + n = PyTuple_GET_SIZE(tuple); +#if PY_MAJOR_VERSION >= 3 + for (i=0; i= 0x030C00A6 + PyObject *current_exception = tstate->current_exception; + if (unlikely(!current_exception)) return 0; + exc_type = (PyObject*) Py_TYPE(current_exception); + if (exc_type == err) return 1; +#else + exc_type = tstate->curexc_type; + if (exc_type == err) return 1; + if (unlikely(!exc_type)) return 0; +#endif + #if CYTHON_AVOID_BORROWED_REFS + Py_INCREF(exc_type); + #endif + if (unlikely(PyTuple_Check(err))) { + result = __Pyx_PyErr_ExceptionMatchesTuple(exc_type, err); + } else { + result = __Pyx_PyErr_GivenExceptionMatches(exc_type, err); + } + #if CYTHON_AVOID_BORROWED_REFS + Py_DECREF(exc_type); + #endif + return result; +} +#endif + +/* PyErrFetchRestore */ +#if CYTHON_FAST_THREAD_STATE +static CYTHON_INLINE void __Pyx_ErrRestoreInState(PyThreadState *tstate, PyObject *type, PyObject *value, PyObject *tb) { +#if PY_VERSION_HEX >= 0x030C00A6 + PyObject *tmp_value; + assert(type == NULL || (value != NULL && type == (PyObject*) Py_TYPE(value))); + if (value) { + #if CYTHON_COMPILING_IN_CPYTHON + if (unlikely(((PyBaseExceptionObject*) value)->traceback != tb)) + #endif + PyException_SetTraceback(value, tb); + } + tmp_value = tstate->current_exception; + tstate->current_exception = value; + Py_XDECREF(tmp_value); + Py_XDECREF(type); + Py_XDECREF(tb); +#else + PyObject *tmp_type, *tmp_value, *tmp_tb; + tmp_type = tstate->curexc_type; + tmp_value = tstate->curexc_value; + tmp_tb = tstate->curexc_traceback; + tstate->curexc_type = type; + tstate->curexc_value = value; + tstate->curexc_traceback = tb; + Py_XDECREF(tmp_type); + Py_XDECREF(tmp_value); + Py_XDECREF(tmp_tb); +#endif +} +static CYTHON_INLINE void __Pyx_ErrFetchInState(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb) { +#if PY_VERSION_HEX >= 0x030C00A6 + PyObject* exc_value; + exc_value = tstate->current_exception; + tstate->current_exception = 0; + *value = exc_value; + *type = NULL; + *tb = NULL; + if (exc_value) { + *type = (PyObject*) Py_TYPE(exc_value); + Py_INCREF(*type); + #if CYTHON_COMPILING_IN_CPYTHON + *tb = ((PyBaseExceptionObject*) exc_value)->traceback; + Py_XINCREF(*tb); + #else + *tb = PyException_GetTraceback(exc_value); + #endif + } +#else + *type = tstate->curexc_type; + *value = tstate->curexc_value; + *tb = tstate->curexc_traceback; + tstate->curexc_type = 0; + tstate->curexc_value = 0; + tstate->curexc_traceback = 0; +#endif +} +#endif + +/* PyObjectGetAttrStr */ +#if CYTHON_USE_TYPE_SLOTS +static CYTHON_INLINE PyObject* __Pyx_PyObject_GetAttrStr(PyObject* obj, PyObject* attr_name) { + PyTypeObject* tp = Py_TYPE(obj); + if (likely(tp->tp_getattro)) + return tp->tp_getattro(obj, attr_name); +#if PY_MAJOR_VERSION < 3 + if (likely(tp->tp_getattr)) + return tp->tp_getattr(obj, PyString_AS_STRING(attr_name)); +#endif + return PyObject_GetAttr(obj, attr_name); +} +#endif + +/* PyObjectGetAttrStrNoError */ +#if __PYX_LIMITED_VERSION_HEX < 0x030d00A1 +static void __Pyx_PyObject_GetAttrStr_ClearAttributeError(void) { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + if (likely(__Pyx_PyErr_ExceptionMatches(PyExc_AttributeError))) + __Pyx_PyErr_Clear(); +} +#endif +static CYTHON_INLINE PyObject* __Pyx_PyObject_GetAttrStrNoError(PyObject* obj, PyObject* attr_name) { + PyObject *result; +#if __PYX_LIMITED_VERSION_HEX >= 0x030d00A1 + (void) PyObject_GetOptionalAttr(obj, attr_name, &result); + return result; +#else +#if CYTHON_COMPILING_IN_CPYTHON && CYTHON_USE_TYPE_SLOTS && PY_VERSION_HEX >= 0x030700B1 + PyTypeObject* tp = Py_TYPE(obj); + if (likely(tp->tp_getattro == PyObject_GenericGetAttr)) { + return _PyObject_GenericGetAttrWithDict(obj, attr_name, NULL, 1); + } +#endif + result = __Pyx_PyObject_GetAttrStr(obj, attr_name); + if (unlikely(!result)) { + __Pyx_PyObject_GetAttrStr_ClearAttributeError(); + } + return result; +#endif +} + +/* GetBuiltinName */ +static PyObject *__Pyx_GetBuiltinName(PyObject *name) { + PyObject* result = __Pyx_PyObject_GetAttrStrNoError(__pyx_b, name); + if (unlikely(!result) && !PyErr_Occurred()) { + PyErr_Format(PyExc_NameError, +#if PY_MAJOR_VERSION >= 3 + "name '%U' is not defined", name); +#else + "name '%.200s' is not defined", PyString_AS_STRING(name)); +#endif + } + return result; +} + +/* TupleAndListFromArray */ +#if CYTHON_COMPILING_IN_CPYTHON +static CYTHON_INLINE void __Pyx_copy_object_array(PyObject *const *CYTHON_RESTRICT src, PyObject** CYTHON_RESTRICT dest, Py_ssize_t length) { + PyObject *v; + Py_ssize_t i; + for (i = 0; i < length; i++) { + v = dest[i] = src[i]; + Py_INCREF(v); + } +} +static CYTHON_INLINE PyObject * +__Pyx_PyTuple_FromArray(PyObject *const *src, Py_ssize_t n) +{ + PyObject *res; + if (n <= 0) { + Py_INCREF(__pyx_empty_tuple); + return __pyx_empty_tuple; + } + res = PyTuple_New(n); + if (unlikely(res == NULL)) return NULL; + __Pyx_copy_object_array(src, ((PyTupleObject*)res)->ob_item, n); + return res; +} +static CYTHON_INLINE PyObject * +__Pyx_PyList_FromArray(PyObject *const *src, Py_ssize_t n) +{ + PyObject *res; + if (n <= 0) { + return PyList_New(0); + } + res = PyList_New(n); + if (unlikely(res == NULL)) return NULL; + __Pyx_copy_object_array(src, ((PyListObject*)res)->ob_item, n); + return res; +} +#endif + +/* BytesEquals */ +static CYTHON_INLINE int __Pyx_PyBytes_Equals(PyObject* s1, PyObject* s2, int equals) { +#if CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API + return PyObject_RichCompareBool(s1, s2, equals); +#else + if (s1 == s2) { + return (equals == Py_EQ); + } else if (PyBytes_CheckExact(s1) & PyBytes_CheckExact(s2)) { + const char *ps1, *ps2; + Py_ssize_t length = PyBytes_GET_SIZE(s1); + if (length != PyBytes_GET_SIZE(s2)) + return (equals == Py_NE); + ps1 = PyBytes_AS_STRING(s1); + ps2 = PyBytes_AS_STRING(s2); + if (ps1[0] != ps2[0]) { + return (equals == Py_NE); + } else if (length == 1) { + return (equals == Py_EQ); + } else { + int result; +#if CYTHON_USE_UNICODE_INTERNALS && (PY_VERSION_HEX < 0x030B0000) + Py_hash_t hash1, hash2; + hash1 = ((PyBytesObject*)s1)->ob_shash; + hash2 = ((PyBytesObject*)s2)->ob_shash; + if (hash1 != hash2 && hash1 != -1 && hash2 != -1) { + return (equals == Py_NE); + } +#endif + result = memcmp(ps1, ps2, (size_t)length); + return (equals == Py_EQ) ? (result == 0) : (result != 0); + } + } else if ((s1 == Py_None) & PyBytes_CheckExact(s2)) { + return (equals == Py_NE); + } else if ((s2 == Py_None) & PyBytes_CheckExact(s1)) { + return (equals == Py_NE); + } else { + int result; + PyObject* py_result = PyObject_RichCompare(s1, s2, equals); + if (!py_result) + return -1; + result = __Pyx_PyObject_IsTrue(py_result); + Py_DECREF(py_result); + return result; + } +#endif +} + +/* UnicodeEquals */ +static CYTHON_INLINE int __Pyx_PyUnicode_Equals(PyObject* s1, PyObject* s2, int equals) { +#if CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API + return PyObject_RichCompareBool(s1, s2, equals); +#else +#if PY_MAJOR_VERSION < 3 + PyObject* owned_ref = NULL; +#endif + int s1_is_unicode, s2_is_unicode; + if (s1 == s2) { + goto return_eq; + } + s1_is_unicode = PyUnicode_CheckExact(s1); + s2_is_unicode = PyUnicode_CheckExact(s2); +#if PY_MAJOR_VERSION < 3 + if ((s1_is_unicode & (!s2_is_unicode)) && PyString_CheckExact(s2)) { + owned_ref = PyUnicode_FromObject(s2); + if (unlikely(!owned_ref)) + return -1; + s2 = owned_ref; + s2_is_unicode = 1; + } else if ((s2_is_unicode & (!s1_is_unicode)) && PyString_CheckExact(s1)) { + owned_ref = PyUnicode_FromObject(s1); + if (unlikely(!owned_ref)) + return -1; + s1 = owned_ref; + s1_is_unicode = 1; + } else if (((!s2_is_unicode) & (!s1_is_unicode))) { + return __Pyx_PyBytes_Equals(s1, s2, equals); + } +#endif + if (s1_is_unicode & s2_is_unicode) { + Py_ssize_t length; + int kind; + void *data1, *data2; + if (unlikely(__Pyx_PyUnicode_READY(s1) < 0) || unlikely(__Pyx_PyUnicode_READY(s2) < 0)) + return -1; + length = __Pyx_PyUnicode_GET_LENGTH(s1); + if (length != __Pyx_PyUnicode_GET_LENGTH(s2)) { + goto return_ne; + } +#if CYTHON_USE_UNICODE_INTERNALS + { + Py_hash_t hash1, hash2; + #if CYTHON_PEP393_ENABLED + hash1 = ((PyASCIIObject*)s1)->hash; + hash2 = ((PyASCIIObject*)s2)->hash; + #else + hash1 = ((PyUnicodeObject*)s1)->hash; + hash2 = ((PyUnicodeObject*)s2)->hash; + #endif + if (hash1 != hash2 && hash1 != -1 && hash2 != -1) { + goto return_ne; + } + } +#endif + kind = __Pyx_PyUnicode_KIND(s1); + if (kind != __Pyx_PyUnicode_KIND(s2)) { + goto return_ne; + } + data1 = __Pyx_PyUnicode_DATA(s1); + data2 = __Pyx_PyUnicode_DATA(s2); + if (__Pyx_PyUnicode_READ(kind, data1, 0) != __Pyx_PyUnicode_READ(kind, data2, 0)) { + goto return_ne; + } else if (length == 1) { + goto return_eq; + } else { + int result = memcmp(data1, data2, (size_t)(length * kind)); + #if PY_MAJOR_VERSION < 3 + Py_XDECREF(owned_ref); + #endif + return (equals == Py_EQ) ? (result == 0) : (result != 0); + } + } else if ((s1 == Py_None) & s2_is_unicode) { + goto return_ne; + } else if ((s2 == Py_None) & s1_is_unicode) { + goto return_ne; + } else { + int result; + PyObject* py_result = PyObject_RichCompare(s1, s2, equals); + #if PY_MAJOR_VERSION < 3 + Py_XDECREF(owned_ref); + #endif + if (!py_result) + return -1; + result = __Pyx_PyObject_IsTrue(py_result); + Py_DECREF(py_result); + return result; + } +return_eq: + #if PY_MAJOR_VERSION < 3 + Py_XDECREF(owned_ref); + #endif + return (equals == Py_EQ); +return_ne: + #if PY_MAJOR_VERSION < 3 + Py_XDECREF(owned_ref); + #endif + return (equals == Py_NE); +#endif +} + +/* fastcall */ +#if CYTHON_METH_FASTCALL +static CYTHON_INLINE PyObject * __Pyx_GetKwValue_FASTCALL(PyObject *kwnames, PyObject *const *kwvalues, PyObject *s) +{ + Py_ssize_t i, n = PyTuple_GET_SIZE(kwnames); + for (i = 0; i < n; i++) + { + if (s == PyTuple_GET_ITEM(kwnames, i)) return kwvalues[i]; + } + for (i = 0; i < n; i++) + { + int eq = __Pyx_PyUnicode_Equals(s, PyTuple_GET_ITEM(kwnames, i), Py_EQ); + if (unlikely(eq != 0)) { + if (unlikely(eq < 0)) return NULL; + return kwvalues[i]; + } + } + return NULL; +} +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030d0000 +CYTHON_UNUSED static PyObject *__Pyx_KwargsAsDict_FASTCALL(PyObject *kwnames, PyObject *const *kwvalues) { + Py_ssize_t i, nkwargs = PyTuple_GET_SIZE(kwnames); + PyObject *dict; + dict = PyDict_New(); + if (unlikely(!dict)) + return NULL; + for (i=0; i= 3 + "%s() got multiple values for keyword argument '%U'", func_name, kw_name); + #else + "%s() got multiple values for keyword argument '%s'", func_name, + PyString_AsString(kw_name)); + #endif +} + +/* ParseKeywords */ +static int __Pyx_ParseOptionalKeywords( + PyObject *kwds, + PyObject *const *kwvalues, + PyObject **argnames[], + PyObject *kwds2, + PyObject *values[], + Py_ssize_t num_pos_args, + const char* function_name) +{ + PyObject *key = 0, *value = 0; + Py_ssize_t pos = 0; + PyObject*** name; + PyObject*** first_kw_arg = argnames + num_pos_args; + int kwds_is_tuple = CYTHON_METH_FASTCALL && likely(PyTuple_Check(kwds)); + while (1) { + Py_XDECREF(key); key = NULL; + Py_XDECREF(value); value = NULL; + if (kwds_is_tuple) { + Py_ssize_t size; +#if CYTHON_ASSUME_SAFE_MACROS + size = PyTuple_GET_SIZE(kwds); +#else + size = PyTuple_Size(kwds); + if (size < 0) goto bad; +#endif + if (pos >= size) break; +#if CYTHON_AVOID_BORROWED_REFS + key = __Pyx_PySequence_ITEM(kwds, pos); + if (!key) goto bad; +#elif CYTHON_ASSUME_SAFE_MACROS + key = PyTuple_GET_ITEM(kwds, pos); +#else + key = PyTuple_GetItem(kwds, pos); + if (!key) goto bad; +#endif + value = kwvalues[pos]; + pos++; + } + else + { + if (!PyDict_Next(kwds, &pos, &key, &value)) break; +#if CYTHON_AVOID_BORROWED_REFS + Py_INCREF(key); +#endif + } + name = first_kw_arg; + while (*name && (**name != key)) name++; + if (*name) { + values[name-argnames] = value; +#if CYTHON_AVOID_BORROWED_REFS + Py_INCREF(value); + Py_DECREF(key); +#endif + key = NULL; + value = NULL; + continue; + } +#if !CYTHON_AVOID_BORROWED_REFS + Py_INCREF(key); +#endif + Py_INCREF(value); + name = first_kw_arg; + #if PY_MAJOR_VERSION < 3 + if (likely(PyString_Check(key))) { + while (*name) { + if ((CYTHON_COMPILING_IN_PYPY || PyString_GET_SIZE(**name) == PyString_GET_SIZE(key)) + && _PyString_Eq(**name, key)) { + values[name-argnames] = value; +#if CYTHON_AVOID_BORROWED_REFS + value = NULL; +#endif + break; + } + name++; + } + if (*name) continue; + else { + PyObject*** argname = argnames; + while (argname != first_kw_arg) { + if ((**argname == key) || ( + (CYTHON_COMPILING_IN_PYPY || PyString_GET_SIZE(**argname) == PyString_GET_SIZE(key)) + && _PyString_Eq(**argname, key))) { + goto arg_passed_twice; + } + argname++; + } + } + } else + #endif + if (likely(PyUnicode_Check(key))) { + while (*name) { + int cmp = ( + #if !CYTHON_COMPILING_IN_PYPY && PY_MAJOR_VERSION >= 3 + (__Pyx_PyUnicode_GET_LENGTH(**name) != __Pyx_PyUnicode_GET_LENGTH(key)) ? 1 : + #endif + PyUnicode_Compare(**name, key) + ); + if (cmp < 0 && unlikely(PyErr_Occurred())) goto bad; + if (cmp == 0) { + values[name-argnames] = value; +#if CYTHON_AVOID_BORROWED_REFS + value = NULL; +#endif + break; + } + name++; + } + if (*name) continue; + else { + PyObject*** argname = argnames; + while (argname != first_kw_arg) { + int cmp = (**argname == key) ? 0 : + #if !CYTHON_COMPILING_IN_PYPY && PY_MAJOR_VERSION >= 3 + (__Pyx_PyUnicode_GET_LENGTH(**argname) != __Pyx_PyUnicode_GET_LENGTH(key)) ? 1 : + #endif + PyUnicode_Compare(**argname, key); + if (cmp < 0 && unlikely(PyErr_Occurred())) goto bad; + if (cmp == 0) goto arg_passed_twice; + argname++; + } + } + } else + goto invalid_keyword_type; + if (kwds2) { + if (unlikely(PyDict_SetItem(kwds2, key, value))) goto bad; + } else { + goto invalid_keyword; + } + } + Py_XDECREF(key); + Py_XDECREF(value); + return 0; +arg_passed_twice: + __Pyx_RaiseDoubleKeywordsError(function_name, key); + goto bad; +invalid_keyword_type: + PyErr_Format(PyExc_TypeError, + "%.200s() keywords must be strings", function_name); + goto bad; +invalid_keyword: + #if PY_MAJOR_VERSION < 3 + PyErr_Format(PyExc_TypeError, + "%.200s() got an unexpected keyword argument '%.200s'", + function_name, PyString_AsString(key)); + #else + PyErr_Format(PyExc_TypeError, + "%s() got an unexpected keyword argument '%U'", + function_name, key); + #endif +bad: + Py_XDECREF(key); + Py_XDECREF(value); + return -1; +} + +/* ArgTypeTest */ +static int __Pyx__ArgTypeTest(PyObject *obj, PyTypeObject *type, const char *name, int exact) +{ + __Pyx_TypeName type_name; + __Pyx_TypeName obj_type_name; + if (unlikely(!type)) { + PyErr_SetString(PyExc_SystemError, "Missing type object"); + return 0; + } + else if (exact) { + #if PY_MAJOR_VERSION == 2 + if ((type == &PyBaseString_Type) && likely(__Pyx_PyBaseString_CheckExact(obj))) return 1; + #endif + } + else { + if (likely(__Pyx_TypeCheck(obj, type))) return 1; + } + type_name = __Pyx_PyType_GetName(type); + obj_type_name = __Pyx_PyType_GetName(Py_TYPE(obj)); + PyErr_Format(PyExc_TypeError, + "Argument '%.200s' has incorrect type (expected " __Pyx_FMT_TYPENAME + ", got " __Pyx_FMT_TYPENAME ")", name, type_name, obj_type_name); + __Pyx_DECREF_TypeName(type_name); + __Pyx_DECREF_TypeName(obj_type_name); + return 0; +} + +/* RaiseException */ +#if PY_MAJOR_VERSION < 3 +static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause) { + __Pyx_PyThreadState_declare + CYTHON_UNUSED_VAR(cause); + Py_XINCREF(type); + if (!value || value == Py_None) + value = NULL; + else + Py_INCREF(value); + if (!tb || tb == Py_None) + tb = NULL; + else { + Py_INCREF(tb); + if (!PyTraceBack_Check(tb)) { + PyErr_SetString(PyExc_TypeError, + "raise: arg 3 must be a traceback or None"); + goto raise_error; + } + } + if (PyType_Check(type)) { +#if CYTHON_COMPILING_IN_PYPY + if (!value) { + Py_INCREF(Py_None); + value = Py_None; + } +#endif + PyErr_NormalizeException(&type, &value, &tb); + } else { + if (value) { + PyErr_SetString(PyExc_TypeError, + "instance exception may not have a separate value"); + goto raise_error; + } + value = type; + type = (PyObject*) Py_TYPE(type); + Py_INCREF(type); + if (!PyType_IsSubtype((PyTypeObject *)type, (PyTypeObject *)PyExc_BaseException)) { + PyErr_SetString(PyExc_TypeError, + "raise: exception class must be a subclass of BaseException"); + goto raise_error; + } + } + __Pyx_PyThreadState_assign + __Pyx_ErrRestore(type, value, tb); + return; +raise_error: + Py_XDECREF(value); + Py_XDECREF(type); + Py_XDECREF(tb); + return; +} +#else +static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause) { + PyObject* owned_instance = NULL; + if (tb == Py_None) { + tb = 0; + } else if (tb && !PyTraceBack_Check(tb)) { + PyErr_SetString(PyExc_TypeError, + "raise: arg 3 must be a traceback or None"); + goto bad; + } + if (value == Py_None) + value = 0; + if (PyExceptionInstance_Check(type)) { + if (value) { + PyErr_SetString(PyExc_TypeError, + "instance exception may not have a separate value"); + goto bad; + } + value = type; + type = (PyObject*) Py_TYPE(value); + } else if (PyExceptionClass_Check(type)) { + PyObject *instance_class = NULL; + if (value && PyExceptionInstance_Check(value)) { + instance_class = (PyObject*) Py_TYPE(value); + if (instance_class != type) { + int is_subclass = PyObject_IsSubclass(instance_class, type); + if (!is_subclass) { + instance_class = NULL; + } else if (unlikely(is_subclass == -1)) { + goto bad; + } else { + type = instance_class; + } + } + } + if (!instance_class) { + PyObject *args; + if (!value) + args = PyTuple_New(0); + else if (PyTuple_Check(value)) { + Py_INCREF(value); + args = value; + } else + args = PyTuple_Pack(1, value); + if (!args) + goto bad; + owned_instance = PyObject_Call(type, args, NULL); + Py_DECREF(args); + if (!owned_instance) + goto bad; + value = owned_instance; + if (!PyExceptionInstance_Check(value)) { + PyErr_Format(PyExc_TypeError, + "calling %R should have returned an instance of " + "BaseException, not %R", + type, Py_TYPE(value)); + goto bad; + } + } + } else { + PyErr_SetString(PyExc_TypeError, + "raise: exception class must be a subclass of BaseException"); + goto bad; + } + if (cause) { + PyObject *fixed_cause; + if (cause == Py_None) { + fixed_cause = NULL; + } else if (PyExceptionClass_Check(cause)) { + fixed_cause = PyObject_CallObject(cause, NULL); + if (fixed_cause == NULL) + goto bad; + } else if (PyExceptionInstance_Check(cause)) { + fixed_cause = cause; + Py_INCREF(fixed_cause); + } else { + PyErr_SetString(PyExc_TypeError, + "exception causes must derive from " + "BaseException"); + goto bad; + } + PyException_SetCause(value, fixed_cause); + } + PyErr_SetObject(type, value); + if (tb) { + #if PY_VERSION_HEX >= 0x030C00A6 + PyException_SetTraceback(value, tb); + #elif CYTHON_FAST_THREAD_STATE + PyThreadState *tstate = __Pyx_PyThreadState_Current; + PyObject* tmp_tb = tstate->curexc_traceback; + if (tb != tmp_tb) { + Py_INCREF(tb); + tstate->curexc_traceback = tb; + Py_XDECREF(tmp_tb); + } +#else + PyObject *tmp_type, *tmp_value, *tmp_tb; + PyErr_Fetch(&tmp_type, &tmp_value, &tmp_tb); + Py_INCREF(tb); + PyErr_Restore(tmp_type, tmp_value, tb); + Py_XDECREF(tmp_tb); +#endif + } +bad: + Py_XDECREF(owned_instance); + return; +} +#endif + +/* PyFunctionFastCall */ +#if CYTHON_FAST_PYCALL && !CYTHON_VECTORCALL +static PyObject* __Pyx_PyFunction_FastCallNoKw(PyCodeObject *co, PyObject **args, Py_ssize_t na, + PyObject *globals) { + PyFrameObject *f; + PyThreadState *tstate = __Pyx_PyThreadState_Current; + PyObject **fastlocals; + Py_ssize_t i; + PyObject *result; + assert(globals != NULL); + /* XXX Perhaps we should create a specialized + PyFrame_New() that doesn't take locals, but does + take builtins without sanity checking them. + */ + assert(tstate != NULL); + f = PyFrame_New(tstate, co, globals, NULL); + if (f == NULL) { + return NULL; + } + fastlocals = __Pyx_PyFrame_GetLocalsplus(f); + for (i = 0; i < na; i++) { + Py_INCREF(*args); + fastlocals[i] = *args++; + } + result = PyEval_EvalFrameEx(f,0); + ++tstate->recursion_depth; + Py_DECREF(f); + --tstate->recursion_depth; + return result; +} +static PyObject *__Pyx_PyFunction_FastCallDict(PyObject *func, PyObject **args, Py_ssize_t nargs, PyObject *kwargs) { + PyCodeObject *co = (PyCodeObject *)PyFunction_GET_CODE(func); + PyObject *globals = PyFunction_GET_GLOBALS(func); + PyObject *argdefs = PyFunction_GET_DEFAULTS(func); + PyObject *closure; +#if PY_MAJOR_VERSION >= 3 + PyObject *kwdefs; +#endif + PyObject *kwtuple, **k; + PyObject **d; + Py_ssize_t nd; + Py_ssize_t nk; + PyObject *result; + assert(kwargs == NULL || PyDict_Check(kwargs)); + nk = kwargs ? PyDict_Size(kwargs) : 0; + #if PY_MAJOR_VERSION < 3 + if (unlikely(Py_EnterRecursiveCall((char*)" while calling a Python object"))) { + return NULL; + } + #else + if (unlikely(Py_EnterRecursiveCall(" while calling a Python object"))) { + return NULL; + } + #endif + if ( +#if PY_MAJOR_VERSION >= 3 + co->co_kwonlyargcount == 0 && +#endif + likely(kwargs == NULL || nk == 0) && + co->co_flags == (CO_OPTIMIZED | CO_NEWLOCALS | CO_NOFREE)) { + if (argdefs == NULL && co->co_argcount == nargs) { + result = __Pyx_PyFunction_FastCallNoKw(co, args, nargs, globals); + goto done; + } + else if (nargs == 0 && argdefs != NULL + && co->co_argcount == Py_SIZE(argdefs)) { + /* function called with no arguments, but all parameters have + a default value: use default values as arguments .*/ + args = &PyTuple_GET_ITEM(argdefs, 0); + result =__Pyx_PyFunction_FastCallNoKw(co, args, Py_SIZE(argdefs), globals); + goto done; + } + } + if (kwargs != NULL) { + Py_ssize_t pos, i; + kwtuple = PyTuple_New(2 * nk); + if (kwtuple == NULL) { + result = NULL; + goto done; + } + k = &PyTuple_GET_ITEM(kwtuple, 0); + pos = i = 0; + while (PyDict_Next(kwargs, &pos, &k[i], &k[i+1])) { + Py_INCREF(k[i]); + Py_INCREF(k[i+1]); + i += 2; + } + nk = i / 2; + } + else { + kwtuple = NULL; + k = NULL; + } + closure = PyFunction_GET_CLOSURE(func); +#if PY_MAJOR_VERSION >= 3 + kwdefs = PyFunction_GET_KW_DEFAULTS(func); +#endif + if (argdefs != NULL) { + d = &PyTuple_GET_ITEM(argdefs, 0); + nd = Py_SIZE(argdefs); + } + else { + d = NULL; + nd = 0; + } +#if PY_MAJOR_VERSION >= 3 + result = PyEval_EvalCodeEx((PyObject*)co, globals, (PyObject *)NULL, + args, (int)nargs, + k, (int)nk, + d, (int)nd, kwdefs, closure); +#else + result = PyEval_EvalCodeEx(co, globals, (PyObject *)NULL, + args, (int)nargs, + k, (int)nk, + d, (int)nd, closure); +#endif + Py_XDECREF(kwtuple); +done: + Py_LeaveRecursiveCall(); + return result; +} +#endif + +/* PyObjectCall */ +#if CYTHON_COMPILING_IN_CPYTHON +static CYTHON_INLINE PyObject* __Pyx_PyObject_Call(PyObject *func, PyObject *arg, PyObject *kw) { + PyObject *result; + ternaryfunc call = Py_TYPE(func)->tp_call; + if (unlikely(!call)) + return PyObject_Call(func, arg, kw); + #if PY_MAJOR_VERSION < 3 + if (unlikely(Py_EnterRecursiveCall((char*)" while calling a Python object"))) + return NULL; + #else + if (unlikely(Py_EnterRecursiveCall(" while calling a Python object"))) + return NULL; + #endif + result = (*call)(func, arg, kw); + Py_LeaveRecursiveCall(); + if (unlikely(!result) && unlikely(!PyErr_Occurred())) { + PyErr_SetString( + PyExc_SystemError, + "NULL result without error in PyObject_Call"); + } + return result; +} +#endif + +/* PyObjectCallMethO */ +#if CYTHON_COMPILING_IN_CPYTHON +static CYTHON_INLINE PyObject* __Pyx_PyObject_CallMethO(PyObject *func, PyObject *arg) { + PyObject *self, *result; + PyCFunction cfunc; + cfunc = __Pyx_CyOrPyCFunction_GET_FUNCTION(func); + self = __Pyx_CyOrPyCFunction_GET_SELF(func); + #if PY_MAJOR_VERSION < 3 + if (unlikely(Py_EnterRecursiveCall((char*)" while calling a Python object"))) + return NULL; + #else + if (unlikely(Py_EnterRecursiveCall(" while calling a Python object"))) + return NULL; + #endif + result = cfunc(self, arg); + Py_LeaveRecursiveCall(); + if (unlikely(!result) && unlikely(!PyErr_Occurred())) { + PyErr_SetString( + PyExc_SystemError, + "NULL result without error in PyObject_Call"); + } + return result; +} +#endif + +/* PyObjectFastCall */ +#if PY_VERSION_HEX < 0x03090000 || CYTHON_COMPILING_IN_LIMITED_API +static PyObject* __Pyx_PyObject_FastCall_fallback(PyObject *func, PyObject **args, size_t nargs, PyObject *kwargs) { + PyObject *argstuple; + PyObject *result = 0; + size_t i; + argstuple = PyTuple_New((Py_ssize_t)nargs); + if (unlikely(!argstuple)) return NULL; + for (i = 0; i < nargs; i++) { + Py_INCREF(args[i]); + if (__Pyx_PyTuple_SET_ITEM(argstuple, (Py_ssize_t)i, args[i]) < 0) goto bad; + } + result = __Pyx_PyObject_Call(func, argstuple, kwargs); + bad: + Py_DECREF(argstuple); + return result; +} +#endif +static CYTHON_INLINE PyObject* __Pyx_PyObject_FastCallDict(PyObject *func, PyObject **args, size_t _nargs, PyObject *kwargs) { + Py_ssize_t nargs = __Pyx_PyVectorcall_NARGS(_nargs); +#if CYTHON_COMPILING_IN_CPYTHON + if (nargs == 0 && kwargs == NULL) { + if (__Pyx_CyOrPyCFunction_Check(func) && likely( __Pyx_CyOrPyCFunction_GET_FLAGS(func) & METH_NOARGS)) + return __Pyx_PyObject_CallMethO(func, NULL); + } + else if (nargs == 1 && kwargs == NULL) { + if (__Pyx_CyOrPyCFunction_Check(func) && likely( __Pyx_CyOrPyCFunction_GET_FLAGS(func) & METH_O)) + return __Pyx_PyObject_CallMethO(func, args[0]); + } +#endif + #if PY_VERSION_HEX < 0x030800B1 + #if CYTHON_FAST_PYCCALL + if (PyCFunction_Check(func)) { + if (kwargs) { + return _PyCFunction_FastCallDict(func, args, nargs, kwargs); + } else { + return _PyCFunction_FastCallKeywords(func, args, nargs, NULL); + } + } + #if PY_VERSION_HEX >= 0x030700A1 + if (!kwargs && __Pyx_IS_TYPE(func, &PyMethodDescr_Type)) { + return _PyMethodDescr_FastCallKeywords(func, args, nargs, NULL); + } + #endif + #endif + #if CYTHON_FAST_PYCALL + if (PyFunction_Check(func)) { + return __Pyx_PyFunction_FastCallDict(func, args, nargs, kwargs); + } + #endif + #endif + if (kwargs == NULL) { + #if CYTHON_VECTORCALL + #if PY_VERSION_HEX < 0x03090000 + vectorcallfunc f = _PyVectorcall_Function(func); + #else + vectorcallfunc f = PyVectorcall_Function(func); + #endif + if (f) { + return f(func, args, (size_t)nargs, NULL); + } + #elif defined(__Pyx_CyFunction_USED) && CYTHON_BACKPORT_VECTORCALL + if (__Pyx_CyFunction_CheckExact(func)) { + __pyx_vectorcallfunc f = __Pyx_CyFunction_func_vectorcall(func); + if (f) return f(func, args, (size_t)nargs, NULL); + } + #endif + } + if (nargs == 0) { + return __Pyx_PyObject_Call(func, __pyx_empty_tuple, kwargs); + } + #if PY_VERSION_HEX >= 0x03090000 && !CYTHON_COMPILING_IN_LIMITED_API + return PyObject_VectorcallDict(func, args, (size_t)nargs, kwargs); + #else + return __Pyx_PyObject_FastCall_fallback(func, args, (size_t)nargs, kwargs); + #endif +} + +/* RaiseUnexpectedTypeError */ +static int +__Pyx_RaiseUnexpectedTypeError(const char *expected, PyObject *obj) +{ + __Pyx_TypeName obj_type_name = __Pyx_PyType_GetName(Py_TYPE(obj)); + PyErr_Format(PyExc_TypeError, "Expected %s, got " __Pyx_FMT_TYPENAME, + expected, obj_type_name); + __Pyx_DECREF_TypeName(obj_type_name); + return 0; +} + +/* CIntToDigits */ +static const char DIGIT_PAIRS_10[2*10*10+1] = { + "00010203040506070809" + "10111213141516171819" + "20212223242526272829" + "30313233343536373839" + "40414243444546474849" + "50515253545556575859" + "60616263646566676869" + "70717273747576777879" + "80818283848586878889" + "90919293949596979899" +}; +static const char DIGIT_PAIRS_8[2*8*8+1] = { + "0001020304050607" + "1011121314151617" + "2021222324252627" + "3031323334353637" + "4041424344454647" + "5051525354555657" + "6061626364656667" + "7071727374757677" +}; +static const char DIGITS_HEX[2*16+1] = { + "0123456789abcdef" + "0123456789ABCDEF" +}; + +/* BuildPyUnicode */ +static PyObject* __Pyx_PyUnicode_BuildFromAscii(Py_ssize_t ulength, char* chars, int clength, + int prepend_sign, char padding_char) { + PyObject *uval; + Py_ssize_t uoffset = ulength - clength; +#if CYTHON_USE_UNICODE_INTERNALS + Py_ssize_t i; +#if CYTHON_PEP393_ENABLED + void *udata; + uval = PyUnicode_New(ulength, 127); + if (unlikely(!uval)) return NULL; + udata = PyUnicode_DATA(uval); +#else + Py_UNICODE *udata; + uval = PyUnicode_FromUnicode(NULL, ulength); + if (unlikely(!uval)) return NULL; + udata = PyUnicode_AS_UNICODE(uval); +#endif + if (uoffset > 0) { + i = 0; + if (prepend_sign) { + __Pyx_PyUnicode_WRITE(PyUnicode_1BYTE_KIND, udata, 0, '-'); + i++; + } + for (; i < uoffset; i++) { + __Pyx_PyUnicode_WRITE(PyUnicode_1BYTE_KIND, udata, i, padding_char); + } + } + for (i=0; i < clength; i++) { + __Pyx_PyUnicode_WRITE(PyUnicode_1BYTE_KIND, udata, uoffset+i, chars[i]); + } +#else + { + PyObject *sign = NULL, *padding = NULL; + uval = NULL; + if (uoffset > 0) { + prepend_sign = !!prepend_sign; + if (uoffset > prepend_sign) { + padding = PyUnicode_FromOrdinal(padding_char); + if (likely(padding) && uoffset > prepend_sign + 1) { + PyObject *tmp; + PyObject *repeat = PyInt_FromSsize_t(uoffset - prepend_sign); + if (unlikely(!repeat)) goto done_or_error; + tmp = PyNumber_Multiply(padding, repeat); + Py_DECREF(repeat); + Py_DECREF(padding); + padding = tmp; + } + if (unlikely(!padding)) goto done_or_error; + } + if (prepend_sign) { + sign = PyUnicode_FromOrdinal('-'); + if (unlikely(!sign)) goto done_or_error; + } + } + uval = PyUnicode_DecodeASCII(chars, clength, NULL); + if (likely(uval) && padding) { + PyObject *tmp = PyNumber_Add(padding, uval); + Py_DECREF(uval); + uval = tmp; + } + if (likely(uval) && sign) { + PyObject *tmp = PyNumber_Add(sign, uval); + Py_DECREF(uval); + uval = tmp; + } +done_or_error: + Py_XDECREF(padding); + Py_XDECREF(sign); + } +#endif + return uval; +} + +/* CIntToPyUnicode */ +static CYTHON_INLINE PyObject* __Pyx_PyUnicode_From_int(int value, Py_ssize_t width, char padding_char, char format_char) { + char digits[sizeof(int)*3+2]; + char *dpos, *end = digits + sizeof(int)*3+2; + const char *hex_digits = DIGITS_HEX; + Py_ssize_t length, ulength; + int prepend_sign, last_one_off; + int remaining; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#endif + const int neg_one = (int) -1, const_zero = (int) 0; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic pop +#endif + const int is_unsigned = neg_one > const_zero; + if (format_char == 'X') { + hex_digits += 16; + format_char = 'x'; + } + remaining = value; + last_one_off = 0; + dpos = end; + do { + int digit_pos; + switch (format_char) { + case 'o': + digit_pos = abs((int)(remaining % (8*8))); + remaining = (int) (remaining / (8*8)); + dpos -= 2; + memcpy(dpos, DIGIT_PAIRS_8 + digit_pos * 2, 2); + last_one_off = (digit_pos < 8); + break; + case 'd': + digit_pos = abs((int)(remaining % (10*10))); + remaining = (int) (remaining / (10*10)); + dpos -= 2; + memcpy(dpos, DIGIT_PAIRS_10 + digit_pos * 2, 2); + last_one_off = (digit_pos < 10); + break; + case 'x': + *(--dpos) = hex_digits[abs((int)(remaining % 16))]; + remaining = (int) (remaining / 16); + break; + default: + assert(0); + break; + } + } while (unlikely(remaining != 0)); + assert(!last_one_off || *dpos == '0'); + dpos += last_one_off; + length = end - dpos; + ulength = length; + prepend_sign = 0; + if (!is_unsigned && value <= neg_one) { + if (padding_char == ' ' || width <= length + 1) { + *(--dpos) = '-'; + ++length; + } else { + prepend_sign = 1; + } + ++ulength; + } + if (width > ulength) { + ulength = width; + } + if (ulength == 1) { + return PyUnicode_FromOrdinal(*dpos); + } + return __Pyx_PyUnicode_BuildFromAscii(ulength, dpos, (int) length, prepend_sign, padding_char); +} + +/* CIntToPyUnicode */ +static CYTHON_INLINE PyObject* __Pyx_PyUnicode_From_Py_ssize_t(Py_ssize_t value, Py_ssize_t width, char padding_char, char format_char) { + char digits[sizeof(Py_ssize_t)*3+2]; + char *dpos, *end = digits + sizeof(Py_ssize_t)*3+2; + const char *hex_digits = DIGITS_HEX; + Py_ssize_t length, ulength; + int prepend_sign, last_one_off; + Py_ssize_t remaining; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#endif + const Py_ssize_t neg_one = (Py_ssize_t) -1, const_zero = (Py_ssize_t) 0; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic pop +#endif + const int is_unsigned = neg_one > const_zero; + if (format_char == 'X') { + hex_digits += 16; + format_char = 'x'; + } + remaining = value; + last_one_off = 0; + dpos = end; + do { + int digit_pos; + switch (format_char) { + case 'o': + digit_pos = abs((int)(remaining % (8*8))); + remaining = (Py_ssize_t) (remaining / (8*8)); + dpos -= 2; + memcpy(dpos, DIGIT_PAIRS_8 + digit_pos * 2, 2); + last_one_off = (digit_pos < 8); + break; + case 'd': + digit_pos = abs((int)(remaining % (10*10))); + remaining = (Py_ssize_t) (remaining / (10*10)); + dpos -= 2; + memcpy(dpos, DIGIT_PAIRS_10 + digit_pos * 2, 2); + last_one_off = (digit_pos < 10); + break; + case 'x': + *(--dpos) = hex_digits[abs((int)(remaining % 16))]; + remaining = (Py_ssize_t) (remaining / 16); + break; + default: + assert(0); + break; + } + } while (unlikely(remaining != 0)); + assert(!last_one_off || *dpos == '0'); + dpos += last_one_off; + length = end - dpos; + ulength = length; + prepend_sign = 0; + if (!is_unsigned && value <= neg_one) { + if (padding_char == ' ' || width <= length + 1) { + *(--dpos) = '-'; + ++length; + } else { + prepend_sign = 1; + } + ++ulength; + } + if (width > ulength) { + ulength = width; + } + if (ulength == 1) { + return PyUnicode_FromOrdinal(*dpos); + } + return __Pyx_PyUnicode_BuildFromAscii(ulength, dpos, (int) length, prepend_sign, padding_char); +} + +/* JoinPyUnicode */ +static PyObject* __Pyx_PyUnicode_Join(PyObject* value_tuple, Py_ssize_t value_count, Py_ssize_t result_ulength, + Py_UCS4 max_char) { +#if CYTHON_USE_UNICODE_INTERNALS && CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + PyObject *result_uval; + int result_ukind, kind_shift; + Py_ssize_t i, char_pos; + void *result_udata; + CYTHON_MAYBE_UNUSED_VAR(max_char); +#if CYTHON_PEP393_ENABLED + result_uval = PyUnicode_New(result_ulength, max_char); + if (unlikely(!result_uval)) return NULL; + result_ukind = (max_char <= 255) ? PyUnicode_1BYTE_KIND : (max_char <= 65535) ? PyUnicode_2BYTE_KIND : PyUnicode_4BYTE_KIND; + kind_shift = (result_ukind == PyUnicode_4BYTE_KIND) ? 2 : result_ukind - 1; + result_udata = PyUnicode_DATA(result_uval); +#else + result_uval = PyUnicode_FromUnicode(NULL, result_ulength); + if (unlikely(!result_uval)) return NULL; + result_ukind = sizeof(Py_UNICODE); + kind_shift = (result_ukind == 4) ? 2 : result_ukind - 1; + result_udata = PyUnicode_AS_UNICODE(result_uval); +#endif + assert(kind_shift == 2 || kind_shift == 1 || kind_shift == 0); + char_pos = 0; + for (i=0; i < value_count; i++) { + int ukind; + Py_ssize_t ulength; + void *udata; + PyObject *uval = PyTuple_GET_ITEM(value_tuple, i); + if (unlikely(__Pyx_PyUnicode_READY(uval))) + goto bad; + ulength = __Pyx_PyUnicode_GET_LENGTH(uval); + if (unlikely(!ulength)) + continue; + if (unlikely((PY_SSIZE_T_MAX >> kind_shift) - ulength < char_pos)) + goto overflow; + ukind = __Pyx_PyUnicode_KIND(uval); + udata = __Pyx_PyUnicode_DATA(uval); + if (!CYTHON_PEP393_ENABLED || ukind == result_ukind) { + memcpy((char *)result_udata + (char_pos << kind_shift), udata, (size_t) (ulength << kind_shift)); + } else { + #if PY_VERSION_HEX >= 0x030d0000 + if (unlikely(PyUnicode_CopyCharacters(result_uval, char_pos, uval, 0, ulength) < 0)) goto bad; + #elif CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030300F0 || defined(_PyUnicode_FastCopyCharacters) + _PyUnicode_FastCopyCharacters(result_uval, char_pos, uval, 0, ulength); + #else + Py_ssize_t j; + for (j=0; j < ulength; j++) { + Py_UCS4 uchar = __Pyx_PyUnicode_READ(ukind, udata, j); + __Pyx_PyUnicode_WRITE(result_ukind, result_udata, char_pos+j, uchar); + } + #endif + } + char_pos += ulength; + } + return result_uval; +overflow: + PyErr_SetString(PyExc_OverflowError, "join() result is too long for a Python string"); +bad: + Py_DECREF(result_uval); + return NULL; +#else + CYTHON_UNUSED_VAR(max_char); + CYTHON_UNUSED_VAR(result_ulength); + CYTHON_UNUSED_VAR(value_count); + return PyUnicode_Join(__pyx_empty_unicode, value_tuple); +#endif +} + +/* GetAttr */ +static CYTHON_INLINE PyObject *__Pyx_GetAttr(PyObject *o, PyObject *n) { +#if CYTHON_USE_TYPE_SLOTS +#if PY_MAJOR_VERSION >= 3 + if (likely(PyUnicode_Check(n))) +#else + if (likely(PyString_Check(n))) +#endif + return __Pyx_PyObject_GetAttrStr(o, n); +#endif + return PyObject_GetAttr(o, n); +} + +/* GetItemInt */ +static PyObject *__Pyx_GetItemInt_Generic(PyObject *o, PyObject* j) { + PyObject *r; + if (unlikely(!j)) return NULL; + r = PyObject_GetItem(o, j); + Py_DECREF(j); + return r; +} +static CYTHON_INLINE PyObject *__Pyx_GetItemInt_List_Fast(PyObject *o, Py_ssize_t i, + CYTHON_NCP_UNUSED int wraparound, + CYTHON_NCP_UNUSED int boundscheck) { +#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + Py_ssize_t wrapped_i = i; + if (wraparound & unlikely(i < 0)) { + wrapped_i += PyList_GET_SIZE(o); + } + if ((!boundscheck) || likely(__Pyx_is_valid_index(wrapped_i, PyList_GET_SIZE(o)))) { + PyObject *r = PyList_GET_ITEM(o, wrapped_i); + Py_INCREF(r); + return r; + } + return __Pyx_GetItemInt_Generic(o, PyInt_FromSsize_t(i)); +#else + return PySequence_GetItem(o, i); +#endif +} +static CYTHON_INLINE PyObject *__Pyx_GetItemInt_Tuple_Fast(PyObject *o, Py_ssize_t i, + CYTHON_NCP_UNUSED int wraparound, + CYTHON_NCP_UNUSED int boundscheck) { +#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + Py_ssize_t wrapped_i = i; + if (wraparound & unlikely(i < 0)) { + wrapped_i += PyTuple_GET_SIZE(o); + } + if ((!boundscheck) || likely(__Pyx_is_valid_index(wrapped_i, PyTuple_GET_SIZE(o)))) { + PyObject *r = PyTuple_GET_ITEM(o, wrapped_i); + Py_INCREF(r); + return r; + } + return __Pyx_GetItemInt_Generic(o, PyInt_FromSsize_t(i)); +#else + return PySequence_GetItem(o, i); +#endif +} +static CYTHON_INLINE PyObject *__Pyx_GetItemInt_Fast(PyObject *o, Py_ssize_t i, int is_list, + CYTHON_NCP_UNUSED int wraparound, + CYTHON_NCP_UNUSED int boundscheck) { +#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS && CYTHON_USE_TYPE_SLOTS + if (is_list || PyList_CheckExact(o)) { + Py_ssize_t n = ((!wraparound) | likely(i >= 0)) ? i : i + PyList_GET_SIZE(o); + if ((!boundscheck) || (likely(__Pyx_is_valid_index(n, PyList_GET_SIZE(o))))) { + PyObject *r = PyList_GET_ITEM(o, n); + Py_INCREF(r); + return r; + } + } + else if (PyTuple_CheckExact(o)) { + Py_ssize_t n = ((!wraparound) | likely(i >= 0)) ? i : i + PyTuple_GET_SIZE(o); + if ((!boundscheck) || likely(__Pyx_is_valid_index(n, PyTuple_GET_SIZE(o)))) { + PyObject *r = PyTuple_GET_ITEM(o, n); + Py_INCREF(r); + return r; + } + } else { + PyMappingMethods *mm = Py_TYPE(o)->tp_as_mapping; + PySequenceMethods *sm = Py_TYPE(o)->tp_as_sequence; + if (mm && mm->mp_subscript) { + PyObject *r, *key = PyInt_FromSsize_t(i); + if (unlikely(!key)) return NULL; + r = mm->mp_subscript(o, key); + Py_DECREF(key); + return r; + } + if (likely(sm && sm->sq_item)) { + if (wraparound && unlikely(i < 0) && likely(sm->sq_length)) { + Py_ssize_t l = sm->sq_length(o); + if (likely(l >= 0)) { + i += l; + } else { + if (!PyErr_ExceptionMatches(PyExc_OverflowError)) + return NULL; + PyErr_Clear(); + } + } + return sm->sq_item(o, i); + } + } +#else + if (is_list || !PyMapping_Check(o)) { + return PySequence_GetItem(o, i); + } +#endif + return __Pyx_GetItemInt_Generic(o, PyInt_FromSsize_t(i)); +} + +/* PyObjectCallOneArg */ +static CYTHON_INLINE PyObject* __Pyx_PyObject_CallOneArg(PyObject *func, PyObject *arg) { + PyObject *args[2] = {NULL, arg}; + return __Pyx_PyObject_FastCall(func, args+1, 1 | __Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET); +} + +/* ObjectGetItem */ +#if CYTHON_USE_TYPE_SLOTS +static PyObject *__Pyx_PyObject_GetIndex(PyObject *obj, PyObject *index) { + PyObject *runerr = NULL; + Py_ssize_t key_value; + key_value = __Pyx_PyIndex_AsSsize_t(index); + if (likely(key_value != -1 || !(runerr = PyErr_Occurred()))) { + return __Pyx_GetItemInt_Fast(obj, key_value, 0, 1, 1); + } + if (PyErr_GivenExceptionMatches(runerr, PyExc_OverflowError)) { + __Pyx_TypeName index_type_name = __Pyx_PyType_GetName(Py_TYPE(index)); + PyErr_Clear(); + PyErr_Format(PyExc_IndexError, + "cannot fit '" __Pyx_FMT_TYPENAME "' into an index-sized integer", index_type_name); + __Pyx_DECREF_TypeName(index_type_name); + } + return NULL; +} +static PyObject *__Pyx_PyObject_GetItem_Slow(PyObject *obj, PyObject *key) { + __Pyx_TypeName obj_type_name; + if (likely(PyType_Check(obj))) { + PyObject *meth = __Pyx_PyObject_GetAttrStrNoError(obj, __pyx_n_s_class_getitem); + if (!meth) { + PyErr_Clear(); + } else { + PyObject *result = __Pyx_PyObject_CallOneArg(meth, key); + Py_DECREF(meth); + return result; + } + } + obj_type_name = __Pyx_PyType_GetName(Py_TYPE(obj)); + PyErr_Format(PyExc_TypeError, + "'" __Pyx_FMT_TYPENAME "' object is not subscriptable", obj_type_name); + __Pyx_DECREF_TypeName(obj_type_name); + return NULL; +} +static PyObject *__Pyx_PyObject_GetItem(PyObject *obj, PyObject *key) { + PyTypeObject *tp = Py_TYPE(obj); + PyMappingMethods *mm = tp->tp_as_mapping; + PySequenceMethods *sm = tp->tp_as_sequence; + if (likely(mm && mm->mp_subscript)) { + return mm->mp_subscript(obj, key); + } + if (likely(sm && sm->sq_item)) { + return __Pyx_PyObject_GetIndex(obj, key); + } + return __Pyx_PyObject_GetItem_Slow(obj, key); +} +#endif + +/* KeywordStringCheck */ +static int __Pyx_CheckKeywordStrings( + PyObject *kw, + const char* function_name, + int kw_allowed) +{ + PyObject* key = 0; + Py_ssize_t pos = 0; +#if CYTHON_COMPILING_IN_PYPY + if (!kw_allowed && PyDict_Next(kw, &pos, &key, 0)) + goto invalid_keyword; + return 1; +#else + if (CYTHON_METH_FASTCALL && likely(PyTuple_Check(kw))) { + Py_ssize_t kwsize; +#if CYTHON_ASSUME_SAFE_MACROS + kwsize = PyTuple_GET_SIZE(kw); +#else + kwsize = PyTuple_Size(kw); + if (kwsize < 0) return 0; +#endif + if (unlikely(kwsize == 0)) + return 1; + if (!kw_allowed) { +#if CYTHON_ASSUME_SAFE_MACROS + key = PyTuple_GET_ITEM(kw, 0); +#else + key = PyTuple_GetItem(kw, pos); + if (!key) return 0; +#endif + goto invalid_keyword; + } +#if PY_VERSION_HEX < 0x03090000 + for (pos = 0; pos < kwsize; pos++) { +#if CYTHON_ASSUME_SAFE_MACROS + key = PyTuple_GET_ITEM(kw, pos); +#else + key = PyTuple_GetItem(kw, pos); + if (!key) return 0; +#endif + if (unlikely(!PyUnicode_Check(key))) + goto invalid_keyword_type; + } +#endif + return 1; + } + while (PyDict_Next(kw, &pos, &key, 0)) { + #if PY_MAJOR_VERSION < 3 + if (unlikely(!PyString_Check(key))) + #endif + if (unlikely(!PyUnicode_Check(key))) + goto invalid_keyword_type; + } + if (!kw_allowed && unlikely(key)) + goto invalid_keyword; + return 1; +invalid_keyword_type: + PyErr_Format(PyExc_TypeError, + "%.200s() keywords must be strings", function_name); + return 0; +#endif +invalid_keyword: + #if PY_MAJOR_VERSION < 3 + PyErr_Format(PyExc_TypeError, + "%.200s() got an unexpected keyword argument '%.200s'", + function_name, PyString_AsString(key)); + #else + PyErr_Format(PyExc_TypeError, + "%s() got an unexpected keyword argument '%U'", + function_name, key); + #endif + return 0; +} + +/* DivInt[Py_ssize_t] */ +static CYTHON_INLINE Py_ssize_t __Pyx_div_Py_ssize_t(Py_ssize_t a, Py_ssize_t b) { + Py_ssize_t q = a / b; + Py_ssize_t r = a - q*b; + q -= ((r != 0) & ((r ^ b) < 0)); + return q; +} + +/* GetAttr3 */ +#if __PYX_LIMITED_VERSION_HEX < 0x030d00A1 +static PyObject *__Pyx_GetAttr3Default(PyObject *d) { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + if (unlikely(!__Pyx_PyErr_ExceptionMatches(PyExc_AttributeError))) + return NULL; + __Pyx_PyErr_Clear(); + Py_INCREF(d); + return d; +} +#endif +static CYTHON_INLINE PyObject *__Pyx_GetAttr3(PyObject *o, PyObject *n, PyObject *d) { + PyObject *r; +#if __PYX_LIMITED_VERSION_HEX >= 0x030d00A1 + int res = PyObject_GetOptionalAttr(o, n, &r); + return (res != 0) ? r : __Pyx_NewRef(d); +#else + #if CYTHON_USE_TYPE_SLOTS + if (likely(PyString_Check(n))) { + r = __Pyx_PyObject_GetAttrStrNoError(o, n); + if (unlikely(!r) && likely(!PyErr_Occurred())) { + r = __Pyx_NewRef(d); + } + return r; + } + #endif + r = PyObject_GetAttr(o, n); + return (likely(r)) ? r : __Pyx_GetAttr3Default(d); +#endif +} + +/* PyDictVersioning */ +#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_TYPE_SLOTS +static CYTHON_INLINE PY_UINT64_T __Pyx_get_tp_dict_version(PyObject *obj) { + PyObject *dict = Py_TYPE(obj)->tp_dict; + return likely(dict) ? __PYX_GET_DICT_VERSION(dict) : 0; +} +static CYTHON_INLINE PY_UINT64_T __Pyx_get_object_dict_version(PyObject *obj) { + PyObject **dictptr = NULL; + Py_ssize_t offset = Py_TYPE(obj)->tp_dictoffset; + if (offset) { +#if CYTHON_COMPILING_IN_CPYTHON + dictptr = (likely(offset > 0)) ? (PyObject **) ((char *)obj + offset) : _PyObject_GetDictPtr(obj); +#else + dictptr = _PyObject_GetDictPtr(obj); +#endif + } + return (dictptr && *dictptr) ? __PYX_GET_DICT_VERSION(*dictptr) : 0; +} +static CYTHON_INLINE int __Pyx_object_dict_version_matches(PyObject* obj, PY_UINT64_T tp_dict_version, PY_UINT64_T obj_dict_version) { + PyObject *dict = Py_TYPE(obj)->tp_dict; + if (unlikely(!dict) || unlikely(tp_dict_version != __PYX_GET_DICT_VERSION(dict))) + return 0; + return obj_dict_version == __Pyx_get_object_dict_version(obj); +} +#endif + +/* GetModuleGlobalName */ +#if CYTHON_USE_DICT_VERSIONS +static PyObject *__Pyx__GetModuleGlobalName(PyObject *name, PY_UINT64_T *dict_version, PyObject **dict_cached_value) +#else +static CYTHON_INLINE PyObject *__Pyx__GetModuleGlobalName(PyObject *name) +#endif +{ + PyObject *result; +#if !CYTHON_AVOID_BORROWED_REFS +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030500A1 && PY_VERSION_HEX < 0x030d0000 + result = _PyDict_GetItem_KnownHash(__pyx_d, name, ((PyASCIIObject *) name)->hash); + __PYX_UPDATE_DICT_CACHE(__pyx_d, result, *dict_cached_value, *dict_version) + if (likely(result)) { + return __Pyx_NewRef(result); + } else if (unlikely(PyErr_Occurred())) { + return NULL; + } +#elif CYTHON_COMPILING_IN_LIMITED_API + if (unlikely(!__pyx_m)) { + return NULL; + } + result = PyObject_GetAttr(__pyx_m, name); + if (likely(result)) { + return result; + } +#else + result = PyDict_GetItem(__pyx_d, name); + __PYX_UPDATE_DICT_CACHE(__pyx_d, result, *dict_cached_value, *dict_version) + if (likely(result)) { + return __Pyx_NewRef(result); + } +#endif +#else + result = PyObject_GetItem(__pyx_d, name); + __PYX_UPDATE_DICT_CACHE(__pyx_d, result, *dict_cached_value, *dict_version) + if (likely(result)) { + return __Pyx_NewRef(result); + } + PyErr_Clear(); +#endif + return __Pyx_GetBuiltinName(name); +} + +/* RaiseTooManyValuesToUnpack */ +static CYTHON_INLINE void __Pyx_RaiseTooManyValuesError(Py_ssize_t expected) { + PyErr_Format(PyExc_ValueError, + "too many values to unpack (expected %" CYTHON_FORMAT_SSIZE_T "d)", expected); +} + +/* RaiseNeedMoreValuesToUnpack */ +static CYTHON_INLINE void __Pyx_RaiseNeedMoreValuesError(Py_ssize_t index) { + PyErr_Format(PyExc_ValueError, + "need more than %" CYTHON_FORMAT_SSIZE_T "d value%.1s to unpack", + index, (index == 1) ? "" : "s"); +} + +/* RaiseNoneIterError */ +static CYTHON_INLINE void __Pyx_RaiseNoneNotIterableError(void) { + PyErr_SetString(PyExc_TypeError, "'NoneType' object is not iterable"); +} + +/* ExtTypeTest */ +static CYTHON_INLINE int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type) { + __Pyx_TypeName obj_type_name; + __Pyx_TypeName type_name; + if (unlikely(!type)) { + PyErr_SetString(PyExc_SystemError, "Missing type object"); + return 0; + } + if (likely(__Pyx_TypeCheck(obj, type))) + return 1; + obj_type_name = __Pyx_PyType_GetName(Py_TYPE(obj)); + type_name = __Pyx_PyType_GetName(type); + PyErr_Format(PyExc_TypeError, + "Cannot convert " __Pyx_FMT_TYPENAME " to " __Pyx_FMT_TYPENAME, + obj_type_name, type_name); + __Pyx_DECREF_TypeName(obj_type_name); + __Pyx_DECREF_TypeName(type_name); + return 0; +} + +/* GetTopmostException */ +#if CYTHON_USE_EXC_INFO_STACK && CYTHON_FAST_THREAD_STATE +static _PyErr_StackItem * +__Pyx_PyErr_GetTopmostException(PyThreadState *tstate) +{ + _PyErr_StackItem *exc_info = tstate->exc_info; + while ((exc_info->exc_value == NULL || exc_info->exc_value == Py_None) && + exc_info->previous_item != NULL) + { + exc_info = exc_info->previous_item; + } + return exc_info; +} +#endif + +/* SaveResetException */ +#if CYTHON_FAST_THREAD_STATE +static CYTHON_INLINE void __Pyx__ExceptionSave(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb) { + #if CYTHON_USE_EXC_INFO_STACK && PY_VERSION_HEX >= 0x030B00a4 + _PyErr_StackItem *exc_info = __Pyx_PyErr_GetTopmostException(tstate); + PyObject *exc_value = exc_info->exc_value; + if (exc_value == NULL || exc_value == Py_None) { + *value = NULL; + *type = NULL; + *tb = NULL; + } else { + *value = exc_value; + Py_INCREF(*value); + *type = (PyObject*) Py_TYPE(exc_value); + Py_INCREF(*type); + *tb = PyException_GetTraceback(exc_value); + } + #elif CYTHON_USE_EXC_INFO_STACK + _PyErr_StackItem *exc_info = __Pyx_PyErr_GetTopmostException(tstate); + *type = exc_info->exc_type; + *value = exc_info->exc_value; + *tb = exc_info->exc_traceback; + Py_XINCREF(*type); + Py_XINCREF(*value); + Py_XINCREF(*tb); + #else + *type = tstate->exc_type; + *value = tstate->exc_value; + *tb = tstate->exc_traceback; + Py_XINCREF(*type); + Py_XINCREF(*value); + Py_XINCREF(*tb); + #endif +} +static CYTHON_INLINE void __Pyx__ExceptionReset(PyThreadState *tstate, PyObject *type, PyObject *value, PyObject *tb) { + #if CYTHON_USE_EXC_INFO_STACK && PY_VERSION_HEX >= 0x030B00a4 + _PyErr_StackItem *exc_info = tstate->exc_info; + PyObject *tmp_value = exc_info->exc_value; + exc_info->exc_value = value; + Py_XDECREF(tmp_value); + Py_XDECREF(type); + Py_XDECREF(tb); + #else + PyObject *tmp_type, *tmp_value, *tmp_tb; + #if CYTHON_USE_EXC_INFO_STACK + _PyErr_StackItem *exc_info = tstate->exc_info; + tmp_type = exc_info->exc_type; + tmp_value = exc_info->exc_value; + tmp_tb = exc_info->exc_traceback; + exc_info->exc_type = type; + exc_info->exc_value = value; + exc_info->exc_traceback = tb; + #else + tmp_type = tstate->exc_type; + tmp_value = tstate->exc_value; + tmp_tb = tstate->exc_traceback; + tstate->exc_type = type; + tstate->exc_value = value; + tstate->exc_traceback = tb; + #endif + Py_XDECREF(tmp_type); + Py_XDECREF(tmp_value); + Py_XDECREF(tmp_tb); + #endif +} +#endif + +/* GetException */ +#if CYTHON_FAST_THREAD_STATE +static int __Pyx__GetException(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb) +#else +static int __Pyx_GetException(PyObject **type, PyObject **value, PyObject **tb) +#endif +{ + PyObject *local_type = NULL, *local_value, *local_tb = NULL; +#if CYTHON_FAST_THREAD_STATE + PyObject *tmp_type, *tmp_value, *tmp_tb; + #if PY_VERSION_HEX >= 0x030C00A6 + local_value = tstate->current_exception; + tstate->current_exception = 0; + if (likely(local_value)) { + local_type = (PyObject*) Py_TYPE(local_value); + Py_INCREF(local_type); + local_tb = PyException_GetTraceback(local_value); + } + #else + local_type = tstate->curexc_type; + local_value = tstate->curexc_value; + local_tb = tstate->curexc_traceback; + tstate->curexc_type = 0; + tstate->curexc_value = 0; + tstate->curexc_traceback = 0; + #endif +#else + PyErr_Fetch(&local_type, &local_value, &local_tb); +#endif + PyErr_NormalizeException(&local_type, &local_value, &local_tb); +#if CYTHON_FAST_THREAD_STATE && PY_VERSION_HEX >= 0x030C00A6 + if (unlikely(tstate->current_exception)) +#elif CYTHON_FAST_THREAD_STATE + if (unlikely(tstate->curexc_type)) +#else + if (unlikely(PyErr_Occurred())) +#endif + goto bad; + #if PY_MAJOR_VERSION >= 3 + if (local_tb) { + if (unlikely(PyException_SetTraceback(local_value, local_tb) < 0)) + goto bad; + } + #endif + Py_XINCREF(local_tb); + Py_XINCREF(local_type); + Py_XINCREF(local_value); + *type = local_type; + *value = local_value; + *tb = local_tb; +#if CYTHON_FAST_THREAD_STATE + #if CYTHON_USE_EXC_INFO_STACK + { + _PyErr_StackItem *exc_info = tstate->exc_info; + #if PY_VERSION_HEX >= 0x030B00a4 + tmp_value = exc_info->exc_value; + exc_info->exc_value = local_value; + tmp_type = NULL; + tmp_tb = NULL; + Py_XDECREF(local_type); + Py_XDECREF(local_tb); + #else + tmp_type = exc_info->exc_type; + tmp_value = exc_info->exc_value; + tmp_tb = exc_info->exc_traceback; + exc_info->exc_type = local_type; + exc_info->exc_value = local_value; + exc_info->exc_traceback = local_tb; + #endif + } + #else + tmp_type = tstate->exc_type; + tmp_value = tstate->exc_value; + tmp_tb = tstate->exc_traceback; + tstate->exc_type = local_type; + tstate->exc_value = local_value; + tstate->exc_traceback = local_tb; + #endif + Py_XDECREF(tmp_type); + Py_XDECREF(tmp_value); + Py_XDECREF(tmp_tb); +#else + PyErr_SetExcInfo(local_type, local_value, local_tb); +#endif + return 0; +bad: + *type = 0; + *value = 0; + *tb = 0; + Py_XDECREF(local_type); + Py_XDECREF(local_value); + Py_XDECREF(local_tb); + return -1; +} + +/* SwapException */ +#if CYTHON_FAST_THREAD_STATE +static CYTHON_INLINE void __Pyx__ExceptionSwap(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb) { + PyObject *tmp_type, *tmp_value, *tmp_tb; + #if CYTHON_USE_EXC_INFO_STACK && PY_VERSION_HEX >= 0x030B00a4 + _PyErr_StackItem *exc_info = tstate->exc_info; + tmp_value = exc_info->exc_value; + exc_info->exc_value = *value; + if (tmp_value == NULL || tmp_value == Py_None) { + Py_XDECREF(tmp_value); + tmp_value = NULL; + tmp_type = NULL; + tmp_tb = NULL; + } else { + tmp_type = (PyObject*) Py_TYPE(tmp_value); + Py_INCREF(tmp_type); + #if CYTHON_COMPILING_IN_CPYTHON + tmp_tb = ((PyBaseExceptionObject*) tmp_value)->traceback; + Py_XINCREF(tmp_tb); + #else + tmp_tb = PyException_GetTraceback(tmp_value); + #endif + } + #elif CYTHON_USE_EXC_INFO_STACK + _PyErr_StackItem *exc_info = tstate->exc_info; + tmp_type = exc_info->exc_type; + tmp_value = exc_info->exc_value; + tmp_tb = exc_info->exc_traceback; + exc_info->exc_type = *type; + exc_info->exc_value = *value; + exc_info->exc_traceback = *tb; + #else + tmp_type = tstate->exc_type; + tmp_value = tstate->exc_value; + tmp_tb = tstate->exc_traceback; + tstate->exc_type = *type; + tstate->exc_value = *value; + tstate->exc_traceback = *tb; + #endif + *type = tmp_type; + *value = tmp_value; + *tb = tmp_tb; +} +#else +static CYTHON_INLINE void __Pyx_ExceptionSwap(PyObject **type, PyObject **value, PyObject **tb) { + PyObject *tmp_type, *tmp_value, *tmp_tb; + PyErr_GetExcInfo(&tmp_type, &tmp_value, &tmp_tb); + PyErr_SetExcInfo(*type, *value, *tb); + *type = tmp_type; + *value = tmp_value; + *tb = tmp_tb; +} +#endif + +/* Import */ +static PyObject *__Pyx_Import(PyObject *name, PyObject *from_list, int level) { + PyObject *module = 0; + PyObject *empty_dict = 0; + PyObject *empty_list = 0; + #if PY_MAJOR_VERSION < 3 + PyObject *py_import; + py_import = __Pyx_PyObject_GetAttrStr(__pyx_b, __pyx_n_s_import); + if (unlikely(!py_import)) + goto bad; + if (!from_list) { + empty_list = PyList_New(0); + if (unlikely(!empty_list)) + goto bad; + from_list = empty_list; + } + #endif + empty_dict = PyDict_New(); + if (unlikely(!empty_dict)) + goto bad; + { + #if PY_MAJOR_VERSION >= 3 + if (level == -1) { + if (strchr(__Pyx_MODULE_NAME, '.') != NULL) { + module = PyImport_ImportModuleLevelObject( + name, __pyx_d, empty_dict, from_list, 1); + if (unlikely(!module)) { + if (unlikely(!PyErr_ExceptionMatches(PyExc_ImportError))) + goto bad; + PyErr_Clear(); + } + } + level = 0; + } + #endif + if (!module) { + #if PY_MAJOR_VERSION < 3 + PyObject *py_level = PyInt_FromLong(level); + if (unlikely(!py_level)) + goto bad; + module = PyObject_CallFunctionObjArgs(py_import, + name, __pyx_d, empty_dict, from_list, py_level, (PyObject *)NULL); + Py_DECREF(py_level); + #else + module = PyImport_ImportModuleLevelObject( + name, __pyx_d, empty_dict, from_list, level); + #endif + } + } +bad: + Py_XDECREF(empty_dict); + Py_XDECREF(empty_list); + #if PY_MAJOR_VERSION < 3 + Py_XDECREF(py_import); + #endif + return module; +} + +/* ImportDottedModule */ +#if PY_MAJOR_VERSION >= 3 +static PyObject *__Pyx__ImportDottedModule_Error(PyObject *name, PyObject *parts_tuple, Py_ssize_t count) { + PyObject *partial_name = NULL, *slice = NULL, *sep = NULL; + if (unlikely(PyErr_Occurred())) { + PyErr_Clear(); + } + if (likely(PyTuple_GET_SIZE(parts_tuple) == count)) { + partial_name = name; + } else { + slice = PySequence_GetSlice(parts_tuple, 0, count); + if (unlikely(!slice)) + goto bad; + sep = PyUnicode_FromStringAndSize(".", 1); + if (unlikely(!sep)) + goto bad; + partial_name = PyUnicode_Join(sep, slice); + } + PyErr_Format( +#if PY_MAJOR_VERSION < 3 + PyExc_ImportError, + "No module named '%s'", PyString_AS_STRING(partial_name)); +#else +#if PY_VERSION_HEX >= 0x030600B1 + PyExc_ModuleNotFoundError, +#else + PyExc_ImportError, +#endif + "No module named '%U'", partial_name); +#endif +bad: + Py_XDECREF(sep); + Py_XDECREF(slice); + Py_XDECREF(partial_name); + return NULL; +} +#endif +#if PY_MAJOR_VERSION >= 3 +static PyObject *__Pyx__ImportDottedModule_Lookup(PyObject *name) { + PyObject *imported_module; +#if PY_VERSION_HEX < 0x030700A1 || (CYTHON_COMPILING_IN_PYPY && PYPY_VERSION_NUM < 0x07030400) + PyObject *modules = PyImport_GetModuleDict(); + if (unlikely(!modules)) + return NULL; + imported_module = __Pyx_PyDict_GetItemStr(modules, name); + Py_XINCREF(imported_module); +#else + imported_module = PyImport_GetModule(name); +#endif + return imported_module; +} +#endif +#if PY_MAJOR_VERSION >= 3 +static PyObject *__Pyx_ImportDottedModule_WalkParts(PyObject *module, PyObject *name, PyObject *parts_tuple) { + Py_ssize_t i, nparts; + nparts = PyTuple_GET_SIZE(parts_tuple); + for (i=1; i < nparts && module; i++) { + PyObject *part, *submodule; +#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + part = PyTuple_GET_ITEM(parts_tuple, i); +#else + part = PySequence_ITEM(parts_tuple, i); +#endif + submodule = __Pyx_PyObject_GetAttrStrNoError(module, part); +#if !(CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS) + Py_DECREF(part); +#endif + Py_DECREF(module); + module = submodule; + } + if (unlikely(!module)) { + return __Pyx__ImportDottedModule_Error(name, parts_tuple, i); + } + return module; +} +#endif +static PyObject *__Pyx__ImportDottedModule(PyObject *name, PyObject *parts_tuple) { +#if PY_MAJOR_VERSION < 3 + PyObject *module, *from_list, *star = __pyx_n_s__3; + CYTHON_UNUSED_VAR(parts_tuple); + from_list = PyList_New(1); + if (unlikely(!from_list)) + return NULL; + Py_INCREF(star); + PyList_SET_ITEM(from_list, 0, star); + module = __Pyx_Import(name, from_list, 0); + Py_DECREF(from_list); + return module; +#else + PyObject *imported_module; + PyObject *module = __Pyx_Import(name, NULL, 0); + if (!parts_tuple || unlikely(!module)) + return module; + imported_module = __Pyx__ImportDottedModule_Lookup(name); + if (likely(imported_module)) { + Py_DECREF(module); + return imported_module; + } + PyErr_Clear(); + return __Pyx_ImportDottedModule_WalkParts(module, name, parts_tuple); +#endif +} +static PyObject *__Pyx_ImportDottedModule(PyObject *name, PyObject *parts_tuple) { +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030400B1 + PyObject *module = __Pyx__ImportDottedModule_Lookup(name); + if (likely(module)) { + PyObject *spec = __Pyx_PyObject_GetAttrStrNoError(module, __pyx_n_s_spec); + if (likely(spec)) { + PyObject *unsafe = __Pyx_PyObject_GetAttrStrNoError(spec, __pyx_n_s_initializing); + if (likely(!unsafe || !__Pyx_PyObject_IsTrue(unsafe))) { + Py_DECREF(spec); + spec = NULL; + } + Py_XDECREF(unsafe); + } + if (likely(!spec)) { + PyErr_Clear(); + return module; + } + Py_DECREF(spec); + Py_DECREF(module); + } else if (PyErr_Occurred()) { + PyErr_Clear(); + } +#endif + return __Pyx__ImportDottedModule(name, parts_tuple); +} + +/* FastTypeChecks */ +#if CYTHON_COMPILING_IN_CPYTHON +static int __Pyx_InBases(PyTypeObject *a, PyTypeObject *b) { + while (a) { + a = __Pyx_PyType_GetSlot(a, tp_base, PyTypeObject*); + if (a == b) + return 1; + } + return b == &PyBaseObject_Type; +} +static CYTHON_INLINE int __Pyx_IsSubtype(PyTypeObject *a, PyTypeObject *b) { + PyObject *mro; + if (a == b) return 1; + mro = a->tp_mro; + if (likely(mro)) { + Py_ssize_t i, n; + n = PyTuple_GET_SIZE(mro); + for (i = 0; i < n; i++) { + if (PyTuple_GET_ITEM(mro, i) == (PyObject *)b) + return 1; + } + return 0; + } + return __Pyx_InBases(a, b); +} +static CYTHON_INLINE int __Pyx_IsAnySubtype2(PyTypeObject *cls, PyTypeObject *a, PyTypeObject *b) { + PyObject *mro; + if (cls == a || cls == b) return 1; + mro = cls->tp_mro; + if (likely(mro)) { + Py_ssize_t i, n; + n = PyTuple_GET_SIZE(mro); + for (i = 0; i < n; i++) { + PyObject *base = PyTuple_GET_ITEM(mro, i); + if (base == (PyObject *)a || base == (PyObject *)b) + return 1; + } + return 0; + } + return __Pyx_InBases(cls, a) || __Pyx_InBases(cls, b); +} +#if PY_MAJOR_VERSION == 2 +static int __Pyx_inner_PyErr_GivenExceptionMatches2(PyObject *err, PyObject* exc_type1, PyObject* exc_type2) { + PyObject *exception, *value, *tb; + int res; + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ErrFetch(&exception, &value, &tb); + res = exc_type1 ? PyObject_IsSubclass(err, exc_type1) : 0; + if (unlikely(res == -1)) { + PyErr_WriteUnraisable(err); + res = 0; + } + if (!res) { + res = PyObject_IsSubclass(err, exc_type2); + if (unlikely(res == -1)) { + PyErr_WriteUnraisable(err); + res = 0; + } + } + __Pyx_ErrRestore(exception, value, tb); + return res; +} +#else +static CYTHON_INLINE int __Pyx_inner_PyErr_GivenExceptionMatches2(PyObject *err, PyObject* exc_type1, PyObject *exc_type2) { + if (exc_type1) { + return __Pyx_IsAnySubtype2((PyTypeObject*)err, (PyTypeObject*)exc_type1, (PyTypeObject*)exc_type2); + } else { + return __Pyx_IsSubtype((PyTypeObject*)err, (PyTypeObject*)exc_type2); + } +} +#endif +static int __Pyx_PyErr_GivenExceptionMatchesTuple(PyObject *exc_type, PyObject *tuple) { + Py_ssize_t i, n; + assert(PyExceptionClass_Check(exc_type)); + n = PyTuple_GET_SIZE(tuple); +#if PY_MAJOR_VERSION >= 3 + for (i=0; itp_as_sequence && type->tp_as_sequence->sq_repeat)) { + return type->tp_as_sequence->sq_repeat(seq, mul); + } else +#endif + { + return __Pyx_PySequence_Multiply_Generic(seq, mul); + } +} + +/* SetItemInt */ +static int __Pyx_SetItemInt_Generic(PyObject *o, PyObject *j, PyObject *v) { + int r; + if (unlikely(!j)) return -1; + r = PyObject_SetItem(o, j, v); + Py_DECREF(j); + return r; +} +static CYTHON_INLINE int __Pyx_SetItemInt_Fast(PyObject *o, Py_ssize_t i, PyObject *v, int is_list, + CYTHON_NCP_UNUSED int wraparound, CYTHON_NCP_UNUSED int boundscheck) { +#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS && CYTHON_USE_TYPE_SLOTS + if (is_list || PyList_CheckExact(o)) { + Py_ssize_t n = (!wraparound) ? i : ((likely(i >= 0)) ? i : i + PyList_GET_SIZE(o)); + if ((!boundscheck) || likely(__Pyx_is_valid_index(n, PyList_GET_SIZE(o)))) { + PyObject* old = PyList_GET_ITEM(o, n); + Py_INCREF(v); + PyList_SET_ITEM(o, n, v); + Py_DECREF(old); + return 1; + } + } else { + PyMappingMethods *mm = Py_TYPE(o)->tp_as_mapping; + PySequenceMethods *sm = Py_TYPE(o)->tp_as_sequence; + if (mm && mm->mp_ass_subscript) { + int r; + PyObject *key = PyInt_FromSsize_t(i); + if (unlikely(!key)) return -1; + r = mm->mp_ass_subscript(o, key, v); + Py_DECREF(key); + return r; + } + if (likely(sm && sm->sq_ass_item)) { + if (wraparound && unlikely(i < 0) && likely(sm->sq_length)) { + Py_ssize_t l = sm->sq_length(o); + if (likely(l >= 0)) { + i += l; + } else { + if (!PyErr_ExceptionMatches(PyExc_OverflowError)) + return -1; + PyErr_Clear(); + } + } + return sm->sq_ass_item(o, i, v); + } + } +#else + if (is_list || !PyMapping_Check(o)) + { + return PySequence_SetItem(o, i, v); + } +#endif + return __Pyx_SetItemInt_Generic(o, PyInt_FromSsize_t(i), v); +} + +/* RaiseUnboundLocalError */ +static CYTHON_INLINE void __Pyx_RaiseUnboundLocalError(const char *varname) { + PyErr_Format(PyExc_UnboundLocalError, "local variable '%s' referenced before assignment", varname); +} + +/* DivInt[long] */ +static CYTHON_INLINE long __Pyx_div_long(long a, long b) { + long q = a / b; + long r = a - q*b; + q -= ((r != 0) & ((r ^ b) < 0)); + return q; +} + +/* ImportFrom */ +static PyObject* __Pyx_ImportFrom(PyObject* module, PyObject* name) { + PyObject* value = __Pyx_PyObject_GetAttrStr(module, name); + if (unlikely(!value) && PyErr_ExceptionMatches(PyExc_AttributeError)) { + const char* module_name_str = 0; + PyObject* module_name = 0; + PyObject* module_dot = 0; + PyObject* full_name = 0; + PyErr_Clear(); + module_name_str = PyModule_GetName(module); + if (unlikely(!module_name_str)) { goto modbad; } + module_name = PyUnicode_FromString(module_name_str); + if (unlikely(!module_name)) { goto modbad; } + module_dot = PyUnicode_Concat(module_name, __pyx_kp_u__2); + if (unlikely(!module_dot)) { goto modbad; } + full_name = PyUnicode_Concat(module_dot, name); + if (unlikely(!full_name)) { goto modbad; } + #if PY_VERSION_HEX < 0x030700A1 || (CYTHON_COMPILING_IN_PYPY && PYPY_VERSION_NUM < 0x07030400) + { + PyObject *modules = PyImport_GetModuleDict(); + if (unlikely(!modules)) + goto modbad; + value = PyObject_GetItem(modules, full_name); + } + #else + value = PyImport_GetModule(full_name); + #endif + modbad: + Py_XDECREF(full_name); + Py_XDECREF(module_dot); + Py_XDECREF(module_name); + } + if (unlikely(!value)) { + PyErr_Format(PyExc_ImportError, + #if PY_MAJOR_VERSION < 3 + "cannot import name %.230s", PyString_AS_STRING(name)); + #else + "cannot import name %S", name); + #endif + } + return value; +} + +/* HasAttr */ +#if __PYX_LIMITED_VERSION_HEX < 0x030d00A1 +static CYTHON_INLINE int __Pyx_HasAttr(PyObject *o, PyObject *n) { + PyObject *r; + if (unlikely(!__Pyx_PyBaseString_Check(n))) { + PyErr_SetString(PyExc_TypeError, + "hasattr(): attribute name must be string"); + return -1; + } + r = __Pyx_GetAttr(o, n); + if (!r) { + PyErr_Clear(); + return 0; + } else { + Py_DECREF(r); + return 1; + } +} +#endif + +/* IsLittleEndian */ +static CYTHON_INLINE int __Pyx_Is_Little_Endian(void) +{ + union { + uint32_t u32; + uint8_t u8[4]; + } S; + S.u32 = 0x01020304; + return S.u8[0] == 4; +} + +/* BufferFormatCheck */ +static void __Pyx_BufFmt_Init(__Pyx_BufFmt_Context* ctx, + __Pyx_BufFmt_StackElem* stack, + __Pyx_TypeInfo* type) { + stack[0].field = &ctx->root; + stack[0].parent_offset = 0; + ctx->root.type = type; + ctx->root.name = "buffer dtype"; + ctx->root.offset = 0; + ctx->head = stack; + ctx->head->field = &ctx->root; + ctx->fmt_offset = 0; + ctx->head->parent_offset = 0; + ctx->new_packmode = '@'; + ctx->enc_packmode = '@'; + ctx->new_count = 1; + ctx->enc_count = 0; + ctx->enc_type = 0; + ctx->is_complex = 0; + ctx->is_valid_array = 0; + ctx->struct_alignment = 0; + while (type->typegroup == 'S') { + ++ctx->head; + ctx->head->field = type->fields; + ctx->head->parent_offset = 0; + type = type->fields->type; + } +} +static int __Pyx_BufFmt_ParseNumber(const char** ts) { + int count; + const char* t = *ts; + if (*t < '0' || *t > '9') { + return -1; + } else { + count = *t++ - '0'; + while (*t >= '0' && *t <= '9') { + count *= 10; + count += *t++ - '0'; + } + } + *ts = t; + return count; +} +static int __Pyx_BufFmt_ExpectNumber(const char **ts) { + int number = __Pyx_BufFmt_ParseNumber(ts); + if (number == -1) + PyErr_Format(PyExc_ValueError,\ + "Does not understand character buffer dtype format string ('%c')", **ts); + return number; +} +static void __Pyx_BufFmt_RaiseUnexpectedChar(char ch) { + PyErr_Format(PyExc_ValueError, + "Unexpected format string character: '%c'", ch); +} +static const char* __Pyx_BufFmt_DescribeTypeChar(char ch, int is_complex) { + switch (ch) { + case '?': return "'bool'"; + case 'c': return "'char'"; + case 'b': return "'signed char'"; + case 'B': return "'unsigned char'"; + case 'h': return "'short'"; + case 'H': return "'unsigned short'"; + case 'i': return "'int'"; + case 'I': return "'unsigned int'"; + case 'l': return "'long'"; + case 'L': return "'unsigned long'"; + case 'q': return "'long long'"; + case 'Q': return "'unsigned long long'"; + case 'f': return (is_complex ? "'complex float'" : "'float'"); + case 'd': return (is_complex ? "'complex double'" : "'double'"); + case 'g': return (is_complex ? "'complex long double'" : "'long double'"); + case 'T': return "a struct"; + case 'O': return "Python object"; + case 'P': return "a pointer"; + case 's': case 'p': return "a string"; + case 0: return "end"; + default: return "unparsable format string"; + } +} +static size_t __Pyx_BufFmt_TypeCharToStandardSize(char ch, int is_complex) { + switch (ch) { + case '?': case 'c': case 'b': case 'B': case 's': case 'p': return 1; + case 'h': case 'H': return 2; + case 'i': case 'I': case 'l': case 'L': return 4; + case 'q': case 'Q': return 8; + case 'f': return (is_complex ? 8 : 4); + case 'd': return (is_complex ? 16 : 8); + case 'g': { + PyErr_SetString(PyExc_ValueError, "Python does not define a standard format string size for long double ('g').."); + return 0; + } + case 'O': case 'P': return sizeof(void*); + default: + __Pyx_BufFmt_RaiseUnexpectedChar(ch); + return 0; + } +} +static size_t __Pyx_BufFmt_TypeCharToNativeSize(char ch, int is_complex) { + switch (ch) { + case '?': case 'c': case 'b': case 'B': case 's': case 'p': return 1; + case 'h': case 'H': return sizeof(short); + case 'i': case 'I': return sizeof(int); + case 'l': case 'L': return sizeof(long); + #ifdef HAVE_LONG_LONG + case 'q': case 'Q': return sizeof(PY_LONG_LONG); + #endif + case 'f': return sizeof(float) * (is_complex ? 2 : 1); + case 'd': return sizeof(double) * (is_complex ? 2 : 1); + case 'g': return sizeof(long double) * (is_complex ? 2 : 1); + case 'O': case 'P': return sizeof(void*); + default: { + __Pyx_BufFmt_RaiseUnexpectedChar(ch); + return 0; + } + } +} +typedef struct { char c; short x; } __Pyx_st_short; +typedef struct { char c; int x; } __Pyx_st_int; +typedef struct { char c; long x; } __Pyx_st_long; +typedef struct { char c; float x; } __Pyx_st_float; +typedef struct { char c; double x; } __Pyx_st_double; +typedef struct { char c; long double x; } __Pyx_st_longdouble; +typedef struct { char c; void *x; } __Pyx_st_void_p; +#ifdef HAVE_LONG_LONG +typedef struct { char c; PY_LONG_LONG x; } __Pyx_st_longlong; +#endif +static size_t __Pyx_BufFmt_TypeCharToAlignment(char ch, int is_complex) { + CYTHON_UNUSED_VAR(is_complex); + switch (ch) { + case '?': case 'c': case 'b': case 'B': case 's': case 'p': return 1; + case 'h': case 'H': return sizeof(__Pyx_st_short) - sizeof(short); + case 'i': case 'I': return sizeof(__Pyx_st_int) - sizeof(int); + case 'l': case 'L': return sizeof(__Pyx_st_long) - sizeof(long); +#ifdef HAVE_LONG_LONG + case 'q': case 'Q': return sizeof(__Pyx_st_longlong) - sizeof(PY_LONG_LONG); +#endif + case 'f': return sizeof(__Pyx_st_float) - sizeof(float); + case 'd': return sizeof(__Pyx_st_double) - sizeof(double); + case 'g': return sizeof(__Pyx_st_longdouble) - sizeof(long double); + case 'P': case 'O': return sizeof(__Pyx_st_void_p) - sizeof(void*); + default: + __Pyx_BufFmt_RaiseUnexpectedChar(ch); + return 0; + } +} +/* These are for computing the padding at the end of the struct to align + on the first member of the struct. This will probably the same as above, + but we don't have any guarantees. + */ +typedef struct { short x; char c; } __Pyx_pad_short; +typedef struct { int x; char c; } __Pyx_pad_int; +typedef struct { long x; char c; } __Pyx_pad_long; +typedef struct { float x; char c; } __Pyx_pad_float; +typedef struct { double x; char c; } __Pyx_pad_double; +typedef struct { long double x; char c; } __Pyx_pad_longdouble; +typedef struct { void *x; char c; } __Pyx_pad_void_p; +#ifdef HAVE_LONG_LONG +typedef struct { PY_LONG_LONG x; char c; } __Pyx_pad_longlong; +#endif +static size_t __Pyx_BufFmt_TypeCharToPadding(char ch, int is_complex) { + CYTHON_UNUSED_VAR(is_complex); + switch (ch) { + case '?': case 'c': case 'b': case 'B': case 's': case 'p': return 1; + case 'h': case 'H': return sizeof(__Pyx_pad_short) - sizeof(short); + case 'i': case 'I': return sizeof(__Pyx_pad_int) - sizeof(int); + case 'l': case 'L': return sizeof(__Pyx_pad_long) - sizeof(long); +#ifdef HAVE_LONG_LONG + case 'q': case 'Q': return sizeof(__Pyx_pad_longlong) - sizeof(PY_LONG_LONG); +#endif + case 'f': return sizeof(__Pyx_pad_float) - sizeof(float); + case 'd': return sizeof(__Pyx_pad_double) - sizeof(double); + case 'g': return sizeof(__Pyx_pad_longdouble) - sizeof(long double); + case 'P': case 'O': return sizeof(__Pyx_pad_void_p) - sizeof(void*); + default: + __Pyx_BufFmt_RaiseUnexpectedChar(ch); + return 0; + } +} +static char __Pyx_BufFmt_TypeCharToGroup(char ch, int is_complex) { + switch (ch) { + case 'c': + return 'H'; + case 'b': case 'h': case 'i': + case 'l': case 'q': case 's': case 'p': + return 'I'; + case '?': case 'B': case 'H': case 'I': case 'L': case 'Q': + return 'U'; + case 'f': case 'd': case 'g': + return (is_complex ? 'C' : 'R'); + case 'O': + return 'O'; + case 'P': + return 'P'; + default: { + __Pyx_BufFmt_RaiseUnexpectedChar(ch); + return 0; + } + } +} +static void __Pyx_BufFmt_RaiseExpected(__Pyx_BufFmt_Context* ctx) { + if (ctx->head == NULL || ctx->head->field == &ctx->root) { + const char* expected; + const char* quote; + if (ctx->head == NULL) { + expected = "end"; + quote = ""; + } else { + expected = ctx->head->field->type->name; + quote = "'"; + } + PyErr_Format(PyExc_ValueError, + "Buffer dtype mismatch, expected %s%s%s but got %s", + quote, expected, quote, + __Pyx_BufFmt_DescribeTypeChar(ctx->enc_type, ctx->is_complex)); + } else { + __Pyx_StructField* field = ctx->head->field; + __Pyx_StructField* parent = (ctx->head - 1)->field; + PyErr_Format(PyExc_ValueError, + "Buffer dtype mismatch, expected '%s' but got %s in '%s.%s'", + field->type->name, __Pyx_BufFmt_DescribeTypeChar(ctx->enc_type, ctx->is_complex), + parent->type->name, field->name); + } +} +static int __Pyx_BufFmt_ProcessTypeChunk(__Pyx_BufFmt_Context* ctx) { + char group; + size_t size, offset, arraysize = 1; + if (ctx->enc_type == 0) return 0; + if (ctx->head->field->type->arraysize[0]) { + int i, ndim = 0; + if (ctx->enc_type == 's' || ctx->enc_type == 'p') { + ctx->is_valid_array = ctx->head->field->type->ndim == 1; + ndim = 1; + if (ctx->enc_count != ctx->head->field->type->arraysize[0]) { + PyErr_Format(PyExc_ValueError, + "Expected a dimension of size %zu, got %zu", + ctx->head->field->type->arraysize[0], ctx->enc_count); + return -1; + } + } + if (!ctx->is_valid_array) { + PyErr_Format(PyExc_ValueError, "Expected %d dimensions, got %d", + ctx->head->field->type->ndim, ndim); + return -1; + } + for (i = 0; i < ctx->head->field->type->ndim; i++) { + arraysize *= ctx->head->field->type->arraysize[i]; + } + ctx->is_valid_array = 0; + ctx->enc_count = 1; + } + group = __Pyx_BufFmt_TypeCharToGroup(ctx->enc_type, ctx->is_complex); + do { + __Pyx_StructField* field = ctx->head->field; + __Pyx_TypeInfo* type = field->type; + if (ctx->enc_packmode == '@' || ctx->enc_packmode == '^') { + size = __Pyx_BufFmt_TypeCharToNativeSize(ctx->enc_type, ctx->is_complex); + } else { + size = __Pyx_BufFmt_TypeCharToStandardSize(ctx->enc_type, ctx->is_complex); + } + if (ctx->enc_packmode == '@') { + size_t align_at = __Pyx_BufFmt_TypeCharToAlignment(ctx->enc_type, ctx->is_complex); + size_t align_mod_offset; + if (align_at == 0) return -1; + align_mod_offset = ctx->fmt_offset % align_at; + if (align_mod_offset > 0) ctx->fmt_offset += align_at - align_mod_offset; + if (ctx->struct_alignment == 0) + ctx->struct_alignment = __Pyx_BufFmt_TypeCharToPadding(ctx->enc_type, + ctx->is_complex); + } + if (type->size != size || type->typegroup != group) { + if (type->typegroup == 'C' && type->fields != NULL) { + size_t parent_offset = ctx->head->parent_offset + field->offset; + ++ctx->head; + ctx->head->field = type->fields; + ctx->head->parent_offset = parent_offset; + continue; + } + if ((type->typegroup == 'H' || group == 'H') && type->size == size) { + } else { + __Pyx_BufFmt_RaiseExpected(ctx); + return -1; + } + } + offset = ctx->head->parent_offset + field->offset; + if (ctx->fmt_offset != offset) { + PyErr_Format(PyExc_ValueError, + "Buffer dtype mismatch; next field is at offset %" CYTHON_FORMAT_SSIZE_T "d but %" CYTHON_FORMAT_SSIZE_T "d expected", + (Py_ssize_t)ctx->fmt_offset, (Py_ssize_t)offset); + return -1; + } + ctx->fmt_offset += size; + if (arraysize) + ctx->fmt_offset += (arraysize - 1) * size; + --ctx->enc_count; + while (1) { + if (field == &ctx->root) { + ctx->head = NULL; + if (ctx->enc_count != 0) { + __Pyx_BufFmt_RaiseExpected(ctx); + return -1; + } + break; + } + ctx->head->field = ++field; + if (field->type == NULL) { + --ctx->head; + field = ctx->head->field; + continue; + } else if (field->type->typegroup == 'S') { + size_t parent_offset = ctx->head->parent_offset + field->offset; + if (field->type->fields->type == NULL) continue; + field = field->type->fields; + ++ctx->head; + ctx->head->field = field; + ctx->head->parent_offset = parent_offset; + break; + } else { + break; + } + } + } while (ctx->enc_count); + ctx->enc_type = 0; + ctx->is_complex = 0; + return 0; +} +static int +__pyx_buffmt_parse_array(__Pyx_BufFmt_Context* ctx, const char** tsp) +{ + const char *ts = *tsp; + int i = 0, number, ndim; + ++ts; + if (ctx->new_count != 1) { + PyErr_SetString(PyExc_ValueError, + "Cannot handle repeated arrays in format string"); + return -1; + } + if (__Pyx_BufFmt_ProcessTypeChunk(ctx) == -1) return -1; + ndim = ctx->head->field->type->ndim; + while (*ts && *ts != ')') { + switch (*ts) { + case ' ': case '\f': case '\r': case '\n': case '\t': case '\v': continue; + default: break; + } + number = __Pyx_BufFmt_ExpectNumber(&ts); + if (number == -1) return -1; + if (i < ndim && (size_t) number != ctx->head->field->type->arraysize[i]) { + PyErr_Format(PyExc_ValueError, + "Expected a dimension of size %zu, got %d", + ctx->head->field->type->arraysize[i], number); + return -1; + } + if (*ts != ',' && *ts != ')') { + PyErr_Format(PyExc_ValueError, + "Expected a comma in format string, got '%c'", *ts); + return -1; + } + if (*ts == ',') ts++; + i++; + } + if (i != ndim) { + PyErr_Format(PyExc_ValueError, "Expected %d dimension(s), got %d", + ctx->head->field->type->ndim, i); + return -1; + } + if (!*ts) { + PyErr_SetString(PyExc_ValueError, + "Unexpected end of format string, expected ')'"); + return -1; + } + ctx->is_valid_array = 1; + ctx->new_count = 1; + *tsp = ++ts; + return 0; +} +static const char* __Pyx_BufFmt_CheckString(__Pyx_BufFmt_Context* ctx, const char* ts) { + int got_Z = 0; + while (1) { + switch(*ts) { + case 0: + if (ctx->enc_type != 0 && ctx->head == NULL) { + __Pyx_BufFmt_RaiseExpected(ctx); + return NULL; + } + if (__Pyx_BufFmt_ProcessTypeChunk(ctx) == -1) return NULL; + if (ctx->head != NULL) { + __Pyx_BufFmt_RaiseExpected(ctx); + return NULL; + } + return ts; + case ' ': + case '\r': + case '\n': + ++ts; + break; + case '<': + if (!__Pyx_Is_Little_Endian()) { + PyErr_SetString(PyExc_ValueError, "Little-endian buffer not supported on big-endian compiler"); + return NULL; + } + ctx->new_packmode = '='; + ++ts; + break; + case '>': + case '!': + if (__Pyx_Is_Little_Endian()) { + PyErr_SetString(PyExc_ValueError, "Big-endian buffer not supported on little-endian compiler"); + return NULL; + } + ctx->new_packmode = '='; + ++ts; + break; + case '=': + case '@': + case '^': + ctx->new_packmode = *ts++; + break; + case 'T': + { + const char* ts_after_sub; + size_t i, struct_count = ctx->new_count; + size_t struct_alignment = ctx->struct_alignment; + ctx->new_count = 1; + ++ts; + if (*ts != '{') { + PyErr_SetString(PyExc_ValueError, "Buffer acquisition: Expected '{' after 'T'"); + return NULL; + } + if (__Pyx_BufFmt_ProcessTypeChunk(ctx) == -1) return NULL; + ctx->enc_type = 0; + ctx->enc_count = 0; + ctx->struct_alignment = 0; + ++ts; + ts_after_sub = ts; + for (i = 0; i != struct_count; ++i) { + ts_after_sub = __Pyx_BufFmt_CheckString(ctx, ts); + if (!ts_after_sub) return NULL; + } + ts = ts_after_sub; + if (struct_alignment) ctx->struct_alignment = struct_alignment; + } + break; + case '}': + { + size_t alignment = ctx->struct_alignment; + ++ts; + if (__Pyx_BufFmt_ProcessTypeChunk(ctx) == -1) return NULL; + ctx->enc_type = 0; + if (alignment && ctx->fmt_offset % alignment) { + ctx->fmt_offset += alignment - (ctx->fmt_offset % alignment); + } + } + return ts; + case 'x': + if (__Pyx_BufFmt_ProcessTypeChunk(ctx) == -1) return NULL; + ctx->fmt_offset += ctx->new_count; + ctx->new_count = 1; + ctx->enc_count = 0; + ctx->enc_type = 0; + ctx->enc_packmode = ctx->new_packmode; + ++ts; + break; + case 'Z': + got_Z = 1; + ++ts; + if (*ts != 'f' && *ts != 'd' && *ts != 'g') { + __Pyx_BufFmt_RaiseUnexpectedChar('Z'); + return NULL; + } + CYTHON_FALLTHROUGH; + case '?': case 'c': case 'b': case 'B': case 'h': case 'H': case 'i': case 'I': + case 'l': case 'L': case 'q': case 'Q': + case 'f': case 'd': case 'g': + case 'O': case 'p': + if ((ctx->enc_type == *ts) && (got_Z == ctx->is_complex) && + (ctx->enc_packmode == ctx->new_packmode) && (!ctx->is_valid_array)) { + ctx->enc_count += ctx->new_count; + ctx->new_count = 1; + got_Z = 0; + ++ts; + break; + } + CYTHON_FALLTHROUGH; + case 's': + if (__Pyx_BufFmt_ProcessTypeChunk(ctx) == -1) return NULL; + ctx->enc_count = ctx->new_count; + ctx->enc_packmode = ctx->new_packmode; + ctx->enc_type = *ts; + ctx->is_complex = got_Z; + ++ts; + ctx->new_count = 1; + got_Z = 0; + break; + case ':': + ++ts; + while(*ts != ':') ++ts; + ++ts; + break; + case '(': + if (__pyx_buffmt_parse_array(ctx, &ts) < 0) return NULL; + break; + default: + { + int number = __Pyx_BufFmt_ExpectNumber(&ts); + if (number == -1) return NULL; + ctx->new_count = (size_t)number; + } + } + } +} + +/* BufferGetAndValidate */ + static CYTHON_INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info) { + if (unlikely(info->buf == NULL)) return; + if (info->suboffsets == __Pyx_minusones) info->suboffsets = NULL; + __Pyx_ReleaseBuffer(info); +} +static void __Pyx_ZeroBuffer(Py_buffer* buf) { + buf->buf = NULL; + buf->obj = NULL; + buf->strides = __Pyx_zeros; + buf->shape = __Pyx_zeros; + buf->suboffsets = __Pyx_minusones; +} +static int __Pyx__GetBufferAndValidate( + Py_buffer* buf, PyObject* obj, __Pyx_TypeInfo* dtype, int flags, + int nd, int cast, __Pyx_BufFmt_StackElem* stack) +{ + buf->buf = NULL; + if (unlikely(__Pyx_GetBuffer(obj, buf, flags) == -1)) { + __Pyx_ZeroBuffer(buf); + return -1; + } + if (unlikely(buf->ndim != nd)) { + PyErr_Format(PyExc_ValueError, + "Buffer has wrong number of dimensions (expected %d, got %d)", + nd, buf->ndim); + goto fail; + } + if (!cast) { + __Pyx_BufFmt_Context ctx; + __Pyx_BufFmt_Init(&ctx, stack, dtype); + if (!__Pyx_BufFmt_CheckString(&ctx, buf->format)) goto fail; + } + if (unlikely((size_t)buf->itemsize != dtype->size)) { + PyErr_Format(PyExc_ValueError, + "Item size of buffer (%" CYTHON_FORMAT_SSIZE_T "d byte%s) does not match size of '%s' (%" CYTHON_FORMAT_SSIZE_T "d byte%s)", + buf->itemsize, (buf->itemsize > 1) ? "s" : "", + dtype->name, (Py_ssize_t)dtype->size, (dtype->size > 1) ? "s" : ""); + goto fail; + } + if (buf->suboffsets == NULL) buf->suboffsets = __Pyx_minusones; + return 0; +fail:; + __Pyx_SafeReleaseBuffer(buf); + return -1; +} + +/* BufferFallbackError */ + static void __Pyx_RaiseBufferFallbackError(void) { + PyErr_SetString(PyExc_ValueError, + "Buffer acquisition failed on assignment; and then reacquiring the old buffer failed too!"); +} + +/* PyIntBinop */ + #if !CYTHON_COMPILING_IN_PYPY +static PyObject* __Pyx_PyInt_SubtractObjC(PyObject *op1, PyObject *op2, long intval, int inplace, int zerodivision_check) { + CYTHON_MAYBE_UNUSED_VAR(intval); + CYTHON_MAYBE_UNUSED_VAR(inplace); + CYTHON_UNUSED_VAR(zerodivision_check); + #if PY_MAJOR_VERSION < 3 + if (likely(PyInt_CheckExact(op1))) { + const long b = intval; + long x; + long a = PyInt_AS_LONG(op1); + + x = (long)((unsigned long)a - (unsigned long)b); + if (likely((x^a) >= 0 || (x^~b) >= 0)) + return PyInt_FromLong(x); + return PyLong_Type.tp_as_number->nb_subtract(op1, op2); + } + #endif + #if CYTHON_USE_PYLONG_INTERNALS + if (likely(PyLong_CheckExact(op1))) { + const long b = intval; + long a, x; +#ifdef HAVE_LONG_LONG + const PY_LONG_LONG llb = intval; + PY_LONG_LONG lla, llx; +#endif + if (unlikely(__Pyx_PyLong_IsZero(op1))) { + return PyLong_FromLong(-intval); + } + if (likely(__Pyx_PyLong_IsCompact(op1))) { + a = __Pyx_PyLong_CompactValue(op1); + } else { + const digit* digits = __Pyx_PyLong_Digits(op1); + const Py_ssize_t size = __Pyx_PyLong_SignedDigitCount(op1); + switch (size) { + case -2: + if (8 * sizeof(long) - 1 > 2 * PyLong_SHIFT) { + a = -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])); + break; + #ifdef HAVE_LONG_LONG + } else if (8 * sizeof(PY_LONG_LONG) - 1 > 2 * PyLong_SHIFT) { + lla = -(PY_LONG_LONG) (((((unsigned PY_LONG_LONG)digits[1]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[0])); + goto long_long; + #endif + } + CYTHON_FALLTHROUGH; + case 2: + if (8 * sizeof(long) - 1 > 2 * PyLong_SHIFT) { + a = (long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])); + break; + #ifdef HAVE_LONG_LONG + } else if (8 * sizeof(PY_LONG_LONG) - 1 > 2 * PyLong_SHIFT) { + lla = (PY_LONG_LONG) (((((unsigned PY_LONG_LONG)digits[1]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[0])); + goto long_long; + #endif + } + CYTHON_FALLTHROUGH; + case -3: + if (8 * sizeof(long) - 1 > 3 * PyLong_SHIFT) { + a = -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])); + break; + #ifdef HAVE_LONG_LONG + } else if (8 * sizeof(PY_LONG_LONG) - 1 > 3 * PyLong_SHIFT) { + lla = -(PY_LONG_LONG) (((((((unsigned PY_LONG_LONG)digits[2]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[1]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[0])); + goto long_long; + #endif + } + CYTHON_FALLTHROUGH; + case 3: + if (8 * sizeof(long) - 1 > 3 * PyLong_SHIFT) { + a = (long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])); + break; + #ifdef HAVE_LONG_LONG + } else if (8 * sizeof(PY_LONG_LONG) - 1 > 3 * PyLong_SHIFT) { + lla = (PY_LONG_LONG) (((((((unsigned PY_LONG_LONG)digits[2]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[1]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[0])); + goto long_long; + #endif + } + CYTHON_FALLTHROUGH; + case -4: + if (8 * sizeof(long) - 1 > 4 * PyLong_SHIFT) { + a = -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])); + break; + #ifdef HAVE_LONG_LONG + } else if (8 * sizeof(PY_LONG_LONG) - 1 > 4 * PyLong_SHIFT) { + lla = -(PY_LONG_LONG) (((((((((unsigned PY_LONG_LONG)digits[3]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[2]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[1]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[0])); + goto long_long; + #endif + } + CYTHON_FALLTHROUGH; + case 4: + if (8 * sizeof(long) - 1 > 4 * PyLong_SHIFT) { + a = (long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0])); + break; + #ifdef HAVE_LONG_LONG + } else if (8 * sizeof(PY_LONG_LONG) - 1 > 4 * PyLong_SHIFT) { + lla = (PY_LONG_LONG) (((((((((unsigned PY_LONG_LONG)digits[3]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[2]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[1]) << PyLong_SHIFT) | (unsigned PY_LONG_LONG)digits[0])); + goto long_long; + #endif + } + CYTHON_FALLTHROUGH; + default: return PyLong_Type.tp_as_number->nb_subtract(op1, op2); + } + } + x = a - b; + return PyLong_FromLong(x); +#ifdef HAVE_LONG_LONG + long_long: + llx = lla - llb; + return PyLong_FromLongLong(llx); +#endif + + + } + #endif + if (PyFloat_CheckExact(op1)) { + const long b = intval; +#if CYTHON_COMPILING_IN_LIMITED_API + double a = __pyx_PyFloat_AsDouble(op1); +#else + double a = PyFloat_AS_DOUBLE(op1); +#endif + double result; + + PyFPE_START_PROTECT("subtract", return NULL) + result = ((double)a) - (double)b; + PyFPE_END_PROTECT(result) + return PyFloat_FromDouble(result); + } + return (inplace ? PyNumber_InPlaceSubtract : PyNumber_Subtract)(op1, op2); +} +#endif + +/* SliceObject */ + static CYTHON_INLINE PyObject* __Pyx_PyObject_GetSlice(PyObject* obj, + Py_ssize_t cstart, Py_ssize_t cstop, + PyObject** _py_start, PyObject** _py_stop, PyObject** _py_slice, + int has_cstart, int has_cstop, int wraparound) { + __Pyx_TypeName obj_type_name; +#if CYTHON_USE_TYPE_SLOTS + PyMappingMethods* mp; +#if PY_MAJOR_VERSION < 3 + PySequenceMethods* ms = Py_TYPE(obj)->tp_as_sequence; + if (likely(ms && ms->sq_slice)) { + if (!has_cstart) { + if (_py_start && (*_py_start != Py_None)) { + cstart = __Pyx_PyIndex_AsSsize_t(*_py_start); + if ((cstart == (Py_ssize_t)-1) && PyErr_Occurred()) goto bad; + } else + cstart = 0; + } + if (!has_cstop) { + if (_py_stop && (*_py_stop != Py_None)) { + cstop = __Pyx_PyIndex_AsSsize_t(*_py_stop); + if ((cstop == (Py_ssize_t)-1) && PyErr_Occurred()) goto bad; + } else + cstop = PY_SSIZE_T_MAX; + } + if (wraparound && unlikely((cstart < 0) | (cstop < 0)) && likely(ms->sq_length)) { + Py_ssize_t l = ms->sq_length(obj); + if (likely(l >= 0)) { + if (cstop < 0) { + cstop += l; + if (cstop < 0) cstop = 0; + } + if (cstart < 0) { + cstart += l; + if (cstart < 0) cstart = 0; + } + } else { + if (!PyErr_ExceptionMatches(PyExc_OverflowError)) + goto bad; + PyErr_Clear(); + } + } + return ms->sq_slice(obj, cstart, cstop); + } +#else + CYTHON_UNUSED_VAR(wraparound); +#endif + mp = Py_TYPE(obj)->tp_as_mapping; + if (likely(mp && mp->mp_subscript)) +#else + CYTHON_UNUSED_VAR(wraparound); +#endif + { + PyObject* result; + PyObject *py_slice, *py_start, *py_stop; + if (_py_slice) { + py_slice = *_py_slice; + } else { + PyObject* owned_start = NULL; + PyObject* owned_stop = NULL; + if (_py_start) { + py_start = *_py_start; + } else { + if (has_cstart) { + owned_start = py_start = PyInt_FromSsize_t(cstart); + if (unlikely(!py_start)) goto bad; + } else + py_start = Py_None; + } + if (_py_stop) { + py_stop = *_py_stop; + } else { + if (has_cstop) { + owned_stop = py_stop = PyInt_FromSsize_t(cstop); + if (unlikely(!py_stop)) { + Py_XDECREF(owned_start); + goto bad; + } + } else + py_stop = Py_None; + } + py_slice = PySlice_New(py_start, py_stop, Py_None); + Py_XDECREF(owned_start); + Py_XDECREF(owned_stop); + if (unlikely(!py_slice)) goto bad; + } +#if CYTHON_USE_TYPE_SLOTS + result = mp->mp_subscript(obj, py_slice); +#else + result = PyObject_GetItem(obj, py_slice); +#endif + if (!_py_slice) { + Py_DECREF(py_slice); + } + return result; + } + obj_type_name = __Pyx_PyType_GetName(Py_TYPE(obj)); + PyErr_Format(PyExc_TypeError, + "'" __Pyx_FMT_TYPENAME "' object is unsliceable", obj_type_name); + __Pyx_DECREF_TypeName(obj_type_name); +bad: + return NULL; +} + +/* PyObject_GenericGetAttrNoDict */ + #if CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP && PY_VERSION_HEX < 0x03070000 +static PyObject *__Pyx_RaiseGenericGetAttributeError(PyTypeObject *tp, PyObject *attr_name) { + __Pyx_TypeName type_name = __Pyx_PyType_GetName(tp); + PyErr_Format(PyExc_AttributeError, +#if PY_MAJOR_VERSION >= 3 + "'" __Pyx_FMT_TYPENAME "' object has no attribute '%U'", + type_name, attr_name); +#else + "'" __Pyx_FMT_TYPENAME "' object has no attribute '%.400s'", + type_name, PyString_AS_STRING(attr_name)); +#endif + __Pyx_DECREF_TypeName(type_name); + return NULL; +} +static CYTHON_INLINE PyObject* __Pyx_PyObject_GenericGetAttrNoDict(PyObject* obj, PyObject* attr_name) { + PyObject *descr; + PyTypeObject *tp = Py_TYPE(obj); + if (unlikely(!PyString_Check(attr_name))) { + return PyObject_GenericGetAttr(obj, attr_name); + } + assert(!tp->tp_dictoffset); + descr = _PyType_Lookup(tp, attr_name); + if (unlikely(!descr)) { + return __Pyx_RaiseGenericGetAttributeError(tp, attr_name); + } + Py_INCREF(descr); + #if PY_MAJOR_VERSION < 3 + if (likely(PyType_HasFeature(Py_TYPE(descr), Py_TPFLAGS_HAVE_CLASS))) + #endif + { + descrgetfunc f = Py_TYPE(descr)->tp_descr_get; + if (unlikely(f)) { + PyObject *res = f(descr, obj, (PyObject *)tp); + Py_DECREF(descr); + return res; + } + } + return descr; +} +#endif + +/* PyObject_GenericGetAttr */ + #if CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP && PY_VERSION_HEX < 0x03070000 +static PyObject* __Pyx_PyObject_GenericGetAttr(PyObject* obj, PyObject* attr_name) { + if (unlikely(Py_TYPE(obj)->tp_dictoffset)) { + return PyObject_GenericGetAttr(obj, attr_name); + } + return __Pyx_PyObject_GenericGetAttrNoDict(obj, attr_name); +} +#endif + +/* FixUpExtensionType */ + #if CYTHON_USE_TYPE_SPECS +static int __Pyx_fix_up_extension_type_from_spec(PyType_Spec *spec, PyTypeObject *type) { +#if PY_VERSION_HEX > 0x030900B1 || CYTHON_COMPILING_IN_LIMITED_API + CYTHON_UNUSED_VAR(spec); + CYTHON_UNUSED_VAR(type); +#else + const PyType_Slot *slot = spec->slots; + while (slot && slot->slot && slot->slot != Py_tp_members) + slot++; + if (slot && slot->slot == Py_tp_members) { + int changed = 0; +#if !(PY_VERSION_HEX <= 0x030900b1 && CYTHON_COMPILING_IN_CPYTHON) + const +#endif + PyMemberDef *memb = (PyMemberDef*) slot->pfunc; + while (memb && memb->name) { + if (memb->name[0] == '_' && memb->name[1] == '_') { +#if PY_VERSION_HEX < 0x030900b1 + if (strcmp(memb->name, "__weaklistoffset__") == 0) { + assert(memb->type == T_PYSSIZET); + assert(memb->flags == READONLY); + type->tp_weaklistoffset = memb->offset; + changed = 1; + } + else if (strcmp(memb->name, "__dictoffset__") == 0) { + assert(memb->type == T_PYSSIZET); + assert(memb->flags == READONLY); + type->tp_dictoffset = memb->offset; + changed = 1; + } +#if CYTHON_METH_FASTCALL + else if (strcmp(memb->name, "__vectorcalloffset__") == 0) { + assert(memb->type == T_PYSSIZET); + assert(memb->flags == READONLY); +#if PY_VERSION_HEX >= 0x030800b4 + type->tp_vectorcall_offset = memb->offset; +#else + type->tp_print = (printfunc) memb->offset; +#endif + changed = 1; + } +#endif +#else + if ((0)); +#endif +#if PY_VERSION_HEX <= 0x030900b1 && CYTHON_COMPILING_IN_CPYTHON + else if (strcmp(memb->name, "__module__") == 0) { + PyObject *descr; + assert(memb->type == T_OBJECT); + assert(memb->flags == 0 || memb->flags == READONLY); + descr = PyDescr_NewMember(type, memb); + if (unlikely(!descr)) + return -1; + if (unlikely(PyDict_SetItem(type->tp_dict, PyDescr_NAME(descr), descr) < 0)) { + Py_DECREF(descr); + return -1; + } + Py_DECREF(descr); + changed = 1; + } +#endif + } + memb++; + } + if (changed) + PyType_Modified(type); + } +#endif + return 0; +} +#endif + +/* PyObjectCallNoArg */ + static CYTHON_INLINE PyObject* __Pyx_PyObject_CallNoArg(PyObject *func) { + PyObject *arg[2] = {NULL, NULL}; + return __Pyx_PyObject_FastCall(func, arg + 1, 0 | __Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET); +} + +/* PyObjectGetMethod */ + static int __Pyx_PyObject_GetMethod(PyObject *obj, PyObject *name, PyObject **method) { + PyObject *attr; +#if CYTHON_UNPACK_METHODS && CYTHON_COMPILING_IN_CPYTHON && CYTHON_USE_PYTYPE_LOOKUP + __Pyx_TypeName type_name; + PyTypeObject *tp = Py_TYPE(obj); + PyObject *descr; + descrgetfunc f = NULL; + PyObject **dictptr, *dict; + int meth_found = 0; + assert (*method == NULL); + if (unlikely(tp->tp_getattro != PyObject_GenericGetAttr)) { + attr = __Pyx_PyObject_GetAttrStr(obj, name); + goto try_unpack; + } + if (unlikely(tp->tp_dict == NULL) && unlikely(PyType_Ready(tp) < 0)) { + return 0; + } + descr = _PyType_Lookup(tp, name); + if (likely(descr != NULL)) { + Py_INCREF(descr); +#if defined(Py_TPFLAGS_METHOD_DESCRIPTOR) && Py_TPFLAGS_METHOD_DESCRIPTOR + if (__Pyx_PyType_HasFeature(Py_TYPE(descr), Py_TPFLAGS_METHOD_DESCRIPTOR)) +#elif PY_MAJOR_VERSION >= 3 + #ifdef __Pyx_CyFunction_USED + if (likely(PyFunction_Check(descr) || __Pyx_IS_TYPE(descr, &PyMethodDescr_Type) || __Pyx_CyFunction_Check(descr))) + #else + if (likely(PyFunction_Check(descr) || __Pyx_IS_TYPE(descr, &PyMethodDescr_Type))) + #endif +#else + #ifdef __Pyx_CyFunction_USED + if (likely(PyFunction_Check(descr) || __Pyx_CyFunction_Check(descr))) + #else + if (likely(PyFunction_Check(descr))) + #endif +#endif + { + meth_found = 1; + } else { + f = Py_TYPE(descr)->tp_descr_get; + if (f != NULL && PyDescr_IsData(descr)) { + attr = f(descr, obj, (PyObject *)Py_TYPE(obj)); + Py_DECREF(descr); + goto try_unpack; + } + } + } + dictptr = _PyObject_GetDictPtr(obj); + if (dictptr != NULL && (dict = *dictptr) != NULL) { + Py_INCREF(dict); + attr = __Pyx_PyDict_GetItemStr(dict, name); + if (attr != NULL) { + Py_INCREF(attr); + Py_DECREF(dict); + Py_XDECREF(descr); + goto try_unpack; + } + Py_DECREF(dict); + } + if (meth_found) { + *method = descr; + return 1; + } + if (f != NULL) { + attr = f(descr, obj, (PyObject *)Py_TYPE(obj)); + Py_DECREF(descr); + goto try_unpack; + } + if (likely(descr != NULL)) { + *method = descr; + return 0; + } + type_name = __Pyx_PyType_GetName(tp); + PyErr_Format(PyExc_AttributeError, +#if PY_MAJOR_VERSION >= 3 + "'" __Pyx_FMT_TYPENAME "' object has no attribute '%U'", + type_name, name); +#else + "'" __Pyx_FMT_TYPENAME "' object has no attribute '%.400s'", + type_name, PyString_AS_STRING(name)); +#endif + __Pyx_DECREF_TypeName(type_name); + return 0; +#else + attr = __Pyx_PyObject_GetAttrStr(obj, name); + goto try_unpack; +#endif +try_unpack: +#if CYTHON_UNPACK_METHODS + if (likely(attr) && PyMethod_Check(attr) && likely(PyMethod_GET_SELF(attr) == obj)) { + PyObject *function = PyMethod_GET_FUNCTION(attr); + Py_INCREF(function); + Py_DECREF(attr); + *method = function; + return 1; + } +#endif + *method = attr; + return 0; +} + +/* PyObjectCallMethod0 */ + static PyObject* __Pyx_PyObject_CallMethod0(PyObject* obj, PyObject* method_name) { + PyObject *method = NULL, *result = NULL; + int is_method = __Pyx_PyObject_GetMethod(obj, method_name, &method); + if (likely(is_method)) { + result = __Pyx_PyObject_CallOneArg(method, obj); + Py_DECREF(method); + return result; + } + if (unlikely(!method)) goto bad; + result = __Pyx_PyObject_CallNoArg(method); + Py_DECREF(method); +bad: + return result; +} + +/* ValidateBasesTuple */ + #if CYTHON_COMPILING_IN_CPYTHON || CYTHON_COMPILING_IN_LIMITED_API || CYTHON_USE_TYPE_SPECS +static int __Pyx_validate_bases_tuple(const char *type_name, Py_ssize_t dictoffset, PyObject *bases) { + Py_ssize_t i, n; +#if CYTHON_ASSUME_SAFE_MACROS + n = PyTuple_GET_SIZE(bases); +#else + n = PyTuple_Size(bases); + if (n < 0) return -1; +#endif + for (i = 1; i < n; i++) + { +#if CYTHON_AVOID_BORROWED_REFS + PyObject *b0 = PySequence_GetItem(bases, i); + if (!b0) return -1; +#elif CYTHON_ASSUME_SAFE_MACROS + PyObject *b0 = PyTuple_GET_ITEM(bases, i); +#else + PyObject *b0 = PyTuple_GetItem(bases, i); + if (!b0) return -1; +#endif + PyTypeObject *b; +#if PY_MAJOR_VERSION < 3 + if (PyClass_Check(b0)) + { + PyErr_Format(PyExc_TypeError, "base class '%.200s' is an old-style class", + PyString_AS_STRING(((PyClassObject*)b0)->cl_name)); +#if CYTHON_AVOID_BORROWED_REFS + Py_DECREF(b0); +#endif + return -1; + } +#endif + b = (PyTypeObject*) b0; + if (!__Pyx_PyType_HasFeature(b, Py_TPFLAGS_HEAPTYPE)) + { + __Pyx_TypeName b_name = __Pyx_PyType_GetName(b); + PyErr_Format(PyExc_TypeError, + "base class '" __Pyx_FMT_TYPENAME "' is not a heap type", b_name); + __Pyx_DECREF_TypeName(b_name); +#if CYTHON_AVOID_BORROWED_REFS + Py_DECREF(b0); +#endif + return -1; + } + if (dictoffset == 0) + { + Py_ssize_t b_dictoffset = 0; +#if CYTHON_USE_TYPE_SLOTS || CYTHON_COMPILING_IN_PYPY + b_dictoffset = b->tp_dictoffset; +#else + PyObject *py_b_dictoffset = PyObject_GetAttrString((PyObject*)b, "__dictoffset__"); + if (!py_b_dictoffset) goto dictoffset_return; + b_dictoffset = PyLong_AsSsize_t(py_b_dictoffset); + Py_DECREF(py_b_dictoffset); + if (b_dictoffset == -1 && PyErr_Occurred()) goto dictoffset_return; +#endif + if (b_dictoffset) { + { + __Pyx_TypeName b_name = __Pyx_PyType_GetName(b); + PyErr_Format(PyExc_TypeError, + "extension type '%.200s' has no __dict__ slot, " + "but base type '" __Pyx_FMT_TYPENAME "' has: " + "either add 'cdef dict __dict__' to the extension type " + "or add '__slots__ = [...]' to the base type", + type_name, b_name); + __Pyx_DECREF_TypeName(b_name); + } +#if !(CYTHON_USE_TYPE_SLOTS || CYTHON_COMPILING_IN_PYPY) + dictoffset_return: +#endif +#if CYTHON_AVOID_BORROWED_REFS + Py_DECREF(b0); +#endif + return -1; + } + } +#if CYTHON_AVOID_BORROWED_REFS + Py_DECREF(b0); +#endif + } + return 0; +} +#endif + +/* PyType_Ready */ + static int __Pyx_PyType_Ready(PyTypeObject *t) { +#if CYTHON_USE_TYPE_SPECS || !(CYTHON_COMPILING_IN_CPYTHON || CYTHON_COMPILING_IN_LIMITED_API) || defined(PYSTON_MAJOR_VERSION) + (void)__Pyx_PyObject_CallMethod0; +#if CYTHON_USE_TYPE_SPECS + (void)__Pyx_validate_bases_tuple; +#endif + return PyType_Ready(t); +#else + int r; + PyObject *bases = __Pyx_PyType_GetSlot(t, tp_bases, PyObject*); + if (bases && unlikely(__Pyx_validate_bases_tuple(t->tp_name, t->tp_dictoffset, bases) == -1)) + return -1; +#if PY_VERSION_HEX >= 0x03050000 && !defined(PYSTON_MAJOR_VERSION) + { + int gc_was_enabled; + #if PY_VERSION_HEX >= 0x030A00b1 + gc_was_enabled = PyGC_Disable(); + (void)__Pyx_PyObject_CallMethod0; + #else + PyObject *ret, *py_status; + PyObject *gc = NULL; + #if PY_VERSION_HEX >= 0x030700a1 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM+0 >= 0x07030400) + gc = PyImport_GetModule(__pyx_kp_u_gc); + #endif + if (unlikely(!gc)) gc = PyImport_Import(__pyx_kp_u_gc); + if (unlikely(!gc)) return -1; + py_status = __Pyx_PyObject_CallMethod0(gc, __pyx_kp_u_isenabled); + if (unlikely(!py_status)) { + Py_DECREF(gc); + return -1; + } + gc_was_enabled = __Pyx_PyObject_IsTrue(py_status); + Py_DECREF(py_status); + if (gc_was_enabled > 0) { + ret = __Pyx_PyObject_CallMethod0(gc, __pyx_kp_u_disable); + if (unlikely(!ret)) { + Py_DECREF(gc); + return -1; + } + Py_DECREF(ret); + } else if (unlikely(gc_was_enabled == -1)) { + Py_DECREF(gc); + return -1; + } + #endif + t->tp_flags |= Py_TPFLAGS_HEAPTYPE; +#if PY_VERSION_HEX >= 0x030A0000 + t->tp_flags |= Py_TPFLAGS_IMMUTABLETYPE; +#endif +#else + (void)__Pyx_PyObject_CallMethod0; +#endif + r = PyType_Ready(t); +#if PY_VERSION_HEX >= 0x03050000 && !defined(PYSTON_MAJOR_VERSION) + t->tp_flags &= ~Py_TPFLAGS_HEAPTYPE; + #if PY_VERSION_HEX >= 0x030A00b1 + if (gc_was_enabled) + PyGC_Enable(); + #else + if (gc_was_enabled) { + PyObject *tp, *v, *tb; + PyErr_Fetch(&tp, &v, &tb); + ret = __Pyx_PyObject_CallMethod0(gc, __pyx_kp_u_enable); + if (likely(ret || r == -1)) { + Py_XDECREF(ret); + PyErr_Restore(tp, v, tb); + } else { + Py_XDECREF(tp); + Py_XDECREF(v); + Py_XDECREF(tb); + r = -1; + } + } + Py_DECREF(gc); + #endif + } +#endif + return r; +#endif +} + +/* SetVTable */ + static int __Pyx_SetVtable(PyTypeObject *type, void *vtable) { + PyObject *ob = PyCapsule_New(vtable, 0, 0); + if (unlikely(!ob)) + goto bad; +#if CYTHON_COMPILING_IN_LIMITED_API + if (unlikely(PyObject_SetAttr((PyObject *) type, __pyx_n_s_pyx_vtable, ob) < 0)) +#else + if (unlikely(PyDict_SetItem(type->tp_dict, __pyx_n_s_pyx_vtable, ob) < 0)) +#endif + goto bad; + Py_DECREF(ob); + return 0; +bad: + Py_XDECREF(ob); + return -1; +} + +/* GetVTable */ + static void* __Pyx_GetVtable(PyTypeObject *type) { + void* ptr; +#if CYTHON_COMPILING_IN_LIMITED_API + PyObject *ob = PyObject_GetAttr((PyObject *)type, __pyx_n_s_pyx_vtable); +#else + PyObject *ob = PyObject_GetItem(type->tp_dict, __pyx_n_s_pyx_vtable); +#endif + if (!ob) + goto bad; + ptr = PyCapsule_GetPointer(ob, 0); + if (!ptr && !PyErr_Occurred()) + PyErr_SetString(PyExc_RuntimeError, "invalid vtable found for imported type"); + Py_DECREF(ob); + return ptr; +bad: + Py_XDECREF(ob); + return NULL; +} + +/* MergeVTables */ + #if !CYTHON_COMPILING_IN_LIMITED_API +static int __Pyx_MergeVtables(PyTypeObject *type) { + int i; + void** base_vtables; + __Pyx_TypeName tp_base_name; + __Pyx_TypeName base_name; + void* unknown = (void*)-1; + PyObject* bases = type->tp_bases; + int base_depth = 0; + { + PyTypeObject* base = type->tp_base; + while (base) { + base_depth += 1; + base = base->tp_base; + } + } + base_vtables = (void**) malloc(sizeof(void*) * (size_t)(base_depth + 1)); + base_vtables[0] = unknown; + for (i = 1; i < PyTuple_GET_SIZE(bases); i++) { + void* base_vtable = __Pyx_GetVtable(((PyTypeObject*)PyTuple_GET_ITEM(bases, i))); + if (base_vtable != NULL) { + int j; + PyTypeObject* base = type->tp_base; + for (j = 0; j < base_depth; j++) { + if (base_vtables[j] == unknown) { + base_vtables[j] = __Pyx_GetVtable(base); + base_vtables[j + 1] = unknown; + } + if (base_vtables[j] == base_vtable) { + break; + } else if (base_vtables[j] == NULL) { + goto bad; + } + base = base->tp_base; + } + } + } + PyErr_Clear(); + free(base_vtables); + return 0; +bad: + tp_base_name = __Pyx_PyType_GetName(type->tp_base); + base_name = __Pyx_PyType_GetName((PyTypeObject*)PyTuple_GET_ITEM(bases, i)); + PyErr_Format(PyExc_TypeError, + "multiple bases have vtable conflict: '" __Pyx_FMT_TYPENAME "' and '" __Pyx_FMT_TYPENAME "'", tp_base_name, base_name); + __Pyx_DECREF_TypeName(tp_base_name); + __Pyx_DECREF_TypeName(base_name); + free(base_vtables); + return -1; +} +#endif + +/* SetupReduce */ + #if !CYTHON_COMPILING_IN_LIMITED_API +static int __Pyx_setup_reduce_is_named(PyObject* meth, PyObject* name) { + int ret; + PyObject *name_attr; + name_attr = __Pyx_PyObject_GetAttrStrNoError(meth, __pyx_n_s_name_2); + if (likely(name_attr)) { + ret = PyObject_RichCompareBool(name_attr, name, Py_EQ); + } else { + ret = -1; + } + if (unlikely(ret < 0)) { + PyErr_Clear(); + ret = 0; + } + Py_XDECREF(name_attr); + return ret; +} +static int __Pyx_setup_reduce(PyObject* type_obj) { + int ret = 0; + PyObject *object_reduce = NULL; + PyObject *object_getstate = NULL; + PyObject *object_reduce_ex = NULL; + PyObject *reduce = NULL; + PyObject *reduce_ex = NULL; + PyObject *reduce_cython = NULL; + PyObject *setstate = NULL; + PyObject *setstate_cython = NULL; + PyObject *getstate = NULL; +#if CYTHON_USE_PYTYPE_LOOKUP + getstate = _PyType_Lookup((PyTypeObject*)type_obj, __pyx_n_s_getstate); +#else + getstate = __Pyx_PyObject_GetAttrStrNoError(type_obj, __pyx_n_s_getstate); + if (!getstate && PyErr_Occurred()) { + goto __PYX_BAD; + } +#endif + if (getstate) { +#if CYTHON_USE_PYTYPE_LOOKUP + object_getstate = _PyType_Lookup(&PyBaseObject_Type, __pyx_n_s_getstate); +#else + object_getstate = __Pyx_PyObject_GetAttrStrNoError((PyObject*)&PyBaseObject_Type, __pyx_n_s_getstate); + if (!object_getstate && PyErr_Occurred()) { + goto __PYX_BAD; + } +#endif + if (object_getstate != getstate) { + goto __PYX_GOOD; + } + } +#if CYTHON_USE_PYTYPE_LOOKUP + object_reduce_ex = _PyType_Lookup(&PyBaseObject_Type, __pyx_n_s_reduce_ex); if (!object_reduce_ex) goto __PYX_BAD; +#else + object_reduce_ex = __Pyx_PyObject_GetAttrStr((PyObject*)&PyBaseObject_Type, __pyx_n_s_reduce_ex); if (!object_reduce_ex) goto __PYX_BAD; +#endif + reduce_ex = __Pyx_PyObject_GetAttrStr(type_obj, __pyx_n_s_reduce_ex); if (unlikely(!reduce_ex)) goto __PYX_BAD; + if (reduce_ex == object_reduce_ex) { +#if CYTHON_USE_PYTYPE_LOOKUP + object_reduce = _PyType_Lookup(&PyBaseObject_Type, __pyx_n_s_reduce); if (!object_reduce) goto __PYX_BAD; +#else + object_reduce = __Pyx_PyObject_GetAttrStr((PyObject*)&PyBaseObject_Type, __pyx_n_s_reduce); if (!object_reduce) goto __PYX_BAD; +#endif + reduce = __Pyx_PyObject_GetAttrStr(type_obj, __pyx_n_s_reduce); if (unlikely(!reduce)) goto __PYX_BAD; + if (reduce == object_reduce || __Pyx_setup_reduce_is_named(reduce, __pyx_n_s_reduce_cython)) { + reduce_cython = __Pyx_PyObject_GetAttrStrNoError(type_obj, __pyx_n_s_reduce_cython); + if (likely(reduce_cython)) { + ret = PyDict_SetItem(((PyTypeObject*)type_obj)->tp_dict, __pyx_n_s_reduce, reduce_cython); if (unlikely(ret < 0)) goto __PYX_BAD; + ret = PyDict_DelItem(((PyTypeObject*)type_obj)->tp_dict, __pyx_n_s_reduce_cython); if (unlikely(ret < 0)) goto __PYX_BAD; + } else if (reduce == object_reduce || PyErr_Occurred()) { + goto __PYX_BAD; + } + setstate = __Pyx_PyObject_GetAttrStrNoError(type_obj, __pyx_n_s_setstate); + if (!setstate) PyErr_Clear(); + if (!setstate || __Pyx_setup_reduce_is_named(setstate, __pyx_n_s_setstate_cython)) { + setstate_cython = __Pyx_PyObject_GetAttrStrNoError(type_obj, __pyx_n_s_setstate_cython); + if (likely(setstate_cython)) { + ret = PyDict_SetItem(((PyTypeObject*)type_obj)->tp_dict, __pyx_n_s_setstate, setstate_cython); if (unlikely(ret < 0)) goto __PYX_BAD; + ret = PyDict_DelItem(((PyTypeObject*)type_obj)->tp_dict, __pyx_n_s_setstate_cython); if (unlikely(ret < 0)) goto __PYX_BAD; + } else if (!setstate || PyErr_Occurred()) { + goto __PYX_BAD; + } + } + PyType_Modified((PyTypeObject*)type_obj); + } + } + goto __PYX_GOOD; +__PYX_BAD: + if (!PyErr_Occurred()) { + __Pyx_TypeName type_obj_name = + __Pyx_PyType_GetName((PyTypeObject*)type_obj); + PyErr_Format(PyExc_RuntimeError, + "Unable to initialize pickling for " __Pyx_FMT_TYPENAME, type_obj_name); + __Pyx_DECREF_TypeName(type_obj_name); + } + ret = -1; +__PYX_GOOD: +#if !CYTHON_USE_PYTYPE_LOOKUP + Py_XDECREF(object_reduce); + Py_XDECREF(object_reduce_ex); + Py_XDECREF(object_getstate); + Py_XDECREF(getstate); +#endif + Py_XDECREF(reduce); + Py_XDECREF(reduce_ex); + Py_XDECREF(reduce_cython); + Py_XDECREF(setstate); + Py_XDECREF(setstate_cython); + return ret; +} +#endif + +/* TypeImport */ + #ifndef __PYX_HAVE_RT_ImportType_3_0_8 +#define __PYX_HAVE_RT_ImportType_3_0_8 +static PyTypeObject *__Pyx_ImportType_3_0_8(PyObject *module, const char *module_name, const char *class_name, + size_t size, size_t alignment, enum __Pyx_ImportType_CheckSize_3_0_8 check_size) +{ + PyObject *result = 0; + char warning[200]; + Py_ssize_t basicsize; + Py_ssize_t itemsize; +#if CYTHON_COMPILING_IN_LIMITED_API + PyObject *py_basicsize; + PyObject *py_itemsize; +#endif + result = PyObject_GetAttrString(module, class_name); + if (!result) + goto bad; + if (!PyType_Check(result)) { + PyErr_Format(PyExc_TypeError, + "%.200s.%.200s is not a type object", + module_name, class_name); + goto bad; + } +#if !CYTHON_COMPILING_IN_LIMITED_API + basicsize = ((PyTypeObject *)result)->tp_basicsize; + itemsize = ((PyTypeObject *)result)->tp_itemsize; +#else + py_basicsize = PyObject_GetAttrString(result, "__basicsize__"); + if (!py_basicsize) + goto bad; + basicsize = PyLong_AsSsize_t(py_basicsize); + Py_DECREF(py_basicsize); + py_basicsize = 0; + if (basicsize == (Py_ssize_t)-1 && PyErr_Occurred()) + goto bad; + py_itemsize = PyObject_GetAttrString(result, "__itemsize__"); + if (!py_itemsize) + goto bad; + itemsize = PyLong_AsSsize_t(py_itemsize); + Py_DECREF(py_itemsize); + py_itemsize = 0; + if (itemsize == (Py_ssize_t)-1 && PyErr_Occurred()) + goto bad; +#endif + if (itemsize) { + if (size % alignment) { + alignment = size % alignment; + } + if (itemsize < (Py_ssize_t)alignment) + itemsize = (Py_ssize_t)alignment; + } + if ((size_t)(basicsize + itemsize) < size) { + PyErr_Format(PyExc_ValueError, + "%.200s.%.200s size changed, may indicate binary incompatibility. " + "Expected %zd from C header, got %zd from PyObject", + module_name, class_name, size, basicsize+itemsize); + goto bad; + } + if (check_size == __Pyx_ImportType_CheckSize_Error_3_0_8 && + ((size_t)basicsize > size || (size_t)(basicsize + itemsize) < size)) { + PyErr_Format(PyExc_ValueError, + "%.200s.%.200s size changed, may indicate binary incompatibility. " + "Expected %zd from C header, got %zd-%zd from PyObject", + module_name, class_name, size, basicsize, basicsize+itemsize); + goto bad; + } + else if (check_size == __Pyx_ImportType_CheckSize_Warn_3_0_8 && (size_t)basicsize > size) { + PyOS_snprintf(warning, sizeof(warning), + "%s.%s size changed, may indicate binary incompatibility. " + "Expected %zd from C header, got %zd from PyObject", + module_name, class_name, size, basicsize); + if (PyErr_WarnEx(NULL, warning, 0) < 0) goto bad; + } + return (PyTypeObject *)result; +bad: + Py_XDECREF(result); + return NULL; +} +#endif + +/* FetchSharedCythonModule */ + static PyObject *__Pyx_FetchSharedCythonABIModule(void) { + return __Pyx_PyImport_AddModuleRef((char*) __PYX_ABI_MODULE_NAME); +} + +/* FetchCommonType */ + static int __Pyx_VerifyCachedType(PyObject *cached_type, + const char *name, + Py_ssize_t basicsize, + Py_ssize_t expected_basicsize) { + if (!PyType_Check(cached_type)) { + PyErr_Format(PyExc_TypeError, + "Shared Cython type %.200s is not a type object", name); + return -1; + } + if (basicsize != expected_basicsize) { + PyErr_Format(PyExc_TypeError, + "Shared Cython type %.200s has the wrong size, try recompiling", + name); + return -1; + } + return 0; +} +#if !CYTHON_USE_TYPE_SPECS +static PyTypeObject* __Pyx_FetchCommonType(PyTypeObject* type) { + PyObject* abi_module; + const char* object_name; + PyTypeObject *cached_type = NULL; + abi_module = __Pyx_FetchSharedCythonABIModule(); + if (!abi_module) return NULL; + object_name = strrchr(type->tp_name, '.'); + object_name = object_name ? object_name+1 : type->tp_name; + cached_type = (PyTypeObject*) PyObject_GetAttrString(abi_module, object_name); + if (cached_type) { + if (__Pyx_VerifyCachedType( + (PyObject *)cached_type, + object_name, + cached_type->tp_basicsize, + type->tp_basicsize) < 0) { + goto bad; + } + goto done; + } + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) goto bad; + PyErr_Clear(); + if (PyType_Ready(type) < 0) goto bad; + if (PyObject_SetAttrString(abi_module, object_name, (PyObject *)type) < 0) + goto bad; + Py_INCREF(type); + cached_type = type; +done: + Py_DECREF(abi_module); + return cached_type; +bad: + Py_XDECREF(cached_type); + cached_type = NULL; + goto done; +} +#else +static PyTypeObject *__Pyx_FetchCommonTypeFromSpec(PyObject *module, PyType_Spec *spec, PyObject *bases) { + PyObject *abi_module, *cached_type = NULL; + const char* object_name = strrchr(spec->name, '.'); + object_name = object_name ? object_name+1 : spec->name; + abi_module = __Pyx_FetchSharedCythonABIModule(); + if (!abi_module) return NULL; + cached_type = PyObject_GetAttrString(abi_module, object_name); + if (cached_type) { + Py_ssize_t basicsize; +#if CYTHON_COMPILING_IN_LIMITED_API + PyObject *py_basicsize; + py_basicsize = PyObject_GetAttrString(cached_type, "__basicsize__"); + if (unlikely(!py_basicsize)) goto bad; + basicsize = PyLong_AsSsize_t(py_basicsize); + Py_DECREF(py_basicsize); + py_basicsize = 0; + if (unlikely(basicsize == (Py_ssize_t)-1) && PyErr_Occurred()) goto bad; +#else + basicsize = likely(PyType_Check(cached_type)) ? ((PyTypeObject*) cached_type)->tp_basicsize : -1; +#endif + if (__Pyx_VerifyCachedType( + cached_type, + object_name, + basicsize, + spec->basicsize) < 0) { + goto bad; + } + goto done; + } + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) goto bad; + PyErr_Clear(); + CYTHON_UNUSED_VAR(module); + cached_type = __Pyx_PyType_FromModuleAndSpec(abi_module, spec, bases); + if (unlikely(!cached_type)) goto bad; + if (unlikely(__Pyx_fix_up_extension_type_from_spec(spec, (PyTypeObject *) cached_type) < 0)) goto bad; + if (PyObject_SetAttrString(abi_module, object_name, cached_type) < 0) goto bad; +done: + Py_DECREF(abi_module); + assert(cached_type == NULL || PyType_Check(cached_type)); + return (PyTypeObject *) cached_type; +bad: + Py_XDECREF(cached_type); + cached_type = NULL; + goto done; +} +#endif + +/* PyVectorcallFastCallDict */ + #if CYTHON_METH_FASTCALL +static PyObject *__Pyx_PyVectorcall_FastCallDict_kw(PyObject *func, __pyx_vectorcallfunc vc, PyObject *const *args, size_t nargs, PyObject *kw) +{ + PyObject *res = NULL; + PyObject *kwnames; + PyObject **newargs; + PyObject **kwvalues; + Py_ssize_t i, pos; + size_t j; + PyObject *key, *value; + unsigned long keys_are_strings; + Py_ssize_t nkw = PyDict_GET_SIZE(kw); + newargs = (PyObject **)PyMem_Malloc((nargs + (size_t)nkw) * sizeof(args[0])); + if (unlikely(newargs == NULL)) { + PyErr_NoMemory(); + return NULL; + } + for (j = 0; j < nargs; j++) newargs[j] = args[j]; + kwnames = PyTuple_New(nkw); + if (unlikely(kwnames == NULL)) { + PyMem_Free(newargs); + return NULL; + } + kwvalues = newargs + nargs; + pos = i = 0; + keys_are_strings = Py_TPFLAGS_UNICODE_SUBCLASS; + while (PyDict_Next(kw, &pos, &key, &value)) { + keys_are_strings &= Py_TYPE(key)->tp_flags; + Py_INCREF(key); + Py_INCREF(value); + PyTuple_SET_ITEM(kwnames, i, key); + kwvalues[i] = value; + i++; + } + if (unlikely(!keys_are_strings)) { + PyErr_SetString(PyExc_TypeError, "keywords must be strings"); + goto cleanup; + } + res = vc(func, newargs, nargs, kwnames); +cleanup: + Py_DECREF(kwnames); + for (i = 0; i < nkw; i++) + Py_DECREF(kwvalues[i]); + PyMem_Free(newargs); + return res; +} +static CYTHON_INLINE PyObject *__Pyx_PyVectorcall_FastCallDict(PyObject *func, __pyx_vectorcallfunc vc, PyObject *const *args, size_t nargs, PyObject *kw) +{ + if (likely(kw == NULL) || PyDict_GET_SIZE(kw) == 0) { + return vc(func, args, nargs, NULL); + } + return __Pyx_PyVectorcall_FastCallDict_kw(func, vc, args, nargs, kw); +} +#endif + +/* CythonFunctionShared */ + #if CYTHON_COMPILING_IN_LIMITED_API +static CYTHON_INLINE int __Pyx__IsSameCyOrCFunction(PyObject *func, void *cfunc) { + if (__Pyx_CyFunction_Check(func)) { + return PyCFunction_GetFunction(((__pyx_CyFunctionObject*)func)->func) == (PyCFunction) cfunc; + } else if (PyCFunction_Check(func)) { + return PyCFunction_GetFunction(func) == (PyCFunction) cfunc; + } + return 0; +} +#else +static CYTHON_INLINE int __Pyx__IsSameCyOrCFunction(PyObject *func, void *cfunc) { + return __Pyx_CyOrPyCFunction_Check(func) && __Pyx_CyOrPyCFunction_GET_FUNCTION(func) == (PyCFunction) cfunc; +} +#endif +static CYTHON_INLINE void __Pyx__CyFunction_SetClassObj(__pyx_CyFunctionObject* f, PyObject* classobj) { +#if PY_VERSION_HEX < 0x030900B1 || CYTHON_COMPILING_IN_LIMITED_API + __Pyx_Py_XDECREF_SET( + __Pyx_CyFunction_GetClassObj(f), + ((classobj) ? __Pyx_NewRef(classobj) : NULL)); +#else + __Pyx_Py_XDECREF_SET( + ((PyCMethodObject *) (f))->mm_class, + (PyTypeObject*)((classobj) ? __Pyx_NewRef(classobj) : NULL)); +#endif +} +static PyObject * +__Pyx_CyFunction_get_doc(__pyx_CyFunctionObject *op, void *closure) +{ + CYTHON_UNUSED_VAR(closure); + if (unlikely(op->func_doc == NULL)) { +#if CYTHON_COMPILING_IN_LIMITED_API + op->func_doc = PyObject_GetAttrString(op->func, "__doc__"); + if (unlikely(!op->func_doc)) return NULL; +#else + if (((PyCFunctionObject*)op)->m_ml->ml_doc) { +#if PY_MAJOR_VERSION >= 3 + op->func_doc = PyUnicode_FromString(((PyCFunctionObject*)op)->m_ml->ml_doc); +#else + op->func_doc = PyString_FromString(((PyCFunctionObject*)op)->m_ml->ml_doc); +#endif + if (unlikely(op->func_doc == NULL)) + return NULL; + } else { + Py_INCREF(Py_None); + return Py_None; + } +#endif + } + Py_INCREF(op->func_doc); + return op->func_doc; +} +static int +__Pyx_CyFunction_set_doc(__pyx_CyFunctionObject *op, PyObject *value, void *context) +{ + CYTHON_UNUSED_VAR(context); + if (value == NULL) { + value = Py_None; + } + Py_INCREF(value); + __Pyx_Py_XDECREF_SET(op->func_doc, value); + return 0; +} +static PyObject * +__Pyx_CyFunction_get_name(__pyx_CyFunctionObject *op, void *context) +{ + CYTHON_UNUSED_VAR(context); + if (unlikely(op->func_name == NULL)) { +#if CYTHON_COMPILING_IN_LIMITED_API + op->func_name = PyObject_GetAttrString(op->func, "__name__"); +#elif PY_MAJOR_VERSION >= 3 + op->func_name = PyUnicode_InternFromString(((PyCFunctionObject*)op)->m_ml->ml_name); +#else + op->func_name = PyString_InternFromString(((PyCFunctionObject*)op)->m_ml->ml_name); +#endif + if (unlikely(op->func_name == NULL)) + return NULL; + } + Py_INCREF(op->func_name); + return op->func_name; +} +static int +__Pyx_CyFunction_set_name(__pyx_CyFunctionObject *op, PyObject *value, void *context) +{ + CYTHON_UNUSED_VAR(context); +#if PY_MAJOR_VERSION >= 3 + if (unlikely(value == NULL || !PyUnicode_Check(value))) +#else + if (unlikely(value == NULL || !PyString_Check(value))) +#endif + { + PyErr_SetString(PyExc_TypeError, + "__name__ must be set to a string object"); + return -1; + } + Py_INCREF(value); + __Pyx_Py_XDECREF_SET(op->func_name, value); + return 0; +} +static PyObject * +__Pyx_CyFunction_get_qualname(__pyx_CyFunctionObject *op, void *context) +{ + CYTHON_UNUSED_VAR(context); + Py_INCREF(op->func_qualname); + return op->func_qualname; +} +static int +__Pyx_CyFunction_set_qualname(__pyx_CyFunctionObject *op, PyObject *value, void *context) +{ + CYTHON_UNUSED_VAR(context); +#if PY_MAJOR_VERSION >= 3 + if (unlikely(value == NULL || !PyUnicode_Check(value))) +#else + if (unlikely(value == NULL || !PyString_Check(value))) +#endif + { + PyErr_SetString(PyExc_TypeError, + "__qualname__ must be set to a string object"); + return -1; + } + Py_INCREF(value); + __Pyx_Py_XDECREF_SET(op->func_qualname, value); + return 0; +} +static PyObject * +__Pyx_CyFunction_get_dict(__pyx_CyFunctionObject *op, void *context) +{ + CYTHON_UNUSED_VAR(context); + if (unlikely(op->func_dict == NULL)) { + op->func_dict = PyDict_New(); + if (unlikely(op->func_dict == NULL)) + return NULL; + } + Py_INCREF(op->func_dict); + return op->func_dict; +} +static int +__Pyx_CyFunction_set_dict(__pyx_CyFunctionObject *op, PyObject *value, void *context) +{ + CYTHON_UNUSED_VAR(context); + if (unlikely(value == NULL)) { + PyErr_SetString(PyExc_TypeError, + "function's dictionary may not be deleted"); + return -1; + } + if (unlikely(!PyDict_Check(value))) { + PyErr_SetString(PyExc_TypeError, + "setting function's dictionary to a non-dict"); + return -1; + } + Py_INCREF(value); + __Pyx_Py_XDECREF_SET(op->func_dict, value); + return 0; +} +static PyObject * +__Pyx_CyFunction_get_globals(__pyx_CyFunctionObject *op, void *context) +{ + CYTHON_UNUSED_VAR(context); + Py_INCREF(op->func_globals); + return op->func_globals; +} +static PyObject * +__Pyx_CyFunction_get_closure(__pyx_CyFunctionObject *op, void *context) +{ + CYTHON_UNUSED_VAR(op); + CYTHON_UNUSED_VAR(context); + Py_INCREF(Py_None); + return Py_None; +} +static PyObject * +__Pyx_CyFunction_get_code(__pyx_CyFunctionObject *op, void *context) +{ + PyObject* result = (op->func_code) ? op->func_code : Py_None; + CYTHON_UNUSED_VAR(context); + Py_INCREF(result); + return result; +} +static int +__Pyx_CyFunction_init_defaults(__pyx_CyFunctionObject *op) { + int result = 0; + PyObject *res = op->defaults_getter((PyObject *) op); + if (unlikely(!res)) + return -1; + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + op->defaults_tuple = PyTuple_GET_ITEM(res, 0); + Py_INCREF(op->defaults_tuple); + op->defaults_kwdict = PyTuple_GET_ITEM(res, 1); + Py_INCREF(op->defaults_kwdict); + #else + op->defaults_tuple = __Pyx_PySequence_ITEM(res, 0); + if (unlikely(!op->defaults_tuple)) result = -1; + else { + op->defaults_kwdict = __Pyx_PySequence_ITEM(res, 1); + if (unlikely(!op->defaults_kwdict)) result = -1; + } + #endif + Py_DECREF(res); + return result; +} +static int +__Pyx_CyFunction_set_defaults(__pyx_CyFunctionObject *op, PyObject* value, void *context) { + CYTHON_UNUSED_VAR(context); + if (!value) { + value = Py_None; + } else if (unlikely(value != Py_None && !PyTuple_Check(value))) { + PyErr_SetString(PyExc_TypeError, + "__defaults__ must be set to a tuple object"); + return -1; + } + PyErr_WarnEx(PyExc_RuntimeWarning, "changes to cyfunction.__defaults__ will not " + "currently affect the values used in function calls", 1); + Py_INCREF(value); + __Pyx_Py_XDECREF_SET(op->defaults_tuple, value); + return 0; +} +static PyObject * +__Pyx_CyFunction_get_defaults(__pyx_CyFunctionObject *op, void *context) { + PyObject* result = op->defaults_tuple; + CYTHON_UNUSED_VAR(context); + if (unlikely(!result)) { + if (op->defaults_getter) { + if (unlikely(__Pyx_CyFunction_init_defaults(op) < 0)) return NULL; + result = op->defaults_tuple; + } else { + result = Py_None; + } + } + Py_INCREF(result); + return result; +} +static int +__Pyx_CyFunction_set_kwdefaults(__pyx_CyFunctionObject *op, PyObject* value, void *context) { + CYTHON_UNUSED_VAR(context); + if (!value) { + value = Py_None; + } else if (unlikely(value != Py_None && !PyDict_Check(value))) { + PyErr_SetString(PyExc_TypeError, + "__kwdefaults__ must be set to a dict object"); + return -1; + } + PyErr_WarnEx(PyExc_RuntimeWarning, "changes to cyfunction.__kwdefaults__ will not " + "currently affect the values used in function calls", 1); + Py_INCREF(value); + __Pyx_Py_XDECREF_SET(op->defaults_kwdict, value); + return 0; +} +static PyObject * +__Pyx_CyFunction_get_kwdefaults(__pyx_CyFunctionObject *op, void *context) { + PyObject* result = op->defaults_kwdict; + CYTHON_UNUSED_VAR(context); + if (unlikely(!result)) { + if (op->defaults_getter) { + if (unlikely(__Pyx_CyFunction_init_defaults(op) < 0)) return NULL; + result = op->defaults_kwdict; + } else { + result = Py_None; + } + } + Py_INCREF(result); + return result; +} +static int +__Pyx_CyFunction_set_annotations(__pyx_CyFunctionObject *op, PyObject* value, void *context) { + CYTHON_UNUSED_VAR(context); + if (!value || value == Py_None) { + value = NULL; + } else if (unlikely(!PyDict_Check(value))) { + PyErr_SetString(PyExc_TypeError, + "__annotations__ must be set to a dict object"); + return -1; + } + Py_XINCREF(value); + __Pyx_Py_XDECREF_SET(op->func_annotations, value); + return 0; +} +static PyObject * +__Pyx_CyFunction_get_annotations(__pyx_CyFunctionObject *op, void *context) { + PyObject* result = op->func_annotations; + CYTHON_UNUSED_VAR(context); + if (unlikely(!result)) { + result = PyDict_New(); + if (unlikely(!result)) return NULL; + op->func_annotations = result; + } + Py_INCREF(result); + return result; +} +static PyObject * +__Pyx_CyFunction_get_is_coroutine(__pyx_CyFunctionObject *op, void *context) { + int is_coroutine; + CYTHON_UNUSED_VAR(context); + if (op->func_is_coroutine) { + return __Pyx_NewRef(op->func_is_coroutine); + } + is_coroutine = op->flags & __Pyx_CYFUNCTION_COROUTINE; +#if PY_VERSION_HEX >= 0x03050000 + if (is_coroutine) { + PyObject *module, *fromlist, *marker = __pyx_n_s_is_coroutine; + fromlist = PyList_New(1); + if (unlikely(!fromlist)) return NULL; + Py_INCREF(marker); +#if CYTHON_ASSUME_SAFE_MACROS + PyList_SET_ITEM(fromlist, 0, marker); +#else + if (unlikely(PyList_SetItem(fromlist, 0, marker) < 0)) { + Py_DECREF(marker); + Py_DECREF(fromlist); + return NULL; + } +#endif + module = PyImport_ImportModuleLevelObject(__pyx_n_s_asyncio_coroutines, NULL, NULL, fromlist, 0); + Py_DECREF(fromlist); + if (unlikely(!module)) goto ignore; + op->func_is_coroutine = __Pyx_PyObject_GetAttrStr(module, marker); + Py_DECREF(module); + if (likely(op->func_is_coroutine)) { + return __Pyx_NewRef(op->func_is_coroutine); + } +ignore: + PyErr_Clear(); + } +#endif + op->func_is_coroutine = __Pyx_PyBool_FromLong(is_coroutine); + return __Pyx_NewRef(op->func_is_coroutine); +} +#if CYTHON_COMPILING_IN_LIMITED_API +static PyObject * +__Pyx_CyFunction_get_module(__pyx_CyFunctionObject *op, void *context) { + CYTHON_UNUSED_VAR(context); + return PyObject_GetAttrString(op->func, "__module__"); +} +static int +__Pyx_CyFunction_set_module(__pyx_CyFunctionObject *op, PyObject* value, void *context) { + CYTHON_UNUSED_VAR(context); + return PyObject_SetAttrString(op->func, "__module__", value); +} +#endif +static PyGetSetDef __pyx_CyFunction_getsets[] = { + {(char *) "func_doc", (getter)__Pyx_CyFunction_get_doc, (setter)__Pyx_CyFunction_set_doc, 0, 0}, + {(char *) "__doc__", (getter)__Pyx_CyFunction_get_doc, (setter)__Pyx_CyFunction_set_doc, 0, 0}, + {(char *) "func_name", (getter)__Pyx_CyFunction_get_name, (setter)__Pyx_CyFunction_set_name, 0, 0}, + {(char *) "__name__", (getter)__Pyx_CyFunction_get_name, (setter)__Pyx_CyFunction_set_name, 0, 0}, + {(char *) "__qualname__", (getter)__Pyx_CyFunction_get_qualname, (setter)__Pyx_CyFunction_set_qualname, 0, 0}, + {(char *) "func_dict", (getter)__Pyx_CyFunction_get_dict, (setter)__Pyx_CyFunction_set_dict, 0, 0}, + {(char *) "__dict__", (getter)__Pyx_CyFunction_get_dict, (setter)__Pyx_CyFunction_set_dict, 0, 0}, + {(char *) "func_globals", (getter)__Pyx_CyFunction_get_globals, 0, 0, 0}, + {(char *) "__globals__", (getter)__Pyx_CyFunction_get_globals, 0, 0, 0}, + {(char *) "func_closure", (getter)__Pyx_CyFunction_get_closure, 0, 0, 0}, + {(char *) "__closure__", (getter)__Pyx_CyFunction_get_closure, 0, 0, 0}, + {(char *) "func_code", (getter)__Pyx_CyFunction_get_code, 0, 0, 0}, + {(char *) "__code__", (getter)__Pyx_CyFunction_get_code, 0, 0, 0}, + {(char *) "func_defaults", (getter)__Pyx_CyFunction_get_defaults, (setter)__Pyx_CyFunction_set_defaults, 0, 0}, + {(char *) "__defaults__", (getter)__Pyx_CyFunction_get_defaults, (setter)__Pyx_CyFunction_set_defaults, 0, 0}, + {(char *) "__kwdefaults__", (getter)__Pyx_CyFunction_get_kwdefaults, (setter)__Pyx_CyFunction_set_kwdefaults, 0, 0}, + {(char *) "__annotations__", (getter)__Pyx_CyFunction_get_annotations, (setter)__Pyx_CyFunction_set_annotations, 0, 0}, + {(char *) "_is_coroutine", (getter)__Pyx_CyFunction_get_is_coroutine, 0, 0, 0}, +#if CYTHON_COMPILING_IN_LIMITED_API + {"__module__", (getter)__Pyx_CyFunction_get_module, (setter)__Pyx_CyFunction_set_module, 0, 0}, +#endif + {0, 0, 0, 0, 0} +}; +static PyMemberDef __pyx_CyFunction_members[] = { +#if !CYTHON_COMPILING_IN_LIMITED_API + {(char *) "__module__", T_OBJECT, offsetof(PyCFunctionObject, m_module), 0, 0}, +#endif +#if CYTHON_USE_TYPE_SPECS + {(char *) "__dictoffset__", T_PYSSIZET, offsetof(__pyx_CyFunctionObject, func_dict), READONLY, 0}, +#if CYTHON_METH_FASTCALL +#if CYTHON_BACKPORT_VECTORCALL + {(char *) "__vectorcalloffset__", T_PYSSIZET, offsetof(__pyx_CyFunctionObject, func_vectorcall), READONLY, 0}, +#else +#if !CYTHON_COMPILING_IN_LIMITED_API + {(char *) "__vectorcalloffset__", T_PYSSIZET, offsetof(PyCFunctionObject, vectorcall), READONLY, 0}, +#endif +#endif +#endif +#if PY_VERSION_HEX < 0x030500A0 || CYTHON_COMPILING_IN_LIMITED_API + {(char *) "__weaklistoffset__", T_PYSSIZET, offsetof(__pyx_CyFunctionObject, func_weakreflist), READONLY, 0}, +#else + {(char *) "__weaklistoffset__", T_PYSSIZET, offsetof(PyCFunctionObject, m_weakreflist), READONLY, 0}, +#endif +#endif + {0, 0, 0, 0, 0} +}; +static PyObject * +__Pyx_CyFunction_reduce(__pyx_CyFunctionObject *m, PyObject *args) +{ + CYTHON_UNUSED_VAR(args); +#if PY_MAJOR_VERSION >= 3 + Py_INCREF(m->func_qualname); + return m->func_qualname; +#else + return PyString_FromString(((PyCFunctionObject*)m)->m_ml->ml_name); +#endif +} +static PyMethodDef __pyx_CyFunction_methods[] = { + {"__reduce__", (PyCFunction)__Pyx_CyFunction_reduce, METH_VARARGS, 0}, + {0, 0, 0, 0} +}; +#if PY_VERSION_HEX < 0x030500A0 || CYTHON_COMPILING_IN_LIMITED_API +#define __Pyx_CyFunction_weakreflist(cyfunc) ((cyfunc)->func_weakreflist) +#else +#define __Pyx_CyFunction_weakreflist(cyfunc) (((PyCFunctionObject*)cyfunc)->m_weakreflist) +#endif +static PyObject *__Pyx_CyFunction_Init(__pyx_CyFunctionObject *op, PyMethodDef *ml, int flags, PyObject* qualname, + PyObject *closure, PyObject *module, PyObject* globals, PyObject* code) { +#if !CYTHON_COMPILING_IN_LIMITED_API + PyCFunctionObject *cf = (PyCFunctionObject*) op; +#endif + if (unlikely(op == NULL)) + return NULL; +#if CYTHON_COMPILING_IN_LIMITED_API + op->func = PyCFunction_NewEx(ml, (PyObject*)op, module); + if (unlikely(!op->func)) return NULL; +#endif + op->flags = flags; + __Pyx_CyFunction_weakreflist(op) = NULL; +#if !CYTHON_COMPILING_IN_LIMITED_API + cf->m_ml = ml; + cf->m_self = (PyObject *) op; +#endif + Py_XINCREF(closure); + op->func_closure = closure; +#if !CYTHON_COMPILING_IN_LIMITED_API + Py_XINCREF(module); + cf->m_module = module; +#endif + op->func_dict = NULL; + op->func_name = NULL; + Py_INCREF(qualname); + op->func_qualname = qualname; + op->func_doc = NULL; +#if PY_VERSION_HEX < 0x030900B1 || CYTHON_COMPILING_IN_LIMITED_API + op->func_classobj = NULL; +#else + ((PyCMethodObject*)op)->mm_class = NULL; +#endif + op->func_globals = globals; + Py_INCREF(op->func_globals); + Py_XINCREF(code); + op->func_code = code; + op->defaults_pyobjects = 0; + op->defaults_size = 0; + op->defaults = NULL; + op->defaults_tuple = NULL; + op->defaults_kwdict = NULL; + op->defaults_getter = NULL; + op->func_annotations = NULL; + op->func_is_coroutine = NULL; +#if CYTHON_METH_FASTCALL + switch (ml->ml_flags & (METH_VARARGS | METH_FASTCALL | METH_NOARGS | METH_O | METH_KEYWORDS | METH_METHOD)) { + case METH_NOARGS: + __Pyx_CyFunction_func_vectorcall(op) = __Pyx_CyFunction_Vectorcall_NOARGS; + break; + case METH_O: + __Pyx_CyFunction_func_vectorcall(op) = __Pyx_CyFunction_Vectorcall_O; + break; + case METH_METHOD | METH_FASTCALL | METH_KEYWORDS: + __Pyx_CyFunction_func_vectorcall(op) = __Pyx_CyFunction_Vectorcall_FASTCALL_KEYWORDS_METHOD; + break; + case METH_FASTCALL | METH_KEYWORDS: + __Pyx_CyFunction_func_vectorcall(op) = __Pyx_CyFunction_Vectorcall_FASTCALL_KEYWORDS; + break; + case METH_VARARGS | METH_KEYWORDS: + __Pyx_CyFunction_func_vectorcall(op) = NULL; + break; + default: + PyErr_SetString(PyExc_SystemError, "Bad call flags for CyFunction"); + Py_DECREF(op); + return NULL; + } +#endif + return (PyObject *) op; +} +static int +__Pyx_CyFunction_clear(__pyx_CyFunctionObject *m) +{ + Py_CLEAR(m->func_closure); +#if CYTHON_COMPILING_IN_LIMITED_API + Py_CLEAR(m->func); +#else + Py_CLEAR(((PyCFunctionObject*)m)->m_module); +#endif + Py_CLEAR(m->func_dict); + Py_CLEAR(m->func_name); + Py_CLEAR(m->func_qualname); + Py_CLEAR(m->func_doc); + Py_CLEAR(m->func_globals); + Py_CLEAR(m->func_code); +#if !CYTHON_COMPILING_IN_LIMITED_API +#if PY_VERSION_HEX < 0x030900B1 + Py_CLEAR(__Pyx_CyFunction_GetClassObj(m)); +#else + { + PyObject *cls = (PyObject*) ((PyCMethodObject *) (m))->mm_class; + ((PyCMethodObject *) (m))->mm_class = NULL; + Py_XDECREF(cls); + } +#endif +#endif + Py_CLEAR(m->defaults_tuple); + Py_CLEAR(m->defaults_kwdict); + Py_CLEAR(m->func_annotations); + Py_CLEAR(m->func_is_coroutine); + if (m->defaults) { + PyObject **pydefaults = __Pyx_CyFunction_Defaults(PyObject *, m); + int i; + for (i = 0; i < m->defaults_pyobjects; i++) + Py_XDECREF(pydefaults[i]); + PyObject_Free(m->defaults); + m->defaults = NULL; + } + return 0; +} +static void __Pyx__CyFunction_dealloc(__pyx_CyFunctionObject *m) +{ + if (__Pyx_CyFunction_weakreflist(m) != NULL) + PyObject_ClearWeakRefs((PyObject *) m); + __Pyx_CyFunction_clear(m); + __Pyx_PyHeapTypeObject_GC_Del(m); +} +static void __Pyx_CyFunction_dealloc(__pyx_CyFunctionObject *m) +{ + PyObject_GC_UnTrack(m); + __Pyx__CyFunction_dealloc(m); +} +static int __Pyx_CyFunction_traverse(__pyx_CyFunctionObject *m, visitproc visit, void *arg) +{ + Py_VISIT(m->func_closure); +#if CYTHON_COMPILING_IN_LIMITED_API + Py_VISIT(m->func); +#else + Py_VISIT(((PyCFunctionObject*)m)->m_module); +#endif + Py_VISIT(m->func_dict); + Py_VISIT(m->func_name); + Py_VISIT(m->func_qualname); + Py_VISIT(m->func_doc); + Py_VISIT(m->func_globals); + Py_VISIT(m->func_code); +#if !CYTHON_COMPILING_IN_LIMITED_API + Py_VISIT(__Pyx_CyFunction_GetClassObj(m)); +#endif + Py_VISIT(m->defaults_tuple); + Py_VISIT(m->defaults_kwdict); + Py_VISIT(m->func_is_coroutine); + if (m->defaults) { + PyObject **pydefaults = __Pyx_CyFunction_Defaults(PyObject *, m); + int i; + for (i = 0; i < m->defaults_pyobjects; i++) + Py_VISIT(pydefaults[i]); + } + return 0; +} +static PyObject* +__Pyx_CyFunction_repr(__pyx_CyFunctionObject *op) +{ +#if PY_MAJOR_VERSION >= 3 + return PyUnicode_FromFormat("", + op->func_qualname, (void *)op); +#else + return PyString_FromFormat("", + PyString_AsString(op->func_qualname), (void *)op); +#endif +} +static PyObject * __Pyx_CyFunction_CallMethod(PyObject *func, PyObject *self, PyObject *arg, PyObject *kw) { +#if CYTHON_COMPILING_IN_LIMITED_API + PyObject *f = ((__pyx_CyFunctionObject*)func)->func; + PyObject *py_name = NULL; + PyCFunction meth; + int flags; + meth = PyCFunction_GetFunction(f); + if (unlikely(!meth)) return NULL; + flags = PyCFunction_GetFlags(f); + if (unlikely(flags < 0)) return NULL; +#else + PyCFunctionObject* f = (PyCFunctionObject*)func; + PyCFunction meth = f->m_ml->ml_meth; + int flags = f->m_ml->ml_flags; +#endif + Py_ssize_t size; + switch (flags & (METH_VARARGS | METH_KEYWORDS | METH_NOARGS | METH_O)) { + case METH_VARARGS: + if (likely(kw == NULL || PyDict_Size(kw) == 0)) + return (*meth)(self, arg); + break; + case METH_VARARGS | METH_KEYWORDS: + return (*(PyCFunctionWithKeywords)(void*)meth)(self, arg, kw); + case METH_NOARGS: + if (likely(kw == NULL || PyDict_Size(kw) == 0)) { +#if CYTHON_ASSUME_SAFE_MACROS + size = PyTuple_GET_SIZE(arg); +#else + size = PyTuple_Size(arg); + if (unlikely(size < 0)) return NULL; +#endif + if (likely(size == 0)) + return (*meth)(self, NULL); +#if CYTHON_COMPILING_IN_LIMITED_API + py_name = __Pyx_CyFunction_get_name((__pyx_CyFunctionObject*)func, NULL); + if (!py_name) return NULL; + PyErr_Format(PyExc_TypeError, + "%.200S() takes no arguments (%" CYTHON_FORMAT_SSIZE_T "d given)", + py_name, size); + Py_DECREF(py_name); +#else + PyErr_Format(PyExc_TypeError, + "%.200s() takes no arguments (%" CYTHON_FORMAT_SSIZE_T "d given)", + f->m_ml->ml_name, size); +#endif + return NULL; + } + break; + case METH_O: + if (likely(kw == NULL || PyDict_Size(kw) == 0)) { +#if CYTHON_ASSUME_SAFE_MACROS + size = PyTuple_GET_SIZE(arg); +#else + size = PyTuple_Size(arg); + if (unlikely(size < 0)) return NULL; +#endif + if (likely(size == 1)) { + PyObject *result, *arg0; + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + arg0 = PyTuple_GET_ITEM(arg, 0); + #else + arg0 = __Pyx_PySequence_ITEM(arg, 0); if (unlikely(!arg0)) return NULL; + #endif + result = (*meth)(self, arg0); + #if !(CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS) + Py_DECREF(arg0); + #endif + return result; + } +#if CYTHON_COMPILING_IN_LIMITED_API + py_name = __Pyx_CyFunction_get_name((__pyx_CyFunctionObject*)func, NULL); + if (!py_name) return NULL; + PyErr_Format(PyExc_TypeError, + "%.200S() takes exactly one argument (%" CYTHON_FORMAT_SSIZE_T "d given)", + py_name, size); + Py_DECREF(py_name); +#else + PyErr_Format(PyExc_TypeError, + "%.200s() takes exactly one argument (%" CYTHON_FORMAT_SSIZE_T "d given)", + f->m_ml->ml_name, size); +#endif + return NULL; + } + break; + default: + PyErr_SetString(PyExc_SystemError, "Bad call flags for CyFunction"); + return NULL; + } +#if CYTHON_COMPILING_IN_LIMITED_API + py_name = __Pyx_CyFunction_get_name((__pyx_CyFunctionObject*)func, NULL); + if (!py_name) return NULL; + PyErr_Format(PyExc_TypeError, "%.200S() takes no keyword arguments", + py_name); + Py_DECREF(py_name); +#else + PyErr_Format(PyExc_TypeError, "%.200s() takes no keyword arguments", + f->m_ml->ml_name); +#endif + return NULL; +} +static CYTHON_INLINE PyObject *__Pyx_CyFunction_Call(PyObject *func, PyObject *arg, PyObject *kw) { + PyObject *self, *result; +#if CYTHON_COMPILING_IN_LIMITED_API + self = PyCFunction_GetSelf(((__pyx_CyFunctionObject*)func)->func); + if (unlikely(!self) && PyErr_Occurred()) return NULL; +#else + self = ((PyCFunctionObject*)func)->m_self; +#endif + result = __Pyx_CyFunction_CallMethod(func, self, arg, kw); + return result; +} +static PyObject *__Pyx_CyFunction_CallAsMethod(PyObject *func, PyObject *args, PyObject *kw) { + PyObject *result; + __pyx_CyFunctionObject *cyfunc = (__pyx_CyFunctionObject *) func; +#if CYTHON_METH_FASTCALL + __pyx_vectorcallfunc vc = __Pyx_CyFunction_func_vectorcall(cyfunc); + if (vc) { +#if CYTHON_ASSUME_SAFE_MACROS + return __Pyx_PyVectorcall_FastCallDict(func, vc, &PyTuple_GET_ITEM(args, 0), (size_t)PyTuple_GET_SIZE(args), kw); +#else + (void) &__Pyx_PyVectorcall_FastCallDict; + return PyVectorcall_Call(func, args, kw); +#endif + } +#endif + if ((cyfunc->flags & __Pyx_CYFUNCTION_CCLASS) && !(cyfunc->flags & __Pyx_CYFUNCTION_STATICMETHOD)) { + Py_ssize_t argc; + PyObject *new_args; + PyObject *self; +#if CYTHON_ASSUME_SAFE_MACROS + argc = PyTuple_GET_SIZE(args); +#else + argc = PyTuple_Size(args); + if (unlikely(!argc) < 0) return NULL; +#endif + new_args = PyTuple_GetSlice(args, 1, argc); + if (unlikely(!new_args)) + return NULL; + self = PyTuple_GetItem(args, 0); + if (unlikely(!self)) { + Py_DECREF(new_args); +#if PY_MAJOR_VERSION > 2 + PyErr_Format(PyExc_TypeError, + "unbound method %.200S() needs an argument", + cyfunc->func_qualname); +#else + PyErr_SetString(PyExc_TypeError, + "unbound method needs an argument"); +#endif + return NULL; + } + result = __Pyx_CyFunction_CallMethod(func, self, new_args, kw); + Py_DECREF(new_args); + } else { + result = __Pyx_CyFunction_Call(func, args, kw); + } + return result; +} +#if CYTHON_METH_FASTCALL +static CYTHON_INLINE int __Pyx_CyFunction_Vectorcall_CheckArgs(__pyx_CyFunctionObject *cyfunc, Py_ssize_t nargs, PyObject *kwnames) +{ + int ret = 0; + if ((cyfunc->flags & __Pyx_CYFUNCTION_CCLASS) && !(cyfunc->flags & __Pyx_CYFUNCTION_STATICMETHOD)) { + if (unlikely(nargs < 1)) { + PyErr_Format(PyExc_TypeError, "%.200s() needs an argument", + ((PyCFunctionObject*)cyfunc)->m_ml->ml_name); + return -1; + } + ret = 1; + } + if (unlikely(kwnames) && unlikely(PyTuple_GET_SIZE(kwnames))) { + PyErr_Format(PyExc_TypeError, + "%.200s() takes no keyword arguments", ((PyCFunctionObject*)cyfunc)->m_ml->ml_name); + return -1; + } + return ret; +} +static PyObject * __Pyx_CyFunction_Vectorcall_NOARGS(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames) +{ + __pyx_CyFunctionObject *cyfunc = (__pyx_CyFunctionObject *)func; + PyMethodDef* def = ((PyCFunctionObject*)cyfunc)->m_ml; +#if CYTHON_BACKPORT_VECTORCALL + Py_ssize_t nargs = (Py_ssize_t)nargsf; +#else + Py_ssize_t nargs = PyVectorcall_NARGS(nargsf); +#endif + PyObject *self; + switch (__Pyx_CyFunction_Vectorcall_CheckArgs(cyfunc, nargs, kwnames)) { + case 1: + self = args[0]; + args += 1; + nargs -= 1; + break; + case 0: + self = ((PyCFunctionObject*)cyfunc)->m_self; + break; + default: + return NULL; + } + if (unlikely(nargs != 0)) { + PyErr_Format(PyExc_TypeError, + "%.200s() takes no arguments (%" CYTHON_FORMAT_SSIZE_T "d given)", + def->ml_name, nargs); + return NULL; + } + return def->ml_meth(self, NULL); +} +static PyObject * __Pyx_CyFunction_Vectorcall_O(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames) +{ + __pyx_CyFunctionObject *cyfunc = (__pyx_CyFunctionObject *)func; + PyMethodDef* def = ((PyCFunctionObject*)cyfunc)->m_ml; +#if CYTHON_BACKPORT_VECTORCALL + Py_ssize_t nargs = (Py_ssize_t)nargsf; +#else + Py_ssize_t nargs = PyVectorcall_NARGS(nargsf); +#endif + PyObject *self; + switch (__Pyx_CyFunction_Vectorcall_CheckArgs(cyfunc, nargs, kwnames)) { + case 1: + self = args[0]; + args += 1; + nargs -= 1; + break; + case 0: + self = ((PyCFunctionObject*)cyfunc)->m_self; + break; + default: + return NULL; + } + if (unlikely(nargs != 1)) { + PyErr_Format(PyExc_TypeError, + "%.200s() takes exactly one argument (%" CYTHON_FORMAT_SSIZE_T "d given)", + def->ml_name, nargs); + return NULL; + } + return def->ml_meth(self, args[0]); +} +static PyObject * __Pyx_CyFunction_Vectorcall_FASTCALL_KEYWORDS(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames) +{ + __pyx_CyFunctionObject *cyfunc = (__pyx_CyFunctionObject *)func; + PyMethodDef* def = ((PyCFunctionObject*)cyfunc)->m_ml; +#if CYTHON_BACKPORT_VECTORCALL + Py_ssize_t nargs = (Py_ssize_t)nargsf; +#else + Py_ssize_t nargs = PyVectorcall_NARGS(nargsf); +#endif + PyObject *self; + switch (__Pyx_CyFunction_Vectorcall_CheckArgs(cyfunc, nargs, NULL)) { + case 1: + self = args[0]; + args += 1; + nargs -= 1; + break; + case 0: + self = ((PyCFunctionObject*)cyfunc)->m_self; + break; + default: + return NULL; + } + return ((_PyCFunctionFastWithKeywords)(void(*)(void))def->ml_meth)(self, args, nargs, kwnames); +} +static PyObject * __Pyx_CyFunction_Vectorcall_FASTCALL_KEYWORDS_METHOD(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames) +{ + __pyx_CyFunctionObject *cyfunc = (__pyx_CyFunctionObject *)func; + PyMethodDef* def = ((PyCFunctionObject*)cyfunc)->m_ml; + PyTypeObject *cls = (PyTypeObject *) __Pyx_CyFunction_GetClassObj(cyfunc); +#if CYTHON_BACKPORT_VECTORCALL + Py_ssize_t nargs = (Py_ssize_t)nargsf; +#else + Py_ssize_t nargs = PyVectorcall_NARGS(nargsf); +#endif + PyObject *self; + switch (__Pyx_CyFunction_Vectorcall_CheckArgs(cyfunc, nargs, NULL)) { + case 1: + self = args[0]; + args += 1; + nargs -= 1; + break; + case 0: + self = ((PyCFunctionObject*)cyfunc)->m_self; + break; + default: + return NULL; + } + return ((__Pyx_PyCMethod)(void(*)(void))def->ml_meth)(self, cls, args, (size_t)nargs, kwnames); +} +#endif +#if CYTHON_USE_TYPE_SPECS +static PyType_Slot __pyx_CyFunctionType_slots[] = { + {Py_tp_dealloc, (void *)__Pyx_CyFunction_dealloc}, + {Py_tp_repr, (void *)__Pyx_CyFunction_repr}, + {Py_tp_call, (void *)__Pyx_CyFunction_CallAsMethod}, + {Py_tp_traverse, (void *)__Pyx_CyFunction_traverse}, + {Py_tp_clear, (void *)__Pyx_CyFunction_clear}, + {Py_tp_methods, (void *)__pyx_CyFunction_methods}, + {Py_tp_members, (void *)__pyx_CyFunction_members}, + {Py_tp_getset, (void *)__pyx_CyFunction_getsets}, + {Py_tp_descr_get, (void *)__Pyx_PyMethod_New}, + {0, 0}, +}; +static PyType_Spec __pyx_CyFunctionType_spec = { + __PYX_TYPE_MODULE_PREFIX "cython_function_or_method", + sizeof(__pyx_CyFunctionObject), + 0, +#ifdef Py_TPFLAGS_METHOD_DESCRIPTOR + Py_TPFLAGS_METHOD_DESCRIPTOR | +#endif +#if (defined(_Py_TPFLAGS_HAVE_VECTORCALL) && CYTHON_METH_FASTCALL) + _Py_TPFLAGS_HAVE_VECTORCALL | +#endif + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_BASETYPE, + __pyx_CyFunctionType_slots +}; +#else +static PyTypeObject __pyx_CyFunctionType_type = { + PyVarObject_HEAD_INIT(0, 0) + __PYX_TYPE_MODULE_PREFIX "cython_function_or_method", + sizeof(__pyx_CyFunctionObject), + 0, + (destructor) __Pyx_CyFunction_dealloc, +#if !CYTHON_METH_FASTCALL + 0, +#elif CYTHON_BACKPORT_VECTORCALL + (printfunc)offsetof(__pyx_CyFunctionObject, func_vectorcall), +#else + offsetof(PyCFunctionObject, vectorcall), +#endif + 0, + 0, +#if PY_MAJOR_VERSION < 3 + 0, +#else + 0, +#endif + (reprfunc) __Pyx_CyFunction_repr, + 0, + 0, + 0, + 0, + __Pyx_CyFunction_CallAsMethod, + 0, + 0, + 0, + 0, +#ifdef Py_TPFLAGS_METHOD_DESCRIPTOR + Py_TPFLAGS_METHOD_DESCRIPTOR | +#endif +#if defined(_Py_TPFLAGS_HAVE_VECTORCALL) && CYTHON_METH_FASTCALL + _Py_TPFLAGS_HAVE_VECTORCALL | +#endif + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_BASETYPE, + 0, + (traverseproc) __Pyx_CyFunction_traverse, + (inquiry) __Pyx_CyFunction_clear, + 0, +#if PY_VERSION_HEX < 0x030500A0 + offsetof(__pyx_CyFunctionObject, func_weakreflist), +#else + offsetof(PyCFunctionObject, m_weakreflist), +#endif + 0, + 0, + __pyx_CyFunction_methods, + __pyx_CyFunction_members, + __pyx_CyFunction_getsets, + 0, + 0, + __Pyx_PyMethod_New, + 0, + offsetof(__pyx_CyFunctionObject, func_dict), + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, +#if PY_VERSION_HEX >= 0x030400a1 + 0, +#endif +#if PY_VERSION_HEX >= 0x030800b1 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07030800) + 0, +#endif +#if __PYX_NEED_TP_PRINT_SLOT + 0, +#endif +#if PY_VERSION_HEX >= 0x030C0000 + 0, +#endif +#if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX >= 0x03090000 && PY_VERSION_HEX < 0x030a0000 + 0, +#endif +}; +#endif +static int __pyx_CyFunction_init(PyObject *module) { +#if CYTHON_USE_TYPE_SPECS + __pyx_CyFunctionType = __Pyx_FetchCommonTypeFromSpec(module, &__pyx_CyFunctionType_spec, NULL); +#else + CYTHON_UNUSED_VAR(module); + __pyx_CyFunctionType = __Pyx_FetchCommonType(&__pyx_CyFunctionType_type); +#endif + if (unlikely(__pyx_CyFunctionType == NULL)) { + return -1; + } + return 0; +} +static CYTHON_INLINE void *__Pyx_CyFunction_InitDefaults(PyObject *func, size_t size, int pyobjects) { + __pyx_CyFunctionObject *m = (__pyx_CyFunctionObject *) func; + m->defaults = PyObject_Malloc(size); + if (unlikely(!m->defaults)) + return PyErr_NoMemory(); + memset(m->defaults, 0, size); + m->defaults_pyobjects = pyobjects; + m->defaults_size = size; + return m->defaults; +} +static CYTHON_INLINE void __Pyx_CyFunction_SetDefaultsTuple(PyObject *func, PyObject *tuple) { + __pyx_CyFunctionObject *m = (__pyx_CyFunctionObject *) func; + m->defaults_tuple = tuple; + Py_INCREF(tuple); +} +static CYTHON_INLINE void __Pyx_CyFunction_SetDefaultsKwDict(PyObject *func, PyObject *dict) { + __pyx_CyFunctionObject *m = (__pyx_CyFunctionObject *) func; + m->defaults_kwdict = dict; + Py_INCREF(dict); +} +static CYTHON_INLINE void __Pyx_CyFunction_SetAnnotationsDict(PyObject *func, PyObject *dict) { + __pyx_CyFunctionObject *m = (__pyx_CyFunctionObject *) func; + m->func_annotations = dict; + Py_INCREF(dict); +} + +/* CythonFunction */ + static PyObject *__Pyx_CyFunction_New(PyMethodDef *ml, int flags, PyObject* qualname, + PyObject *closure, PyObject *module, PyObject* globals, PyObject* code) { + PyObject *op = __Pyx_CyFunction_Init( + PyObject_GC_New(__pyx_CyFunctionObject, __pyx_CyFunctionType), + ml, flags, qualname, closure, module, globals, code + ); + if (likely(op)) { + PyObject_GC_Track(op); + } + return op; +} + +/* CLineInTraceback */ + #ifndef CYTHON_CLINE_IN_TRACEBACK +static int __Pyx_CLineForTraceback(PyThreadState *tstate, int c_line) { + PyObject *use_cline; + PyObject *ptype, *pvalue, *ptraceback; +#if CYTHON_COMPILING_IN_CPYTHON + PyObject **cython_runtime_dict; +#endif + CYTHON_MAYBE_UNUSED_VAR(tstate); + if (unlikely(!__pyx_cython_runtime)) { + return c_line; + } + __Pyx_ErrFetchInState(tstate, &ptype, &pvalue, &ptraceback); +#if CYTHON_COMPILING_IN_CPYTHON + cython_runtime_dict = _PyObject_GetDictPtr(__pyx_cython_runtime); + if (likely(cython_runtime_dict)) { + __PYX_PY_DICT_LOOKUP_IF_MODIFIED( + use_cline, *cython_runtime_dict, + __Pyx_PyDict_GetItemStr(*cython_runtime_dict, __pyx_n_s_cline_in_traceback)) + } else +#endif + { + PyObject *use_cline_obj = __Pyx_PyObject_GetAttrStrNoError(__pyx_cython_runtime, __pyx_n_s_cline_in_traceback); + if (use_cline_obj) { + use_cline = PyObject_Not(use_cline_obj) ? Py_False : Py_True; + Py_DECREF(use_cline_obj); + } else { + PyErr_Clear(); + use_cline = NULL; + } + } + if (!use_cline) { + c_line = 0; + (void) PyObject_SetAttr(__pyx_cython_runtime, __pyx_n_s_cline_in_traceback, Py_False); + } + else if (use_cline == Py_False || (use_cline != Py_True && PyObject_Not(use_cline) != 0)) { + c_line = 0; + } + __Pyx_ErrRestoreInState(tstate, ptype, pvalue, ptraceback); + return c_line; +} +#endif + +/* CodeObjectCache */ + #if !CYTHON_COMPILING_IN_LIMITED_API +static int __pyx_bisect_code_objects(__Pyx_CodeObjectCacheEntry* entries, int count, int code_line) { + int start = 0, mid = 0, end = count - 1; + if (end >= 0 && code_line > entries[end].code_line) { + return count; + } + while (start < end) { + mid = start + (end - start) / 2; + if (code_line < entries[mid].code_line) { + end = mid; + } else if (code_line > entries[mid].code_line) { + start = mid + 1; + } else { + return mid; + } + } + if (code_line <= entries[mid].code_line) { + return mid; + } else { + return mid + 1; + } +} +static PyCodeObject *__pyx_find_code_object(int code_line) { + PyCodeObject* code_object; + int pos; + if (unlikely(!code_line) || unlikely(!__pyx_code_cache.entries)) { + return NULL; + } + pos = __pyx_bisect_code_objects(__pyx_code_cache.entries, __pyx_code_cache.count, code_line); + if (unlikely(pos >= __pyx_code_cache.count) || unlikely(__pyx_code_cache.entries[pos].code_line != code_line)) { + return NULL; + } + code_object = __pyx_code_cache.entries[pos].code_object; + Py_INCREF(code_object); + return code_object; +} +static void __pyx_insert_code_object(int code_line, PyCodeObject* code_object) { + int pos, i; + __Pyx_CodeObjectCacheEntry* entries = __pyx_code_cache.entries; + if (unlikely(!code_line)) { + return; + } + if (unlikely(!entries)) { + entries = (__Pyx_CodeObjectCacheEntry*)PyMem_Malloc(64*sizeof(__Pyx_CodeObjectCacheEntry)); + if (likely(entries)) { + __pyx_code_cache.entries = entries; + __pyx_code_cache.max_count = 64; + __pyx_code_cache.count = 1; + entries[0].code_line = code_line; + entries[0].code_object = code_object; + Py_INCREF(code_object); + } + return; + } + pos = __pyx_bisect_code_objects(__pyx_code_cache.entries, __pyx_code_cache.count, code_line); + if ((pos < __pyx_code_cache.count) && unlikely(__pyx_code_cache.entries[pos].code_line == code_line)) { + PyCodeObject* tmp = entries[pos].code_object; + entries[pos].code_object = code_object; + Py_DECREF(tmp); + return; + } + if (__pyx_code_cache.count == __pyx_code_cache.max_count) { + int new_max = __pyx_code_cache.max_count + 64; + entries = (__Pyx_CodeObjectCacheEntry*)PyMem_Realloc( + __pyx_code_cache.entries, ((size_t)new_max) * sizeof(__Pyx_CodeObjectCacheEntry)); + if (unlikely(!entries)) { + return; + } + __pyx_code_cache.entries = entries; + __pyx_code_cache.max_count = new_max; + } + for (i=__pyx_code_cache.count; i>pos; i--) { + entries[i] = entries[i-1]; + } + entries[pos].code_line = code_line; + entries[pos].code_object = code_object; + __pyx_code_cache.count++; + Py_INCREF(code_object); +} +#endif + +/* AddTraceback */ + #include "compile.h" +#include "frameobject.h" +#include "traceback.h" +#if PY_VERSION_HEX >= 0x030b00a6 && !CYTHON_COMPILING_IN_LIMITED_API + #ifndef Py_BUILD_CORE + #define Py_BUILD_CORE 1 + #endif + #include "internal/pycore_frame.h" +#endif +#if CYTHON_COMPILING_IN_LIMITED_API +static PyObject *__Pyx_PyCode_Replace_For_AddTraceback(PyObject *code, PyObject *scratch_dict, + PyObject *firstlineno, PyObject *name) { + PyObject *replace = NULL; + if (unlikely(PyDict_SetItemString(scratch_dict, "co_firstlineno", firstlineno))) return NULL; + if (unlikely(PyDict_SetItemString(scratch_dict, "co_name", name))) return NULL; + replace = PyObject_GetAttrString(code, "replace"); + if (likely(replace)) { + PyObject *result; + result = PyObject_Call(replace, __pyx_empty_tuple, scratch_dict); + Py_DECREF(replace); + return result; + } + PyErr_Clear(); + #if __PYX_LIMITED_VERSION_HEX < 0x030780000 + { + PyObject *compiled = NULL, *result = NULL; + if (unlikely(PyDict_SetItemString(scratch_dict, "code", code))) return NULL; + if (unlikely(PyDict_SetItemString(scratch_dict, "type", (PyObject*)(&PyType_Type)))) return NULL; + compiled = Py_CompileString( + "out = type(code)(\n" + " code.co_argcount, code.co_kwonlyargcount, code.co_nlocals, code.co_stacksize,\n" + " code.co_flags, code.co_code, code.co_consts, code.co_names,\n" + " code.co_varnames, code.co_filename, co_name, co_firstlineno,\n" + " code.co_lnotab)\n", "", Py_file_input); + if (!compiled) return NULL; + result = PyEval_EvalCode(compiled, scratch_dict, scratch_dict); + Py_DECREF(compiled); + if (!result) PyErr_Print(); + Py_DECREF(result); + result = PyDict_GetItemString(scratch_dict, "out"); + if (result) Py_INCREF(result); + return result; + } + #else + return NULL; + #endif +} +static void __Pyx_AddTraceback(const char *funcname, int c_line, + int py_line, const char *filename) { + PyObject *code_object = NULL, *py_py_line = NULL, *py_funcname = NULL, *dict = NULL; + PyObject *replace = NULL, *getframe = NULL, *frame = NULL; + PyObject *exc_type, *exc_value, *exc_traceback; + int success = 0; + if (c_line) { + (void) __pyx_cfilenm; + (void) __Pyx_CLineForTraceback(__Pyx_PyThreadState_Current, c_line); + } + PyErr_Fetch(&exc_type, &exc_value, &exc_traceback); + code_object = Py_CompileString("_getframe()", filename, Py_eval_input); + if (unlikely(!code_object)) goto bad; + py_py_line = PyLong_FromLong(py_line); + if (unlikely(!py_py_line)) goto bad; + py_funcname = PyUnicode_FromString(funcname); + if (unlikely(!py_funcname)) goto bad; + dict = PyDict_New(); + if (unlikely(!dict)) goto bad; + { + PyObject *old_code_object = code_object; + code_object = __Pyx_PyCode_Replace_For_AddTraceback(code_object, dict, py_py_line, py_funcname); + Py_DECREF(old_code_object); + } + if (unlikely(!code_object)) goto bad; + getframe = PySys_GetObject("_getframe"); + if (unlikely(!getframe)) goto bad; + if (unlikely(PyDict_SetItemString(dict, "_getframe", getframe))) goto bad; + frame = PyEval_EvalCode(code_object, dict, dict); + if (unlikely(!frame) || frame == Py_None) goto bad; + success = 1; + bad: + PyErr_Restore(exc_type, exc_value, exc_traceback); + Py_XDECREF(code_object); + Py_XDECREF(py_py_line); + Py_XDECREF(py_funcname); + Py_XDECREF(dict); + Py_XDECREF(replace); + if (success) { + PyTraceBack_Here( + (struct _frame*)frame); + } + Py_XDECREF(frame); +} +#else +static PyCodeObject* __Pyx_CreateCodeObjectForTraceback( + const char *funcname, int c_line, + int py_line, const char *filename) { + PyCodeObject *py_code = NULL; + PyObject *py_funcname = NULL; + #if PY_MAJOR_VERSION < 3 + PyObject *py_srcfile = NULL; + py_srcfile = PyString_FromString(filename); + if (!py_srcfile) goto bad; + #endif + if (c_line) { + #if PY_MAJOR_VERSION < 3 + py_funcname = PyString_FromFormat( "%s (%s:%d)", funcname, __pyx_cfilenm, c_line); + if (!py_funcname) goto bad; + #else + py_funcname = PyUnicode_FromFormat( "%s (%s:%d)", funcname, __pyx_cfilenm, c_line); + if (!py_funcname) goto bad; + funcname = PyUnicode_AsUTF8(py_funcname); + if (!funcname) goto bad; + #endif + } + else { + #if PY_MAJOR_VERSION < 3 + py_funcname = PyString_FromString(funcname); + if (!py_funcname) goto bad; + #endif + } + #if PY_MAJOR_VERSION < 3 + py_code = __Pyx_PyCode_New( + 0, + 0, + 0, + 0, + 0, + 0, + __pyx_empty_bytes, /*PyObject *code,*/ + __pyx_empty_tuple, /*PyObject *consts,*/ + __pyx_empty_tuple, /*PyObject *names,*/ + __pyx_empty_tuple, /*PyObject *varnames,*/ + __pyx_empty_tuple, /*PyObject *freevars,*/ + __pyx_empty_tuple, /*PyObject *cellvars,*/ + py_srcfile, /*PyObject *filename,*/ + py_funcname, /*PyObject *name,*/ + py_line, + __pyx_empty_bytes /*PyObject *lnotab*/ + ); + Py_DECREF(py_srcfile); + #else + py_code = PyCode_NewEmpty(filename, funcname, py_line); + #endif + Py_XDECREF(py_funcname); + return py_code; +bad: + Py_XDECREF(py_funcname); + #if PY_MAJOR_VERSION < 3 + Py_XDECREF(py_srcfile); + #endif + return NULL; +} +static void __Pyx_AddTraceback(const char *funcname, int c_line, + int py_line, const char *filename) { + PyCodeObject *py_code = 0; + PyFrameObject *py_frame = 0; + PyThreadState *tstate = __Pyx_PyThreadState_Current; + PyObject *ptype, *pvalue, *ptraceback; + if (c_line) { + c_line = __Pyx_CLineForTraceback(tstate, c_line); + } + py_code = __pyx_find_code_object(c_line ? -c_line : py_line); + if (!py_code) { + __Pyx_ErrFetchInState(tstate, &ptype, &pvalue, &ptraceback); + py_code = __Pyx_CreateCodeObjectForTraceback( + funcname, c_line, py_line, filename); + if (!py_code) { + /* If the code object creation fails, then we should clear the + fetched exception references and propagate the new exception */ + Py_XDECREF(ptype); + Py_XDECREF(pvalue); + Py_XDECREF(ptraceback); + goto bad; + } + __Pyx_ErrRestoreInState(tstate, ptype, pvalue, ptraceback); + __pyx_insert_code_object(c_line ? -c_line : py_line, py_code); + } + py_frame = PyFrame_New( + tstate, /*PyThreadState *tstate,*/ + py_code, /*PyCodeObject *code,*/ + __pyx_d, /*PyObject *globals,*/ + 0 /*PyObject *locals*/ + ); + if (!py_frame) goto bad; + __Pyx_PyFrame_SetLineNumber(py_frame, py_line); + PyTraceBack_Here(py_frame); +bad: + Py_XDECREF(py_code); + Py_XDECREF(py_frame); +} +#endif + +#if PY_MAJOR_VERSION < 3 +static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags) { + __Pyx_TypeName obj_type_name; + if (PyObject_CheckBuffer(obj)) return PyObject_GetBuffer(obj, view, flags); + if (__Pyx_TypeCheck(obj, __pyx_array_type)) return __pyx_array_getbuffer(obj, view, flags); + if (__Pyx_TypeCheck(obj, __pyx_memoryview_type)) return __pyx_memoryview_getbuffer(obj, view, flags); + obj_type_name = __Pyx_PyType_GetName(Py_TYPE(obj)); + PyErr_Format(PyExc_TypeError, + "'" __Pyx_FMT_TYPENAME "' does not have the buffer interface", + obj_type_name); + __Pyx_DECREF_TypeName(obj_type_name); + return -1; +} +static void __Pyx_ReleaseBuffer(Py_buffer *view) { + PyObject *obj = view->obj; + if (!obj) return; + if (PyObject_CheckBuffer(obj)) { + PyBuffer_Release(view); + return; + } + if ((0)) {} + view->obj = NULL; + Py_DECREF(obj); +} +#endif + + + /* MemviewSliceIsContig */ + static int +__pyx_memviewslice_is_contig(const __Pyx_memviewslice mvs, char order, int ndim) +{ + int i, index, step, start; + Py_ssize_t itemsize = mvs.memview->view.itemsize; + if (order == 'F') { + step = 1; + start = 0; + } else { + step = -1; + start = ndim - 1; + } + for (i = 0; i < ndim; i++) { + index = start + step * i; + if (mvs.suboffsets[index] >= 0 || mvs.strides[index] != itemsize) + return 0; + itemsize *= mvs.shape[index]; + } + return 1; +} + +/* OverlappingSlices */ + static void +__pyx_get_array_memory_extents(__Pyx_memviewslice *slice, + void **out_start, void **out_end, + int ndim, size_t itemsize) +{ + char *start, *end; + int i; + start = end = slice->data; + for (i = 0; i < ndim; i++) { + Py_ssize_t stride = slice->strides[i]; + Py_ssize_t extent = slice->shape[i]; + if (extent == 0) { + *out_start = *out_end = start; + return; + } else { + if (stride > 0) + end += stride * (extent - 1); + else + start += stride * (extent - 1); + } + } + *out_start = start; + *out_end = end + itemsize; +} +static int +__pyx_slices_overlap(__Pyx_memviewslice *slice1, + __Pyx_memviewslice *slice2, + int ndim, size_t itemsize) +{ + void *start1, *end1, *start2, *end2; + __pyx_get_array_memory_extents(slice1, &start1, &end1, ndim, itemsize); + __pyx_get_array_memory_extents(slice2, &start2, &end2, ndim, itemsize); + return (start1 < end2) && (start2 < end1); +} + +/* CIntFromPyVerify */ + #define __PYX_VERIFY_RETURN_INT(target_type, func_type, func_value)\ + __PYX__VERIFY_RETURN_INT(target_type, func_type, func_value, 0) +#define __PYX_VERIFY_RETURN_INT_EXC(target_type, func_type, func_value)\ + __PYX__VERIFY_RETURN_INT(target_type, func_type, func_value, 1) +#define __PYX__VERIFY_RETURN_INT(target_type, func_type, func_value, exc)\ + {\ + func_type value = func_value;\ + if (sizeof(target_type) < sizeof(func_type)) {\ + if (unlikely(value != (func_type) (target_type) value)) {\ + func_type zero = 0;\ + if (exc && unlikely(value == (func_type)-1 && PyErr_Occurred()))\ + return (target_type) -1;\ + if (is_unsigned && unlikely(value < zero))\ + goto raise_neg_overflow;\ + else\ + goto raise_overflow;\ + }\ + }\ + return (target_type) value;\ + } + +/* MemviewDtypeToObject */ + static CYTHON_INLINE PyObject *__pyx_memview_get_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t(const char *itemp) { + return (PyObject *) __Pyx_PyInt_From_int64_t(*(__pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t *) itemp); +} +static CYTHON_INLINE int __pyx_memview_set_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t(const char *itemp, PyObject *obj) { + __pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t value = __Pyx_PyInt_As_int64_t(obj); + if (unlikely((value == ((int64_t)-1)) && PyErr_Occurred())) + return 0; + *(__pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t *) itemp = value; + return 1; +} + +/* TypeInfoCompare */ + static int +__pyx_typeinfo_cmp(__Pyx_TypeInfo *a, __Pyx_TypeInfo *b) +{ + int i; + if (!a || !b) + return 0; + if (a == b) + return 1; + if (a->size != b->size || a->typegroup != b->typegroup || + a->is_unsigned != b->is_unsigned || a->ndim != b->ndim) { + if (a->typegroup == 'H' || b->typegroup == 'H') { + return a->size == b->size; + } else { + return 0; + } + } + if (a->ndim) { + for (i = 0; i < a->ndim; i++) + if (a->arraysize[i] != b->arraysize[i]) + return 0; + } + if (a->typegroup == 'S') { + if (a->flags != b->flags) + return 0; + if (a->fields || b->fields) { + if (!(a->fields && b->fields)) + return 0; + for (i = 0; a->fields[i].type && b->fields[i].type; i++) { + __Pyx_StructField *field_a = a->fields + i; + __Pyx_StructField *field_b = b->fields + i; + if (field_a->offset != field_b->offset || + !__pyx_typeinfo_cmp(field_a->type, field_b->type)) + return 0; + } + return !a->fields[i].type && !b->fields[i].type; + } + } + return 1; +} + +/* MemviewSliceValidateAndInit */ + static int +__pyx_check_strides(Py_buffer *buf, int dim, int ndim, int spec) +{ + if (buf->shape[dim] <= 1) + return 1; + if (buf->strides) { + if (spec & __Pyx_MEMVIEW_CONTIG) { + if (spec & (__Pyx_MEMVIEW_PTR|__Pyx_MEMVIEW_FULL)) { + if (unlikely(buf->strides[dim] != sizeof(void *))) { + PyErr_Format(PyExc_ValueError, + "Buffer is not indirectly contiguous " + "in dimension %d.", dim); + goto fail; + } + } else if (unlikely(buf->strides[dim] != buf->itemsize)) { + PyErr_SetString(PyExc_ValueError, + "Buffer and memoryview are not contiguous " + "in the same dimension."); + goto fail; + } + } + if (spec & __Pyx_MEMVIEW_FOLLOW) { + Py_ssize_t stride = buf->strides[dim]; + if (stride < 0) + stride = -stride; + if (unlikely(stride < buf->itemsize)) { + PyErr_SetString(PyExc_ValueError, + "Buffer and memoryview are not contiguous " + "in the same dimension."); + goto fail; + } + } + } else { + if (unlikely(spec & __Pyx_MEMVIEW_CONTIG && dim != ndim - 1)) { + PyErr_Format(PyExc_ValueError, + "C-contiguous buffer is not contiguous in " + "dimension %d", dim); + goto fail; + } else if (unlikely(spec & (__Pyx_MEMVIEW_PTR))) { + PyErr_Format(PyExc_ValueError, + "C-contiguous buffer is not indirect in " + "dimension %d", dim); + goto fail; + } else if (unlikely(buf->suboffsets)) { + PyErr_SetString(PyExc_ValueError, + "Buffer exposes suboffsets but no strides"); + goto fail; + } + } + return 1; +fail: + return 0; +} +static int +__pyx_check_suboffsets(Py_buffer *buf, int dim, int ndim, int spec) +{ + CYTHON_UNUSED_VAR(ndim); + if (spec & __Pyx_MEMVIEW_DIRECT) { + if (unlikely(buf->suboffsets && buf->suboffsets[dim] >= 0)) { + PyErr_Format(PyExc_ValueError, + "Buffer not compatible with direct access " + "in dimension %d.", dim); + goto fail; + } + } + if (spec & __Pyx_MEMVIEW_PTR) { + if (unlikely(!buf->suboffsets || (buf->suboffsets[dim] < 0))) { + PyErr_Format(PyExc_ValueError, + "Buffer is not indirectly accessible " + "in dimension %d.", dim); + goto fail; + } + } + return 1; +fail: + return 0; +} +static int +__pyx_verify_contig(Py_buffer *buf, int ndim, int c_or_f_flag) +{ + int i; + if (c_or_f_flag & __Pyx_IS_F_CONTIG) { + Py_ssize_t stride = 1; + for (i = 0; i < ndim; i++) { + if (unlikely(stride * buf->itemsize != buf->strides[i] && buf->shape[i] > 1)) { + PyErr_SetString(PyExc_ValueError, + "Buffer not fortran contiguous."); + goto fail; + } + stride = stride * buf->shape[i]; + } + } else if (c_or_f_flag & __Pyx_IS_C_CONTIG) { + Py_ssize_t stride = 1; + for (i = ndim - 1; i >- 1; i--) { + if (unlikely(stride * buf->itemsize != buf->strides[i] && buf->shape[i] > 1)) { + PyErr_SetString(PyExc_ValueError, + "Buffer not C contiguous."); + goto fail; + } + stride = stride * buf->shape[i]; + } + } + return 1; +fail: + return 0; +} +static int __Pyx_ValidateAndInit_memviewslice( + int *axes_specs, + int c_or_f_flag, + int buf_flags, + int ndim, + __Pyx_TypeInfo *dtype, + __Pyx_BufFmt_StackElem stack[], + __Pyx_memviewslice *memviewslice, + PyObject *original_obj) +{ + struct __pyx_memoryview_obj *memview, *new_memview; + __Pyx_RefNannyDeclarations + Py_buffer *buf; + int i, spec = 0, retval = -1; + __Pyx_BufFmt_Context ctx; + int from_memoryview = __pyx_memoryview_check(original_obj); + __Pyx_RefNannySetupContext("ValidateAndInit_memviewslice", 0); + if (from_memoryview && __pyx_typeinfo_cmp(dtype, ((struct __pyx_memoryview_obj *) + original_obj)->typeinfo)) { + memview = (struct __pyx_memoryview_obj *) original_obj; + new_memview = NULL; + } else { + memview = (struct __pyx_memoryview_obj *) __pyx_memoryview_new( + original_obj, buf_flags, 0, dtype); + new_memview = memview; + if (unlikely(!memview)) + goto fail; + } + buf = &memview->view; + if (unlikely(buf->ndim != ndim)) { + PyErr_Format(PyExc_ValueError, + "Buffer has wrong number of dimensions (expected %d, got %d)", + ndim, buf->ndim); + goto fail; + } + if (new_memview) { + __Pyx_BufFmt_Init(&ctx, stack, dtype); + if (unlikely(!__Pyx_BufFmt_CheckString(&ctx, buf->format))) goto fail; + } + if (unlikely((unsigned) buf->itemsize != dtype->size)) { + PyErr_Format(PyExc_ValueError, + "Item size of buffer (%" CYTHON_FORMAT_SSIZE_T "u byte%s) " + "does not match size of '%s' (%" CYTHON_FORMAT_SSIZE_T "u byte%s)", + buf->itemsize, + (buf->itemsize > 1) ? "s" : "", + dtype->name, + dtype->size, + (dtype->size > 1) ? "s" : ""); + goto fail; + } + if (buf->len > 0) { + for (i = 0; i < ndim; i++) { + spec = axes_specs[i]; + if (unlikely(!__pyx_check_strides(buf, i, ndim, spec))) + goto fail; + if (unlikely(!__pyx_check_suboffsets(buf, i, ndim, spec))) + goto fail; + } + if (unlikely(buf->strides && !__pyx_verify_contig(buf, ndim, c_or_f_flag))) + goto fail; + } + if (unlikely(__Pyx_init_memviewslice(memview, ndim, memviewslice, + new_memview != NULL) == -1)) { + goto fail; + } + retval = 0; + goto no_fail; +fail: + Py_XDECREF(new_memview); + retval = -1; +no_fail: + __Pyx_RefNannyFinishContext(); + return retval; +} + +/* ObjectToMemviewSlice */ + static CYTHON_INLINE __Pyx_memviewslice __Pyx_PyObject_to_MemoryviewSlice_ds_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t(PyObject *obj, int writable_flag) { + __Pyx_memviewslice result = { 0, 0, { 0 }, { 0 }, { 0 } }; + __Pyx_BufFmt_StackElem stack[1]; + int axes_specs[] = { (__Pyx_MEMVIEW_DIRECT | __Pyx_MEMVIEW_STRIDED) }; + int retcode; + if (obj == Py_None) { + result.memview = (struct __pyx_memoryview_obj *) Py_None; + return result; + } + retcode = __Pyx_ValidateAndInit_memviewslice(axes_specs, 0, + PyBUF_RECORDS_RO | writable_flag, 1, + &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t, stack, + &result, obj); + if (unlikely(retcode == -1)) + goto __pyx_fail; + return result; +__pyx_fail: + result.memview = NULL; + result.data = NULL; + return result; +} + +/* ObjectToMemviewSlice */ + static CYTHON_INLINE __Pyx_memviewslice __Pyx_PyObject_to_MemoryviewSlice_dsds_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t(PyObject *obj, int writable_flag) { + __Pyx_memviewslice result = { 0, 0, { 0 }, { 0 }, { 0 } }; + __Pyx_BufFmt_StackElem stack[1]; + int axes_specs[] = { (__Pyx_MEMVIEW_DIRECT | __Pyx_MEMVIEW_STRIDED), (__Pyx_MEMVIEW_DIRECT | __Pyx_MEMVIEW_STRIDED) }; + int retcode; + if (obj == Py_None) { + result.memview = (struct __pyx_memoryview_obj *) Py_None; + return result; + } + retcode = __Pyx_ValidateAndInit_memviewslice(axes_specs, 0, + PyBUF_RECORDS_RO | writable_flag, 2, + &__Pyx_TypeInfo_nn___pyx_t_7fairseq_4data_22token_block_utils_fast_DTYPE_t, stack, + &result, obj); + if (unlikely(retcode == -1)) + goto __pyx_fail; + return result; +__pyx_fail: + result.memview = NULL; + result.data = NULL; + return result; +} + +/* Declarations */ + #if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) + #ifdef __cplusplus + static CYTHON_INLINE __pyx_t_float_complex __pyx_t_float_complex_from_parts(float x, float y) { + return ::std::complex< float >(x, y); + } + #else + static CYTHON_INLINE __pyx_t_float_complex __pyx_t_float_complex_from_parts(float x, float y) { + return x + y*(__pyx_t_float_complex)_Complex_I; + } + #endif +#else + static CYTHON_INLINE __pyx_t_float_complex __pyx_t_float_complex_from_parts(float x, float y) { + __pyx_t_float_complex z; + z.real = x; + z.imag = y; + return z; + } +#endif + +/* Arithmetic */ + #if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) +#else + static CYTHON_INLINE int __Pyx_c_eq_float(__pyx_t_float_complex a, __pyx_t_float_complex b) { + return (a.real == b.real) && (a.imag == b.imag); + } + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_sum_float(__pyx_t_float_complex a, __pyx_t_float_complex b) { + __pyx_t_float_complex z; + z.real = a.real + b.real; + z.imag = a.imag + b.imag; + return z; + } + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_diff_float(__pyx_t_float_complex a, __pyx_t_float_complex b) { + __pyx_t_float_complex z; + z.real = a.real - b.real; + z.imag = a.imag - b.imag; + return z; + } + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_prod_float(__pyx_t_float_complex a, __pyx_t_float_complex b) { + __pyx_t_float_complex z; + z.real = a.real * b.real - a.imag * b.imag; + z.imag = a.real * b.imag + a.imag * b.real; + return z; + } + #if 1 + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_quot_float(__pyx_t_float_complex a, __pyx_t_float_complex b) { + if (b.imag == 0) { + return __pyx_t_float_complex_from_parts(a.real / b.real, a.imag / b.real); + } else if (fabsf(b.real) >= fabsf(b.imag)) { + if (b.real == 0 && b.imag == 0) { + return __pyx_t_float_complex_from_parts(a.real / b.real, a.imag / b.imag); + } else { + float r = b.imag / b.real; + float s = (float)(1.0) / (b.real + b.imag * r); + return __pyx_t_float_complex_from_parts( + (a.real + a.imag * r) * s, (a.imag - a.real * r) * s); + } + } else { + float r = b.real / b.imag; + float s = (float)(1.0) / (b.imag + b.real * r); + return __pyx_t_float_complex_from_parts( + (a.real * r + a.imag) * s, (a.imag * r - a.real) * s); + } + } + #else + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_quot_float(__pyx_t_float_complex a, __pyx_t_float_complex b) { + if (b.imag == 0) { + return __pyx_t_float_complex_from_parts(a.real / b.real, a.imag / b.real); + } else { + float denom = b.real * b.real + b.imag * b.imag; + return __pyx_t_float_complex_from_parts( + (a.real * b.real + a.imag * b.imag) / denom, + (a.imag * b.real - a.real * b.imag) / denom); + } + } + #endif + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_neg_float(__pyx_t_float_complex a) { + __pyx_t_float_complex z; + z.real = -a.real; + z.imag = -a.imag; + return z; + } + static CYTHON_INLINE int __Pyx_c_is_zero_float(__pyx_t_float_complex a) { + return (a.real == 0) && (a.imag == 0); + } + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_conj_float(__pyx_t_float_complex a) { + __pyx_t_float_complex z; + z.real = a.real; + z.imag = -a.imag; + return z; + } + #if 1 + static CYTHON_INLINE float __Pyx_c_abs_float(__pyx_t_float_complex z) { + #if !defined(HAVE_HYPOT) || defined(_MSC_VER) + return sqrtf(z.real*z.real + z.imag*z.imag); + #else + return hypotf(z.real, z.imag); + #endif + } + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_pow_float(__pyx_t_float_complex a, __pyx_t_float_complex b) { + __pyx_t_float_complex z; + float r, lnr, theta, z_r, z_theta; + if (b.imag == 0 && b.real == (int)b.real) { + if (b.real < 0) { + float denom = a.real * a.real + a.imag * a.imag; + a.real = a.real / denom; + a.imag = -a.imag / denom; + b.real = -b.real; + } + switch ((int)b.real) { + case 0: + z.real = 1; + z.imag = 0; + return z; + case 1: + return a; + case 2: + return __Pyx_c_prod_float(a, a); + case 3: + z = __Pyx_c_prod_float(a, a); + return __Pyx_c_prod_float(z, a); + case 4: + z = __Pyx_c_prod_float(a, a); + return __Pyx_c_prod_float(z, z); + } + } + if (a.imag == 0) { + if (a.real == 0) { + return a; + } else if ((b.imag == 0) && (a.real >= 0)) { + z.real = powf(a.real, b.real); + z.imag = 0; + return z; + } else if (a.real > 0) { + r = a.real; + theta = 0; + } else { + r = -a.real; + theta = atan2f(0.0, -1.0); + } + } else { + r = __Pyx_c_abs_float(a); + theta = atan2f(a.imag, a.real); + } + lnr = logf(r); + z_r = expf(lnr * b.real - theta * b.imag); + z_theta = theta * b.real + lnr * b.imag; + z.real = z_r * cosf(z_theta); + z.imag = z_r * sinf(z_theta); + return z; + } + #endif +#endif + +/* Declarations */ + #if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) + #ifdef __cplusplus + static CYTHON_INLINE __pyx_t_double_complex __pyx_t_double_complex_from_parts(double x, double y) { + return ::std::complex< double >(x, y); + } + #else + static CYTHON_INLINE __pyx_t_double_complex __pyx_t_double_complex_from_parts(double x, double y) { + return x + y*(__pyx_t_double_complex)_Complex_I; + } + #endif +#else + static CYTHON_INLINE __pyx_t_double_complex __pyx_t_double_complex_from_parts(double x, double y) { + __pyx_t_double_complex z; + z.real = x; + z.imag = y; + return z; + } +#endif + +/* Arithmetic */ + #if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) +#else + static CYTHON_INLINE int __Pyx_c_eq_double(__pyx_t_double_complex a, __pyx_t_double_complex b) { + return (a.real == b.real) && (a.imag == b.imag); + } + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_sum_double(__pyx_t_double_complex a, __pyx_t_double_complex b) { + __pyx_t_double_complex z; + z.real = a.real + b.real; + z.imag = a.imag + b.imag; + return z; + } + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_diff_double(__pyx_t_double_complex a, __pyx_t_double_complex b) { + __pyx_t_double_complex z; + z.real = a.real - b.real; + z.imag = a.imag - b.imag; + return z; + } + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_prod_double(__pyx_t_double_complex a, __pyx_t_double_complex b) { + __pyx_t_double_complex z; + z.real = a.real * b.real - a.imag * b.imag; + z.imag = a.real * b.imag + a.imag * b.real; + return z; + } + #if 1 + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_quot_double(__pyx_t_double_complex a, __pyx_t_double_complex b) { + if (b.imag == 0) { + return __pyx_t_double_complex_from_parts(a.real / b.real, a.imag / b.real); + } else if (fabs(b.real) >= fabs(b.imag)) { + if (b.real == 0 && b.imag == 0) { + return __pyx_t_double_complex_from_parts(a.real / b.real, a.imag / b.imag); + } else { + double r = b.imag / b.real; + double s = (double)(1.0) / (b.real + b.imag * r); + return __pyx_t_double_complex_from_parts( + (a.real + a.imag * r) * s, (a.imag - a.real * r) * s); + } + } else { + double r = b.real / b.imag; + double s = (double)(1.0) / (b.imag + b.real * r); + return __pyx_t_double_complex_from_parts( + (a.real * r + a.imag) * s, (a.imag * r - a.real) * s); + } + } + #else + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_quot_double(__pyx_t_double_complex a, __pyx_t_double_complex b) { + if (b.imag == 0) { + return __pyx_t_double_complex_from_parts(a.real / b.real, a.imag / b.real); + } else { + double denom = b.real * b.real + b.imag * b.imag; + return __pyx_t_double_complex_from_parts( + (a.real * b.real + a.imag * b.imag) / denom, + (a.imag * b.real - a.real * b.imag) / denom); + } + } + #endif + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_neg_double(__pyx_t_double_complex a) { + __pyx_t_double_complex z; + z.real = -a.real; + z.imag = -a.imag; + return z; + } + static CYTHON_INLINE int __Pyx_c_is_zero_double(__pyx_t_double_complex a) { + return (a.real == 0) && (a.imag == 0); + } + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_conj_double(__pyx_t_double_complex a) { + __pyx_t_double_complex z; + z.real = a.real; + z.imag = -a.imag; + return z; + } + #if 1 + static CYTHON_INLINE double __Pyx_c_abs_double(__pyx_t_double_complex z) { + #if !defined(HAVE_HYPOT) || defined(_MSC_VER) + return sqrt(z.real*z.real + z.imag*z.imag); + #else + return hypot(z.real, z.imag); + #endif + } + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_pow_double(__pyx_t_double_complex a, __pyx_t_double_complex b) { + __pyx_t_double_complex z; + double r, lnr, theta, z_r, z_theta; + if (b.imag == 0 && b.real == (int)b.real) { + if (b.real < 0) { + double denom = a.real * a.real + a.imag * a.imag; + a.real = a.real / denom; + a.imag = -a.imag / denom; + b.real = -b.real; + } + switch ((int)b.real) { + case 0: + z.real = 1; + z.imag = 0; + return z; + case 1: + return a; + case 2: + return __Pyx_c_prod_double(a, a); + case 3: + z = __Pyx_c_prod_double(a, a); + return __Pyx_c_prod_double(z, a); + case 4: + z = __Pyx_c_prod_double(a, a); + return __Pyx_c_prod_double(z, z); + } + } + if (a.imag == 0) { + if (a.real == 0) { + return a; + } else if ((b.imag == 0) && (a.real >= 0)) { + z.real = pow(a.real, b.real); + z.imag = 0; + return z; + } else if (a.real > 0) { + r = a.real; + theta = 0; + } else { + r = -a.real; + theta = atan2(0.0, -1.0); + } + } else { + r = __Pyx_c_abs_double(a); + theta = atan2(a.imag, a.real); + } + lnr = log(r); + z_r = exp(lnr * b.real - theta * b.imag); + z_theta = theta * b.real + lnr * b.imag; + z.real = z_r * cos(z_theta); + z.imag = z_r * sin(z_theta); + return z; + } + #endif +#endif + +/* MemviewSliceCopyTemplate */ + static __Pyx_memviewslice +__pyx_memoryview_copy_new_contig(const __Pyx_memviewslice *from_mvs, + const char *mode, int ndim, + size_t sizeof_dtype, int contig_flag, + int dtype_is_object) +{ + __Pyx_RefNannyDeclarations + int i; + __Pyx_memviewslice new_mvs = { 0, 0, { 0 }, { 0 }, { 0 } }; + struct __pyx_memoryview_obj *from_memview = from_mvs->memview; + Py_buffer *buf = &from_memview->view; + PyObject *shape_tuple = NULL; + PyObject *temp_int = NULL; + struct __pyx_array_obj *array_obj = NULL; + struct __pyx_memoryview_obj *memview_obj = NULL; + __Pyx_RefNannySetupContext("__pyx_memoryview_copy_new_contig", 0); + for (i = 0; i < ndim; i++) { + if (unlikely(from_mvs->suboffsets[i] >= 0)) { + PyErr_Format(PyExc_ValueError, "Cannot copy memoryview slice with " + "indirect dimensions (axis %d)", i); + goto fail; + } + } + shape_tuple = PyTuple_New(ndim); + if (unlikely(!shape_tuple)) { + goto fail; + } + __Pyx_GOTREF(shape_tuple); + for(i = 0; i < ndim; i++) { + temp_int = PyInt_FromSsize_t(from_mvs->shape[i]); + if(unlikely(!temp_int)) { + goto fail; + } else { + PyTuple_SET_ITEM(shape_tuple, i, temp_int); + temp_int = NULL; + } + } + array_obj = __pyx_array_new(shape_tuple, sizeof_dtype, buf->format, (char *) mode, NULL); + if (unlikely(!array_obj)) { + goto fail; + } + __Pyx_GOTREF(array_obj); + memview_obj = (struct __pyx_memoryview_obj *) __pyx_memoryview_new( + (PyObject *) array_obj, contig_flag, + dtype_is_object, + from_mvs->memview->typeinfo); + if (unlikely(!memview_obj)) + goto fail; + if (unlikely(__Pyx_init_memviewslice(memview_obj, ndim, &new_mvs, 1) < 0)) + goto fail; + if (unlikely(__pyx_memoryview_copy_contents(*from_mvs, new_mvs, ndim, ndim, + dtype_is_object) < 0)) + goto fail; + goto no_fail; +fail: + __Pyx_XDECREF(new_mvs.memview); + new_mvs.memview = NULL; + new_mvs.data = NULL; +no_fail: + __Pyx_XDECREF(shape_tuple); + __Pyx_XDECREF(temp_int); + __Pyx_XDECREF(array_obj); + __Pyx_RefNannyFinishContext(); + return new_mvs; +} + +/* MemviewSliceInit */ + static int +__Pyx_init_memviewslice(struct __pyx_memoryview_obj *memview, + int ndim, + __Pyx_memviewslice *memviewslice, + int memview_is_new_reference) +{ + __Pyx_RefNannyDeclarations + int i, retval=-1; + Py_buffer *buf = &memview->view; + __Pyx_RefNannySetupContext("init_memviewslice", 0); + if (unlikely(memviewslice->memview || memviewslice->data)) { + PyErr_SetString(PyExc_ValueError, + "memviewslice is already initialized!"); + goto fail; + } + if (buf->strides) { + for (i = 0; i < ndim; i++) { + memviewslice->strides[i] = buf->strides[i]; + } + } else { + Py_ssize_t stride = buf->itemsize; + for (i = ndim - 1; i >= 0; i--) { + memviewslice->strides[i] = stride; + stride *= buf->shape[i]; + } + } + for (i = 0; i < ndim; i++) { + memviewslice->shape[i] = buf->shape[i]; + if (buf->suboffsets) { + memviewslice->suboffsets[i] = buf->suboffsets[i]; + } else { + memviewslice->suboffsets[i] = -1; + } + } + memviewslice->memview = memview; + memviewslice->data = (char *)buf->buf; + if (__pyx_add_acquisition_count(memview) == 0 && !memview_is_new_reference) { + Py_INCREF(memview); + } + retval = 0; + goto no_fail; +fail: + memviewslice->memview = 0; + memviewslice->data = 0; + retval = -1; +no_fail: + __Pyx_RefNannyFinishContext(); + return retval; +} +#ifndef Py_NO_RETURN +#define Py_NO_RETURN +#endif +static void __pyx_fatalerror(const char *fmt, ...) Py_NO_RETURN { + va_list vargs; + char msg[200]; +#if PY_VERSION_HEX >= 0x030A0000 || defined(HAVE_STDARG_PROTOTYPES) + va_start(vargs, fmt); +#else + va_start(vargs); +#endif + vsnprintf(msg, 200, fmt, vargs); + va_end(vargs); + Py_FatalError(msg); +} +static CYTHON_INLINE int +__pyx_add_acquisition_count_locked(__pyx_atomic_int_type *acquisition_count, + PyThread_type_lock lock) +{ + int result; + PyThread_acquire_lock(lock, 1); + result = (*acquisition_count)++; + PyThread_release_lock(lock); + return result; +} +static CYTHON_INLINE int +__pyx_sub_acquisition_count_locked(__pyx_atomic_int_type *acquisition_count, + PyThread_type_lock lock) +{ + int result; + PyThread_acquire_lock(lock, 1); + result = (*acquisition_count)--; + PyThread_release_lock(lock); + return result; +} +static CYTHON_INLINE void +__Pyx_INC_MEMVIEW(__Pyx_memviewslice *memslice, int have_gil, int lineno) +{ + __pyx_nonatomic_int_type old_acquisition_count; + struct __pyx_memoryview_obj *memview = memslice->memview; + if (unlikely(!memview || (PyObject *) memview == Py_None)) { + return; + } + old_acquisition_count = __pyx_add_acquisition_count(memview); + if (unlikely(old_acquisition_count <= 0)) { + if (likely(old_acquisition_count == 0)) { + if (have_gil) { + Py_INCREF((PyObject *) memview); + } else { + PyGILState_STATE _gilstate = PyGILState_Ensure(); + Py_INCREF((PyObject *) memview); + PyGILState_Release(_gilstate); + } + } else { + __pyx_fatalerror("Acquisition count is %d (line %d)", + old_acquisition_count+1, lineno); + } + } +} +static CYTHON_INLINE void __Pyx_XCLEAR_MEMVIEW(__Pyx_memviewslice *memslice, + int have_gil, int lineno) { + __pyx_nonatomic_int_type old_acquisition_count; + struct __pyx_memoryview_obj *memview = memslice->memview; + if (unlikely(!memview || (PyObject *) memview == Py_None)) { + memslice->memview = NULL; + return; + } + old_acquisition_count = __pyx_sub_acquisition_count(memview); + memslice->data = NULL; + if (likely(old_acquisition_count > 1)) { + memslice->memview = NULL; + } else if (likely(old_acquisition_count == 1)) { + if (have_gil) { + Py_CLEAR(memslice->memview); + } else { + PyGILState_STATE _gilstate = PyGILState_Ensure(); + Py_CLEAR(memslice->memview); + PyGILState_Release(_gilstate); + } + } else { + __pyx_fatalerror("Acquisition count is %d (line %d)", + old_acquisition_count-1, lineno); + } +} + +/* CIntToPy */ + static CYTHON_INLINE PyObject* __Pyx_PyInt_From_int64_t(int64_t value) { +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#endif + const int64_t neg_one = (int64_t) -1, const_zero = (int64_t) 0; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic pop +#endif + const int is_unsigned = neg_one > const_zero; + if (is_unsigned) { + if (sizeof(int64_t) < sizeof(long)) { + return PyInt_FromLong((long) value); + } else if (sizeof(int64_t) <= sizeof(unsigned long)) { + return PyLong_FromUnsignedLong((unsigned long) value); +#ifdef HAVE_LONG_LONG + } else if (sizeof(int64_t) <= sizeof(unsigned PY_LONG_LONG)) { + return PyLong_FromUnsignedLongLong((unsigned PY_LONG_LONG) value); +#endif + } + } else { + if (sizeof(int64_t) <= sizeof(long)) { + return PyInt_FromLong((long) value); +#ifdef HAVE_LONG_LONG + } else if (sizeof(int64_t) <= sizeof(PY_LONG_LONG)) { + return PyLong_FromLongLong((PY_LONG_LONG) value); +#endif + } + } + { + int one = 1; int little = (int)*(unsigned char *)&one; + unsigned char *bytes = (unsigned char *)&value; +#if !CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030d0000 + return _PyLong_FromByteArray(bytes, sizeof(int64_t), + little, !is_unsigned); +#else + PyObject *from_bytes, *result = NULL; + PyObject *py_bytes = NULL, *arg_tuple = NULL, *kwds = NULL, *order_str = NULL; + from_bytes = PyObject_GetAttrString((PyObject*)&PyLong_Type, "from_bytes"); + if (!from_bytes) return NULL; + py_bytes = PyBytes_FromStringAndSize((char*)bytes, sizeof(int64_t)); + if (!py_bytes) goto limited_bad; + order_str = PyUnicode_FromString(little ? "little" : "big"); + if (!order_str) goto limited_bad; + arg_tuple = PyTuple_Pack(2, py_bytes, order_str); + if (!arg_tuple) goto limited_bad; + if (!is_unsigned) { + kwds = PyDict_New(); + if (!kwds) goto limited_bad; + if (PyDict_SetItemString(kwds, "signed", __Pyx_NewRef(Py_True))) goto limited_bad; + } + result = PyObject_Call(from_bytes, arg_tuple, kwds); + limited_bad: + Py_XDECREF(kwds); + Py_XDECREF(arg_tuple); + Py_XDECREF(order_str); + Py_XDECREF(py_bytes); + Py_XDECREF(from_bytes); + return result; +#endif + } +} + +/* CIntFromPy */ + static CYTHON_INLINE int64_t __Pyx_PyInt_As_int64_t(PyObject *x) { +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#endif + const int64_t neg_one = (int64_t) -1, const_zero = (int64_t) 0; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic pop +#endif + const int is_unsigned = neg_one > const_zero; +#if PY_MAJOR_VERSION < 3 + if (likely(PyInt_Check(x))) { + if ((sizeof(int64_t) < sizeof(long))) { + __PYX_VERIFY_RETURN_INT(int64_t, long, PyInt_AS_LONG(x)) + } else { + long val = PyInt_AS_LONG(x); + if (is_unsigned && unlikely(val < 0)) { + goto raise_neg_overflow; + } + return (int64_t) val; + } + } else +#endif + if (likely(PyLong_Check(x))) { + if (is_unsigned) { +#if CYTHON_USE_PYLONG_INTERNALS + if (unlikely(__Pyx_PyLong_IsNeg(x))) { + goto raise_neg_overflow; + } else if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(int64_t, __Pyx_compact_upylong, __Pyx_PyLong_CompactValueUnsigned(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_DigitCount(x)) { + case 2: + if ((8 * sizeof(int64_t) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int64_t, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int64_t) >= 2 * PyLong_SHIFT)) { + return (int64_t) (((((int64_t)digits[1]) << PyLong_SHIFT) | (int64_t)digits[0])); + } + } + break; + case 3: + if ((8 * sizeof(int64_t) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int64_t, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int64_t) >= 3 * PyLong_SHIFT)) { + return (int64_t) (((((((int64_t)digits[2]) << PyLong_SHIFT) | (int64_t)digits[1]) << PyLong_SHIFT) | (int64_t)digits[0])); + } + } + break; + case 4: + if ((8 * sizeof(int64_t) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int64_t, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int64_t) >= 4 * PyLong_SHIFT)) { + return (int64_t) (((((((((int64_t)digits[3]) << PyLong_SHIFT) | (int64_t)digits[2]) << PyLong_SHIFT) | (int64_t)digits[1]) << PyLong_SHIFT) | (int64_t)digits[0])); + } + } + break; + } + } +#endif +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030C00A7 + if (unlikely(Py_SIZE(x) < 0)) { + goto raise_neg_overflow; + } +#else + { + int result = PyObject_RichCompareBool(x, Py_False, Py_LT); + if (unlikely(result < 0)) + return (int64_t) -1; + if (unlikely(result == 1)) + goto raise_neg_overflow; + } +#endif + if ((sizeof(int64_t) <= sizeof(unsigned long))) { + __PYX_VERIFY_RETURN_INT_EXC(int64_t, unsigned long, PyLong_AsUnsignedLong(x)) +#ifdef HAVE_LONG_LONG + } else if ((sizeof(int64_t) <= sizeof(unsigned PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(int64_t, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x)) +#endif + } + } else { +#if CYTHON_USE_PYLONG_INTERNALS + if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(int64_t, __Pyx_compact_pylong, __Pyx_PyLong_CompactValue(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_SignedDigitCount(x)) { + case -2: + if ((8 * sizeof(int64_t) - 1 > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int64_t, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int64_t) - 1 > 2 * PyLong_SHIFT)) { + return (int64_t) (((int64_t)-1)*(((((int64_t)digits[1]) << PyLong_SHIFT) | (int64_t)digits[0]))); + } + } + break; + case 2: + if ((8 * sizeof(int64_t) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int64_t, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int64_t) - 1 > 2 * PyLong_SHIFT)) { + return (int64_t) ((((((int64_t)digits[1]) << PyLong_SHIFT) | (int64_t)digits[0]))); + } + } + break; + case -3: + if ((8 * sizeof(int64_t) - 1 > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int64_t, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int64_t) - 1 > 3 * PyLong_SHIFT)) { + return (int64_t) (((int64_t)-1)*(((((((int64_t)digits[2]) << PyLong_SHIFT) | (int64_t)digits[1]) << PyLong_SHIFT) | (int64_t)digits[0]))); + } + } + break; + case 3: + if ((8 * sizeof(int64_t) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int64_t, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int64_t) - 1 > 3 * PyLong_SHIFT)) { + return (int64_t) ((((((((int64_t)digits[2]) << PyLong_SHIFT) | (int64_t)digits[1]) << PyLong_SHIFT) | (int64_t)digits[0]))); + } + } + break; + case -4: + if ((8 * sizeof(int64_t) - 1 > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int64_t, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int64_t) - 1 > 4 * PyLong_SHIFT)) { + return (int64_t) (((int64_t)-1)*(((((((((int64_t)digits[3]) << PyLong_SHIFT) | (int64_t)digits[2]) << PyLong_SHIFT) | (int64_t)digits[1]) << PyLong_SHIFT) | (int64_t)digits[0]))); + } + } + break; + case 4: + if ((8 * sizeof(int64_t) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int64_t, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int64_t) - 1 > 4 * PyLong_SHIFT)) { + return (int64_t) ((((((((((int64_t)digits[3]) << PyLong_SHIFT) | (int64_t)digits[2]) << PyLong_SHIFT) | (int64_t)digits[1]) << PyLong_SHIFT) | (int64_t)digits[0]))); + } + } + break; + } + } +#endif + if ((sizeof(int64_t) <= sizeof(long))) { + __PYX_VERIFY_RETURN_INT_EXC(int64_t, long, PyLong_AsLong(x)) +#ifdef HAVE_LONG_LONG + } else if ((sizeof(int64_t) <= sizeof(PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(int64_t, PY_LONG_LONG, PyLong_AsLongLong(x)) +#endif + } + } + { + int64_t val; + PyObject *v = __Pyx_PyNumber_IntOrLong(x); +#if PY_MAJOR_VERSION < 3 + if (likely(v) && !PyLong_Check(v)) { + PyObject *tmp = v; + v = PyNumber_Long(tmp); + Py_DECREF(tmp); + } +#endif + if (likely(v)) { + int ret = -1; +#if PY_VERSION_HEX < 0x030d0000 && !(CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API) || defined(_PyLong_AsByteArray) + int one = 1; int is_little = (int)*(unsigned char *)&one; + unsigned char *bytes = (unsigned char *)&val; + ret = _PyLong_AsByteArray((PyLongObject *)v, + bytes, sizeof(val), + is_little, !is_unsigned); +#else + PyObject *stepval = NULL, *mask = NULL, *shift = NULL; + int bits, remaining_bits, is_negative = 0; + long idigit; + int chunk_size = (sizeof(long) < 8) ? 30 : 62; + if (unlikely(!PyLong_CheckExact(v))) { + PyObject *tmp = v; + v = PyNumber_Long(v); + assert(PyLong_CheckExact(v)); + Py_DECREF(tmp); + if (unlikely(!v)) return (int64_t) -1; + } +#if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 + if (Py_SIZE(x) == 0) + return (int64_t) 0; + is_negative = Py_SIZE(x) < 0; +#else + { + int result = PyObject_RichCompareBool(x, Py_False, Py_LT); + if (unlikely(result < 0)) + return (int64_t) -1; + is_negative = result == 1; + } +#endif + if (is_unsigned && unlikely(is_negative)) { + goto raise_neg_overflow; + } else if (is_negative) { + stepval = PyNumber_Invert(v); + if (unlikely(!stepval)) + return (int64_t) -1; + } else { + stepval = __Pyx_NewRef(v); + } + val = (int64_t) 0; + mask = PyLong_FromLong((1L << chunk_size) - 1); if (unlikely(!mask)) goto done; + shift = PyLong_FromLong(chunk_size); if (unlikely(!shift)) goto done; + for (bits = 0; bits < (int) sizeof(int64_t) * 8 - chunk_size; bits += chunk_size) { + PyObject *tmp, *digit; + digit = PyNumber_And(stepval, mask); + if (unlikely(!digit)) goto done; + idigit = PyLong_AsLong(digit); + Py_DECREF(digit); + if (unlikely(idigit < 0)) goto done; + tmp = PyNumber_Rshift(stepval, shift); + if (unlikely(!tmp)) goto done; + Py_DECREF(stepval); stepval = tmp; + val |= ((int64_t) idigit) << bits; + #if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 + if (Py_SIZE(stepval) == 0) + goto unpacking_done; + #endif + } + idigit = PyLong_AsLong(stepval); + if (unlikely(idigit < 0)) goto done; + remaining_bits = ((int) sizeof(int64_t) * 8) - bits - (is_unsigned ? 0 : 1); + if (unlikely(idigit >= (1L << remaining_bits))) + goto raise_overflow; + val |= ((int64_t) idigit) << bits; + #if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 + unpacking_done: + #endif + if (!is_unsigned) { + if (unlikely(val & (((int64_t) 1) << (sizeof(int64_t) * 8 - 1)))) + goto raise_overflow; + if (is_negative) + val = ~val; + } + ret = 0; + done: + Py_XDECREF(shift); + Py_XDECREF(mask); + Py_XDECREF(stepval); +#endif + Py_DECREF(v); + if (likely(!ret)) + return val; + } + return (int64_t) -1; + } + } else { + int64_t val; + PyObject *tmp = __Pyx_PyNumber_IntOrLong(x); + if (!tmp) return (int64_t) -1; + val = __Pyx_PyInt_As_int64_t(tmp); + Py_DECREF(tmp); + return val; + } +raise_overflow: + PyErr_SetString(PyExc_OverflowError, + "value too large to convert to int64_t"); + return (int64_t) -1; +raise_neg_overflow: + PyErr_SetString(PyExc_OverflowError, + "can't convert negative value to int64_t"); + return (int64_t) -1; +} + +/* CIntFromPy */ + static CYTHON_INLINE int __Pyx_PyInt_As_int(PyObject *x) { +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#endif + const int neg_one = (int) -1, const_zero = (int) 0; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic pop +#endif + const int is_unsigned = neg_one > const_zero; +#if PY_MAJOR_VERSION < 3 + if (likely(PyInt_Check(x))) { + if ((sizeof(int) < sizeof(long))) { + __PYX_VERIFY_RETURN_INT(int, long, PyInt_AS_LONG(x)) + } else { + long val = PyInt_AS_LONG(x); + if (is_unsigned && unlikely(val < 0)) { + goto raise_neg_overflow; + } + return (int) val; + } + } else +#endif + if (likely(PyLong_Check(x))) { + if (is_unsigned) { +#if CYTHON_USE_PYLONG_INTERNALS + if (unlikely(__Pyx_PyLong_IsNeg(x))) { + goto raise_neg_overflow; + } else if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(int, __Pyx_compact_upylong, __Pyx_PyLong_CompactValueUnsigned(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_DigitCount(x)) { + case 2: + if ((8 * sizeof(int) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) >= 2 * PyLong_SHIFT)) { + return (int) (((((int)digits[1]) << PyLong_SHIFT) | (int)digits[0])); + } + } + break; + case 3: + if ((8 * sizeof(int) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) >= 3 * PyLong_SHIFT)) { + return (int) (((((((int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0])); + } + } + break; + case 4: + if ((8 * sizeof(int) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) >= 4 * PyLong_SHIFT)) { + return (int) (((((((((int)digits[3]) << PyLong_SHIFT) | (int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0])); + } + } + break; + } + } +#endif +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030C00A7 + if (unlikely(Py_SIZE(x) < 0)) { + goto raise_neg_overflow; + } +#else + { + int result = PyObject_RichCompareBool(x, Py_False, Py_LT); + if (unlikely(result < 0)) + return (int) -1; + if (unlikely(result == 1)) + goto raise_neg_overflow; + } +#endif + if ((sizeof(int) <= sizeof(unsigned long))) { + __PYX_VERIFY_RETURN_INT_EXC(int, unsigned long, PyLong_AsUnsignedLong(x)) +#ifdef HAVE_LONG_LONG + } else if ((sizeof(int) <= sizeof(unsigned PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(int, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x)) +#endif + } + } else { +#if CYTHON_USE_PYLONG_INTERNALS + if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(int, __Pyx_compact_pylong, __Pyx_PyLong_CompactValue(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_SignedDigitCount(x)) { + case -2: + if ((8 * sizeof(int) - 1 > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) - 1 > 2 * PyLong_SHIFT)) { + return (int) (((int)-1)*(((((int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); + } + } + break; + case 2: + if ((8 * sizeof(int) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) - 1 > 2 * PyLong_SHIFT)) { + return (int) ((((((int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); + } + } + break; + case -3: + if ((8 * sizeof(int) - 1 > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) - 1 > 3 * PyLong_SHIFT)) { + return (int) (((int)-1)*(((((((int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); + } + } + break; + case 3: + if ((8 * sizeof(int) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) - 1 > 3 * PyLong_SHIFT)) { + return (int) ((((((((int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); + } + } + break; + case -4: + if ((8 * sizeof(int) - 1 > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) - 1 > 4 * PyLong_SHIFT)) { + return (int) (((int)-1)*(((((((((int)digits[3]) << PyLong_SHIFT) | (int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); + } + } + break; + case 4: + if ((8 * sizeof(int) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) - 1 > 4 * PyLong_SHIFT)) { + return (int) ((((((((((int)digits[3]) << PyLong_SHIFT) | (int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); + } + } + break; + } + } +#endif + if ((sizeof(int) <= sizeof(long))) { + __PYX_VERIFY_RETURN_INT_EXC(int, long, PyLong_AsLong(x)) +#ifdef HAVE_LONG_LONG + } else if ((sizeof(int) <= sizeof(PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(int, PY_LONG_LONG, PyLong_AsLongLong(x)) +#endif + } + } + { + int val; + PyObject *v = __Pyx_PyNumber_IntOrLong(x); +#if PY_MAJOR_VERSION < 3 + if (likely(v) && !PyLong_Check(v)) { + PyObject *tmp = v; + v = PyNumber_Long(tmp); + Py_DECREF(tmp); + } +#endif + if (likely(v)) { + int ret = -1; +#if PY_VERSION_HEX < 0x030d0000 && !(CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API) || defined(_PyLong_AsByteArray) + int one = 1; int is_little = (int)*(unsigned char *)&one; + unsigned char *bytes = (unsigned char *)&val; + ret = _PyLong_AsByteArray((PyLongObject *)v, + bytes, sizeof(val), + is_little, !is_unsigned); +#else + PyObject *stepval = NULL, *mask = NULL, *shift = NULL; + int bits, remaining_bits, is_negative = 0; + long idigit; + int chunk_size = (sizeof(long) < 8) ? 30 : 62; + if (unlikely(!PyLong_CheckExact(v))) { + PyObject *tmp = v; + v = PyNumber_Long(v); + assert(PyLong_CheckExact(v)); + Py_DECREF(tmp); + if (unlikely(!v)) return (int) -1; + } +#if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 + if (Py_SIZE(x) == 0) + return (int) 0; + is_negative = Py_SIZE(x) < 0; +#else + { + int result = PyObject_RichCompareBool(x, Py_False, Py_LT); + if (unlikely(result < 0)) + return (int) -1; + is_negative = result == 1; + } +#endif + if (is_unsigned && unlikely(is_negative)) { + goto raise_neg_overflow; + } else if (is_negative) { + stepval = PyNumber_Invert(v); + if (unlikely(!stepval)) + return (int) -1; + } else { + stepval = __Pyx_NewRef(v); + } + val = (int) 0; + mask = PyLong_FromLong((1L << chunk_size) - 1); if (unlikely(!mask)) goto done; + shift = PyLong_FromLong(chunk_size); if (unlikely(!shift)) goto done; + for (bits = 0; bits < (int) sizeof(int) * 8 - chunk_size; bits += chunk_size) { + PyObject *tmp, *digit; + digit = PyNumber_And(stepval, mask); + if (unlikely(!digit)) goto done; + idigit = PyLong_AsLong(digit); + Py_DECREF(digit); + if (unlikely(idigit < 0)) goto done; + tmp = PyNumber_Rshift(stepval, shift); + if (unlikely(!tmp)) goto done; + Py_DECREF(stepval); stepval = tmp; + val |= ((int) idigit) << bits; + #if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 + if (Py_SIZE(stepval) == 0) + goto unpacking_done; + #endif + } + idigit = PyLong_AsLong(stepval); + if (unlikely(idigit < 0)) goto done; + remaining_bits = ((int) sizeof(int) * 8) - bits - (is_unsigned ? 0 : 1); + if (unlikely(idigit >= (1L << remaining_bits))) + goto raise_overflow; + val |= ((int) idigit) << bits; + #if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 + unpacking_done: + #endif + if (!is_unsigned) { + if (unlikely(val & (((int) 1) << (sizeof(int) * 8 - 1)))) + goto raise_overflow; + if (is_negative) + val = ~val; + } + ret = 0; + done: + Py_XDECREF(shift); + Py_XDECREF(mask); + Py_XDECREF(stepval); +#endif + Py_DECREF(v); + if (likely(!ret)) + return val; + } + return (int) -1; + } + } else { + int val; + PyObject *tmp = __Pyx_PyNumber_IntOrLong(x); + if (!tmp) return (int) -1; + val = __Pyx_PyInt_As_int(tmp); + Py_DECREF(tmp); + return val; + } +raise_overflow: + PyErr_SetString(PyExc_OverflowError, + "value too large to convert to int"); + return (int) -1; +raise_neg_overflow: + PyErr_SetString(PyExc_OverflowError, + "can't convert negative value to int"); + return (int) -1; +} + +/* CIntFromPy */ + static CYTHON_INLINE long __Pyx_PyInt_As_long(PyObject *x) { +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#endif + const long neg_one = (long) -1, const_zero = (long) 0; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic pop +#endif + const int is_unsigned = neg_one > const_zero; +#if PY_MAJOR_VERSION < 3 + if (likely(PyInt_Check(x))) { + if ((sizeof(long) < sizeof(long))) { + __PYX_VERIFY_RETURN_INT(long, long, PyInt_AS_LONG(x)) + } else { + long val = PyInt_AS_LONG(x); + if (is_unsigned && unlikely(val < 0)) { + goto raise_neg_overflow; + } + return (long) val; + } + } else +#endif + if (likely(PyLong_Check(x))) { + if (is_unsigned) { +#if CYTHON_USE_PYLONG_INTERNALS + if (unlikely(__Pyx_PyLong_IsNeg(x))) { + goto raise_neg_overflow; + } else if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(long, __Pyx_compact_upylong, __Pyx_PyLong_CompactValueUnsigned(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_DigitCount(x)) { + case 2: + if ((8 * sizeof(long) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) >= 2 * PyLong_SHIFT)) { + return (long) (((((long)digits[1]) << PyLong_SHIFT) | (long)digits[0])); + } + } + break; + case 3: + if ((8 * sizeof(long) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) >= 3 * PyLong_SHIFT)) { + return (long) (((((((long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0])); + } + } + break; + case 4: + if ((8 * sizeof(long) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) >= 4 * PyLong_SHIFT)) { + return (long) (((((((((long)digits[3]) << PyLong_SHIFT) | (long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0])); + } + } + break; + } + } +#endif +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030C00A7 + if (unlikely(Py_SIZE(x) < 0)) { + goto raise_neg_overflow; + } +#else + { + int result = PyObject_RichCompareBool(x, Py_False, Py_LT); + if (unlikely(result < 0)) + return (long) -1; + if (unlikely(result == 1)) + goto raise_neg_overflow; + } +#endif + if ((sizeof(long) <= sizeof(unsigned long))) { + __PYX_VERIFY_RETURN_INT_EXC(long, unsigned long, PyLong_AsUnsignedLong(x)) +#ifdef HAVE_LONG_LONG + } else if ((sizeof(long) <= sizeof(unsigned PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(long, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x)) +#endif + } + } else { +#if CYTHON_USE_PYLONG_INTERNALS + if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(long, __Pyx_compact_pylong, __Pyx_PyLong_CompactValue(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_SignedDigitCount(x)) { + case -2: + if ((8 * sizeof(long) - 1 > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) - 1 > 2 * PyLong_SHIFT)) { + return (long) (((long)-1)*(((((long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); + } + } + break; + case 2: + if ((8 * sizeof(long) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) - 1 > 2 * PyLong_SHIFT)) { + return (long) ((((((long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); + } + } + break; + case -3: + if ((8 * sizeof(long) - 1 > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) - 1 > 3 * PyLong_SHIFT)) { + return (long) (((long)-1)*(((((((long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); + } + } + break; + case 3: + if ((8 * sizeof(long) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) - 1 > 3 * PyLong_SHIFT)) { + return (long) ((((((((long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); + } + } + break; + case -4: + if ((8 * sizeof(long) - 1 > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) - 1 > 4 * PyLong_SHIFT)) { + return (long) (((long)-1)*(((((((((long)digits[3]) << PyLong_SHIFT) | (long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); + } + } + break; + case 4: + if ((8 * sizeof(long) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) - 1 > 4 * PyLong_SHIFT)) { + return (long) ((((((((((long)digits[3]) << PyLong_SHIFT) | (long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); + } + } + break; + } + } +#endif + if ((sizeof(long) <= sizeof(long))) { + __PYX_VERIFY_RETURN_INT_EXC(long, long, PyLong_AsLong(x)) +#ifdef HAVE_LONG_LONG + } else if ((sizeof(long) <= sizeof(PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(long, PY_LONG_LONG, PyLong_AsLongLong(x)) +#endif + } + } + { + long val; + PyObject *v = __Pyx_PyNumber_IntOrLong(x); +#if PY_MAJOR_VERSION < 3 + if (likely(v) && !PyLong_Check(v)) { + PyObject *tmp = v; + v = PyNumber_Long(tmp); + Py_DECREF(tmp); + } +#endif + if (likely(v)) { + int ret = -1; +#if PY_VERSION_HEX < 0x030d0000 && !(CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API) || defined(_PyLong_AsByteArray) + int one = 1; int is_little = (int)*(unsigned char *)&one; + unsigned char *bytes = (unsigned char *)&val; + ret = _PyLong_AsByteArray((PyLongObject *)v, + bytes, sizeof(val), + is_little, !is_unsigned); +#else + PyObject *stepval = NULL, *mask = NULL, *shift = NULL; + int bits, remaining_bits, is_negative = 0; + long idigit; + int chunk_size = (sizeof(long) < 8) ? 30 : 62; + if (unlikely(!PyLong_CheckExact(v))) { + PyObject *tmp = v; + v = PyNumber_Long(v); + assert(PyLong_CheckExact(v)); + Py_DECREF(tmp); + if (unlikely(!v)) return (long) -1; + } +#if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 + if (Py_SIZE(x) == 0) + return (long) 0; + is_negative = Py_SIZE(x) < 0; +#else + { + int result = PyObject_RichCompareBool(x, Py_False, Py_LT); + if (unlikely(result < 0)) + return (long) -1; + is_negative = result == 1; + } +#endif + if (is_unsigned && unlikely(is_negative)) { + goto raise_neg_overflow; + } else if (is_negative) { + stepval = PyNumber_Invert(v); + if (unlikely(!stepval)) + return (long) -1; + } else { + stepval = __Pyx_NewRef(v); + } + val = (long) 0; + mask = PyLong_FromLong((1L << chunk_size) - 1); if (unlikely(!mask)) goto done; + shift = PyLong_FromLong(chunk_size); if (unlikely(!shift)) goto done; + for (bits = 0; bits < (int) sizeof(long) * 8 - chunk_size; bits += chunk_size) { + PyObject *tmp, *digit; + digit = PyNumber_And(stepval, mask); + if (unlikely(!digit)) goto done; + idigit = PyLong_AsLong(digit); + Py_DECREF(digit); + if (unlikely(idigit < 0)) goto done; + tmp = PyNumber_Rshift(stepval, shift); + if (unlikely(!tmp)) goto done; + Py_DECREF(stepval); stepval = tmp; + val |= ((long) idigit) << bits; + #if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 + if (Py_SIZE(stepval) == 0) + goto unpacking_done; + #endif + } + idigit = PyLong_AsLong(stepval); + if (unlikely(idigit < 0)) goto done; + remaining_bits = ((int) sizeof(long) * 8) - bits - (is_unsigned ? 0 : 1); + if (unlikely(idigit >= (1L << remaining_bits))) + goto raise_overflow; + val |= ((long) idigit) << bits; + #if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 + unpacking_done: + #endif + if (!is_unsigned) { + if (unlikely(val & (((long) 1) << (sizeof(long) * 8 - 1)))) + goto raise_overflow; + if (is_negative) + val = ~val; + } + ret = 0; + done: + Py_XDECREF(shift); + Py_XDECREF(mask); + Py_XDECREF(stepval); +#endif + Py_DECREF(v); + if (likely(!ret)) + return val; + } + return (long) -1; + } + } else { + long val; + PyObject *tmp = __Pyx_PyNumber_IntOrLong(x); + if (!tmp) return (long) -1; + val = __Pyx_PyInt_As_long(tmp); + Py_DECREF(tmp); + return val; + } +raise_overflow: + PyErr_SetString(PyExc_OverflowError, + "value too large to convert to long"); + return (long) -1; +raise_neg_overflow: + PyErr_SetString(PyExc_OverflowError, + "can't convert negative value to long"); + return (long) -1; +} + +/* CIntToPy */ + static CYTHON_INLINE PyObject* __Pyx_PyInt_From_long(long value) { +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#endif + const long neg_one = (long) -1, const_zero = (long) 0; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic pop +#endif + const int is_unsigned = neg_one > const_zero; + if (is_unsigned) { + if (sizeof(long) < sizeof(long)) { + return PyInt_FromLong((long) value); + } else if (sizeof(long) <= sizeof(unsigned long)) { + return PyLong_FromUnsignedLong((unsigned long) value); +#ifdef HAVE_LONG_LONG + } else if (sizeof(long) <= sizeof(unsigned PY_LONG_LONG)) { + return PyLong_FromUnsignedLongLong((unsigned PY_LONG_LONG) value); +#endif + } + } else { + if (sizeof(long) <= sizeof(long)) { + return PyInt_FromLong((long) value); +#ifdef HAVE_LONG_LONG + } else if (sizeof(long) <= sizeof(PY_LONG_LONG)) { + return PyLong_FromLongLong((PY_LONG_LONG) value); +#endif + } + } + { + int one = 1; int little = (int)*(unsigned char *)&one; + unsigned char *bytes = (unsigned char *)&value; +#if !CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030d0000 + return _PyLong_FromByteArray(bytes, sizeof(long), + little, !is_unsigned); +#else + PyObject *from_bytes, *result = NULL; + PyObject *py_bytes = NULL, *arg_tuple = NULL, *kwds = NULL, *order_str = NULL; + from_bytes = PyObject_GetAttrString((PyObject*)&PyLong_Type, "from_bytes"); + if (!from_bytes) return NULL; + py_bytes = PyBytes_FromStringAndSize((char*)bytes, sizeof(long)); + if (!py_bytes) goto limited_bad; + order_str = PyUnicode_FromString(little ? "little" : "big"); + if (!order_str) goto limited_bad; + arg_tuple = PyTuple_Pack(2, py_bytes, order_str); + if (!arg_tuple) goto limited_bad; + if (!is_unsigned) { + kwds = PyDict_New(); + if (!kwds) goto limited_bad; + if (PyDict_SetItemString(kwds, "signed", __Pyx_NewRef(Py_True))) goto limited_bad; + } + result = PyObject_Call(from_bytes, arg_tuple, kwds); + limited_bad: + Py_XDECREF(kwds); + Py_XDECREF(arg_tuple); + Py_XDECREF(order_str); + Py_XDECREF(py_bytes); + Py_XDECREF(from_bytes); + return result; +#endif + } +} + +/* CIntToPy */ + static CYTHON_INLINE PyObject* __Pyx_PyInt_From_int(int value) { +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#endif + const int neg_one = (int) -1, const_zero = (int) 0; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic pop +#endif + const int is_unsigned = neg_one > const_zero; + if (is_unsigned) { + if (sizeof(int) < sizeof(long)) { + return PyInt_FromLong((long) value); + } else if (sizeof(int) <= sizeof(unsigned long)) { + return PyLong_FromUnsignedLong((unsigned long) value); +#ifdef HAVE_LONG_LONG + } else if (sizeof(int) <= sizeof(unsigned PY_LONG_LONG)) { + return PyLong_FromUnsignedLongLong((unsigned PY_LONG_LONG) value); +#endif + } + } else { + if (sizeof(int) <= sizeof(long)) { + return PyInt_FromLong((long) value); +#ifdef HAVE_LONG_LONG + } else if (sizeof(int) <= sizeof(PY_LONG_LONG)) { + return PyLong_FromLongLong((PY_LONG_LONG) value); +#endif + } + } + { + int one = 1; int little = (int)*(unsigned char *)&one; + unsigned char *bytes = (unsigned char *)&value; +#if !CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030d0000 + return _PyLong_FromByteArray(bytes, sizeof(int), + little, !is_unsigned); +#else + PyObject *from_bytes, *result = NULL; + PyObject *py_bytes = NULL, *arg_tuple = NULL, *kwds = NULL, *order_str = NULL; + from_bytes = PyObject_GetAttrString((PyObject*)&PyLong_Type, "from_bytes"); + if (!from_bytes) return NULL; + py_bytes = PyBytes_FromStringAndSize((char*)bytes, sizeof(int)); + if (!py_bytes) goto limited_bad; + order_str = PyUnicode_FromString(little ? "little" : "big"); + if (!order_str) goto limited_bad; + arg_tuple = PyTuple_Pack(2, py_bytes, order_str); + if (!arg_tuple) goto limited_bad; + if (!is_unsigned) { + kwds = PyDict_New(); + if (!kwds) goto limited_bad; + if (PyDict_SetItemString(kwds, "signed", __Pyx_NewRef(Py_True))) goto limited_bad; + } + result = PyObject_Call(from_bytes, arg_tuple, kwds); + limited_bad: + Py_XDECREF(kwds); + Py_XDECREF(arg_tuple); + Py_XDECREF(order_str); + Py_XDECREF(py_bytes); + Py_XDECREF(from_bytes); + return result; +#endif + } +} + +/* CIntFromPy */ + static CYTHON_INLINE char __Pyx_PyInt_As_char(PyObject *x) { +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#endif + const char neg_one = (char) -1, const_zero = (char) 0; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic pop +#endif + const int is_unsigned = neg_one > const_zero; +#if PY_MAJOR_VERSION < 3 + if (likely(PyInt_Check(x))) { + if ((sizeof(char) < sizeof(long))) { + __PYX_VERIFY_RETURN_INT(char, long, PyInt_AS_LONG(x)) + } else { + long val = PyInt_AS_LONG(x); + if (is_unsigned && unlikely(val < 0)) { + goto raise_neg_overflow; + } + return (char) val; + } + } else +#endif + if (likely(PyLong_Check(x))) { + if (is_unsigned) { +#if CYTHON_USE_PYLONG_INTERNALS + if (unlikely(__Pyx_PyLong_IsNeg(x))) { + goto raise_neg_overflow; + } else if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(char, __Pyx_compact_upylong, __Pyx_PyLong_CompactValueUnsigned(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_DigitCount(x)) { + case 2: + if ((8 * sizeof(char) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(char, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(char) >= 2 * PyLong_SHIFT)) { + return (char) (((((char)digits[1]) << PyLong_SHIFT) | (char)digits[0])); + } + } + break; + case 3: + if ((8 * sizeof(char) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(char, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(char) >= 3 * PyLong_SHIFT)) { + return (char) (((((((char)digits[2]) << PyLong_SHIFT) | (char)digits[1]) << PyLong_SHIFT) | (char)digits[0])); + } + } + break; + case 4: + if ((8 * sizeof(char) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(char, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(char) >= 4 * PyLong_SHIFT)) { + return (char) (((((((((char)digits[3]) << PyLong_SHIFT) | (char)digits[2]) << PyLong_SHIFT) | (char)digits[1]) << PyLong_SHIFT) | (char)digits[0])); + } + } + break; + } + } +#endif +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030C00A7 + if (unlikely(Py_SIZE(x) < 0)) { + goto raise_neg_overflow; + } +#else + { + int result = PyObject_RichCompareBool(x, Py_False, Py_LT); + if (unlikely(result < 0)) + return (char) -1; + if (unlikely(result == 1)) + goto raise_neg_overflow; + } +#endif + if ((sizeof(char) <= sizeof(unsigned long))) { + __PYX_VERIFY_RETURN_INT_EXC(char, unsigned long, PyLong_AsUnsignedLong(x)) +#ifdef HAVE_LONG_LONG + } else if ((sizeof(char) <= sizeof(unsigned PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(char, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x)) +#endif + } + } else { +#if CYTHON_USE_PYLONG_INTERNALS + if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(char, __Pyx_compact_pylong, __Pyx_PyLong_CompactValue(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_SignedDigitCount(x)) { + case -2: + if ((8 * sizeof(char) - 1 > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(char, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(char) - 1 > 2 * PyLong_SHIFT)) { + return (char) (((char)-1)*(((((char)digits[1]) << PyLong_SHIFT) | (char)digits[0]))); + } + } + break; + case 2: + if ((8 * sizeof(char) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(char, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(char) - 1 > 2 * PyLong_SHIFT)) { + return (char) ((((((char)digits[1]) << PyLong_SHIFT) | (char)digits[0]))); + } + } + break; + case -3: + if ((8 * sizeof(char) - 1 > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(char, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(char) - 1 > 3 * PyLong_SHIFT)) { + return (char) (((char)-1)*(((((((char)digits[2]) << PyLong_SHIFT) | (char)digits[1]) << PyLong_SHIFT) | (char)digits[0]))); + } + } + break; + case 3: + if ((8 * sizeof(char) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(char, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(char) - 1 > 3 * PyLong_SHIFT)) { + return (char) ((((((((char)digits[2]) << PyLong_SHIFT) | (char)digits[1]) << PyLong_SHIFT) | (char)digits[0]))); + } + } + break; + case -4: + if ((8 * sizeof(char) - 1 > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(char, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(char) - 1 > 4 * PyLong_SHIFT)) { + return (char) (((char)-1)*(((((((((char)digits[3]) << PyLong_SHIFT) | (char)digits[2]) << PyLong_SHIFT) | (char)digits[1]) << PyLong_SHIFT) | (char)digits[0]))); + } + } + break; + case 4: + if ((8 * sizeof(char) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(char, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(char) - 1 > 4 * PyLong_SHIFT)) { + return (char) ((((((((((char)digits[3]) << PyLong_SHIFT) | (char)digits[2]) << PyLong_SHIFT) | (char)digits[1]) << PyLong_SHIFT) | (char)digits[0]))); + } + } + break; + } + } +#endif + if ((sizeof(char) <= sizeof(long))) { + __PYX_VERIFY_RETURN_INT_EXC(char, long, PyLong_AsLong(x)) +#ifdef HAVE_LONG_LONG + } else if ((sizeof(char) <= sizeof(PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(char, PY_LONG_LONG, PyLong_AsLongLong(x)) +#endif + } + } + { + char val; + PyObject *v = __Pyx_PyNumber_IntOrLong(x); +#if PY_MAJOR_VERSION < 3 + if (likely(v) && !PyLong_Check(v)) { + PyObject *tmp = v; + v = PyNumber_Long(tmp); + Py_DECREF(tmp); + } +#endif + if (likely(v)) { + int ret = -1; +#if PY_VERSION_HEX < 0x030d0000 && !(CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API) || defined(_PyLong_AsByteArray) + int one = 1; int is_little = (int)*(unsigned char *)&one; + unsigned char *bytes = (unsigned char *)&val; + ret = _PyLong_AsByteArray((PyLongObject *)v, + bytes, sizeof(val), + is_little, !is_unsigned); +#else + PyObject *stepval = NULL, *mask = NULL, *shift = NULL; + int bits, remaining_bits, is_negative = 0; + long idigit; + int chunk_size = (sizeof(long) < 8) ? 30 : 62; + if (unlikely(!PyLong_CheckExact(v))) { + PyObject *tmp = v; + v = PyNumber_Long(v); + assert(PyLong_CheckExact(v)); + Py_DECREF(tmp); + if (unlikely(!v)) return (char) -1; + } +#if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 + if (Py_SIZE(x) == 0) + return (char) 0; + is_negative = Py_SIZE(x) < 0; +#else + { + int result = PyObject_RichCompareBool(x, Py_False, Py_LT); + if (unlikely(result < 0)) + return (char) -1; + is_negative = result == 1; + } +#endif + if (is_unsigned && unlikely(is_negative)) { + goto raise_neg_overflow; + } else if (is_negative) { + stepval = PyNumber_Invert(v); + if (unlikely(!stepval)) + return (char) -1; + } else { + stepval = __Pyx_NewRef(v); + } + val = (char) 0; + mask = PyLong_FromLong((1L << chunk_size) - 1); if (unlikely(!mask)) goto done; + shift = PyLong_FromLong(chunk_size); if (unlikely(!shift)) goto done; + for (bits = 0; bits < (int) sizeof(char) * 8 - chunk_size; bits += chunk_size) { + PyObject *tmp, *digit; + digit = PyNumber_And(stepval, mask); + if (unlikely(!digit)) goto done; + idigit = PyLong_AsLong(digit); + Py_DECREF(digit); + if (unlikely(idigit < 0)) goto done; + tmp = PyNumber_Rshift(stepval, shift); + if (unlikely(!tmp)) goto done; + Py_DECREF(stepval); stepval = tmp; + val |= ((char) idigit) << bits; + #if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 + if (Py_SIZE(stepval) == 0) + goto unpacking_done; + #endif + } + idigit = PyLong_AsLong(stepval); + if (unlikely(idigit < 0)) goto done; + remaining_bits = ((int) sizeof(char) * 8) - bits - (is_unsigned ? 0 : 1); + if (unlikely(idigit >= (1L << remaining_bits))) + goto raise_overflow; + val |= ((char) idigit) << bits; + #if CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030B0000 + unpacking_done: + #endif + if (!is_unsigned) { + if (unlikely(val & (((char) 1) << (sizeof(char) * 8 - 1)))) + goto raise_overflow; + if (is_negative) + val = ~val; + } + ret = 0; + done: + Py_XDECREF(shift); + Py_XDECREF(mask); + Py_XDECREF(stepval); +#endif + Py_DECREF(v); + if (likely(!ret)) + return val; + } + return (char) -1; + } + } else { + char val; + PyObject *tmp = __Pyx_PyNumber_IntOrLong(x); + if (!tmp) return (char) -1; + val = __Pyx_PyInt_As_char(tmp); + Py_DECREF(tmp); + return val; + } +raise_overflow: + PyErr_SetString(PyExc_OverflowError, + "value too large to convert to char"); + return (char) -1; +raise_neg_overflow: + PyErr_SetString(PyExc_OverflowError, + "can't convert negative value to char"); + return (char) -1; +} + +/* FormatTypeName */ + #if CYTHON_COMPILING_IN_LIMITED_API +static __Pyx_TypeName +__Pyx_PyType_GetName(PyTypeObject* tp) +{ + PyObject *name = __Pyx_PyObject_GetAttrStr((PyObject *)tp, + __pyx_n_s_name_2); + if (unlikely(name == NULL) || unlikely(!PyUnicode_Check(name))) { + PyErr_Clear(); + Py_XDECREF(name); + name = __Pyx_NewRef(__pyx_n_s__35); + } + return name; +} +#endif + +/* CheckBinaryVersion */ + static unsigned long __Pyx_get_runtime_version(void) { +#if __PYX_LIMITED_VERSION_HEX >= 0x030B00A4 + return Py_Version & ~0xFFUL; +#else + const char* rt_version = Py_GetVersion(); + unsigned long version = 0; + unsigned long factor = 0x01000000UL; + unsigned int digit = 0; + int i = 0; + while (factor) { + while ('0' <= rt_version[i] && rt_version[i] <= '9') { + digit = digit * 10 + (unsigned int) (rt_version[i] - '0'); + ++i; + } + version += factor * digit; + if (rt_version[i] != '.') + break; + digit = 0; + factor >>= 8; + ++i; + } + return version; +#endif +} +static int __Pyx_check_binary_version(unsigned long ct_version, unsigned long rt_version, int allow_newer) { + const unsigned long MAJOR_MINOR = 0xFFFF0000UL; + if ((rt_version & MAJOR_MINOR) == (ct_version & MAJOR_MINOR)) + return 0; + if (likely(allow_newer && (rt_version & MAJOR_MINOR) > (ct_version & MAJOR_MINOR))) + return 1; + { + char message[200]; + PyOS_snprintf(message, sizeof(message), + "compile time Python version %d.%d " + "of module '%.100s' " + "%s " + "runtime version %d.%d", + (int) (ct_version >> 24), (int) ((ct_version >> 16) & 0xFF), + __Pyx_MODULE_NAME, + (allow_newer) ? "was newer than" : "does not match", + (int) (rt_version >> 24), (int) ((rt_version >> 16) & 0xFF) + ); + return PyErr_WarnEx(NULL, message, 1); + } +} + +/* InitStrings */ + #if PY_MAJOR_VERSION >= 3 +static int __Pyx_InitString(__Pyx_StringTabEntry t, PyObject **str) { + if (t.is_unicode | t.is_str) { + if (t.intern) { + *str = PyUnicode_InternFromString(t.s); + } else if (t.encoding) { + *str = PyUnicode_Decode(t.s, t.n - 1, t.encoding, NULL); + } else { + *str = PyUnicode_FromStringAndSize(t.s, t.n - 1); + } + } else { + *str = PyBytes_FromStringAndSize(t.s, t.n - 1); + } + if (!*str) + return -1; + if (PyObject_Hash(*str) == -1) + return -1; + return 0; +} +#endif +static int __Pyx_InitStrings(__Pyx_StringTabEntry *t) { + while (t->p) { + #if PY_MAJOR_VERSION >= 3 + __Pyx_InitString(*t, t->p); + #else + if (t->is_unicode) { + *t->p = PyUnicode_DecodeUTF8(t->s, t->n - 1, NULL); + } else if (t->intern) { + *t->p = PyString_InternFromString(t->s); + } else { + *t->p = PyString_FromStringAndSize(t->s, t->n - 1); + } + if (!*t->p) + return -1; + if (PyObject_Hash(*t->p) == -1) + return -1; + #endif + ++t; + } + return 0; +} + +#include +static CYTHON_INLINE Py_ssize_t __Pyx_ssize_strlen(const char *s) { + size_t len = strlen(s); + if (unlikely(len > (size_t) PY_SSIZE_T_MAX)) { + PyErr_SetString(PyExc_OverflowError, "byte string is too long"); + return -1; + } + return (Py_ssize_t) len; +} +static CYTHON_INLINE PyObject* __Pyx_PyUnicode_FromString(const char* c_str) { + Py_ssize_t len = __Pyx_ssize_strlen(c_str); + if (unlikely(len < 0)) return NULL; + return __Pyx_PyUnicode_FromStringAndSize(c_str, len); +} +static CYTHON_INLINE PyObject* __Pyx_PyByteArray_FromString(const char* c_str) { + Py_ssize_t len = __Pyx_ssize_strlen(c_str); + if (unlikely(len < 0)) return NULL; + return PyByteArray_FromStringAndSize(c_str, len); +} +static CYTHON_INLINE const char* __Pyx_PyObject_AsString(PyObject* o) { + Py_ssize_t ignore; + return __Pyx_PyObject_AsStringAndSize(o, &ignore); +} +#if __PYX_DEFAULT_STRING_ENCODING_IS_ASCII || __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT +#if !CYTHON_PEP393_ENABLED +static const char* __Pyx_PyUnicode_AsStringAndSize(PyObject* o, Py_ssize_t *length) { + char* defenc_c; + PyObject* defenc = _PyUnicode_AsDefaultEncodedString(o, NULL); + if (!defenc) return NULL; + defenc_c = PyBytes_AS_STRING(defenc); +#if __PYX_DEFAULT_STRING_ENCODING_IS_ASCII + { + char* end = defenc_c + PyBytes_GET_SIZE(defenc); + char* c; + for (c = defenc_c; c < end; c++) { + if ((unsigned char) (*c) >= 128) { + PyUnicode_AsASCIIString(o); + return NULL; + } + } + } +#endif + *length = PyBytes_GET_SIZE(defenc); + return defenc_c; +} +#else +static CYTHON_INLINE const char* __Pyx_PyUnicode_AsStringAndSize(PyObject* o, Py_ssize_t *length) { + if (unlikely(__Pyx_PyUnicode_READY(o) == -1)) return NULL; +#if __PYX_DEFAULT_STRING_ENCODING_IS_ASCII + if (likely(PyUnicode_IS_ASCII(o))) { + *length = PyUnicode_GET_LENGTH(o); + return PyUnicode_AsUTF8(o); + } else { + PyUnicode_AsASCIIString(o); + return NULL; + } +#else + return PyUnicode_AsUTF8AndSize(o, length); +#endif +} +#endif +#endif +static CYTHON_INLINE const char* __Pyx_PyObject_AsStringAndSize(PyObject* o, Py_ssize_t *length) { +#if __PYX_DEFAULT_STRING_ENCODING_IS_ASCII || __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT + if ( +#if PY_MAJOR_VERSION < 3 && __PYX_DEFAULT_STRING_ENCODING_IS_ASCII + __Pyx_sys_getdefaultencoding_not_ascii && +#endif + PyUnicode_Check(o)) { + return __Pyx_PyUnicode_AsStringAndSize(o, length); + } else +#endif +#if (!CYTHON_COMPILING_IN_PYPY && !CYTHON_COMPILING_IN_LIMITED_API) || (defined(PyByteArray_AS_STRING) && defined(PyByteArray_GET_SIZE)) + if (PyByteArray_Check(o)) { + *length = PyByteArray_GET_SIZE(o); + return PyByteArray_AS_STRING(o); + } else +#endif + { + char* result; + int r = PyBytes_AsStringAndSize(o, &result, length); + if (unlikely(r < 0)) { + return NULL; + } else { + return result; + } + } +} +static CYTHON_INLINE int __Pyx_PyObject_IsTrue(PyObject* x) { + int is_true = x == Py_True; + if (is_true | (x == Py_False) | (x == Py_None)) return is_true; + else return PyObject_IsTrue(x); +} +static CYTHON_INLINE int __Pyx_PyObject_IsTrueAndDecref(PyObject* x) { + int retval; + if (unlikely(!x)) return -1; + retval = __Pyx_PyObject_IsTrue(x); + Py_DECREF(x); + return retval; +} +static PyObject* __Pyx_PyNumber_IntOrLongWrongResultType(PyObject* result, const char* type_name) { + __Pyx_TypeName result_type_name = __Pyx_PyType_GetName(Py_TYPE(result)); +#if PY_MAJOR_VERSION >= 3 + if (PyLong_Check(result)) { + if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1, + "__int__ returned non-int (type " __Pyx_FMT_TYPENAME "). " + "The ability to return an instance of a strict subclass of int is deprecated, " + "and may be removed in a future version of Python.", + result_type_name)) { + __Pyx_DECREF_TypeName(result_type_name); + Py_DECREF(result); + return NULL; + } + __Pyx_DECREF_TypeName(result_type_name); + return result; + } +#endif + PyErr_Format(PyExc_TypeError, + "__%.4s__ returned non-%.4s (type " __Pyx_FMT_TYPENAME ")", + type_name, type_name, result_type_name); + __Pyx_DECREF_TypeName(result_type_name); + Py_DECREF(result); + return NULL; +} +static CYTHON_INLINE PyObject* __Pyx_PyNumber_IntOrLong(PyObject* x) { +#if CYTHON_USE_TYPE_SLOTS + PyNumberMethods *m; +#endif + const char *name = NULL; + PyObject *res = NULL; +#if PY_MAJOR_VERSION < 3 + if (likely(PyInt_Check(x) || PyLong_Check(x))) +#else + if (likely(PyLong_Check(x))) +#endif + return __Pyx_NewRef(x); +#if CYTHON_USE_TYPE_SLOTS + m = Py_TYPE(x)->tp_as_number; + #if PY_MAJOR_VERSION < 3 + if (m && m->nb_int) { + name = "int"; + res = m->nb_int(x); + } + else if (m && m->nb_long) { + name = "long"; + res = m->nb_long(x); + } + #else + if (likely(m && m->nb_int)) { + name = "int"; + res = m->nb_int(x); + } + #endif +#else + if (!PyBytes_CheckExact(x) && !PyUnicode_CheckExact(x)) { + res = PyNumber_Int(x); + } +#endif + if (likely(res)) { +#if PY_MAJOR_VERSION < 3 + if (unlikely(!PyInt_Check(res) && !PyLong_Check(res))) { +#else + if (unlikely(!PyLong_CheckExact(res))) { +#endif + return __Pyx_PyNumber_IntOrLongWrongResultType(res, name); + } + } + else if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_TypeError, + "an integer is required"); + } + return res; +} +static CYTHON_INLINE Py_ssize_t __Pyx_PyIndex_AsSsize_t(PyObject* b) { + Py_ssize_t ival; + PyObject *x; +#if PY_MAJOR_VERSION < 3 + if (likely(PyInt_CheckExact(b))) { + if (sizeof(Py_ssize_t) >= sizeof(long)) + return PyInt_AS_LONG(b); + else + return PyInt_AsSsize_t(b); + } +#endif + if (likely(PyLong_CheckExact(b))) { + #if CYTHON_USE_PYLONG_INTERNALS + if (likely(__Pyx_PyLong_IsCompact(b))) { + return __Pyx_PyLong_CompactValue(b); + } else { + const digit* digits = __Pyx_PyLong_Digits(b); + const Py_ssize_t size = __Pyx_PyLong_SignedDigitCount(b); + switch (size) { + case 2: + if (8 * sizeof(Py_ssize_t) > 2 * PyLong_SHIFT) { + return (Py_ssize_t) (((((size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); + } + break; + case -2: + if (8 * sizeof(Py_ssize_t) > 2 * PyLong_SHIFT) { + return -(Py_ssize_t) (((((size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); + } + break; + case 3: + if (8 * sizeof(Py_ssize_t) > 3 * PyLong_SHIFT) { + return (Py_ssize_t) (((((((size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); + } + break; + case -3: + if (8 * sizeof(Py_ssize_t) > 3 * PyLong_SHIFT) { + return -(Py_ssize_t) (((((((size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); + } + break; + case 4: + if (8 * sizeof(Py_ssize_t) > 4 * PyLong_SHIFT) { + return (Py_ssize_t) (((((((((size_t)digits[3]) << PyLong_SHIFT) | (size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); + } + break; + case -4: + if (8 * sizeof(Py_ssize_t) > 4 * PyLong_SHIFT) { + return -(Py_ssize_t) (((((((((size_t)digits[3]) << PyLong_SHIFT) | (size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); + } + break; + } + } + #endif + return PyLong_AsSsize_t(b); + } + x = PyNumber_Index(b); + if (!x) return -1; + ival = PyInt_AsSsize_t(x); + Py_DECREF(x); + return ival; +} +static CYTHON_INLINE Py_hash_t __Pyx_PyIndex_AsHash_t(PyObject* o) { + if (sizeof(Py_hash_t) == sizeof(Py_ssize_t)) { + return (Py_hash_t) __Pyx_PyIndex_AsSsize_t(o); +#if PY_MAJOR_VERSION < 3 + } else if (likely(PyInt_CheckExact(o))) { + return PyInt_AS_LONG(o); +#endif + } else { + Py_ssize_t ival; + PyObject *x; + x = PyNumber_Index(o); + if (!x) return -1; + ival = PyInt_AsLong(x); + Py_DECREF(x); + return ival; + } +} +static CYTHON_INLINE PyObject * __Pyx_PyBool_FromLong(long b) { + return b ? __Pyx_NewRef(Py_True) : __Pyx_NewRef(Py_False); +} +static CYTHON_INLINE PyObject * __Pyx_PyInt_FromSize_t(size_t ival) { + return PyInt_FromSize_t(ival); +} + + +/* #### Code section: utility_code_pragmas_end ### */ +#ifdef _MSC_VER +#pragma warning( pop ) +#endif + + + +/* #### Code section: end ### */ +#endif /* Py_PYTHON_H */ diff --git a/fairseq/data/token_block_utils_fast.pyx b/fairseq/data/token_block_utils_fast.pyx new file mode 100644 index 0000000000000000000000000000000000000000..08af4f30613a7b6ffa965a7c7084acabec8f8749 --- /dev/null +++ b/fairseq/data/token_block_utils_fast.pyx @@ -0,0 +1,187 @@ +# cython: language_level=3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from itertools import chain +from libc.math cimport ceil + +cimport cython +cimport numpy as np + +from libc.stdint cimport int32_t, int64_t + +DTYPE = np.int64 +ctypedef int64_t DTYPE_t + + +@cython.boundscheck(False) +@cython.wraparound(False) +@cython.nonecheck(False) +cdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_none_mode(np.ndarray[DTYPE_t, ndim=1] sizes, int block_size): + cdef DTYPE_t total_size = sizes.sum() + cdef DTYPE_t length = ceil(total_size / block_size) + cdef np.ndarray[DTYPE_t, ndim=2] slice_indices = np.zeros([length, 2], dtype=DTYPE) + cdef DTYPE_t[:, :] slice_indices_view = slice_indices + cdef DTYPE_t i + cdef DTYPE_t start + cdef DTYPE_t end + for i in range(length): + start = i * block_size + end = min(start + block_size, total_size) + slice_indices_view[i][0] = start + slice_indices_view[i][1] = end + return slice_indices + + +cdef np.ndarray[DTYPE_t, ndim=2] _fast_convert_to_np_array(list list_of_list): + """ + Faster function to convert DTYPE_t list of list. + Only fast when there are huge number of rows and low number of columns. + """ + cdef np.ndarray[DTYPE_t, ndim=1] flat = np.fromiter(chain.from_iterable(list_of_list), DTYPE, -1) + return flat.reshape((len(list_of_list), -1)) + + +@cython.boundscheck(False) +@cython.wraparound(False) +@cython.nonecheck(False) +cpdef np.ndarray[DTYPE_t, ndim=2] _get_slice_indices_fast(np.ndarray[DTYPE_t, ndim=1] sizes, str break_mode, int block_size, int document_sep_len): + cdef DTYPE_t tok_idx = 0 + cdef DTYPE_t sz_idx = 0 + cdef DTYPE_t curr_size = 0 + cdef DTYPE_t i = 0 + cdef DTYPE_t length + cdef DTYPE_t total_size + cdef DTYPE_t[:] sizes_view = sizes + cdef np.ndarray[DTYPE_t, ndim=2] slice_indices + cdef list slice_indices_list = [] + + if break_mode is None or break_mode == 'none': + slice_indices = _get_slice_indices_none_mode(sizes, block_size) + elif break_mode == 'complete': + while sz_idx < len(sizes_view): + if curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0: + curr_size += sizes_view[sz_idx] + sz_idx += 1 + else: + slice_indices_list.append((tok_idx, tok_idx + curr_size)) + tok_idx += curr_size + curr_size = 0 + if curr_size > 0: + slice_indices_list.append((tok_idx, tok_idx + curr_size)) + slice_indices = _fast_convert_to_np_array(slice_indices_list) + elif break_mode == 'complete_doc': + while sz_idx < len(sizes_view): + if ( + (curr_size + sizes_view[sz_idx] <= block_size or curr_size == 0) + # an empty sentence indicates end-of-document: + and sizes_view[sz_idx] != document_sep_len + ): + curr_size += sizes_view[sz_idx] + sz_idx += 1 + else: + # Only keep non-empty documents. + if curr_size > 1: + slice_indices_list.append((tok_idx, tok_idx + curr_size)) + tok_idx += curr_size + curr_size = 0 + if sizes_view[sz_idx] == document_sep_len: + tok_idx += sizes_view[sz_idx] + sz_idx += 1 + if curr_size > 1: + slice_indices_list.append((tok_idx, tok_idx + curr_size)) + slice_indices = _fast_convert_to_np_array(slice_indices_list) + elif break_mode == 'eos': + slice_indices = np.zeros((len(sizes), 2), dtype=DTYPE) + cumsum = sizes.cumsum(axis=0) + slice_indices[1:, 0] = cumsum[:cumsum.shape[0] - 1] + slice_indices[:, 1] = cumsum + else: + raise ValueError('Invalid break_mode: ' + break_mode) + return slice_indices + + +@cython.boundscheck(False) +@cython.wraparound(False) +@cython.nonecheck(False) +cpdef np.ndarray[DTYPE_t, ndim=2] _get_block_to_dataset_index_fast(np.ndarray[DTYPE_t, ndim=1] sizes, np.ndarray[DTYPE_t, ndim=2] slice_indices): + cdef DTYPE_t start_ds_idx + cdef DTYPE_t start_offset + cdef DTYPE_t end_ds_idx + cdef DTYPE_t i + cdef DTYPE_t s + cdef DTYPE_t e + cdef DatasetSearcher ds = DatasetSearcher(sizes) + cdef np.ndarray[DTYPE_t, ndim=2] block_to_dataset_index = np.zeros([len(slice_indices), 3], dtype=DTYPE) + cdef DTYPE_t[:, :] block_to_dataset_index_view = block_to_dataset_index + cdef DTYPE_t[:, :] slice_indices_view = slice_indices + cdef Py_ssize_t x_max = slice_indices.shape[0] + + for i in range(x_max): + s = slice_indices_view[i][0] + e = slice_indices_view[i][1] + ds.seek(s) + start_ds_idx = ds.current_index + start_offset = ds.current_offset + if e <= s: + end_ds_idx = start_ds_idx + else: + ds.seek(e - 1) + end_ds_idx = ds.current_index + block_to_dataset_index_view[i][0] = start_ds_idx # starting index in dataset + block_to_dataset_index_view[i][1] = start_offset # starting offset within starting index + block_to_dataset_index_view[i][2] = end_ds_idx # ending index in dataset + return block_to_dataset_index + + +cdef class DatasetSearcher(object): + """Helper for mapping "flat" indices to indices and offsets in an + underlying dataset.""" + cdef DTYPE_t current_i + cdef DTYPE_t current_offset + cdef DTYPE_t current_index + cdef DTYPE_t[:] sizes + + def __init__(self, DTYPE_t[:] sizes): + self.sizes = sizes + self.reset() + + cdef reset(self): + self.current_offset = 0 # offset within current index in underlying dataset + self.current_i = 0 # "flat" index + self.current_index = 0 # index in underlying dataset + + @cython.boundscheck(False) + @cython.wraparound(False) + @cython.nonecheck(False) + cdef int step(self, DTYPE_t i): + cdef DTYPE_t to_consume + cdef DTYPE_t remaining + if i < self.current_i: + self.reset() + if i > self.current_i: + to_consume = i - self.current_i + remaining = self.sizes[self.current_index] - self.current_offset + if remaining > to_consume: + self.current_offset += to_consume + self.current_i += to_consume + else: + assert remaining >= 0 + self.current_i += remaining + self.current_index += 1 + self.current_offset = 0 + return 1 + return 0 + + @cython.boundscheck(False) + @cython.wraparound(False) + @cython.nonecheck(False) + cdef seek(self, DTYPE_t i): + cdef int not_done = 1 + while not_done == 1: + not_done = self.step(i) + assert self.current_i == i diff --git a/fairseq/data/transform_eos_concat_langpair_dataset.py b/fairseq/data/transform_eos_concat_langpair_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..effa127d50c63546c7eeac952053930dd0a4f2b1 --- /dev/null +++ b/fairseq/data/transform_eos_concat_langpair_dataset.py @@ -0,0 +1,139 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +from torch.utils.data.dataloader import default_collate + +from fairseq.data import ConcatDataset + +logger = logging.getLogger(__name__) + + +class TransformEosConcatLangPairDataset(ConcatDataset): + """ + It is a combination of TransformEosLangPairDataset and ConcatDataset for multiple LangPairDataset datasets. + Assume all datasets share the same src_eos, tgt_bos, left_pad_source and left_pad_target + """ + + def __init__( + self, + datasets, + src_eos, + tgt_bos, + new_src_eos=None, + new_tgt_bos=None, + ): + super().__init__(datasets) + if new_src_eos is not None and new_src_eos != []: + assert len(new_src_eos) == len(datasets) + else: + new_src_eos = [] + if new_tgt_bos is not None and new_tgt_bos != []: + assert len(new_tgt_bos) == len(datasets) + else: + new_tgt_bos = [] + self.src_eos = src_eos + self.tgt_bos = tgt_bos + self.new_src_eos = ( + torch.LongTensor(new_src_eos).cpu() if len(new_src_eos) > 0 else [] + ) + self.new_tgt_bos = ( + torch.LongTensor(new_tgt_bos).cpu() if len(new_tgt_bos) > 0 else [] + ) + self.left_pad_source = self.is_left_pad_source(datasets) + self.left_pad_target = self.is_left_pad_target(datasets) + self.pad_idx = self.src_dict_pad() + + def src_dict_pad(self): + if hasattr(self.datasets[0], "src_dict"): + return self.datasets[0].src_dict.pad() + if hasattr(self.datasets[0], "dataset"): + return self.datasets[0].dataset.src_dict.pad() + raise NotImplementedError("No src_dict is found") + + def __getitem__(self, idx): + dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) + return dataset_idx, self.datasets[dataset_idx][sample_idx] + + def is_left_pad_source(self, datasets): + def _left_pad_source(ds): + if hasattr(ds, "left_pad_source"): + return ds.left_pad_source + if hasattr(ds, "dataset"): + return _left_pad_source(ds.dataset) + logger.warn(f"{type(ds)} has no left_pad_source, using default True") + return True + + left_pad_source = _left_pad_source(datasets[0]) + for ds in datasets: + if left_pad_source != _left_pad_source(ds): + raise ValueError("Different left_pad_source setting detected!") + return left_pad_source + + def is_left_pad_target(self, datasets): + def _left_pad_target(ds): + if hasattr(ds, "left_pad_target"): + return ds.left_pad_target + if hasattr(ds, "dataset"): + return _left_pad_target(ds.dataset) + logger.warn(f"{type(ds)} has no left_pad_target, using default False") + return False + + left_pad_target = _left_pad_target(datasets[0]) + for ds in datasets: + if left_pad_target != _left_pad_target(ds): + raise ValueError("Different left_pad_target setting detected!") + return left_pad_target + + def collater(self, samples, **extra_args): + if len(samples) == 0: + return samples + + dataset_ids = [s[0] for s in samples] + samples = [s[1] for s in samples] + + if hasattr(self.datasets[0], "collater"): + samples = self.datasets[0].collater(samples, **extra_args) + else: + samples = default_collate(samples, **extra_args) + + if len(self.new_src_eos) > 0: + if self.left_pad_source: + assert ( + samples["net_input"]["src_tokens"][:, -1] != self.src_eos + ).sum() == 0 + samples["net_input"]["src_tokens"][:, -1] = self.new_src_eos[ + dataset_ids + ] + + else: + eos_idx = samples["net_input"]["src_lengths"] - 1 + assert ( + samples["net_input"]["src_tokens"][ + torch.arange(eos_idx.size(0)), eos_idx + ] + != self.src_eos + ).sum() == 0 + samples["net_input"]["src_tokens"].scatter_( + 1, eos_idx.view(-1, 1), self.new_src_eos[dataset_ids].view(-1, 1) + ) + + if len(self.new_tgt_bos) > 0 and "prev_output_tokens" in samples["net_input"]: + if self.left_pad_target: + # TODO: support different padding direction on target side + raise NotImplementedError( + "TransformEosLangPairDataset does not implement --left-pad-target True option" + ) + else: + assert ( + samples["net_input"]["prev_output_tokens"][:, 0] != self.tgt_bos + ).sum() == 0 + samples["net_input"]["prev_output_tokens"][:, 0] = self.new_tgt_bos[ + dataset_ids + ] + + return samples diff --git a/fairseq/data/transform_eos_dataset.py b/fairseq/data/transform_eos_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fb14ff018edf13b20f5d0e486692dfb0a37ec6d1 --- /dev/null +++ b/fairseq/data/transform_eos_dataset.py @@ -0,0 +1,120 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from . import FairseqDataset + + +class TransformEosDataset(FairseqDataset): + """A :class:`~fairseq.data.FairseqDataset` wrapper that appends/prepends/strips EOS. + + Note that the transformation is applied in :func:`collater`. + + Args: + dataset (~fairseq.data.FairseqDataset): dataset to wrap + eos (int): index of the end-of-sentence symbol + append_eos_to_src (bool, optional): append EOS to the end of src + remove_eos_from_src (bool, optional): remove EOS from the end of src + append_eos_to_tgt (bool, optional): append EOS to the end of tgt + remove_eos_from_tgt (bool, optional): remove EOS from the end of tgt + """ + + def __init__( + self, + dataset, + eos, + append_eos_to_src=False, + remove_eos_from_src=False, + append_eos_to_tgt=False, + remove_eos_from_tgt=False, + has_target=True, + ): + if not isinstance(dataset, FairseqDataset): + raise ValueError("dataset must be an instance of FairseqDataset") + if append_eos_to_src and remove_eos_from_src: + raise ValueError("cannot combine append_eos_to_src and remove_eos_from_src") + if append_eos_to_tgt and remove_eos_from_tgt: + raise ValueError("cannot combine append_eos_to_tgt and remove_eos_from_tgt") + + self.dataset = dataset + self.eos = torch.LongTensor([eos]) + self.append_eos_to_src = append_eos_to_src + self.remove_eos_from_src = remove_eos_from_src + self.append_eos_to_tgt = append_eos_to_tgt + self.remove_eos_from_tgt = remove_eos_from_tgt + self.has_target = has_target + + # precompute how we should adjust the reported sizes + self._src_delta = 0 + self._src_delta += 1 if append_eos_to_src else 0 + self._src_delta -= 1 if remove_eos_from_src else 0 + self._tgt_delta = 0 + self._tgt_delta += 1 if append_eos_to_tgt else 0 + self._tgt_delta -= 1 if remove_eos_from_tgt else 0 + + self._checked_src = False + self._checked_tgt = False + + def _check_src(self, src, expect_eos): + if not self._checked_src: + assert (src[-1] == self.eos[0]) == expect_eos + self._checked_src = True + + def _check_tgt(self, tgt, expect_eos): + if self.has_target and not self._checked_tgt: + assert (tgt[-1] == self.eos[0]) == expect_eos + self._checked_tgt = True + + def __getitem__(self, index): + return self.dataset[index] + + def __len__(self): + return len(self.dataset) + + def collater(self, samples): + def transform(item): + if self.append_eos_to_src: + self.eos = self.eos.to(device=item["source"].device) + self._check_src(item["source"], expect_eos=False) + item["source"] = torch.cat([item["source"], self.eos]) + if self.remove_eos_from_src: + self.eos = self.eos.to(device=item["source"].device) + self._check_src(item["source"], expect_eos=True) + item["source"] = item["source"][:-1] + if self.append_eos_to_tgt: + self.eos = self.eos.to(device=item["target"].device) + self._check_tgt(item["target"], expect_eos=False) + item["target"] = torch.cat([item["target"], self.eos]) + if self.remove_eos_from_tgt: + self.eos = self.eos.to(device=item["target"].device) + self._check_tgt(item["target"], expect_eos=True) + item["target"] = item["target"][:-1] + return item + + samples = list(map(transform, samples)) + return self.dataset.collater(samples) + + def num_tokens(self, index): + return self.dataset.num_tokens(index) + + def size(self, index): + if self.has_target: + src_len, tgt_len = self.dataset.size(index) + return (src_len + self._src_delta, tgt_len + self._tgt_delta) + else: + return self.dataset.size(index) + + def ordered_indices(self): + # NOTE: we assume that the ordering does not change based on the + # addition or removal of eos + return self.dataset.ordered_indices() + + @property + def supports_prefetch(self): + return getattr(self.dataset, "supports_prefetch", False) + + def prefetch(self, indices): + return self.dataset.prefetch(indices) diff --git a/fairseq/data/transform_eos_lang_pair_dataset.py b/fairseq/data/transform_eos_lang_pair_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d8b21090144bfc975d4d5a3ee2c21b2e8acde03d --- /dev/null +++ b/fairseq/data/transform_eos_lang_pair_dataset.py @@ -0,0 +1,113 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Optional + +import torch + +from . import FairseqDataset + + +class TransformEosLangPairDataset(FairseqDataset): + """A :class:`~fairseq.data.FairseqDataset` wrapper that transform bos on + collated samples of language pair dataset. + + Note that the transformation is applied in :func:`collater`. + + Args: + dataset (~fairseq.data.FairseqDataset): dataset that collates sample into + LanguagePairDataset schema + src_eos (int): original source end-of-sentence symbol index to be replaced + new_src_eos (int, optional): new end-of-sentence symbol index to replace source eos symbol + tgt_bos (int, optional): original target beginning-of-sentence symbol index to be replaced + new_tgt_bos (int, optional): new beginning-of-sentence symbol index to replace at the + beginning of 'prev_output_tokens' + """ + + def __init__( + self, + dataset: FairseqDataset, + src_eos: int, + new_src_eos: Optional[int] = None, + tgt_bos: Optional[int] = None, + new_tgt_bos: Optional[int] = None, + ): + self.dataset = dataset + self.src_eos = src_eos + self.new_src_eos = new_src_eos + self.tgt_bos = tgt_bos + self.new_tgt_bos = new_tgt_bos + + def __getitem__(self, index): + return self.dataset[index] + + def __len__(self): + return len(self.dataset) + + def collater(self, samples, **extra_args): + samples = self.dataset.collater(samples, **extra_args) + if len(samples) == 0: + return samples + + if "net_input" not in samples: + return samples + + if self.new_src_eos is not None: + if self.dataset.left_pad_source: + assert ( + samples["net_input"]["src_tokens"][:, -1] != self.src_eos + ).sum() == 0 + samples["net_input"]["src_tokens"][:, -1] = self.new_src_eos + else: + eos_idx = samples["net_input"]["src_lengths"] - 1 + assert ( + samples["net_input"]["src_tokens"][ + torch.arange(eos_idx.size(0)), eos_idx + ] + != self.src_eos + ).sum() == 0 + eos_idx = eos_idx.resize_(len(samples["net_input"]["src_lengths"]), 1) + samples["net_input"]["src_tokens"].scatter_( + 1, eos_idx, self.new_src_eos + ) + + if ( + self.new_tgt_bos is not None + and "prev_output_tokens" in samples["net_input"] + ): + if self.dataset.left_pad_target: + # TODO: support different padding direction on target side + raise NotImplementedError( + "TransformEosLangPairDataset does not implement --left-pad-target True option" + ) + else: + assert ( + samples["net_input"]["prev_output_tokens"][:, 0] != self.tgt_bos + ).sum() == 0 + samples["net_input"]["prev_output_tokens"][:, 0] = self.new_tgt_bos + + return samples + + def num_tokens(self, index): + return self.dataset.num_tokens(index) + + def size(self, index): + return self.dataset.size(index) + + @property + def sizes(self): + # dataset.sizes can be a dynamically computed sizes: + return self.dataset.sizes + + def ordered_indices(self): + return self.dataset.ordered_indices() + + @property + def supports_prefetch(self): + return getattr(self.dataset, "supports_prefetch", False) + + def prefetch(self, indices): + return self.dataset.prefetch(indices) diff --git a/fairseq/dataclass/__init__.py b/fairseq/dataclass/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..25408d28ec44cee56eb5fb3ab0c817dc04159e95 --- /dev/null +++ b/fairseq/dataclass/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .configs import FairseqDataclass +from .constants import ChoiceEnum + + +__all__ = [ + "FairseqDataclass", + "ChoiceEnum", +] diff --git a/fairseq/dataclass/__pycache__/__init__.cpython-310.pyc b/fairseq/dataclass/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0b8a9504c66bc8ffe270ebbcb1794a1aaf96d3c Binary files /dev/null and b/fairseq/dataclass/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/dataclass/__pycache__/__init__.cpython-311.pyc b/fairseq/dataclass/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a0816e9ff4faa654934bc7edaf05c5aae2967cd Binary files /dev/null and b/fairseq/dataclass/__pycache__/__init__.cpython-311.pyc differ diff --git a/fairseq/dataclass/__pycache__/configs.cpython-310.pyc b/fairseq/dataclass/__pycache__/configs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9bc2254f51ed5e81955692ca875d8629ee043b2 Binary files /dev/null and b/fairseq/dataclass/__pycache__/configs.cpython-310.pyc differ diff --git a/fairseq/dataclass/__pycache__/configs.cpython-311.pyc b/fairseq/dataclass/__pycache__/configs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10a0df00baeec0cbf3dfde8459c1db31d63a0d69 Binary files /dev/null and b/fairseq/dataclass/__pycache__/configs.cpython-311.pyc differ diff --git a/fairseq/dataclass/__pycache__/constants.cpython-310.pyc b/fairseq/dataclass/__pycache__/constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6a100a4b6d496516f90bc3bffcdb1c3792b8f82 Binary files /dev/null and b/fairseq/dataclass/__pycache__/constants.cpython-310.pyc differ diff --git a/fairseq/dataclass/__pycache__/constants.cpython-311.pyc b/fairseq/dataclass/__pycache__/constants.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bfda252c203e7a4e53f63a7e6c973f12ae3584d Binary files /dev/null and b/fairseq/dataclass/__pycache__/constants.cpython-311.pyc differ diff --git a/fairseq/dataclass/__pycache__/initialize.cpython-310.pyc b/fairseq/dataclass/__pycache__/initialize.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..640301859e48f64a2ce4794220b2cc67fc87da0d Binary files /dev/null and b/fairseq/dataclass/__pycache__/initialize.cpython-310.pyc differ diff --git a/fairseq/dataclass/__pycache__/utils.cpython-310.pyc b/fairseq/dataclass/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a380dc27c76890736b0af30c2347437410cb2c5 Binary files /dev/null and b/fairseq/dataclass/__pycache__/utils.cpython-310.pyc differ diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py new file mode 100644 index 0000000000000000000000000000000000000000..af957fec64711c697da6840969da305e412783df --- /dev/null +++ b/fairseq/dataclass/configs.py @@ -0,0 +1,1147 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import sys +from dataclasses import _MISSING_TYPE, dataclass, field +from typing import Any, List, Optional + +import torch +from omegaconf import II, MISSING + +from fairseq.dataclass.constants import ( + DATASET_IMPL_CHOICES, + DDP_BACKEND_CHOICES, + DDP_COMM_HOOK_CHOICES, + GENERATION_CONSTRAINTS_CHOICES, + GENERATION_DECODING_FORMAT_CHOICES, + LOG_FORMAT_CHOICES, + PIPELINE_CHECKPOINT_CHOICES, + PRINT_ALIGNMENT_CHOICES, + ZERO_SHARDING_CHOICES, +) + + +@dataclass +class FairseqDataclass: + """fairseq base dataclass that supported fetching attributes and metas""" + + _name: Optional[str] = None + + @staticmethod + def name(): + return None + + def _get_all_attributes(self) -> List[str]: + return [k for k in self.__dataclass_fields__.keys()] + + def _get_meta( + self, attribute_name: str, meta: str, default: Optional[Any] = None + ) -> Any: + return self.__dataclass_fields__[attribute_name].metadata.get(meta, default) + + def _get_name(self, attribute_name: str) -> str: + return self.__dataclass_fields__[attribute_name].name + + def _get_default(self, attribute_name: str) -> Any: + if hasattr(self, attribute_name): + if str(getattr(self, attribute_name)).startswith("${"): + return str(getattr(self, attribute_name)) + elif str(self.__dataclass_fields__[attribute_name].default).startswith( + "${" + ): + return str(self.__dataclass_fields__[attribute_name].default) + elif ( + getattr(self, attribute_name) + != self.__dataclass_fields__[attribute_name].default + ): + return getattr(self, attribute_name) + + f = self.__dataclass_fields__[attribute_name] + if not isinstance(f.default_factory, _MISSING_TYPE): + return f.default_factory() + return f.default + + def _get_type(self, attribute_name: str) -> Any: + return self.__dataclass_fields__[attribute_name].type + + def _get_help(self, attribute_name: str) -> Any: + return self._get_meta(attribute_name, "help") + + def _get_argparse_const(self, attribute_name: str) -> Any: + return self._get_meta(attribute_name, "argparse_const") + + def _get_argparse_alias(self, attribute_name: str) -> Any: + return self._get_meta(attribute_name, "argparse_alias") + + def _get_choices(self, attribute_name: str) -> Any: + return self._get_meta(attribute_name, "choices") + + @classmethod + def from_namespace(cls, args): + if isinstance(args, cls): + return args + else: + config = cls() + for k in config.__dataclass_fields__.keys(): + if k.startswith("_"): + # private member, skip + continue + if hasattr(args, k): + setattr(config, k, getattr(args, k)) + + return config + + +@dataclass +class CommonConfig(FairseqDataclass): + # This is the core dataclass including common parameters shared by all different jobs. Please append your params to other dataclasses if they were + # used for a particular purpose or task, such as those dedicated for `distributed training`, `optimization`, etc. + no_progress_bar: bool = field( + default=False, metadata={"help": "disable progress bar"} + ) + log_interval: int = field( + default=100, + metadata={ + "help": "log progress every N batches (when progress bar is disabled)" + }, + ) + log_format: Optional[LOG_FORMAT_CHOICES] = field( + default=None, metadata={"help": "log format to use"} + ) + log_file: Optional[str] = field( + default=None, metadata={"help": "log file to copy metrics to."} + ) + aim_repo: Optional[str] = field( + default=None, + metadata={"help": "path to Aim repository"}, + ) + aim_run_hash: Optional[str] = field( + default=None, + metadata={ + "help": "Aim run hash. If skipped, creates or continues run " + "based on save_dir" + }, + ) + tensorboard_logdir: Optional[str] = field( + default=None, + metadata={ + "help": "path to save logs for tensorboard, should match --logdir " + "of running tensorboard (default: no tensorboard logging)" + }, + ) + wandb_project: Optional[str] = field( + default=None, + metadata={"help": "Weights and Biases project name to use for logging"}, + ) + azureml_logging: Optional[bool] = field( + default=False, + metadata={"help": "Log scalars to AzureML context"}, + ) + seed: int = field( + default=1, metadata={"help": "pseudo random number generator seed"} + ) + cpu: bool = field(default=False, metadata={"help": "use CPU instead of CUDA"}) + tpu: bool = field(default=False, metadata={"help": "use TPU instead of CUDA"}) + bf16: bool = field(default=False, metadata={"help": "use bfloat16; implies --tpu"}) + memory_efficient_bf16: bool = field( + default=False, + metadata={ + "help": "use a memory-efficient version of BF16 training; implies --bf16" + }, + ) + fp16: bool = field(default=False, metadata={"help": "use FP16"}) + memory_efficient_fp16: bool = field( + default=False, + metadata={ + "help": "use a memory-efficient version of FP16 training; implies --fp16" + }, + ) + fp16_no_flatten_grads: bool = field( + default=False, metadata={"help": "don't flatten FP16 grads tensor"} + ) + fp16_init_scale: int = field( + default=2**7, metadata={"help": "default FP16 loss scale"} + ) + fp16_scale_window: Optional[int] = field( + default=None, + metadata={"help": "number of updates before increasing loss scale"}, + ) + fp16_scale_tolerance: float = field( + default=0.0, + metadata={ + "help": "pct of updates that can overflow before decreasing the loss scale" + }, + ) + on_cpu_convert_precision: bool = field( + default=False, + metadata={ + "help": "if set, the floating point conversion to fp16/bf16 runs on CPU. " + "This reduces bus transfer time and GPU memory usage." + }, + ) + min_loss_scale: float = field( + default=1e-4, + metadata={ + "help": "minimum FP16/AMP loss scale, after which training is stopped" + }, + ) + threshold_loss_scale: Optional[float] = field( + default=None, metadata={"help": "threshold FP16 loss scale from below"} + ) + amp: bool = field(default=False, metadata={"help": "use automatic mixed precision"}) + amp_batch_retries: int = field( + default=2, + metadata={ + "help": "number of retries of same batch after reducing loss scale with AMP" + }, + ) + amp_init_scale: int = field( + default=2**7, metadata={"help": "default AMP loss scale"} + ) + amp_scale_window: Optional[int] = field( + default=None, + metadata={"help": "number of updates before increasing AMP loss scale"}, + ) + user_dir: Optional[str] = field( + default=None, + metadata={ + "help": "path to a python module containing custom extensions (tasks and/or architectures)" + }, + ) + empty_cache_freq: int = field( + default=0, + metadata={"help": "how often to clear the PyTorch CUDA cache (0 to disable)"}, + ) + all_gather_list_size: int = field( + default=16384, + metadata={"help": "number of bytes reserved for gathering stats from workers"}, + ) + model_parallel_size: int = field( + default=1, metadata={"help": "total number of GPUs to parallelize model over"} + ) + quantization_config_path: Optional[str] = field( + default=None, metadata={"help": "path to quantization config file"} + ) + profile: bool = field( + default=False, metadata={"help": "enable autograd profiler emit_nvtx"} + ) + reset_logging: bool = field( + default=False, + metadata={ + "help": "when using Hydra, reset the logging at the beginning of training" + }, + ) + suppress_crashes: bool = field( + default=False, + metadata={ + "help": "suppress crashes when training with the hydra_train entry point so that the " + "main method can return a value (useful for sweeps)" + }, + ) + use_plasma_view: bool = field( + default=False, metadata={"help": "Store indices and sizes in shared memory"} + ) + plasma_path: Optional[str] = field( + default="/tmp/plasma", + metadata={ + "help": "path to run plasma_store, defaults to /tmp/plasma. Paths outside /tmp tend to fail." + }, + ) + + +@dataclass +class DistributedTrainingConfig(FairseqDataclass): + distributed_world_size: int = field( + default=max(1, torch.cuda.device_count()), + metadata={ + "help": "total number of GPUs across all nodes (default: all visible GPUs)" + }, + ) + distributed_num_procs: Optional[int] = field( + default=max(1, torch.cuda.device_count()), + metadata={ + "help": "total number of processes to fork (default: all visible GPUs)" + }, + ) + distributed_rank: Optional[int] = field( + default=0, metadata={"help": "rank of the current worker"} + ) + distributed_backend: str = field( + default="nccl", metadata={"help": "distributed backend"} + ) + distributed_init_method: Optional[str] = field( + default=None, + metadata={ + "help": "typically tcp://hostname:port that will be used to " + "establish initial connetion" + }, + ) + distributed_port: int = field( + default=-1, + metadata={ + "help": "port number (not required if using --distributed-init-method)" + }, + ) + device_id: int = field( + default=os.getenv("LOCAL_RANK", 0), + metadata={ + "help": "which GPU to use (by default looks for $LOCAL_RANK, usually configured automatically)", + "argparse_alias": "--local_rank", + }, + ) + distributed_no_spawn: bool = field( + default=False, + metadata={ + "help": "do not spawn multiple processes even if multiple GPUs are visible" + }, + ) + ddp_backend: DDP_BACKEND_CHOICES = field( + default="pytorch_ddp", metadata={"help": "DistributedDataParallel backend"} + ) + ddp_comm_hook: DDP_COMM_HOOK_CHOICES = field( + default="none", metadata={"help": "communication hook"} + ) + bucket_cap_mb: int = field( + default=25, metadata={"help": "bucket size for reduction"} + ) + fix_batches_to_gpus: bool = field( + default=False, + metadata={ + "help": "don't shuffle batches between GPUs; this reduces overall " + "randomness and may affect precision but avoids the cost of re-reading the data" + }, + ) + find_unused_parameters: bool = field( + default=False, + metadata={ + "help": "disable unused parameter detection (not applicable to " + "--ddp-backend=legacy_ddp)" + }, + ) + gradient_as_bucket_view: bool = field( + default=False, + metadata={ + "help": "when set to True, gradients will be views pointing to different offsets of allreduce communication buckets. This can reduce peak memory usage, where the saved memory size will be equal to the total gradients size. " + "--gradient-as-bucket-view=gradient_as_bucket_view)" + }, + ) + fast_stat_sync: bool = field( + default=False, + metadata={"help": "[deprecated] this is now defined per Criterion"}, + ) + heartbeat_timeout: int = field( + default=-1, + metadata={ + "help": "kill the job if no progress is made in N seconds; " + "set to -1 to disable" + }, + ) + broadcast_buffers: bool = field( + default=False, + metadata={ + "help": "Copy non-trainable parameters between GPUs, such as " + "batchnorm population statistics" + }, + ) + slowmo_momentum: Optional[float] = field( + default=None, + metadata={ + "help": "SlowMo momentum term; by default use 0.0 for 16 GPUs, " + "0.2 for 32 GPUs; 0.5 for 64 GPUs, 0.6 for > 64 GPUs" + }, + ) + slowmo_base_algorithm: str = field( + default="localsgd", + metadata={ + "help": "Base algorithm. Either 'localsgd' or 'sgp'. Please refer " + "to the documentation of 'slowmo_base_algorithm' parameter in " + "https://fairscale.readthedocs.io/en/latest/api/experimental/nn/slowmo_ddp.html " + "for more details" + }, + ) + localsgd_frequency: int = field( + default=3, metadata={"help": "Local SGD allreduce frequency"} + ) + nprocs_per_node: int = field( + default=max(1, torch.cuda.device_count()), + metadata={ + "help": "number of GPUs in each node. An allreduce operation across GPUs in " + "a node is very fast. Hence, we do allreduce across GPUs in a node, " + "and gossip across different nodes" + }, + ) + pipeline_model_parallel: bool = field( + default=False, + metadata={"help": "if set, use pipeline model parallelism across GPUs"}, + ) + pipeline_balance: Optional[str] = field( + default=None, + metadata={ + "help": "partition the model into N_K pieces, where each piece " + "contains N_i layers. The sum(args.pipeline_balance) " + "should equal the total number of layers in the model" + }, + ) + pipeline_devices: Optional[str] = field( + default=None, + metadata={ + "help": "a list of device indices indicating which device to place " + "each of the N_K partitions. The length of this list should " + "equal the length of the --pipeline-balance argument" + }, + ) + pipeline_chunks: Optional[int] = field( + default=0, metadata={"help": "microbatch count for pipeline model parallelism"} + ) + pipeline_encoder_balance: Optional[str] = field( + default=None, + metadata={ + "help": "partition the pipeline parallel encoder into N_K pieces, where each piece " + "contains N_i layers. The sum(args.pipeline_encoder_balance) " + "should equal the total number of encoder layers in the model" + }, + ) + pipeline_encoder_devices: Optional[str] = field( + default=None, + metadata={ + "help": "a list of device indices indicating which device to place " + "each of the N_K partitions. The length of this list should " + "equal the length of the --pipeline-encoder-balance argument" + }, + ) + pipeline_decoder_balance: Optional[str] = field( + default=None, + metadata={ + "help": "partition the pipeline parallel decoder into N_K pieces, where each piece " + "contains N_i layers. The sum(args.pipeline_decoder_balance) " + "should equal the total number of decoder layers in the model" + }, + ) + pipeline_decoder_devices: Optional[str] = field( + default=None, + metadata={ + "help": "a list of device indices indicating which device to place " + "each of the N_K partitions. The length of this list should " + "equal the length of the --pipeline-decoder-balance argument" + }, + ) + pipeline_checkpoint: PIPELINE_CHECKPOINT_CHOICES = field( + default="never", + metadata={"help": "checkpointing mode for pipeline model parallelism"}, + ) + zero_sharding: ZERO_SHARDING_CHOICES = field( + default="none", metadata={"help": "ZeRO sharding"} + ) + fp16: bool = II("common.fp16") + memory_efficient_fp16: bool = II("common.memory_efficient_fp16") + tpu: bool = II("common.tpu") + # configuration for --ddp-backend=fully_sharded + no_reshard_after_forward: bool = field( + default=False, + metadata={"help": "don't reshard parameters after forward pass"}, + ) + fp32_reduce_scatter: bool = field( + default=False, + metadata={"help": "reduce-scatter grads in FP32"}, + ) + cpu_offload: bool = field( + default=False, metadata={"help": "offload FP32 params to CPU"} + ) + use_sharded_state: bool = field( + default=False, + metadata={"help": "use sharded checkpoint files"}, + ) + not_fsdp_flatten_parameters: bool = field( + default=False, + metadata={"help": "not flatten parameter param for fsdp"}, + ) + + +@dataclass +class DatasetConfig(FairseqDataclass): + num_workers: int = field( + default=1, metadata={"help": "how many subprocesses to use for data loading"} + ) + skip_invalid_size_inputs_valid_test: bool = field( + default=False, + metadata={"help": "ignore too long or too short lines in valid and test set"}, + ) + max_tokens: Optional[int] = field( + default=None, metadata={"help": "maximum number of tokens in a batch"} + ) + batch_size: Optional[int] = field( + default=None, + metadata={ + "help": "number of examples in a batch", + "argparse_alias": "--max-sentences", + }, + ) + required_batch_size_multiple: int = field( + default=8, metadata={"help": "batch size will be a multiplier of this value"} + ) + required_seq_len_multiple: int = field( + default=1, + metadata={ + "help": "maximum sequence length in batch will be a multiplier of this value" + }, + ) + dataset_impl: Optional[DATASET_IMPL_CHOICES] = field( + default=None, metadata={"help": "output dataset implementation"} + ) + data_buffer_size: int = field( + default=10, metadata={"help": "Number of batches to preload"} + ) + train_subset: str = field( + default="train", + metadata={"help": "data subset to use for training (e.g. train, valid, test)"}, + ) + valid_subset: str = field( + default="valid", + metadata={ + "help": "comma separated list of data subsets to use for validation" + " (e.g. train, valid, test)" + }, + ) + combine_valid_subsets: Optional[bool] = field( + default=None, + metadata={ + "help": "comma separated list of data subsets to use for validation" + " (e.g. train, valid, test)", + "argparse_alias": "--combine-val", + }, + ) + ignore_unused_valid_subsets: Optional[bool] = field( + default=False, + metadata={"help": "do not raise error if valid subsets are ignored"}, + ) + + validate_interval: int = field( + default=1, metadata={"help": "validate every N epochs"} + ) + validate_interval_updates: int = field( + default=0, metadata={"help": "validate every N updates"} + ) + validate_after_updates: int = field( + default=0, metadata={"help": "dont validate until reaching this many updates"} + ) + fixed_validation_seed: Optional[int] = field( + default=None, metadata={"help": "specified random seed for validation"} + ) + disable_validation: bool = field( + default=False, metadata={"help": "disable validation"} + ) + max_tokens_valid: Optional[int] = field( + default=II("dataset.max_tokens"), + metadata={ + "help": "maximum number of tokens in a validation batch" + " (defaults to --max-tokens)" + }, + ) + batch_size_valid: Optional[int] = field( + default=II("dataset.batch_size"), + metadata={ + "help": "batch size of the validation batch (defaults to --batch-size)", + "argparse_alias": "--max-sentences-valid", + }, + ) + max_valid_steps: Optional[int] = field( + default=None, + metadata={"help": "How many batches to evaluate", "argparse_alias": "--nval"}, + ) + curriculum: int = field( + default=0, metadata={"help": "don't shuffle batches for first N epochs"} + ) + gen_subset: str = field( + default="test", + metadata={"help": "data subset to generate (train, valid, test)"}, + ) + num_shards: int = field( + default=1, metadata={"help": "shard generation over N shards"} + ) + shard_id: int = field( + default=0, metadata={"help": "id of the shard to generate (id < num_shards)"} + ) + grouped_shuffling: bool = field( + default=False, + metadata={ + "help": "shuffle batches in groups of num_shards to enable similar sequence lengths on each GPU worker when batches are sorted by length", + }, + ) + update_epoch_batch_itr: bool = field( + default=II("dataset.grouped_shuffling"), + metadata={ + "help": "if true then prevents the reuse the epoch batch iterator by setting can_reuse_epoch_itr to false, defaults to --grouped-shuffling )", + }, + ) + update_ordered_indices_seed: bool = field( + default=False, + metadata={ + "help": "if true then increment seed with epoch for getting batch iterators, defautls to False.", + }, + ) + + +@dataclass +class OptimizationConfig(FairseqDataclass): + max_epoch: int = field( + default=0, metadata={"help": "force stop training at specified epoch"} + ) + max_update: int = field( + default=0, metadata={"help": "force stop training at specified update"} + ) + stop_time_hours: float = field( + default=0, + metadata={ + "help": "force stop training after specified cumulative time (if >0)" + }, + ) + clip_norm: float = field( + default=0.0, metadata={"help": "clip threshold of gradients"} + ) + sentence_avg: bool = field( + default=False, + metadata={ + "help": "normalize gradients by the number of sentences in a batch" + " (default is to normalize by number of tokens)" + }, + ) + update_freq: List[int] = field( + default_factory=lambda: [1], + metadata={"help": "update parameters every N_i batches, when in epoch i"}, + ) + lr: List[float] = field( + default_factory=lambda: [0.25], + metadata={ + "help": "learning rate for the first N epochs; all epochs >N using LR_N" + " (note: this may be interpreted differently depending on --lr-scheduler)" + }, + ) + stop_min_lr: float = field( + default=-1.0, + metadata={"help": "stop training when the learning rate reaches this minimum"}, + ) + use_bmuf: bool = field( + default=False, + metadata={ + "help": "specify global optimizer for syncing models on different GPUs/shards" + }, + ) + skip_remainder_batch: Optional[bool] = field( + default=False, + metadata={ + "help": "if set, include the last (partial) batch of each epoch in training" + " (default is to skip it)." + }, + ) + debug_param_names: bool = False + + +@dataclass +class CheckpointConfig(FairseqDataclass): + save_dir: str = field( + default="checkpoints", metadata={"help": "path to save checkpoints"} + ) + restore_file: str = field( + default="checkpoint_last.pt", + metadata={ + "help": "filename from which to load checkpoint " + "(default: /checkpoint_last.pt" + }, + ) + continue_once: Optional[str] = field( + default=None, + metadata={ + "help": "continues from this checkpoint, unless a checkpoint indicated in 'restore_file' option is present" + }, + ) + finetune_from_model: Optional[str] = field( + default=None, + metadata={ + "help": "finetune from a pretrained model; note that meters and lr scheduler will be reset" + }, + ) + reset_dataloader: bool = field( + default=False, + metadata={ + "help": "if set, does not reload dataloader state from the checkpoint" + }, + ) + reset_lr_scheduler: bool = field( + default=False, + metadata={ + "help": "if set, does not load lr scheduler state from the checkpoint" + }, + ) + reset_meters: bool = field( + default=False, + metadata={"help": "if set, does not load meters from the checkpoint"}, + ) + reset_optimizer: bool = field( + default=False, + metadata={"help": "if set, does not load optimizer state from the checkpoint"}, + ) + optimizer_overrides: str = field( + default="{}", + metadata={ + "help": "a dictionary used to override optimizer args when loading a checkpoint" + }, + ) + save_interval: int = field( + default=1, metadata={"help": "save a checkpoint every N epochs"} + ) + save_interval_updates: int = field( + default=0, metadata={"help": "save a checkpoint (and validate) every N updates"} + ) + keep_interval_updates: int = field( + default=-1, + metadata={ + "help": "keep the last N checkpoints saved with --save-interval-updates" + }, + ) + keep_interval_updates_pattern: int = field( + default=-1, + metadata={ + "help": "when used with --keep-interval-updates, skips deleting " + "any checkpoints with update X where " + "X %% keep_interval_updates_pattern == 0" + }, + ) + keep_last_epochs: int = field( + default=-1, metadata={"help": "keep last N epoch checkpoints"} + ) + keep_best_checkpoints: int = field( + default=-1, metadata={"help": "keep best N checkpoints based on scores"} + ) + no_save: bool = field( + default=False, metadata={"help": "don't save models or checkpoints"} + ) + no_epoch_checkpoints: bool = field( + default=False, metadata={"help": "only store last and best checkpoints"} + ) + no_last_checkpoints: bool = field( + default=False, metadata={"help": "don't store last checkpoints"} + ) + no_save_optimizer_state: bool = field( + default=False, + metadata={"help": "don't save optimizer-state as part of checkpoint"}, + ) + best_checkpoint_metric: str = field( + default="loss", metadata={"help": 'metric to use for saving "best" checkpoints'} + ) + maximize_best_checkpoint_metric: bool = field( + default=False, + metadata={ + "help": 'select the largest metric value for saving "best" checkpoints' + }, + ) + patience: int = field( + default=-1, + metadata={ + "help": ( + "early stop training if valid performance doesn't " + "improve for N consecutive validation runs; note " + "that this is influenced by --validate-interval" + ) + }, + ) + checkpoint_suffix: str = field( + default="", metadata={"help": "suffix to add to the checkpoint file name"} + ) + checkpoint_shard_count: int = field( + default=1, + metadata={ + "help": "Number of shards containing the checkpoint - " + "if the checkpoint is over 300GB, it is preferable " + "to split it into shards to prevent OOM on CPU while loading " + "the checkpoint" + }, + ) + load_checkpoint_on_all_dp_ranks: bool = field( + default=False, + metadata={ + "help": "load checkpoints on all data parallel devices " + "(default: only load on rank 0 and broadcast to other devices)" + }, + ) + write_checkpoints_asynchronously: bool = field( + default=False, + metadata={ + "help": ( + "Write checkpoints asynchronously in a separate " + "thread. NOTE: This feature is currently being tested." + ), + "argparse_alias": "--save-async", + }, + ) + model_parallel_size: int = II("common.model_parallel_size") + + +@dataclass +class FairseqBMUFConfig(FairseqDataclass): + block_lr: float = field( + default=1, metadata={"help": "block learning rate for bmuf"} + ) + block_momentum: float = field( + default=0.875, metadata={"help": "block momentum for bmuf"} + ) + global_sync_iter: int = field( + default=50, metadata={"help": "Iteration for syncing global model"} + ) + warmup_iterations: int = field( + default=500, metadata={"help": "warmup iterations for model to broadcast"} + ) + use_nbm: bool = field( + default=False, + metadata={"help": "Specify whether you want to use classical BM / Nesterov BM"}, + ) + average_sync: bool = field( + default=False, + metadata={ + "help": "Specify whether you want to average the local momentum after each sync" + }, + ) + distributed_world_size: int = II("distributed_training.distributed_world_size") + + +@dataclass +class GenerationConfig(FairseqDataclass): + beam: int = field( + default=5, + metadata={"help": "beam size"}, + ) + beam_mt: int = field( + default=0, + metadata={"help": "beam size for the first-pass decoder"}, + ) + nbest: int = field( + default=1, + metadata={"help": "number of hypotheses to output"}, + ) + max_len_a: float = field( + default=0, + metadata={ + "help": "generate sequences of maximum length ax + b, where x is the source length" + }, + ) + max_len_b: int = field( + default=200, + metadata={ + "help": "generate sequences of maximum length ax + b, where x is the source length" + }, + ) + max_len_a_mt: float = field( + default=0, + metadata={ + "help": "generate sequences of maximum length ax + b, where x is the source length for the first-pass decoder" + }, + ) + max_len_b_mt: int = field( + default=200, + metadata={ + "help": "generate sequences of maximum length ax + b, where x is the source length for the first-pass decoder" + }, + ) + min_len: int = field( + default=1, + metadata={"help": "minimum generation length"}, + ) + match_source_len: bool = field( + default=False, + metadata={"help": "generations should match the source length"}, + ) + unnormalized: bool = field( + default=False, + metadata={"help": "compare unnormalized hypothesis scores"}, + ) + no_early_stop: bool = field( + default=False, + metadata={"help": "deprecated"}, + ) + no_beamable_mm: bool = field( + default=False, + metadata={"help": "don't use BeamableMM in attention layers"}, + ) + lenpen: float = field( + default=1, + metadata={ + "help": "length penalty: <1.0 favors shorter, >1.0 favors longer sentences" + }, + ) + lenpen_mt: float = field( + default=1, + metadata={ + "help": "length penalty for the first-pass decoder: <1.0 favors shorter, >1.0 favors longer sentences" + }, + ) + unkpen: float = field( + default=0, + metadata={ + "help": "unknown word penalty: <0 produces more unks, >0 produces fewer" + }, + ) + replace_unk: Optional[str] = field( + default=None, + metadata={ + "help": "perform unknown replacement (optionally with alignment dictionary)", + "argparse_const": "@@ ", + }, + ) + sacrebleu: bool = field( + default=False, + metadata={"help": "score with sacrebleu"}, + ) + score_reference: bool = field( + default=False, + metadata={"help": "just score the reference translation"}, + ) + prefix_size: int = field( + default=0, + metadata={"help": "initialize generation by target prefix of given length"}, + ) + no_repeat_ngram_size: int = field( + default=0, + metadata={ + "help": "ngram blocking such that this size ngram cannot be repeated in the generation" + }, + ) + sampling: bool = field( + default=False, + metadata={"help": "sample hypotheses instead of using beam search"}, + ) + sampling_topk: int = field( + default=-1, + metadata={"help": "sample from top K likely next words instead of all words"}, + ) + sampling_topp: float = field( + default=-1.0, + metadata={ + "help": "sample from the smallest set whose cumulative probability mass exceeds p for next words" + }, + ) + constraints: Optional[GENERATION_CONSTRAINTS_CHOICES] = field( + default=None, + metadata={ + "help": "enables lexically constrained decoding", + "argparse_const": "ordered", + }, + ) + temperature: float = field( + default=1.0, + metadata={"help": "temperature for generation"}, + ) + diverse_beam_groups: int = field( + default=-1, + metadata={"help": "number of groups for Diverse Beam Search"}, + ) + diverse_beam_strength: float = field( + default=0.5, + metadata={"help": "strength of diversity penalty for Diverse Beam Search"}, + ) + diversity_rate: float = field( + default=-1.0, + metadata={"help": "strength of diversity penalty for Diverse Siblings Search"}, + ) + print_alignment: Optional[PRINT_ALIGNMENT_CHOICES] = field( + default=None, + metadata={ + "help": "if set, uses attention feedback to compute and print alignment to source tokens " + "(valid options are: hard, soft, otherwise treated as hard alignment)", + "argparse_const": "hard", + }, + ) + print_step: bool = field( + default=False, + metadata={"help": "print steps"}, + ) + lm_path: Optional[str] = field( + default=None, + metadata={"help": "path to lm checkpoint for lm fusion"}, + ) + lm_weight: float = field( + default=0.0, + metadata={"help": "weight for lm probs for lm fusion"}, + ) + + # arguments for iterative refinement generator + iter_decode_eos_penalty: float = field( + default=0.0, + metadata={"help": "if > 0.0, it penalized early-stopping in decoding."}, + ) + iter_decode_max_iter: int = field( + default=10, + metadata={"help": "maximum iterations for iterative refinement."}, + ) + iter_decode_force_max_iter: bool = field( + default=False, + metadata={ + "help": "if set, run exact the maximum number of iterations without early stop" + }, + ) + iter_decode_with_beam: int = field( + default=1, + metadata={ + "help": "if > 1, model will generate translations varying by the lengths." + }, + ) + iter_decode_with_external_reranker: bool = field( + default=False, + metadata={ + "help": "if set, the last checkpoint are assumed to be a reranker to rescore the translations" + }, + ) + retain_iter_history: bool = field( + default=False, + metadata={ + "help": "if set, decoding returns the whole history of iterative refinement" + }, + ) + retain_dropout: bool = field( + default=False, + metadata={"help": "Use dropout at inference time"}, + ) + # temporarily set to Any until https://github.com/facebookresearch/hydra/issues/1117 is fixed + # retain_dropout_modules: Optional[List[str]] = field( + retain_dropout_modules: Any = field( + default=None, + metadata={ + "help": "if set, only retain dropout for the specified modules; " + "if not set, then dropout will be retained for all modules" + }, + ) + # special decoding format for advanced decoding. + decoding_format: Optional[GENERATION_DECODING_FORMAT_CHOICES] = field( + default=None, + metadata={"help": "special decoding format for advanced decoding."}, + ) + no_seed_provided: bool = field( + default=False, + metadata={"help": "if set, dont use seed for initializing random generators"}, + ) + eos_token: Optional[str] = field( + default=None, + metadata={"help": "EOS token"}, + ) + + +@dataclass +class CommonEvalConfig(FairseqDataclass): + path: Optional[str] = field( + default=None, + metadata={"help": "path(s) to model file(s), colon separated"}, + ) + post_process: Optional[str] = field( + default=None, + metadata={ + "help": ( + "post-process text by removing BPE, letter segmentation, etc. " + "Valid options can be found in fairseq.data.utils.post_process." + ), + "argparse_const": "subword_nmt", + "argparse_alias": "--remove-bpe", + }, + ) + quiet: bool = field(default=False, metadata={"help": "only print final scores"}) + model_overrides: str = field( + default="{}", + metadata={ + "help": "a dictionary used to override model args at generation that were used during model training" + }, + ) + results_path: Optional[str] = field( + default=None, metadata={"help": "path to save eval results (optional)"} + ) + + +@dataclass +class EvalLMConfig(FairseqDataclass): + output_word_probs: bool = field( + default=False, + metadata={ + "help": "if set, outputs words and their predicted log probabilities to standard output" + }, + ) + output_word_stats: bool = field( + default=False, + metadata={ + "help": "if set, outputs word statistics such as word count, average probability, etc" + }, + ) + context_window: int = field( + default=0, + metadata={ + "help": "ensures that every evaluated token has access to a context of at least this size, if possible" + }, + ) + softmax_batch: int = field( + default=sys.maxsize, + metadata={ + "help": "if BxT is more than this, will batch the softmax over vocab to this amount of tokens, in order to fit into GPU memory" + }, + ) + + +@dataclass +class InteractiveConfig(FairseqDataclass): + buffer_size: int = field( + default=0, + metadata={ + "help": "read this many sentences into a buffer before processing them" + }, + ) + input: str = field( + default="-", + metadata={"help": "file to read from; use - for stdin"}, + ) + + +@dataclass +class EMAConfig(FairseqDataclass): + store_ema: bool = field( + default=False, metadata={help: "store exponential moving average shadow model"} + ) + ema_decay: float = field( + default=0.9999, metadata={"help": "decay for exponential moving average model"} + ) + ema_start_update: int = field( + default=0, metadata={"help": "start EMA update after this many model updates"} + ) + ema_seed_model: Optional[str] = field( + default=None, + metadata={ + "help": "Seed to load EMA model from. " + "Used to load EMA model separately from the actual model." + }, + ) + ema_update_freq: int = field( + default=1, metadata={"help": "Do EMA update every this many model updates"} + ) + ema_fp32: bool = field( + default=False, + metadata={"help": "If true, store EMA model in fp32 even if model is in fp16"}, + ) + + +@dataclass +class FairseqConfig(FairseqDataclass): + common: CommonConfig = CommonConfig() + common_eval: CommonEvalConfig = CommonEvalConfig() + distributed_training: DistributedTrainingConfig = DistributedTrainingConfig() + dataset: DatasetConfig = DatasetConfig() + optimization: OptimizationConfig = OptimizationConfig() + checkpoint: CheckpointConfig = CheckpointConfig() + bmuf: FairseqBMUFConfig = FairseqBMUFConfig() + generation: GenerationConfig = GenerationConfig() + eval_lm: EvalLMConfig = EvalLMConfig() + interactive: InteractiveConfig = InteractiveConfig() + model: Any = MISSING + task: Any = None + criterion: Any = None + optimizer: Any = None + lr_scheduler: Any = None + scoring: Any = None + bpe: Any = None + tokenizer: Any = None + ema: EMAConfig = EMAConfig() diff --git a/fairseq/dataclass/constants.py b/fairseq/dataclass/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..5af92f2b3aa51e460f0b045a348d3766f93eb90b --- /dev/null +++ b/fairseq/dataclass/constants.py @@ -0,0 +1,56 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum, EnumMeta +from typing import List + + +class StrEnumMeta(EnumMeta): + # this is workaround for submitit pickling leading to instance checks failing in hydra for StrEnum, see + # https://github.com/facebookresearch/hydra/issues/1156 + @classmethod + def __instancecheck__(cls, other): + return "enum" in str(type(other)) + + +class StrEnum(Enum, metaclass=StrEnumMeta): + def __str__(self): + return self.value + + def __eq__(self, other: str): + return self.value == other + + def __repr__(self): + return self.value + + def __hash__(self): + return hash(str(self)) + + +def ChoiceEnum(choices: List[str]): + """return the Enum class used to enforce list of choices""" + return StrEnum("Choices", {k: k for k in choices}) + + +LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"]) +DDP_BACKEND_CHOICES = ChoiceEnum( + [ + "c10d", # alias for pytorch_ddp + "fully_sharded", # FullyShardedDataParallel from fairscale + "legacy_ddp", + "no_c10d", # alias for legacy_ddp + "pytorch_ddp", + "slowmo", + ] +) +DDP_COMM_HOOK_CHOICES = ChoiceEnum(["none", "fp16"]) +DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta", "huffman"]) +GENERATION_CONSTRAINTS_CHOICES = ChoiceEnum(["ordered", "unordered"]) +GENERATION_DECODING_FORMAT_CHOICES = ChoiceEnum( + ["unigram", "ensemble", "vote", "dp", "bs"] +) +ZERO_SHARDING_CHOICES = ChoiceEnum(["none", "os"]) +PIPELINE_CHECKPOINT_CHOICES = ChoiceEnum(["always", "never", "except_last"]) +PRINT_ALIGNMENT_CHOICES = ChoiceEnum(["hard", "soft"]) diff --git a/fairseq/dataclass/initialize.py b/fairseq/dataclass/initialize.py new file mode 100644 index 0000000000000000000000000000000000000000..5a7784bad194761b6d60ccfa5aed2fc01aa123c0 --- /dev/null +++ b/fairseq/dataclass/initialize.py @@ -0,0 +1,61 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import logging +from hydra.core.config_store import ConfigStore +from fairseq.dataclass.configs import FairseqConfig +from omegaconf import DictConfig, OmegaConf + + +logger = logging.getLogger(__name__) + + +def hydra_init(cfg_name="config") -> None: + + cs = ConfigStore.instance() + cs.store(name=f"{cfg_name}", node=FairseqConfig) + + for k in FairseqConfig.__dataclass_fields__: + v = FairseqConfig.__dataclass_fields__[k].default + try: + cs.store(name=k, node=v) + except BaseException: + logger.error(f"{k} - {v}") + raise + + +def add_defaults(cfg: DictConfig) -> None: + """This function adds default values that are stored in dataclasses that hydra doesn't know about""" + + from fairseq.registry import REGISTRIES + from fairseq.tasks import TASK_DATACLASS_REGISTRY + from fairseq.models import ARCH_MODEL_NAME_REGISTRY, MODEL_DATACLASS_REGISTRY + from fairseq.dataclass.utils import merge_with_parent + from typing import Any + + OmegaConf.set_struct(cfg, False) + + for k, v in FairseqConfig.__dataclass_fields__.items(): + field_cfg = cfg.get(k) + if field_cfg is not None and v.type == Any: + dc = None + + if isinstance(field_cfg, str): + field_cfg = DictConfig({"_name": field_cfg}) + field_cfg.__dict__["_parent"] = field_cfg.__dict__["_parent"] + + name = getattr(field_cfg, "_name", None) + + if k == "task": + dc = TASK_DATACLASS_REGISTRY.get(name) + elif k == "model": + name = ARCH_MODEL_NAME_REGISTRY.get(name, name) + dc = MODEL_DATACLASS_REGISTRY.get(name) + elif k in REGISTRIES: + dc = REGISTRIES[k]["dataclass_registry"].get(name) + + if dc is not None: + cfg[k] = merge_with_parent(dc, field_cfg) diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f6467d5f402f3904dd2adf67101a248e89bba887 --- /dev/null +++ b/fairseq/dataclass/utils.py @@ -0,0 +1,510 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import ast +import inspect +import logging +import os +import re +from argparse import ArgumentError, ArgumentParser, Namespace +from dataclasses import _MISSING_TYPE, MISSING, is_dataclass +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Type + +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.configs import FairseqConfig +from hydra.core.global_hydra import GlobalHydra +from hydra.experimental import compose, initialize +from omegaconf import DictConfig, OmegaConf, open_dict, _utils + +logger = logging.getLogger(__name__) + + +def eval_str_list(x, x_type=float): + if x is None: + return None + if isinstance(x, str): + if len(x) == 0: + return [] + x = ast.literal_eval(x) + try: + return list(map(x_type, x)) + except TypeError: + return [x_type(x)] + + +def interpret_dc_type(field_type): + if isinstance(field_type, str): + raise RuntimeError("field should be a type") + + if field_type == Any: + return str + + typestring = str(field_type) + if re.match( + r"(typing.|^)Union\[(.*), NoneType\]$", typestring + ) or typestring.startswith("typing.Optional"): + return field_type.__args__[0] + return field_type + + +def gen_parser_from_dataclass( + parser: ArgumentParser, + dataclass_instance: FairseqDataclass, + delete_default: bool = False, + with_prefix: Optional[str] = None, +) -> None: + """ + convert a dataclass instance to tailing parser arguments. + + If `with_prefix` is provided, prefix all the keys in the resulting parser with it. It means that we are + building a flat namespace from a structured dataclass (see transformer_config.py for example). + """ + + def argparse_name(name: str): + if name == "data" and (with_prefix is None or with_prefix == ""): + # normally data is positional args, so we don't add the -- nor the prefix + return name + if name == "_name": + # private member, skip + return None + full_name = "--" + name.replace("_", "-") + if with_prefix is not None and with_prefix != "": + # if a prefix is specified, construct the prefixed arg name + full_name = with_prefix + "-" + full_name[2:] # strip -- when composing + return full_name + + def get_kwargs_from_dc( + dataclass_instance: FairseqDataclass, k: str + ) -> Dict[str, Any]: + """k: dataclass attributes""" + + kwargs = {} + + field_type = dataclass_instance._get_type(k) + inter_type = interpret_dc_type(field_type) + + field_default = dataclass_instance._get_default(k) + + if isinstance(inter_type, type) and issubclass(inter_type, Enum): + field_choices = [t.value for t in list(inter_type)] + else: + field_choices = None + + field_help = dataclass_instance._get_help(k) + field_const = dataclass_instance._get_argparse_const(k) + + if isinstance(field_default, str) and field_default.startswith("${"): + kwargs["default"] = field_default + else: + if field_default is MISSING: + kwargs["required"] = True + if field_choices is not None: + kwargs["choices"] = field_choices + if ( + isinstance(inter_type, type) + and (issubclass(inter_type, List) or issubclass(inter_type, Tuple)) + ) or ("List" in str(inter_type) or "Tuple" in str(inter_type)): + if "int" in str(inter_type): + kwargs["type"] = lambda x: eval_str_list(x, int) + elif "float" in str(inter_type): + kwargs["type"] = lambda x: eval_str_list(x, float) + elif "str" in str(inter_type): + kwargs["type"] = lambda x: eval_str_list(x, str) + else: + raise NotImplementedError( + "parsing of type " + str(inter_type) + " is not implemented" + ) + if field_default is not MISSING: + kwargs["default"] = ( + ",".join(map(str, field_default)) + if field_default is not None + else None + ) + elif ( + isinstance(inter_type, type) and issubclass(inter_type, Enum) + ) or "Enum" in str(inter_type): + kwargs["type"] = str + if field_default is not MISSING: + if isinstance(field_default, Enum): + kwargs["default"] = field_default.value + else: + kwargs["default"] = field_default + elif inter_type is bool: + kwargs["action"] = ( + "store_false" if field_default is True else "store_true" + ) + kwargs["default"] = field_default + else: + kwargs["type"] = inter_type + if field_default is not MISSING: + kwargs["default"] = field_default + + # build the help with the hierarchical prefix + if with_prefix is not None and with_prefix != "" and field_help is not None: + field_help = with_prefix[2:] + ": " + field_help + + kwargs["help"] = field_help + if field_const is not None: + kwargs["const"] = field_const + kwargs["nargs"] = "?" + + return kwargs + + for k in dataclass_instance._get_all_attributes(): + field_name = argparse_name(dataclass_instance._get_name(k)) + field_type = dataclass_instance._get_type(k) + if field_name is None: + continue + elif inspect.isclass(field_type) and issubclass(field_type, FairseqDataclass): + # for fields that are of type FairseqDataclass, we can recursively + # add their fields to the namespace (so we add the args from model, task, etc. to the root namespace) + prefix = None + if with_prefix is not None: + # if a prefix is specified, then we don't want to copy the subfields directly to the root namespace + # but we prefix them with the name of the current field. + prefix = field_name + gen_parser_from_dataclass(parser, field_type(), delete_default, prefix) + continue + + kwargs = get_kwargs_from_dc(dataclass_instance, k) + + field_args = [field_name] + alias = dataclass_instance._get_argparse_alias(k) + if alias is not None: + field_args.append(alias) + + if "default" in kwargs: + if isinstance(kwargs["default"], str) and kwargs["default"].startswith( + "${" + ): + if kwargs["help"] is None: + # this is a field with a name that will be added elsewhere + continue + else: + del kwargs["default"] + if delete_default and "default" in kwargs: + del kwargs["default"] + try: + parser.add_argument(*field_args, **kwargs) + except ArgumentError: + pass + + +def _set_legacy_defaults(args, cls): + """Helper to set default arguments based on *add_args*.""" + if not hasattr(cls, "add_args"): + return + + import argparse + + parser = argparse.ArgumentParser( + argument_default=argparse.SUPPRESS, allow_abbrev=False + ) + cls.add_args(parser) + # copied from argparse.py: + defaults = argparse.Namespace() + for action in parser._actions: + if action.dest is not argparse.SUPPRESS: + if not hasattr(defaults, action.dest): + if action.default is not argparse.SUPPRESS: + setattr(defaults, action.dest, action.default) + for key, default_value in vars(defaults).items(): + if not hasattr(args, key): + setattr(args, key, default_value) + + +def _override_attr( + sub_node: str, data_class: Type[FairseqDataclass], args: Namespace +) -> List[str]: + overrides = [] + + if not inspect.isclass(data_class) or not issubclass(data_class, FairseqDataclass): + return overrides + + def get_default(f): + if not isinstance(f.default_factory, _MISSING_TYPE): + return f.default_factory() + return f.default + + for k, v in data_class.__dataclass_fields__.items(): + if k.startswith("_"): + # private member, skip + continue + + val = get_default(v) if not hasattr(args, k) else getattr(args, k) + + field_type = interpret_dc_type(v.type) + if ( + isinstance(val, str) + and not val.startswith("${") # not interpolation + and field_type != str + and ( + not inspect.isclass(field_type) or not issubclass(field_type, Enum) + ) # not choices enum + ): + # upgrade old models that stored complex parameters as string + val = ast.literal_eval(val) + + if isinstance(val, tuple): + val = list(val) + + v_type = getattr(v.type, "__origin__", None) + if ( + (v_type is List or v_type is list or v_type is Optional) + # skip interpolation + and not (isinstance(val, str) and val.startswith("${")) + ): + # if type is int but val is float, then we will crash later - try to convert here + if hasattr(v.type, "__args__"): + t_args = v.type.__args__ + if len(t_args) == 1 and (t_args[0] is float or t_args[0] is int): + val = list(map(t_args[0], val)) + elif val is not None and ( + field_type is int or field_type is bool or field_type is float + ): + try: + val = field_type(val) + except: + pass # ignore errors here, they are often from interpolation args + + if val is None: + overrides.append("{}.{}=null".format(sub_node, k)) + elif val == "": + overrides.append("{}.{}=''".format(sub_node, k)) + elif isinstance(val, str): + val = val.replace("'", r"\'") + overrides.append("{}.{}='{}'".format(sub_node, k, val)) + elif isinstance(val, FairseqDataclass): + overrides += _override_attr(f"{sub_node}.{k}", type(val), args) + elif isinstance(val, Namespace): + sub_overrides, _ = override_module_args(val) + for so in sub_overrides: + overrides.append(f"{sub_node}.{k}.{so}") + else: + overrides.append("{}.{}={}".format(sub_node, k, val)) + + return overrides + + +def migrate_registry( + name, value, registry, args, overrides, deletes, use_name_as_val=False +): + if value in registry: + overrides.append("{}={}".format(name, value)) + overrides.append("{}._name={}".format(name, value)) + overrides.extend(_override_attr(name, registry[value], args)) + elif use_name_as_val and value is not None: + overrides.append("{}={}".format(name, value)) + else: + deletes.append(name) + + +def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]: + """use the field in args to overrides those in cfg""" + overrides = [] + deletes = [] + + for k in FairseqConfig.__dataclass_fields__.keys(): + overrides.extend( + _override_attr(k, FairseqConfig.__dataclass_fields__[k].type, args) + ) + + if args is not None: + if hasattr(args, "task"): + from fairseq.tasks import TASK_DATACLASS_REGISTRY + + migrate_registry( + "task", args.task, TASK_DATACLASS_REGISTRY, args, overrides, deletes + ) + else: + deletes.append("task") + + # these options will be set to "None" if they have not yet been migrated + # so we can populate them with the entire flat args + CORE_REGISTRIES = {"criterion", "optimizer", "lr_scheduler"} + + from fairseq.registry import REGISTRIES + + for k, v in REGISTRIES.items(): + if hasattr(args, k): + migrate_registry( + k, + getattr(args, k), + v["dataclass_registry"], + args, + overrides, + deletes, + use_name_as_val=k not in CORE_REGISTRIES, + ) + else: + deletes.append(k) + + no_dc = True + if hasattr(args, "arch"): + from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_MODEL_NAME_REGISTRY + + if args.arch in ARCH_MODEL_REGISTRY: + m_cls = ARCH_MODEL_REGISTRY[args.arch] + dc = getattr(m_cls, "__dataclass", None) + if dc is not None: + m_name = ARCH_MODEL_NAME_REGISTRY[args.arch] + overrides.append("model={}".format(m_name)) + overrides.append("model._name={}".format(args.arch)) + # override model params with those exist in args + overrides.extend(_override_attr("model", dc, args)) + no_dc = False + if no_dc: + deletes.append("model") + + return overrides, deletes + + +class omegaconf_no_object_check: + def __init__(self): + # Changed in https://github.com/omry/omegaconf/pull/911 - both are kept for back compat. + if hasattr(_utils, "is_primitive_type"): + self.old_is_primitive = _utils.is_primitive_type + else: + self.old_is_primitive = _utils.is_primitive_type_annotation + + def __enter__(self): + if hasattr(_utils, "is_primitive_type"): + _utils.is_primitive_type = lambda _: True + else: + _utils.is_primitive_type_annotation = lambda _: True + + def __exit__(self, type, value, traceback): + if hasattr(_utils, "is_primitive_type"): + _utils.is_primitive_type = self.old_is_primitive + else: + _utils.is_primitive_type_annotation = self.old_is_primitive + + +def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig: + """Convert a flat argparse.Namespace to a structured DictConfig.""" + + # Here we are using field values provided in args to override counterparts inside config object + overrides, deletes = override_module_args(args) + + # configs will be in fairseq/config after installation + config_path = os.path.join("..", "config") + + GlobalHydra.instance().clear() + + with initialize(config_path=config_path): + try: + composed_cfg = compose("config", overrides=overrides, strict=False) + except: + logger.error("Error when composing. Overrides: " + str(overrides)) + raise + + for k in deletes: + composed_cfg[k] = None + + cfg = OmegaConf.create( + OmegaConf.to_container(composed_cfg, resolve=True, enum_to_str=True) + ) + + # hack to be able to set Namespace in dict config. this should be removed when we update to newer + # omegaconf version that supports object flags, or when we migrate all existing models + from omegaconf import _utils + + with omegaconf_no_object_check(): + if cfg.task is None and getattr(args, "task", None): + cfg.task = Namespace(**vars(args)) + from fairseq.tasks import TASK_REGISTRY + + _set_legacy_defaults(cfg.task, TASK_REGISTRY[args.task]) + cfg.task._name = args.task + if cfg.model is None and getattr(args, "arch", None): + cfg.model = Namespace(**vars(args)) + from fairseq.models import ARCH_MODEL_REGISTRY + + _set_legacy_defaults(cfg.model, ARCH_MODEL_REGISTRY[args.arch]) + cfg.model._name = args.arch + if cfg.optimizer is None and getattr(args, "optimizer", None): + cfg.optimizer = Namespace(**vars(args)) + from fairseq.optim import OPTIMIZER_REGISTRY + + _set_legacy_defaults(cfg.optimizer, OPTIMIZER_REGISTRY[args.optimizer]) + cfg.optimizer._name = args.optimizer + if cfg.lr_scheduler is None and getattr(args, "lr_scheduler", None): + cfg.lr_scheduler = Namespace(**vars(args)) + from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY + + _set_legacy_defaults( + cfg.lr_scheduler, LR_SCHEDULER_REGISTRY[args.lr_scheduler] + ) + cfg.lr_scheduler._name = args.lr_scheduler + if cfg.criterion is None and getattr(args, "criterion", None): + cfg.criterion = Namespace(**vars(args)) + from fairseq.criterions import CRITERION_REGISTRY + + _set_legacy_defaults(cfg.criterion, CRITERION_REGISTRY[args.criterion]) + cfg.criterion._name = args.criterion + + OmegaConf.set_struct(cfg, True) + return cfg + + +def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]): + # this will be deprecated when we get rid of argparse and model_overrides logic + + from fairseq.registry import REGISTRIES + + with open_dict(cfg): + for k in cfg.keys(): + # "k in cfg" will return false if its a "mandatory value (e.g. ???)" + if k in cfg and isinstance(cfg[k], DictConfig): + if k in overrides and isinstance(overrides[k], dict): + for ok, ov in overrides[k].items(): + if isinstance(ov, dict) and cfg[k][ok] is not None: + overwrite_args_by_name(cfg[k][ok], ov) + else: + cfg[k][ok] = ov + else: + overwrite_args_by_name(cfg[k], overrides) + elif k in cfg and isinstance(cfg[k], Namespace): + for override_key, val in overrides.items(): + setattr(cfg[k], override_key, val) + elif k in overrides: + if ( + k in REGISTRIES + and overrides[k] in REGISTRIES[k]["dataclass_registry"] + ): + cfg[k] = DictConfig( + REGISTRIES[k]["dataclass_registry"][overrides[k]] + ) + overwrite_args_by_name(cfg[k], overrides) + cfg[k]._name = overrides[k] + else: + cfg[k] = overrides[k] + + +def merge_with_parent(dc: FairseqDataclass, cfg: DictConfig, remove_missing=False): + if remove_missing: + + def remove_missing_rec(src_keys, target_cfg): + if is_dataclass(target_cfg): + target_keys = set(target_cfg.__dataclass_fields__.keys()) + else: + target_keys = set(target_cfg.keys()) + + for k in list(src_keys.keys()): + if k not in target_keys: + del src_keys[k] + elif OmegaConf.is_config(src_keys[k]): + tgt = getattr(target_cfg, k) + if tgt is not None and (is_dataclass(tgt) or hasattr(tgt, "keys")): + remove_missing_rec(src_keys[k], tgt) + + with open_dict(cfg): + remove_missing_rec(cfg, dc) + + merged_cfg = OmegaConf.merge(dc, cfg) + merged_cfg.__dict__["_parent"] = cfg.__dict__["_parent"] + OmegaConf.set_struct(merged_cfg, True) + return merged_cfg diff --git a/fairseq/distributed/__init__.py b/fairseq/distributed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9130db8f5d039519d663ee16c7ff2c102f5481f5 --- /dev/null +++ b/fairseq/distributed/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .distributed_timeout_wrapper import DistributedTimeoutWrapper +from .fully_sharded_data_parallel import ( + fsdp_enable_wrap, + fsdp_wrap, + FullyShardedDataParallel, +) +from .legacy_distributed_data_parallel import LegacyDistributedDataParallel +from .module_proxy_wrapper import ModuleProxyWrapper +from .tpu_distributed_data_parallel import TPUDistributedDataParallel + + +__all__ = [ + "DistributedTimeoutWrapper", + "fsdp_enable_wrap", + "fsdp_wrap", + "FullyShardedDataParallel", + "LegacyDistributedDataParallel", + "ModuleProxyWrapper", + "TPUDistributedDataParallel", +] diff --git a/fairseq/distributed/__pycache__/__init__.cpython-310.pyc b/fairseq/distributed/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0beca33476de82a9322afad7b3c34e0708098123 Binary files /dev/null and b/fairseq/distributed/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/distributed/__pycache__/__init__.cpython-311.pyc b/fairseq/distributed/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..659d03a9db3bf5005b7ac04de4e6412c50ce4fd3 Binary files /dev/null and b/fairseq/distributed/__pycache__/__init__.cpython-311.pyc differ diff --git a/fairseq/distributed/__pycache__/distributed_timeout_wrapper.cpython-310.pyc b/fairseq/distributed/__pycache__/distributed_timeout_wrapper.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00efdb3ac09212f8076cffc5470cbee39e89b5db Binary files /dev/null and b/fairseq/distributed/__pycache__/distributed_timeout_wrapper.cpython-310.pyc differ diff --git a/fairseq/distributed/__pycache__/distributed_timeout_wrapper.cpython-311.pyc b/fairseq/distributed/__pycache__/distributed_timeout_wrapper.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fc4a5db1f4aa93b51c76fa31a7d00ee683dc18e Binary files /dev/null and b/fairseq/distributed/__pycache__/distributed_timeout_wrapper.cpython-311.pyc differ diff --git a/fairseq/distributed/__pycache__/fully_sharded_data_parallel.cpython-310.pyc b/fairseq/distributed/__pycache__/fully_sharded_data_parallel.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..667f948dfd0167d6455d3a27ba555e5a1af895a6 Binary files /dev/null and b/fairseq/distributed/__pycache__/fully_sharded_data_parallel.cpython-310.pyc differ diff --git a/fairseq/distributed/__pycache__/fully_sharded_data_parallel.cpython-311.pyc b/fairseq/distributed/__pycache__/fully_sharded_data_parallel.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0994aa2645e92d5345dd49a7521650d8029e3a6f Binary files /dev/null and b/fairseq/distributed/__pycache__/fully_sharded_data_parallel.cpython-311.pyc differ diff --git a/fairseq/distributed/__pycache__/legacy_distributed_data_parallel.cpython-310.pyc b/fairseq/distributed/__pycache__/legacy_distributed_data_parallel.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d794cf65cdd7832780bfb8cdf257e33ca29a1d0 Binary files /dev/null and b/fairseq/distributed/__pycache__/legacy_distributed_data_parallel.cpython-310.pyc differ diff --git a/fairseq/distributed/__pycache__/module_proxy_wrapper.cpython-310.pyc b/fairseq/distributed/__pycache__/module_proxy_wrapper.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b259af47728da4c47e7102e2fd281e803d083d2d Binary files /dev/null and b/fairseq/distributed/__pycache__/module_proxy_wrapper.cpython-310.pyc differ diff --git a/fairseq/distributed/__pycache__/tpu_distributed_data_parallel.cpython-310.pyc b/fairseq/distributed/__pycache__/tpu_distributed_data_parallel.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0686999bc8ad24dced671767588b30750c651eac Binary files /dev/null and b/fairseq/distributed/__pycache__/tpu_distributed_data_parallel.cpython-310.pyc differ diff --git a/fairseq/distributed/__pycache__/utils.cpython-310.pyc b/fairseq/distributed/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..263549c43dd66352cbc388c7eb914c10189e1f88 Binary files /dev/null and b/fairseq/distributed/__pycache__/utils.cpython-310.pyc differ diff --git a/fairseq/distributed/distributed_timeout_wrapper.py b/fairseq/distributed/distributed_timeout_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..6e06b4b6dd9a5fedd5d72bde02ceb7aaf74833d7 --- /dev/null +++ b/fairseq/distributed/distributed_timeout_wrapper.py @@ -0,0 +1,97 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import signal +import threading + +from torch import nn + + +logger = logging.getLogger(__name__) + + +class DistributedTimeoutWrapper(nn.Module): + """ + A wrapper that kills the process if no progress is made within a given + *timeout*. The timer is reset every time :func:`forward` is called. + + Usage:: + + module = DistributedTimeoutWrapper(module, timeout=30) + x = module(input) + time.sleep(20) # safe + x = module(input) + time.sleep(45) # job will be killed before this returns + + Args: + module (nn.Module): module to wrap + timeout (int): number of seconds before killing the process + (set to a value <= 0 to disable the timeout) + signal (Optional): signal to send once timeout is triggered + """ + + def __init__(self, module: nn.Module, timeout: int, signal=signal.SIGINT): + super().__init__() + self.module = module + self.timeout = timeout + self.signal = signal + + if timeout > 0: + self._heartbeat = threading.Event() + self._heartbeat_thread = threading.Thread( + target=self._check_heartbeat, + args=(os.getpid(),), + daemon=True, + ) + self._heartbeat_thread.start() + self._terminated = False + else: + self._heartbeat = None + self._heartbeat_thread = None + + def __del__(self): + self.stop_timeout() + + def __getattr__(self, name): + """Forward missing attributes to wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self.module, name) + + def stop_timeout(self): + if self._heartbeat_thread is not None: + self._terminated = True + self._heartbeat_thread.join() + + def state_dict(self, *args, **kwargs): + return self.module.state_dict(*args, **kwargs) + + def load_state_dict(self, *args, **kwargs): + return self.module.load_state_dict(*args, **kwargs) + + def forward(self, *args, **kwargs): + if self._heartbeat is not None: + self._heartbeat.set() + return self.module(*args, **kwargs) + + def _check_heartbeat(self, parent_pid): + self._heartbeat.wait() # wait for the first forward pass + while True: + self._heartbeat.clear() + success = self._heartbeat.wait(timeout=self.timeout) + if self._terminated: + break + elif not success: + logger.error( + ( + "Killing job for not making progress in {} seconds. " + "Set --heartbeat-timeout=-1 to disable this timeout." + ).format(int(self.timeout)) + ) + os.kill(parent_pid, self.signal) + return diff --git a/fairseq/distributed/fully_sharded_data_parallel.py b/fairseq/distributed/fully_sharded_data_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..1c508b05dd2c5aa4a3aa586a6998e04dbbbbb918 --- /dev/null +++ b/fairseq/distributed/fully_sharded_data_parallel.py @@ -0,0 +1,145 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +from typing import Optional + +import torch +from fairseq.dataclass.configs import DistributedTrainingConfig +from fairseq.distributed import utils as dist_utils + + +try: + from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP + + has_FSDP = True +except ImportError: + FSDP = torch.nn.Module + has_FSDP = False + + +class FullyShardedDataParallel(FSDP): + """ + A small wrapper around fairscale's FullyShardedDataParallel (FSDP) with some + fairseq-specific checkpoint saving/loading logic. + + Args: + use_sharded_state (bool): if True, then ``state_dict`` will return + ``FSDP.local_state_dict`` and ``load_state_dict`` will call + ``FSDP.load_local_state_dict``. Otherwise, ``state_dict`` will + return the full model weights on data parallel rank 0 (empty on + other ranks) and ``load_state_dict`` will broadcast model weights + from rank 0 to other ranks. + """ + + def __init__(self, *args, use_sharded_state: bool = False, **kwargs): + if not has_FSDP: + raise ImportError( + "Cannot find FullyShardedDataParallel. " + "Please install fairscale with: pip install fairscale" + ) + super().__init__(*args, **kwargs) + self.use_sharded_state = use_sharded_state + + @property + def unwrapped_module(self) -> torch.nn.Module: + if self.flatten_parameters: + return self.module.module + else: + return self.module + + def state_dict(self, destination=None, prefix="", keep_vars=False): + if self.use_sharded_state: + return super().local_state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) + else: + if self.rank == 0: + return super().state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) + else: + # We must call state_dict() due to use of communication + # primitives. But we don't use the result. + super().state_dict() + return destination or {} + + def load_state_dict(self, state_dict, strict=True, model_cfg=None): + if self.use_sharded_state: + return super().load_local_state_dict(state_dict, strict=strict) + else: + state_dict = dist_utils.broadcast_object( + state_dict, src_rank=0, group=self.process_group + ) + return super().load_state_dict(state_dict, strict=strict) + + +class DummyProcessGroup: + def __init__(self, rank: int, size: int): + self._rank = rank + self._size = size + + def rank(self) -> int: + return self._rank + + def size(self) -> int: + return self._size + + +@contextlib.contextmanager +def fsdp_enable_wrap(cfg: DistributedTrainingConfig): + try: + from fairscale.nn import enable_wrap + except ImportError: + raise ImportError( + "Cannot find FullyShardedDataParallel. " + "Please install fairscale with: pip install fairscale" + ) + if cfg.memory_efficient_fp16: + assert cfg.fp16 # memory_efficient_fp16 should imply fp16 + group = dist_utils.get_data_parallel_group() + if group is None and cfg.distributed_world_size == 1: + group = DummyProcessGroup(rank=0, size=1) + fsdp_config = { + "process_group": group, + "reshard_after_forward": not cfg.no_reshard_after_forward, + "mixed_precision": cfg.fp16 and not cfg.memory_efficient_fp16, + "fp32_reduce_scatter": cfg.fp32_reduce_scatter, + "flatten_parameters": not cfg.not_fsdp_flatten_parameters, + "cpu_offload": cfg.cpu_offload, + "compute_dtype": torch.float16 if cfg.fp16 else torch.float32, + "bucket_cap_mb": cfg.bucket_cap_mb, + "state_dict_device": torch.device("cpu"), # reduce GPU mem usage + } + with enable_wrap( + wrapper_cls=FullyShardedDataParallel, + use_sharded_state=cfg.use_sharded_state, + **fsdp_config, + ): + yield + + +def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs): + """ + Helper to wrap layers/modules in FSDP. This falls back to a no-op if + fairscale is not available. + + Args: + module (nn.Module): module to (maybe) wrap + min_num_params (int, Optional): minimum number of layer params to wrap + """ + try: + from fairscale.nn import wrap + + if min_num_params is not None: + num_params = sum(p.numel() for p in module.parameters()) + if num_params >= min_num_params: + return wrap(module, **kwargs) + else: + return module + else: + return wrap(module, **kwargs) + except ImportError: + return module diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..cd434c7372ba30ea0e6f87e084230448f53480e9 --- /dev/null +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -0,0 +1,165 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +A modified version of the legacy DistributedDataParallel module that uses c10d +communication primitives. This version is simpler than the latest PyTorch +version and is useful for debugging. Notably it does not overlap gradient +communication with the backward pass, which makes it slower but more robust +than the PyTorch version. + +This version also supports the *no_sync* context manager, which allows faster +training with `--update-freq`. +""" + +from collections import OrderedDict +from contextlib import contextmanager + +import torch +from torch import nn + +from fairseq.distributed import utils + + +class LegacyDistributedDataParallel(nn.Module): + """Implements distributed data parallelism at the module level. + + A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`. + This version uses a c10d process group for communication and does not + broadcast buffers. + + Args: + module (~torch.nn.Module): module to be parallelized + process_group: the c10d process group to be used for distributed data + parallel all-reduction. + buffer_size (int, optional): number of elements to buffer before + performing all-reduce (default: 256M). + """ + + def __init__(self, module, process_group, buffer_size=2**28): + super().__init__() + + self.module = module + self.process_group = process_group + self.world_size = utils.get_world_size(self.process_group) + + # Never use a bigger buffer than the number of model params + self.buffer_size = min(buffer_size, sum(p.numel() for p in module.parameters())) + self.buffer = None + + # We can also forcibly accumulate grads locally and only do the + # all-reduce at some later time + self.accumulate_grads = False + + # make per-device lists of parameters + paramlists = OrderedDict() + for param in self.module.parameters(): + device = param.device + if paramlists.get(device) is None: + paramlists[device] = [] + paramlists[device] += [param] + self.per_device_params = list(paramlists.values()) + + @contextmanager + def no_sync(self): + """A context manager to disable gradient synchronization.""" + old_accumulate_grads = self.accumulate_grads + self.accumulate_grads = True + yield + self.accumulate_grads = old_accumulate_grads + + def forward(self, *inputs, **kwargs): + return self.module(*inputs, **kwargs) + + def all_reduce_grads(self): + """ + This function must be called explicitly after backward to reduce + gradients. There is no automatic hook like c10d. + """ + + def all_reduce_params(params): + buffer = self.buffer + nonzero_buffer = False + if len(params) > 1: + offset = 0 + for p in params: + sz = p.numel() + if p.grad is not None: + buffer[offset : offset + sz].copy_(p.grad.data.view(-1)) + nonzero_buffer = True + else: + buffer[offset : offset + sz].zero_() + offset += sz + else: + # we only have a single grad to all-reduce + p = params[0] + if p.grad is not None: + buffer = p.grad.data + nonzero_buffer = True + elif p.numel() <= self.buffer.numel(): + buffer = buffer[: p.numel()] + buffer.zero_() + else: + buffer = torch.zeros_like(p) + + if nonzero_buffer: + buffer.div_(self.world_size) + + utils.all_reduce(buffer, self.process_group) + + # copy all-reduced grads back into their original place + offset = 0 + for p in params: + sz = p.numel() + if p.grad is not None: + p.grad.data.copy_(buffer[offset : offset + sz].view_as(p)) + else: + p.grad = buffer[offset : offset + sz].view_as(p).clone() + offset += sz + + def reduction_fn(): + # This function only needs to be called once + if self.accumulate_grads: + return + + if self.buffer is None: + self.buffer = next(self.module.parameters()).new(self.buffer_size) + + for params in self.per_device_params: + # All-reduce the gradients in buckets + offset = 0 + buffered_params = [] + for param in params: + if not param.requires_grad: + continue + if param.grad is None: + param.grad = torch.zeros_like(param) + + if hasattr(param, "expert"): + # Skip gradient sync for unshared parameters + continue + + if param.grad.requires_grad: + raise RuntimeError( + "DistributedDataParallel only works " + "with gradients that don't require " + "grad" + ) + sz = param.numel() + if sz > self.buffer.numel(): + # all-reduce big params directly + all_reduce_params([param]) + else: + if offset + sz > self.buffer.numel(): + all_reduce_params(buffered_params) + offset = 0 + buffered_params.clear() + buffered_params.append(param) + offset += sz + + if len(buffered_params) > 0: + all_reduce_params(buffered_params) + + reduction_fn() diff --git a/fairseq/distributed/module_proxy_wrapper.py b/fairseq/distributed/module_proxy_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..904dc0c202e09db244518836c0f061e0850cad61 --- /dev/null +++ b/fairseq/distributed/module_proxy_wrapper.py @@ -0,0 +1,56 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from torch import nn + + +class ModuleProxyWrapper(nn.Module): + """ + Wrap a DistributedDataParallel module and forward requests for missing + attributes to the module wrapped by DDP (the twice-wrapped module). + Also forward calls to :func:`state_dict` and :func:`load_state_dict`. + + Usage:: + + module.xyz = "hello world" + wrapped_module = DistributedDataParallel(module, **ddp_args) + wrapped_module = ModuleProxyWrapper(wrapped_module) + assert wrapped_module.xyz == "hello world" + assert wrapped_module.state_dict().keys() == module.state_dict().keys() + + Args: + module (nn.Module): module to wrap + """ + + def __init__(self, module: nn.Module): + super().__init__() + assert hasattr( + module, "module" + ), "ModuleProxyWrapper expects input to wrap another module" + self.module = module + + def __getattr__(self, name): + """Forward missing attributes to twice-wrapped module.""" + try: + # defer to nn.Module's logic + return super().__getattr__(name) + except AttributeError: + try: + # forward to the once-wrapped module + return getattr(self.module, name) + except AttributeError: + # forward to the twice-wrapped module + return getattr(self.module.module, name) + + def state_dict(self, *args, **kwargs): + """Forward to the twice-wrapped module.""" + return self.module.module.state_dict(*args, **kwargs) + + def load_state_dict(self, *args, **kwargs): + """Forward to the twice-wrapped module.""" + return self.module.module.load_state_dict(*args, **kwargs) + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) diff --git a/fairseq/distributed/tpu_distributed_data_parallel.py b/fairseq/distributed/tpu_distributed_data_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..3b9e1033011db87100c64ec39845e81228a26381 --- /dev/null +++ b/fairseq/distributed/tpu_distributed_data_parallel.py @@ -0,0 +1,43 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn + +from fairseq.distributed import utils + + +class TPUDistributedDataParallel(nn.Module): + def __init__(self, module, process_group): + super().__init__() + self.module = module + self.process_group = process_group + self.world_size = utils.get_world_size(self.process_group) + + def forward(self, *inputs, **kwargs): + return self.module(*inputs, **kwargs) + + def all_reduce_grads(self): + gradients = [] + for p in self.parameters(): + if not p.requires_grad: + continue + if p.grad is None: + p.grad = torch.zeros_like(p) + if p.grad.requires_grad: + raise RuntimeError( + "TPUDistributedDataParallel only works with gradients that don't " + "require grad" + ) + gradients.append(p.grad) + + import torch_xla.core.xla_model as xm + + xm.all_reduce( + "sum", + gradients, + scale=1.0 / self.world_size, + groups=self.process_group[1], + ) diff --git a/fairseq/distributed/utils.py b/fairseq/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..968830d58582e436386111d90896bf95889c736e --- /dev/null +++ b/fairseq/distributed/utils.py @@ -0,0 +1,843 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import io +import logging +import os +import pickle +import random +import socket +import struct +import subprocess +import warnings +from argparse import Namespace +from collections import OrderedDict +from dataclasses import dataclass +from typing import Any, Dict, List, Mapping, Optional + +import torch +import torch.distributed as dist +from fairseq.dataclass.configs import DistributedTrainingConfig, FairseqConfig +from omegaconf import open_dict + +try: + import torch_xla.core.xla_model as xm +except ImportError: + xm = None + + +# Flag to indicate if we're using Megatron +# NOTE: this is a temporary hack until we move away from Megatron's model parallel init +_USE_MEGATRON = False + +# Whether to use XLA ops (e.g., on TPUs) instead of CUDA ops. +_USE_XLA = False + + +logger = logging.getLogger(__name__) + + +def is_master(cfg: DistributedTrainingConfig): + return cfg.distributed_rank == 0 + + +def infer_init_method(cfg: DistributedTrainingConfig, force_distributed=False): + if cfg.distributed_init_method is not None or cfg.tpu: + return + + num_pipelines_per_node = None + if cfg.pipeline_model_parallel: + num_pipeline_devices, num_pipelines_per_node = _pipeline_parallel_pre_init(cfg) + + if cfg.distributed_world_size == 1: + return + if all( + key in os.environ + for key in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"] + ): + # support torch.distributed.launch + _infer_torch_distributed_launch_init(cfg) + else: + # we can determine the init method automatically for Slurm + if not _infer_slurm_init(cfg, num_pipelines_per_node): + if cfg.distributed_port <= 0 or force_distributed: + _infer_single_node_init(cfg) + elif cfg.distributed_port <= 0: + _infer_single_node_init(cfg) + + if cfg.pipeline_model_parallel: + _pipeline_parallel_post_init(cfg, num_pipeline_devices, num_pipelines_per_node) + elif not cfg.distributed_no_spawn: + with open_dict(cfg): + cfg.distributed_num_procs = min( + torch.cuda.device_count(), cfg.distributed_world_size + ) + else: + if cfg.device_id > 0: + logger.info( + "setting CUDA device={} on rank {}".format( + cfg.device_id, cfg.distributed_rank + ) + ) + torch.cuda.set_device(cfg.device_id) + + +def _infer_torch_distributed_launch_init(cfg: DistributedTrainingConfig): + cfg.distributed_init_method = "env://" + cfg.distributed_world_size = int(os.environ["WORLD_SIZE"]) + cfg.distributed_rank = int(os.environ["RANK"]) + cfg.device_id = cfg.distributed_rank % torch.cuda.device_count() + # processes are created by torch.distributed.launch + cfg.distributed_no_spawn = True + + +def _infer_slurm_init(cfg: DistributedTrainingConfig, num_pipelines_per_node): + node_list = os.environ.get("SLURM_STEP_NODELIST") + if node_list is None: + node_list = os.environ.get("SLURM_JOB_NODELIST") + if node_list is not None: + try: + hostnames = subprocess.check_output( + ["scontrol", "show", "hostnames", node_list] + ) + cfg.distributed_init_method = "tcp://{host}:{port}".format( + host=hostnames.split()[0].decode("utf-8"), + port=cfg.distributed_port, + ) + nnodes = int(os.environ.get("SLURM_NNODES")) + ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE") + if ntasks_per_node is not None: + ntasks_per_node = int(ntasks_per_node) + else: + ntasks = int(os.environ.get("SLURM_NTASKS")) + nnodes = int(os.environ.get("SLURM_NNODES")) + assert ntasks % nnodes == 0 + ntasks_per_node = int(ntasks / nnodes) + if ntasks_per_node == 1: + gpus_per_node = torch.cuda.device_count() + node_id = int(os.environ.get("SLURM_NODEID")) + cfg.distributed_rank = node_id * gpus_per_node + cfg.distributed_world_size = nnodes * gpus_per_node + elif cfg.pipeline_model_parallel: + assert ntasks_per_node == num_pipelines_per_node, ( + "SLURM --ntasks-per-node must match number of pipelines per " + "node (={})".format(num_pipelines_per_node) + ) + cfg.distributed_no_spawn = True + # For 4-way MP on nodes with 8 GPUs, ranks will be [0, 1] on + # the first node, [1, 2] on the second node, etc. This + # matches torch.distributed.launch. + node_id = int(os.environ.get("SLURM_NODEID")) + local_id = int(os.environ.get("SLURM_LOCALID")) + cfg.distributed_rank = node_id * num_pipelines_per_node + local_id + # In the above example, device_id will always be in [0, 1], + # which also matches torch.distributed.launch. + cfg.device_id = local_id + # We also want to set distributed_world_size to be the total + # number of pipelines across all nodes. + cfg.distributed_world_size = nnodes * num_pipelines_per_node + else: + assert ( + ntasks_per_node == cfg.distributed_world_size // nnodes + ), f"{ntasks_per_node}, {cfg.distributed_world_size}, {nnodes}" + cfg.distributed_no_spawn = True + cfg.distributed_rank = int(os.environ.get("SLURM_PROCID")) + cfg.device_id = int(os.environ.get("SLURM_LOCALID")) + logger.info(f"Rank {cfg.distributed_rank}, device_id: {cfg.device_id}") + return True + except subprocess.CalledProcessError as e: # scontrol failed + raise e + except FileNotFoundError: # Slurm is not installed + pass + + return False + + +def _infer_single_node_init(cfg: DistributedTrainingConfig): + assert ( + cfg.distributed_world_size <= torch.cuda.device_count() + ), f"world size is {cfg.distributed_world_size} but have {torch.cuda.device_count()} available devices" + + if cfg.distributed_port <= 0: + jobid = os.environ.get("SLURM_JOB_ID") + task_id = os.environ.get("SLURM_ARRAY_TASK_ID") + + if jobid is not None: + if task_id is not None: + jobid += str(task_id) + jobid = int(jobid) + rng = random.Random(jobid) + port = rng.randint(10000, 60000) + else: + port = random.randint(10000, 60000) + + cfg.distributed_port = port + cfg.distributed_init_method = "tcp://localhost:{port}".format( + port=cfg.distributed_port + ) + + +def _pipeline_parallel_pre_init(cfg: DistributedTrainingConfig): + from fairseq import utils + + balance_exists = ( + cfg.pipeline_balance is not None + or cfg.pipeline_encoder_balance is not None + or cfg.pipeline_decoder_balance is not None + ) + devices_exist = ( + cfg.pipeline_devices is not None + or cfg.pipeline_encoder_devices is not None + or cfg.pipeline_decoder_devices is not None + ) + if not balance_exists: + raise ValueError( + "--pipeline-balance is currently required for pipeline model parallelism" + ) + if not devices_exist: + raise ValueError( + "--pipeline-devices is currently required for pipeline model parallelism" + ) + + cfg.pipeline_balance = utils.eval_str_list(cfg.pipeline_balance, type=int) + if cfg.pipeline_devices is not None: + cfg.pipeline_devices = utils.eval_str_list(cfg.pipeline_devices, type=int) + num_pipeline_devices = len(set(cfg.pipeline_devices)) + else: + cfg.pipeline_encoder_devices = utils.eval_str_list( + cfg.pipeline_encoder_devices, type=int + ) + cfg.pipeline_decoder_devices = utils.eval_str_list( + cfg.pipeline_decoder_devices, type=int + ) + num_pipeline_devices = len( + set(cfg.pipeline_encoder_devices + cfg.pipeline_decoder_devices) + ) + gpus_per_node = torch.cuda.device_count() + assert ( + gpus_per_node >= num_pipeline_devices + and gpus_per_node % num_pipeline_devices == 0 + ), ( + "the number of unique device IDs in --pipeline-devices must evenly divide " + "the number of GPUs per node (multi-node pipelining is not yet supported)" + ) + num_pipelines_per_node = gpus_per_node // num_pipeline_devices + return num_pipeline_devices, num_pipelines_per_node + + +def _pipeline_parallel_post_init( + cfg: DistributedTrainingConfig, num_pipeline_devices, num_pipelines_per_node +): + if not cfg.distributed_no_spawn: + # When distributed_no_spawn is False, we expect distributed_rank and + # distributed_world_size to be based on the total number of GPUs, so + # we need to correct them to be based on the number of pipelines. + assert cfg.distributed_world_size % num_pipeline_devices == 0 + cfg.distributed_world_size = cfg.distributed_world_size // num_pipeline_devices + # In the case of 4-way MP on nodes with 8 GPUs, we want + # distributed_rank to be the starting GPU index for each pipeline + # i.e., 0, 2, ... + gpus_per_node = torch.cuda.device_count() + assert cfg.distributed_rank % gpus_per_node == 0 + assert cfg.distributed_rank % num_pipeline_devices == 0 + + with open_dict(cfg): + cfg.distributed_rank = cfg.distributed_rank // num_pipeline_devices + # launch one process per pipeline + cfg.distributed_num_procs = num_pipelines_per_node + + # if we have 4-way MP on a node with 8 GPUs, we want device_ids to be 0 + # and 4, indicating the starting device IDs for each pipeline + cfg.device_id *= num_pipeline_devices + + if cfg.device_id > 0: + # if there's multiple pipelines on a node (e.g., 4-way MP on an 8 + # GPU node), we need to adjust pipeline_devices accordingly + logger.debug( + "setting CUDA device={} on rank {}".format( + cfg.device_id, cfg.distributed_rank + ) + ) + torch.cuda.set_device(cfg.device_id) + with open_dict(cfg): + cfg.pipeline_devices = [cfg.device_id + d for d in cfg.pipeline_devices] + logger.info( + "setting pipeline_devices={} on rank {}".format( + cfg.pipeline_devices, cfg.distributed_rank + ) + ) + + +def distributed_init(cfg: FairseqConfig): + if isinstance(cfg, Namespace): + from fairseq.dataclass.utils import convert_namespace_to_omegaconf + + cfg = convert_namespace_to_omegaconf(cfg) + + if not cfg.common.tpu: + if torch.distributed.is_available() and torch.distributed.is_initialized(): + warnings.warn( + "Distributed is already initialized, cannot initialize twice!" + ) + else: + logger.info( + "distributed init (rank {}): {}".format( + cfg.distributed_training.distributed_rank, + cfg.distributed_training.distributed_init_method, + ) + ) + dist.init_process_group( + backend=cfg.distributed_training.distributed_backend, + init_method=cfg.distributed_training.distributed_init_method, + world_size=cfg.distributed_training.distributed_world_size, + rank=cfg.distributed_training.distributed_rank, + ) + logger.info( + "initialized host {} as rank {}".format( + socket.gethostname(), + cfg.distributed_training.distributed_rank, + ) + ) + + # perform a dummy all-reduce to initialize the NCCL communicator + if torch.cuda.is_available(): + dist.all_reduce(torch.zeros(1).cuda()) + + cfg.distributed_training.distributed_rank = torch.distributed.get_rank() + else: + assert xm.xrt_world_size() == cfg.distributed_training.distributed_world_size + global _USE_XLA + _USE_XLA = True + cfg.distributed_training.device_id = xm.get_local_ordinal() + cfg.distributed_training.distributed_rank = xm.get_ordinal() + xm.rendezvous("distributed_init") # wait for all workers + + if is_master(cfg.distributed_training): + logging.getLogger().setLevel(logging.INFO) + else: + logging.getLogger().setLevel(logging.WARNING) + + if cfg.common.model_parallel_size > 1: + try: + from fairseq.model_parallel.megatron.mpu import ( + initialize_model_parallel, + model_parallel_cuda_manual_seed, + ) + except ImportError: + raise ImportError( + "\n\nPlease install the megatron submodule:" + "\n\n git submodule update --init " + "fairseq/model_parallel/megatron" + ) + global _USE_MEGATRON + _USE_MEGATRON = True + initialize_model_parallel(cfg.common.model_parallel_size) + model_parallel_cuda_manual_seed(cfg.common.seed) + model_part_number = get_model_parallel_rank() + cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format(model_part_number) + + if hasattr(cfg, "model") and getattr(cfg.model, "base_layers", 0) > 0: + cfg.checkpoint.checkpoint_suffix = ( + f"-rank-{cfg.distributed_training.distributed_rank}" + ) + + return cfg.distributed_training.distributed_rank + + +def distributed_main(i, main, cfg: FairseqConfig, kwargs): + cfg.distributed_training.device_id = i + if torch.cuda.is_available() and not cfg.common.cpu and not cfg.common.tpu: + torch.cuda.set_device(cfg.distributed_training.device_id) + if cfg.distributed_training.distributed_rank is None: # torch.multiprocessing.spawn + cfg.distributed_training.distributed_rank = kwargs.pop("start_rank", 0) + i + + cfg.distributed_training.distributed_rank = distributed_init(cfg) + + after_distributed_init_fn = kwargs.pop("after_distributed_init_fn", None) + if after_distributed_init_fn: + cfg = after_distributed_init_fn(cfg) + + main(cfg, **kwargs) + + if torch.distributed.is_initialized(): + torch.distributed.barrier(get_global_group()) + + +def call_main(cfg: FairseqConfig, main, **kwargs): + if cfg.distributed_training.distributed_init_method is None: + infer_init_method(cfg.distributed_training) + + if cfg.distributed_training.distributed_init_method is not None: + # distributed training + if not cfg.distributed_training.distributed_no_spawn: + start_rank = cfg.distributed_training.distributed_rank + cfg.distributed_training.distributed_rank = None # assign automatically + kwargs["start_rank"] = start_rank + + torch.multiprocessing.spawn( + fn=distributed_main, + args=(main, cfg, kwargs), + nprocs=min( + torch.cuda.device_count(), + cfg.distributed_training.distributed_world_size, + ), + join=True, + ) + else: + distributed_main(cfg.distributed_training.device_id, main, cfg, kwargs) + elif cfg.common.tpu and cfg.distributed_training.distributed_world_size > 1: + import torch_xla.distributed.xla_multiprocessing as xmp + + torch.multiprocessing.set_sharing_strategy("file_system") + xmp.spawn( + fn=distributed_main, + args=(main, cfg, kwargs), + # tpu-comment: + # 8 devices in one TPU VM, is the max processes to be spawned. + # The rest is driven by xm.distributed.xla_dist + nprocs=min(cfg.distributed_training.distributed_world_size, 8), + ) + else: + # single GPU main + main(cfg, **kwargs) + + +def use_xla(): + global _USE_XLA + return _USE_XLA + + +def new_groups(grouped_ranks: List[List[int]]): + if use_xla(): + return ("tpu", grouped_ranks) + else: + groups = [dist.new_group(g) for g in grouped_ranks] + my_group_idx = _find_my_group_index(grouped_ranks) + return groups[my_group_idx] + + +def _find_my_group_index(grouped_ranks): + my_rank = get_global_rank() + for i, group in enumerate(grouped_ranks): + if my_rank in group: + return i + raise RuntimeError + + +def _find_my_group(grouped_ranks): + index = _find_my_group_index(grouped_ranks) + return grouped_ranks[index] + + +def get_rank(group): + if use_xla(): + assert group[0] == "tpu" + my_group = _find_my_group(group[1]) + return my_group.index(get_global_rank()) + else: + return dist.get_rank(group=group) + + +def get_world_size(group): + if use_xla(): + assert group[0] == "tpu" + my_group = _find_my_group(group[1]) + return len(my_group) + elif torch.distributed.is_initialized(): + return dist.get_world_size(group=group) + else: + return 1 + + +def get_global_group(): + if use_xla(): + return new_groups([list(range(get_global_world_size()))]) + elif torch.distributed.is_initialized(): + if not hasattr(get_global_group, "_global_group"): + # ideally we could use torch.distributed.group.WORLD, but it seems + # to cause random NCCL hangs in some cases + get_global_group._global_group = dist.new_group() + return get_global_group._global_group + else: + return None + + +def get_global_rank(): + if use_xla(): + return xm.get_ordinal() + elif torch.distributed.is_initialized(): + return torch.distributed.get_rank() + else: + return 0 + + +def get_global_world_size(): + if use_xla(): + return xm.xrt_world_size() + elif torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + else: + return 1 + + +def get_data_parallel_group(): + """Get the data parallel group the caller rank belongs to.""" + global _USE_MEGATRON + if _USE_MEGATRON: + from fairseq.model_parallel.megatron import mpu + + return mpu.get_data_parallel_group() + else: + return get_global_group() + + +def get_data_parallel_rank(): + """Return my rank for the data parallel group.""" + return get_rank(get_data_parallel_group()) + + +def get_data_parallel_world_size(): + """Return world size for the data parallel group.""" + return get_world_size(get_data_parallel_group()) + + +def get_model_parallel_group(): + global _USE_MEGATRON + if _USE_MEGATRON: + from fairseq.model_parallel.megatron import mpu + + return mpu.get_model_parallel_group() + else: + return None + + +def get_model_parallel_rank(): + """Return my rank for the model parallel group.""" + return get_rank(get_model_parallel_group()) + + +def get_model_parallel_world_size(): + """Return world size for the model parallel group.""" + return get_world_size(get_model_parallel_group()) + + +def all_reduce(tensor, group, op="sum"): + if use_xla(): + assert isinstance(group, tuple) and group[0] == "tpu" + tensor = [tensor] # wrap in a list to make xm.all_reduce in-place + return xm.all_reduce(op, tensor, groups=group[1])[0] + else: + if op == "sum": + op = dist.ReduceOp.SUM + elif op == "max": + op = dist.ReduceOp.MAX + else: + raise NotImplementedError + dist.all_reduce(tensor, op=op, group=group) + return tensor + + +def broadcast(tensor, src, group): + if use_xla(): + # XLA doesn't support broadcast, hack it with all_reduce + if get_rank(group) != src: + tensor.zero_() + all_reduce(tensor, group) + else: + dist.broadcast(tensor, src=src, group=group) + + +def all_to_all(tensor, group): + """Perform an all-to-all operation on a 1D Tensor.""" + assert tensor.dim() == 1 + split_count = get_world_size(group=group) + assert tensor.numel() % split_count == 0 + if use_xla(): + assert isinstance(group, tuple) and group[0] == "tpu" + return xm.all_to_all( + tensor, + split_dimension=0, + concat_dimension=0, + split_count=split_count, + groups=group[1], + ) + else: + output = torch.zeros_like(tensor) + dist.all_to_all_single(output, tensor, group=group) + return output + + +def all_gather(tensor, group, return_tensor=False): + """Perform an all-gather operation.""" + if use_xla(): + result = xm.all_gather(tensor, groups=group[1]) + world_size = get_world_size(group=group) + result = result.view(world_size, *tensor.size()) + if return_tensor: + return result + else: + return [result[i] for i in range(world_size)] + else: + world_size = get_world_size(group=group) + rank = get_rank(group=group) + tensor_list = [ + tensor if i == rank else torch.empty_like(tensor) for i in range(world_size) + ] + dist.all_gather(tensor_list, tensor, group=group) + if return_tensor: + return torch.stack(tensor_list, dim=0) + else: + return tensor_list + + +def all_gather_list(data, group=None, max_size=16384): + """Gathers arbitrary data from all nodes into a list. + + Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python + data. Note that *data* must be picklable and any CUDA tensors will be moved + to CPU and returned on CPU as well. + + Args: + data (Any): data from the local worker to be gathered on other workers + group: group of the collective + max_size (int, optional): maximum size of the data to be gathered + across workers + """ + from fairseq import utils + + if group is None: + group = get_global_group() + rank = get_rank(group=group) + world_size = get_world_size(group=group) + + buffer_size = max_size * world_size + if ( + not hasattr(all_gather_list, "_buffer") + or all_gather_list._buffer.numel() < buffer_size + ): + all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size) + all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory() + buffer = all_gather_list._buffer + buffer.zero_() + cpu_buffer = all_gather_list._cpu_buffer + + data = utils.move_to_cpu(data) + enc = pickle.dumps(data) + enc_size = len(enc) + header_size = 4 # size of header that contains the length of the encoded data + size = header_size + enc_size + if size > max_size: + raise ValueError( + "encoded data size ({}) exceeds max_size ({})".format(size, max_size) + ) + + header = struct.pack(">I", enc_size) + cpu_buffer[:size] = torch.ByteTensor(list(header + enc)) + start = rank * max_size + buffer[start : start + size].copy_(cpu_buffer[:size]) + + all_reduce(buffer, group=group) + + buffer = buffer.cpu() + try: + result = [] + for i in range(world_size): + out_buffer = buffer[i * max_size : (i + 1) * max_size] + (enc_size,) = struct.unpack(">I", bytes(out_buffer[:header_size].tolist())) + if enc_size > 0: + result.append( + pickle.loads( + bytes(out_buffer[header_size : header_size + enc_size].tolist()) + ) + ) + return result + except pickle.UnpicklingError: + raise Exception( + "Unable to unpickle data from other workers. all_gather_list requires all " + "workers to enter the function together, so this error usually indicates " + "that the workers have fallen out of sync somehow. Workers can fall out of " + "sync if one of them runs out of memory, or if there are other conditions " + "in your training script that can cause one worker to finish an epoch " + "while other workers are still iterating over their portions of the data. " + "Try rerunning with --ddp-backend=legacy_ddp and see if that helps." + ) + + +def all_reduce_dict(data: Mapping[str, Any], device, group) -> Dict[str, Any]: + """ + AllReduce a dictionary of values across workers. We separately + reduce items that are already on the device and items on CPU for + better performance. + + Args: + data (Mapping[str, Any]): dictionary of data to all-reduce, but + cannot be a nested dictionary + device (torch.device): device for the reduction + group: group of the collective + """ + data_keys = list(data.keys()) + + # We want to separately reduce items that are already on the + # device and items on CPU for performance reasons. + cpu_data = OrderedDict() + device_data = OrderedDict() + for k in data_keys: + t = data[k] + if not torch.is_tensor(t): + cpu_data[k] = torch.tensor(t, dtype=torch.double) + elif t.device.type != device.type: + cpu_data[k] = t.to(dtype=torch.double) + else: + device_data[k] = t.to(dtype=torch.double) + + def _all_reduce_dict(data: OrderedDict): + if len(data) == 0: + return data + buf = torch.cat([t.view(-1) for t in data.values()]).to(device=device) + all_reduce(buf, group=group) + split_buf = torch.split(buf.clone(), [t.numel() for t in data.values()]) + reduced_data = [t.view_as(orig) for t, orig in zip(split_buf, data.values())] + return OrderedDict(zip(data.keys(), reduced_data)) + + cpu_data = _all_reduce_dict(cpu_data) + device_data = _all_reduce_dict(device_data) + + def get_from_stack(key): + if key in cpu_data: + return cpu_data[key] + elif key in device_data: + return device_data[key] + raise KeyError + + return OrderedDict([(key, get_from_stack(key)) for key in data_keys]) + + +def broadcast_tensors( + tensors: Optional[List[torch.Tensor]], + src_rank: int, + group: object, + dist_device: Optional[torch.device] = None, +) -> List[torch.Tensor]: + """ + Broadcasts a list of tensors without other (non-src) ranks needing to know + the dtypes/shapes of the tensors. + """ + if dist_device is None: + if torch.distributed.get_backend(group) == "nccl": + dist_device = torch.device("cuda") + else: + dist_device = torch.device("cpu") + + # share metadata first to simplify transfer + is_src_rank = get_rank(group) == src_rank + if is_src_rank: + metadata = [ + {"size": t.size(), "dtype": t.dtype, "device": t.device} for t in tensors + ] + metadata = _broadcast_object_slow(metadata, src_rank, group, dist_device) + else: + metadata = _broadcast_object_slow(None, src_rank, group, dist_device) + + out_tensors = [] + for i, meta in enumerate(metadata): + if is_src_rank: + tensor = tensors[i] + broadcast(tensors[i].to(dist_device), src=src_rank, group=group) + else: + tensor = torch.zeros( + [meta["size"].numel()], dtype=meta["dtype"], device=dist_device + ) + broadcast(tensor, src=src_rank, group=group) + tensor = tensor.view(meta["size"]).to(meta["device"]) + out_tensors.append(tensor) + return out_tensors + + +def broadcast_object( + obj: Any, + src_rank: int, + group: object, + dist_device: Optional[torch.device] = None, +) -> Any: + """Broadcast an arbitrary Python object to other workers.""" + if dist_device is None: + if torch.distributed.get_backend(group) == "nccl": + dist_device = torch.device("cuda") + else: + dist_device = torch.device("cpu") + + if get_rank(group) == src_rank: + # split the tensors from the non-tensors so we can broadcast them + # directly, avoiding unnecessary serialization/deserialization + tensors = [] + obj = _split_tensors_from_obj(obj, tensors) + obj = _broadcast_object_slow(obj, src_rank, group, dist_device) + tensors = broadcast_tensors(tensors, src_rank, group, dist_device) + else: + obj = _broadcast_object_slow(None, src_rank, group, dist_device) + tensors = broadcast_tensors(None, src_rank, group, dist_device) + return _put_tensors_in_obj(obj, tensors) + + +def _broadcast_object_slow( + obj: Any, + src_rank: int, + group: object, + dist_device: torch.device, +) -> Any: + if get_rank(group) == src_rank: + # Emit data + buffer = io.BytesIO() + torch.save(obj, buffer) + buffer = torch.ByteTensor(buffer.getbuffer()).to(dist_device) + length = torch.LongTensor([len(buffer)]).to(dist_device) + broadcast(length, src=src_rank, group=group) + broadcast(buffer, src=src_rank, group=group) + else: + # Fetch from the source + length = torch.LongTensor([0]).to(dist_device) + broadcast(length, src=src_rank, group=group) + buffer = torch.ByteTensor(int(length.item())).to(dist_device) + broadcast(buffer, src=src_rank, group=group) + buffer = io.BytesIO(buffer.cpu().numpy()) + obj = torch.load(buffer, map_location="cpu") + return obj + + +@dataclass(frozen=True) +class _TensorPlaceholder: + index: int + + +def _split_tensors_from_obj(obj: Any, tensors: List[torch.Tensor]) -> Any: + if torch.is_tensor(obj): + placeholder = _TensorPlaceholder(index=len(tensors)) + tensors.append(obj) + return placeholder + elif isinstance(obj, dict): + return {k: _split_tensors_from_obj(v, tensors) for k, v in obj.items()} + elif isinstance(obj, list): + return [_split_tensors_from_obj(v, tensors) for v in obj] + elif isinstance(obj, tuple): + return tuple(_split_tensors_from_obj(v, tensors) for v in obj) + elif isinstance(obj, set): + return {_split_tensors_from_obj(v, tensors) for v in obj} + else: + return obj + + +def _put_tensors_in_obj(obj: Any, tensors: List[torch.Tensor]) -> Any: + if isinstance(obj, _TensorPlaceholder): + return tensors[obj.index] + elif isinstance(obj, dict): + return {k: _put_tensors_in_obj(v, tensors) for k, v in obj.items()} + elif isinstance(obj, list): + return [_put_tensors_in_obj(v, tensors) for v in obj] + elif isinstance(obj, tuple): + return tuple(_put_tensors_in_obj(v, tensors) for v in obj) + elif isinstance(obj, set): + return {_put_tensors_in_obj(v, tensors) for v in obj} + else: + return obj diff --git a/fairseq/file_chunker_utils.py b/fairseq/file_chunker_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3f275490993dbbcd05d990c050ae6c7b4c9568c9 --- /dev/null +++ b/fairseq/file_chunker_utils.py @@ -0,0 +1,84 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import typing as tp + + +def _safe_readline(fd) -> str: + pos = fd.tell() + while True: + try: + return fd.readline() + except UnicodeDecodeError: + pos -= 1 + fd.seek(pos) # search where this character begins + + +def find_offsets(filename: str, num_chunks: int) -> tp.List[int]: + """ + given a file and a number of chuncks, find the offsets in the file + to be able to chunk around full lines. + """ + with open(filename, "r", encoding="utf-8") as f: + size = os.fstat(f.fileno()).st_size + chunk_size = size // num_chunks + offsets = [0 for _ in range(num_chunks + 1)] + for i in range(1, num_chunks): + f.seek(chunk_size * i) + _safe_readline(f) + offsets[i] = f.tell() + offsets[-1] = size + return offsets + + +class ChunkLineIterator: + """ + Iterator to properly iterate over lines of a file chunck. + """ + + def __init__(self, fd, start_offset: int, end_offset: int): + self._fd = fd + self._start_offset = start_offset + self._end_offset = end_offset + + def __iter__(self) -> tp.Iterable[str]: + self._fd.seek(self._start_offset) + # next(f) breaks f.tell(), hence readline() must be used + line = _safe_readline(self._fd) + while line: + pos = self._fd.tell() + # f.tell() does not always give the byte position in the file + # sometimes it skips to a very large number + # it is unlikely that through a normal read we go from + # end bytes to end + 2**32 bytes (4 GB) and this makes it unlikely + # that the procedure breaks by the undeterministic behavior of + # f.tell() + if ( + self._end_offset > 0 + and pos > self._end_offset + and pos < self._end_offset + 2**32 + ): + break + yield line + line = self._fd.readline() + + +class Chunker: + """ + contextmanager to read a chunck of a file line by line. + """ + + def __init__(self, path: str, start_offset: int, end_offset: int): + self.path = path + self.start_offset = start_offset + self.end_offset = end_offset + + def __enter__(self) -> ChunkLineIterator: + self.fd = open(self.path, "r", encoding="utf-8") + return ChunkLineIterator(self.fd, self.start_offset, self.end_offset) + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.fd.close() diff --git a/fairseq/file_io.py b/fairseq/file_io.py new file mode 100644 index 0000000000000000000000000000000000000000..8eca70a0668d09e211c06b5b432e4d0d2125ca72 --- /dev/null +++ b/fairseq/file_io.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import shutil +from typing import List, Optional + + +logger = logging.getLogger(__file__) + + +try: + from iopath.common.file_io import g_pathmgr as IOPathManager + + try: + # [FB only - for now] AWS PathHandler for PathManager + from .fb_pathhandlers import S3PathHandler + + IOPathManager.register_handler(S3PathHandler()) + except KeyError: + logging.warning("S3PathHandler already registered.") + except ImportError: + logging.debug( + "S3PathHandler couldn't be imported. Either missing fb-only files, or boto3 module." + ) + +except ImportError: + IOPathManager = None + + +class PathManager: + """ + Wrapper for insulating OSS I/O (using Python builtin operations) from + iopath's PathManager abstraction (for transparently handling various + internal backends). + """ + + @staticmethod + def open( + path: str, + mode: str = "r", + buffering: int = -1, + encoding: Optional[str] = None, + errors: Optional[str] = None, + newline: Optional[str] = None, + ): + if IOPathManager: + return IOPathManager.open( + path=path, + mode=mode, + buffering=buffering, + encoding=encoding, + errors=errors, + newline=newline, + ) + return open( + path, + mode=mode, + buffering=buffering, + encoding=encoding, + errors=errors, + newline=newline, + ) + + @staticmethod + def copy(src_path: str, dst_path: str, overwrite: bool = False) -> bool: + if IOPathManager: + return IOPathManager.copy( + src_path=src_path, dst_path=dst_path, overwrite=overwrite + ) + return shutil.copyfile(src_path, dst_path) + + @staticmethod + def get_local_path(path: str, **kwargs) -> str: + if IOPathManager: + return IOPathManager.get_local_path(path, **kwargs) + return path + + @staticmethod + def exists(path: str) -> bool: + if IOPathManager: + return IOPathManager.exists(path) + return os.path.exists(path) + + @staticmethod + def isfile(path: str) -> bool: + if IOPathManager: + return IOPathManager.isfile(path) + return os.path.isfile(path) + + @staticmethod + def ls(path: str) -> List[str]: + if IOPathManager: + return IOPathManager.ls(path) + return os.listdir(path) + + @staticmethod + def mkdirs(path: str) -> None: + if IOPathManager: + return IOPathManager.mkdirs(path) + os.makedirs(path, exist_ok=True) + + @staticmethod + def rm(path: str) -> None: + if IOPathManager: + return IOPathManager.rm(path) + os.remove(path) + + @staticmethod + def chmod(path: str, mode: int) -> None: + if not PathManager.path_requires_pathmanager(path): + os.chmod(path, mode) + + @staticmethod + def register_handler(handler) -> None: + if IOPathManager: + return IOPathManager.register_handler(handler=handler) + + @staticmethod + def copy_from_local( + local_path: str, dst_path: str, overwrite: bool = False, **kwargs + ) -> None: + if IOPathManager: + return IOPathManager.copy_from_local( + local_path=local_path, dst_path=dst_path, overwrite=overwrite, **kwargs + ) + return shutil.copyfile(local_path, dst_path) + + @staticmethod + def path_requires_pathmanager(path: str) -> bool: + """Do we require PathManager to access given path?""" + if IOPathManager: + for p in IOPathManager._path_handlers.keys(): + if path.startswith(p): + return True + return False + + @staticmethod + def supports_rename(path: str) -> bool: + # PathManager doesn't yet support renames + return not PathManager.path_requires_pathmanager(path) + + @staticmethod + def rename(src: str, dst: str): + os.rename(src, dst) + + """ + ioPath async PathManager methods: + """ + + @staticmethod + def opena( + path: str, + mode: str = "r", + buffering: int = -1, + encoding: Optional[str] = None, + errors: Optional[str] = None, + newline: Optional[str] = None, + ): + """ + Return file descriptor with asynchronous write operations. + """ + global IOPathManager + if not IOPathManager: + logging.info("ioPath is initializing PathManager.") + try: + from iopath.common.file_io import PathManager + + IOPathManager = PathManager() + except Exception: + logging.exception("Failed to initialize ioPath PathManager object.") + return IOPathManager.opena( + path=path, + mode=mode, + buffering=buffering, + encoding=encoding, + errors=errors, + newline=newline, + ) + + @staticmethod + def async_close() -> bool: + """ + Wait for files to be written and clean up asynchronous PathManager. + NOTE: `PathManager.async_close()` must be called at the end of any + script that uses `PathManager.opena(...)`. + """ + global IOPathManager + if IOPathManager: + return IOPathManager.async_close() + return False diff --git a/fairseq/file_utils.py b/fairseq/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b99da2e8cd82a7f4e419fc0abdbc00d617efc611 --- /dev/null +++ b/fairseq/file_utils.py @@ -0,0 +1,370 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utilities for working with the local dataset cache. +This file is adapted from `AllenNLP `_. +and `huggingface `_. +""" + +import fnmatch +import json +import logging +import os +import shutil +import tarfile +import tempfile +from functools import partial, wraps +from hashlib import sha256 +from io import open + + +try: + from torch.hub import _get_torch_home + + torch_cache_home = _get_torch_home() +except ImportError: + torch_cache_home = os.path.expanduser( + os.getenv( + "TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch") + ) + ) +default_cache_path = os.path.join(torch_cache_home, "pytorch_fairseq") + +try: + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse + +try: + from pathlib import Path + + PYTORCH_FAIRSEQ_CACHE = Path(os.getenv("PYTORCH_FAIRSEQ_CACHE", default_cache_path)) +except (AttributeError, ImportError): + PYTORCH_FAIRSEQ_CACHE = os.getenv("PYTORCH_FAIRSEQ_CACHE", default_cache_path) + +CONFIG_NAME = "config.json" +WEIGHTS_NAME = "pytorch_model.bin" + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +def load_archive_file(archive_file): + # redirect to the cache, if necessary + try: + resolved_archive_file = cached_path(archive_file, cache_dir=None) + except EnvironmentError: + logger.info( + "Archive name '{}' was not found in archive name list. " + "We assumed '{}' was a path or URL but couldn't find any file " + "associated to this path or URL.".format( + archive_file, + archive_file, + ) + ) + return None + + if resolved_archive_file == archive_file: + logger.info("loading archive file {}".format(archive_file)) + else: + logger.info( + "loading archive file {} from cache at {}".format( + archive_file, resolved_archive_file + ) + ) + + # Extract archive to temp dir and replace .tar.bz2 if necessary + tempdir = None + if not os.path.isdir(resolved_archive_file): + tempdir = tempfile.mkdtemp() + logger.info( + "extracting archive file {} to temp dir {}".format( + resolved_archive_file, tempdir + ) + ) + ext = os.path.splitext(archive_file)[1][1:] + with tarfile.open(resolved_archive_file, "r:" + ext) as archive: + top_dir = os.path.commonprefix(archive.getnames()) + archive.extractall(tempdir) + os.remove(resolved_archive_file) + shutil.move(os.path.join(tempdir, top_dir), resolved_archive_file) + shutil.rmtree(tempdir) + + return resolved_archive_file + + +def url_to_filename(url, etag=None): + """ + Convert `url` into a hashed filename in a repeatable way. + If `etag` is specified, append its hash to the URL's, delimited + by a period. + """ + url_bytes = url.encode("utf-8") + url_hash = sha256(url_bytes) + filename = url_hash.hexdigest() + + if etag: + etag_bytes = etag.encode("utf-8") + etag_hash = sha256(etag_bytes) + filename += "." + etag_hash.hexdigest() + + return filename + + +def filename_to_url(filename, cache_dir=None): + """ + Return the url and etag (which may be ``None``) stored for `filename`. + Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. + """ + if cache_dir is None: + cache_dir = PYTORCH_FAIRSEQ_CACHE + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + cache_path = os.path.join(cache_dir, filename) + if not os.path.exists(cache_path): + raise EnvironmentError("file {} not found".format(cache_path)) + + meta_path = cache_path + ".json" + if not os.path.exists(meta_path): + raise EnvironmentError("file {} not found".format(meta_path)) + + with open(meta_path, encoding="utf-8") as meta_file: + metadata = json.load(meta_file) + url = metadata["url"] + etag = metadata["etag"] + + return url, etag + + +def cached_path_from_pm(url_or_filename): + """ + Tries to cache the specified URL using PathManager class. + Returns the cached path if success otherwise failure. + """ + try: + from fairseq.file_io import PathManager + + local_path = PathManager.get_local_path(url_or_filename) + return local_path + except Exception: + return None + + +def cached_path(url_or_filename, cache_dir=None): + """ + Given something that might be a URL (or might be a local path), + determine which. If it's a URL, download the file and cache it, and + return the path to the cached file. If it's already a local path, + make sure the file exists and then return the path. + """ + if cache_dir is None: + cache_dir = PYTORCH_FAIRSEQ_CACHE + if isinstance(url_or_filename, Path): + url_or_filename = str(url_or_filename) + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + parsed = urlparse(url_or_filename) + + if parsed.scheme in ("http", "https", "s3"): + # URL, so get it from the cache (downloading if necessary) + return get_from_cache(url_or_filename, cache_dir) + elif os.path.exists(url_or_filename): + # File, and it exists. + return url_or_filename + elif parsed.scheme == "": + # File, but it doesn't exist. + raise EnvironmentError("file {} not found".format(url_or_filename)) + else: + cached_path = cached_path_from_pm(url_or_filename) + if cached_path: + return cached_path + # Something unknown + raise ValueError( + "unable to parse {} as a URL or as a local path".format(url_or_filename) + ) + + +def split_s3_path(url): + """Split a full s3 path into the bucket name and path.""" + parsed = urlparse(url) + if not parsed.netloc or not parsed.path: + raise ValueError("bad s3 path {}".format(url)) + bucket_name = parsed.netloc + s3_path = parsed.path + # Remove '/' at beginning of path. + if s3_path.startswith("/"): + s3_path = s3_path[1:] + return bucket_name, s3_path + + +def s3_request(func): + """ + Wrapper function for s3 requests in order to create more helpful error + messages. + """ + + @wraps(func) + def wrapper(url, *args, **kwargs): + from botocore.exceptions import ClientError + + try: + return func(url, *args, **kwargs) + except ClientError as exc: + if int(exc.response["Error"]["Code"]) == 404: + raise EnvironmentError("file {} not found".format(url)) + else: + raise + + return wrapper + + +@s3_request +def s3_etag(url): + """Check ETag on S3 object.""" + import boto3 + + s3_resource = boto3.resource("s3") + bucket_name, s3_path = split_s3_path(url) + s3_object = s3_resource.Object(bucket_name, s3_path) + return s3_object.e_tag + + +@s3_request +def s3_get(url, temp_file): + """Pull a file directly from S3.""" + import boto3 + + s3_resource = boto3.resource("s3") + bucket_name, s3_path = split_s3_path(url) + s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) + + +def request_wrap_timeout(func, url): + import requests + + for attempt, timeout in enumerate([10, 20, 40, 60, 60]): + try: + return func(timeout=timeout) + except requests.exceptions.Timeout as e: + logger.warning( + "Request for %s timed-out (attempt %d). Retrying with a timeout of %d secs", + url, + attempt, + timeout, + exc_info=e, + ) + continue + raise RuntimeError(f"Unable to fetch file {url}") + + +def http_get(url, temp_file): + import requests + from tqdm import tqdm + + req = request_wrap_timeout(partial(requests.get, url, stream=True), url) + content_length = req.headers.get("Content-Length") + total = int(content_length) if content_length is not None else None + progress = tqdm(unit="B", total=total) + for chunk in req.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + progress.close() + + +def get_from_cache(url, cache_dir=None): + """ + Given a URL, look for the corresponding dataset in the local cache. + If it's not there, download it. Then return the path to the cached file. + """ + if cache_dir is None: + cache_dir = PYTORCH_FAIRSEQ_CACHE + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + if not os.path.exists(cache_dir): + os.makedirs(cache_dir) + + # Get eTag to add to filename, if it exists. + if url.startswith("s3://"): + etag = s3_etag(url) + else: + try: + import requests + + response = request_wrap_timeout( + partial(requests.head, url, allow_redirects=True), url + ) + if response.status_code != 200: + etag = None + else: + etag = response.headers.get("ETag") + except RuntimeError: + etag = None + + filename = url_to_filename(url, etag) + + # get cache path to put the file + cache_path = os.path.join(cache_dir, filename) + + # If we don't have a connection (etag is None) and can't identify the file + # try to get the last downloaded one + if not os.path.exists(cache_path) and etag is None: + matching_files = fnmatch.filter(os.listdir(cache_dir), filename + ".*") + matching_files = list(filter(lambda s: not s.endswith(".json"), matching_files)) + if matching_files: + cache_path = os.path.join(cache_dir, matching_files[-1]) + + if not os.path.exists(cache_path): + # Download to temporary file, then copy to cache dir once finished. + # Otherwise you get corrupt cache entries if the download gets interrupted. + with tempfile.NamedTemporaryFile() as temp_file: + logger.info("%s not found in cache, downloading to %s", url, temp_file.name) + + # GET file object + if url.startswith("s3://"): + s3_get(url, temp_file) + else: + http_get(url, temp_file) + + # we are copying the file before closing it, so flush to avoid truncation + temp_file.flush() + # shutil.copyfileobj() starts at the current position, so go to the start + temp_file.seek(0) + + logger.info("copying %s to cache at %s", temp_file.name, cache_path) + with open(cache_path, "wb") as cache_file: + shutil.copyfileobj(temp_file, cache_file) + + logger.info("creating metadata file for %s", cache_path) + meta = {"url": url, "etag": etag} + meta_path = cache_path + ".json" + with open(meta_path, "w") as meta_file: + output_string = json.dumps(meta) + meta_file.write(output_string) + + logger.info("removing temp file %s", temp_file.name) + + return cache_path + + +def read_set_from_file(filename): + """ + Extract a de-duped collection (set) of text from a file. + Expected file format is one item per line. + """ + collection = set() + with open(filename, "r", encoding="utf-8") as file_: + for line in file_: + collection.add(line.rstrip()) + return collection + + +def get_file_extension(path, dot=True, lower=True): + ext = os.path.splitext(path)[1] + ext = ext if dot else ext[1:] + return ext.lower() if lower else ext diff --git a/fairseq/hub_utils.py b/fairseq/hub_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b0c2da15bf2484a1109871d36c7a16d60219c42d --- /dev/null +++ b/fairseq/hub_utils.py @@ -0,0 +1,326 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import copy +import logging +import os +from typing import Any, Dict, Iterator, List + +import torch +from omegaconf import open_dict +from torch import nn + +from fairseq import utils +from fairseq.data import encoders + +logger = logging.getLogger(__name__) + + +def from_pretrained( + model_name_or_path, + checkpoint_file="model.pt", + data_name_or_path=".", + archive_map=None, + **kwargs +): + from fairseq import checkpoint_utils, file_utils + + if archive_map is not None: + if model_name_or_path in archive_map: + model_name_or_path = archive_map[model_name_or_path] + if data_name_or_path is not None and data_name_or_path in archive_map: + data_name_or_path = archive_map[data_name_or_path] + + # allow archive_map to set default arg_overrides (e.g., tokenizer, bpe) + # for each model + if isinstance(model_name_or_path, dict): + for k, v in model_name_or_path.items(): + if k == "checkpoint_file": + checkpoint_file = v + elif ( + k != "path" + # only set kwargs that don't already have overrides + and k not in kwargs + ): + kwargs[k] = v + model_name_or_path = model_name_or_path["path"] + + model_path = file_utils.load_archive_file(model_name_or_path) + + # convenience hack for loading data and BPE codes from model archive + if data_name_or_path.startswith("."): + kwargs["data"] = os.path.abspath(os.path.join(model_path, data_name_or_path)) + else: + kwargs["data"] = file_utils.load_archive_file(data_name_or_path) + for file, arg in { + "code": "bpe_codes", + "bpecodes": "bpe_codes", + "sentencepiece.bpe.model": "sentencepiece_model", + "merges.txt": "bpe_merges", + "vocab.json": "bpe_vocab", + }.items(): + path = os.path.join(model_path, file) + if os.path.exists(path): + kwargs[arg] = path + + if "user_dir" in kwargs: + utils.import_user_module(argparse.Namespace(user_dir=kwargs["user_dir"])) + + model_path = [ + os.path.join(model_path, cpt) for cpt in checkpoint_file.split(os.pathsep) + ] + + if "is_vocoder" in kwargs: + args = {"data": kwargs["data"], "model_path": model_path} + task = None + models = None + else: + models, args, task = checkpoint_utils.load_model_ensemble_and_task( + model_path, + arg_overrides=kwargs, + ) + if "generation_args" in kwargs and kwargs["generation_args"]: + for key in kwargs["generation_args"]: + setattr(args["generation"], key, kwargs["generation_args"][key]) + + return { + "args": args, + "task": task, + "models": models, + } + + +class GeneratorHubInterface(nn.Module): + """ + PyTorch Hub interface for generating sequences from a pre-trained + translation or language model. + """ + + def __init__(self, cfg, task, models): + super().__init__() + self.cfg = cfg + self.task = task + self.models = nn.ModuleList(models) + self.src_dict = task.source_dictionary + self.tgt_dict = task.target_dictionary + + # optimize model for generation + for model in self.models: + model.prepare_for_inference_(cfg) + + # Load alignment dictionary for unknown word replacement + # (None if no unknown word replacement, empty if no path to align dictionary) + self.align_dict = utils.load_align_dict(cfg.generation.replace_unk) + + self.tokenizer = encoders.build_tokenizer(cfg.tokenizer) + self.bpe = encoders.build_bpe(cfg.bpe) + + self.max_positions = utils.resolve_max_positions( + self.task.max_positions(), *[model.max_positions() for model in models] + ) + + # this is useful for determining the device + self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float)) + + @property + def device(self): + return self._float_tensor.device + + def translate( + self, sentences: List[str], beam: int = 5, verbose: bool = False, **kwargs + ) -> List[str]: + return self.sample(sentences, beam, verbose, **kwargs) + + def sample( + self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs + ) -> List[str]: + if isinstance(sentences, str): + return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0] + tokenized_sentences = [self.encode(sentence) for sentence in sentences] + batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs) + return [self.decode(hypos[0]["tokens"]) for hypos in batched_hypos] + + def score( + self, sentences: List[str], replace_newline_with_eos: bool = False, **kwargs + ): + if isinstance(sentences, str): + return self.score( + [sentences], replace_newline_with_eos=replace_newline_with_eos, **kwargs + )[0] + + def encode(sentence): + if replace_newline_with_eos: + return torch.cat([self.encode(line) for line in sentence.splitlines()]) + else: + return self.encode(sentence) + + # NOTE: this doesn't support translation tasks currently + tokenized_sentences = [encode(sentence) for sentence in sentences] + return [ + hypos[0] + for hypos in self.generate( + tokenized_sentences, score_reference=True, **kwargs + ) + ] + + def generate( + self, + tokenized_sentences: List[torch.LongTensor], + beam: int = 5, + verbose: bool = False, + skip_invalid_size_inputs=False, + inference_step_args=None, + prefix_allowed_tokens_fn=None, + **kwargs + ) -> List[List[Dict[str, torch.Tensor]]]: + if torch.is_tensor(tokenized_sentences) and tokenized_sentences.dim() == 1: + return self.generate( + tokenized_sentences.unsqueeze(0), beam=beam, verbose=verbose, **kwargs + )[0] + + # build generator using current args as well as any kwargs + gen_args = copy.deepcopy(self.cfg.generation) + with open_dict(gen_args): + gen_args.beam = beam + for k, v in kwargs.items(): + setattr(gen_args, k, v) + generator = self.task.build_generator( + self.models, + gen_args, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + ) + + inference_step_args = inference_step_args or {} + results = [] + for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs): + batch = utils.apply_to_sample(lambda t: t.to(self.device), batch) + translations = self.task.inference_step( + generator, self.models, batch, **inference_step_args + ) + for id, hypos in zip(batch["id"].tolist(), translations): + results.append((id, hypos)) + + # sort output to match input order + outputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])] + + if verbose: + + def getarg(name, default): + return getattr(gen_args, name, getattr(self.cfg, name, default)) + + for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs): + src_str_with_unk = self.string(source_tokens) + logger.info("S\t{}".format(src_str_with_unk)) + for hypo in target_hypotheses: + hypo_str = self.decode(hypo["tokens"]) + logger.info("H\t{}\t{}".format(hypo["score"], hypo_str)) + logger.info( + "P\t{}".format( + " ".join( + map( + lambda x: "{:.4f}".format(x), + hypo["positional_scores"].tolist(), + ) + ) + ) + ) + if hypo["alignment"] is not None and getarg( + "print_alignment", False + ): + logger.info( + "A\t{}".format( + " ".join( + [ + "{}-{}".format(src_idx, tgt_idx) + for src_idx, tgt_idx in hypo["alignment"] + ] + ) + ) + ) + return outputs + + def encode(self, sentence: str) -> torch.LongTensor: + sentence = self.tokenize(sentence) + sentence = self.apply_bpe(sentence) + return self.binarize(sentence) + + def decode(self, tokens: torch.LongTensor) -> str: + sentence = self.string(tokens) + sentence = self.remove_bpe(sentence) + return self.detokenize(sentence) + + def tokenize(self, sentence: str) -> str: + if self.tokenizer is not None: + sentence = self.tokenizer.encode(sentence) + return sentence + + def detokenize(self, sentence: str) -> str: + if self.tokenizer is not None: + sentence = self.tokenizer.decode(sentence) + return sentence + + def apply_bpe(self, sentence: str) -> str: + if self.bpe is not None: + sentence = self.bpe.encode(sentence) + return sentence + + def remove_bpe(self, sentence: str) -> str: + if self.bpe is not None: + sentence = self.bpe.decode(sentence) + return sentence + + def binarize(self, sentence: str) -> torch.LongTensor: + return self.src_dict.encode_line(sentence, add_if_not_exist=False).long() + + def string(self, tokens: torch.LongTensor) -> str: + return self.tgt_dict.string(tokens) + + def _build_batches( + self, tokens: List[List[int]], skip_invalid_size_inputs: bool + ) -> Iterator[Dict[str, Any]]: + lengths = torch.LongTensor([t.numel() for t in tokens]) + batch_iterator = self.task.get_batch_iterator( + dataset=self.task.build_dataset_for_inference(tokens, lengths), + max_tokens=self.cfg.dataset.max_tokens, + max_sentences=self.cfg.dataset.batch_size, + max_positions=self.max_positions, + ignore_invalid_inputs=skip_invalid_size_inputs, + disable_iterator_cache=True, + ).next_epoch_itr(shuffle=False) + return batch_iterator + + +class BPEHubInterface(object): + """PyTorch Hub interface for Byte-Pair Encoding (BPE).""" + + def __init__(self, bpe, **kwargs): + super().__init__() + args = argparse.Namespace(bpe=bpe, **kwargs) + self.bpe = encoders.build_bpe(args) + assert self.bpe is not None + + def encode(self, sentence: str) -> str: + return self.bpe.encode(sentence) + + def decode(self, sentence: str) -> str: + return self.bpe.decode(sentence) + + +class TokenizerHubInterface(object): + """PyTorch Hub interface for tokenization.""" + + def __init__(self, tokenizer, **kwargs): + super().__init__() + args = argparse.Namespace(tokenizer=tokenizer, **kwargs) + self.tokenizer = encoders.build_tokenizer(args) + assert self.tokenizer is not None + + def encode(self, sentence: str) -> str: + return self.tokenizer.encode(sentence) + + def decode(self, sentence: str) -> str: + return self.tokenizer.decode(sentence) diff --git a/fairseq/incremental_decoding_utils.py b/fairseq/incremental_decoding_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b26e6cd01cd4cbdffa23d88b354eb4a55a94189b --- /dev/null +++ b/fairseq/incremental_decoding_utils.py @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import uuid +from typing import Dict, Optional + +from torch import Tensor + + +class FairseqIncrementalState(object): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.init_incremental_state() + + def init_incremental_state(self): + self._incremental_state_id = str(uuid.uuid4()) + + def _get_full_incremental_state_key(self, key: str) -> str: + return "{}.{}".format(self._incremental_state_id, key) + + def get_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + ) -> Optional[Dict[str, Optional[Tensor]]]: + """Helper for getting incremental state for an nn.Module.""" + full_key = self._get_full_incremental_state_key(key) + if incremental_state is None or full_key not in incremental_state: + return None + return incremental_state[full_key] + + def set_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + value: Dict[str, Optional[Tensor]], + ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: + """Helper for setting incremental state for an nn.Module.""" + if incremental_state is not None: + full_key = self._get_full_incremental_state_key(key) + incremental_state[full_key] = value + return incremental_state + + +def with_incremental_state(cls): + cls.__bases__ = (FairseqIncrementalState,) + tuple( + b for b in cls.__bases__ if b != FairseqIncrementalState + ) + return cls diff --git a/fairseq/iterative_refinement_generator.py b/fairseq/iterative_refinement_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..3d32c6bf4dcaacde7f7834da0d1f58d59c8345a9 --- /dev/null +++ b/fairseq/iterative_refinement_generator.py @@ -0,0 +1,359 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections import namedtuple + +import numpy as np +import torch +from fairseq import utils + + +DecoderOut = namedtuple( + "IterativeRefinementDecoderOut", + ["output_tokens", "output_scores", "attn", "step", "max_step", "history"], +) + + +class IterativeRefinementGenerator(object): + def __init__( + self, + tgt_dict, + models=None, + eos_penalty=0.0, + max_iter=10, + max_ratio=2, + beam_size=1, + decoding_format=None, + retain_dropout=False, + adaptive=True, + retain_history=False, + reranking=False, + ): + """ + Generates translations based on iterative refinement. + + Args: + tgt_dict: target dictionary + eos_penalty: if > 0.0, it penalized early-stopping in decoding + max_iter: maximum number of refinement iterations + max_ratio: generate sequences of maximum length ax, where x is the source length + decoding_format: decoding mode in {'unigram', 'ensemble', 'vote', 'dp', 'bs'} + retain_dropout: retaining dropout in the inference + adaptive: decoding with early stop + """ + self.bos = tgt_dict.bos() + self.pad = tgt_dict.pad() + self.unk = tgt_dict.unk() + self.eos = tgt_dict.eos() + self.vocab_size = len(tgt_dict) + self.eos_penalty = eos_penalty + self.max_iter = max_iter + self.max_ratio = max_ratio + self.beam_size = beam_size + self.reranking = reranking + self.decoding_format = decoding_format + self.retain_dropout = retain_dropout + self.retain_history = retain_history + self.adaptive = adaptive + self.models = models + + def generate_batched_itr( + self, + data_itr, + maxlen_a=None, + maxlen_b=None, + cuda=False, + timer=None, + prefix_size=0, + ): + """Iterate over a batched dataset and yield individual translations. + + Args: + maxlen_a/b: generate sequences of maximum length ax + b, + where x is the source sentence length. + cuda: use GPU for generation + timer: StopwatchMeter for timing generations. + """ + + for sample in data_itr: + if "net_input" not in sample: + continue + if timer is not None: + timer.start() + with torch.no_grad(): + hypos = self.generate( + self.models, + sample, + prefix_tokens=sample["target"][:, :prefix_size] + if prefix_size > 0 + else None, + ) + if timer is not None: + timer.stop(sample["ntokens"]) + for i, id in enumerate(sample["id"]): + # remove padding + src = utils.strip_pad(sample["net_input"]["src_tokens"][i, :], self.pad) + ref = utils.strip_pad(sample["target"][i, :], self.pad) + yield id, src, ref, hypos[i] + + @torch.no_grad() + def generate(self, models, sample, prefix_tokens=None, constraints=None): + if constraints is not None: + raise NotImplementedError( + "Constrained decoding with the IterativeRefinementGenerator is not supported" + ) + + # TODO: iterative refinement generator does not support ensemble for now. + if not self.retain_dropout: + for model in models: + model.eval() + + model, reranker = models[0], None + if self.reranking: + assert len(models) > 1, "Assuming the last checkpoint is the reranker" + assert ( + self.beam_size > 1 + ), "Reranking requires multiple translation for each example" + + reranker = models[-1] + models = models[:-1] + + if len(models) > 1 and hasattr(model, "enable_ensemble"): + assert model.allow_ensemble, "{} does not support ensembling".format( + model.__class__.__name__ + ) + model.enable_ensemble(models) + + # TODO: better encoder inputs? + src_tokens = sample["net_input"]["src_tokens"] + src_lengths = sample["net_input"]["src_lengths"] + bsz, src_len = src_tokens.size() + + # initialize + encoder_out = model.forward_encoder([src_tokens, src_lengths]) + prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens) + + if self.beam_size > 1: + assert ( + model.allow_length_beam + ), "{} does not support decoding with length beam.".format( + model.__class__.__name__ + ) + + # regenerate data based on length-beam + length_beam_order = ( + utils.new_arange(src_tokens, self.beam_size, bsz).t().reshape(-1) + ) + encoder_out = model.encoder.reorder_encoder_out( + encoder_out, length_beam_order + ) + prev_decoder_out = model.regenerate_length_beam( + prev_decoder_out, self.beam_size + ) + bsz = bsz * self.beam_size + + sent_idxs = torch.arange(bsz) + prev_output_tokens = prev_decoder_out.output_tokens.clone() + + if self.retain_history: + prev_decoder_out = prev_decoder_out._replace(history=[prev_output_tokens]) + + finalized = [[] for _ in range(bsz)] + + def is_a_loop(x, y, s, a): + b, l_x, l_y = x.size(0), x.size(1), y.size(1) + if l_x > l_y: + y = torch.cat([y, x.new_zeros(b, l_x - l_y).fill_(self.pad)], 1) + s = torch.cat([s, s.new_zeros(b, l_x - l_y)], 1) + if a is not None: + a = torch.cat([a, a.new_zeros(b, l_x - l_y, a.size(2))], 1) + elif l_x < l_y: + x = torch.cat([x, y.new_zeros(b, l_y - l_x).fill_(self.pad)], 1) + return (x == y).all(1), y, s, a + + def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn): + cutoff = prev_out_token.ne(self.pad) + tokens = prev_out_token[cutoff] + if prev_out_score is None: + scores, score = None, None + else: + scores = prev_out_score[cutoff] + score = scores.mean() + + if prev_out_attn is None: + hypo_attn, alignment = None, None + else: + hypo_attn = prev_out_attn[cutoff] + alignment = hypo_attn.max(dim=1)[1] + return { + "steps": step, + "tokens": tokens, + "positional_scores": scores, + "score": score, + "hypo_attn": hypo_attn, + "alignment": alignment, + } + + for step in range(self.max_iter + 1): + + decoder_options = { + "eos_penalty": self.eos_penalty, + "max_ratio": self.max_ratio, + "decoding_format": self.decoding_format, + } + prev_decoder_out = prev_decoder_out._replace( + step=step, + max_step=self.max_iter + 1, + ) + + decoder_out = model.forward_decoder( + prev_decoder_out, encoder_out, **decoder_options + ) + + if self.adaptive: + # terminate if there is a loop + terminated, out_tokens, out_scores, out_attn = is_a_loop( + prev_output_tokens, + decoder_out.output_tokens, + decoder_out.output_scores, + decoder_out.attn, + ) + decoder_out = decoder_out._replace( + output_tokens=out_tokens, + output_scores=out_scores, + attn=out_attn, + ) + + else: + terminated = decoder_out.output_tokens.new_zeros( + decoder_out.output_tokens.size(0) + ).bool() + + if step == self.max_iter: # reach last iteration, terminate + terminated.fill_(1) + + # collect finalized sentences + finalized_idxs = sent_idxs[terminated.to(sent_idxs.device)] + finalized_tokens = decoder_out.output_tokens[terminated] + finalized_scores = decoder_out.output_scores[terminated] + finalized_attn = ( + None + if (decoder_out.attn is None or decoder_out.attn.size(0) == 0) + else decoder_out.attn[terminated] + ) + + if self.retain_history: + finalized_history_tokens = [h[terminated] for h in decoder_out.history] + + for i in range(finalized_idxs.size(0)): + finalized[finalized_idxs[i]] = [ + finalized_hypos( + step, + finalized_tokens[i], + finalized_scores[i], + None if finalized_attn is None else finalized_attn[i], + ) + ] + + if self.retain_history: + finalized[finalized_idxs[i]][0]["history"] = [] + for j in range(len(finalized_history_tokens)): + finalized[finalized_idxs[i]][0]["history"].append( + finalized_hypos( + step, finalized_history_tokens[j][i], None, None + ) + ) + + # check if all terminated + if terminated.sum() == terminated.size(0): + break + + # for next step + not_terminated = ~terminated + prev_decoder_out = decoder_out._replace( + output_tokens=decoder_out.output_tokens[not_terminated], + output_scores=decoder_out.output_scores[not_terminated], + attn=decoder_out.attn[not_terminated] + if (decoder_out.attn is not None and decoder_out.attn.size(0) > 0) + else None, + history=[h[not_terminated] for h in decoder_out.history] + if decoder_out.history is not None + else None, + ) + encoder_out = model.encoder.reorder_encoder_out( + encoder_out, not_terminated.nonzero(as_tuple=False).squeeze() + ) + sent_idxs = sent_idxs[not_terminated.to(sent_idxs.device)] + prev_output_tokens = prev_decoder_out.output_tokens.clone() + + if self.beam_size > 1: + if reranker is not None: + finalized = self.rerank( + reranker, finalized, [src_tokens, src_lengths], self.beam_size + ) + + # aggregate information from length beam + finalized = [ + finalized[ + np.argmax( + [ + finalized[self.beam_size * i + j][0]["score"] + for j in range(self.beam_size) + ] + ) + + self.beam_size * i + ] + for i in range(len(finalized) // self.beam_size) + ] + + return finalized + + def rerank(self, reranker, finalized, encoder_input, beam_size): + def rebuild_batch(finalized): + finalized_tokens = [f[0]["tokens"] for f in finalized] + finalized_maxlen = max(f.size(0) for f in finalized_tokens) + final_output_tokens = ( + finalized_tokens[0] + .new_zeros(len(finalized_tokens), finalized_maxlen) + .fill_(self.pad) + ) + for i, f in enumerate(finalized_tokens): + final_output_tokens[i, : f.size(0)] = f + return final_output_tokens + + final_output_tokens = rebuild_batch(finalized) + final_output_tokens[ + :, 0 + ] = self.eos # autoregressive model assumes starting with EOS + + reranker_encoder_out = reranker.encoder(*encoder_input) + length_beam_order = ( + utils.new_arange( + final_output_tokens, beam_size, reranker_encoder_out.encoder_out.size(1) + ) + .t() + .reshape(-1) + ) + reranker_encoder_out = reranker.encoder.reorder_encoder_out( + reranker_encoder_out, length_beam_order + ) + reranking_scores = reranker.get_normalized_probs( + reranker.decoder(final_output_tokens[:, :-1], reranker_encoder_out), + True, + None, + ) + reranking_scores = reranking_scores.gather(2, final_output_tokens[:, 1:, None]) + reranking_masks = final_output_tokens[:, 1:].ne(self.pad) + reranking_scores = ( + reranking_scores[:, :, 0].masked_fill_(~reranking_masks, 0).sum(1) + ) + reranking_scores = reranking_scores / reranking_masks.sum(1).type_as( + reranking_scores + ) + + for i in range(len(finalized)): + finalized[i][0]["score"] = reranking_scores[i] + + return finalized diff --git a/fairseq/logging/__init__.py b/fairseq/logging/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fairseq/logging/__pycache__/__init__.cpython-310.pyc b/fairseq/logging/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72b62a0d4ca48508a63e73f2e344351137e068e2 Binary files /dev/null and b/fairseq/logging/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/logging/__pycache__/meters.cpython-310.pyc b/fairseq/logging/__pycache__/meters.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f45ba8d13eb8476524174211b550f2333a6909c Binary files /dev/null and b/fairseq/logging/__pycache__/meters.cpython-310.pyc differ diff --git a/fairseq/logging/__pycache__/metrics.cpython-310.pyc b/fairseq/logging/__pycache__/metrics.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17442d53950eb476f557ffefde7fa3f6005f7983 Binary files /dev/null and b/fairseq/logging/__pycache__/metrics.cpython-310.pyc differ diff --git a/fairseq/logging/__pycache__/progress_bar.cpython-310.pyc b/fairseq/logging/__pycache__/progress_bar.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8cc3e733df6c694370031bd4d4e3c7b6ceb5e41 Binary files /dev/null and b/fairseq/logging/__pycache__/progress_bar.cpython-310.pyc differ diff --git a/fairseq/logging/meters.py b/fairseq/logging/meters.py new file mode 100644 index 0000000000000000000000000000000000000000..495bd083000de9e4a05f1470228c1171c8c8bb9c --- /dev/null +++ b/fairseq/logging/meters.py @@ -0,0 +1,351 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import bisect +import time +from collections import OrderedDict +from typing import Dict, Optional + +try: + import torch + + def type_as(a, b): + if torch.is_tensor(a) and torch.is_tensor(b): + return a.to(b) + else: + return a + +except ImportError: + torch = None + + def type_as(a, b): + return a + + +try: + import numpy as np +except ImportError: + np = None + + +class Meter(object): + """Base class for Meters.""" + + def __init__(self): + pass + + def state_dict(self): + return {} + + def load_state_dict(self, state_dict): + pass + + def reset(self): + raise NotImplementedError + + @property + def smoothed_value(self) -> float: + """Smoothed value used for logging.""" + raise NotImplementedError + + +def safe_round(number, ndigits): + if hasattr(number, "__round__"): + return round(number, ndigits) + elif torch is not None and torch.is_tensor(number) and number.numel() == 1: + return safe_round(number.item(), ndigits) + elif np is not None and np.ndim(number) == 0 and hasattr(number, "item"): + return safe_round(number.item(), ndigits) + else: + return number + + +class AverageMeter(Meter): + """Computes and stores the average and current value""" + + def __init__(self, round: Optional[int] = None): + self.round = round + self.reset() + + def reset(self): + self.val = None # most recent update + self.sum = 0 # sum from all updates + self.count = 0 # total n from all updates + + def update(self, val, n=1): + if val is not None: + self.val = val + if n > 0: + self.sum = type_as(self.sum, val) + (val * n) + self.count = type_as(self.count, n) + n + + def state_dict(self): + return { + "val": self.val, + "sum": self.sum, + "count": self.count, + "round": self.round, + } + + def load_state_dict(self, state_dict): + self.val = state_dict["val"] + self.sum = state_dict["sum"] + self.count = state_dict["count"] + self.round = state_dict.get("round", None) + + @property + def avg(self): + return self.sum / self.count if self.count > 0 else self.val + + @property + def smoothed_value(self) -> float: + val = self.avg + if self.round is not None and val is not None: + val = safe_round(val, self.round) + return val + + +class SumMeter(Meter): + """Computes and stores the sum""" + + def __init__(self, round: Optional[int] = None): + self.round = round + self.reset() + + def reset(self): + self.sum = 0 # sum from all updates + + def update(self, val): + if val is not None: + self.sum = type_as(self.sum, val) + val + + def state_dict(self): + return { + "sum": self.sum, + "round": self.round, + } + + def load_state_dict(self, state_dict): + self.sum = state_dict["sum"] + self.round = state_dict.get("round", None) + + @property + def smoothed_value(self) -> float: + val = self.sum + if self.round is not None and val is not None: + val = safe_round(val, self.round) + return val + + +class ConcatTensorMeter(Meter): + """Concatenates tensors""" + + def __init__(self, dim=0): + super().__init__() + self.reset() + self.dim = dim + + def reset(self): + self.tensor = None + + def update(self, val): + if self.tensor is None: + self.tensor = val + else: + self.tensor = torch.cat([self.tensor, val], dim=self.dim) + + def state_dict(self): + return { + "tensor": self.tensor, + } + + def load_state_dict(self, state_dict): + self.tensor = state_dict["tensor"] + + @property + def smoothed_value(self) -> float: + return [] # return a dummy value + + +class TimeMeter(Meter): + """Computes the average occurrence of some event per second""" + + def __init__( + self, + init: int = 0, + n: int = 0, + round: Optional[int] = None, + ): + self.round = round + self.reset(init, n) + + def reset(self, init=0, n=0): + self.init = init + self.start = time.perf_counter() + self.n = n + self.i = 0 + + def update(self, val=1): + self.n = type_as(self.n, val) + val + self.i += 1 + + def state_dict(self): + return { + "init": self.elapsed_time, + "n": self.n, + "round": self.round, + } + + def load_state_dict(self, state_dict): + if "start" in state_dict: + # backwards compatibility for old state_dicts + self.reset(init=state_dict["init"]) + else: + self.reset(init=state_dict["init"], n=state_dict["n"]) + self.round = state_dict.get("round", None) + + @property + def avg(self): + return self.n / self.elapsed_time + + @property + def elapsed_time(self): + return self.init + (time.perf_counter() - self.start) + + @property + def smoothed_value(self) -> float: + val = self.avg + if self.round is not None and val is not None: + val = safe_round(val, self.round) + return val + + +class StopwatchMeter(Meter): + """Computes the sum/avg duration of some event in seconds""" + + def __init__(self, round: Optional[int] = None): + self.round = round + self.sum = 0 + self.n = 0 + self.start_time = None + + def start(self): + self.start_time = time.perf_counter() + + def stop(self, n=1, prehook=None): + if self.start_time is not None: + if prehook is not None: + prehook() + delta = time.perf_counter() - self.start_time + self.sum = self.sum + delta + self.n = type_as(self.n, n) + n + + def reset(self): + self.sum = 0 # cumulative time during which stopwatch was active + self.n = 0 # total n across all start/stop + self.start() + + def state_dict(self): + return { + "sum": self.sum, + "n": self.n, + "round": self.round, + } + + def load_state_dict(self, state_dict): + self.sum = state_dict["sum"] + self.n = state_dict["n"] + self.start_time = None + self.round = state_dict.get("round", None) + + @property + def avg(self): + return self.sum / self.n if self.n > 0 else self.sum + + @property + def elapsed_time(self): + if self.start_time is None: + return 0.0 + return time.perf_counter() - self.start_time + + @property + def smoothed_value(self) -> float: + val = self.avg if self.sum > 0 else self.elapsed_time + if self.round is not None and val is not None: + val = safe_round(val, self.round) + return val + + +class MetersDict(OrderedDict): + """A sorted dictionary of :class:`Meters`. + + Meters are sorted according to a priority that is given when the + meter is first added to the dictionary. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.priorities = [] + + def __setitem__(self, key, value): + assert key not in self, "MetersDict doesn't support reassignment" + priority, value = value + bisect.insort(self.priorities, (priority, len(self.priorities), key)) + super().__setitem__(key, value) + for _, _, key in self.priorities: # reorder dict to match priorities + self.move_to_end(key) + + def add_meter(self, key, meter, priority): + self.__setitem__(key, (priority, meter)) + + def state_dict(self): + return [ + (pri, key, self[key].__class__.__name__, self[key].state_dict()) + for pri, _, key in self.priorities + # can't serialize DerivedMeter instances + if not isinstance(self[key], MetersDict._DerivedMeter) + ] + + def load_state_dict(self, state_dict): + self.clear() + self.priorities.clear() + for pri, key, meter_cls, meter_state in state_dict: + meter = globals()[meter_cls]() + meter.load_state_dict(meter_state) + self.add_meter(key, meter, pri) + + def get_smoothed_value(self, key: str) -> float: + """Get a single smoothed value.""" + meter = self[key] + if isinstance(meter, MetersDict._DerivedMeter): + return meter.fn(self) + else: + return meter.smoothed_value + + def get_smoothed_values(self) -> Dict[str, float]: + """Get all smoothed values.""" + return OrderedDict( + [ + (key, self.get_smoothed_value(key)) + for key in self.keys() + if not key.startswith("_") + ] + ) + + def reset(self): + """Reset Meter instances.""" + for meter in self.values(): + if isinstance(meter, MetersDict._DerivedMeter): + continue + meter.reset() + + class _DerivedMeter(Meter): + """A Meter whose values are derived from other Meters.""" + + def __init__(self, fn): + self.fn = fn + + def reset(self): + pass diff --git a/fairseq/logging/metrics.py b/fairseq/logging/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..49301f27f84351b83b8c869bac86c78ec9f126e6 --- /dev/null +++ b/fairseq/logging/metrics.py @@ -0,0 +1,336 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +A standalone module for aggregating metrics. + +Metrics can be logged from anywhere using the `log_*` functions defined +in this module. The logged values will be aggregated dynamically based +on the aggregation context in which the logging occurs. See the +:func:`aggregate` context manager for more details. +""" + +import contextlib +import uuid +from collections import defaultdict +from typing import Callable, List, Optional + +from .meters import * + + +# Aggregation contexts are considered "active" when inside the scope +# created by the :func:`aggregate` context manager. +_aggregators = OrderedDict() +_active_aggregators = OrderedDict() +_active_aggregators_cnt = defaultdict(lambda: 0) + + +def reset() -> None: + """Reset all metrics aggregators.""" + _aggregators.clear() + _active_aggregators.clear() + _active_aggregators_cnt.clear() + + # The "default" aggregator observes all logged values. + _aggregators["default"] = MetersDict() + _active_aggregators["default"] = _aggregators["default"] + _active_aggregators_cnt["default"] = 1 + + +reset() + + +@contextlib.contextmanager +def aggregate(name: Optional[str] = None, new_root: bool = False): + """Context manager to aggregate metrics under a given name. + + Aggregations can be nested. If *new_root* is ``False``, then logged + metrics will be recorded along the entire stack of nested + aggregators, including a global "default" aggregator. If *new_root* + is ``True``, then this aggregator will be the root of a new + aggregation stack, thus bypassing any parent aggregators. + + Note that aggregation contexts are uniquely identified by their + *name* (e.g., train, valid). Creating a context with an existing + name will reuse the corresponding :class:`MetersDict` instance. + If no name is given, then a temporary aggregator will be created. + + Usage:: + + with metrics.aggregate("train"): + for step, batch in enumerate(epoch): + with metrics.aggregate("train_inner") as agg: + metrics.log_scalar("loss", get_loss(batch)) + if step % log_interval == 0: + print(agg.get_smoothed_value("loss")) + agg.reset() + print(metrics.get_smoothed_values("train")["loss"]) + + Args: + name (str): name of the aggregation. Defaults to a + random/temporary name if not given explicitly. + new_root (bool): make this aggregation the root of a new + aggregation stack. + """ + if name is None: + # generate a temporary name + name = str(uuid.uuid4()) + assert name not in _aggregators + agg = MetersDict() + else: + assert name != "default" + agg = _aggregators.setdefault(name, MetersDict()) + + if new_root: + backup_aggregators = _active_aggregators.copy() + _active_aggregators.clear() + backup_aggregators_cnt = _active_aggregators_cnt.copy() + _active_aggregators_cnt.clear() + + _active_aggregators[name] = agg + _active_aggregators_cnt[name] += 1 + + yield agg + + _active_aggregators_cnt[name] -= 1 + if _active_aggregators_cnt[name] == 0 and name in _active_aggregators: + del _active_aggregators[name] + + if new_root: + _active_aggregators.clear() + _active_aggregators.update(backup_aggregators) + _active_aggregators_cnt.clear() + _active_aggregators_cnt.update(backup_aggregators_cnt) + + +def get_active_aggregators() -> List[MetersDict]: + return list(_active_aggregators.values()) + + +def log_scalar( + key: str, + value: float, + weight: float = 1, + priority: int = 10, + round: Optional[int] = None, +): + """Log a scalar value. + + Args: + key (str): name of the field to log + value (float): value to log + weight (float): weight that this value contributes to the average. + A weight of 0 will always log the latest value. + priority (int): smaller values are logged earlier in the output + round (Optional[int]): number of digits to round to when displaying + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, AverageMeter(round=round), priority) + agg[key].update(value, weight) + + +def log_scalar_sum( + key: str, + value: float, + priority: int = 10, + round: Optional[int] = None, +): + """Log a scalar value that is summed for reporting. + + Args: + key (str): name of the field to log + value (float): value to log + priority (int): smaller values are logged earlier in the output + round (Optional[int]): number of digits to round to when displaying + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, SumMeter(round=round), priority) + agg[key].update(value) + + +def log_concat_tensor( + key: str, + value: torch.Tensor, + priority: int = 10, + dim: int = 0, +): + """Log a scalar value that is summed for reporting. + + Args: + key (str): name of the field to log + value (float): value to log + priority (int): smaller values are logged earlier in the output + round (Optional[int]): number of digits to round to when displaying + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, ConcatTensorMeter(dim=dim), priority) + agg[key].update(value) + + +def log_derived(key: str, fn: Callable[[MetersDict], float], priority: int = 20): + """Log a scalar value derived from other meters. + + Args: + key (str): name of the field to log + fn (Callable[[MetersDict], float]): function that takes a single + argument *meters* and returns the derived value + priority (int): smaller values are logged earlier in the output + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, MetersDict._DerivedMeter(fn), priority) + + +def log_speed( + key: str, + value: float, + priority: int = 30, + round: Optional[int] = None, +): + """Log the rate of some quantity per second. + + Args: + key (str): name of the field to log + value (float): value to log + priority (int): smaller values are logged earlier in the output + round (Optional[int]): number of digits to round to when displaying + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, TimeMeter(round=round), priority) + agg[key].reset() # reset meter on the first call + else: + agg[key].update(value) + + +def log_start_time(key: str, priority: int = 40, round: Optional[int] = None): + """Log the duration of some event in seconds. + + The duration will be computed once :func:`log_stop_time` is called. + + Args: + key (str): name of the field to log + priority (int): smaller values are logged earlier in the output + round (Optional[int]): number of digits to round to when displaying + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, StopwatchMeter(round=round), priority) + agg[key].start() + + +def log_stop_time(key: str, weight: float = 0.0, prehook=None): + """Log the duration of some event in seconds. + + The duration will be computed since :func:`log_start_time` was called. + Set weight > 0 to report the average time instead of the sum. + + Args: + key (str): name of the field to log + weight (float): weight that this time contributes to the average + prehook (function, no arguments): will be called before the timer + is stopped. For example, use prehook=torch.cuda.synchronize to + make sure all gpu operations are done before timer is stopped. + """ + for agg in get_active_aggregators(): + if key in agg: + agg[key].stop(weight, prehook) + + +def log_custom( + new_meter_fn: Callable[[], Meter], + key: str, + *args, + priority: int = 50, + **kwargs, +): + """Log using a custom Meter. + + Any extra *args* or *kwargs* will be passed through to the Meter's + *update* method. + + Args: + new_meter_fn (Callable[[], Meter]): function that returns a new + Meter instance + key (str): name of the field to log + priority (int): smaller values are logged earlier in the output + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, new_meter_fn(), priority) + agg[key].update(*args, **kwargs) + + +def reset_meter(name: str, key: str) -> None: + """Reset Meter instance aggregated under a given *name* and *key*.""" + meter = get_meter(name, key) + if meter is not None: + meter.reset() + + +def reset_meters(name: str) -> None: + """Reset Meter instances aggregated under a given *name*.""" + meters = get_meters(name) + if meters is not None: + meters.reset() + + +def get_meter(name: str, key: str) -> Meter: + """Get a single Meter instance aggregated under *name* and *key*. + + Returns: + Meter or None if no metrics have been logged under *name* and *key*. + """ + if name not in _aggregators: + return None + return _aggregators[name].get(key, None) + + +def get_meters(name: str) -> MetersDict: + """Get Meter instances aggregated under a given *name*. + + Returns: + MetersDict or None if no metrics have been logged under *name*. + """ + return _aggregators.get(name, None) + + +def get_smoothed_value(name: str, key: str) -> float: + """Get a single smoothed value. + + Raises: + KeyError: if no metrics have been logged under *name* and *key*. + """ + return _aggregators[name].get_smoothed_value(key) + + +def get_smoothed_values(name: str) -> Dict[str, float]: + """Get smoothed values aggregated under a given *name*. + + Raises: + KeyError: if no metrics have been logged under *name*. + """ + return _aggregators[name].get_smoothed_values() + + +def state_dict(): + return OrderedDict([(name, agg.state_dict()) for name, agg in _aggregators.items()]) + + +def load_state_dict(state_dict): + for name, agg_state in state_dict.items(): + _aggregators[name] = MetersDict() + _aggregators[name].load_state_dict(agg_state) + + +def xla_metrics_report(): + try: + import torch_xla.debug.metrics as met + + print(met.metrics_report()) + except ImportError: + return diff --git a/fairseq/logging/progress_bar.py b/fairseq/logging/progress_bar.py new file mode 100644 index 0000000000000000000000000000000000000000..4c64b61bad6edbf4b9ff5bcc2f26952e8b1bfc9c --- /dev/null +++ b/fairseq/logging/progress_bar.py @@ -0,0 +1,582 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Wrapper around various loggers and progress bars (e.g., tqdm). +""" + +import atexit +import json +import logging +import os +import sys +from collections import OrderedDict +from contextlib import contextmanager +from numbers import Number +from typing import Optional + +import torch + +from .meters import AverageMeter, StopwatchMeter, TimeMeter + +logger = logging.getLogger(__name__) + + +def progress_bar( + iterator, + log_format: Optional[str] = None, + log_interval: int = 100, + log_file: Optional[str] = None, + epoch: Optional[int] = None, + prefix: Optional[str] = None, + aim_repo: Optional[str] = None, + aim_run_hash: Optional[str] = None, + aim_param_checkpoint_dir: Optional[str] = None, + tensorboard_logdir: Optional[str] = None, + default_log_format: str = "tqdm", + wandb_project: Optional[str] = None, + wandb_run_name: Optional[str] = None, + azureml_logging: Optional[bool] = False, +): + if log_format is None: + log_format = default_log_format + if log_file is not None: + handler = logging.FileHandler(filename=log_file) + logger.addHandler(handler) + + if log_format == "tqdm" and not sys.stderr.isatty(): + log_format = "simple" + + if log_format == "json": + bar = JsonProgressBar(iterator, epoch, prefix, log_interval) + elif log_format == "none": + bar = NoopProgressBar(iterator, epoch, prefix) + elif log_format == "simple": + bar = SimpleProgressBar(iterator, epoch, prefix, log_interval) + elif log_format == "tqdm": + bar = TqdmProgressBar(iterator, epoch, prefix) + else: + raise ValueError("Unknown log format: {}".format(log_format)) + + if aim_repo: + bar = AimProgressBarWrapper( + bar, + aim_repo=aim_repo, + aim_run_hash=aim_run_hash, + aim_param_checkpoint_dir=aim_param_checkpoint_dir, + ) + + if tensorboard_logdir: + try: + # [FB only] custom wrapper for TensorBoard + import palaas # noqa + + from .fb_tbmf_wrapper import FbTbmfWrapper + + bar = FbTbmfWrapper(bar, log_interval) + except ImportError: + bar = TensorboardProgressBarWrapper(bar, tensorboard_logdir) + + if wandb_project: + bar = WandBProgressBarWrapper(bar, wandb_project, run_name=wandb_run_name) + + if azureml_logging: + bar = AzureMLProgressBarWrapper(bar) + + return bar + + +def build_progress_bar( + args, + iterator, + epoch: Optional[int] = None, + prefix: Optional[str] = None, + default: str = "tqdm", + no_progress_bar: str = "none", +): + """Legacy wrapper that takes an argparse.Namespace.""" + if getattr(args, "no_progress_bar", False): + default = no_progress_bar + if getattr(args, "distributed_rank", 0) == 0: + tensorboard_logdir = getattr(args, "tensorboard_logdir", None) + else: + tensorboard_logdir = None + return progress_bar( + iterator, + log_format=args.log_format, + log_interval=args.log_interval, + epoch=epoch, + prefix=prefix, + tensorboard_logdir=tensorboard_logdir, + default_log_format=default, + ) + + +def format_stat(stat): + if isinstance(stat, Number): + stat = "{:g}".format(stat) + elif isinstance(stat, AverageMeter): + stat = "{:.3f}".format(stat.avg) + elif isinstance(stat, TimeMeter): + stat = "{:g}".format(round(stat.avg)) + elif isinstance(stat, StopwatchMeter): + stat = "{:g}".format(round(stat.sum)) + elif torch.is_tensor(stat): + stat = stat.tolist() + return stat + + +class BaseProgressBar(object): + """Abstract class for progress bars.""" + + def __init__(self, iterable, epoch=None, prefix=None): + self.iterable = iterable + self.n = getattr(iterable, "n", 0) + self.epoch = epoch + self.prefix = "" + if epoch is not None: + self.prefix += "epoch {:03d}".format(epoch) + if prefix is not None: + self.prefix += (" | " if self.prefix != "" else "") + prefix + + def __len__(self): + return len(self.iterable) + + def __enter__(self): + return self + + def __exit__(self, *exc): + return False + + def __iter__(self): + raise NotImplementedError + + def log(self, stats, tag=None, step=None): + """Log intermediate stats according to log_interval.""" + raise NotImplementedError + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + raise NotImplementedError + + def update_config(self, config): + """Log latest configuration.""" + pass + + def _str_commas(self, stats): + return ", ".join(key + "=" + stats[key].strip() for key in stats.keys()) + + def _str_pipes(self, stats): + return " | ".join(key + " " + stats[key].strip() for key in stats.keys()) + + def _format_stats(self, stats): + postfix = OrderedDict(stats) + # Preprocess stats according to datatype + for key in postfix.keys(): + postfix[key] = str(format_stat(postfix[key])) + return postfix + + +@contextmanager +def rename_logger(logger, new_name): + old_name = logger.name + if new_name is not None: + logger.name = new_name + yield logger + logger.name = old_name + + +class JsonProgressBar(BaseProgressBar): + """Log output in JSON format.""" + + def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000): + super().__init__(iterable, epoch, prefix) + self.log_interval = log_interval + self.i = None + self.size = None + + def __iter__(self): + self.size = len(self.iterable) + for i, obj in enumerate(self.iterable, start=self.n): + self.i = i + yield obj + + def log(self, stats, tag=None, step=None): + """Log intermediate stats according to log_interval.""" + step = step or self.i or 0 + if step > 0 and self.log_interval is not None and step % self.log_interval == 0: + update = ( + self.epoch - 1 + (self.i + 1) / float(self.size) + if self.epoch is not None + else None + ) + stats = self._format_stats(stats, epoch=self.epoch, update=update) + with rename_logger(logger, tag): + logger.info(json.dumps(stats)) + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + self.stats = stats + if tag is not None: + self.stats = OrderedDict( + [(tag + "_" + k, v) for k, v in self.stats.items()] + ) + stats = self._format_stats(self.stats, epoch=self.epoch) + with rename_logger(logger, tag): + logger.info(json.dumps(stats)) + + def _format_stats(self, stats, epoch=None, update=None): + postfix = OrderedDict() + if epoch is not None: + postfix["epoch"] = epoch + if update is not None: + postfix["update"] = round(update, 3) + # Preprocess stats according to datatype + for key in stats.keys(): + postfix[key] = format_stat(stats[key]) + return postfix + + +class NoopProgressBar(BaseProgressBar): + """No logging.""" + + def __init__(self, iterable, epoch=None, prefix=None): + super().__init__(iterable, epoch, prefix) + + def __iter__(self): + for obj in self.iterable: + yield obj + + def log(self, stats, tag=None, step=None): + """Log intermediate stats according to log_interval.""" + pass + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + pass + + +class SimpleProgressBar(BaseProgressBar): + """A minimal logger for non-TTY environments.""" + + def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000): + super().__init__(iterable, epoch, prefix) + self.log_interval = log_interval + self.i = None + self.size = None + + def __iter__(self): + self.size = len(self.iterable) + for i, obj in enumerate(self.iterable, start=self.n): + self.i = i + yield obj + + def log(self, stats, tag=None, step=None): + """Log intermediate stats according to log_interval.""" + step = step or self.i or 0 + if step > 0 and self.log_interval is not None and step % self.log_interval == 0: + stats = self._format_stats(stats) + postfix = self._str_commas(stats) + with rename_logger(logger, tag): + logger.info( + "{}: {:5d} / {:d} {}".format( + self.prefix, self.i + 1, self.size, postfix + ) + ) + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + postfix = self._str_pipes(self._format_stats(stats)) + with rename_logger(logger, tag): + logger.info("{} | {}".format(self.prefix, postfix)) + + +class TqdmProgressBar(BaseProgressBar): + """Log to tqdm.""" + + def __init__(self, iterable, epoch=None, prefix=None): + super().__init__(iterable, epoch, prefix) + from tqdm import tqdm + + self.tqdm = tqdm( + iterable, + self.prefix, + leave=False, + disable=(logger.getEffectiveLevel() > logging.INFO), + ) + + def __iter__(self): + return iter(self.tqdm) + + def log(self, stats, tag=None, step=None): + """Log intermediate stats according to log_interval.""" + self.tqdm.set_postfix(self._format_stats(stats), refresh=False) + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + postfix = self._str_pipes(self._format_stats(stats)) + with rename_logger(logger, tag): + logger.info("{} | {}".format(self.prefix, postfix)) + + +try: + import functools + + from aim import Repo as AimRepo + + @functools.lru_cache() + def get_aim_run(repo, run_hash): + from aim import Run + + return Run(run_hash=run_hash, repo=repo) + +except ImportError: + get_aim_run = None + AimRepo = None + + +class AimProgressBarWrapper(BaseProgressBar): + """Log to Aim.""" + + def __init__(self, wrapped_bar, aim_repo, aim_run_hash, aim_param_checkpoint_dir): + self.wrapped_bar = wrapped_bar + + if get_aim_run is None: + self.run = None + logger.warning("Aim not found, please install with: pip install aim") + else: + logger.info(f"Storing logs at Aim repo: {aim_repo}") + + if not aim_run_hash: + # Find run based on save_dir parameter + query = f"run.checkpoint.save_dir == '{aim_param_checkpoint_dir}'" + try: + runs_generator = AimRepo(aim_repo).query_runs(query) + run = next(runs_generator.iter_runs()) + aim_run_hash = run.run.hash + except Exception: + pass + + if aim_run_hash: + logger.info(f"Appending to run: {aim_run_hash}") + + self.run = get_aim_run(aim_repo, aim_run_hash) + + def __iter__(self): + return iter(self.wrapped_bar) + + def log(self, stats, tag=None, step=None): + """Log intermediate stats to Aim.""" + self._log_to_aim(stats, tag, step) + self.wrapped_bar.log(stats, tag=tag, step=step) + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + self._log_to_aim(stats, tag, step) + self.wrapped_bar.print(stats, tag=tag, step=step) + + def update_config(self, config): + """Log latest configuration.""" + if self.run is not None: + for key in config: + self.run.set(key, config[key], strict=False) + self.wrapped_bar.update_config(config) + + def _log_to_aim(self, stats, tag=None, step=None): + if self.run is None: + return + + if step is None: + step = stats["num_updates"] + + if "train" in tag: + context = {"tag": tag, "subset": "train"} + elif "val" in tag: + context = {"tag": tag, "subset": "val"} + else: + context = {"tag": tag} + + for key in stats.keys() - {"num_updates"}: + self.run.track(stats[key], name=key, step=step, context=context) + + +try: + _tensorboard_writers = {} + from torch.utils.tensorboard import SummaryWriter +except ImportError: + try: + from tensorboardX import SummaryWriter + except ImportError: + SummaryWriter = None + + +def _close_writers(): + for w in _tensorboard_writers.values(): + w.close() + + +atexit.register(_close_writers) + + +class TensorboardProgressBarWrapper(BaseProgressBar): + """Log to tensorboard.""" + + def __init__(self, wrapped_bar, tensorboard_logdir): + self.wrapped_bar = wrapped_bar + self.tensorboard_logdir = tensorboard_logdir + + if SummaryWriter is None: + logger.warning( + "tensorboard not found, please install with: pip install tensorboard" + ) + + def _writer(self, key): + if SummaryWriter is None: + return None + _writers = _tensorboard_writers + if key not in _writers: + _writers[key] = SummaryWriter(os.path.join(self.tensorboard_logdir, key)) + _writers[key].add_text("sys.argv", " ".join(sys.argv)) + return _writers[key] + + def __iter__(self): + return iter(self.wrapped_bar) + + def log(self, stats, tag=None, step=None): + """Log intermediate stats to tensorboard.""" + self._log_to_tensorboard(stats, tag, step) + self.wrapped_bar.log(stats, tag=tag, step=step) + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + self._log_to_tensorboard(stats, tag, step) + self.wrapped_bar.print(stats, tag=tag, step=step) + + def update_config(self, config): + """Log latest configuration.""" + # TODO add hparams to Tensorboard + self.wrapped_bar.update_config(config) + + def _log_to_tensorboard(self, stats, tag=None, step=None): + writer = self._writer(tag or "") + if writer is None: + return + if step is None: + step = stats["num_updates"] + for key in stats.keys() - {"num_updates"}: + if isinstance(stats[key], AverageMeter): + writer.add_scalar(key, stats[key].val, step) + elif isinstance(stats[key], Number): + writer.add_scalar(key, stats[key], step) + elif torch.is_tensor(stats[key]) and stats[key].numel() == 1: + writer.add_scalar(key, stats[key].item(), step) + writer.flush() + + +try: + import wandb +except ImportError: + wandb = None + + +class WandBProgressBarWrapper(BaseProgressBar): + """Log to Weights & Biases.""" + + def __init__(self, wrapped_bar, wandb_project, run_name=None): + self.wrapped_bar = wrapped_bar + if wandb is None: + logger.warning("wandb not found, pip install wandb") + return + + # reinit=False to ensure if wandb.init() is called multiple times + # within one process it still references the same run + wandb.init(project=wandb_project, reinit=False, name=run_name) + + def __iter__(self): + return iter(self.wrapped_bar) + + def log(self, stats, tag=None, step=None): + """Log intermediate stats to tensorboard.""" + self._log_to_wandb(stats, tag, step) + self.wrapped_bar.log(stats, tag=tag, step=step) + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + self._log_to_wandb(stats, tag, step) + self.wrapped_bar.print(stats, tag=tag, step=step) + + def update_config(self, config): + """Log latest configuration.""" + if wandb is not None: + wandb.config.update(config) + self.wrapped_bar.update_config(config) + + def _log_to_wandb(self, stats, tag=None, step=None): + if wandb is None: + return + if step is None: + step = stats["num_updates"] + + prefix = "" if tag is None else tag + "/" + + for key in stats.keys() - {"num_updates"}: + if isinstance(stats[key], AverageMeter): + wandb.log({prefix + key: stats[key].val}, step=step) + elif isinstance(stats[key], Number): + wandb.log({prefix + key: stats[key]}, step=step) + + +try: + from azureml.core import Run +except ImportError: + Run = None + + +class AzureMLProgressBarWrapper(BaseProgressBar): + """Log to Azure ML""" + + def __init__(self, wrapped_bar): + self.wrapped_bar = wrapped_bar + if Run is None: + logger.warning("azureml.core not found, pip install azureml-core") + return + self.run = Run.get_context() + + def __exit__(self, *exc): + if Run is not None: + self.run.complete() + return False + + def __iter__(self): + return iter(self.wrapped_bar) + + def log(self, stats, tag=None, step=None): + """Log intermediate stats to AzureML""" + self._log_to_azureml(stats, tag, step) + self.wrapped_bar.log(stats, tag=tag, step=step) + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats""" + self._log_to_azureml(stats, tag, step) + self.wrapped_bar.print(stats, tag=tag, step=step) + + def update_config(self, config): + """Log latest configuration.""" + self.wrapped_bar.update_config(config) + + def _log_to_azureml(self, stats, tag=None, step=None): + if Run is None: + return + if step is None: + step = stats["num_updates"] + + prefix = "" if tag is None else tag + "/" + + for key in stats.keys() - {"num_updates"}: + name = prefix + key + if isinstance(stats[key], AverageMeter): + self.run.log_row(name=name, **{"step": step, key: stats[key].val}) + elif isinstance(stats[key], Number): + self.run.log_row(name=name, **{"step": step, key: stats[key]}) diff --git a/fairseq/model_parallel/__init__.py b/fairseq/model_parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..69f21684872f72ae8ee26d9ff7d2d2b6e6d526c3 --- /dev/null +++ b/fairseq/model_parallel/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import criterions, models, modules # noqa diff --git a/fairseq/model_parallel/__pycache__/__init__.cpython-310.pyc b/fairseq/model_parallel/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6a44d116fe6e0e25ec06ebb3ac641143ed191af Binary files /dev/null and b/fairseq/model_parallel/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/model_parallel/criterions/__init__.py b/fairseq/model_parallel/criterions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5fae7bd4c2cfa7b4f64ad62dd9b9082f59f0e50d --- /dev/null +++ b/fairseq/model_parallel/criterions/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + + +# automatically import any Python files in the criterions/ directory +for file in sorted(os.listdir(os.path.dirname(__file__))): + if file.endswith(".py") and not file.startswith("_"): + module = file[: file.find(".py")] + importlib.import_module("fairseq.model_parallel.criterions." + module) diff --git a/fairseq/model_parallel/criterions/__pycache__/__init__.cpython-310.pyc b/fairseq/model_parallel/criterions/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..937d5f3f6f3f034da46c3c5498c6ef9479df18b4 Binary files /dev/null and b/fairseq/model_parallel/criterions/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/model_parallel/criterions/__pycache__/vocab_parallel_cross_entropy.cpython-310.pyc b/fairseq/model_parallel/criterions/__pycache__/vocab_parallel_cross_entropy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..237fd7157f87577947460d34d6ce8958cc22a5d4 Binary files /dev/null and b/fairseq/model_parallel/criterions/__pycache__/vocab_parallel_cross_entropy.cpython-310.pyc differ diff --git a/fairseq/model_parallel/criterions/vocab_parallel_cross_entropy.py b/fairseq/model_parallel/criterions/vocab_parallel_cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..5ffbaa87640973317e3cac4c396cdc11af2fa380 --- /dev/null +++ b/fairseq/model_parallel/criterions/vocab_parallel_cross_entropy.py @@ -0,0 +1,88 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from fairseq import utils +from fairseq.logging import metrics +from fairseq.criterions import FairseqCriterion, register_criterion + + +try: + from fairseq.model_parallel.megatron.mpu.cross_entropy import ( + vocab_parallel_cross_entropy, + ) + + has_megatron_submodule = True +except (ImportError, ModuleNotFoundError): + has_megatron_submodule = False + + +@register_criterion("vocab_parallel_cross_entropy") +class VocabParallelCrossEntropyCriterion(FairseqCriterion): + def __init__(self, task, sentence_avg): + super().__init__(task) + self.sentence_avg = sentence_avg + if not has_megatron_submodule: + raise ImportError( + "\n\nPlease install the megatron submodule:" + "\n\n git submodule update --init " + "fairseq/model_parallel/megatron" + ) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample["net_input"]) + target = sample["target"] + + loss = vocab_parallel_cross_entropy(net_output[0].float(), target) + loss = (loss * (target != self.padding_idx)).sum() + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + logging_output = { + "loss": utils.item(loss.data) if reduce else loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + if sample_size != ntokens: + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) + else: + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/model_parallel/megatron_trainer.py b/fairseq/model_parallel/megatron_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..aedf608bce21d11d0a1e9646d9c373aae198dce6 --- /dev/null +++ b/fairseq/model_parallel/megatron_trainer.py @@ -0,0 +1,75 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Train a network across multiple GPUs. +""" + +from fairseq.dataclass.configs import FairseqConfig +from fairseq.distributed import utils as distributed_utils +from fairseq.trainer import Trainer + +try: + from fairseq.model_parallel.megatron.mpu import ( + get_data_parallel_rank, + get_data_parallel_world_size, + get_model_parallel_src_rank, + get_cuda_rng_tracker, + ) + + has_megatron_submodule = True +except (ImportError, ModuleNotFoundError): + has_megatron_submodule = False + + +class MegatronTrainer(Trainer): + """Main class for model parallel with data parallel training.""" + + def __init__(self, cfg: FairseqConfig, task, model, criterion, **kwargs): + if not has_megatron_submodule: + raise ImportError( + "\n\nPlease install the megatron submodule:" + "\n\n git submodule update --init " + "fairseq/model_parallel/megatron" + ) + super().__init__(cfg, task, model, criterion, **kwargs) + + def clip_grad_norm(self, clip_norm): + def _aggregate_model_parallel_grad_norm(total_norm): + total_norm = total_norm**2 + distributed_utils.all_reduce( + total_norm, group=distributed_utils.get_model_parallel_group() + ) + total_norm = total_norm**0.5 + return total_norm + + return self.optimizer.clip_grad_norm( + clip_norm, + aggregate_norm_fn=_aggregate_model_parallel_grad_norm, + ) + + def save_checkpoint(self, filename, extra_state): + """Save all training state in a checkpoint file.""" + extra_state["rng_tracker_states"] = get_cuda_rng_tracker().get_states() + super().save_checkpoint(filename, extra_state) + + def load_checkpoint( + self, + filename, + reset_optimizer=False, + reset_lr_scheduler=False, + optimizer_overrides=None, + reset_meters=False, + ): + extra_state = super().load_checkpoint( + filename, + reset_optimizer=reset_optimizer, + reset_lr_scheduler=reset_lr_scheduler, + optimizer_overrides=optimizer_overrides, + reset_meters=reset_meters, + ) + if extra_state is not None and "rng_tracker_states" in extra_state: + get_cuda_rng_tracker().set_states(extra_state["rng_tracker_states"]) + return extra_state diff --git a/fairseq/model_parallel/models/__init__.py b/fairseq/model_parallel/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3532479e52a0e1f1ba204c6f5d51c71c98ee5df0 --- /dev/null +++ b/fairseq/model_parallel/models/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + + +# automatically import any Python files in the models/ directory +models_dir = os.path.dirname(__file__) +for file in os.listdir(models_dir): + path = os.path.join(models_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + model_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("fairseq.model_parallel.models." + model_name) diff --git a/fairseq/model_parallel/models/__pycache__/__init__.cpython-310.pyc b/fairseq/model_parallel/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81d1adbffa1bd84858ec882e097b7adad024d2d7 Binary files /dev/null and b/fairseq/model_parallel/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/model_parallel/models/__pycache__/transformer.cpython-310.pyc b/fairseq/model_parallel/models/__pycache__/transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..685f6fa65143e1fdba9f60dcec590e299c03fe71 Binary files /dev/null and b/fairseq/model_parallel/models/__pycache__/transformer.cpython-310.pyc differ diff --git a/fairseq/model_parallel/models/__pycache__/transformer_lm.cpython-310.pyc b/fairseq/model_parallel/models/__pycache__/transformer_lm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2386f40fa19246a7049629d017b3f89ca6e8ddae Binary files /dev/null and b/fairseq/model_parallel/models/__pycache__/transformer_lm.cpython-310.pyc differ diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/__init__.py b/fairseq/model_parallel/models/pipeline_parallel_transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..117827c3e9c176477f33e3a6fd7fe19a922411a2 --- /dev/null +++ b/fairseq/model_parallel/models/pipeline_parallel_transformer/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .model import * # noqa diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/__pycache__/__init__.cpython-310.pyc b/fairseq/model_parallel/models/pipeline_parallel_transformer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f2397746bee839d71190fad60c494340df200fa Binary files /dev/null and b/fairseq/model_parallel/models/pipeline_parallel_transformer/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/__pycache__/layers.cpython-310.pyc b/fairseq/model_parallel/models/pipeline_parallel_transformer/__pycache__/layers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5003bf51f0af4249ea592ef1b09cb38a48d3c831 Binary files /dev/null and b/fairseq/model_parallel/models/pipeline_parallel_transformer/__pycache__/layers.cpython-310.pyc differ diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/__pycache__/model.cpython-310.pyc b/fairseq/model_parallel/models/pipeline_parallel_transformer/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..223fff50ba2e5cd862e6dd037b627f9744388d2f Binary files /dev/null and b/fairseq/model_parallel/models/pipeline_parallel_transformer/__pycache__/model.cpython-310.pyc differ diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py b/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..85dbd44b3c7f762048ff21808313d0317f8da7a4 --- /dev/null +++ b/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py @@ -0,0 +1,600 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from collections import namedtuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq import options, utils +from fairseq.modules import ( + AdaptiveSoftmax, + LayerNorm, + MultiheadAttention, + PositionalEmbedding, +) + +EncoderOut = namedtuple( + "TransformerEncoderOut", + [ + "encoder_out", # T x B x C + "encoder_padding_mask", # B x T + "encoder_embedding", # B x T x C + "encoder_states", # List[T x B x C] + ], +) + + +class TransformerEncoderEmbedding(nn.Module): + """Encoder Embedding + Positional Embedding""" + + def __init__(self, args, embed_tokens): + super().__init__() + self.dropout = args.dropout + self.max_source_positions = args.max_source_positions + self.embed_tokens = embed_tokens + if isinstance(embed_tokens, nn.ModuleList): + self.padding_idx = embed_tokens[0].padding_idx + embed_dim = sum(e.embedding_dim for e in embed_tokens) + else: + self.padding_idx = embed_tokens.padding_idx + embed_dim = embed_tokens.embedding_dim + self.embed_scale = math.sqrt(embed_dim) + self.embed_positions = ( + PositionalEmbedding( + args.max_source_positions, + embed_dim, + self.padding_idx, + learned=args.encoder_learned_pos, + ) + if not args.no_token_positional_embeddings + else None + ) + if getattr(args, "layernorm_embedding", False): + self.layernorm_embedding = LayerNorm(embed_dim) + else: + self.layernorm_embedding = None + + def forward(self, input): + # embed tokens and positions + src_tokens = input[0] + prev_output_tokens = input[2] + if isinstance(self.embed_tokens, nn.ModuleList): + x_embed_list = [] + for embed_tokens_part in self.embed_tokens: + x_embed_list.append(embed_tokens_part(src_tokens)) + + embedded = torch.cat(x_embed_list, dim=-1) + else: + embedded = self.embed_tokens(src_tokens) + x = embed = self.embed_scale * embedded + if self.embed_positions is not None: + x = embed + self.embed_positions(src_tokens) + if self.layernorm_embedding: + x = self.layernorm_embedding(x) + x = F.dropout(x, p=self.dropout, training=self.training) + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # compute padding mask + encoder_padding_mask = src_tokens.eq(self.padding_idx) + return (x, encoder_padding_mask, prev_output_tokens) + + +class TransformerEncoderLayerNorm(nn.Module): + """ + Layer norm at the the end of all encoder layers if + args.encoder_enormalize_before = True + """ + + def __init__(self, args, embed_dim): + super().__init__() + if args.encoder_normalize_before: + self.layer_norm = LayerNorm(embed_dim) + else: + self.layer_norm = None + + def forward(self, input): + x = input[0] + encoder_padding_mask = input[1] + prev_output_tokens = input[2] + if self.layer_norm: + x = self.layer_norm(x) + # keeping track of the incremental_state is not supported yet + return (x, encoder_padding_mask, prev_output_tokens) + + +class TransformerDecoderEmbedding(nn.Module): + """Decoder Embedding + Positional Embedding""" + + def __init__(self, args, embed_tokens): + super().__init__() + self.dropout = args.dropout + self.share_input_output_embed = args.share_decoder_input_output_embed + input_embed_dim = ( + sum(e.embedding_dim for e in embed_tokens) + if isinstance(embed_tokens, nn.ModuleList) + else embed_tokens.embedding_dim + ) + embed_dim = args.decoder_embed_dim + self.output_embed_dim = args.decoder_output_dim + + padding_idx = ( + embed_tokens[0].padding_idx + if isinstance(embed_tokens, nn.ModuleList) + else embed_tokens.padding_idx + ) + self.max_target_positions = args.max_target_positions + + self.embed_tokens = embed_tokens + self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim + + self.project_in_dim = ( + Linear(input_embed_dim, embed_dim, bias=False) + if embed_dim != input_embed_dim + else None + ) + + self.embed_positions = ( + PositionalEmbedding( + args.max_target_positions, + embed_dim, + padding_idx, + learned=args.decoder_learned_pos, + ) + if not args.no_token_positional_embeddings + else None + ) + + def forward(self, input): + mt_task = False + if isinstance(input, tuple): + if len(input) == 3: + encoder_out = input[0] + encoder_padding_mask = input[1] + prev_output_tokens = input[2] + incremental_state = None # Hardcoding to avoid passing of None objects + mt_task = True + else: + # HACK for now, need to fix (TODO sidgoyal) + prev_output_tokens = input[0] + # discard "src_lengths" + encoder_out = None + encoder_padding_mask = None + incremental_state = None + + else: + prev_output_tokens = input + encoder_out = None + encoder_padding_mask = None + incremental_state = None + + positions = ( + self.embed_positions( + prev_output_tokens, + incremental_state=incremental_state, + ) + if self.embed_positions is not None + else None + ) + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + # embed tokens and positions + + if isinstance(self.embed_tokens, nn.ModuleList): + x_embed_list = [] + for embed_tokens_part in self.embed_tokens: + x_embed_list.append(embed_tokens_part(prev_output_tokens)) + + x = self.embed_scale * torch.cat(x_embed_list, dim=-1) + else: + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if positions is not None: + x += positions + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + if mt_task: + return (x, encoder_out, encoder_padding_mask) + return x + + +class TransformerDecoderOutputLayer(nn.Module): + def __init__(self, args, embed_tokens, dictionary): + super().__init__() + self.share_input_output_embed = args.share_decoder_input_output_embed + self.embed_tokens = embed_tokens + self.output_embed_dim = args.decoder_output_dim + embed_dim = args.decoder_embed_dim + + self.project_out_dim = ( + Linear(embed_dim, self.output_embed_dim, bias=False) + if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights + else None + ) + self.adaptive_softmax = None + if args.adaptive_softmax_cutoff is not None: + assert not isinstance(embed_tokens, nn.ModuleList) + self.adaptive_softmax = AdaptiveSoftmax( + len(dictionary), + self.output_embed_dim, + options.eval_str_list(args.adaptive_softmax_cutoff, type=int), + dropout=args.adaptive_softmax_dropout, + adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, + factor=args.adaptive_softmax_factor, + tie_proj=args.tie_adaptive_proj, + ) + elif not self.share_input_output_embed: + self.embed_tokens = nn.Parameter( + torch.Tensor(len(dictionary), self.output_embed_dim) + ) + nn.init.normal_( + self.embed_tokens, mean=0, std=self.output_embed_dim**-0.5 + ) + + if args.decoder_normalize_before and not getattr( + args, "no_decoder_final_norm", False + ): + self.layer_norm = LayerNorm(embed_dim) + else: + self.layer_norm = None + + def forward(self, input, apply_final_proj=True): + if isinstance(input, tuple): + x = input[0] + else: + x = input + + if self.layer_norm: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + if self.project_out_dim is not None: + x = self.project_out_dim(x) + if apply_final_proj: + x = self.output_layer(x) + return x + + def output_layer(self, features, **kwargs): + """Project features to the vocabulary size.""" + if self.adaptive_softmax is None: + # project back to size of vocabulary + if self.share_input_output_embed: + if isinstance(self.embed_tokens, nn.ModuleList): + output = None + for i, emb in enumerate(self.embed_tokens): + sidx = i * emb.embedding_dim + eidx = (i + 1) * emb.embedding_dim + if output is None: + output = F.linear(features[:, :, sidx:eidx], emb.weight) + else: + output += F.linear(features[:, :, sidx:eidx], emb.weight) + + return output + else: + return F.linear(features, self.embed_tokens.weight) + else: + return F.linear(features, self.embed_tokens) + else: + return features + + +class TransformerEncoderLayer(nn.Module): + """Encoder layer block. + In the original paper each operation (multi-head attention or FFN) is + postprocessed with: `dropout -> add residual -> layernorm`. In the + tensor2tensor code they suggest that learning is more robust when + preprocessing each layer with layernorm and postprocessing with: + `dropout -> add residual`. We default to the approach in the paper, but the + tensor2tensor approach can be enabled by setting + *args.encoder_normalize_before* to ``True``. + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + + def __init__(self, args): + super().__init__() + self.embed_dim = args.encoder_embed_dim + self.self_attn = MultiheadAttention( + self.embed_dim, + args.encoder_attention_heads, + dropout=args.attention_dropout, + self_attention=True, + ) + self.self_attn_layer_norm = LayerNorm(self.embed_dim) + self.dropout = args.dropout + self.activation_fn = utils.get_activation_fn( + activation=getattr(args, "activation_fn", "relu") + ) + self.activation_dropout = getattr(args, "activation_dropout", 0) + if self.activation_dropout == 0: + # for backwards compatibility with models that use args.relu_dropout + self.activation_dropout = getattr(args, "relu_dropout", 0) + self.normalize_before = args.encoder_normalize_before + self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim) + self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim) + self.final_layer_norm = LayerNorm(self.embed_dim) + + def upgrade_state_dict_named(self, state_dict, name): + """ + Rename layer norm states from `...layer_norms.0.weight` to + `...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to + `...final_layer_norm.weight` + """ + layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"} + for old, new in layer_norm_map.items(): + for m in ("weight", "bias"): + k = "{}.layer_norms.{}.{}".format(name, old, m) + if k in state_dict: + state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k] + del state_dict[k] + + def forward(self, input): + """ + Args: + input (Tuple): + input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` + input[1] (ByteTensor/FloatTensor): encoder padding mask - + binary ByteTensor of shape `(batch, src_len)` where padding elements + are indicated by ``1``. + input[2] (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing) + Returns: + output (Tuple): + output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)` + output[1] (ByteTensor/FloatTensor): encoder padding mask + output[2] (LongTensor): previous decoder outputs + """ + x = input[0] + encoder_padding_mask = input[1] + prev_output_tokens = input[2] + residual = x + x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) + x, _ = self.self_attn( + query=x, key=x, value=x, key_padding_mask=encoder_padding_mask + ) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) + + residual = x + x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) + x = self.activation_fn(self.fc1(x)) + x = F.dropout(x, p=self.activation_dropout, training=self.training) + x = self.fc2(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) + return (x, encoder_padding_mask, prev_output_tokens) + + def maybe_layer_norm(self, layer_norm, x, before=False, after=False): + assert before ^ after + if after ^ self.normalize_before: + return layer_norm(x) + else: + return x + + +class TransformerDecoderLayer(nn.Module): + """Decoder layer block. + + In the original paper each operation (multi-head attention, encoder + attention or FFN) is postprocessed with: `dropout -> add residual -> + layernorm`. In the tensor2tensor code they suggest that learning is more + robust when preprocessing each layer with layernorm and postprocessing with: + `dropout -> add residual`. We default to the approach in the paper, but the + tensor2tensor approach can be enabled by setting + *args.decoder_normalize_before* to ``True``. + + Args: + args (argparse.Namespace): parsed command-line arguments + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__( + self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False + ): + super().__init__() + self.embed_dim = args.decoder_embed_dim + self.self_attn = MultiheadAttention( + embed_dim=self.embed_dim, + num_heads=args.decoder_attention_heads, + dropout=args.attention_dropout, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + self_attention=True, + ) + self.dropout = args.dropout + self.activation_fn = utils.get_activation_fn( + activation=getattr(args, "activation_fn", "relu") + ) + self.activation_dropout = getattr(args, "activation_dropout", 0) + if self.activation_dropout == 0: + # for backwards compatibility with models that use args.relu_dropout + self.activation_dropout = getattr(args, "relu_dropout", 0) + self.normalize_before = args.decoder_normalize_before + + # use layerNorm rather than FusedLayerNorm for exporting. + # char_inputs can be used to determint this. + # TODO remove this once we update apex with the fix + export = getattr(args, "char_inputs", False) + self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) + + if no_encoder_attn: + self.encoder_attn = None + self.encoder_attn_layer_norm = None + else: + self.encoder_attn = MultiheadAttention( + self.embed_dim, + args.decoder_attention_heads, + kdim=getattr(args, "encoder_embed_dim", None), + vdim=getattr(args, "encoder_embed_dim", None), + dropout=args.attention_dropout, + encoder_decoder_attention=True, + ) + self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) + + self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) + self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) + + self.final_layer_norm = LayerNorm(self.embed_dim, export=export) + self.need_attn = True + + self.onnx_trace = False + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def forward(self, input): + """ + Args: + input (Tuple): + input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` + input[1] (Tensor): encoder output of shape `(batch, src_len, embed_dim)` + input[2] (ByteTensor/FloatTensor): encoder padding mask - + binary ByteTensor of shape `(batch, src_len)` where padding elements + are indicated by ``1``. + Returns: + output (Tuple): + output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)` + output[1] (ByteTensor/FloatTensor): encoder padding mask + output[2] (LongTensor): previous decoder outputs + """ + # Note: incremental state is not yet supported + mt_task = False + if isinstance(input, tuple): + x = input[0] + encoder_out = input[1] + encoder_padding_mask = input[2] + incremental_state = None + mt_task = True + else: + x = input + encoder_out = None + encoder_padding_mask = None + incremental_state = None + + if incremental_state is None: + self_attn_mask = self.buffered_future_mask(x) + else: + self_attn_mask = None + + # TODO: add back prev_self_attn_state, prev_attn_state, + # self_attn_padding_mask + prev_self_attn_state = None + prev_attn_state = None + self_attn_padding_mask = None + + residual = x + x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) + if prev_self_attn_state is not None: + if incremental_state is None: + incremental_state = {} + prev_key, prev_value = prev_self_attn_state + saved_state = {"prev_key": prev_key, "prev_value": prev_value} + self.self_attn._set_input_buffer(incremental_state, saved_state) + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + incremental_state=incremental_state, + need_weights=False, + attn_mask=self_attn_mask, + ) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) + + if self.encoder_attn is not None: + residual = x + x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) + if prev_attn_state is not None: + if incremental_state is None: + incremental_state = {} + prev_key, prev_value = prev_attn_state + saved_state = {"prev_key": prev_key, "prev_value": prev_value} + self.encoder_attn._set_input_buffer(incremental_state, saved_state) + x, attn = self.encoder_attn( + query=x, + key=encoder_out, + value=encoder_out, + key_padding_mask=encoder_padding_mask, + incremental_state=incremental_state, + static_kv=True, + need_weights=(not self.training and self.need_attn), + ) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True) + + residual = x + x = self.maybe_layer_norm(self.final_layer_norm, x, before=True) + x = self.activation_fn(self.fc1(x)) + x = F.dropout(x, p=self.activation_dropout, training=self.training) + x = self.fc2(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) + + if mt_task: + return (x, encoder_out, encoder_padding_mask) + return x + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + if ( + not hasattr(self, "_future_mask") + or self._future_mask is None + or self._future_mask.device != tensor.device + ): + self._future_mask = torch.triu( + utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 + ) + if self._future_mask.size(0) < dim: + self._future_mask = torch.triu( + utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1 + ) + return self._future_mask[:dim, :dim] + + def maybe_layer_norm(self, layer_norm, x, before=False, after=False): + assert before ^ after + if after ^ self.normalize_before: + return layer_norm(x) + else: + return x + + def make_generation_fast_(self, need_attn=False, **kwargs): + self.need_attn = need_attn + + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5) + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +def Linear(in_features, out_features, bias=True): + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.0) + return m diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py new file mode 100644 index 0000000000000000000000000000000000000000..7873ac679170d2647f0491747a75f60364e248dc --- /dev/null +++ b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py @@ -0,0 +1,779 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import utils +from fairseq.model_parallel.models.pipeline_parallel_transformer.layers import ( + Embedding, + TransformerDecoderEmbedding, + TransformerDecoderLayer, + TransformerDecoderOutputLayer, + TransformerEncoderEmbedding, + TransformerEncoderLayer, + TransformerEncoderLayerNorm, +) +from fairseq.models import ( + BaseFairseqModel, + FairseqDecoder, + FairseqEncoder, + register_model, + register_model_architecture, +) +from fairseq.models.fairseq_encoder import EncoderOut +from fairseq.models.transformer import ( + base_architecture, + transformer_iwslt_de_en, + transformer_wmt_en_de_big, +) +from fairseq.modules import SinusoidalPositionalEmbedding + + +logger = logging.getLogger(__name__) + + +DEFAULT_MAX_SOURCE_POSITIONS = 1024 +DEFAULT_MAX_TARGET_POSITIONS = 1024 +TORCH_PIPE = False +RPC_INIT = False + + +def import_pipe(): + global TORCH_PIPE + global RPC_INIT + try: + from torch.distributed.pipeline.sync import Pipe # noqa + + global Pipe + from torch.distributed.pipeline.sync.utils import partition_model + + global partition_model + from torch.distributed import rpc + import tempfile + + TORCH_PIPE = True + # Initialize single process RPC agent since TORCH_PIPE requires + # RRef. RRef depends on RPC being initialized and as a result we initialize + # RPC with a single node. + tmpfile = tempfile.NamedTemporaryFile() + if not RPC_INIT: + rpc.init_rpc( + name="worker", + rank=0, + world_size=1, + rpc_backend_options=rpc.TensorPipeRpcBackendOptions( + init_method="file://{}".format(tmpfile.name), + ), + ) + RPC_INIT = True + logger.info("Using torch pipe") + except ImportError: + try: + from fairscale.nn import Pipe # noqa + + logger.info("Using fairscale pipe") + except ImportError: + raise ImportError("Please install fairscale with: pip install fairscale") + + +@register_model("pipeline_parallel_transformer") +class PipelineParallelTransformerModel(BaseFairseqModel): + def __init__(self, encoder, decoder, balance, devices, chunks, checkpoint): + import_pipe() + super().__init__() + assert isinstance(encoder, FairseqEncoder) + assert isinstance(decoder, FairseqDecoder) + encoder_module_list = ( + [encoder.embedding_layer] + + list(encoder.encoder_layers) + + [encoder.final_layer_norm] + ) + self.num_encoder_modules = len(encoder_module_list) + decoder_module_list = ( + [decoder.embedding_layer] + + list(decoder.decoder_layers) + + [decoder.decoder_output_layer] + ) + self.num_decoder_modules = len(decoder_module_list) + module_list = encoder_module_list + decoder_module_list + self.devices = devices + if TORCH_PIPE: + self.model = Pipe( + partition_model(nn.Sequential(*module_list), balance, devices), + chunks=chunks, + checkpoint=checkpoint, + ) + else: + self.model = Pipe( + nn.Sequential(*module_list), + balance=balance, + devices=devices, + chunks=chunks, + checkpoint=checkpoint, + ) + self.encoder_max_positions = self.max_positions_helper( + encoder.embedding_layer, "max_source_positions" + ) + self.decoder_max_positions = self.max_positions_helper( + decoder.embedding_layer, "max_target_positions" + ) + self.adaptive_softmax = getattr(decoder, "adaptive_softmax", None) + # Note: To be populated during inference + self.encoder = None + self.decoder = None + + def forward(self, src_tokens, src_lengths, prev_output_tokens): + if self.training: + input_lst = [src_tokens, src_lengths, prev_output_tokens] + input = tuple(i.to(self.devices[0], non_blocking=True) for i in input_lst) + if TORCH_PIPE: + return self.model(input).local_value() + else: + return self.model(input) + else: + assert self.encoder is not None and self.decoder is not None, ( + "encoder and decoder need to be initialized by " + + "calling the `prepare_for_inference_()` method" + ) + encoder_output_tuple = self.encoder(input) + return self.decoder(encoder_output_tuple) + + def prepare_for_inference_(self, cfg): + if self.encoder is not None and self.decoder is not None: + logger.info("Encoder and Decoder already initialized") + return + encoder_module_list = [] + decoder_module_list = [] + module_count = 0 + for partition in self.model.partitions: + for module in partition: + if module_count < self.num_encoder_modules: + encoder_module_list.append(module) + else: + decoder_module_list.append(module) + module_count += 1 + self.model = None + self.encoder = TransformerEncoder( + cfg.distributed_training, None, None, encoder_module_list + ) + self.decoder = TransformerDecoder( + cfg.distributed_training, + None, + None, + decoder_module_list=decoder_module_list, + ) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--activation-fn', + choices=utils.get_available_activation_fns(), + help='activation function to use') + parser.add_argument('--dropout', type=float, metavar='D', + help='dropout probability') + parser.add_argument('--attention-dropout', type=float, metavar='D', + help='dropout probability for attention weights') + parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D', + help='dropout probability after activation in FFN.') + parser.add_argument('--encoder-embed-path', type=str, metavar='STR', + help='path to pre-trained encoder embedding') + parser.add_argument('--encoder-embed-dim', type=int, metavar='N', + help='encoder embedding dimension') + parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N', + help='encoder embedding dimension for FFN') + parser.add_argument('--encoder-layers', type=int, metavar='N', + help='num encoder layers') + parser.add_argument('--encoder-attention-heads', type=int, metavar='N', + help='num encoder attention heads') + parser.add_argument('--encoder-normalize-before', action='store_true', + help='apply layernorm before each encoder block') + parser.add_argument('--encoder-learned-pos', action='store_true', + help='use learned positional embeddings in the encoder') + parser.add_argument('--decoder-embed-path', type=str, metavar='STR', + help='path to pre-trained decoder embedding') + parser.add_argument('--decoder-embed-dim', type=int, metavar='N', + help='decoder embedding dimension') + parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N', + help='decoder embedding dimension for FFN') + parser.add_argument('--decoder-layers', type=int, metavar='N', + help='num decoder layers') + parser.add_argument('--decoder-attention-heads', type=int, metavar='N', + help='num decoder attention heads') + parser.add_argument('--decoder-learned-pos', action='store_true', + help='use learned positional embeddings in the decoder') + parser.add_argument('--decoder-normalize-before', action='store_true', + help='apply layernorm before each decoder block') + parser.add_argument('--share-decoder-input-output-embed', action='store_true', + help='share decoder input and output embeddings') + parser.add_argument('--share-all-embeddings', action='store_true', + help='share encoder, decoder and output embeddings' + ' (requires shared dictionary and embed dim)') + parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true', + help='if set, disables positional embeddings (outside self attention)') + parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', + help='comma separated list of adaptive softmax cutoff points. ' + 'Must be used with adaptive_loss criterion'), + parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', + help='sets adaptive softmax dropout for the tail projections') + parser.add_argument('--num-embedding-chunks', type=int, metavar='N', default=1, + help='Number of embedding layer chunks (enables more even distribution' + 'of optimizer states across data parallel nodes' + 'when using optimizer state sharding and' + 'a big embedding vocabulary)') + # fmt: on + + @classmethod + def build_model_base(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_architecture(args) + + if not hasattr(args, "max_source_positions"): + args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS + if not hasattr(args, "max_target_positions"): + args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS + + src_dict, tgt_dict = task.source_dictionary, task.target_dictionary + + def build_embedding(dictionary, embed_dim, path=None, num_embed_chunks=1): + assert embed_dim % num_embed_chunks == 0, ( + f"Number of embedding chunks = {num_embed_chunks} should be " + + f"divisible by the embedding dimension = {embed_dim}" + ) + assert path is None or num_embed_chunks == 1, ( + "Loading embedding from a path with number of embedding chunks > 1" + + " is not yet supported" + ) + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + # if provided, load from preloaded dictionaries + if path: + emb = Embedding(num_embeddings, embed_dim, padding_idx) + embed_dict = utils.parse_embedding(path) + utils.load_embedding(embed_dict, dictionary, emb) + else: + embed_chunk_dim = embed_dim // num_embed_chunks + emb = nn.ModuleList() + for i in range(num_embed_chunks): + emb.append(Embedding(num_embeddings, embed_chunk_dim, padding_idx)) + return emb + + num_embed_chunks = args.num_embedding_chunks + if args.share_all_embeddings: + if src_dict != tgt_dict: + raise ValueError("--share-all-embeddings requires a joined dictionary") + if args.encoder_embed_dim != args.decoder_embed_dim: + raise ValueError( + "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" + ) + if args.decoder_embed_path and ( + args.decoder_embed_path != args.encoder_embed_path + ): + raise ValueError( + "--share-all-embeddings not compatible with --decoder-embed-path" + ) + encoder_embed_tokens = build_embedding( + src_dict, + args.encoder_embed_dim, + args.encoder_embed_path, + num_embed_chunks, + ) + decoder_embed_tokens = encoder_embed_tokens + args.share_decoder_input_output_embed = True + else: + assert args.share_decoder_input_output_embed or num_embed_chunks == 1, ( + "Not sharing decoder I/O embeddings is not yet supported with number of " + + "embedding chunks > 1" + ) + encoder_embed_tokens = build_embedding( + src_dict, + args.encoder_embed_dim, + args.encoder_embed_path, + num_embed_chunks, + ) + decoder_embed_tokens = build_embedding( + tgt_dict, + args.decoder_embed_dim, + args.decoder_embed_path, + num_embed_chunks, + ) + + encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) + decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) + return (encoder, decoder) + + @classmethod + def build_encoder(cls, args, src_dict, embed_tokens): + return TransformerEncoder(args, src_dict, embed_tokens) + + @classmethod + def build_decoder(cls, args, tgt_dict, embed_tokens): + return TransformerDecoder(args, tgt_dict, embed_tokens) + + @classmethod + def build_model(cls, args, task): + encoder, decoder = cls.build_model_base(args, task) + return PipelineParallelTransformerModel( + encoder=encoder, + decoder=decoder, + balance=utils.eval_str_list(args.pipeline_balance, type=int), + devices=utils.eval_str_list(args.pipeline_devices, type=int), + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) + + def output_layer(self, features, **kwargs): + """Project features to the default output size (typically vocabulary size).""" + return self.decoder.output_layer(features, **kwargs) + + def max_positions(self): + """Maximum length supported by the model.""" + return (self.encoder_max_positions, self.decoder_max_positions) + + def max_positions_helper( + self, embedding_layer, max_positions_field="max_source_positions" + ): + """Maximum input length supported by the encoder or decoder.""" + if embedding_layer.embed_positions is None: + return getattr(embedding_layer, max_positions_field) + return min( + getattr(embedding_layer, max_positions_field), + embedding_layer.embed_positions.max_positions, + ) + + def get_normalized_probs(self, net_output, log_probs, sample=None): + """Get normalized probabilities (or log probs) from a net's output.""" + + if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None: + if sample is not None: + assert "target" in sample + target = sample["target"] + else: + target = None + out = self.adaptive_softmax.get_log_prob(net_output, target=target) + return out.exp_() if not log_probs else out + + # A Pipe() module returns a tuple of tensors as the output. + # In this case, the tuple has one element - the output tensor of logits + logits = net_output if isinstance(net_output, torch.Tensor) else net_output[0] + if log_probs: + return utils.log_softmax(logits, dim=-1, onnx_trace=False) + else: + return utils.softmax(logits, dim=-1, onnx_trace=False) + + def max_decoder_positions(self): + """Maximum length supported by the decoder.""" + return self.decoder_max_positions + + def load_state_dict(self, state_dict, strict=True, model_cfg=None): + """Copies parameters and buffers from *state_dict* into this module and + its descendants. + + Overrides the method in :class:`nn.Module`. Compared with that method + this additionally "upgrades" *state_dicts* from old checkpoints. + """ + self.upgrade_state_dict(state_dict) + is_regular_transformer = not any("model.partitions" in k for k in state_dict) + if is_regular_transformer: + state_dict = self.convert_to_pipeline_parallel_state_dict(state_dict) + return super().load_state_dict(state_dict, strict) + + def convert_to_pipeline_parallel_state_dict(self, state_dict): + new_state_dict = self.state_dict() + encoder_layer_idx = 0 + decoder_layer_idx = 0 + encoder_key_suffixes = [ + "self_attn.k_proj.weight", + "self_attn.k_proj.bias", + "self_attn.v_proj.weight", + "self_attn.v_proj.bias", + "self_attn.q_proj.weight", + "self_attn.q_proj.bias", + "self_attn.out_proj.weight", + "self_attn.out_proj.bias", + "self_attn_layer_norm.weight", + "self_attn_layer_norm.bias", + "fc1.weight", + "fc1.bias", + "fc2.weight", + "fc2.bias", + "final_layer_norm.weight", + "final_layer_norm.bias", + ] + decoder_key_suffixes = [ + "self_attn.k_proj.weight", + "self_attn.k_proj.bias", + "self_attn.v_proj.weight", + "self_attn.v_proj.bias", + "self_attn.q_proj.weight", + "self_attn.q_proj.bias", + "self_attn.out_proj.weight", + "self_attn.out_proj.bias", + "self_attn_layer_norm.weight", + "self_attn_layer_norm.bias", + "encoder_attn.k_proj.weight", + "encoder_attn.k_proj.bias", + "encoder_attn.v_proj.weight", + "encoder_attn.v_proj.bias", + "encoder_attn.q_proj.weight", + "encoder_attn.q_proj.bias", + "encoder_attn.out_proj.weight", + "encoder_attn.out_proj.bias", + "encoder_attn_layer_norm.weight", + "encoder_attn_layer_norm.bias", + "fc1.weight", + "fc1.bias", + "fc2.weight", + "fc2.bias", + "final_layer_norm.weight", + "final_layer_norm.bias", + ] + for pid, partition in enumerate(self.model.partitions): + logger.info(f"Begin Partition {pid}") + for mid, module in enumerate(partition): + # fmt: off + if isinstance(module, TransformerEncoderEmbedding): + new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['encoder.embed_tokens.weight'] + if isinstance(module, TransformerEncoderLayer): + for suffix in encoder_key_suffixes: + new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'encoder.layers.{encoder_layer_idx}.{suffix}'] + encoder_layer_idx += 1 + if isinstance(module, TransformerDecoderLayer): + for suffix in decoder_key_suffixes: + new_state_dict[f'model.partitions.{pid}.{mid}.{suffix}'] = state_dict[f'decoder.layers.{decoder_layer_idx}.{suffix}'] + decoder_layer_idx += 1 + if isinstance(module, TransformerEncoderLayerNorm): + if 'encoder.layer_norm.weight' in state_dict: + new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.weight'] = state_dict['encoder.layer_norm.weight'] + new_state_dict[f'model.partitions.{pid}.{mid}.layer_norm.bias'] = state_dict['encoder.layer_norm.bias'] + if isinstance(module, TransformerDecoderEmbedding): + new_state_dict[f'model.partitions.{pid}.{mid}.embed_tokens.weight'] = state_dict['decoder.embed_tokens.weight'] + if isinstance(module, TransformerDecoderOutputLayer): + new_state_dict[f'model.partitions.{pid}.{mid}.output_projection.weight'] = state_dict['decoder.output_projection.weight'] + # fmt: on + return new_state_dict + + +class TransformerEncoder(FairseqEncoder): + """ + Transformer encoder consisting of *args.encoder_layers* layers. Each layer + is a :class:`TransformerEncoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): encoding dictionary + embed_tokens (torch.nn.Embedding): input embedding + """ + + def __init__(self, args, dictionary, embed_tokens, encoder_module_list=None): + super().__init__(dictionary) + self.register_buffer("version", torch.Tensor([3])) + import_pipe() + self.use_pipeline = encoder_module_list is not None + if not self.use_pipeline: + self.embedding_layer = TransformerEncoderEmbedding(args, embed_tokens) + self.encoder_layers = nn.Sequential( + *[TransformerEncoderLayer(args) for i in range(args.encoder_layers)] + ) + if isinstance(embed_tokens, nn.ModuleList): + emb_dim = sum(e.embedding_dim for e in embed_tokens) + else: + emb_dim = embed_tokens.embedding_dim + self.final_layer_norm = TransformerEncoderLayerNorm(args, emb_dim) + else: + encoder_balance = utils.eval_str_list( + args.pipeline_encoder_balance, type=int + ) + encoder_devices = utils.eval_str_list( + args.pipeline_encoder_devices, type=int + ) + assert sum(encoder_balance) == len(encoder_module_list), ( + f"Sum of encoder_balance={encoder_balance} is not equal " + + f"to num_encoder_modules={len(encoder_module_list)}" + ) + if TORCH_PIPE: + self.model = Pipe( + module=partition_model( + nn.Sequential(*encoder_module_list), + encoder_balance, + encoder_devices, + ), + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) + else: + self.model = Pipe( + module=nn.Sequential(*encoder_module_list), + balance=encoder_balance, + devices=encoder_devices, + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) + + def forward(self, src_tokens, src_lengths): + """ + Args: + input_tuple( + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (torch.LongTensor): lengths of each source sentence of + shape `(batch)` + ) + + Returns: + output_tuple( + - **encoder_out** (Tensor): the last encoder layer's output of + shape `(src_len, batch, embed_dim)` + - **encoder_padding_mask** (ByteTensor): the positions of + padding elements of shape `(batch, src_len)` + - prev_output_tokens + - **encoder_states** (List[Tensor]): all intermediate + hidden states of shape `(src_len, batch, embed_dim)`. + Only populated if *return_all_hiddens* is True. + ) + """ + dummy_prev_output_tokens = torch.zeros( + 1, dtype=src_tokens.dtype, device=src_tokens.device + ) + input_tuple = (src_tokens, src_lengths, dummy_prev_output_tokens) + if self.use_pipeline: + input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple) + if TORCH_PIPE: + encoder_out = self.model(input_tuple).local_value() + else: + encoder_out = self.model(input_tuple) + else: + encoder_embed_output_tuple = self.embedding_layer(input_tuple) + encoder_layers_output = self.encoder_layers(encoder_embed_output_tuple) + encoder_out = self.final_layer_norm(encoder_layers_output) + # first element is the encoder output + # second element is the encoder padding mask + # the remaining elements of EncoderOut are not computed by + # the PipelineParallelTransformer + return EncoderOut(encoder_out[0], encoder_out[1], None, None, None, None) + + def reorder_encoder_out(self, encoder_out, new_order): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + if encoder_out.encoder_out is not None: + encoder_out = encoder_out._replace( + encoder_out=encoder_out.encoder_out.index_select(1, new_order) + ) + if encoder_out.encoder_padding_mask is not None: + encoder_out = encoder_out._replace( + encoder_padding_mask=encoder_out.encoder_padding_mask.index_select( + 0, new_order + ) + ) + if encoder_out.encoder_embedding is not None: + encoder_out = encoder_out._replace( + encoder_embedding=encoder_out.encoder_embedding.index_select( + 0, new_order + ) + ) + if encoder_out.encoder_states is not None: + for idx, state in enumerate(encoder_out.encoder_states): + encoder_out.encoder_states[idx] = state.index_select(1, new_order) + return encoder_out + + def max_positions(self): + """Maximum input length supported by the encoder.""" + if self.embedding_layer.embed_positions is None: + return self.embedding_layer.max_source_positions + return min( + self.embedding_layer.max_source_positions, + self.embedding_layer.embed_positions.max_positions, + ) + + +class TransformerDecoder(FairseqDecoder): + """ + Transformer decoder consisting of *args.decoder_layers* layers. Each layer + is a :class:`TransformerDecoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): decoding dictionary + embed_tokens (torch.nn.Embedding): output embedding + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__( + self, + args, + dictionary, + embed_tokens, + no_encoder_attn=False, + decoder_module_list=None, + ): + super().__init__(dictionary) + self.register_buffer("version", torch.Tensor([3])) + import_pipe() + self.use_pipeline = decoder_module_list is not None + if not self.use_pipeline: + self.embedding_layer = TransformerDecoderEmbedding(args, embed_tokens) + self.decoder_layers = nn.Sequential( + *[ + TransformerDecoderLayer(args, no_encoder_attn) + for _ in range(args.decoder_layers) + ] + ) + self.decoder_output_layer = TransformerDecoderOutputLayer( + args, embed_tokens, dictionary + ) + else: + decoder_balance = utils.eval_str_list( + args.pipeline_decoder_balance, type=int + ) + decoder_devices = utils.eval_str_list( + args.pipeline_decoder_devices, type=int + ) + assert sum(decoder_balance) == len(decoder_module_list), ( + f"Sum of decoder_balance={decoder_balance} is not equal " + + f"to num_decoder_modules={len(decoder_module_list)}" + ) + if TORCH_PIPE: + self.model = Pipe( + module=partition_model( + nn.Sequential(*decoder_module_list), + decoder_balance, + decoder_devices, + ), + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) + else: + self.model = Pipe( + module=nn.Sequential(*decoder_module_list), + balance=decoder_balance, + devices=decoder_devices, + chunks=args.pipeline_chunks, + checkpoint=args.pipeline_checkpoint, + ) + + def forward( + self, + prev_output_tokens, + encoder_out=None, + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (optional): output from the encoder, used for + encoder-side attention + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + features_only (bool, optional): only return features without + applying output layer (default: False). + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + input_tuple = ( + encoder_out.encoder_out, + encoder_out.encoder_padding_mask, + prev_output_tokens, + ) + if self.use_pipeline: + input_tuple = tuple(i.to(self.model.devices[0]) for i in input_tuple) + if TORCH_PIPE: + return (self.model(input_tuple).local_value(),) + else: + return (self.model(input_tuple),) + else: + embed_layer_output = self.embedding_layer(input_tuple) + state = self.decoder_layers(embed_layer_output) + return (self.decoder_output_layer(state),) + + def output_layer(self, features, **kwargs): + """Project features to the vocabulary size.""" + if self.adaptive_softmax is None: + # project back to size of vocabulary + if self.share_input_output_embed: + return F.linear(features, self.embed_tokens.weight) + else: + return F.linear(features, self.embed_out) + else: + return features + + def max_positions(self): + """Maximum output length supported by the decoder.""" + if self.embedding_layer.embed_positions is None: + return self.embedding_layer.max_target_positions + return min( + self.embedding_layer.max_target_positions, + self.embedding_layer.embed_positions.max_positions, + ) + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + if ( + not hasattr(self, "_future_mask") + or self._future_mask is None + or self._future_mask.device != tensor.device + or self._future_mask.size(0) < dim + ): + self._future_mask = torch.triu( + utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 + ) + return self._future_mask[:dim, :dim] + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + for i in range(len(self.layers)): + # update layer norms + layer_norm_map = { + "0": "self_attn_layer_norm", + "1": "encoder_attn_layer_norm", + "2": "final_layer_norm", + } + for old, new in layer_norm_map.items(): + for m in ("weight", "bias"): + k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m) + if k in state_dict: + state_dict[ + "{}.layers.{}.{}.{}".format(name, i, new, m) + ] = state_dict[k] + del state_dict[k] + + version_key = "{}.version".format(name) + if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2: + # earlier checkpoints did not normalize after the stack of layers + self.layer_norm = None + self.normalize = False + state_dict[version_key] = torch.Tensor([1]) + + return state_dict + + +@register_model_architecture( + "pipeline_parallel_transformer", "transformer_iwslt_de_en_pipeline_parallel" +) +def transformer_iwslt_de_en_dist(args): + transformer_iwslt_de_en(args) + + +@register_model_architecture( + "pipeline_parallel_transformer", "transformer_wmt_en_de_big_pipeline_parallel" +) +def transformer_wmt_en_de_big_dist(args): + transformer_wmt_en_de_big(args) diff --git a/fairseq/model_parallel/models/roberta/__init__.py b/fairseq/model_parallel/models/roberta/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..117827c3e9c176477f33e3a6fd7fe19a922411a2 --- /dev/null +++ b/fairseq/model_parallel/models/roberta/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .model import * # noqa diff --git a/fairseq/model_parallel/models/roberta/__pycache__/__init__.cpython-310.pyc b/fairseq/model_parallel/models/roberta/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..520a3f6035653e33187f8b9baf7b5fb0b4cb8b77 Binary files /dev/null and b/fairseq/model_parallel/models/roberta/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/model_parallel/models/roberta/__pycache__/model.cpython-310.pyc b/fairseq/model_parallel/models/roberta/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e05876c914d6dd3c17bba9f9be968bc86297769a Binary files /dev/null and b/fairseq/model_parallel/models/roberta/__pycache__/model.cpython-310.pyc differ diff --git a/fairseq/model_parallel/models/roberta/model.py b/fairseq/model_parallel/models/roberta/model.py new file mode 100644 index 0000000000000000000000000000000000000000..77a80ef72057219110b34678a38705549910edd3 --- /dev/null +++ b/fairseq/model_parallel/models/roberta/model.py @@ -0,0 +1,225 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +RoBERTa: A Robustly Optimized BERT Pretraining Approach. +""" + +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import utils +from fairseq.model_parallel.models.transformer import ModelParallelTransformerEncoder +from fairseq.models import register_model, register_model_architecture +from fairseq.models.roberta import ( + roberta_base_architecture, + roberta_prenorm_architecture, + RobertaEncoder, + RobertaModel, +) +from fairseq.modules import LayerNorm + + +try: + from fairseq.model_parallel.megatron.mpu import ( + copy_to_model_parallel_region, + gather_from_model_parallel_region, + ColumnParallelLinear, + VocabParallelEmbedding, + ) + + has_megatron_submodule = True +except (ImportError, ModuleNotFoundError): + has_megatron_submodule = False + +logger = logging.getLogger(__name__) + + +@register_model("model_parallel_roberta") +class ModelParallelRobertaModel(RobertaModel): + def __init__(self, args, encoder): + super().__init__(args, encoder) + + self.classification_heads = nn.ModuleDict() + + @staticmethod + def add_args(parser): + RobertaModel.add_args(parser) + parser.add_argument( + "--no-final-layer-norm", + action="store_true", + help=( + "don't add final layernorm (only applicable when " + "--encoder-normalize-before=True" + ), + ) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present + base_architecture(args) + + task.source_dictionary.pad_to_multiple_(args.model_parallel_size * 8) + task.target_dictionary.pad_to_multiple_(args.model_parallel_size * 8) + + if not hasattr(args, "max_positions"): + args.max_positions = args.tokens_per_sample + + if getattr(args, "untie_weights_roberta", False): + raise NotImplementedError( + "--untie-weights-roberta is not supported in model parallel mode" + ) + + encoder = ModelParallelRobertaEncoder(args, task.source_dictionary) + return cls(args, encoder) + + def forward( + self, + src_tokens, + features_only=False, + return_all_hiddens=False, + classification_head_name=None, + **kwargs + ): + if classification_head_name is not None: + features_only = True + + x, extra = self.encoder(src_tokens, features_only, return_all_hiddens, **kwargs) + + if classification_head_name is not None: + x = self.classification_heads[classification_head_name](x) + return x, extra + + def register_classification_head( + self, name, num_classes=None, inner_dim=None, **kwargs + ): + """Register a classification head.""" + if name in self.classification_heads: + prev_num_classes = self.classification_heads[name].out_proj.out_features + prev_inner_dim = self.classification_heads[name].dense.out_features + if num_classes != prev_num_classes or inner_dim != prev_inner_dim: + logger.warning( + 're-registering head "{}" with num_classes {} (prev: {}) ' + "and inner_dim {} (prev: {})".format( + name, num_classes, prev_num_classes, inner_dim, prev_inner_dim + ) + ) + self.classification_heads[name] = ModelParallelRobertaClassificationHead( + self.args.encoder_embed_dim, + inner_dim or self.args.encoder_embed_dim, + num_classes, + self.args.pooler_activation_fn, + self.args.pooler_dropout, + ) + + +class ModelParallelRobertaLMHead(nn.Module): + """Head for masked language modeling.""" + + def __init__(self, embed_dim, output_dim, activation_fn, weight=None): + super().__init__() + self.dense = ColumnParallelLinear(embed_dim, embed_dim, gather_output=True) + self.activation_fn = utils.get_activation_fn(activation_fn) + self.layer_norm = LayerNorm(embed_dim) + + if weight is None: + weight = nn.Linear(embed_dim, output_dim, bias=False).weight + self.weight = weight + self.bias = nn.Parameter(torch.zeros(output_dim)) + + def forward(self, features, masked_tokens=None, **kwargs): + # Only project the unmasked tokens while training, + # saves both memory and computation + if masked_tokens is not None: + features = features[masked_tokens, :] + + x = self.dense(features) + x = self.activation_fn(x) + x = self.layer_norm(x) + + x = copy_to_model_parallel_region(x) + # project back to size of vocabulary with bias + x = F.linear(x, self.weight) + x = gather_from_model_parallel_region(x).contiguous() + x = x + self.bias + return x + + +class ModelParallelRobertaClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout + ): + super().__init__() + self.dense = ColumnParallelLinear(input_dim, inner_dim, gather_output=True) + self.activation_fn = utils.get_activation_fn(activation_fn) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = self.activation_fn(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +class ModelParallelRobertaEncoder(RobertaEncoder): + """RoBERTa encoder.""" + + def __init__(self, args, dictionary): + super().__init__(args, dictionary) + assert not self.args.untie_weights_roberta + + def build_embedding(self, vocab_size, embedding_dim, padding_idx): + return VocabParallelEmbedding(vocab_size, embedding_dim, padding_idx) + + def build_encoder(self, args, dictionary, embed_tokens): + return ModelParallelTransformerEncoder(args, dictionary, embed_tokens) + + def build_lm_head(self, embed_dim, output_dim, activation_fn, weight): + return ModelParallelRobertaLMHead(embed_dim, output_dim, activation_fn, weight) + + +@register_model_architecture("model_parallel_roberta", "model_parallel_roberta") +def base_architecture(args): + args.no_final_layer_norm = getattr(args, "no_final_layer_norm", False) + # model parallel RoBERTa defaults to "Pre-LN" formulation + roberta_prenorm_architecture(args) + + +# earlier versions of model parallel RoBERTa removed the final layer norm +@register_model_architecture("model_parallel_roberta", "model_parallel_roberta_v1") +def model_parallel_roberta_v1_architecture(args): + args.no_final_layer_norm = getattr(args, "no_final_layer_norm", True) + base_architecture(args) + + +@register_model_architecture( + "model_parallel_roberta", "model_parallel_roberta_postnorm" +) +def model_parallel_roberta_postnorm_architecture(args): + # the original BERT/RoBERTa uses the "Post-LN" formulation + roberta_base_architecture(args) + + +@register_model_architecture("model_parallel_roberta", "model_parallel_roberta_base") +def model_parallel_roberta_base_architecture(args): + base_architecture(args) + + +@register_model_architecture("model_parallel_roberta", "model_parallel_roberta_large") +def model_parallel_roberta_large_architecture(args): + args.encoder_layers = getattr(args, "encoder_layers", 24) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + base_architecture(args) diff --git a/fairseq/model_parallel/models/transformer.py b/fairseq/model_parallel/models/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..cf3b2e8baf01389a34056cc68cbf6ad1d4475707 --- /dev/null +++ b/fairseq/model_parallel/models/transformer.py @@ -0,0 +1,121 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch.nn as nn + +from fairseq.model_parallel.modules import ( + ModelParallelTransformerDecoderLayer, + ModelParallelTransformerEncoderLayer, +) +from fairseq.models import register_model +from fairseq.models.transformer import ( + TransformerDecoder, + TransformerEncoder, + TransformerModel, +) + +try: + from fairseq.model_parallel.megatron.mpu import ( + VocabParallelEmbedding, + copy_to_model_parallel_region, + gather_from_model_parallel_region, + ) + + has_megatron_submodule = True +except (ImportError, ModuleNotFoundError): + has_megatron_submodule = False + + +logger = logging.getLogger(__name__) + + +@register_model("model_parallel_transformer") +class ModelParallelTransformerModel(TransformerModel): + """ + Model parallel Transformer model. + """ + + @classmethod + def build_embedding(cls, args, dictionary, embed_dim, path=None): + if not has_megatron_submodule: + raise ImportError( + "\n\nPlease install the megatron submodule:" + "\n\n git submodule update --init " + "fairseq/model_parallel/megatron" + ) + dictionary.pad_to_multiple_(args.model_parallel_size * 8) + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + + def _vocab_init(tensor, **kwargs): + nn.init.normal_(tensor, mean=0, std=num_embeddings**-0.5) + nn.init.constant_(tensor[1], 0) + + emb = VocabParallelEmbedding( + num_embeddings, embed_dim, padding_idx, init_method=_vocab_init + ) + # if provided, load from preloaded dictionaries + if path: + raise NotImplementedError( + "Loading of embedding from path is not supported for model parallel" + ) + return emb + + @classmethod + def build_encoder(cls, args, src_dict, embed_tokens): + return ModelParallelTransformerEncoder(args, src_dict, embed_tokens) + + @classmethod + def build_decoder(cls, args, tgt_dict, embed_tokens): + return ModelParallelTransformerDecoder( + args, + tgt_dict, + embed_tokens, + no_encoder_attn=getattr(args, "no_cross_attention", False), + ) + + +class ModelParallelTransformerEncoder(TransformerEncoder): + """ + Model parallel Transformer encoder consisting of *args.encoder_layers* layers. Each layer + is a :class:`ModelParallelTransformerEncoderLayer`. + """ + + def __init__(self, args, dictionary, embed_tokens): + super().__init__(args, dictionary, embed_tokens) + + if args.no_final_layer_norm: + self.layer_norm = None + + def build_encoder_layer(self, args): + return ModelParallelTransformerEncoderLayer(args) + + +class ModelParallelTransformerDecoder(TransformerDecoder): + """ + Model Parallel Transformer decoder consisting of *args.decoder_layers* layers. Each layer + is a :class:`ModelParallelTransformerDecoderLayer`. + """ + + def build_decoder_layer(self, args, no_encoder_attn=False): + return ModelParallelTransformerDecoderLayer(args, no_encoder_attn) + + def output_layer(self, features, **kwargs): + """Project features to the vocabulary size.""" + if not self.share_input_output_embed: + raise NotImplementedError( + "Model parallel training currently requires --share-decoder-input-output-embed" + ) + + features = copy_to_model_parallel_region(features) + + # project back to size of vocabulary + x = self.output_projection(features) + + if getattr(self.args, "criterion") != "vocab_parallel_cross_entropy": + x = gather_from_model_parallel_region(x).contiguous() + return x diff --git a/fairseq/model_parallel/models/transformer_lm.py b/fairseq/model_parallel/models/transformer_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..03e4dbe26393eedfb71da94d6675b08cbdb8626d --- /dev/null +++ b/fairseq/model_parallel/models/transformer_lm.py @@ -0,0 +1,169 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn + +from fairseq.model_parallel.models.transformer import ModelParallelTransformerDecoder +from fairseq.models import register_model, register_model_architecture +from fairseq.models.transformer_lm import TransformerLanguageModel + +try: + from fairseq.model_parallel.megatron.mpu import VocabParallelEmbedding + + has_megatron_submodule = True +except (ImportError, ModuleNotFoundError): + has_megatron_submodule = False + + +DEFAULT_MAX_TARGET_POSITIONS = 1024 + + +@register_model("model_parallel_transformer_lm") +class ModelParallelTransformerLanguageModel(TransformerLanguageModel): + @staticmethod + def add_args(parser): + TransformerLanguageModel.add_args(parser) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + if not has_megatron_submodule: + raise ImportError( + "\n\nPlease install the megatron submodule:" + "\n\n git submodule update --init " + "fairseq/model_parallel/megatron" + ) + + # make sure all arguments are present in older models + base_lm_architecture(args) + + task.source_dictionary.pad_to_multiple_(args.model_parallel_size * 8) + task.target_dictionary.pad_to_multiple_(args.model_parallel_size * 8) + + if args.decoder_layers_to_keep: + args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) + + if getattr(args, "max_target_positions", None) is None: + args.max_target_positions = getattr( + args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS + ) + + if args.character_embeddings: + raise NotImplementedError( + "Character embeddings is not supported for model parallel" + ) + elif args.adaptive_input: + raise NotImplementedError( + "Adaptive input is not supported for model parallel" + ) + else: + embed_tokens = cls.build_embedding( + args, task.source_dictionary, args.decoder_input_dim + ) + + decoder = ModelParallelTransformerDecoder( + args, + task.target_dictionary, + embed_tokens, + no_encoder_attn=True, + ) + return cls(decoder) + + @classmethod + def build_embedding(cls, args, dictionary, embed_dim, path=None): + def _vocab_init(tensor, **kwargs): + nn.init.normal_(tensor, mean=0, std=embed_dim**-0.5) + nn.init.constant_(tensor[1], 0) + + embed_tokens = VocabParallelEmbedding( + len(dictionary), embed_dim, dictionary.pad(), init_method=_vocab_init + ) + return embed_tokens + + +def base_lm_architecture(args): + # backward compatibility for older model checkpoints + if hasattr(args, "no_tie_adaptive_proj"): + # previous models defined --no-tie-adaptive-proj, so use the existence of + # that option to determine if this is an "old" model checkpoint + args.no_decoder_final_norm = True # old models always set this to True + if args.no_tie_adaptive_proj is False: + args.tie_adaptive_proj = True + if hasattr(args, "decoder_final_norm"): + args.no_decoder_final_norm = not args.decoder_final_norm + + args.activation_fn = getattr(args, "activation_fn", "relu") + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.relu_dropout = getattr(args, "relu_dropout", 0.0) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + # Model training is not stable without this + args.decoder_normalize_before = True + args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.character_embeddings = getattr(args, "character_embeddings", False) + args.character_filters = getattr( + args, + "character_filters", + "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]", + ) + args.character_embedding_dim = getattr(args, "character_embedding_dim", 4) + args.char_embedder_highway_layers = getattr(args, "char_embedder_highway_layers", 2) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4) + args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None) + args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) + args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) + args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) + args.layernorm_embedding = getattr(args, "layernorm_embedding", False) + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0.0) + args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) + args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0.0) + args.add_bos_token = getattr(args, "add_bos_token", False) + + +@register_model_architecture("model_parallel_transformer_lm", "transformer_lm_megatron") +def transformer_lm_megatron(args): + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 3072) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3072 * 4) + args.decoder_layers = getattr(args, "decoder_layers", 72) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_fn = getattr(args, "activation_fn", "gelu") + base_lm_architecture(args) + + +@register_model_architecture( + "model_parallel_transformer_lm", "transformer_lm_megatron_11b" +) +def transformer_lm_megatron_11b(args): + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 3072) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 3072 * 6) + args.decoder_layers = getattr(args, "decoder_layers", 72) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_fn = getattr(args, "activation_fn", "gelu") + base_lm_architecture(args) diff --git a/fairseq/model_parallel/modules/__init__.py b/fairseq/model_parallel/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..11603217a188f420ea849ae0fde19979736ba208 --- /dev/null +++ b/fairseq/model_parallel/modules/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +from .multihead_attention import ModelParallelMultiheadAttention +from .transformer_layer import ( + ModelParallelTransformerEncoderLayer, + ModelParallelTransformerDecoderLayer, +) + +__all__ = [ + "ModelParallelMultiheadAttention", + "ModelParallelTransformerEncoderLayer", + "ModelParallelTransformerDecoderLayer", +] diff --git a/fairseq/model_parallel/modules/__pycache__/__init__.cpython-310.pyc b/fairseq/model_parallel/modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..011a23f166a7aba6d37dc9948130bfc148eef7a2 Binary files /dev/null and b/fairseq/model_parallel/modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/model_parallel/modules/__pycache__/multihead_attention.cpython-310.pyc b/fairseq/model_parallel/modules/__pycache__/multihead_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..359324cf9c3f042249ab1092ec329f5de62f2100 Binary files /dev/null and b/fairseq/model_parallel/modules/__pycache__/multihead_attention.cpython-310.pyc differ diff --git a/fairseq/model_parallel/modules/__pycache__/transformer_layer.cpython-310.pyc b/fairseq/model_parallel/modules/__pycache__/transformer_layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fb3ca9eb00901a3e50a283c0e05d003f8150582 Binary files /dev/null and b/fairseq/model_parallel/modules/__pycache__/transformer_layer.cpython-310.pyc differ diff --git a/fairseq/model_parallel/modules/multihead_attention.py b/fairseq/model_parallel/modules/multihead_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..bbea4509508227045a75da4f36de003356901b1c --- /dev/null +++ b/fairseq/model_parallel/modules/multihead_attention.py @@ -0,0 +1,349 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from fairseq import utils +from fairseq.incremental_decoding_utils import with_incremental_state +from fairseq.modules.fairseq_dropout import FairseqDropout + +try: + from fairseq.model_parallel.megatron.mpu import ( + ColumnParallelLinear, + RowParallelLinear, + get_cuda_rng_tracker, + get_model_parallel_world_size, + ) + + has_megatron_submodule = True +except (ImportError, ModuleNotFoundError): + has_megatron_submodule = False + + +@with_incremental_state +class ModelParallelMultiheadAttention(nn.Module): + """Model parallel Multi-headed attention. + This performs the Multi-headed attention over multiple gpus. + + See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + self_attention=False, + encoder_decoder_attention=False, + ): + super().__init__() + if not has_megatron_submodule: + raise ImportError( + "\n\nPlease install the megatron submodule:" + "\n\n git submodule update --init " + "fairseq/model_parallel/megatron" + ) + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.model_parallel_size = get_model_parallel_world_size() + + self.num_heads_partition = num_heads // self.model_parallel_size + assert ( + self.num_heads_partition * self.model_parallel_size == num_heads + ), "Number of heads must be divisible by model parallel size" + + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim**-0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert ( + not self.self_attention or self.qkv_same_dim + ), "Self-attention requires query, key and value to be of the same size" + + self.k_proj = ColumnParallelLinear( + self.kdim, embed_dim, bias=bias, gather_output=False + ) + self.v_proj = ColumnParallelLinear( + self.vdim, embed_dim, bias=bias, gather_output=False + ) + self.q_proj = ColumnParallelLinear( + embed_dim, embed_dim, bias=bias, gather_output=False + ) + self.out_proj = RowParallelLinear( + embed_dim, embed_dim, bias=bias, input_is_parallel=True + ) + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + **unused_kwargs, + ) -> Tuple[Tensor, Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + """ + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + + is_tpu = query.device.type == "xla" + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads_partition, self.head_dim) + .transpose(0, 1) + ) + if k is not None: + k = ( + k.contiguous() + .view(-1, bsz * self.num_heads_partition, self.head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, bsz * self.num_heads_partition, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads_partition, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view( + bsz * self.num_heads_partition, -1, self.head_dim + ) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view( + bsz * self.num_heads_partition, -1, self.head_dim + ) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = ( + ModelParallelMultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + ) + + saved_state["prev_key"] = k.view( + bsz, self.num_heads_partition, -1, self.head_dim + ) + saved_state["prev_value"] = v.view( + bsz, self.num_heads_partition, -1, self.head_dim + ) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + src_len = k.size(1) + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + + assert list(attn_weights.size()) == [ + bsz * self.num_heads_partition, + tgt_len, + src_len, + ] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view( + bsz, self.num_heads_partition, tgt_len, src_len + ) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view( + bsz * self.num_heads_partition, tgt_len, src_len + ) + + attn_weights_float = utils.softmax(attn_weights, dim=-1) + attn_weights = attn_weights_float.type_as(attn_weights) + + with get_cuda_rng_tracker().fork(): + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [ + bsz * self.num_heads_partition, + tgt_len, + self.head_dim, + ] + embed_dim_partition = embed_dim // self.model_parallel_size + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim_partition) + attn = self.out_proj(attn) + # return attn_weights None to keep the return type same as single gpu multihead attention + # This will be deprecated. + attn_weights: Optional[Tensor] = None + + return attn, attn_weights + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + + filler = torch.zeros(batch_size, src_len - prev_key_padding_mask.size(1)) + if prev_key_padding_mask.is_cuda: + filler = filler.cuda() + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + elif key_padding_mask is not None: + filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1)) + if key_padding_mask.is_cuda: + filler = filler.cuda() + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + def reorder_incremental_state( + self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order + ): + """Reorder buffered internal state (for incremental generation).""" + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + for k in input_buffer.keys(): + if input_buffer[k] is not None: + input_buffer[k] = input_buffer[k].index_select(0, new_order) + incremental_state = self._set_input_buffer(incremental_state, input_buffer) + return incremental_state + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) diff --git a/fairseq/model_parallel/modules/transformer_layer.py b/fairseq/model_parallel/modules/transformer_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..7ab53c6e5f12f15562717effb86ab8cb8d6b4fa3 --- /dev/null +++ b/fairseq/model_parallel/modules/transformer_layer.py @@ -0,0 +1,78 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.model_parallel.modules import ModelParallelMultiheadAttention +from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer + + +try: + from fairseq.model_parallel.megatron.mpu import ( + ColumnParallelLinear, + RowParallelLinear, + ) + + has_megatron_submodule = True +except (ImportError, ModuleNotFoundError): + has_megatron_submodule = False + + +class ModelParallelTransformerEncoderLayer(TransformerEncoderLayer): + """Encoder layer block over multiple gpus. + + See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details. + """ + + def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): + if q_noise > 0: + raise NotImplementedError + return ColumnParallelLinear(input_dim, output_dim, gather_output=False) + + def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): + if q_noise > 0: + raise NotImplementedError + return RowParallelLinear(input_dim, output_dim, input_is_parallel=True) + + def build_self_attention(self, embed_dim, args, **unused_kwargs): + return ModelParallelMultiheadAttention( + embed_dim, + args.encoder_attention_heads, + dropout=args.attention_dropout, + self_attention=True, + ) + + +class ModelParallelTransformerDecoderLayer(TransformerDecoderLayer): + """Decoder layer block. + + See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details. + """ + + def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): + if q_noise > 0: + raise NotImplementedError + return ColumnParallelLinear(input_dim, output_dim, gather_output=False) + + def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): + if q_noise > 0: + raise NotImplementedError + return RowParallelLinear(input_dim, output_dim, input_is_parallel=True) + + def build_self_attention(self, embed_dim, args, **unused_kwargs): + return ModelParallelMultiheadAttention( + embed_dim=embed_dim, + num_heads=args.decoder_attention_heads, + dropout=args.attention_dropout, + self_attention=not getattr(args, "cross_self_attention", False), + ) + + def build_encoder_attention(self, embed_dim, args, **unused_kwargs): + return ModelParallelMultiheadAttention( + embed_dim=embed_dim, + num_heads=args.decoder_attention_heads, + kdim=getattr(args, "encoder_embed_dim", None), + vdim=getattr(args, "encoder_embed_dim", None), + dropout=args.attention_dropout, + encoder_decoder_attention=True, + ) diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..11cf6ee530d4e3ad9533ec120445fe17238f904b --- /dev/null +++ b/fairseq/models/__init__.py @@ -0,0 +1,236 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import argparse +import importlib +import os + +from contextlib import ExitStack + +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import merge_with_parent +from hydra.core.config_store import ConfigStore +from omegaconf import open_dict, OmegaConf + +from .composite_encoder import CompositeEncoder +from .distributed_fairseq_model import DistributedFairseqModel +from .fairseq_decoder import FairseqDecoder +from .fairseq_encoder import FairseqEncoder +from .fairseq_incremental_decoder import FairseqIncrementalDecoder +from .fairseq_model import ( + BaseFairseqModel, + FairseqEncoderDecoderModel, + FairseqEncoderModel, + FairseqLanguageModel, + FairseqModel, + FairseqMultiModel, +) + + +MODEL_REGISTRY = {} +MODEL_DATACLASS_REGISTRY = {} +ARCH_MODEL_REGISTRY = {} +ARCH_MODEL_NAME_REGISTRY = {} +ARCH_MODEL_INV_REGISTRY = {} +ARCH_CONFIG_REGISTRY = {} + + +__all__ = [ + "BaseFairseqModel", + "CompositeEncoder", + "DistributedFairseqModel", + "FairseqDecoder", + "FairseqEncoder", + "FairseqEncoderDecoderModel", + "FairseqEncoderModel", + "FairseqIncrementalDecoder", + "FairseqLanguageModel", + "FairseqModel", + "FairseqMultiModel", +] + + +def build_model(cfg: FairseqDataclass, task, from_checkpoint=False): + + model = None + model_type = getattr(cfg, "_name", None) or getattr(cfg, "arch", None) + + if not model_type and len(cfg) == 1: + # this is hit if config object is nested in directory that is named after model type + + model_type = next(iter(cfg)) + if model_type in MODEL_DATACLASS_REGISTRY: + cfg = cfg[model_type] + else: + raise Exception( + "Could not infer model type from directory. Please add _name field to indicate model type. " + "Available models: " + + str(MODEL_DATACLASS_REGISTRY.keys()) + + " Requested model type: " + + model_type + ) + + if model_type in ARCH_MODEL_REGISTRY: + # case 1: legacy models + model = ARCH_MODEL_REGISTRY[model_type] + elif model_type in MODEL_DATACLASS_REGISTRY: + # case 2: config-driven models + model = MODEL_REGISTRY[model_type] + + if model_type in MODEL_DATACLASS_REGISTRY: + # set defaults from dataclass. note that arch name and model name can be the same + dc = MODEL_DATACLASS_REGISTRY[model_type] + + if isinstance(cfg, argparse.Namespace): + cfg = dc.from_namespace(cfg) + else: + cfg = merge_with_parent(dc(), cfg, from_checkpoint) + else: + if model_type in ARCH_CONFIG_REGISTRY: + with open_dict(cfg) if OmegaConf.is_config(cfg) else ExitStack(): + # this calls the different "arch" functions (like base_architecture()) that you indicate + # if you specify --arch on the command line. this is only applicable to the old argparse based models + # hydra models should expose different architectures via different config files + # it will modify the cfg object and default parameters according to the arch + ARCH_CONFIG_REGISTRY[model_type](cfg) + + assert model is not None, ( + f"Could not infer model type from {cfg}. " + "Available models: {}".format(MODEL_DATACLASS_REGISTRY.keys()) + + f" Requested model type: {model_type}" + ) + + return model.build_model(cfg, task) + + +def register_model(name, dataclass=None): + """ + New model types can be added to fairseq with the :func:`register_model` + function decorator. + + For example:: + + @register_model('lstm') + class LSTM(FairseqEncoderDecoderModel): + (...) + + .. note:: All models must implement the :class:`BaseFairseqModel` interface. + Typically you will extend :class:`FairseqEncoderDecoderModel` for + sequence-to-sequence tasks or :class:`FairseqLanguageModel` for + language modeling tasks. + + Args: + name (str): the name of the model + """ + + def register_model_cls(cls): + if name in MODEL_REGISTRY: + return MODEL_REGISTRY[name] + + if not issubclass(cls, BaseFairseqModel): + raise ValueError( + "Model ({}: {}) must extend BaseFairseqModel".format(name, cls.__name__) + ) + MODEL_REGISTRY[name] = cls + if dataclass is not None and not issubclass(dataclass, FairseqDataclass): + raise ValueError( + "Dataclass {} must extend FairseqDataclass".format(dataclass) + ) + + cls.__dataclass = dataclass + if dataclass is not None: + MODEL_DATACLASS_REGISTRY[name] = dataclass + + cs = ConfigStore.instance() + node = dataclass() + node._name = name + cs.store(name=name, group="model", node=node, provider="fairseq") + + @register_model_architecture(name, name) + def noop(_): + pass + + return cls + + return register_model_cls + + +def register_model_architecture(model_name, arch_name): + """ + New model architectures can be added to fairseq with the + :func:`register_model_architecture` function decorator. After registration, + model architectures can be selected with the ``--arch`` command-line + argument. + + For example:: + + @register_model_architecture('lstm', 'lstm_luong_wmt_en_de') + def lstm_luong_wmt_en_de(cfg): + args.encoder_embed_dim = getattr(cfg.model, 'encoder_embed_dim', 1000) + (...) + + The decorated function should take a single argument *cfg*, which is a + :class:`omegaconf.DictConfig`. The decorated function should modify these + arguments in-place to match the desired architecture. + + Args: + model_name (str): the name of the Model (Model must already be + registered) + arch_name (str): the name of the model architecture (``--arch``) + """ + + def register_model_arch_fn(fn): + if model_name not in MODEL_REGISTRY: + raise ValueError( + "Cannot register model architecture for unknown model type ({})".format( + model_name + ) + ) + if arch_name in ARCH_MODEL_REGISTRY: + raise ValueError( + "Cannot register duplicate model architecture ({})".format(arch_name) + ) + if not callable(fn): + raise ValueError( + "Model architecture must be callable ({})".format(arch_name) + ) + ARCH_MODEL_REGISTRY[arch_name] = MODEL_REGISTRY[model_name] + ARCH_MODEL_NAME_REGISTRY[arch_name] = model_name + ARCH_MODEL_INV_REGISTRY.setdefault(model_name, []).append(arch_name) + ARCH_CONFIG_REGISTRY[arch_name] = fn + return fn + + return register_model_arch_fn + + +def import_models(models_dir, namespace): + for file in os.listdir(models_dir): + path = os.path.join(models_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + model_name = file[: file.find(".py")] if file.endswith(".py") else file + importlib.import_module(namespace + "." + model_name) + + # extra `model_parser` for sphinx + if model_name in MODEL_REGISTRY: + parser = argparse.ArgumentParser(add_help=False) + group_archs = parser.add_argument_group("Named architectures") + group_archs.add_argument( + "--arch", choices=ARCH_MODEL_INV_REGISTRY[model_name] + ) + group_args = parser.add_argument_group( + "Additional command-line arguments" + ) + MODEL_REGISTRY[model_name].add_args(group_args) + globals()[model_name + "_parser"] = parser + + +# automatically import any Python files in the models/ directory +models_dir = os.path.dirname(__file__) +import_models(models_dir, "fairseq.models") diff --git a/fairseq/models/__pycache__/__init__.cpython-310.pyc b/fairseq/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd5534283277844dbcbbd8a4f7bb404503f8b917 Binary files /dev/null and b/fairseq/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/models/__pycache__/composite_encoder.cpython-310.pyc b/fairseq/models/__pycache__/composite_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ac591273428363130a3643e61dbce2d0424c9ec Binary files /dev/null and b/fairseq/models/__pycache__/composite_encoder.cpython-310.pyc differ diff --git a/fairseq/models/__pycache__/distributed_fairseq_model.cpython-310.pyc b/fairseq/models/__pycache__/distributed_fairseq_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f91f5390c7a393862a7f2092859576604fdc239b Binary files /dev/null and b/fairseq/models/__pycache__/distributed_fairseq_model.cpython-310.pyc differ diff --git a/fairseq/models/__pycache__/fairseq_decoder.cpython-310.pyc b/fairseq/models/__pycache__/fairseq_decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b26cc55d51ac27eb95e4d79057096bad38da1e5c Binary files /dev/null and b/fairseq/models/__pycache__/fairseq_decoder.cpython-310.pyc differ diff --git a/fairseq/models/__pycache__/fairseq_encoder.cpython-310.pyc b/fairseq/models/__pycache__/fairseq_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80cacbe625d9130caaebf103b40d6ea1e06118a9 Binary files /dev/null and b/fairseq/models/__pycache__/fairseq_encoder.cpython-310.pyc differ diff --git a/fairseq/models/__pycache__/fairseq_incremental_decoder.cpython-310.pyc b/fairseq/models/__pycache__/fairseq_incremental_decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22809b3a3065e810ca130b9028e32c3d71abfffd Binary files /dev/null and b/fairseq/models/__pycache__/fairseq_incremental_decoder.cpython-310.pyc differ diff --git a/fairseq/models/__pycache__/fairseq_model.cpython-310.pyc b/fairseq/models/__pycache__/fairseq_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cc89cbbc86c50ae19d45dfe0a1b55223cd22f12 Binary files /dev/null and b/fairseq/models/__pycache__/fairseq_model.cpython-310.pyc differ diff --git a/fairseq/models/__pycache__/fconv.cpython-310.pyc b/fairseq/models/__pycache__/fconv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9a57bae991c57b110680b62f8869c2828926491 Binary files /dev/null and b/fairseq/models/__pycache__/fconv.cpython-310.pyc differ diff --git a/fairseq/models/__pycache__/fconv_lm.cpython-310.pyc b/fairseq/models/__pycache__/fconv_lm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df1cc02d540f33a67d1b5417d3eb9dc6fcbbc8a0 Binary files /dev/null and b/fairseq/models/__pycache__/fconv_lm.cpython-310.pyc differ diff --git a/fairseq/models/__pycache__/fconv_self_att.cpython-310.pyc b/fairseq/models/__pycache__/fconv_self_att.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32ffed2ae197d1520e9d79d12a4115719120a47d Binary files /dev/null and b/fairseq/models/__pycache__/fconv_self_att.cpython-310.pyc differ diff --git a/fairseq/models/__pycache__/lightconv.cpython-310.pyc b/fairseq/models/__pycache__/lightconv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..144da838aa6e37cfaa08f80b71adfadea2b3c80f Binary files /dev/null and b/fairseq/models/__pycache__/lightconv.cpython-310.pyc differ diff --git a/fairseq/models/__pycache__/lightconv_lm.cpython-310.pyc b/fairseq/models/__pycache__/lightconv_lm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81864b630625ced40ac6fe96f680cf2747ccab9d Binary files /dev/null and b/fairseq/models/__pycache__/lightconv_lm.cpython-310.pyc differ diff --git a/fairseq/models/__pycache__/lstm.cpython-310.pyc b/fairseq/models/__pycache__/lstm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be5e26c36a44cc5b7d6c1727eee476b9748319ef Binary files /dev/null and b/fairseq/models/__pycache__/lstm.cpython-310.pyc differ diff --git a/fairseq/models/__pycache__/lstm_lm.cpython-310.pyc b/fairseq/models/__pycache__/lstm_lm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cc190f4c081c9352f21d6d91c5a655abe80d823 Binary files /dev/null and b/fairseq/models/__pycache__/lstm_lm.cpython-310.pyc differ diff --git a/fairseq/models/__pycache__/masked_lm.cpython-310.pyc b/fairseq/models/__pycache__/masked_lm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74ecf6d7cdf438a1b83116a72b0739813473c6a8 Binary files /dev/null and b/fairseq/models/__pycache__/masked_lm.cpython-310.pyc differ diff --git a/fairseq/models/__pycache__/model_utils.cpython-310.pyc b/fairseq/models/__pycache__/model_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2320fa35b5666ed148402a3ce7ec7a62c308a158 Binary files /dev/null and b/fairseq/models/__pycache__/model_utils.cpython-310.pyc differ diff --git a/fairseq/models/__pycache__/multilingual_transformer.cpython-310.pyc b/fairseq/models/__pycache__/multilingual_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f922a86f3acab32f5d4dfae7fe99d227bb220926 Binary files /dev/null and b/fairseq/models/__pycache__/multilingual_transformer.cpython-310.pyc differ diff --git a/fairseq/models/__pycache__/transformer_align.cpython-310.pyc b/fairseq/models/__pycache__/transformer_align.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c71a38409f16609c837a37e4627b13a505aa697d Binary files /dev/null and b/fairseq/models/__pycache__/transformer_align.cpython-310.pyc differ diff --git a/fairseq/models/__pycache__/transformer_from_pretrained_xlm.cpython-310.pyc b/fairseq/models/__pycache__/transformer_from_pretrained_xlm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cadd4114cf8d2bd1938b2950b2569f266dfdf22 Binary files /dev/null and b/fairseq/models/__pycache__/transformer_from_pretrained_xlm.cpython-310.pyc differ diff --git a/fairseq/models/__pycache__/transformer_lm.cpython-310.pyc b/fairseq/models/__pycache__/transformer_lm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00c50461933d01f1394954d95f65d7dc48e4d90c Binary files /dev/null and b/fairseq/models/__pycache__/transformer_lm.cpython-310.pyc differ diff --git a/fairseq/models/__pycache__/transformer_ulm.cpython-310.pyc b/fairseq/models/__pycache__/transformer_ulm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0616971029b20c2ecf2ac3a3171a47d4e5b6e55 Binary files /dev/null and b/fairseq/models/__pycache__/transformer_ulm.cpython-310.pyc differ diff --git a/fairseq/models/bart/__init__.py b/fairseq/models/bart/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a701923f7e5a2a8aa9b75e5580ddea22907f53ee --- /dev/null +++ b/fairseq/models/bart/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .hub_interface import * # noqa +from .model import * # noqa diff --git a/fairseq/models/bart/__pycache__/__init__.cpython-310.pyc b/fairseq/models/bart/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..106409937b32cfe0a4f4814b1a9559a201bfb5d3 Binary files /dev/null and b/fairseq/models/bart/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/models/bart/__pycache__/hub_interface.cpython-310.pyc b/fairseq/models/bart/__pycache__/hub_interface.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..611f047d51522038c004665592b4fceb91aa4bea Binary files /dev/null and b/fairseq/models/bart/__pycache__/hub_interface.cpython-310.pyc differ diff --git a/fairseq/models/bart/__pycache__/model.cpython-310.pyc b/fairseq/models/bart/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..405bcbe2b6487c8738cdb6c36bf3157db45419f9 Binary files /dev/null and b/fairseq/models/bart/__pycache__/model.cpython-310.pyc differ diff --git a/fairseq/models/bart/hub_interface.py b/fairseq/models/bart/hub_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..6b647c9642147bd1bedf56af1b7180a1d39fec98 --- /dev/null +++ b/fairseq/models/bart/hub_interface.py @@ -0,0 +1,211 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import logging +from typing import Dict, List + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import utils +from fairseq.data import encoders +from fairseq.hub_utils import GeneratorHubInterface +from omegaconf import open_dict + + +logger = logging.getLogger(__name__) + + +class BARTHubInterface(GeneratorHubInterface): + """A simple PyTorch Hub interface to BART. + + Usage: https://github.com/pytorch/fairseq/tree/main/examples/bart + """ + + def __init__(self, cfg, task, model): + super().__init__(cfg, task, [model]) + self.model = self.models[0] + + def encode( + self, sentence: str, *addl_sentences, no_separator=True + ) -> torch.LongTensor: + """ + BPE-encode a sentence (or multiple sentences). + + Every sequence begins with a beginning-of-sentence (``) symbol. + Every sentence ends with an end-of-sentence (``). + + Example (single sentence): ` a b c ` + Example (sentence pair): ` d e f 1 2 3 ` + + The BPE encoding follows GPT-2. One subtle detail is that the GPT-2 BPE + requires leading spaces. For example:: + + >>> bart.encode('Hello world').tolist() + [0, 31414, 232, 2] + >>> bart.encode(' world').tolist() + [0, 232, 2] + >>> bart.encode('world').tolist() + [0, 8331, 2] + """ + tokens = self.bpe.encode(sentence) + if len(tokens.split(" ")) > min(self.max_positions) - 2: + tokens = " ".join(tokens.split(" ")[: min(self.max_positions) - 2]) + bpe_sentence = " " + tokens + " " + for s in addl_sentences: + bpe_sentence += " " if not no_separator else "" + bpe_sentence += " " + self.bpe.encode(s) + " " + tokens = self.task.source_dictionary.encode_line(bpe_sentence, append_eos=False) + return tokens.long() + + def decode(self, tokens: torch.LongTensor): + assert tokens.dim() == 1 + tokens = tokens.cpu().numpy() + if tokens[0] == self.task.source_dictionary.bos(): + tokens = tokens[1:] # remove + eos_mask = tokens == self.task.source_dictionary.eos() + doc_mask = eos_mask[1:] & eos_mask[:-1] + sentences = np.split(tokens, doc_mask.nonzero()[0] + 1) + sentences = [ + self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences + ] + if len(sentences) == 1: + return sentences[0] + return sentences + + def _build_sample(self, src_tokens: List[torch.LongTensor]): + # assert torch.is_tensor(src_tokens) + dataset = self.task.build_dataset_for_inference( + src_tokens, + [x.numel() for x in src_tokens], + ) + sample = dataset.collater(dataset) + sample = utils.apply_to_sample(lambda tensor: tensor.to(self.device), sample) + return sample + + def generate( + self, + tokenized_sentences: List[torch.LongTensor], + *args, + inference_step_args=None, + skip_invalid_size_inputs=False, + **kwargs + ) -> List[List[Dict[str, torch.Tensor]]]: + inference_step_args = inference_step_args or {} + if "prefix_tokens" in inference_step_args: + raise NotImplementedError("prefix generation not implemented for BART") + res = [] + for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs): + src_tokens = batch["net_input"]["src_tokens"] + inference_step_args["prefix_tokens"] = src_tokens.new_full( + (src_tokens.size(0), 1), fill_value=self.task.source_dictionary.bos() + ).to(device=self.device) + results = super().generate( + src_tokens, + *args, + inference_step_args=inference_step_args, + skip_invalid_size_inputs=skip_invalid_size_inputs, + **kwargs + ) + for id, hypos in zip(batch["id"].tolist(), results): + res.append((id, hypos)) + res = [hypos for _, hypos in sorted(res, key=lambda x: x[0])] + return res + + def extract_features( + self, tokens: torch.LongTensor, return_all_hiddens: bool = False + ) -> torch.Tensor: + if tokens.dim() == 1: + tokens = tokens.unsqueeze(0) + if tokens.size(-1) > min(self.model.max_positions()): + raise ValueError( + "tokens exceeds maximum length: {} > {}".format( + tokens.size(-1), self.model.max_positions() + ) + ) + tokens.to(device=self.device), + prev_output_tokens = tokens.clone() + + prev_output_tokens[:, 0] = tokens.gather( + 1, + (tokens.ne(self.task.source_dictionary.pad()).sum(dim=1) - 1).unsqueeze(-1), + ).squeeze() + + prev_output_tokens[:, 1:] = tokens[:, :-1] + features, extra = self.model( + src_tokens=tokens, + src_lengths=None, + prev_output_tokens=prev_output_tokens, + features_only=True, + return_all_hiddens=return_all_hiddens, + ) + if return_all_hiddens: + # convert from T x B x C -> B x T x C + inner_states = extra["inner_states"] + return [inner_state.transpose(0, 1) for inner_state in inner_states] + else: + return features # just the last layer's features + + def register_classification_head( + self, name: str, num_classes: int = None, embedding_size: int = None, **kwargs + ): + self.model.register_classification_head( + name, num_classes=num_classes, embedding_size=embedding_size, **kwargs + ) + + def predict(self, head: str, tokens: torch.LongTensor, return_logits: bool = False): + if tokens.dim() == 1: + tokens = tokens.unsqueeze(0) + features = self.extract_features(tokens.to(device=self.device)) + sentence_representation = features[ + tokens.eq(self.task.source_dictionary.eos()), : + ].view(features.size(0), -1, features.size(-1))[:, -1, :] + + logits = self.model.classification_heads[head](sentence_representation) + if return_logits: + return logits + return F.log_softmax(logits, dim=-1) + + def fill_mask( + self, + masked_inputs: List[str], + topk: int = 5, + match_source_len: bool = True, + **generate_kwargs + ): + masked_token = "" + batch_tokens = [] + for masked_input in masked_inputs: + assert ( + masked_token in masked_input + ), "please add one {} token for the input".format(masked_token) + + text_spans = masked_input.split(masked_token) + text_spans_bpe = ( + (" {0} ".format(masked_token)) + .join([self.bpe.encode(text_span.rstrip()) for text_span in text_spans]) + .strip() + ) + tokens = self.task.source_dictionary.encode_line( + " " + text_spans_bpe + " ", + append_eos=False, + add_if_not_exist=False, + ).long() + batch_tokens.append(tokens) + + # ensure beam size is at least as big as topk + generate_kwargs["beam"] = max( + topk, + generate_kwargs.get("beam", -1), + ) + generate_kwargs["match_source_len"] = match_source_len + batch_hypos = self.generate(batch_tokens, **generate_kwargs) + + return [ + [(self.decode(hypo["tokens"]), hypo["score"]) for hypo in hypos[:topk]] + for hypos in batch_hypos + ] diff --git a/fairseq/models/bart/model.py b/fairseq/models/bart/model.py new file mode 100644 index 0000000000000000000000000000000000000000..e3670c0a2c716eea87f8692ac396a0da7efd365f --- /dev/null +++ b/fairseq/models/bart/model.py @@ -0,0 +1,394 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +BART: Denoising Sequence-to-Sequence Pre-training for +Natural Language Generation, Translation, and Comprehension +""" +import logging +from typing import Optional + +import torch +import torch.nn as nn + +from fairseq import utils +from fairseq.models import register_model, register_model_architecture +from fairseq.models.transformer import TransformerModel +from fairseq.modules.transformer_sentence_encoder import init_bert_params + +from .hub_interface import BARTHubInterface + +logger = logging.getLogger(__name__) + + +@register_model("bart") +class BARTModel(TransformerModel): + __jit_unused_properties__ = ["supported_targets"] + + @classmethod + def hub_models(cls): + return { + "bart.base": "http://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz", + "bart.large": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz", + "bart.large.mnli": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz", + "bart.large.cnn": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz", + "bart.large.xsum": "http://dl.fbaipublicfiles.com/fairseq/models/bart.large.xsum.tar.gz", + } + + def __init__(self, args, encoder, decoder): + super().__init__(args, encoder, decoder) + + # We follow BERT's random weight initialization + self.apply(init_bert_params) + + self.classification_heads = nn.ModuleDict() + if hasattr(self.encoder, "dictionary"): + self.eos: int = self.encoder.dictionary.eos() + + @staticmethod + def add_args(parser): + super(BARTModel, BARTModel).add_args(parser) + parser.add_argument( + "--pooler-dropout", + type=float, + metavar="D", + help="dropout probability in the masked_lm pooler layers", + ) + parser.add_argument( + "--pooler-activation-fn", + choices=utils.get_available_activation_fns(), + help="activation function to use for pooler layer", + ) + parser.add_argument( + "--spectral-norm-classification-head", + action="store_true", + help="Apply spectral normalization on the classification head", + ) + + @property + def supported_targets(self): + return {"self"} + + def forward( + self, + src_tokens, + src_lengths, + prev_output_tokens, + features_only: bool = False, + classification_head_name: Optional[str] = None, + token_embeddings: Optional[torch.Tensor] = None, + return_all_hiddens: bool = True, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + if classification_head_name is not None: + features_only = True + + encoder_out = self.encoder( + src_tokens, + src_lengths=src_lengths, + token_embeddings=token_embeddings, + return_all_hiddens=return_all_hiddens, + ) + x, extra = self.decoder( + prev_output_tokens, + encoder_out=encoder_out, + features_only=features_only, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + src_lengths=src_lengths, + return_all_hiddens=return_all_hiddens, + ) + eos: int = self.eos + if classification_head_name is not None: + sentence_representation = x[src_tokens.eq(eos), :].view( + x.size(0), -1, x.size(-1) + )[:, -1, :] + for k, head in self.classification_heads.items(): + # for torch script only supports iteration + if k == classification_head_name: + x = head(sentence_representation) + break + return x, extra + + @classmethod + def from_pretrained( + cls, + model_name_or_path, + checkpoint_file="model.pt", + data_name_or_path=".", + bpe="gpt2", + sample_break_mode="eos", + **kwargs, + ): + from fairseq import hub_utils + + x = hub_utils.from_pretrained( + model_name_or_path, + checkpoint_file, + data_name_or_path, + archive_map=cls.hub_models(), + bpe=bpe, + load_checkpoint_heads=True, + sample_break_mode=sample_break_mode, + **kwargs, + ) + return BARTHubInterface(x["args"], x["task"], x["models"][0]) + + def register_classification_head( + self, name, num_classes=None, inner_dim=None, **kwargs + ): + """Register a classification head.""" + logger.info("Registering classification head: {0}".format(name)) + if name in self.classification_heads: + prev_num_classes = self.classification_heads[name].out_proj.out_features + prev_inner_dim = self.classification_heads[name].dense.out_features + if num_classes != prev_num_classes or inner_dim != prev_inner_dim: + logger.warning( + 're-registering head "{}" with num_classes {} (prev: {}) ' + "and inner_dim {} (prev: {})".format( + name, num_classes, prev_num_classes, inner_dim, prev_inner_dim + ) + ) + self.classification_heads[name] = BARTClassificationHead( + input_dim=self.args.encoder_embed_dim, + inner_dim=inner_dim or self.args.encoder_embed_dim, + num_classes=num_classes, + activation_fn=self.args.pooler_activation_fn, + pooler_dropout=self.args.pooler_dropout, + do_spectral_norm=getattr( + self.args, "spectral_norm_classification_head", False + ), + ) + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + + prefix = name + "." if name != "" else "" + current_head_names = ( + [] + if not hasattr(self, "classification_heads") + else self.classification_heads.keys() + ) + + # Handle new classification heads present in the state dict. + keys_to_delete = [] + for k in state_dict.keys(): + if not k.startswith(prefix + "classification_heads."): + continue + + head_name = k[len(prefix + "classification_heads.") :].split(".")[0] + num_classes = state_dict[ + prefix + "classification_heads." + head_name + ".out_proj.weight" + ].size(0) + inner_dim = state_dict[ + prefix + "classification_heads." + head_name + ".dense.weight" + ].size(0) + + if getattr(self.args, "load_checkpoint_heads", False): + if head_name not in current_head_names: + self.register_classification_head(head_name, num_classes, inner_dim) + else: + if head_name not in current_head_names: + logger.warning( + "deleting classification head ({}) from checkpoint " + "not present in current model: {}".format(head_name, k) + ) + keys_to_delete.append(k) + elif ( + num_classes + != self.classification_heads[head_name].out_proj.out_features + or inner_dim + != self.classification_heads[head_name].dense.out_features + ): + logger.warning( + "deleting classification head ({}) from checkpoint " + "with different dimensions than current model: {}".format( + head_name, k + ) + ) + keys_to_delete.append(k) + for k in keys_to_delete: + del state_dict[k] + + def truncate_emb(key): + if key in state_dict: + state_dict[key] = state_dict[key][:-1, :] + + # When finetuning on translation task, remove last row of + # embedding matrix that corresponds to mask_idx token. + loaded_dict_size = state_dict["encoder.embed_tokens.weight"].size(0) + if ( + loaded_dict_size == len(self.encoder.dictionary) + 1 + and "" not in self.encoder.dictionary + ): + truncate_emb("encoder.embed_tokens.weight") + truncate_emb("decoder.embed_tokens.weight") + truncate_emb("encoder.output_projection.weight") + truncate_emb("decoder.output_projection.weight") + + # When continued pretraining on new set of languages for mbart, + # add extra lang embeddings at the end of embed_tokens. + # Note: newly added languages are assumed to have been added at the end. + if self.args.task == "multilingual_denoising" and loaded_dict_size < len( + self.encoder.dictionary + ): + logger.info( + "Adding extra language embeddings not found in pretrained model for " + "continued pretraining of MBART on new set of languages." + ) + loaded_mask_token_embedding = state_dict["encoder.embed_tokens.weight"][ + -1, : + ] + + num_langids_to_add = len(self.encoder.dictionary) - loaded_dict_size + embed_dim = state_dict["encoder.embed_tokens.weight"].size(1) + + new_lang_embed_to_add = torch.zeros(num_langids_to_add, embed_dim) + nn.init.normal_(new_lang_embed_to_add, mean=0, std=embed_dim**-0.5) + new_lang_embed_to_add = new_lang_embed_to_add.to( + dtype=state_dict["encoder.embed_tokens.weight"].dtype, + ) + + state_dict["encoder.embed_tokens.weight"] = torch.cat( + [ + state_dict["encoder.embed_tokens.weight"][ + : loaded_dict_size - 1, : + ], + new_lang_embed_to_add, + loaded_mask_token_embedding.unsqueeze(0), + ] + ) + state_dict["decoder.embed_tokens.weight"] = torch.cat( + [ + state_dict["decoder.embed_tokens.weight"][ + : loaded_dict_size - 1, : + ], + new_lang_embed_to_add, + loaded_mask_token_embedding.unsqueeze(0), + ] + ) + + # Copy any newly-added classification heads into the state dict + # with their current weights. + if hasattr(self, "classification_heads"): + cur_state = self.classification_heads.state_dict() + for k, v in cur_state.items(): + if prefix + "classification_heads." + k not in state_dict: + logger.info("Overwriting " + prefix + "classification_heads." + k) + state_dict[prefix + "classification_heads." + k] = v + + def set_beam_size(self, beam): + """Set beam size for efficient beamable enc-dec attention.""" + beamable = False + for layer in self.decoder.layers: + if layer.encoder_attn is not None: + if hasattr(layer.encoder_attn, "set_beam_size"): + layer.encoder_attn.set_beam_size(beam) + beamable = True + if beamable: + self.encoder.reorder_encoder_out = self.encoder._reorder_encoder_out + + +class BARTClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim, + inner_dim, + num_classes, + activation_fn, + pooler_dropout, + do_spectral_norm=False, + ): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.activation_fn = utils.get_activation_fn(activation_fn) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + if do_spectral_norm: + self.out_proj = torch.nn.utils.spectral_norm(self.out_proj) + + def forward(self, features, **kwargs): + x = features + x = self.dropout(x) + x = self.dense(x) + x = self.activation_fn(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +@register_model_architecture("bart", "bart_large") +def bart_large_architecture(args): + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 1024) + args.encoder_layers = getattr(args, "encoder_layers", 12) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 12) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.relu_dropout = getattr(args, "relu_dropout", 0.0) + args.dropout = getattr(args, "dropout", 0.1) + args.max_target_positions = getattr(args, "max_target_positions", 1024) + args.max_source_positions = getattr(args, "max_source_positions", 1024) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", True + ) + args.share_all_embeddings = getattr(args, "share_all_embeddings", True) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + + args.no_scale_embedding = getattr(args, "no_scale_embedding", True) + args.layernorm_embedding = getattr(args, "layernorm_embedding", True) + + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") + args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) + + +@register_model_architecture("bart", "bart_base") +def bart_base_architecture(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * 768) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12) + bart_large_architecture(args) + + +@register_model_architecture("bart", "mbart_large") +def mbart_large_architecture(args): + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + bart_large_architecture(args) + + +@register_model_architecture("bart", "mbart_base") +def mbart_base_architecture(args): + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + bart_base_architecture(args) + + +@register_model_architecture("bart", "mbart_base_wmt20") +def mbart_base_wmt20_architecture(args): + args.layernorm_embedding = getattr(args, "layernorm_embedding", False) + mbart_base_architecture(args) diff --git a/fairseq/models/composite_encoder.py b/fairseq/models/composite_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..4e20fe3a833a2d87876cbec294ad2bebfba7f591 --- /dev/null +++ b/fairseq/models/composite_encoder.py @@ -0,0 +1,57 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .fairseq_encoder import FairseqEncoder + + +class CompositeEncoder(FairseqEncoder): + """ + A wrapper around a dictionary of :class:`FairseqEncoder` objects. + + We run forward on each encoder and return a dictionary of outputs. The first + encoder's dictionary is used for initialization. + + Args: + encoders (dict): a dictionary of :class:`FairseqEncoder` objects. + """ + + def __init__(self, encoders): + super().__init__(next(iter(encoders.values())).dictionary) + self.encoders = encoders + for key in self.encoders: + self.add_module(key, self.encoders[key]) + + def forward(self, src_tokens, src_lengths): + """ + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (LongTensor): lengths of each source sentence of shape + `(batch)` + + Returns: + dict: + the outputs from each Encoder + """ + encoder_out = {} + for key in self.encoders: + encoder_out[key] = self.encoders[key](src_tokens, src_lengths) + return encoder_out + + def reorder_encoder_out(self, encoder_out, new_order): + """Reorder encoder output according to new_order.""" + for key in self.encoders: + encoder_out[key] = self.encoders[key].reorder_encoder_out( + encoder_out[key], new_order + ) + return encoder_out + + def max_positions(self): + return min(self.encoders[key].max_positions() for key in self.encoders) + + def upgrade_state_dict(self, state_dict): + for key in self.encoders: + self.encoders[key].upgrade_state_dict(state_dict) + return state_dict diff --git a/fairseq/models/distributed_fairseq_model.py b/fairseq/models/distributed_fairseq_model.py new file mode 100644 index 0000000000000000000000000000000000000000..fd76bcd4bfdba4dce83fb5a6ef01b15f8de1fe67 --- /dev/null +++ b/fairseq/models/distributed_fairseq_model.py @@ -0,0 +1,147 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import signal +import threading + +import torch +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel + +from fairseq.distributed import ( + DistributedTimeoutWrapper, + LegacyDistributedDataParallel, + ModuleProxyWrapper, + TPUDistributedDataParallel, +) + +logger = logging.getLogger(__name__) + + +_SLOWMO_DDP_DISABLED = False +try: + from fairscale.experimental.nn.data_parallel import ( + SlowMoBaseAlgorithm, + SlowMoDistributedDataParallel, + ) +except ImportError: + _SLOWMO_DDP_DISABLED = True + + +def DistributedFairseqModel(args, model, process_group, device): + """ + Wrap a *model* to support distributed data parallel training. + + This is similar to the built-in DistributedDataParallel, but allows + additional configuration of the DistributedDataParallel class to + use, and also provides easier access to the wrapped model by + forwarding requests for missing attributes to the wrapped model. + + Args: + args (argparse.Namespace): fairseq args + model (BaseFairseqModel): model to wrap + process_group: the c10d process group to be used for distributed data + parallel all-reduction. + device: device to move model to + """ + assert isinstance(model, nn.Module) + if args.tpu: + wrapped_model = TPUDistributedDataParallel( + module=model.to(device), + process_group=process_group, + ) + # forward missing getattr and state_dict/load_state_dict to orig model + wrapped_model = ModuleProxyWrapper(wrapped_model) + elif args.ddp_backend in {"c10d", "pytorch_ddp"}: + wrapped_model = DistributedDataParallel( + module=model.to(device), + device_ids=[args.device_id], + output_device=args.device_id, + broadcast_buffers=args.broadcast_buffers, + bucket_cap_mb=args.bucket_cap_mb, + process_group=process_group, + find_unused_parameters=args.find_unused_parameters, + gradient_as_bucket_view=args.gradient_as_bucket_view, + ) + if args.ddp_comm_hook == "fp16": + logger.info("enable fp16 communication hook in DDP") + try: + from torch.distributed.algorithms.ddp_comm_hooks import ( + DDPCommHookType, + register_ddp_comm_hook, + ) + except: + logger.error( + "Could not import from torch.distributed.algorithms.ddp_comm_hooks; you may need to update your pytorch version" + ) + raise + + register_ddp_comm_hook(DDPCommHookType.FP16_COMPRESS, wrapped_model) + # forward missing getattr and state_dict/load_state_dict to orig model + wrapped_model = ModuleProxyWrapper(wrapped_model) + elif args.ddp_backend in {"no_c10d", "legacy_ddp"}: + wrapped_model = LegacyDistributedDataParallel( + module=model.to(device), + buffer_size=2**28, + process_group=process_group, + ) + # forward missing getattr and state_dict/load_state_dict to orig model + wrapped_model = ModuleProxyWrapper(wrapped_model) + elif args.ddp_backend == "slowmo": + if _SLOWMO_DDP_DISABLED: + raise ImportError( + "Cannot find SlowMoDistributedDataParallel. " + "Please install fairscale with: pip install fairscale" + ) + + # The values of slowmo_momentum below were obtained by tuning on the + # En-De 16 dataset by training the transformer_wmt_en_de_large model + if args.slowmo_momentum is None: + if args.distributed_world_size <= 16: + args.slowmo_momentum = 0.0 + elif args.distributed_world_size <= 32: + args.slowmo_momentum = 0.2 + elif args.distributed_world_size <= 64: + args.slowmo_momentum = 0.5 + else: + args.slowmo_momentum = 0.6 + slowmo_base_algorithm = SlowMoBaseAlgorithm[args.slowmo_base_algorithm.upper()] + + wrapped_model = SlowMoDistributedDataParallel( + module=model.to(device), + broadcast_buffers=args.broadcast_buffers, + nprocs_per_node=args.nprocs_per_node, + slowmo_momentum=args.slowmo_momentum, + slowmo_base_algorithm=slowmo_base_algorithm, + localsgd_frequency=args.localsgd_frequency, + ) + # forward missing getattr and state_dict/load_state_dict to orig model + wrapped_model = ModuleProxyWrapper(wrapped_model) + elif args.ddp_backend == "fully_sharded": + try: + from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP + except ImportError: + raise ImportError( + "Cannot find FullyShardedDataParallel. " + "Please install fairscale with: pip install fairscale" + ) + assert isinstance(model, FSDP), "expected model to already be wrapped in FSDP" + wrapped_model = model + if args.memory_efficient_fp16: + wrapped_model = wrapped_model.half() + if not args.cpu_offload: + wrapped_model = wrapped_model.to(device=device) + else: + raise ValueError("Unknown --ddp-backend: " + args.ddp_backend) + + # kill hung distributed jobs after a timeout + if getattr(args, "heartbeat_timeout", -1) > 0: + wrapped_model = DistributedTimeoutWrapper( + wrapped_model, timeout=getattr(args, "heartbeat_timeout", -1) + ) + + return wrapped_model diff --git a/fairseq/models/ema/__init__.py b/fairseq/models/ema/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..503ceaa609b092e48bd32a0031f4e2ffb875483f --- /dev/null +++ b/fairseq/models/ema/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + +from .ema import EMA + + +def build_ema(model, cfg, device): + return EMA(model, cfg, device) + + +# automatically import any Python files in the models/ema/ directory +for file in sorted(os.listdir(os.path.dirname(__file__))): + if file.endswith(".py") and not file.startswith("_"): + file_name = file[: file.find(".py")] + importlib.import_module("fairseq.models.ema." + file_name) diff --git a/fairseq/models/ema/__pycache__/__init__.cpython-310.pyc b/fairseq/models/ema/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ca654771405d1d89705fb804b9f9bfd41a724d3 Binary files /dev/null and b/fairseq/models/ema/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/models/ema/__pycache__/ema.cpython-310.pyc b/fairseq/models/ema/__pycache__/ema.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cb4b7364fd70cba9a469e7db41cef660cabbb97 Binary files /dev/null and b/fairseq/models/ema/__pycache__/ema.cpython-310.pyc differ diff --git a/fairseq/models/ema/ema.py b/fairseq/models/ema/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..472d5d5f17d4e2f1c06b4d4ae93526297e9abd33 --- /dev/null +++ b/fairseq/models/ema/ema.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 + +""" +This module has the EMA class used to store a copy of the exponentially decayed +model params. + +Typical usage of EMA class involves initializing an object using an existing +model (random or from a seed model) and setting the config like ema_decay, +ema_start_update which determine how the EMA model is updated. After every +update of the model i.e. at the end of the train_step, the EMA should be updated +by passing the new model to the EMA.step function. The EMA model state dict +can be stored in the extra state under the key of "ema" and dumped +into a checkpoint and loaded. The EMA object can be passed to tasks +by setting task.uses_ema property. +EMA is a smoothed/ensemble model which might have better performance +when used for inference or further fine-tuning. EMA class has a +reverse function to load the EMA params into a model and use it +like a regular model. + +This implementation is used for trainer-level ema tracking. For EMA tracking +inside the model, please use fairseq/modules/ema_module.py instead. +""" + +import copy +import logging + +import torch + +from fairseq import checkpoint_utils + + +class EMA(object): + """Exponential Moving Average of Fairseq Models + EMA keeps a copy of the exponentially decayed model params. + The set of params should include both gradient-descent and + non-gradient descent params, such as batch mean/var and buffers. + This is a modified implementation of + the open source code in https://github.com/zhawe01/fairseq-gec.git, + and internal source code in + fbcode/mobile-vision/projects/classification_pytorch/lib/utils/model_ema.py. + + Similar to TF EMA. + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage. + EMA provides a averaged and smoothed set of model weights, and has been shown to + improve vision models. EMA class does all necessary functions to update, reload, + or init EMA methods. + + EMA object is initialized from an arbitrary model. By default, it is stored in + the same device (unless device specified at initialization) and with the + same precision as the model (unless ema_fp32 is True). ema_fp32 is recommended. + This stores the EMA parameters in fp32 only for the EMA update step, and + is used at the default precision otherwise. + EMA is usually enabled using EMAConfig with store_ema=True. Some important + parameters to configure EMA are + 1) ema_decay - The decay of EMA + 2) ema_update_freq - EMA is updated every this many model updates. + 3) ema_start_update - Start EMA update after this many model updates [default 0] + + Key methods: + 1) step - One update of EMA using new model + 2) restore - Update EMA from a state dict + 3) reverse - Load EMA into a model + 4) get_decay, _set_decay - Used to get or set the decay. Note _set_decay is + called from step. + 5) build_fp32_params - Used to initialize or update the fp32 copy of EMA params. + Note this is enabled only when ema_fp32=True + """ + + def __init__(self, model, config, device=None, skip_keys=None): + """ + @param model model to initialize the EMA with + @param config EMAConfig object with configuration like + ema_decay, ema_update_freq, ema_fp32 + @param device If provided, copy EMA to this device (e.g. gpu). + Otherwise EMA is in the same device as the model. + """ + + self.decay = config.ema_decay + self.model = copy.deepcopy(model) + self.model.requires_grad_(False) + self.config = config + self.skip_keys = skip_keys or set() + self.fp32_params = {} + + if self.config.ema_seed_model is not None: + state = checkpoint_utils.load_ema_from_checkpoint( + self.config.ema_seed_model + ) + self.model.load_state_dict(state["model"], strict=True) + + if device is not None: + logging.info(f"Copying EMA model to device {device}") + self.model = self.model.to(device=device) + + if self.config.ema_fp32: + self.build_fp32_params() + + self.update_freq_counter = 0 + + def get_model(self): + return self.model + + def build_fp32_params(self, state_dict=None): + """ + Store a copy of the EMA params in fp32. + If state dict is passed, the EMA params is copied from + the provided state dict. Otherwise, it is copied from the + current EMA model parameters. + """ + if not self.config.ema_fp32: + raise RuntimeError( + "build_fp32_params should not be called if ema_fp32=False. " + "Use ema_fp32=True if this is really intended." + ) + + if state_dict is None: + state_dict = self.model.state_dict() + + def _to_float(t): + return t.float() if torch.is_floating_point(t) else t + + for param_key in state_dict: + if param_key in self.fp32_params: + self.fp32_params[param_key].copy_(state_dict[param_key]) + else: + self.fp32_params[param_key] = _to_float(state_dict[param_key]) + + def restore(self, state_dict, build_fp32_params=False): + """Load data from a model spec into EMA model""" + self.model.load_state_dict(state_dict, strict=False) + if build_fp32_params: + self.build_fp32_params(state_dict) + + def _set_decay(self, decay): + self.decay = decay + + def get_decay(self): + return self.decay + + def _step_internal(self, new_model, updates=None): + """One update of the EMA model based on new model weights""" + decay = self.decay + + ema_state_dict = {} + ema_params = ( + self.fp32_params if self.config.ema_fp32 else self.model.state_dict() + ) + for key, param in new_model.state_dict().items(): + if isinstance(param, dict): + continue + try: + ema_param = ema_params[key] + except KeyError: + ema_param = ( + param.float().clone() if param.ndim == 1 else copy.deepcopy(param) + ) + + if param.shape != ema_param.shape: + raise ValueError( + "incompatible tensor shapes between model param and ema param" + + "{} vs. {}".format(param.shape, ema_param.shape) + ) + + if "version" in key: + # Do not decay a model.version pytorch param + continue + + if key in self.skip_keys: + ema_param = param.to(dtype=ema_param.dtype).clone() + else: + ema_param.mul_(decay) + ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1 - decay) + ema_state_dict[key] = ema_param + self.restore(ema_state_dict, build_fp32_params=False) + + def step(self, new_model, updates=None): + """ + One update of EMA which is done every self.config.ema_update_freq + updates of the model. + + @param updates The current number of model updates done. + Decay is set of 0 if model updates < ema_start_update, which means + the model will be simply copied over to the EMA. + When model updates >= ema_start_updates, then EMA is updated with + a decay of self.config.ema_decay. + """ + if updates is not None: + self._set_decay( + 0 if updates < self.config.ema_start_update else self.config.ema_decay + ) + if self.config.ema_update_freq > 1: + self.update_freq_counter += 1 + if self.update_freq_counter >= self.config.ema_update_freq: + self._step_internal(new_model, updates) + self.update_freq_counter = 0 + else: + self._step_internal(new_model, updates) + + def reverse(self, model): + """ + Load the model parameters from EMA model. + Useful for inference or fine-tuning from the EMA model. + """ + d = self.model.state_dict() + if "_ema" in d: + del d["_ema"] + + model.load_state_dict(d, strict=False) + return model diff --git a/fairseq/models/fairseq_decoder.py b/fairseq/models/fairseq_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..13b73d639e9179d3de34081d5cc5ec73cb9a1f75 --- /dev/null +++ b/fairseq/models/fairseq_decoder.py @@ -0,0 +1,104 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Optional, Tuple + +import torch.nn as nn +from fairseq import utils +from torch import Tensor + + +class FairseqDecoder(nn.Module): + """Base class for decoders.""" + + def __init__(self, dictionary): + super().__init__() + self.dictionary = dictionary + self.onnx_trace = False + self.adaptive_softmax = None + + def forward(self, prev_output_tokens, encoder_out=None, **kwargs): + """ + Args: + prev_output_tokens (LongTensor): shifted output tokens of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (dict, optional): output from the encoder, used for + encoder-side attention + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + x, extra = self.extract_features( + prev_output_tokens, encoder_out=encoder_out, **kwargs + ) + x = self.output_layer(x) + return x, extra + + def extract_features(self, prev_output_tokens, encoder_out=None, **kwargs): + """ + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + raise NotImplementedError + + def output_layer(self, features, **kwargs): + """ + Project features to the default output size, e.g., vocabulary size. + + Args: + features (Tensor): features returned by *extract_features*. + """ + raise NotImplementedError + + def get_normalized_probs( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Get normalized probabilities (or log probs) from a net's output.""" + return self.get_normalized_probs_scriptable(net_output, log_probs, sample) + + # TorchScript doesn't support super() method so that the scriptable Subclass + # can't access the base class model in Torchscript. + # Current workaround is to add a helper function with different name and + # call the helper function from scriptable Subclass. + def get_normalized_probs_scriptable( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Get normalized probabilities (or log probs) from a net's output.""" + + if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None: + if sample is not None: + assert "target" in sample + target = sample["target"] + else: + target = None + out = self.adaptive_softmax.get_log_prob(net_output[0], target=target) + return out.exp_() if not log_probs else out + + logits = net_output[0] + if log_probs: + return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) + else: + return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace) + + def max_positions(self): + """Maximum input length supported by the decoder.""" + return 1e6 # an arbitrary large number + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade old state dicts to work with newer code.""" + return state_dict + + def prepare_for_onnx_export_(self): + self.onnx_trace = True diff --git a/fairseq/models/fairseq_encoder.py b/fairseq/models/fairseq_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..08cbde15a46e9b6d58e11c2f6052e7cf2d0cc8b2 --- /dev/null +++ b/fairseq/models/fairseq_encoder.py @@ -0,0 +1,92 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, NamedTuple, Optional + +import torch +import torch.nn as nn +from torch import Tensor + + +EncoderOut = NamedTuple( + "EncoderOut", + [ + ("encoder_out", Tensor), # T x B x C + ("encoder_padding_mask", Optional[Tensor]), # B x T + ("encoder_embedding", Optional[Tensor]), # B x T x C + ("encoder_states", Optional[List[Tensor]]), # List[T x B x C] + ("src_tokens", Optional[Tensor]), # B x T + ("src_lengths", Optional[Tensor]), # B x 1 + ], +) + + +class FairseqEncoder(nn.Module): + """Base class for encoders.""" + + def __init__(self, dictionary): + super().__init__() + self.dictionary = dictionary + + def forward(self, src_tokens, src_lengths=None, **kwargs): + """ + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (LongTensor): lengths of each source sentence of shape + `(batch)` + """ + raise NotImplementedError + + def forward_torchscript(self, net_input: Dict[str, Tensor]): + """A TorchScript-compatible version of forward. + + Encoders which use additional arguments may want to override + this method for TorchScript compatibility. + """ + if torch.jit.is_scripting(): + return self.forward( + src_tokens=net_input["src_tokens"], + src_lengths=net_input["src_lengths"], + ) + else: + return self.forward_non_torchscript(net_input) + + @torch.jit.unused + def forward_non_torchscript(self, net_input: Dict[str, Tensor]): + encoder_input = { + k: v for k, v in net_input.items() if k != "prev_output_tokens" + } + return self.forward(**encoder_input) + + def reorder_encoder_out(self, encoder_out, new_order): + """ + Reorder encoder output according to `new_order`. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + `encoder_out` rearranged according to `new_order` + """ + raise NotImplementedError + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return 1e6 # an arbitrary large number + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade old state dicts to work with newer code.""" + return state_dict + + def set_num_updates(self, num_updates): + """State from trainer to pass along to model at every update.""" + + def _apply(m): + if hasattr(m, "set_num_updates") and m != self: + m.set_num_updates(num_updates) + + self.apply(_apply) diff --git a/fairseq/models/fairseq_incremental_decoder.py b/fairseq/models/fairseq_incremental_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..cc72a0f8f3da238a8ce846240e5008d91ce1bc1a --- /dev/null +++ b/fairseq/models/fairseq_incremental_decoder.py @@ -0,0 +1,118 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Dict, Optional + +from fairseq.incremental_decoding_utils import with_incremental_state +from fairseq.models import FairseqDecoder +from torch import Tensor + + +logger = logging.getLogger(__name__) + + +@with_incremental_state +class FairseqIncrementalDecoder(FairseqDecoder): + """Base class for incremental decoders. + + Incremental decoding is a special mode at inference time where the Model + only receives a single timestep of input corresponding to the previous + output token (for teacher forcing) and must produce the next output + *incrementally*. Thus the model must cache any long-term state that is + needed about the sequence, e.g., hidden states, convolutional states, etc. + + Compared to the standard :class:`FairseqDecoder` interface, the incremental + decoder interface allows :func:`forward` functions to take an extra keyword + argument (*incremental_state*) that can be used to cache state across + time-steps. + + The :class:`FairseqIncrementalDecoder` interface also defines the + :func:`reorder_incremental_state` method, which is used during beam search + to select and reorder the incremental state based on the selection of beams. + + To learn more about how incremental decoding works, refer to `this blog + `_. + """ + + def __init__(self, dictionary): + super().__init__(dictionary) + + def forward( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs + ): + """ + Args: + prev_output_tokens (LongTensor): shifted output tokens of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (dict, optional): output from the encoder, used for + encoder-side attention + incremental_state (dict, optional): dictionary used for storing + state during :ref:`Incremental decoding` + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + raise NotImplementedError + + def extract_features( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs + ): + """ + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + raise NotImplementedError + + def reorder_incremental_state( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + new_order: Tensor, + ): + """Reorder incremental state. + + This will be called when the order of the input has changed from the + previous time step. A typical use case is beam search, where the input + order changes between time steps based on the selection of beams. + """ + pass + + def reorder_incremental_state_scripting( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + new_order: Tensor, + ): + """Main entry point for reordering the incremental state. + + Due to limitations in TorchScript, we call this function in + :class:`fairseq.sequence_generator.SequenceGenerator` instead of + calling :func:`reorder_incremental_state` directly. + """ + for module in self.modules(): + if hasattr(module, "reorder_incremental_state"): + result = module.reorder_incremental_state(incremental_state, new_order) + if result is not None: + incremental_state = result + + def set_beam_size(self, beam_size): + """Sets the beam size in the decoder and all children.""" + if getattr(self, "_beam_size", -1) != beam_size: + seen = set() + + def apply_set_beam_size(module): + if ( + module != self + and hasattr(module, "set_beam_size") + and module not in seen + ): + seen.add(module) + module.set_beam_size(beam_size) + + self.apply(apply_set_beam_size) + self._beam_size = beam_size diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py new file mode 100644 index 0000000000000000000000000000000000000000..65ead9dcf2d0169ac90601de0fa17f0a8443ff46 --- /dev/null +++ b/fairseq/models/fairseq_model.py @@ -0,0 +1,579 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Base classes for various fairseq models. +""" + +import logging +from argparse import Namespace +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import utils +from fairseq.data import Dictionary +from fairseq.dataclass.utils import ( + convert_namespace_to_omegaconf, + gen_parser_from_dataclass, +) +from fairseq.models import FairseqDecoder, FairseqEncoder +from omegaconf import DictConfig +from torch import Tensor + + +logger = logging.getLogger(__name__) + + +def check_type(module, expected_type): + if hasattr(module, "unwrapped_module"): + assert isinstance( + module.unwrapped_module, expected_type + ), f"{type(module.unwrapped_module)} != {expected_type}" + else: + assert isinstance(module, expected_type), f"{type(module)} != {expected_type}" + + +class BaseFairseqModel(nn.Module): + """Base class for fairseq models.""" + + def __init__(self): + super().__init__() + self._is_generation_fast = False + + @classmethod + def add_args(cls, parser): + """Add model-specific arguments to the parser.""" + dc = getattr(cls, "__dataclass", None) + if dc is not None: + # do not set defaults so that settings defaults from various architectures still works + gen_parser_from_dataclass(parser, dc(), delete_default=True) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + raise NotImplementedError("Model must implement the build_model method") + + def get_targets(self, sample, net_output): + """Get targets from either the sample or the net's output.""" + return sample["target"] + + def get_normalized_probs( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Get normalized probabilities (or log probs) from a net's output.""" + return self.get_normalized_probs_scriptable(net_output, log_probs, sample) + + # TorchScript doesn't support super() method so that the scriptable Subclass + # can't access the base class model in Torchscript. + # Current workaround is to add a helper function with different name and + # call the helper function from scriptable Subclass. + def get_normalized_probs_scriptable( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Scriptable helper function for get_normalized_probs in ~BaseFairseqModel""" + if hasattr(self, "decoder"): + return self.decoder.get_normalized_probs(net_output, log_probs, sample) + elif torch.is_tensor(net_output): + # syntactic sugar for simple models which don't have a decoder + # (e.g., the classification tutorial) + logits = net_output.float() + if log_probs: + return F.log_softmax(logits, dim=-1) + else: + return F.softmax(logits, dim=-1) + raise NotImplementedError + + def extract_features(self, *args, **kwargs): + """Similar to *forward* but only return features.""" + return self(*args, **kwargs) + + def max_positions(self): + """Maximum length supported by the model.""" + return None + + def load_state_dict( + self, + state_dict, + strict=True, + model_cfg: Optional[DictConfig] = None, + args: Optional[Namespace] = None, + ): + """Copies parameters and buffers from *state_dict* into this module and + its descendants. + + Overrides the method in :class:`nn.Module`. Compared with that method + this additionally "upgrades" *state_dicts* from old checkpoints. + """ + + if model_cfg is None and args is not None: + logger.warn( + "using 'args' is deprecated, please update your code to use dataclass config" + ) + model_cfg = convert_namespace_to_omegaconf(args).model + + self.upgrade_state_dict(state_dict) + + from fairseq.checkpoint_utils import prune_state_dict + + new_state_dict = prune_state_dict(state_dict, model_cfg) + return super().load_state_dict(new_state_dict, strict) + + def upgrade_state_dict(self, state_dict): + """Upgrade old state dicts to work with newer code.""" + self.upgrade_state_dict_named(state_dict, "") + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade old state dicts to work with newer code. + + Args: + state_dict (dict): state dictionary to upgrade, in place + name (str): the state dict key corresponding to the current module + """ + assert state_dict is not None + + def do_upgrade(m, prefix): + if len(prefix) > 0: + prefix += "." + + for n, c in m.named_children(): + name = prefix + n + if hasattr(c, "upgrade_state_dict_named"): + c.upgrade_state_dict_named(state_dict, name) + elif hasattr(c, "upgrade_state_dict"): + c.upgrade_state_dict(state_dict) + do_upgrade(c, name) + + do_upgrade(self, name) + + def set_num_updates(self, num_updates): + """State from trainer to pass along to model at every update.""" + for m in self.modules(): + if hasattr(m, "set_num_updates") and m != self: + m.set_num_updates(num_updates) + + def set_epoch(self, epoch): + for m in self.modules(): + if hasattr(m, "set_epoch") and m != self: + m.set_epoch(epoch) + + def prepare_for_inference_(self, cfg: DictConfig): + """Prepare model for inference.""" + kwargs = {} + kwargs["beamable_mm_beam_size"] = ( + None + if getattr(cfg.generation, "no_beamable_mm", False) + else getattr(cfg.generation, "beam", 5) + ) + kwargs["need_attn"] = getattr(cfg.generation, "print_alignment", False) + if getattr(cfg.generation, "retain_dropout", False): + kwargs["retain_dropout"] = cfg.generation.retain_dropout + kwargs["retain_dropout_modules"] = cfg.generation.retain_dropout_modules + self.make_generation_fast_(**kwargs) + + def make_generation_fast_(self, **kwargs): + """ + Legacy entry point to optimize model for faster generation. + Prefer prepare_for_inference_. + """ + if self._is_generation_fast: + return # only apply once + self._is_generation_fast = True + + # remove weight norm from all modules in the network + def apply_remove_weight_norm(module): + try: + nn.utils.remove_weight_norm(module) + except (AttributeError, ValueError): # this module didn't have weight norm + return + + self.apply(apply_remove_weight_norm) + + def apply_make_generation_fast_(module, prefix): + if len(prefix) > 0: + prefix += "." + + base_func = BaseFairseqModel.make_generation_fast_ + for n, m in module.named_modules(): + if ( + m != self + and hasattr(m, "make_generation_fast_") + # don't call this implementation again, e.g., if + # children modules also inherit from BaseFairseqModel + and m.make_generation_fast_.__func__ is not base_func + ): + name = prefix + n + m.make_generation_fast_(name=name, **kwargs) + + apply_make_generation_fast_(self, "") + + def train(mode=True): + if mode: + raise RuntimeError("cannot train after make_generation_fast") + + # this model should no longer be used for training + self.eval() + self.train = train + + def prepare_for_onnx_export_(self, **kwargs): + """Make model exportable via ONNX trace.""" + seen = set() + + def apply_prepare_for_onnx_export_(module): + if ( + module != self + and hasattr(module, "prepare_for_onnx_export_") + and module not in seen + ): + seen.add(module) + module.prepare_for_onnx_export_(**kwargs) + + self.apply(apply_prepare_for_onnx_export_) + + @classmethod + def from_pretrained( + cls, + model_name_or_path, + checkpoint_file="model.pt", + data_name_or_path=".", + **kwargs, + ): + """ + Load a :class:`~fairseq.models.FairseqModel` from a pre-trained model + file. Downloads and caches the pre-trained model file if needed. + + The base implementation returns a + :class:`~fairseq.hub_utils.GeneratorHubInterface`, which can be used to + generate translations or sample from language models. The underlying + :class:`~fairseq.models.FairseqModel` can be accessed via the + *generator.models* attribute. + + Other models may override this to implement custom hub interfaces. + + Args: + model_name_or_path (str): either the name of a pre-trained model to + load or a path/URL to a pre-trained model state dict + checkpoint_file (str, optional): colon-separated list of checkpoint + files in the model archive to ensemble (default: 'model.pt') + data_name_or_path (str, optional): point args.data to the archive + at the given path/URL. Can start with '.' or './' to reuse the + model archive path. + """ + from fairseq import hub_utils + + x = hub_utils.from_pretrained( + model_name_or_path, + checkpoint_file, + data_name_or_path, + archive_map=cls.hub_models(), + **kwargs, + ) + logger.info(x["args"]) + return hub_utils.GeneratorHubInterface(x["args"], x["task"], x["models"]) + + @classmethod + def hub_models(cls): + return {} + + +class FairseqEncoderDecoderModel(BaseFairseqModel): + """Base class for encoder-decoder models. + + Args: + encoder (FairseqEncoder): the encoder + decoder (FairseqDecoder): the decoder + """ + + def __init__(self, encoder, decoder): + super().__init__() + + self.encoder = encoder + self.decoder = decoder + + check_type(self.encoder, FairseqEncoder) + check_type(self.decoder, FairseqDecoder) + + def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): + """ + Run the forward pass for an encoder-decoder model. + + First feed a batch of source tokens through the encoder. Then, feed the + encoder output and previous decoder outputs (i.e., teacher forcing) to + the decoder to produce the next outputs:: + + encoder_out = self.encoder(src_tokens, src_lengths) + return self.decoder(prev_output_tokens, encoder_out) + + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (LongTensor): source sentence lengths of shape `(batch)` + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) + decoder_out = self.decoder( + prev_output_tokens, encoder_out=encoder_out, **kwargs + ) + return decoder_out + + def forward_decoder(self, prev_output_tokens, **kwargs): + return self.decoder(prev_output_tokens, **kwargs) + + def extract_features(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): + """ + Similar to *forward* but only return features. + + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) + features = self.decoder.extract_features( + prev_output_tokens, encoder_out=encoder_out, **kwargs + ) + return features + + def output_layer(self, features, **kwargs): + """Project features to the default output size (typically vocabulary size).""" + return self.decoder.output_layer(features, **kwargs) + + def max_positions(self): + """Maximum length supported by the model.""" + return (self.encoder.max_positions(), self.decoder.max_positions()) + + def max_decoder_positions(self): + """Maximum length supported by the decoder.""" + return self.decoder.max_positions() + + +class FairseqModel(FairseqEncoderDecoderModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + utils.deprecation_warning( + "FairseqModel is deprecated, please use FairseqEncoderDecoderModel " + "or BaseFairseqModel instead", + stacklevel=4, + ) + + +class FairseqMultiModel(BaseFairseqModel): + """Base class for combining multiple encoder-decoder models.""" + + def __init__(self, encoders, decoders): + super().__init__() + assert encoders.keys() == decoders.keys() + self.keys = list(encoders.keys()) + for key in self.keys: + check_type(encoders[key], FairseqEncoder) + check_type(decoders[key], FairseqDecoder) + + self.models = nn.ModuleDict( + { + key: FairseqEncoderDecoderModel(encoders[key], decoders[key]) + for key in self.keys + } + ) + + @staticmethod + def build_shared_embeddings( + dicts: Dict[str, Dictionary], + langs: List[str], + embed_dim: int, + build_embedding: callable, + pretrained_embed_path: Optional[str] = None, + ): + """ + Helper function to build shared embeddings for a set of languages after + checking that all dicts corresponding to those languages are equivalent. + + Args: + dicts: Dict of lang_id to its corresponding Dictionary + langs: languages that we want to share embeddings for + embed_dim: embedding dimension + build_embedding: callable function to actually build the embedding + pretrained_embed_path: Optional path to load pretrained embeddings + """ + shared_dict = dicts[langs[0]] + if any(dicts[lang] != shared_dict for lang in langs): + raise ValueError( + "--share-*-embeddings requires a joined dictionary: " + "--share-encoder-embeddings requires a joined source " + "dictionary, --share-decoder-embeddings requires a joined " + "target dictionary, and --share-all-embeddings requires a " + "joint source + target dictionary." + ) + return build_embedding(shared_dict, embed_dim, pretrained_embed_path) + + def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): + raise NotImplementedError + + def max_positions(self): + """Maximum length supported by the model.""" + return { + key: ( + self.models[key].encoder.max_positions(), + self.models[key].decoder.max_positions(), + ) + for key in self.keys + } + + def max_decoder_positions(self): + """Maximum length supported by the decoder.""" + return min(model.decoder.max_positions() for model in self.models.values()) + + @property + def encoder(self): + return self.models[self.keys[0]].encoder + + @property + def decoder(self): + return self.models[self.keys[0]].decoder + + def forward_decoder(self, prev_output_tokens, **kwargs): + return self.decoder(prev_output_tokens, **kwargs) + + def load_state_dict( + self, + state_dict, + strict=True, + model_cfg=None, + args: Optional[Namespace] = None, + ): + """Copies parameters and buffers from *state_dict* into this module and + its descendants. + + Overrides the method in :class:`nn.Module`. Compared with that method + this additionally "upgrades" *state_dicts* from old checkpoints. + """ + + if model_cfg is None and args is not None: + logger.warn( + "using 'args' is deprecated, please update your code to use dataclass config" + ) + model_cfg = convert_namespace_to_omegaconf(args).model + + self.upgrade_state_dict(state_dict) + + from fairseq.checkpoint_utils import prune_state_dict + + new_state_dict = prune_state_dict(state_dict, model_cfg) + return super().load_state_dict(new_state_dict, strict) + + +class FairseqLanguageModel(BaseFairseqModel): + """Base class for decoder-only models. + + Args: + decoder (FairseqDecoder): the decoder + """ + + def __init__(self, decoder): + super().__init__() + self.decoder = decoder + check_type(self.decoder, FairseqDecoder) + + def forward(self, src_tokens, **kwargs): + """ + Run the forward pass for a decoder-only model. + + Feeds a batch of tokens through the decoder to predict the next tokens. + + Args: + src_tokens (LongTensor): tokens on which to condition the decoder, + of shape `(batch, tgt_len)` + src_lengths (LongTensor): source sentence lengths of shape `(batch)` + + Returns: + tuple: + - the decoder's output of shape `(batch, seq_len, vocab)` + - a dictionary with any model-specific outputs + """ + return self.decoder(src_tokens, **kwargs) + + def forward_decoder(self, prev_output_tokens, **kwargs): + return self.decoder(prev_output_tokens, **kwargs) + + def extract_features(self, src_tokens, **kwargs): + """ + Similar to *forward* but only return features. + + Returns: + tuple: + - the decoder's features of shape `(batch, seq_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + return self.decoder.extract_features(src_tokens, **kwargs) + + def output_layer(self, features, **kwargs): + """Project features to the default output size (typically vocabulary size).""" + return self.decoder.output_layer(features, **kwargs) + + def max_positions(self): + """Maximum length supported by the model.""" + return self.decoder.max_positions() + + def max_decoder_positions(self): + """Maximum length supported by the decoder.""" + return self.decoder.max_positions() + + @property + def supported_targets(self): + return {"future"} + + +class FairseqEncoderModel(BaseFairseqModel): + """Base class for encoder-only models. + + Args: + encoder (FairseqEncoder): the encoder + """ + + def __init__(self, encoder): + super().__init__() + self.encoder = encoder + check_type(self.encoder, FairseqEncoder) + + def forward(self, src_tokens, src_lengths, **kwargs): + """ + Run the forward pass for a encoder-only model. + + Feeds a batch of tokens through the encoder to generate features. + + Args: + src_tokens (LongTensor): input tokens of shape `(batch, src_len)` + src_lengths (LongTensor): source sentence lengths of shape `(batch)` + + Returns: + the encoder's output, typically of shape `(batch, src_len, features)` + """ + return self.encoder(src_tokens, src_lengths, **kwargs) + + def get_normalized_probs(self, net_output, log_probs, sample=None): + """Get normalized probabilities (or log probs) from a net's output.""" + encoder_out = net_output["encoder_out"] + if torch.is_tensor(encoder_out): + logits = encoder_out.float() + if log_probs: + return F.log_softmax(logits, dim=-1) + else: + return F.softmax(logits, dim=-1) + raise NotImplementedError + + def max_positions(self): + """Maximum length supported by the model.""" + return self.encoder.max_positions() diff --git a/fairseq/models/fconv.py b/fairseq/models/fconv.py new file mode 100644 index 0000000000000000000000000000000000000000..c99a2151014d816ec9aff6f4b27d71224dd7b4cf --- /dev/null +++ b/fairseq/models/fconv.py @@ -0,0 +1,756 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import utils +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderDecoderModel, + FairseqIncrementalDecoder, + register_model, + register_model_architecture, +) +from fairseq.modules import ( + AdaptiveSoftmax, + BeamableMM, + FairseqDropout, + GradMultiply, + LearnedPositionalEmbedding, + LinearizedConvolution, +) + + +@register_model("fconv") +class FConvModel(FairseqEncoderDecoderModel): + """ + A fully convolutional model, i.e. a convolutional encoder and a + convolutional decoder, as described in `"Convolutional Sequence to Sequence + Learning" (Gehring et al., 2017) `_. + + Args: + encoder (FConvEncoder): the encoder + decoder (FConvDecoder): the decoder + + The Convolutional model provides the following named architectures and + command-line arguments: + + .. argparse:: + :ref: fairseq.models.fconv_parser + :prog: + """ + + @classmethod + def hub_models(cls): + def moses_subword(path): + return { + "path": path, + "tokenizer": "moses", + "bpe": "subword_nmt", + } + + return { + "conv.wmt14.en-fr": moses_subword( + "https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2" + ), + "conv.wmt14.en-de": moses_subword( + "https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2" + ), + "conv.wmt17.en-de": moses_subword( + "https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2" + ), + } + + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + self.encoder.num_attention_layers = sum( + layer is not None for layer in decoder.attention + ) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--dropout', type=float, metavar='D', + help='dropout probability') + parser.add_argument('--encoder-embed-dim', type=int, metavar='N', + help='encoder embedding dimension') + parser.add_argument('--encoder-embed-path', type=str, metavar='STR', + help='path to pre-trained encoder embedding') + parser.add_argument('--encoder-layers', type=str, metavar='EXPR', + help='encoder layers [(dim, kernel_size), ...]') + parser.add_argument('--decoder-embed-dim', type=int, metavar='N', + help='decoder embedding dimension') + parser.add_argument('--decoder-embed-path', type=str, metavar='STR', + help='path to pre-trained decoder embedding') + parser.add_argument('--decoder-layers', type=str, metavar='EXPR', + help='decoder layers [(dim, kernel_size), ...]') + parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N', + help='decoder output embedding dimension') + parser.add_argument('--decoder-attention', type=str, metavar='EXPR', + help='decoder attention [True, ...]') + parser.add_argument('--share-input-output-embed', action='store_true', + help='share input and output embeddings (requires' + ' --decoder-out-embed-dim and --decoder-embed-dim' + ' to be equal)') + # fmt: on + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + # make sure that all args are properly defaulted (in case there are any new ones) + base_architecture(args) + + encoder_embed_dict = None + if args.encoder_embed_path: + encoder_embed_dict = utils.parse_embedding(args.encoder_embed_path) + utils.print_embed_overlap(encoder_embed_dict, task.source_dictionary) + + decoder_embed_dict = None + if args.decoder_embed_path: + decoder_embed_dict = utils.parse_embedding(args.decoder_embed_path) + utils.print_embed_overlap(decoder_embed_dict, task.target_dictionary) + + encoder = FConvEncoder( + dictionary=task.source_dictionary, + embed_dim=args.encoder_embed_dim, + embed_dict=encoder_embed_dict, + convolutions=eval(args.encoder_layers), + dropout=args.dropout, + max_positions=args.max_source_positions, + ) + decoder = FConvDecoder( + dictionary=task.target_dictionary, + embed_dim=args.decoder_embed_dim, + embed_dict=decoder_embed_dict, + convolutions=eval(args.decoder_layers), + out_embed_dim=args.decoder_out_embed_dim, + attention=eval(args.decoder_attention), + dropout=args.dropout, + max_positions=args.max_target_positions, + share_embed=args.share_input_output_embed, + ) + return FConvModel(encoder, decoder) + + +class FConvEncoder(FairseqEncoder): + """ + Convolutional encoder consisting of `len(convolutions)` layers. + + Args: + dictionary (~fairseq.data.Dictionary): encoding dictionary + embed_dim (int, optional): embedding dimension + embed_dict (str, optional): filename from which to load pre-trained + embeddings + max_positions (int, optional): maximum supported input sequence length + convolutions (list, optional): the convolutional layer structure. Each + list item `i` corresponds to convolutional layer `i`. Layers are + given as ``(out_channels, kernel_width, [residual])``. Residual + connections are added between layers when ``residual=1`` (which is + the default behavior). + dropout (float, optional): dropout to be applied before each conv layer + """ + + def __init__( + self, + dictionary, + embed_dim=512, + embed_dict=None, + max_positions=1024, + convolutions=((512, 3),) * 20, + dropout=0.1, + ): + super().__init__(dictionary) + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) + self.num_attention_layers = None + + num_embeddings = len(dictionary) + self.padding_idx = dictionary.pad() + self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx) + if embed_dict: + self.embed_tokens = utils.load_embedding( + embed_dict, self.dictionary, self.embed_tokens + ) + + self.embed_positions = PositionalEmbedding( + max_positions, + embed_dim, + self.padding_idx, + ) + + convolutions = extend_conv_spec(convolutions) + in_channels = convolutions[0][0] + self.fc1 = Linear(embed_dim, in_channels, dropout=dropout) + self.projections = nn.ModuleList() + self.convolutions = nn.ModuleList() + self.residuals = [] + + layer_in_channels = [in_channels] + for _, (out_channels, kernel_size, residual) in enumerate(convolutions): + if residual == 0: + residual_dim = out_channels + else: + residual_dim = layer_in_channels[-residual] + self.projections.append( + Linear(residual_dim, out_channels) + if residual_dim != out_channels + else None + ) + if kernel_size % 2 == 1: + padding = kernel_size // 2 + else: + padding = 0 + self.convolutions.append( + ConvTBC( + in_channels, + out_channels * 2, + kernel_size, + dropout=dropout, + padding=padding, + ) + ) + self.residuals.append(residual) + in_channels = out_channels + layer_in_channels.append(out_channels) + self.fc2 = Linear(in_channels, embed_dim) + + def forward(self, src_tokens, src_lengths): + """ + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (LongTensor): lengths of each source sentence of shape + `(batch)` + + Returns: + dict: + - **encoder_out** (tuple): a tuple with two elements, where the + first element is the last encoder layer's output and the + second element is the same quantity summed with the input + embedding (used for attention). The shape of both tensors is + `(batch, src_len, embed_dim)`. + - **encoder_padding_mask** (ByteTensor): the positions of + padding elements of shape `(batch, src_len)` + """ + # embed tokens and positions + x = self.embed_tokens(src_tokens) + self.embed_positions(src_tokens) + x = self.dropout_module(x) + input_embedding = x + + # project to size of convolution + x = self.fc1(x) + + # used to mask padding in input + encoder_padding_mask = src_tokens.eq(self.padding_idx).t() # -> T x B + if not encoder_padding_mask.any(): + encoder_padding_mask = None + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + residuals = [x] + # temporal convolutions + for proj, conv, res_layer in zip( + self.projections, self.convolutions, self.residuals + ): + if res_layer > 0: + residual = residuals[-res_layer] + residual = residual if proj is None else proj(residual) + else: + residual = None + + if encoder_padding_mask is not None: + x = x.masked_fill(encoder_padding_mask.unsqueeze(-1), 0) + + x = self.dropout_module(x) + if conv.kernel_size[0] % 2 == 1: + # padding is implicit in the conv + x = conv(x) + else: + padding_l = (conv.kernel_size[0] - 1) // 2 + padding_r = conv.kernel_size[0] // 2 + x = F.pad(x, (0, 0, 0, 0, padding_l, padding_r)) + x = conv(x) + x = F.glu(x, dim=2) + + if residual is not None: + x = (x + residual) * math.sqrt(0.5) + residuals.append(x) + + # T x B x C -> B x T x C + x = x.transpose(1, 0) + + # project back to size of embedding + x = self.fc2(x) + + if encoder_padding_mask is not None: + encoder_padding_mask = encoder_padding_mask.t() # -> B x T + x = x.masked_fill(encoder_padding_mask.unsqueeze(-1), 0) + + # scale gradients (this only affects backward, not forward) + x = GradMultiply.apply(x, 1.0 / (2.0 * self.num_attention_layers)) + + # add output to input embedding for attention + y = (x + input_embedding) * math.sqrt(0.5) + + return { + "encoder_out": (x, y), + "encoder_padding_mask": encoder_padding_mask, # B x T + } + + def reorder_encoder_out(self, encoder_out, new_order): + if encoder_out["encoder_out"] is not None: + encoder_out["encoder_out"] = ( + encoder_out["encoder_out"][0].index_select(0, new_order), + encoder_out["encoder_out"][1].index_select(0, new_order), + ) + if encoder_out["encoder_padding_mask"] is not None: + encoder_out["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ].index_select(0, new_order) + return encoder_out + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return self.embed_positions.max_positions + + +class AttentionLayer(nn.Module): + def __init__(self, conv_channels, embed_dim, bmm=None): + super().__init__() + # projects from output of convolution to embedding dimension + self.in_projection = Linear(conv_channels, embed_dim) + # projects from embedding dimension to convolution size + self.out_projection = Linear(embed_dim, conv_channels) + + self.bmm = bmm if bmm is not None else torch.bmm + + def forward(self, x, target_embedding, encoder_out, encoder_padding_mask): + residual = x + + # attention + x = (self.in_projection(x) + target_embedding) * math.sqrt(0.5) + x = self.bmm(x, encoder_out[0]) + + # don't attend over padding + if encoder_padding_mask is not None: + x = ( + x.float() + .masked_fill(encoder_padding_mask.unsqueeze(1), float("-inf")) + .type_as(x) + ) # FP16 support: cast to float and back + + # softmax over last dim + sz = x.size() + x = F.softmax(x.view(sz[0] * sz[1], sz[2]), dim=1) + x = x.view(sz) + attn_scores = x + + x = self.bmm(x, encoder_out[1]) + + # scale attention output (respecting potentially different lengths) + s = encoder_out[1].size(1) + if encoder_padding_mask is None: + x = x * (s * math.sqrt(1.0 / s)) + else: + s = s - encoder_padding_mask.type_as(x).sum( + dim=1, keepdim=True + ) # exclude padding + s = s.unsqueeze(-1) + x = x * (s * s.rsqrt()) + + # project back + x = (self.out_projection(x) + residual) * math.sqrt(0.5) + return x, attn_scores + + def make_generation_fast_(self, beamable_mm_beam_size=None, **kwargs): + """Replace torch.bmm with BeamableMM.""" + if beamable_mm_beam_size is not None: + del self.bmm + self.add_module("bmm", BeamableMM(beamable_mm_beam_size)) + + +class FConvDecoder(FairseqIncrementalDecoder): + """Convolutional decoder""" + + def __init__( + self, + dictionary, + embed_dim=512, + embed_dict=None, + out_embed_dim=256, + max_positions=1024, + convolutions=((512, 3),) * 20, + attention=True, + dropout=0.1, + share_embed=False, + positional_embeddings=True, + adaptive_softmax_cutoff=None, + adaptive_softmax_dropout=0.0, + ): + super().__init__(dictionary) + self.register_buffer("version", torch.Tensor([2])) + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) + self.need_attn = True + + convolutions = extend_conv_spec(convolutions) + in_channels = convolutions[0][0] + if isinstance(attention, bool): + # expand True into [True, True, ...] and do the same with False + attention = [attention] * len(convolutions) + if not isinstance(attention, list) or len(attention) != len(convolutions): + raise ValueError( + "Attention is expected to be a list of booleans of " + "length equal to the number of layers." + ) + + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) + if embed_dict: + self.embed_tokens = utils.load_embedding( + embed_dict, self.dictionary, self.embed_tokens + ) + + self.embed_positions = ( + PositionalEmbedding( + max_positions, + embed_dim, + padding_idx, + ) + if positional_embeddings + else None + ) + + self.fc1 = Linear(embed_dim, in_channels, dropout=dropout) + self.projections = nn.ModuleList() + self.convolutions = nn.ModuleList() + self.attention = nn.ModuleList() + self.residuals = [] + + layer_in_channels = [in_channels] + for i, (out_channels, kernel_size, residual) in enumerate(convolutions): + if residual == 0: + residual_dim = out_channels + else: + residual_dim = layer_in_channels[-residual] + self.projections.append( + Linear(residual_dim, out_channels) + if residual_dim != out_channels + else None + ) + self.convolutions.append( + LinearizedConv1d( + in_channels, + out_channels * 2, + kernel_size, + padding=(kernel_size - 1), + dropout=dropout, + ) + ) + self.attention.append( + AttentionLayer(out_channels, embed_dim) if attention[i] else None + ) + self.residuals.append(residual) + in_channels = out_channels + layer_in_channels.append(out_channels) + + self.adaptive_softmax = None + self.fc2 = self.fc3 = None + + if adaptive_softmax_cutoff is not None: + assert not share_embed + self.adaptive_softmax = AdaptiveSoftmax( + num_embeddings, + in_channels, + adaptive_softmax_cutoff, + dropout=adaptive_softmax_dropout, + ) + else: + self.fc2 = Linear(in_channels, out_embed_dim) + if share_embed: + assert out_embed_dim == embed_dim, ( + "Shared embed weights implies same dimensions " + " out_embed_dim={} vs embed_dim={}".format(out_embed_dim, embed_dim) + ) + self.fc3 = nn.Linear(out_embed_dim, num_embeddings) + self.fc3.weight = self.embed_tokens.weight + else: + self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout) + + def forward( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused + ): + if encoder_out is not None: + encoder_padding_mask = encoder_out["encoder_padding_mask"] + encoder_out = encoder_out["encoder_out"] + + # split and transpose encoder outputs + encoder_a, encoder_b = self._split_encoder_out( + encoder_out, incremental_state + ) + + if self.embed_positions is not None: + pos_embed = self.embed_positions(prev_output_tokens, incremental_state) + else: + pos_embed = 0 + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + x = self._embed_tokens(prev_output_tokens, incremental_state) + + # embed tokens and combine with positional embeddings + x += pos_embed + x = self.dropout_module(x) + target_embedding = x + + # project to size of convolution + x = self.fc1(x) + + # B x T x C -> T x B x C + x = self._transpose_if_training(x, incremental_state) + + # temporal convolutions + avg_attn_scores = None + num_attn_layers = len(self.attention) + residuals = [x] + for proj, conv, attention, res_layer in zip( + self.projections, self.convolutions, self.attention, self.residuals + ): + if res_layer > 0: + residual = residuals[-res_layer] + residual = residual if proj is None else proj(residual) + else: + residual = None + + x = self.dropout_module(x) + x = conv(x, incremental_state) + x = F.glu(x, dim=2) + + # attention + if attention is not None: + x = self._transpose_if_training(x, incremental_state) + + x, attn_scores = attention( + x, target_embedding, (encoder_a, encoder_b), encoder_padding_mask + ) + + if not self.training and self.need_attn: + attn_scores = attn_scores / num_attn_layers + if avg_attn_scores is None: + avg_attn_scores = attn_scores + else: + avg_attn_scores.add_(attn_scores) + + x = self._transpose_if_training(x, incremental_state) + + # residual + if residual is not None: + x = (x + residual) * math.sqrt(0.5) + residuals.append(x) + + # T x B x C -> B x T x C + x = self._transpose_if_training(x, incremental_state) + + # project back to size of vocabulary if not using adaptive softmax + if self.fc2 is not None and self.fc3 is not None: + x = self.fc2(x) + x = self.dropout_module(x) + x = self.fc3(x) + + return x, avg_attn_scores + + def reorder_incremental_state(self, incremental_state, new_order): + super().reorder_incremental_state(incremental_state, new_order) + encoder_out = utils.get_incremental_state( + self, incremental_state, "encoder_out" + ) + if encoder_out is not None: + encoder_out = tuple(eo.index_select(0, new_order) for eo in encoder_out) + utils.set_incremental_state( + self, incremental_state, "encoder_out", encoder_out + ) + + def max_positions(self): + """Maximum output length supported by the decoder.""" + return ( + self.embed_positions.max_positions + if self.embed_positions is not None + else float("inf") + ) + + def upgrade_state_dict(self, state_dict): + if utils.item(state_dict.get("decoder.version", torch.Tensor([1]))[0]) < 2: + # old models use incorrect weight norm dimension + for i, conv in enumerate(self.convolutions): + # reconfigure weight norm + nn.utils.remove_weight_norm(conv) + self.convolutions[i] = nn.utils.weight_norm(conv, dim=0) + state_dict["decoder.version"] = torch.Tensor([1]) + return state_dict + + def make_generation_fast_(self, need_attn=False, **kwargs): + self.need_attn = need_attn + + def _embed_tokens(self, tokens, incremental_state): + if incremental_state is not None: + # keep only the last token for incremental forward pass + tokens = tokens[:, -1:] + return self.embed_tokens(tokens) + + def _split_encoder_out(self, encoder_out, incremental_state): + """Split and transpose encoder outputs. + + This is cached when doing incremental inference. + """ + cached_result = utils.get_incremental_state( + self, incremental_state, "encoder_out" + ) + if cached_result is not None: + return cached_result + + # transpose only once to speed up attention layers + encoder_a, encoder_b = encoder_out + encoder_a = encoder_a.transpose(1, 2).contiguous() + result = (encoder_a, encoder_b) + + if incremental_state is not None: + utils.set_incremental_state(self, incremental_state, "encoder_out", result) + return result + + def _transpose_if_training(self, x, incremental_state): + if incremental_state is None: + x = x.transpose(0, 1) + return x + + +def extend_conv_spec(convolutions): + """ + Extends convolutional spec that is a list of tuples of 2 or 3 parameters + (kernel size, dim size and optionally how many layers behind to look for residual) + to default the residual propagation param if it is not specified + """ + extended = [] + for spec in convolutions: + if len(spec) == 3: + extended.append(spec) + elif len(spec) == 2: + extended.append(spec + (1,)) + else: + raise Exception( + "invalid number of parameters in convolution spec " + + str(spec) + + ". expected 2 or 3" + ) + return tuple(extended) + + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, 0, 0.1) + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx): + m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx) + nn.init.normal_(m.weight, 0, 0.1) + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +def Linear(in_features, out_features, dropout=0.0): + """Weight-normalized Linear layer (input: N x T x C)""" + m = nn.Linear(in_features, out_features) + nn.init.normal_(m.weight, mean=0, std=math.sqrt((1 - dropout) / in_features)) + nn.init.constant_(m.bias, 0) + return nn.utils.weight_norm(m) + + +def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0.0, **kwargs): + """Weight-normalized Conv1d layer optimized for decoding""" + m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs) + std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels)) + nn.init.normal_(m.weight, mean=0, std=std) + nn.init.constant_(m.bias, 0) + return nn.utils.weight_norm(m, dim=2) + + +def ConvTBC(in_channels, out_channels, kernel_size, dropout=0.0, **kwargs): + """Weight-normalized Conv1d layer""" + from fairseq.modules import ConvTBC + + m = ConvTBC(in_channels, out_channels, kernel_size, **kwargs) + std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels)) + nn.init.normal_(m.weight, mean=0, std=std) + nn.init.constant_(m.bias, 0) + return nn.utils.weight_norm(m, dim=2) + + +@register_model_architecture("fconv", "fconv") +def base_architecture(args): + args.dropout = getattr(args, "dropout", 0.1) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) + args.encoder_layers = getattr(args, "encoder_layers", "[(512, 3)] * 20") + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_layers = getattr(args, "decoder_layers", "[(512, 3)] * 20") + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256) + args.decoder_attention = getattr(args, "decoder_attention", "True") + args.share_input_output_embed = getattr(args, "share_input_output_embed", False) + + +@register_model_architecture("fconv", "fconv_iwslt_de_en") +def fconv_iwslt_de_en(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_layers = getattr(args, "encoder_layers", "[(256, 3)] * 4") + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256) + args.decoder_layers = getattr(args, "decoder_layers", "[(256, 3)] * 3") + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256) + base_architecture(args) + + +@register_model_architecture("fconv", "fconv_wmt_en_ro") +def fconv_wmt_en_ro(args): + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 512) + base_architecture(args) + + +@register_model_architecture("fconv", "fconv_wmt_en_de") +def fconv_wmt_en_de(args): + convs = "[(512, 3)] * 9" # first 9 layers have 512 units + convs += " + [(1024, 3)] * 4" # next 4 layers have 1024 units + convs += " + [(2048, 1)] * 2" # final 2 layers use 1x1 convolutions + + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) + args.encoder_layers = getattr(args, "encoder_layers", convs) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 768) + args.decoder_layers = getattr(args, "decoder_layers", convs) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 512) + base_architecture(args) + + +@register_model_architecture("fconv", "fconv_wmt_en_fr") +def fconv_wmt_en_fr(args): + convs = "[(512, 3)] * 6" # first 6 layers have 512 units + convs += " + [(768, 3)] * 4" # next 4 layers have 768 units + convs += " + [(1024, 3)] * 3" # next 3 layers have 1024 units + convs += " + [(2048, 1)] * 1" # next 1 layer uses 1x1 convolutions + convs += " + [(4096, 1)] * 1" # final 1 layer uses 1x1 convolutions + + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) + args.encoder_layers = getattr(args, "encoder_layers", convs) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 768) + args.decoder_layers = getattr(args, "decoder_layers", convs) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 512) + base_architecture(args) diff --git a/fairseq/models/fconv_lm.py b/fairseq/models/fconv_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..4b243d6669cb57880353b45a01843ec22010fb5f --- /dev/null +++ b/fairseq/models/fconv_lm.py @@ -0,0 +1,136 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq import utils +from fairseq.models import ( + FairseqLanguageModel, + register_model, + register_model_architecture, +) +from fairseq.models.fconv import FConvDecoder +from fairseq.utils import safe_hasattr + + +@register_model("fconv_lm") +class FConvLanguageModel(FairseqLanguageModel): + def __init__(self, decoder): + super().__init__(decoder) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + "--dropout", type=float, metavar="D", help="dropout probability" + ) + parser.add_argument( + "--decoder-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension", + ) + parser.add_argument( + "--decoder-layers", + type=str, + metavar="EXPR", + help="decoder layers [(dim, kernel_size), ...]", + ) + parser.add_argument( + "--decoder-out-embed-dim", + type=int, + metavar="N", + help="decoder output embedding dimension", + ) + parser.add_argument( + "--adaptive-softmax-cutoff", + metavar="EXPR", + help="comma separated list of adaptive softmax cutoff points. " + "Must be used with adaptive_loss criterion", + ) + parser.add_argument( + "--adaptive-softmax-dropout", + type=float, + metavar="D", + help="sets adaptive softmax dropout for the tail projections", + ) + parser.add_argument( + "--decoder-attention", + type=str, + metavar="EXPR", + help="decoder attention [True, ...]", + ) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + # make sure all arguments are present in older models + base_lm_architecture(args) + + if safe_hasattr(args, "max_target_positions") and not safe_hasattr( + args, "tokens_per_sample" + ): + args.tokens_per_sample = args.max_target_positions + + decoder = FConvDecoder( + dictionary=task.target_dictionary, + embed_dim=args.decoder_embed_dim, + convolutions=eval(args.decoder_layers), + out_embed_dim=args.decoder_embed_dim, + attention=eval(args.decoder_attention), + dropout=args.dropout, + max_positions=args.tokens_per_sample, + share_embed=False, + positional_embeddings=False, + adaptive_softmax_cutoff=( + utils.eval_str_list(args.adaptive_softmax_cutoff, type=int) + if args.criterion == "adaptive_loss" + else None + ), + adaptive_softmax_dropout=args.adaptive_softmax_dropout, + ) + return FConvLanguageModel(decoder) + + +@register_model_architecture("fconv_lm", "fconv_lm") +def base_lm_architecture(args): + args.dropout = getattr(args, "dropout", 0.1) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 128) + args.decoder_layers = getattr(args, "decoder_layers", "[(1268, 4)] * 13") + args.decoder_attention = getattr(args, "decoder_attention", "False") + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + + +@register_model_architecture("fconv_lm", "fconv_lm_dauphin_wikitext103") +def fconv_lm_dauphin_wikitext103(args): + layers = "[(850, 6)] * 3" + layers += " + [(850, 1)] * 1" + layers += " + [(850, 5)] * 4" + layers += " + [(850, 1)] * 1" + layers += " + [(850, 4)] * 3" + layers += " + [(1024, 4)] * 1" + layers += " + [(2048, 4)] * 1" + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 280) + args.decoder_layers = getattr(args, "decoder_layers", layers) + args.decoder_attention = getattr(args, "decoder_attention", "False") + args.adaptive_softmax_cutoff = getattr( + args, "adaptive_softmax_cutoff", "10000,20000,200000" + ) + base_lm_architecture(args) + + +@register_model_architecture("fconv_lm", "fconv_lm_dauphin_gbw") +def fconv_lm_dauphin_gbw(args): + layers = "[(512, 5)]" + layers += " + [(128, 1, 0), (128, 5, 0), (512, 1, 3)] * 3" + layers += " + [(512, 1, 0), (512, 5, 0), (1024, 1, 3)] * 3" + layers += " + [(1024, 1, 0), (1024, 5, 0), (2048, 1, 3)] * 6" + layers += " + [(1024, 1, 0), (1024, 5, 0), (4096, 1, 3)]" + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 128) + args.decoder_layers = getattr(args, "decoder_layers", layers) + args.decoder_attention = getattr(args, "decoder_attention", "False") + args.adaptive_softmax_cutoff = getattr( + args, "adaptive_softmax_cutoff", "10000,50000,200000" + ) + base_lm_architecture(args) diff --git a/fairseq/models/fconv_self_att.py b/fairseq/models/fconv_self_att.py new file mode 100644 index 0000000000000000000000000000000000000000..8357ef7847ed25a62345e219c41906156828c233 --- /dev/null +++ b/fairseq/models/fconv_self_att.py @@ -0,0 +1,674 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import checkpoint_utils +from fairseq.incremental_decoding_utils import with_incremental_state +from fairseq.models import ( + CompositeEncoder, + FairseqDecoder, + FairseqEncoder, + FairseqEncoderDecoderModel, + register_model, + register_model_architecture, +) +from fairseq.modules import ( + DownsampledMultiHeadAttention, + FairseqDropout, + GradMultiply, + LayerNorm, + LearnedPositionalEmbedding, + LinearizedConvolution, +) + + +logger = logging.getLogger(__name__) + + +@register_model("fconv_self_att") +class FConvModelSelfAtt(FairseqEncoderDecoderModel): + @classmethod + def hub_models(cls): + return { + "conv.stories.pretrained": { + "path": "https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz", + "checkpoint_file": "pretrained_checkpoint.pt", + "tokenizer": "nltk", + }, + "conv.stories": { + "path": "https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz", + "checkpoint_file": "fusion_checkpoint.pt", + "tokenizer": "nltk", + "pretrained": "True", + "pretrained_checkpoint": "./pretrained_checkpoint.pt", + }, + # Test set containing dictionaries + "data.stories": "https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2", + } + + def __init__(self, encoder, decoder, pretrained_encoder=None): + super().__init__(encoder, decoder) + self.encoder.num_attention_layers = sum( + layer is not None for layer in decoder.attention + ) + self.pretrained_encoder = pretrained_encoder + if self.pretrained_encoder is None: + encoders = {"encoder": encoder} + else: + encoders = {"encoder": encoder, "pretrained": self.pretrained_encoder} + # for fusion model, CompositeEncoder contains both pretrained and training encoders + # these are forwarded and then combined in the decoder + self.encoder = CompositeEncoder(encoders) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--dropout', type=float, metavar='D', + help='dropout probability') + parser.add_argument('--encoder-embed-dim', type=int, metavar='N', + help='encoder embedding dimension') + parser.add_argument('--encoder-layers', type=str, metavar='EXPR', + help='encoder layers [(dim, kernel_size), ...]') + parser.add_argument('--decoder-embed-dim', type=int, metavar='N', + help='decoder embedding dimension') + parser.add_argument('--decoder-layers', type=str, metavar='EXPR', + help='decoder layers [(dim, kernel_size), ...]') + parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N', + help='decoder output embedding dimension') + parser.add_argument('--decoder-attention', type=str, metavar='EXPR', + help='decoder attention [True, ...]') + parser.add_argument('--self-attention', type=str, metavar='EXPR', + help='decoder self-attention layers, ex: [True] + [False]*5') + parser.add_argument('--multihead-attention-nheads', type=int, + help='Number of heads to use in attention') + parser.add_argument('--multihead-self-attention-nheads', type=int, + help='Number of heads to use in self-attention') + parser.add_argument('--encoder-attention', type=str, metavar='EXPR', + help='encoder attention [True, ...]') + parser.add_argument('--encoder-attention-nheads', type=int, + help='Number of heads to use in encoder attention') + parser.add_argument('--project-input', type=str, metavar='EXPR', + help='Use projections in self-attention [True, ...]') + parser.add_argument('--gated-attention', type=str, metavar='EXPR', + help='Use GLU layers in self-attention projections [True, ...]') + parser.add_argument('--downsample', type=str, metavar='EXPR', + help='Use downsampling in self-attention [True, ...]') + parser.add_argument('--pretrained-checkpoint', metavar='DIR', + help='path to load checkpoint from pretrained model') + parser.add_argument('--pretrained', type=str, metavar='EXPR', + help='use pretrained model when training [True, ...]') + # fmt: on + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + trained_encoder, trained_decoder = None, None + pretrained = eval(args.pretrained) + if pretrained: + logger.info("loading pretrained model") + if not os.path.exists(args.pretrained_checkpoint): + new_pretrained_checkpoint = os.path.join( + args.data, args.pretrained_checkpoint + ) + if os.path.exists(new_pretrained_checkpoint): + args.pretrained_checkpoint = new_pretrained_checkpoint + trained_model = checkpoint_utils.load_model_ensemble( + filenames=[args.pretrained_checkpoint], + task=task, + )[0][0] + trained_decoder = list(trained_model.children())[1] + trained_encoder = list(trained_model.children())[0] + + # freeze pretrained model + for param in trained_decoder.parameters(): + param.requires_grad = False + for param in trained_encoder.parameters(): + param.requires_grad = False + + encoder = FConvEncoder( + task.source_dictionary, + embed_dim=args.encoder_embed_dim, + convolutions=eval(args.encoder_layers), + dropout=args.dropout, + max_positions=args.max_source_positions, + attention=eval(args.encoder_attention), + attention_nheads=args.encoder_attention_nheads, + ) + + decoder = FConvDecoder( + task.target_dictionary, + embed_dim=args.decoder_embed_dim, + convolutions=eval(args.decoder_layers), + out_embed_dim=args.decoder_out_embed_dim, + attention=eval(args.decoder_attention), + dropout=args.dropout, + max_positions=args.max_target_positions, + selfattention=eval(args.self_attention), + attention_nheads=args.multihead_attention_nheads, + selfattention_nheads=args.multihead_self_attention_nheads, + project_input=eval(args.project_input), + gated_attention=eval(args.gated_attention), + downsample=eval(args.downsample), + pretrained=pretrained, + trained_decoder=trained_decoder, + ) + model = FConvModelSelfAtt(encoder, decoder, trained_encoder) + + return model + + @property + def pretrained(self): + return self.pretrained_encoder is not None + + +class FConvEncoder(FairseqEncoder): + """Convolutional encoder""" + + def __init__( + self, + dictionary, + embed_dim=512, + max_positions=1024, + convolutions=((512, 3),) * 20, + dropout=0.1, + attention=False, + attention_nheads=1, + ): + super().__init__(dictionary) + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) + self.num_attention_layers = None + + num_embeddings = len(dictionary) + self.padding_idx = dictionary.pad() + self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx) + self.embed_positions = PositionalEmbedding( + max_positions, + embed_dim, + self.padding_idx, + ) + + def expand_bool_array(val): + if isinstance(val, bool): + # expand True into [True, True, ...] and do the same with False + return [val] * len(convolutions) + return val + + attention = expand_bool_array(attention) + + in_channels = convolutions[0][0] + self.fc1 = Linear(embed_dim, in_channels, dropout=dropout) + self.projections = nn.ModuleList() + self.convolutions = nn.ModuleList() + self.attention = nn.ModuleList() + self.attproj = nn.ModuleList() + for i, (out_channels, kernel_size) in enumerate(convolutions): + self.projections.append( + Linear(in_channels, out_channels) + if in_channels != out_channels + else None + ) + self.convolutions.append( + ConvTBC(in_channels, out_channels * 2, kernel_size, dropout=dropout) + ) + + self.attention.append( + SelfAttention(out_channels, embed_dim, attention_nheads) + if attention[i] + else None + ) + in_channels = out_channels + + self.fc2 = Linear(in_channels, embed_dim) + + def forward(self, src_tokens, src_lengths): + # embed tokens and positions + x = self.embed_tokens(src_tokens) + self.embed_positions(src_tokens) + x = self.dropout_module(x) + input_embedding = x.transpose(0, 1) + + # project to size of convolution + x = self.fc1(x) + + encoder_padding_mask = src_tokens.eq(self.padding_idx).t() # -> T x B + if not encoder_padding_mask.any(): + encoder_padding_mask = None + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # temporal convolutions + for proj, conv, attention in zip( + self.projections, self.convolutions, self.attention + ): + residual = x if proj is None else proj(x) + + if encoder_padding_mask is not None: + x = x.masked_fill(encoder_padding_mask.unsqueeze(-1), 0) + + x = self.dropout_module(x) + padding_l = (conv.kernel_size[0] - 1) // 2 + padding_r = conv.kernel_size[0] // 2 + x = F.pad(x, (0, 0, 0, 0, padding_l, padding_r)) + x = conv(x) + x = F.glu(x, dim=2) + if attention is not None: + x = attention(x) + x = (x + residual) * math.sqrt(0.5) + + # T x B x C -> B x T x C + x = x.transpose(1, 0) + + # project back to size of embedding + x = self.fc2(x) + + if encoder_padding_mask is not None: + encoder_padding_mask = encoder_padding_mask.t() # -> B x T + x = x.masked_fill(encoder_padding_mask.unsqueeze(-1), 0) + + # scale gradients (this only affects backward, not forward) + x = GradMultiply.apply(x, 1.0 / (2.0 * self.num_attention_layers)) + + # add output to input embedding for attention + y = (x + input_embedding.transpose(0, 1)) * math.sqrt(0.5) + + return { + "encoder_out": (x, y), + "encoder_padding_mask": encoder_padding_mask, # B x T + } + + def reorder_encoder_out(self, encoder_out, new_order): + encoder_out["encoder_out"] = tuple( + eo.index_select(0, new_order) for eo in encoder_out["encoder_out"] + ) + + if encoder_out["encoder_padding_mask"] is not None: + encoder_out["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ].index_select(0, new_order) + + if "pretrained" in encoder_out: + encoder_out["pretrained"]["encoder_out"] = tuple( + eo.index_select(0, new_order) + for eo in encoder_out["pretrained"]["encoder_out"] + ) + + return encoder_out + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return self.embed_positions.max_positions + + +@with_incremental_state +class FConvDecoder(FairseqDecoder): + """Convolutional decoder""" + + def __init__( + self, + dictionary, + embed_dim=512, + out_embed_dim=256, + max_positions=1024, + convolutions=((512, 3),) * 8, + attention=True, + dropout=0.1, + selfattention=False, + attention_nheads=1, + selfattention_nheads=1, + project_input=False, + gated_attention=False, + downsample=False, + pretrained=False, + trained_decoder=None, + ): + super().__init__(dictionary) + self.register_buffer("version", torch.Tensor([2])) + self.pretrained = pretrained + self.pretrained_decoder = trained_decoder + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) + self.need_attn = True + in_channels = convolutions[0][0] + + def expand_bool_array(val): + if isinstance(val, bool): + # expand True into [True, True, ...] and do the same with False + return [val] * len(convolutions) + return val + + attention = expand_bool_array(attention) + selfattention = expand_bool_array(selfattention) + + if not isinstance(attention, list) or len(attention) != len(convolutions): + raise ValueError( + "Attention is expected to be a list of booleans of " + "length equal to the number of layers." + ) + + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) + + self.embed_positions = PositionalEmbedding( + max_positions, + embed_dim, + padding_idx, + ) + + self.fc1 = Linear(embed_dim, in_channels, dropout=dropout) + self.projections = nn.ModuleList() + self.convolutions = nn.ModuleList() + self.attention = nn.ModuleList() + self.selfattention = nn.ModuleList() + self.attproj = nn.ModuleList() + for i, (out_channels, kernel_size) in enumerate(convolutions): + self.projections.append( + Linear(in_channels, out_channels) + if in_channels != out_channels + else None + ) + self.convolutions.append( + LinearizedConv1d( + in_channels, + out_channels * 2, + kernel_size, + padding=(kernel_size - 1), + dropout=dropout, + ) + ) + + self.attention.append( + DownsampledMultiHeadAttention( + out_channels, + embed_dim, + attention_nheads, + project_input=project_input, + gated=False, + downsample=False, + ) + if attention[i] + else None + ) + + self.attproj.append( + Linear(out_channels, embed_dim, dropout=dropout) + if attention[i] + else None + ) + self.selfattention.append( + SelfAttention( + out_channels, + embed_dim, + selfattention_nheads, + project_input=project_input, + gated=gated_attention, + downsample=downsample, + ) + if selfattention[i] + else None + ) + in_channels = out_channels + + self.fc2 = Linear(in_channels, out_embed_dim) + self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout) + + # model fusion + if self.pretrained: + # independent gates are learned from the concatenated input + self.gate1 = nn.Sequential( + Linear(out_embed_dim * 2, out_embed_dim), nn.Sigmoid() + ) + self.gate2 = nn.Sequential( + Linear(out_embed_dim * 2, out_embed_dim), nn.Sigmoid() + ) + # pretrained and trained models are joined + self.joining = nn.Sequential( + Linear(out_embed_dim * 2, out_embed_dim * 2), + LayerNorm(out_embed_dim * 2), + nn.GLU(), + Linear(out_embed_dim, out_embed_dim * 2), + LayerNorm(out_embed_dim * 2), + nn.GLU(), + Linear(out_embed_dim, out_embed_dim), + LayerNorm(out_embed_dim), + ) + # pretrained model contains an output layer that is nhid -> vocab size + # but the models are combined in their hidden state + # the hook stores the output of the pretrained model forward + self.pretrained_outputs = {} + + def save_output(): + def hook(a, b, output): + self.pretrained_outputs["out"] = output + + return hook + + self.pretrained_decoder.fc2.register_forward_hook(save_output()) + + def forward(self, prev_output_tokens, encoder_out): + trained_encoder_out = encoder_out["pretrained"] if self.pretrained else None + encoder_out = encoder_out["encoder"]["encoder_out"] + + encoder_a, encoder_b = self._split_encoder_out(encoder_out) + + # embed positions + positions = self.embed_positions(prev_output_tokens) + + # embed tokens and positions + x = self.embed_tokens(prev_output_tokens) + positions + x = self.dropout_module(x) + target_embedding = x.transpose(0, 1) + + # project to size of convolution + x = self.fc1(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # temporal convolutions + avg_attn_scores = None + for proj, conv, attention, selfattention, attproj in zip( + self.projections, + self.convolutions, + self.attention, + self.selfattention, + self.attproj, + ): + residual = x if proj is None else proj(x) + + x = self.dropout_module(x) + x = conv(x) + x = F.glu(x, dim=2) + + # attention + if attention is not None: + r = x + x, attn_scores = attention( + attproj(x) + target_embedding, encoder_a, encoder_b + ) + x = x + r + if not self.training and self.need_attn: + if avg_attn_scores is None: + avg_attn_scores = attn_scores + else: + avg_attn_scores.add_(attn_scores) + + if selfattention is not None: + x = selfattention(x) + + x = (x + residual) * math.sqrt(0.5) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + # project back to size of vocabulary + x = self.fc2(x) + x = self.dropout_module(x) + if not self.pretrained: + x = self.fc3(x) + + # fusion gating + if self.pretrained: + trained_x, _ = self.pretrained_decoder.forward( + prev_output_tokens, trained_encoder_out + ) + y = torch.cat([x, self.pretrained_outputs["out"]], dim=-1) + gate1 = self.gate1(y) + gate2 = self.gate2(y) + gated_x1 = gate1 * x + gated_x2 = gate2 * self.pretrained_outputs["out"] + fusion = torch.cat([gated_x1, gated_x2], dim=-1) + fusion = self.joining(fusion) + fusion_output = self.fc3(fusion) + return fusion_output, avg_attn_scores + else: + return x, avg_attn_scores + + def max_positions(self): + """Maximum output length supported by the decoder.""" + return self.embed_positions.max_positions + + def make_generation_fast_(self, need_attn=False, **kwargs): + self.need_attn = need_attn + + def _split_encoder_out(self, encoder_out): + """Split and transpose encoder outputs.""" + # transpose only once to speed up attention layers + encoder_a, encoder_b = encoder_out + encoder_a = encoder_a.transpose(0, 1).contiguous() + encoder_b = encoder_b.transpose(0, 1).contiguous() + result = (encoder_a, encoder_b) + return result + + +class SelfAttention(nn.Module): + def __init__( + self, + out_channels, + embed_dim, + num_heads, + project_input=False, + gated=False, + downsample=False, + ): + super().__init__() + self.attention = DownsampledMultiHeadAttention( + out_channels, + embed_dim, + num_heads, + dropout=0, + bias=True, + project_input=project_input, + gated=gated, + downsample=downsample, + ) + self.in_proj_q = Linear(out_channels, embed_dim) + self.in_proj_k = Linear(out_channels, embed_dim) + self.in_proj_v = Linear(out_channels, embed_dim) + self.ln = LayerNorm(out_channels) + + def forward(self, x): + residual = x + query = self.in_proj_q(x) + key = self.in_proj_k(x) + value = self.in_proj_v(x) + x, _ = self.attention( + query, key, value, mask_future_timesteps=True, use_scalar_bias=True + ) + return self.ln(x + residual) + + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + m.weight.data.normal_(0, 0.1) + return m + + +def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx): + m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx) + m.weight.data.normal_(0, 0.1) + return m + + +def Linear(in_features, out_features, dropout=0.0): + """Weight-normalized Linear layer (input: N x T x C)""" + m = nn.Linear(in_features, out_features) + m.weight.data.normal_(mean=0, std=math.sqrt((1 - dropout) / in_features)) + m.bias.data.zero_() + return m + + +def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0.0, **kwargs): + """Weight-normalized Conv1d layer optimized for decoding""" + m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs) + std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels)) + m.weight.data.normal_(mean=0, std=std) + m.bias.data.zero_() + return m + + +def ConvTBC(in_channels, out_channels, kernel_size, dropout=0.0, **kwargs): + """Weight-normalized Conv1d layer""" + from fairseq.modules import ConvTBC + + m = ConvTBC(in_channels, out_channels, kernel_size, **kwargs) + std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels)) + m.weight.data.normal_(mean=0, std=std) + m.bias.data.zero_() + return m + + +@register_model_architecture("fconv_self_att", "fconv_self_att") +def base_architecture(args): + args.dropout = getattr(args, "dropout", 0.1) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_layers = getattr(args, "encoder_layers", "[(512, 3)] * 3") + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_layers = getattr(args, "decoder_layers", "[(512, 3)] * 8") + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256) + args.decoder_attention = getattr(args, "decoder_attention", "True") + args.self_attention = getattr(args, "self_attention", "False") + args.encoder_attention = getattr(args, "encoder_attention", "False") + args.multihead_attention_nheads = getattr(args, "multihead_attention_nheads", 1) + args.multihead_self_attention_nheads = getattr( + args, "multihead_self_attention_nheads", 1 + ) + args.encoder_attention_nheads = getattr(args, "encoder_attention_nheads", 1) + args.project_input = getattr(args, "project_input", "False") + args.gated_attention = getattr(args, "gated_attention", "False") + args.downsample = getattr(args, "downsample", "False") + args.pretrained_checkpoint = getattr(args, "pretrained_checkpoint", "") + args.pretrained = getattr(args, "pretrained", "False") + + +@register_model_architecture("fconv_self_att", "fconv_self_att_wp") +def fconv_self_att_wp(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_layers = getattr( + args, "encoder_layers", "[(128, 3)] * 2 + [(512,3)] * 1" + ) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256) + args.decoder_layers = getattr( + args, "decoder_layers", "[(512, 4)] * 4 + [(768, 4)] * 2 + [(1024, 4)] * 1" + ) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256) + args.self_attention = getattr(args, "self_attention", "True") + args.multihead_self_attention_nheads = getattr( + args, "multihead_self_attention_nheads", 4 + ) + args.project_input = getattr(args, "project_input", "True") + args.gated_attention = getattr(args, "gated_attention", "True") + args.downsample = getattr(args, "downsample", "True") + base_architecture(args) diff --git a/fairseq/models/hubert/__init__.py b/fairseq/models/hubert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a1b0eabbdbcaf12b15bb96b329ab1e276256f79a --- /dev/null +++ b/fairseq/models/hubert/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .hubert import * # noqa +from .hubert_asr import * # noqa diff --git a/fairseq/models/hubert/__pycache__/__init__.cpython-310.pyc b/fairseq/models/hubert/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa9fb468aa96cc5c5fec7fbc91fbf77dab8d3e17 Binary files /dev/null and b/fairseq/models/hubert/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/models/hubert/__pycache__/hubert.cpython-310.pyc b/fairseq/models/hubert/__pycache__/hubert.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1fd2206c1ac531c40175f432e2f1bd8fadc7512 Binary files /dev/null and b/fairseq/models/hubert/__pycache__/hubert.cpython-310.pyc differ diff --git a/fairseq/models/hubert/__pycache__/hubert_asr.cpython-310.pyc b/fairseq/models/hubert/__pycache__/hubert_asr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d4388202b8e023238444b123ac1e53a00ed0fd9 Binary files /dev/null and b/fairseq/models/hubert/__pycache__/hubert_asr.cpython-310.pyc differ diff --git a/fairseq/models/hubert/hubert.py b/fairseq/models/hubert/hubert.py new file mode 100644 index 0000000000000000000000000000000000000000..cc3b777efdb66622514eb97dbde3642ea1574fb5 --- /dev/null +++ b/fairseq/models/hubert/hubert.py @@ -0,0 +1,576 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +from omegaconf import II + +from fairseq import utils +from fairseq.data.data_utils import compute_mask_indices +from fairseq.data.dictionary import Dictionary +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.models import BaseFairseqModel, register_model +from fairseq.models.wav2vec.wav2vec2 import ( + EXTRACTOR_MODE_CHOICES, + MASKING_DISTRIBUTION_CHOICES, + LAYER_TYPE_CHOICES, + ConvFeatureExtractionModel, + TransformerEncoder, +) +from fairseq.modules import GradMultiply, LayerNorm +from fairseq.tasks.hubert_pretraining import ( + HubertPretrainingConfig, + HubertPretrainingTask, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class HubertConfig(FairseqDataclass): + label_rate: float = II("task.label_rate") + + extractor_mode: EXTRACTOR_MODE_CHOICES = field( + default="default", + metadata={ + "help": "mode for feature extractor. default has a single group " + "norm with d groups in the first conv block, whereas layer_norm " + "has layer norms in every block (meant to use with normalize=True)" + }, + ) + encoder_layers: int = field( + default=12, metadata={"help": "num encoder layers in the transformer"} + ) + encoder_embed_dim: int = field( + default=768, metadata={"help": "encoder embedding dimension"} + ) + encoder_ffn_embed_dim: int = field( + default=3072, metadata={"help": "encoder embedding dimension for FFN"} + ) + encoder_attention_heads: int = field( + default=12, metadata={"help": "num encoder attention heads"} + ) + activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( + default="gelu", metadata={"help": "activation function to use"} + ) + layer_type: LAYER_TYPE_CHOICES = field( + default="transformer", metadata={"help": "layer type in encoder"} + ) + + # dropouts + dropout: float = field( + default=0.1, + metadata={"help": "dropout probability for the transformer"}, + ) + attention_dropout: float = field( + default=0.1, + metadata={"help": "dropout probability for attention weights"}, + ) + activation_dropout: float = field( + default=0.0, + metadata={"help": "dropout probability after activation in FFN"}, + ) + encoder_layerdrop: float = field( + default=0.0, + metadata={"help": "probability of dropping a tarnsformer layer"}, + ) + dropout_input: float = field( + default=0.0, + metadata={"help": "dropout to apply to the input (after feat extr)"}, + ) + dropout_features: float = field( + default=0.0, + metadata={"help": "dropout to apply to the features (after feat extr)"}, + ) + + final_dim: int = field( + default=0, + metadata={ + "help": "project final representations and targets to this many " + "dimensions. set to encoder_embed_dim is <= 0" + }, + ) + untie_final_proj: bool = field( + default=False, + metadata={"help": "use separate projection for each target"}, + ) + layer_norm_first: bool = field( + default=False, + metadata={"help": "apply layernorm first in the transformer"}, + ) + conv_feature_layers: str = field( + default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", + metadata={ + "help": "string describing convolutional feature extraction " + "layers in form of a python list that contains " + "[(dim, kernel_size, stride), ...]" + }, + ) + conv_bias: bool = field( + default=False, metadata={"help": "include bias in conv encoder"} + ) + logit_temp: float = field( + default=0.1, metadata={"help": "temperature to divide logits by"} + ) + target_glu: bool = field( + default=False, metadata={"help": "adds projection + glu to targets"} + ) + feature_grad_mult: float = field( + default=1.0, + metadata={"help": "multiply feature extractor var grads by this"}, + ) + + # masking + mask_length: int = field(default=10, metadata={"help": "mask length"}) + mask_prob: float = field( + default=0.65, + metadata={"help": "probability of replacing a token with mask"}, + ) + mask_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", metadata={"help": "how to choose mask length"} + ) + mask_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indicesh" + }, + ) + no_mask_overlap: bool = field( + default=False, metadata={"help": "whether to allow masks to overlap"} + ) + mask_min_space: int = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + + # channel masking + mask_channel_length: int = field( + default=10, + metadata={"help": "length of the mask for features (channels)"}, + ) + mask_channel_prob: float = field( + default=0.0, + metadata={"help": "probability of replacing a feature with 0"}, + ) + mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", + metadata={"help": "how to choose mask length for channel masking"}, + ) + mask_channel_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indicesh" + }, + ) + no_mask_channel_overlap: bool = field( + default=False, + metadata={"help": "whether to allow channel masks to overlap"}, + ) + mask_channel_min_space: int = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + + # positional embeddings + conv_pos: int = field( + default=128, + metadata={"help": "number of filters for convolutional positional embeddings"}, + ) + conv_pos_groups: int = field( + default=16, + metadata={"help": "number of groups for convolutional positional embedding"}, + ) + conv_pos_batch_norm: bool = field( + default=False, + metadata={ + "help": "use batch norm instead of weight norm in conv_pos (for bf16 models)" + }, + ) + + latent_temp: Tuple[float, float, float] = field( + default=(2, 0.5, 0.999995), + metadata={"help": "legacy (to be removed)"}, + ) + + # loss computation + skip_masked: bool = field( + default=False, + metadata={"help": "skip computing losses over masked frames"}, + ) + skip_nomask: bool = field( + default=False, + metadata={"help": "skip computing losses over unmasked frames"}, + ) + + checkpoint_activations: bool = field( + default=False, + metadata={"help": "recompute activations and save memory for extra compute"}, + ) + + # FP16 optimization + required_seq_len_multiple: int = field( + default=2, + metadata={ + "help": "pad the input to encoder such that the sequence length is divisible by multiple" + }, + ) + + # Conformer + depthwise_conv_kernel_size: int = field( + default=31, + metadata={ + "help": "depthwise-conv-kernel-size for convolution in conformer layer" + }, + ) + attn_type: str = field( + default="", + metadata={"help": "if espnet use ESPNET MHA"}, + ) + pos_enc_type: str = field( + default="abs", + metadata={"help": "Positional encoding type to use in conformer"}, + ) + fp16: bool = field(default=False, metadata={"help": "If fp16 is being used"}) + + +@register_model("hubert", dataclass=HubertConfig) +class HubertModel(BaseFairseqModel): + def __init__( + self, + cfg: HubertConfig, + task_cfg: HubertPretrainingConfig, + dictionaries: List[Dictionary], + ) -> None: + super().__init__() + logger.info(f"HubertModel Config: {cfg}") + + feature_enc_layers = eval(cfg.conv_feature_layers) # noqa + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers]) + self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate + + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + self.logit_temp = cfg.logit_temp + self.skip_masked = cfg.skip_masked + self.skip_nomask = cfg.skip_nomask + + final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim + + self.mask_emb = nn.Parameter( + torch.FloatTensor(cfg.encoder_embed_dim).uniform_() + ) + + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + self.target_glu = None + if cfg.target_glu: + self.target_glu = nn.Sequential( + nn.Linear(final_dim, final_dim * 2), nn.GLU() + ) + + self.untie_final_proj = cfg.untie_final_proj + if self.untie_final_proj: + self.final_proj = nn.Linear( + cfg.encoder_embed_dim, final_dim * len(dictionaries) + ) + else: + self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) + + # modules below are not needed during fine-tuning + if any([d is None for d in dictionaries]): + logger.info("cannot find dictionary. assume will be used for fine-tuning") + else: + self.num_classes = [len(d) for d in dictionaries] + self.label_embs_concat = nn.Parameter( + torch.FloatTensor(sum(self.num_classes), final_dim) + ) + nn.init.uniform_(self.label_embs_concat) + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + @classmethod + def build_model(cls, cfg: HubertConfig, task: HubertPretrainingTask): + """Build a new model instance.""" + + model = HubertModel(cfg, task.cfg, task.dictionaries) + return model + + def apply_mask(self, x, padding_mask, target_list): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def compute_nce(self, x, pos, negs): + neg_is_pos = (pos == negs).all(-1) + pos = pos.unsqueeze(0) + targets = torch.cat([pos, negs], dim=0) + + logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) + logits /= self.logit_temp + if neg_is_pos.any(): + logits[1:][neg_is_pos] = float("-inf") + logits = logits.transpose(0, 1) # (num_x, num_cls+1) + return logits + + def forward_features(self, source: torch.Tensor) -> torch.Tensor: + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + return features + + def forward_targets( + self, + features: torch.Tensor, + target_list: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Trim features to ensure labels exist and then get aligned labels + feat_tsz = features.size(2) + targ_tsz = min([t.size(1) for t in target_list]) + if self.feat2tar_ratio * feat_tsz > targ_tsz: + feat_tsz = int(targ_tsz / self.feat2tar_ratio) + features = features[..., :feat_tsz] + target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio + target_list = [t[:, target_inds.long()] for t in target_list] + return features, target_list + + def forward_padding_mask( + self, + features: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) + padding_mask = padding_mask.all(-1) + return padding_mask + + def forward( + self, + source: torch.Tensor, + target_list: Optional[List[torch.Tensor]] = None, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = True, + features_only: bool = False, + output_layer: Optional[int] = None, + ) -> Dict[str, torch.Tensor]: + """output layer is 1-based""" + features = self.forward_features(source) + if target_list is not None: + features, target_list = self.forward_targets(features, target_list) + + features_pen = features.float().pow(2).mean() + + features = features.transpose(1, 2) + features = self.layer_norm(features) + unmasked_features = features.clone() + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + unmasked_features = self.dropout_features(unmasked_features) + + if mask: + x, mask_indices = self.apply_mask(features, padding_mask, target_list) + else: + x = features + mask_indices = None + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x, _ = self.encoder( + x, + padding_mask=padding_mask, + layer=None if output_layer is None else output_layer - 1, + ) + + if features_only: + return {"x": x, "padding_mask": padding_mask, "features": features} + + def compute_pred(proj_x, target, label_embs): + # compute logits for the i-th label set + y = torch.index_select(label_embs, 0, target.long()) + negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1) + if self.target_glu: + y = self.target_glu(y) + negs = self.target_glu(negs) + # proj_x: (S, D) + # y: (S, D) + # negs: (Neg, S, D) + return self.compute_nce(proj_x, y, negs) + + label_embs_list = self.label_embs_concat.split(self.num_classes, 0) + + if not self.skip_masked: + masked_indices = torch.logical_and(~padding_mask, mask_indices) + proj_x_m = self.final_proj(x[masked_indices]) + if self.untie_final_proj: + proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1) + else: + proj_x_m_list = [proj_x_m for _ in range(len(target_list))] + logit_m_list = [ + compute_pred(proj_x_m, t[masked_indices], label_embs_list[i]) + for i, (proj_x_m, t) in enumerate(zip(proj_x_m_list, target_list)) + ] + else: + logit_m_list = [None for _ in target_list] + + if not self.skip_nomask: + nomask_indices = torch.logical_and(~padding_mask, ~mask_indices) + proj_x_u = self.final_proj(x[nomask_indices]) + if self.untie_final_proj: + proj_x_u_list = proj_x_u.chunk(len(target_list), dim=-1) + else: + proj_x_u_list = [proj_x_u for _ in range(len(target_list))] + + logit_u_list = [ + compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i]) + for i, (proj_x_u, t) in enumerate(zip(proj_x_u_list, target_list)) + ] + else: + logit_u_list = [None for _ in target_list] + + result = { + "logit_m_list": logit_m_list, + "logit_u_list": logit_u_list, + "padding_mask": padding_mask, + "features_pen": features_pen, + } + return result + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + res = self.forward( + source, + padding_mask=padding_mask, + mask=mask, + features_only=True, + output_layer=output_layer, + ) + feature = res["features"] if ret_conv else res["x"] + return feature, res["padding_mask"] + + def get_logits(self, net_output, is_masked=True): + if is_masked: + logits_list = net_output["logit_m_list"] + else: + logits_list = net_output["logit_u_list"] + logits_list = [x.float() for x in logits_list if x is not None] + return logits_list + + def get_targets(self, net_output, is_masked=True): + logits_list = self.get_logits(net_output, is_masked) + targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list] + return targets_list + + def get_extra_losses(self, net_output): + extra_losses = [] + names = [] + + if "features_pen" in net_output: + extra_losses.append(net_output["features_pen"]) + names.append("features_pen") + + return extra_losses, names + + def remove_pretraining_modules(self): + self.target_glu = None + self.final_proj = None diff --git a/fairseq/models/hubert/hubert_asr.py b/fairseq/models/hubert/hubert_asr.py new file mode 100644 index 0000000000000000000000000000000000000000..11c85ce7d11f8529d6ff7636eb467ae56bc11772 --- /dev/null +++ b/fairseq/models/hubert/hubert_asr.py @@ -0,0 +1,675 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import copy +import logging +import math +from argparse import Namespace +from dataclasses import dataclass, field +from typing import Any, Optional +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from omegaconf import II, MISSING, open_dict + +from fairseq import checkpoint_utils, tasks, utils +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.models import ( + BaseFairseqModel, + FairseqEncoder, + FairseqEncoderDecoderModel, + FairseqIncrementalDecoder, + register_model, +) +from fairseq.models.hubert.hubert import MASKING_DISTRIBUTION_CHOICES +from fairseq.modules import LayerNorm, PositionalEmbedding, TransformerDecoderLayer +from fairseq.tasks import FairseqTask + +logger = logging.getLogger(__name__) + + +@dataclass +class HubertAsrConfig(FairseqDataclass): + w2v_path: str = field(default=MISSING, metadata={"help": "path to hubert model"}) + no_pretrained_weights: bool = field( + default=False, + metadata={"help": "if true, does not load pretrained weights"}, + ) + dropout_input: float = field( + default=0.0, + metadata={"help": "dropout to apply to the input (after feat extr)"}, + ) + final_dropout: float = field( + default=0.0, + metadata={"help": "dropout after transformer and before final projection"}, + ) + dropout: float = field( + default=0.0, + metadata={"help": "dropout probability inside hubert model"}, + ) + attention_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability for attention weights " "inside hubert model" + }, + ) + activation_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability after activation in FFN " "inside hubert model" + }, + ) + encoder_embed_dim: Optional[int] = field( + default=768, metadata={"help": "encoder embedding dimension"} + ) + + # masking + apply_mask: bool = field( + default=False, metadata={"help": "apply masking during fine-tuning"} + ) + mask_length: int = field( + default=10, metadata={"help": "repeat the mask indices multiple times"} + ) + mask_prob: float = field( + default=0.5, + metadata={ + "help": "probability of replacing a token with mask " + "(normalized by length)" + }, + ) + mask_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", metadata={"help": "how to choose masks"} + ) + mask_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indices" + }, + ) + no_mask_overlap: bool = field( + default=False, metadata={"help": "whether to allow masks to overlap"} + ) + + # channel masking + mask_channel_length: int = field( + default=10, + metadata={"help": "length of the mask for features (channels)"}, + ) + mask_channel_prob: float = field( + default=0.0, + metadata={"help": "probability of replacing a feature with 0"}, + ) + mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", + metadata={"help": "how to choose mask length for channel masking"}, + ) + mask_channel_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indices" + }, + ) + no_mask_channel_overlap: bool = field( + default=False, + metadata={"help": "whether to allow channel masks to overlap"}, + ) + freeze_finetune_updates: int = field( + default=0, + metadata={"help": "dont finetune hubert for this many updates"}, + ) + feature_grad_mult: float = field( + default=0.0, + metadata={"help": "reset feature grad mult in hubert to this"}, + ) + layerdrop: float = field( + default=0.0, + metadata={"help": "probability of dropping a layer in hubert"}, + ) + normalize: bool = II("task.normalize") + data: str = II("task.data") + + # this holds the loaded hubert args + w2v_args: Any = None + + +@dataclass +class HubertCtcConfig(HubertAsrConfig): + pass + + +@register_model("hubert_ctc", dataclass=HubertCtcConfig) +class HubertCtc(BaseFairseqModel): + def __init__(self, cfg: HubertCtcConfig, w2v_encoder: BaseFairseqModel): + super().__init__() + self.cfg = cfg + self.w2v_encoder = w2v_encoder + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + @classmethod + def build_model(cls, cfg: HubertCtcConfig, task: FairseqTask): + """Build a new model instance.""" + w2v_encoder = HubertEncoder(cfg, task) + return cls(cfg, w2v_encoder) + + def get_normalized_probs(self, net_output, log_probs): + """Get normalized probabilities (or log probs) from a net's output.""" + + logits = net_output["encoder_out"] + if log_probs: + return utils.log_softmax(logits.float(), dim=-1) + else: + return utils.softmax(logits.float(), dim=-1) + + def get_logits(self, net_output): + logits = net_output["encoder_out"] + padding = net_output["encoder_padding_mask"] + if padding is not None and padding.any(): + padding = padding.T + logits[padding][..., 0] = 0 + logits[padding][..., 1:] = float("-inf") + + return logits + + def forward(self, **kwargs): + x = self.w2v_encoder(**kwargs) + return x + + +@dataclass +class HubertSeq2SeqConfig(HubertAsrConfig): + decoder_embed_dim: int = field( + default=768, metadata={"help": "decoder embedding dimension"} + ) + decoder_ffn_embed_dim: int = field( + default=3072, metadata={"help": "decoder embedding dimension for FFN"} + ) + decoder_layers: int = field(default=6, metadata={"help": "num of decoder layers"}) + decoder_layerdrop: float = field( + default=0.0, metadata={"help": "decoder layerdrop chance"} + ) + decoder_attention_heads: int = field( + default=4, metadata={"help": "num decoder attention heads"} + ) + decoder_learned_pos: bool = field( + default=False, + metadata={"help": "use learned positional embeddings in the decoder"}, + ) + decoder_normalize_before: bool = field( + default=False, metadata={"help": "apply layernorm before each decoder block"} + ) + no_token_positional_embeddings: bool = field( + default=False, + metadata={ + "help": "if set, disables positional embeddings (outside self attention)" + }, + ) + decoder_dropout: float = field( + default=0.0, metadata={"help": "dropout probability in the decoder"} + ) + decoder_attention_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability for attention weights inside the decoder" + }, + ) + decoder_activation_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability after activation in FFN inside the decoder" + }, + ) + max_target_positions: int = field( + default=2048, metadata={"help": "max target positions"} + ) + share_decoder_input_output_embed: bool = field( + default=False, metadata={"help": "share decoder input and output embeddings"} + ) + autoregressive: bool = II("task.autoregressive") + seq2seq_path: str = field( + default="", + metadata={"help": "reset_dict"}, + ) + reset_dict: bool = field( + default=False, + metadata={"help": "reset_dict"}, + ) + + +@register_model("hubert_seq2seq", dataclass=HubertSeq2SeqConfig) +class HubertSeq2SeqModel(FairseqEncoderDecoderModel): + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @classmethod + def build_model(cls, cfg: HubertSeq2SeqConfig, task: FairseqTask): + """Build a new model instance.""" + + assert ( + cfg.autoregressive + ), "Please set task.autoregressive=true for seq2seq asr models" + + src_dict, tgt_dict = task.source_dictionary, task.target_dictionary + + def build_embedding(dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + emb = Embedding(num_embeddings, embed_dim, padding_idx) + return emb + + decoder_embed_tokens = build_embedding(tgt_dict, cfg.decoder_embed_dim) + + encoder = cls.build_encoder(cfg, task) + decoder = cls.build_decoder(cfg, tgt_dict, decoder_embed_tokens) + + model = HubertSeq2SeqModel(encoder, decoder) + + if cfg["seq2seq_path"]: + state = checkpoint_utils.load_checkpoint_to_cpu(cfg.seq2seq_path) + state = state["model"] + if cfg["reset_dict"]: + del state["decoder.embed_out"] + del state["decoder.embed_tokens.weight"] + model.load_state_dict(state, strict=False) + return model + + @classmethod + def build_encoder(cls, cfg: HubertAsrConfig, task): + return HubertEncoder(cfg, task) + + @classmethod + def build_decoder(cls, cfg: HubertSeq2SeqConfig, tgt_dict, embed_tokens): + return TransformerDecoder(cfg, tgt_dict, embed_tokens) + + def forward(self, **kwargs): + encoder_out = self.encoder(**kwargs) + decoder_out = self.decoder(encoder_out=encoder_out, **kwargs) + return decoder_out + + def upgrade_state_dict_named(self, state_dict, name): + return state_dict + + def load_state_dict( + self, + state_dict, + strict=True, + model_cfg=None, + args: Optional[Namespace] = None, + ): + if model_cfg.reset_dict: + logger.warn("Overriding loading strict state dict!") + del state_dict["decoder.embed_out"] + del state_dict["decoder.embed_tokens.weight"] + return super().load_state_dict(state_dict, False, model_cfg, args) + return super().load_state_dict(state_dict, strict, model_cfg, args) + + +class HubertEncoder(FairseqEncoder): + def __init__(self, cfg: HubertAsrConfig, task): + self.apply_mask = cfg.apply_mask + + arg_overrides = { + "dropout": cfg.dropout, + "activation_dropout": cfg.activation_dropout, + "dropout_input": cfg.dropout_input, + "attention_dropout": cfg.attention_dropout, + "mask_length": cfg.mask_length, + "mask_prob": cfg.mask_prob, + "mask_selection": cfg.mask_selection, + "mask_other": cfg.mask_other, + "no_mask_overlap": cfg.no_mask_overlap, + "mask_channel_length": cfg.mask_channel_length, + "mask_channel_prob": cfg.mask_channel_prob, + "mask_channel_selection": cfg.mask_channel_selection, + "mask_channel_other": cfg.mask_channel_other, + "no_mask_channel_overlap": cfg.no_mask_channel_overlap, + "encoder_layerdrop": cfg.layerdrop, + "feature_grad_mult": cfg.feature_grad_mult, + } + + if cfg.w2v_args is None: + state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides) + w2v_args = state.get("cfg", None) + if w2v_args is None: + w2v_args = convert_namespace_to_omegaconf(state["args"]) + cfg.w2v_args = w2v_args + else: + state = None + w2v_args = cfg.w2v_args + if isinstance(w2v_args, Namespace): + cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args) + + assert cfg.normalize == w2v_args.task.normalize, ( + "Fine-tuning works best when data normalization is the same. " + "Please check that --normalize is set or unset for " + "both pre-training and here" + ) + + w2v_args.task.data = cfg.data + pretrain_task = tasks.setup_task(w2v_args.task) + if state is not None and "task_state" in state: + # This will load the stored "dictionaries" object + pretrain_task.load_state_dict(state["task_state"]) + else: + pretrain_task.load_state_dict(task.state_dict()) + + model = pretrain_task.build_model(w2v_args.model, from_checkpoint=True) + if state is not None and not cfg.no_pretrained_weights: + # set strict=False because we omit some modules + model.load_state_dict(state["model"], strict=False) + + model.remove_pretraining_modules() + + super().__init__(pretrain_task.source_dictionary) + + d = w2v_args.model.encoder_embed_dim + + self.w2v_model = model + + self.final_dropout = nn.Dropout(cfg.final_dropout) + self.freeze_finetune_updates = cfg.freeze_finetune_updates + self.num_updates = 0 + + if task.target_dictionary is not None and not cfg.autoregressive: + self.proj = Linear(d, len(task.target_dictionary)) + elif getattr(cfg, "decoder_embed_dim", d) != d: + self.proj = Linear(d, cfg.decoder_embed_dim) + else: + self.proj = None + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + super().set_num_updates(num_updates) + self.num_updates = num_updates + + def forward(self, source, padding_mask, tbc=True, **kwargs): + + w2v_args = { + "source": source, + "padding_mask": padding_mask, + "mask": self.apply_mask and self.training, + } + + ft = self.freeze_finetune_updates <= self.num_updates + + with torch.no_grad() if not ft else contextlib.ExitStack(): + x, padding_mask = self.w2v_model.extract_features(**w2v_args) + + if tbc: + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + x = self.final_dropout(x) + + if self.proj: + x = self.proj(x) + + return { + "encoder_out": x, # T x B x C + "encoder_padding_mask": padding_mask, # B x T + "padding_mask": padding_mask, + } + + def reorder_encoder_out(self, encoder_out, new_order): + if encoder_out["encoder_out"] is not None: + encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select( + 1, new_order + ) + if encoder_out["encoder_padding_mask"] is not None: + encoder_out["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ].index_select(0, new_order) + if encoder_out["padding_mask"] is not None: + encoder_out["padding_mask"] = encoder_out["padding_mask"].index_select( + 0, new_order + ) + return encoder_out + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return None + + def upgrade_state_dict_named(self, state_dict, name): + return state_dict + + +class TransformerDecoder(FairseqIncrementalDecoder): + """ + Transformer decoder consisting of *args.decoder_layers* layers. Each layer + is a :class:`TransformerDecoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): decoding dictionary + embed_tokens (torch.nn.Embedding): output embedding + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__( + self, + cfg: HubertSeq2SeqConfig, + dictionary, + embed_tokens, + no_encoder_attn=False, + ): + super().__init__(dictionary) + + self.dropout = cfg.decoder_dropout + self.share_input_output_embed = cfg.share_decoder_input_output_embed + + input_embed_dim = embed_tokens.embedding_dim + embed_dim = cfg.decoder_embed_dim + self.output_embed_dim = cfg.decoder_embed_dim + + self.layerdrop = cfg.decoder_layerdrop + + self.padding_idx = embed_tokens.padding_idx + self.max_target_positions = cfg.max_target_positions + + self.embed_tokens = embed_tokens + self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim + + self.project_in_dim = ( + Linear(input_embed_dim, embed_dim, bias=False) + if embed_dim != input_embed_dim + else None + ) + + self.embed_positions = ( + PositionalEmbedding( + cfg.max_target_positions, + embed_dim, + self.padding_idx, + learned=cfg.decoder_learned_pos, + ) + if not cfg.no_token_positional_embeddings + else None + ) + + # TODO: update this when transformer gets converted to dataclass configs + transformer_cfg = copy.deepcopy(cfg) + with open_dict(transformer_cfg): + transformer_cfg.dropout = transformer_cfg.decoder_dropout + transformer_cfg.attention_dropout = ( + transformer_cfg.decoder_attention_dropout + ) + transformer_cfg.activation_dropout = ( + transformer_cfg.decoder_activation_dropout + ) + + self.layers = nn.ModuleList([]) + self.layers.extend( + [ + TransformerDecoderLayer(transformer_cfg, no_encoder_attn) + for _ in range(transformer_cfg.decoder_layers) + ] + ) + + if not self.share_input_output_embed: + self.embed_out = nn.Parameter( + torch.Tensor(len(dictionary), self.output_embed_dim) + ) + nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim**-0.5) + + if transformer_cfg.decoder_normalize_before: + self.layer_norm = LayerNorm(embed_dim) + else: + self.layer_norm = None + + def forward( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (Tensor, optional): output from the encoder, used for + encoder-side attention + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + if type(prev_output_tokens) == list: + max_len = max((len(x) for x in prev_output_tokens)) + tmp = torch.zeros( + [len(prev_output_tokens), max_len], device=prev_output_tokens[0].device + ) + for (i, p) in enumerate(prev_output_tokens): + tmp[i, : len(p)] = p + prev_output_tokens = tmp + prev_output_tokens = prev_output_tokens.long() + x, extra = self.extract_features( + prev_output_tokens, encoder_out, incremental_state + ) + x = self.output_layer(x) + return x, extra + + def extract_features( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused + ): + """ + Similar to *forward* but only return features. + + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + + # embed positions + positions = ( + self.embed_positions( + prev_output_tokens, incremental_state=incremental_state + ) + if self.embed_positions is not None + else None + ) + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if positions is not None: + x += positions + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + attn = None + + inner_states = [x] + + # decoder layers + self_attn_padding_mask = None + if prev_output_tokens.eq(self.padding_idx).any(): + self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) + for layer in self.layers: + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, attn, _ = layer( + x, + encoder_out["encoder_out"] if encoder_out is not None else None, + encoder_out["padding_mask"] if encoder_out is not None else None, + incremental_state, + self_attn_mask=self.buffered_future_mask(x) + if incremental_state is None + else None, + self_attn_padding_mask=self_attn_padding_mask, + ) + inner_states.append(x) + + if self.layer_norm: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, {"attn": attn, "inner_states": inner_states} + + def output_layer(self, features, **kwargs): + """Project features to the vocabulary size.""" + # project back to size of vocabulary + if self.share_input_output_embed: + return F.linear(features, self.embed_tokens.weight) + else: + return F.linear(features, self.embed_out) + + def max_positions(self): + """Maximum output length supported by the decoder.""" + if self.embed_positions is None: + return self.max_target_positions + return min(self.max_target_positions, self.embed_positions.max_positions) + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + if ( + not hasattr(self, "_future_mask") + or self._future_mask is None + or self._future_mask.device != tensor.device + or self._future_mask.size(0) < dim + ): + self._future_mask = torch.triu( + utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 + ) + return self._future_mask[:dim, :dim] + + def upgrade_state_dict_named(self, state_dict, name): + return state_dict + + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5) + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +def Linear(in_features, out_features, bias=True): + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.0) + return m diff --git a/fairseq/models/huggingface/__init__.py b/fairseq/models/huggingface/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f7911c2c8edf516855023a285b18935e5389ec02 --- /dev/null +++ b/fairseq/models/huggingface/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + + +# automatically import any Python files in the models/huggingface/ directory +models_dir = os.path.dirname(__file__) +for file in os.listdir(models_dir): + path = os.path.join(models_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + model_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("fairseq.models.huggingface." + model_name) diff --git a/fairseq/models/huggingface/__pycache__/__init__.cpython-310.pyc b/fairseq/models/huggingface/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1997f5702523e82a74e994fca81c319e9be76428 Binary files /dev/null and b/fairseq/models/huggingface/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/models/huggingface/__pycache__/hf_gpt2.cpython-310.pyc b/fairseq/models/huggingface/__pycache__/hf_gpt2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbf7a5ef9aaec5c298678ede6d25f9cea812cce6 Binary files /dev/null and b/fairseq/models/huggingface/__pycache__/hf_gpt2.cpython-310.pyc differ diff --git a/fairseq/models/huggingface/hf_gpt2.py b/fairseq/models/huggingface/hf_gpt2.py new file mode 100644 index 0000000000000000000000000000000000000000..3a8eb78198f5808557092f814e92f1c9d72933ec --- /dev/null +++ b/fairseq/models/huggingface/hf_gpt2.py @@ -0,0 +1,168 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import sys +from typing import Dict, List, Optional + +import torch +from fairseq.models import ( + FairseqIncrementalDecoder, + FairseqLanguageModel, + register_model, + register_model_architecture, +) + + +logger = logging.getLogger(__name__) + + +DEFAULT_MAX_TARGET_POSITIONS = 1024 + + +@register_model("hf_gpt2") +class HuggingFaceGPT2LanguageModel(FairseqLanguageModel): + def __init__(self, decoder): + super().__init__(decoder) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--embed-dim', type=int, metavar='N', + help='embedding dimension') + parser.add_argument('--num-attention-heads', type=int, metavar='N', + help='num attention heads') + parser.add_argument('--num-layers', type=int, metavar='N', + help='num layers') + parser.add_argument('--dropout', type=float, metavar='D', + help='dropout probability for all fully connected layers ' + 'in the embeddings, encoder, and pooler') + parser.add_argument('--attention-dropout', type=float, metavar='D', + help='dropout probability for attention weights') + # fmt: on + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + default_architecture(args) + return cls(HuggingFaceGPT2Decoder(args, task)) + + +class HuggingFaceGPT2Decoder(FairseqIncrementalDecoder): + def __init__(self, args, task): + try: + from transformers import GPT2Config, GPT2LMHeadModel + except ImportError: + raise ImportError( + "\n\nPlease install huggingface/transformers with:" + "\n\n pip install transformers" + ) + + super().__init__(task.target_dictionary) + + config = GPT2Config( + vocab_size=len(task.target_dictionary), + n_positions=args.max_target_positions + 1, + n_ctx=args.max_target_positions, + n_embd=args.embed_dim, + n_layer=args.num_layers, + n_head=args.num_attention_heads, + resid_pdrop=args.dropout, + embd_pdrop=args.dropout, + attn_pdrop=args.attention_dropout, + layer_norm_epsilon=1e-6, + ) + self.model = GPT2LMHeadModel(config) + + # set zero embedding for padding symbol + self.pad_idx = task.target_dictionary.pad() + self.model.transformer.wte.weight.data[self.pad_idx].zero_() + self.model.transformer.wpe.weight.data[0].zero_() + + def forward( + self, + prev_output_tokens, + src_lengths=None, + incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None, + encoder_out=None, + ): + features = self.extract_features(prev_output_tokens, incremental_state) + lm_logits = self.model.lm_head(features) + return (lm_logits,) + + def extract_features( + self, + prev_output_tokens, + incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None, + ): + if incremental_state: + past = self.get_incremental_state("past") + else: + past = None + + # don't attend to padding symbols + attention_mask = prev_output_tokens.ne(self.pad_idx).int() + + # set position ids to exclude padding symbols + position_ids = attention_mask * ( + torch.arange(1, 1 + prev_output_tokens.size(1)) + .to(prev_output_tokens) + .repeat(prev_output_tokens.size(0), 1) + ) + + outputs = self.model.transformer( + input_ids=prev_output_tokens, + past=past, + attention_mask=attention_mask, + position_ids=position_ids, + ) + last_hidden_states = outputs[0] + + if incremental_state: + self.set_incremental_state(incremental_state, "past", outputs[1]) + + return last_hidden_states + + def max_positions(self): + return self.model.config.n_positions - 1 + + +@register_model_architecture("hf_gpt2", "hf_gpt2") +def default_architecture(args): + if getattr(args, "max_target_positions", None) is None: + args.max_target_positions = getattr( + args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS + ) + args.embed_dim = getattr(args, "embed_dim", 768) + args.num_attention_heads = getattr(args, "num_attention_heads", 12) + args.num_layers = getattr(args, "num_layers", 12) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + + +@register_model_architecture("hf_gpt2", "hf_gpt2_medium") +def hf_gpt2_medium(args): + args.embed_dim = getattr(args, "embed_dim", 1024) + args.num_attention_heads = getattr(args, "num_attention_heads", 16) + args.num_layers = getattr(args, "num_layers", 24) + default_architecture(args) + + +@register_model_architecture("hf_gpt2", "hf_gpt2_large") +def hf_gpt2_large(args): + args.embed_dim = getattr(args, "embed_dim", 1280) + args.num_attention_heads = getattr(args, "num_attention_heads", 20) + args.num_layers = getattr(args, "num_layers", 36) + default_architecture(args) + + +@register_model_architecture("hf_gpt2", "hf_gpt2_xl") +def hf_gpt2_xl(args): + args.embed_dim = getattr(args, "embed_dim", 1600) + args.num_attention_heads = getattr(args, "num_attention_heads", 25) + args.num_layers = getattr(args, "num_layers", 48) + default_architecture(args) diff --git a/fairseq/models/lightconv.py b/fairseq/models/lightconv.py new file mode 100644 index 0000000000000000000000000000000000000000..7950280e303129f9e529248cd2fe8fbc7a28c10a --- /dev/null +++ b/fairseq/models/lightconv.py @@ -0,0 +1,1119 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq import utils +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderDecoderModel, + FairseqIncrementalDecoder, + register_model, + register_model_architecture, +) +from fairseq.modules import ( + AdaptiveSoftmax, + DynamicConv_scripatable as DynamicConv, + FairseqDropout, + LayerNorm, + LightweightConv, + MultiheadAttention, + PositionalEmbedding, +) +from fairseq.utils import safe_hasattr +from torch import Tensor + + +@register_model("lightconv") +class LightConvModel(FairseqEncoderDecoderModel): + """ + LightConv and DynamicConv model from `"Pay Less Attention with Lightweight and Dynamic Convolutions" (Wu, et al, 2019) + `_. + To use LightConv please set ``--encoder-conv-type lightweight --decoder-conv-type lightweight`` + To use DynamicConv please set ``--encoder-conv-type dynamic --decoder-conv-type dynamic`` + + Args: + encoder (LightConvEncoder): the encoder + decoder (LightConvDecoder): the decoder + + The LightConv model provides the following named architectures and + command-line arguments: + + .. argparse:: + :ref: fairseq.models.lightconv_parser + :prog: + """ + + @classmethod + def hub_models(cls): + # fmt: off + + def moses_subword(path): + return { + 'path': path, + 'tokenizer': 'moses', + 'bpe': 'subword_nmt', + } + + return { + 'lightconv.no_glu.iwslt14.de-en': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/iwslt14.de-en.lightconv.tar.gz'), + 'dynamicconv.no_glu.iwslt14.de-en': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/iwslt14.de-en.dynamicconv.tar.gz'), + 'lightconv.no_glu.wmt16.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.lightconv.tar.gz'), + 'dynamicconv.no_glu.wmt16.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.dynamicconv.tar.gz'), + 'lightconv.glu.wmt16.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.lightconv-glu.tar.gz'), + 'dynamicconv.glu.wmt16.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.dynamicconv-glu.tar.gz'), + 'lightconv.glu.wmt17.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.lightconv-glu.tar.gz'), + 'dynamicconv.glu.wmt17.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.dynamicconv-glu.tar.gz'), + 'lightconv.glu.wmt14.en-fr': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt14.en-fr.joined-dict.lightconv-glu.tar.gz'), + 'dynamicconv.glu.wmt14.en-fr': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt14.en-fr.joined-dict.dynamicconv-glu.tar.gz'), + 'lightconv.glu.wmt17.zh-en': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt17.zh-en.lightconv-glu.tar.gz'), + 'dynamicconv.glu.wmt17.zh-en': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt17.zh-en.dynamicconv-glu.tar.gz'), + } + # fmt: on + + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + "--dropout", type=float, metavar="D", help="dropout probability" + ) + parser.add_argument( + "--attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights", + ) + parser.add_argument( + "--relu-dropout", + type=float, + metavar="D", + help="dropout probability after ReLU in FFN", + ) + parser.add_argument( + "--input-dropout", + type=float, + metavar="D", + help="dropout probability of the inputs", + ) + parser.add_argument( + "--encoder-embed-path", + type=str, + metavar="STR", + help="path to pre-trained encoder embedding", + ) + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension", + ) + parser.add_argument( + "--encoder-conv-dim", + type=int, + metavar="N", + help="encoder embedding dimension", + ) + parser.add_argument( + "--encoder-ffn-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension for FFN", + ) + parser.add_argument( + "--encoder-layers", type=int, metavar="N", help="num encoder layers" + ) + parser.add_argument( + "--encoder-attention-heads", + type=int, + metavar="N", + help="num encoder attention heads or LightConv/DynamicConv heads", + ) + parser.add_argument( + "--encoder-normalize-before", + action="store_true", + help="apply layernorm before each encoder block", + ) + parser.add_argument( + "--encoder-learned-pos", + action="store_true", + help="use learned positional embeddings in the encoder", + ) + parser.add_argument( + "--decoder-embed-path", + type=str, + metavar="STR", + help="path to pre-trained decoder embedding", + ) + parser.add_argument( + "--decoder-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension", + ) + parser.add_argument( + "--decoder-conv-dim", + type=int, + metavar="N", + help="decoder embedding dimension", + ) + parser.add_argument( + "--decoder-ffn-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension for FFN", + ) + parser.add_argument( + "--decoder-layers", type=int, metavar="N", help="num decoder layers" + ) + parser.add_argument( + "--decoder-attention-heads", + type=int, + metavar="N", + help="num decoder attention heads or LightConv/DynamicConv heads", + ) + parser.add_argument( + "--decoder-learned-pos", + action="store_true", + help="use learned positional embeddings in the decoder", + ) + parser.add_argument( + "--decoder-normalize-before", + action="store_true", + help="apply layernorm before each decoder block", + ) + parser.add_argument( + "--share-decoder-input-output-embed", + action="store_true", + help="share decoder input and output embeddings", + ) + parser.add_argument( + "--share-all-embeddings", + action="store_true", + help="share encoder, decoder and output embeddings" + " (requires shared dictionary and embed dim)", + ) + parser.add_argument( + "--adaptive-softmax-cutoff", + metavar="EXPR", + help="comma separated list of adaptive softmax cutoff points. " + "Must be used with adaptive_loss criterion", + ), + parser.add_argument( + "--adaptive-softmax-dropout", + type=float, + metavar="D", + help="sets adaptive softmax dropout for the tail projections", + ) + + """LightConv and DynamicConv arguments""" + parser.add_argument( + "--encoder-kernel-size-list", + type=lambda x: utils.eval_str_list(x, int), + help='list of kernel size (default: "[3,7,15,31,31,31,31]")', + ) + parser.add_argument( + "--decoder-kernel-size-list", + type=lambda x: utils.eval_str_list(x, int), + help='list of kernel size (default: "[3,7,15,31,31,31]")', + ) + parser.add_argument( + "--encoder-glu", type=utils.eval_bool, help="glu after in proj" + ) + parser.add_argument( + "--decoder-glu", type=utils.eval_bool, help="glu after in proj" + ) + parser.add_argument( + "--encoder-conv-type", + default="dynamic", + type=str, + choices=["dynamic", "lightweight"], + help="type of convolution", + ) + parser.add_argument( + "--decoder-conv-type", + default="dynamic", + type=str, + choices=["dynamic", "lightweight"], + help="type of convolution", + ) + parser.add_argument("--weight-softmax", default=True, type=utils.eval_bool) + parser.add_argument( + "--weight-dropout", + type=float, + metavar="D", + help="dropout probability for conv weights", + ) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_architecture(args) + + if not safe_hasattr(args, "max_source_positions"): + args.max_source_positions = 1024 + if not safe_hasattr(args, "max_target_positions"): + args.max_target_positions = 1024 + + src_dict, tgt_dict = task.source_dictionary, task.target_dictionary + + def build_embedding(dictionary, embed_dim, path=None): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + emb = Embedding(num_embeddings, embed_dim, padding_idx) + # if provided, load from preloaded dictionaries + if path: + embed_dict = utils.parse_embedding(path) + utils.load_embedding(embed_dict, dictionary, emb) + return emb + + if args.share_all_embeddings: + if src_dict != tgt_dict: + raise RuntimeError( + "--share-all-embeddings requires a joined dictionary" + ) + if args.encoder_embed_dim != args.decoder_embed_dim: + raise RuntimeError( + "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" + ) + if args.decoder_embed_path and ( + args.decoder_embed_path != args.encoder_embed_path + ): + raise RuntimeError( + "--share-all-embeddings not compatible with --decoder-embed-path" + ) + encoder_embed_tokens = build_embedding( + src_dict, args.encoder_embed_dim, args.encoder_embed_path + ) + decoder_embed_tokens = encoder_embed_tokens + args.share_decoder_input_output_embed = True + else: + encoder_embed_tokens = build_embedding( + src_dict, args.encoder_embed_dim, args.encoder_embed_path + ) + decoder_embed_tokens = build_embedding( + tgt_dict, args.decoder_embed_dim, args.decoder_embed_path + ) + + encoder = LightConvEncoder(args, src_dict, encoder_embed_tokens) + decoder = LightConvDecoder(args, tgt_dict, decoder_embed_tokens) + return LightConvModel(encoder, decoder) + + def forward( + self, + src_tokens: Tensor, + src_lengths: Tensor, + prev_output_tokens: Tensor, + ): + """ + (The forward method inherited from the base class has a **kwargs + argument in its input, which is not supported in torchscript. This + method overwrites the forward method definition without **kwargs.) + + Run the forward pass for an encoder-decoder model. + + First feed a batch of source tokens through the encoder. Then, feed the + encoder output and previous decoder outputs (i.e., teacher forcing) to + the decoder to produce the next outputs:: + + encoder_out = self.encoder(src_tokens, src_lengths) + return self.decoder(prev_output_tokens, encoder_out) + + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (LongTensor): source sentence lengths of shape `(batch)` + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + encoder_out = self.encoder(src_tokens, src_lengths) + decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out) + return decoder_out + + +class LightConvEncoder(FairseqEncoder): + """ + LightConv encoder consisting of *args.encoder_layers* layers. Each layer + is a :class:`LightConvEncoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): encoding dictionary + embed_tokens (torch.nn.Embedding): input embedding + """ + + def __init__(self, args, dictionary, embed_tokens): + super().__init__(dictionary) + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) + + embed_dim = embed_tokens.embedding_dim + self.padding_idx = embed_tokens.padding_idx + self.max_source_positions = args.max_source_positions + + self.embed_tokens = embed_tokens + self.embed_scale = math.sqrt(embed_dim) + self.embed_positions = ( + PositionalEmbedding( + args.max_source_positions, + embed_dim, + self.padding_idx, + learned=args.encoder_learned_pos, + ) + if not args.no_token_positional_embeddings + else None + ) + + self.layers = nn.ModuleList([]) + self.layers.extend( + [ + LightConvEncoderLayer( + args, kernel_size=args.encoder_kernel_size_list[i] + ) + for i in range(args.encoder_layers) + ] + ) + self.register_buffer("version", torch.Tensor([2])) + self.normalize = args.encoder_normalize_before + if self.normalize: + self.layer_norm = LayerNorm(embed_dim) + else: + self.layer_norm = None + + def forward( + self, src_tokens: Tensor, src_lengths: Optional[Tensor] = None + ) -> Dict[str, List[Tensor]]: + """ + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + + Returns: + dict: + - **encoder_out** (Tensor): the last encoder layer's output of + shape `(src_len, batch, embed_dim)` + - **encoder_padding_mask** (ByteTensor): the positions of + padding elements of shape `(batch, src_len)` + """ + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(src_tokens) + if self.embed_positions is not None: + x += self.embed_positions(src_tokens) + x = self.dropout_module(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # compute padding mask + encoder_padding_mask = src_tokens.eq(self.padding_idx) # B x T + if not encoder_padding_mask.any(): + encoder_mask = None + else: + encoder_mask = encoder_padding_mask + + # encoder layers + for layer in self.layers: + x = layer(x, encoder_mask) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + output_dict: Dict[str, List[Tensor]] = {} + if src_lengths is not None: + output_dict["src_lengths"] = [src_lengths] + output_dict["encoder_out"] = [x] # T x B x C + if encoder_mask is not None: + output_dict["encoder_padding_mask"] = [encoder_mask] # B x T + + return output_dict + + @torch.jit.export + def reorder_encoder_out( + self, encoder_out: Dict[str, List[Tensor]], new_order: Tensor + ): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + if len(encoder_out["encoder_out"]) == 0: + encoder = [] + else: + encoder = [encoder_out["encoder_out"][0].index_select(1, new_order)] + output_dict = {"encoder_out": encoder} + + if ("encoder_padding_mask" not in encoder_out) or ( + len(encoder_out["encoder_padding_mask"]) == 0 + ): + encoder_padding_mask = [] + else: + encoder_padding_mask = [ + encoder_out["encoder_padding_mask"][0].index_select(0, new_order) + ] + output_dict["encoder_padding_mask"] = encoder_padding_mask + return output_dict + + def max_positions(self): + """Maximum input length supported by the encoder.""" + if self.embed_positions is None: + return self.max_source_positions + return min(self.max_source_positions, self.embed_positions.max_positions) + + +class LightConvDecoder(FairseqIncrementalDecoder): + """ + LightConv decoder consisting of *args.decoder_layers* layers. Each layer + is a :class:`LightConvDecoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): decoding dictionary + embed_tokens (torch.nn.Embedding): output embedding + no_encoder_attn (bool, optional): whether to attend to encoder outputs. + Default: ``False`` + """ + + def __init__( + self, args, dictionary, embed_tokens, no_encoder_attn=False, final_norm=True + ): + super().__init__(dictionary) + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) + self.share_input_output_embed = args.share_decoder_input_output_embed + + input_embed_dim = embed_tokens.embedding_dim + embed_dim = args.decoder_embed_dim + output_embed_dim = args.decoder_output_dim + + padding_idx = embed_tokens.padding_idx + self.max_target_positions = args.max_target_positions + + self.embed_tokens = embed_tokens + self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim + + self.project_in_dim = ( + Linear(input_embed_dim, embed_dim, bias=False) + if embed_dim != input_embed_dim + else None + ) + + self.embed_positions = ( + PositionalEmbedding( + args.max_target_positions, + embed_dim, + padding_idx, + learned=args.decoder_learned_pos, + ) + if not args.no_token_positional_embeddings + else None + ) + + self.layers = nn.ModuleList([]) + self.layers.extend( + [ + LightConvDecoderLayer( + args, + no_encoder_attn, + kernel_size=args.decoder_kernel_size_list[i], + dictionary=dictionary, + ) + for i in range(args.decoder_layers) + ] + ) + + self.adaptive_softmax = None + self.output_projection = None + + self.project_out_dim = ( + Linear(embed_dim, output_embed_dim, bias=False) + if embed_dim != output_embed_dim and not args.tie_adaptive_weights + else None + ) + + if args.adaptive_softmax_cutoff is not None: + self.adaptive_softmax = AdaptiveSoftmax( + len(dictionary), + output_embed_dim, + utils.eval_str_list(args.adaptive_softmax_cutoff, type=int), + dropout=args.adaptive_softmax_dropout, + adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, + factor=args.adaptive_softmax_factor, + tie_proj=args.tie_adaptive_proj, + ) + elif self.share_input_output_embed: + self.output_projection = nn.Linear( + self.embed_tokens.weight.shape[1], + self.embed_tokens.weight.shape[0], + bias=False, + ) + self.output_projection.weight = self.embed_tokens.weight + + else: + self.output_projection = nn.Linear( + output_embed_dim, len(dictionary), bias=False + ) + nn.init.normal_( + self.output_projection.weight, mean=0, std=output_embed_dim**-0.5 + ) + self.register_buffer("version", torch.Tensor([2])) + self.normalize = args.decoder_normalize_before and final_norm + if self.normalize: + self.layer_norm = LayerNorm(embed_dim) + else: + self.layer_norm = None + + def forward( + self, + prev_output_tokens: Tensor, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + src_lengths: Optional[Any] = None, + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (Tensor, optional): output from the encoder, used for + encoder-side attention + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + + Returns: + tuple: + - the last decoder layer's output of shape `(batch, tgt_len, + vocab)` + - the last decoder layer's attention weights of shape `(batch, + tgt_len, src_len)` + """ + # embed positions + positions = ( + self.embed_positions( + prev_output_tokens, + incremental_state=incremental_state, + ) + if self.embed_positions is not None + else None + ) + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(prev_output_tokens.contiguous()) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if positions is not None: + x += positions + x = self.dropout_module(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + attn = None + + inner_states: List[Optional[Tensor]] = [x] + + # decoder layers + attn: Optional[Tensor] = None + for layer in self.layers: + encoder: Optional[Tensor] = None + encoder_padding_mask: Optional[Tensor] = None + if encoder_out is not None: + if len(encoder_out["encoder_out"]) > 0: + encoder = encoder_out["encoder_out"][0] + if ( + "encoder_padding_mask" in encoder_out + and len(encoder_out["encoder_padding_mask"]) > 0 + ): + encoder_padding_mask = encoder_out["encoder_padding_mask"][0] + x, attn = layer( + x, + encoder, + encoder_padding_mask, + incremental_state, + ) + inner_states.append(x) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + if self.project_out_dim is not None: + x = self.project_out_dim(x) + + if self.adaptive_softmax is None: + # project back to size of vocabulary + x = self.output_projection(x) + + return x, {"attn": [attn], "inner_states": inner_states} + + def max_positions(self): + """Maximum output length supported by the decoder.""" + if self.embed_positions is None: + return self.max_target_positions + return min(self.max_target_positions, self.embed_positions.max_positions) + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + if ( + not hasattr(self, "_future_mask") + or self._future_mask is None + or self._future_mask.device != tensor.device + ): + self._future_mask = torch.triu( + utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 + ) + if self._future_mask.size(0) < dim: + self._future_mask = torch.triu( + utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1 + ) + return self._future_mask[:dim, :dim] + + +class LightConvEncoderLayer(nn.Module): + """Encoder layer block. + + Args: + args (argparse.Namespace): parsed command-line arguments + kernel_size: kernel size of the convolution + """ + + def __init__(self, args, kernel_size=0): + super().__init__() + self.embed_dim = args.encoder_embed_dim + self.conv_dim = args.encoder_conv_dim + padding_l = ( + kernel_size // 2 + if kernel_size % 2 == 1 + else ((kernel_size - 1) // 2, kernel_size // 2) + ) + + if args.encoder_glu: + self.linear1 = Linear(self.embed_dim, 2 * self.conv_dim) + self.act = nn.GLU() + else: + self.linear1 = Linear(self.embed_dim, self.conv_dim) + self.act = None + if args.encoder_conv_type == "lightweight": + self.conv = LightweightConv( + self.conv_dim, + kernel_size, + padding_l=padding_l, + weight_softmax=args.weight_softmax, + num_heads=args.encoder_attention_heads, + weight_dropout=args.weight_dropout, + ) + elif args.encoder_conv_type == "dynamic": + self.conv = DynamicConv( + self.conv_dim, + kernel_size, + padding_l=padding_l, + weight_softmax=args.weight_softmax, + num_heads=args.encoder_attention_heads, + weight_dropout=args.weight_dropout, + ) + else: + raise NotImplementedError + self.linear2 = Linear(self.conv_dim, self.embed_dim) + + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) + self.relu_dropout_module = FairseqDropout( + args.relu_dropout, module_name=self.__class__.__name__ + ) + self.input_dropout_module = FairseqDropout( + args.input_dropout, module_name=self.__class__.__name__ + ) + self.normalize_before = args.encoder_normalize_before + self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim) + self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim) + self.layer_norm1 = LayerNorm(self.embed_dim) + self.layer_norm2 = LayerNorm(self.embed_dim) + + def forward(self, x, encoder_padding_mask: Optional[Tensor] = None) -> Tensor: + """ + Args: + x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_padding_mask (ByteTensor): binary ByteTensor of shape + `(batch, src_len)` where padding elements are indicated by ``1``. + + Returns: + encoded output of shape `(batch, src_len, embed_dim)` + """ + residual = x + normalize = self.maybe_layer_norm(before=True) + if normalize: + x = self.layer_norm1(x) + x = self.input_dropout_module(x) + x = self.linear1(x) + if self.act is not None: + x = self.act(x) + if encoder_padding_mask is not None: + x = x.masked_fill(encoder_padding_mask.transpose(0, 1).unsqueeze(2), 0) + x = self.conv(x) + x = self.linear2(x) + x = self.dropout_module(x) + x = residual + x + normalize = self.maybe_layer_norm(after=True) + if normalize: + x = self.layer_norm1(x) + + residual = x + normalize = self.maybe_layer_norm(before=True) + if normalize: + x = self.layer_norm2(x) + x = F.relu(self.fc1(x)) + x = self.relu_dropout_module(x) + x = self.fc2(x) + x = self.dropout_module(x) + x = residual + x + normalize = self.maybe_layer_norm(after=True) + if normalize: + x = self.layer_norm2(x) + return x + + def maybe_layer_norm(self, before: bool = False, after: bool = False): + assert before ^ after, "Incorrect arguments" + return after ^ self.normalize_before + + def extra_repr(self): + return ( + "dropout={}, relu_dropout={}, input_dropout={}, normalize_before={}".format( + self.dropout_module.p, + self.relu_dropout_module.p, + self.input_dropout_module.p, + self.normalize_before, + ) + ) + + +class LightConvDecoderLayer(nn.Module): + """Decoder layer block. + + Args: + args (argparse.Namespace): parsed command-line arguments + no_encoder_attn (bool, optional): whether to attend to encoder outputs. + Default: ``False`` + kernel_size: kernel size of the convolution + """ + + def __init__(self, args, no_encoder_attn=False, kernel_size=0, dictionary=None): + super().__init__() + self.embed_dim = args.decoder_embed_dim + self.conv_dim = args.decoder_conv_dim + if args.decoder_glu: + self.linear1 = Linear(self.embed_dim, 2 * self.conv_dim) + self.act = nn.GLU() + else: + self.linear1 = Linear(self.embed_dim, self.conv_dim) + self.act = None + if args.decoder_conv_type == "lightweight": + self.conv = LightweightConv( + self.conv_dim, + kernel_size, + padding_l=kernel_size - 1, + weight_softmax=args.weight_softmax, + num_heads=args.decoder_attention_heads, + weight_dropout=args.weight_dropout, + ) + elif args.decoder_conv_type == "dynamic": + self.conv = DynamicConv( + self.conv_dim, + kernel_size, + padding_l=kernel_size - 1, + weight_softmax=args.weight_softmax, + num_heads=args.decoder_attention_heads, + weight_dropout=args.weight_dropout, + ) + else: + raise NotImplementedError + self.linear2 = Linear(self.conv_dim, self.embed_dim) + + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) + self.relu_dropout_module = FairseqDropout( + args.relu_dropout, module_name=self.__class__.__name__ + ) + self.input_dropout_module = FairseqDropout( + args.input_dropout, module_name=self.__class__.__name__ + ) + self.normalize_before = args.decoder_normalize_before + + self.conv_layer_norm = LayerNorm(self.embed_dim) + + if no_encoder_attn: + self.encoder_attn = None + self.encoder_attn_layer_norm = None + else: + self.encoder_attn = MultiheadAttention( + self.embed_dim, + args.decoder_attention_heads, + dropout=args.attention_dropout, + encoder_decoder_attention=True, + dictionary=dictionary, + ) + self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) + + self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) + self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) + + self.final_layer_norm = LayerNorm(self.embed_dim) + self.need_attn = True + + def forward( + self, + x: Tensor, + encoder_out: Optional[Tensor], + encoder_padding_mask: Optional[Tensor], + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + prev_conv_state: Optional[Tensor] = None, + prev_attn_state: Optional[Tuple[Tensor, Tensor]] = None, + conv_mask: Optional[Tensor] = None, + conv_padding_mask: Optional[Tensor] = None, + ): + """ + Args: + x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_padding_mask (ByteTensor): binary ByteTensor of shape + `(batch, src_len)` where padding elements are indicated by ``1``. + + Returns: + encoded output of shape `(batch, src_len, embed_dim)` + """ + residual = x + normalize = self.maybe_layer_norm(before=True) + if normalize: + x = self.conv_layer_norm(x) + if prev_conv_state is not None: + self.conv._set_input_buffer(incremental_state, prev_conv_state) + x = self.input_dropout_module(x) + x = self.linear1(x) + if self.act is not None: + x = self.act(x) + x = self.conv(x, incremental_state=incremental_state) + x = self.linear2(x) + x = self.dropout_module(x) + x = residual + x + normalize = self.maybe_layer_norm(after=True) + if normalize: + x = self.conv_layer_norm(x) + + attn: Optional[Tensor] = None + if self.encoder_attn is not None: + residual = x + normalize = self.maybe_layer_norm(before=True) + if normalize: + x = self.encoder_attn_layer_norm(x) + + if prev_attn_state is not None: + saved_state: Dict[str, Optional[Tensor]] = { + "prev_key": prev_attn_state[0], + "prev_value": prev_attn_state[1], + } + self.encoder_attn._set_input_buffer(incremental_state, saved_state) + x, attn = self.encoder_attn( + query=x, + key=encoder_out, + value=encoder_out, + key_padding_mask=encoder_padding_mask, + incremental_state=incremental_state, + static_kv=True, + need_weights=(not self.training and self.need_attn), + ) + x = self.dropout_module(x) + x = residual + x + normalize = self.maybe_layer_norm(after=True) + if normalize: + x = self.encoder_attn_layer_norm(x) + + residual = x + normalize = self.maybe_layer_norm(before=True) + if normalize: + x = self.final_layer_norm(x) + x = F.relu(self.fc1(x)) + x = self.relu_dropout_module(x) + x = self.fc2(x) + x = self.dropout_module(x) + x = residual + x + normalize = self.maybe_layer_norm(after=True) + if normalize: + x = self.final_layer_norm(x) + return x, attn + + def maybe_layer_norm(self, before: bool = False, after: bool = False): + assert before ^ after, "Incorrect usage" + return after ^ self.normalize_before + + def make_generation_fast_(self, need_attn: bool = False, **kwargs): + self.need_attn = need_attn + + def extra_repr(self): + return ( + "dropout={}, relu_dropout={}, input_dropout={}, normalize_before={}".format( + self.dropout_module.p, + self.relu_dropout_module.p, + self.input_dropout_module.p, + self.normalize_before, + ) + ) + + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5) + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +def Linear(in_features, out_features, bias=True): + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.0) + return m + + +@register_model_architecture("lightconv", "lightconv") +def base_architecture(args): + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_layers = getattr(args, "encoder_layers", 7) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.relu_dropout = getattr(args, "relu_dropout", 0.0) + args.dropout = getattr(args, "dropout", 0.1) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.share_all_embeddings = getattr(args, "share_all_embeddings", False) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + + args.encoder_conv_dim = getattr(args, "encoder_conv_dim", args.encoder_embed_dim) + args.decoder_conv_dim = getattr(args, "decoder_conv_dim", args.decoder_embed_dim) + + args.encoder_kernel_size_list = getattr( + args, "encoder_kernel_size_list", [3, 7, 15, 31, 31, 31, 31] + ) + args.decoder_kernel_size_list = getattr( + args, "decoder_kernel_size_list", [3, 7, 15, 31, 31, 31] + ) + if len(args.encoder_kernel_size_list) == 1: + args.encoder_kernel_size_list = ( + args.encoder_kernel_size_list * args.encoder_layers + ) + if len(args.decoder_kernel_size_list) == 1: + args.decoder_kernel_size_list = ( + args.decoder_kernel_size_list * args.decoder_layers + ) + assert ( + len(args.encoder_kernel_size_list) == args.encoder_layers + ), "encoder_kernel_size_list doesn't match encoder_layers" + assert ( + len(args.decoder_kernel_size_list) == args.decoder_layers + ), "decoder_kernel_size_list doesn't match decoder_layers" + args.encoder_glu = getattr(args, "encoder_glu", True) + args.decoder_glu = getattr(args, "decoder_glu", True) + args.input_dropout = getattr(args, "input_dropout", 0.1) + args.weight_dropout = getattr(args, "weight_dropout", args.attention_dropout) + + +@register_model_architecture("lightconv", "lightconv_iwslt_de_en") +def lightconv_iwslt_de_en(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.encoder_layers = getattr(args, "encoder_layers", 7) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.weight_dropout = getattr(args, "weight_dropout", 0.1) + args.encoder_glu = getattr(args, "encoder_glu", False) + args.decoder_glu = getattr(args, "decoder_glu", False) + args.input_dropout = getattr(args, "input_dropout", 0.0) + base_architecture(args) + + +@register_model_architecture("lightconv", "lightconv_wmt_en_de") +def lightconv_wmt_en_de(args): + base_architecture(args) + + +@register_model_architecture("lightconv", "lightconv_wmt_en_de_big") +def lightconv_wmt_en_de_big(args): + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + args.dropout = getattr(args, "dropout", 0.3) + base_architecture(args) + + +@register_model_architecture("lightconv", "lightconv_wmt_en_fr_big") +def lightconv_wmt_en_fr_big(args): + args.dropout = getattr(args, "dropout", 0.1) + lightconv_wmt_en_de_big(args) + + +@register_model_architecture("lightconv", "lightconv_wmt_zh_en_big") +def lightconv_wmt_zh_en_big(args): + args.dropout = getattr(args, "dropout", 0.2) + args.attention_dropout = getattr(args, "attention_dropout", 0.2) + args.weight_dropout = getattr(args, "weight_dropout", 0.2) + lightconv_wmt_en_de_big(args) diff --git a/fairseq/models/lightconv_lm.py b/fairseq/models/lightconv_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..1d9efc4e42a5ecc1b83338055f18ade5a83ea666 --- /dev/null +++ b/fairseq/models/lightconv_lm.py @@ -0,0 +1,306 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq import utils +from fairseq.models import ( + FairseqLanguageModel, + register_model, + register_model_architecture, +) +from fairseq.models.lightconv import Embedding, LightConvDecoder +from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder + + +@register_model("lightconv_lm") +class LightConvLanguageModel(FairseqLanguageModel): + def __init__(self, decoder): + super().__init__(decoder) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + "--dropout", + default=0.1, + type=float, + metavar="D", + help="dropout probability", + ) + parser.add_argument( + "--attention-dropout", + default=0.0, + type=float, + metavar="D", + help="dropout probability for attention weights", + ) + parser.add_argument( + "--relu-dropout", + default=0.0, + type=float, + metavar="D", + help="dropout probability after ReLU in FFN", + ) + parser.add_argument( + "--input-dropout", + type=float, + metavar="D", + help="dropout probability of the inputs", + ) + parser.add_argument( + "--decoder-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension", + ) + parser.add_argument( + "--decoder-output-dim", + type=int, + metavar="N", + help="decoder output dimension", + ) + parser.add_argument( + "--decoder-input-dim", type=int, metavar="N", help="decoder input dimension" + ) + parser.add_argument( + "--decoder-ffn-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension for FFN", + ) + parser.add_argument( + "--decoder-layers", type=int, metavar="N", help="num decoder layers" + ) + parser.add_argument( + "--decoder-attention-heads", + type=int, + metavar="N", + help="num decoder attention heads or LightConv/DynamicConv heads", + ) + parser.add_argument( + "--decoder-normalize-before", + default=False, + action="store_true", + help="apply layernorm before each decoder block", + ) + parser.add_argument( + "--adaptive-softmax-cutoff", + metavar="EXPR", + help="comma separated list of adaptive softmax cutoff points. " + "Must be used with adaptive_loss criterion", + ) + parser.add_argument( + "--adaptive-softmax-dropout", + type=float, + metavar="D", + help="sets adaptive softmax dropout for the tail projections", + ) + parser.add_argument( + "--adaptive-softmax-factor", + type=float, + metavar="N", + help="adaptive input factor", + ) + parser.add_argument( + "--no-token-positional-embeddings", + default=False, + action="store_true", + help="if set, disables positional embeddings (outside self attention)", + ) + parser.add_argument( + "--share-decoder-input-output-embed", + default=False, + action="store_true", + help="share decoder input and output embeddings", + ) + parser.add_argument( + "--character-embeddings", + default=False, + action="store_true", + help="if set, uses character embedding convolutions to produce token embeddings", + ) + parser.add_argument( + "--character-filters", + type=str, + metavar="LIST", + default="[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]", + help="size of character embeddings", + ) + parser.add_argument( + "--character-embedding-dim", + type=int, + metavar="N", + default=4, + help="size of character embeddings", + ) + parser.add_argument( + "--char-embedder-highway-layers", + type=int, + metavar="N", + default=2, + help="number of highway layers for character token embeddder", + ) + parser.add_argument( + "--adaptive-input", + default=False, + action="store_true", + help="if set, uses adaptive input", + ) + parser.add_argument( + "--adaptive-input-factor", + type=float, + metavar="N", + help="adaptive input factor", + ) + parser.add_argument( + "--adaptive-input-cutoff", + metavar="EXPR", + help="comma separated list of adaptive input cutoff points.", + ) + parser.add_argument( + "--tie-adaptive-weights", + action="store_true", + help="if set, ties the weights of adaptive softmax and adaptive input", + ) + parser.add_argument( + "--tie-adaptive-proj", + action="store_true", + help="if set, ties the projection weights of adaptive softmax and adaptive input", + ) + parser.add_argument( + "--decoder-learned-pos", + action="store_true", + help="use learned positional embeddings in the decoder", + ) + + """LightConv and DynamicConv arguments""" + parser.add_argument( + "--decoder-kernel-size-list", + type=lambda x: utils.eval_str_list(x, int), + help='list of kernel size (default: "[3,7,15,31,31,31]")', + ) + parser.add_argument( + "--decoder-glu", type=utils.eval_bool, help="glu after in proj" + ) + parser.add_argument( + "--decoder-conv-type", + default="dynamic", + type=str, + choices=["dynamic", "lightweight"], + help="type of convolution", + ) + parser.add_argument("--weight-softmax", default=True, type=utils.eval_bool) + parser.add_argument( + "--weight-dropout", + type=float, + metavar="D", + help="dropout probability for conv weights", + ) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_lm_architecture(args) + + if getattr(args, "max_source_positions", None) is None: + args.max_source_positions = args.tokens_per_sample + if getattr(args, "max_target_positions", None) is None: + args.max_target_positions = args.tokens_per_sample + + if args.character_embeddings: + embed_tokens = CharacterTokenEmbedder( + task.dictionary, + eval(args.character_filters), + args.character_embedding_dim, + args.decoder_embed_dim, + args.char_embedder_highway_layers, + ) + elif args.adaptive_input: + embed_tokens = AdaptiveInput( + len(task.dictionary), + task.dictionary.pad(), + args.decoder_input_dim, + args.adaptive_input_factor, + args.decoder_embed_dim, + utils.eval_str_list(args.adaptive_input_cutoff, type=int), + ) + else: + embed_tokens = Embedding( + len(task.dictionary), args.decoder_input_dim, task.dictionary.pad() + ) + + if args.tie_adaptive_weights: + assert args.adaptive_input + assert args.adaptive_input_factor == args.adaptive_softmax_factor + assert ( + args.adaptive_softmax_cutoff == args.adaptive_input_cutoff + ), "{} != {}".format( + args.adaptive_softmax_cutoff, args.adaptive_input_cutoff + ) + assert args.decoder_input_dim == args.decoder_output_dim + + decoder = LightConvDecoder( + args, + task.output_dictionary, + embed_tokens, + no_encoder_attn=True, + final_norm=False, + ) + return LightConvLanguageModel(decoder) + + +@register_model_architecture("lightconv_lm", "lightconv_lm") +def base_lm_architecture(args): + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.adaptive_softmax_factor = getattr(args, "adaptive_softmax_factor", 4) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + + args.character_embeddings = getattr(args, "character_embeddings", False) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + args.decoder_conv_dim = getattr(args, "decoder_conv_dim", args.decoder_embed_dim) + + # The model training is not stable without this + args.decoder_normalize_before = True + + args.adaptive_input = getattr(args, "adaptive_input", False) + args.adaptive_input_factor = getattr(args, "adaptive_input_factor", 4) + args.adaptive_input_cutoff = getattr(args, "adaptive_input_cutoff", None) + + args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) + args.tie_adaptive_proj = getattr(args, "tie_adaptive_proj", False) + + args.decoder_kernel_size_list = getattr( + args, "decoder_kernel_size_list", [3, 7, 15, 31, 31, 31] + ) + if len(args.decoder_kernel_size_list) == 1: + args.decoder_kernel_size_list = ( + args.decoder_kernel_size_list * args.decoder_layers + ) + assert ( + len(args.decoder_kernel_size_list) == args.decoder_layers + ), "decoder_kernel_size_list doesn't match decoder_layers" + args.decoder_glu = getattr(args, "decoder_glu", True) + args.input_dropout = getattr(args, "input_dropout", 0.1) + args.weight_dropout = getattr(args, "weight_dropout", args.attention_dropout) + + +@register_model_architecture("lightconv_lm", "lightconv_lm_gbw") +def lightconv_lm_gbw(args): + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + base_lm_architecture(args) diff --git a/fairseq/models/lstm.py b/fairseq/models/lstm.py new file mode 100644 index 0000000000000000000000000000000000000000..8a29156270f05f72500b9142bfb5e613a4d7a19e --- /dev/null +++ b/fairseq/models/lstm.py @@ -0,0 +1,755 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import utils +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderDecoderModel, + FairseqIncrementalDecoder, + register_model, + register_model_architecture, +) +from fairseq.modules import AdaptiveSoftmax, FairseqDropout +from torch import Tensor + + +DEFAULT_MAX_SOURCE_POSITIONS = 1e5 +DEFAULT_MAX_TARGET_POSITIONS = 1e5 + + +@register_model("lstm") +class LSTMModel(FairseqEncoderDecoderModel): + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--dropout', type=float, metavar='D', + help='dropout probability') + parser.add_argument('--encoder-embed-dim', type=int, metavar='N', + help='encoder embedding dimension') + parser.add_argument('--encoder-embed-path', type=str, metavar='STR', + help='path to pre-trained encoder embedding') + parser.add_argument('--encoder-freeze-embed', action='store_true', + help='freeze encoder embeddings') + parser.add_argument('--encoder-hidden-size', type=int, metavar='N', + help='encoder hidden size') + parser.add_argument('--encoder-layers', type=int, metavar='N', + help='number of encoder layers') + parser.add_argument('--encoder-bidirectional', action='store_true', + help='make all layers of encoder bidirectional') + parser.add_argument('--decoder-embed-dim', type=int, metavar='N', + help='decoder embedding dimension') + parser.add_argument('--decoder-embed-path', type=str, metavar='STR', + help='path to pre-trained decoder embedding') + parser.add_argument('--decoder-freeze-embed', action='store_true', + help='freeze decoder embeddings') + parser.add_argument('--decoder-hidden-size', type=int, metavar='N', + help='decoder hidden size') + parser.add_argument('--decoder-layers', type=int, metavar='N', + help='number of decoder layers') + parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N', + help='decoder output embedding dimension') + parser.add_argument('--decoder-attention', type=str, metavar='BOOL', + help='decoder attention') + parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', + help='comma separated list of adaptive softmax cutoff points. ' + 'Must be used with adaptive_loss criterion') + parser.add_argument('--share-decoder-input-output-embed', default=False, + action='store_true', + help='share decoder input and output embeddings') + parser.add_argument('--share-all-embeddings', default=False, action='store_true', + help='share encoder, decoder and output embeddings' + ' (requires shared dictionary and embed dim)') + + # Granular dropout settings (if not specified these default to --dropout) + parser.add_argument('--encoder-dropout-in', type=float, metavar='D', + help='dropout probability for encoder input embedding') + parser.add_argument('--encoder-dropout-out', type=float, metavar='D', + help='dropout probability for encoder output') + parser.add_argument('--decoder-dropout-in', type=float, metavar='D', + help='dropout probability for decoder input embedding') + parser.add_argument('--decoder-dropout-out', type=float, metavar='D', + help='dropout probability for decoder output') + # fmt: on + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + # make sure that all args are properly defaulted (in case there are any new ones) + base_architecture(args) + + if args.encoder_layers != args.decoder_layers: + raise ValueError("--encoder-layers must match --decoder-layers") + + max_source_positions = getattr( + args, "max_source_positions", DEFAULT_MAX_SOURCE_POSITIONS + ) + max_target_positions = getattr( + args, "max_target_positions", DEFAULT_MAX_TARGET_POSITIONS + ) + + def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) + embed_dict = utils.parse_embedding(embed_path) + utils.print_embed_overlap(embed_dict, dictionary) + return utils.load_embedding(embed_dict, dictionary, embed_tokens) + + if args.encoder_embed_path: + pretrained_encoder_embed = load_pretrained_embedding_from_file( + args.encoder_embed_path, task.source_dictionary, args.encoder_embed_dim + ) + else: + num_embeddings = len(task.source_dictionary) + pretrained_encoder_embed = Embedding( + num_embeddings, args.encoder_embed_dim, task.source_dictionary.pad() + ) + + if args.share_all_embeddings: + # double check all parameters combinations are valid + if task.source_dictionary != task.target_dictionary: + raise ValueError("--share-all-embeddings requires a joint dictionary") + if args.decoder_embed_path and ( + args.decoder_embed_path != args.encoder_embed_path + ): + raise ValueError( + "--share-all-embed not compatible with --decoder-embed-path" + ) + if args.encoder_embed_dim != args.decoder_embed_dim: + raise ValueError( + "--share-all-embeddings requires --encoder-embed-dim to " + "match --decoder-embed-dim" + ) + pretrained_decoder_embed = pretrained_encoder_embed + args.share_decoder_input_output_embed = True + else: + # separate decoder input embeddings + pretrained_decoder_embed = None + if args.decoder_embed_path: + pretrained_decoder_embed = load_pretrained_embedding_from_file( + args.decoder_embed_path, + task.target_dictionary, + args.decoder_embed_dim, + ) + # one last double check of parameter combinations + if args.share_decoder_input_output_embed and ( + args.decoder_embed_dim != args.decoder_out_embed_dim + ): + raise ValueError( + "--share-decoder-input-output-embeddings requires " + "--decoder-embed-dim to match --decoder-out-embed-dim" + ) + + if args.encoder_freeze_embed: + pretrained_encoder_embed.weight.requires_grad = False + if args.decoder_freeze_embed: + pretrained_decoder_embed.weight.requires_grad = False + + encoder = LSTMEncoder( + dictionary=task.source_dictionary, + embed_dim=args.encoder_embed_dim, + hidden_size=args.encoder_hidden_size, + num_layers=args.encoder_layers, + dropout_in=args.encoder_dropout_in, + dropout_out=args.encoder_dropout_out, + bidirectional=args.encoder_bidirectional, + pretrained_embed=pretrained_encoder_embed, + max_source_positions=max_source_positions, + ) + decoder = LSTMDecoder( + dictionary=task.target_dictionary, + embed_dim=args.decoder_embed_dim, + hidden_size=args.decoder_hidden_size, + out_embed_dim=args.decoder_out_embed_dim, + num_layers=args.decoder_layers, + dropout_in=args.decoder_dropout_in, + dropout_out=args.decoder_dropout_out, + attention=utils.eval_bool(args.decoder_attention), + encoder_output_units=encoder.output_units, + pretrained_embed=pretrained_decoder_embed, + share_input_output_embed=args.share_decoder_input_output_embed, + adaptive_softmax_cutoff=( + utils.eval_str_list(args.adaptive_softmax_cutoff, type=int) + if args.criterion == "adaptive_loss" + else None + ), + max_target_positions=max_target_positions, + residuals=False, + ) + return cls(encoder, decoder) + + def forward( + self, + src_tokens, + src_lengths, + prev_output_tokens, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + ): + encoder_out = self.encoder(src_tokens, src_lengths=src_lengths) + decoder_out = self.decoder( + prev_output_tokens, + encoder_out=encoder_out, + incremental_state=incremental_state, + ) + return decoder_out + + +class LSTMEncoder(FairseqEncoder): + """LSTM encoder.""" + + def __init__( + self, + dictionary, + embed_dim=512, + hidden_size=512, + num_layers=1, + dropout_in=0.1, + dropout_out=0.1, + bidirectional=False, + left_pad=True, + pretrained_embed=None, + padding_idx=None, + max_source_positions=DEFAULT_MAX_SOURCE_POSITIONS, + ): + super().__init__(dictionary) + self.num_layers = num_layers + self.dropout_in_module = FairseqDropout( + dropout_in * 1.0, module_name=self.__class__.__name__ + ) + self.dropout_out_module = FairseqDropout( + dropout_out * 1.0, module_name=self.__class__.__name__ + ) + self.bidirectional = bidirectional + self.hidden_size = hidden_size + self.max_source_positions = max_source_positions + + num_embeddings = len(dictionary) + self.padding_idx = padding_idx if padding_idx is not None else dictionary.pad() + if pretrained_embed is None: + self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx) + else: + self.embed_tokens = pretrained_embed + + self.lstm = LSTM( + input_size=embed_dim, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=self.dropout_out_module.p if num_layers > 1 else 0.0, + bidirectional=bidirectional, + ) + self.left_pad = left_pad + + self.output_units = hidden_size + if bidirectional: + self.output_units *= 2 + + def forward( + self, + src_tokens: Tensor, + src_lengths: Tensor, + enforce_sorted: bool = True, + ): + """ + Args: + src_tokens (LongTensor): tokens in the source language of + shape `(batch, src_len)` + src_lengths (LongTensor): lengths of each source sentence of + shape `(batch)` + enforce_sorted (bool, optional): if True, `src_tokens` is + expected to contain sequences sorted by length in a + decreasing order. If False, this condition is not + required. Default: True. + """ + if self.left_pad: + # nn.utils.rnn.pack_padded_sequence requires right-padding; + # convert left-padding to right-padding + src_tokens = utils.convert_padding_direction( + src_tokens, + torch.zeros_like(src_tokens).fill_(self.padding_idx), + left_to_right=True, + ) + + bsz, seqlen = src_tokens.size() + + # embed tokens + x = self.embed_tokens(src_tokens) + x = self.dropout_in_module(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # pack embedded source tokens into a PackedSequence + packed_x = nn.utils.rnn.pack_padded_sequence( + x, src_lengths.cpu(), enforce_sorted=enforce_sorted + ) + + # apply LSTM + if self.bidirectional: + state_size = 2 * self.num_layers, bsz, self.hidden_size + else: + state_size = self.num_layers, bsz, self.hidden_size + h0 = x.new_zeros(*state_size) + c0 = x.new_zeros(*state_size) + packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0)) + + # unpack outputs and apply dropout + x, _ = nn.utils.rnn.pad_packed_sequence( + packed_outs, padding_value=self.padding_idx * 1.0 + ) + x = self.dropout_out_module(x) + assert list(x.size()) == [seqlen, bsz, self.output_units] + + if self.bidirectional: + final_hiddens = self.combine_bidir(final_hiddens, bsz) + final_cells = self.combine_bidir(final_cells, bsz) + + encoder_padding_mask = src_tokens.eq(self.padding_idx).t() + + return tuple( + ( + x, # seq_len x batch x hidden + final_hiddens, # num_layers x batch x num_directions*hidden + final_cells, # num_layers x batch x num_directions*hidden + encoder_padding_mask, # seq_len x batch + ) + ) + + def combine_bidir(self, outs, bsz: int): + out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous() + return out.view(self.num_layers, bsz, -1) + + def reorder_encoder_out( + self, encoder_out: Tuple[Tensor, Tensor, Tensor, Tensor], new_order + ): + return tuple( + ( + encoder_out[0].index_select(1, new_order), + encoder_out[1].index_select(1, new_order), + encoder_out[2].index_select(1, new_order), + encoder_out[3].index_select(1, new_order), + ) + ) + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return self.max_source_positions + + +class AttentionLayer(nn.Module): + def __init__(self, input_embed_dim, source_embed_dim, output_embed_dim, bias=False): + super().__init__() + + self.input_proj = Linear(input_embed_dim, source_embed_dim, bias=bias) + self.output_proj = Linear( + input_embed_dim + source_embed_dim, output_embed_dim, bias=bias + ) + + def forward(self, input, source_hids, encoder_padding_mask): + # input: bsz x input_embed_dim + # source_hids: srclen x bsz x source_embed_dim + + # x: bsz x source_embed_dim + x = self.input_proj(input) + + # compute attention + attn_scores = (source_hids * x.unsqueeze(0)).sum(dim=2) + + # don't attend over padding + if encoder_padding_mask is not None: + attn_scores = ( + attn_scores.float() + .masked_fill_(encoder_padding_mask, float("-inf")) + .type_as(attn_scores) + ) # FP16 support: cast to float and back + + attn_scores = F.softmax(attn_scores, dim=0) # srclen x bsz + + # sum weighted sources + x = (attn_scores.unsqueeze(2) * source_hids).sum(dim=0) + + x = torch.tanh(self.output_proj(torch.cat((x, input), dim=1))) + return x, attn_scores + + +class LSTMDecoder(FairseqIncrementalDecoder): + """LSTM decoder.""" + + def __init__( + self, + dictionary, + embed_dim=512, + hidden_size=512, + out_embed_dim=512, + num_layers=1, + dropout_in=0.1, + dropout_out=0.1, + attention=True, + encoder_output_units=512, + pretrained_embed=None, + share_input_output_embed=False, + adaptive_softmax_cutoff=None, + max_target_positions=DEFAULT_MAX_TARGET_POSITIONS, + residuals=False, + ): + super().__init__(dictionary) + self.dropout_in_module = FairseqDropout( + dropout_in * 1.0, module_name=self.__class__.__name__ + ) + self.dropout_out_module = FairseqDropout( + dropout_out * 1.0, module_name=self.__class__.__name__ + ) + self.hidden_size = hidden_size + self.share_input_output_embed = share_input_output_embed + self.need_attn = True + self.max_target_positions = max_target_positions + self.residuals = residuals + self.num_layers = num_layers + + self.adaptive_softmax = None + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + if pretrained_embed is None: + self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) + else: + self.embed_tokens = pretrained_embed + + self.encoder_output_units = encoder_output_units + if encoder_output_units != hidden_size and encoder_output_units != 0: + self.encoder_hidden_proj = Linear(encoder_output_units, hidden_size) + self.encoder_cell_proj = Linear(encoder_output_units, hidden_size) + else: + self.encoder_hidden_proj = self.encoder_cell_proj = None + + # disable input feeding if there is no encoder + # input feeding is described in arxiv.org/abs/1508.04025 + input_feed_size = 0 if encoder_output_units == 0 else hidden_size + self.layers = nn.ModuleList( + [ + LSTMCell( + input_size=input_feed_size + embed_dim + if layer == 0 + else hidden_size, + hidden_size=hidden_size, + ) + for layer in range(num_layers) + ] + ) + + if attention: + # TODO make bias configurable + self.attention = AttentionLayer( + hidden_size, encoder_output_units, hidden_size, bias=False + ) + else: + self.attention = None + + if hidden_size != out_embed_dim: + self.additional_fc = Linear(hidden_size, out_embed_dim) + + if adaptive_softmax_cutoff is not None: + # setting adaptive_softmax dropout to dropout_out for now but can be redefined + self.adaptive_softmax = AdaptiveSoftmax( + num_embeddings, + hidden_size, + adaptive_softmax_cutoff, + dropout=dropout_out, + ) + elif not self.share_input_output_embed: + self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) + + def forward( + self, + prev_output_tokens, + encoder_out: Optional[Tuple[Tensor, Tensor, Tensor, Tensor]] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + src_lengths: Optional[Tensor] = None, + ): + x, attn_scores = self.extract_features( + prev_output_tokens, encoder_out, incremental_state + ) + return self.output_layer(x), attn_scores + + def extract_features( + self, + prev_output_tokens, + encoder_out: Optional[Tuple[Tensor, Tensor, Tensor, Tensor]] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + ): + """ + Similar to *forward* but only return features. + """ + # get outputs from encoder + if encoder_out is not None: + encoder_outs = encoder_out[0] + encoder_hiddens = encoder_out[1] + encoder_cells = encoder_out[2] + encoder_padding_mask = encoder_out[3] + else: + encoder_outs = torch.empty(0) + encoder_hiddens = torch.empty(0) + encoder_cells = torch.empty(0) + encoder_padding_mask = torch.empty(0) + srclen = encoder_outs.size(0) + + if incremental_state is not None and len(incremental_state) > 0: + prev_output_tokens = prev_output_tokens[:, -1:] + + bsz, seqlen = prev_output_tokens.size() + + # embed tokens + x = self.embed_tokens(prev_output_tokens) + x = self.dropout_in_module(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # initialize previous states (or get from cache during incremental generation) + if incremental_state is not None and len(incremental_state) > 0: + prev_hiddens, prev_cells, input_feed = self.get_cached_state( + incremental_state + ) + elif encoder_out is not None: + # setup recurrent cells + prev_hiddens = [encoder_hiddens[i] for i in range(self.num_layers)] + prev_cells = [encoder_cells[i] for i in range(self.num_layers)] + if self.encoder_hidden_proj is not None: + prev_hiddens = [self.encoder_hidden_proj(y) for y in prev_hiddens] + prev_cells = [self.encoder_cell_proj(y) for y in prev_cells] + input_feed = x.new_zeros(bsz, self.hidden_size) + else: + # setup zero cells, since there is no encoder + zero_state = x.new_zeros(bsz, self.hidden_size) + prev_hiddens = [zero_state for i in range(self.num_layers)] + prev_cells = [zero_state for i in range(self.num_layers)] + input_feed = None + + assert ( + srclen > 0 or self.attention is None + ), "attention is not supported if there are no encoder outputs" + attn_scores: Optional[Tensor] = ( + x.new_zeros(srclen, seqlen, bsz) if self.attention is not None else None + ) + outs = [] + for j in range(seqlen): + # input feeding: concatenate context vector from previous time step + if input_feed is not None: + input = torch.cat((x[j, :, :], input_feed), dim=1) + else: + input = x[j] + + for i, rnn in enumerate(self.layers): + # recurrent cell + hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i])) + + # hidden state becomes the input to the next layer + input = self.dropout_out_module(hidden) + if self.residuals: + input = input + prev_hiddens[i] + + # save state for next time step + prev_hiddens[i] = hidden + prev_cells[i] = cell + + # apply attention using the last layer's hidden state + if self.attention is not None: + assert attn_scores is not None + out, attn_scores[:, j, :] = self.attention( + hidden, encoder_outs, encoder_padding_mask + ) + else: + out = hidden + out = self.dropout_out_module(out) + + # input feeding + if input_feed is not None: + input_feed = out + + # save final output + outs.append(out) + + # Stack all the necessary tensors together and store + prev_hiddens_tensor = torch.stack(prev_hiddens) + prev_cells_tensor = torch.stack(prev_cells) + cache_state = torch.jit.annotate( + Dict[str, Optional[Tensor]], + { + "prev_hiddens": prev_hiddens_tensor, + "prev_cells": prev_cells_tensor, + "input_feed": input_feed, + }, + ) + self.set_incremental_state(incremental_state, "cached_state", cache_state) + + # collect outputs across time steps + x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size) + + # T x B x C -> B x T x C + x = x.transpose(1, 0) + + if hasattr(self, "additional_fc") and self.adaptive_softmax is None: + x = self.additional_fc(x) + x = self.dropout_out_module(x) + # srclen x tgtlen x bsz -> bsz x tgtlen x srclen + if not self.training and self.need_attn and self.attention is not None: + assert attn_scores is not None + attn_scores = attn_scores.transpose(0, 2) + else: + attn_scores = None + return x, attn_scores + + def output_layer(self, x): + """Project features to the vocabulary size.""" + if self.adaptive_softmax is None: + if self.share_input_output_embed: + x = F.linear(x, self.embed_tokens.weight) + else: + x = self.fc_out(x) + return x + + def get_cached_state( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + ) -> Tuple[List[Tensor], List[Tensor], Optional[Tensor]]: + cached_state = self.get_incremental_state(incremental_state, "cached_state") + assert cached_state is not None + prev_hiddens_ = cached_state["prev_hiddens"] + assert prev_hiddens_ is not None + prev_cells_ = cached_state["prev_cells"] + assert prev_cells_ is not None + prev_hiddens = [prev_hiddens_[i] for i in range(self.num_layers)] + prev_cells = [prev_cells_[j] for j in range(self.num_layers)] + input_feed = cached_state[ + "input_feed" + ] # can be None for decoder-only language models + return prev_hiddens, prev_cells, input_feed + + def reorder_incremental_state( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + new_order: Tensor, + ): + if incremental_state is None or len(incremental_state) == 0: + return + prev_hiddens, prev_cells, input_feed = self.get_cached_state(incremental_state) + prev_hiddens = [p.index_select(0, new_order) for p in prev_hiddens] + prev_cells = [p.index_select(0, new_order) for p in prev_cells] + if input_feed is not None: + input_feed = input_feed.index_select(0, new_order) + cached_state_new = torch.jit.annotate( + Dict[str, Optional[Tensor]], + { + "prev_hiddens": torch.stack(prev_hiddens), + "prev_cells": torch.stack(prev_cells), + "input_feed": input_feed, + }, + ) + self.set_incremental_state(incremental_state, "cached_state", cached_state_new), + return + + def max_positions(self): + """Maximum output length supported by the decoder.""" + return self.max_target_positions + + def make_generation_fast_(self, need_attn=False, **kwargs): + self.need_attn = need_attn + + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.uniform_(m.weight, -0.1, 0.1) + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +def LSTM(input_size, hidden_size, **kwargs): + m = nn.LSTM(input_size, hidden_size, **kwargs) + for name, param in m.named_parameters(): + if "weight" in name or "bias" in name: + param.data.uniform_(-0.1, 0.1) + return m + + +def LSTMCell(input_size, hidden_size, **kwargs): + m = nn.LSTMCell(input_size, hidden_size, **kwargs) + for name, param in m.named_parameters(): + if "weight" in name or "bias" in name: + param.data.uniform_(-0.1, 0.1) + return m + + +def Linear(in_features, out_features, bias=True, dropout=0.0): + """Linear layer (input: N x T x C)""" + m = nn.Linear(in_features, out_features, bias=bias) + m.weight.data.uniform_(-0.1, 0.1) + if bias: + m.bias.data.uniform_(-0.1, 0.1) + return m + + +@register_model_architecture("lstm", "lstm") +def base_architecture(args): + args.dropout = getattr(args, "dropout", 0.1) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) + args.encoder_freeze_embed = getattr(args, "encoder_freeze_embed", False) + args.encoder_hidden_size = getattr( + args, "encoder_hidden_size", args.encoder_embed_dim + ) + args.encoder_layers = getattr(args, "encoder_layers", 1) + args.encoder_bidirectional = getattr(args, "encoder_bidirectional", False) + args.encoder_dropout_in = getattr(args, "encoder_dropout_in", args.dropout) + args.encoder_dropout_out = getattr(args, "encoder_dropout_out", args.dropout) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_freeze_embed = getattr(args, "decoder_freeze_embed", False) + args.decoder_hidden_size = getattr( + args, "decoder_hidden_size", args.decoder_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 1) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 512) + args.decoder_attention = getattr(args, "decoder_attention", "1") + args.decoder_dropout_in = getattr(args, "decoder_dropout_in", args.dropout) + args.decoder_dropout_out = getattr(args, "decoder_dropout_out", args.dropout) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.share_all_embeddings = getattr(args, "share_all_embeddings", False) + args.adaptive_softmax_cutoff = getattr( + args, "adaptive_softmax_cutoff", "10000,50000,200000" + ) + + +@register_model_architecture("lstm", "lstm_wiseman_iwslt_de_en") +def lstm_wiseman_iwslt_de_en(args): + args.dropout = getattr(args, "dropout", 0.1) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_dropout_in = getattr(args, "encoder_dropout_in", 0) + args.encoder_dropout_out = getattr(args, "encoder_dropout_out", 0) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 256) + args.decoder_dropout_in = getattr(args, "decoder_dropout_in", 0) + args.decoder_dropout_out = getattr(args, "decoder_dropout_out", args.dropout) + base_architecture(args) + + +@register_model_architecture("lstm", "lstm_luong_wmt_en_de") +def lstm_luong_wmt_en_de(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1000) + args.encoder_layers = getattr(args, "encoder_layers", 4) + args.encoder_dropout_out = getattr(args, "encoder_dropout_out", 0) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1000) + args.decoder_layers = getattr(args, "decoder_layers", 4) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 1000) + args.decoder_dropout_out = getattr(args, "decoder_dropout_out", 0) + base_architecture(args) diff --git a/fairseq/models/lstm_lm.py b/fairseq/models/lstm_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..454f0ac36fab78bf02a8e2f07ed9607d1da87e34 --- /dev/null +++ b/fairseq/models/lstm_lm.py @@ -0,0 +1,142 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq import utils +from fairseq.models import ( + FairseqLanguageModel, + register_model, + register_model_architecture, +) +from fairseq.models.lstm import Embedding, LSTMDecoder + + +DEFAULT_MAX_TARGET_POSITIONS = 1e5 + + +@register_model("lstm_lm") +class LSTMLanguageModel(FairseqLanguageModel): + def __init__(self, decoder): + super().__init__(decoder) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--dropout', type=float, metavar='D', + help='dropout probability') + parser.add_argument('--decoder-embed-dim', type=int, metavar='N', + help='decoder embedding dimension') + parser.add_argument('--decoder-embed-path', type=str, metavar='STR', + help='path to pre-trained decoder embedding') + parser.add_argument('--decoder-hidden-size', type=int, metavar='N', + help='decoder hidden size') + parser.add_argument('--decoder-layers', type=int, metavar='N', + help='number of decoder layers') + parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N', + help='decoder output embedding dimension') + parser.add_argument('--decoder-attention', type=str, metavar='BOOL', + help='decoder attention') + parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', + help='comma separated list of adaptive softmax cutoff points. ' + 'Must be used with adaptive_loss criterion') + parser.add_argument('--residuals', default=False, + action='store_true', + help='applying residuals between LSTM layers') + + # Granular dropout settings (if not specified these default to --dropout) + parser.add_argument('--decoder-dropout-in', type=float, metavar='D', + help='dropout probability for decoder input embedding') + parser.add_argument('--decoder-dropout-out', type=float, metavar='D', + help='dropout probability for decoder output') + parser.add_argument('--share-decoder-input-output-embed', default=False, + action='store_true', + help='share decoder input and output embeddings') + # fmt: on + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_architecture(args) + + if getattr(args, "max_target_positions", None) is not None: + max_target_positions = args.max_target_positions + else: + max_target_positions = getattr( + args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS + ) + + def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) + embed_dict = utils.parse_embedding(embed_path) + utils.print_embed_overlap(embed_dict, dictionary) + return utils.load_embedding(embed_dict, dictionary, embed_tokens) + + pretrained_decoder_embed = None + if args.decoder_embed_path: + pretrained_decoder_embed = load_pretrained_embedding_from_file( + args.decoder_embed_path, task.target_dictionary, args.decoder_embed_dim + ) + + if args.share_decoder_input_output_embed: + # double check all parameters combinations are valid + if task.source_dictionary != task.target_dictionary: + raise ValueError( + "--share-decoder-input-output-embeddings requires a joint dictionary" + ) + + if args.decoder_embed_dim != args.decoder_out_embed_dim: + raise ValueError( + "--share-decoder-input-output-embeddings requires " + "--decoder-embed-dim to match --decoder-out-embed-dim" + ) + + decoder = LSTMDecoder( + dictionary=task.dictionary, + embed_dim=args.decoder_embed_dim, + hidden_size=args.decoder_hidden_size, + out_embed_dim=args.decoder_out_embed_dim, + num_layers=args.decoder_layers, + dropout_in=args.decoder_dropout_in, + dropout_out=args.decoder_dropout_out, + attention=False, # decoder-only language model doesn't support attention + encoder_output_units=0, + pretrained_embed=pretrained_decoder_embed, + share_input_output_embed=args.share_decoder_input_output_embed, + adaptive_softmax_cutoff=( + utils.eval_str_list(args.adaptive_softmax_cutoff, type=int) + if args.criterion == "adaptive_loss" + else None + ), + max_target_positions=max_target_positions, + residuals=args.residuals, + ) + + return cls(decoder) + + +@register_model_architecture("lstm_lm", "lstm_lm") +def base_architecture(args): + args.dropout = getattr(args, "dropout", 0.1) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_hidden_size = getattr( + args, "decoder_hidden_size", args.decoder_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 1) + args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 512) + args.decoder_attention = getattr(args, "decoder_attention", "0") + args.decoder_dropout_in = getattr(args, "decoder_dropout_in", args.dropout) + args.decoder_dropout_out = getattr(args, "decoder_dropout_out", args.dropout) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.adaptive_softmax_cutoff = getattr( + args, "adaptive_softmax_cutoff", "10000,50000,200000" + ) + args.residuals = getattr(args, "residuals", False) diff --git a/fairseq/models/masked_lm.py b/fairseq/models/masked_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..b71254cef8c5a4b40b0152c1b4615f1941d6c346 --- /dev/null +++ b/fairseq/models/masked_lm.py @@ -0,0 +1,398 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import utils +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderModel, + register_model, + register_model_architecture, +) +from fairseq.modules import ( + LayerNorm, + SinusoidalPositionalEmbedding, + TransformerSentenceEncoder, +) +from fairseq.modules.transformer_sentence_encoder import init_bert_params +from fairseq.utils import safe_hasattr + + +logger = logging.getLogger(__name__) + + +@register_model("masked_lm") +class MaskedLMModel(FairseqEncoderModel): + """ + Class for training a Masked Language Model. It also supports an + additional sentence level prediction if the sent-loss argument is set. + """ + + def __init__(self, args, encoder): + super().__init__(encoder) + self.args = args + + # if specified then apply bert initialization on the model. We need + # to explictly call this to make sure that the output embeddings + # and projection layers are also correctly initialized + if getattr(args, "apply_bert_init", False): + self.apply(init_bert_params) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # Arguments related to dropout + parser.add_argument( + "--dropout", type=float, metavar="D", help="dropout probability" + ) + parser.add_argument( + "--attention-dropout", + type=float, + metavar="D", + help="dropout probability for" " attention weights", + ) + parser.add_argument( + "--act-dropout", + type=float, + metavar="D", + help="dropout probability after" " activation in FFN", + ) + + # Arguments related to hidden states and self-attention + parser.add_argument( + "--encoder-ffn-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension for FFN", + ) + parser.add_argument( + "--encoder-layers", type=int, metavar="N", help="num encoder layers" + ) + parser.add_argument( + "--encoder-attention-heads", + type=int, + metavar="N", + help="num encoder attention heads", + ) + + # Arguments related to input and output embeddings + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension", + ) + parser.add_argument( + "--share-encoder-input-output-embed", + action="store_true", + help="share encoder input" " and output embeddings", + ) + parser.add_argument( + "--encoder-learned-pos", + action="store_true", + help="use learned positional embeddings in the encoder", + ) + parser.add_argument( + "--no-token-positional-embeddings", + action="store_true", + help="if set, disables positional embeddings" " (outside self attention)", + ) + parser.add_argument( + "--num-segment", type=int, metavar="N", help="num segment in the input" + ) + parser.add_argument( + "--max-positions", type=int, help="number of positional embeddings to learn" + ) + + # Arguments related to sentence level prediction + parser.add_argument( + "--sentence-class-num", + type=int, + metavar="N", + help="number of classes for sentence task", + ) + parser.add_argument( + "--sent-loss", + action="store_true", + help="if set," " calculate sentence level predictions", + ) + + # Arguments related to parameter initialization + parser.add_argument( + "--apply-bert-init", + action="store_true", + help="use custom param initialization for BERT", + ) + + # misc params + parser.add_argument( + "--activation-fn", + choices=utils.get_available_activation_fns(), + help="activation function to use", + ) + parser.add_argument( + "--pooler-activation-fn", + choices=utils.get_available_activation_fns(), + help="Which activation function to use for pooler layer.", + ) + parser.add_argument( + "--encoder-normalize-before", + action="store_true", + help="apply layernorm before each encoder block", + ) + + def forward(self, src_tokens, segment_labels=None, **kwargs): + return self.encoder(src_tokens, segment_labels=segment_labels, **kwargs) + + def max_positions(self): + return self.encoder.max_positions + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + # make sure all arguments are present in older models + base_architecture(args) + + if not safe_hasattr(args, "max_positions"): + args.max_positions = args.tokens_per_sample + + logger.info(args) + + encoder = MaskedLMEncoder(args, task.dictionary) + return cls(args, encoder) + + +class MaskedLMEncoder(FairseqEncoder): + """ + Encoder for Masked Language Modelling. + """ + + def __init__(self, args, dictionary): + super().__init__(dictionary) + + self.padding_idx = dictionary.pad() + self.vocab_size = dictionary.__len__() + self.max_positions = args.max_positions + + self.sentence_encoder = TransformerSentenceEncoder( + padding_idx=self.padding_idx, + vocab_size=self.vocab_size, + num_encoder_layers=args.encoder_layers, + embedding_dim=args.encoder_embed_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=args.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.act_dropout, + max_seq_len=self.max_positions, + num_segments=args.num_segment, + use_position_embeddings=not args.no_token_positional_embeddings, + encoder_normalize_before=args.encoder_normalize_before, + apply_bert_init=args.apply_bert_init, + activation_fn=args.activation_fn, + learned_pos_embedding=args.encoder_learned_pos, + ) + + self.share_input_output_embed = args.share_encoder_input_output_embed + self.embed_out = None + self.sentence_projection_layer = None + self.sentence_out_dim = args.sentence_class_num + self.lm_output_learned_bias = None + + # Remove head is set to true during fine-tuning + self.load_softmax = not getattr(args, "remove_head", False) + + self.masked_lm_pooler = nn.Linear( + args.encoder_embed_dim, args.encoder_embed_dim + ) + self.pooler_activation = utils.get_activation_fn(args.pooler_activation_fn) + + self.lm_head_transform_weight = nn.Linear( + args.encoder_embed_dim, args.encoder_embed_dim + ) + self.activation_fn = utils.get_activation_fn(args.activation_fn) + self.layer_norm = LayerNorm(args.encoder_embed_dim) + + self.lm_output_learned_bias = None + if self.load_softmax: + self.lm_output_learned_bias = nn.Parameter(torch.zeros(self.vocab_size)) + + if not self.share_input_output_embed: + self.embed_out = nn.Linear( + args.encoder_embed_dim, self.vocab_size, bias=False + ) + + if args.sent_loss: + self.sentence_projection_layer = nn.Linear( + args.encoder_embed_dim, self.sentence_out_dim, bias=False + ) + + def forward(self, src_tokens, segment_labels=None, masked_tokens=None, **unused): + """ + Forward pass for Masked LM encoder. This first computes the token + embedding using the token embedding matrix, position embeddings (if + specified) and segment embeddings (if specified). + + Here we assume that the sentence representation corresponds to the + output of the classification_token (see bert_task or cross_lingual_lm + task for more details). + Args: + - src_tokens: B x T matrix representing sentences + - segment_labels: B x T matrix representing segment label for tokens + Returns: + - a tuple of the following: + - logits for predictions in format B x T x C to be used in + softmax afterwards + - a dictionary of additional data, where 'pooled_output' contains + the representation for classification_token and 'inner_states' + is a list of internal model states used to compute the + predictions (similar in ELMO). 'sentence_logits' + is the prediction logit for NSP task and is only computed if + this is specified in the input arguments. + """ + + inner_states, sentence_rep = self.sentence_encoder( + src_tokens, + segment_labels=segment_labels, + ) + + x = inner_states[-1].transpose(0, 1) + # project masked tokens only + if masked_tokens is not None: + x = x[masked_tokens, :] + x = self.layer_norm(self.activation_fn(self.lm_head_transform_weight(x))) + + pooled_output = self.pooler_activation(self.masked_lm_pooler(sentence_rep)) + + # project back to size of vocabulary + if self.share_input_output_embed and hasattr( + self.sentence_encoder.embed_tokens, "weight" + ): + x = F.linear(x, self.sentence_encoder.embed_tokens.weight) + elif self.embed_out is not None: + x = self.embed_out(x) + if self.lm_output_learned_bias is not None: + x = x + self.lm_output_learned_bias + sentence_logits = None + if self.sentence_projection_layer: + sentence_logits = self.sentence_projection_layer(pooled_output) + + return x, { + "inner_states": inner_states, + "pooled_output": pooled_output, + "sentence_logits": sentence_logits, + } + + def max_positions(self): + """Maximum output length supported by the encoder.""" + return self.max_positions + + def upgrade_state_dict_named(self, state_dict, name): + if not self.load_softmax: + for k in list(state_dict.keys()): + if ( + "embed_out.weight" in k + or "sentence_projection_layer.weight" in k + or "lm_output_learned_bias" in k + ): + del state_dict[k] + return state_dict + + +@register_model_architecture("masked_lm", "masked_lm") +def base_architecture(args): + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.act_dropout = getattr(args, "act_dropout", 0.0) + + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.share_encoder_input_output_embed = getattr( + args, "share_encoder_input_output_embed", False + ) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.num_segment = getattr(args, "num_segment", 2) + + args.sentence_class_num = getattr(args, "sentence_class_num", 2) + args.sent_loss = getattr(args, "sent_loss", False) + + args.apply_bert_init = getattr(args, "apply_bert_init", False) + + args.activation_fn = getattr(args, "activation_fn", "relu") + args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + + +@register_model_architecture("masked_lm", "bert_base") +def bert_base_architecture(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) + args.share_encoder_input_output_embed = getattr( + args, "share_encoder_input_output_embed", True + ) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True) + args.num_segment = getattr(args, "num_segment", 2) + + args.encoder_layers = getattr(args, "encoder_layers", 12) + + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072) + + args.sentence_class_num = getattr(args, "sentence_class_num", 2) + args.sent_loss = getattr(args, "sent_loss", True) + + args.apply_bert_init = getattr(args, "apply_bert_init", True) + + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) + base_architecture(args) + + +@register_model_architecture("masked_lm", "bert_large") +def bert_large_architecture(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_layers = getattr(args, "encoder_layers", 24) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) + bert_base_architecture(args) + + +@register_model_architecture("masked_lm", "xlm_base") +def xlm_architecture(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.share_encoder_input_output_embed = getattr( + args, "share_encoder_input_output_embed", True + ) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", True) + args.num_segment = getattr(args, "num_segment", 1) + + args.encoder_layers = getattr(args, "encoder_layers", 6) + + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) + + args.sent_loss = getattr(args, "sent_loss", False) + + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") + args.apply_bert_init = getattr(args, "apply_bert_init", True) + base_architecture(args) diff --git a/fairseq/models/model_utils.py b/fairseq/models/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..732d66b1d5f695151c26d29eb7f6b53179c269f1 --- /dev/null +++ b/fairseq/models/model_utils.py @@ -0,0 +1,92 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional + +import torch +from torch import Tensor + + +@torch.jit.script +def script_skip_tensor_list(x: List[Tensor], mask): + res = [xi[mask] if xi.size(0) == mask.size(0) else xi[:, mask] for xi in x] + outputs = [] + for i, t in enumerate(res): + if t.numel() != 0: + outputs.append(t) + else: + outputs.append(x[i]) + return outputs + + +@torch.jit.script +def script_skip_tensor(x: Tensor, mask): + # None case + if x.size(0) == 0: + return x + res = x[mask] if x.size(0) == mask.size(0) else x[:, mask] + if res.numel() == 0: + return x + else: + return res + + +@torch.jit.script +def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int): + """ + Expand 2D/3D tensor on dim=1 + """ + if x is None: + return None + + assert x.dim() == 2 or x.dim() == 3 + assert trg_dim >= x.size(1), (trg_dim, x.size()) + if trg_dim == x.size(1): + return x + + dims = [x.size(0), trg_dim - x.size(1)] + if x.dim() == 3: + dims.append(x.size(2)) + x = torch.cat([x, torch.zeros(dims).to(x).fill_(padding_idx)], 1) + + return x + + +@torch.jit.script +def coalesce(x: Optional[Tensor], y: Tensor) -> Tensor: + return x if x is not None else y + + +@torch.jit.script +def fill_tensors( + x: Optional[Tensor], mask, y: Optional[Tensor], padding_idx: int +) -> Optional[Tensor]: + """ + Filling tensor x with y at masked positions (dim=0). + """ + if x is None or x.size()[0] == 0 or y is None: + return x + assert x.dim() == y.dim() and mask.size(0) == x.size(0) + assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2)) + + n_selected = mask.sum() + if n_selected == 0: + return x + assert n_selected == y.size(0) + if n_selected == x.size(0): + return y + + if x.size(1) < y.size(1): + x = expand_2d_or_3d_tensor(x, y.size(1), padding_idx) + x[mask] = y + elif x.size(1) > y.size(1): + x[mask] = torch.tensor(padding_idx).type_as(x) + if x.dim() == 2: + x[mask, : y.size(1)] = y + else: + x[mask, : y.size(1), :] = y + else: + x[mask] = y + return x diff --git a/fairseq/models/multilingual_transformer.py b/fairseq/models/multilingual_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..e722b647edd92c95a3e93489031ae331f90e0463 --- /dev/null +++ b/fairseq/models/multilingual_transformer.py @@ -0,0 +1,229 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections import OrderedDict + +from fairseq import utils +from fairseq.models import ( + FairseqMultiModel, + register_model, + register_model_architecture, +) +from fairseq.models.transformer import ( + Embedding, + TransformerDecoder, + TransformerEncoder, + TransformerModel, + base_architecture, +) +from fairseq.utils import safe_hasattr + + +@register_model("multilingual_transformer") +class MultilingualTransformerModel(FairseqMultiModel): + """Train Transformer models for multiple language pairs simultaneously. + + Requires `--task multilingual_translation`. + + We inherit all arguments from TransformerModel and assume that all language + pairs use a single Transformer architecture. In addition, we provide several + options that are specific to the multilingual setting. + + Args: + --share-encoder-embeddings: share encoder embeddings across all source languages + --share-decoder-embeddings: share decoder embeddings across all target languages + --share-encoders: share all encoder params (incl. embeddings) across all source languages + --share-decoders: share all decoder params (incl. embeddings) across all target languages + """ + + def __init__(self, encoders, decoders): + super().__init__(encoders, decoders) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + TransformerModel.add_args(parser) + parser.add_argument( + "--share-encoder-embeddings", + action="store_true", + help="share encoder embeddings across languages", + ) + parser.add_argument( + "--share-decoder-embeddings", + action="store_true", + help="share decoder embeddings across languages", + ) + parser.add_argument( + "--share-encoders", + action="store_true", + help="share encoders across languages", + ) + parser.add_argument( + "--share-decoders", + action="store_true", + help="share decoders across languages", + ) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + from fairseq.tasks.multilingual_translation import MultilingualTranslationTask + + assert isinstance(task, MultilingualTranslationTask) + + # make sure all arguments are present in older models + base_multilingual_architecture(args) + + if not safe_hasattr(args, "max_source_positions"): + args.max_source_positions = 1024 + if not safe_hasattr(args, "max_target_positions"): + args.max_target_positions = 1024 + + src_langs = [lang_pair.split("-")[0] for lang_pair in task.model_lang_pairs] + tgt_langs = [lang_pair.split("-")[1] for lang_pair in task.model_lang_pairs] + + if args.share_encoders: + args.share_encoder_embeddings = True + if args.share_decoders: + args.share_decoder_embeddings = True + + def build_embedding(dictionary, embed_dim, path=None): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + emb = Embedding(num_embeddings, embed_dim, padding_idx) + # if provided, load from preloaded dictionaries + if path: + embed_dict = utils.parse_embedding(path) + utils.load_embedding(embed_dict, dictionary, emb) + return emb + + # build shared embeddings (if applicable) + shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None + if args.share_all_embeddings: + if args.encoder_embed_dim != args.decoder_embed_dim: + raise ValueError( + "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" + ) + if args.decoder_embed_path and ( + args.decoder_embed_path != args.encoder_embed_path + ): + raise ValueError( + "--share-all-embeddings not compatible with --decoder-embed-path" + ) + shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( + dicts=task.dicts, + langs=task.langs, + embed_dim=args.encoder_embed_dim, + build_embedding=build_embedding, + pretrained_embed_path=args.encoder_embed_path, + ) + shared_decoder_embed_tokens = shared_encoder_embed_tokens + args.share_decoder_input_output_embed = True + else: + if args.share_encoder_embeddings: + shared_encoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( + dicts=task.dicts, + langs=src_langs, + embed_dim=args.encoder_embed_dim, + build_embedding=build_embedding, + pretrained_embed_path=args.encoder_embed_path, + ) + if args.share_decoder_embeddings: + shared_decoder_embed_tokens = FairseqMultiModel.build_shared_embeddings( + dicts=task.dicts, + langs=tgt_langs, + embed_dim=args.decoder_embed_dim, + build_embedding=build_embedding, + pretrained_embed_path=args.decoder_embed_path, + ) + + # encoders/decoders for each language + lang_encoders, lang_decoders = {}, {} + + def get_encoder(lang): + if lang not in lang_encoders: + if shared_encoder_embed_tokens is not None: + encoder_embed_tokens = shared_encoder_embed_tokens + else: + encoder_embed_tokens = build_embedding( + task.dicts[lang], + args.encoder_embed_dim, + args.encoder_embed_path, + ) + lang_encoders[lang] = cls._get_module_class( + True, args, task.dicts[lang], encoder_embed_tokens, src_langs + ) + return lang_encoders[lang] + + def get_decoder(lang): + if lang not in lang_decoders: + if shared_decoder_embed_tokens is not None: + decoder_embed_tokens = shared_decoder_embed_tokens + else: + decoder_embed_tokens = build_embedding( + task.dicts[lang], + args.decoder_embed_dim, + args.decoder_embed_path, + ) + lang_decoders[lang] = cls._get_module_class( + False, args, task.dicts[lang], decoder_embed_tokens, tgt_langs + ) + return lang_decoders[lang] + + # shared encoders/decoders (if applicable) + shared_encoder, shared_decoder = None, None + if args.share_encoders: + shared_encoder = get_encoder(src_langs[0]) + if args.share_decoders: + shared_decoder = get_decoder(tgt_langs[0]) + + encoders, decoders = OrderedDict(), OrderedDict() + for lang_pair, src, tgt in zip(task.model_lang_pairs, src_langs, tgt_langs): + encoders[lang_pair] = ( + shared_encoder if shared_encoder is not None else get_encoder(src) + ) + decoders[lang_pair] = ( + shared_decoder if shared_decoder is not None else get_decoder(tgt) + ) + + return MultilingualTransformerModel(encoders, decoders) + + @classmethod + def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs): + module_class = TransformerEncoder if is_encoder else TransformerDecoder + return module_class(args, lang_dict, embed_tokens) + + def load_state_dict(self, state_dict, strict=True, model_cfg=None): + state_dict_subset = state_dict.copy() + for k, _ in state_dict.items(): + assert k.startswith("models.") + lang_pair = k.split(".")[1] + if lang_pair not in self.models: + del state_dict_subset[k] + super().load_state_dict(state_dict_subset, strict=strict, model_cfg=model_cfg) + + +@register_model_architecture("multilingual_transformer", "multilingual_transformer") +def base_multilingual_architecture(args): + base_architecture(args) + args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", False) + args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", False) + args.share_encoders = getattr(args, "share_encoders", False) + args.share_decoders = getattr(args, "share_decoders", False) + + +@register_model_architecture( + "multilingual_transformer", "multilingual_transformer_iwslt_de_en" +) +def multilingual_transformer_iwslt_de_en(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) + args.decoder_layers = getattr(args, "decoder_layers", 6) + base_multilingual_architecture(args) diff --git a/fairseq/models/nat/__init__.py b/fairseq/models/nat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05fe822487c3bcde8346648d5826f1669c6bc1ca --- /dev/null +++ b/fairseq/models/nat/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +from .fairseq_nat_model import * +from .nonautoregressive_transformer import * +from .nat_crf_transformer import * +from .iterative_nonautoregressive_transformer import * +from .cmlm_transformer import * +from .levenshtein_transformer import * +from .insertion_transformer import * diff --git a/fairseq/models/nat/__pycache__/__init__.cpython-310.pyc b/fairseq/models/nat/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e011f766245e97770247151f97505d8a015fe291 Binary files /dev/null and b/fairseq/models/nat/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/models/nat/__pycache__/cmlm_transformer.cpython-310.pyc b/fairseq/models/nat/__pycache__/cmlm_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..734a07a051b958d23376caab2c2abc1c71449ddd Binary files /dev/null and b/fairseq/models/nat/__pycache__/cmlm_transformer.cpython-310.pyc differ diff --git a/fairseq/models/nat/__pycache__/fairseq_nat_model.cpython-310.pyc b/fairseq/models/nat/__pycache__/fairseq_nat_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19ba074cb76c84376c1d79d5b8f0686017723106 Binary files /dev/null and b/fairseq/models/nat/__pycache__/fairseq_nat_model.cpython-310.pyc differ diff --git a/fairseq/models/nat/__pycache__/insertion_transformer.cpython-310.pyc b/fairseq/models/nat/__pycache__/insertion_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbae70f6dbafb578abd1f2aa0378bed17219f07a Binary files /dev/null and b/fairseq/models/nat/__pycache__/insertion_transformer.cpython-310.pyc differ diff --git a/fairseq/models/nat/__pycache__/iterative_nonautoregressive_transformer.cpython-310.pyc b/fairseq/models/nat/__pycache__/iterative_nonautoregressive_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fbef0c93786019c8a3d299e07ce9fc429a3e4c9 Binary files /dev/null and b/fairseq/models/nat/__pycache__/iterative_nonautoregressive_transformer.cpython-310.pyc differ diff --git a/fairseq/models/nat/__pycache__/levenshtein_transformer.cpython-310.pyc b/fairseq/models/nat/__pycache__/levenshtein_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fb907549cbf12613f922a71b8a9d1fc31659441 Binary files /dev/null and b/fairseq/models/nat/__pycache__/levenshtein_transformer.cpython-310.pyc differ diff --git a/fairseq/models/nat/__pycache__/levenshtein_utils.cpython-310.pyc b/fairseq/models/nat/__pycache__/levenshtein_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fe17dcfe5c48c1d8f3c7c159abfa4b738770ab5 Binary files /dev/null and b/fairseq/models/nat/__pycache__/levenshtein_utils.cpython-310.pyc differ diff --git a/fairseq/models/nat/__pycache__/nat_crf_transformer.cpython-310.pyc b/fairseq/models/nat/__pycache__/nat_crf_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e789363a67a39fb0fbba3db4d878424581835e02 Binary files /dev/null and b/fairseq/models/nat/__pycache__/nat_crf_transformer.cpython-310.pyc differ diff --git a/fairseq/models/nat/__pycache__/nonautoregressive_transformer.cpython-310.pyc b/fairseq/models/nat/__pycache__/nonautoregressive_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8817ef3f2b906f1a0216eb81f2eabd2002b7ebc Binary files /dev/null and b/fairseq/models/nat/__pycache__/nonautoregressive_transformer.cpython-310.pyc differ diff --git a/fairseq/models/nat/cmlm_transformer.py b/fairseq/models/nat/cmlm_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..c876e9453c101c00bd8e93e6e6f1fb48dc26f993 --- /dev/null +++ b/fairseq/models/nat/cmlm_transformer.py @@ -0,0 +1,162 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This file implements: +Ghazvininejad, Marjan, et al. +"Constant-time machine translation with conditional masked language models." +arXiv preprint arXiv:1904.09324 (2019). +""" + +from fairseq.models import register_model, register_model_architecture +from fairseq.models.nat import NATransformerModel +from fairseq.utils import new_arange + + +def _skeptical_unmasking(output_scores, output_masks, p): + sorted_index = output_scores.sort(-1)[1] + boundary_len = ( + (output_masks.sum(1, keepdim=True).type_as(output_scores) - 2) * p + ).long() + skeptical_mask = new_arange(output_masks) < boundary_len + return skeptical_mask.scatter(1, sorted_index, skeptical_mask) + + +@register_model("cmlm_transformer") +class CMLMNATransformerModel(NATransformerModel): + @staticmethod + def add_args(parser): + NATransformerModel.add_args(parser) + + def forward( + self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs + ): + assert not self.decoder.src_embedding_copy, "do not support embedding copy." + + # encoding + encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) + # length prediction + length_out = self.decoder.forward_length( + normalize=False, encoder_out=encoder_out + ) + length_tgt = self.decoder.forward_length_prediction( + length_out, encoder_out, tgt_tokens + ) + + # decoding + word_ins_out = self.decoder( + normalize=False, + prev_output_tokens=prev_output_tokens, + encoder_out=encoder_out, + ) + word_ins_mask = prev_output_tokens.eq(self.unk) + + return { + "word_ins": { + "out": word_ins_out, + "tgt": tgt_tokens, + "mask": word_ins_mask, + "ls": self.args.label_smoothing, + "nll_loss": True, + }, + "length": { + "out": length_out, + "tgt": length_tgt, + "factor": self.decoder.length_loss_factor, + }, + } + + def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs): + + step = decoder_out.step + max_step = decoder_out.max_step + + output_tokens = decoder_out.output_tokens + output_scores = decoder_out.output_scores + history = decoder_out.history + + # execute the decoder + output_masks = output_tokens.eq(self.unk) + _scores, _tokens = self.decoder( + normalize=True, + prev_output_tokens=output_tokens, + encoder_out=encoder_out, + ).max(-1) + output_tokens.masked_scatter_(output_masks, _tokens[output_masks]) + output_scores.masked_scatter_(output_masks, _scores[output_masks]) + + if history is not None: + history.append(output_tokens.clone()) + + # skeptical decoding (depend on the maximum decoding steps.) + if (step + 1) < max_step: + skeptical_mask = _skeptical_unmasking( + output_scores, output_tokens.ne(self.pad), 1 - (step + 1) / max_step + ) + + output_tokens.masked_fill_(skeptical_mask, self.unk) + output_scores.masked_fill_(skeptical_mask, 0.0) + + if history is not None: + history.append(output_tokens.clone()) + + return decoder_out._replace( + output_tokens=output_tokens, + output_scores=output_scores, + attn=None, + history=history, + ) + + +@register_model_architecture("cmlm_transformer", "cmlm_transformer") +def cmlm_base_architecture(args): + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.dropout = getattr(args, "dropout", 0.1) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.share_all_embeddings = getattr(args, "share_all_embeddings", True) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.apply_bert_init = getattr(args, "apply_bert_init", False) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + + # --- special arguments --- + args.sg_length_pred = getattr(args, "sg_length_pred", False) + args.pred_length_offset = getattr(args, "pred_length_offset", False) + args.length_loss_factor = getattr(args, "length_loss_factor", 0.1) + args.ngram_predictor = getattr(args, "ngram_predictor", 1) + args.src_embedding_copy = getattr(args, "src_embedding_copy", False) + + +@register_model_architecture("cmlm_transformer", "cmlm_transformer_wmt_en_de") +def cmlm_wmt_en_de(args): + cmlm_base_architecture(args) diff --git a/fairseq/models/nat/fairseq_nat_model.py b/fairseq/models/nat/fairseq_nat_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a5594a4ed9b395a1665c70fa7410d5caa16a879d --- /dev/null +++ b/fairseq/models/nat/fairseq_nat_model.py @@ -0,0 +1,172 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +from fairseq.models.transformer import ( + TransformerDecoder, + TransformerEncoder, + TransformerModel, +) +from fairseq.modules.transformer_sentence_encoder import init_bert_params + + +def ensemble_encoder(func): + def wrapper(self, *args, **kwargs): + if self.ensemble_models is None or len(self.ensemble_models) == 1: + return func(self, *args, **kwargs) + encoder_outs = [ + func(model, *args, **kwargs, return_all_hiddens=True) + for model in self.ensemble_models + ] + _encoder_out = encoder_outs[0].copy() + + def stack(key): + outs = [e[key][0] for e in encoder_outs] + return [torch.stack(outs, -1) if outs[0] is not None else None] + + _encoder_out["encoder_out"] = stack("encoder_out") + _encoder_out["encoder_embedding"] = stack("encoder_embedding") + + num_layers = len(_encoder_out["encoder_states"]) + if num_layers > 0: + _encoder_out["encoder_states"] = [ + torch.stack([e["encoder_states"][i] for e in encoder_outs], -1) + for i in range(num_layers) + ] + return _encoder_out + + return wrapper + + +def ensemble_decoder(func): + def wrapper(self, normalize=False, encoder_out=None, *args, **kwargs): + if self.ensemble_models is None or len(self.ensemble_models) == 1: + return func( + self, normalize=normalize, encoder_out=encoder_out, *args, **kwargs + ) + + def _replace(encoder_out, new_val): + new_encoder_out = encoder_out.copy() + new_encoder_out["encoder_out"] = [new_val] + return new_encoder_out + + action_outs = [ + func( + model, + normalize=normalize, + encoder_out=_replace( + encoder_out, encoder_out["encoder_out"][0][:, :, :, i] + ), + *args, + **kwargs + ) + for i, model in enumerate(self.ensemble_models) + ] + + if not isinstance(action_outs[0], tuple): # return multiple values + action_outs = [[a] for a in action_outs] + else: + action_outs = [list(a) for a in action_outs] + + ensembled_outs = [] + for i in range(len(action_outs[0])): + if i == 0 and normalize: + ensembled_outs += [ + torch.logsumexp( + torch.stack([a[i] for a in action_outs], -1), dim=-1 + ) + - math.log(len(self.ensemble_models)) + ] + elif action_outs[0][i] is not None: + ensembled_outs += [torch.stack([a[i] for a in action_outs], -1)] + else: + ensembled_outs += [None] + + if len(ensembled_outs) == 1: + return ensembled_outs[0] + return tuple(ensembled_outs) + + return wrapper + + +class FairseqNATModel(TransformerModel): + """ + Abstract class for all nonautoregressive-based models + """ + + def __init__(self, args, encoder, decoder): + super().__init__(args, encoder, decoder) + self.tgt_dict = decoder.dictionary + self.bos = decoder.dictionary.bos() + self.eos = decoder.dictionary.eos() + self.pad = decoder.dictionary.pad() + self.unk = decoder.dictionary.unk() + + self.ensemble_models = None + + @property + def allow_length_beam(self): + return False + + @property + def allow_ensemble(self): + return True + + def enable_ensemble(self, models): + self.encoder.ensemble_models = [m.encoder for m in models] + self.decoder.ensemble_models = [m.decoder for m in models] + + @staticmethod + def add_args(parser): + TransformerModel.add_args(parser) + parser.add_argument( + "--apply-bert-init", + action="store_true", + help="use custom param initialization for BERT", + ) + + @classmethod + def build_decoder(cls, args, tgt_dict, embed_tokens): + decoder = FairseqNATDecoder(args, tgt_dict, embed_tokens) + if getattr(args, "apply_bert_init", False): + decoder.apply(init_bert_params) + return decoder + + @classmethod + def build_encoder(cls, args, src_dict, embed_tokens): + encoder = FairseqNATEncoder(args, src_dict, embed_tokens) + if getattr(args, "apply_bert_init", False): + encoder.apply(init_bert_params) + return encoder + + def forward_encoder(self, encoder_inputs): + return self.encoder(*encoder_inputs) + + def forward_decoder(self, *args, **kwargs): + return NotImplementedError + + def initialize_output_tokens(self, *args, **kwargs): + return NotImplementedError + + def forward(self, *args, **kwargs): + return NotImplementedError + + +class FairseqNATEncoder(TransformerEncoder): + def __init__(self, args, dictionary, embed_tokens): + super().__init__(args, dictionary, embed_tokens) + self.ensemble_models = None + + @ensemble_encoder + def forward(self, *args, **kwargs): + return super().forward(*args, **kwargs) + + +class FairseqNATDecoder(TransformerDecoder): + def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): + super().__init__(args, dictionary, embed_tokens, no_encoder_attn) + self.ensemble_models = None diff --git a/fairseq/models/nat/insertion_transformer.py b/fairseq/models/nat/insertion_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..bc28000f59a3b9e8098f9fe710cc8335d39eea3e --- /dev/null +++ b/fairseq/models/nat/insertion_transformer.py @@ -0,0 +1,280 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +import torch.nn.functional as F +from fairseq.models import register_model, register_model_architecture +from fairseq.models.nat import ( + FairseqNATModel, + LevenshteinTransformerDecoder, + LevenshteinTransformerModel, + ensemble_decoder, +) +from fairseq.models.transformer import Linear +from fairseq.modules.transformer_sentence_encoder import init_bert_params +from fairseq.utils import new_arange + + +class NegativeDistanceScore(object): + def __init__(self): + + # pre-compute some values + self.scores = {} + + self.scores[0.5] = self.compute_score_full(50, 0.5) + self.scores[1.0] = self.compute_score_full(50, 1.0) + self.scores[2.0] = self.compute_score_full(50, 2.0) + + def __call__(self, i, L, tau): + if (tau is None) or (tau > 1000): + return 1 / L + + if tau in self.scores: + if L < self.scores[tau].shape[0]: + return self.scores[tau][L - 1, i] + return self.compute_score(L, tau)[i] + + def compute_score(self, L, tau): + s = np.array([-abs(L / 2 - i) / tau for i in range(L)]) + s = np.exp(s - s.max()) + return s / s.sum() + + def compute_score_full(self, L, tau): + s = -abs(np.arange(0, L - 1)[:, None] / 2 - np.arange(L)[None, :]) / tau + s = np.tril(s, 0) + np.triu(s - float("inf"), 1) + s = np.exp(s - s.max(1, keepdims=True)) + return s / s.sum(1, keepdims=True) + + +neg_scorer = NegativeDistanceScore() + + +def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx, vocab_size, tau=None): + try: + from fairseq import libnat + except ImportError as e: + import sys + + sys.stderr.write("ERROR: missing libnat. run `pip install --editable .`\n") + raise e + + B = in_tokens.size(0) + T = in_tokens.size(1) + V = vocab_size + + with torch.cuda.device_of(in_tokens): + in_tokens_list = [ + [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist()) + ] + out_tokens_list = [ + [t for t in s if t != padding_idx] + for i, s in enumerate(out_tokens.tolist()) + ] + + full_labels = libnat.suggested_ed2_path( + in_tokens_list, out_tokens_list, padding_idx + ) + insert_labels = [a[:-1] for a in full_labels] + + # numericalize1 + insert_label_tensors = in_tokens.new_zeros(B * (T - 1) * V).float() + insert_index, insert_labels = zip( + *[ + (w + (j + i * (T - 1)) * V, neg_scorer(k, len(label), tau)) + for i, labels in enumerate(insert_labels) + for j, label in enumerate(labels[1:-1]) + for k, w in enumerate(label) + ] + ) # HACK 1:-1 + insert_index, insert_labels = [ + torch.tensor(list(a), device=in_tokens.device) + for a in [insert_index, insert_labels] + ] + insert_label_tensors.scatter_(0, insert_index.long(), insert_labels) + insert_label_tensors = insert_label_tensors.view(B, T - 1, V) + + return insert_label_tensors + + +def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, padding_idx): + + padding_masks = in_tokens[:, 1:].eq(padding_idx) + word_ins_scores.masked_fill_(padding_masks, 0.0) + word_ins_pred.masked_fill_(padding_masks, padding_idx) + + in_coords = new_arange(in_tokens).type_as(in_scores) + + # shift all padding predictions to infinite + out_coords = (in_coords[:, 1:] - 0.5).masked_fill( + word_ins_pred.eq(padding_idx), float("inf") + ) + out_coords = torch.cat([in_coords, out_coords], 1).sort(-1)[1] + out_tokens = torch.cat([in_tokens, word_ins_pred], 1).gather(1, out_coords) + out_scores = torch.cat([in_scores, word_ins_scores], 1).gather(1, out_coords) + return out_tokens, out_scores + + +@register_model("insertion_transformer") +class InsertionTransformerModel(LevenshteinTransformerModel): + def __init__(self, args, encoder, decoder): + super().__init__(args, encoder, decoder) + + @staticmethod + def add_args(parser): + FairseqNATModel.add_args(parser) + parser.add_argument("--label-tau", default=None, type=float) + + @classmethod + def build_decoder(cls, args, tgt_dict, embed_tokens): + decoder = InsertionTransformerDecoder(args, tgt_dict, embed_tokens) + if getattr(args, "apply_bert_init", False): + decoder.apply(init_bert_params) + return decoder + + def forward( + self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs + ): + + assert tgt_tokens is not None, "forward function only supports training." + + # encoding + encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) + + # generate training labels for insertion + word_ins_out = self.decoder.forward_word_ins( + normalize=False, + prev_output_tokens=prev_output_tokens, + encoder_out=encoder_out, + ) + + word_ins_tgt = _get_ins_targets( + prev_output_tokens, + tgt_tokens, + self.pad, + self.unk, + len(self.tgt_dict), + tau=self.decoder.label_tau, + ).type_as(word_ins_out) + word_ins_masks = prev_output_tokens[:, 1:].ne(self.pad) + + return { + "word_ins": { + "out": word_ins_out, + "tgt": word_ins_tgt, + "mask": word_ins_masks, + "ls": self.args.label_smoothing, + "nll_loss": True, + } + } + + def forward_decoder( + self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs + ): + + output_tokens = decoder_out.output_tokens + output_scores = decoder_out.output_scores + history = decoder_out.history + + # TODO: decoding for InsertionTransformer + word_ins_score = self.decoder.forward_word_ins( + normalize=True, prev_output_tokens=output_tokens, encoder_out=encoder_out + ) + + if eos_penalty > 0.0: + word_ins_score[:, :, self.pad] -= eos_penalty + word_ins_score, word_ins_pred = word_ins_score.max(-1) + output_tokens, output_scores = _apply_ins_words( + output_tokens, output_scores, word_ins_pred, word_ins_score, self.pad + ) + + # delete some unnecessary paddings + cut_off = output_tokens.ne(self.pad).sum(1).max() + output_tokens = output_tokens[:, :cut_off] + output_scores = output_scores[:, :cut_off] + + if history is not None: + history.append(output_tokens.clone()) + + return decoder_out._replace( + output_tokens=output_tokens, + output_scores=output_scores, + attn=None, + history=history, + ) + + +class InsertionTransformerDecoder(LevenshteinTransformerDecoder): + def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): + # use the TransformerDecoder's __init__ + super(LevenshteinTransformerDecoder, self).__init__( + args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn + ) + + self.dictionary = dictionary + self.bos = dictionary.bos() + self.unk = dictionary.unk() + self.eos = dictionary.eos() + self.pool_out = Linear(self.output_embed_dim * 2, self.output_embed_dim) + + self.label_tau = getattr(args, "label_tau", None) + + @ensemble_decoder + def forward_word_ins(self, normalize, encoder_out, prev_output_tokens): + features = self.extract_features(prev_output_tokens, encoder_out=encoder_out)[0] + features = self.pool_out( + torch.cat([features[:, :-1, :], features[:, 1:, :]], 2) + ) + decoder_out = self.output_layer(features) + return F.log_softmax(decoder_out, -1) if normalize else decoder_out + + def forward_mask_ins(self, *args, **kwargs): + raise NotImplementedError + + def forward_word_del(self, *args, **kwargs): + raise NotImplementedError + + +@register_model_architecture("insertion_transformer", "insertion_transformer") +def insertion_base_architecture(args): + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.dropout = getattr(args, "dropout", 0.1) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.share_all_embeddings = getattr(args, "share_all_embeddings", False) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.apply_bert_init = getattr(args, "apply_bert_init", False) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + + # special for insertion transformer + args.label_tau = getattr(args, "label_tau", None) diff --git a/fairseq/models/nat/iterative_nonautoregressive_transformer.py b/fairseq/models/nat/iterative_nonautoregressive_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..bc39509980a80eb8c21e0bfdb304649ad3acc4d0 --- /dev/null +++ b/fairseq/models/nat/iterative_nonautoregressive_transformer.py @@ -0,0 +1,228 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from fairseq.models import register_model, register_model_architecture +from fairseq.models.nat import NATransformerModel + + +def _sequential_poisoning(s, V, beta=0.33, bos=2, eos=3, pad=1): + # s: input batch + # V: vocabulary size + rand_words = torch.randint(low=4, high=V, size=s.size(), device=s.device) + choices = torch.rand(size=s.size(), device=s.device) + choices.masked_fill_((s == pad) | (s == bos) | (s == eos), 1) + + replace = choices < beta / 3 + repeat = (choices >= beta / 3) & (choices < beta * 2 / 3) + swap = (choices >= beta * 2 / 3) & (choices < beta) + safe = choices >= beta + + for i in range(s.size(1) - 1): + rand_word = rand_words[:, i] + next_word = s[:, i + 1] + self_word = s[:, i] + + replace_i = replace[:, i] + swap_i = swap[:, i] & (next_word != 3) + repeat_i = repeat[:, i] & (next_word != 3) + safe_i = safe[:, i] | ((next_word == 3) & (~replace_i)) + + s[:, i] = ( + self_word * (safe_i | repeat_i).long() + + next_word * swap_i.long() + + rand_word * replace_i.long() + ) + s[:, i + 1] = ( + next_word * (safe_i | replace_i).long() + + self_word * (swap_i | repeat_i).long() + ) + return s + + +def gumbel_noise(input, TINY=1e-8): + return ( + input.new_zeros(*input.size()) + .uniform_() + .add_(TINY) + .log_() + .neg_() + .add_(TINY) + .log_() + .neg_() + ) + + +@register_model("iterative_nonautoregressive_transformer") +class IterNATransformerModel(NATransformerModel): + @staticmethod + def add_args(parser): + NATransformerModel.add_args(parser) + parser.add_argument( + "--train-step", + type=int, + help="number of refinement iterations during training", + ) + parser.add_argument( + "--dae-ratio", + type=float, + help="the probability of switching to the denoising auto-encoder loss", + ) + parser.add_argument( + "--stochastic-approx", + action="store_true", + help="sampling from the decoder as the inputs for next iteration", + ) + + @classmethod + def build_model(cls, args, task): + model = super().build_model(args, task) + model.train_step = getattr(args, "train_step", 4) + model.dae_ratio = getattr(args, "dae_ratio", 0.5) + model.stochastic_approx = getattr(args, "stochastic_approx", False) + return model + + def forward( + self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs + ): + + B, T = prev_output_tokens.size() + + # encoding + encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) + + # length prediction + length_out = self.decoder.forward_length( + normalize=False, encoder_out=encoder_out + ) + length_tgt = self.decoder.forward_length_prediction( + length_out, encoder_out, tgt_tokens + ) + + # decoding + word_ins_outs, word_ins_tgts, word_ins_masks = [], [], [] + for t in range(self.train_step): + word_ins_out = self.decoder( + normalize=False, + prev_output_tokens=prev_output_tokens, + encoder_out=encoder_out, + step=t, + ) + word_ins_tgt = tgt_tokens + word_ins_mask = word_ins_tgt.ne(self.pad) + + word_ins_outs.append(word_ins_out) + word_ins_tgts.append(word_ins_tgt) + word_ins_masks.append(word_ins_mask) + + if t < (self.train_step - 1): + # prediction for next iteration + if self.stochastic_approx: + word_ins_prediction = ( + word_ins_out + gumbel_noise(word_ins_out) + ).max(-1)[1] + else: + word_ins_prediction = word_ins_out.max(-1)[1] + + prev_output_tokens = prev_output_tokens.masked_scatter( + word_ins_mask, word_ins_prediction[word_ins_mask] + ) + + if self.dae_ratio > 0: + # we do not perform denoising for the first iteration + corrputed = ( + torch.rand(size=(B,), device=prev_output_tokens.device) + < self.dae_ratio + ) + corrputed_tokens = _sequential_poisoning( + tgt_tokens[corrputed], + len(self.tgt_dict), + 0.33, + self.bos, + self.eos, + self.pad, + ) + prev_output_tokens[corrputed] = corrputed_tokens + + # concat everything + word_ins_out = torch.cat(word_ins_outs, 0) + word_ins_tgt = torch.cat(word_ins_tgts, 0) + word_ins_mask = torch.cat(word_ins_masks, 0) + + return { + "word_ins": { + "out": word_ins_out, + "tgt": word_ins_tgt, + "mask": word_ins_mask, + "ls": self.args.label_smoothing, + "nll_loss": True, + }, + "length": { + "out": length_out, + "tgt": length_tgt, + "factor": self.decoder.length_loss_factor, + }, + } + + +@register_model_architecture( + "iterative_nonautoregressive_transformer", "iterative_nonautoregressive_transformer" +) +def inat_base_architecture(args): + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.dropout = getattr(args, "dropout", 0.1) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.share_all_embeddings = getattr(args, "share_all_embeddings", False) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.apply_bert_init = getattr(args, "apply_bert_init", False) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + + # --- special arguments --- + args.sg_length_pred = getattr(args, "sg_length_pred", False) + args.pred_length_offset = getattr(args, "pred_length_offset", False) + args.length_loss_factor = getattr(args, "length_loss_factor", 0.1) + args.ngram_predictor = getattr(args, "ngram_predictor", 1) + args.src_embedding_copy = getattr(args, "src_embedding_copy", False) + + args.train_step = getattr(args, "train_step", 4) + args.dae_ratio = getattr(args, "dae_ratio", 0.5) + args.stochastic_approx = getattr(args, "stochastic_approx", False) + + +@register_model_architecture( + "iterative_nonautoregressive_transformer", + "iterative_nonautoregressive_transformer_wmt_en_de", +) +def iter_nat_wmt_en_de(args): + inat_base_architecture(args) diff --git a/fairseq/models/nat/levenshtein_transformer.py b/fairseq/models/nat/levenshtein_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d60d3c52d50b1f20957039a75622ffb95d5eea24 --- /dev/null +++ b/fairseq/models/nat/levenshtein_transformer.py @@ -0,0 +1,510 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq.iterative_refinement_generator import DecoderOut +from fairseq.models import register_model, register_model_architecture +from fairseq.models.nat import FairseqNATDecoder, FairseqNATModel, ensemble_decoder +from fairseq.models.transformer import Embedding +from fairseq.modules import TransformerDecoderLayer +from fairseq.modules.transformer_sentence_encoder import init_bert_params + +from .levenshtein_utils import ( + _apply_del_words, + _apply_ins_masks, + _apply_ins_words, + _fill, + _get_del_targets, + _get_ins_targets, + _skip, + _skip_encoder_out, +) + + +@register_model("levenshtein_transformer") +class LevenshteinTransformerModel(FairseqNATModel): + @property + def allow_length_beam(self): + return False + + @staticmethod + def add_args(parser): + FairseqNATModel.add_args(parser) + parser.add_argument( + "--early-exit", + default="6,6,6", + type=str, + help="number of decoder layers before word_del, mask_ins, word_ins", + ) + parser.add_argument( + "--no-share-discriminator", + action="store_true", + help="separate parameters for discriminator", + ) + parser.add_argument( + "--no-share-maskpredictor", + action="store_true", + help="separate parameters for mask-predictor", + ) + parser.add_argument( + "--share-discriminator-maskpredictor", + action="store_true", + help="share the parameters for both mask-predictor and discriminator", + ) + parser.add_argument( + "--sampling-for-deletion", + action="store_true", + help="instead of argmax, use sampling to predict the tokens", + ) + + @classmethod + def build_decoder(cls, args, tgt_dict, embed_tokens): + decoder = LevenshteinTransformerDecoder(args, tgt_dict, embed_tokens) + if getattr(args, "apply_bert_init", False): + decoder.apply(init_bert_params) + return decoder + + def forward( + self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs + ): + + assert tgt_tokens is not None, "forward function only supports training." + + # encoding + encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) + + # generate training labels for insertion + masked_tgt_masks, masked_tgt_tokens, mask_ins_targets = _get_ins_targets( + prev_output_tokens, tgt_tokens, self.pad, self.unk + ) + mask_ins_targets = mask_ins_targets.clamp(min=0, max=255) # for safe prediction + mask_ins_masks = prev_output_tokens[:, 1:].ne(self.pad) + + mask_ins_out, _ = self.decoder.forward_mask_ins( + normalize=False, + prev_output_tokens=prev_output_tokens, + encoder_out=encoder_out, + ) + word_ins_out, _ = self.decoder.forward_word_ins( + normalize=False, + prev_output_tokens=masked_tgt_tokens, + encoder_out=encoder_out, + ) + + # make online prediction + if self.decoder.sampling_for_deletion: + word_predictions = torch.multinomial( + F.softmax(word_ins_out, -1).view(-1, word_ins_out.size(-1)), 1 + ).view(word_ins_out.size(0), -1) + else: + word_predictions = F.log_softmax(word_ins_out, dim=-1).max(2)[1] + + word_predictions.masked_scatter_( + ~masked_tgt_masks, tgt_tokens[~masked_tgt_masks] + ) + + # generate training labels for deletion + word_del_targets = _get_del_targets(word_predictions, tgt_tokens, self.pad) + word_del_out, _ = self.decoder.forward_word_del( + normalize=False, + prev_output_tokens=word_predictions, + encoder_out=encoder_out, + ) + word_del_masks = word_predictions.ne(self.pad) + + return { + "mask_ins": { + "out": mask_ins_out, + "tgt": mask_ins_targets, + "mask": mask_ins_masks, + "ls": 0.01, + }, + "word_ins": { + "out": word_ins_out, + "tgt": tgt_tokens, + "mask": masked_tgt_masks, + "ls": self.args.label_smoothing, + "nll_loss": True, + }, + "word_del": { + "out": word_del_out, + "tgt": word_del_targets, + "mask": word_del_masks, + }, + } + + def forward_decoder( + self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs + ): + + output_tokens = decoder_out.output_tokens + output_scores = decoder_out.output_scores + attn = decoder_out.attn + history = decoder_out.history + + bsz = output_tokens.size(0) + if max_ratio is None: + max_lens = torch.zeros_like(output_tokens).fill_(255) + else: + if not encoder_out["encoder_padding_mask"]: + max_src_len = encoder_out["encoder_out"].size(0) + src_lens = encoder_out["encoder_out"].new(bsz).fill_(max_src_len) + else: + src_lens = (~encoder_out["encoder_padding_mask"][0]).sum(1) + max_lens = (src_lens * max_ratio).clamp(min=10).long() + + # delete words + # do not delete tokens if it is + can_del_word = output_tokens.ne(self.pad).sum(1) > 2 + if can_del_word.sum() != 0: # we cannot delete, skip + word_del_score, word_del_attn = self.decoder.forward_word_del( + normalize=True, + prev_output_tokens=_skip(output_tokens, can_del_word), + encoder_out=_skip_encoder_out(self.encoder, encoder_out, can_del_word), + ) + word_del_pred = word_del_score.max(-1)[1].bool() + + _tokens, _scores, _attn = _apply_del_words( + output_tokens[can_del_word], + output_scores[can_del_word], + word_del_attn, + word_del_pred, + self.pad, + self.bos, + self.eos, + ) + output_tokens = _fill(output_tokens, can_del_word, _tokens, self.pad) + output_scores = _fill(output_scores, can_del_word, _scores, 0) + attn = _fill(attn, can_del_word, _attn, 0.0) + + if history is not None: + history.append(output_tokens.clone()) + + # insert placeholders + can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lens + if can_ins_mask.sum() != 0: + mask_ins_score, _ = self.decoder.forward_mask_ins( + normalize=True, + prev_output_tokens=_skip(output_tokens, can_ins_mask), + encoder_out=_skip_encoder_out(self.encoder, encoder_out, can_ins_mask), + ) + if eos_penalty > 0.0: + mask_ins_score[:, :, 0] = mask_ins_score[:, :, 0] - eos_penalty + mask_ins_pred = mask_ins_score.max(-1)[1] + mask_ins_pred = torch.min( + mask_ins_pred, max_lens[can_ins_mask, None].expand_as(mask_ins_pred) + ) + + _tokens, _scores = _apply_ins_masks( + output_tokens[can_ins_mask], + output_scores[can_ins_mask], + mask_ins_pred, + self.pad, + self.unk, + self.eos, + ) + output_tokens = _fill(output_tokens, can_ins_mask, _tokens, self.pad) + output_scores = _fill(output_scores, can_ins_mask, _scores, 0) + + if history is not None: + history.append(output_tokens.clone()) + + # insert words + can_ins_word = output_tokens.eq(self.unk).sum(1) > 0 + if can_ins_word.sum() != 0: + word_ins_score, word_ins_attn = self.decoder.forward_word_ins( + normalize=True, + prev_output_tokens=_skip(output_tokens, can_ins_word), + encoder_out=_skip_encoder_out(self.encoder, encoder_out, can_ins_word), + ) + word_ins_score, word_ins_pred = word_ins_score.max(-1) + _tokens, _scores = _apply_ins_words( + output_tokens[can_ins_word], + output_scores[can_ins_word], + word_ins_pred, + word_ins_score, + self.unk, + ) + + output_tokens = _fill(output_tokens, can_ins_word, _tokens, self.pad) + output_scores = _fill(output_scores, can_ins_word, _scores, 0) + attn = _fill(attn, can_ins_word, word_ins_attn, 0.0) + + if history is not None: + history.append(output_tokens.clone()) + + # delete some unnecessary paddings + cut_off = output_tokens.ne(self.pad).sum(1).max() + output_tokens = output_tokens[:, :cut_off] + output_scores = output_scores[:, :cut_off] + attn = None if attn is None else attn[:, :cut_off, :] + + return decoder_out._replace( + output_tokens=output_tokens, + output_scores=output_scores, + attn=attn, + history=history, + ) + + def initialize_output_tokens(self, encoder_out, src_tokens): + initial_output_tokens = src_tokens.new_zeros(src_tokens.size(0), 2) + initial_output_tokens[:, 0] = self.bos + initial_output_tokens[:, 1] = self.eos + + initial_output_scores = initial_output_tokens.new_zeros( + *initial_output_tokens.size() + ).type_as(encoder_out["encoder_out"][0]) + + return DecoderOut( + output_tokens=initial_output_tokens, + output_scores=initial_output_scores, + attn=None, + step=0, + max_step=0, + history=None, + ) + + +class LevenshteinTransformerDecoder(FairseqNATDecoder): + def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): + super().__init__( + args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn + ) + self.dictionary = dictionary + self.bos = dictionary.bos() + self.unk = dictionary.unk() + self.eos = dictionary.eos() + self.sampling_for_deletion = getattr(args, "sampling_for_deletion", False) + self.embed_mask_ins = Embedding(256, self.output_embed_dim * 2, None) + self.embed_word_del = Embedding(2, self.output_embed_dim, None) + + # del_word, ins_mask, ins_word + self.early_exit = [int(i) for i in args.early_exit.split(",")] + assert len(self.early_exit) == 3 + + # copy layers for mask-predict/deletion + self.layers_msk = None + if getattr(args, "no_share_maskpredictor", False): + self.layers_msk = nn.ModuleList( + [ + TransformerDecoderLayer(args, no_encoder_attn) + for _ in range(self.early_exit[1]) + ] + ) + self.layers_del = None + if getattr(args, "no_share_discriminator", False): + self.layers_del = nn.ModuleList( + [ + TransformerDecoderLayer(args, no_encoder_attn) + for _ in range(self.early_exit[0]) + ] + ) + + if getattr(args, "share_discriminator_maskpredictor", False): + assert getattr( + args, "no_share_discriminator", False + ), "must set saperate discriminator" + self.layers_msk = self.layers_del + + def extract_features( + self, + prev_output_tokens, + encoder_out=None, + early_exit=None, + layers=None, + **unused + ): + """ + Similar to *forward* but only return features. + Inputs: + prev_output_tokens: Tensor(B, T) + encoder_out: a dictionary of hidden states and masks + + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + the LevenshteinTransformer decoder has full-attention to all generated tokens + """ + # embed positions + positions = ( + self.embed_positions(prev_output_tokens) + if self.embed_positions is not None + else None + ) + + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if positions is not None: + x += positions + x = self.dropout_module(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + attn = None + inner_states = [x] + + # decoder layers + decoder_padding_mask = prev_output_tokens.eq(self.padding_idx) + layers = self.layers if layers is None else layers + early_exit = len(layers) if early_exit is None else early_exit + for _, layer in enumerate(layers[:early_exit]): + x, attn, _ = layer( + x, + encoder_out["encoder_out"][0] + if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0) + else None, + encoder_out["encoder_padding_mask"][0] + if ( + encoder_out is not None + and len(encoder_out["encoder_padding_mask"]) > 0 + ) + else None, + self_attn_mask=None, + self_attn_padding_mask=decoder_padding_mask, + ) + inner_states.append(x) + + if self.layer_norm: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + if self.project_out_dim is not None: + x = self.project_out_dim(x) + + return x, {"attn": attn, "inner_states": inner_states} + + @ensemble_decoder + def forward_mask_ins(self, normalize, encoder_out, prev_output_tokens, **unused): + features, extra = self.extract_features( + prev_output_tokens, + encoder_out=encoder_out, + early_exit=self.early_exit[1], + layers=self.layers_msk, + **unused + ) + features_cat = torch.cat([features[:, :-1, :], features[:, 1:, :]], 2) + decoder_out = F.linear(features_cat, self.embed_mask_ins.weight) + if normalize: + return F.log_softmax(decoder_out, -1), extra["attn"] + return decoder_out, extra["attn"] + + @ensemble_decoder + def forward_word_ins(self, normalize, encoder_out, prev_output_tokens, **unused): + features, extra = self.extract_features( + prev_output_tokens, + encoder_out=encoder_out, + early_exit=self.early_exit[2], + layers=self.layers, + **unused + ) + decoder_out = self.output_layer(features) + if normalize: + return F.log_softmax(decoder_out, -1), extra["attn"] + return decoder_out, extra["attn"] + + @ensemble_decoder + def forward_word_del(self, normalize, encoder_out, prev_output_tokens, **unused): + features, extra = self.extract_features( + prev_output_tokens, + encoder_out=encoder_out, + early_exit=self.early_exit[0], + layers=self.layers_del, + **unused + ) + decoder_out = F.linear(features, self.embed_word_del.weight) + if normalize: + return F.log_softmax(decoder_out, -1), extra["attn"] + return decoder_out, extra["attn"] + + +@register_model_architecture("levenshtein_transformer", "levenshtein_transformer") +def levenshtein_base_architecture(args): + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.dropout = getattr(args, "dropout", 0.1) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.share_all_embeddings = getattr(args, "share_all_embeddings", False) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.apply_bert_init = getattr(args, "apply_bert_init", False) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.sampling_for_deletion = getattr(args, "sampling_for_deletion", False) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + args.early_exit = getattr(args, "early_exit", "6,6,6") + args.no_share_discriminator = getattr(args, "no_share_discriminator", False) + args.no_share_maskpredictor = getattr(args, "no_share_maskpredictor", False) + args.share_discriminator_maskpredictor = getattr( + args, "share_discriminator_maskpredictor", False + ) + args.no_share_last_layer = getattr(args, "no_share_last_layer", False) + + +@register_model_architecture( + "levenshtein_transformer", "levenshtein_transformer_wmt_en_de" +) +def levenshtein_transformer_wmt_en_de(args): + levenshtein_base_architecture(args) + + +# similar parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017) +@register_model_architecture( + "levenshtein_transformer", "levenshtein_transformer_vaswani_wmt_en_de_big" +) +def levenshtein_transformer_vaswani_wmt_en_de_big(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + args.dropout = getattr(args, "dropout", 0.3) + levenshtein_base_architecture(args) + + +# default parameters used in tensor2tensor implementation +@register_model_architecture( + "levenshtein_transformer", "levenshtein_transformer_wmt_en_de_big" +) +def levenshtein_transformer_wmt_en_de_big_t2t(args): + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_dropout = getattr(args, "activation_dropout", 0.1) + levenshtein_transformer_vaswani_wmt_en_de_big(args) diff --git a/fairseq/models/nat/levenshtein_utils.py b/fairseq/models/nat/levenshtein_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..375a98c2e11354de085f0a7926f407bd1a6a2ad4 --- /dev/null +++ b/fairseq/models/nat/levenshtein_utils.py @@ -0,0 +1,293 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from fairseq.utils import new_arange + + +# -------------- Helper Functions --------------------------------------------------- # + + +def load_libnat(): + try: + from fairseq import libnat_cuda + + return libnat_cuda, True + + except ImportError as e: + print(str(e) + "... fall back to CPU version") + + try: + from fairseq import libnat + + return libnat, False + + except ImportError as e: + import sys + + sys.stderr.write( + "ERROR: missing libnat_cuda. run `python setup.py build_ext --inplace`\n" + ) + raise e + + +def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx): + libnat, use_cuda = load_libnat() + + def _get_ins_targets_cuda(in_tokens, out_tokens, padding_idx, unk_idx): + in_masks = in_tokens.ne(padding_idx) + out_masks = out_tokens.ne(padding_idx) + mask_ins_targets, masked_tgt_masks = libnat.generate_insertion_labels( + out_tokens.int(), + libnat.levenshtein_distance( + in_tokens.int(), + out_tokens.int(), + in_masks.sum(1).int(), + out_masks.sum(1).int(), + ), + ) + masked_tgt_masks = masked_tgt_masks.bool() & out_masks + mask_ins_targets = mask_ins_targets.type_as(in_tokens)[ + :, 1 : in_masks.size(1) + ].masked_fill_(~in_masks[:, 1:], 0) + masked_tgt_tokens = out_tokens.masked_fill(masked_tgt_masks, unk_idx) + return masked_tgt_masks, masked_tgt_tokens, mask_ins_targets + + def _get_ins_targets_cpu(in_tokens, out_tokens, padding_idx, unk_idx): + in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1) + + in_tokens_list = [ + [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist()) + ] + out_tokens_list = [ + [t for t in s if t != padding_idx] + for i, s in enumerate(out_tokens.tolist()) + ] + + full_labels = libnat.suggested_ed2_path( + in_tokens_list, out_tokens_list, padding_idx + ) + mask_inputs = [ + [len(c) if c[0] != padding_idx else 0 for c in a[:-1]] for a in full_labels + ] + + # generate labels + masked_tgt_masks = [] + for mask_input in mask_inputs: + mask_label = [] + for beam_size in mask_input[1:-1]: # HACK 1:-1 + mask_label += [0] + [1 for _ in range(beam_size)] + masked_tgt_masks.append( + mask_label + [0 for _ in range(out_seq_len - len(mask_label))] + ) + mask_ins_targets = [ + mask_input[1:-1] + + [0 for _ in range(in_seq_len - 1 - len(mask_input[1:-1]))] + for mask_input in mask_inputs + ] + + # transform to tensor + masked_tgt_masks = torch.tensor( + masked_tgt_masks, device=out_tokens.device + ).bool() + mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device) + masked_tgt_tokens = out_tokens.masked_fill(masked_tgt_masks, unk_idx) + return masked_tgt_masks, masked_tgt_tokens, mask_ins_targets + + if use_cuda: + return _get_ins_targets_cuda(in_tokens, out_tokens, padding_idx, unk_idx) + return _get_ins_targets_cpu(in_tokens, out_tokens, padding_idx, unk_idx) + + +def _get_del_targets(in_tokens, out_tokens, padding_idx): + libnat, use_cuda = load_libnat() + + def _get_del_targets_cuda(in_tokens, out_tokens, padding_idx): + in_masks = in_tokens.ne(padding_idx) + out_masks = out_tokens.ne(padding_idx) + + word_del_targets = libnat.generate_deletion_labels( + in_tokens.int(), + libnat.levenshtein_distance( + in_tokens.int(), + out_tokens.int(), + in_masks.sum(1).int(), + out_masks.sum(1).int(), + ), + ) + word_del_targets = word_del_targets.type_as(in_tokens).masked_fill_( + ~in_masks, 0 + ) + return word_del_targets + + def _get_del_targets_cpu(in_tokens, out_tokens, padding_idx): + out_seq_len = out_tokens.size(1) + with torch.cuda.device_of(in_tokens): + in_tokens_list = [ + [t for t in s if t != padding_idx] + for i, s in enumerate(in_tokens.tolist()) + ] + out_tokens_list = [ + [t for t in s if t != padding_idx] + for i, s in enumerate(out_tokens.tolist()) + ] + + full_labels = libnat.suggested_ed2_path( + in_tokens_list, out_tokens_list, padding_idx + ) + word_del_targets = [b[-1] for b in full_labels] + word_del_targets = [ + labels + [0 for _ in range(out_seq_len - len(labels))] + for labels in word_del_targets + ] + + # transform to tensor + word_del_targets = torch.tensor(word_del_targets, device=out_tokens.device) + return word_del_targets + + if use_cuda: + return _get_del_targets_cuda(in_tokens, out_tokens, padding_idx) + return _get_del_targets_cpu(in_tokens, out_tokens, padding_idx) + + +def _apply_ins_masks( + in_tokens, in_scores, mask_ins_pred, padding_idx, unk_idx, eos_idx +): + + in_masks = in_tokens.ne(padding_idx) + in_lengths = in_masks.sum(1) + + # HACK: hacky way to shift all the paddings to eos first. + in_tokens.masked_fill_(~in_masks, eos_idx) + mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0) + + out_lengths = in_lengths + mask_ins_pred.sum(1) + out_max_len = out_lengths.max() + out_masks = new_arange(out_lengths, out_max_len)[None, :] < out_lengths[:, None] + + reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1) + out_tokens = ( + in_tokens.new_zeros(in_tokens.size(0), out_max_len) + .fill_(padding_idx) + .masked_fill_(out_masks, unk_idx) + ) + out_tokens[:, 0] = in_tokens[:, 0] + out_tokens.scatter_(1, reordering, in_tokens[:, 1:]) + + out_scores = None + if in_scores is not None: + in_scores.masked_fill_(~in_masks, 0) + out_scores = in_scores.new_zeros(*out_tokens.size()) + out_scores[:, 0] = in_scores[:, 0] + out_scores.scatter_(1, reordering, in_scores[:, 1:]) + + return out_tokens, out_scores + + +def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, unk_idx): + word_ins_masks = in_tokens.eq(unk_idx) + out_tokens = in_tokens.masked_scatter(word_ins_masks, word_ins_pred[word_ins_masks]) + + if in_scores is not None: + out_scores = in_scores.masked_scatter( + word_ins_masks, word_ins_scores[word_ins_masks] + ) + else: + out_scores = None + + return out_tokens, out_scores + + +def _apply_del_words( + in_tokens, in_scores, in_attn, word_del_pred, padding_idx, bos_idx, eos_idx +): + # apply deletion to a tensor + in_masks = in_tokens.ne(padding_idx) + bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx) + + max_len = in_tokens.size(1) + word_del_pred.masked_fill_(~in_masks, 1) + word_del_pred.masked_fill_(bos_eos_masks, 0) + + reordering = new_arange(in_tokens).masked_fill_(word_del_pred, max_len).sort(1)[1] + + out_tokens = in_tokens.masked_fill(word_del_pred, padding_idx).gather(1, reordering) + + out_scores = None + if in_scores is not None: + out_scores = in_scores.masked_fill(word_del_pred, 0).gather(1, reordering) + + out_attn = None + if in_attn is not None: + _mask = word_del_pred[:, :, None].expand_as(in_attn) + _reordering = reordering[:, :, None].expand_as(in_attn) + out_attn = in_attn.masked_fill(_mask, 0.0).gather(1, _reordering) + + return out_tokens, out_scores, out_attn + + +def _skip(x, mask): + """ + Getting sliced (dim=0) tensor by mask. Supporting tensor and list/dict of tensors. + """ + if isinstance(x, int): + return x + + if x is None: + return None + + if isinstance(x, torch.Tensor): + if x.size(0) == mask.size(0): + return x[mask] + elif x.size(1) == mask.size(0): + return x[:, mask] + + if isinstance(x, list): + return [_skip(x_i, mask) for x_i in x] + + if isinstance(x, dict): + return {k: _skip(v, mask) for k, v in x.items()} + + raise NotImplementedError + + +def _skip_encoder_out(encoder, encoder_out, mask): + if not mask.any(): + return encoder_out + else: + return encoder.reorder_encoder_out( + encoder_out, mask.nonzero(as_tuple=False).squeeze() + ) + + +def _fill(x, mask, y, padding_idx): + """ + Filling tensor x with y at masked positions (dim=0). + """ + if x is None: + return y + assert x.dim() == y.dim() and mask.size(0) == x.size(0) + assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2)) + n_selected = mask.sum() + assert n_selected == y.size(0) + + if n_selected == x.size(0): + return y + + if x.size(1) < y.size(1): + dims = [x.size(0), y.size(1) - x.size(1)] + if x.dim() == 3: + dims.append(x.size(2)) + x = torch.cat([x, x.new_zeros(*dims).fill_(padding_idx)], 1) + x[mask] = y + elif x.size(1) > y.size(1): + x[mask] = padding_idx + if x.dim() == 2: + x[mask, : y.size(1)] = y + else: + x[mask, : y.size(1), :] = y + else: + x[mask] = y + return x diff --git a/fairseq/models/nat/nat_crf_transformer.py b/fairseq/models/nat/nat_crf_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d4b3cd931ceb077eb30db73df1d5d6cd714a86c2 --- /dev/null +++ b/fairseq/models/nat/nat_crf_transformer.py @@ -0,0 +1,121 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from fairseq.models import register_model, register_model_architecture +from fairseq.models.nat import NATransformerModel, base_architecture +from fairseq.modules import DynamicCRF + + +@register_model("nacrf_transformer") +class NACRFTransformerModel(NATransformerModel): + def __init__(self, args, encoder, decoder): + super().__init__(args, encoder, decoder) + self.crf_layer = DynamicCRF( + num_embedding=len(self.tgt_dict), + low_rank=args.crf_lowrank_approx, + beam_size=args.crf_beam_approx, + ) + + @property + def allow_ensemble(self): + return False + + @staticmethod + def add_args(parser): + NATransformerModel.add_args(parser) + parser.add_argument( + "--crf-lowrank-approx", + type=int, + help="the dimension of low-rank approximation of transition", + ) + parser.add_argument( + "--crf-beam-approx", + type=int, + help="the beam size for apporixmating the normalizing factor", + ) + parser.add_argument( + "--word-ins-loss-factor", + type=float, + help="weights on NAT loss used to co-training with CRF loss.", + ) + + def forward( + self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs + ): + # encoding + encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) + + # length prediction + length_out = self.decoder.forward_length( + normalize=False, encoder_out=encoder_out + ) + length_tgt = self.decoder.forward_length_prediction( + length_out, encoder_out, tgt_tokens + ) + + # decoding + word_ins_out = self.decoder( + normalize=False, + prev_output_tokens=prev_output_tokens, + encoder_out=encoder_out, + ) + word_ins_tgt, word_ins_mask = tgt_tokens, tgt_tokens.ne(self.pad) + + # compute the log-likelihood of CRF + crf_nll = -self.crf_layer(word_ins_out, word_ins_tgt, word_ins_mask) + crf_nll = (crf_nll / word_ins_mask.type_as(crf_nll).sum(-1)).mean() + + return { + "word_ins": { + "out": word_ins_out, + "tgt": word_ins_tgt, + "mask": word_ins_mask, + "ls": self.args.label_smoothing, + "nll_loss": True, + "factor": self.args.word_ins_loss_factor, + }, + "word_crf": {"loss": crf_nll}, + "length": { + "out": length_out, + "tgt": length_tgt, + "factor": self.decoder.length_loss_factor, + }, + } + + def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs): + output_tokens = decoder_out.output_tokens + output_scores = decoder_out.output_scores + history = decoder_out.history + + # execute the decoder and get emission scores + output_masks = output_tokens.ne(self.pad) + word_ins_out = self.decoder( + normalize=False, prev_output_tokens=output_tokens, encoder_out=encoder_out + ) + + # run viterbi decoding through CRF + _scores, _tokens = self.crf_layer.forward_decoder(word_ins_out, output_masks) + output_tokens.masked_scatter_(output_masks, _tokens[output_masks]) + output_scores.masked_scatter_(output_masks, _scores[output_masks]) + if history is not None: + history.append(output_tokens.clone()) + + return decoder_out._replace( + output_tokens=output_tokens, + output_scores=output_scores, + attn=None, + history=history, + ) + + +@register_model_architecture("nacrf_transformer", "nacrf_transformer") +def nacrf_base_architecture(args): + args.crf_lowrank_approx = getattr(args, "crf_lowrank_approx", 32) + args.crf_beam_approx = getattr(args, "crf_beam_approx", 64) + args.word_ins_loss_factor = getattr(args, "word_ins_loss_factor", 0.5) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True) + base_architecture(args) diff --git a/fairseq/models/nat/nonautoregressive_ensembles.py b/fairseq/models/nat/nonautoregressive_ensembles.py new file mode 100644 index 0000000000000000000000000000000000000000..0a0221f9c437afc547afa1ad019e5e4a6aaaef17 --- /dev/null +++ b/fairseq/models/nat/nonautoregressive_ensembles.py @@ -0,0 +1,254 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn.functional as F +from fairseq.models.nat import ( + _apply_del_words, + _apply_ins_masks, + _apply_ins_words, + _fill, + _skip, + _skip_encoder_out, +) + + +class _EnsembleModelEncoder(object): + def __init__(self, models): + self.models = models + + def reorder_encoder_out(self, encoder_outs, new_order): + encoder_outs = [ + model.encoder.reorder_encoder_out(encoder_out, new_order) + for model, encoder_out in zip(self.models, encoder_outs) + ] + return encoder_outs + + +class BasicEnsembleModel(torch.nn.Module): + """A wrapper around an ensemble of models.""" + + def __init__(self, models): + super().__init__() + self.models = torch.nn.ModuleList(models) + self.bos = self.models[0].decoder.dictionary.bos() + self.eos = self.models[0].decoder.dictionary.eos() + self.pad = self.models[0].decoder.dictionary.pad() + self.unk = self.models[0].decoder.dictionary.unk() + self.encoder = _EnsembleModelEncoder(self.models) + + def has_encoder(self): + return hasattr(self.models[0], "encoder") + + def max_decoder_positions(self): + return min(m.max_decoder_positions() for m in self.models) + + @torch.no_grad() + def forward_encoder(self, encoder_input): + if not self.has_encoder(): + return None + return [model.forward_encoder(encoder_input) for model in self.models] + + @torch.no_grad() + def forward_decoder(self, *inputs): + raise NotImplementedError + + def initialize_output_tokens(self, *inputs): + raise NotImplementedError + + +class EnsembleLevT(BasicEnsembleModel): + """A wrapper around an ensemble of models.""" + + def __init__(self, models): + super().__init__(models) + + @torch.no_grad() + def forward_decoder( + self, decoder_out, encoder_outs, eos_penalty=0.0, max_ratio=None, **kwargs + ): + # LevT ensembling + # A pipeline of three steps: deletion, placeholder, and word insertion. + # We need to average scores in each step in a pipeline way because of dependence. + # deletion + output_tokens = decoder_out.output_tokens + output_scores = decoder_out.output_scores + attn = decoder_out.attn + + bsz = output_tokens.size(0) + if max_ratio is None: + max_lens = output_tokens.new().fill_(255) + else: + if not encoder_outs[0]["encoder_padding_mask"]: + src_lens = ( + encoder_outs[0]["encoder_out"][0] + .new(bsz) + .fill_(encoder_outs[0]["encoder_out"][0].size(1)) + ) + else: + src_lens = (~encoder_outs[0]["encoder_padding_mask"][0]).sum(1) + max_lens = (src_lens * max_ratio).clamp(min=10).long() + + # delete words + # do not delete tokens if it is + can_del_word = output_tokens.ne(self.pad).sum(1) > 2 + if can_del_word.sum() != 0: # we cannot delete, skip + output_tokens, output_scores, attn = self.forward_word_del( + encoder_outs, + output_tokens, + output_scores, + attn, + can_del_word, + ) + + # insert placeholders + can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lens + if can_ins_mask.sum() != 0: + output_tokens, output_scores = self.forward_mask_ins( + encoder_outs, + output_tokens, + output_scores, + can_ins_mask, + eos_penalty, + max_lens, + ) + + # insert words + can_ins_word = output_tokens.eq(self.unk).sum(1) > 0 + if can_ins_word.sum() != 0: + output_tokens, output_scores, attn = self.forward_word_ins( + encoder_outs, + output_tokens, + output_scores, + attn, + can_ins_word, + ) + + # delete some unnecessary paddings + cut_off = output_tokens.ne(self.pad).sum(1).max() + output_tokens = output_tokens[:, :cut_off] + output_scores = output_scores[:, :cut_off] + attn = None if attn is None else attn[:, :cut_off, :] + return decoder_out._replace( + output_tokens=output_tokens, + output_scores=output_scores, + attn=attn, + history=None, + ) + + def forward_word_del( + self, encoder_outs, output_tokens, output_scores, attn, can_del_word + ): + word_del_score_avg = [] + word_del_attn_avg = [] + for model, encoder_out in zip(self.models, encoder_outs): + word_del_out, word_del_attn = model.decoder.forward_word_del( + _skip(output_tokens, can_del_word), + _skip_encoder_out(model.encoder, encoder_out, can_del_word), + ) + word_del_score = F.log_softmax(word_del_out, 2) + word_del_score_avg.append(word_del_score) + word_del_attn_avg.append(word_del_attn) + word_del_score_avg = torch.logsumexp( + torch.stack(word_del_score_avg, dim=0), dim=0 + ) - math.log(len(self.models)) + word_del_pred = word_del_score_avg.max(-1)[1].bool() + if word_del_attn_avg[0] is not None: + word_del_attn_avg = torch.stack(word_del_attn_avg, dim=0) / len(self.models) + else: + word_del_attn_avg = None + + _tokens, _scores, _attn = _apply_del_words( + output_tokens[can_del_word], + output_scores[can_del_word], + word_del_attn_avg, + word_del_pred, + self.pad, + self.bos, + self.eos, + ) + output_tokens = _fill(output_tokens, can_del_word, _tokens, self.pad) + output_scores = _fill(output_scores, can_del_word, _scores, 0) + attn = _fill(attn, can_del_word, _attn, 0.0) + return output_tokens, output_scores, attn + + def forward_mask_ins( + self, + encoder_outs, + output_tokens, + output_scores, + can_ins_mask, + eos_penalty, + max_lens, + ): + mask_ins_score_avg = [] + for model, encoder_out in zip(self.models, encoder_outs): + mask_ins_out, _ = model.decoder.forward_mask_ins( + _skip(output_tokens, can_ins_mask), + _skip_encoder_out(model.encoder, encoder_out, can_ins_mask), + ) + mask_ins_score = F.log_softmax(mask_ins_out, 2) + if eos_penalty > 0.0: + mask_ins_score[:, :, 0] -= eos_penalty + mask_ins_score_avg.append(mask_ins_score) + mask_ins_score_avg = torch.logsumexp( + torch.stack(mask_ins_score_avg, dim=0), dim=0 + ) - math.log(len(self.models)) + mask_ins_pred = mask_ins_score_avg.max(-1)[1] + mask_ins_pred = torch.min( + mask_ins_pred, max_lens[can_ins_mask, None].expand_as(mask_ins_pred) + ) + _tokens, _scores = _apply_ins_masks( + output_tokens[can_ins_mask], + output_scores[can_ins_mask], + mask_ins_pred, + self.pad, + self.unk, + self.eos, + ) + output_tokens = _fill(output_tokens, can_ins_mask, _tokens, self.pad) + output_scores = _fill(output_scores, can_ins_mask, _scores, 0) + return output_tokens, output_scores + + def forward_word_ins( + self, encoder_outs, output_tokens, output_scores, attn, can_ins_word + ): + word_ins_score_avg = [] + word_ins_attn_avg = [] + for model, encoder_out in zip(self.models, encoder_outs): + word_ins_out, word_ins_attn = model.decoder.forward_word_ins( + _skip(output_tokens, can_ins_word), + _skip_encoder_out(model.encoder, encoder_out, can_ins_word), + ) + word_ins_score = F.log_softmax(word_ins_out, 2) + word_ins_score_avg.append(word_ins_score) + word_ins_attn_avg.append(word_ins_attn) + word_ins_score_avg = torch.logsumexp( + torch.stack(word_ins_score_avg, dim=0), dim=0 + ) - math.log(len(self.models)) + if word_ins_attn_avg[0] is not None: + word_ins_attn_avg = torch.stack(word_ins_attn_avg, dim=0) / len(self.models) + else: + word_ins_attn_avg = None + word_ins_score_max, word_ins_pred = word_ins_score_avg.max(-1) + + _tokens, _scores = _apply_ins_words( + output_tokens[can_ins_word], + output_scores[can_ins_word], + word_ins_pred, + word_ins_score_max, + self.unk, + ) + + output_tokens = _fill(output_tokens, can_ins_word, _tokens, self.pad) + output_scores = _fill(output_scores, can_ins_word, _scores, 0) + attn = _fill(attn, can_ins_word, word_ins_attn, 0.0) + return output_tokens, output_scores, attn + + def initialize_output_tokens(self, encoder_outs, src_tokens): + # LevT doesn't do length prediction. + return self.models[0].initialize_output_tokens(encoder_outs[0], src_tokens) diff --git a/fairseq/models/nat/nonautoregressive_transformer.py b/fairseq/models/nat/nonautoregressive_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d114202d25fbd1dca66c7abebb0b0a8bffbe094d --- /dev/null +++ b/fairseq/models/nat/nonautoregressive_transformer.py @@ -0,0 +1,456 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.iterative_refinement_generator import DecoderOut +from fairseq.models import register_model, register_model_architecture +from fairseq.models.nat import FairseqNATDecoder, FairseqNATModel, ensemble_decoder +from fairseq.models.transformer import Embedding +from fairseq.modules.transformer_sentence_encoder import init_bert_params + + +def _mean_pooling(enc_feats, src_masks): + # enc_feats: T x B x C + # src_masks: B x T or None + if src_masks is None: + enc_feats = enc_feats.mean(0) + else: + src_masks = (~src_masks).transpose(0, 1).type_as(enc_feats) + enc_feats = ( + (enc_feats / src_masks.sum(0)[None, :, None]) * src_masks[:, :, None] + ).sum(0) + return enc_feats + + +def _argmax(x, dim): + return (x == x.max(dim, keepdim=True)[0]).type_as(x) + + +def _uniform_assignment(src_lens, trg_lens): + max_trg_len = trg_lens.max() + steps = (src_lens.float() - 1) / (trg_lens.float() - 1) # step-size + # max_trg_len + index_t = utils.new_arange(trg_lens, max_trg_len).float() + index_t = steps[:, None] * index_t[None, :] # batch_size X max_trg_len + index_t = torch.round(index_t).long().detach() + return index_t + + +@register_model("nonautoregressive_transformer") +class NATransformerModel(FairseqNATModel): + @property + def allow_length_beam(self): + return True + + @staticmethod + def add_args(parser): + FairseqNATModel.add_args(parser) + + # length prediction + parser.add_argument( + "--src-embedding-copy", + action="store_true", + help="copy encoder word embeddings as the initial input of the decoder", + ) + parser.add_argument( + "--pred-length-offset", + action="store_true", + help="predicting the length difference between the target and source sentences", + ) + parser.add_argument( + "--sg-length-pred", + action="store_true", + help="stop the gradients back-propagated from the length predictor", + ) + parser.add_argument( + "--length-loss-factor", + type=float, + help="weights on the length prediction loss", + ) + + @classmethod + def build_decoder(cls, args, tgt_dict, embed_tokens): + decoder = NATransformerDecoder(args, tgt_dict, embed_tokens) + if getattr(args, "apply_bert_init", False): + decoder.apply(init_bert_params) + return decoder + + def forward( + self, src_tokens, src_lengths, prev_output_tokens, tgt_tokens, **kwargs + ): + # encoding + encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) + + # length prediction + length_out = self.decoder.forward_length( + normalize=False, encoder_out=encoder_out + ) + length_tgt = self.decoder.forward_length_prediction( + length_out, encoder_out, tgt_tokens + ) + + # decoding + word_ins_out = self.decoder( + normalize=False, + prev_output_tokens=prev_output_tokens, + encoder_out=encoder_out, + ) + + return { + "word_ins": { + "out": word_ins_out, + "tgt": tgt_tokens, + "mask": tgt_tokens.ne(self.pad), + "ls": self.args.label_smoothing, + "nll_loss": True, + }, + "length": { + "out": length_out, + "tgt": length_tgt, + "factor": self.decoder.length_loss_factor, + }, + } + + def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs): + step = decoder_out.step + output_tokens = decoder_out.output_tokens + output_scores = decoder_out.output_scores + history = decoder_out.history + + # execute the decoder + output_masks = output_tokens.ne(self.pad) + _scores, _tokens = self.decoder( + normalize=True, + prev_output_tokens=output_tokens, + encoder_out=encoder_out, + step=step, + ).max(-1) + + output_tokens.masked_scatter_(output_masks, _tokens[output_masks]) + output_scores.masked_scatter_(output_masks, _scores[output_masks]) + if history is not None: + history.append(output_tokens.clone()) + + return decoder_out._replace( + output_tokens=output_tokens, + output_scores=output_scores, + attn=None, + history=history, + ) + + def initialize_output_tokens(self, encoder_out, src_tokens): + # length prediction + length_tgt = self.decoder.forward_length_prediction( + self.decoder.forward_length(normalize=True, encoder_out=encoder_out), + encoder_out=encoder_out, + ) + + max_length = length_tgt.clamp_(min=2).max() + idx_length = utils.new_arange(src_tokens, max_length) + + initial_output_tokens = src_tokens.new_zeros( + src_tokens.size(0), max_length + ).fill_(self.pad) + initial_output_tokens.masked_fill_( + idx_length[None, :] < length_tgt[:, None], self.unk + ) + initial_output_tokens[:, 0] = self.bos + initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos) + + initial_output_scores = initial_output_tokens.new_zeros( + *initial_output_tokens.size() + ).type_as(encoder_out["encoder_out"][0]) + + return DecoderOut( + output_tokens=initial_output_tokens, + output_scores=initial_output_scores, + attn=None, + step=0, + max_step=0, + history=None, + ) + + def regenerate_length_beam(self, decoder_out, beam_size): + output_tokens = decoder_out.output_tokens + length_tgt = output_tokens.ne(self.pad).sum(1) + length_tgt = ( + length_tgt[:, None] + + utils.new_arange(length_tgt, 1, beam_size) + - beam_size // 2 + ) + length_tgt = length_tgt.view(-1).clamp_(min=2) + max_length = length_tgt.max() + idx_length = utils.new_arange(length_tgt, max_length) + + initial_output_tokens = output_tokens.new_zeros( + length_tgt.size(0), max_length + ).fill_(self.pad) + initial_output_tokens.masked_fill_( + idx_length[None, :] < length_tgt[:, None], self.unk + ) + initial_output_tokens[:, 0] = self.bos + initial_output_tokens.scatter_(1, length_tgt[:, None] - 1, self.eos) + + initial_output_scores = initial_output_tokens.new_zeros( + *initial_output_tokens.size() + ).type_as(decoder_out.output_scores) + + return decoder_out._replace( + output_tokens=initial_output_tokens, output_scores=initial_output_scores + ) + + +class NATransformerDecoder(FairseqNATDecoder): + def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): + super().__init__( + args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn + ) + self.dictionary = dictionary + self.bos = dictionary.bos() + self.unk = dictionary.unk() + self.eos = dictionary.eos() + + self.encoder_embed_dim = args.encoder_embed_dim + self.sg_length_pred = getattr(args, "sg_length_pred", False) + self.pred_length_offset = getattr(args, "pred_length_offset", False) + self.length_loss_factor = getattr(args, "length_loss_factor", 0.1) + self.src_embedding_copy = getattr(args, "src_embedding_copy", False) + self.embed_length = Embedding(256, self.encoder_embed_dim, None) + + @ensemble_decoder + def forward(self, normalize, encoder_out, prev_output_tokens, step=0, **unused): + features, _ = self.extract_features( + prev_output_tokens, + encoder_out=encoder_out, + embedding_copy=(step == 0) & self.src_embedding_copy, + ) + decoder_out = self.output_layer(features) + return F.log_softmax(decoder_out, -1) if normalize else decoder_out + + @ensemble_decoder + def forward_length(self, normalize, encoder_out): + enc_feats = encoder_out["encoder_out"][0] # T x B x C + if len(encoder_out["encoder_padding_mask"]) > 0: + src_masks = encoder_out["encoder_padding_mask"][0] # B x T + else: + src_masks = None + enc_feats = _mean_pooling(enc_feats, src_masks) + if self.sg_length_pred: + enc_feats = enc_feats.detach() + length_out = F.linear(enc_feats, self.embed_length.weight) + return F.log_softmax(length_out, -1) if normalize else length_out + + def extract_features( + self, + prev_output_tokens, + encoder_out=None, + early_exit=None, + embedding_copy=False, + **unused + ): + """ + Similar to *forward* but only return features. + + Inputs: + prev_output_tokens: Tensor(B, T) + encoder_out: a dictionary of hidden states and masks + + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + the LevenshteinTransformer decoder has full-attention to all generated tokens + """ + # embedding + if embedding_copy: + src_embd = encoder_out["encoder_embedding"][0] + if len(encoder_out["encoder_padding_mask"]) > 0: + src_mask = encoder_out["encoder_padding_mask"][0] + else: + src_mask = None + src_mask = ( + ~src_mask + if src_mask is not None + else prev_output_tokens.new_ones(*src_embd.size()[:2]).bool() + ) + + x, decoder_padding_mask = self.forward_embedding( + prev_output_tokens, + self.forward_copying_source( + src_embd, src_mask, prev_output_tokens.ne(self.padding_idx) + ), + ) + + else: + + x, decoder_padding_mask = self.forward_embedding(prev_output_tokens) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + attn = None + inner_states = [x] + + # decoder layers + for i, layer in enumerate(self.layers): + + # early exit from the decoder. + if (early_exit is not None) and (i >= early_exit): + break + + x, attn, _ = layer( + x, + encoder_out["encoder_out"][0] + if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0) + else None, + encoder_out["encoder_padding_mask"][0] + if ( + encoder_out is not None + and len(encoder_out["encoder_padding_mask"]) > 0 + ) + else None, + self_attn_mask=None, + self_attn_padding_mask=decoder_padding_mask, + ) + inner_states.append(x) + + if self.layer_norm: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + if self.project_out_dim is not None: + x = self.project_out_dim(x) + + return x, {"attn": attn, "inner_states": inner_states} + + def forward_embedding(self, prev_output_tokens, states=None): + # embed positions + positions = ( + self.embed_positions(prev_output_tokens) + if self.embed_positions is not None + else None + ) + + # embed tokens and positions + if states is None: + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + if self.project_in_dim is not None: + x = self.project_in_dim(x) + else: + x = states + + if positions is not None: + x += positions + x = self.dropout_module(x) + decoder_padding_mask = prev_output_tokens.eq(self.padding_idx) + return x, decoder_padding_mask + + def forward_copying_source(self, src_embeds, src_masks, tgt_masks): + length_sources = src_masks.sum(1) + length_targets = tgt_masks.sum(1) + mapped_inputs = _uniform_assignment(length_sources, length_targets).masked_fill( + ~tgt_masks, 0 + ) + copied_embedding = torch.gather( + src_embeds, + 1, + mapped_inputs.unsqueeze(-1).expand( + *mapped_inputs.size(), src_embeds.size(-1) + ), + ) + return copied_embedding + + def forward_length_prediction(self, length_out, encoder_out, tgt_tokens=None): + enc_feats = encoder_out["encoder_out"][0] # T x B x C + if len(encoder_out["encoder_padding_mask"]) > 0: + src_masks = encoder_out["encoder_padding_mask"][0] # B x T + else: + src_masks = None + if self.pred_length_offset: + if src_masks is None: + src_lengs = enc_feats.new_ones(enc_feats.size(1)).fill_( + enc_feats.size(0) + ) + else: + src_lengs = (~src_masks).transpose(0, 1).type_as(enc_feats).sum(0) + src_lengs = src_lengs.long() + + if tgt_tokens is not None: + # obtain the length target + tgt_lengs = tgt_tokens.ne(self.padding_idx).sum(1).long() + if self.pred_length_offset: + length_tgt = tgt_lengs - src_lengs + 128 + else: + length_tgt = tgt_lengs + length_tgt = length_tgt.clamp(min=0, max=255) + + else: + # predict the length target (greedy for now) + # TODO: implementing length-beam + pred_lengs = length_out.max(-1)[1] + if self.pred_length_offset: + length_tgt = pred_lengs - 128 + src_lengs + else: + length_tgt = pred_lengs + + return length_tgt + + +@register_model_architecture( + "nonautoregressive_transformer", "nonautoregressive_transformer" +) +def base_architecture(args): + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.dropout = getattr(args, "dropout", 0.1) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.share_all_embeddings = getattr(args, "share_all_embeddings", False) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.apply_bert_init = getattr(args, "apply_bert_init", False) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + + # --- special arguments --- + args.sg_length_pred = getattr(args, "sg_length_pred", False) + args.pred_length_offset = getattr(args, "pred_length_offset", False) + args.length_loss_factor = getattr(args, "length_loss_factor", 0.1) + args.src_embedding_copy = getattr(args, "src_embedding_copy", False) + + +@register_model_architecture( + "nonautoregressive_transformer", "nonautoregressive_transformer_wmt_en_de" +) +def nonautoregressive_transformer_wmt_en_de(args): + base_architecture(args) diff --git a/fairseq/models/roberta/__init__.py b/fairseq/models/roberta/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4cd723ae96aec8e3182773483f123109d23b620e --- /dev/null +++ b/fairseq/models/roberta/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .hub_interface import * # noqa +from .model import * # noqa +from .enc_dec import * # noqa +from .model_camembert import * # noqa +from .model_gottbert import * # noqa +from .model_xlmr import * # noqa diff --git a/fairseq/models/roberta/__pycache__/__init__.cpython-310.pyc b/fairseq/models/roberta/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2eae8abdae446f65ee91d4f5427b3d1aa915541 Binary files /dev/null and b/fairseq/models/roberta/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/models/roberta/__pycache__/enc_dec.cpython-310.pyc b/fairseq/models/roberta/__pycache__/enc_dec.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cbb7b1d90d3d5dfe1611c21f73b9e77d5721fa5 Binary files /dev/null and b/fairseq/models/roberta/__pycache__/enc_dec.cpython-310.pyc differ diff --git a/fairseq/models/roberta/__pycache__/hub_interface.cpython-310.pyc b/fairseq/models/roberta/__pycache__/hub_interface.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1bd49ae05943b2ea3c5927235ce0e92e1d5475e Binary files /dev/null and b/fairseq/models/roberta/__pycache__/hub_interface.cpython-310.pyc differ diff --git a/fairseq/models/roberta/__pycache__/model.cpython-310.pyc b/fairseq/models/roberta/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5541b3a239be889fa67040bc1cf02405547521e Binary files /dev/null and b/fairseq/models/roberta/__pycache__/model.cpython-310.pyc differ diff --git a/fairseq/models/roberta/__pycache__/model_camembert.cpython-310.pyc b/fairseq/models/roberta/__pycache__/model_camembert.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13745776d1521ba2e5805d36831484b664e12eec Binary files /dev/null and b/fairseq/models/roberta/__pycache__/model_camembert.cpython-310.pyc differ diff --git a/fairseq/models/roberta/__pycache__/model_gottbert.cpython-310.pyc b/fairseq/models/roberta/__pycache__/model_gottbert.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b04bf1f7241d4c631c2d61520a930fb4a77a36fa Binary files /dev/null and b/fairseq/models/roberta/__pycache__/model_gottbert.cpython-310.pyc differ diff --git a/fairseq/models/roberta/__pycache__/model_xlmr.cpython-310.pyc b/fairseq/models/roberta/__pycache__/model_xlmr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8394ca5df0c0b09be39dc5129f1ee27edc7bcf39 Binary files /dev/null and b/fairseq/models/roberta/__pycache__/model_xlmr.cpython-310.pyc differ diff --git a/fairseq/models/roberta/alignment_utils.py b/fairseq/models/roberta/alignment_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ccc7f74cb94d5b8baa2d4e9dfd44f653d47ee43e --- /dev/null +++ b/fairseq/models/roberta/alignment_utils.py @@ -0,0 +1,118 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections import Counter +from typing import List + +import torch + + +def align_bpe_to_words(roberta, bpe_tokens: torch.LongTensor, other_tokens: List[str]): + """ + Helper to align GPT-2 BPE to other tokenization formats (e.g., spaCy). + + Args: + roberta (RobertaHubInterface): RoBERTa instance + bpe_tokens (torch.LongTensor): GPT-2 BPE tokens of shape `(T_bpe)` + other_tokens (List[str]): other tokens of shape `(T_words)` + + Returns: + List[str]: mapping from *other_tokens* to corresponding *bpe_tokens*. + """ + assert bpe_tokens.dim() == 1 + assert bpe_tokens[0] == 0 + + def clean(text): + return text.strip() + + # remove whitespaces to simplify alignment + bpe_tokens = [roberta.task.source_dictionary.string([x]) for x in bpe_tokens] + bpe_tokens = [ + clean(roberta.bpe.decode(x) if x not in {"", ""} else x) for x in bpe_tokens + ] + other_tokens = [clean(str(o)) for o in other_tokens] + + # strip leading + bpe_tokens = bpe_tokens[1:] + assert "".join(bpe_tokens) == "".join(other_tokens) + + # create alignment from every word to a list of BPE tokens + alignment = [] + bpe_toks = filter(lambda item: item[1] != "", enumerate(bpe_tokens, start=1)) + j, bpe_tok = next(bpe_toks) + for other_tok in other_tokens: + bpe_indices = [] + while True: + if other_tok.startswith(bpe_tok): + bpe_indices.append(j) + other_tok = other_tok[len(bpe_tok) :] + try: + j, bpe_tok = next(bpe_toks) + except StopIteration: + j, bpe_tok = None, None + elif bpe_tok.startswith(other_tok): + # other_tok spans multiple BPE tokens + bpe_indices.append(j) + bpe_tok = bpe_tok[len(other_tok) :] + other_tok = "" + else: + raise Exception('Cannot align "{}" and "{}"'.format(other_tok, bpe_tok)) + if other_tok == "": + break + assert len(bpe_indices) > 0 + alignment.append(bpe_indices) + assert len(alignment) == len(other_tokens) + + return alignment + + +def align_features_to_words(roberta, features, alignment): + """ + Align given features to words. + + Args: + roberta (RobertaHubInterface): RoBERTa instance + features (torch.Tensor): features to align of shape `(T_bpe x C)` + alignment: alignment between BPE tokens and words returned by + func:`align_bpe_to_words`. + """ + assert features.dim() == 2 + + bpe_counts = Counter(j for bpe_indices in alignment for j in bpe_indices) + assert bpe_counts[0] == 0 # shouldn't be aligned + denom = features.new([bpe_counts.get(j, 1) for j in range(len(features))]) + weighted_features = features / denom.unsqueeze(-1) + + output = [weighted_features[0]] + largest_j = -1 + for bpe_indices in alignment: + output.append(weighted_features[bpe_indices].sum(dim=0)) + largest_j = max(largest_j, *bpe_indices) + for j in range(largest_j + 1, len(features)): + output.append(weighted_features[j]) + output = torch.stack(output) + assert torch.all(torch.abs(output.sum(dim=0) - features.sum(dim=0)) < 1e-4) + return output + + +def spacy_nlp(): + if getattr(spacy_nlp, "_nlp", None) is None: + try: + from spacy.lang.en import English + + spacy_nlp._nlp = English() + except ImportError: + raise ImportError("Please install spacy with: pip install spacy") + return spacy_nlp._nlp + + +def spacy_tokenizer(): + if getattr(spacy_tokenizer, "_tokenizer", None) is None: + try: + nlp = spacy_nlp() + spacy_tokenizer._tokenizer = nlp.Defaults.create_tokenizer(nlp) + except ImportError: + raise ImportError("Please install spacy with: pip install spacy") + return spacy_tokenizer._tokenizer diff --git a/fairseq/models/roberta/enc_dec.py b/fairseq/models/roberta/enc_dec.py new file mode 100644 index 0000000000000000000000000000000000000000..e538dee0aa5984b1a3d02ce81117d2046c030593 --- /dev/null +++ b/fairseq/models/roberta/enc_dec.py @@ -0,0 +1,192 @@ +import argparse +import logging + +import torch.nn as nn +import fairseq.checkpoint_utils +from fairseq.models import ( + FairseqEncoderDecoderModel, + register_model, + register_model_architecture, +) +from fairseq.models.transformer import TransformerDecoder +from fairseq.models.roberta import model as roberta + +logger = logging.getLogger(__name__) + + +@register_model("roberta_enc_dec") +class RobertaEncDecModel(FairseqEncoderDecoderModel): + @staticmethod + def add_args(parser): + parser.add_argument( + "--pretrained-mlm-checkpoint", + default=None, + type=str, + metavar="PRETRAINED", + help="path to pretrained mlm checkpoint", + ) + parser.add_argument( + "--pretrained-decoder", action="store_true", help="reload decoder" + ) + parser.add_argument( + "--hack-layernorm-embedding", + action="store_true", + help="hack to reload old models trained with encoder-normalize-before=False (no equivalent to encoder-normalize-before=False and layernorm_embedding=False", + ) + parser.add_argument( + "--share-decoder-input-output-embed", + action="store_true", + help="share decoder input and output embeddings", + ) + parser.add_argument( + "--share-all-embeddings", + action="store_true", + help="share encoder, decoder and output embeddings" + " (requires shared dictionary and embed dim)", + ) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present + base_enc_dec_architecture(args) + if args.pretrained_mlm_checkpoint: + arg_overrides = None + if args.hack_layernorm_embedding: + arg_overrides = {"layernorm_embedding": False} + loaded = fairseq.checkpoint_utils.load_model_ensemble_and_task( + [args.pretrained_mlm_checkpoint], arg_overrides=arg_overrides + ) + ([roberta_enc], _cfg, _task) = loaded + else: + # Do we need to edit untie_weights here ? + share_in_out = ( + args.share_decoder_input_output_embed or args.share_all_embeddings + ) + args.untie_weights_roberta = not share_in_out + if args.hack_layernorm_embedding: + args.layernorm_embedding = False + args.encoder_normalize_before = False + roberta_enc = roberta.RobertaModel.build_model(args, task) + + return cls.from_roberta(roberta_enc, args, task.source_dictionary) + + @staticmethod + def from_roberta(roberta_enc: roberta.RobertaModel, args, dictionary): + encoder = roberta_enc.encoder.sentence_encoder + vocab_size, embed_dim = encoder.embed_tokens.weight.shape + + if args.share_all_embeddings: + lm_head = roberta_enc.encoder.lm_head + assert encoder.embed_tokens.weight is lm_head.weight, ( + "Can't use --share-all-embeddings with a model " + "that was pretraiend with --untie-weights-roberta_enc" + ) + else: + lm_head = roberta.RobertaLMHead( + embed_dim, vocab_size, roberta_enc.args.activation_fn + ) + + dec_embs = nn.Embedding(vocab_size, embed_dim, dictionary.pad()) + if args.share_all_embeddings or args.share_decoder_input_output_embed: + # Note: I wasn't able to use Embedding _weight parameter to achive this sharing. + dec_embs.weight = lm_head.weight + + decoder = TransformerDecoder( + RobertaEncDecModel.read_args_from_roberta(roberta_enc.args), + dictionary, + dec_embs, + no_encoder_attn=False, + output_projection=lm_head, + ) + if getattr(args, "pretrained_decoder", False): + decoder_dict = encoder.state_dict() + + # TODO: hide setting "encoder_attn" layers behind a flag. + for k, w in list(decoder_dict.items()): + if ".self_attn" in k: + k_enc_attn = k.replace(".self_attn", ".encoder_attn") + decoder_dict[k_enc_attn] = w.detach().clone() + + for k, w in lm_head.state_dict().items(): + decoder_dict["output_projection." + k] = w + + missing_keys, unexpected_keys = decoder.load_state_dict( + decoder_dict, strict=False + ) + # missing_keys = [m for m in missing_keys if ".encoder_attn" not in m] + assert not missing_keys and not unexpected_keys, ( + "Failed to load state dict. " + f"Missing keys: {missing_keys}. " + f"Unexpected keys: {unexpected_keys}." + ) + + if args.share_all_embeddings: + assert decoder.output_projection.weight is decoder.embed_tokens.weight + assert encoder.embed_tokens.weight is decoder.embed_tokens.weight + elif args.share_decoder_input_output_embed: + assert decoder.output_projection.weight is decoder.embed_tokens.weight + assert encoder.embed_tokens.weight is not decoder.embed_tokens.weight + else: + assert decoder.output_projection.weight is not decoder.embed_tokens.weight + assert encoder.embed_tokens.weight is not decoder.embed_tokens.weight + + return RobertaEncDecModel(encoder, decoder) + + @staticmethod + def read_args_from_roberta(roberta_args: argparse.Namespace): + # TODO: this would become easier if encoder/decoder where using a similar + # TransformerConfig object + args = argparse.Namespace(**vars(roberta_args)) + attr_map = [ + ("encoder_attention_heads", "decoder_attention_heads"), + ("encoder_embed_dim", "decoder_embed_dim"), + ("encoder_embed_dim", "decoder_output_dim"), + ("encoder_normalize_before", "decoder_normalize_before"), + ("encoder_layers_to_keep", "decoder_layers_to_keep"), + ("encoder_ffn_embed_dim", "decoder_ffn_embed_dim"), + ("encoder_layerdrop", "decoder_layerdrop"), + ("encoder_layers", "decoder_layers"), + ("encoder_learned_pos", "decoder_learned_pos"), + # should this be set from here ? + ("max_positions", "max_target_positions"), + ] + for k1, k2 in attr_map: + setattr(args, k2, getattr(roberta_args, k1)) + + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = not roberta_args.untie_weights_roberta + return args + + def upgrade_state_dict_named(self, state_dict, name): + prefix = name + "." if name != "" else "" + super().upgrade_state_dict_named(state_dict, name) + old_keys = list(state_dict.keys()) + + # rename decoder -> encoder before upgrading children modules + for k in old_keys: + if k.startswith(prefix + "encoder.lm_head"): + state_dict.pop(k) + continue + new_k = k + new_k = new_k.replace(".sentence_encoder.", ".") + new_k = new_k.replace("decoder.lm_head.", "decoder.output_projection.") + if k == new_k: + continue + # print(k, "->", new_k) + state_dict[new_k] = state_dict.pop(k) + + +@register_model_architecture("roberta_enc_dec", "roberta_enc_dec") +def base_enc_dec_architecture(args): + args.hack_layernorm_embedding = getattr(args, "hack_layernorm_embedding", False) + args.pretrained_mlm_checkpoint = getattr(args, "pretrained_mlm_checkpoint", None) + args.pretrained_decoder = getattr(args, "pretrained_decoder", None) + args.share_all_embeddings = getattr(args, "share_all_embeddings", False) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + + roberta.base_architecture(args) diff --git a/fairseq/models/roberta/hub_interface.py b/fairseq/models/roberta/hub_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..ba298d63ba5da2a5b2f1a44d0384a6b249277ef4 --- /dev/null +++ b/fairseq/models/roberta/hub_interface.py @@ -0,0 +1,235 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import utils +from fairseq.data import encoders + + +class RobertaHubInterface(nn.Module): + """A simple PyTorch Hub interface to RoBERTa. + + Usage: https://github.com/pytorch/fairseq/tree/main/examples/roberta + """ + + def __init__(self, cfg, task, model): + super().__init__() + self.cfg = cfg + self.task = task + self.model = model + + self.bpe = encoders.build_bpe(cfg.bpe) + + # this is useful for determining the device + self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float)) + + @property + def device(self): + return self._float_tensor.device + + def encode( + self, sentence: str, *addl_sentences, no_separator=False + ) -> torch.LongTensor: + """ + BPE-encode a sentence (or multiple sentences). + + Every sequence begins with a beginning-of-sentence (``) symbol. + Every sentence ends with an end-of-sentence (``) and we use an + extra end-of-sentence (``) as a separator. + + Example (single sentence): ` a b c ` + Example (sentence pair): ` d e f 1 2 3 ` + + The BPE encoding follows GPT-2. One subtle detail is that the GPT-2 BPE + requires leading spaces. For example:: + + >>> roberta.encode('Hello world').tolist() + [0, 31414, 232, 2] + >>> roberta.encode(' world').tolist() + [0, 232, 2] + >>> roberta.encode('world').tolist() + [0, 8331, 2] + """ + bpe_sentence = " " + self.bpe.encode(sentence) + " " + for s in addl_sentences: + bpe_sentence += " " if not no_separator else "" + bpe_sentence += " " + self.bpe.encode(s) + " " + tokens = self.task.source_dictionary.encode_line( + bpe_sentence, append_eos=False, add_if_not_exist=False + ) + return tokens.long() + + def decode(self, tokens: torch.LongTensor): + assert tokens.dim() == 1 + tokens = tokens.numpy() + if tokens[0] == self.task.source_dictionary.bos(): + tokens = tokens[1:] # remove + eos_mask = tokens == self.task.source_dictionary.eos() + doc_mask = eos_mask[1:] & eos_mask[:-1] + sentences = np.split(tokens, doc_mask.nonzero()[0] + 1) + sentences = [ + self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences + ] + if len(sentences) == 1: + return sentences[0] + return sentences + + def extract_features( + self, tokens: torch.LongTensor, return_all_hiddens: bool = False + ) -> torch.Tensor: + if tokens.dim() == 1: + tokens = tokens.unsqueeze(0) + if tokens.size(-1) > self.model.max_positions(): + raise ValueError( + "tokens exceeds maximum length: {} > {}".format( + tokens.size(-1), self.model.max_positions() + ) + ) + features, extra = self.model( + tokens.to(device=self.device), + features_only=True, + return_all_hiddens=return_all_hiddens, + ) + if return_all_hiddens: + # convert from T x B x C -> B x T x C + inner_states = extra["inner_states"] + return [inner_state.transpose(0, 1) for inner_state in inner_states] + else: + return features # just the last layer's features + + def register_classification_head( + self, name: str, num_classes: int = None, embedding_size: int = None, **kwargs + ): + self.model.register_classification_head( + name, num_classes=num_classes, embedding_size=embedding_size, **kwargs + ) + + def predict(self, head: str, tokens: torch.LongTensor, return_logits: bool = False): + features = self.extract_features(tokens.to(device=self.device)) + logits = self.model.classification_heads[head](features) + if return_logits: + return logits + return F.log_softmax(logits, dim=-1) + + def extract_features_aligned_to_words( + self, sentence: str, return_all_hiddens: bool = False + ) -> torch.Tensor: + """Extract RoBERTa features, aligned to spaCy's word-level tokenizer.""" + from fairseq.models.roberta import alignment_utils + from spacy.tokens import Doc + + nlp = alignment_utils.spacy_nlp() + tokenizer = alignment_utils.spacy_tokenizer() + + # tokenize both with GPT-2 BPE and spaCy + bpe_toks = self.encode(sentence) + spacy_toks = tokenizer(sentence) + spacy_toks_ws = [t.text_with_ws for t in tokenizer(sentence)] + alignment = alignment_utils.align_bpe_to_words(self, bpe_toks, spacy_toks_ws) + + # extract features and align them + features = self.extract_features( + bpe_toks, return_all_hiddens=return_all_hiddens + ) + features = features.squeeze(0) + aligned_feats = alignment_utils.align_features_to_words( + self, features, alignment + ) + + # wrap in spaCy Doc + doc = Doc( + nlp.vocab, + words=[""] + [x.text for x in spacy_toks] + [""], + spaces=[True] + + [x.endswith(" ") for x in spacy_toks_ws[:-1]] + + [True, False], + ) + assert len(doc) == aligned_feats.size(0) + doc.user_token_hooks["vector"] = lambda token: aligned_feats[token.i] + return doc + + def fill_mask(self, masked_input: str, topk: int = 5): + masked_token = "" + assert ( + masked_token in masked_input and masked_input.count(masked_token) == 1 + ), "Please add one {0} token for the input, eg: 'He is a {0} guy'".format( + masked_token + ) + + text_spans = masked_input.split(masked_token) + text_spans_bpe = ( + (" {0} ".format(masked_token)) + .join([self.bpe.encode(text_span.rstrip()) for text_span in text_spans]) + .strip() + ) + tokens = self.task.source_dictionary.encode_line( + " " + text_spans_bpe + " ", + append_eos=False, + add_if_not_exist=False, + ) + + masked_index = (tokens == self.task.mask_idx).nonzero(as_tuple=False) + if tokens.dim() == 1: + tokens = tokens.unsqueeze(0) + + with utils.model_eval(self.model): + features, extra = self.model( + tokens.long().to(device=self.device), + features_only=False, + return_all_hiddens=False, + ) + logits = features[0, masked_index, :].squeeze() + prob = logits.softmax(dim=0) + values, index = prob.topk(k=topk, dim=0) + topk_predicted_token_bpe = self.task.source_dictionary.string(index) + + topk_filled_outputs = [] + for index, predicted_token_bpe in enumerate( + topk_predicted_token_bpe.split(" ") + ): + predicted_token = self.bpe.decode(predicted_token_bpe) + # Quick hack to fix https://github.com/pytorch/fairseq/issues/1306 + if predicted_token_bpe.startswith("\u2581"): + predicted_token = " " + predicted_token + if " {0}".format(masked_token) in masked_input: + topk_filled_outputs.append( + ( + masked_input.replace( + " {0}".format(masked_token), predicted_token + ), + values[index].item(), + predicted_token, + ) + ) + else: + topk_filled_outputs.append( + ( + masked_input.replace(masked_token, predicted_token), + values[index].item(), + predicted_token, + ) + ) + return topk_filled_outputs + + def disambiguate_pronoun(self, sentence: str) -> bool: + """ + Usage:: + + >>> disambiguate_pronoun('The _trophy_ would not fit in the brown suitcase because [it] was too big.') + True + + >>> disambiguate_pronoun('The trophy would not fit in the brown suitcase because [it] was too big.') + 'The trophy' + """ + assert hasattr( + self.task, "disambiguate_pronoun" + ), "roberta.disambiguate_pronoun() requires a model trained with the WSC task." + with utils.model_eval(self.model): + return self.task.disambiguate_pronoun( + self.model, sentence, use_cuda=self.device.type == "cuda" + ) diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py new file mode 100644 index 0000000000000000000000000000000000000000..d7ced9190ca60a922e289958b9f636875f41b125 --- /dev/null +++ b/fairseq/models/roberta/model.py @@ -0,0 +1,700 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +RoBERTa: A Robustly Optimized BERT Pretraining Approach. +""" + +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq import utils +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderModel, + register_model, + register_model_architecture, +) +from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, TransformerEncoder +from fairseq.modules import LayerNorm +from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ +from fairseq.modules.transformer_sentence_encoder import init_bert_params +from fairseq.utils import safe_getattr, safe_hasattr + +from .hub_interface import RobertaHubInterface + +logger = logging.getLogger(__name__) + + +@register_model("roberta") +class RobertaModel(FairseqEncoderModel): + @classmethod + def hub_models(cls): + return { + "roberta.base": "http://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz", + "roberta.large": "http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz", + "roberta.large.mnli": "http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz", + "roberta.large.wsc": "http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.wsc.tar.gz", + } + + def __init__(self, args, encoder): + super().__init__(encoder) + self.args = args + + # We follow BERT's random weight initialization + self.apply(init_bert_params) + + self.classification_heads = nn.ModuleDict() + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + "--encoder-layers", type=int, metavar="L", help="num encoder layers" + ) + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="H", + help="encoder embedding dimension", + ) + parser.add_argument( + "--encoder-ffn-embed-dim", + type=int, + metavar="F", + help="encoder embedding dimension for FFN", + ) + parser.add_argument( + "--encoder-attention-heads", + type=int, + metavar="A", + help="num encoder attention heads", + ) + parser.add_argument( + "--activation-fn", + choices=utils.get_available_activation_fns(), + help="activation function to use", + ) + parser.add_argument( + "--pooler-activation-fn", + choices=utils.get_available_activation_fns(), + help="activation function to use for pooler layer", + ) + parser.add_argument( + "--encoder-normalize-before", + action="store_true", + help="apply layernorm before each encoder block", + ) + parser.add_argument( + "--layernorm-embedding", + action="store_true", + help="add layernorm to embedding", + ) + parser.add_argument( + "--dropout", type=float, metavar="D", help="dropout probability" + ) + parser.add_argument( + "--attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights", + ) + parser.add_argument( + "--activation-dropout", + type=float, + metavar="D", + help="dropout probability after activation in FFN", + ) + parser.add_argument( + "--pooler-dropout", + type=float, + metavar="D", + help="dropout probability in the masked_lm pooler layers", + ) + parser.add_argument( + "--max-positions", type=int, help="number of positional embeddings to learn" + ) + parser.add_argument( + "--load-checkpoint-heads", + action="store_true", + help="(re-)register and load heads when loading checkpoints", + ) + parser.add_argument( + "--untie-weights-roberta", + action="store_true", + help="Untie weights between embeddings and classifiers in RoBERTa", + ) + # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) + parser.add_argument( + "--encoder-layerdrop", + type=float, + metavar="D", + default=0, + help="LayerDrop probability for encoder", + ) + parser.add_argument( + "--encoder-layers-to-keep", + default=None, + help="which layers to *keep* when pruning as a comma-separated list", + ) + # args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020) + parser.add_argument( + "--quant-noise-pq", + type=float, + metavar="D", + default=0, + help="iterative PQ quantization noise at training time", + ) + parser.add_argument( + "--quant-noise-pq-block-size", + type=int, + metavar="D", + default=8, + help="block size of quantization noise at training time", + ) + parser.add_argument( + "--quant-noise-scalar", + type=float, + metavar="D", + default=0, + help="scalar quantization noise and scalar quantization at training time", + ) + # args for "Better Fine-Tuning by Reducing Representational Collapse" (Aghajanyan et al. 2020) + parser.add_argument( + "--spectral-norm-classification-head", + action="store_true", + default=False, + help="Apply spectral normalization on the classification head", + ) + # args for Fully Sharded Data Parallel (FSDP) training + parser.add_argument( + "--min-params-to-wrap", + type=int, + metavar="D", + default=DEFAULT_MIN_PARAMS_TO_WRAP, + help=( + "minimum number of params for a layer to be wrapped with FSDP() when " + "training with --ddp-backend=fully_sharded. Smaller values will " + "improve memory efficiency, but may make torch.distributed " + "communication less efficient due to smaller input sizes. This option " + "is set to 0 (i.e., always wrap) when --checkpoint-activations or " + "--offload-activations are passed." + ), + ) + # args for AdaPruning + # In short, it adds regularizarion for the multihead attention module and feed forward neural nets + # For more details, please refer to the paper https://openreview.net/forum?id=_CMSV7FTzGI + parser.add_argument( + "--mha-reg-scale-factor", + type=float, + metavar="D", + default=0.0, + help="scaling factor for regularization term in adptive pruning, recommendation is 0.000375", + ) + parser.add_argument( + "--ffn-reg-scale-factor", + type=float, + metavar="D", + default=0.0, + help="scaling factor for regularization term in adptive pruning, recommendation is 0.000375", + ) + parser.add_argument( + "--mha-heads-to-keep", + type=int, + metavar="D", + default=-1, + help="number of heads to keep in each multi-head attention module, -1 means keeping all heads", + ) + parser.add_argument( + "--ffn-blocks-to-remove", + type=int, + metavar="D", + default=-1, + help="number of feedforward blocks to remove in each transformer layer, -1 means keeping all ffn blocks", + ) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + from omegaconf import OmegaConf + + if OmegaConf.is_config(args): + OmegaConf.set_struct(args, False) + + # make sure all arguments are present + base_architecture(args) + + if not safe_hasattr(args, "max_positions"): + if not safe_hasattr(args, "tokens_per_sample"): + args.tokens_per_sample = task.max_positions() + args.max_positions = args.tokens_per_sample + + encoder = RobertaEncoder(args, task.source_dictionary) + + if OmegaConf.is_config(args): + OmegaConf.set_struct(args, True) + + return cls(args, encoder) + + def forward( + self, + src_tokens, + features_only=False, + return_all_hiddens=False, + classification_head_name=None, + **kwargs, + ): + if classification_head_name is not None: + features_only = True + + x, extra = self.encoder(src_tokens, features_only, return_all_hiddens, **kwargs) + + if classification_head_name is not None: + x = self.classification_heads[classification_head_name](x) + return x, extra + + def _get_adaptive_head_loss(self): + norm_loss = 0 + scaling = float(self.args.mha_reg_scale_factor) + for layer in self.encoder.sentence_encoder.layers: + norm_loss_layer = 0 + for i in range(layer.self_attn.num_heads): + start_idx = i * layer.self_attn.head_dim + end_idx = (i + 1) * layer.self_attn.head_dim + norm_loss_layer += scaling * ( + torch.sum( + torch.abs( + layer.self_attn.q_proj.weight[ + start_idx:end_idx, + ] + ) + ) + + torch.sum( + torch.abs(layer.self_attn.q_proj.bias[start_idx:end_idx]) + ) + ) + norm_loss_layer += scaling * ( + torch.sum( + torch.abs( + layer.self_attn.k_proj.weight[ + start_idx:end_idx, + ] + ) + ) + + torch.sum( + torch.abs(layer.self_attn.k_proj.bias[start_idx:end_idx]) + ) + ) + norm_loss_layer += scaling * ( + torch.sum( + torch.abs( + layer.self_attn.v_proj.weight[ + start_idx:end_idx, + ] + ) + ) + + torch.sum( + torch.abs(layer.self_attn.v_proj.bias[start_idx:end_idx]) + ) + ) + + norm_loss += norm_loss_layer + return norm_loss + + def _get_adaptive_ffn_loss(self): + ffn_scale_factor = float(self.args.ffn_reg_scale_factor) + filter_loss = 0 + for layer in self.encoder.sentence_encoder.layers: + filter_loss += torch.sum( + torch.abs(layer.fc1.weight * ffn_scale_factor) + ) + torch.sum(torch.abs(layer.fc2.weight * ffn_scale_factor)) + filter_loss += torch.sum( + torch.abs(layer.fc1.bias * ffn_scale_factor) + ) + torch.sum(torch.abs(layer.fc2.bias * ffn_scale_factor)) + return filter_loss + + def get_normalized_probs(self, net_output, log_probs, sample=None): + """Get normalized probabilities (or log probs) from a net's output.""" + logits = net_output[0].float() + if log_probs: + return F.log_softmax(logits, dim=-1) + else: + return F.softmax(logits, dim=-1) + + def register_classification_head( + self, name, num_classes=None, inner_dim=None, **kwargs + ): + """Register a classification head.""" + if name in self.classification_heads: + prev_num_classes = self.classification_heads[name].out_proj.out_features + prev_inner_dim = self.classification_heads[name].dense.out_features + if num_classes != prev_num_classes or inner_dim != prev_inner_dim: + logger.warning( + 're-registering head "{}" with num_classes {} (prev: {}) ' + "and inner_dim {} (prev: {})".format( + name, num_classes, prev_num_classes, inner_dim, prev_inner_dim + ) + ) + self.classification_heads[name] = RobertaClassificationHead( + input_dim=self.args.encoder_embed_dim, + inner_dim=inner_dim or self.args.encoder_embed_dim, + num_classes=num_classes, + activation_fn=self.args.pooler_activation_fn, + pooler_dropout=self.args.pooler_dropout, + q_noise=self.args.quant_noise_pq, + qn_block_size=self.args.quant_noise_pq_block_size, + do_spectral_norm=self.args.spectral_norm_classification_head, + ) + + @property + def supported_targets(self): + return {"self"} + + @classmethod + def from_pretrained( + cls, + model_name_or_path, + checkpoint_file="model.pt", + data_name_or_path=".", + bpe="gpt2", + **kwargs, + ): + from fairseq import hub_utils + + x = hub_utils.from_pretrained( + model_name_or_path, + checkpoint_file, + data_name_or_path, + archive_map=cls.hub_models(), + bpe=bpe, + load_checkpoint_heads=True, + **kwargs, + ) + + logger.info(x["args"]) + return RobertaHubInterface(x["args"], x["task"], x["models"][0]) + + def upgrade_state_dict_named(self, state_dict, name): + prefix = name + "." if name != "" else "" + + # rename decoder -> encoder before upgrading children modules + for k in list(state_dict.keys()): + if k.startswith(prefix + "decoder"): + new_k = prefix + "encoder" + k[len(prefix + "decoder") :] + state_dict[new_k] = state_dict[k] + del state_dict[k] + + # rename emb_layer_norm -> layernorm_embedding + for k in list(state_dict.keys()): + if ".emb_layer_norm." in k: + new_k = k.replace(".emb_layer_norm.", ".layernorm_embedding.") + state_dict[new_k] = state_dict[k] + del state_dict[k] + + # upgrade children modules + super().upgrade_state_dict_named(state_dict, name) + + # Handle new classification heads present in the state dict. + current_head_names = ( + [] + if not hasattr(self, "classification_heads") + else self.classification_heads.keys() + ) + keys_to_delete = [] + for k in state_dict.keys(): + if not k.startswith(prefix + "classification_heads."): + continue + + head_name = k[len(prefix + "classification_heads.") :].split(".")[0] + num_classes = state_dict[ + prefix + "classification_heads." + head_name + ".out_proj.weight" + ].size(0) + inner_dim = state_dict[ + prefix + "classification_heads." + head_name + ".dense.weight" + ].size(0) + + if getattr(self.args, "load_checkpoint_heads", False): + if head_name not in current_head_names: + self.register_classification_head(head_name, num_classes, inner_dim) + else: + if head_name not in current_head_names: + logger.warning( + "deleting classification head ({}) from checkpoint " + "not present in current model: {}".format(head_name, k) + ) + keys_to_delete.append(k) + elif ( + num_classes + != self.classification_heads[head_name].out_proj.out_features + or inner_dim + != self.classification_heads[head_name].dense.out_features + ): + logger.warning( + "deleting classification head ({}) from checkpoint " + "with different dimensions than current model: {}".format( + head_name, k + ) + ) + keys_to_delete.append(k) + for k in keys_to_delete: + del state_dict[k] + + # Copy any newly-added classification heads into the state dict + # with their current weights. + if hasattr(self, "classification_heads"): + cur_state = self.classification_heads.state_dict() + for k, v in cur_state.items(): + if prefix + "classification_heads." + k not in state_dict: + logger.info("Overwriting " + prefix + "classification_heads." + k) + state_dict[prefix + "classification_heads." + k] = v + + # adapt data2vec models + if ( + "encoder._ema" in state_dict + and "encoder.lm_head.weight" not in state_dict + ): + lm_state = self.encoder.lm_head.state_dict() + for k, v in lm_state.items(): + state_dict["encoder.lm_head." + k] = v + + for k in list(state_dict.keys()): + if k.startswith("encoder.regression_head") or k == "encoder._ema": + del state_dict[k] + + +class RobertaLMHead(nn.Module): + """Head for masked language modeling.""" + + def __init__(self, embed_dim, output_dim, activation_fn, weight=None): + super().__init__() + self.dense = nn.Linear(embed_dim, embed_dim) + self.activation_fn = utils.get_activation_fn(activation_fn) + self.layer_norm = LayerNorm(embed_dim) + + if weight is None: + weight = nn.Linear(embed_dim, output_dim, bias=False).weight + self.weight = weight + self.bias = nn.Parameter(torch.zeros(output_dim)) + + def forward(self, features, masked_tokens=None, **kwargs): + # Only project the masked tokens while training, + # saves both memory and computation + if masked_tokens is not None: + features = features[masked_tokens, :] + + x = self.dense(features) + x = self.activation_fn(x) + x = self.layer_norm(x) + # project back to size of vocabulary with bias + x = F.linear(x, self.weight) + self.bias + return x + + +class RobertaClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim, + inner_dim, + num_classes, + activation_fn, + pooler_dropout, + q_noise=0, + qn_block_size=8, + do_spectral_norm=False, + ): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.activation_fn = utils.get_activation_fn(activation_fn) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = apply_quant_noise_( + nn.Linear(inner_dim, num_classes), q_noise, qn_block_size + ) + if do_spectral_norm: + if q_noise != 0: + raise NotImplementedError( + "Attempting to use Spectral Normalization with Quant Noise. This is not officially supported" + ) + self.out_proj = torch.nn.utils.spectral_norm(self.out_proj) + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = self.activation_fn(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +class RobertaEncoder(FairseqEncoder): + """RoBERTa encoder.""" + + def __init__(self, args, dictionary): + super().__init__(dictionary) + + # set any missing default values + base_architecture(args) + self.args = args + + if args.encoder_layers_to_keep: + args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) + + embed_tokens = self.build_embedding( + len(dictionary), args.encoder_embed_dim, dictionary.pad() + ) + + self.sentence_encoder = self.build_encoder(args, dictionary, embed_tokens) + + self.lm_head = self.build_lm_head( + embed_dim=args.encoder_embed_dim, + output_dim=len(dictionary), + activation_fn=args.activation_fn, + weight=( + self.sentence_encoder.embed_tokens.weight + if not args.untie_weights_roberta + else None + ), + ) + + def build_embedding(self, vocab_size, embedding_dim, padding_idx): + return nn.Embedding(vocab_size, embedding_dim, padding_idx) + + def build_encoder(self, args, dictionary, embed_tokens): + encoder = TransformerEncoder(args, dictionary, embed_tokens) + encoder.apply(init_bert_params) + return encoder + + def build_lm_head(self, embed_dim, output_dim, activation_fn, weight): + return RobertaLMHead(embed_dim, output_dim, activation_fn, weight) + + def forward( + self, + src_tokens, + features_only=False, + return_all_hiddens=False, + masked_tokens=None, + **unused, + ): + """ + Args: + src_tokens (LongTensor): input tokens of shape `(batch, src_len)` + features_only (bool, optional): skip LM head and just return + features. If True, the output will be of shape + `(batch, src_len, embed_dim)`. + return_all_hiddens (bool, optional): also return all of the + intermediate hidden states (default: False). + + Returns: + tuple: + - the LM output of shape `(batch, src_len, vocab)` + - a dictionary of additional data, where 'inner_states' + is a list of hidden states. Note that the hidden + states have shape `(src_len, batch, vocab)`. + """ + x, extra = self.extract_features( + src_tokens, return_all_hiddens=return_all_hiddens + ) + if not features_only: + x = self.output_layer(x, masked_tokens=masked_tokens) + return x, extra + + def extract_features(self, src_tokens, return_all_hiddens=False, **kwargs): + encoder_out = self.sentence_encoder( + src_tokens, + return_all_hiddens=return_all_hiddens, + token_embeddings=kwargs.get("token_embeddings", None), + ) + # T x B x C -> B x T x C + features = encoder_out["encoder_out"][0].transpose(0, 1) + inner_states = encoder_out["encoder_states"] if return_all_hiddens else None + return features, {"inner_states": inner_states} + + def output_layer(self, features, masked_tokens=None, **unused): + return self.lm_head(features, masked_tokens) + + def max_positions(self): + """Maximum output length supported by the encoder.""" + return self.args.max_positions + + +@register_model_architecture("roberta", "roberta") +def base_architecture(args): + args.encoder_layers = safe_getattr(args, "encoder_layers", 12) + args.encoder_embed_dim = safe_getattr(args, "encoder_embed_dim", 768) + args.encoder_ffn_embed_dim = safe_getattr(args, "encoder_ffn_embed_dim", 3072) + args.encoder_attention_heads = safe_getattr(args, "encoder_attention_heads", 12) + + args.dropout = safe_getattr(args, "dropout", 0.1) + args.attention_dropout = safe_getattr(args, "attention_dropout", 0.1) + args.activation_dropout = safe_getattr(args, "activation_dropout", 0.0) + args.pooler_dropout = safe_getattr(args, "pooler_dropout", 0.0) + + args.max_source_positions = safe_getattr(args, "max_positions", 512) + args.no_token_positional_embeddings = safe_getattr( + args, "no_token_positional_embeddings", False + ) + + # BERT has a few structural differences compared to the original Transformer + args.encoder_learned_pos = safe_getattr(args, "encoder_learned_pos", True) + args.layernorm_embedding = safe_getattr(args, "layernorm_embedding", True) + args.no_scale_embedding = safe_getattr(args, "no_scale_embedding", True) + args.activation_fn = safe_getattr(args, "activation_fn", "gelu") + args.encoder_normalize_before = safe_getattr( + args, "encoder_normalize_before", False + ) + args.pooler_activation_fn = safe_getattr(args, "pooler_activation_fn", "tanh") + args.untie_weights_roberta = safe_getattr(args, "untie_weights_roberta", False) + + # Adaptive input config + args.adaptive_input = safe_getattr(args, "adaptive_input", False) + + # LayerDrop config + args.encoder_layerdrop = safe_getattr(args, "encoder_layerdrop", 0.0) + args.encoder_layers_to_keep = safe_getattr(args, "encoder_layers_to_keep", None) + + # Quantization noise config + args.quant_noise_pq = safe_getattr(args, "quant_noise_pq", 0) + args.quant_noise_pq_block_size = safe_getattr(args, "quant_noise_pq_block_size", 8) + args.quant_noise_scalar = safe_getattr(args, "quant_noise_scalar", 0) + + # R4F config + args.spectral_norm_classification_head = safe_getattr( + args, "spectral_norm_classification_head", False + ) + + +@register_model_architecture("roberta", "roberta_prenorm") +def roberta_prenorm_architecture(args): + args.layernorm_embedding = safe_getattr(args, "layernorm_embedding", False) + args.encoder_normalize_before = safe_getattr(args, "encoder_normalize_before", True) + base_architecture(args) + + +@register_model_architecture("roberta", "roberta_base") +def roberta_base_architecture(args): + base_architecture(args) + + +@register_model_architecture("roberta", "roberta_large") +def roberta_large_architecture(args): + args.encoder_layers = safe_getattr(args, "encoder_layers", 24) + args.encoder_embed_dim = safe_getattr(args, "encoder_embed_dim", 1024) + args.encoder_ffn_embed_dim = safe_getattr(args, "encoder_ffn_embed_dim", 4096) + args.encoder_attention_heads = safe_getattr(args, "encoder_attention_heads", 16) + base_architecture(args) + + +@register_model_architecture("roberta", "xlm") +def xlm_architecture(args): + args.encoder_layers = safe_getattr(args, "encoder_layers", 16) + args.encoder_embed_dim = safe_getattr(args, "encoder_embed_dim", 1280) + args.encoder_ffn_embed_dim = safe_getattr(args, "encoder_ffn_embed_dim", 1280 * 4) + args.encoder_attention_heads = safe_getattr(args, "encoder_attention_heads", 16) + base_architecture(args) diff --git a/fairseq/models/roberta/model_camembert.py b/fairseq/models/roberta/model_camembert.py new file mode 100644 index 0000000000000000000000000000000000000000..46447546fafb4a0a887b481022cac07631047c80 --- /dev/null +++ b/fairseq/models/roberta/model_camembert.py @@ -0,0 +1,50 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +CamemBERT: a Tasty French Language Model +""" + +from fairseq.models import register_model + +from .hub_interface import RobertaHubInterface +from .model import RobertaModel + + +@register_model("camembert") +class CamembertModel(RobertaModel): + @classmethod + def hub_models(cls): + return { + "camembert": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz", + "camembert.v0": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz", + "camembert-base": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz", + "camembert-large": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-large.tar.gz", + "camembert-base-ccnet": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet.tar.gz", + "camembert-base-ccnet-4gb": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet-4gb.tar.gz", + "camembert-base-wikipedia-4gb": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-wikipedia-4gb.tar.gz", + "camembert-base-oscar-4gb": "http://dl.fbaipublicfiles.com/fairseq/models/camembert-base-oscar-4gb.tar.gz", + } + + @classmethod + def from_pretrained( + cls, + model_name_or_path, + checkpoint_file="model.pt", + data_name_or_path=".", + bpe="sentencepiece", + **kwargs + ): + from fairseq import hub_utils + + x = hub_utils.from_pretrained( + model_name_or_path, + checkpoint_file, + data_name_or_path, + archive_map=cls.hub_models(), + bpe=bpe, + load_checkpoint_heads=True, + **kwargs, + ) + return RobertaHubInterface(x["args"], x["task"], x["models"][0]) diff --git a/fairseq/models/roberta/model_gottbert.py b/fairseq/models/roberta/model_gottbert.py new file mode 100644 index 0000000000000000000000000000000000000000..dc7a019b33387cce2e222138a339be1c904335c7 --- /dev/null +++ b/fairseq/models/roberta/model_gottbert.py @@ -0,0 +1,49 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +GottBERT: a pure German Language Model +""" + +from fairseq.models import register_model + +from .hub_interface import RobertaHubInterface +from .model import RobertaModel + + +@register_model("gottbert") +class GottbertModel(RobertaModel): + @classmethod + def hub_models(cls): + return { + "gottbert-base": "https://dl.gottbert.de/fairseq/models/gottbert-base.tar.gz", + } + + @classmethod + def from_pretrained( + cls, + model_name_or_path, + checkpoint_file="model.pt", + data_name_or_path=".", + bpe="hf_byte_bpe", + bpe_vocab="vocab.json", + bpe_merges="merges.txt", + bpe_add_prefix_space=False, + **kwargs + ): + from fairseq import hub_utils + + x = hub_utils.from_pretrained( + model_name_or_path, + checkpoint_file, + data_name_or_path, + archive_map=cls.hub_models(), + bpe=bpe, + load_checkpoint_heads=True, + bpe_vocab=bpe_vocab, + bpe_merges=bpe_merges, + bpe_add_prefix_space=bpe_add_prefix_space, + **kwargs, + ) + return RobertaHubInterface(x["args"], x["task"], x["models"][0]) diff --git a/fairseq/models/roberta/model_xlmr.py b/fairseq/models/roberta/model_xlmr.py new file mode 100644 index 0000000000000000000000000000000000000000..cf6e354d53b918dd4c7c78bfcd38ac0d63cab3bd --- /dev/null +++ b/fairseq/models/roberta/model_xlmr.py @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Unsupervised Cross-lingual Representation Learning at Scale +""" + +from fairseq.models import register_model + +from .hub_interface import RobertaHubInterface +from .model import RobertaModel + + +@register_model("xlmr") +class XLMRModel(RobertaModel): + @classmethod + def hub_models(cls): + return { + "xlmr.base": "http://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.tar.gz", + "xlmr.large": "http://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.tar.gz", + "xlmr.xl": "http://dl.fbaipublicfiles.com/fairseq/models/xlmr/xlmr.xl.tar.gz", + "xlmr.xxl": "http://dl.fbaipublicfiles.com/fairseq/models/xlmr/xlmr.xxl.tar.gz", + } + + @classmethod + def from_pretrained( + cls, + model_name_or_path, + checkpoint_file="model.pt", + data_name_or_path=".", + bpe="sentencepiece", + **kwargs + ): + from fairseq import hub_utils + + x = hub_utils.from_pretrained( + model_name_or_path, + checkpoint_file, + data_name_or_path, + archive_map=cls.hub_models(), + bpe=bpe, + load_checkpoint_heads=True, + **kwargs, + ) + return RobertaHubInterface(x["args"], x["task"], x["models"][0]) diff --git a/fairseq/models/speech_dlm/__init__.py b/fairseq/models/speech_dlm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6ea914d6a578651fecd18cc7f352382623de303a --- /dev/null +++ b/fairseq/models/speech_dlm/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .speech_dlm import * # noqa +from .hub_interface import * # noqa diff --git a/fairseq/models/speech_dlm/__pycache__/__init__.cpython-310.pyc b/fairseq/models/speech_dlm/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..831f8702fd42151fe9b2143bd3f43de407192340 Binary files /dev/null and b/fairseq/models/speech_dlm/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/models/speech_dlm/__pycache__/hub_interface.cpython-310.pyc b/fairseq/models/speech_dlm/__pycache__/hub_interface.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66e73381ba3e5ce2ff318bcea21654fe652bff48 Binary files /dev/null and b/fairseq/models/speech_dlm/__pycache__/hub_interface.cpython-310.pyc differ diff --git a/fairseq/models/speech_dlm/__pycache__/speech_dlm.cpython-310.pyc b/fairseq/models/speech_dlm/__pycache__/speech_dlm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b18b12ae9b96b2e5154d9a618a82fff7234ae470 Binary files /dev/null and b/fairseq/models/speech_dlm/__pycache__/speech_dlm.cpython-310.pyc differ diff --git a/fairseq/models/speech_dlm/hub_interface.py b/fairseq/models/speech_dlm/hub_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..11bc0f50bbb8cf3d146741c1f02b1ebfe8d4b7f6 --- /dev/null +++ b/fairseq/models/speech_dlm/hub_interface.py @@ -0,0 +1,192 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import logging +from typing import Any, Dict, Iterator, List + +import torch +from fairseq import utils +from omegaconf import open_dict +from torch import nn + +from tqdm import tqdm + +from fairseq.hub_utils import GeneratorHubInterface + + +logger = logging.getLogger(__name__) + + +class MultichannelGeneratorHubInterface(GeneratorHubInterface): + """Pytorch Hub interface for generating sequences from a pre-trained + multichannel language model. + """ + + def __init__(self, cfg, task, models): + super().__init__(cfg, task, models) + self.cfg = cfg + self.task = task + self.models = nn.ModuleList(models) + self.src_dicts = task.source_dictionaries + self.tgt_dicts = task.target_dictionaries + self.channels = task.channels + + # optimize model for generation + for model in self.models: + model.prepare_for_inference_(cfg) + + def sample( + self, + sentences: List[Dict[str, str]], + beam: int = 1, + verbose: bool = False, + **kwargs + ) -> List[str]: + if isinstance(sentences, dict): + return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0] + tokenized_sentences = [self.encode(sentence) for sentence in sentences] + batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs) + return [self.decode(hypos[0]["tokens"]) for hypos in batched_hypos] + + def score(self, sentences: List[Dict[str, str]], **kwargs): + raise NotImplementedError( + "MultichannelGeneratorHubInterface doesn't support score() method" + ) + + def generate( + self, + tokenized_sentences: List[Dict[str, torch.LongTensor]], + beam: int = 5, + verbose: bool = False, + skip_invalid_size_inputs=False, + inference_step_args=None, + **kwargs + ) -> List[List[Dict[str, torch.Tensor]]]: + if isinstance(tokenized_sentences, dict): + return self.generate( + [tokenized_sentences], beam=beam, verbose=verbose, **kwargs + )[0] + + # build generator using current args as well as any kwargs + gen_args = copy.deepcopy(self.cfg.generation) + with open_dict(gen_args): + gen_args.beam = beam + for k, v in kwargs.items(): + setattr(gen_args, k, v) + generator = self.task.build_generator(self.models, gen_args) + + inference_step_args = inference_step_args or {} + results = [] + for batch in tqdm( + self._build_batches(tokenized_sentences, skip_invalid_size_inputs) + ): + batch = utils.apply_to_sample(lambda t: t.to(self.device), batch) + translations = self.task.inference_step( + generator, self.models, batch, **inference_step_args + ) + for id, hypos in zip(batch["id"].tolist(), translations): + # The output of the generator is supposed to be a tensor of size (bsz x max_len x n_channels) + # So we need to convert it to dictionary form + for i in range(len(hypos)): + hypos[i]["tokens"] = { + channel: hypos[i]["tokens"][..., j] + for j, channel in enumerate(self.channels) + } + results.append((id, hypos)) + + # sort output to match input order + outputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])] + + if verbose: + + def getarg(name, default): + return getattr(gen_args, name, getattr(self.cfg, name, default)) + + for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs): + src_str_with_unk = { + channel: self.string(source_tokens[channel], channel) + for channel in source_tokens + } + logger.info("S\t{}".format(src_str_with_unk)) + for hypo in target_hypotheses: + hypo_str = self.decode(hypo["tokens"]) + logger.info("H\t{}\t{}".format(hypo["score"], hypo_str)) + # hypo["positional_scores"]: T x n_channels + pos_scores = {} + for c, channel in enumerate(source_tokens): + pos_scores[channel] = " ".join( + map( + lambda x: "{:.4f}".format(x), + hypo["positional_scores"][:, c].tolist(), + ) + ) + logger.info("P\t{}".format(pos_scores)) + + return outputs + + def encode(self, sentence: Dict[str, str]) -> Dict[str, torch.LongTensor]: + assert isinstance( + sentence, dict + ), "Input sentence is expected to be a dictionary over channels" + assert set(sentence.keys()) == set( + self.channels + ), "Mismatch between input sentence keys and model channels ({} vs {})".format( + set(sentence.keys()), set(self.channels) + ) + encoded_sentence = {} + for channel in sentence: + sentence_channel = sentence[channel] + sentence_channel = self.tokenize(sentence_channel) + sentence_channel = self.apply_bpe(sentence_channel) + sentence_channel = self.binarize(sentence_channel, channel) + encoded_sentence[channel] = sentence_channel + sentence_size = encoded_sentence[self.channels[0]].size() + assert all( + encoded_sentence[channel].size() == sentence_size + for channel in encoded_sentence + ), "Input tensors are expected to have the same size in all channels" + return encoded_sentence + + def decode(self, tokens: Dict[str, torch.LongTensor]) -> Dict[str, str]: + assert isinstance( + tokens, dict + ), "Input tokens are expected to be a dictionary over channels" + assert set(tokens.keys()) == set( + self.channels + ), "Mismatch between input tokens keys and model channels ({} vs {})".format( + set(tokens.keys()), set(self.channels) + ) + decoded_sentence = {} + for channel in tokens: + tokens_channel = tokens[channel] + sentence_channel = self.string(tokens_channel, channel) + sentence_channel = self.remove_bpe(sentence_channel) + sentence_channel = self.detokenize(sentence_channel) + decoded_sentence[channel] = sentence_channel + return decoded_sentence + + def binarize(self, sentence: str, channel: str) -> torch.LongTensor: + return ( + self.src_dicts[channel].encode_line(sentence, add_if_not_exist=False).long() + ) + + def string(self, tokens: torch.LongTensor, channel: str) -> str: + return self.tgt_dicts[channel].string(tokens) + + def _build_batches( + self, tokens: List[Dict[str, List[int]]], skip_invalid_size_inputs: bool + ) -> Iterator[Dict[str, Any]]: + lengths = torch.LongTensor([next(iter(d.values())).numel() for d in tokens]) + batch_iterator = self.task.get_batch_iterator( + dataset=self.task.build_dataset_for_inference(tokens, lengths), + max_tokens=self.cfg.dataset.max_tokens, + max_sentences=self.cfg.dataset.batch_size, + max_positions=self.max_positions, + ignore_invalid_inputs=skip_invalid_size_inputs, + disable_iterator_cache=True, + ).next_epoch_itr(shuffle=False) + return batch_iterator diff --git a/fairseq/models/speech_dlm/modules/__init__.py b/fairseq/models/speech_dlm/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fairseq/models/speech_dlm/modules/__pycache__/__init__.cpython-310.pyc b/fairseq/models/speech_dlm/modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4851c2c418eea4c9289038b9f66b44fee62eb69 Binary files /dev/null and b/fairseq/models/speech_dlm/modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/models/speech_dlm/modules/__pycache__/speech_dlm_decoder.cpython-310.pyc b/fairseq/models/speech_dlm/modules/__pycache__/speech_dlm_decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..761585c0c3ac60194d4d1506a64d13f6f15d64ff Binary files /dev/null and b/fairseq/models/speech_dlm/modules/__pycache__/speech_dlm_decoder.cpython-310.pyc differ diff --git a/fairseq/models/speech_dlm/modules/__pycache__/speech_dlm_decoder_layer.cpython-310.pyc b/fairseq/models/speech_dlm/modules/__pycache__/speech_dlm_decoder_layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fbaa67500967e088fd206c8322b81de6f919483 Binary files /dev/null and b/fairseq/models/speech_dlm/modules/__pycache__/speech_dlm_decoder_layer.cpython-310.pyc differ diff --git a/fairseq/models/speech_dlm/modules/speech_dlm_decoder.py b/fairseq/models/speech_dlm/modules/speech_dlm_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a14a1d64a87055666b017e885f0a370536f823f8 --- /dev/null +++ b/fairseq/models/speech_dlm/modules/speech_dlm_decoder.py @@ -0,0 +1,572 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from fairseq import utils +from fairseq.models import FairseqIncrementalDecoder +from fairseq.modules import ( + FairseqDropout, + LayerDropModuleList, + LayerNorm, + PositionalEmbedding, +) +from .speech_dlm_decoder_layer import ( + CrossChannelTransformerDecoderLayer, + StandardTransformerDecoderLayer, +) +from fairseq.modules.checkpoint_activations import checkpoint_wrapper +from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ +from torch import Tensor + + +class CrossChannelTransformerDecoder(FairseqIncrementalDecoder): + """ + Cross-channel Transformer Decoder Block for parallel spoken dialogue units + as described in the paper: https://arxiv.org/pdf/2203.16502.pdf; + consisting of *args.decoder_layers* layers. Each layer is a + :class:`StandardTransformerDecoderLayer` or + :class:`CrossChannelTransformerDecoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): decoding dictionary + embed_tokens (torch.nn.Embedding): output embedding + channels (list): list of channel names (string) + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__(self, args, dictionary, embed_tokens, channels, no_encoder_attn=False): + self.args = args + super().__init__(dictionary) + self.register_buffer("version", torch.Tensor([3])) + self._future_mask = torch.empty(0) + + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) + self.decoder_layerdrop = args.decoder_layerdrop + self.share_input_output_embed = args.share_decoder_input_output_embed + self.channels = channels + + input_embed_dim = embed_tokens.embedding_dim + embed_dim = args.decoder_embed_dim + self.embed_dim = embed_dim + self.output_embed_dim = args.decoder_output_dim + + self.padding_idx = embed_tokens.padding_idx + self.max_target_positions = args.max_target_positions + + self.embed_tokens = embed_tokens + + self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) + + if args.quant_noise_pq > 0: + self.quant_noise = apply_quant_noise_( + nn.Linear(embed_dim, embed_dim, bias=False), + args.quant_noise_pq, + args.quant_noise_pq_block_size, + ) + else: + self.quant_noise = None + + self.project_in_dim = ( + nn.Linear(input_embed_dim, embed_dim, bias=False) + if embed_dim != input_embed_dim + else None + ) + self.embed_positions = ( + PositionalEmbedding( + self.max_target_positions, + embed_dim, + self.padding_idx, + learned=args.decoder_learned_pos, + ) + if not args.no_token_positional_embeddings + else None + ) + + if getattr(args, "layernorm_embedding", False): + self.layernorm_embedding = LayerNorm(embed_dim) + else: + self.layernorm_embedding = None + + self.cross_self_attention = getattr(args, "cross_self_attention", False) + + assert 0 <= args.decoder_cross_layers <= args.decoder_layers, ( + "The number of cross-channel attention decoder layers must be non-negative" + f"and not exceeds the number of decoder layers (found {args.decoder_cross_layers})" + ) + + if self.decoder_layerdrop > 0.0: + self.layers = LayerDropModuleList(p=self.decoder_layerdrop) + else: + self.layers = nn.ModuleList([]) + self.layers.extend( + [ + self.build_decoder_layer(args, no_encoder_attn) + if i < args.decoder_layers - args.decoder_cross_layers + else self.build_cross_decoder_layer(args, no_encoder_attn) + for i in range(args.decoder_layers) + ] + ) + self.num_layers = len(self.layers) + self.non_cross_layers = args.decoder_layers - args.decoder_cross_layers + + if args.decoder_normalize_before and not getattr( + args, "no_decoder_final_norm", False + ): + self.layer_norm = LayerNorm(embed_dim) + else: + self.layer_norm = None + + self.project_out_dim = ( + nn.Linear(embed_dim, self.output_embed_dim, bias=False) + if embed_dim != self.output_embed_dim + else None + ) + + self.output_projection = None + self.is_cross_prediction = bool( + float(args.main_and_cross_weights.split(",")[1]) != 0 + ) + self.n_output_projections = ( + 1 if not self.is_cross_prediction else len(self.channels) + ) + + if self.share_input_output_embed: + # Output projection is a list of projections + # where the first proj is for the main-channel, + # then roll in a cicular way. + # For example: if the main channel has index i + # the second proj is for channel i+1 (mod N_channels), etc. + self.output_projection = nn.ModuleList( + [ + nn.Linear( + embed_tokens.weight.shape[1], # embed_dim + embed_tokens.weight.shape[0], # n_dictionaries + bias=False, + ) + for _ in range(self.n_output_projections) + ] + ) + # Only share the main-channel projection + self.output_projection[0].weight = embed_tokens.weight + for i in range(1, self.n_output_projections): + nn.init.normal_( + self.output_projection[i].weight, + mean=0, + std=embed_tokens.weight.shape[1] ** -0.5, + ) + else: + self.output_projection = nn.ModuleList( + [ + nn.Linear(self.output_embed_dim, len(dictionary), bias=False) + for _ in range(self.n_output_projections) + ] + ) + for i in range(self.n_output_projections): + nn.init.normal_( + self.output_projection[i].weight, + mean=0, + std=self.output_embed_dim**-0.5, + ) + self.output_duration_prediction = ( + None + if str(args.duration_prediction).lower() == "false" + else nn.ModuleList( + [ + nn.Linear(self.output_embed_dim, 1) + for _ in range(self.n_output_projections) + ] + ) + ) + + def build_decoder_layer(self, args, no_encoder_attn=False): + layer = StandardTransformerDecoderLayer(args, no_encoder_attn) + if getattr(args, "checkpoint_activations", False): + offload_to_cpu = getattr(args, "offload_activations", False) + layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) + return layer + + def build_cross_decoder_layer(self, args, no_encoder_attn=False): + layer = CrossChannelTransformerDecoderLayer(args, no_encoder_attn) + if getattr(args, "checkpoint_activations", False): + offload_to_cpu = getattr(args, "offload_activations", False) + layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) + return layer + + def forward( + self, + prev_output_tokens: Dict[str, Tensor], + encoder_out: Optional[Dict[str, List[Tensor]]] = None, + incremental_state: Optional[ + List[Dict[str, Dict[str, Optional[Tensor]]]] + ] = None, + features_only: bool = False, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + src_lengths: Optional[Any] = None, + # return_all_hiddens: bool = False, + ): + """ + Args: + prev_output_tokens (dict[str, LongTensor]): previous decoder outputs, + dictionary over all channels with the values being the tensors + of shape `(batch, tgt_len)`, for teacher forcing + encoder_out (optional): output from the encoder, used for + encoder-side attention + incremental_state (dict): list of dictionaries used for storing state + during :ref:`Incremental decoding` + features_only (bool, optional): only return features without + applying output layer (default: False). + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). + + Returns: + tuple: + - the decoder's output, dict over channels of tensors + of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + x, extra = self.extract_features( + prev_output_tokens, + encoder_out=encoder_out, + incremental_state=incremental_state, + full_context_alignment=full_context_alignment, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + ) + if not features_only: + x = self.output_layer(x) + return x, extra + + def extract_features( + self, + prev_output_tokens: Dict[str, Tensor], + encoder_out: Optional[Dict[str, List[Tensor]]], + incremental_state: Optional[ + List[Dict[str, Dict[str, Optional[Tensor]]]] + ] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + return self.extract_features_scriptable( + prev_output_tokens, + encoder_out, + incremental_state, + full_context_alignment, + alignment_layer, + alignment_heads, + ) + + """ + A scriptable subclass of this class has an extract_features method and calls + super().extract_features, but super() is not supported in torchscript. A copy of + this function is made to be used in the subclass instead. + """ + + def extract_features_scriptable( + self, + prev_output_tokens: Dict[str, Tensor], + encoder_out: Optional[Dict[str, List[Tensor]]], + incremental_state: Optional[ + List[Dict[str, Dict[str, Optional[Tensor]]]] + ] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + """ + The core function of *forward* but only return features. + + The input (prev_output_tokens) is a dictionary over all channels, + expected to have the following form: + { + 'channel1' : Tensor((batch x tgt_len)), + 'channel2' : Tensor((batch x tgt_len)), + } + + Args: + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). + alignment_layer (int, optional): return mean alignment over + heads at this layer (default: last layer). + alignment_heads (int, optional): only average alignment over + this many heads (default: all heads). + + Returns: + tuple: + - the decoder's features, dict over channels of tensors + of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + if alignment_layer is None: + alignment_layer = self.num_layers - 1 + + x_list = [] + for i, channel in enumerate(self.channels): + # embed positions + positions = None + if self.embed_positions is not None: + positions = self.embed_positions( + prev_output_tokens[channel], + incremental_state=incremental_state[i] + if incremental_state is not None + else None, + ) + + if incremental_state is not None: + prev_output_tokens[channel] = prev_output_tokens[channel][:, -1:] + if positions is not None: + positions = positions[:, -1:] + + # embed tokens and positions + x = self.embed_tokens(prev_output_tokens[channel]) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + x = self.embed_scale * x + + if self.quant_noise is not None: + x = self.quant_noise(x) + + if positions is not None: + x += positions + + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + + x = self.dropout_module(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + x_list.append(x) + + self_attn_padding_mask: Optional[Tensor] = None + if ( + self.cross_self_attention + or prev_output_tokens[self.channels[0]].eq(self.padding_idx).any() + ): + self_attn_padding_mask = prev_output_tokens[self.channels[0]].eq( + self.padding_idx + ) + + # decoder layers + attn: Optional[Dict[Tensor]] = None + inner_states: List[Optional[Dict[str, Tensor]]] = [ + {channel: x_list[i] for i, channel in enumerate(self.channels)} + ] + for idx, layer in enumerate(self.layers): + if incremental_state is None and not full_context_alignment: + self_attn_mask = self.buffered_future_mask(x_list[0]) + else: + self_attn_mask = None + + # need to change to tensor for the checkpoint activation to work + if isinstance(x_list, list): + x_list = torch.stack(x_list) + x_list, layer_attn_list, _ = layer( + x_list, + encoder_out["encoder_out"][0] + if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0) + else None, + encoder_out["encoder_padding_mask"][0] + if ( + encoder_out is not None + and len(encoder_out["encoder_padding_mask"]) > 0 + ) + else None, + incremental_state, + self_attn_mask=self_attn_mask, + self_attn_padding_mask=self_attn_padding_mask, + need_attn=bool((idx == alignment_layer)), + need_head_weights=bool((idx == alignment_layer)), + ) + + inner_states.append( + {channel: x_list[i] for i, channel in enumerate(self.channels)} + ) + if idx == alignment_layer and all( + layer_attn is not None for layer_attn in layer_attn_list + ): + attn = { + channel: layer_attn_list[i].float().to(x_list[0]) + for i, channel in enumerate(self.channels) + } + # change back from tensor to list + if not isinstance(x_list, list): + x_list = list(torch.unbind(x_list)) + + if attn is not None: + for channel in attn: + if alignment_heads is not None: + attn[channel] = attn[channel][:alignment_heads] + + # average probabilities over heads + attn[channel] = attn[channel].mean(dim=0) + + for i, x in enumerate(x_list): + if self.layer_norm is not None: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + if self.project_out_dim is not None: + x = self.project_out_dim(x) + + x_list[i] = x + + x = {channel: x_list[i] for i, channel in enumerate(self.channels)} + + return x, {"attn": [attn], "inner_states": inner_states} + + def output_layer(self, features): + """Project features to the vocabulary size. + Return a dictionary of the form: + { + 'input-channel': { + 'predicted-channel': token prediction tensor of shape `(batch, tgt_len, vocab)`, + } + } + + if duration_prediction is enabled + { + 'input-channel': { + 'predicted-channel': { + 'pred_token': token prediction tensor of shape `(batch, tgt_len, vocab)`, + 'pred_duration': duration prediction tensor + } + } + } + """ + # project back to size of vocabulary + if self.output_duration_prediction is None: + if self.is_cross_prediction: + return { + channel: { + pred_channel: self.output_projection[j - i](features[channel]) + for j, pred_channel in enumerate(self.channels) + } + for i, channel in enumerate(self.channels) + } + else: + return { + channel: {channel: self.output_projection[0](features[channel])} + for i, channel in enumerate(self.channels) + } + else: + if self.is_cross_prediction: + return { + channel: { + pred_channel: { + "pred_token": self.output_projection[j - i]( + features[channel] + ), + "pred_duration": self.output_duration_prediction[j - i]( + features[channel] + ), + } + for j, pred_channel in enumerate(self.channels) + } + for i, channel in enumerate(self.channels) + } + else: + return { + channel: { + channel: { + "pred_token": self.output_projection[0](features[channel]), + "pred_duration": self.output_duration_prediction[0]( + features[channel] + ), + } + } + for i, channel in enumerate(self.channels) + } + + def max_positions(self): + """Maximum output length supported by the decoder.""" + if self.embed_positions is None: + return self.max_target_positions + return min(self.max_target_positions, self.embed_positions.max_positions) + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround. + if ( + self._future_mask.size(0) == 0 + or (not self._future_mask.device == tensor.device) + or self._future_mask.size(0) < dim + ): + self._future_mask = torch.triu( + utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1 + ) + self._future_mask = self._future_mask.to(tensor) + return self._future_mask[:dim, :dim] + + def get_normalized_probs_scriptable( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Get normalized probabilities (or log probs) from a net's output.""" + + logits_dict = net_output[0] + out_dict = {} + for channel in logits_dict: + out_dict[channel] = {} + for pred_channel in logits_dict[channel]: + if isinstance(logits_dict[channel][pred_channel], dict): + pred_token_logits = logits_dict[channel][pred_channel]["pred_token"] + else: + pred_token_logits = logits_dict[channel][pred_channel] + if log_probs: + out = utils.log_softmax( + pred_token_logits, dim=-1, onnx_trace=self.onnx_trace + ) + else: + out = utils.softmax( + pred_token_logits, dim=-1, onnx_trace=self.onnx_trace + ) + if isinstance(logits_dict[channel][pred_channel], dict): + out_dict[channel][pred_channel] = { + "pred_token": out, + "pred_duration": logits_dict[channel][pred_channel][ + "pred_duration" + ].float(), + } # move to float32 to avoid inf loss + else: + out_dict[channel][pred_channel] = out + return out_dict + + def reorder_incremental_state_scripting( + self, + incremental_state: List[Dict[str, Dict[str, Optional[Tensor]]]], + new_order: Tensor, + ): + """Main entry point for reordering the incremental state. + + Due to limitations in TorchScript, we call this function in + :class:`fairseq.sequence_generator.SequenceGenerator` instead of + calling :func:`reorder_incremental_state` directly. + """ + for module in self.modules(): + if hasattr(module, "reorder_incremental_state"): + for i, incremental_state_channel in enumerate(incremental_state): + result = module.reorder_incremental_state( + incremental_state_channel, new_order + ) + if result is not None: + incremental_state[i] = result diff --git a/fairseq/models/speech_dlm/modules/speech_dlm_decoder_layer.py b/fairseq/models/speech_dlm/modules/speech_dlm_decoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..fb65fdf810d613b7d8615d3cbe56b633833367f9 --- /dev/null +++ b/fairseq/models/speech_dlm/modules/speech_dlm_decoder_layer.py @@ -0,0 +1,717 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Tuple, Optional + +import torch +import torch.nn as nn +from fairseq import utils +from fairseq.modules import LayerNorm, MultiheadAttention +from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.quant_noise import quant_noise +from torch import Tensor + + +class CrossChannelTransformerDecoderLayer(nn.Module): + """Cross-Attention Transformer Decoder Layer block as described + in the paper: https://arxiv.org/pdf/2203.16502.pdf + + Composed of a Multi-head Self Attention block followed by a + Multi-head Cross-Attention block which attends to the self-attention + outputs of the other channels. The weights of the attention blocks + in all channels are shared. + + Args: + args (argparse.Namespace): parsed command-line arguments + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__( + self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False + ): + super().__init__() + self.embed_dim = args.decoder_embed_dim + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) + self.quant_noise = getattr(args, "quant_noise_pq", 0) + self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8) + + # This cross_self_attention is used for encoder-decoder systems, + # It's not the cross-channel attention (defined below as cross_channel_attn) + self.cross_self_attention = getattr(args, "cross_self_attention", False) + + self.self_attn = self.build_self_attention( + self.embed_dim, + args, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + ) + self.cross_channel_attn = self.build_cross_channel_attention( + self.embed_dim, + args, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + ) + + self.activation_fn = utils.get_activation_fn( + activation=str(args.activation_fn) + if getattr(args, "activation_fn", None) is not None + else "relu" + ) + activation_dropout_p = getattr(args, "activation_dropout", 0) or 0 + if activation_dropout_p == 0: + # for backwards compatibility with models that use args.relu_dropout + activation_dropout_p = getattr(args, "relu_dropout", 0) or 0 + self.activation_dropout_module = FairseqDropout( + float(activation_dropout_p), module_name=self.__class__.__name__ + ) + self.normalize_before = args.decoder_normalize_before + + # use layerNorm rather than FusedLayerNorm for exporting. + # char_inputs can be used to determint this. + # TODO remove this once we update apex with the fix + export = getattr(args, "char_inputs", False) + self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) + self.cross_channel_attn_layer_norm = LayerNorm(self.embed_dim, export=export) + + if no_encoder_attn: + self.encoder_attn = None + self.encoder_attn_layer_norm = None + else: + self.encoder_attn = self.build_encoder_attention(self.embed_dim, args) + self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) + + self.fc1 = self.build_fc1( + self.embed_dim, + args.decoder_ffn_embed_dim, + self.quant_noise, + self.quant_noise_block_size, + ) + self.fc2 = self.build_fc2( + args.decoder_ffn_embed_dim, + self.embed_dim, + self.quant_noise, + self.quant_noise_block_size, + ) + + self.final_layer_norm = LayerNorm(self.embed_dim, export=export) + self.need_attn = True + + self.onnx_trace = False + + def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): + return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) + + def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): + return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) + + def build_self_attention( + self, embed_dim, args, add_bias_kv=False, add_zero_attn=False + ): + return MultiheadAttention( + embed_dim, + args.decoder_attention_heads, + dropout=args.attention_dropout, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + self_attention=not getattr(args, "cross_self_attention", False), + q_noise=self.quant_noise, + qn_block_size=self.quant_noise_block_size, + ) + + def build_cross_channel_attention( + self, embed_dim, args, add_bias_kv=False, add_zero_attn=False + ): + return MultiheadAttention( + embed_dim, + args.decoder_attention_heads, + dropout=args.attention_dropout, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + self_attention=False, + q_noise=self.quant_noise, + qn_block_size=self.quant_noise_block_size, + ) + + def build_encoder_attention(self, embed_dim, args): + return MultiheadAttention( + embed_dim, + args.decoder_attention_heads, + kdim=getattr(args, "encoder_embed_dim", None), + vdim=getattr(args, "encoder_embed_dim", None), + dropout=args.attention_dropout, + encoder_decoder_attention=True, + q_noise=self.quant_noise, + qn_block_size=self.quant_noise_block_size, + ) + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def residual_connection(self, x, residual): + return residual + x + + def forward( + self, + x_list_tensor: List[torch.Tensor], + encoder_out: Optional[torch.Tensor] = None, + encoder_padding_mask: Optional[torch.Tensor] = None, + incremental_state: Optional[ + List[Dict[str, Dict[str, Optional[Tensor]]]] + ] = None, + prev_self_attn_state: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, + prev_attn_state: Optional[List[torch.Tensor]] = None, + self_attn_mask: Optional[torch.Tensor] = None, + self_attn_padding_mask: Optional[torch.Tensor] = None, + need_attn: bool = False, + need_head_weights: bool = False, + ): + """ + Args: + x_list_tensor (List[Tensor]): list of input tensors in different channels, + each tensor is of shape `(seq_len, batch, embed_dim)` + encoder_padding_mask (ByteTensor, optional): binary + ByteTensor of shape `(batch, src_len)` where padding + elements are indicated by ``1``. + incremental_state (optional): list of incremental_state dictionaries over + different channels (sequence generation mode) + prev_self_attn_state (List[Tuple[Tensor, Tensor]], optional): list of tuples + (self_attn_state, cross_channel_attn_state) over different channels + need_attn (bool, optional): return attention weights + need_head_weights (bool, optional): return attention weights + for each head (default: return average over heads). + + Returns: + list of encoded output of shape `(seq_len, batch, embed_dim)` + """ + n_channels = len(x_list_tensor) + if need_head_weights: + need_attn = True + + # incremental_state is a list of dictionaries over different channels + if incremental_state is not None: + assert isinstance(incremental_state, list) + assert len(incremental_state) == n_channels + + # prev_self_attn_state is a list of tuples (self_attn_state, cross_channel_attn_state) over different channels + if prev_self_attn_state is not None: + assert isinstance(prev_self_attn_state, list) + assert len(prev_self_attn_state) == n_channels + for prev_self_attn_state_channel in prev_self_attn_state: + assert isinstance(prev_self_attn_state_channel, tuple) + assert len(prev_self_attn_state_channel) == 2 + + # Backup for other channels & cross channel attention + self_attn_mask_orin = self_attn_mask + self_attn_padding_mask_orin = self_attn_padding_mask + + x_list = [] + attn_list = [] + for i, x in enumerate(x_list_tensor): + residual = x + + if self.normalize_before: + x = self.self_attn_layer_norm(x) + + if prev_self_attn_state is not None: + prev_key, prev_value = prev_self_attn_state[i][0][:2] + saved_state: Dict[str, Optional[Tensor]] = { + "prev_key": prev_key, + "prev_value": prev_value, + } + if len(prev_self_attn_state[i][0]) >= 3: + saved_state["prev_key_padding_mask"] = prev_self_attn_state[i][0][2] + assert incremental_state is not None + self.self_attn._set_input_buffer(incremental_state[i], saved_state) + _self_attn_input_buffer = self.self_attn._get_input_buffer( + incremental_state[i] if incremental_state is not None else None + ) + if self.cross_self_attention and not ( + incremental_state is not None + and _self_attn_input_buffer is not None + and "prev_key" in _self_attn_input_buffer + ): + if self_attn_mask_orin is not None: + assert encoder_out is not None + self_attn_mask = torch.cat( + ( + x.new_zeros(x.size(0), encoder_out.size(0)), + self_attn_mask_orin, + ), + dim=1, + ) + if self_attn_padding_mask_orin is not None: + if encoder_padding_mask is None: + assert encoder_out is not None + encoder_padding_mask = self_attn_padding_mask_orin.new_zeros( + encoder_out.size(1), encoder_out.size(0) + ) + self_attn_padding_mask = torch.cat( + (encoder_padding_mask, self_attn_padding_mask_orin), dim=1 + ) + assert encoder_out is not None + y = torch.cat((encoder_out, x), dim=0) + else: + y = x + + x, attn = self.self_attn( + query=x, + key=y, + value=y, + key_padding_mask=self_attn_padding_mask, + incremental_state=incremental_state[i] + if incremental_state is not None + else None, + need_weights=False, + attn_mask=self_attn_mask, + ) + + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.self_attn_layer_norm(x) + + if self.encoder_attn is not None and encoder_out is not None: + residual = x + if self.normalize_before: + x = self.encoder_attn_layer_norm(x) + if prev_attn_state is not None: + prev_key, prev_value = prev_attn_state[:2] + saved_state: Dict[str, Optional[Tensor]] = { + "prev_key": prev_key, + "prev_value": prev_value, + } + if len(prev_attn_state) >= 3: + saved_state["prev_key_padding_mask"] = prev_attn_state[2] + assert incremental_state is not None + self.encoder_attn._set_input_buffer( + incremental_state[i], saved_state + ) + + x, attn = self.encoder_attn( + query=x, + key=encoder_out, + value=encoder_out, + key_padding_mask=encoder_padding_mask, + incremental_state=incremental_state[i] + if incremental_state is not None + else None, + static_kv=True, + need_weights=need_attn or (not self.training and self.need_attn), + need_head_weights=need_head_weights, + ) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.encoder_attn_layer_norm(x) + + x_list.append(x) + attn_list.append(attn) + + # Store attentions & new x(s) (bc the old x(s) are used in other channels) + x_list_new = [] + # Here comes the cross channel attention + for i, x in enumerate(x_list): + residual = x + if self.normalize_before: + x = self.cross_channel_attn_layer_norm(x) + + if prev_self_attn_state is not None: + prev_key, prev_value = prev_self_attn_state[i][1][:2] + saved_state: Dict[str, Optional[Tensor]] = { + "prev_key": prev_key, + "prev_value": prev_value, + } + if len(prev_self_attn_state[i][1]) >= 3: + saved_state["prev_key_padding_mask"] = prev_self_attn_state[i][1][2] + assert incremental_state is not None + self.cross_channel_attn._set_input_buffer( + incremental_state[i], saved_state + ) + + # The cross attention is computed with the concatenation of attentions from other channels + if len(x_list) > 1: + x_other = torch.cat( + [x_list[(i + j) % len(x_list)] for j in range(1, len(x_list))], + dim=0, + ) + else: + # Self-attention when having only one channel + x_other = x_list[i] + + x, attn = self.cross_channel_attn( + query=x, + key=x_other, + value=x_other, + key_padding_mask=self_attn_padding_mask_orin, + incremental_state=incremental_state[i] + if incremental_state is not None + else None, + need_weights=False, + attn_mask=self_attn_mask_orin, + ) + + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.cross_channel_attn_layer_norm(x) + + x_list_new.append(x) + x_list = x_list_new + + for i, x in enumerate(x_list): + residual = x + if self.normalize_before: + x = self.final_layer_norm(x) + + x = self.activation_fn(self.fc1(x)) + x = self.activation_dropout_module(x) + x = self.fc2(x) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.final_layer_norm(x) + + x_list[i] = x + # Trick for the checkpoint activation + x_list_tensor = torch.stack(x_list) + if self.onnx_trace and incremental_state is not None: + self_and_cross_attn_state_list = [] + for i in range(n_channels): + self_and_cross_attn_state = [] + for self_attn_module in [self.self_attn, self.cross_channel_attn]: + saved_state = self_attn_module._get_input_buffer( + incremental_state[i] + ) + assert saved_state is not None + if self_attn_padding_mask is not None: + self_attn_module_state = [ + saved_state["prev_key"], + saved_state["prev_value"], + saved_state["prev_key_padding_mask"], + ] + else: + self_attn_module_state = [ + saved_state["prev_key"], + saved_state["prev_value"], + ] + self_and_cross_attn_state.append(self_attn_module_state) + self_and_cross_attn_state_list.append(tuple(self_and_cross_attn_state)) + return x_list_tensor, attn_list, self_and_cross_attn_state_list + return x_list_tensor, attn_list, None + + def make_generation_fast_(self, need_attn: bool = False, **kwargs): + self.need_attn = need_attn + + +# Rewrite fairseq.modules.TransformerDecoderLayer +# to be compatible with checkpoint_activations +# (avoid forwarding model multiple times) +class StandardTransformerDecoderLayer(nn.Module): + """Rewrite fairseq.modules.TransformerDecoderLayer to avoid forwarding + model multiple times and be compatible with checkpoint_activations. + + The input is expected to be a list of tensors from different channels, + each is forwarded to the same model (shared attention weights). + + In the original paper each operation (multi-head attention, encoder + attention or FFN) is postprocessed with: `dropout -> add residual -> + layernorm`. In the tensor2tensor code they suggest that learning is more + robust when preprocessing each layer with layernorm and postprocessing with: + `dropout -> add residual`. We default to the approach in the paper, but the + tensor2tensor approach can be enabled by setting + *args.decoder_normalize_before* to ``True``. + + Args: + args (argparse.Namespace): parsed command-line arguments + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__( + self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False + ): + super().__init__() + self.embed_dim = args.decoder_embed_dim + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) + self.quant_noise = getattr(args, "quant_noise_pq", 0) + self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8) + + self.cross_self_attention = getattr(args, "cross_self_attention", False) + + self.self_attn = self.build_self_attention( + self.embed_dim, + args, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + ) + + self.activation_fn = utils.get_activation_fn( + activation=str(args.activation_fn) + if getattr(args, "activation_fn", None) is not None + else "relu" + ) + activation_dropout_p = getattr(args, "activation_dropout", 0) or 0 + if activation_dropout_p == 0: + # for backwards compatibility with models that use args.relu_dropout + activation_dropout_p = getattr(args, "relu_dropout", 0) or 0 + self.activation_dropout_module = FairseqDropout( + float(activation_dropout_p), module_name=self.__class__.__name__ + ) + self.normalize_before = args.decoder_normalize_before + + # use layerNorm rather than FusedLayerNorm for exporting. + # char_inputs can be used to determint this. + # TODO remove this once we update apex with the fix + export = getattr(args, "char_inputs", False) + self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) + + if no_encoder_attn: + self.encoder_attn = None + self.encoder_attn_layer_norm = None + else: + self.encoder_attn = self.build_encoder_attention(self.embed_dim, args) + self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) + + self.fc1 = self.build_fc1( + self.embed_dim, + args.decoder_ffn_embed_dim, + self.quant_noise, + self.quant_noise_block_size, + ) + self.fc2 = self.build_fc2( + args.decoder_ffn_embed_dim, + self.embed_dim, + self.quant_noise, + self.quant_noise_block_size, + ) + + self.final_layer_norm = LayerNorm(self.embed_dim, export=export) + self.need_attn = True + + self.onnx_trace = False + + def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): + return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) + + def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): + return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) + + def build_self_attention( + self, embed_dim, args, add_bias_kv=False, add_zero_attn=False + ): + return MultiheadAttention( + embed_dim, + args.decoder_attention_heads, + dropout=args.attention_dropout, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + self_attention=not getattr(args, "cross_self_attention", False), + q_noise=self.quant_noise, + qn_block_size=self.quant_noise_block_size, + ) + + def build_encoder_attention(self, embed_dim, args): + return MultiheadAttention( + embed_dim, + args.decoder_attention_heads, + kdim=getattr(args, "encoder_embed_dim", None), + vdim=getattr(args, "encoder_embed_dim", None), + dropout=args.attention_dropout, + encoder_decoder_attention=True, + q_noise=self.quant_noise, + qn_block_size=self.quant_noise_block_size, + ) + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def residual_connection(self, x, residual): + return residual + x + + def forward( + self, + x_list_tensor: List[torch.Tensor], + encoder_out: Optional[torch.Tensor] = None, + encoder_padding_mask: Optional[torch.Tensor] = None, + incremental_state: Optional[ + List[Dict[str, Dict[str, Optional[Tensor]]]] + ] = None, + prev_self_attn_state: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, + prev_attn_state: Optional[List[torch.Tensor]] = None, + self_attn_mask: Optional[torch.Tensor] = None, + self_attn_padding_mask: Optional[torch.Tensor] = None, + need_attn: bool = False, + need_head_weights: bool = False, + ): + """ + Args: + x_list_tensor (List[Tensor]): list of input tensors in different channels, + each tensor is of shape `(seq_len, batch, embed_dim)` + encoder_padding_mask (ByteTensor, optional): binary + ByteTensor of shape `(batch, src_len)` where padding + elements are indicated by ``1``. + incremental_state (optional): list of incremental_state dictionaries over + different channels (sequence generation mode) + prev_self_attn_state (List[Tuple[Tensor, Tensor]], optional): list of tuples + (self_attn_state, cross_channel_attn_state) over different channels + need_attn (bool, optional): return attention weights + need_head_weights (bool, optional): return attention weights + for each head (default: return average over heads). + + Returns: + list of encoded output of shape `(seq_len, batch, embed_dim)` + """ + n_channels = len(x_list_tensor) + if need_head_weights: + need_attn = True + + # incremental_state is a list of dictionaries over different channels + if incremental_state is not None: + assert isinstance(incremental_state, list) + assert len(incremental_state) == n_channels + + # prev_self_attn_state is a list of self_attn_state over different channels + if prev_self_attn_state is not None: + assert isinstance(prev_self_attn_state, list) + assert len(prev_self_attn_state) == n_channels + + x_list = [] + attn_list = [] + for i, x in enumerate(x_list_tensor): + residual = x + + if self.normalize_before: + x = self.self_attn_layer_norm(x) + + if prev_self_attn_state is not None: + prev_key, prev_value = prev_self_attn_state[i][:2] + saved_state: Dict[str, Optional[Tensor]] = { + "prev_key": prev_key, + "prev_value": prev_value, + } + if len(prev_self_attn_state[i]) >= 3: + saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] + assert incremental_state is not None + self.self_attn._set_input_buffer(incremental_state[i], saved_state) + _self_attn_input_buffer = self.self_attn._get_input_buffer( + incremental_state + ) + if self.cross_self_attention and not ( + incremental_state is not None + and _self_attn_input_buffer is not None + and "prev_key" in _self_attn_input_buffer + ): + if self_attn_mask is not None: + assert encoder_out is not None + self_attn_mask = torch.cat( + (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), + dim=1, + ) + if self_attn_padding_mask is not None: + if encoder_padding_mask is None: + assert encoder_out is not None + encoder_padding_mask = self_attn_padding_mask.new_zeros( + encoder_out.size(1), encoder_out.size(0) + ) + self_attn_padding_mask = torch.cat( + (encoder_padding_mask, self_attn_padding_mask), dim=1 + ) + assert encoder_out is not None + y = torch.cat((encoder_out, x), dim=0) + else: + y = x + + x, attn = self.self_attn( + query=x, + key=y, + value=y, + key_padding_mask=self_attn_padding_mask, + incremental_state=incremental_state[i] + if incremental_state is not None + else None, + need_weights=False, + attn_mask=self_attn_mask, + ) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.self_attn_layer_norm(x) + + if self.encoder_attn is not None and encoder_out is not None: + residual = x + if self.normalize_before: + x = self.encoder_attn_layer_norm(x) + if prev_attn_state is not None: + prev_key, prev_value = prev_attn_state[:2] + saved_state: Dict[str, Optional[Tensor]] = { + "prev_key": prev_key, + "prev_value": prev_value, + } + if len(prev_attn_state) >= 3: + saved_state["prev_key_padding_mask"] = prev_attn_state[2] + assert incremental_state is not None + self.encoder_attn._set_input_buffer(incremental_state, saved_state) + + x, attn = self.encoder_attn( + query=x, + key=encoder_out, + value=encoder_out, + key_padding_mask=encoder_padding_mask, + incremental_state=incremental_state[i] + if incremental_state is not None + else None, + static_kv=True, + need_weights=need_attn or (not self.training and self.need_attn), + need_head_weights=need_head_weights, + ) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.encoder_attn_layer_norm(x) + + residual = x + if self.normalize_before: + x = self.final_layer_norm(x) + + x = self.activation_fn(self.fc1(x)) + x = self.activation_dropout_module(x) + x = self.fc2(x) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.final_layer_norm(x) + + x_list.append(x) + attn_list.append(attn) + + # Trick for the checkpoint activation + x_list_tensor = torch.stack(x_list) + if self.onnx_trace and incremental_state is not None: + self_attn_state_list = [] + for i in range(n_channels): + saved_state = self.self_attn._get_input_buffer(incremental_state[i]) + assert saved_state is not None + if self_attn_padding_mask is not None: + self_attn_state = [ + saved_state["prev_key"], + saved_state["prev_value"], + saved_state["prev_key_padding_mask"], + ] + else: + self_attn_state = [ + saved_state["prev_key"], + saved_state["prev_value"], + ] + self_attn_state_list.append(self_attn_state) + return x_list_tensor, attn_list, self_attn_state_list + return x_list_tensor, attn_list, None + + def make_generation_fast_(self, need_attn: bool = False, **kwargs): + self.need_attn = need_attn diff --git a/fairseq/models/speech_dlm/sequence_generator/__init__.py b/fairseq/models/speech_dlm/sequence_generator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a88e14457766bf30c916926f171b81fa60dd33ce --- /dev/null +++ b/fairseq/models/speech_dlm/sequence_generator/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .multichannel_sequence_generator import * # noqa diff --git a/fairseq/models/speech_dlm/sequence_generator/multichannel_search.py b/fairseq/models/speech_dlm/sequence_generator/multichannel_search.py new file mode 100644 index 0000000000000000000000000000000000000000..db4b77f3457ca2da089772e2c7e3df850c87442f --- /dev/null +++ b/fairseq/models/speech_dlm/sequence_generator/multichannel_search.py @@ -0,0 +1,430 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, Optional + +import torch +import torch.nn as nn +from torch import Tensor + + +class MultichannelSearch(nn.Module): + def __init__(self, tgt_dicts): + super().__init__() + tgt_dict = list(tgt_dicts.values())[0] + self.pad = tgt_dict.pad() + self.unk = tgt_dict.unk() + self.eos = tgt_dict.eos() + for tgt_dict in tgt_dicts.values(): + assert self.pad == tgt_dict.pad() + assert self.unk == tgt_dict.unk() + assert self.eos == tgt_dict.eos() + self.vocab_sizes = {channel: len(tgt_dicts[channel]) for channel in tgt_dicts} + self.src_lengths = torch.tensor(-1) + self.supports_constraints = False + self.stop_on_max_len = False + + def step( + self, step, lprobs, scores, prev_output_tokens=None, original_batch_idxs=None + ): + """Take a single search step. + + Args: + step: the current search step, starting at 0 + lprobs: dictionary of channels {channel : (bsz x input_beam_size x vocab_size_channel)} + the model's log-probabilities over the vocabulary at the current step + scores: {channel : (bsz x input_beam_size x step)} + the historical model scores of each hypothesis up to this point + prev_output_tokens: {channel : (bsz x step)} + the previously generated oputput tokens + original_batch_idxs: (bsz) + the tensor with the batch indices, in the range [0, bsz) + this is useful in case there has been applied a re-ordering + and we need to know the orignal indices + + Return: A tuple of (scores, indices, beams) where: + scores: {channel : (bsz x output_beam_size)} + the scores of the chosen elements; output_beam_size can be + larger than input_beam_size, e.g., we may return + 2*input_beam_size to account for EOS + indices: {channel : (bsz x output_beam_size)} + the indices of the chosen elements + beams: (bsz x output_beam_size) + the hypothesis ids of the chosen elements, in the range [0, input_beam_size) + """ + raise NotImplementedError + + @torch.jit.export + def set_src_lengths(self, src_lengths): + self.src_lengths = src_lengths + + @torch.jit.export + def init_constraints(self, batch_constraints: Optional[Tensor], beam_size: int): + """Initialize constraint states for constrained decoding (if supported). + + Args: + batch_constraints: (torch.Tensor, optional) + the list of constraints, in packed form + beam_size: (int) + the beam size + Returns: + *encoder_out* rearranged according to *new_order* + """ + pass + + def prune_sentences(self, batch_idxs: Tensor): + """ + Removes constraint states for completed sentences (if supported). + This is called from sequence_generator._generate() when sentences are + deleted from the batch. + + Args: + batch_idxs: Indices of *sentences* whose constraint state should be *kept*. + """ + pass + + def update_constraints(self, active_hypos: Tensor): + """ + Updates the constraint states by selecting the beam items that are retained. + This is called at each time step of sequence_generator._generate() when + the set of 2 * {beam_size} candidate hypotheses are reduced to the beam size. + + Args: + active_hypos: (batch size, beam size) + list of integers denoting, for each sentence, which beam candidate items + should be kept. + """ + pass + + +def unravel_index(index, shape): + out = [] + for dim in reversed(shape): + out.append(index % dim) + index = index // dim + return torch.stack(tuple(reversed(out)), dim=-1) + + +def topk_sum(lprobs_list, k): + """ + lprobs_list = [lprobs_1,...,lprobs_n], where: + lprobs_1 : (batch_size x beam_size x vocab_1) + ... + lprobs_n : (batch_size x beam_size x vocab_n) + + Return: + - topk_values : (batch_size x k) + values of the topk sum of the form : + lprobs_1[bsz, beam_idx, vocab_1_idx] + ... + lprobs_n[bsz, beam_idx, vocab_n_idx] + - topk_idxs : (batch_size x k x n+1) + each (n+1)-tensor being [beam_idx, vocab_1_idx, ..., vocab_n_idx] + """ + # Reduce all lprobs to k candidates first to reduce later complexity + # We may assume that k << vocab + lprobs_topk_list = [] + lprobs_topk_indices_list = [] + for lprobs in lprobs_list: + k_i = min(k, lprobs.size(-1)) + topk_values, topk_indices = torch.topk(lprobs, k=k_i) + # topk_values : (batch_size x beam_size x k_i) + # topk_indices : (batch_size x beam_size x k_i) + lprobs_topk_list.append(topk_values) + lprobs_topk_indices_list.append(topk_indices) + + # Compute all possible sums + sum_lprobs_topk = lprobs_topk_list[0] + for i in range(1, len(lprobs_topk_list)): + unsqueezed_lprobs = lprobs_topk_list[i] + for _ in range(i): + unsqueezed_lprobs = unsqueezed_lprobs.unsqueeze(-2) + sum_lprobs_topk = sum_lprobs_topk.unsqueeze(-1) + unsqueezed_lprobs + # sum_lprobs : (batch_size x beam_size x k_1 x ... x k_n) + + # Get the top k sums and the (transformed indices) + topk_sum_values, topk_sum_indices = torch.topk( + sum_lprobs_topk.view(sum_lprobs_topk.size(0), -1), k=k + ) + # topk_sum_values : (batch_size x k) + # topk_sum_indices : (batch_size x k) + topk_sum_indices = unravel_index(topk_sum_indices, tuple(sum_lprobs_topk.shape[1:])) + # topk_sum_indices : (batch_size x k x n+1) + + # Convert the transformed indices to the true indices + for i_batch in range(topk_sum_indices.size(0)): + for i_cand in range(topk_sum_indices.size(1)): + i_beam, *transformed_vocab_indices = topk_sum_indices[i_batch, i_cand] + true_vocab_indices = [i_beam] + for j, transformed_vocab_j_idx in enumerate(transformed_vocab_indices): + true_vocab_j_idx = lprobs_topk_indices_list[j][ + i_batch, i_beam, transformed_vocab_j_idx + ] + true_vocab_indices.append(true_vocab_j_idx) + topk_sum_indices[i_batch, i_cand] = torch.tensor(true_vocab_indices) + + topk_sum_beams = topk_sum_indices[:, :, 0] + topk_sum_indices = topk_sum_indices[:, :, 1:] + + return topk_sum_values, topk_sum_indices, topk_sum_beams + + +class MultichannelBeamSearch(MultichannelSearch): + def __init__(self, tgt_dicts): + super().__init__(tgt_dicts) + self.constraint_states = None + + @torch.jit.export + def step( + self, + step: int, + lprobs, + scores: Optional[Dict[str, Tensor]], + prev_output_tokens: Optional[Dict[str, Tensor]] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + channels = list(lprobs.keys()) + bsz, beam_size, _ = lprobs[channels[0]].size() + + lprobs_list = [] + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam + for channel in channels: + lprobs_list.append(lprobs[channel][:, ::beam_size, :].contiguous()) + else: + # make probs contain cumulative scores for each hypothesis + assert scores is not None + for channel in channels: + lprobs_list.append( + lprobs[channel] + scores[channel][:, :, step - 1].unsqueeze(-1) + ) + + topk_sum_values, topk_sum_indices, topk_sum_beams = topk_sum( + lprobs_list, k=beam_size * 2 + ) + + beams_buf = topk_sum_beams + scores_buf = {} + indices_buf = {} + for i, channel in enumerate(channels): + indices_buf[channel] = topk_sum_indices[:, :, i] + scores_buf[channel] = ( + torch.tensor( + [ + lprobs_list[i][i_batch, i_beam, i_index] + for i_batch in range(bsz) + for i_beam, i_index in zip( + beams_buf[i_batch], indices_buf[channel][i_batch] + ) + ] + ) + .view(bsz, -1) + .to(lprobs_list[i].device) + ) + + # At this point, beams_buf and indices_buf are single-dim and contain relative indices + return scores_buf, indices_buf, beams_buf + + +class ContiguousMultichannelBeamSearch(MultichannelSearch): + def __init__(self, tgt_dicts): + super().__init__(tgt_dicts) + self.constraint_states = None + + @torch.jit.export + def step( + self, + step: int, + lprobs, + scores: Optional[Tensor], + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + n_channels = len(lprobs) + bsz, beam_size, _ = lprobs[0].size() + + lprobs_list = [] + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam + for i in range(n_channels): + lprobs_list.append(lprobs[i][:, ::beam_size, :].contiguous()) + else: + # make probs contain cumulative scores for each hypothesis + assert scores is not None + for i in range(n_channels): + lprobs_list.append(lprobs[i] + scores[:, :, step - 1, i].unsqueeze(-1)) + + topk_sum_values, topk_sum_indices, topk_sum_beams = topk_sum( + lprobs_list, k=beam_size * 2 + ) + + beams_buf = topk_sum_beams + indices_buf = topk_sum_indices + scores_buf = ( + torch.tensor( + [ + lprobs_list[i][i_batch, i_beam, i_index] + for i in range(len(lprobs_list)) + for i_batch in range(bsz) + for i_beam, i_index in zip( + beams_buf[i_batch], indices_buf[i_batch, :, i] + ) + ] + ) + .view(len(lprobs_list), bsz, -1) + .permute(1, 2, 0) + .to(lprobs_list[0].device) + ) + + # At this point, beams_buf and indices_buf are single-dim and contain relative indices + return scores_buf, indices_buf, beams_buf + + +class ContiguousMultichannelSampling(MultichannelSearch): + sampling_topk: int + sampling_topp: float + + def __init__(self, tgt_dicts, sampling_topk=-1, sampling_topp=-1.0): + super().__init__(tgt_dicts) + self.sampling_topk = sampling_topk + self.sampling_topp = sampling_topp + + def _sample_topp(self, lprobs): + """Sample among the smallest set of elements whose cumulative probability mass exceeds p. + + See `"The Curious Case of Neural Text Degeneration" + (Holtzman et al., 2019) `_. + + Args: + lprobs: (bsz x input_beam_size x vocab_size) + the model's log-probabilities over the vocabulary at the current step + + Return: A tuple of (trimed_probs, truncated_indices) where: + trimed_probs: (bsz x input_beam_size x ?) + the model's probabilities over the elements selected to sample from. The + width of the third dimension is determined by top-P. + truncated_indices: (bsz x input_beam_size x ?) + the indices of the chosen elements. + """ + probs = lprobs.exp_() + + # sort the last dimension (vocab dimension) in descending order + sorted_probs, sorted_indices = probs.sort(descending=True) + + # compute a mask to indicate the words to be included in the top-P set. + cumsum_probs = sorted_probs.cumsum(dim=2) + mask = cumsum_probs.lt(self.sampling_topp) + + # note that mask was computed by 'lt'. One more word needs to be included + # so that the cumulative probability mass can exceed p. + cumsum_mask = mask.cumsum(dim=2) + last_included = cumsum_mask[:, :, -1:] + last_included.clamp_(0, mask.size()[2] - 1) + mask = mask.scatter_(2, last_included, 1) + + # truncate unnecessary dims. + max_dim = last_included.max() + truncated_mask = mask[:, :, : max_dim + 1] + truncated_probs = sorted_probs[:, :, : max_dim + 1] + truncated_indices = sorted_indices[:, :, : max_dim + 1] + + # trim the words that are not in top-P by setting their probabilities + # to 0, so that they would not be sampled later. + trim_mask = ~truncated_mask + trimed_probs = truncated_probs.masked_fill_(trim_mask, 0) + return trimed_probs, truncated_indices + + @torch.jit.export + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + n_channels = len(lprobs) + bsz, beam_size, vocab_size = lprobs[0].size() + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam + for i in range(n_channels): + lprobs[i] = lprobs[i][:, ::beam_size, :].contiguous() + + probs = [] + top_indices = [] + for i in range(n_channels): + if self.sampling_topp > 0: + # only sample from the smallest set of words whose cumulative probability mass exceeds p + probs_i, top_indices_i = self._sample_topp(lprobs[i]) + elif self.sampling_topk > 0: + # only sample from top-k candidates + lprobs[i], top_indices_i = lprobs[i].topk( + min(self.sampling_topk, lprobs[i].size(-1)) + ) + probs_i = lprobs[i].exp_() + else: + probs_i = lprobs[i].exp_() + + # dummy data to be consistent with true branch for type check + top_indices_i = torch.empty(0).to(probs_i) + probs.append(probs_i) + top_indices.append(top_indices_i) + # sample + indices_buf = [] + for i in range(n_channels): + if step == 0: + indices_buf.append( + torch.multinomial( + probs[i].view(bsz, -1), + beam_size, + replacement=True, + ).view(bsz, beam_size) + ) + else: + indices_buf.append( + torch.multinomial( + probs[i].view(bsz * beam_size, -1), + 1, + replacement=True, + ).view(bsz, beam_size) + ) + + if step == 0: + for i in range(n_channels): + # expand to beam size + probs[i] = probs[i].expand(bsz, beam_size, -1) + + # gather scores + scores_buf = [] + for i in range(n_channels): + scores_buf.append( + torch.gather(probs[i], dim=2, index=indices_buf[i].unsqueeze(-1)) + ) + scores_buf[i] = scores_buf[i].log_().view(bsz, -1) + + # remap indices if using top-k or top-P sampling + if self.sampling_topk > 0 or self.sampling_topp > 0: + for i in range(n_channels): + indices_buf[i] = torch.gather( + top_indices[i].expand(bsz, beam_size, -1), + dim=2, + index=indices_buf[i].unsqueeze(-1), + ).squeeze(2) + + if step == 0: + beams_buf = indices_buf[0].new_zeros(bsz, beam_size) + else: + beams_buf = torch.arange(0, beam_size).to(indices_buf[0]).repeat(bsz, 1) + # make scores cumulative + for i in range(n_channels): + scores_buf[i].add_( + torch.gather(scores[:, :, step - 1, i], dim=1, index=beams_buf) + ) + scores_buf = torch.stack(scores_buf, dim=-1) + indices_buf = torch.stack(indices_buf, dim=-1) + + return scores_buf, indices_buf, beams_buf diff --git a/fairseq/models/speech_dlm/sequence_generator/multichannel_sequence_generator.py b/fairseq/models/speech_dlm/sequence_generator/multichannel_sequence_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..24807b866de15661973af9e8079a4bd50bb8004b --- /dev/null +++ b/fairseq/models/speech_dlm/sequence_generator/multichannel_sequence_generator.py @@ -0,0 +1,1110 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Dict, List, Optional + +from omegaconf.listconfig import ListConfig +from omegaconf.dictconfig import DictConfig + +import torch +import torch.nn as nn +from fairseq.models import FairseqIncrementalDecoder +from torch import Tensor +from fairseq.ngram_repeat_block import NGramRepeatBlock +from .multichannel_search import ContiguousMultichannelBeamSearch +from fairseq.models.speech_dlm import SpeechDLM + + +class MultichannelSequenceGenerator(nn.Module): + def __init__( + self, + models, + tgt_dicts, + beam_size=1, + max_len_a=0, + max_len_b=200, + min_len=1, + normalize_scores=True, + len_penalty=1.0, + unk_penalty=0.0, + temperature=1.0, + match_source_len=False, + no_repeat_ngram_size=0, + search_strategy=None, + eos=None, + symbols_to_strip_from_output=None, + lm_model=None, + lm_weight=1.0, + duration_temperature=1.0, + ): + """Generate multi-channel parallel units with the SpeechDLM model + as described in the paper: https://arxiv.org/pdf/2203.16502.pdf; + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models, + currently support fairseq.models.TransformerModel for scripting + beam_size (int, optional): beam width (default: 1) + max_len_a/b (int, optional): generate sequences of maximum length + ax + b, where x is the source length + min_len (int, optional): the minimum length of the generated output + (not including end-of-sentence) + normalize_scores (bool, optional): normalize scores by the length + of the output (default: True) + len_penalty (float, optional): length penalty, where <1.0 favors + shorter, >1.0 favors longer sentences (default: 1.0) + unk_penalty (float, optional): unknown word penalty, where <0 + produces more unks, >0 produces fewer (default: 0.0) + temperature (float, optional): temperature, where values + >1.0 produce more uniform samples and values <1.0 produce + sharper samples (default: 1.0) + match_source_len (bool, optional): outputs should match the source + length (default: False) + duration_temperature (float, optional): rate of the duration prediction, + higher rate induces a faster generated wav (default: 1.0) + """ + super().__init__() + if isinstance(models, MultichannelEnsembleModel): + self.model = models + else: + self.model = MultichannelEnsembleModel(models) + self.tgt_dicts = tgt_dicts + self.pad = list(tgt_dicts.values())[0].pad() + self.unk = list(tgt_dicts.values())[0].unk() + self.eos = list(tgt_dicts.values())[0].eos() if eos is None else eos + self.symbols_to_strip_from_output = ( + symbols_to_strip_from_output.union({self.eos}) + if symbols_to_strip_from_output is not None + else {self.eos} + ) + self.channels = list(tgt_dicts.keys()) + self.n_channels = len(self.channels) + self.vocab_sizes = [len(tgt_dicts[channel]) for channel in self.channels] + # the max beam size is the dictionary size - 1, since we never select pad + max_possible_beam_size = 1 + for i in self.vocab_sizes: + max_possible_beam_size *= i - 1 + self.beam_size = min(beam_size, max_possible_beam_size) + self.max_len_a = max_len_a + self.max_len_b = max_len_b + self.min_len = min_len + + self.normalize_scores = normalize_scores + self.len_penalty = len_penalty + self.unk_penalty = unk_penalty + if isinstance(temperature, (int, float)): + temperature = {channel: temperature for channel in self.channels} + elif isinstance(temperature, ListConfig) or isinstance(temperature, list): + temperature = { + channel: temperature[i] for i, channel in enumerate(self.channels) + } + assert isinstance(temperature, DictConfig) or isinstance( + temperature, dict + ), f"temperature: expected dict, but found {type(temperature)}" + self.temperature = temperature + self.match_source_len = match_source_len + + if no_repeat_ngram_size > 0: + self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size) + else: + self.repeat_ngram_blocker = None + + for channel in temperature: + assert temperature[channel] > 0, "--temperature must be greater than 0" + + if search_strategy is None: + self.search = ContiguousMultichannelBeamSearch(tgt_dicts) + else: + self.search = search_strategy + # We only need to set src_lengths in LengthConstrainedBeamSearch. + # As a module attribute, setting it would break in multithread + # settings when the model is shared. + self.should_set_src_lengths = ( + hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths + ) + + self.model.eval() + + self.lm_model = lm_model + self.lm_weight = lm_weight + if self.lm_model is not None: + self.lm_model.eval() + + self.duration_prediction = bool( + str(getattr(models[0].decoder.args, "duration_prediction", "false")).lower() + == "true" + ) + self.delayed_duration = bool( + str( + getattr(models[0].decoder.args, "delayed_duration_target", "false") + ).lower() + == "true" + ) + self.duration_temperature = duration_temperature + + def cuda(self): + self.model.cuda() + return self + + @torch.no_grad() + def forward( + self, + sample: Dict[str, Dict[str, Tensor]], # TODO: Modify this + prefix_tokens: Optional[Dict[str, Tensor]] = None, + bos_token: Optional[int] = None, + ): + """Generate a batch of translations. + + Args: + sample (dict): batch + prefix_tokens (dict of torch.LongTensor, optional): force decoder to begin + with these tokens + bos_token (int, optional): beginning of sentence token + (default: self.eos) + """ + return self._generate(sample, prefix_tokens, bos_token=bos_token) + + @torch.no_grad() + def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs): + """Generate translations. Match the api of other fairseq generators. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models + sample (dict): batch + prefix_tokens (dict of torch.LongTensor, optional): force decoder to begin + with these tokens + constraints (torch.LongTensor, optional): force decoder to include + the list of constraints + bos_token (int, optional): beginning of sentence token + (default: self.eos) + """ + return self._generate(sample, **kwargs) + + def _generate( + self, + sample: Dict[str, Dict[str, Tensor]], + prefix_tokens: Optional[Dict[str, Tensor]] = None, + constraints: Optional[Tensor] = None, + bos_token: Optional[int] = None, + ): + """ + Here sample is expected to have the following form + { + 'id': index, + 'net_input': { + 'src_tokens': { + 'channel1' : tensor((batch x src_length)), + 'channel2' : tensor((batch x src_length)), + }, + ... + }, + } + and prefix_tokens + { + 'channel1' : tensor((batch x prefix_length)), + 'channel2' : tensor((batch x prefix_length)), + } + """ + if self.model.is_speech_dlm: + incremental_states = torch.jit.annotate( + List[Dict[str, Dict[str, Optional[Tensor]]]], + [ + torch.jit.annotate( + List[Dict[str, Dict[str, Optional[Tensor]]]], + [{} for _ in range(self.n_channels)], + ) + for i in range(self.model.models_size) + ], + ) + else: + incremental_states = torch.jit.annotate( + List[Dict[str, Dict[str, Optional[Tensor]]]], + [ + torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) + for i in range(self.model.models_size) + ], + ) + net_input = sample["net_input"] + # Convert from dict to tensor form + # shape of src_tokens : (bsz x src_len x n_channels) + src_tokens = torch.stack( + [net_input["src_tokens"][channel] for channel in self.channels], dim=-1 + ) + prefix_tokens = torch.stack( + [prefix_tokens[channel] for channel in self.channels], dim=-1 + ) + # length of the source text being the character length except EndOfSentence and pad + src_lengths = ( + (src_tokens[..., 0].ne(self.eos) & src_tokens[..., 0].ne(self.pad)) + .long() + .sum(dim=1) + ) + + # bsz: total number of sentences in beam + # Note that src_tokens may have more than 2 dimensions (i.e. audio features) + bsz, src_len = src_tokens.size()[:2] + beam_size = self.beam_size + + if constraints is not None and not self.search.supports_constraints: + raise NotImplementedError( + "Target-side constraints were provided, but search method doesn't support them" + ) + + # Initialize constraints, when active + self.search.init_constraints(constraints, beam_size) + + max_len: int = -1 + if self.match_source_len: + max_len = src_lengths.max().item() + else: + max_len = min( + int(self.max_len_a * src_len + self.max_len_b), + # exclude the EOS marker + self.model.max_decoder_positions() - 1, + ) + assert ( + self.min_len <= max_len + ), "min_len cannot be larger than max_len, please adjust these!" + # compute the encoder output for each beam + encoder_outs = self.model.forward_encoder(net_input) + + # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores + new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) + new_order = new_order.to(src_tokens.device).long() + encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order) + # ensure encoder_outs is a List. + assert encoder_outs is not None + + # initialize buffers + # cumulative scores of hypotheses + scores = ( + torch.zeros(bsz * beam_size, max_len + 1, self.n_channels) + .to(src_tokens) + .float() + ) # +1 for eos; pad is never chosen for scoring + tokens = ( + torch.zeros(bsz * beam_size, max_len + 2, self.n_channels) + .to(src_tokens) + .long() + .fill_(self.pad) + ) # +2 for eos and pad + tokens[:, 0] = self.eos if bos_token is None else bos_token + attn: Optional[Tensor] = None + + # A list that indicates candidates that should be ignored. + # For example, suppose we're sampling and have already finalized 2/5 + # samples. Then cands_to_ignore would mark 2 positions as being ignored, + # so that we only finalize the remaining 3 samples. + cands_to_ignore = ( + torch.zeros(bsz, beam_size).to(src_tokens).eq(-1) + ) # forward and backward-compatible False mask + + # list of completed sentences + finalized = torch.jit.annotate( + List[List[Dict[str, Tensor]]], + [torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)], + ) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step + + finished = [ + False for i in range(bsz) + ] # a boolean array indicating if the sentence at the index is finished or not + num_remaining_sent = bsz # number of sentences remaining + + # number of candidate hypos per step + cand_size = 2 * beam_size # 2 x beam size in case half are EOS + + # offset arrays for converting between different indexing schemes + bbsz_offsets = ( + (torch.arange(0, bsz) * beam_size) + .unsqueeze(1) + .type_as(tokens) + .to(src_tokens.device) + ) + cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_tokens.device) + + reorder_state: Optional[Tensor] = None + batch_idxs: Optional[Tensor] = None + + original_batch_idxs: Optional[Tensor] = None + if "id" in sample and isinstance(sample["id"], Tensor): + original_batch_idxs = sample["id"] + else: + original_batch_idxs = torch.arange(0, bsz).type_as(tokens) + + if self.duration_prediction: + dur_counter = torch.ones(bsz * beam_size, self.n_channels).to(src_tokens) + # save the indice where the dur_counter just copied from dur_pred + dur_counter_jump_indices = None + + for step in range(max_len + 1): # one extra step for EOS marker + # reorder decoder internal states based on the prev choice of beams + if reorder_state is not None: + if batch_idxs is not None: + # update beam indices to take into account removed sentences + corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as( + batch_idxs + ) + reorder_state.view(-1, beam_size).add_( + corr.unsqueeze(-1) * beam_size + ) + original_batch_idxs = original_batch_idxs[batch_idxs] + self.model.reorder_incremental_state(incremental_states, reorder_state) + encoder_outs = self.model.reorder_encoder_out( + encoder_outs, reorder_state + ) + + input_tokens = { + channel: tokens[:, : step + 1, i] + for i, channel in enumerate(self.channels) + } + + lprobs_dict, avg_attn_scores = self.model.forward_decoder( + input_tokens, + encoder_outs, + incremental_states, + self.temperature, + ) + + # Because the sizes of vocab is different, we cannot concat the lprobs to form a single tensor + if not self.duration_prediction: + lprobs_list = list(lprobs_dict.values()) + else: + lprobs_list = [ + net_output["pred_token"] for net_output in lprobs_dict.values() + ] + + # non-positive predicted durations + dur_preds = ( + torch.stack( + [ + net_output["pred_duration"] + for net_output in lprobs_dict.values() + ] + ) + .squeeze(-1) + .T + ) + dur_preds = dur_preds / self.duration_temperature + dur_preds = dur_preds.round().long() + dur_preds[dur_preds < 1] = 1 + + # dur_preds & dur_counter needs to be modified when there isn't an edge + if step > 0: + non_edge_indices = tokens[:, step, :] == tokens[:, step - 1, :] + if self.delayed_duration: + dur_preds[non_edge_indices] = 1 + else: + if dur_counter_jump_indices is not None: + dur_counter[dur_counter_jump_indices & non_edge_indices] = 2 + + # update dur_counter + if step > 0: + if self.delayed_duration: + dur_counter -= ( + (dur_counter == 1) + | (tokens[:, step, :] == tokens[:, step - 1, :]) + ).int() + dur_counter[dur_counter < 0] = 0 + else: + dur_counter -= ( + tokens[:, step, :] == tokens[:, step - 1, :] + ).int() + dur_counter[dur_counter < 1] = 1 + + # whether to copy previous token (ie. if the counter is still on) + # and get get the new duration + if self.delayed_duration: + dur_counter_jump_indices = dur_counter == 0 + dur_counter[dur_counter_jump_indices] = dur_preds[ + dur_counter_jump_indices + ] + + # whether to copy previous token in this step + copy_prev_token = dur_counter != 1 + if self.delayed_duration is False: + dur_counter_jump_indices = dur_counter == 1 + dur_counter[dur_counter_jump_indices] = dur_preds[ + dur_counter_jump_indices + ] + # else: + # dur_counter[dur_counter==0] = dur_preds[dur_counter==0] - 1 + # copy_prev_token = (dur_counter > 0) + + if self.lm_model is not None: + assert False, "Currently not supported in multichannelLM case" + + for i in range(self.n_channels): + lprobs_list[i][lprobs_list[i] != lprobs_list[i]] = torch.tensor( + -math.inf + ).to(lprobs_list[i]) + + lprobs_list[i][:, self.pad] = -math.inf # never select pad + lprobs_list[i][:, self.unk] -= self.unk_penalty # apply unk penalty + + # handle max length constraint + if step >= max_len: + lprobs_list[i][:, : self.eos] = -math.inf + lprobs_list[i][:, self.eos + 1 :] = -math.inf + else: + lprobs_list[i][ + :, self.eos + ] = -math.inf # quick fix for short generation + + # handle prefix tokens (possibly with different lengths) + if ( + prefix_tokens is not None + and step < prefix_tokens.size(1) + and step < max_len + ): + ( + lprobs_list[i], + tokens[..., i], + scores[..., i], + ) = self._prefix_tokens( + step, + lprobs_list[i], + scores[..., i], + tokens[..., i], + prefix_tokens[..., i], + beam_size, + ) + if self.duration_prediction: + # Can copy previous token if the prefix token is padding or unk (1-channel conditionned case) + can_copy_mask = ( + prefix_tokens[:, step, i].eq(self.pad) + | prefix_tokens[:, step, i].eq(self.unk) + ).repeat_interleave(beam_size) + copy_prev_token[:, i] &= can_copy_mask + elif step < self.min_len: + # minimum length constraint (does not apply if using prefix_tokens) + lprobs_list[i][:, self.eos] = -math.inf + + if self.duration_prediction: + if step < max_len: + for j in range(copy_prev_token.size(0)): + if copy_prev_token[j, i]: + prev_token = tokens[j, step, i] + lprobs_list[i][j, :prev_token] = -math.inf + lprobs_list[i][j, prev_token + 1 :] = -math.inf + # lprobs_list[i][j, prev_token] = 0. + # dur_counter[j,i] -= 1 + # else: + # prev_token = tokens[j, step, i] + # if not (lprobs_list[i][j,:].ne(-math.inf).nonzero() == prev_token).all(): + # lprobs_list[i][j, prev_token] = -math.inf + # dur_counter[j,i] = 0. + + # Record attention scores, only support avg_attn_scores is a Tensor + if avg_attn_scores is not None: + if attn is None: + attn = torch.empty( + bsz * beam_size, avg_attn_scores.size(1), max_len + 2 + ).to(scores) + attn[:, :, step + 1].copy_(avg_attn_scores) + + scores = scores.type_as(lprobs_list[0]) + eos_bbsz_idx = torch.empty(0).to( + tokens + ) # indices of hypothesis ending with eos (finished sentences) + eos_scores = torch.empty(0).to( + scores + ) # scores of hypothesis ending with eos (finished sentences) + + if self.should_set_src_lengths: + self.search.set_src_lengths(src_lengths) + + if self.repeat_ngram_blocker is not None: + for i in range(self.n_channels): + lprobs_list[i] = self.repeat_ngram_blocker( + tokens, lprobs_list[i], bsz, beam_size, step + ) + + # Shape: (batch, cand_size) + cand_scores, cand_indices, cand_beams = self.search.step( + step, + [ + lprobs_list[i].view(bsz, -1, self.vocab_sizes[i]) + for i in range(self.n_channels) + ], + scores.view(bsz, beam_size, -1, self.n_channels)[:, :, :step, :], + tokens[:, : step + 1], + original_batch_idxs, + ) + + # cand_bbsz_idx contains beam indices for the top candidate + # hypotheses, with a range of values: [0, bsz*beam_size), + # and dimensions: [bsz, cand_size] + cand_bbsz_idx = cand_beams.add(bbsz_offsets) + + # finalize hypotheses that end in eos + # Shape of eos_mask: (batch size, beam size) + eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf) + eos_mask = torch.any(eos_mask, dim=-1, keepdim=False) + eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask) + + # only consider eos when it's among the top beam_size indices + # Now we know what beam item(s) to finish + # Shape: 1d list of absolute-numbered + eos_bbsz_idx = torch.masked_select( + cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size] + ) + + finalized_sents: List[int] = [] + if eos_bbsz_idx.numel() > 0: + eos_scores = torch.stack( + [ + torch.masked_select( + cand_scores[:, :beam_size, i], mask=eos_mask[:, :beam_size] + ) + for i in range(self.n_channels) + ], + dim=-1, + ) + finalized_sents = self.finalize_hypos( + step, + eos_bbsz_idx, + eos_scores, + tokens, + scores, + finalized, + finished, + beam_size, + attn, + src_lengths, + max_len, + ) + num_remaining_sent -= len(finalized_sents) + + assert num_remaining_sent >= 0 + if num_remaining_sent == 0: + break + if self.search.stop_on_max_len and step >= max_len: + break + assert step < max_len, f"{step} < {max_len}" + + # Remove finalized sentences (ones for which {beam_size} + # finished hypotheses have been generated) from the batch. + if len(finalized_sents) > 0: + new_bsz = bsz - len(finalized_sents) + + # construct batch_idxs which holds indices of batches to keep for the next pass + batch_mask = torch.ones( + bsz, dtype=torch.bool, device=cand_indices.device + ) + batch_mask[finalized_sents] = False + # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it + batch_idxs = torch.arange( + bsz, device=cand_indices.device + ).masked_select(batch_mask) + + # Choose the subset of the hypothesized constraints that will continue + self.search.prune_sentences(batch_idxs) + + eos_mask = eos_mask[batch_idxs] + cand_beams = cand_beams[batch_idxs] + bbsz_offsets.resize_(new_bsz, 1) + cand_bbsz_idx = cand_beams.add(bbsz_offsets) + cand_scores = cand_scores[batch_idxs] + cand_indices = cand_indices[batch_idxs] + + if prefix_tokens is not None: + prefix_tokens = prefix_tokens[batch_idxs] + src_lengths = src_lengths[batch_idxs] + cands_to_ignore = cands_to_ignore[batch_idxs] + + scores = scores.view(bsz, -1)[batch_idxs].view( + new_bsz * beam_size, -1, self.n_channels + ) + tokens = tokens.view(bsz, -1)[batch_idxs].view( + new_bsz * beam_size, -1, self.n_channels + ) + if self.duration_prediction: + dur_counter = dur_counter.view(bsz, -1)[batch_idxs].view( + new_bsz * beam_size, self.n_channels + ) + if attn is not None: + attn = attn.view(bsz, -1)[batch_idxs].view( + new_bsz * beam_size, attn.size(1), -1 + ) + bsz = new_bsz + else: + batch_idxs = None + + # Set active_mask so that values > cand_size indicate eos hypos + # and values < cand_size indicate candidate active hypos. + # After, the min values per row are the top candidate active hypos + + # Rewrite the operator since the element wise or is not supported in torchscript. + + eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size])) + active_mask = torch.add( + eos_mask.type_as(cand_offsets) * cand_size, + cand_offsets[: eos_mask.size(1)], + ) + + # get the top beam_size active hypotheses, which are just + # the hypos with the smallest values in active_mask. + # {active_hypos} indicates which {beam_size} hypotheses + # from the list of {2 * beam_size} candidates were + # selected. Shapes: (batch size, beam size) + new_cands_to_ignore, active_hypos = torch.topk( + active_mask, k=beam_size, dim=1, largest=False + ) + + # update cands_to_ignore to ignore any finalized hypos. + cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size] + # Make sure there is at least one active item for each sentence in the batch. + assert (~cands_to_ignore).any(dim=1).all() + + # update cands_to_ignore to ignore any finalized hypos + # {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam + # can be selected more than once). + active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos) + active_bbsz_idx = active_bbsz_idx.view(-1) + + # active_scores = torch.stack([ + # torch.gather(cand_scores[...,0], dim=1, index=active_hypos) + # for i in range(self.n_channels) + # ], dim = -1) + # active_scores = active_scores.view(-1) + + # copy tokens and scores for active hypotheses + + # Set the tokens for each beam (can select the same row more than once) + tokens[:, : step + 1] = torch.index_select( + tokens[:, : step + 1], dim=0, index=active_bbsz_idx + ) + # Select the next token for each of them + for i in range(self.n_channels): + tokens.view(bsz, beam_size, -1, self.n_channels)[ + :, :, step + 1, i + ] = torch.gather(cand_indices[..., i], dim=1, index=active_hypos) + if step > 0: + scores[:, :step] = torch.index_select( + scores[:, :step], dim=0, index=active_bbsz_idx + ) + for i in range(self.n_channels): + scores.view(bsz, beam_size, -1, self.n_channels)[ + :, :, step, i + ] = torch.gather(cand_scores[..., i], dim=1, index=active_hypos) + + if self.duration_prediction: + dur_counter = torch.index_select( + dur_counter, dim=0, index=active_bbsz_idx + ) + + # Update constraints based on which candidates were selected for the next beam + self.search.update_constraints(active_hypos) + + # copy attention for active hypotheses + if attn is not None: + attn[:, :, : step + 2] = torch.index_select( + attn[:, :, : step + 2], dim=0, index=active_bbsz_idx + ) + + # reorder incremental state in decoder + reorder_state = active_bbsz_idx + + # sort by score descending + for sent in range(len(finalized)): + scores = torch.tensor( + [float(elem["score"].item()) for elem in finalized[sent]] + ) + _, sorted_scores_indices = torch.sort(scores, descending=True) + finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices] + finalized[sent] = torch.jit.annotate( + List[Dict[str, Tensor]], finalized[sent] + ) + return finalized + + def _prefix_tokens( + self, step: int, lprobs, scores, tokens, prefix_tokens, beam_size: int + ): + """Handle prefix tokens""" + prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1) + prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1)) + prefix_mask = prefix_toks.ne(self.pad) + # used for 1-channel generation, do not force the unk token (i.e. unk tokens are changed) + prefix_mask &= prefix_toks.ne(self.unk) + # zeroing the copying tokens + # if step > 0: + # copy_mask = (prefix_tokens[:, step] == prefix_tokens[:, step-1]).unsqueeze(-1).repeat(1, beam_size).view(-1) + # prefix_lprobs[copy_mask & prefix_mask] = 0. + lprobs[prefix_mask] = torch.tensor(-math.inf).to(lprobs) + lprobs[prefix_mask] = lprobs[prefix_mask].scatter( + -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask] + ) + # shouldn't stop at unk token + unk_mask = prefix_toks.eq(self.unk) + if len(lprobs[unk_mask]) > 0: + # otherwise it won't assign to lprobs, + # see: https://discuss.pytorch.org/t/how-to-mask-and-assign-a-value-to-tensor/18437 + copy_lprobs = lprobs[unk_mask][:, :] + copy_lprobs[:, self.eos] = -math.inf + lprobs[unk_mask] = copy_lprobs + # if prefix includes eos, then we should make sure tokens and + # scores are the same across all beams + eos_mask = prefix_toks.eq(self.eos) + if eos_mask.any(): + # validate that the first beam matches the prefix + first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[ + :, 0, 1 : step + 1 + ] + eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] + target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] + assert (first_beam == target_prefix).all() + + # copy tokens, scores and lprobs from the first beam to all beams + tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, beam_size) + scores = self.replicate_first_beam(scores, eos_mask_batch_dim, beam_size) + lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, beam_size) + return lprobs, tokens, scores + + def replicate_first_beam(self, tensor, mask, beam_size: int): + tensor = tensor.view(-1, beam_size, tensor.size(-1)) + tensor[mask] = tensor[mask][:, :1, :] + return tensor.view(-1, tensor.size(-1)) + + def finalize_hypos( + self, + step: int, + bbsz_idx, + eos_scores, + tokens, + scores, + finalized: List[List[Dict[str, Tensor]]], + finished: List[bool], + beam_size: int, + attn: Optional[Tensor], + src_lengths, + max_len: int, + ): + """Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly. + A sentence is finalized when {beam_size} finished items have been collected for it. + + Returns number of sentences (not beam items) being finalized. + These will be removed from the batch and not processed further. + Args: + bbsz_idx (Tensor): + """ + assert bbsz_idx.numel() == eos_scores.size(0) + + # clone relevant token and attention tensors. + # tokens is (batch * beam, max_len). So the index_select + # gets the newly EOS rows, then selects cols 1..{step + 2} + tokens_clone = tokens.index_select(0, bbsz_idx)[ + :, 1 : step + 2 + ] # skip the first index, which is EOS + + tokens_clone[:, step] = self.eos + attn_clone = ( + attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2] + if attn is not None + else None + ) + + # compute scores per token position + pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1] + pos_scores[:, step, :] = eos_scores + # convert from cumulative to per-position scores + pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] + + # normalize sentence-level scores + if self.normalize_scores: + eos_scores /= (step + 1) ** self.len_penalty + + # cum_unfin records which sentences in the batch are finished. + # It helps match indexing between (a) the original sentences + # in the batch and (b) the current, possibly-reduced set of + # sentences. + cum_unfin: List[int] = [] + prev = 0 + for f in finished: + if f: + prev += 1 + else: + cum_unfin.append(prev) + + # The keys here are of the form "{sent}_{unfin_idx}", where + # "unfin_idx" is the index in the current (possibly reduced) + # list of sentences, and "sent" is the index in the original, + # unreduced batch + # set() is not supported in script export + sents_seen: Dict[str, Optional[Tensor]] = {} + + # For every finished beam item + for i in range(bbsz_idx.size()[0]): + idx = bbsz_idx[i] + score = eos_scores[i].sum() + # sentence index in the current (possibly reduced) batch + unfin_idx = idx // beam_size + # sentence index in the original (unreduced) batch + sent = unfin_idx + cum_unfin[unfin_idx] + # Cannot create dict for key type '(int, int)' in torchscript. + # The workaround is to cast int to string + seen = str(sent.item()) + "_" + str(unfin_idx.item()) + if seen not in sents_seen: + sents_seen[seen] = None + + if self.match_source_len and step > src_lengths[unfin_idx]: + score = torch.tensor(-math.inf).to(score) + + # An input sentence (among those in a batch) is finished when + # beam_size hypotheses have been collected for it + if len(finalized[sent]) < beam_size: + if attn_clone is not None: + # remove padding tokens from attn scores + hypo_attn = attn_clone[i] + else: + hypo_attn = torch.empty(0) + + finalized[sent].append( + { + "tokens": tokens_clone[i], + "score": score, + "attention": hypo_attn, # src_len x tgt_len + "alignment": torch.empty(0), + "positional_scores": pos_scores[i], + } + ) + + newly_finished: List[int] = [] + + for seen in sents_seen.keys(): + # check termination conditions for this sentence + sent: int = int(float(seen.split("_")[0])) + unfin_idx: int = int(float(seen.split("_")[1])) + + if not finished[sent] and self.is_finished( + step, unfin_idx, max_len, len(finalized[sent]), beam_size + ): + finished[sent] = True + newly_finished.append(unfin_idx) + + return newly_finished + + def is_finished( + self, + step: int, + unfin_idx: int, + max_len: int, + finalized_sent_len: int, + beam_size: int, + ): + """ + Check whether decoding for a sentence is finished, which + occurs when the list of finalized sentences has reached the + beam size, or when we reach the maximum length. + """ + assert finalized_sent_len <= beam_size + if finalized_sent_len == beam_size or step == max_len: + return True + return False + + +class MultichannelEnsembleModel(nn.Module): + """A wrapper around an ensemble of SpeechDLM models.""" + + def __init__(self, models): + super().__init__() + self.models_size = len(models) + # method '__len__' is not supported in ModuleList for torch script + self.single_model = models[0] + self.models = nn.ModuleList(models) + + self.has_incremental: bool = False + if all( + hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder) + for m in models + ): + self.has_incremental = True + + if isinstance(models[0], SpeechDLM): + self.is_speech_dlm = True + # Otherwise it's a multi-channel language model (without cross-prediction outputs) + else: + self.is_speech_dlm = False + + if getattr(models[0].decoder.args, "duration_prediction", False): + self.is_duration_prediction = True + else: + self.is_duration_prediction = False + + def forward(self): + pass + + def has_encoder(self): + return hasattr(self.single_model, "encoder") + + def has_incremental_states(self): + return self.has_incremental + + def max_decoder_positions(self): + return min([m.max_decoder_positions() for m in self.models]) + + @torch.jit.export + def forward_encoder(self, net_input: Dict[str, Tensor]): + if not self.has_encoder(): + return None + return [model.encoder.forward_torchscript(net_input) for model in self.models] + + @torch.jit.export + def forward_decoder( + self, + tokens, + encoder_outs: List[Dict[str, List[Tensor]]], + incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], + temperature: Dict[str, float] = 1.0, + ): + if isinstance(temperature, (float, int)): + temperature = {channel: temperature for channel in tokens} + log_probs = {channel: [] for channel in tokens} + avg_attn: Optional[Tensor] = None + encoder_out: Optional[Dict[str, List[Tensor]]] = None + for i, model in enumerate(self.models): + if self.has_encoder(): + encoder_out = encoder_outs[i] + # decode each model + if self.has_incremental_states(): + decoder_out = model.decoder.forward( + tokens, + encoder_out=encoder_out, + incremental_state=incremental_states[i], + ) + else: + decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out) + + attn: Optional[Tensor] = None + decoder_len = len(decoder_out) + if decoder_len > 1 and decoder_out[1] is not None: + if isinstance(decoder_out[1], Tensor): + attn = decoder_out[1] + else: + attn_holder = decoder_out[1]["attn"] + if isinstance(attn_holder, Tensor): + attn = attn_holder + elif attn_holder is not None: + attn = attn_holder[0] + if attn is not None: + attn = attn[:, -1, :] + + if self.is_speech_dlm: + if self.is_duration_prediction: + decoder_out_divided_by_temperature = { + channel_src: { + channel_pred: { + "pred_token": decoder_out[0][channel_src][channel_pred][ + "pred_token" + ][:, -1:, :].div_(temperature[channel_pred]), + "pred_duration": decoder_out[0][channel_src][ + channel_pred + ]["pred_duration"][:, -1:, :], + } + for channel_pred in decoder_out[0][channel_src] + } + for channel_src in decoder_out[0] + } + else: + decoder_out_divided_by_temperature = { + channel_src: { + channel_pred: decoder_out[0][channel_src][channel_pred][ + :, -1:, : + ].div_(temperature[channel_pred]) + for channel_pred in decoder_out[0][channel_src] + } + for channel_src in decoder_out[0] + } + else: + decoder_out_divided_by_temperature = { + channel: decoder_out[0][channel][:, -1:, :].div_( + temperature[channel] + ) + for channel in decoder_out[0] + } + decoder_out_tuple = ( + decoder_out_divided_by_temperature, + None if decoder_len <= 1 else decoder_out[1], + ) + + probs = model.get_normalized_probs( + decoder_out_tuple, log_probs=True, sample=None + ) + + if self.is_speech_dlm: + if self.is_duration_prediction: + probs = { + channel: { + "pred_token": probs[channel][channel]["pred_token"][ + :, -1, : + ], + "pred_duration": probs[channel][channel]["pred_duration"][ + :, -1, : + ], + } + for channel in probs + } + else: + probs = { + channel: probs[channel][channel][:, -1, :] for channel in probs + } + else: + probs = {channel: probs[channel][:, -1, :] for channel in probs} + if self.models_size == 1: + return probs, attn + + for channel in probs: + log_probs[channel].append(probs[channel]) + if attn is not None: + if avg_attn is None: + avg_attn = attn + else: + avg_attn.add_(attn) + + avg_probs = {} + for channel in log_probs: + avg_probs[channel] = torch.logsumexp( + torch.stack(log_probs[channel], dim=0), dim=0 + ) - math.log(self.models_size) + + if avg_attn is not None: + avg_attn.div_(self.models_size) + return avg_probs, avg_attn + + @torch.jit.export + def reorder_encoder_out( + self, encoder_outs: Optional[List[Dict[str, List[Tensor]]]], new_order + ): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + new_outs: List[Dict[str, List[Tensor]]] = [] + if not self.has_encoder(): + return new_outs + for i, model in enumerate(self.models): + assert encoder_outs is not None + new_outs.append( + model.encoder.reorder_encoder_out(encoder_outs[i], new_order) + ) + return new_outs + + @torch.jit.export + def reorder_incremental_state( + self, + incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], + new_order, + ): + if not self.has_incremental_states(): + return + for i, model in enumerate(self.models): + model.decoder.reorder_incremental_state_scripting( + incremental_states[i], new_order + ) diff --git a/fairseq/models/speech_dlm/speech_dlm.py b/fairseq/models/speech_dlm/speech_dlm.py new file mode 100644 index 0000000000000000000000000000000000000000..dc13f565f147229ff7d01c590184b14d52f308e3 --- /dev/null +++ b/fairseq/models/speech_dlm/speech_dlm.py @@ -0,0 +1,280 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from dataclasses import dataclass, field +from typing import Optional + +from fairseq import utils +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.models import ( + FairseqLanguageModel, + register_model, + register_model_architecture, +) +from fairseq.models.transformer import Embedding +from .modules.speech_dlm_decoder import CrossChannelTransformerDecoder +from omegaconf import II + + +DEFAULT_MAX_TARGET_POSITIONS = 1024 + +logger = logging.getLogger(__name__) + + +@dataclass +class SpeechDLMConfig(FairseqDataclass): + activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( + default="relu", metadata={"help": "activation function to use"} + ) + dropout: float = field(default=0.1, metadata={"help": "dropout probability"}) + attention_dropout: float = field( + default=0.0, metadata={"help": "dropout probability for attention weights"} + ) + activation_dropout: float = field( + default=0.0, metadata={"help": "dropout probability after activation in FFN."} + ) + relu_dropout: float = field( + default=0.0, metadata={"help": "dropout probability after activation in FFN."} + ) + decoder_embed_dim: int = field( + default=512, metadata={"help": "decoder embedding dimension"} + ) + decoder_output_dim: int = field( + default=512, metadata={"help": "decoder output dimension"} + ) + decoder_input_dim: int = field( + default=512, metadata={"help": "decoder input dimension"} + ) + decoder_ffn_embed_dim: int = field( + default=2048, metadata={"help": "decoder embedding dimension for FFN"} + ) + decoder_layers: int = field(default=6, metadata={"help": "num decoder layers"}) + decoder_cross_layers: int = field( + default=-1, metadata={"help": "num self cross attention decoder layers"} + ) + decoder_attention_heads: int = field( + default=8, metadata={"help": "num decoder attention heads"} + ) + decoder_normalize_before: bool = field( + default=False, metadata={"help": "apply layernorm before each decoder block"} + ) + no_decoder_final_norm: bool = field( + default=False, + metadata={"help": "don't add an extra layernorm after the last decoder block"}, + ) + no_token_positional_embeddings: bool = field( + default=False, + metadata={ + "help": "if set, disables positional embeddings (outside self attention)" + }, + ) + share_decoder_input_output_embed: bool = field( + default=False, metadata={"help": "share decoder input and output embeddings"} + ) + decoder_learned_pos: bool = field( + default=False, + metadata={"help": "use learned positional embeddings in the decoder"}, + ) + decoder_layerdrop: float = field( + default=0.0, metadata={"help": "LayerDrop probability for decoder"} + ) + decoder_layers_to_keep: Optional[str] = field( + default=None, + metadata={ + "help": "which layers to *keep* when pruning as a comma-separated list" + }, + ) + layernorm_embedding: bool = field( + default=False, metadata={"help": "add layernorm to embedding"} + ) + no_scale_embedding: bool = field( + default=False, metadata={"help": "if True, dont scale embeddings"} + ) + checkpoint_activations: bool = field( + default=False, metadata={"help": "checkpoint activations at each layer"} + ) + offload_activations: bool = field( + default=False, + metadata={"help": "move checkpointed activations to CPU after they are used."}, + ) + quant_noise_pq: float = field( + default=0.0, + metadata={"help": "iterative PQ quantization noise at training time"}, + ) + quant_noise_pq_block_size: int = field( + default=8, + metadata={"help": "block size of quantization noise at training time"}, + ) + # TODO common var add to parent + quant_noise_scalar: float = field( + default=0.0, + metadata={ + "help": "scalar quantization noise and scalar quantization at training time" + }, + ) + add_bos_token: bool = II("task.add_bos_token") + tokens_per_sample: int = II("task.tokens_per_sample") + max_target_positions: Optional[int] = II("task.max_target_positions") + tpu: bool = II("common.tpu") + duration_prediction: str = II("task.duration_prediction") + delayed_duration_target: str = II("task.delayed_duration_target") + main_and_cross_weights: str = II("criterion.main_and_cross_weights") + + +@register_model("speech_dlm", dataclass=SpeechDLMConfig) +class SpeechDLM(FairseqLanguageModel): + """Spoken Unit-based Dialogue Language Model model (SpeechDLM) as described + in the paper: https://arxiv.org/pdf/2203.16502.pdf + """ + + def __init__(self, decoder): + super().__init__(decoder) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + # make sure all arguments are present in older models + base_lm_architecture(args) + + if args.decoder_layers_to_keep: + args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) + + if args.decoder_cross_layers < 0: + args.decoder_cross_layers = args.decoder_layers + + if getattr(args, "max_target_positions", None) is None: + args.max_target_positions = getattr( + args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS + ) + + # Assert all dictionary to be the same + assert all( + task.source_dictionaries[channel] == task.source_dictionary + for channel in task.channels + ), "Source dictionaries of all channels are expected to be the same!!!" + assert all( + task.target_dictionaries[channel] == task.target_dictionary + for channel in task.channels + ), "Target dictionaries of all channels are expected to be the same!!!" + # Build the unit embeddings + embed_tokens = cls.build_embedding( + args, task.source_dictionary, args.decoder_input_dim + ) + + decoder = CrossChannelTransformerDecoder( + args, + task.target_dictionary, + embed_tokens, + channels=task.channels, + no_encoder_attn=True, + ) + return cls(decoder) + + @classmethod + def build_embedding(cls, args, dictionary, embed_dim, path=None): + embed_tokens = Embedding(len(dictionary), embed_dim, dictionary.pad()) + return embed_tokens + + @classmethod + def from_pretrained( + cls, + model_name_or_path, + checkpoint_file="model.pt", + data_name_or_path=".", + **kwargs, + ): + """ + Load a :class:`~fairseq.models.FairseqModel` from a pre-trained model + file. Downloads and caches the pre-trained model file if needed. + + The base implementation returns a + :class:`~fairseq.hub_utils.GeneratorHubInterface`, which can be used to + generate translations or sample from language models. The underlying + :class:`~fairseq.models.FairseqModel` can be accessed via the + *generator.models* attribute. + + This function return a class:`MultichannelGeneratorHubInterface` object, + which allows generation in multiple channels with a multichannel model. + + Args: + model_name_or_path (str): either the name of a pre-trained model to + load or a path/URL to a pre-trained model state dict + checkpoint_file (str, optional): colon-separated list of checkpoint + files in the model archive to ensemble (default: 'model.pt') + data_name_or_path (str, optional): point args.data to the archive + at the given path/URL. Can start with '.' or './' to reuse the + model archive path. + """ + from fairseq import hub_utils + from .hub_interface import MultichannelGeneratorHubInterface + + x = hub_utils.from_pretrained( + model_name_or_path, + checkpoint_file, + data_name_or_path, + archive_map=cls.hub_models(), + **kwargs, + ) + logger.info(x["args"]) + return MultichannelGeneratorHubInterface(x["args"], x["task"], x["models"]) + + @property + def supported_targets(self): + return {"next", "edge", "duration"} + + +def base_lm_architecture(args): + # backward compatibility for older model checkpoints + if hasattr(args, "decoder_final_norm"): + args.no_decoder_final_norm = not args.decoder_final_norm + + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_cross_layers = getattr(args, "decoder_cross_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) + args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) + args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) + + args.add_bos_token = getattr(args, "add_bos_token", False) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + + # Model training is not stable without this + args.decoder_normalize_before = True + args.no_decoder_final_norm = getattr(args, "no_decoder_final_norm", False) + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + args.layernorm_embedding = getattr(args, "layernorm_embedding", False) + args.checkpoint_activations = getattr(args, "checkpoint_activations", False) + args.offload_activations = getattr(args, "offload_activations", False) + if args.offload_activations: + args.checkpoint_activations = True + + +@register_model_architecture("speech_dlm", "speech_dlm_big") +def speech_dlm_big(args): + args.decoder_layers = getattr(args, "decoder_layers", 12) + args.decoder_cross_layers = getattr(args, "decoder_cross_layers", 12) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + base_lm_architecture(args) diff --git a/fairseq/models/speech_to_speech/__init__.py b/fairseq/models/speech_to_speech/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f29215c2fe6eedd203d105703ed94576c625ba86 --- /dev/null +++ b/fairseq/models/speech_to_speech/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .s2s_conformer import * # noqa +from .s2s_conformer_translatotron2 import * # noqa +from .s2s_conformer_unity import * # noqa +from .s2s_transformer import * # noqa diff --git a/fairseq/models/speech_to_speech/__pycache__/__init__.cpython-310.pyc b/fairseq/models/speech_to_speech/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7da1dd9c4b677bf6e09b70dba2ad375607e9786a Binary files /dev/null and b/fairseq/models/speech_to_speech/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/models/speech_to_speech/__pycache__/s2s_conformer.cpython-310.pyc b/fairseq/models/speech_to_speech/__pycache__/s2s_conformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aecf623945db7d47f96776cefd6e38c92ca97beb Binary files /dev/null and b/fairseq/models/speech_to_speech/__pycache__/s2s_conformer.cpython-310.pyc differ diff --git a/fairseq/models/speech_to_speech/__pycache__/s2s_conformer_translatotron2.cpython-310.pyc b/fairseq/models/speech_to_speech/__pycache__/s2s_conformer_translatotron2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..139e74c132d495943f6914c4e3cb8c58d8f2e51c Binary files /dev/null and b/fairseq/models/speech_to_speech/__pycache__/s2s_conformer_translatotron2.cpython-310.pyc differ diff --git a/fairseq/models/speech_to_speech/__pycache__/s2s_conformer_unity.cpython-310.pyc b/fairseq/models/speech_to_speech/__pycache__/s2s_conformer_unity.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..790a32045667e4f7df0989254a36efe36e95093d Binary files /dev/null and b/fairseq/models/speech_to_speech/__pycache__/s2s_conformer_unity.cpython-310.pyc differ diff --git a/fairseq/models/speech_to_speech/__pycache__/s2s_transformer.cpython-310.pyc b/fairseq/models/speech_to_speech/__pycache__/s2s_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76457aabd485d027e8a70ecd2aeeb586dd245d88 Binary files /dev/null and b/fairseq/models/speech_to_speech/__pycache__/s2s_transformer.cpython-310.pyc differ diff --git a/fairseq/models/speech_to_speech/modules/__init__.py b/fairseq/models/speech_to_speech/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fairseq/models/speech_to_speech/modules/__pycache__/__init__.cpython-310.pyc b/fairseq/models/speech_to_speech/modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..770e98644cea2ac76c33139eb2b5ceadef95a513 Binary files /dev/null and b/fairseq/models/speech_to_speech/modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/models/speech_to_speech/modules/__pycache__/ctc_decoder.cpython-310.pyc b/fairseq/models/speech_to_speech/modules/__pycache__/ctc_decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c24901c3d224f28280bb44602b7e80501c80961d Binary files /dev/null and b/fairseq/models/speech_to_speech/modules/__pycache__/ctc_decoder.cpython-310.pyc differ diff --git a/fairseq/models/speech_to_speech/modules/__pycache__/stacked_embedding.cpython-310.pyc b/fairseq/models/speech_to_speech/modules/__pycache__/stacked_embedding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88eeb9795b051e2d96da2b2c1cfad651e2f1c5dd Binary files /dev/null and b/fairseq/models/speech_to_speech/modules/__pycache__/stacked_embedding.cpython-310.pyc differ diff --git a/fairseq/models/speech_to_speech/modules/__pycache__/transformer_decoder_aug.cpython-310.pyc b/fairseq/models/speech_to_speech/modules/__pycache__/transformer_decoder_aug.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e8d1233dec168d5ed4a0c8b11c62a7d795955e8 Binary files /dev/null and b/fairseq/models/speech_to_speech/modules/__pycache__/transformer_decoder_aug.cpython-310.pyc differ diff --git a/fairseq/models/speech_to_speech/modules/__pycache__/transformer_encoder.cpython-310.pyc b/fairseq/models/speech_to_speech/modules/__pycache__/transformer_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79070be583731b620ae6ffafb11f184baa329dbf Binary files /dev/null and b/fairseq/models/speech_to_speech/modules/__pycache__/transformer_encoder.cpython-310.pyc differ diff --git a/fairseq/models/speech_to_speech/modules/ctc_decoder.py b/fairseq/models/speech_to_speech/modules/ctc_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..721efbf61ae335e13fbf40a2fecfb5bc97f9638a --- /dev/null +++ b/fairseq/models/speech_to_speech/modules/ctc_decoder.py @@ -0,0 +1,18 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from torch import nn + +from fairseq.models import FairseqEncoder + + +class CTCDecoder(FairseqEncoder): + def __init__(self, dictionary, in_dim): + super().__init__(dictionary) + self.proj = nn.Linear(in_dim, len(dictionary)) + + def forward(self, src_tokens, src_lengths=None, **kwargs): + encoder_out = self.proj(src_tokens) + return {"encoder_out": encoder_out} diff --git a/fairseq/models/speech_to_speech/modules/stacked_embedding.py b/fairseq/models/speech_to_speech/modules/stacked_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..5955a08538f54eaac871d222378a0ed86071581f --- /dev/null +++ b/fairseq/models/speech_to_speech/modules/stacked_embedding.py @@ -0,0 +1,48 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn + +from fairseq.models.transformer import Linear + + +class StackedEmbedding(nn.Embedding): + """Embedding module that supports stacked units -> single embedding""" + + def __init__(self, num_embeddings, embed_dim, padding_idx, num_stacked=1): + super().__init__(num_embeddings, embed_dim, padding_idx) + # follow transformer.Embedding + nn.init.normal_(self.weight, mean=0, std=embed_dim**-0.5) + nn.init.constant_(self.weight[padding_idx], 0) + + self.offset = ( + 4 # skip , , , , specific to fairseq dictionary + ) + self.vocab_size = num_embeddings - self.offset + self.num_stacked = num_stacked + + if self.num_stacked > 1: + self.project_in_dim = Linear(embed_dim * num_stacked, embed_dim, bias=False) + + def forward(self, input): + if self.num_stacked == 1: + return super().forward(input) + + # expand input indices + mask = input >= self.offset + stacked_input = [] + cum_input = input.new_zeros(input.shape) + for i in range(1, self.num_stacked + 1): + div = pow(self.vocab_size, i) + next_input = torch.remainder(input - self.offset - cum_input, div) + cum_input += next_input + next_input = torch.floor_divide(next_input, div // self.vocab_size) + stacked_input.append((next_input + self.offset) * mask + input * ~mask) + + stacked_input = torch.stack(stacked_input[::-1], dim=2) + embed = super().forward(stacked_input).view(input.size(0), input.size(1), -1) + embed = self.project_in_dim(embed) + return embed diff --git a/fairseq/models/speech_to_speech/modules/transformer_decoder_aug.py b/fairseq/models/speech_to_speech/modules/transformer_decoder_aug.py new file mode 100644 index 0000000000000000000000000000000000000000..68f42c2b3633a330a605ed169fd527cded73d9f0 --- /dev/null +++ b/fairseq/models/speech_to_speech/modules/transformer_decoder_aug.py @@ -0,0 +1,108 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Optional + +from torch import Tensor + +from fairseq.models.transformer import Linear +from fairseq.models.transformer.transformer_decoder_aug import AugTransformerDecoder + + +class AugTransformerUnitDecoder(AugTransformerDecoder): + """Based on Transformer decoder, with support to decoding stacked units""" + + def __init__( + self, + args, + dictionary, + embed_tokens, + no_encoder_attn=False, + output_projection=None, + ): + super().__init__( + args, dictionary, embed_tokens, no_encoder_attn, output_projection + ) + self.n_frames_per_step = args.n_frames_per_step + + self.out_proj_n_frames = ( + Linear( + self.output_embed_dim, + self.output_embed_dim * self.n_frames_per_step, + bias=False, + ) + if self.n_frames_per_step > 1 + else None + ) + + def forward( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, + encoder_out_aug: Optional[Dict[str, List[Tensor]]] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + features_only: bool = False, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + src_lengths: Optional[Any] = None, + return_all_hiddens: bool = False, + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (optional): output from the encoder, used for + encoder-side attention, should be of size T x B x C + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + features_only (bool, optional): only return features without + applying output layer (default: False). + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + + x, extra = self.extract_features( + prev_output_tokens, + encoder_out=encoder_out, + encoder_out_aug=encoder_out_aug, + incremental_state=incremental_state, + full_context_alignment=full_context_alignment, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + ) + + if not features_only: + bsz, seq_len, d = x.size() + if self.out_proj_n_frames: + x = self.out_proj_n_frames(x) + x = self.output_layer(x.view(bsz, seq_len, self.n_frames_per_step, d)) + x = x.view(bsz, seq_len * self.n_frames_per_step, -1) + if ( + incremental_state is None and self.n_frames_per_step > 1 + ): # teacher-forcing mode in training + x = x[ + :, : -(self.n_frames_per_step - 1), : + ] # remove extra frames after + + return x, extra + + def upgrade_state_dict_named(self, state_dict, name): + if self.n_frames_per_step > 1: + move_keys = [ + ( + f"{name}.project_in_dim.weight", + f"{name}.embed_tokens.project_in_dim.weight", + ) + ] + for from_k, to_k in move_keys: + if from_k in state_dict and to_k not in state_dict: + state_dict[to_k] = state_dict[from_k] + del state_dict[from_k] diff --git a/fairseq/models/speech_to_speech/modules/transformer_encoder.py b/fairseq/models/speech_to_speech/modules/transformer_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..fb1af433d81972d2190218c8b4fd18ac2e946150 --- /dev/null +++ b/fairseq/models/speech_to_speech/modules/transformer_encoder.py @@ -0,0 +1,85 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn + +from fairseq.models import FairseqEncoder +from fairseq.modules import LayerNorm, TransformerEncoderLayer + + +class TransformerEncoderNoEmb(FairseqEncoder): + """Transformer encoder without token embeddings.""" + + def __init__(self, args): + super().__init__(None) + + self.layers = nn.ModuleList( + [TransformerEncoderLayer(args) for _ in range(args.encoder_layers)] + ) + if args.encoder_normalize_before: + self.layer_norm = LayerNorm(args.encoder_embed_dim) + else: + self.layer_norm = None + + def forward(self, x, encoder_padding_mask, return_all_hiddens=False): + + encoder_states = [] + + for layer in self.layers: + x = layer(x, encoder_padding_mask) + if return_all_hiddens: + encoder_states.append(x) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [encoder_padding_mask] + if encoder_padding_mask is not None and encoder_padding_mask.any() + else [], # B x T + "encoder_embedding": [], # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], + "src_lengths": [], + } + + def reorder_encoder_out(self, encoder_out, new_order): + new_encoder_out = ( + [] + if len(encoder_out["encoder_out"]) == 0 + else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]] + ) + + new_encoder_padding_mask = ( + [] + if len(encoder_out["encoder_padding_mask"]) == 0 + else [ + x.index_select(0, new_order) + for x in encoder_out["encoder_padding_mask"] + ] + ) + + new_encoder_embedding = ( + [] + if len(encoder_out["encoder_embedding"]) == 0 + else [ + x.index_select(0, new_order) for x in encoder_out["encoder_embedding"] + ] + ) + + encoder_states = encoder_out["encoder_states"] + if len(encoder_states) > 0: + for idx, state in enumerate(encoder_states): + encoder_states[idx] = state.index_select(1, new_order) + + return { + "encoder_out": new_encoder_out, # T x B x C + "encoder_padding_mask": new_encoder_padding_mask, # B x T + "encoder_embedding": new_encoder_embedding, # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], # B x T + "src_lengths": [], # B x 1 + } diff --git a/fairseq/models/speech_to_speech/s2s_conformer.py b/fairseq/models/speech_to_speech/s2s_conformer.py new file mode 100644 index 0000000000000000000000000000000000000000..636396d536689772e44c29506e4a49b683562f37 --- /dev/null +++ b/fairseq/models/speech_to_speech/s2s_conformer.py @@ -0,0 +1,172 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from pathlib import Path + +import torch + +from fairseq import checkpoint_utils +from fairseq.models import register_model, register_model_architecture +from fairseq.models.speech_to_speech.s2s_transformer import ( + S2SpecTTransformerModel, + S2UTTransformerModel, + s2spect_architecture_base, + s2ut_architecture_base, +) +from fairseq.models.speech_to_text import S2TConformerEncoder +from fairseq.models.transformer import Linear + +logger = logging.getLogger(__name__) + + +def build_s2s_conformer_encoder(args): + encoder = S2SConformerEncoder(args) + pretraining_path = getattr(args, "load_pretrained_encoder_from", None) + if pretraining_path is not None: + if not Path(pretraining_path).exists(): + logger.warning( + f"skipped pretraining because {pretraining_path} does not exist" + ) + else: + encoder = checkpoint_utils.load_pretrained_component_from_model( + component=encoder, checkpoint=pretraining_path + ) + logger.info(f"loaded pretrained encoder from: {pretraining_path}") + return encoder + + +class S2SConformerEncoder(S2TConformerEncoder): + """Based on S2T transformer encoder, with support + to incorporate target speaker embedding.""" + + def __init__(self, args): + super().__init__(args) + + self.spk_emb_proj = None + if args.target_speaker_embed: + self.spk_emb_proj = Linear( + args.encoder_embed_dim + args.speaker_embed_dim, args.encoder_embed_dim + ) + + def forward( + self, src_tokens, src_lengths, tgt_speaker=None, return_all_hiddens=False + ): + out = super().forward(src_tokens, src_lengths, return_all_hiddens) + + if self.spk_emb_proj: + x = out["encoder_out"][0] + seq_len, bsz, _ = x.size() + tgt_speaker_emb = tgt_speaker.view(1, bsz, -1).expand(seq_len, bsz, -1) + x = self.spk_emb_proj(torch.cat([x, tgt_speaker_emb], dim=2)) + out["encoder_out"][0] = x + + return out + + +@register_model("s2ut_conformer") +class S2UTConformerModel(S2UTTransformerModel): + """ + Direct speech-to-speech translation model with Conformer encoder + Transformer discrete unit decoder + """ + + @staticmethod + def add_args(parser): + S2UTTransformerModel.add_args(parser) + parser.add_argument( + "--depthwise-conv-kernel-size", + type=int, + metavar="N", + help="kernel size of depthwise convolution layers", + ) + parser.add_argument( + "--attn-type", + type=str, + metavar="STR", + help="If not specified uses fairseq MHA. Other valid option is espnet for using conformer", + ) + parser.add_argument( + "--pos-enc-type", + type=str, + metavar="STR", + help="Must be specified in addition to attn-type=espnet for rel_pos and rope", + ) + + @classmethod + def build_encoder(cls, args): + return build_s2s_conformer_encoder(args) + + +@register_model("s2spect_conformer") +class S2SpecTConformerModel(S2SpecTTransformerModel): + """ + Direct speech-to-speech translation model with Conformer encoder + TTS Transformer decoder + """ + + @staticmethod + def add_args(parser): + S2SpecTTransformerModel.add_args(parser) + parser.add_argument("--depthwise-conv-kernel-size", type=int, default=31) + parser.add_argument( + "--attn-type", + type=str, + default=None, + help="If not specified uses fairseq MHA. Other valid option is espnet for using conformer", + ) + parser.add_argument( + "--pos-enc-type", + type=str, + default="abs", + help="Must be specified in addition to attn-type=espnet for rel_pos and rope", + ) + + @classmethod + def build_encoder(cls, args): + return build_s2s_conformer_encoder(args) + + +@register_model_architecture("s2ut_conformer", "s2ut_conformer") +def s2ut_conformer_architecture_base(args): + args.attn_type = getattr(args, "attn_type", None) + args.pos_enc_type = getattr(args, "pos_enc_type", "abs") + args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) + args.input_channels = getattr(args, "input_channels", 1) + args.max_source_positions = getattr(args, "max_source_positions", 6000) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.dropout = getattr(args, "dropout", 0.1) + args.encoder_layers = getattr(args, "encoder_layers", 16) + args.depthwise_conv_kernel_size = getattr(args, "depthwise_conv_kernel_size", 31) + s2ut_architecture_base(args) + + +@register_model_architecture("s2spect_conformer", "s2spect_conformer") +def s2spect_conformer_architecture_base(args): + args.attn_type = getattr(args, "attn_type", None) + args.pos_enc_type = getattr(args, "pos_enc_type", "abs") + args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) + args.input_channels = getattr(args, "input_channels", 1) + args.max_source_positions = getattr(args, "max_source_positions", 6000) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.dropout = getattr(args, "dropout", 0.1) + args.encoder_layers = getattr(args, "encoder_layers", 16) + args.depthwise_conv_kernel_size = getattr(args, "depthwise_conv_kernel_size", 31) + s2spect_architecture_base(args) + + +@register_model_architecture("s2spect_conformer", "s2spect_conformer_fisher") +def s2spect_architecture_fisher(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.dropout = getattr(args, "dropout", 0.1) + + # decoder + args.prenet_dim = getattr(args, "prenet_dim", 32) + + s2spect_conformer_architecture_base(args) diff --git a/fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py b/fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py new file mode 100644 index 0000000000000000000000000000000000000000..8016daee8d62c1857307807cb925d4bb26aac6ec --- /dev/null +++ b/fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py @@ -0,0 +1,262 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import logging + +from fairseq.models import ( + FairseqEncoderModel, + FairseqLanguageModel, + register_model, + register_model_architecture, +) +from fairseq.models.speech_to_speech.modules.ctc_decoder import CTCDecoder +from fairseq.models.speech_to_speech.modules.transformer_encoder import ( + TransformerEncoderNoEmb, +) +from fairseq.models.speech_to_speech.s2s_conformer import S2SpecTConformerModel +from fairseq.models.speech_to_speech.s2s_conformer_unity import ( + multitask_text_transformer_decoder_arch, +) +from fairseq.models.speech_to_speech.s2s_transformer import ( + base_multitask_text_transformer_decoder_arch, + s2spect_architecture_base, +) +from fairseq.models.text_to_speech import TTSTransformerDecoder +from fairseq.models.transformer import TransformerDecoder, TransformerModelBase + +logger = logging.getLogger(__name__) + + +@register_model("s2spect2_conformer") +class S2SpecT2ConformerModel(S2SpecTConformerModel): + """ + Direct speech-to-speech translation model with Conformer encoder + MT Transformer decoder + TTS Transformer decoder + """ + + @staticmethod + def add_args(parser): + S2SpecTConformerModel.add_args(parser) + parser.add_argument( + "--translation-decoder-layers", + type=int, + default=4, + metavar="N", + help="num decoder layers in the first-pass translation module", + ) + parser.add_argument( + "--synthesizer", + default="transformer", + choices=["transformer"], + help="", + ) + parser.add_argument( + "--synthesizer-encoder-layers", + type=int, + default=0, + metavar="N", + help="num encoder layers in the second-pass synthesizer module", + ) + + @classmethod + def build_multitask_decoder( + cls, + args, + tgt_dict, + in_dim, + is_mt_decoder, + decoder_layers, + decoder_embed_dim, + decoder_attention_heads, + ): + decoder_args = args.decoder_args + decoder_args.encoder_embed_dim = in_dim + if args.decoder_type == "transformer": + if is_mt_decoder: + multitask_text_transformer_decoder_arch( + decoder_args, + decoder_layers, + decoder_embed_dim, + decoder_attention_heads, + ) # 4L + else: + base_multitask_text_transformer_decoder_arch(decoder_args) # 2L + task_decoder = TransformerDecoder( + decoder_args, + tgt_dict, + embed_tokens=TransformerModelBase.build_embedding( + decoder_args, + tgt_dict, + decoder_args.decoder_embed_dim, + ), + ) + elif args.decoder_type == "ctc": + task_decoder = CTCDecoder( + dictionary=tgt_dict, + in_dim=in_dim, + ) + else: + raise NotImplementedError( + "currently only support multitask decoder_type 'transformer', 'ctc'" + ) + + return task_decoder + + @classmethod + def build_decoder(cls, args): + _args = copy.deepcopy(args) + _args.encoder_embed_dim = args.decoder_embed_dim + + if args.synthesizer == "transformer": + return TTSTransformerDecoder(_args, None, padding_idx=1) + else: + raise NotImplementedError(args.synthesizer) + + @classmethod + def build_model(cls, args, task): + encoder = cls.build_encoder(args) + decoder = cls.build_decoder(args) + base_model = cls(encoder, decoder) + + # set up multitask decoders + base_model.mt_task_name = None + base_model.multitask_decoders = {} + has_first_pass_decoder = False + for task_name, task_obj in task.multitask_tasks.items(): + if task_obj.is_first_pass_decoder: + has_first_pass_decoder = True + base_model.mt_task_name = task_name + + in_dim = ( + args.encoder_embed_dim + if task_obj.args.input_from == "encoder" + else args.decoder_embed_dim + ) + task_decoder = cls.build_multitask_decoder( + task_obj.args, + task_obj.target_dictionary, + in_dim, + task_obj.is_first_pass_decoder, + getattr(args, "translation_decoder_layers", 4), + getattr(args, "decoder_embed_dim", 256), + getattr(args, "decoder_attention_heads", 4), + ) + + setattr(base_model, f"{task_name}_decoder", task_decoder) + decoder_model_cls = ( + FairseqEncoderModel + if task_obj.args.decoder_type == "ctc" + else FairseqLanguageModel + ) + base_model.multitask_decoders[task_name] = decoder_model_cls( + getattr(base_model, f"{task_name}_decoder") + ) + + assert has_first_pass_decoder, "set at least one intermediate non-CTC decoder" + + # set up encoder on top of the auxiliary MT decoder + if getattr(args, "synthesizer_encoder_layers", 0) > 0: + base_model.synthesizer_encoder = cls.build_text_encoder(args) + else: + base_model.synthesizer_encoder = None + + return base_model + + @classmethod + def build_text_encoder(cls, args): + _args = copy.deepcopy(args) + _args.encoder_layers = args.synthesizer_encoder_layers + _args.encoder_embed_dim = args.decoder_embed_dim + _args.encoder_ffn_embed_dim = args.decoder_ffn_embed_dim + _args.encoder_attention_heads = args.decoder_attention_heads + _args.encoder_normalize_before = True + return TransformerEncoderNoEmb(_args) + + def forward( + self, + src_tokens, + src_lengths, + prev_output_tokens, + prev_output_tokens_mt, + tgt_speaker=None, + incremental_state=None, + target_lengths=None, + speaker=None, + return_all_hiddens=False, + ): + encoder_out = self.encoder( + src_tokens, + src_lengths=src_lengths, + tgt_speaker=tgt_speaker, + return_all_hiddens=return_all_hiddens, + ) + + # 1. MT decoder + mt_decoder = getattr(self, f"{self.mt_task_name}_decoder") + mt_decoder_out = mt_decoder( + prev_output_tokens_mt, + encoder_out=encoder_out, + ) + x = mt_decoder_out[1]["inner_states"][-1] + if mt_decoder.layer_norm is not None: + x = mt_decoder.layer_norm(x) + + mt_decoder_padding_mask = None + if prev_output_tokens_mt.eq(mt_decoder.padding_idx).any(): + mt_decoder_padding_mask = prev_output_tokens_mt.eq(mt_decoder.padding_idx) + + # 2. TTS encoder + if self.synthesizer_encoder is not None: + tts_encoder_out = self.synthesizer_encoder( + x, + mt_decoder_padding_mask, + return_all_hiddens=return_all_hiddens, + ) + else: + tts_encoder_out = { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [mt_decoder_padding_mask], # B x T + } + + # 3. TTS decoder + decoder_out = self.decoder( + prev_output_tokens, + encoder_out=tts_encoder_out, + incremental_state=incremental_state, + target_lengths=target_lengths, + speaker=speaker, + ) + if return_all_hiddens: + decoder_out[-1]["encoder_states"] = encoder_out["encoder_states"] + decoder_out[-1]["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ] + decoder_out[-1]["mt_decoder_out"] = mt_decoder_out + return decoder_out + + +@register_model_architecture( + model_name="s2spect2_conformer", arch_name="s2spect2_conformer" +) +def s2spect2_conformer_architecture_base(args): + args.conv_version = getattr(args, "conv_version", "convtransformer") + args.attn_type = getattr(args, "attn_type", None) + args.pos_enc_type = getattr(args, "pos_enc_type", "abs") + args.max_source_positions = getattr(args, "max_source_positions", 6000) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.dropout = getattr(args, "dropout", 0.1) + args.encoder_layers = getattr(args, "encoder_layers", 16) + args.depthwise_conv_kernel_size = getattr(args, "depthwise_conv_kernel_size", 31) + s2spect_architecture_base(args) + + +# for old naming +@register_model_architecture( + model_name="s2spect2_conformer", arch_name="s2spect_conformer_translatotron2" +) +def s2spect2_conformer_architecture_base_legacy(args): + s2spect2_conformer_architecture_base(args) diff --git a/fairseq/models/speech_to_speech/s2s_conformer_unity.py b/fairseq/models/speech_to_speech/s2s_conformer_unity.py new file mode 100644 index 0000000000000000000000000000000000000000..64388d6d1688c338bc0c11cc23ab985a0c652e36 --- /dev/null +++ b/fairseq/models/speech_to_speech/s2s_conformer_unity.py @@ -0,0 +1,298 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import logging + +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderModel, + FairseqLanguageModel, + register_model, + register_model_architecture, +) +from fairseq.models.speech_to_speech.modules.ctc_decoder import CTCDecoder +from fairseq.models.speech_to_speech.modules.stacked_embedding import StackedEmbedding +from fairseq.models.speech_to_speech.modules.transformer_decoder_aug import ( + AugTransformerUnitDecoder, +) +from fairseq.models.speech_to_speech.modules.transformer_encoder import ( + TransformerEncoderNoEmb, +) +from fairseq.models.speech_to_speech.s2s_conformer import S2UTConformerModel +from fairseq.models.speech_to_speech.s2s_transformer import ( + TransformerUnitDecoder, + base_multitask_text_transformer_decoder_arch, + s2ut_architecture_base, +) +from fairseq.models.transformer import TransformerDecoder, TransformerModelBase + +logger = logging.getLogger(__name__) + + +def multitask_text_transformer_decoder_arch( + args, decoder_layers, decoder_embed_dim=256, decoder_attention_heads=4 +): + args.decoder_layers = decoder_layers + args.decoder_embed_dim = decoder_embed_dim + args.decoder_attention_heads = decoder_attention_heads + base_multitask_text_transformer_decoder_arch(args) + + +@register_model("unity_conformer") +class UnityConformerModel(S2UTConformerModel): + """ + Direct speech-to-speech translation model with Conformer encoder + MT Transformer decoder + Transformer discrete unit decoder + """ + + @staticmethod + def add_args(parser): + S2UTConformerModel.add_args(parser) + parser.add_argument( + "--translation-decoder-layers", + type=int, + default=4, + metavar="N", + help="num decoder layers in the first-pass translation module", + ) + parser.add_argument( + "--synthesizer", + default="transformer", + choices=["transformer"], + help="", + ) + parser.add_argument( + "--synthesizer-encoder-layers", + type=int, + default=0, + metavar="N", + help="num encoder layers in the second-pass synthesizer module", + ) + parser.add_argument( + "--synthesizer-augmented-cross-attention", + action="store_true", + default=False, + help="augmented cross-attention over speech encoder output", + ) + + @classmethod + def build_multitask_decoder( + cls, + args, + tgt_dict, + in_dim, + is_first_pass_decoder, + decoder_layers, + decoder_embed_dim, + decoder_attention_heads, + ): + decoder_args = args.decoder_args + decoder_args.encoder_embed_dim = in_dim + if args.decoder_type == "transformer": + if is_first_pass_decoder: + multitask_text_transformer_decoder_arch( + decoder_args, + decoder_layers, + decoder_embed_dim, + decoder_attention_heads, + ) # 4L + else: + base_multitask_text_transformer_decoder_arch(decoder_args) # 2L + task_decoder = TransformerDecoder( + decoder_args, + tgt_dict, + embed_tokens=TransformerModelBase.build_embedding( + decoder_args, + tgt_dict, + decoder_args.decoder_embed_dim, + ), + ) + elif args.decoder_type == "ctc": + task_decoder = CTCDecoder( + dictionary=tgt_dict, + in_dim=in_dim, + ) + else: + raise NotImplementedError( + "currently only support multitask decoder_type 'transformer', 'ctc'" + ) + + return task_decoder + + @classmethod + def build_decoder(cls, args, tgt_dict, aug_attn=False): + num_embeddings = len(tgt_dict) + padding_idx = tgt_dict.pad() + embed_tokens = StackedEmbedding( + num_embeddings, + args.decoder_embed_dim, + padding_idx, + num_stacked=args.n_frames_per_step, + ) + + _args = copy.deepcopy(args) + _args.encoder_embed_dim = args.decoder_embed_dim + + decoder_cls = AugTransformerUnitDecoder if aug_attn else TransformerUnitDecoder + return decoder_cls( + _args, + tgt_dict, + embed_tokens, + ) + + @classmethod + def build_model(cls, args, task): + encoder = cls.build_encoder(args) + decoder = cls.build_decoder( + args, + task.target_dictionary, + aug_attn=getattr(args, "synthesizer_augmented_cross_attention", False), + ) + base_model = cls(encoder, decoder) + + base_model.t2u_augmented_cross_attn = getattr( + args, "synthesizer_augmented_cross_attention", False + ) + + # set up multitask decoders + base_model.mt_task_name = None + base_model.multitask_decoders = {} + has_first_pass_decoder = False + for task_name, task_obj in task.multitask_tasks.items(): + if task_obj.is_first_pass_decoder: + has_first_pass_decoder = True + base_model.mt_task_name = task_name + + in_dim = ( + args.encoder_embed_dim + if task_obj.args.input_from == "encoder" + else args.decoder_embed_dim + ) + task_decoder = cls.build_multitask_decoder( + task_obj.args, + task_obj.target_dictionary, + in_dim, + task_obj.is_first_pass_decoder, + getattr(args, "translation_decoder_layers", 4), + getattr(args, "decoder_embed_dim", 256), + getattr(args, "decoder_attention_heads", 4), + ) + + setattr(base_model, f"{task_name}_decoder", task_decoder) + decoder_model_cls = ( + FairseqEncoderModel + if task_obj.args.decoder_type == "ctc" + else FairseqLanguageModel + ) + base_model.multitask_decoders[task_name] = decoder_model_cls( + getattr(base_model, f"{task_name}_decoder") + ) + + assert has_first_pass_decoder, "set at least one intermediate non-CTC decoder" + + # set up encoder on top of the auxiliary MT decoder + if getattr(args, "synthesizer_encoder_layers", 0) > 0: + base_model.synthesizer_encoder = cls.build_text_encoder(args) + else: + base_model.synthesizer_encoder = None + + return base_model + + @classmethod + def build_text_encoder(cls, args): + _args = copy.deepcopy(args) + _args.encoder_layers = args.synthesizer_encoder_layers + _args.encoder_embed_dim = args.decoder_embed_dim + _args.encoder_ffn_embed_dim = args.decoder_ffn_embed_dim + _args.encoder_attention_heads = args.decoder_attention_heads + _args.encoder_normalize_before = True + return TransformerEncoderNoEmb(_args) + + def forward( + self, + src_tokens, + src_lengths, + prev_output_tokens, + prev_output_tokens_mt, + tgt_speaker=None, + return_all_hiddens=False, + ): + mt_decoder = getattr(self, f"{self.mt_task_name}_decoder") + + encoder_out = self.encoder( + src_tokens, + src_lengths=src_lengths, + tgt_speaker=tgt_speaker, + return_all_hiddens=return_all_hiddens, + ) + + # 1. MT decoder + mt_decoder_out = mt_decoder( + prev_output_tokens_mt, + encoder_out=encoder_out, + ) + x = mt_decoder_out[1]["inner_states"][-1] + if mt_decoder.layer_norm is not None: + x = mt_decoder.layer_norm(x) + + mt_decoder_padding_mask = None + if prev_output_tokens_mt.eq(mt_decoder.padding_idx).any(): + mt_decoder_padding_mask = prev_output_tokens_mt.eq(mt_decoder.padding_idx) + + # 2. T2U encoder + if self.synthesizer_encoder is not None: + t2u_encoder_out = self.synthesizer_encoder( + x, + mt_decoder_padding_mask, + return_all_hiddens=return_all_hiddens, + ) + else: + t2u_encoder_out = { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [mt_decoder_padding_mask], # B x T + } + + # 3. T2U decoder + if self.t2u_augmented_cross_attn: + decoder_out = self.decoder( + prev_output_tokens, + encoder_out=encoder_out, + encoder_out_aug=t2u_encoder_out, + ) + else: + decoder_out = self.decoder( + prev_output_tokens, + encoder_out=t2u_encoder_out, + ) + if return_all_hiddens: + decoder_out[-1]["encoder_states"] = encoder_out["encoder_states"] + decoder_out[-1]["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ] + decoder_out[-1]["mt_decoder_out"] = mt_decoder_out + return decoder_out + + +@register_model_architecture(model_name="unity_conformer", arch_name="unity_conformer") +def unity_conformer_architecture_base(args): + args.conv_version = getattr(args, "conv_version", "convtransformer") + args.attn_type = getattr(args, "attn_type", None) + args.pos_enc_type = getattr(args, "pos_enc_type", "abs") + args.max_source_positions = getattr(args, "max_source_positions", 6000) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.dropout = getattr(args, "dropout", 0.1) + args.encoder_layers = getattr(args, "encoder_layers", 16) + args.depthwise_conv_kernel_size = getattr(args, "depthwise_conv_kernel_size", 31) + s2ut_architecture_base(args) + + +# for old naming +@register_model_architecture( + model_name="unity_conformer", arch_name="s2ut_conformer_translatotron2" +) +def unity_conformer_architecture_base_legacy(args): + unity_conformer_architecture_base(args) diff --git a/fairseq/models/speech_to_speech/s2s_transformer.py b/fairseq/models/speech_to_speech/s2s_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..07393d2598c0a38ec98bc2b66abd7dfa8af18825 --- /dev/null +++ b/fairseq/models/speech_to_speech/s2s_transformer.py @@ -0,0 +1,722 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from pathlib import Path +from typing import Any, Dict, List, Optional + +import torch +from torch import Tensor + +from fairseq import checkpoint_utils, utils +from fairseq.models import ( + FairseqEncoderDecoderModel, + FairseqEncoderModel, + FairseqLanguageModel, + register_model, + register_model_architecture, +) +from fairseq.models.speech_to_speech.modules.ctc_decoder import CTCDecoder +from fairseq.models.speech_to_speech.modules.stacked_embedding import StackedEmbedding +from fairseq.models.speech_to_text import S2TTransformerEncoder +from fairseq.models.text_to_speech import TTSTransformerDecoder +from fairseq.models.transformer import Linear, TransformerDecoder, TransformerModelBase + +logger = logging.getLogger(__name__) + + +class S2STransformerEncoder(S2TTransformerEncoder): + """Based on S2T transformer encoder, with support + to incorporate target speaker embedding.""" + + def __init__(self, args): + super().__init__(args) + + self.spk_emb_proj = None + if args.target_speaker_embed: + self.spk_emb_proj = Linear( + args.encoder_embed_dim + args.speaker_embed_dim, args.encoder_embed_dim + ) + + def forward( + self, src_tokens, src_lengths, tgt_speaker=None, return_all_hiddens=False + ): + out = super().forward(src_tokens, src_lengths, return_all_hiddens) + + if self.spk_emb_proj: + x = out["encoder_out"][0] + seq_len, bsz, _ = x.size() + tgt_speaker_emb = tgt_speaker.view(1, bsz, -1).expand(seq_len, bsz, -1) + x = self.spk_emb_proj(torch.cat([x, tgt_speaker_emb], dim=2)) + out["encoder_out"][0] = x + + return out + + +class TransformerUnitDecoder(TransformerDecoder): + """Based on Transformer decoder, with support to decoding stacked units""" + + def __init__( + self, + args, + dictionary, + embed_tokens, + no_encoder_attn=False, + output_projection=None, + ): + super().__init__( + args, dictionary, embed_tokens, no_encoder_attn, output_projection + ) + self.n_frames_per_step = args.n_frames_per_step + + self.out_proj_n_frames = ( + Linear( + self.output_embed_dim, + self.output_embed_dim * self.n_frames_per_step, + bias=False, + ) + if self.n_frames_per_step > 1 + else None + ) + + def forward( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + features_only: bool = False, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + src_lengths: Optional[Any] = None, + return_all_hiddens: bool = False, + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (optional): output from the encoder, used for + encoder-side attention, should be of size T x B x C + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + features_only (bool, optional): only return features without + applying output layer (default: False). + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + + x, extra = self.extract_features( + prev_output_tokens, + encoder_out=encoder_out, + incremental_state=incremental_state, + full_context_alignment=full_context_alignment, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + ) + + if not features_only: + bsz, seq_len, d = x.size() + if self.out_proj_n_frames: + x = self.out_proj_n_frames(x) + x = self.output_layer(x.view(bsz, seq_len, self.n_frames_per_step, d)) + x = x.view(bsz, seq_len * self.n_frames_per_step, -1) + if ( + incremental_state is None and self.n_frames_per_step > 1 + ): # teacher-forcing mode in training + x = x[ + :, : -(self.n_frames_per_step - 1), : + ] # remove extra frames after + + return x, extra + + def upgrade_state_dict_named(self, state_dict, name): + if self.n_frames_per_step > 1: + move_keys = [ + ( + f"{name}.project_in_dim.weight", + f"{name}.embed_tokens.project_in_dim.weight", + ) + ] + for from_k, to_k in move_keys: + if from_k in state_dict and to_k not in state_dict: + state_dict[to_k] = state_dict[from_k] + del state_dict[from_k] + + +class S2STransformerMultitaskModelBase(FairseqEncoderDecoderModel): + @classmethod + def build_encoder(cls, args): + encoder = S2STransformerEncoder(args) + pretraining_path = getattr(args, "load_pretrained_encoder_from", None) + if pretraining_path is not None: + if not Path(pretraining_path).exists(): + logger.warning( + f"skipped pretraining because {pretraining_path} does not exist" + ) + else: + encoder = checkpoint_utils.load_pretrained_component_from_model( + component=encoder, checkpoint=pretraining_path + ) + logger.info(f"loaded pretrained encoder from: {pretraining_path}") + return encoder + + @classmethod + def build_multitask_decoder(cls, args, tgt_dict, in_dim): + decoder_args = args.decoder_args + decoder_args.encoder_embed_dim = in_dim + if args.decoder_type == "transformer": + base_multitask_text_transformer_decoder_arch(decoder_args) + task_decoder = TransformerDecoder( + decoder_args, + tgt_dict, + embed_tokens=TransformerModelBase.build_embedding( + decoder_args, + tgt_dict, + decoder_args.decoder_embed_dim, + ), + ) + elif args.decoder_type == "ctc": + task_decoder = CTCDecoder( + dictionary=tgt_dict, + in_dim=in_dim, + ) + else: + raise NotImplementedError( + "currently only support multitask decoder_type 'transformer', 'ctc'" + ) + + return task_decoder + + @classmethod + def build_model(cls, args, task): + encoder = cls.build_encoder(args) + decoder = ( + cls.build_decoder(args, task.target_dictionary) + if task.args.target_is_code + else cls.build_decoder(args) + ) + base_model = cls(encoder, decoder) + + # set up multitask decoders + base_model.multitask_decoders = {} + for task_name, task_obj in task.multitask_tasks.items(): + in_dim = ( + args.encoder_embed_dim + if task_obj.args.input_from == "encoder" + else args.decoder_embed_dim + ) + task_decoder = cls.build_multitask_decoder( + task_obj.args, task_obj.target_dictionary, in_dim + ) + + setattr(base_model, f"{task_name}_decoder", task_decoder) + decoder_model_cls = ( + FairseqEncoderModel + if task_obj.args.decoder_type == "ctc" + else FairseqLanguageModel + ) + base_model.multitask_decoders[task_name] = decoder_model_cls( + getattr(base_model, f"{task_name}_decoder") + ) + + return base_model + + def forward_encoder(self, src_tokens, src_lengths, speaker=None, **kwargs): + return self.encoder( + src_tokens, src_lengths=src_lengths, tgt_speaker=speaker, **kwargs + ) + + +@register_model("s2ut_transformer") +class S2UTTransformerModel(S2STransformerMultitaskModelBase): + """ + Direct speech-to-speech translation model with Transformer encoder + Transformer discrete unit decoder + https://arxiv.org/abs/2107.05604 + """ + + @staticmethod + def add_args(parser): + # input + parser.add_argument( + "--conv-kernel-sizes", + type=str, + metavar="STR", + help="kernel sizes of Conv1d (s2t_transformer) subsampling layers", + ) + parser.add_argument( + "--conv-channels", + type=int, + metavar="N", + help="# of channels in Conv1d (s2t_transformer) subsampling layers", + ) + parser.add_argument( + "--conv-out-channels", + type=int, + metavar="N", + help="# of channels in Conv2d (convtransformer) subsampling layers", + ) + parser.add_argument( + "--conv-version", + type=str, + default="s2t_transformer", + choices=["s2t_transformer", "convtransformer"], + help="version of frontend convolutional layers", + ) + # Transformer + parser.add_argument( + "--activation-fn", + type=str, + default="relu", + choices=utils.get_available_activation_fns(), + help="activation function to use", + ) + parser.add_argument( + "--dropout", type=float, metavar="D", help="dropout probability" + ) + parser.add_argument( + "--attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights", + ) + parser.add_argument( + "--activation-dropout", + "--relu-dropout", + type=float, + metavar="D", + help="dropout probability after activation in FFN.", + ) + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension", + ) + parser.add_argument( + "--encoder-ffn-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension for FFN", + ) + parser.add_argument( + "--encoder-layers", type=int, metavar="N", help="num encoder layers" + ) + parser.add_argument( + "--encoder-attention-heads", + type=int, + metavar="N", + help="num encoder attention heads", + ) + parser.add_argument( + "--encoder-normalize-before", + action="store_true", + help="apply layernorm before each encoder block", + ) + parser.add_argument( + "--decoder-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension", + ) + parser.add_argument( + "--decoder-ffn-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension for FFN", + ) + parser.add_argument( + "--decoder-layers", type=int, metavar="N", help="num decoder layers" + ) + parser.add_argument( + "--decoder-attention-heads", + type=int, + metavar="N", + help="num decoder attention heads", + ) + parser.add_argument( + "--decoder-normalize-before", + action="store_true", + help="apply layernorm before each decoder block", + ) + parser.add_argument( + "--share-decoder-input-output-embed", + action="store_true", + help="share decoder input and output embeddings", + ) + parser.add_argument( + "--layernorm-embedding", + action="store_true", + help="add layernorm to embedding", + ) + parser.add_argument( + "--no-scale-embedding", + action="store_true", + help="if True, dont scale embeddings", + ) + parser.add_argument( + "--load-pretrained-encoder-from", + type=str, + metavar="STR", + help="model to take encoder weights from (for initialization)", + ) + parser.add_argument( + "--encoder-freezing-updates", + type=int, + metavar="N", + help="freeze encoder for first N updates", + ) + # speaker + parser.add_argument( + "--speaker-embed-dim", + type=int, + metavar="N", + help="speaker embedding dimension", + ) + + @classmethod + def build_decoder(cls, args, tgt_dict): + num_embeddings = len(tgt_dict) + padding_idx = tgt_dict.pad() + embed_tokens = StackedEmbedding( + num_embeddings, + args.decoder_embed_dim, + padding_idx, + num_stacked=args.n_frames_per_step, + ) + + return TransformerUnitDecoder( + args, + tgt_dict, + embed_tokens, + ) + + def forward( + self, + src_tokens, + src_lengths, + prev_output_tokens, + tgt_speaker=None, + return_all_hiddens=False, + ): + encoder_out = self.encoder( + src_tokens, + src_lengths=src_lengths, + tgt_speaker=tgt_speaker, + return_all_hiddens=return_all_hiddens, + ) + decoder_out = self.decoder( + prev_output_tokens, + encoder_out=encoder_out, + ) + if return_all_hiddens: + decoder_out[-1]["encoder_states"] = encoder_out["encoder_states"] + decoder_out[-1]["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ] + return decoder_out + + +@register_model("s2spect_transformer") +class S2SpecTTransformerModel(S2STransformerMultitaskModelBase): + """ + Speech-to-spectrogram model with S2T Transformer encoder + TTS Transformer decoder + """ + + @staticmethod + def add_args(parser): + # input + parser.add_argument( + "--conv-kernel-sizes", + type=str, + metavar="STR", + help="kernel sizes of Conv1d (s2t_transformer) subsampling layers", + ) + parser.add_argument( + "--conv-channels", + type=int, + metavar="N", + help="# of channels in Conv1d (s2t_transformer) subsampling layers", + ) + parser.add_argument( + "--conv-version", + type=str, + default="s2t_transformer", + choices=["s2t_transformer", "convtransformer"], + help="version of frontend convolutional layers", + ) + # Transformer + parser.add_argument( + "--activation-fn", + type=str, + default="relu", + choices=utils.get_available_activation_fns(), + help="activation function to use", + ) + parser.add_argument( + "--dropout", type=float, metavar="D", help="dropout probability" + ) + parser.add_argument( + "--attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights", + ) + parser.add_argument( + "--activation-dropout", + "--relu-dropout", + type=float, + metavar="D", + help="dropout probability after activation in FFN.", + ) + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension", + ) + parser.add_argument( + "--encoder-ffn-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension for FFN", + ) + parser.add_argument( + "--encoder-layers", type=int, metavar="N", help="num encoder layers" + ) + parser.add_argument( + "--encoder-attention-heads", + type=int, + metavar="N", + help="num encoder attention heads", + ) + parser.add_argument( + "--encoder-normalize-before", + action="store_true", + help="apply layernorm before each encoder block", + ) + parser.add_argument( + "--no-scale-embedding", + action="store_true", + help="if True, dont scale embeddings", + ) + parser.add_argument( + "--load-pretrained-encoder-from", + type=str, + metavar="STR", + help="model to take encoder weights from (for initialization)", + ) + parser.add_argument( + "--encoder-freezing-updates", + type=int, + metavar="N", + help="freeze encoder for first N updates", + ) + # speaker + parser.add_argument( + "--speaker-embed-dim", + type=int, + metavar="N", + help="speaker embedding dimension", + ) + # decoder + parser.add_argument("--output-frame-dim", type=int) + # decoder prenet + parser.add_argument("--prenet-dropout", type=float) + parser.add_argument("--prenet-layers", type=int) + parser.add_argument("--prenet-dim", type=int) + # decoder postnet + parser.add_argument("--postnet-dropout", type=float) + parser.add_argument("--postnet-layers", type=int) + parser.add_argument("--postnet-conv-dim", type=int) + parser.add_argument("--postnet-conv-kernel-size", type=int) + # decoder transformer layers + parser.add_argument("--decoder-transformer-layers", type=int) + parser.add_argument("--decoder-embed-dim", type=int) + parser.add_argument("--decoder-ffn-embed-dim", type=int) + parser.add_argument("--decoder-normalize-before", action="store_true") + parser.add_argument("--decoder-attention-heads", type=int) + + @classmethod + def build_decoder(cls, args): + return TTSTransformerDecoder(args, None, padding_idx=1) + + def forward( + self, + src_tokens, + src_lengths, + prev_output_tokens, + tgt_speaker=None, + incremental_state=None, + target_lengths=None, + speaker=None, + return_all_hiddens=False, + ): + encoder_out = self.encoder( + src_tokens, + src_lengths=src_lengths, + tgt_speaker=tgt_speaker, + return_all_hiddens=return_all_hiddens, + ) + decoder_out = self.decoder( + prev_output_tokens, + encoder_out=encoder_out, + incremental_state=incremental_state, + target_lengths=target_lengths, + speaker=speaker, + ) + if return_all_hiddens: + decoder_out[-1]["encoder_states"] = encoder_out["encoder_states"] + decoder_out[-1]["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ] + return decoder_out + + +def base_multitask_text_transformer_decoder_arch(args): + args.dropout = getattr(args, "dropout", 0.3) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", True + ) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256) + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + + args.max_target_positions = getattr(args, "max_target_positions", 1024) + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + + args.adaptive_input = getattr(args, "adaptive_input", False) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + + args.decoder_layers = getattr(args, "decoder_layers", 2) + + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + + # decoder layer + args.activation_dropout = getattr(args, "activation_dropout", args.dropout) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048) + + args.attention_dropout = getattr(args, "attention_dropout", args.dropout) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) + + +def base_s2st_transformer_encoder_architecture(args): + args.encoder_freezing_updates = getattr(args, "encoder_freezing_updates", 0) + + # Convolutional subsampler + args.input_channels = getattr(args, "input_channels", 1) + args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5") # for Conv1d + args.conv_channels = getattr(args, "conv_channels", 1024) # for Conv1d + args.conv_out_channels = getattr(args, "conv_out_channels", 256) # for Conv2d + args.conv_version = getattr(args, "conv_version", "s2t_transformer") + # Transformer + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_layers = getattr(args, "encoder_layers", 12) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", args.dropout) + args.activation_dropout = getattr(args, "activation_dropout", args.dropout) + args.activation_fn = getattr(args, "activation_fn", "relu") + + args.speaker_embed_dim = getattr(args, "speaker_embed_dim", 256) + + +@register_model_architecture( + model_name="s2ut_transformer", arch_name="s2ut_transformer" +) +def s2ut_architecture_base(args): + base_s2st_transformer_encoder_architecture(args) + + # decoder + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + + +@register_model_architecture("s2ut_transformer", "s2ut_transformer_fisher") +def s2ut_architecture_fisher(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.dropout = getattr(args, "dropout", 0.1) + + s2ut_architecture_base(args) + + +@register_model_architecture( + model_name="s2spect_transformer", arch_name="s2spect_transformer" +) +def s2spect_architecture_base(args): + base_s2st_transformer_encoder_architecture(args) + + # decoder + args.output_frame_dim = getattr(args, "output_frame_dim", 80) + # decoder prenet + args.prenet_dropout = getattr(args, "prenet_dropout", 0.5) + args.prenet_layers = getattr(args, "prenet_layers", 2) + args.prenet_dim = getattr(args, "prenet_dim", 256) + # decoder postnet + args.postnet_dropout = getattr(args, "postnet_dropout", 0.5) + args.postnet_layers = getattr(args, "postnet_layers", 5) + args.postnet_conv_dim = getattr(args, "postnet_conv_dim", 512) + args.postnet_conv_kernel_size = getattr(args, "postnet_conv_kernel_size", 5) + # decoder transformer layers + args.decoder_transformer_layers = getattr(args, "decoder_transformer_layers", 6) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", 4 * args.decoder_embed_dim + ) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) + + +@register_model_architecture("s2spect_transformer", "s2spect_transformer_fisher") +def s2spect_architecture_fisher(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.dropout = getattr(args, "dropout", 0.1) + + # decoder + args.prenet_dim = getattr(args, "prenet_dim", 32) + + s2spect_architecture_base(args) diff --git a/fairseq/models/speech_to_text/__init__.py b/fairseq/models/speech_to_text/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..62ef663efba1b6d745169bddbc5ba98adadbc1da --- /dev/null +++ b/fairseq/models/speech_to_text/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .berard import * # noqa +from .convtransformer import * # noqa +from .multi_modality_model import * # noqa +from .s2t_conformer import * # noqa +from .s2t_transformer import * # noqa +from .s2t_wav_transformer import * # noqa +from .xm_transformer import * # noqa +from .xm_transformer_unity import * # noqa diff --git a/fairseq/models/speech_to_text/__pycache__/__init__.cpython-310.pyc b/fairseq/models/speech_to_text/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1025353ecd047c660f4bbf665db404aebf188086 Binary files /dev/null and b/fairseq/models/speech_to_text/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/models/speech_to_text/__pycache__/berard.cpython-310.pyc b/fairseq/models/speech_to_text/__pycache__/berard.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59924b48d320b837ff00c86bd69733f5e57b8e87 Binary files /dev/null and b/fairseq/models/speech_to_text/__pycache__/berard.cpython-310.pyc differ diff --git a/fairseq/models/speech_to_text/__pycache__/convtransformer.cpython-310.pyc b/fairseq/models/speech_to_text/__pycache__/convtransformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88ee5732ffa24587fdf0a3217ad2495453646e01 Binary files /dev/null and b/fairseq/models/speech_to_text/__pycache__/convtransformer.cpython-310.pyc differ diff --git a/fairseq/models/speech_to_text/__pycache__/hub_interface.cpython-310.pyc b/fairseq/models/speech_to_text/__pycache__/hub_interface.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7028f079e18690f241ea1249965f4b53b8d4f7f2 Binary files /dev/null and b/fairseq/models/speech_to_text/__pycache__/hub_interface.cpython-310.pyc differ diff --git a/fairseq/models/speech_to_text/__pycache__/multi_modality_model.cpython-310.pyc b/fairseq/models/speech_to_text/__pycache__/multi_modality_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..923d4e2c6b0c24a584eca7b8665ef556d0122748 Binary files /dev/null and b/fairseq/models/speech_to_text/__pycache__/multi_modality_model.cpython-310.pyc differ diff --git a/fairseq/models/speech_to_text/__pycache__/s2t_conformer.cpython-310.pyc b/fairseq/models/speech_to_text/__pycache__/s2t_conformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ca24fcb6a9829ceed55715a9ea715140f951999 Binary files /dev/null and b/fairseq/models/speech_to_text/__pycache__/s2t_conformer.cpython-310.pyc differ diff --git a/fairseq/models/speech_to_text/__pycache__/s2t_transformer.cpython-310.pyc b/fairseq/models/speech_to_text/__pycache__/s2t_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bfcdfe63e4dd7bb399217b15eaaee5e745f25e2 Binary files /dev/null and b/fairseq/models/speech_to_text/__pycache__/s2t_transformer.cpython-310.pyc differ diff --git a/fairseq/models/speech_to_text/__pycache__/s2t_wav_transformer.cpython-310.pyc b/fairseq/models/speech_to_text/__pycache__/s2t_wav_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92c1c022046435467337feb762ff6e1175b6b524 Binary files /dev/null and b/fairseq/models/speech_to_text/__pycache__/s2t_wav_transformer.cpython-310.pyc differ diff --git a/fairseq/models/speech_to_text/__pycache__/utils.cpython-310.pyc b/fairseq/models/speech_to_text/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..460fe312384cdc9e647b75a198d5cc051ddb9767 Binary files /dev/null and b/fairseq/models/speech_to_text/__pycache__/utils.cpython-310.pyc differ diff --git a/fairseq/models/speech_to_text/__pycache__/xm_transformer.cpython-310.pyc b/fairseq/models/speech_to_text/__pycache__/xm_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0824e0bb913e13f45db08eb47cfc37c26f9b57af Binary files /dev/null and b/fairseq/models/speech_to_text/__pycache__/xm_transformer.cpython-310.pyc differ diff --git a/fairseq/models/speech_to_text/__pycache__/xm_transformer_unity.cpython-310.pyc b/fairseq/models/speech_to_text/__pycache__/xm_transformer_unity.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..facdf17317d42e931d1460e17449f295dc431c9c Binary files /dev/null and b/fairseq/models/speech_to_text/__pycache__/xm_transformer_unity.cpython-310.pyc differ diff --git a/fairseq/models/speech_to_text/berard.py b/fairseq/models/speech_to_text/berard.py new file mode 100644 index 0000000000000000000000000000000000000000..107ac983c62d721eae0a00b633ee350c9f1673da --- /dev/null +++ b/fairseq/models/speech_to_text/berard.py @@ -0,0 +1,607 @@ +#!/usr/bin/env python3 + +from ast import literal_eval +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq import checkpoint_utils, utils +from fairseq.data.data_utils import lengths_to_padding_mask +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderDecoderModel, + FairseqIncrementalDecoder, + register_model, + register_model_architecture, +) + + +@register_model("s2t_berard") +class BerardModel(FairseqEncoderDecoderModel): + """Implementation of a model similar to https://arxiv.org/abs/1802.04200 + + Paper title: End-to-End Automatic Speech Translation of Audiobooks + An implementation is available in tensorflow at + https://github.com/eske/seq2seq + Relevant files in this implementation are the config + (https://github.com/eske/seq2seq/blob/master/config/LibriSpeech/AST.yaml) + and the model code + (https://github.com/eske/seq2seq/blob/master/translate/models.py). + The encoder and decoder try to be close to the original implementation. + The attention is an MLP as in Bahdanau et al. + (https://arxiv.org/abs/1409.0473). + There is no state initialization by averaging the encoder outputs. + """ + + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @staticmethod + def add_args(parser): + parser.add_argument( + "--input-layers", + type=str, + metavar="EXPR", + help="List of linear layer dimensions. These " + "layers are applied to the input features and " + "are followed by tanh and possibly dropout.", + ) + parser.add_argument( + "--dropout", + type=float, + metavar="D", + help="Dropout probability to use in the encoder/decoder. " + "Note that this parameters control dropout in various places, " + "there is no fine-grained control for dropout for embeddings " + "vs LSTM layers for example.", + ) + parser.add_argument( + "--in-channels", + type=int, + metavar="N", + help="Number of encoder input channels. " "Typically value is 1.", + ) + parser.add_argument( + "--conv-layers", + type=str, + metavar="EXPR", + help="List of conv layers " "(format: (channels, kernel, stride)).", + ) + parser.add_argument( + "--num-blstm-layers", + type=int, + metavar="N", + help="Number of encoder bi-LSTM layers.", + ) + parser.add_argument( + "--lstm-size", type=int, metavar="N", help="LSTM hidden size." + ) + parser.add_argument( + "--decoder-embed-dim", + type=int, + metavar="N", + help="Embedding dimension of the decoder target tokens.", + ) + parser.add_argument( + "--decoder-hidden-dim", + type=int, + metavar="N", + help="Decoder LSTM hidden dimension.", + ) + parser.add_argument( + "--decoder-num-layers", + type=int, + metavar="N", + help="Number of decoder LSTM layers.", + ) + parser.add_argument( + "--attention-dim", + type=int, + metavar="N", + help="Hidden layer dimension in MLP attention.", + ) + parser.add_argument( + "--output-layer-dim", + type=int, + metavar="N", + help="Hidden layer dim for linear layer prior to output projection.", + ) + parser.add_argument( + "--load-pretrained-encoder-from", + type=str, + metavar="STR", + help="model to take encoder weights from (for initialization)", + ) + parser.add_argument( + "--load-pretrained-decoder-from", + type=str, + metavar="STR", + help="model to take decoder weights from (for initialization)", + ) + + @classmethod + def build_encoder(cls, args, task): + encoder = BerardEncoder( + input_layers=literal_eval(args.input_layers), + conv_layers=literal_eval(args.conv_layers), + in_channels=args.input_channels, + input_feat_per_channel=args.input_feat_per_channel, + num_blstm_layers=args.num_blstm_layers, + lstm_size=args.lstm_size, + dropout=args.dropout, + ) + if getattr(args, "load_pretrained_encoder_from", None) is not None: + encoder = checkpoint_utils.load_pretrained_component_from_model( + component=encoder, checkpoint=args.load_pretrained_encoder_from + ) + return encoder + + @classmethod + def build_decoder(cls, args, task): + decoder = LSTMDecoder( + dictionary=task.target_dictionary, + embed_dim=args.decoder_embed_dim, + num_layers=args.decoder_num_layers, + hidden_size=args.decoder_hidden_dim, + dropout=args.dropout, + encoder_output_dim=2 * args.lstm_size, # bidirectional + attention_dim=args.attention_dim, + output_layer_dim=args.output_layer_dim, + ) + if getattr(args, "load_pretrained_decoder_from", None) is not None: + decoder = checkpoint_utils.load_pretrained_component_from_model( + component=decoder, checkpoint=args.load_pretrained_decoder_from + ) + return decoder + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + encoder = cls.build_encoder(args, task) + decoder = cls.build_decoder(args, task) + + return cls(encoder, decoder) + + def get_normalized_probs(self, net_output, log_probs, sample=None): + # net_output['encoder_out'] is a (B, T, D) tensor + lprobs = super().get_normalized_probs(net_output, log_probs, sample) + # lprobs is a (B, T, D) tensor + lprobs.batch_first = True + return lprobs + + +class BerardEncoder(FairseqEncoder): + def __init__( + self, + input_layers: List[int], + conv_layers: List[Tuple[int]], + in_channels: int, + input_feat_per_channel: int, + num_blstm_layers: int, + lstm_size: int, + dropout: float, + ): + """ + Args: + input_layers: list of linear layer dimensions. These layers are + applied to the input features and are followed by tanh and + possibly dropout. + conv_layers: list of conv2d layer configurations. A configuration is + a tuple (out_channels, conv_kernel_size, stride). + in_channels: number of input channels. + input_feat_per_channel: number of input features per channel. These + are speech features, typically 40 or 80. + num_blstm_layers: number of bidirectional LSTM layers. + lstm_size: size of the LSTM hidden (and cell) size. + dropout: dropout probability. Dropout can be applied after the + linear layers and LSTM layers but not to the convolutional + layers. + """ + super().__init__(None) + + self.input_layers = nn.ModuleList() + in_features = input_feat_per_channel + for out_features in input_layers: + if dropout > 0: + self.input_layers.append( + nn.Sequential( + nn.Linear(in_features, out_features), nn.Dropout(p=dropout) + ) + ) + else: + self.input_layers.append(nn.Linear(in_features, out_features)) + in_features = out_features + + self.in_channels = in_channels + self.input_dim = input_feat_per_channel + self.conv_kernel_sizes_and_strides = [] + self.conv_layers = nn.ModuleList() + lstm_input_dim = input_layers[-1] + for conv_layer in conv_layers: + out_channels, conv_kernel_size, conv_stride = conv_layer + self.conv_layers.append( + nn.Conv2d( + in_channels, + out_channels, + conv_kernel_size, + stride=conv_stride, + padding=conv_kernel_size // 2, + ) + ) + self.conv_kernel_sizes_and_strides.append((conv_kernel_size, conv_stride)) + in_channels = out_channels + lstm_input_dim //= conv_stride + + lstm_input_dim *= conv_layers[-1][0] + self.lstm_size = lstm_size + self.num_blstm_layers = num_blstm_layers + self.lstm = nn.LSTM( + input_size=lstm_input_dim, + hidden_size=lstm_size, + num_layers=num_blstm_layers, + dropout=dropout, + bidirectional=True, + ) + self.output_dim = 2 * lstm_size # bidirectional + if dropout > 0: + self.dropout = nn.Dropout(p=dropout) + else: + self.dropout = None + + def forward(self, src_tokens, src_lengths=None, **kwargs): + """ + Args + src_tokens: padded tensor (B, T, C * feat) + src_lengths: tensor of original lengths of input utterances (B,) + """ + bsz, max_seq_len, _ = src_tokens.size() + # (B, C, T, feat) + x = ( + src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim) + .transpose(1, 2) + .contiguous() + ) + + for input_layer in self.input_layers: + x = input_layer(x) + x = torch.tanh(x) + + for conv_layer in self.conv_layers: + x = conv_layer(x) + + bsz, _, output_seq_len, _ = x.size() + + # (B, C, T, feat) -> (B, T, C, feat) -> (T, B, C, feat) -> + # (T, B, C * feat) + x = x.transpose(1, 2).transpose(0, 1).contiguous().view(output_seq_len, bsz, -1) + + input_lengths = src_lengths.clone() + for k, s in self.conv_kernel_sizes_and_strides: + p = k // 2 + input_lengths = (input_lengths.float() + 2 * p - k) / s + 1 + input_lengths = input_lengths.floor().long() + + packed_x = nn.utils.rnn.pack_padded_sequence(x, input_lengths) + + h0 = x.new(2 * self.num_blstm_layers, bsz, self.lstm_size).zero_() + c0 = x.new(2 * self.num_blstm_layers, bsz, self.lstm_size).zero_() + packed_outs, _ = self.lstm(packed_x, (h0, c0)) + + # unpack outputs and apply dropout + x, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_outs) + if self.dropout is not None: + x = self.dropout(x) + + encoder_padding_mask = ( + lengths_to_padding_mask(output_lengths).to(src_tokens.device).t() + ) + + return { + "encoder_out": x, # (T, B, C) + "encoder_padding_mask": encoder_padding_mask, # (T, B) + } + + def reorder_encoder_out(self, encoder_out, new_order): + encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select( + 1, new_order + ) + encoder_out["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ].index_select(1, new_order) + return encoder_out + + +class MLPAttention(nn.Module): + """The original attention from Badhanau et al. (2014) + + https://arxiv.org/abs/1409.0473, based on a Multi-Layer Perceptron. + The attention score between position i in the encoder and position j in the + decoder is: alpha_ij = V_a * tanh(W_ae * enc_i + W_ad * dec_j + b_a) + """ + + def __init__(self, decoder_hidden_state_dim, context_dim, attention_dim): + super().__init__() + + self.context_dim = context_dim + self.attention_dim = attention_dim + # W_ae and b_a + self.encoder_proj = nn.Linear(context_dim, self.attention_dim, bias=True) + # W_ad + self.decoder_proj = nn.Linear( + decoder_hidden_state_dim, self.attention_dim, bias=False + ) + # V_a + self.to_scores = nn.Linear(self.attention_dim, 1, bias=False) + + def forward(self, decoder_state, source_hids, encoder_padding_mask): + """The expected input dimensions are: + decoder_state: bsz x decoder_hidden_state_dim + source_hids: src_len x bsz x context_dim + encoder_padding_mask: src_len x bsz + """ + src_len, bsz, _ = source_hids.size() + # (src_len*bsz) x context_dim (to feed through linear) + flat_source_hids = source_hids.view(-1, self.context_dim) + # (src_len*bsz) x attention_dim + encoder_component = self.encoder_proj(flat_source_hids) + # src_len x bsz x attention_dim + encoder_component = encoder_component.view(src_len, bsz, self.attention_dim) + # 1 x bsz x attention_dim + decoder_component = self.decoder_proj(decoder_state).unsqueeze(0) + # Sum with broadcasting and apply the non linearity + # src_len x bsz x attention_dim + hidden_att = torch.tanh( + (decoder_component + encoder_component).view(-1, self.attention_dim) + ) + # Project onto the reals to get attentions scores (src_len x bsz) + attn_scores = self.to_scores(hidden_att).view(src_len, bsz) + + # Mask + softmax (src_len x bsz) + if encoder_padding_mask is not None: + attn_scores = ( + attn_scores.float() + .masked_fill_(encoder_padding_mask, float("-inf")) + .type_as(attn_scores) + ) # FP16 support: cast to float and back + # srclen x bsz + normalized_masked_attn_scores = F.softmax(attn_scores, dim=0) + + # Sum weighted sources (bsz x context_dim) + attn_weighted_context = ( + source_hids * normalized_masked_attn_scores.unsqueeze(2) + ).sum(dim=0) + + return attn_weighted_context, normalized_masked_attn_scores + + +class LSTMDecoder(FairseqIncrementalDecoder): + def __init__( + self, + dictionary, + embed_dim, + num_layers, + hidden_size, + dropout, + encoder_output_dim, + attention_dim, + output_layer_dim, + ): + """ + Args: + dictionary: target text dictionary. + embed_dim: embedding dimension for target tokens. + num_layers: number of LSTM layers. + hidden_size: hidden size for LSTM layers. + dropout: dropout probability. Dropout can be applied to the + embeddings, the LSTM layers, and the context vector. + encoder_output_dim: encoder output dimension (hidden size of + encoder LSTM). + attention_dim: attention dimension for MLP attention. + output_layer_dim: size of the linear layer prior to output + projection. + """ + super().__init__(dictionary) + self.num_layers = num_layers + self.hidden_size = hidden_size + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + self.embed_tokens = nn.Embedding(num_embeddings, embed_dim, padding_idx) + if dropout > 0: + self.dropout = nn.Dropout(p=dropout) + else: + self.dropout = None + + self.layers = nn.ModuleList() + for layer_id in range(num_layers): + input_size = embed_dim if layer_id == 0 else encoder_output_dim + self.layers.append( + nn.LSTMCell(input_size=input_size, hidden_size=hidden_size) + ) + + self.context_dim = encoder_output_dim + self.attention = MLPAttention( + decoder_hidden_state_dim=hidden_size, + context_dim=encoder_output_dim, + attention_dim=attention_dim, + ) + + self.deep_output_layer = nn.Linear( + hidden_size + encoder_output_dim + embed_dim, output_layer_dim + ) + self.output_projection = nn.Linear(output_layer_dim, num_embeddings) + + def forward( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs + ): + encoder_padding_mask = encoder_out["encoder_padding_mask"] + encoder_outs = encoder_out["encoder_out"] + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + bsz, seqlen = prev_output_tokens.size() + + srclen = encoder_outs.size(0) + + # embed tokens + embeddings = self.embed_tokens(prev_output_tokens) + x = embeddings + if self.dropout is not None: + x = self.dropout(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # initialize previous states (or get from cache during incremental + # generation) + cached_state = utils.get_incremental_state( + self, incremental_state, "cached_state" + ) + if cached_state is not None: + prev_hiddens, prev_cells = cached_state + else: + prev_hiddens = [encoder_out["encoder_out"].mean(dim=0)] * self.num_layers + prev_cells = [x.new_zeros(bsz, self.hidden_size)] * self.num_layers + + attn_scores = x.new_zeros(bsz, srclen) + attention_outs = [] + outs = [] + for j in range(seqlen): + input = x[j, :, :] + attention_out = None + for i, layer in enumerate(self.layers): + # the previous state is one layer below except for the bottom + # layer where the previous state is the state emitted by the + # top layer + hidden, cell = layer( + input, + ( + prev_hiddens[(i - 1) % self.num_layers], + prev_cells[(i - 1) % self.num_layers], + ), + ) + if self.dropout is not None: + hidden = self.dropout(hidden) + prev_hiddens[i] = hidden + prev_cells[i] = cell + if attention_out is None: + attention_out, attn_scores = self.attention( + hidden, encoder_outs, encoder_padding_mask + ) + if self.dropout is not None: + attention_out = self.dropout(attention_out) + attention_outs.append(attention_out) + input = attention_out + + # collect the output of the top layer + outs.append(hidden) + + # cache previous states (no-op except during incremental generation) + utils.set_incremental_state( + self, incremental_state, "cached_state", (prev_hiddens, prev_cells) + ) + + # collect outputs across time steps + x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size) + attention_outs_concat = torch.cat(attention_outs, dim=0).view( + seqlen, bsz, self.context_dim + ) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + attention_outs_concat = attention_outs_concat.transpose(0, 1) + + # concat LSTM output, attention output and embedding + # before output projection + x = torch.cat((x, attention_outs_concat, embeddings), dim=2) + x = self.deep_output_layer(x) + x = torch.tanh(x) + if self.dropout is not None: + x = self.dropout(x) + # project back to size of vocabulary + x = self.output_projection(x) + + # to return the full attn_scores tensor, we need to fix the decoder + # to account for subsampling input frames + # return x, attn_scores + return x, None + + def reorder_incremental_state(self, incremental_state, new_order): + super().reorder_incremental_state(incremental_state, new_order) + cached_state = utils.get_incremental_state( + self, incremental_state, "cached_state" + ) + if cached_state is None: + return + + def reorder_state(state): + if isinstance(state, list): + return [reorder_state(state_i) for state_i in state] + return state.index_select(0, new_order) + + new_state = tuple(map(reorder_state, cached_state)) + utils.set_incremental_state(self, incremental_state, "cached_state", new_state) + + +@register_model_architecture(model_name="s2t_berard", arch_name="s2t_berard") +def berard(args): + """The original version: "End-to-End Automatic Speech Translation of + Audiobooks" (https://arxiv.org/abs/1802.04200) + """ + args.input_layers = getattr(args, "input_layers", "[256, 128]") + args.conv_layers = getattr(args, "conv_layers", "[(16, 3, 2), (16, 3, 2)]") + args.num_blstm_layers = getattr(args, "num_blstm_layers", 3) + args.lstm_size = getattr(args, "lstm_size", 256) + args.dropout = getattr(args, "dropout", 0.2) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 128) + args.decoder_num_layers = getattr(args, "decoder_num_layers", 2) + args.decoder_hidden_dim = getattr(args, "decoder_hidden_dim", 512) + args.attention_dim = getattr(args, "attention_dim", 512) + args.output_layer_dim = getattr(args, "output_layer_dim", 128) + args.load_pretrained_encoder_from = getattr( + args, "load_pretrained_encoder_from", None + ) + args.load_pretrained_decoder_from = getattr( + args, "load_pretrained_decoder_from", None + ) + + +@register_model_architecture(model_name="s2t_berard", arch_name="s2t_berard_256_3_3") +def berard_256_3_3(args): + """Used in + * "Harnessing Indirect Training Data for End-to-End Automatic Speech + Translation: Tricks of the Trade" (https://arxiv.org/abs/1909.06515) + * "CoVoST: A Diverse Multilingual Speech-To-Text Translation Corpus" + (https://arxiv.org/pdf/2002.01320.pdf) + * "Self-Supervised Representations Improve End-to-End Speech Translation" + (https://arxiv.org/abs/2006.12124) + """ + args.decoder_num_layers = getattr(args, "decoder_num_layers", 3) + berard(args) + + +@register_model_architecture(model_name="s2t_berard", arch_name="s2t_berard_512_3_2") +def berard_512_3_2(args): + args.num_blstm_layers = getattr(args, "num_blstm_layers", 3) + args.lstm_size = getattr(args, "lstm_size", 512) + args.dropout = getattr(args, "dropout", 0.3) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256) + args.decoder_num_layers = getattr(args, "decoder_num_layers", 2) + args.decoder_hidden_dim = getattr(args, "decoder_hidden_dim", 1024) + args.attention_dim = getattr(args, "attention_dim", 512) + args.output_layer_dim = getattr(args, "output_layer_dim", 256) + berard(args) + + +@register_model_architecture(model_name="s2t_berard", arch_name="s2t_berard_512_5_3") +def berard_512_5_3(args): + args.num_blstm_layers = getattr(args, "num_blstm_layers", 5) + args.lstm_size = getattr(args, "lstm_size", 512) + args.dropout = getattr(args, "dropout", 0.3) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256) + args.decoder_num_layers = getattr(args, "decoder_num_layers", 3) + args.decoder_hidden_dim = getattr(args, "decoder_hidden_dim", 1024) + args.attention_dim = getattr(args, "attention_dim", 512) + args.output_layer_dim = getattr(args, "output_layer_dim", 256) + berard(args) diff --git a/fairseq/models/speech_to_text/convtransformer.py b/fairseq/models/speech_to_text/convtransformer.py new file mode 100644 index 0000000000000000000000000000000000000000..4d0fc02aee908ad50349946d7fbbfec563ec07e8 --- /dev/null +++ b/fairseq/models/speech_to_text/convtransformer.py @@ -0,0 +1,443 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from fairseq import checkpoint_utils, utils +from fairseq.data.data_utils import lengths_to_padding_mask +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderDecoderModel, + register_model, + register_model_architecture, +) +from fairseq.models.speech_to_text.modules.convolution import infer_conv_output_dim +from fairseq.models.transformer import Embedding, TransformerDecoder +from fairseq.modules import LayerNorm, PositionalEmbedding, TransformerEncoderLayer + +logger = logging.getLogger(__name__) + + +@register_model("convtransformer") +class ConvTransformerModel(FairseqEncoderDecoderModel): + """ + Transformer-based Speech translation model from ESPNet-ST + https://arxiv.org/abs/2004.10234 + """ + + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + "--input-feat-per-channel", + type=int, + metavar="N", + help="encoder input dimension per input channel", + ) + parser.add_argument( + "--activation-fn", + choices=utils.get_available_activation_fns(), + help="activation function to use", + ) + parser.add_argument( + "--dropout", type=float, metavar="D", help="dropout probability" + ) + parser.add_argument( + "--attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights", + ) + parser.add_argument( + "--activation-dropout", + "--relu-dropout", + type=float, + metavar="D", + help="dropout probability after activation in FFN.", + ) + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension", + ) + parser.add_argument( + "--encoder-ffn-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension for FFN", + ) + parser.add_argument( + "--encoder-layers", type=int, metavar="N", help="num encoder layers" + ) + parser.add_argument( + "--encoder-attention-heads", + type=int, + metavar="N", + help="num encoder attention heads", + ) + parser.add_argument( + "--encoder-normalize-before", + action="store_true", + help="apply layernorm before each encoder block", + ) + parser.add_argument( + "--decoder-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension", + ) + parser.add_argument( + "--decoder-ffn-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension for FFN", + ) + parser.add_argument( + "--decoder-layers", type=int, metavar="N", help="num decoder layers" + ) + parser.add_argument( + "--decoder-attention-heads", + type=int, + metavar="N", + help="num decoder attention heads", + ) + parser.add_argument( + "--decoder-normalize-before", + action="store_true", + help="apply layernorm before each decoder block", + ) + parser.add_argument( + "--decoder-output-dim", + type=int, + metavar="N", + help="decoder output dimension (extra linear layer if different from decoder embed dim)", + ) + parser.add_argument( + "--share-decoder-input-output-embed", + action="store_true", + help="share decoder input and output embeddings", + ) + parser.add_argument( + "--layernorm-embedding", + action="store_true", + help="add layernorm to embedding", + ) + parser.add_argument( + "--no-scale-embedding", + action="store_true", + help="if True, dont scale embeddings", + ) + parser.add_argument( + "--load-pretrained-encoder-from", + type=str, + metavar="STR", + help="model to take encoder weights from (for initialization)", + ) + parser.add_argument( + "--load-pretrained-decoder-from", + type=str, + metavar="STR", + help="model to take decoder weights from (for initialization)", + ) + parser.add_argument( + "--conv-out-channels", + type=int, + metavar="INT", + help="the number of output channels of conv layer", + ) + + @classmethod + def build_encoder(cls, args): + encoder = ConvTransformerEncoder(args) + if getattr(args, "load_pretrained_encoder_from", None) is not None: + encoder = checkpoint_utils.load_pretrained_component_from_model( + component=encoder, checkpoint=args.load_pretrained_encoder_from + ) + return encoder + + @classmethod + def build_decoder(cls, args, task, embed_tokens): + decoder = TransformerDecoderNoExtra(args, task.target_dictionary, embed_tokens) + if getattr(args, "load_pretrained_decoder_from", None) is not None: + decoder = checkpoint_utils.load_pretrained_component_from_model( + component=decoder, checkpoint=args.load_pretrained_decoder_from + ) + return decoder + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_architecture(args) + + def build_embedding(dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + return Embedding(num_embeddings, embed_dim, padding_idx) + + decoder_embed_tokens = build_embedding( + task.target_dictionary, args.decoder_embed_dim + ) + encoder = cls.build_encoder(args) + decoder = cls.build_decoder(args, task, decoder_embed_tokens) + return cls(encoder, decoder) + + @staticmethod + @torch.jit.unused + def set_batch_first(lprobs): + lprobs.batch_first = True + + def get_normalized_probs( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + # net_output['encoder_out'] is a (B, T, D) tensor + lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample) + if self.training: + self.set_batch_first(lprobs) + return lprobs + + def output_layout(self): + return "BTD" + + """ + The forward method inherited from the base class has a **kwargs argument in + its input, which is not supported in torchscript. This method overrites the forward + method definition without **kwargs. + """ + + def forward(self, src_tokens, src_lengths, prev_output_tokens): + encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths) + decoder_out = self.decoder( + prev_output_tokens=prev_output_tokens, encoder_out=encoder_out + ) + return decoder_out + + +class ConvTransformerEncoder(FairseqEncoder): + """Conv + Transformer encoder""" + + def __init__(self, args): + """Construct an Encoder object.""" + super().__init__(None) + + self.dropout = args.dropout + self.embed_scale = ( + 1.0 if args.no_scale_embedding else math.sqrt(args.encoder_embed_dim) + ) + self.padding_idx = 1 + self.in_channels = 1 + self.input_dim = args.input_feat_per_channel + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, args.conv_out_channels, 3, stride=2, padding=3 // 2), + torch.nn.ReLU(), + torch.nn.Conv2d( + args.conv_out_channels, + args.conv_out_channels, + 3, + stride=2, + padding=3 // 2, + ), + torch.nn.ReLU(), + ) + transformer_input_dim = infer_conv_output_dim( + self.in_channels, self.input_dim, args.conv_out_channels + ) + self.out = torch.nn.Linear(transformer_input_dim, args.encoder_embed_dim) + self.embed_positions = PositionalEmbedding( + args.max_source_positions, + args.encoder_embed_dim, + self.padding_idx, + learned=False, + ) + + self.transformer_layers = nn.ModuleList([]) + self.transformer_layers.extend( + [TransformerEncoderLayer(args) for i in range(args.encoder_layers)] + ) + if args.encoder_normalize_before: + self.layer_norm = LayerNorm(args.encoder_embed_dim) + else: + self.layer_norm = None + + def pooling_ratio(self): + return 4 + + def forward(self, src_tokens, src_lengths): + """Encode input sequence. + :param torch.Tensor xs: input tensor + :param torch.Tensor masks: input mask + :return: position embedded tensor and mask + :rtype Tuple[torch.Tensor, torch.Tensor]: + """ + bsz, max_seq_len, _ = src_tokens.size() + x = ( + src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim) + .transpose(1, 2) + .contiguous() + ) + x = self.conv(x) + bsz, _, output_seq_len, _ = x.size() + x = x.transpose(1, 2).transpose(0, 1).contiguous().view(output_seq_len, bsz, -1) + x = self.out(x) + x = self.embed_scale * x + + subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5) + input_len_0 = (src_lengths.float() / subsampling_factor).ceil().long() + input_len_1 = x.size(0) * torch.ones([src_lengths.size(0)]).long().to( + input_len_0.device + ) + input_lengths = torch.min(input_len_0, input_len_1) + + encoder_padding_mask = lengths_to_padding_mask(input_lengths) + + positions = self.embed_positions(encoder_padding_mask).transpose(0, 1) + x += positions + x = F.dropout(x, p=self.dropout, training=self.training) + + for layer in self.transformer_layers: + x = layer(x, encoder_padding_mask) + + if not encoder_padding_mask.any(): + maybe_encoder_padding_mask = None + else: + maybe_encoder_padding_mask = encoder_padding_mask + + return { + "encoder_out": [x], + "encoder_padding_mask": [maybe_encoder_padding_mask] + if maybe_encoder_padding_mask is not None + else [], + "encoder_embedding": [], + "encoder_states": [], + "src_tokens": [], + "src_lengths": [], + } + + @torch.jit.export + def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] + if len(encoder_out["encoder_padding_mask"]) == 0: + new_encoder_padding_mask = [] + else: + new_encoder_padding_mask = [ + (encoder_out["encoder_padding_mask"][0]).index_select(0, new_order) + ] + if len(encoder_out["encoder_embedding"]) == 0: + new_encoder_embedding = [] + else: + new_encoder_embedding = [ + (encoder_out["encoder_embedding"][0]).index_select(0, new_order) + ] + encoder_states = encoder_out["encoder_states"] + if len(encoder_states) > 0: + for idx, state in enumerate(encoder_states): + encoder_states[idx] = state.index_select(1, new_order) + + return { + "encoder_out": new_encoder_out, + "encoder_padding_mask": new_encoder_padding_mask, + "encoder_embedding": new_encoder_embedding, + "encoder_states": encoder_states, + "src_tokens": [], + "src_lengths": [], + } + + +class TransformerDecoderNoExtra(TransformerDecoder): + def extract_features( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]], + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + # call scriptable method from parent class + x, _ = self.extract_features_scriptable( + prev_output_tokens, + encoder_out, + incremental_state, + full_context_alignment, + alignment_layer, + alignment_heads, + ) + return x, None + + +@register_model_architecture(model_name="convtransformer", arch_name="convtransformer") +def base_architecture(args): + args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.dropout = getattr(args, "dropout", 0.1) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + args.max_source_positions = getattr(args, "max_source_positions", 3000) + args.max_target_positions = getattr(args, "max_target_positions", 1024) + args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) + args.conv_out_channels = getattr(args, "conv_out_channels", args.encoder_embed_dim) + + +@register_model_architecture("convtransformer", "convtransformer_espnet") +def convtransformer_espnet(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_layers = getattr(args, "encoder_layers", 12) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) diff --git a/fairseq/models/speech_to_text/hub_interface.py b/fairseq/models/speech_to_text/hub_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d78427f68724c2262c61d48b31194d97e38773da --- /dev/null +++ b/fairseq/models/speech_to_text/hub_interface.py @@ -0,0 +1,128 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from argparse import Namespace +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import fairseq.data.audio.feature_transforms.utterance_cmvn as utt_cmvn +from fairseq.data import encoders +from fairseq.data.audio.audio_utils import convert_waveform as convert_wav +from fairseq.data.audio.audio_utils import get_fbank +from fairseq.data.audio.audio_utils import get_waveform as get_wav +from fairseq.data.audio.speech_to_text_dataset import SpeechToTextDataset + +logger = logging.getLogger(__name__) + + +class S2THubInterface(nn.Module): + def __init__(self, cfg, task, model): + super().__init__() + self.cfg = cfg + self.task = task + self.model = model + self.model.eval() + self.generator = self.task.build_generator([self.model], self.cfg.generation) + + @classmethod + def get_model_input(cls, task, audio: Union[str, torch.Tensor]): + input_type = task.data_cfg.hub.get("input_type", "fbank80") + if input_type == "fbank80_w_utt_cmvn": + if isinstance(audio, str): + feat = utt_cmvn.UtteranceCMVN()(get_fbank(audio)) + feat = feat.unsqueeze(0) # T x D -> 1 x T x D + else: + import torchaudio.compliance.kaldi as kaldi + + feat = kaldi.fbank(audio, num_mel_bins=80).numpy() # 1 x T x D + elif input_type in {"waveform", "standardized_waveform"}: + if isinstance(audio, str): + feat, sr = get_wav(audio) # C x T + feat, _ = convert_wav( + feat, sr, to_sample_rate=16_000, to_mono=True + ) # C x T -> 1 x T + else: + feat = audio.numpy() + else: + raise ValueError(f"Unknown value: input_type = {input_type}") + + src_lengths = torch.Tensor([feat.shape[1]]).long() + src_tokens = torch.from_numpy(feat) # 1 x T (x D) + if input_type == "standardized_waveform": + with torch.no_grad(): + src_tokens = F.layer_norm(src_tokens, src_tokens.shape) + + return { + "net_input": { + "src_tokens": src_tokens, + "src_lengths": src_lengths, + "prev_output_tokens": None, + }, + "target_lengths": None, + "speaker": None, + } + + @classmethod + def detokenize(cls, task, tokens): + text = task.tgt_dict.string(tokens) + tkn_cfg = task.data_cfg.bpe_tokenizer + tokenizer = encoders.build_bpe(Namespace(**tkn_cfg)) + return text if tokenizer is None else tokenizer.decode(text) + + @classmethod + def get_prefix_token(cls, task, lang): + prefix_size = int(task.data_cfg.prepend_tgt_lang_tag) + prefix_tokens = None + if prefix_size > 0: + assert lang is not None + lang_tag = SpeechToTextDataset.get_lang_tag_idx(lang, task.tgt_dict) + prefix_tokens = torch.Tensor([lang_tag]).long().unsqueeze(0) + return prefix_tokens + + @classmethod + def get_prediction( + cls, task, model, generator, sample, tgt_lang=None, synthesize_speech=False + ) -> Union[str, Tuple[str, Tuple[torch.Tensor, int]]]: + _tgt_lang = tgt_lang or task.data_cfg.hub.get("tgt_lang", None) + prefix = cls.get_prefix_token(task, _tgt_lang) + pred_tokens = generator.generate([model], sample, prefix_tokens=prefix) + pred = cls.detokenize(task, pred_tokens[0][0]["tokens"]) + eos_token = task.data_cfg.config.get("eos_token", None) + if eos_token: + pred = " ".join(pred.split(" ")[:-1]) + + if synthesize_speech: + pfx = f"{_tgt_lang}_" if task.data_cfg.prepend_tgt_lang_tag else "" + tts_model_id = task.data_cfg.hub.get(f"{pfx}tts_model_id", None) + speaker = task.data_cfg.hub.get(f"{pfx}speaker", None) + if tts_model_id is None: + logger.warning("TTS model configuration not found") + else: + _repo, _id = tts_model_id.split(":") + tts_model = torch.hub.load(_repo, _id, verbose=False) + pred = (pred, tts_model.predict(pred, speaker=speaker)) + return pred + + def predict( + self, + audio: Union[str, torch.Tensor], + tgt_lang: Optional[str] = None, + synthesize_speech: bool = False, + ) -> Union[str, Tuple[str, Tuple[torch.Tensor, int]]]: + # `audio` is either a file path or a 1xT Tensor + # return either text or (text, synthetic speech) + sample = self.get_model_input(self.task, audio) + return self.get_prediction( + self.task, + self.model, + self.generator, + sample, + tgt_lang=tgt_lang, + synthesize_speech=synthesize_speech, + ) diff --git a/fairseq/models/speech_to_text/modules/__init__.py b/fairseq/models/speech_to_text/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fairseq/models/speech_to_text/modules/__pycache__/__init__.cpython-310.pyc b/fairseq/models/speech_to_text/modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..472b2c3dabd15a46102f4fc5007ef0d2e661360d Binary files /dev/null and b/fairseq/models/speech_to_text/modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/models/speech_to_text/modules/__pycache__/augmented_memory_attention.cpython-310.pyc b/fairseq/models/speech_to_text/modules/__pycache__/augmented_memory_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91c1ebbe78f9ccadfd5523749b210d5a2da3e4e8 Binary files /dev/null and b/fairseq/models/speech_to_text/modules/__pycache__/augmented_memory_attention.cpython-310.pyc differ diff --git a/fairseq/models/speech_to_text/modules/__pycache__/convolution.cpython-310.pyc b/fairseq/models/speech_to_text/modules/__pycache__/convolution.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3db54bf7fc5965e27276fde6c3bb275400cc35e Binary files /dev/null and b/fairseq/models/speech_to_text/modules/__pycache__/convolution.cpython-310.pyc differ diff --git a/fairseq/models/speech_to_text/modules/__pycache__/emformer.cpython-310.pyc b/fairseq/models/speech_to_text/modules/__pycache__/emformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..739f2e619fff441e5603120cda92662b5511f51f Binary files /dev/null and b/fairseq/models/speech_to_text/modules/__pycache__/emformer.cpython-310.pyc differ diff --git a/fairseq/models/speech_to_text/modules/augmented_memory_attention.py b/fairseq/models/speech_to_text/modules/augmented_memory_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..2d330f96f689864a51a0bd6d952d0ec3070186ac --- /dev/null +++ b/fairseq/models/speech_to_text/modules/augmented_memory_attention.py @@ -0,0 +1,487 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from fairseq.models import FairseqEncoder +from fairseq.models.speech_to_text import ConvTransformerEncoder +from fairseq.models.speech_to_text.utils import ( + attention_suppression, + lengths_to_encoder_padding_mask, + segments_to_sequence, + sequence_to_segments, +) +from fairseq.modules import MultiheadAttention, TransformerEncoderLayer + +# ------------------------------------------------------------------------------ +# AugmentedMemoryConvTransformerEncoder +# ------------------------------------------------------------------------------ + + +class AugmentedMemoryConvTransformerEncoder(ConvTransformerEncoder): + def __init__(self, args): + super().__init__(args) + + args.encoder_stride = self.stride() + + self.left_context = args.left_context // args.encoder_stride + + self.right_context = args.right_context // args.encoder_stride + + self.left_context_after_stride = args.left_context // args.encoder_stride + self.right_context_after_stride = args.right_context // args.encoder_stride + + self.transformer_layers = nn.ModuleList([]) + self.transformer_layers.extend( + [ + AugmentedMemoryTransformerEncoderLayer(args) + for i in range(args.encoder_layers) + ] + ) + + def stride(self): + # Hard coded here. Should infer from convs in future + stride = 4 + return stride + + def forward(self, src_tokens, src_lengths, states=None): + """Encode input sequence. + :param torch.Tensor xs: input tensor + :param torch.Tensor masks: input mask + :return: position embedded tensor and mask + :rtype Tuple[torch.Tensor, torch.Tensor]: + """ + bsz, max_seq_len, _ = src_tokens.size() + x = ( + src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim) + .transpose(1, 2) + .contiguous() + ) + x = self.conv(x) + bsz, _, output_seq_len, _ = x.size() + x = x.transpose(1, 2).transpose(0, 1).contiguous().view(output_seq_len, bsz, -1) + x = self.out(x) + x = self.embed_scale * x + + subsampling_factor = 1.0 * max_seq_len / output_seq_len + input_lengths = torch.max( + (src_lengths.float() / subsampling_factor).ceil().long(), + x.size(0) * src_lengths.new_ones([src_lengths.size(0)]).long(), + ) + + encoder_padding_mask, _ = lengths_to_encoder_padding_mask( + input_lengths, batch_first=True + ) + + # TODO: fix positional embedding + positions = self.embed_positions(encoder_padding_mask).transpose(0, 1) + + x += positions + x = F.dropout(x, p=self.dropout, training=self.training) + + # State to store memory banks etc. + if states is None: + states = [ + {"memory_banks": None, "encoder_states": None} + for i in range(len(self.transformer_layers)) + ] + + for i, layer in enumerate(self.transformer_layers): + # x size: + # (self.left_size + self.segment_size + self.right_size) + # / self.stride, num_heads, dim + # TODO: Consider mask here + x = layer(x, states[i]) + states[i]["encoder_states"] = x[ + self.left_context_after_stride : -self.right_context_after_stride + ] + + lengths = ( + ( + ~encoder_padding_mask[ + :, self.left_context_after_stride : -self.right_context_after_stride + ] + ) + .sum(dim=1, keepdim=True) + .long() + ) + + return states[-1]["encoder_states"], lengths, states + + +# ------------------------------------------------------------------------------ +# AugmentedMemoryTransformerEncoderLayer +# ------------------------------------------------------------------------------ +class AugmentedMemoryTransformerEncoderLayer(TransformerEncoderLayer): + def __init__(self, args): + super().__init__(args) + + self.left_context = args.left_context // args.encoder_stride + self.right_context = args.right_context // args.encoder_stride + + def forward(self, x, state): + + length, batch_size, x_dim = x.size() + + residual = x + + if self.normalize_before: + x = self.self_attn_layer_norm(x) + + # init_state + if state.get("memory_banks", None) is None: + state["memory_banks"] = [] + + # TODO reseach new sum_query method + seg_start = self.left_context + seg_end = length - self.right_context + if seg_start < seg_end: + summarization_query = torch.mean(x[seg_start:seg_end], keepdim=True, dim=0) + else: + summarization_query = x.new_zeros(1, batch_size, x_dim) + + x = torch.cat([x, summarization_query], dim=0) + + x = self.self_attn(input_and_summary=x, state=state) + + x = self.dropout_module(x) + x = residual + x + + if not self.normalize_before: + x = self.self_attn_layer_norm(x) + + residual = x + if self.normalize_before: + x = self.final_layer_norm(x) + + x = self.activation_fn(self.fc1(x)) + x = self.activation_dropout_module(x) + x = self.fc2(x) + x = self.dropout_module(x) + x = residual + x + if not self.normalize_before: + x = self.final_layer_norm(x) + + return x + + def build_self_attention(self, embed_dim, args): + return AugmentedMemoryMultiheadAttention( + embed_dim=embed_dim, + num_heads=args.encoder_attention_heads, + dropout=args.attention_dropout, + self_attention=True, + q_noise=self.quant_noise, + qn_block_size=self.quant_noise_block_size, + tanh_on_mem=True, + max_memory_size=args.max_memory_size, + ) + + +# ------------------------------------------------------------------------------ +# AugmentedMemoryMultiheadAttention +# ------------------------------------------------------------------------------ +class AugmentedMemoryMultiheadAttention(MultiheadAttention): + """ + Augmented Memory Attention from + Streaming Transformer-based Acoustic Models + Using Self-attention with Augmented Memory + https://arxiv.org/abs/2005.08042 + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + tanh_on_mem=False, + memory_dim=None, + std_scale=0.5, # 0.5 based on https://arxiv.org/abs/2005.09137 + max_memory_size=-1, + disable_mem_on_mem_attn=True, + ): + super().__init__( + embed_dim, + num_heads, + kdim, + vdim, + dropout, + bias, + add_bias_kv, + add_zero_attn, + self_attention, + encoder_decoder_attention, + q_noise, + qn_block_size, + ) + + self.memory_dim = memory_dim if memory_dim is not None else embed_dim + self.std_scale = std_scale + self.disable_mem_on_mem_attn = disable_mem_on_mem_attn + + # This Operator was used for factorization in PySpeech + self.v2e = lambda x: x + + if tanh_on_mem: + self.squash_mem = torch.tanh + self.nonlinear_squash_mem = True + else: + self.squash_mem = lambda x: x + self.nonlinear_squash_mem = False + + self.max_memory_size = max_memory_size + + def forward(self, input_and_summary, state): + """ + input: Encoder states of current segment with left or right context, + plus one summarization query + + """ + + length, batch_size, _ = input_and_summary.shape + length = length - 1 # not include sum_query, last index + + memory = state["memory_banks"] + # TODO: positional embedding on memory + + if self.max_memory_size > -1 and len(memory) > self.max_memory_size: + # TODO: need to fix here + if self.max_memory_size == 0: + memory = memory.new_zeros(1, memory.size(1), self.memory_dim) + else: + memory = memory[-self.max_memory_size :] + + memory_and_input = torch.cat(memory + [input_and_summary[:-1]], dim=0) + input_and_sum_query = input_and_summary + + q = self.q_proj(self.v2e(input_and_sum_query)) + k = self.k_proj(self.v2e(memory_and_input)) + v = self.v_proj(self.v2e(memory_and_input)) + + q = ( + q.contiguous() + .view(-1, batch_size * self.num_heads, self.head_dim) + .transpose(0, 1) + * self.scaling + ) + k = ( + k.contiguous() + .view(-1, batch_size * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + v = ( + v.contiguous() + .view(-1, batch_size * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + attention_weights = torch.bmm(q, k.transpose(1, 2)) + + if self.disable_mem_on_mem_attn: + attention_weights = self.suppress_mem_on_mem_attention( + batch_size, self.num_heads, len(memory), attention_weights + ) + + if self.std_scale is not None: + attention_weights = attention_suppression(attention_weights, self.std_scale) + + assert list(attention_weights.shape) == [ + batch_size * self.num_heads, + length + 1, + length + len(memory), + ] + + attention_weights = torch.nn.functional.softmax( + attention_weights.float(), dim=-1 + ).type_as(attention_weights) + + attention_probs = self.dropout_module(attention_weights) + + # [T, T, B, n_head] + [T, B, n_head, d_head] -> [T, B, n_head, d_head] + attention = torch.bmm(attention_probs, v) + + assert list(attention.shape) == [ + batch_size * self.num_heads, + length + 1, + self.head_dim, + ] + + attention = ( + attention.transpose(0, 1) + .contiguous() + .view(length + 1, batch_size, self.embed_dim) + ) + + output_and_memory = self.out_proj(attention) + + next_m = output_and_memory[-1:] + next_m = self.squash_mem(next_m) + output = output_and_memory[:-1] + + state["memory_banks"].append(next_m) + + return output + + def suppress_mem_on_mem_attention( + self, B: int, num_heads: int, mem_size: int, attention_weight: Tensor + ): + """ + Arguments: + - B: batch size + - num_heads: number of attention heads + - mem_size: size of memory bank + - attention_weight: a [B*num_heads, T + 1, T + mem_size] vector + + Return: + modified attention_weight with [B*num_heads, -1, :mem_size] = -inf + """ + attention_weight[:, -1, :mem_size] = float("-inf") + return attention_weight + + +# ------------------------------------------------------------------------------ +# SequenceEncoder +# ------------------------------------------------------------------------------ +class SequenceEncoder(FairseqEncoder): + """ + SequenceEncoder encodes sequences. + + More specifically, `src_tokens` and `src_lengths` in `forward()` should + describe a batch of "complete" sequences rather than segments. + + Segment-by-segment inference can be triggered by `segment_size`: + 1) `segment_size` is None: + SequenceEncoder treats the input sequence as one single segment. + 2) `segment_size` is not None (some int instead): + SequenceEncoder does the following: + 1. breaks the input sequence into several segments + 2. inference on each segment and collect the outputs + 3. concatanete segment outputs into the output sequence. + Note that `segment_size` here shouldn't include additional left/right + contexts needed, for example if we wish to infer with LC-BLSTM where the + middle chunk size is 100 and right context is 20, `segment_size` should be + 100. + """ + + def __init__(self, args, module): + super().__init__(None) + + self.module = module + self.input_time_axis = 1 + self.output_time_axis = 0 + self.segment_size = args.segment_size + self.left_context = args.left_context + self.right_context = args.right_context + + def forward( + self, + src_tokens: Tensor, + src_lengths: Tensor, + states=None, + ): + + seg_src_tokens_lengths = sequence_to_segments( + sequence=src_tokens, + time_axis=self.input_time_axis, + lengths=src_lengths, + segment_size=self.segment_size, + extra_left_context=self.left_context, + extra_right_context=self.right_context, + ) + + seg_encoder_states_lengths: List[Tuple[Tensor, Tensor]] = [] + + for seg_src_tokens, seg_src_lengths in seg_src_tokens_lengths: + (seg_encoder_states, seg_enc_lengths, states) = self.module( + seg_src_tokens, + seg_src_lengths, + states=states, + ) + + seg_encoder_states_lengths.append((seg_encoder_states, seg_enc_lengths)) + + encoder_out, enc_lengths = segments_to_sequence( + segments=seg_encoder_states_lengths, time_axis=self.output_time_axis + ) + + encoder_padding_mask, _ = lengths_to_encoder_padding_mask( + enc_lengths, batch_first=True + ) + + if not encoder_padding_mask.any(): + encoder_padding_mask = None + + return { + "encoder_out": [encoder_out], + "encoder_padding_mask": [encoder_padding_mask], + "encoder_embedding": [], + "encoder_states": [states], + "src_tokens": [], + "src_lengths": [], + } + + def incremental_encode( + self, + seg_src_tokens: Tensor, + seg_src_lengths: Tensor, + states=None, + ): + """ + Different from forward function, this function takes segmented speech + as input, and append encoder states to previous states + """ + (seg_encoder_states, seg_enc_lengths, states) = self.module( + seg_src_tokens, + seg_src_lengths, + states=states, + ) + return seg_encoder_states, seg_enc_lengths, states + + +# ------------------------------------------------------------------------------ +# Augmented memory model decorator +# ------------------------------------------------------------------------------ +def augmented_memory(klass): + class StreamSeq2SeqModel(klass): + @staticmethod + def add_args(parser): + super(StreamSeq2SeqModel, StreamSeq2SeqModel).add_args(parser) + parser.add_argument( + "--segment-size", type=int, required=True, help="Length of the segment." + ) + parser.add_argument( + "--left-context", + type=int, + default=0, + help="Left context for the segment.", + ) + parser.add_argument( + "--right-context", + type=int, + default=0, + help="Right context for the segment.", + ) + parser.add_argument( + "--max-memory-size", + type=int, + default=-1, + help="Right context for the segment.", + ) + + StreamSeq2SeqModel.__name__ = klass.__name__ + return StreamSeq2SeqModel diff --git a/fairseq/models/speech_to_text/modules/convolution.py b/fairseq/models/speech_to_text/modules/convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..526d7540c59bd46a5f49f462e4c52c8b16d76aca --- /dev/null +++ b/fairseq/models/speech_to_text/modules/convolution.py @@ -0,0 +1,126 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List + +import torch +import torch.nn as nn + + +class Conv1dSubsampler(nn.Module): + """Convolutional subsampler: a stack of 1D convolution (along temporal + dimension) followed by non-linear activation via gated linear units + (https://arxiv.org/abs/1911.08460) + + Args: + in_channels (int): the number of input channels + mid_channels (int): the number of intermediate channels + out_channels (int): the number of output channels + kernel_sizes (List[int]): the kernel size for each convolutional layer + """ + + def __init__( + self, + in_channels: int, + mid_channels: int, + out_channels: int, + kernel_sizes: List[int] = (3, 3), + ): + super(Conv1dSubsampler, self).__init__() + self.n_layers = len(kernel_sizes) + self.conv_layers = nn.ModuleList( + nn.Conv1d( + in_channels if i == 0 else mid_channels // 2, + mid_channels if i < self.n_layers - 1 else out_channels * 2, + k, + stride=2, + padding=k // 2, + ) + for i, k in enumerate(kernel_sizes) + ) + + def get_out_seq_lens_tensor(self, in_seq_lens_tensor): + out = in_seq_lens_tensor.clone() + for _ in range(self.n_layers): + out = ((out.float() - 1) / 2 + 1).floor().long() + return out + + def forward(self, src_tokens, src_lengths): + bsz, in_seq_len, _ = src_tokens.size() # B x T x (C x D) + x = src_tokens.transpose(1, 2).contiguous() # -> B x (C x D) x T + for conv in self.conv_layers: + x = conv(x) + x = nn.functional.glu(x, dim=1) + _, _, out_seq_len = x.size() + x = x.transpose(1, 2).transpose(0, 1).contiguous() # -> T x B x (C x D) + return x, self.get_out_seq_lens_tensor(src_lengths) + + +def infer_conv_output_dim(in_channels, input_dim, out_channels): + sample_seq_len = 200 + sample_bsz = 10 + x = torch.randn(sample_bsz, in_channels, sample_seq_len, input_dim) + x = torch.nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=3 // 2)(x) + x = torch.nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=3 // 2)(x) + x = x.transpose(1, 2) + mb, seq = x.size()[:2] + return x.contiguous().view(mb, seq, -1).size(-1) + + +class Conv2dSubsampler(nn.Module): + """Convolutional subsampler: a stack of 2D convolution based on ESPnet implementation + (https://github.com/espnet/espnet) + + Args: + input_channels (int): the number of input channels + input_feat_per_channel (int): encoder input dimension per input channel + conv_out_channels (int): the number of output channels of conv layer + encoder_embed_dim (int): encoder dimentions + """ + + def __init__( + self, + input_channels: int, + input_feat_per_channel: int, + conv_out_channels: int, + encoder_embed_dim: int, + ): + super().__init__() + assert input_channels == 1, input_channels + self.conv = torch.nn.Sequential( + torch.nn.Conv2d( + input_channels, conv_out_channels, 3, stride=2, padding=3 // 2 + ), + torch.nn.ReLU(), + torch.nn.Conv2d( + conv_out_channels, + conv_out_channels, + 3, + stride=2, + padding=3 // 2, + ), + torch.nn.ReLU(), + ) + transformer_input_dim = infer_conv_output_dim( + input_channels, input_feat_per_channel, conv_out_channels + ) + self.out = torch.nn.Linear(transformer_input_dim, encoder_embed_dim) + + def forward(self, src_tokens, src_lengths): + B, T_i, C = src_tokens.size() + x = src_tokens.view(B, T_i, 1, C).transpose(1, 2).contiguous() + x = self.conv(x) + B, _, T_o, _ = x.size() + x = x.transpose(1, 2).transpose(0, 1).contiguous().view(T_o, B, -1) + x = self.out(x) + + subsampling_factor = int(T_i * 1.0 / T_o + 0.5) + input_len_0 = (src_lengths.float() / subsampling_factor).ceil().long() + input_len_1 = x.size(0) * torch.ones([src_lengths.size(0)]).long().to( + input_len_0.device + ) + input_lengths = torch.min(input_len_0, input_len_1) + return x, input_lengths diff --git a/fairseq/models/speech_to_text/modules/emformer.py b/fairseq/models/speech_to_text/modules/emformer.py new file mode 100644 index 0000000000000000000000000000000000000000..935d5930787f3fd3d3dbe227e233c140a10f65b2 --- /dev/null +++ b/fairseq/models/speech_to_text/modules/emformer.py @@ -0,0 +1,1844 @@ +#!/usr/bin/env python3 +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + + +import math +import re +from functools import partial +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from torch import Tensor +from torch import device as Device + +from fairseq.models import FairseqEncoder +from fairseq.models.speech_to_text.utils import ( + NoOp, + attention_suppression, + layer_norm_backward_hook, + lengths_to_padding_mask, + segments_to_sequence, +) + +try: + import torch.ao.quantization as quantization + from torch.ao.quantization.qconfig import ( + default_dynamic_qconfig, + per_channel_dynamic_qconfig, + ) +except ImportError: + import torch.quantization as quantization + from torch.quantization.qconfig import ( + default_dynamic_qconfig, + per_channel_dynamic_qconfig, + ) + + +class RelativePositionEmbedding(nn.Module): + """ + Implementation according to https://arxiv.org/abs/1803.02155 + """ + + def __init__(self, head_dim, max_position, norm_init=True): + super().__init__() + self.head_dim = head_dim + self.max_position = max_position + self.embeddings = nn.Parameter(torch.Tensor(max_position * 2 + 1, head_dim)) + if norm_init: + nn.init.xavier_normal_(self.embeddings) + else: + nn.init.xavier_uniform_(self.embeddings) + + def forward(self, input: Tensor): + output = nn.functional.embedding(input.long(), self.embeddings) + return output + + +class Fp32LayerNorm(nn.Module): + def __init__( + self, + input_dim, + clamp_grad=True, + max_grad_value=256, + eps=1e-5, + elementwise_affine=True, + ): + super().__init__() + self.torch_module = torch.nn.LayerNorm( + input_dim, eps=eps, elementwise_affine=elementwise_affine + ) + if clamp_grad: + hook = partial(layer_norm_backward_hook, clamp_value=max_grad_value) + self.torch_module.register_backward_hook(hook) + + def forward(self, input): + output = torch.nn.functional.layer_norm( + input.float(), + self.torch_module.normalized_shape, + self.torch_module.weight.float() + if self.torch_module.weight is not None + else None, + self.torch_module.bias.float() + if self.torch_module.bias is not None + else None, + self.torch_module.eps, + ).type_as(input) + return output + + +# ------------------------------------------------------------------------------ +# PositionwiseFF +# ------------------------------------------------------------------------------ + + +class PositionwiseFF(nn.Module): + """ + FFN layer in transformer. + + Args: + input_dim: input embedding dimension + ffn_dim: FFN layer inner dimension + dropout_on_fc1: dropout for first linear layer + dropout_on_fc2: dropout fr second linear layer + activation_fn: activation function used after first linear layer. \ + Only relu or gelu is supported. + + """ + + def __init__( + self, input_dim, ffn_dim, dropout_on_fc1, dropout_on_fc2, activation_fn + ): + super(PositionwiseFF, self).__init__() + + self.input_dim = input_dim + self.ffn_dim = ffn_dim + if activation_fn == "relu": + ac = nn.ReLU() + elif activation_fn == "gelu": + ac = nn.GELU() + else: + raise ValueError("Unsupported activation_fn = ({})".format(activation_fn)) + + # fc1 -> ac -> dropout -> fc2 -> dropout + self.module = nn.Sequential( + nn.Linear(input_dim, ffn_dim), + ac, + nn.Dropout(dropout_on_fc1), + nn.Linear(ffn_dim, input_dim), + nn.Dropout(dropout_on_fc2), + ) + + self.layer_norm = Fp32LayerNorm(input_dim) + + def forward(self, input): + module_out = self.module(self.layer_norm(input)) + output = module_out + input + + return output + + def quantize_(self, params=None): + if params and "per_channel" in params and params["per_channel"]: + qconfig = per_channel_dynamic_qconfig + else: + qconfig = default_dynamic_qconfig + quantization.quantize_dynamic( + self, {torch.nn.Linear: qconfig}, dtype=torch.qint8, inplace=True + ) + return self + + +# ------------------------------------------------------------------------------ +# SummarizationLayer +# ------------------------------------------------------------------------------ + + +class SummarizationLayer(nn.Module): + def __init__(self, method, segment_size, embedding_dim): + super(SummarizationLayer, self).__init__() + self.segment_size = segment_size + self.embedding_dim = embedding_dim + nonlin_match = re.match(r"nonlinear\((?P[a-z]+),(?P[0-9]+)\)", method) + self.method = method + if method == "mean": + self.module = nn.AvgPool1d( + kernel_size=segment_size, + stride=segment_size, + ceil_mode=True, + ) + elif method == "max": + self.module = nn.MaxPool1d( + kernel_size=segment_size, + stride=segment_size, + ceil_mode=True, + ) + elif method == "linear": + self.module = nn.Linear(segment_size, 1) + elif nonlin_match: + nonlin_args = nonlin_match.groupdict() + act_type = nonlin_args["act"] + hid_dim = int(nonlin_args["dim"]) + if act_type == "relu": + act = nn.ReLU() + elif act_type == "gelu": + act = nn.GELU() + else: + raise ValueError("Unsupported activation_fn = ({})".format(act_type)) + self.module = nn.Sequential( + nn.Linear(segment_size, hid_dim), + act, + nn.Linear(hid_dim, 1), + ) + else: + raise ValueError("Unsupported summarization method = ({})".format(method)) + + def forward(self, input): + # T, B, D -> B, D, T + input = input.permute(1, 2, 0) + + if self.method == "mean" or self.method == "max": + output = self.module(input) + output = output.permute(2, 0, 1) + return output + + full_seg_length = input.size(2) // self.segment_size * self.segment_size + if full_seg_length > 0: + # at least one seg is full + B = input.size(0) + D = input.size(1) + input_todo = ( + input[:, :, :full_seg_length] + .contiguous() + .view(B, -1, self.segment_size) + ) + output = self.module(input_todo) + output = output.view(B, D, -1) + else: + output = input.new_zeros(input.size(0), input.size(1), 0) + left = input.size(2) - full_seg_length + if left > 0: + # when last seg is not full, use zeros as last memory placeholder + zeros = input.new_zeros(input.size(0), input.size(1), 1) + output = torch.cat([output, zeros], dim=2) + output = output.permute(2, 0, 1) + return output + + +# ------------------------------------------------------------------------------ +# NoSegAugmentedMemoryMultiheadAttentionBmm +# ------------------------------------------------------------------------------ + + +class NoSegAugmentedMemoryMultiheadAttentionBmm(nn.Module): + """ + Whole utterance augmented memory multihead attention using BMM. + + Different with previous augmented memory multihead attention where + the utterance is chunked into segments. Here we use attention mask + achieve so. The input embedding [right_context, utterance, summary] + is a concatenation of right context, utterance and summary. + + Right context block is the concatenation of all the right context for + each segments. [right_context_0, right_context_1, ..., right_context_n] + For example, if we have utterance = [v0, v1, v2, ...., v20]. segment + size 8, right_context size 4. Then the right context blocks = + [v8, v9, v10, v11, v16, v17, v18, v19, 0, 0, 0, 0], where v8, v9, v10, + and v11 are the right context for first segment. v16, v17, v18 and v19 + are the right context for second segment. 0, 0, 0 and 0 are right context + for the last segment. + + utterance is corresponding to input embedding sequence + + summary is concatenation of average of each segments. [summary_0, + summary_1, ..., ]. + + In augmented memory multihead attention, the query is [right_context, + utterance, summary], key is [memory, right_context, utterance]. Different + with AugmentedMemoryMultiheadAttentionBmm, memory here is passed from + previous attention layer. For the first attention layer, memory is average + of each segment. + + Memory is a concatenation of memory from each segments in previous attention + layer. For example, current layer is i, then memory is [m_0, m_1, ..., m_n]. + Each m_k is the output from seg_k in layer i-1. + + args: + input_dim: input embedding dimension + num_heads: number of heads in multihead self-attention + dropout: attention dropout + std_scale: if std_scale is not None. The weak attention suppression is + turned on. For std_scale = 0.5, all the attention smaller than + mean + 0.5 * std will be suppressed. + scaled_init: whether to use scaled init for linear weight + tanh_on_mem: whether to use tanh on memory output + use_mem: whether to use memory or not. When max_memory_size is 0, then + we don't have memory anymore. + layer_index: current self-attention layer index that is used in depth + initialization + max_relative_position: max relative position used in relative position + embedding + rpe_old_option: To be compatible with previous model. The previous model + was trained with attention += attention + rpe. The correct equation + should be attention = attention + rpe + + """ + + def __init__( + self, + input_dim, + num_heads, + dropout=0.0, + std_scale=None, + scaled_init=False, + tanh_on_mem=False, + use_mem=True, + mini_batches=False, + negative_inf="-inf", + layer_index=-1, + max_relative_position=0, + rpe_old_option=True, + ): + if input_dim % num_heads: + raise ValueError( + "input_dim ({}) must be divisible by num_heads ({})".format( + input_dim, num_heads + ) + ) + + super().__init__() + + embed_dim = input_dim + self.e2h_kv = torch.nn.Linear(input_dim, 2 * input_dim, bias=True) + self.e2h_q = torch.nn.Linear(input_dim, input_dim, bias=True) + self.rpe_old_option = rpe_old_option + if max_relative_position > 0: + self.use_rpe = True + self.rpe_k = RelativePositionEmbedding( + head_dim=input_dim // num_heads, + max_position=max_relative_position, + ) + self.rpe_v = RelativePositionEmbedding( + head_dim=input_dim // num_heads, + max_position=max_relative_position, + ) + else: + self.use_rpe = False + self.rpe_k = None + self.rpe_v = None + if scaled_init: + if layer_index == -1: + gain = 1.0 / math.sqrt(2) + else: + # https://arxiv.org/abs/2005.09684 depthwise initialization + # stablize the training greatly. Use depthwise initialization to + # replace incremental loss. + gain = 1.0 / math.sqrt(layer_index + 1) + torch.nn.init.xavier_uniform_(self.e2h_kv.weight, gain=gain) + torch.nn.init.xavier_uniform_(self.e2h_q.weight, gain=gain) + + self.out_proj = torch.nn.Linear(embed_dim, embed_dim, bias=True) + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + + self.head_dim = embed_dim // num_heads + self.scaling = self.head_dim**-0.5 + + self.std_scale = std_scale + self.use_mem = use_mem + self.mini_batches = mini_batches + self.negative_inf = negative_inf + + if tanh_on_mem: + self.squash_mem = torch.tanh + self.nonlinear_squash_mem = True + else: + self.squash_mem = NoOp() + self.nonlinear_squash_mem = False + + def prepare_qkv( + self, + input: Tensor, + mems: Tensor, + lengths: Tensor, + summary_length: int, + lc_length: int, + ): + # T: right_context length + utterance_length + summary_length + T, B, D = input.shape + mem_length = mems.size(0) + utterance_length = torch.max(lengths) + + right_context_blocks_length = T - utterance_length - summary_length + rc_block = input[:right_context_blocks_length, :, :] + utterance_block = input[right_context_blocks_length : T - summary_length, :, :] + + if B == 1: + padding_mask = None + else: + klengths = lengths + mem_length + right_context_blocks_length + lc_length + padding_mask = lengths_to_padding_mask(lengths=klengths) + + mem_rc_input = torch.cat([mems, rc_block, utterance_block], dim=0) + + # In training lc_length = 0 + key_length = mem_rc_input.size(0) + lc_length + rc_input_sum = input + q = self.e2h_q(rc_input_sum) + kv = self.e2h_kv(mem_rc_input) + k, v = kv.chunk(chunks=2, dim=2) + result_qkv = (q, k, v) + input_shape = (T, B, D) + result_lengths_info = ( + mem_length, + utterance_length, + right_context_blocks_length, + key_length, + ) + if padding_mask is not None: + assert padding_mask.size(0) == B + assert padding_mask.size(1) == key_length + + return result_qkv, input_shape, result_lengths_info, padding_mask + + def prepare_attention_weights( + self, + q: Tensor, + new_k: Tensor, + new_v: Tensor, + input_shape: Tuple[int, int, int], + rpe: Optional[Tensor], + ) -> Tuple[Tensor, Tensor, Tensor]: + T, B, D = input_shape + q = ( + q.contiguous().view(-1, B * self.num_heads, self.head_dim).transpose(0, 1) + * self.scaling + ) + + k = ( + new_k.contiguous() + .view(-1, B * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + v = ( + new_v.contiguous() + .view(-1, B * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + attention_weights = torch.bmm(q, k.transpose(1, 2)) + if self.use_rpe and rpe is not None and self.rpe_v is not None: + r_k = self.rpe_k(rpe) + # [q, B*h, d] * [q, k, d] -> [B*h, q, k] + attention_weights_rpe = torch.matmul( + q.transpose(0, 1), r_k.transpose(1, 2) + ).transpose(0, 1) + attention_weights = attention_weights + attention_weights_rpe + attention_weights_float = attention_weights.float() + + return attention_weights, attention_weights_float, v + + def prepare_attention_output( + self, + attention_weights: Tensor, + attention_weights_float: Tensor, + v: Tensor, + input_shape: Tuple[int, int, int], + key_length: int, + padding_mask: Optional[Tensor], + rpe: Optional[Tensor], + ) -> Tensor: + T, B, D = input_shape + if padding_mask is not None: + attention_weights_float = attention_weights_float.view( + B, self.num_heads, T, key_length + ) + attention_weights_float = attention_weights_float.masked_fill( + padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") + ) + attention_weights_float = attention_weights_float.view( + B * self.num_heads, T, key_length + ) + + if self.std_scale is not None: + attention_weights_float = attention_suppression( + attention_weights_float, self.std_scale + ) + + attention_weights_float = torch.nn.functional.softmax( + attention_weights_float, dim=-1 + ) + attention_weights = attention_weights_float.type_as(attention_weights) + + attention_probs = torch.nn.functional.dropout( + attention_weights, p=self.dropout, training=self.training + ) + + # [T, key_length, B, n_head]+ [key_length, B, n_head, d_head] + # -> [T, B, n_head, d_head] + attention = torch.bmm(attention_probs, v) + if self.use_rpe and rpe is not None and self.rpe_v is not None: + r_v = self.rpe_v(rpe) + attention_rpe = torch.matmul( + attention_probs.transpose(0, 1), r_v + ).transpose(0, 1) + + if self.rpe_old_option: + attention += attention + attention_rpe + else: + attention = attention + attention_rpe + + assert list(attention.shape) == [B * self.num_heads, T, self.head_dim] + + attention = attention.transpose(0, 1).contiguous().view(T, B, self.embed_dim) + + rc_output_memory = self.out_proj(attention) + return rc_output_memory + + @torch.jit.unused + def forward( + self, + input: Tensor, + lengths: Tensor, + mems: Tensor, + attention_mask: Tensor, + pre_mems: Optional[Tensor] = None, + left_context_key: Optional[Tensor] = None, + left_context_val: Optional[Tensor] = None, + rpe: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """ + forward function for NoSegAugmentedMemoryMultiheadAttentionBmm in training. + + args: + input: formed in the following way + [right_context_0, right_contex_1, ..., seg_0, seg_1, + ..., summary_0, summary_1,..] + lengths: the length of query which is [seg_0, seg_1, ....] + mems: [mem_0, mem_1, ...]. + attention_mask: attention mask for query = [right_context, query, summary] + key = [mem, right_context, query]. This is only used for traing. + + """ + if self.use_mem: + mem_length = mems.size(0) + summary_length = mem_length + 1 + if pre_mems is not None: + mems = torch.cat([pre_mems, mems], dim=0) + else: + mem_length = 0 + summary_length = 0 + + # In training, lc_length = 0 + if left_context_key is not None: + lc_length = left_context_key.size(0) + else: + lc_length = 0 + results = self.prepare_qkv( + input=input, + mems=mems, + lengths=lengths, + summary_length=summary_length, + lc_length=lc_length, + ) + result_qkv, input_shape, result_lengths_info, padding_mask = results + q, k, v = result_qkv + ( + mem_length, + utterance_length, + right_context_blocks_length, + key_length, + ) = result_lengths_info + + if left_context_key is not None: + # add the cache key and value + new_k = torch.cat( + [ + k[: mem_length + right_context_blocks_length, :, :], + left_context_key, + k[-utterance_length:, :, :], + ], + dim=0, + ) + new_v = torch.cat( + [ + v[: mem_length + right_context_blocks_length, :, :], + left_context_val, + v[-utterance_length:, :, :], + ], + dim=0, + ) + next_k = new_k[mem_length + right_context_blocks_length :, :, :] + next_v = new_v[mem_length + right_context_blocks_length :, :, :] + else: + new_k = k + new_v = v + next_k = None + next_v = None + + attention_weights, attention_weights_float, v = self.prepare_attention_weights( + q=q, + new_k=new_k, + new_v=new_v, + input_shape=input_shape, + rpe=rpe, + ) + + # mask attention + attention_mask = attention_mask.unsqueeze(0) + attention_weights_float = attention_weights_float.masked_fill( + attention_mask, float(self.negative_inf) + ) + + rc_output_memory = self.prepare_attention_output( + attention_weights=attention_weights, + attention_weights_float=attention_weights_float, + v=v, + input_shape=input_shape, + key_length=key_length, + padding_mask=padding_mask, + rpe=rpe, + ) + + if self.use_mem: + # next_m length equals to summary length - 1 + # last memory is ignored + if self.mini_batches: + next_m = rc_output_memory[-summary_length:] + else: + next_m = rc_output_memory[-summary_length:-1] + + next_m = self.squash_mem(next_m) + # rc and output + rc_output = rc_output_memory[:-summary_length] + if not self.nonlinear_squash_mem: + next_m = torch.clamp(next_m, min=-10, max=10) + else: + next_m = mems + rc_output = rc_output_memory + + return rc_output, next_m, next_k, next_v + + @torch.jit.export + def forward_jit( + self, + input: Tensor, + lengths: Tensor, + mems: Tensor, + left_context_key: Tensor, + left_context_val: Tensor, + rpe: Optional[Tensor], + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """ + forward function for NoSegAugmentedMemoryMultiheadAttentionBmm in decoding. + + args: + input: formed in the following way + [right_context_0, right_contex_1, ..., seg_0, seg_1, + ..., summary_0, summary_1,..] + lengths: the length of query which is [seg_0, seg_1, ....] + mems: [mem_0, mem_1, ...]. + left_context_key: left_context for key part. This is only used for online + decoding. In training, this is empty tensor + left_context_val: left_context for value part. This is only used for online + decoding. In training, this is empty tensor + + """ + lc_length = left_context_key.size(0) + + # In decoding, summary_length = 1 or 0 + if self.use_mem: + summary_length = 1 + else: + summary_length = 0 + + results = self.prepare_qkv( + input=input, + mems=mems, + lengths=lengths, + summary_length=summary_length, + lc_length=lc_length, + ) + result_qkv, input_shape, result_lengths_info, padding_mask = results + q, k, v = result_qkv + ( + mem_length, + utterance_length, + right_context_blocks_length, + key_length, + ) = result_lengths_info + + # add the cache key and value + new_k = torch.cat( + [ + k[: mem_length + right_context_blocks_length, :, :], + left_context_key, + k[-utterance_length:, :, :], + ], + dim=0, + ) + new_v = torch.cat( + [ + v[: mem_length + right_context_blocks_length, :, :], + left_context_val, + v[-utterance_length:, :, :], + ], + dim=0, + ) + next_k = new_k[mem_length + right_context_blocks_length :, :, :] + next_v = new_v[mem_length + right_context_blocks_length :, :, :] + + attention_weights, attention_weights_float, v = self.prepare_attention_weights( + q=q, + new_k=new_k, + new_v=new_v, + input_shape=input_shape, + rpe=rpe, + ) + # In online decoding, we don't have attention mask. But we still need + # to disable the attention from summary query to memory + attention_weights_float[:, -1, :mem_length] = float(self.negative_inf) + rc_output_memory = self.prepare_attention_output( + attention_weights=attention_weights, + attention_weights_float=attention_weights_float, + v=v, + input_shape=input_shape, + key_length=key_length, + padding_mask=padding_mask, + rpe=rpe, + ) + + # In decoding, summary length is 1 + if self.use_mem: + next_m = rc_output_memory[-1:] + next_m = self.squash_mem(next_m) + # rc and output + rc_output = rc_output_memory[:-1] + if not self.nonlinear_squash_mem: + next_m = torch.clamp(next_m, min=-10, max=10) + else: + rc_output = rc_output_memory + # empty tensor as input mems + next_m = mems + + return rc_output, next_m, next_k, next_v + + def quantize_(self, params=None): + if params and "per_channel" in params and params["per_channel"]: + qconfig = per_channel_dynamic_qconfig + else: + qconfig = default_dynamic_qconfig + quantization.quantize_dynamic( + self, {torch.nn.Linear: qconfig}, dtype=torch.qint8, inplace=True + ) + return self + + +class NoSegAugmentedMemoryTransformer(nn.Module): + """ + Whole utterance augmented memory transformer. + + This is not pyspeech nn layer. It is used as a module in a master layer where + multiple transformers is used. + """ + + def __init__( + self, + input_dim, + num_heads, + ffn_dim, + dropout_in_attn=0.0, + dropout_on_attn=None, + dropout_on_fc1=None, + dropout_on_fc2=None, + activation_fn="relu", + tanh_on_mem=False, + std_scale=None, + scaled_init=False, + segment_size=128, + use_mem=True, + mini_batches=False, + negative_inf="-inf", + layer_index=-1, + summarization_method="mean", + max_relative_position=0, + rpe_old_option=True, + ): + super(NoSegAugmentedMemoryTransformer, self).__init__() + + self.attention = NoSegAugmentedMemoryMultiheadAttentionBmm( + input_dim=input_dim, + num_heads=num_heads, + dropout=dropout_in_attn, + scaled_init=scaled_init, + tanh_on_mem=tanh_on_mem, + std_scale=std_scale, + use_mem=use_mem, + mini_batches=mini_batches, + negative_inf=negative_inf, + layer_index=layer_index, + max_relative_position=max_relative_position, + ) + self.dropout = nn.Dropout(dropout_on_attn) + self.pos_ff = PositionwiseFF( + input_dim=input_dim, + ffn_dim=ffn_dim, + dropout_on_fc1=dropout_on_fc1, + dropout_on_fc2=dropout_on_fc2, + activation_fn=activation_fn, + ) + self.layer_norm_pre = Fp32LayerNorm(input_dim) + self.layer_norm = Fp32LayerNorm(input_dim) + self.segment_size = segment_size + self.use_mem = use_mem + + self.memory_op = SummarizationLayer( + summarization_method, segment_size, input_dim + ) + + def set_mini_batches(self, mini_batches): + self.attention.mini_batches = mini_batches + + def gen_summary_queries(self, input): + sum_input = self.memory_op(input) + return sum_input + + def pre_attention_ops(self, input, right_context_blocks): + rc_length = right_context_blocks.size(0) + input_length = input.size(0) + + rc_and_input = torch.cat([right_context_blocks, input], dim=0) + residual_input = rc_and_input + rc_and_input = self.layer_norm_pre(rc_and_input) + + query_input = rc_and_input[-input_length:, :, :] + return rc_length, input_length, residual_input, query_input, rc_and_input + + def after_attention_ops(self, attention_output, residual_input): + output = self.dropout(attention_output) + output = output + residual_input + output = self.pos_ff(output) + output = self.layer_norm(output) + return output + + @torch.jit.export + def forward_jit( + self, + input: Tensor, + lengths: Tensor, + mems: Tensor, + left_context_key: Tensor, + left_context_val: Tensor, + right_context_blocks: Tensor, + rpe: Optional[Tensor], + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + + results = self.pre_attention_ops(input, right_context_blocks) + rc_length, input_length, residual_input, query_input, rc_and_input = results + + # In online decoding, the summary query size is always 1 or 0 + if self.use_mem: + summary_query = self.gen_summary_queries(query_input) + summary_query = summary_query[0:1, :, :] + rc_qu_su = torch.cat([rc_and_input, summary_query], dim=0) + else: + rc_qu_su = rc_and_input + + rc_output, next_m, next_k, next_v = self.attention.forward_jit( + input=rc_qu_su, + lengths=lengths, + mems=mems, + left_context_key=left_context_key, + left_context_val=left_context_val, + rpe=rpe, + ) + rc_output = self.after_attention_ops(rc_output, residual_input) + results = ( + rc_output[-input_length:, :, :], + next_m, + rc_output[0:rc_length, :, :], + next_k, + next_v, + ) + return results + + @torch.jit.unused + def forward( + self, + input, + lengths, + mems, + right_context_blocks, + attention_mask, + pre_mems, + left_context_key, + left_context_val, + rpe, + ): + + results = self.pre_attention_ops(input, right_context_blocks) + rc_length, input_length, residual_input, query_input, rc_and_input = results + if self.use_mem: + summary_query = self.gen_summary_queries(query_input) + rc_qu_su = torch.cat([rc_and_input, summary_query], dim=0) + else: + rc_qu_su = rc_and_input + + rc_output, next_m, next_k, next_v = self.attention( + input=rc_qu_su, + lengths=lengths, + mems=mems, + attention_mask=attention_mask, + pre_mems=pre_mems, + left_context_key=left_context_key, + left_context_val=left_context_val, + rpe=rpe, + ) + + # [TODO] Note memory did not go through pos_ff. What happen if we pass + # memory through the pos_ff as well? + rc_output = self.after_attention_ops(rc_output, residual_input) + results = ( + rc_output[-input_length:, :, :], + next_m, + rc_output[0:rc_length, :, :], + next_k, + next_v, + ) + + return results + + +class NoSegAugmentedMemoryTransformerEncoderLayer(FairseqEncoder): + """ + Whole utterance augmented memory transformer encoder layer. This is a master layer + where we can define multiple augmented memory transformers. There are two reasons + to setup the master layer. + 1. We only need to define once about the attention mask. All the layers in the master + layer share the same mask. + 2. pyspeech nn layer has special input and output format. Defining one master layer is + easier to passing memory between different layes inside the master layer + + args: + input_dim: input embedding dimension + num_heads: number of heads in multihead self-attention + ffn_dim: ffn dimension in FFN layer + num_layers: number of augmented memory transformer layers + dropout_in_attn: dropout used in multi-head self-attention + dropout_on_attn: dropout used for output from te multihead self-attention + dropout_on_fc1: dropout used in FFN layer for the first linear layer + dropout_on_fc2: dropout used in FFN layer for the second linear layer + segment_size: segment size for each segment + context_config: (left_context_size, right_context_size) defines the surround context size + for each segment + max_memory_size: maximum memory size used for each segment + scaled_init: whether use scaled init for weight initialization in attention layer + std_scale: if std_scale is not None. The weak attention suppression is + turned on. For std_scale = 0.5, all the attention smaller than + mean + 0.5 * std will be suppressed. + activation_fn: activation function used in FFN layer. [ReLU, GELU] supported + tanh_on_mem: whether use tanh on memory + mini_batches: use mini-btach training + negative_inf: the negative infinity value used in attention masking. default is "-inf". + For some situation, e.g. LM. it is better to use "-1e8" to avoid nan issue. + summarization_method: method to generate segment summrization embedding + max_relative_position: max relatie position for relative position embedding + rpe_old_option: To be compatible with previous model. The previous model + was trained with attention += attention + rpe. The correct equation + should be attention = attention + rpe + [TODO]: remove the rpe_old_option by the end of 2021 Q1. + + """ + + def __init__( + self, + input_dim, + num_heads, + ffn_dim, + num_layers=1, + dropout_in_attn=0.0, + dropout_on_attn=0.0, + dropout_on_fc1=0.0, + dropout_on_fc2=0.0, + segment_size=128, + context_config=(0, 0), + max_memory_size=0, + scaled_init=True, + std_scale=None, + activation_fn="relu", + tanh_on_mem=False, + mini_batches=False, + negative_inf="-inf", + deep_init=True, + summarization_method="mean", + max_relative_position=0, + rpe_old_option=True, + ): + super().__init__(None) + if input_dim % num_heads: + raise ValueError( + "input_dim ({}) must be divisible by num_heads ({})".format( + input_dim, num_heads + ) + ) + + # we used to support growing memory size. However, it will cause + # cross stream batching failure. Now we need to have exact max memory size + if max_memory_size < 0: + raise ValueError("max_memory_size must be >= 0") + + # Only assign right_context. In decoding, left context will be cached. + # No need to let the online decoder to re-assign the left context + self.left_context, self.right_context = context_config + self.segment_size = segment_size + self.memory_dim = input_dim + self.max_memory_size = max_memory_size + self.mini_batches = mini_batches + if self.max_memory_size != 0: + self.use_mem = True + else: + self.use_mem = False + + self.memory_op = SummarizationLayer( + summarization_method, segment_size, input_dim + ) + + self.layers = torch.nn.ModuleList() + self.num_layers = num_layers + self.max_relative_position = max_relative_position + if self.max_relative_position > 0: + self.use_rpe = True + else: + self.use_rpe = False + for i in range(self.num_layers): + if deep_init: + layer_index = i + else: + layer_index = -1 + + self.layers.append( + NoSegAugmentedMemoryTransformer( + num_heads=num_heads, + input_dim=input_dim, + ffn_dim=ffn_dim, + dropout_in_attn=dropout_in_attn, + dropout_on_attn=dropout_on_attn, + dropout_on_fc1=dropout_on_fc1, + dropout_on_fc2=dropout_on_fc2, + segment_size=segment_size, + std_scale=std_scale, + activation_fn=activation_fn, + tanh_on_mem=tanh_on_mem, + scaled_init=scaled_init, + use_mem=self.use_mem, + mini_batches=mini_batches, + negative_inf=negative_inf, + layer_index=layer_index, + summarization_method=summarization_method, + max_relative_position=max_relative_position, + rpe_old_option=rpe_old_option, + ) + ) + + def set_mini_batches(self, mini_batches): + # handy function only used for unit test + self.mini_batches = mini_batches + for layer in self.layers: + layer.set_mini_batches(mini_batches) + + def _get_relative_position( + self, + input: Tensor, + max_relative_position: int, + left_context_length: int, + past_length: int, + is_decoding: bool, + ): + # For training, we copy the right context to the start of the utterance + # First dimension in distance is corresponding to query. + # [right context, utterance, summary vector] + # Second dimension in distance is corresponding to key. + # [Memory bank, right context, utterance] + # For summary vector in query part, the distance with + # all other position is 2*max_position. For memory bank in key, + # the distance with all other positions is 0. + + T, B, D = input.shape + num_segs = math.ceil((T - self.right_context) / self.segment_size) + + # utterance + u_st = past_length * self.segment_size + u_ed = u_st + T + utterance_ranges = torch.arange(u_st, u_ed - self.right_context) + + # left context. Only in minibatch or decoding + left_context_ranges = torch.arange(u_st - left_context_length, u_st) + + # Right context block + # right context + utterance + right_context_blocks = [] + for i in range(0, num_segs - 1): + st = (i + 1) * self.segment_size + u_st + ed = st + self.right_context + assert ed < u_ed + temp = torch.arange(st, ed) + right_context_blocks.append(temp) + right_context_blocks.append(torch.arange(u_ed - self.right_context, u_ed)) + right_context_ranges = torch.cat(right_context_blocks) + + if self.use_mem: + # Memory bank + # The position for memory -n, .., -1 + if is_decoding: + memory_size = min(past_length, self.max_memory_size) + else: + memory_size = num_segs + past_length - 1 + memory_bank_ranges = torch.arange( + -max_relative_position - 1, -max_relative_position - 1 - memory_size, -1 + ) + + # summary vector + # The position for summary vector as the T+max_relative_position+1. + # After the clamping, the relative position is max_relative_position + summary_pos_st = u_ed + max_relative_position + 1 + summary_vector_ranges = torch.arange( + summary_pos_st, summary_pos_st + num_segs + ) + + key_ranges = torch.cat( + [ + memory_bank_ranges, + right_context_ranges, + left_context_ranges, + utterance_ranges, + ] + ) + + query_ranges = torch.cat( + [right_context_ranges, utterance_ranges, summary_vector_ranges] + ) + else: + key_ranges = torch.cat( + [right_context_ranges, left_context_ranges, utterance_ranges] + ) + + query_ranges = torch.cat([right_context_ranges, utterance_ranges]) + + distance = key_ranges[None, :] - query_ranges[:, None] + distance_clamp = ( + torch.clamp(distance, -max_relative_position, max_relative_position) + + max_relative_position + ) + distance_clamp = distance_clamp.to(input.device).long().detach() + return distance_clamp + + def _get_attention_mask(self, input, past_length=0, left_context_cache=0): + # attention mask for each query contains three parts: + # 1. memory part + # 2. left_context + segment + # 3. right_context_block + # so for each segment and its correspoinding right context block, + # the attention matrix is formed by 9 parts: + # [0, m, 0, 0, right_context, 0, 0, seg, 0] + # [before memory, memory, after memory, before right context, right_context, + # after right context, before seg, seg, after seg] + # + # Query is formed in the way as [right_context_blocks, utterance, summary] + # + # Note: put m and right_context before segment is convenient + # for padding_mask operation. + # Key lengths = m_length + right_context_block_length + lengths + utterance_length, batch_size, _ = input.shape + summary_length = math.ceil(utterance_length / self.segment_size) + num_segs = summary_length + rc_length = self.right_context * num_segs + rc = self.right_context + lc = self.left_context + + # using mini-batches, there is left context cache available for current + # sequence. + lcc = left_context_cache + + # max_memory_size is 0 then we don't have memory and summary + # past_length is the memory carry from previous sequence + if self.use_mem: + mem_length = num_segs - 1 + past_length + else: + mem_length = 0 + rc_mask = [] + query_mask = [] + summary_mask = [] + for j in range(0, num_segs): + ssize = min(self.segment_size, utterance_length - j * self.segment_size) + + rc_size = rc + rc_mat = [] + q_mat = [] + s_mat = [] + m_start = max(j + past_length - self.max_memory_size, 0) + + # max_memory_size is 0, then we don't use memory + if self.use_mem: + # part 0: before memory + rc_mat.append(input.new_zeros(rc_size, m_start)) + q_mat.append(input.new_zeros(ssize, m_start)) + s_mat.append(input.new_zeros(1, m_start)) + + # part 1: memory + col_1 = j + past_length - m_start + rc_mat.append(torch.ones(rc_size, col_1, device=input.device)) + q_mat.append(torch.ones(ssize, col_1, device=input.device)) + # based on D22875746, disable summary query attention + # on memeory is better for long form utterance + s_mat.append(input.new_zeros(1, col_1)) + + # part 2: after memory + col_2 = mem_length - (j + past_length) + rc_mat.append(input.new_zeros(rc_size, col_2)) + q_mat.append(input.new_zeros(ssize, col_2)) + s_mat.append(input.new_zeros(1, col_2)) + + # part 3: before right context + rc_start = j * rc + rc_mat.append(input.new_zeros(rc_size, rc_start)) + q_mat.append(input.new_zeros(ssize, rc_start)) + s_mat.append(input.new_zeros(1, rc_start)) + + # part 4: right context + rc_end = rc_start + rc + col_4 = rc + rc_mat.append(torch.ones(rc_size, col_4, device=input.device)) + q_mat.append(torch.ones(ssize, col_4, device=input.device)) + s_mat.append(torch.ones(1, col_4, device=input.device)) + + # part 5: after right context + col_5 = rc_length - rc_end + rc_mat.append(input.new_zeros(rc_size, col_5)) + q_mat.append(input.new_zeros(ssize, col_5)) + s_mat.append(input.new_zeros(1, col_5)) + + # part 6: before query segment + seg_start = max(j * self.segment_size + lcc - lc, 0) + rc_mat.append(input.new_zeros(rc_size, seg_start)) + q_mat.append(input.new_zeros(ssize, seg_start)) + s_mat.append(input.new_zeros(1, seg_start)) + + # part 7: query segment + # note: right context is put in right context block + # here we only need to consider about left context + seg_end = min((j + 1) * self.segment_size + lcc, utterance_length + lcc) + col_7 = seg_end - seg_start + rc_mat.append(torch.ones(rc_size, col_7, device=input.device)) + q_mat.append(torch.ones(ssize, col_7, device=input.device)) + s_mat.append(torch.ones(1, col_7, device=input.device)) + + # part 8: after query segment + col_8 = utterance_length + lcc - seg_end + rc_mat.append(input.new_zeros(rc_size, col_8)) + q_mat.append(input.new_zeros(ssize, col_8)) + s_mat.append(input.new_zeros(1, col_8)) + + rc_mask.append(torch.cat(rc_mat, dim=1)) + query_mask.append(torch.cat(q_mat, dim=1)) + summary_mask.append(torch.cat(s_mat, dim=1)) + + # no memory, then we don't need summary either + if self.use_mem: + attention_mask = ( + 1 + - torch.cat( + [ + torch.cat(rc_mask, dim=0), + torch.cat(query_mask, dim=0), + torch.cat(summary_mask, dim=0), + ], + dim=0, + ) + ).to(torch.bool) + else: + attention_mask = ( + 1 + - torch.cat( + [torch.cat(rc_mask, dim=0), torch.cat(query_mask, dim=0)], dim=0 + ) + ).to(torch.bool) + + return attention_mask + + @torch.jit.export + def init_state( + self, batch_size: int, device: Optional[Device] = None + ) -> List[Tensor]: + empty_memory = torch.zeros( + self.num_layers, + self.max_memory_size, + batch_size, + self.memory_dim, + device=device, + ) + left_context_key = torch.zeros( + self.num_layers, + self.left_context, + batch_size, + self.memory_dim, + device=device, + ) + left_context_val = torch.zeros( + self.num_layers, + self.left_context, + batch_size, + self.memory_dim, + device=device, + ) + past_length = torch.zeros(1, batch_size, dtype=torch.int32, device=device) + + return [empty_memory, left_context_key, left_context_val, past_length] + + @torch.jit.export + def batch_state(self, states: List[List[Tensor]]) -> List[Tensor]: + if len(states) == 0: + return [] + batched_m = [] + batched_lc_key = [] + batched_lc_val = [] + batched_past_length = [] + for state in states: + if len(state) == 0: + continue + m, lc_key, lc_val, past_length = state + batched_m.append(m) + batched_lc_key.append(lc_key) + batched_lc_val.append(lc_val) + batched_past_length.append(past_length) + + if ( + (len(batched_m) == 0) + or (len(batched_lc_key) == 0) + or (len(batched_lc_val) == 0) + or (len(batched_past_length) == 0) + ): + return [ + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + ] + + batched_m = torch.cat(batched_m, dim=2) + batched_lc_key = torch.cat(batched_lc_key, dim=2) + batched_lc_val = torch.cat(batched_lc_val, dim=2) + batched_past_length = torch.cat(batched_past_length, dim=1) + return [batched_m, batched_lc_key, batched_lc_val, batched_past_length] + + @torch.jit.export + def reorder_state(self, state: List[Tensor], indices: Tensor) -> List[Tensor]: + if len(state) == 0: + return [] + m, lc_key, lc_val, past_length = state + indices = indices.to(device=m.device) + reord_m = torch.index_select(m, 2, indices) + reord_lc_key = torch.index_select(lc_key, 2, indices) + reord_lc_val = torch.index_select(lc_val, 2, indices) + reord_past_length = torch.index_select(past_length, 1, indices) + return [reord_m, reord_lc_key, reord_lc_val, reord_past_length] + + @torch.jit.export + def reset_state(self, state: List[Tensor], indices: Tensor) -> List[Tensor]: + m, lc_key, lc_val, past_length = state + m = m.index_fill(dim=2, index=indices, value=0.0) + lc_key = lc_key.index_fill(dim=2, index=indices, value=0.0) + lc_val = lc_val.index_fill(dim=2, index=indices, value=0.0) + past_length = past_length.index_fill(dim=1, index=indices, value=0) + + return [m, lc_key, lc_val, past_length] + + @torch.jit.export + def state_size(self) -> int: + return 4 + + @torch.jit.export + def batch_size_in_state( + self, state: Optional[List[Tensor]], sloppy: bool = True + ) -> Optional[int]: + if state is None: + return None + return state[0].size(2) + + def gen_summary_queries(self, input): + sum_input = self.memory_op(input) + return sum_input + + def _gen_right_context_padded_input(self, input): + # This function deals with input that is already + # padded with right context (e.g. minibatch training) + right_context_blocks = [] + T, B, D = input.shape + num_segs = math.ceil((T - self.right_context) / self.segment_size) + for i in range(0, num_segs - 1): + st = (i + 1) * self.segment_size + ed = st + self.right_context + assert ed < T + temp = input[st:ed, :, :] + right_context_blocks.append(temp) + + # last segment right context is already available + right_context_blocks.append(input[T - self.right_context :, :, :]) + return torch.cat(right_context_blocks, dim=0) + + def _gen_segs_right_context(self, input, lengths): + segments = [] + T, B, D = input.size() + nT = T - self.right_context + + # assume input is right context padded + num_segs = math.ceil(nT / self.segment_size) + # pad zeros to the utterance to make sure each + # segment has the same right context. For the + for i in range(0, num_segs - 1): + st = i * self.segment_size + ed = min(T, st + self.segment_size + self.right_context) + temp = input[st:ed, :, :] + rest_lengths = torch.clamp( + lengths - self.segment_size, min=0, max=nT - (i + 1) * self.segment_size + ) + segments.append((temp, lengths - rest_lengths + self.right_context)) + lengths = rest_lengths + + last_seg = input[st + self.segment_size :, :, :] + segments.append((last_seg, rest_lengths + self.right_context)) + + return segments + + @torch.jit.unused + def forward( + self, input: Tensor, padding_masks: Tensor, state: Optional[List[Tensor]] = None + ) -> Tuple[Tensor, Tensor, List[Tensor], List[Tensor]]: + # Xutai: originally the second argument is lengths. + lengths = (~padding_masks).sum(dim=1).long() + # mini batch training. + if self.mini_batches: + return self.forward_mini_batches(input, lengths, state) + + # regular full sequence training. Note, assume the right context in provided + # in the input. + T, B, D = input.size() + right_context_blocks = self._gen_right_context_padded_input(input) + + # generate the relative positional embedding + if self.use_rpe: + rpe = self._get_relative_position( + input=input, + max_relative_position=self.max_relative_position, + left_context_length=0, + past_length=0, + is_decoding=False, + ) + else: + rpe = None + input = input[: T - self.right_context, :, :] + + attention_mask = self._get_attention_mask(input) + + # firt layer use each segment mean as memory + # ignore the last one seg average + if self.use_mem: + mems = self.gen_summary_queries(input)[:-1, :, :] + else: + mems = torch.zeros(0, input.size(1), input.size(2), device=input.device) + mems = mems.type_as(input) + + output = input + all_outputs = [] + + for layer in self.layers: + output, mems, right_context_blocks, _, _ = layer( + input=output, + lengths=lengths, + attention_mask=attention_mask, + mems=mems, + right_context_blocks=right_context_blocks, + pre_mems=None, + left_context_key=None, + left_context_val=None, + rpe=rpe, + ) + all_outputs.append(output) + return output, padding_masks, [], all_outputs + + def forward_jit_mini_batch_init( + self, + seg: Tensor, + state: Optional[List[Tensor]] = None, + is_decoding: bool = False, + ): + # Prepare state. In whole sequence training, state is ignored. + # For minibatch training, we need to prepare state + if state is None: + state = self.init_state(batch_size=seg.size(1), device=seg.device) + if seg.dtype == torch.half: + state = [state[0].half(), state[1].half(), state[2].half(), state[3]] + + if self.use_mem: + # note input average only on seg, not on right context + # first layer use each segmetn mean as memory. the last + # one segment average is used in state + full_mems = self.gen_summary_queries(seg) + if is_decoding: + mems = full_mems[0:1, :, :] + state_mems = torch.cat([state[0][0], mems], dim=0) + else: + mems = full_mems[:-1, :, :] + state_mems = torch.cat([state[0][0], full_mems], dim=0) + else: + mems = state[0][0] + state_mems = mems + + # track processed segment number or memory number + # the same batch as the same bumber of past length + past_length = state[3][0][0].item() + past_left_context = min(past_length * self.segment_size, self.left_context) + past_length = min(self.max_memory_size, past_length) + + return state, mems, state_mems, past_length, past_left_context + + def state_update_before( + self, layer: int, state: List[Tensor], past_length: int, past_left_context: int + ): + pre_mems = state[0][layer][self.max_memory_size - past_length :, :, :] + lc_key = state[1][layer][self.left_context - past_left_context :, :, :] + lc_val = state[2][layer][self.left_context - past_left_context :, :, :] + return pre_mems, lc_key, lc_val + + def state_update_after( + self, + layer: int, + state: List[Tensor], + mems: Tensor, + next_key: Tensor, + next_val: Tensor, + mems_list: List[Tensor], + lc_key_list: List[Tensor], + lc_val_list: List[Tensor], + ): + # mems is used for next layer + if layer < self.num_layers - 1: + state_mems = torch.cat([state[0][layer + 1], mems], dim=0) + mems_list.append(state_mems[-self.max_memory_size :, :, :]) + + # when mems pass to next sequence, we need the last memory. when mems + # use for the next layer, we can ignore the last memory + mems = mems[:-1, :, :] + + # note state[1][i] and state[2][i] original length equals to self.left_context + new_k = torch.cat([state[1][layer], next_key], dim=0) + new_v = torch.cat([state[2][layer], next_val], dim=0) + lc_key_list.append(new_k[-self.left_context :, :, :]) + lc_val_list.append(new_v[-self.left_context :, :, :]) + return mems_list, lc_key_list, lc_val_list, mems + + def state_update_after_loop( + self, + state: List[Tensor], + mems_list: List[Tensor], + lc_key_list: List[Tensor], + lc_val_list: List[Tensor], + update_length: int, + ): + state[0] = torch.stack(mems_list, dim=0) + state[1] = torch.stack(lc_key_list, dim=0) + state[2] = torch.stack(lc_val_list, dim=0) + state[3] = state[3] + update_length + return state + + @torch.jit.unused + def forward_mini_batches( + self, input: Tensor, lengths: Tensor, state: Optional[List[Tensor]] = None + ) -> Tuple[Tensor, Tensor, List[Tensor], List[Tensor]]: + T, B, D = input.size() + + # input without right context + seg = input[: T - self.right_context, :, :] + + # get right context blocks + right_context_blocks = self._gen_right_context_padded_input(input) + + mems_list = [] + lc_key_list = [] + lc_val_list = [] + results = self.forward_jit_mini_batch_init(seg, state, False) + state, mems, state_mems, past_length, past_left_context = results + + # relative position embedding + if self.use_rpe: + rpe = self._get_relative_position( + input=input, + max_relative_position=self.max_relative_position, + left_context_length=past_left_context, + past_length=past_length, + is_decoding=False, + ) + else: + rpe = None + + # get attention mask based on seg (not include right context) and available + # left context + attention_mask = self._get_attention_mask(seg, past_length, past_left_context) + mems_list.append(state_mems[-self.max_memory_size :, :, :]) + output = seg + i = 0 + all_outputs = [] + for layer in self.layers: + # In order to make cross stream batching work, mem, left context key + # and left context value in the state should always be the same shape. + # We use the past length to track the processed segment number. In this + # way, we take out the essential memory, left context key and left + # context val from the state. After finish the forward for current segment + # we add the new memory, left context key and left context value into the + # staate and trim out the oldest part to keep the shape consistent. + pre_mems, lc_key, lc_val = self.state_update_before( + i, state, past_length, past_left_context + ) + + output, mems, right_context_blocks, next_key, next_val = layer.forward( + input=output, + lengths=lengths, + attention_mask=attention_mask, + mems=mems, + right_context_blocks=right_context_blocks, + pre_mems=pre_mems, + left_context_key=lc_key, + left_context_val=lc_val, + rpe=rpe, + ) + all_outputs.append(output) + mems_list, lc_key_list, lc_val_list, mems = self.state_update_after( + layer=i, + state=state, + mems=mems, + next_key=next_key, + next_val=next_val, + mems_list=mems_list, + lc_key_list=lc_key_list, + lc_val_list=lc_val_list, + ) + + i += 1 + + # update state + update_length = math.ceil((T - self.right_context) / self.segment_size) + state = self.state_update_after_loop( + state=state, + mems_list=mems_list, + lc_key_list=lc_key_list, + lc_val_list=lc_val_list, + update_length=update_length, + ) + + return output, lengths, state, all_outputs + + def forward_jit_test( + self, input: Tensor, lengths: Tensor, state: Optional[List[Tensor]] = None + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + This one simulate sequence encoder forward jit. This is for unit test purpose. + It is not used in training or decoding. Note, extra_right_context is set in + the model. In unit test, input = [utterance, right_context], lengths = + [utterance_length]. + args: + input: input utterance + lengths: utterance input length + state: None here. input is whole utterance + """ + # [TODO] sequence_to_segment has bug in lengths. + seg_src_tokens_lengths = self._gen_segs_right_context(input, lengths) + + seg_enc_tokens_lengths: List[Tuple[Tensor, Tensor]] = [] + state: Optional[List[Tensor]] = None + for seg_src_tokens, seg_src_lengths in seg_src_tokens_lengths: + seg_enc_tokens, seg_enc_lengths, state = self.forward_jit( + input=seg_src_tokens, lengths=seg_src_lengths, state=state + ) + seg_enc_tokens_lengths.append((seg_enc_tokens, seg_enc_lengths)) + + enc_tokens, enc_lengths = segments_to_sequence( + segments=seg_enc_tokens_lengths, time_axis=0 + ) + + state = [] # returns trivial state + + return enc_tokens, enc_lengths, state + + @torch.jit.export + def forward_jit( + self, input: Tensor, lengths: Tensor, state: Optional[List[Tensor]] = None + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Forward helper for online decoding. + + args: + input: [seg, right_context]. We assume in online we + always padding the right context to the preset right context size. + For the last segment, we may have short segment size, but right + context size is the same as other segments + lengths: utterance input length is the utterance segment length and + right context size + state: [memory, left_context_key, left_context_val]. To improve throughput, + in addition to memory, we also cache key and value for left_context in + multihead self-attention + """ + # In online decoding, input = [segment, right_context] + # Lengths = [segment_length, right_context_length] + # so we need strip right context in output + T, B, D = input.size() + rc_str = T - self.right_context + rc_end = T + right_context_blocks = input[rc_str:rc_end, :, :] + seg = input[:rc_str, :, :] + lengths = torch.clamp(lengths - self.right_context, min=0) + mems_list = [] + lc_key_list = [] + lc_val_list = [] + + results = self.forward_jit_mini_batch_init(seg, state, True) + state, mems, state_mems, past_length, past_left_context = results + + # relative position embedding + if self.use_rpe: + rpe = self._get_relative_position( + input=input, + max_relative_position=self.max_relative_position, + left_context_length=past_left_context, + past_length=past_length, + is_decoding=True, + ) + else: + rpe = None + + # memory for first layer. + mems_list.append(state_mems[-self.max_memory_size :, :, :]) + output = seg + i = 0 + for layer in self.layers: + # In order to make cross stream batching work, mem, left context key + # and left context value in the state should always be the same shape. + # We use the past length to track the processed segment number. In this + # way, we take out the essential memory, left context key and left + # context val from the state. After finish the forward for current segment + # we add the new memory, left context key and left context value into the + # staate and trim out the oldest part to keep the shape consistent. + true_mems, lc_key, lc_val = self.state_update_before( + layer=i, + state=state, + past_length=past_length, + past_left_context=past_left_context, + ) + + output, mems, right_context_blocks, next_key, next_val = layer.forward_jit( + input=output, + lengths=lengths, + mems=true_mems, + right_context_blocks=right_context_blocks, + left_context_key=lc_key, + left_context_val=lc_val, + rpe=rpe, + ) + # mems is used for next layer + mems_list, lc_key_list, lc_val_list, _ = self.state_update_after( + layer=i, + state=state, + mems_list=mems_list, + mems=mems, + next_key=next_key, + next_val=next_val, + lc_key_list=lc_key_list, + lc_val_list=lc_val_list, + ) + i += 1 + + # update state + state = self.state_update_after_loop( + state=state, + mems_list=mems_list, + lc_key_list=lc_key_list, + lc_val_list=lc_val_list, + update_length=1, + ) + + return output, lengths, state + + def quantize_(self, params=None): + if params and "per_channel" in params and params["per_channel"]: + qconfig = per_channel_dynamic_qconfig + else: + qconfig = default_dynamic_qconfig + quantization.quantize_dynamic( + self, {torch.nn.Linear: qconfig}, dtype=torch.qint8, inplace=True + ) + return self + + +# ------------------------------------------------------------------------------ +# Emformer encoder for seq2seq model +# This is a wrapper over the original emformer +# ------------------------------------------------------------------------------ +def emformer_encoder(klass): + class SpeechEncoder(klass): + def __init__(self, args): + super().__init__(args) + stride = SpeechEncoder.conv_layer_stride(args) + trf_left_context = args.segment_left_context // stride + trf_right_context = args.segment_right_context // stride + context_config = [trf_left_context, trf_right_context] + self.transformer_layers = nn.ModuleList( + [ + NoSegAugmentedMemoryTransformerEncoderLayer( + input_dim=args.encoder_embed_dim, + num_heads=args.encoder_attention_heads, + ffn_dim=args.encoder_ffn_embed_dim, + num_layers=args.encoder_layers, + dropout_in_attn=args.dropout, + dropout_on_attn=args.dropout, + dropout_on_fc1=args.dropout, + dropout_on_fc2=args.dropout, + activation_fn=args.activation_fn, + context_config=context_config, + segment_size=args.segment_length, + max_memory_size=args.max_memory_size, + scaled_init=True, # TODO: use constant for now. + tanh_on_mem=args.amtrf_tanh_on_mem, + ) + ] + ) + + def forward(self, src_tokens, src_lengths): + encoder_out = super().forward(src_tokens, src_lengths) + output = encoder_out["encoder_out"][0] + encoder_padding_masks = encoder_out["encoder_padding_mask"][0] + + # This is because that in the original implementation + # the output didn't consider the last segment as right context. + encoder_padding_masks = encoder_padding_masks[:, : output.size(0)] + + return { + "encoder_out": [output], + "encoder_padding_mask": [encoder_padding_masks], + "encoder_embedding": [], + "encoder_states": [], + "src_tokens": [], + "src_lengths": [], + } + + @staticmethod + def conv_layer_stride(args): + # TODO: make it configurable from the args + return 4 + + SpeechEncoder.__name__ = klass.__name__ + return SpeechEncoder diff --git a/fairseq/models/speech_to_text/multi_modality_model.py b/fairseq/models/speech_to_text/multi_modality_model.py new file mode 100644 index 0000000000000000000000000000000000000000..046421620ae9e78a6138632c26e8a25ce3b64478 --- /dev/null +++ b/fairseq/models/speech_to_text/multi_modality_model.py @@ -0,0 +1,49 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.models import FairseqDecoder, FairseqEncoder + + +# a container for different encoders with training samples from different modality +# each time, only one encoder is selected +class MultiModalityEncoder(FairseqEncoder): + def __init__(self, dictionary): + super().__init__(dictionary) + + def select_encoder(self, mode, **kwargs): + raise NotImplementedError("Model must implement the select_encoder method") + return None, kwargs + + # def post_encoder(self, encoder_out, src_tokens, src_lengths, mode, **kwargs): + # # Default do nothing + # return encoder_out + + # get sample data from JointSpeechTextDataset + def forward(self, src_tokens, src_lengths=None, mode="", **kwargs): + encoder, kwargs = self.select_encoder(mode, **kwargs) + # return self.post_encoder(encoder(src_tokens, src_lengths, **kwargs), src_tokens, src_lengths, mode, **kwargs) + return encoder(src_tokens, src_lengths, **kwargs) + + +# a container for different decoders with training samples from different modality +# each time, only one decoder is selected +class MultiInputDecoder(FairseqDecoder): + def __init__(self, dictionary): + super().__init__(dictionary) + + def select_decoder(self, mode, **kwargs): + raise NotImplementedError("Model must implement the select_decoder method") + return None, kwargs + + def forward( + self, prev_output_tokens, encoder_out, incremental_state=None, mode="", **kwargs + ): + decoder, kwargs = self.select_decoder(mode, **kwargs) + return decoder( + prev_output_tokens, + encoder_out, + incremental_state=incremental_state, + **kwargs + ) diff --git a/fairseq/models/speech_to_text/s2t_conformer.py b/fairseq/models/speech_to_text/s2t_conformer.py new file mode 100644 index 0000000000000000000000000000000000000000..79dbbec1b90a05ab101d5b086cd18eec858cdc8c --- /dev/null +++ b/fairseq/models/speech_to_text/s2t_conformer.py @@ -0,0 +1,234 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +from pathlib import Path + +import torch + +from fairseq import checkpoint_utils +from fairseq.data.data_utils import lengths_to_padding_mask +from fairseq.models import FairseqEncoder, register_model, register_model_architecture +from fairseq.models.speech_to_text.modules.convolution import ( + Conv1dSubsampler, + Conv2dSubsampler, +) +from fairseq.models.speech_to_text.s2t_transformer import ( + S2TTransformerEncoder, + S2TTransformerModel, +) +from fairseq.models.speech_to_text.s2t_transformer import ( + base_architecture as transformer_base_architecture, +) +from fairseq.modules import PositionalEmbedding, RelPositionalEncoding +from fairseq.modules.conformer_layer import ConformerEncoderLayer + +logger = logging.getLogger(__name__) + + +class S2TConformerEncoder(FairseqEncoder): + """Conformer Encoder for speech translation based on https://arxiv.org/abs/2005.08100""" + + def __init__(self, args): + super().__init__(None) + + self.encoder_freezing_updates = args.encoder_freezing_updates + self.num_updates = 0 + + self.embed_scale = math.sqrt(args.encoder_embed_dim) + if args.no_scale_embedding: + self.embed_scale = 1.0 + self.padding_idx = 1 + self.conv_version = args.conv_version + if self.conv_version == "s2t_transformer": + self.subsample = Conv1dSubsampler( + args.input_feat_per_channel * args.input_channels, + args.conv_channels, + args.encoder_embed_dim, + [int(k) for k in args.conv_kernel_sizes.split(",")], + ) + elif self.conv_version == "convtransformer": + self.subsample = Conv2dSubsampler( + args.input_channels, + args.input_feat_per_channel, + args.conv_out_channels, + args.encoder_embed_dim, + ) + self.pos_enc_type = args.pos_enc_type + if self.pos_enc_type == "rel_pos": + self.embed_positions = RelPositionalEncoding( + args.max_source_positions, args.encoder_embed_dim + ) + elif self.pos_enc_type == "rope": + self.embed_positions = None + else: # Use absolute positional embedding + self.pos_enc_type = "abs" + self.embed_positions = PositionalEmbedding( + args.max_source_positions, args.encoder_embed_dim, self.padding_idx + ) + + self.linear = torch.nn.Linear(args.encoder_embed_dim, args.encoder_embed_dim) + self.dropout = torch.nn.Dropout(args.dropout) + self.conformer_layers = torch.nn.ModuleList( + [ + ConformerEncoderLayer( + embed_dim=args.encoder_embed_dim, + ffn_embed_dim=args.encoder_ffn_embed_dim, + attention_heads=args.encoder_attention_heads, + dropout=args.dropout, + depthwise_conv_kernel_size=args.depthwise_conv_kernel_size, + attn_type=args.attn_type, + pos_enc_type=self.pos_enc_type, + use_fp16=args.fp16, + ) + for _ in range(args.encoder_layers) + ] + ) + + def _forward(self, src_tokens, src_lengths, return_all_hiddens=False): + """ + Args: + src_tokens: Input source tokens Tensor of shape B X T X C + src_lengths: Lengths Tensor corresponding to input source tokens + return_all_hiddens: If true will append the self attention states to the encoder states + Returns: + encoder_out: Tensor of shape B X T X C + encoder_padding_mask: Optional Tensor with mask + encoder_embedding: Optional Tensor. Always empty here + encoder_states: List of Optional Tensors wih self attention states + src_tokens: Optional Tensor. Always empty here + src_lengths: Optional Tensor. Always empty here + """ + x, input_lengths = self.subsample(src_tokens, src_lengths) # returns T X B X C + encoder_padding_mask = lengths_to_padding_mask(input_lengths) + x = self.embed_scale * x + if self.pos_enc_type == "rel_pos": + positions = self.embed_positions(x) + + elif self.pos_enc_type == "rope": + positions = None + + else: + positions = self.embed_positions(encoder_padding_mask).transpose(0, 1) + x += positions + positions = None + + x = self.linear(x) + x = self.dropout(x) + encoder_states = [] + + # x is T X B X C + for layer in self.conformer_layers: + x, _ = layer(x, encoder_padding_mask, positions) + if return_all_hiddens: + encoder_states.append(x) + + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [encoder_padding_mask] + if encoder_padding_mask.any() + else [], # B x T + "encoder_embedding": [], # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], + "src_lengths": [], + } + + def forward(self, src_tokens, src_lengths, return_all_hiddens=False): + if self.num_updates < self.encoder_freezing_updates: + with torch.no_grad(): + x = self._forward( + src_tokens, + src_lengths, + return_all_hiddens=return_all_hiddens, + ) + else: + x = self._forward( + src_tokens, + src_lengths, + return_all_hiddens=return_all_hiddens, + ) + return x + + def reorder_encoder_out(self, encoder_out, new_order): + """Required method for a FairseqEncoder. Calls the method from the parent class""" + return S2TTransformerEncoder.reorder_encoder_out(self, encoder_out, new_order) + + def set_num_updates(self, num_updates): + super().set_num_updates(num_updates) + self.num_updates = num_updates + + +@register_model("s2t_conformer") +class S2TConformerModel(S2TTransformerModel): + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @staticmethod + def add_args(parser): + S2TTransformerModel.add_args(parser) + parser.add_argument( + "--input-feat-per-channel", + type=int, + metavar="N", + help="dimension of input features per channel", + ) + parser.add_argument( + "--input-channels", + type=int, + metavar="N", + help="number of chennels of input features", + ) + parser.add_argument( + "--depthwise-conv-kernel-size", + type=int, + metavar="N", + help="kernel size of depthwise convolution layers", + ) + parser.add_argument( + "--attn-type", + type=str, + metavar="STR", + help="If not specified uses fairseq MHA. Other valid option is espnet", + ) + parser.add_argument( + "--pos-enc-type", + type=str, + metavar="STR", + help="Must be specified in addition to attn-type=espnet for rel_pos and rope", + ) + + @classmethod + def build_encoder(cls, args): + encoder = S2TConformerEncoder(args) + pretraining_path = getattr(args, "load_pretrained_encoder_from", None) + if pretraining_path is not None: + if not Path(pretraining_path).exists(): + logger.warning( + f"skipped pretraining because {pretraining_path} does not exist" + ) + else: + encoder = checkpoint_utils.load_pretrained_component_from_model( + component=encoder, checkpoint=pretraining_path + ) + logger.info(f"loaded pretrained encoder from: {pretraining_path}") + return encoder + + +@register_model_architecture("s2t_conformer", "s2t_conformer") +def conformer_base_architecture(args): + args.attn_type = getattr(args, "attn_type", None) + args.pos_enc_type = getattr(args, "pos_enc_type", "abs") + args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) + args.input_channels = getattr(args, "input_channels", 1) + args.max_source_positions = getattr(args, "max_source_positions", 6000) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.dropout = getattr(args, "dropout", 0.1) + args.encoder_layers = getattr(args, "encoder_layers", 16) + args.depthwise_conv_kernel_size = getattr(args, "depthwise_conv_kernel_size", 31) + transformer_base_architecture(args) diff --git a/fairseq/models/speech_to_text/s2t_transformer.py b/fairseq/models/speech_to_text/s2t_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..50fae2ffa276368e9fb1c339fda904f27363c45d --- /dev/null +++ b/fairseq/models/speech_to_text/s2t_transformer.py @@ -0,0 +1,552 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from torch import Tensor + +from fairseq import checkpoint_utils, utils +from fairseq.data.data_utils import lengths_to_padding_mask +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderDecoderModel, + register_model, + register_model_architecture, +) +from fairseq.models.speech_to_text.hub_interface import S2THubInterface +from fairseq.models.speech_to_text.modules.convolution import ( + Conv1dSubsampler, + Conv2dSubsampler, +) +from fairseq.models.transformer import Embedding, TransformerDecoder +from fairseq.modules import ( + FairseqDropout, + LayerNorm, + PositionalEmbedding, + TransformerEncoderLayer, +) + +logger = logging.getLogger(__name__) + + +@register_model("s2t_transformer") +class S2TTransformerModel(FairseqEncoderDecoderModel): + """Adapted Transformer model (https://arxiv.org/abs/1706.03762) for + speech-to-text tasks. The Transformer encoder/decoder remains the same. + A trainable input subsampler is prepended to the Transformer encoder to + project inputs into the encoder dimension as well as downsample input + sequence for computational efficiency.""" + + @classmethod + def hub_models(cls): + base_url = "http://dl.fbaipublicfiles.com/fairseq/s2t" + model_ids = [ + "s2t_transformer_s-en-asr-librispeech", + "s2t_transformer_m-en-asr-librispeech", + "s2t_transformer_l-en-asr-librispeech", + ] + return {i: f"{base_url}/{i}.tar.gz" for i in model_ids} + + @classmethod + def from_pretrained( + cls, + model_name_or_path, + checkpoint_file="model.pt", + data_name_or_path=".", + config_yaml="config.yaml", + **kwargs, + ): + from fairseq import hub_utils + + x = hub_utils.from_pretrained( + model_name_or_path, + checkpoint_file, + data_name_or_path, + archive_map=cls.hub_models(), + config_yaml=config_yaml, + **kwargs, + ) + return S2THubInterface(x["args"], x["task"], x["models"][0]) + + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + # input + parser.add_argument( + "--conv-kernel-sizes", + type=str, + metavar="STR", + help="kernel sizes of Conv1d (s2t_transformer) subsampling layers", + ) + parser.add_argument( + "--conv-channels", + type=int, + metavar="N", + help="# of channels in Conv1d (s2t_transformer) subsampling layers", + ) + parser.add_argument( + "--conv-out-channels", + type=int, + metavar="N", + help="# of channels in Conv2d (convtransformer) subsampling layers", + ) + parser.add_argument( + "--conv-version", + type=str, + default="s2t_transformer", + choices=["s2t_transformer", "convtransformer"], + help="version of frontend convolutional layers", + ) + # Transformer + parser.add_argument( + "--activation-fn", + type=str, + default="relu", + choices=utils.get_available_activation_fns(), + help="activation function to use", + ) + parser.add_argument( + "--dropout", type=float, metavar="D", help="dropout probability" + ) + parser.add_argument( + "--attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights", + ) + parser.add_argument( + "--activation-dropout", + "--relu-dropout", + type=float, + metavar="D", + help="dropout probability after activation in FFN.", + ) + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension", + ) + parser.add_argument( + "--encoder-ffn-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension for FFN", + ) + parser.add_argument( + "--encoder-layers", type=int, metavar="N", help="num encoder layers" + ) + parser.add_argument( + "--encoder-attention-heads", + type=int, + metavar="N", + help="num encoder attention heads", + ) + parser.add_argument( + "--encoder-normalize-before", + action="store_true", + help="apply layernorm before each encoder block", + ) + parser.add_argument( + "--decoder-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension", + ) + parser.add_argument( + "--decoder-ffn-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension for FFN", + ) + parser.add_argument( + "--decoder-layers", type=int, metavar="N", help="num decoder layers" + ) + parser.add_argument( + "--decoder-attention-heads", + type=int, + metavar="N", + help="num decoder attention heads", + ) + parser.add_argument( + "--decoder-normalize-before", + action="store_true", + help="apply layernorm before each decoder block", + ) + parser.add_argument( + "--share-decoder-input-output-embed", + action="store_true", + help="share decoder input and output embeddings", + ) + parser.add_argument( + "--layernorm-embedding", + action="store_true", + help="add layernorm to embedding", + ) + parser.add_argument( + "--no-scale-embedding", + action="store_true", + help="if True, dont scale embeddings", + ) + parser.add_argument( + "--load-pretrained-encoder-from", + type=str, + metavar="STR", + help="model to take encoder weights from (for initialization)", + ) + parser.add_argument( + "--encoder-freezing-updates", + type=int, + metavar="N", + help="freeze encoder for first N updates", + ) + + @classmethod + def build_encoder(cls, args): + encoder = S2TTransformerEncoder(args) + pretraining_path = getattr(args, "load_pretrained_encoder_from", None) + if pretraining_path is not None: + if not Path(pretraining_path).exists(): + logger.warning( + f"skipped pretraining because {pretraining_path} does not exist" + ) + else: + encoder = checkpoint_utils.load_pretrained_component_from_model( + component=encoder, checkpoint=pretraining_path + ) + logger.info(f"loaded pretrained encoder from: {pretraining_path}") + return encoder + + @classmethod + def build_decoder(cls, args, task, embed_tokens): + return TransformerDecoderScriptable(args, task.target_dictionary, embed_tokens) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_architecture(args) + + def build_embedding(dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + return Embedding(num_embeddings, embed_dim, padding_idx) + + decoder_embed_tokens = build_embedding( + task.target_dictionary, args.decoder_embed_dim + ) + args.tgt_dict_size = len(task.target_dictionary) + encoder = cls.build_encoder(args) + decoder = cls.build_decoder(args, task, decoder_embed_tokens) + return cls(encoder, decoder) + + def get_normalized_probs( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + # net_output['encoder_out'] is a (B, T, D) tensor + lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample) + lprobs.batch_first = True + return lprobs + + def get_ctc_target(self, sample: Optional[Dict[str, Tensor]]): + return sample["target"], sample["target_lengths"] + + def get_ctc_output( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + sample: Optional[Dict[str, Tensor]], + ): + encoder_out = net_output[1]["encoder_out"]["encoder_out"][0] + logits = self.encoder.ctc_proj(encoder_out) # T x B x C + out = utils.log_softmax(logits.float(), dim=-1) + padding_mask = net_output[1]["encoder_out"]["encoder_padding_mask"] + lens = out.new_full((out.shape[1],), out.shape[0]).long() + if len(padding_mask) > 0: + lens -= padding_mask[0].sum(dim=-1) + return out, lens + + def forward(self, src_tokens, src_lengths, prev_output_tokens): + """ + The forward method inherited from the base class has a **kwargs + argument in its input, which is not supported in torchscript. This + method overwrites the forward method definition without **kwargs. + """ + encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths) + decoder_out = self.decoder( + prev_output_tokens=prev_output_tokens, encoder_out=encoder_out + ) + return decoder_out + + +class S2TTransformerEncoder(FairseqEncoder): + """Speech-to-text Transformer encoder that consists of input subsampler and + Transformer encoder.""" + + def __init__(self, args): + super().__init__(None) + + self.encoder_freezing_updates = args.encoder_freezing_updates + self.num_updates = 0 + + self.dropout_module = FairseqDropout( + p=args.dropout, module_name=self.__class__.__name__ + ) + self.embed_scale = math.sqrt(args.encoder_embed_dim) + if args.no_scale_embedding: + self.embed_scale = 1.0 + self.padding_idx = 1 + + self.conv_version = args.conv_version + if self.conv_version == "s2t_transformer": + self.subsample = Conv1dSubsampler( + args.input_feat_per_channel * args.input_channels, + args.conv_channels, + args.encoder_embed_dim, + [int(k) for k in args.conv_kernel_sizes.split(",")], + ) + elif self.conv_version == "convtransformer": + self.subsample = Conv2dSubsampler( + args.input_channels, + args.input_feat_per_channel, + args.conv_out_channels, + args.encoder_embed_dim, + ) + + self.embed_positions = PositionalEmbedding( + args.max_source_positions, args.encoder_embed_dim, self.padding_idx + ) + + self.transformer_layers = nn.ModuleList( + [TransformerEncoderLayer(args) for _ in range(args.encoder_layers)] + ) + if args.encoder_normalize_before: + self.layer_norm = LayerNorm(args.encoder_embed_dim) + else: + self.layer_norm = None + + self.ctc_proj = None + if getattr(args, "ctc_weight", 0.0) > 0.0: + self.ctc_proj = nn.Linear(args.encoder_embed_dim, args.tgt_dict_size) + + def _forward(self, src_tokens, src_lengths, return_all_hiddens=False): + x, input_lengths = self.subsample(src_tokens, src_lengths) + x = self.embed_scale * x + + encoder_padding_mask = lengths_to_padding_mask(input_lengths) + positions = self.embed_positions(encoder_padding_mask).transpose(0, 1) + x += positions + x = self.dropout_module(x) + + encoder_states = [] + + for layer in self.transformer_layers: + x = layer(x, encoder_padding_mask) + if return_all_hiddens: + encoder_states.append(x) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [encoder_padding_mask] + if encoder_padding_mask.any() + else [], # B x T + "encoder_embedding": [], # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], + "src_lengths": [], + } + + def forward(self, src_tokens, src_lengths, return_all_hiddens=False): + if self.num_updates < self.encoder_freezing_updates: + with torch.no_grad(): + x = self._forward( + src_tokens, src_lengths, return_all_hiddens=return_all_hiddens + ) + else: + x = self._forward( + src_tokens, src_lengths, return_all_hiddens=return_all_hiddens + ) + return x + + def reorder_encoder_out(self, encoder_out, new_order): + new_encoder_out = ( + [] + if len(encoder_out["encoder_out"]) == 0 + else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]] + ) + + new_encoder_padding_mask = ( + [] + if len(encoder_out["encoder_padding_mask"]) == 0 + else [ + x.index_select(0, new_order) + for x in encoder_out["encoder_padding_mask"] + ] + ) + + new_encoder_embedding = ( + [] + if len(encoder_out["encoder_embedding"]) == 0 + else [ + x.index_select(0, new_order) for x in encoder_out["encoder_embedding"] + ] + ) + + encoder_states = encoder_out["encoder_states"] + if len(encoder_states) > 0: + for idx, state in enumerate(encoder_states): + encoder_states[idx] = state.index_select(1, new_order) + + return { + "encoder_out": new_encoder_out, # T x B x C + "encoder_padding_mask": new_encoder_padding_mask, # B x T + "encoder_embedding": new_encoder_embedding, # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], # B x T + "src_lengths": [], # B x 1 + } + + def set_num_updates(self, num_updates): + super().set_num_updates(num_updates) + self.num_updates = num_updates + + +class TransformerDecoderScriptable(TransformerDecoder): + def extract_features( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + # call scriptable method from parent class + x, _ = self.extract_features_scriptable( + prev_output_tokens, + encoder_out, + incremental_state, + full_context_alignment, + alignment_layer, + alignment_heads, + ) + extra = {"encoder_out": encoder_out} if incremental_state is None else None + return x, extra + + +@register_model_architecture(model_name="s2t_transformer", arch_name="s2t_transformer") +def base_architecture(args): + args.encoder_freezing_updates = getattr(args, "encoder_freezing_updates", 0) + # Convolutional subsampler + args.input_channels = getattr(args, "input_channels", 1) + args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5") # for Conv1d + args.conv_channels = getattr(args, "conv_channels", 1024) # for Conv1d + args.conv_out_channels = getattr(args, "conv_out_channels", 256) # for Conv2d + args.conv_version = getattr(args, "conv_version", "s2t_transformer") + # Transformer + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_layers = getattr(args, "encoder_layers", 12) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", args.dropout) + args.activation_dropout = getattr(args, "activation_dropout", args.dropout) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + + +@register_model_architecture("s2t_transformer", "s2t_transformer_s") +def s2t_transformer_s(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) + args.dropout = getattr(args, "dropout", 0.1) + base_architecture(args) + + +@register_model_architecture("s2t_transformer", "s2t_transformer_xs") +def s2t_transformer_xs(args): + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.decoder_layers = getattr(args, "decoder_layers", 3) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 4) + args.dropout = getattr(args, "dropout", 0.3) + s2t_transformer_s(args) + + +@register_model_architecture("s2t_transformer", "s2t_transformer_sp") +def s2t_transformer_sp(args): + args.encoder_layers = getattr(args, "encoder_layers", 16) + s2t_transformer_s(args) + + +@register_model_architecture("s2t_transformer", "s2t_transformer_m") +def s2t_transformer_m(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512 * 4) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.dropout = getattr(args, "dropout", 0.15) + base_architecture(args) + + +@register_model_architecture("s2t_transformer", "s2t_transformer_mp") +def s2t_transformer_mp(args): + args.encoder_layers = getattr(args, "encoder_layers", 16) + s2t_transformer_m(args) + + +@register_model_architecture("s2t_transformer", "s2t_transformer_l") +def s2t_transformer_l(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024 * 4) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + args.dropout = getattr(args, "dropout", 0.2) + base_architecture(args) + + +@register_model_architecture("s2t_transformer", "s2t_transformer_lp") +def s2t_transformer_lp(args): + args.encoder_layers = getattr(args, "encoder_layers", 16) + s2t_transformer_l(args) diff --git a/fairseq/models/speech_to_text/s2t_wav_transformer.py b/fairseq/models/speech_to_text/s2t_wav_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..ad21aeeb1a6f2c47a0ed380478853a189e3577a4 --- /dev/null +++ b/fairseq/models/speech_to_text/s2t_wav_transformer.py @@ -0,0 +1,504 @@ +#!/usr/bin/env python3 + +import math + +import torch +import torch.nn as nn + +from fairseq.data.data_utils import compute_mask_indices +from fairseq.models import FairseqEncoder +from fairseq.models.wav2vec import ConvFeatureExtractionModel +from fairseq.modules import GradMultiply, LayerNorm, SamePad, TransformerEncoderLayer + + +# Transformer encoder with wave input, it is adopted from wav2vec 2.0 Encoder. +# use wav input +# use trained position embedding so it is easier to match with text input +class SpeechWavTransformerEncoder(FairseqEncoder): + + # extra parameters for speech encoder besides those defined in transformermodel + @staticmethod + def add_args(parser): + parser.add_argument( + "--dropout-input", + type=float, + metavar="D", + help="dropout to apply to the input (after feat extr)", + ) + parser.add_argument( + "--dropout-features", + type=float, + metavar="D", + help="dropout to apply to the unmasked features (after feat extr)", + ) + parser.add_argument( + "--speech-extractor-mode", + type=str, + default="layer_norm", + choices=["default", "layer_norm"], + help="feature extractor norm", + ) + + parser.add_argument( + "--speech-conv-bias", + action="store_true", + help="include bias in speech conv encoder", + ) + + parser.add_argument( + "--conv-feature-layers", + default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]", + help="string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]", + ) + + parser.add_argument( + "--speech-mask-length", + type=int, + help="repeat the mask indices multiple times", + ) + + parser.add_argument( + "--speech-mask-prob", + type=float, + help="probability of replacing a token with mask", + ) + + parser.add_argument( + "--speech-mask-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + help="how to choose masks", + ) + + parser.add_argument( + "--speech-mask-other", + type=float, + help="stdev of the mask length in case of 'normal' selection strategy", + ) + + parser.add_argument( + "--speech-no-mask-overlap", + action="store_true", + help="whether to allow masks to overlap", + ) + + parser.add_argument( + "--speech-mask-min-space", + type=int, + help="min space between spans (if no overlap is enabled)", + ) + + parser.add_argument( + "--speech-mask-channel-length", + type=int, + help="repeat the mask indices multiple times", + ) + + parser.add_argument( + "--speech-mask-channel-prob", + type=float, + help="probability of replacing a token with mask", + ) + + parser.add_argument( + "--speech-mask-channel-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + help="how to choose masks", + ) + + parser.add_argument( + "--speech-mask-channel-other", + type=float, + help="stdev of the mask length in case of 'normal' selection strategy", + ) + + parser.add_argument( + "--speech-no-mask-channel-overlap", + action="store_true", + help="whether to allow masks to overlap", + ) + + parser.add_argument( + "--no-scale-feature", + action="store_true", + help="no scale for the calculated features", + ) + + parser.add_argument( + "--speech-mask-channel-min-space", + type=int, + help="min space between spans (if no overlap is enabled)", + ) + + parser.add_argument( + "--feature-grad-mult", + type=float, + help="reset feature grad mult in wav2vec 2.0 to this", + ) + + # positional embeddings + parser.add_argument( + "--conv-pos", + type=int, + default=128, + help="number of filters for convolutional positional embeddings", + ) + + parser.add_argument( + "--conv-pos-groups", + type=int, + default=16, + help="number of groups for convolutional positional embedding", + ) + # model configures + parser.add_argument( + "--speech-encoder-layers", + type=int, + help="number of speech encoder layers", + ) + parser.add_argument( + "--text-encoder-layers", + type=int, + help="number of text encoder layers", + ) + + def __init__(self, args, alway_mask=False): + super().__init__(args) + self.args = args + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + self.feat_scale = math.sqrt(args.encoder_embed_dim) + if args.no_scale_feature: + self.feat_scale = 1.0 + + subsample = ConvFeatureExtractionModel( + conv_layers=eval(args.conv_feature_layers), + dropout=0.0, + mode=args.speech_extractor_mode, # default, layer_norm + conv_bias=args.speech_conv_bias, + ) + self.feature_enc_layers = eval(args.conv_feature_layers) + self.subsample = subsample + self.feat_proj = ( + nn.Linear(self.feature_enc_layers[-1][0], self.embedding_dim) + if self.feature_enc_layers[-1][0] != self.embedding_dim + else None + ) + + self.feat_layer_norm = LayerNorm(self.feature_enc_layers[-1][0]) + + self.embed_positions = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=args.conv_pos, + padding=args.conv_pos // 2, + groups=args.conv_pos_groups, + ) + std = math.sqrt(4 / (args.conv_pos * self.embedding_dim)) + nn.init.normal_(self.embed_positions.weight, mean=0, std=std) + nn.init.constant_(self.embed_positions.bias, 0) + + self.embed_positions = nn.utils.weight_norm( + self.embed_positions, name="weight", dim=2 + ) + self.embed_positions = nn.Sequential( + self.embed_positions, SamePad(args.conv_pos), nn.GELU() + ) + + self.mask_prob = args.speech_mask_prob + self.mask_selection = args.speech_mask_selection + self.mask_other = args.speech_mask_other + self.mask_length = args.speech_mask_length + self.no_mask_overlap = args.speech_no_mask_overlap + self.mask_min_space = args.speech_mask_min_space + + self.mask_channel_prob = args.speech_mask_channel_prob + self.mask_channel_selection = args.speech_mask_channel_selection + self.mask_channel_other = args.speech_mask_channel_other + self.mask_channel_length = args.speech_mask_channel_length + self.no_mask_channel_overlap = args.speech_no_mask_channel_overlap + self.mask_channel_min_space = args.speech_mask_channel_min_space + + self.dropout_input = nn.Dropout(args.dropout_input) + self.dropout_features = nn.Dropout(args.dropout_features) + + self.feature_grad_mult = args.feature_grad_mult + + self.mask_emb = nn.Parameter( + torch.FloatTensor(args.encoder_embed_dim).uniform_() + ) + + self.layers = nn.ModuleList( + [TransformerEncoderLayer(args) for _ in range(args.encoder_layers)] + ) + self.layer_norm = LayerNorm(args.encoder_embed_dim) + self.normalize_before = args.encoder_normalize_before + self.alway_mask = alway_mask + + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + return torch.floor((input_length - kernel_size) / stride + 1) + + for i in range(len(self.feature_enc_layers)): + input_lengths = _conv_out_length( + input_lengths, + self.feature_enc_layers[i][1], + self.feature_enc_layers[i][2], + ) + + return input_lengths.to(torch.long) + + def apply_mask(self, x, padding_mask): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def forward( + self, + src_tokens, + src_lengths, + return_all_hiddens=False, + padding_mask=None, + features_only=True, + ): + mask = self.training or self.alway_mask + if self.feature_grad_mult > 0 and self.training: + features = self.subsample(src_tokens) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.subsample(src_tokens) + features = features.transpose(1, 2) + features = self.feat_layer_norm(features) + if self.feat_proj is not None: + features = self.feat_proj(features) + + if padding_mask is not None: + input_lengths = (1 - padding_mask.long()).sum(-1) + else: + input_lengths = src_lengths + # apply conv formula to get real output_lengths + output_lengths = self._get_feat_extract_output_lengths(input_lengths) + + padding_mask = torch.zeros( + features.shape[:2], dtype=features.dtype, device=features.device + ) + + # these two operations makes sure that all values + # before the output lengths indices are attended to + padding_mask[ + ( + torch.arange(padding_mask.shape[0], device=padding_mask.device), + output_lengths - 1, + ) + ] = 1 + padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool() + + features = self.feat_scale * features if self.feat_scale != 1.0 else features + unmasked_features = features.clone() + + features = self.dropout_input(features) + unmasked_features = self.dropout_features(unmasked_features) + if mask: + x, mask_indices = self.apply_mask(features, padding_mask) + else: + x = features + mask_indices = None + + def cal_transformer_layers(x, encoder_padding_mask, return_all_hiddens=False): + # x: B x T x C + positions = self.embed_positions(x.transpose(1, 2)).transpose(1, 2) + x = x + positions + if not self.normalize_before: + x = self.layer_norm(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + encoder_states = [] + for layer in self.layers: + x = layer(x, encoder_padding_mask) + if return_all_hiddens: + encoder_states.append(x) + if self.normalize_before: + x = self.layer_norm(x) + return x, encoder_states + + x, encoder_states = cal_transformer_layers(x, padding_mask, return_all_hiddens) + if features_only: + return { + "encoder_out": [x], # [T x B x C] + "encoder_padding_mask": [padding_mask] + if padding_mask is not None + else [], # B x T + "encoder_embedding": [], # + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], + "src_lengths": [], + "mask_indices": [mask_indices], + } + + x_unmasked = x + if self.mask_prob > 0 or self.mask_channel_prob > 0: + x_unmasked, _ = cal_transformer_layers(unmasked_features, padding_mask) + return { + "encoder_out": [x], # [T x B x C] + "encoder_unmasked_out": [x_unmasked], # [T x B x C] + "encoder_padding_mask": [padding_mask] + if padding_mask is not None + else [], # B x T + "encoder_embedding": [], # + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], + "src_lengths": [], + "mask_indices": [mask_indices] if mask_indices is not None else [], # B X T + } + + def reorder_encoder_out(self, encoder_out, new_order): + new_encoder_out = ( + [] + if len(encoder_out["encoder_out"]) == 0 + else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]] + ) + + new_encoder_padding_mask = ( + [] + if len(encoder_out["encoder_padding_mask"]) == 0 + else [ + x.index_select(0, new_order) + for x in encoder_out["encoder_padding_mask"] + ] + ) + + new_encoder_embedding = ( + [] + if len(encoder_out["encoder_embedding"]) == 0 + else [ + x.index_select(0, new_order) for x in encoder_out["encoder_embedding"] + ] + ) + + encoder_states = encoder_out["encoder_states"] + if len(encoder_states) > 0: + for idx, state in enumerate(encoder_states): + encoder_states[idx] = state.index_select(1, new_order) + + return { + "encoder_out": new_encoder_out, # T x B x C + "encoder_padding_mask": new_encoder_padding_mask, # B x T + "encoder_embedding": new_encoder_embedding, # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], # B x T + "src_lengths": [], # B x 1 + } + + +class StackedSpeechWavTransformerEncoder(FairseqEncoder): + def __init__(self, speech_enc, text_enc_layers, text_layer_norm): + super().__init__(None) + self.speech_encoder = speech_enc + self.text_encoder_layers = text_enc_layers + self.final_layer_norm = text_layer_norm + + def forward( + self, + src_tokens, + src_lengths=None, + return_all_hiddens=False, + padding_mask=None, + features_only=True, + ): + + out = self.speech_encoder.forward( + src_tokens, + src_lengths, + return_all_hiddens, + padding_mask=padding_mask, + features_only=features_only, + ) + x = out["encoder_out"][0] + encoder_padding_mask = None + if len(out["encoder_padding_mask"]) > 0: + encoder_padding_mask = out["encoder_padding_mask"][0] + + def cal_text_layers(x, padding_mask, return_all_hiddens=False): + encoder_states = [] + for layer in self.text_encoder_layers: + x = layer(x, padding_mask) + if return_all_hiddens: + encoder_states.append(x) + if self.final_layer_norm is not None: + x = self.final_layer_norm(x) + return x, encoder_states + + x, encoder_states = cal_text_layers(x, encoder_padding_mask, return_all_hiddens) + if features_only: + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [encoder_padding_mask] + if encoder_padding_mask is not None + else [], # B x T + "encoder_embedding": [], # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], + "src_lengths": [], + } + + x_u = out["encoder_unmasked_out"][0] + x_u, _ = cal_text_layers(x_u, encoder_padding_mask) + + return { + "encoder_out": [x], # [T x B x C] + "encoder_unmasked_out": [x_u], # [T x B x C] + "encoder_padding_mask": [encoder_padding_mask] + if encoder_padding_mask is not None + else [], # B x T + "encoder_embedding": [], # + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], + "src_lengths": [], + "mask_indices": out["mask_indices"], # B X T + } + + def reorder_encoder_out(self, encoder_out, new_order): + return self.speech_encoder.reorder_encoder_out(encoder_out, new_order) diff --git a/fairseq/models/speech_to_text/utils.py b/fairseq/models/speech_to_text/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..33117446a5e2f3b71c64f0a7b6b8122a1ac7c182 --- /dev/null +++ b/fairseq/models/speech_to_text/utils.py @@ -0,0 +1,562 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + + +import logging +from collections.abc import Iterable +from itertools import repeat +from typing import List, Optional, Tuple + +import torch +from torch import Tensor + +# ------------------------------------------------------------------------------ +# assert_equal() +# ------------------------------------------------------------------------------ + + +def assert_equal(value1, value2, name1=None, name2=None): + """Asserts two values are equal otherwise raise an error.""" + + str_name1 = "" if name1 is None else "{} ".format(name1) + str_name2 = "" if name2 is None else "{} ".format(name2) + if value1 != value2: + str_value1 = "{}" if name1 is None else "({})" + str_value1 = str_value1.format(value1) + str_value2 = "{}" if name2 is None else "({})" + str_value2 = str_value2.format(value2) + raise ValueError( + "Expected {}{} == {}{}".format(str_name1, str_value1, str_name2, str_value2) + ) + + +def fill_config(config, key, value): + if value is not None: + if key not in config or config[key] is None: + config[key] = value + assert_equal(value, config[key], "value", f'config["{key}"]') + + +# ------------------------------------------------------------------------------ +# check_and_return_expected() +# ------------------------------------------------------------------------------ + + +def check_and_return_expected(value, undefined_value, expected_value, name=None): + """ + Return the expected value while checking if the given value is undefined or + equal to the expected value. + """ + if (undefined_value is None and value is None) or (undefined_value == value): + return expected_value + if value != expected_value: + str_name = "" if name is None else "{} ".format(name) + str_value = "{}" if name is None else "({})" + str_value = str_value.format(value) + raise ValueError( + "Expected {}{} == {}".format(str_name, str_value, expected_value) + ) + return expected_value + + +# ------------------------------------------------------------------------------ +# get_time_axis() +# ------------------------------------------------------------------------------ + + +def get_time_axis(layout): + """ + Extract the time axis from the layout, for example for breaking sequence into + segments. + """ + if layout in ["TB", "TBD"]: + return 0 + if layout in ["BT", "BTD"]: + return 1 + if layout in ["BCTD"]: + return 2 + raise ValueError("Unsupported layout = {}".format(layout)) + + +# ------------------------------------------------------------------------------ +# get_batch_axis() +# ------------------------------------------------------------------------------ + + +def get_batch_axis(layout): + """ + Extract the batch axis from the layout + """ + if layout in ["TB", "TBD"]: + return 1 + if layout in ["BT", "BTD", "BCTD"]: + return 0 + raise ValueError("Unsupported layout = {}".format(layout)) + + +# ------------------------------------------------------------------------------ +# monotonically_increasing_and_bounded() +# ------------------------------------------------------------------------------ + + +def monotonically_increasing_and_bounded(iterable, min=None, max=None): + """ + Check if the elements in the given iterable are monotonically increasing and + bounded by upper/lower bounds. + """ + if not isinstance(iterable, Iterable): + raise TypeError( + "Expected iterable to be of type Iterable, got ({})".format( + iterable.__class__.__name__ + ) + ) + for i in range(len(iterable)): + if min is not None and iterable[i] < min: + return False + if max is not None and iterable[i] > max: + return False + if i > 0 and iterable[i] <= iterable[i - 1]: + return False + return True + + +# ------------------------------------------------------------------------------ +# to_pair() +# ------------------------------------------------------------------------------ + + +def to_pair(value, name): + """Make a pair (of type tuple) of given value.""" + if isinstance(value, Iterable): + if len(value) != 2: + raise ValueError( + "Expected `{}` to have exactly 2 elements, got: ({})".format( + name, value + ) + ) + return value + return tuple(repeat(value, 2)) + + +# ------------------------------------------------------------------------------ +# infer_conv_output_attrs() +# ------------------------------------------------------------------------------ + + +# TODO(cfyeh): figure out if we can get `output_dim` without calling the module. +def infer_conv_output_attrs( + module, input_channels, input_dim, batch_size=1, max_length=8 +): + """Get output attributes of a module with input.""" + input = torch.randn(batch_size, input_channels, max_length, input_dim) + output = module(input) + output_channels = output.shape[1] + output_dim = output.shape[-1] + return output_channels, output_dim + + +# ------------------------------------------------------------------------------ +# NoOp +# ------------------------------------------------------------------------------ + + +class NoOp(torch.nn.Module): + """ + NoOp simply passes the input as the output. + """ + + def __init__(self): + super().__init__() + + def forward(self, input: Tensor) -> Tensor: + return input + + +# ------------------------------------------------------------------------------ +# Permute: a torch.nn.Module applies permutation on the input tensor. +# ------------------------------------------------------------------------------ + + +class Permute(torch.nn.Module): + def __init__(self, dims): + super().__init__() + self.dims = dims + + def forward(self, input: Tensor) -> Tensor: + return input.permute(self.dims).contiguous() + + +# ------------------------------------------------------------------------------ +# lengths_to_padding_mask() +# ------------------------------------------------------------------------------ + + +def lengths_to_padding_mask(lengths: Tensor) -> Tensor: + """Convert lengths of shape (B, ) to padding mask.""" + batch_size = lengths.shape[0] + max_length = int(torch.max(lengths).item()) + padding_mask = torch.arange( # [0, ..., T-1] + max_length, device=lengths.device, dtype=lengths.dtype + ).expand(batch_size, max_length) >= lengths.unsqueeze(1) + + return padding_mask + + +# ------------------------------------------------------------------------------ +# lengths_to_attention_mask() +# ------------------------------------------------------------------------------ + + +def lengths_to_attention_mask( + lengths: Tensor, + left_context: Optional[int] = None, + right_context: Optional[int] = None, +) -> Optional[Tensor]: + """ + Generate attention mask based on (lengths, left_context, right_context). + left_context is None means unlimited left context. + right_context is None means unlimited right context. + """ + + if left_context is None and right_context is None: + return None + + max_length = int(torch.max(lengths).item()) + + # For example, with `max_length` == 5, + # indices = tensor([ + # [ 0, 1, 2, 3, 4, 5], + # [-1, 0, 1, 2, 3, 4], + # [-2, -1, 0, 1, 2, 3], + # [-3, -2, -1, 0, 1, 2], + # [-4, -3, -2, -1, 0, 1], + # [-5, -4, -3, -2, -1, 0], + # ]) + + # In some cases the second torch.arange is created on cpu which causes a + # failure. Adding the device option to guard against it. + indices = torch.arange( + max_length, device=lengths.device, dtype=lengths.dtype + ).expand(max_length, max_length) - torch.arange( + max_length, device=lengths.device + ).view( + max_length, -1 + ) + + # For example, with `max_length` == 5, + # bool_mask = tensor([ + # [True, True, True, True, True], + # [True, True, True, True, True], + # [True, True, True, True, True], + # [True, True, True, True, True], + # [True, True, True, True, True], + # ]) + bool_mask = ( + torch.tensor([True]).to(device=lengths.device).expand(max_length, max_length) + ) + + # For example, with `max_length` == 5, left_context == 2 + # left_mask = tensor([ + # [ True, True, True, True, True], + # [ True, True, True, True, True], + # [ True, True, True, True, True], + # [False, True, True, True, True], + # [False, False, True, True, True], + # ]) + if left_context is not None: + left_mask = indices >= -left_context + bool_mask = bool_mask & left_mask + + # For example, with `max_length` == 5, right_context == 1 + # right_mask = tensor([ + # [True, True, False, False, False], + # [True, True, True, False, False], + # [True, True, True, True, False], + # [True, True, True, True, True], + # [True, True, True, True, True], + # ]) + if right_context is not None: + right_mask = indices <= right_context + bool_mask = bool_mask & right_mask + + bool_mask = (~bool_mask).to(device=lengths.device) + return bool_mask + + +# ------------------------------------------------------------------------------ +# infer_output_norm() +# ------------------------------------------------------------------------------ + + +def infer_output_norm(module, output_norm=None): + """ + Infer the output norm (string and module) needed on the module gvien desired + output normalization. + """ + if output_norm == module.output_norm(): + # output_norm already matches module.output_norm(). + return (None, NoOp()) + + if output_norm is None and module.output_norm() is not None: + logger = logging.getLogger("infer_output_norm()") + logger.warning( + "trying to set output_norm ({}) ".format(output_norm) + + "but got module.output_norm() ({}), ".format(module.output_norm()) + + "the combined output_norm() will be ({})".format(module.output_norm()) + ) + return (None, NoOp()) + + if output_norm == "log_softmax": + if module.output_norm() is not None: + raise ValueError( + "incompatible output_norm ({}) ".format(output_norm) + + "and module.output_norm() ({})".format(module.output_norm()) + ) + else: + return ("log_softmax", torch.nn.LogSoftmax(dim=-1)) + + if output_norm == "softmax": + if module.output_norm() is not None: + raise ValueError( + "incompatible output_norm ({}) ".format(output_norm) + + "and module.output_norm() ({})".format(module.output_norm()) + ) + else: + return ("softmax", torch.nn.Softmax(dim=-1)) + + raise ValueError( + "output_norm ({}) not in ".format(output_norm) + + "supported list = [None, softmax, log_softmax]" + ) + + +# ------------------------------------------------------------------------------ +# infer_channels_from_layout() +# ------------------------------------------------------------------------------ + + +def infer_channels_from_layout(layout, channels): + """Extract the number of channels from the layout.""" + if layout in ("TBD", "BTD"): + if channels is not None and channels != 1: + raise ValueError( + "Expected channels ({}) to be 1 for layout = {}".format( + channels, layout + ) + ) + if channels is None: + return 1 + return channels + + +# ------------------------------------------------------------------------------ +# pad_sequence() +# ------------------------------------------------------------------------------ + + +@torch.jit.export +def pad_sequence( + sequence: Tensor, + time_axis: int, + extra_left_context: int = 0, + extra_right_context: int = 0, +) -> Tensor: + """Pad extra left/right contexts to the sequence.""" + + if extra_left_context == 0 and extra_right_context == 0: + return sequence + + tensors_to_concat = [] + + if extra_left_context: + size = (extra_left_context,) + fill_value = 0 + indices = torch.full( + size=size, + fill_value=fill_value, + dtype=torch.long, + device=sequence.device, + ) + left_padding = torch.index_select(sequence, time_axis, indices) + tensors_to_concat.append(left_padding) + + tensors_to_concat.append(sequence) + + # NOTE(cfyeh): for efficiency reason we pad 0 instead of the last frame for + # extra right contexts. + if extra_right_context: + size = list(sequence.shape) + size[time_axis] = extra_right_context + right_padding = torch.zeros(size, dtype=sequence.dtype, device=sequence.device) + tensors_to_concat.append(right_padding) + + padded_sequence = torch.cat(tensors_to_concat, dim=time_axis) + return padded_sequence + + +# ------------------------------------------------------------------------------ +# sequence_to_segments() +# ------------------------------------------------------------------------------ + + +@torch.jit.export +def sequence_to_segments( + sequence: Tensor, + time_axis: int, + lengths: Tensor, + segment_size: Optional[int] = None, + extra_left_context: int = 0, + extra_right_context: int = 0, +) -> List[Tuple[Tensor, Tensor]]: + """Breaks sequence into segments.""" + + sequence = pad_sequence( + sequence=sequence, + time_axis=time_axis, + extra_left_context=extra_left_context, + extra_right_context=extra_right_context, + ) + + lengths = lengths + extra_left_context + extra_right_context + + segments: List[Tuple[Tensor, Tensor]] = [] + + if segment_size is None: + segments.append((sequence, lengths)) + return segments + + offset = 0 + end = sequence.shape[time_axis] + step = segment_size + size = extra_left_context + segment_size + extra_right_context + + while offset + extra_left_context + extra_right_context < end: + clamped_size = min(size, end - offset) + segment_lengths = torch.clamp(lengths - offset, min=0, max=clamped_size) + indices = torch.arange( + start=offset, + end=(offset + clamped_size), + step=1, + dtype=torch.long, + device=sequence.device, + ) + segment_tensor = torch.index_select(sequence, time_axis, indices) + segments.append((segment_tensor, segment_lengths)) + offset = offset + step + + return segments + + +# ------------------------------------------------------------------------------ +# segments_to_sequence() +# ------------------------------------------------------------------------------ + + +@torch.jit.export +def segments_to_sequence( + segments: List[Tuple[Tensor, Tensor]], time_axis: int +) -> Tuple[Tensor, Tensor]: + """Concatenate segments into a full sequence.""" + if len(segments) == 1: + return segments[0] + + tensors_to_concat: List[Tensor] = [] + lengths_to_stack: List[Tensor] = [] + + for tensor, lengths in segments: + tensors_to_concat.append(tensor) + lengths_to_stack.append(lengths) + + sequence = torch.cat(tensors_to_concat, dim=time_axis) + lengths = torch.stack(lengths_to_stack, dim=0) + lengths = torch.sum(lengths, dim=0) + + return sequence, lengths + + +def lengths_to_encoder_padding_mask(lengths, batch_first: bool = False): + """ + convert lengths (a 1-D Long/Int tensor) to 2-D binary tensor + + Args: + lengths: a (B, )-shaped tensor + batch_first: whether to return a (B, T) tensor + + Return: + max_length: maximum length of B sequences + encoder_padding_mask: a (max_length, B) binary mask, where + [t, b] = False for t < lengths[b] and True otherwise + + TODO: + kernelize this function if benchmarking shows this function is slow + """ + max_lengths = torch.max(lengths).item() + bsz = lengths.size(0) + encoder_padding_mask = torch.arange( + max_lengths + ).to( # a (T, ) tensor with [0, ..., T-1] + lengths.device + ).view( # move to the right device + 1, max_lengths + ).expand( # reshape to (1, T)-shaped tensor + bsz, -1 + ) > lengths.view( # expand to (B, T)-shaped tensor + bsz, 1 + ).expand( + -1, max_lengths + ) + if not batch_first: + return encoder_padding_mask.t(), max_lengths + else: + return encoder_padding_mask, max_lengths + + +# ------------------------------------------------------------------------------ +# attention suppression +# ------------------------------------------------------------------------------ + + +def attention_suppression(attention_weights: Tensor, scale: float): + # B, H, qlen, klen -> B, H, qlen, 1 + attention_prob = torch.nn.functional.softmax(attention_weights.float(), dim=-1) + attention_nozeros = attention_prob.to(torch.bool) + nozeros_sum = torch.sum(attention_nozeros.to(torch.float), dim=-1, keepdim=True) + + # For very sparse situation, we need get round about 0s + key_sum = torch.sum(attention_prob, dim=-1, keepdim=True) + + # nozeros_sum should > 1 + key_mean = key_sum / (nozeros_sum + 1e-8) + + # std calculation + dis = (attention_prob - key_mean) * (attention_prob - key_mean) + + # if attention_prob[i] < threshold, then dis_masked[i] = 0; for all i + dis_masked = torch.where( + attention_nozeros, dis, attention_prob.new_zeros(attention_prob.size()) + ) + + key_var = torch.sum(dis_masked, dim=-1, keepdim=True) + key_var = key_var / (nozeros_sum - 1.0 + 1e-8) + key_std = torch.sqrt(key_var) + key_thread = key_mean - scale * key_std + + # if attention_prob[i] >= key_thread, then attention_prob[i] + # , otherwise "-inf" + inf_tensor = attention_prob.new_zeros(attention_prob.size()).detach() + inf_tensor[:] = float("-inf") + attention_weights_float = torch.where( + attention_prob < key_thread, + inf_tensor, + attention_weights.float(), + ) + + return attention_weights_float.type_as(attention_weights) + + +def layer_norm_backward_hook(module, grad_input, grad_output, clamp_value): + return tuple(torch.clamp(v, min=-clamp_value, max=clamp_value) for v in grad_input) diff --git a/fairseq/models/speech_to_text/xm_transformer.py b/fairseq/models/speech_to_text/xm_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..7b4b23464182b233ef178be2a3babfb13872259a --- /dev/null +++ b/fairseq/models/speech_to_text/xm_transformer.py @@ -0,0 +1,855 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import logging +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +from torch import Tensor + +from fairseq import checkpoint_utils, utils +from fairseq.data.data_utils import lengths_to_padding_mask +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderDecoderModel, + FairseqEncoderModel, + FairseqLanguageModel, + register_model, + register_model_architecture, +) +from fairseq.models.speech_to_speech.modules.ctc_decoder import CTCDecoder +from fairseq.models.speech_to_text.hub_interface import S2THubInterface +from fairseq.models.transformer import ( + Embedding, + TransformerDecoder, + TransformerModelBase, +) +from fairseq.models.wav2vec import Wav2VecEncoder +from fairseq.modules.layer_norm import LayerNorm + +logger = logging.getLogger(__name__) + + +def build_embedding(dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + return Embedding(num_embeddings, embed_dim, padding_idx) + + +class Conv1dAdaptor(nn.Module): + def __init__( + self, + in_dim, + out_dim, + n_layers=3, + kernel_size=3, + stride=2, + layerdrop=0.0, + layernorm=False, + proj=False, + ): + super().__init__() + self.proj, self.proj_ln = None, None + self.post_proj, self.post_proj_ln = None, None + if proj: + self.proj = nn.Sequential( + nn.Linear(in_dim, in_dim * 4), nn.ReLU(), nn.Linear(in_dim * 4, in_dim) + ) + self.proj_ln = LayerNorm(in_dim) + self.post_proj = nn.Sequential( + nn.Linear(out_dim, out_dim * 4), + nn.ReLU(), + nn.Linear(out_dim * 4, out_dim), + ) + self.post_proj_ln = LayerNorm(out_dim) + + self.layers = nn.ModuleList( + nn.Conv1d( + in_dim if i == 0 else out_dim, + out_dim * 2, + kernel_size, + stride=stride, + padding=kernel_size // 2, + ) + for i in range(n_layers) + ) + self.stride = stride + self.layerdrop = layerdrop + self.layernorm = LayerNorm(in_dim) if layernorm else None + + @classmethod + def add_args(cls, parser): + parser.add_argument("--adaptor-n-layers", type=int) + parser.add_argument("--adaptor-kernel-size", type=int) + parser.add_argument("--adaptor-stride", type=int) + parser.add_argument("--adaptor-layerdrop", type=float) + parser.add_argument("--adaptor-layernorm", action="store_true") + parser.add_argument("--adaptor-proj", action="store_true") + + def forward(self, x, padding_mask: Optional[torch.Tensor]): + if self.layernorm is not None: + x = self.layernorm(x) + + if self.proj is not None: + x = x + 0.5 * self.proj(x) + x = self.proj_ln(x) + + if padding_mask is not None: + x = utils.index_put(x, padding_mask.T, 0) + + # T x B x C -> B x C x T + x = x.transpose(0, 1).transpose(1, 2) + out_lens = None + if padding_mask is not None: + out_lens = (~padding_mask).sum(1).float() + + for layer in self.layers: + layerdrop_prob = np.random.random() + if not self.training or (layerdrop_prob > self.layerdrop): + x = nn.functional.glu(layer(x), dim=1) + if padding_mask is not None: + out_lens = ((out_lens - 1) / self.stride + 1).floor() + # B x C x T -> T x B x C + x = x.transpose(1, 2).transpose(0, 1) + + if self.post_proj is not None: + x = x + 0.5 * self.post_proj(x) + x = self.post_proj_ln(x) + + out_padding_mask = None + if padding_mask is not None: + out_padding_mask = lengths_to_padding_mask(out_lens.long()) + x = utils.index_put(x, out_padding_mask.T, 0) + return x, out_padding_mask + + +def add_wav2vec_asr_args(parser): + parser.add_argument("--w2v-path", help="path to wav2vec 2.0 model") + parser.add_argument( + "--no-pretrained-weights", + action="store_true", + help="if true, does not load pretrained weights", + ) + parser.add_argument( + "--dropout-input", + type=float, + metavar="D", + help="dropout to apply to the input (after feat extr)", + ) + parser.add_argument( + "--final-dropout", + type=float, + metavar="D", + help="dropout after transformer and before final projection", + ) + parser.add_argument( + "--apply-mask", action="store_true", help="apply masking during fine-tuning" + ) + parser.add_argument( + "--dropout", + type=float, + metavar="D", + help="dropout probability inside wav2vec 2.0 model", + ) + parser.add_argument( + "--attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights inside wav2vec 2.0 model", + ) + parser.add_argument( + "--activation-dropout", + "--relu-dropout", + type=float, + metavar="D", + help="dropout probability after activation in FFN inside wav2vec 2.0 model", + ) + parser.add_argument( + "--mask-length", type=int, help="repeat the mask indices multiple times" + ) + parser.add_argument( + "--mask-prob", type=float, help="probability of replacing a token with mask" + ) + parser.add_argument( + "--mask-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + help="how to choose masks", + ) + parser.add_argument( + "--mask-other", + type=float, + help="stdev of the mask length in case of 'normal' selection strategy", + ) + parser.add_argument( + "--no-mask-overlap", + action="store_true", + help="whether to allow masks to overlap", + ) + parser.add_argument( + "--mask-channel-length", type=int, help="repeat the mask indices multiple times" + ) + parser.add_argument( + "--mask-channel-prob", + type=float, + help="probability of replacing a token with mask", + ) + parser.add_argument( + "--mask-channel-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + help="how to choose masks", + ) + parser.add_argument( + "--mask-channel-other", + type=float, + help="stdev of the mask length in case of 'normal' selection strategy", + ) + parser.add_argument( + "--no-mask-channel-overlap", + action="store_true", + help="whether to allow masks to overlap", + ) + parser.add_argument( + "--freeze-finetune-updates", + type=int, + metavar="N", + help="dont finetune wav2vec for this many updates", + ) + parser.add_argument( + "--feature-grad-mult", + type=float, + metavar="D", + help="reset feature grad mult in wav2vec 2.0 to this", + ) + parser.add_argument( + "--layerdrop", + type=float, + metavar="D", + help="probability of dropping a layer in wav2vec 2.0", + ) + parser.add_argument( + "--max-positions", + type=int, + metavar="N", + help="Max input positions to be used in the conformer encoder in wav2vec 2.0", + ) + parser.add_argument("--encoder-proj", action="store_true") + parser.add_argument("--w2v-args", default=None) + parser.add_argument( + "--remove-weight-norm", + action="store_true", + help="if set, then the weight-norm (in one pos_conv layer) is removed from the model", + ) + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension to be used when w2v_path is None and no encoder_proj is set", + ) + + +def need_finetuning(ft_params, param_name): + if ft_params == "all": + return True + ft_params_list = ft_params.split(",") + for ft_param in ft_params_list: + if ft_param in param_name: + return True + return False + + +class Wav2VecEncoderWithAdaptor(FairseqEncoder): + def build_adaptor(self, args): + adaptor = None + if args.adaptor_n_layers > 0: + adaptor = Conv1dAdaptor( + args.decoder_embed_dim, + args.decoder_embed_dim, + n_layers=args.adaptor_n_layers, + kernel_size=args.adaptor_kernel_size, + stride=args.adaptor_stride, + layerdrop=args.adaptor_layerdrop, + layernorm=args.adaptor_layernorm, + proj=args.adaptor_proj, + ) + return adaptor + + def __init__(self, args): + super().__init__(None) + self.w2v_encoder = Wav2VecEncoder(args) + self.is_v0_arch = not args.adaptor_proj + self.w2v_proj_ln = None + if not self.is_v0_arch and self.w2v_encoder.proj is not None: + self.w2v_proj_ln = LayerNorm(args.decoder_embed_dim) + self.adaptor = self.build_adaptor(args) + + self.num_updates = 0 + self.freezing_updates = args.w2v_freezing_updates + self.finetuning_params = args.finetune_w2v_params + for k, p in self.w2v_encoder.w2v_model.named_parameters(): + p.requires_grad = need_finetuning(self.finetuning_params, k) + + @classmethod + def add_args(cls, parser): + """Add model-specific arguments to the parser.""" + add_wav2vec_asr_args(parser) + parser.add_argument( + "--normalize", + action="store_true", + help="if set, normalizes input to have 0 mean and unit variance", + ) + parser.add_argument( + "--finetune-w2v-params", + type=str, + metavar="STR", + help="comma-separated param strings to finetune.", + ) + parser.add_argument("--w2v-freezing-updates", type=int) + parser.add_argument("--load-pretrained-encoder-from", type=str, metavar="STR") + Conv1dAdaptor.add_args(parser) + + def set_num_updates(self, num_updates): + super().set_num_updates(num_updates) + self.num_updates = num_updates + + def forward(self, src_tokens, src_lengths=None, **kwargs): + if ( + self.freezing_updates is not None + and self.num_updates > self.freezing_updates + ): + for p in self.w2v_encoder.w2v_model.parameters(): + p.requires_grad = True + + padding_mask = lengths_to_padding_mask(src_lengths) + out = self.w2v_encoder.forward(src_tokens, padding_mask, tbc=True) + x, padding_mask = out["encoder_out"], out["padding_mask"] + if self.w2v_proj_ln is not None: + x = self.w2v_proj_ln(x) + + if self.adaptor is not None: + x, padding_mask = self.adaptor(x, padding_mask) + + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [] + if padding_mask is None + else [padding_mask], # B x T + "encoder_embedding": [], # B x T x C + "encoder_states": [], # List[T x B x C] + "src_tokens": [], + "src_lengths": [], + } + + def reorder_encoder_out(self, encoder_out, new_order): + new_encoder_out = ( + [] + if len(encoder_out["encoder_out"]) == 0 + else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]] + ) + + new_encoder_padding_mask = ( + [] + if len(encoder_out["encoder_padding_mask"]) == 0 + else [ + x.index_select(0, new_order) + for x in encoder_out["encoder_padding_mask"] + ] + ) + + new_encoder_embedding = ( + [] + if len(encoder_out["encoder_embedding"]) == 0 + else [ + x.index_select(0, new_order) for x in encoder_out["encoder_embedding"] + ] + ) + + encoder_states = encoder_out["encoder_states"] + if len(encoder_states) > 0: + for idx, state in enumerate(encoder_states): + encoder_states[idx] = state.index_select(1, new_order) + + return { + "encoder_out": new_encoder_out, # T x B x C + "encoder_padding_mask": new_encoder_padding_mask, # B x T + "encoder_embedding": new_encoder_embedding, # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], # B x T + "src_lengths": [], # B x 1 + } + + +def add_decoder_args(parser): + parser.add_argument( + "--activation-fn", + type=str, + default="relu", + choices=utils.get_available_activation_fns(), + help="activation function to use", + ) + parser.add_argument( + "--decoder-dropout", type=float, metavar="D", help="dropout probability" + ) + parser.add_argument( + "--decoder-attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights", + ) + parser.add_argument( + "--decoder-activation-dropout", + type=float, + metavar="D", + help="dropout probability after activation in FFN.", + ) + parser.add_argument( + "--decoder-embed-dim", type=int, metavar="N", help="decoder embedding dimension" + ) + parser.add_argument( + "--decoder-ffn-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension for FFN", + ) + parser.add_argument( + "--decoder-layers", type=int, metavar="N", help="num decoder layers" + ) + parser.add_argument( + "--decoder-attention-heads", + type=int, + metavar="N", + help="num decoder attention heads", + ) + parser.add_argument( + "--decoder-normalize-before", + action="store_true", + help="apply layernorm before each decoder block", + ) + parser.add_argument( + "--layernorm-embedding", action="store_true", help="add layernorm to embedding" + ) + parser.add_argument( + "--decoder-layerdrop", + type=float, + metavar="D", + help="layerdrop probability for decoder", + ) + parser.add_argument( + "--decoder-learned-pos", + action="store_true", + help="learn positional embedding in decoder", + ) + parser.add_argument( + "--share-decoder-input-output-embed", + action="store_true", + help="share decoder input and output embeddings", + ) + parser.add_argument( + "--no-scale-embedding", + action="store_true", + help="if True, dont scale embeddings", + ) + parser.add_argument( + "--load-pretrained-decoder-from", + type=str, + metavar="STR", + help="model to take decoder weights from (for initialization)", + ) + parser.add_argument( + "--finetune-decoder-params", + type=str, + metavar="STR", + help="comma-separated param strings to finetune.", + ) + + +def remove_weight_norm_from_model(model): + from functools import reduce + + layers_with_wn = [] + for param_name, _ in model.named_parameters(): + if param_name.endswith("_g"): + # retrieve the module with this param_name + module_names = param_name.split(".")[ + :-1 + ] # exclude the actual parameter name + wn_module = reduce(getattr, module_names, model) + layers_with_wn.append(wn_module) + for wn_module in layers_with_wn: + torch.nn.utils.remove_weight_norm(wn_module) + logger.warning(f"Weight norm removed from module with {wn_module}\n") + + +@register_model("xm_transformer") +class XMTransformerModel(FairseqEncoderDecoderModel): + @classmethod + def hub_models(cls): + base_url = "http://dl.fbaipublicfiles.com/fairseq/s2t" + model_ids = [ + "xm_transformer_600m-es_en-multi_domain", + "xm_transformer_600m-ru_en-multi_domain", + "xm_transformer_600m-fr_en-multi_domain", + "xm_transformer_600m-en_es-multi_domain", + "xm_transformer_600m-en_ru-multi_domain", + "xm_transformer_600m-en_fr-multi_domain", + "xm_transformer_600m-en_zh-multi_domain", + "xm_transformer_600m-en_ar-multi_domain", + "xm_transformer_600m-en_tr-multi_domain", + "xm_transformer_600m-en_vi-multi_domain", + "xm_transformer-21_en-xls_r_300m", + "xm_transformer-en_15-xls_r_300m", + "xm_transformer-21_en-xls_r_1b", + "xm_transformer-en_15-xls_r_1b", + "xm_transformer-21_en-xls_r_2b", + "xm_transformer-en_15-xls_r_2b", + "xm_transformer-22_16-xls_r_2b", + "xm_transformer_s2ut_800m-es-en-st-asr-bt_h1_2022", + "xm_transformer_s2ut_800m-en-es-st_plus_asr", + "xm_transformer_s2ut_800m-hk-en-h1_2022", + "xm_transformer_s2ut_800m-en-hk-h1_2022", + ] + return {i: f"{base_url}/{i}.tar.gz" for i in model_ids} + + @classmethod + def from_pretrained( + cls, + model_name_or_path, + checkpoint_file="model.pt", + data_name_or_path=".", + config_yaml="config.yaml", + task="speech_to_text", + generation_args=None, + **kwargs, + ): + from fairseq import hub_utils + + x = hub_utils.from_pretrained( + model_name_or_path, + checkpoint_file, + data_name_or_path, + archive_map=cls.hub_models(), + config_yaml=config_yaml, + task=task, + generation_args=generation_args, + **kwargs, + ) + return S2THubInterface(x["args"], x["task"], x["models"][0]) + + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @classmethod + def add_args(cls, parser): + """Add model-specific arguments to the parser.""" + Wav2VecEncoderWithAdaptor.add_args(parser) + add_decoder_args(parser) + parser.add_argument("--checkpoint-activations", action="store_true") + parser.add_argument("--offload-activations", action="store_true") + parser.add_argument("--min-params-to-wrap", type=int, metavar="N") + + @classmethod + def maybe_load_pretrained(cls, component, checkpoint: Optional[str] = None): + if checkpoint is None: + return component + + _load = checkpoint_utils.load_pretrained_component_from_model + try: + return _load(component, checkpoint) + except RuntimeError as e: + logger.warning(e) + return _load(component, checkpoint, strict=False) + + @classmethod + def build_encoder(cls, args): + _args = copy.deepcopy(args) + if not args.adaptor_proj and not args.encoder_proj: # V0 arch + if args.w2v_path: + state = checkpoint_utils.load_checkpoint_to_cpu(args.w2v_path) + if state.get("cfg") is not None: + encoder_embed_dim = state["cfg"]._content["model"][ + "encoder_embed_dim" + ] + elif state.get("args") is not None: + encoder_embed_dim = state["args"].encoder_embed_dim + else: + raise ValueError(f"Invalid config in {args.w2v_path}") + _args.decoder_embed_dim = encoder_embed_dim + del state + else: + _args.decoder_embed_dim = args.encoder_embed_dim + + encoder = Wav2VecEncoderWithAdaptor(_args) + encoder = cls.maybe_load_pretrained( + encoder, getattr(args, "load_pretrained_encoder_from", None) + ) + if args.remove_weight_norm: + # remove the wn for EMA usage + logger.warning("Removing weight norm from wav2vec encoder") + remove_weight_norm_from_model(encoder) + + return encoder + + @classmethod + def get_decoder_args_from_checkpoint(cls, ckpt_args): + assert "model" in ckpt_args, "Model args not found in checkpoint cfg!" + decoder_args = {} + for k, v in ckpt_args["model"].__dict__.items(): + if "decoder" in k: + decoder_args[k] = v + + return decoder_args + + @classmethod + def override_decoder_args(cls, cli_args, decoder_args_dict): + for k, v in decoder_args_dict.items(): + if v != getattr(cli_args, k, None): + logger.warning( + f"Overriding decoder arg {k}: from {getattr(cli_args, k, None)} to {v}" + ) + setattr(cli_args, k, v) + + return cli_args + + @classmethod + def build_decoder(cls, args, task, embed_tokens): + _args = copy.deepcopy(args) + if args.adaptor_proj or args.encoder_proj: # not V0 arch + _args.encoder_embed_dim = _args.decoder_embed_dim + _args.dropout = args.decoder_dropout + _args.attention_dropout = args.decoder_attention_dropout + _args.activation_dropout = args.decoder_activation_dropout + _args.layerdrop = _args.decoder_layerdrop + + decoder = TransformerDecoder(_args, task.target_dictionary, embed_tokens) + decoder = cls.maybe_load_pretrained( + decoder, getattr(args, "load_pretrained_decoder_from", None) + ) + + for k, p in decoder.named_parameters(): + p.requires_grad = need_finetuning(args.finetune_decoder_params, k) + return decoder + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_architecture(args) + if getattr(args, "load_pretrained_decoder_from", None) is not None: + ckpt = torch.load(getattr(args, "load_pretrained_decoder_from", None)) + decoder_args_dict = cls.get_decoder_args_from_checkpoint(ckpt["cfg"]) + args = cls.override_decoder_args(args, decoder_args_dict) + + decoder_embed_tokens = build_embedding( + task.target_dictionary, args.decoder_embed_dim + ) + + encoder = cls.build_encoder(args) + decoder = cls.build_decoder(args, task, decoder_embed_tokens) + base_model = cls(encoder, decoder) + + # set up multitask decoders + base_model.multitask_decoders = {} + for i, (task_name, task_obj) in enumerate(task.multitask_tasks.items()): + # dummy auxiliary decoder + if task_obj.args.get_loss_weight(0) == 0: + continue + + task_decoder = cls.build_multitask_decoder( + args, task_obj.args, task_obj.target_dictionary, args.decoder_embed_dim + ) + + setattr(base_model, f"{task_name}_decoder", task_decoder) + decoder_model_cls = ( + FairseqEncoderModel + if task_obj.args.decoder_type == "ctc" + else FairseqLanguageModel + ) + base_model.multitask_decoders[task_name] = decoder_model_cls( + getattr(base_model, f"{task_name}_decoder") + ) + return base_model + + @classmethod + def build_multitask_decoder( + cls, + args, + mtl_args, + tgt_dict, + in_dim, + is_first_pass_decoder=False, + ): + decoder_args = mtl_args.decoder_args + decoder_args.encoder_embed_dim = in_dim + if mtl_args.decoder_type == "transformer": + if is_first_pass_decoder: + task_decoder = cls.build_text_decoder(args, tgt_dict) + else: + from fairseq.models.speech_to_speech import ( + base_multitask_text_transformer_decoder_arch, + ) + + base_multitask_text_transformer_decoder_arch(decoder_args) # 2L + task_decoder = TransformerDecoder( + decoder_args, + tgt_dict, + embed_tokens=TransformerModelBase.build_embedding( + decoder_args, + tgt_dict, + decoder_args.decoder_embed_dim, + ), + ) + elif mtl_args.decoder_type == "ctc": + task_decoder = CTCDecoder( + dictionary=tgt_dict, + in_dim=in_dim, + ) + else: + raise NotImplementedError( + "currently only support multitask decoder_type 'transformer', 'ctc'" + ) + + return task_decoder + + def get_normalized_probs( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + return self.get_normalized_probs_scriptable(net_output, log_probs, sample) + + def forward( + self, + src_tokens, + src_lengths, + prev_output_tokens, + return_all_hiddens=False, + **kwargs, + ): + """ + The forward method inherited from the base class has a **kwargs + argument in its input, which is not supported in torchscript. This + method overwrites the forward method definition without **kwargs. + """ + encoder_out = self.encoder( + src_tokens=src_tokens, src_lengths=src_lengths, **kwargs + ) + decoder_out = self.decoder( + prev_output_tokens=prev_output_tokens, encoder_out=encoder_out + ) + if return_all_hiddens: + decoder_out[-1]["encoder_states"] = encoder_out["encoder_out"] + # NOTE: from the top layer + decoder_out[-1]["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ] + return decoder_out + + def upgrade_state_dict(self, state_dict): + for k, _ in state_dict.items(): + if "adaptor.layers" in state_dict: + new = k.replace("adaptor.layers", "adaptor_layers") + state_dict[new] = state_dict[k] + del state_dict[k] + + +def set_default_w2v_encoder_args(args): + args.no_pretrained_weights = getattr(args, "no_pretrained_weights", False) + args.dropout_input = getattr(args, "dropout_input", 0) + args.final_dropout = getattr(args, "final_dropout", 0) + args.apply_mask = getattr(args, "apply_mask", False) + args.dropout = getattr(args, "dropout", 0) + args.attention_dropout = getattr(args, "attention_dropout", 0) + args.activation_dropout = getattr(args, "activation_dropout", 0) + args.encoder_proj = getattr(args, "encoder_proj", False) + args.remove_weight_norm = getattr(args, "remove_weight_norm", False) + + args.mask_length = getattr(args, "mask_length", 10) + args.mask_prob = getattr(args, "mask_prob", 0.5) + args.mask_selection = getattr(args, "mask_selection", "static") + args.mask_other = getattr(args, "mask_other", 0) + args.no_mask_overlap = getattr(args, "no_mask_overlap", False) + args.mask_channel_length = getattr(args, "mask_channel_length", 10) + args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.5) + args.mask_channel_before = getattr(args, "mask_channel_before", False) + args.mask_channel_selection = getattr(args, "mask_channel_selection", "static") + args.mask_channel_other = getattr(args, "mask_channel_other", 0) + args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False) + + args.freeze_finetune_updates = getattr(args, "freeze_finetune_updates", 0) + args.feature_grad_mult = 0.1 + args.layerdrop = getattr(args, "layerdrop", 0.0) + + args.normalize = getattr(args, "normalize", False) + args.finetune_w2v_params = getattr(args, "finetune_w2v_params", "all") + args.w2v_freezing_updates = getattr(args, "w2v_freezing_updates", None) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + + +def set_default_adaptor_args(args): + args.adaptor_n_layers = getattr(args, "adaptor_n_layers", 3) + args.adaptor_kernel_size = getattr(args, "adaptor_kernel_size", 3) + args.adaptor_stride = getattr(args, "adaptor_stride", 2) + args.adaptor_layerdrop = getattr(args, "adaptor_layerdrop", 0.0) + args.adaptor_layernorm = getattr(args, "adaptor_layernorm", False) + args.adaptor_proj = getattr(args, "adaptor_proj", False) + + +def set_default_transformer_decoder_args(args): + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4 * 1024) + args.decoder_layers = getattr(args, "decoder_layers", 12) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.decoder_attention_dropout = getattr(args, "decoder_attention_dropout", 0.0) + args.decoder_activation_dropout = getattr(args, "decoder_activation_dropout", 0.0) + args.decoder_dropout = getattr(args, "decoder_dropout", 0.1) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + args.layernorm_embedding = getattr(args, "layernorm_embedding", False) + + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") + args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) + + args.finetune_decoder_params = getattr(args, "finetune_decoder_params", "all") + + +def set_default_general_args(args): + args.checkpoint_activations = getattr(args, "checkpoint_activations", False) + args.offload_activations = getattr(args, "offload_activations", False) + args.min_params_to_wrap = getattr(args, "min_params_to_wrap", int(1e8)) + args.max_positions = getattr(args, "max_positions", 3000) + + +@register_model_architecture(model_name="xm_transformer", arch_name="xm_transformer") +def base_architecture(args): + set_default_general_args(args) + set_default_w2v_encoder_args(args) + set_default_adaptor_args(args) + set_default_transformer_decoder_args(args) diff --git a/fairseq/models/speech_to_text/xm_transformer_unity.py b/fairseq/models/speech_to_text/xm_transformer_unity.py new file mode 100644 index 0000000000000000000000000000000000000000..f77ef4e5707fdce1bd93b46e9bbaad5890f8e7bf --- /dev/null +++ b/fairseq/models/speech_to_text/xm_transformer_unity.py @@ -0,0 +1,315 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import logging + +from fairseq.models import ( + FairseqEncoderModel, + FairseqLanguageModel, + register_model, + register_model_architecture, +) +from fairseq.models.speech_to_speech.modules.ctc_decoder import CTCDecoder +from fairseq.models.speech_to_speech.modules.transformer_encoder import ( + TransformerEncoderNoEmb, +) +from fairseq.models.speech_to_text.xm_transformer import XMTransformerModel +from fairseq.models.speech_to_text.xm_transformer import ( + base_architecture as xm_t_base_architecture, +) +from fairseq.models.speech_to_text.xm_transformer import ( + build_embedding, + need_finetuning, + set_default_adaptor_args, + set_default_general_args, + set_default_transformer_decoder_args, + set_default_w2v_encoder_args, +) +from fairseq.models.transformer import Linear, TransformerDecoder, TransformerModelBase +from fairseq.models.transformer.transformer_decoder_aug import AugTransformerDecoder + +logger = logging.getLogger(__name__) + + +def unit_transformer_decoder_arch_base( + args, decoder_layers=6, decoder_embed_dim=768, decoder_attention_heads=12 +): + args.encoder_layers = decoder_layers + args.decoder_layers = decoder_layers + args.decoder_embed_dim = decoder_embed_dim + args.decoder_ffn_embed_dim = decoder_embed_dim * 4 + args.decoder_attention_heads = decoder_attention_heads + args.encoder_embed_dim = args.decoder_embed_dim + args.decoder_output_dim = decoder_embed_dim + args.decoder_input_dim = decoder_embed_dim + + +def unit_transformer_decoder_arch_large( + args, decoder_layers=12, decoder_embed_dim=1024, decoder_attention_heads=16 +): + args.encoder_layers = decoder_layers + args.decoder_layers = decoder_layers + args.decoder_embed_dim = decoder_embed_dim + args.decoder_ffn_embed_dim = decoder_embed_dim * 4 + args.decoder_attention_heads = decoder_attention_heads + args.encoder_embed_dim = args.decoder_embed_dim + args.decoder_output_dim = decoder_embed_dim + args.decoder_input_dim = decoder_embed_dim + + +@register_model("unity_xm_transformer") +class XMTransformerModelUnitY(XMTransformerModel): + @classmethod + def hub_models(cls): + base_url = "http://dl.fbaipublicfiles.com/fairseq/s2t" + model_ids = [] + return {i: f"{base_url}/{i}.tar.gz" for i in model_ids} + + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @classmethod + def add_args(cls, parser): + """Add model-specific arguments to the parser.""" + XMTransformerModel.add_args(parser) + parser.add_argument( + "--translation-decoder-layers", + type=int, + default=4, + metavar="N", + help="num decoder layers in the first-pass translation module", + ) + parser.add_argument( + "--synthesizer-encoder-layers", + type=int, + default=0, + metavar="N", + help="num encoder layers in the second-pass synthesizer module", + ) + parser.add_argument( + "--synthesizer-augmented-cross-attention", + action="store_true", + default=False, + help="augmented cross-attention over speech encoder output", + ) + parser.add_argument( + "--load-pretrained-aux-decoder-from", + type=str, + metavar="STR", + help="model to take decoder weights from (for initialization)", + ) + + @classmethod + def build_text_decoder(cls, args, tgt_dict): + _args = copy.deepcopy(args) + + if args.adaptor_proj or args.encoder_proj: # not V0 arch + _args.encoder_embed_dim = _args.decoder_embed_dim + _args.dropout = args.decoder_dropout + _args.attention_dropout = args.decoder_attention_dropout + _args.activation_dropout = args.decoder_activation_dropout + _args.layerdrop = _args.decoder_layerdrop + _args.decoder_layers = _args.translation_decoder_layers + + embed_tokens = build_embedding(tgt_dict, _args.decoder_embed_dim) + decoder = TransformerDecoder(_args, tgt_dict, embed_tokens) + + if getattr(args, "load_pretrained_aux_decoder_from", None) is not None: + decoder = cls.maybe_load_pretrained( + decoder, getattr(args, "load_pretrained_aux_decoder_from", None) + ) + + for k, p in decoder.named_parameters(): + p.requires_grad = need_finetuning(args.finetune_decoder_params, k) + return decoder + + @classmethod + def build_decoder(cls, args, task, aug_attn=False): + _args = copy.deepcopy(args) + _args.layerdrop = 0.0 # turn off layerdrop for shallow layers + + _args.encoder_embed_dim = args.decoder_embed_dim + + proj = None + if args.decoder_embed_dim != _args.decoder_embed_dim: + proj = Linear(args.decoder_embed_dim, _args.decoder_embed_dim) + + embed_tokens = build_embedding(task.target_dictionary, _args.decoder_embed_dim) + decoder_cls = AugTransformerDecoder if aug_attn else TransformerDecoder + decoder = decoder_cls(_args, task.target_dictionary, embed_tokens) + + if getattr(args, "load_pretrained_decoder_from", None) is not None: + # load all layers first and then discard the bottom layers + embed_tokens = build_embedding( + task.target_dictionary, _args.decoder_embed_dim + ) + decoder_tmp = decoder_cls(_args, task.target_dictionary, embed_tokens) + decoder_tmp = cls.maybe_load_pretrained( + decoder_tmp, getattr(_args, "load_pretrained_decoder_from", None) + ) + state_dict = decoder_tmp.state_dict() + for k, p in decoder.named_parameters(): + p.data = state_dict[k].data + p.requires_grad = need_finetuning(_args.finetune_decoder_params, k) + decoder.layers = decoder.layers[-_args.decoder_layers :] + + return decoder, proj, _args + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + xm_t_base_architecture(args) + + encoder = cls.build_encoder(args) + decoder, proj, unit_args = cls.build_decoder( + args, + task, + aug_attn=getattr(args, "synthesizer_augmented_cross_attention", False), + ) + base_model = cls(encoder, decoder) + setattr(base_model, "proj", proj) + + base_model.t2u_augmented_cross_attn = getattr( + args, "synthesizer_augmented_cross_attention", False + ) + + # set up multitask decoders + base_model.mt_task_name = None + base_model.multitask_decoders = {} + has_first_pass_decoder = False + for task_name, task_obj in task.multitask_tasks.items(): + if task_obj.is_first_pass_decoder: + has_first_pass_decoder = True + base_model.mt_task_name = task_name + + task_decoder = cls.build_multitask_decoder( + args, + task_obj.args, + task_obj.target_dictionary, + args.decoder_embed_dim, + task_obj.is_first_pass_decoder, + ) + + setattr(base_model, f"{task_name}_decoder", task_decoder) + decoder_model_cls = ( + FairseqEncoderModel + if task_obj.args.decoder_type == "ctc" + else FairseqLanguageModel + ) + base_model.multitask_decoders[task_name] = decoder_model_cls( + getattr(base_model, f"{task_name}_decoder") + ) + + assert has_first_pass_decoder, "set at least one intermediate non-CTC decoder" + + # set up encoder on top of the auxiliary MT decoder + if getattr(args, "synthesizer_encoder_layers", 0) > 0: + base_model.synthesizer_encoder = cls.build_t2u_encoder(unit_args) + else: + base_model.synthesizer_encoder = None + + return base_model + + @classmethod + def build_t2u_encoder(cls, args): + _args = copy.deepcopy(args) + _args.encoder_layers = _args.synthesizer_encoder_layers + _args.encoder_embed_dim = args.decoder_embed_dim + _args.encoder_ffn_embed_dim = args.decoder_ffn_embed_dim + _args.encoder_attention_heads = args.decoder_attention_heads + _args.encoder_normalize_before = True + return TransformerEncoderNoEmb(_args) + + def forward( + self, + src_tokens, + src_lengths, + prev_output_tokens, + prev_output_tokens_mt, + return_all_hiddens=False, + tgt_speaker=None, + **kwargs, + ): + """ + The forward method inherited from the base class has a **kwargs + argument in its input, which is not supported in torchscript. This + method overwrites the forward method definition without **kwargs. + """ + encoder_out = self.encoder( + src_tokens=src_tokens, src_lengths=src_lengths, **kwargs + ) + + # 1. MT decoder + mt_decoder = getattr(self, f"{self.mt_task_name}_decoder") + mt_decoder_out = mt_decoder( + prev_output_tokens_mt, + encoder_out=encoder_out, + ) + x = mt_decoder_out[1]["inner_states"][-1] + if mt_decoder.layer_norm is not None: + x = mt_decoder.layer_norm(x) + if self.proj is not None: + x = self.proj(x) + + mt_decoder_padding_mask = None + if prev_output_tokens_mt.eq(mt_decoder.padding_idx).any(): + mt_decoder_padding_mask = prev_output_tokens_mt.eq(mt_decoder.padding_idx) + + # 2. T2U encoder + if self.synthesizer_encoder is not None: + t2u_encoder_out = self.synthesizer_encoder( + x, + mt_decoder_padding_mask, + ) + else: + t2u_encoder_out = { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [mt_decoder_padding_mask], # B x T + } + + # 3. T2U decoder + if self.t2u_augmented_cross_attn: + decoder_out = self.decoder( + prev_output_tokens, + encoder_out=encoder_out, + encoder_out_aug=t2u_encoder_out, + ) + else: + decoder_out = self.decoder( + prev_output_tokens, + encoder_out=t2u_encoder_out, + ) + if return_all_hiddens: + decoder_out[-1]["encoder_states"] = encoder_out["encoder_out"] + # NOTE: from the top layer + decoder_out[-1]["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ] + decoder_out[-1]["mt_decoder_out"] = mt_decoder_out + return decoder_out + + +@register_model_architecture( + model_name="unity_xm_transformer", arch_name="unity_xm_transformer" +) +def base_architecture_unity(args): + set_default_general_args(args) + set_default_w2v_encoder_args(args) + set_default_adaptor_args(args) + set_default_transformer_decoder_args(args) + + args.layernorm_embedding = False + args.decoder_learned_pos = False + + +# for old models +@register_model_architecture( + model_name="unity_xm_transformer", arch_name="xm_transformer_t2" +) +def base_architecture_unity_legacy(args): + base_architecture_unity(args) diff --git a/fairseq/models/text_to_speech/__init__.py b/fairseq/models/text_to_speech/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c0dcd69b07f5e99c8ec60139471a34b140e46b29 --- /dev/null +++ b/fairseq/models/text_to_speech/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .tacotron2 import * # noqa +from .tts_transformer import * # noqa +from .fastspeech2 import * # noqa +from .vocoder import * # noqa diff --git a/fairseq/models/text_to_speech/__pycache__/__init__.cpython-310.pyc b/fairseq/models/text_to_speech/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81a58b91234dfc5da2ee347dbc0b009b3b46f7db Binary files /dev/null and b/fairseq/models/text_to_speech/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/models/text_to_speech/__pycache__/codehifigan.cpython-310.pyc b/fairseq/models/text_to_speech/__pycache__/codehifigan.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e1eb0c9086fbc69a8808b75aebf66038a631986 Binary files /dev/null and b/fairseq/models/text_to_speech/__pycache__/codehifigan.cpython-310.pyc differ diff --git a/fairseq/models/text_to_speech/__pycache__/fastspeech2.cpython-310.pyc b/fairseq/models/text_to_speech/__pycache__/fastspeech2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d88c7a0e04551b1e6f1e91dbdc962d808f557db6 Binary files /dev/null and b/fairseq/models/text_to_speech/__pycache__/fastspeech2.cpython-310.pyc differ diff --git a/fairseq/models/text_to_speech/__pycache__/hifigan.cpython-310.pyc b/fairseq/models/text_to_speech/__pycache__/hifigan.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38f0df7caef8b82991b29a0e82f9aab356490027 Binary files /dev/null and b/fairseq/models/text_to_speech/__pycache__/hifigan.cpython-310.pyc differ diff --git a/fairseq/models/text_to_speech/__pycache__/hub_interface.cpython-310.pyc b/fairseq/models/text_to_speech/__pycache__/hub_interface.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12748e540300c6c3fa995174f7e30760a844537c Binary files /dev/null and b/fairseq/models/text_to_speech/__pycache__/hub_interface.cpython-310.pyc differ diff --git a/fairseq/models/text_to_speech/__pycache__/tacotron2.cpython-310.pyc b/fairseq/models/text_to_speech/__pycache__/tacotron2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb9483eadf839a129ba9c214fa206b08dbea43a7 Binary files /dev/null and b/fairseq/models/text_to_speech/__pycache__/tacotron2.cpython-310.pyc differ diff --git a/fairseq/models/text_to_speech/__pycache__/tts_transformer.cpython-310.pyc b/fairseq/models/text_to_speech/__pycache__/tts_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f9c0fed4ca9e93c48cbc8a8986428bde3133e3f Binary files /dev/null and b/fairseq/models/text_to_speech/__pycache__/tts_transformer.cpython-310.pyc differ diff --git a/fairseq/models/text_to_speech/__pycache__/vocoder.cpython-310.pyc b/fairseq/models/text_to_speech/__pycache__/vocoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79ea5c022c42b11137e2d64d5a98b02963d03a9d Binary files /dev/null and b/fairseq/models/text_to_speech/__pycache__/vocoder.cpython-310.pyc differ diff --git a/fairseq/models/text_to_speech/codehifigan.py b/fairseq/models/text_to_speech/codehifigan.py new file mode 100644 index 0000000000000000000000000000000000000000..d1574dd63fb6ef1fd10022f908790e632c1b3bd6 --- /dev/null +++ b/fairseq/models/text_to_speech/codehifigan.py @@ -0,0 +1,95 @@ +from argparse import Namespace +import torch +import torch.nn as nn + +from fairseq.models.text_to_speech.fastspeech2 import VariancePredictor +from fairseq.models.text_to_speech.hifigan import Generator + + +class CodeGenerator(Generator): + def __init__(self, cfg): + super().__init__(cfg) + self.dict = nn.Embedding(cfg["num_embeddings"], cfg["embedding_dim"]) + self.multispkr = cfg.get("multispkr", None) + self.embedder = cfg.get("embedder_params", None) + + if self.multispkr and not self.embedder: + self.spkr = nn.Embedding(cfg.get("num_speakers", 200), cfg["embedding_dim"]) + elif self.embedder: + self.spkr = nn.Linear(cfg.get("embedder_dim", 256), cfg["embedding_dim"]) + + self.dur_predictor = None + if cfg.get("dur_predictor_params", None): + self.dur_predictor = VariancePredictor( + Namespace(**cfg["dur_predictor_params"]) + ) + + self.f0 = cfg.get("f0", None) + n_f0_bin = cfg.get("f0_quant_num_bin", 0) + self.f0_quant_embed = ( + None if n_f0_bin <= 0 else nn.Embedding(n_f0_bin, cfg["embedding_dim"]) + ) + + @staticmethod + def _upsample(signal, max_frames): + if signal.dim() == 3: + bsz, channels, cond_length = signal.size() + elif signal.dim() == 2: + signal = signal.unsqueeze(2) + bsz, channels, cond_length = signal.size() + else: + signal = signal.view(-1, 1, 1) + bsz, channels, cond_length = signal.size() + + signal = signal.unsqueeze(3).repeat(1, 1, 1, max_frames // cond_length) + + # pad zeros as needed (if signal's shape does not divide completely with max_frames) + reminder = (max_frames - signal.shape[2] * signal.shape[3]) // signal.shape[3] + if reminder > 0: + raise NotImplementedError( + "Padding condition signal - misalignment between condition features." + ) + + signal = signal.view(bsz, channels, max_frames) + return signal + + def forward(self, **kwargs): + x = self.dict(kwargs["code"]).transpose(1, 2) + + if self.dur_predictor and kwargs.get("dur_prediction", False): + assert x.size(0) == 1, "only support single sample" + log_dur_pred = self.dur_predictor(x.transpose(1, 2)) + dur_out = torch.clamp( + torch.round((torch.exp(log_dur_pred) - 1)).long(), min=1 + ) + # B x C x T + x = torch.repeat_interleave(x, dur_out.view(-1), dim=2) + + if self.f0: + if self.f0_quant_embed: + kwargs["f0"] = self.f0_quant_embed(kwargs["f0"].long()).transpose(1, 2) + else: + kwargs["f0"] = kwargs["f0"].unsqueeze(1) + + if x.shape[-1] < kwargs["f0"].shape[-1]: + x = self._upsample(x, kwargs["f0"].shape[-1]) + elif x.shape[-1] > kwargs["f0"].shape[-1]: + kwargs["f0"] = self._upsample(kwargs["f0"], x.shape[-1]) + x = torch.cat([x, kwargs["f0"]], dim=1) + + if self.multispkr: + assert ( + "spkr" in kwargs + ), 'require "spkr" input for multispeaker CodeHiFiGAN vocoder' + spkr = self.spkr(kwargs["spkr"]).transpose(1, 2) + spkr = self._upsample(spkr, x.shape[-1]) + x = torch.cat([x, spkr], dim=1) + + for k, feat in kwargs.items(): + if k in ["spkr", "code", "f0", "dur_prediction"]: + continue + + feat = self._upsample(feat, x.shape[-1]) + x = torch.cat([x, feat], dim=1) + + return super().forward(x) diff --git a/fairseq/models/text_to_speech/fastspeech2.py b/fairseq/models/text_to_speech/fastspeech2.py new file mode 100644 index 0000000000000000000000000000000000000000..fb2d0df37ddec199a7989cf87ecd2386dc84d74e --- /dev/null +++ b/fairseq/models/text_to_speech/fastspeech2.py @@ -0,0 +1,448 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +from torch import nn + +from fairseq import utils +from fairseq.data.data_utils import lengths_to_padding_mask +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderModel, + register_model, + register_model_architecture, +) +from fairseq.models.text_to_speech.hub_interface import TTSHubInterface +from fairseq.models.text_to_speech.tacotron2 import Postnet +from fairseq.modules import ( + FairseqDropout, + LayerNorm, + MultiheadAttention, + PositionalEmbedding, +) + +logger = logging.getLogger(__name__) + + +def model_init(m): + if isinstance(m, nn.Conv1d): + nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("relu")) + + +def Embedding(num_embeddings, embedding_dim, padding_idx=None): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5) + return m + + +class PositionwiseFeedForward(nn.Module): + def __init__(self, in_dim, hidden_dim, kernel_size, dropout): + super().__init__() + self.ffn = nn.Sequential( + nn.Conv1d( + in_dim, + hidden_dim, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ), + nn.ReLU(), + nn.Conv1d( + hidden_dim, + in_dim, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ), + ) + self.layer_norm = LayerNorm(in_dim) + self.dropout = self.dropout_module = FairseqDropout( + p=dropout, module_name=self.__class__.__name__ + ) + + def forward(self, x): + # B x T x C + residual = x + x = self.ffn(x.transpose(1, 2)).transpose(1, 2) + x = self.dropout(x) + return self.layer_norm(x + residual) + + +class FFTLayer(torch.nn.Module): + def __init__( + self, embed_dim, n_heads, hidden_dim, kernel_size, dropout, attention_dropout + ): + super().__init__() + self.self_attn = MultiheadAttention( + embed_dim, n_heads, dropout=attention_dropout, self_attention=True + ) + self.layer_norm = LayerNorm(embed_dim) + self.ffn = PositionwiseFeedForward( + embed_dim, hidden_dim, kernel_size, dropout=dropout + ) + + def forward(self, x, padding_mask=None): + # B x T x C + residual = x + x = x.transpose(0, 1) + x, _ = self.self_attn( + query=x, key=x, value=x, key_padding_mask=padding_mask, need_weights=False + ) + x = x.transpose(0, 1) + x = self.layer_norm(x + residual) + return self.ffn(x) + + +class LengthRegulator(nn.Module): + def forward(self, x, durations): + # x: B x T x C + out_lens = durations.sum(dim=1) + max_len = out_lens.max() + bsz, seq_len, dim = x.size() + out = x.new_zeros((bsz, max_len, dim)) + + for b in range(bsz): + indices = [] + for t in range(seq_len): + indices.extend([t] * utils.item(durations[b, t])) + indices = torch.tensor(indices, dtype=torch.long).to(x.device) + out_len = utils.item(out_lens[b]) + out[b, :out_len] = x[b].index_select(0, indices) + + return out, out_lens + + +class VariancePredictor(nn.Module): + def __init__(self, args): + super().__init__() + self.conv1 = nn.Sequential( + nn.Conv1d( + args.encoder_embed_dim, + args.var_pred_hidden_dim, + kernel_size=args.var_pred_kernel_size, + padding=(args.var_pred_kernel_size - 1) // 2, + ), + nn.ReLU(), + ) + self.ln1 = nn.LayerNorm(args.var_pred_hidden_dim) + self.dropout_module = FairseqDropout( + p=args.var_pred_dropout, module_name=self.__class__.__name__ + ) + self.conv2 = nn.Sequential( + nn.Conv1d( + args.var_pred_hidden_dim, + args.var_pred_hidden_dim, + kernel_size=args.var_pred_kernel_size, + padding=1, + ), + nn.ReLU(), + ) + self.ln2 = nn.LayerNorm(args.var_pred_hidden_dim) + self.proj = nn.Linear(args.var_pred_hidden_dim, 1) + + def forward(self, x): + # Input: B x T x C; Output: B x T + x = self.conv1(x.transpose(1, 2)).transpose(1, 2) + x = self.dropout_module(self.ln1(x)) + x = self.conv2(x.transpose(1, 2)).transpose(1, 2) + x = self.dropout_module(self.ln2(x)) + return self.proj(x).squeeze(dim=2) + + +class VarianceAdaptor(nn.Module): + def __init__(self, args): + super().__init__() + self.args = args + self.length_regulator = LengthRegulator() + self.duration_predictor = VariancePredictor(args) + self.pitch_predictor = VariancePredictor(args) + self.energy_predictor = VariancePredictor(args) + + n_bins, steps = self.args.var_pred_n_bins, self.args.var_pred_n_bins - 1 + self.pitch_bins = torch.linspace(args.pitch_min, args.pitch_max, steps) + self.embed_pitch = Embedding(n_bins, args.encoder_embed_dim) + self.energy_bins = torch.linspace(args.energy_min, args.energy_max, steps) + self.embed_energy = Embedding(n_bins, args.encoder_embed_dim) + + def get_pitch_emb(self, x, tgt=None, factor=1.0): + out = self.pitch_predictor(x) + bins = self.pitch_bins.to(x.device) + if tgt is None: + out = out * factor + emb = self.embed_pitch(torch.bucketize(out, bins)) + else: + emb = self.embed_pitch(torch.bucketize(tgt, bins)) + return out, emb + + def get_energy_emb(self, x, tgt=None, factor=1.0): + out = self.energy_predictor(x) + bins = self.energy_bins.to(x.device) + if tgt is None: + out = out * factor + emb = self.embed_energy(torch.bucketize(out, bins)) + else: + emb = self.embed_energy(torch.bucketize(tgt, bins)) + return out, emb + + def forward( + self, + x, + padding_mask, + durations=None, + pitches=None, + energies=None, + d_factor=1.0, + p_factor=1.0, + e_factor=1.0, + ): + # x: B x T x C + log_dur_out = self.duration_predictor(x) + dur_out = torch.clamp( + torch.round((torch.exp(log_dur_out) - 1) * d_factor).long(), min=0 + ) + dur_out.masked_fill_(padding_mask, 0) + + pitch_out, pitch_emb = self.get_pitch_emb(x, pitches, p_factor) + x = x + pitch_emb + energy_out, energy_emb = self.get_energy_emb(x, energies, e_factor) + x = x + energy_emb + + x, out_lens = self.length_regulator( + x, dur_out if durations is None else durations + ) + + return x, out_lens, log_dur_out, pitch_out, energy_out + + +class FastSpeech2Encoder(FairseqEncoder): + def __init__(self, args, src_dict, embed_speaker): + super().__init__(src_dict) + self.args = args + self.padding_idx = src_dict.pad() + self.n_frames_per_step = args.n_frames_per_step + self.out_dim = args.output_frame_dim * args.n_frames_per_step + + self.embed_speaker = embed_speaker + self.spk_emb_proj = None + if embed_speaker is not None: + self.spk_emb_proj = nn.Linear( + args.encoder_embed_dim + args.speaker_embed_dim, args.encoder_embed_dim + ) + + self.dropout_module = FairseqDropout( + p=args.dropout, module_name=self.__class__.__name__ + ) + self.embed_tokens = Embedding( + len(src_dict), args.encoder_embed_dim, padding_idx=self.padding_idx + ) + + self.embed_positions = PositionalEmbedding( + args.max_source_positions, args.encoder_embed_dim, self.padding_idx + ) + self.pos_emb_alpha = nn.Parameter(torch.ones(1)) + self.dec_pos_emb_alpha = nn.Parameter(torch.ones(1)) + + self.encoder_fft_layers = nn.ModuleList( + FFTLayer( + args.encoder_embed_dim, + args.encoder_attention_heads, + args.fft_hidden_dim, + args.fft_kernel_size, + dropout=args.dropout, + attention_dropout=args.attention_dropout, + ) + for _ in range(args.encoder_layers) + ) + + self.var_adaptor = VarianceAdaptor(args) + + self.decoder_fft_layers = nn.ModuleList( + FFTLayer( + args.decoder_embed_dim, + args.decoder_attention_heads, + args.fft_hidden_dim, + args.fft_kernel_size, + dropout=args.dropout, + attention_dropout=args.attention_dropout, + ) + for _ in range(args.decoder_layers) + ) + + self.out_proj = nn.Linear(args.decoder_embed_dim, self.out_dim) + + self.postnet = None + if args.add_postnet: + self.postnet = Postnet( + self.out_dim, + args.postnet_conv_dim, + args.postnet_conv_kernel_size, + args.postnet_layers, + args.postnet_dropout, + ) + + self.apply(model_init) + + def forward( + self, + src_tokens, + src_lengths=None, + speaker=None, + durations=None, + pitches=None, + energies=None, + **kwargs, + ): + x = self.embed_tokens(src_tokens) + + enc_padding_mask = src_tokens.eq(self.padding_idx) + x += self.pos_emb_alpha * self.embed_positions(enc_padding_mask) + x = self.dropout_module(x) + + for layer in self.encoder_fft_layers: + x = layer(x, enc_padding_mask) + + if self.embed_speaker is not None: + bsz, seq_len, _ = x.size() + emb = self.embed_speaker(speaker).expand(bsz, seq_len, -1) + x = self.spk_emb_proj(torch.cat([x, emb], dim=2)) + + x, out_lens, log_dur_out, pitch_out, energy_out = self.var_adaptor( + x, enc_padding_mask, durations, pitches, energies + ) + + dec_padding_mask = lengths_to_padding_mask(out_lens) + x += self.dec_pos_emb_alpha * self.embed_positions(dec_padding_mask) + for layer in self.decoder_fft_layers: + x = layer(x, dec_padding_mask) + + x = self.out_proj(x) + x_post = None + if self.postnet is not None: + x_post = x + self.postnet(x) + return x, x_post, out_lens, log_dur_out, pitch_out, energy_out + + +@register_model("fastspeech2") +class FastSpeech2Model(FairseqEncoderModel): + """ + Implementation for https://arxiv.org/abs/2006.04558 + """ + + NON_AUTOREGRESSIVE = True + + @classmethod + def hub_models(cls): + base_url = "http://dl.fbaipublicfiles.com/fairseq/s2" + model_ids = [ + "fastspeech2-en-ljspeech", + "fastspeech2-en-200_speaker-cv4", + ] + return {i: f"{base_url}/{i}.tar.gz" for i in model_ids} + + @classmethod + def from_pretrained( + cls, + model_name_or_path, + checkpoint_file="model.pt", + data_name_or_path=".", + config_yaml="config.yaml", + vocoder: str = "griffin_lim", + fp16: bool = False, + **kwargs, + ): + from fairseq import hub_utils + + x = hub_utils.from_pretrained( + model_name_or_path, + checkpoint_file, + data_name_or_path, + archive_map=cls.hub_models(), + config_yaml=config_yaml, + vocoder=vocoder, + fp16=fp16, + **kwargs, + ) + return TTSHubInterface(x["args"], x["task"], x["models"][0]) + + @staticmethod + def add_args(parser): + parser.add_argument("--dropout", type=float) + parser.add_argument("--output-frame-dim", type=int) + parser.add_argument("--speaker-embed-dim", type=int) + # FFT blocks + parser.add_argument("--fft-hidden-dim", type=int) + parser.add_argument("--fft-kernel-size", type=int) + parser.add_argument("--attention-dropout", type=float) + parser.add_argument("--encoder-layers", type=int) + parser.add_argument("--encoder-embed-dim", type=int) + parser.add_argument("--encoder-attention-heads", type=int) + parser.add_argument("--decoder-layers", type=int) + parser.add_argument("--decoder-embed-dim", type=int) + parser.add_argument("--decoder-attention-heads", type=int) + # variance predictor + parser.add_argument("--var-pred-n-bins", type=int) + parser.add_argument("--var-pred-hidden-dim", type=int) + parser.add_argument("--var-pred-kernel-size", type=int) + parser.add_argument("--var-pred-dropout", type=float) + # postnet + parser.add_argument("--add-postnet", action="store_true") + parser.add_argument("--postnet-dropout", type=float) + parser.add_argument("--postnet-layers", type=int) + parser.add_argument("--postnet-conv-dim", type=int) + parser.add_argument("--postnet-conv-kernel-size", type=int) + + def __init__(self, encoder, args, src_dict): + super().__init__(encoder) + self._num_updates = 0 + + out_dim = args.output_frame_dim * args.n_frames_per_step + self.ctc_proj = None + if getattr(args, "ctc_weight", 0.0) > 0.0: + self.ctc_proj = nn.Linear(out_dim, len(src_dict)) + + @classmethod + def build_model(cls, args, task): + embed_speaker = task.get_speaker_embeddings(args) + encoder = FastSpeech2Encoder(args, task.src_dict, embed_speaker) + return cls(encoder, args, task.src_dict) + + def set_num_updates(self, num_updates): + super().set_num_updates(num_updates) + self._num_updates = num_updates + + def get_normalized_probs(self, net_output, log_probs, sample=None): + logits = self.ctc_proj(net_output[0]) + if log_probs: + return utils.log_softmax(logits.float(), dim=-1) + else: + return utils.softmax(logits.float(), dim=-1) + + +@register_model_architecture("fastspeech2", "fastspeech2") +def base_architecture(args): + args.dropout = getattr(args, "dropout", 0.2) + args.output_frame_dim = getattr(args, "output_frame_dim", 80) + args.speaker_embed_dim = getattr(args, "speaker_embed_dim", 64) + # FFT blocks + args.fft_hidden_dim = getattr(args, "fft_hidden_dim", 1024) + args.fft_kernel_size = getattr(args, "fft_kernel_size", 9) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.encoder_layers = getattr(args, "encoder_layers", 4) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 2) + args.decoder_layers = getattr(args, "decoder_layers", 4) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 2) + # variance predictor + args.var_pred_n_bins = getattr(args, "var_pred_n_bins", 256) + args.var_pred_hidden_dim = getattr(args, "var_pred_hidden_dim", 256) + args.var_pred_kernel_size = getattr(args, "var_pred_kernel_size", 3) + args.var_pred_dropout = getattr(args, "var_pred_dropout", 0.5) + # postnet + args.add_postnet = getattr(args, "add_postnet", False) + args.postnet_dropout = getattr(args, "postnet_dropout", 0.5) + args.postnet_layers = getattr(args, "postnet_layers", 5) + args.postnet_conv_dim = getattr(args, "postnet_conv_dim", 512) + args.postnet_conv_kernel_size = getattr(args, "postnet_conv_kernel_size", 5) diff --git a/fairseq/models/text_to_speech/hifigan.py b/fairseq/models/text_to_speech/hifigan.py new file mode 100644 index 0000000000000000000000000000000000000000..a852beefecd57c4f62f75be0341943678beb17c2 --- /dev/null +++ b/fairseq/models/text_to_speech/hifigan.py @@ -0,0 +1,179 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, weight_norm + +LRELU_SLOPE = 0.1 + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return (kernel_size * dilation - dilation) // 2 + + +class ResBlock(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock, self).__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for layer in self.convs1: + remove_weight_norm(layer) + for layer in self.convs2: + remove_weight_norm(layer) + + +class Generator(torch.nn.Module): + def __init__(self, cfg): + super(Generator, self).__init__() + self.num_kernels = len(cfg["resblock_kernel_sizes"]) + self.num_upsamples = len(cfg["upsample_rates"]) + self.conv_pre = weight_norm( + Conv1d( + cfg.get("model_in_dim", 80), + cfg["upsample_initial_channel"], + 7, + 1, + padding=3, + ) + ) + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate( + zip(cfg["upsample_rates"], cfg["upsample_kernel_sizes"]) + ): + self.ups.append( + weight_norm( + ConvTranspose1d( + cfg["upsample_initial_channel"] // (2**i), + cfg["upsample_initial_channel"] // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = cfg["upsample_initial_channel"] // (2 ** (i + 1)) + for k, d in zip( + cfg["resblock_kernel_sizes"], cfg["resblock_dilation_sizes"] + ): + self.resblocks.append(ResBlock(ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print("Removing weight norm...") + for layer in self.ups: + remove_weight_norm(layer) + for layer in self.resblocks: + layer.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) diff --git a/fairseq/models/text_to_speech/hub_interface.py b/fairseq/models/text_to_speech/hub_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..e251c65c1d6cc6e60168a2da170618907be8bf60 --- /dev/null +++ b/fairseq/models/text_to_speech/hub_interface.py @@ -0,0 +1,188 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import random +from pathlib import Path +from typing import Dict, Optional, Tuple + +import torch +import torch.nn as nn + +logger = logging.getLogger(__name__) + + +class TTSHubInterface(nn.Module): + def __init__(self, cfg, task, model): + super().__init__() + self.cfg = cfg + self.task = task + self.model = model + self.model.eval() + + self.update_cfg_with_data_cfg(self.cfg, self.task.data_cfg) + self.generator = self.task.build_generator([self.model], self.cfg) + + @classmethod + def phonemize( + cls, + text: str, + lang: Optional[str], + phonemizer: Optional[str] = None, + preserve_punct: bool = False, + to_simplified_zh: bool = False, + ): + if to_simplified_zh: + import hanziconv + + text = hanziconv.HanziConv.toSimplified(text) + + if phonemizer == "g2p": + import g2p_en + + g2p = g2p_en.G2p() + if preserve_punct: + return " ".join("|" if p == " " else p for p in g2p(text)) + else: + res = [{",": "sp", ";": "sp"}.get(p, p) for p in g2p(text)] + return " ".join(p for p in res if p.isalnum()) + if phonemizer == "g2pc": + import g2pc + + g2p = g2pc.G2pC() + return " ".join([w[3] for w in g2p(text)]) + elif phonemizer == "ipa": + assert lang is not None + import phonemizer + from phonemizer.separator import Separator + + lang_map = {"en": "en-us", "fr": "fr-fr"} + return phonemizer.phonemize( + text, + backend="espeak", + language=lang_map.get(lang, lang), + separator=Separator(word="| ", phone=" "), + ) + else: + return text + + @classmethod + def tokenize(cls, text: str, tkn_cfg: Dict[str, str]): + sentencepiece_model = tkn_cfg.get("sentencepiece_model", None) + if sentencepiece_model is not None: + assert Path(sentencepiece_model).exists() + import sentencepiece as sp + + spm = sp.SentencePieceProcessor() + spm.Load(sentencepiece_model) + return " ".join(spm.Encode(text, out_type=str)) + else: + return text + + @classmethod + def update_cfg_with_data_cfg(cls, cfg, data_cfg): + cfg["task"].vocoder = data_cfg.vocoder.get("type", "griffin_lim") + + @classmethod + def get_model_input( + cls, task, text: str, speaker: Optional[int] = None, verbose: bool = False + ): + phonemized = cls.phonemize( + text, + task.data_cfg.hub.get("lang", None), + task.data_cfg.hub.get("phonemizer", None), + task.data_cfg.hub.get("preserve_punct", False), + task.data_cfg.hub.get("to_simplified_zh", False), + ) + tkn_cfg = task.data_cfg.bpe_tokenizer + tokenized = cls.tokenize(phonemized, tkn_cfg) + if verbose: + logger.info(f"text: {text}") + logger.info(f"phonemized: {phonemized}") + logger.info(f"tokenized: {tokenized}") + + spk = task.data_cfg.hub.get("speaker", speaker) + n_speakers = len(task.speaker_to_id or {}) + if spk is None and n_speakers > 0: + spk = random.randint(0, n_speakers - 1) + if spk is not None: + spk = max(0, min(spk, n_speakers - 1)) + if verbose: + logger.info(f"speaker: {spk}") + spk = None if spk is None else torch.Tensor([[spk]]).long() + + src_tokens = task.src_dict.encode_line(tokenized, add_if_not_exist=False).view( + 1, -1 + ) + src_lengths = torch.Tensor([len(tokenized.split())]).long() + return { + "net_input": { + "src_tokens": src_tokens, + "src_lengths": src_lengths, + "prev_output_tokens": None, + }, + "target_lengths": None, + "speaker": spk, + } + + @classmethod + def get_prediction(cls, task, model, generator, sample) -> Tuple[torch.Tensor, int]: + prediction = generator.generate(model, sample) + return prediction[0]["waveform"], task.sr + + def predict( + self, text: str, speaker: Optional[int] = None, verbose: bool = False + ) -> Tuple[torch.Tensor, int]: + sample = self.get_model_input(self.task, text, speaker, verbose=verbose) + return self.get_prediction(self.task, self.model, self.generator, sample) + + +class VocoderHubInterface(nn.Module): + """Vocoder interface to run vocoder models through hub. Currently we only support unit vocoder""" + + def __init__(self, cfg, model): + super().__init__() + self.vocoder = model + self.vocoder.eval() + self.sr = 16000 + self.multispkr = self.vocoder.model.multispkr + if self.multispkr: + logger.info("multi-speaker vocoder") + self.num_speakers = cfg.get( + "num_speakers", + 200, + ) # following the default in codehifigan to set to 200 + + def get_model_input( + self, + text: str, + speaker: Optional[int] = -1, + ): + units = list(map(int, text.strip().split())) + x = { + "code": torch.LongTensor(units).view(1, -1), + } + if not speaker: + speaker = -1 + if self.multispkr: + assert ( + speaker < self.num_speakers + ), f"invalid --speaker-id ({speaker}) with total #speakers = {self.num_speakers}" + spk = random.randint(0, self.num_speakers - 1) if speaker == -1 else speaker + x["spkr"] = torch.LongTensor([spk]).view(1, 1) + return x + + def get_prediction(self, sample, dur_prediction: Optional[bool] = True): + wav = self.vocoder(sample, dur_prediction) + return wav, self.sr + + def predict( + self, + text: str, + speaker: Optional[int] = None, + dur_prediction: Optional[bool] = True, + ): + sample = self.get_model_input(text, speaker) + return self.get_prediction(sample, dur_prediction) diff --git a/fairseq/models/text_to_speech/tacotron2.py b/fairseq/models/text_to_speech/tacotron2.py new file mode 100644 index 0000000000000000000000000000000000000000..4df40756178ca84d06f812efb9f22319f49cf71b --- /dev/null +++ b/fairseq/models/text_to_speech/tacotron2.py @@ -0,0 +1,380 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +from torch import nn +from torch.nn import functional as F + +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderDecoderModel, + FairseqIncrementalDecoder, + register_model, + register_model_architecture, +) +from fairseq.modules import LSTMCellWithZoneOut, LocationAttention + + +logger = logging.getLogger(__name__) + + +def encoder_init(m): + if isinstance(m, nn.Conv1d): + nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("relu")) + + +class Tacotron2Encoder(FairseqEncoder): + def __init__(self, args, src_dict, embed_speaker): + super().__init__(src_dict) + self.padding_idx = src_dict.pad() + self.embed_speaker = embed_speaker + self.spk_emb_proj = None + if embed_speaker is not None: + self.spk_emb_proj = nn.Linear( + args.encoder_embed_dim + args.speaker_embed_dim, args.encoder_embed_dim + ) + + self.embed_tokens = nn.Embedding( + len(src_dict), args.encoder_embed_dim, padding_idx=self.padding_idx + ) + + assert args.encoder_conv_kernel_size % 2 == 1 + self.convolutions = nn.ModuleList( + nn.Sequential( + nn.Conv1d( + args.encoder_embed_dim, + args.encoder_embed_dim, + kernel_size=args.encoder_conv_kernel_size, + padding=((args.encoder_conv_kernel_size - 1) // 2), + ), + nn.BatchNorm1d(args.encoder_embed_dim), + nn.ReLU(), + nn.Dropout(args.encoder_dropout), + ) + for _ in range(args.encoder_conv_layers) + ) + + self.lstm = nn.LSTM( + args.encoder_embed_dim, + args.encoder_embed_dim // 2, + num_layers=args.encoder_lstm_layers, + batch_first=True, + bidirectional=True, + ) + + self.apply(encoder_init) + + def forward(self, src_tokens, src_lengths=None, speaker=None, **kwargs): + x = self.embed_tokens(src_tokens) + x = x.transpose(1, 2).contiguous() # B x T x C -> B x C x T + for conv in self.convolutions: + x = conv(x) + x = x.transpose(1, 2).contiguous() # B x C x T -> B x T x C + + src_lengths = src_lengths.cpu().long() + x = nn.utils.rnn.pack_padded_sequence(x, src_lengths, batch_first=True) + x = self.lstm(x)[0] + x = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)[0] + + encoder_padding_mask = src_tokens.eq(self.padding_idx) + + if self.embed_speaker is not None: + seq_len, bsz, _ = x.size() + emb = self.embed_speaker(speaker).expand(seq_len, bsz, -1) + x = self.spk_emb_proj(torch.cat([x, emb], dim=2)) + + return { + "encoder_out": [x], # B x T x C + "encoder_padding_mask": encoder_padding_mask, # B x T + } + + +class Prenet(nn.Module): + def __init__(self, in_dim, n_layers, n_units, dropout): + super().__init__() + self.layers = nn.ModuleList( + nn.Sequential(nn.Linear(in_dim if i == 0 else n_units, n_units), nn.ReLU()) + for i in range(n_layers) + ) + self.dropout = dropout + + def forward(self, x): + for layer in self.layers: + x = F.dropout(layer(x), p=self.dropout) # always applies dropout + return x + + +class Postnet(nn.Module): + def __init__(self, in_dim, n_channels, kernel_size, n_layers, dropout): + super(Postnet, self).__init__() + self.convolutions = nn.ModuleList() + assert kernel_size % 2 == 1 + for i in range(n_layers): + cur_layers = ( + [ + nn.Conv1d( + in_dim if i == 0 else n_channels, + n_channels if i < n_layers - 1 else in_dim, + kernel_size=kernel_size, + padding=((kernel_size - 1) // 2), + ), + nn.BatchNorm1d(n_channels if i < n_layers - 1 else in_dim), + ] + + ([nn.Tanh()] if i < n_layers - 1 else []) + + [nn.Dropout(dropout)] + ) + nn.init.xavier_uniform_( + cur_layers[0].weight, + torch.nn.init.calculate_gain("tanh" if i < n_layers - 1 else "linear"), + ) + self.convolutions.append(nn.Sequential(*cur_layers)) + + def forward(self, x): + x = x.transpose(1, 2) # B x T x C -> B x C x T + for conv in self.convolutions: + x = conv(x) + return x.transpose(1, 2) + + +def decoder_init(m): + if isinstance(m, torch.nn.Conv1d): + nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("tanh")) + + +class Tacotron2Decoder(FairseqIncrementalDecoder): + def __init__(self, args, src_dict): + super().__init__(None) + self.args = args + self.n_frames_per_step = args.n_frames_per_step + self.out_dim = args.output_frame_dim * args.n_frames_per_step + + self.prenet = Prenet( + self.out_dim, args.prenet_layers, args.prenet_dim, args.prenet_dropout + ) + + # take prev_context, prev_frame, (speaker embedding) as input + self.attention_lstm = LSTMCellWithZoneOut( + args.zoneout, + args.prenet_dim + args.encoder_embed_dim, + args.decoder_lstm_dim, + ) + + # take attention_lstm output, attention_state, encoder_out as input + self.attention = LocationAttention( + args.attention_dim, + args.encoder_embed_dim, + args.decoder_lstm_dim, + (1 + int(args.attention_use_cumprob)), + args.attention_conv_dim, + args.attention_conv_kernel_size, + ) + + # take attention_lstm output, context, (gated_latent) as input + self.lstm = nn.ModuleList( + LSTMCellWithZoneOut( + args.zoneout, + args.encoder_embed_dim + args.decoder_lstm_dim, + args.decoder_lstm_dim, + ) + for i in range(args.decoder_lstm_layers) + ) + + proj_in_dim = args.encoder_embed_dim + args.decoder_lstm_dim + self.feat_proj = nn.Linear(proj_in_dim, self.out_dim) + self.eos_proj = nn.Linear(proj_in_dim, 1) + + self.postnet = Postnet( + self.out_dim, + args.postnet_conv_dim, + args.postnet_conv_kernel_size, + args.postnet_layers, + args.postnet_dropout, + ) + + self.ctc_proj = None + if getattr(args, "ctc_weight", 0.0) > 0.0: + self.ctc_proj = nn.Linear(self.out_dim, len(src_dict)) + + self.apply(decoder_init) + + def _get_states(self, incremental_state, enc_out): + bsz, in_len, _ = enc_out.size() + alstm_h = self.get_incremental_state(incremental_state, "alstm_h") + if alstm_h is None: + alstm_h = enc_out.new_zeros(bsz, self.args.decoder_lstm_dim) + alstm_c = self.get_incremental_state(incremental_state, "alstm_c") + if alstm_c is None: + alstm_c = enc_out.new_zeros(bsz, self.args.decoder_lstm_dim) + + lstm_h = self.get_incremental_state(incremental_state, "lstm_h") + if lstm_h is None: + lstm_h = [ + enc_out.new_zeros(bsz, self.args.decoder_lstm_dim) + for _ in range(self.args.decoder_lstm_layers) + ] + lstm_c = self.get_incremental_state(incremental_state, "lstm_c") + if lstm_c is None: + lstm_c = [ + enc_out.new_zeros(bsz, self.args.decoder_lstm_dim) + for _ in range(self.args.decoder_lstm_layers) + ] + + attn_w = self.get_incremental_state(incremental_state, "attn_w") + if attn_w is None: + attn_w = enc_out.new_zeros(bsz, in_len) + attn_w_cum = self.get_incremental_state(incremental_state, "attn_w_cum") + if attn_w_cum is None: + attn_w_cum = enc_out.new_zeros(bsz, in_len) + return alstm_h, alstm_c, lstm_h, lstm_c, attn_w, attn_w_cum + + def _get_init_attn_c(self, enc_out, enc_mask): + bsz = enc_out.size(0) + if self.args.init_attn_c == "zero": + return enc_out.new_zeros(bsz, self.args.encoder_embed_dim) + elif self.args.init_attn_c == "avg": + enc_w = (~enc_mask).type(enc_out.type()) + enc_w = enc_w / enc_w.sum(dim=1, keepdim=True) + return torch.sum(enc_out * enc_w.unsqueeze(2), dim=1) + else: + raise ValueError(f"{self.args.init_attn_c} not supported") + + def forward( + self, + prev_output_tokens, + encoder_out=None, + incremental_state=None, + target_lengths=None, + **kwargs, + ): + enc_mask = encoder_out["encoder_padding_mask"] + enc_out = encoder_out["encoder_out"][0] + in_len = enc_out.size(1) + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:, :] + bsz, out_len, _ = prev_output_tokens.size() + + prenet_out = self.prenet(prev_output_tokens) + (alstm_h, alstm_c, lstm_h, lstm_c, attn_w, attn_w_cum) = self._get_states( + incremental_state, enc_out + ) + attn_ctx = self._get_init_attn_c(enc_out, enc_mask) + + attn_out = enc_out.new_zeros(bsz, in_len, out_len) + feat_out = enc_out.new_zeros(bsz, out_len, self.out_dim) + eos_out = enc_out.new_zeros(bsz, out_len) + for t in range(out_len): + alstm_in = torch.cat((attn_ctx, prenet_out[:, t, :]), dim=1) + alstm_h, alstm_c = self.attention_lstm(alstm_in, (alstm_h, alstm_c)) + + attn_state = attn_w.unsqueeze(1) + if self.args.attention_use_cumprob: + attn_state = torch.stack((attn_w, attn_w_cum), dim=1) + attn_ctx, attn_w = self.attention(enc_out, enc_mask, alstm_h, attn_state) + attn_w_cum = attn_w_cum + attn_w + attn_out[:, :, t] = attn_w + + for i, cur_lstm in enumerate(self.lstm): + if i == 0: + lstm_in = torch.cat((attn_ctx, alstm_h), dim=1) + else: + lstm_in = torch.cat((attn_ctx, lstm_h[i - 1]), dim=1) + lstm_h[i], lstm_c[i] = cur_lstm(lstm_in, (lstm_h[i], lstm_c[i])) + + proj_in = torch.cat((attn_ctx, lstm_h[-1]), dim=1) + feat_out[:, t, :] = self.feat_proj(proj_in) + eos_out[:, t] = self.eos_proj(proj_in).squeeze(1) + self.attention.clear_cache() + + self.set_incremental_state(incremental_state, "alstm_h", alstm_h) + self.set_incremental_state(incremental_state, "alstm_c", alstm_c) + self.set_incremental_state(incremental_state, "lstm_h", lstm_h) + self.set_incremental_state(incremental_state, "lstm_c", lstm_c) + self.set_incremental_state(incremental_state, "attn_w", attn_w) + self.set_incremental_state(incremental_state, "attn_w_cum", attn_w_cum) + + post_feat_out = feat_out + self.postnet(feat_out) + eos_out = eos_out.view(bsz, out_len, 1) + return post_feat_out, eos_out, {"attn": attn_out, "feature_out": feat_out} + + +@register_model("tacotron_2") +class Tacotron2Model(FairseqEncoderDecoderModel): + """ + Implementation for https://arxiv.org/pdf/1712.05884.pdf + """ + + @staticmethod + def add_args(parser): + # encoder + parser.add_argument("--encoder-dropout", type=float) + parser.add_argument("--encoder-embed-dim", type=int) + parser.add_argument("--encoder-conv-layers", type=int) + parser.add_argument("--encoder-conv-kernel-size", type=int) + parser.add_argument("--encoder-lstm-layers", type=int) + # decoder + parser.add_argument("--attention-dim", type=int) + parser.add_argument("--attention-conv-dim", type=int) + parser.add_argument("--attention-conv-kernel-size", type=int) + parser.add_argument("--prenet-dropout", type=float) + parser.add_argument("--prenet-layers", type=int) + parser.add_argument("--prenet-dim", type=int) + parser.add_argument("--postnet-dropout", type=float) + parser.add_argument("--postnet-layers", type=int) + parser.add_argument("--postnet-conv-dim", type=int) + parser.add_argument("--postnet-conv-kernel-size", type=int) + parser.add_argument("--init-attn-c", type=str) + parser.add_argument("--attention-use-cumprob", action="store_true") + parser.add_argument("--zoneout", type=float) + parser.add_argument("--decoder-lstm-layers", type=int) + parser.add_argument("--decoder-lstm-dim", type=int) + parser.add_argument("--output-frame-dim", type=int) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._num_updates = 0 + + @classmethod + def build_model(cls, args, task): + embed_speaker = task.get_speaker_embeddings(args) + encoder = Tacotron2Encoder(args, task.src_dict, embed_speaker) + decoder = Tacotron2Decoder(args, task.src_dict) + return cls(encoder, decoder) + + def forward_encoder(self, src_tokens, src_lengths, **kwargs): + return self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) + + def set_num_updates(self, num_updates): + super().set_num_updates(num_updates) + self._num_updates = num_updates + + +@register_model_architecture("tacotron_2", "tacotron_2") +def base_architecture(args): + # encoder + args.encoder_dropout = getattr(args, "encoder_dropout", 0.5) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_conv_layers = getattr(args, "encoder_conv_layers", 3) + args.encoder_conv_kernel_size = getattr(args, "encoder_conv_kernel_size", 5) + args.encoder_lstm_layers = getattr(args, "encoder_lstm_layers", 1) + # decoder + args.attention_dim = getattr(args, "attention_dim", 128) + args.attention_conv_dim = getattr(args, "attention_conv_dim", 32) + args.attention_conv_kernel_size = getattr(args, "attention_conv_kernel_size", 15) + args.prenet_dropout = getattr(args, "prenet_dropout", 0.5) + args.prenet_layers = getattr(args, "prenet_layers", 2) + args.prenet_dim = getattr(args, "prenet_dim", 256) + args.postnet_dropout = getattr(args, "postnet_dropout", 0.5) + args.postnet_layers = getattr(args, "postnet_layers", 5) + args.postnet_conv_dim = getattr(args, "postnet_conv_dim", 512) + args.postnet_conv_kernel_size = getattr(args, "postnet_conv_kernel_size", 5) + args.init_attn_c = getattr(args, "init_attn_c", "zero") + args.attention_use_cumprob = getattr(args, "attention_use_cumprob", True) + args.zoneout = getattr(args, "zoneout", 0.1) + args.decoder_lstm_layers = getattr(args, "decoder_lstm_layers", 2) + args.decoder_lstm_dim = getattr(args, "decoder_lstm_dim", 1024) + args.output_frame_dim = getattr(args, "output_frame_dim", 80) diff --git a/fairseq/models/text_to_speech/tts_transformer.py b/fairseq/models/text_to_speech/tts_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..19afc2b717b3036805e9d9bcfdb9a366b9cc19a5 --- /dev/null +++ b/fairseq/models/text_to_speech/tts_transformer.py @@ -0,0 +1,454 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import List, Optional + +import torch +from torch import nn + +from fairseq import utils +from fairseq.data.data_utils import lengths_to_padding_mask +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderDecoderModel, + FairseqIncrementalDecoder, + register_model, + register_model_architecture, +) +from fairseq.models.text_to_speech.hub_interface import TTSHubInterface +from fairseq.models.text_to_speech.tacotron2 import Postnet, Prenet +from fairseq.modules import ( + FairseqDropout, + LayerNorm, + PositionalEmbedding, + TransformerDecoderLayer, + TransformerEncoderLayer, +) + +logger = logging.getLogger(__name__) + + +def encoder_init(m): + if isinstance(m, nn.Conv1d): + nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("relu")) + + +def Embedding(num_embeddings, embedding_dim): + m = nn.Embedding(num_embeddings, embedding_dim) + nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5) + return m + + +class TTSTransformerEncoder(FairseqEncoder): + def __init__(self, args, src_dict, embed_speaker): + super().__init__(src_dict) + self.padding_idx = src_dict.pad() + self.embed_speaker = embed_speaker + self.spk_emb_proj = None + if embed_speaker is not None: + self.spk_emb_proj = nn.Linear( + args.encoder_embed_dim + args.speaker_embed_dim, args.encoder_embed_dim + ) + + self.dropout_module = FairseqDropout( + p=args.dropout, module_name=self.__class__.__name__ + ) + self.embed_tokens = nn.Embedding( + len(src_dict), args.encoder_embed_dim, padding_idx=self.padding_idx + ) + assert args.encoder_conv_kernel_size % 2 == 1 + self.prenet = nn.ModuleList( + nn.Sequential( + nn.Conv1d( + args.encoder_embed_dim, + args.encoder_embed_dim, + kernel_size=args.encoder_conv_kernel_size, + padding=((args.encoder_conv_kernel_size - 1) // 2), + ), + nn.BatchNorm1d(args.encoder_embed_dim), + nn.ReLU(), + nn.Dropout(args.encoder_dropout), + ) + for _ in range(args.encoder_conv_layers) + ) + self.prenet_proj = nn.Linear(args.encoder_embed_dim, args.encoder_embed_dim) + self.embed_positions = PositionalEmbedding( + args.max_source_positions, args.encoder_embed_dim, self.padding_idx + ) + self.pos_emb_alpha = nn.Parameter(torch.ones(1)) + + self.transformer_layers = nn.ModuleList( + TransformerEncoderLayer(args) + for _ in range(args.encoder_transformer_layers) + ) + if args.encoder_normalize_before: + self.layer_norm = LayerNorm(args.encoder_embed_dim) + else: + self.layer_norm = None + + self.apply(encoder_init) + + def forward(self, src_tokens, src_lengths=None, speaker=None, **kwargs): + x = self.embed_tokens(src_tokens) + x = x.transpose(1, 2).contiguous() # B x T x C -> B x C x T + for conv in self.prenet: + x = conv(x) + x = x.transpose(1, 2).contiguous() # B x C x T -> B x T x C + x = self.prenet_proj(x) + + padding_mask = src_tokens.eq(self.padding_idx) + positions = self.embed_positions(padding_mask) + x += self.pos_emb_alpha * positions + x = self.dropout_module(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + for layer in self.transformer_layers: + x = layer(x, padding_mask) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + if self.embed_speaker is not None: + seq_len, bsz, _ = x.size() + emb = self.embed_speaker(speaker).transpose(0, 1) + emb = emb.expand(seq_len, bsz, -1) + x = self.spk_emb_proj(torch.cat([x, emb], dim=2)) + + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [padding_mask] + if padding_mask.any() + else [], # B x T + "encoder_embedding": [], # B x T x C + "encoder_states": [], # List[T x B x C] + "src_tokens": [], + "src_lengths": [], + } + + +def decoder_init(m): + if isinstance(m, torch.nn.Conv1d): + nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("tanh")) + + +class TTSTransformerDecoder(FairseqIncrementalDecoder): + def __init__(self, args, src_dict, padding_idx=1): + super().__init__(None) + self._future_mask = torch.empty(0) + + self.args = args + self.padding_idx = src_dict.pad() if src_dict else padding_idx + self.n_frames_per_step = args.n_frames_per_step + self.out_dim = args.output_frame_dim * args.n_frames_per_step + + self.dropout_module = FairseqDropout( + args.dropout, module_name=self.__class__.__name__ + ) + self.embed_positions = PositionalEmbedding( + args.max_target_positions, args.decoder_embed_dim, self.padding_idx + ) + self.pos_emb_alpha = nn.Parameter(torch.ones(1)) + self.prenet = nn.Sequential( + Prenet( + self.out_dim, args.prenet_layers, args.prenet_dim, args.prenet_dropout + ), + nn.Linear(args.prenet_dim, args.decoder_embed_dim), + ) + + self.n_transformer_layers = args.decoder_transformer_layers + self.transformer_layers = nn.ModuleList( + TransformerDecoderLayer(args) for _ in range(self.n_transformer_layers) + ) + if args.decoder_normalize_before: + self.layer_norm = LayerNorm(args.decoder_embed_dim) + else: + self.layer_norm = None + + self.feat_proj = nn.Linear(args.decoder_embed_dim, self.out_dim) + self.eos_proj = nn.Linear(args.decoder_embed_dim, 1) + + self.postnet = Postnet( + self.out_dim, + args.postnet_conv_dim, + args.postnet_conv_kernel_size, + args.postnet_layers, + args.postnet_dropout, + ) + + self.ctc_proj = None + if getattr(args, "ctc_weight", 0.0) > 0.0: + self.ctc_proj = nn.Linear(self.out_dim, len(src_dict)) + + self.apply(decoder_init) + + def extract_features( + self, + prev_outputs, + encoder_out=None, + incremental_state=None, + target_lengths=None, + speaker=None, + **kwargs, + ): + alignment_layer = self.n_transformer_layers - 1 + self_attn_padding_mask = lengths_to_padding_mask(target_lengths) + positions = self.embed_positions( + self_attn_padding_mask, incremental_state=incremental_state + ) + + if incremental_state is not None: + prev_outputs = prev_outputs[:, -1:, :] + self_attn_padding_mask = self_attn_padding_mask[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + x = self.prenet(prev_outputs) + x += self.pos_emb_alpha * positions + x = self.dropout_module(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + if not self_attn_padding_mask.any(): + self_attn_padding_mask = None + + attn: Optional[torch.Tensor] = None + inner_states: List[Optional[torch.Tensor]] = [x] + for idx, transformer_layer in enumerate(self.transformer_layers): + if incremental_state is None: + self_attn_mask = self.buffered_future_mask(x) + else: + self_attn_mask = None + + x, layer_attn, _ = transformer_layer( + x, + encoder_out["encoder_out"][0] + if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0) + else None, + encoder_out["encoder_padding_mask"][0] + if ( + encoder_out is not None + and len(encoder_out["encoder_padding_mask"]) > 0 + ) + else None, + incremental_state, + self_attn_mask=self_attn_mask, + self_attn_padding_mask=self_attn_padding_mask, + need_attn=bool((idx == alignment_layer)), + need_head_weights=bool((idx == alignment_layer)), + ) + inner_states.append(x) + if layer_attn is not None and idx == alignment_layer: + attn = layer_attn.float().to(x) + + if attn is not None: + # average probabilities over heads, transpose to + # (B, src_len, tgt_len) + attn = attn.mean(dim=0).transpose(2, 1) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, {"attn": attn, "inner_states": inner_states} + + def forward( + self, + prev_output_tokens, + encoder_out=None, + incremental_state=None, + target_lengths=None, + speaker=None, + **kwargs, + ): + x, extra = self.extract_features( + prev_output_tokens, + encoder_out=encoder_out, + incremental_state=incremental_state, + target_lengths=target_lengths, + speaker=speaker, + **kwargs, + ) + attn = extra["attn"] + feat_out = self.feat_proj(x) + bsz, seq_len, _ = x.size() + eos_out = self.eos_proj(x) + post_feat_out = feat_out + self.postnet(feat_out) + return ( + post_feat_out, + eos_out, + { + "attn": attn, + "feature_out": feat_out, + "inner_states": extra["inner_states"], + }, + ) + + def get_normalized_probs(self, net_output, log_probs, sample): + logits = self.ctc_proj(net_output[2]["feature_out"]) + if log_probs: + return utils.log_softmax(logits.float(), dim=-1) + else: + return utils.softmax(logits.float(), dim=-1) + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround. + if ( + self._future_mask.size(0) == 0 + or (not self._future_mask.device == tensor.device) + or self._future_mask.size(0) < dim + ): + self._future_mask = torch.triu( + utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1 + ) + self._future_mask = self._future_mask.to(tensor) + return self._future_mask[:dim, :dim] + + +@register_model("tts_transformer") +class TTSTransformerModel(FairseqEncoderDecoderModel): + """ + Implementation for https://arxiv.org/pdf/1809.08895.pdf + """ + + @classmethod + def hub_models(cls): + base_url = "http://dl.fbaipublicfiles.com/fairseq/s2" + model_ids = [ + "tts_transformer-en-ljspeech", + "tts_transformer-en-200_speaker-cv4", + "tts_transformer-es-css10", + "tts_transformer-fr-cv7_css10", + "tts_transformer-ru-cv7_css10", + "tts_transformer-zh-cv7_css10", + "tts_transformer-ar-cv7_css10", + "tts_transformer-tr-cv7_css10", + "tts_transformer-vi-cv7", + ] + return {i: f"{base_url}/{i}.tar.gz" for i in model_ids} + + @classmethod + def from_pretrained( + cls, + model_name_or_path, + checkpoint_file="model.pt", + data_name_or_path=".", + config_yaml="config.yaml", + vocoder: str = "griffin_lim", + fp16: bool = False, + **kwargs, + ): + from fairseq import hub_utils + + x = hub_utils.from_pretrained( + model_name_or_path, + checkpoint_file, + data_name_or_path, + archive_map=cls.hub_models(), + config_yaml=config_yaml, + vocoder=vocoder, + fp16=fp16, + **kwargs, + ) + return TTSHubInterface(x["args"], x["task"], x["models"][0]) + + @staticmethod + def add_args(parser): + parser.add_argument("--dropout", type=float) + parser.add_argument("--output-frame-dim", type=int) + parser.add_argument("--speaker-embed-dim", type=int) + # encoder prenet + parser.add_argument("--encoder-dropout", type=float) + parser.add_argument("--encoder-conv-layers", type=int) + parser.add_argument("--encoder-conv-kernel-size", type=int) + # encoder transformer layers + parser.add_argument("--encoder-transformer-layers", type=int) + parser.add_argument("--encoder-embed-dim", type=int) + parser.add_argument("--encoder-ffn-embed-dim", type=int) + parser.add_argument("--encoder-normalize-before", action="store_true") + parser.add_argument("--encoder-attention-heads", type=int) + parser.add_argument("--attention-dropout", type=float) + parser.add_argument("--activation-dropout", "--relu-dropout", type=float) + parser.add_argument("--activation-fn", type=str, default="relu") + # decoder prenet + parser.add_argument("--prenet-dropout", type=float) + parser.add_argument("--prenet-layers", type=int) + parser.add_argument("--prenet-dim", type=int) + # decoder postnet + parser.add_argument("--postnet-dropout", type=float) + parser.add_argument("--postnet-layers", type=int) + parser.add_argument("--postnet-conv-dim", type=int) + parser.add_argument("--postnet-conv-kernel-size", type=int) + # decoder transformer layers + parser.add_argument("--decoder-transformer-layers", type=int) + parser.add_argument("--decoder-embed-dim", type=int) + parser.add_argument("--decoder-ffn-embed-dim", type=int) + parser.add_argument("--decoder-normalize-before", action="store_true") + parser.add_argument("--decoder-attention-heads", type=int) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._num_updates = 0 + + @classmethod + def build_model(cls, args, task): + embed_speaker = task.get_speaker_embeddings(args) + encoder = TTSTransformerEncoder(args, task.src_dict, embed_speaker) + decoder = TTSTransformerDecoder(args, task.src_dict) + return cls(encoder, decoder) + + def forward_encoder(self, src_tokens, src_lengths, speaker=None, **kwargs): + return self.encoder( + src_tokens, src_lengths=src_lengths, speaker=speaker, **kwargs + ) + + def set_num_updates(self, num_updates): + super().set_num_updates(num_updates) + self._num_updates = num_updates + + +@register_model_architecture("tts_transformer", "tts_transformer") +def base_architecture(args): + args.dropout = getattr(args, "dropout", 0.1) + args.output_frame_dim = getattr(args, "output_frame_dim", 80) + args.speaker_embed_dim = getattr(args, "speaker_embed_dim", 64) + # encoder prenet + args.encoder_dropout = getattr(args, "encoder_dropout", 0.5) + args.encoder_conv_layers = getattr(args, "encoder_conv_layers", 3) + args.encoder_conv_kernel_size = getattr(args, "encoder_conv_kernel_size", 5) + # encoder transformer layers + args.encoder_transformer_layers = getattr(args, "encoder_transformer_layers", 6) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr( + args, "encoder_ffn_embed_dim", 4 * args.encoder_embed_dim + ) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "relu") + # decoder prenet + args.prenet_dropout = getattr(args, "prenet_dropout", 0.5) + args.prenet_layers = getattr(args, "prenet_layers", 2) + args.prenet_dim = getattr(args, "prenet_dim", 256) + # decoder postnet + args.postnet_dropout = getattr(args, "postnet_dropout", 0.5) + args.postnet_layers = getattr(args, "postnet_layers", 5) + args.postnet_conv_dim = getattr(args, "postnet_conv_dim", 512) + args.postnet_conv_kernel_size = getattr(args, "postnet_conv_kernel_size", 5) + # decoder transformer layers + args.decoder_transformer_layers = getattr(args, "decoder_transformer_layers", 6) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", 4 * args.decoder_embed_dim + ) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) diff --git a/fairseq/models/text_to_speech/vocoder.py b/fairseq/models/text_to_speech/vocoder.py new file mode 100644 index 0000000000000000000000000000000000000000..dbc02da368cd337fba7727f75e6092110e6e32c6 --- /dev/null +++ b/fairseq/models/text_to_speech/vocoder.py @@ -0,0 +1,305 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json +import logging +from typing import Dict + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from fairseq.data.audio.audio_utils import ( + TTSSpectrogram, + get_fourier_basis, + get_mel_filters, + get_window, +) +from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig +from fairseq.models import BaseFairseqModel, register_model +from fairseq.models.text_to_speech.codehifigan import CodeGenerator as CodeHiFiGANModel +from fairseq.models.text_to_speech.hifigan import Generator as HiFiGANModel +from fairseq.models.text_to_speech.hub_interface import VocoderHubInterface + +logger = logging.getLogger(__name__) + + +class PseudoInverseMelScale(torch.nn.Module): + def __init__(self, n_stft, n_mels, sample_rate, f_min, f_max) -> None: + super(PseudoInverseMelScale, self).__init__() + self.n_mels = n_mels + basis = get_mel_filters(sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max) + basis = torch.pinverse(basis) # F x F_mel + self.register_buffer("basis", basis) + + def forward(self, melspec: torch.Tensor) -> torch.Tensor: + # pack batch + shape = melspec.shape # B_1 x ... x B_K x F_mel x T + n_mels, time = shape[-2], shape[-1] + melspec = melspec.view(-1, n_mels, time) + + freq, _ = self.basis.size() # F x F_mel + assert self.n_mels == n_mels, (self.n_mels, n_mels) + specgram = self.basis.matmul(melspec).clamp(min=0) + + # unpack batch + specgram = specgram.view(shape[:-2] + (freq, time)) + return specgram + + +class GriffinLim(torch.nn.Module): + def __init__( + self, + n_fft: int, + win_length: int, + hop_length: int, + n_iter: int, + window_fn=torch.hann_window, + ): + super(GriffinLim, self).__init__() + self.transform = TTSSpectrogram( + n_fft, win_length, hop_length, return_phase=True + ) + + basis = get_fourier_basis(n_fft) + basis = torch.pinverse(n_fft / hop_length * basis).T[:, None, :] + basis *= get_window(window_fn, n_fft, win_length) + self.register_buffer("basis", basis) + + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + self.n_iter = n_iter + + self.tiny = 1.1754944e-38 + + @classmethod + def get_window_sum_square( + cls, n_frames, hop_length, win_length, n_fft, window_fn=torch.hann_window + ) -> torch.Tensor: + w_sq = get_window(window_fn, n_fft, win_length) ** 2 + n = n_fft + hop_length * (n_frames - 1) + x = torch.zeros(n, dtype=torch.float32) + for i in range(n_frames): + ofst = i * hop_length + x[ofst : min(n, ofst + n_fft)] += w_sq[: max(0, min(n_fft, n - ofst))] + return x + + def inverse(self, magnitude: torch.Tensor, phase) -> torch.Tensor: + x = torch.cat( + [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 + ) + x = F.conv_transpose1d(x, self.basis, stride=self.hop_length) + win_sum_sq = self.get_window_sum_square( + magnitude.shape[-1], + hop_length=self.hop_length, + win_length=self.win_length, + n_fft=self.n_fft, + ).to(magnitude.device) + # remove modulation effects + approx_nonzero_indices = win_sum_sq > self.tiny + x[:, :, approx_nonzero_indices] /= win_sum_sq[approx_nonzero_indices] + x *= self.n_fft / self.hop_length + x = x[:, :, self.n_fft // 2 :] + x = x[:, :, : -self.n_fft // 2 :] + return x + + def forward(self, specgram: torch.Tensor) -> torch.Tensor: + angles = np.angle(np.exp(2j * np.pi * np.random.rand(*specgram.shape))) + angles = torch.from_numpy(angles).to(specgram) + _specgram = specgram.view(-1, specgram.shape[-2], specgram.shape[-1]) + waveform = self.inverse(_specgram, angles).squeeze(1) + for _ in range(self.n_iter): + _, angles = self.transform(waveform) + waveform = self.inverse(_specgram, angles).squeeze(1) + return waveform.squeeze(0) + + +class GriffinLimVocoder(nn.Module): + def __init__( + self, + sample_rate, + win_size, + hop_size, + n_fft, + n_mels, + f_min, + f_max, + window_fn, + spec_bwd_max_iter=32, + fp16=False, + ): + super().__init__() + self.inv_mel_transform = PseudoInverseMelScale( + n_stft=n_fft // 2 + 1, + n_mels=n_mels, + sample_rate=sample_rate, + f_min=f_min, + f_max=f_max, + ) + self.gl_transform = GriffinLim( + n_fft=n_fft, + win_length=win_size, + hop_length=hop_size, + window_fn=window_fn, + n_iter=spec_bwd_max_iter, + ) + if fp16: + self.half() + self.inv_mel_transform.half() + self.gl_transform.half() + else: + self.float() + self.inv_mel_transform.float() + self.gl_transform.float() + + def forward(self, x): + # x: (B x) T x D -> (B x) 1 x T + # NOTE: batched forward produces noisier waveform. recommend running + # one utterance at a time + self.eval() + x = x.exp().transpose(-1, -2) + x = self.inv_mel_transform(x) + x = self.gl_transform(x) + return x + + @classmethod + def from_data_cfg(cls, args, data_cfg: S2TDataConfig): + feat_cfg = data_cfg.config["features"] + window_fn = getattr(torch, feat_cfg["window_fn"] + "_window") + return cls( + sample_rate=feat_cfg["sample_rate"], + win_size=int(feat_cfg["win_len_t"] * feat_cfg["sample_rate"]), + hop_size=int(feat_cfg["hop_len_t"] * feat_cfg["sample_rate"]), + n_fft=feat_cfg["n_fft"], + n_mels=feat_cfg["n_mels"], + f_min=feat_cfg["f_min"], + f_max=feat_cfg["f_max"], + window_fn=window_fn, + spec_bwd_max_iter=args.spec_bwd_max_iter, + fp16=args.fp16, + ) + + +class HiFiGANVocoder(nn.Module): + def __init__( + self, checkpoint_path: str, model_cfg: Dict[str, str], fp16: bool = False + ) -> None: + super().__init__() + self.model = HiFiGANModel(model_cfg) + state_dict = torch.load(checkpoint_path) + self.model.load_state_dict(state_dict["generator"]) + if fp16: + self.model.half() + logger.info(f"loaded HiFiGAN checkpoint from {checkpoint_path}") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # (B x) T x D -> (B x) 1 x T + model = self.model.eval() + if len(x.shape) == 2: + return model(x.unsqueeze(0).transpose(1, 2)).detach().squeeze(0) + else: + return model(x.transpose(-1, -2)).detach() + + @classmethod + def from_data_cfg(cls, args, data_cfg: S2TDataConfig): + vocoder_cfg = data_cfg.vocoder + assert vocoder_cfg.get("type", "griffin_lim") == "hifigan" + with open(vocoder_cfg["config"]) as f: + model_cfg = json.load(f) + return cls(vocoder_cfg["checkpoint"], model_cfg, fp16=args.fp16) + + +@register_model("CodeHiFiGANVocoder") +class CodeHiFiGANVocoder(BaseFairseqModel): + def __init__( + self, checkpoint_path: str, model_cfg: Dict[str, str], fp16: bool = False + ) -> None: + super().__init__() + self.model = CodeHiFiGANModel(model_cfg) + if torch.cuda.is_available(): + state_dict = torch.load(checkpoint_path) + else: + state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.model.load_state_dict(state_dict["generator"]) + self.model.eval() + if fp16: + self.model.half() + self.model.remove_weight_norm() + logger.info(f"loaded CodeHiFiGAN checkpoint from {checkpoint_path}") + + def forward(self, x: Dict[str, torch.Tensor], dur_prediction=False) -> torch.Tensor: + assert "code" in x + x["dur_prediction"] = dur_prediction + + # remove invalid code + mask = x["code"] >= 0 + x["code"] = x["code"][mask].unsqueeze(dim=0) + if "f0" in x: + f0_up_ratio = x["f0"].size(1) // x["code"].size(1) + mask = mask.unsqueeze(2).repeat(1, 1, f0_up_ratio).view(-1, x["f0"].size(1)) + x["f0"] = x["f0"][mask].unsqueeze(dim=0) + + return self.model(**x).detach().squeeze() + + @classmethod + def from_data_cfg(cls, args, data_cfg): + vocoder_cfg = data_cfg.vocoder + assert vocoder_cfg is not None, "vocoder not specified in the data config" + with open(vocoder_cfg["config"]) as f: + model_cfg = json.load(f) + return cls(vocoder_cfg["checkpoint"], model_cfg, fp16=args.fp16) + + @classmethod + def hub_models(cls): + base_url = "http://dl.fbaipublicfiles.com/fairseq/vocoder" + model_ids = [ + "unit_hifigan_mhubert_vp_en_es_fr_it3_400k_layer11_km1000_lj_dur", + "unit_hifigan_mhubert_vp_en_es_fr_it3_400k_layer11_km1000_es_css10_dur", + "unit_hifigan_HK_layer12.km2500_frame_TAT-TTS", + ] + return {i: f"{base_url}/{i}.tar.gz" for i in model_ids} + + @classmethod + def from_pretrained( + cls, + model_name_or_path, + checkpoint_file="model.pt", + data_name_or_path=".", + config="config.json", + fp16: bool = False, + **kwargs, + ): + from fairseq import hub_utils + + x = hub_utils.from_pretrained( + model_name_or_path, + checkpoint_file, + data_name_or_path, + archive_map=cls.hub_models(), + config_yaml=config, + fp16=fp16, + is_vocoder=True, + **kwargs, + ) + + with open(f"{x['args']['data']}/{config}") as f: + vocoder_cfg = json.load(f) + assert len(x["args"]["model_path"]) == 1, "Too many vocoder models in the input" + + vocoder = CodeHiFiGANVocoder(x["args"]["model_path"][0], vocoder_cfg) + return VocoderHubInterface(vocoder_cfg, vocoder) + + +def get_vocoder(args, data_cfg: S2TDataConfig): + if args.vocoder == "griffin_lim": + return GriffinLimVocoder.from_data_cfg(args, data_cfg) + elif args.vocoder == "hifigan": + return HiFiGANVocoder.from_data_cfg(args, data_cfg) + elif args.vocoder == "code_hifigan": + return CodeHiFiGANVocoder.from_data_cfg(args, data_cfg) + else: + raise ValueError("Unknown vocoder") diff --git a/fairseq/models/transformer/__init__.py b/fairseq/models/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..681fca3d4553f6832a65f61fc186793bc4ee0679 --- /dev/null +++ b/fairseq/models/transformer/__init__.py @@ -0,0 +1,50 @@ +# Copyright (c) Facebook Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +from .transformer_config import ( + TransformerConfig, + DEFAULT_MAX_SOURCE_POSITIONS, + DEFAULT_MAX_TARGET_POSITIONS, + DEFAULT_MIN_PARAMS_TO_WRAP, +) +from .transformer_decoder import TransformerDecoder, TransformerDecoderBase, Linear +from .transformer_encoder import TransformerEncoder, TransformerEncoderBase +from .transformer_legacy import ( + TransformerModel, + base_architecture, + tiny_architecture, + transformer_iwslt_de_en, + transformer_wmt_en_de, + transformer_vaswani_wmt_en_de_big, + transformer_vaswani_wmt_en_fr_big, + transformer_wmt_en_de_big, + transformer_wmt_en_de_big_t2t, +) +from .transformer_base import TransformerModelBase, Embedding + + +__all__ = [ + "TransformerModelBase", + "TransformerConfig", + "TransformerDecoder", + "TransformerDecoderBase", + "TransformerEncoder", + "TransformerEncoderBase", + "TransformerModel", + "Embedding", + "Linear", + "base_architecture", + "tiny_architecture", + "transformer_iwslt_de_en", + "transformer_wmt_en_de", + "transformer_vaswani_wmt_en_de_big", + "transformer_vaswani_wmt_en_fr_big", + "transformer_wmt_en_de_big", + "transformer_wmt_en_de_big_t2t", + "DEFAULT_MAX_SOURCE_POSITIONS", + "DEFAULT_MAX_TARGET_POSITIONS", + "DEFAULT_MIN_PARAMS_TO_WRAP", +] diff --git a/fairseq/models/transformer/__pycache__/__init__.cpython-310.pyc b/fairseq/models/transformer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..076a1f60cd71a3c3f5ede9e6429af40220a10353 Binary files /dev/null and b/fairseq/models/transformer/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/models/transformer/__pycache__/transformer_base.cpython-310.pyc b/fairseq/models/transformer/__pycache__/transformer_base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3dc2f48b693c46341e2e4fed556c752a93d44cd1 Binary files /dev/null and b/fairseq/models/transformer/__pycache__/transformer_base.cpython-310.pyc differ diff --git a/fairseq/models/transformer/__pycache__/transformer_config.cpython-310.pyc b/fairseq/models/transformer/__pycache__/transformer_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f967ba5d5919bcfec2191e939cf2916a9da164b1 Binary files /dev/null and b/fairseq/models/transformer/__pycache__/transformer_config.cpython-310.pyc differ diff --git a/fairseq/models/transformer/__pycache__/transformer_decoder.cpython-310.pyc b/fairseq/models/transformer/__pycache__/transformer_decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70c5c81054d8fcb350d6ee87f1f70482fbebaccd Binary files /dev/null and b/fairseq/models/transformer/__pycache__/transformer_decoder.cpython-310.pyc differ diff --git a/fairseq/models/transformer/__pycache__/transformer_decoder_aug.cpython-310.pyc b/fairseq/models/transformer/__pycache__/transformer_decoder_aug.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eca074c97d211c251872acc1ec4751b687cdb816 Binary files /dev/null and b/fairseq/models/transformer/__pycache__/transformer_decoder_aug.cpython-310.pyc differ diff --git a/fairseq/models/transformer/__pycache__/transformer_encoder.cpython-310.pyc b/fairseq/models/transformer/__pycache__/transformer_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1e0fcf8158d08c4530b660db2bd6941de763ac5 Binary files /dev/null and b/fairseq/models/transformer/__pycache__/transformer_encoder.cpython-310.pyc differ diff --git a/fairseq/models/transformer/__pycache__/transformer_legacy.cpython-310.pyc b/fairseq/models/transformer/__pycache__/transformer_legacy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80dce1fb3ccc7c76a616006ae771ad4171718c9b Binary files /dev/null and b/fairseq/models/transformer/__pycache__/transformer_legacy.cpython-310.pyc differ diff --git a/fairseq/models/transformer/transformer_base.py b/fairseq/models/transformer/transformer_base.py new file mode 100644 index 0000000000000000000000000000000000000000..f9f097f04b15ff9d87a7d95c2bbfd9e81a58b813 --- /dev/null +++ b/fairseq/models/transformer/transformer_base.py @@ -0,0 +1,193 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from torch import Tensor + +import logging + +from fairseq import utils +from fairseq.dataclass.utils import gen_parser_from_dataclass +from fairseq.distributed import fsdp_wrap +from fairseq.models import FairseqEncoderDecoderModel +from fairseq.models.transformer import ( + TransformerConfig, + TransformerDecoderBase, + TransformerEncoderBase, +) + + +logger = logging.getLogger(__name__) + + +class TransformerModelBase(FairseqEncoderDecoderModel): + """ + Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017) + `_. + + Args: + encoder (TransformerEncoder): the encoder + decoder (TransformerDecoder): the decoder + + The Transformer model provides the following named architectures and + command-line arguments: + + .. argparse:: + :ref: fairseq.models.transformer_parser + :prog: + """ + + def __init__(self, cfg, encoder, decoder): + super().__init__(encoder, decoder) + self.cfg = cfg + self.supports_align_args = True + + @classmethod + def add_args(cls, parser): + """Add model-specific arguments to the parser.""" + # we want to build the args recursively in this case. + gen_parser_from_dataclass( + parser, TransformerConfig(), delete_default=False, with_prefix="" + ) + + @classmethod + def build_model(cls, cfg, task): + """Build a new model instance.""" + + # -- TODO T96535332 + # bug caused by interaction between OmegaConf II and argparsing + cfg.decoder.input_dim = int(cfg.decoder.input_dim) + cfg.decoder.output_dim = int(cfg.decoder.output_dim) + # -- + + if cfg.encoder.layers_to_keep: + cfg.encoder.layers = len(cfg.encoder.layers_to_keep.split(",")) + if cfg.decoder.layers_to_keep: + cfg.decoder.layers = len(cfg.decoder.layers_to_keep.split(",")) + + src_dict, tgt_dict = task.source_dictionary, task.target_dictionary + + if cfg.share_all_embeddings: + if src_dict != tgt_dict: + raise ValueError("--share-all-embeddings requires a joined dictionary") + if cfg.encoder.embed_dim != cfg.decoder.embed_dim: + raise ValueError( + "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" + ) + if cfg.decoder.embed_path and ( + cfg.decoder.embed_path != cfg.encoder.embed_path + ): + raise ValueError( + "--share-all-embeddings not compatible with --decoder-embed-path" + ) + encoder_embed_tokens = cls.build_embedding( + cfg, src_dict, cfg.encoder.embed_dim, cfg.encoder.embed_path + ) + decoder_embed_tokens = encoder_embed_tokens + cfg.share_decoder_input_output_embed = True + elif cfg.merge_src_tgt_embed: + logger.info(f"source dict size: {len(src_dict)}") + logger.info(f"target dict size: {len(tgt_dict)}") + src_dict.update(tgt_dict) + task.src_dict = src_dict + task.tgt_dict = src_dict + logger.info(f"merged dict size: {len(src_dict)}") + encoder_embed_tokens = cls.build_embedding( + cfg, src_dict, cfg.encoder.embed_dim + ) + decoder_embed_tokens = encoder_embed_tokens + cfg.share_decoder_input_output_embed = True + else: + encoder_embed_tokens = cls.build_embedding( + cfg, src_dict, cfg.encoder.embed_dim, cfg.encoder.embed_path + ) + decoder_embed_tokens = cls.build_embedding( + cfg, tgt_dict, cfg.decoder.embed_dim, cfg.decoder.embed_path + ) + if cfg.offload_activations: + cfg.checkpoint_activations = True # offloading implies checkpointing + encoder = cls.build_encoder(cfg, src_dict, encoder_embed_tokens) + decoder = cls.build_decoder(cfg, tgt_dict, decoder_embed_tokens) + return cls(cfg, encoder, decoder) + + @classmethod + def build_embedding(cls, cfg, dictionary, embed_dim, path=None): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + + emb = Embedding(num_embeddings, embed_dim, padding_idx) + # if provided, load from preloaded dictionaries + if path: + embed_dict = utils.parse_embedding(path) + utils.load_embedding(embed_dict, dictionary, emb) + return emb + + @classmethod + def build_encoder(cls, cfg, src_dict, embed_tokens): + return TransformerEncoderBase(cfg, src_dict, embed_tokens) + + @classmethod + def build_decoder(cls, cfg, tgt_dict, embed_tokens): + return TransformerDecoderBase( + cfg, + tgt_dict, + embed_tokens, + no_encoder_attn=cfg.no_cross_attention, + ) + + # TorchScript doesn't support optional arguments with variable length (**kwargs). + # Current workaround is to add union of all arguments in child classes. + def forward( + self, + src_tokens, + src_lengths, + prev_output_tokens, + return_all_hiddens: bool = True, + features_only: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + """ + Run the forward pass for an encoder-decoder model. + + Copied from the base class, but without ``**kwargs``, + which are not supported by TorchScript. + """ + encoder_out = self.encoder( + src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens + ) + decoder_out = self.decoder( + prev_output_tokens, + encoder_out=encoder_out, + features_only=features_only, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + src_lengths=src_lengths, + return_all_hiddens=return_all_hiddens, + ) + return decoder_out + + # Since get_normalized_probs is in the Fairseq Model which is not scriptable, + # I rewrite the get_normalized_probs from Base Class to call the + # helper function in the Base Class. + @torch.jit.export + def get_normalized_probs( + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + """Get normalized probabilities (or log probs) from a net's output.""" + return self.get_normalized_probs_scriptable(net_output, log_probs, sample) + + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5) + nn.init.constant_(m.weight[padding_idx], 0) + return m diff --git a/fairseq/models/transformer/transformer_config.py b/fairseq/models/transformer/transformer_config.py new file mode 100644 index 0000000000000000000000000000000000000000..4650de2e1710a7bb790431dea855a82ac9b49def --- /dev/null +++ b/fairseq/models/transformer/transformer_config.py @@ -0,0 +1,341 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import re +from dataclasses import dataclass, field, fields +from typing import List, Optional + +from omegaconf import II + +from fairseq import utils +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.utils import safe_getattr, safe_hasattr + +DEFAULT_MAX_SOURCE_POSITIONS = 1024 +DEFAULT_MAX_TARGET_POSITIONS = 1024 + +DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) + +_NAME_PARSER = r"(decoder|encoder|quant_noise)_(.*)" + + +@dataclass +class EncDecBaseConfig(FairseqDataclass): + embed_path: Optional[str] = field( + default=None, metadata={"help": "path to pre-trained embedding"} + ) + embed_dim: Optional[int] = field( + default=512, metadata={"help": "embedding dimension"} + ) + ffn_embed_dim: int = field( + default=2048, metadata={"help": "embedding dimension for FFN"} + ) + layers: int = field(default=6, metadata={"help": "number of layers"}) + attention_heads: int = field( + default=8, metadata={"help": "number of attention heads"} + ) + normalize_before: bool = field( + default=False, metadata={"help": "apply layernorm before each block"} + ) + learned_pos: bool = field( + default=False, metadata={"help": "use learned positional embeddings"} + ) + # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) + layerdrop: float = field(default=0, metadata={"help": "LayerDrop probability"}) + layers_to_keep: Optional[List[int]] = field( + default=None, metadata={"help": "which layers to *keep* when pruning"} + ) + + xformers_att_config: Optional[str] = field( + default=None, + metadata={ + "help": "config for xFormers attention, defined in xformers.components.attention.AttentionConfig" + }, + ) + + +@dataclass +class DecoderConfig(EncDecBaseConfig): + input_dim: int = II("model.decoder.embed_dim") + output_dim: int = field( + default=II("model.decoder.embed_dim"), + metadata={ + "help": "decoder output dimension (extra linear layer if different from decoder embed dim)" + }, + ) + + def __post_init__(self): + # II doesn't work if we are just creating the object outside of hydra so fix that + if self.input_dim == II("model.decoder.embed_dim"): + self.input_dim = self.embed_dim + if self.output_dim == II("model.decoder.embed_dim"): + self.output_dim = self.embed_dim + + +@dataclass +class QuantNoiseConfig(FairseqDataclass): + pq: float = field( + default=0.0, + metadata={"help": "iterative PQ quantization noise at training time"}, + ) + pq_block_size: int = field( + default=8, + metadata={"help": "block size of quantization noise at training time"}, + ) + scalar: float = field( + default=0.0, + metadata={ + "help": "scalar quantization noise and scalar quantization at training time" + }, + ) + + +@dataclass +class TransformerConfig(FairseqDataclass): + activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( + default="relu", + metadata={"help": "activation function to use"}, + ) + dropout: float = field(default=0.1, metadata={"help": "dropout probability"}) + attention_dropout: float = field( + default=0.0, metadata={"help": "dropout probability for attention weights"} + ) + activation_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability after activation in FFN.", + "alias": "--relu-dropout", + }, + ) + adaptive_input: bool = False + encoder: EncDecBaseConfig = EncDecBaseConfig() + # TODO should really be in the encoder config + max_source_positions: int = field( + default=DEFAULT_MAX_SOURCE_POSITIONS, + metadata={"help": "Maximum input length supported by the encoder"}, + ) + decoder: DecoderConfig = DecoderConfig() + # TODO should really be in the decoder config + max_target_positions: int = field( + default=DEFAULT_MAX_TARGET_POSITIONS, + metadata={"help": "Maximum output length supported by the decoder"}, + ) + share_decoder_input_output_embed: bool = field( + default=False, metadata={"help": "share decoder input and output embeddings"} + ) + share_all_embeddings: bool = field( + default=False, + metadata={ + "help": "share encoder, decoder and output embeddings (requires shared dictionary and embed dim)" + }, + ) + merge_src_tgt_embed: bool = field( + default=False, + metadata={ + "help": "if true then the source and target embedding table is " + "merged into one table. This is going to make the model smaller but " + "it might hurt performance." + }, + ) + no_token_positional_embeddings: bool = field( + default=False, + metadata={ + "help": "if True, disables positional embeddings (outside self attention)" + }, + ) + adaptive_softmax_cutoff: Optional[List[int]] = field( + default=None, + metadata={ + "help": "list of adaptive softmax cutoff points. Must be used with adaptive_loss criterion" + }, + ) + adaptive_softmax_dropout: float = field( + default=0.0, + metadata={"help": "sets adaptive softmax dropout for the tail projections"}, + ) + adaptive_softmax_factor: float = field( + default=4, metadata={"help": "adaptive input factor"} + ) + layernorm_embedding: bool = field( + default=False, metadata={"help": "add layernorm to embedding"} + ) + tie_adaptive_weights: bool = field( + default=False, + metadata={ + "help": "if set, ties the weights of adaptive softmax and adaptive input" + }, + ) + tie_adaptive_proj: bool = field( + default=False, + metadata={ + "help": "if set, ties the projection weights of adaptive softmax and adaptive input" + }, + ) + no_scale_embedding: bool = field( + default=False, metadata={"help": "if True, dont scale embeddings"} + ) + checkpoint_activations: bool = field( + default=False, + metadata={ + "help": "checkpoint activations at each layer, which saves GPU memory usage at the cost of some additional compute" + }, + ) + offload_activations: bool = field( + default=False, + metadata={ + "help": "checkpoint activations at each layer, then save to gpu. Sets --checkpoint-activations." + }, + ) + # args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019) + no_cross_attention: bool = field( + default=False, metadata={"help": "do not perform cross-attention"} + ) + cross_self_attention: bool = field( + default=False, metadata={"help": "perform cross+self-attention"} + ) + # args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020) + quant_noise: QuantNoiseConfig = field(default=QuantNoiseConfig()) + min_params_to_wrap: int = field( + default=DEFAULT_MIN_PARAMS_TO_WRAP, + metadata={ + "help": "minimum number of params for a layer to be wrapped with FSDP() when " + "training with --ddp-backend=fully_sharded. Smaller values will " + "improve memory efficiency, but may make torch.distributed " + "communication less efficient due to smaller input sizes. This option " + "is set to 0 (i.e., always wrap) when --checkpoint-activations or " + "--offload-activations are passed." + }, + ) + # DEPRECATED field, but some old checkpoints might have it + char_inputs: bool = field( + default=False, metadata={"help": "if set, model takes character ids as input"} + ) + relu_dropout: float = 0.0 + # config for "BASE Layers: Simplifying Training of Large, Sparse Models" + base_layers: Optional[int] = field( + default=0, metadata={"help": "number of BASE layers in total"} + ) + base_sublayers: Optional[int] = field( + default=1, metadata={"help": "number of sublayers in each BASE layer"} + ) + base_shuffle: Optional[int] = field( + default=1, + metadata={"help": "shuffle tokens between workers before computing assignment"}, + ) + + export: bool = field( + default=False, + metadata={"help": "make the layernorm exportable with torchscript."}, + ) + + # copied from transformer_lm but expected in transformer_decoder: + no_decoder_final_norm: bool = field( + default=False, + metadata={"help": "don't add an extra layernorm after the last decoder block"}, + ) + + # We need to make this hierarchical dataclass like the flat namespace + # __getattr__ and __setattr__ here allow backward compatibility + # for subclasses of Transformer(Legacy) that depend on read/write on + # the flat namespace. + + def __getattr__(self, name): + match = re.match(_NAME_PARSER, name) + if match: + sub = safe_getattr(self, match[1]) + return safe_getattr(sub, match[2]) + raise AttributeError(f"invalid argument {name}.") + + def __setattr__(self, name, value): + match = re.match(_NAME_PARSER, name) + if match: + sub = safe_getattr(self, match[1]) + setattr(sub, match[2], value) + else: + super().__setattr__(name, value) + + @staticmethod + def _copy_keys(args, cls, prefix, seen): + """ + copy the prefixed keys (decoder_embed_dim) to the DC fields: decoder.embed_dim + """ + cfg = cls() + for fld in fields(cls): + # for all the fields in the DC, find the fields (e.g. embed_dim) + # in the namespace with the prefix (e.g. decoder) + # and set it on the dc. + args_key = f"{prefix}_{fld.name}" + if safe_hasattr(args, args_key): + seen.add(args_key) + setattr(cfg, fld.name, safe_getattr(args, args_key)) + if safe_hasattr(args, fld.name): + seen.add(fld.name) + setattr(cfg, fld.name, safe_getattr(args, fld.name)) + return cfg + + @classmethod + def from_namespace(cls, args): + if args is None: + return None + if not isinstance(args, cls): + seen = set() + config = cls() + # currently, we can go generically from DC fields to args hierarchically + # but we can't easily deconstruct a flat namespace to a hierarchical + # DC. Mostly because we could have a sub-dc called `decoder-foo` that should not + # go to the sub struct called `decoder`. There are ways to go around this, but let's keep it simple + # for now. + for fld in fields(cls): + # concretelly, the transformer_config know what sub-dc it has, so we go through all the dc fields + # and if it's one that has a sub-dc, we build that sub-dc with `copy_keys()` + if fld.name == "decoder": + if safe_hasattr(args, "decoder"): + # in some cases, the args we receive is already structured (as DictConfigs), so let's just build the correct DC + seen.add("decoder") + config.decoder = DecoderConfig(**args.decoder) + else: + config.decoder = cls._copy_keys( + args, DecoderConfig, "decoder", seen + ) + elif fld.name == "encoder": + # same but for encoder + if safe_hasattr(args, "encoder"): + seen.add("encoder") + config.encoder = EncDecBaseConfig(**args.encoder) + else: + config.encoder = cls._copy_keys( + args, EncDecBaseConfig, "encoder", seen + ) + elif fld.name == "quant_noise": + # same but for quant_noise + if safe_hasattr(args, "quant_noise"): + seen.add("quant_noise") + config.quant_noise = QuantNoiseConfig(**args.quant_noise) + else: + config.quant_noise = cls._copy_keys( + args, QuantNoiseConfig, "quant_noise", seen + ) + elif safe_hasattr(args, fld.name): + # if it's not a structure field, it's just a normal field, copy it over + seen.add(fld.name) + setattr(config, fld.name, safe_getattr(args, fld.name)) + # we got all the fields defined in the dataclass, but + # the argparse namespace might have extra args for two reasons: + # - we are in a legacy class so all the args are not declared in the dataclass. Ideally once everyone has defined a dataclass for their model, we won't need this + # - some places expect args to be there but never define them + args_dict = ( + args._asdict() + if safe_hasattr(args, "_asdict") + else vars(args) + if safe_hasattr(args, "__dict__") + else {} + ) # namedtupled doesn't have __dict__ :-/ + for key, value in args_dict.items(): + if key not in seen: + setattr(config, key, value) + return config + else: + return args diff --git a/fairseq/models/transformer/transformer_decoder.py b/fairseq/models/transformer/transformer_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..744c73f4f8cda6f0e9d6899d146ca1343d514b45 --- /dev/null +++ b/fairseq/models/transformer/transformer_decoder.py @@ -0,0 +1,474 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn +from torch import Tensor + +from fairseq import utils +from fairseq.distributed import fsdp_wrap +from fairseq.models import FairseqIncrementalDecoder +from fairseq.models.transformer import TransformerConfig +from fairseq.modules import ( + AdaptiveSoftmax, + BaseLayer, + FairseqDropout, + LayerDropModuleList, + LayerNorm, + PositionalEmbedding, + SinusoidalPositionalEmbedding, + transformer_layer, +) +from fairseq.modules.checkpoint_activations import checkpoint_wrapper +from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ + + +# rewrite name for backward compatibility in `make_generation_fast_` +def module_name_fordropout(module_name: str) -> str: + if module_name == "TransformerDecoderBase": + return "TransformerDecoder" + else: + return module_name + + +class TransformerDecoderBase(FairseqIncrementalDecoder): + """ + Transformer decoder consisting of *cfg.decoder.layers* layers. Each layer + is a :class:`TransformerDecoderLayer`. + + Args: + cfg (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): decoding dictionary + embed_tokens (torch.nn.Embedding): output embedding + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__( + self, + cfg, + dictionary, + embed_tokens, + no_encoder_attn=False, + output_projection=None, + ): + self.cfg = cfg + super().__init__(dictionary) + self.register_buffer("version", torch.Tensor([3])) + self._future_mask = torch.empty(0) + + self.dropout_module = FairseqDropout( + cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__) + ) + self.decoder_layerdrop = cfg.decoder.layerdrop + self.share_input_output_embed = cfg.share_decoder_input_output_embed + + input_embed_dim = embed_tokens.embedding_dim + embed_dim = cfg.decoder.embed_dim + self.embed_dim = embed_dim + self.output_embed_dim = cfg.decoder.output_dim + + self.padding_idx = embed_tokens.padding_idx + self.max_target_positions = cfg.max_target_positions + + self.embed_tokens = embed_tokens + + self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim) + + if not cfg.adaptive_input and cfg.quant_noise.pq > 0: + self.quant_noise = apply_quant_noise_( + nn.Linear(embed_dim, embed_dim, bias=False), + cfg.quant_noise.pq, + cfg.quant_noise.pq_block_size, + ) + else: + self.quant_noise = None + + self.project_in_dim = ( + Linear(input_embed_dim, embed_dim, bias=False) + if embed_dim != input_embed_dim + else None + ) + self.embed_positions = ( + PositionalEmbedding( + self.max_target_positions, + embed_dim, + self.padding_idx, + learned=cfg.decoder.learned_pos, + ) + if not cfg.no_token_positional_embeddings + else None + ) + if cfg.layernorm_embedding: + self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export) + else: + self.layernorm_embedding = None + + self.cross_self_attention = cfg.cross_self_attention + + if self.decoder_layerdrop > 0.0: + self.layers = LayerDropModuleList(p=self.decoder_layerdrop) + else: + self.layers = nn.ModuleList([]) + self.layers.extend( + [ + self.build_decoder_layer(cfg, no_encoder_attn) + for _ in range(cfg.decoder.layers) + ] + ) + self.num_layers = len(self.layers) + + if cfg.decoder.normalize_before and not cfg.no_decoder_final_norm: + self.layer_norm = LayerNorm(embed_dim, export=cfg.export) + else: + self.layer_norm = None + + self.project_out_dim = ( + Linear(embed_dim, self.output_embed_dim, bias=False) + if embed_dim != self.output_embed_dim and not cfg.tie_adaptive_weights + else None + ) + + self.adaptive_softmax = None + self.output_projection = output_projection + if self.output_projection is None: + self.build_output_projection(cfg, dictionary, embed_tokens) + + def build_output_projection(self, cfg, dictionary, embed_tokens): + if cfg.adaptive_softmax_cutoff is not None: + self.adaptive_softmax = AdaptiveSoftmax( + len(dictionary), + self.output_embed_dim, + utils.eval_str_list(cfg.adaptive_softmax_cutoff, type=int), + dropout=cfg.adaptive_softmax_dropout, + adaptive_inputs=embed_tokens if cfg.tie_adaptive_weights else None, + factor=cfg.adaptive_softmax_factor, + tie_proj=cfg.tie_adaptive_proj, + ) + elif self.share_input_output_embed: + self.output_projection = nn.Linear( + self.embed_tokens.weight.shape[1], + self.embed_tokens.weight.shape[0], + bias=False, + ) + self.output_projection.weight = self.embed_tokens.weight + else: + self.output_projection = nn.Linear( + self.output_embed_dim, len(dictionary), bias=False + ) + nn.init.normal_( + self.output_projection.weight, mean=0, std=self.output_embed_dim**-0.5 + ) + num_base_layers = cfg.base_layers + for i in range(num_base_layers): + self.layers.insert( + ((i + 1) * cfg.decoder.layers) // (num_base_layers + 1), + BaseLayer(cfg), + ) + + def build_decoder_layer(self, cfg, no_encoder_attn=False): + layer = transformer_layer.TransformerDecoderLayerBase(cfg, no_encoder_attn) + checkpoint = cfg.checkpoint_activations + if checkpoint: + offload_to_cpu = cfg.offload_activations + layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) + # if we are checkpointing, enforce that FSDP always wraps the + # checkpointed layer, regardless of layer size + min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 + layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) + return layer + + def forward( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + features_only: bool = False, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + src_lengths: Optional[Any] = None, + return_all_hiddens: bool = False, + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (optional): output from the encoder, used for + encoder-side attention, should be of size T x B x C + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + features_only (bool, optional): only return features without + applying output layer (default: False). + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + + x, extra = self.extract_features( + prev_output_tokens, + encoder_out=encoder_out, + incremental_state=incremental_state, + full_context_alignment=full_context_alignment, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + ) + + if not features_only: + x = self.output_layer(x) + return x, extra + + def extract_features( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]], + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + return self.extract_features_scriptable( + prev_output_tokens, + encoder_out, + incremental_state, + full_context_alignment, + alignment_layer, + alignment_heads, + ) + + """ + A scriptable subclass of this class has an extract_features method and calls + super().extract_features, but super() is not supported in torchscript. A copy of + this function is made to be used in the subclass instead. + """ + + def extract_features_scriptable( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]], + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + """ + Similar to *forward* but only return features. + + Includes several features from "Jointly Learning to Align and + Translate with Transformer Models" (Garg et al., EMNLP 2019). + + Args: + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). + alignment_layer (int, optional): return mean alignment over + heads at this layer (default: last layer). + alignment_heads (int, optional): only average alignment over + this many heads (default: all heads). + + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + bs, slen = prev_output_tokens.size() + if alignment_layer is None: + alignment_layer = self.num_layers - 1 + + enc: Optional[Tensor] = None + padding_mask: Optional[Tensor] = None + if encoder_out is not None and len(encoder_out["encoder_out"]) > 0: + enc = encoder_out["encoder_out"][0] + if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0: + padding_mask = encoder_out["encoder_padding_mask"][0] + + # embed positions + positions = None + if self.embed_positions is not None: + positions = self.embed_positions( + prev_output_tokens, incremental_state=incremental_state + ) + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + # Prevent torchscript exporting issue for dynamic quant embedding + prev_output_tokens = prev_output_tokens.contiguous() + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + + if self.quant_noise is not None: + x = self.quant_noise(x) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if positions is not None: + x += positions + + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + + x = self.dropout_module(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + self_attn_padding_mask: Optional[Tensor] = None + if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any(): + self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) + + # decoder layers + attn: Optional[Tensor] = None + inner_states: List[Optional[Tensor]] = [x] + for idx, layer in enumerate(self.layers): + if incremental_state is None and not full_context_alignment: + self_attn_mask = self.buffered_future_mask(x) + else: + self_attn_mask = None + + x, layer_attn, _ = layer( + x, + enc, + padding_mask, + incremental_state, + self_attn_mask=self_attn_mask, + self_attn_padding_mask=self_attn_padding_mask, + need_attn=bool((idx == alignment_layer)), + need_head_weights=bool((idx == alignment_layer)), + ) + inner_states.append(x) + if layer_attn is not None and idx == alignment_layer: + attn = layer_attn.float().to(x) + + if attn is not None: + if alignment_heads is not None: + attn = attn[:alignment_heads] + + # average probabilities over heads + attn = attn.mean(dim=0) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + if self.project_out_dim is not None: + x = self.project_out_dim(x) + + return x, {"attn": [attn], "inner_states": inner_states} + + def output_layer(self, features): + """Project features to the vocabulary size.""" + if self.adaptive_softmax is None: + # project back to size of vocabulary + return self.output_projection(features) + else: + return features + + def max_positions(self): + """Maximum output length supported by the decoder.""" + if self.embed_positions is None: + return self.max_target_positions + return min(self.max_target_positions, self.embed_positions.max_positions) + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround. + if ( + self._future_mask.size(0) == 0 + or (not self._future_mask.device == tensor.device) + or self._future_mask.size(0) < dim + ): + self._future_mask = torch.triu( + utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1 + ) + self._future_mask = self._future_mask.to(tensor) + return self._future_mask[:dim, :dim] + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + if f"{name}.output_projection.weight" not in state_dict: + if self.share_input_output_embed: + embed_out_key = f"{name}.embed_tokens.weight" + else: + embed_out_key = f"{name}.embed_out" + if embed_out_key in state_dict: + state_dict[f"{name}.output_projection.weight"] = state_dict[ + embed_out_key + ] + if not self.share_input_output_embed: + del state_dict[embed_out_key] + + for i in range(self.num_layers): + # update layer norms + layer_norm_map = { + "0": "self_attn_layer_norm", + "1": "encoder_attn_layer_norm", + "2": "final_layer_norm", + } + for old, new in layer_norm_map.items(): + for m in ("weight", "bias"): + k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m) + if k in state_dict: + state_dict[ + "{}.layers.{}.{}.{}".format(name, i, new, m) + ] = state_dict[k] + del state_dict[k] + + version_key = "{}.version".format(name) + if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2: + # earlier checkpoints did not normalize after the stack of layers + self.layer_norm = None + self.normalize = False + state_dict[version_key] = torch.Tensor([1]) + + return state_dict + + +def Linear(in_features, out_features, bias=True): + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.0) + return m + + +class TransformerDecoder(TransformerDecoderBase): + def __init__( + self, + args, + dictionary, + embed_tokens, + no_encoder_attn=False, + output_projection=None, + ): + self.args = args + super().__init__( + TransformerConfig.from_namespace(args), + dictionary, + embed_tokens, + no_encoder_attn=no_encoder_attn, + output_projection=output_projection, + ) + + def build_output_projection(self, args, dictionary, embed_tokens): + super().build_output_projection( + TransformerConfig.from_namespace(args), dictionary, embed_tokens + ) + + def build_decoder_layer(self, args, no_encoder_attn=False): + return super().build_decoder_layer( + TransformerConfig.from_namespace(args), no_encoder_attn=no_encoder_attn + ) diff --git a/fairseq/models/transformer/transformer_decoder_aug.py b/fairseq/models/transformer/transformer_decoder_aug.py new file mode 100644 index 0000000000000000000000000000000000000000..b73c06e02a0a9749120e50ce6e9f65c35a0462dc --- /dev/null +++ b/fairseq/models/transformer/transformer_decoder_aug.py @@ -0,0 +1,384 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn +from torch import Tensor + +from fairseq import utils +from fairseq.distributed import fsdp_wrap +from fairseq.models.transformer import TransformerConfig +from fairseq.models.transformer.transformer_decoder import TransformerDecoderBase +from fairseq.modules import ( + LayerDropModuleList, + SinusoidalPositionalEmbedding, + transformer_layer_aug, +) +from fairseq.modules.checkpoint_activations import checkpoint_wrapper + + +class AugTransformerDecoderBase(TransformerDecoderBase): + """ + Transformer decoder augmented with an additional cross-attention. Each layer + is a :class:`AugTransformerDecoderLayerBase`. + + Args: + cfg (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): decoding dictionary + embed_tokens (torch.nn.Embedding): output embedding + encoder_attn_merge_type (str, optional): the way to combine outputs from + two cross-attention modules. If "sequential" is set, two cross-attention + modules are stacked sequentially. If "parallel" is set, they are processed + in parallel and combined before feeding it to FFN (default: sequential). + dropnet_ratio (float, optional): a probability to drop each cross-attention + module during training (default: 0.0). + """ + + def __init__( + self, + cfg, + dictionary, + embed_tokens, + output_projection=None, + encoder_attn_merge_type="sequential", + dropnet_ratio=0.0, + ): + super().__init__( + cfg, + dictionary, + embed_tokens, + no_encoder_attn=False, + output_projection=output_projection, + ) + # assert cfg.cross_self_attention + self.cross_self_attention = cfg.cross_self_attention + + if self.decoder_layerdrop > 0.0: + self.layers = LayerDropModuleList(p=self.decoder_layerdrop) + else: + self.layers = nn.ModuleList([]) + self.layers.extend( + [ + self.build_decoder_layer(cfg, encoder_attn_merge_type, dropnet_ratio) + for _ in range(cfg.decoder.layers) + ] + ) + + def build_decoder_layer( + self, + cfg, + encoder_attn_merge_type="sequential", + dropnet_ratio=0, + ): + layer = transformer_layer_aug.AugTransformerDecoderLayerBase( + cfg, + no_encoder_attn=False, + encoder_attn_merge_type=encoder_attn_merge_type, + dropnet_ratio=dropnet_ratio, + ) + checkpoint = cfg.checkpoint_activations + if checkpoint: + offload_to_cpu = cfg.offload_activations + layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) + # if we are checkpointing, enforce that FSDP always wraps the + # checkpointed layer, regardless of layer size + min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 + layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) + return layer + + def forward( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, + encoder_out_aug: Optional[Dict[str, List[Tensor]]] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + features_only: bool = False, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + src_lengths: Optional[Any] = None, + return_all_hiddens: bool = False, + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (optional): output from the encoder, used for + encoder-side attention, should be of size T x B x C + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + features_only (bool, optional): only return features without + applying output layer (default: False). + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + + x, extra = self.extract_features( + prev_output_tokens, + encoder_out=encoder_out, + encoder_out_aug=encoder_out_aug, + incremental_state=incremental_state, + full_context_alignment=full_context_alignment, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + ) + + if not features_only: + x = self.output_layer(x) + return x, extra + + def extract_features( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]], + encoder_out_aug: Optional[Dict[str, List[Tensor]]], + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + return self.extract_features_scriptable( + prev_output_tokens, + encoder_out, + encoder_out_aug, + incremental_state, + full_context_alignment, + alignment_layer, + alignment_heads, + ) + + """ + A scriptable subclass of this class has an extract_features method and calls + super().extract_features, but super() is not supported in torchscript. A copy of + this function is made to be used in the subclass instead. + """ + + def extract_features_scriptable( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]], + encoder_out_aug: Optional[Dict[str, List[Tensor]]], + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + """ + Similar to *forward* but only return features. + + Includes several features from "Jointly Learning to Align and + Translate with Transformer Models" (Garg et al., EMNLP 2019). + + Args: + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). + alignment_layer (int, optional): return mean alignment over + heads at this layer (default: last layer). + alignment_heads (int, optional): only average alignment over + this many heads (default: all heads). + + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + bs, slen = prev_output_tokens.size() + if alignment_layer is None: + alignment_layer = self.num_layers - 1 + + enc: Optional[Tensor] = None + padding_mask: Optional[Tensor] = None + if encoder_out is not None and len(encoder_out["encoder_out"]) > 0: + enc = encoder_out["encoder_out"][0] + if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0: + padding_mask = encoder_out["encoder_padding_mask"][0] + + enc_aug: Optional[Tensor] = None + padding_mask_aug: Optional[Tensor] = None + if encoder_out_aug is not None and len(encoder_out_aug["encoder_out"]) > 0: + enc_aug = encoder_out_aug["encoder_out"][0] + if ( + encoder_out_aug is not None + and len(encoder_out_aug["encoder_padding_mask"]) > 0 + ): + padding_mask_aug = encoder_out_aug["encoder_padding_mask"][0] + + # embed positions + positions = None + if self.embed_positions is not None: + positions = self.embed_positions( + prev_output_tokens, incremental_state=incremental_state + ) + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + # Prevent torchscript exporting issue for dynamic quant embedding + prev_output_tokens = prev_output_tokens.contiguous() + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + + if self.quant_noise is not None: + x = self.quant_noise(x) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if positions is not None: + x += positions + + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + + x = self.dropout_module(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + self_attn_padding_mask: Optional[Tensor] = None + if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any(): + self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) + + # decoder layers + attn: Optional[Tensor] = None + attn_aug: Optional[Tensor] = None + inner_states: List[Optional[Tensor]] = [x] + for idx, layer in enumerate(self.layers): + if incremental_state is None and not full_context_alignment: + self_attn_mask = self.buffered_future_mask(x) + else: + self_attn_mask = None + + x, layer_attn, layer_attn_aug, _ = layer( + x, + enc, + padding_mask, + enc_aug, + padding_mask_aug, + incremental_state, + self_attn_mask=self_attn_mask, + self_attn_padding_mask=self_attn_padding_mask, + need_attn=bool((idx == alignment_layer)), + need_head_weights=bool((idx == alignment_layer)), + ) + inner_states.append(x) + if layer_attn is not None and idx == alignment_layer: + attn = layer_attn.float().to(x) + if layer_attn_aug is not None and idx == alignment_layer: + attn_aug = layer_attn_aug.float().to(x) + + if attn is not None: + if alignment_heads is not None: + attn = attn[:alignment_heads] + + # average probabilities over heads + attn = attn.mean(dim=0) + + if attn_aug is not None: + if alignment_heads is not None: + attn_aug = attn_aug[:alignment_heads] + + # average probabilities over heads + attn_aug = attn_aug.mean(dim=0) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + if self.project_out_dim is not None: + x = self.project_out_dim(x) + + return x, {"attn": [attn], "attn_aug": [attn_aug], "inner_states": inner_states} + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + if f"{name}.output_projection.weight" not in state_dict: + if self.share_input_output_embed: + embed_out_key = f"{name}.embed_tokens.weight" + else: + embed_out_key = f"{name}.embed_out" + if embed_out_key in state_dict: + state_dict[f"{name}.output_projection.weight"] = state_dict[ + embed_out_key + ] + if not self.share_input_output_embed: + del state_dict[embed_out_key] + + for i in range(self.num_layers): + # update layer norms + layer_norm_map = { + "0": "self_attn_layer_norm", + "1": "encoder_attn_layer_norm", + "2": "encoder_attn_layer_norm2", + "3": "final_layer_norm", + } + for old, new in layer_norm_map.items(): + for m in ("weight", "bias"): + k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m) + if k in state_dict: + state_dict[ + "{}.layers.{}.{}.{}".format(name, i, new, m) + ] = state_dict[k] + del state_dict[k] + + version_key = "{}.version".format(name) + if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2: + # earlier checkpoints did not normalize after the stack of layers + self.layer_norm = None + self.normalize = False + state_dict[version_key] = torch.Tensor([1]) + + return state_dict + + +class AugTransformerDecoder(AugTransformerDecoderBase): + def __init__( + self, + args, + dictionary, + embed_tokens, + output_projection=None, + ): + self.args = args + super().__init__( + TransformerConfig.from_namespace(args), + dictionary, + embed_tokens, + no_encoder_attn=False, + output_projection=output_projection, + encoder_attn_merge_type=getattr( + args, "synthesizer_augmented_cross_attention_merge_type", "sequential" + ), + dropnet_ratio=getattr(args, "dropnet_ratio", 0), + ) + + def build_output_projection(self, args, dictionary, embed_tokens): + super().build_output_projection( + TransformerConfig.from_namespace(args), dictionary, embed_tokens + ) + + def build_decoder_layer( + self, + args, + encoder_attn_merge_type="sequential", + dropnet_ratio=0, + ): + return super().build_decoder_layer( + TransformerConfig.from_namespace(args), + no_encoder_attn=False, + encoder_attn_merge_type=encoder_attn_merge_type, + dropnet_ratio=dropnet_ratio, + ) diff --git a/fairseq/models/transformer/transformer_encoder.py b/fairseq/models/transformer/transformer_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a684fcb448e9ac6a7a6440c825bbe3212789cf88 --- /dev/null +++ b/fairseq/models/transformer/transformer_encoder.py @@ -0,0 +1,362 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +from torch import Tensor + +from fairseq import utils +from fairseq.distributed import fsdp_wrap +from fairseq.models import FairseqEncoder +from fairseq.models.transformer import TransformerConfig +from fairseq.modules import ( + FairseqDropout, + LayerDropModuleList, + LayerNorm, + PositionalEmbedding, + SinusoidalPositionalEmbedding, + transformer_layer, +) +from fairseq.modules.checkpoint_activations import checkpoint_wrapper +from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ + + +# rewrite name for backward compatibility in `make_generation_fast_` +def module_name_fordropout(module_name: str) -> str: + if module_name == "TransformerEncoderBase": + return "TransformerEncoder" + else: + return module_name + + +class TransformerEncoderBase(FairseqEncoder): + """ + Transformer encoder consisting of *cfg.encoder.layers* layers. Each layer + is a :class:`TransformerEncoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): encoding dictionary + embed_tokens (torch.nn.Embedding): input embedding + """ + + def __init__(self, cfg, dictionary, embed_tokens, return_fc=False): + self.cfg = cfg + super().__init__(dictionary) + self.register_buffer("version", torch.Tensor([3])) + + self.dropout_module = FairseqDropout( + cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__) + ) + self.encoder_layerdrop = cfg.encoder.layerdrop + self.return_fc = return_fc + + embed_dim = embed_tokens.embedding_dim + self.padding_idx = embed_tokens.padding_idx + self.max_source_positions = cfg.max_source_positions + + self.embed_tokens = embed_tokens + + self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim) + + self.embed_positions = ( + PositionalEmbedding( + cfg.max_source_positions, + embed_dim, + self.padding_idx, + learned=cfg.encoder.learned_pos, + ) + if not cfg.no_token_positional_embeddings + else None + ) + if cfg.layernorm_embedding: + self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export) + else: + self.layernorm_embedding = None + + if not cfg.adaptive_input and cfg.quant_noise.pq > 0: + self.quant_noise = apply_quant_noise_( + nn.Linear(embed_dim, embed_dim, bias=False), + cfg.quant_noise.pq, + cfg.quant_noise.pq_block_size, + ) + else: + self.quant_noise = None + + if self.encoder_layerdrop > 0.0: + self.layers = LayerDropModuleList(p=self.encoder_layerdrop) + else: + self.layers = nn.ModuleList([]) + self.layers.extend( + [self.build_encoder_layer(cfg) for i in range(cfg.encoder.layers)] + ) + self.num_layers = len(self.layers) + + if cfg.encoder.normalize_before: + self.layer_norm = LayerNorm(embed_dim, export=cfg.export) + else: + self.layer_norm = None + + def build_encoder_layer(self, cfg): + layer = transformer_layer.TransformerEncoderLayerBase( + cfg, return_fc=self.return_fc + ) + checkpoint = cfg.checkpoint_activations + if checkpoint: + offload_to_cpu = cfg.offload_activations + layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) + # if we are checkpointing, enforce that FSDP always wraps the + # checkpointed layer, regardless of layer size + min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 + layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) + return layer + + def forward_embedding( + self, src_tokens, token_embedding: Optional[torch.Tensor] = None + ): + # embed tokens and positions + if token_embedding is None: + token_embedding = self.embed_tokens(src_tokens) + x = embed = self.embed_scale * token_embedding + if self.embed_positions is not None: + x = embed + self.embed_positions(src_tokens) + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + x = self.dropout_module(x) + if self.quant_noise is not None: + x = self.quant_noise(x) + return x, embed + + def forward( + self, + src_tokens, + src_lengths: Optional[torch.Tensor] = None, + return_all_hiddens: bool = False, + token_embeddings: Optional[torch.Tensor] = None, + ): + """ + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (torch.LongTensor): lengths of each source sentence of + shape `(batch)` + return_all_hiddens (bool, optional): also return all of the + intermediate hidden states (default: False). + token_embeddings (torch.Tensor, optional): precomputed embeddings + default `None` will recompute embeddings + + Returns: + dict: + - **encoder_out** (Tensor): the last encoder layer's output of + shape `(src_len, batch, embed_dim)` + - **encoder_padding_mask** (ByteTensor): the positions of + padding elements of shape `(batch, src_len)` + - **encoder_embedding** (Tensor): the (scaled) embedding lookup + of shape `(batch, src_len, embed_dim)` + - **encoder_states** (List[Tensor]): all intermediate + hidden states of shape `(src_len, batch, embed_dim)`. + Only populated if *return_all_hiddens* is True. + """ + return self.forward_scriptable( + src_tokens, src_lengths, return_all_hiddens, token_embeddings + ) + + # TorchScript doesn't support super() method so that the scriptable Subclass + # can't access the base class model in Torchscript. + # Current workaround is to add a helper function with different name and + # call the helper function from scriptable Subclass. + def forward_scriptable( + self, + src_tokens, + src_lengths: Optional[torch.Tensor] = None, + return_all_hiddens: bool = False, + token_embeddings: Optional[torch.Tensor] = None, + ): + """ + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (torch.LongTensor): lengths of each source sentence of + shape `(batch)` + return_all_hiddens (bool, optional): also return all of the + intermediate hidden states (default: False). + token_embeddings (torch.Tensor, optional): precomputed embeddings + default `None` will recompute embeddings + + Returns: + dict: + - **encoder_out** (Tensor): the last encoder layer's output of + shape `(src_len, batch, embed_dim)` + - **encoder_padding_mask** (ByteTensor): the positions of + padding elements of shape `(batch, src_len)` + - **encoder_embedding** (Tensor): the (scaled) embedding lookup + of shape `(batch, src_len, embed_dim)` + - **encoder_states** (List[Tensor]): all intermediate + hidden states of shape `(src_len, batch, embed_dim)`. + Only populated if *return_all_hiddens* is True. + """ + # compute padding mask + encoder_padding_mask = src_tokens.eq(self.padding_idx) + has_pads = ( + torch.tensor(src_tokens.device.type == "xla") or encoder_padding_mask.any() + ) + # Torchscript doesn't handle bool Tensor correctly, so we need to work around. + if torch.jit.is_scripting(): + has_pads = torch.tensor(1) if has_pads else torch.tensor(0) + + x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) + + # account for padding while computing the representation + x = x * ( + 1 - encoder_padding_mask.unsqueeze(-1).type_as(x) * has_pads.type_as(x) + ) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + encoder_states = [] + fc_results = [] + + if return_all_hiddens: + encoder_states.append(x) + + # encoder layers + for layer in self.layers: + lr = layer( + x, encoder_padding_mask=encoder_padding_mask if has_pads else None + ) + + if isinstance(lr, tuple) and len(lr) == 2: + x, fc_result = lr + else: + x = lr + fc_result = None + + if return_all_hiddens and not torch.jit.is_scripting(): + assert encoder_states is not None + encoder_states.append(x) + fc_results.append(fc_result) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in + # `forward` so we use a dictionary instead. + # TorchScript does not support mixed values so the values are all lists. + # The empty list is equivalent to None. + src_lengths = ( + src_tokens.ne(self.padding_idx) + .sum(dim=1, dtype=torch.int32) + .reshape(-1, 1) + .contiguous() + ) + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [encoder_padding_mask], # B x T + "encoder_embedding": [encoder_embedding], # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "fc_results": fc_results, # List[T x B x C] + "src_tokens": [], + "src_lengths": [src_lengths], + } + + @torch.jit.export + def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + if len(encoder_out["encoder_out"]) == 0: + new_encoder_out = [] + else: + new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] + if len(encoder_out["encoder_padding_mask"]) == 0: + new_encoder_padding_mask = [] + else: + new_encoder_padding_mask = [ + encoder_out["encoder_padding_mask"][0].index_select(0, new_order) + ] + if len(encoder_out["encoder_embedding"]) == 0: + new_encoder_embedding = [] + else: + new_encoder_embedding = [ + encoder_out["encoder_embedding"][0].index_select(0, new_order) + ] + + if len(encoder_out["src_tokens"]) == 0: + src_tokens = [] + else: + src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)] + + if len(encoder_out["src_lengths"]) == 0: + src_lengths = [] + else: + src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)] + + encoder_states = encoder_out["encoder_states"] + if len(encoder_states) > 0: + for idx, state in enumerate(encoder_states): + encoder_states[idx] = state.index_select(1, new_order) + + return { + "encoder_out": new_encoder_out, # T x B x C + "encoder_padding_mask": new_encoder_padding_mask, # B x T + "encoder_embedding": new_encoder_embedding, # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": src_tokens, # B x T + "src_lengths": src_lengths, # B x 1 + } + + @torch.jit.export + def _reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): + """Dummy re-order function for beamable enc-dec attention""" + return encoder_out + + def max_positions(self): + """Maximum input length supported by the encoder.""" + if self.embed_positions is None: + return self.max_source_positions + return min(self.max_source_positions, self.embed_positions.max_positions) + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + for i in range(self.num_layers): + # update layer norms + self.layers[i].upgrade_state_dict_named( + state_dict, "{}.layers.{}".format(name, i) + ) + + version_key = "{}.version".format(name) + if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: + # earlier checkpoints did not normalize after the stack of layers + self.layer_norm = None + self.normalize = False + state_dict[version_key] = torch.Tensor([1]) + return state_dict + + +class TransformerEncoder(TransformerEncoderBase): + def __init__(self, args, dictionary, embed_tokens, return_fc=False): + self.args = args + super().__init__( + TransformerConfig.from_namespace(args), + dictionary, + embed_tokens, + return_fc=return_fc, + ) + + def build_encoder_layer(self, args): + return super().build_encoder_layer( + TransformerConfig.from_namespace(args), + ) diff --git a/fairseq/models/transformer/transformer_legacy.py b/fairseq/models/transformer/transformer_legacy.py new file mode 100644 index 0000000000000000000000000000000000000000..00d14a7dde778a4b3b2faecde50eecdd5115a864 --- /dev/null +++ b/fairseq/models/transformer/transformer_legacy.py @@ -0,0 +1,277 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.dataclass.utils import gen_parser_from_dataclass +from fairseq.models import ( + register_model, + register_model_architecture, +) +from fairseq.models.transformer.transformer_config import ( + TransformerConfig, + DEFAULT_MAX_SOURCE_POSITIONS, + DEFAULT_MAX_TARGET_POSITIONS, + DEFAULT_MIN_PARAMS_TO_WRAP, +) +from fairseq.models.transformer.transformer_base import ( + TransformerModelBase, +) + + +@register_model("transformer") +class TransformerModel(TransformerModelBase): + """ + This is the legacy implementation of the transformer model that + uses argparse for configuration. + """ + + @classmethod + def hub_models(cls): + # fmt: off + + def moses_subword(path): + return { + 'path': path, + 'tokenizer': 'moses', + 'bpe': 'subword_nmt', + } + + def moses_fastbpe(path): + return { + 'path': path, + 'tokenizer': 'moses', + 'bpe': 'fastbpe', + } + + def spm(path): + return { + 'path': path, + 'bpe': 'sentencepiece', + 'tokenizer': 'space', + } + + return { + 'transformer.wmt14.en-fr': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2'), + 'transformer.wmt16.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2', + 'transformer.wmt18.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz'), + 'transformer.wmt19.en-de': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz'), + 'transformer.wmt19.en-ru': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz'), + 'transformer.wmt19.de-en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz'), + 'transformer.wmt19.ru-en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz'), + 'transformer.wmt19.en-de.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.single_model.tar.gz'), + 'transformer.wmt19.en-ru.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.single_model.tar.gz'), + 'transformer.wmt19.de-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.single_model.tar.gz'), + 'transformer.wmt19.ru-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.single_model.tar.gz'), + 'transformer.wmt20.en-ta': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-ta.single.tar.gz'), + 'transformer.wmt20.en-iu.news': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.news.single.tar.gz'), + 'transformer.wmt20.en-iu.nh': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.nh.single.tar.gz'), + 'transformer.wmt20.ta-en': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.ta-en.single.tar.gz'), + 'transformer.wmt20.iu-en.news': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.news.single.tar.gz'), + 'transformer.wmt20.iu-en.nh': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.nh.single.tar.gz'), + 'transformer.flores101.mm100.615M': spm('https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_615M.tar.gz'), + 'transformer.flores101.mm100.175M': spm('https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_175M.tar.gz'), + } + # fmt: on + + def __init__(self, args, encoder, decoder): + cfg = TransformerConfig.from_namespace(args) + super().__init__(cfg, encoder, decoder) + self.args = args + + @classmethod + def add_args(cls, parser): + """Add model-specific arguments to the parser.""" + # we want to build the args recursively in this case. + # do not set defaults so that settings defaults from various architectures still works + gen_parser_from_dataclass( + parser, TransformerConfig(), delete_default=True, with_prefix="" + ) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + base_architecture(args) + + if args.encoder_layers_to_keep: + args.encoder_layers = len(args.encoder_layers_to_keep.split(",")) + if args.decoder_layers_to_keep: + args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) + + if getattr(args, "max_source_positions", None) is None: + args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS + if getattr(args, "max_target_positions", None) is None: + args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS + + src_dict, tgt_dict = task.source_dictionary, task.target_dictionary + + if args.share_all_embeddings: + if src_dict != tgt_dict: + raise ValueError("--share-all-embeddings requires a joined dictionary") + if args.encoder_embed_dim != args.decoder_embed_dim: + raise ValueError( + "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" + ) + if args.decoder_embed_path and ( + args.decoder_embed_path != args.encoder_embed_path + ): + raise ValueError( + "--share-all-embeddings not compatible with --decoder-embed-path" + ) + args.share_decoder_input_output_embed = True + + if getattr(args, "offload_activations", False): + args.checkpoint_activations = True # offloading implies checkpointing + + if not args.share_all_embeddings: + args.min_params_to_wrap = getattr( + args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP + ) + cfg = TransformerConfig.from_namespace(args) + return super().build_model(cfg, task) + + @classmethod + def build_embedding(cls, args, dictionary, embed_dim, path=None): + return super().build_embedding( + TransformerConfig.from_namespace(args), dictionary, embed_dim, path + ) + + @classmethod + def build_encoder(cls, args, src_dict, embed_tokens): + return super().build_encoder( + TransformerConfig.from_namespace(args), src_dict, embed_tokens + ) + + @classmethod + def build_decoder(cls, args, tgt_dict, embed_tokens): + return super().build_decoder( + TransformerConfig.from_namespace(args), tgt_dict, embed_tokens + ) + + +# architectures + + +@register_model_architecture("transformer", "transformer_tiny") +def tiny_architecture(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 64) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 64) + args.encoder_layers = getattr(args, "encoder_layers", 2) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 2) + args.decoder_layers = getattr(args, "decoder_layers", 2) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 2) + return base_architecture(args) + + +@register_model_architecture("transformer", "transformer") +def base_architecture(args): + args.encoder_embed_path = getattr(args, "encoder_embed_path", None) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) + + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "relu") + args.dropout = getattr(args, "dropout", 0.1) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.share_all_embeddings = getattr(args, "share_all_embeddings", False) + args.merge_src_tgt_embed = getattr(args, "merge_src_tgt_embed", False) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.no_cross_attention = getattr(args, "no_cross_attention", False) + args.cross_self_attention = getattr(args, "cross_self_attention", False) + + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) + + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + args.layernorm_embedding = getattr(args, "layernorm_embedding", False) + args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) + args.checkpoint_activations = getattr(args, "checkpoint_activations", False) + args.offload_activations = getattr(args, "offload_activations", False) + if args.offload_activations: + args.checkpoint_activations = True + args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) + args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) + args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) + + +@register_model_architecture("transformer", "transformer_iwslt_de_en") +def transformer_iwslt_de_en(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.encoder_layers = getattr(args, "encoder_layers", 6) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) + args.decoder_layers = getattr(args, "decoder_layers", 6) + base_architecture(args) + + +@register_model_architecture("transformer", "transformer_wmt_en_de") +def transformer_wmt_en_de(args): + base_architecture(args) + + +# parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017) +@register_model_architecture("transformer", "transformer_vaswani_wmt_en_de_big") +def transformer_vaswani_wmt_en_de_big(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + args.dropout = getattr(args, "dropout", 0.3) + base_architecture(args) + + +@register_model_architecture("transformer", "transformer_vaswani_wmt_en_fr_big") +def transformer_vaswani_wmt_en_fr_big(args): + args.dropout = getattr(args, "dropout", 0.1) + transformer_vaswani_wmt_en_de_big(args) + + +@register_model_architecture("transformer", "transformer_wmt_en_de_big") +def transformer_wmt_en_de_big(args): + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + transformer_vaswani_wmt_en_de_big(args) + + +# default parameters used in tensor2tensor implementation +@register_model_architecture("transformer", "transformer_wmt_en_de_big_t2t") +def transformer_wmt_en_de_big_t2t(args): + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_dropout = getattr(args, "activation_dropout", 0.1) + transformer_vaswani_wmt_en_de_big(args) diff --git a/fairseq/models/transformer_align.py b/fairseq/models/transformer_align.py new file mode 100644 index 0000000000000000000000000000000000000000..eaf585bd10e630ae6cd89920f197cd165f55ad58 --- /dev/null +++ b/fairseq/models/transformer_align.py @@ -0,0 +1,93 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.models import register_model, register_model_architecture +from fairseq.models.transformer import ( + TransformerModel, + base_architecture, + transformer_wmt_en_de_big, +) + + +@register_model("transformer_align") +class TransformerAlignModel(TransformerModel): + """ + See "Jointly Learning to Align and Translate with Transformer + Models" (Garg et al., EMNLP 2019). + """ + + def __init__(self, encoder, decoder, args): + super().__init__(args, encoder, decoder) + self.alignment_heads = args.alignment_heads + self.alignment_layer = args.alignment_layer + self.full_context_alignment = args.full_context_alignment + + @staticmethod + def add_args(parser): + # fmt: off + super(TransformerAlignModel, TransformerAlignModel).add_args(parser) + parser.add_argument('--alignment-heads', type=int, metavar='D', + help='Number of cross attention heads per layer to supervised with alignments') + parser.add_argument('--alignment-layer', type=int, metavar='D', + help='Layer number which has to be supervised. 0 corresponding to the bottommost layer.') + parser.add_argument('--full-context-alignment', action='store_true', + help='Whether or not alignment is supervised conditioned on the full target context.') + # fmt: on + + @classmethod + def build_model(cls, args, task): + # set any default arguments + transformer_align(args) + + transformer_model = TransformerModel.build_model(args, task) + return TransformerAlignModel( + transformer_model.encoder, transformer_model.decoder, args + ) + + def forward(self, src_tokens, src_lengths, prev_output_tokens): + encoder_out = self.encoder(src_tokens, src_lengths) + return self.forward_decoder(prev_output_tokens, encoder_out) + + def forward_decoder( + self, + prev_output_tokens, + encoder_out=None, + incremental_state=None, + features_only=False, + **extra_args, + ): + attn_args = { + "alignment_layer": self.alignment_layer, + "alignment_heads": self.alignment_heads, + } + decoder_out = self.decoder(prev_output_tokens, encoder_out, **attn_args) + + if self.full_context_alignment: + attn_args["full_context_alignment"] = self.full_context_alignment + _, alignment_out = self.decoder( + prev_output_tokens, + encoder_out, + features_only=True, + **attn_args, + **extra_args, + ) + decoder_out[1]["attn"] = alignment_out["attn"] + + return decoder_out + + +@register_model_architecture("transformer_align", "transformer_align") +def transformer_align(args): + args.alignment_heads = getattr(args, "alignment_heads", 1) + args.alignment_layer = getattr(args, "alignment_layer", 4) + args.full_context_alignment = getattr(args, "full_context_alignment", False) + base_architecture(args) + + +@register_model_architecture("transformer_align", "transformer_wmt_en_de_big_align") +def transformer_wmt_en_de_big_align(args): + args.alignment_heads = getattr(args, "alignment_heads", 1) + args.alignment_layer = getattr(args, "alignment_layer", 4) + transformer_wmt_en_de_big(args) diff --git a/fairseq/models/transformer_from_pretrained_xlm.py b/fairseq/models/transformer_from_pretrained_xlm.py new file mode 100644 index 0000000000000000000000000000000000000000..236d9942e1fb0238cc92e2b4f160520b5cdd6504 --- /dev/null +++ b/fairseq/models/transformer_from_pretrained_xlm.py @@ -0,0 +1,152 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +from typing import Any, Dict + +from fairseq import checkpoint_utils +from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary +from fairseq.models import register_model, register_model_architecture +from fairseq.models.transformer import ( + TransformerDecoder, + TransformerEncoder, + TransformerModel, + base_architecture as transformer_base_architecture, +) + + +@register_model("transformer_from_pretrained_xlm") +class TransformerFromPretrainedXLMModel(TransformerModel): + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + TransformerModel.add_args(parser) + parser.add_argument( + "--pretrained-xlm-checkpoint", + type=str, + metavar="STR", + help="XLM model to use for initializing transformer encoder and/or decoder", + ) + parser.add_argument( + "--init-encoder-only", + action="store_true", + help="if set, don't load the XLM weights and embeddings into decoder", + ) + parser.add_argument( + "--init-decoder-only", + action="store_true", + help="if set, don't load the XLM weights and embeddings into encoder", + ) + + @classmethod + def build_model(self, args, task, cls_dictionary=MaskedLMDictionary): + assert hasattr(args, "pretrained_xlm_checkpoint"), ( + "You must specify a path for --pretrained-xlm-checkpoint to use " + "--arch transformer_from_pretrained_xlm" + ) + assert isinstance(task.source_dictionary, cls_dictionary) and isinstance( + task.target_dictionary, cls_dictionary + ), ( + "You should use a MaskedLMDictionary when using --arch " + "transformer_from_pretrained_xlm because the pretrained XLM model " + "was trained using data binarized with MaskedLMDictionary. " + "For translation, you may want to use --task " + "translation_from_pretrained_xlm" + ) + assert not ( + getattr(args, "init_encoder_only", False) + and getattr(args, "init_decoder_only", False) + ), "Only one of --init-encoder-only and --init-decoder-only can be set." + return super().build_model(args, task) + + @classmethod + def build_encoder(cls, args, src_dict, embed_tokens): + return TransformerEncoderFromPretrainedXLM(args, src_dict, embed_tokens) + + @classmethod + def build_decoder(cls, args, tgt_dict, embed_tokens): + return TransformerDecoderFromPretrainedXLM(args, tgt_dict, embed_tokens) + + +def upgrade_state_dict_with_xlm_weights( + state_dict: Dict[str, Any], pretrained_xlm_checkpoint: str +) -> Dict[str, Any]: + """ + Load XLM weights into a Transformer encoder or decoder model. + + Args: + state_dict: state dict for either TransformerEncoder or + TransformerDecoder + pretrained_xlm_checkpoint: checkpoint to load XLM weights from + + Raises: + AssertionError: If architecture (num layers, attention heads, etc.) + does not match between the current Transformer encoder or + decoder and the pretrained_xlm_checkpoint + """ + if not os.path.exists(pretrained_xlm_checkpoint): + raise IOError("Model file not found: {}".format(pretrained_xlm_checkpoint)) + + state = checkpoint_utils.load_checkpoint_to_cpu(pretrained_xlm_checkpoint) + xlm_state_dict = state["model"] + for key in xlm_state_dict.keys(): + + for search_key in ["embed_tokens", "embed_positions", "layers"]: + if search_key in key: + subkey = key[key.find(search_key) :] + assert subkey in state_dict, ( + "{} Transformer encoder / decoder " + "state_dict does not contain {}. Cannot " + "load {} from pretrained XLM checkpoint " + "{} into Transformer.".format( + str(state_dict.keys()), subkey, key, pretrained_xlm_checkpoint + ) + ) + + state_dict[subkey] = xlm_state_dict[key] + return state_dict + + +class TransformerEncoderFromPretrainedXLM(TransformerEncoder): + def __init__(self, args, dictionary, embed_tokens): + super().__init__(args, dictionary, embed_tokens) + if getattr(args, "init_decoder_only", False): + # Don't load XLM weights for encoder if --init-decoder-only + return + + assert hasattr(args, "pretrained_xlm_checkpoint"), ( + "--pretrained-xlm-checkpoint must be specified to load Transformer " + "encoder from pretrained XLM" + ) + xlm_loaded_state_dict = upgrade_state_dict_with_xlm_weights( + state_dict=self.state_dict(), + pretrained_xlm_checkpoint=args.pretrained_xlm_checkpoint, + ) + self.load_state_dict(xlm_loaded_state_dict, strict=True) + + +class TransformerDecoderFromPretrainedXLM(TransformerDecoder): + def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): + super().__init__(args, dictionary, embed_tokens, no_encoder_attn) + if getattr(args, "init_encoder_only", False): + # Don't load XLM weights for decoder if --init-encoder-only + return + assert hasattr(args, "pretrained_xlm_checkpoint"), ( + "--pretrained-xlm-checkpoint must be specified to load Transformer " + "decoder from pretrained XLM" + ) + + xlm_loaded_state_dict = upgrade_state_dict_with_xlm_weights( + state_dict=self.state_dict(), + pretrained_xlm_checkpoint=args.pretrained_xlm_checkpoint, + ) + self.load_state_dict(xlm_loaded_state_dict, strict=True) + + +@register_model_architecture( + "transformer_from_pretrained_xlm", "transformer_from_pretrained_xlm" +) +def base_architecture(args): + transformer_base_architecture(args) diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..1e3aa72d38a9970916f3677bb3b7871aceaeafce --- /dev/null +++ b/fairseq/models/transformer_lm.py @@ -0,0 +1,607 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass, field +from typing import Optional + +from omegaconf import II + +from fairseq import options, utils +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.models import ( + FairseqLanguageModel, + register_model, + register_model_architecture, +) +from fairseq.models.transformer import ( + DEFAULT_MIN_PARAMS_TO_WRAP, + Embedding, + TransformerDecoder, +) +from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder +from fairseq.utils import safe_getattr, safe_hasattr + +DEFAULT_MAX_TARGET_POSITIONS = 1024 + + +@dataclass +class TransformerLanguageModelConfig(FairseqDataclass): + activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( + default="relu", metadata={"help": "activation function to use"} + ) + dropout: float = field(default=0.1, metadata={"help": "dropout probability"}) + attention_dropout: float = field( + default=0.0, metadata={"help": "dropout probability for attention weights"} + ) + activation_dropout: float = field( + default=0.0, metadata={"help": "dropout probability after activation in FFN."} + ) + relu_dropout: float = field( + default=0.0, metadata={"help": "dropout probability after activation in FFN."} + ) + decoder_embed_dim: int = field( + default=512, metadata={"help": "decoder embedding dimension"} + ) + decoder_output_dim: int = field( + default=512, metadata={"help": "decoder output dimension"} + ) + decoder_input_dim: int = field( + default=512, metadata={"help": "decoder input dimension"} + ) + decoder_ffn_embed_dim: int = field( + default=2048, metadata={"help": "decoder embedding dimension for FFN"} + ) + decoder_layers: int = field(default=6, metadata={"help": "num decoder layers"}) + decoder_attention_heads: int = field( + default=8, metadata={"help": "num decoder attention heads"} + ) + decoder_normalize_before: bool = field( + default=False, metadata={"help": "apply layernorm before each decoder block"} + ) + no_decoder_final_norm: bool = field( + default=False, + metadata={"help": "don't add an extra layernorm after the last decoder block"}, + ) + adaptive_softmax_cutoff: Optional[str] = field( + default=None, + metadata={ + "help": "comma separated list of adaptive softmax cutoff points. " + "Must be used with adaptive_loss criterion" + }, + ) + adaptive_softmax_dropout: float = field( + default=0, + metadata={"help": "sets adaptive softmax dropout for the tail projections"}, + ) + adaptive_softmax_factor: float = field( + default=4, metadata={"help": "adaptive input factor"} + ) + no_token_positional_embeddings: bool = field( + default=False, + metadata={ + "help": "if set, disables positional embeddings (outside self attention)" + }, + ) + share_decoder_input_output_embed: bool = field( + default=False, metadata={"help": "share decoder input and output embeddings"} + ) + character_embeddings: bool = field( + default=False, + metadata={ + "help": "if set, uses character embedding convolutions to produce token embeddings" + }, + ) + character_filters: str = field( + default="[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]", + metadata={"help": "size of character embeddings"}, + ) + character_embedding_dim: int = field( + default=4, metadata={"help": "size of character embeddings"} + ) + char_embedder_highway_layers: int = field( + default=2, + metadata={"help": "number of highway layers for character token embeddder"}, + ) + adaptive_input: bool = field( + default=False, metadata={"help": "if set, uses adaptive input"} + ) + adaptive_input_factor: float = field( + default=4, metadata={"help": "adaptive input factor"} + ) + adaptive_input_cutoff: Optional[str] = field( + default=None, + metadata={"help": "comma separated list of adaptive input cutoff points."}, + ) + tie_adaptive_weights: bool = field( + default=False, + metadata={ + "help": "if set, ties the weights of adaptive softmax and adaptive input" + }, + ) + tie_adaptive_proj: bool = field( + default=False, + metadata={ + "help": "if set, ties the projection weights of adaptive softmax and adaptive input" + }, + ) + decoder_learned_pos: bool = field( + default=False, + metadata={"help": "use learned positional embeddings in the decoder"}, + ) + layernorm_embedding: bool = field( + default=False, metadata={"help": "add layernorm to embedding"} + ) + no_scale_embedding: bool = field( + default=False, metadata={"help": "if True, dont scale embeddings"} + ) + checkpoint_activations: bool = field( + default=False, metadata={"help": "checkpoint activations at each layer"} + ) + offload_activations: bool = field( + default=False, + metadata={"help": "move checkpointed activations to CPU after they are used."}, + ) + # config for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019) + decoder_layerdrop: float = field( + default=0.0, metadata={"help": "LayerDrop probability for decoder"} + ) + decoder_layers_to_keep: Optional[str] = field( + default=None, + metadata={ + "help": "which layers to *keep* when pruning as a comma-separated list" + }, + ) + # config for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020) + quant_noise_pq: float = field( + default=0.0, + metadata={"help": "iterative PQ quantization noise at training time"}, + ) + quant_noise_pq_block_size: int = field( + default=8, + metadata={"help": "block size of quantization noise at training time"}, + ) + quant_noise_scalar: float = field( + default=0.0, + metadata={ + "help": "scalar quantization noise and scalar quantization at training time" + }, + ) + # config for Fully Sharded Data Parallel (FSDP) training + min_params_to_wrap: int = field( + default=DEFAULT_MIN_PARAMS_TO_WRAP, + metadata={ + "help": ( + "minimum number of params for a layer to be wrapped with FSDP() when " + "training with --ddp-backend=fully_sharded. Smaller values will " + "improve memory efficiency, but may make torch.distributed " + "communication less efficient due to smaller input sizes. This option " + "is set to 0 (i.e., always wrap) when --checkpoint-activations or " + "--offload-activations are passed." + ) + }, + ) + # config for "BASE Layers: Simplifying Training of Large, Sparse Models" + base_layers: Optional[int] = field( + default=0, metadata={"help": "number of BASE layers in total"} + ) + base_sublayers: Optional[int] = field( + default=1, metadata={"help": "number of sublayers in each BASE layer"} + ) + base_shuffle: Optional[int] = field( + default=1, + metadata={"help": "shuffle tokens between workers before computing assignment"}, + ) + # NormFormer + scale_fc: Optional[bool] = field( + default=False, + metadata={"help": "Insert LayerNorm between fully connected layers"}, + ) + scale_attn: Optional[bool] = field( + default=False, metadata={"help": "Insert LayerNorm after attention"} + ) + scale_heads: Optional[bool] = field( + default=False, + metadata={"help": "Learn a scale coefficient for each attention head"}, + ) + scale_resids: Optional[bool] = field( + default=False, + metadata={"help": "Learn a scale coefficient for each residual connection"}, + ) + + # xFormers arguments + decoder_xformers_att_config: Optional[str] = field( + default=None, + metadata={ + "help": "config for xFormers library attention, defined in xformers.components.attention.AttentionConfig", + }, + ) + + # options from other parts of the config + add_bos_token: bool = II("task.add_bos_token") + tokens_per_sample: int = II("task.tokens_per_sample") + max_target_positions: Optional[int] = II("task.max_target_positions") + tpu: bool = II("common.tpu") + + +@register_model("transformer_lm", dataclass=TransformerLanguageModelConfig) +class TransformerLanguageModel(FairseqLanguageModel): + @classmethod + def hub_models(cls): + def moses_fastbpe(path): + return {"path": path, "tokenizer": "moses", "bpe": "fastbpe"} + + def spm(path): + return {"path": path, "tokenizer": "space", "bpe": "sentencepiece"} + + return { + "transformer_lm.gbw.adaptive_huge": "https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2", + "transformer_lm.wiki103.adaptive": "https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.v2.tar.bz2", + "transformer_lm.wmt19.en": moses_fastbpe( + "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.en.tar.bz2" + ), + "transformer_lm.wmt19.de": moses_fastbpe( + "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.de.tar.bz2" + ), + "transformer_lm.wmt19.ru": moses_fastbpe( + "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.ru.tar.bz2" + ), + "transformer_lm.wmt20.en": spm( + "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt20.en.tar.gz" + ), + "transformer_lm.wmt20.ta": spm( + "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt20.ta.tar.gz" + ), + "transformer_lm.wmt20.iu.news": spm( + "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt20.iu.news.tar.gz" + ), + "transformer_lm.wmt20.iu.nh": spm( + "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt20.iu.nh.tar.gz" + ), + } + + def __init__(self, decoder): + super().__init__(decoder) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + if args.decoder_layers_to_keep: + args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) + + if safe_getattr(args, "max_target_positions", None) is None: + args.max_target_positions = safe_getattr( + args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS + ) + + if args.character_embeddings: + embed_tokens = CharacterTokenEmbedder( + task.source_dictionary, + eval(args.character_filters), + args.character_embedding_dim, + args.decoder_embed_dim, + args.char_embedder_highway_layers, + ) + elif args.adaptive_input: + embed_tokens = AdaptiveInput( + len(task.source_dictionary), + task.source_dictionary.pad(), + args.decoder_input_dim, + args.adaptive_input_factor, + args.decoder_embed_dim, + options.eval_str_list(args.adaptive_input_cutoff, type=int), + args.quant_noise_pq, + args.quant_noise_pq_block_size, + ) + else: + embed_tokens = cls.build_embedding( + args, task.source_dictionary, args.decoder_input_dim + ) + + if args.tie_adaptive_weights: + assert args.adaptive_input + assert args.adaptive_input_factor == args.adaptive_softmax_factor + assert ( + args.adaptive_softmax_cutoff == args.adaptive_input_cutoff + ), "{} != {}".format( + args.adaptive_softmax_cutoff, args.adaptive_input_cutoff + ) + assert args.decoder_input_dim == args.decoder_output_dim + + decoder = TransformerDecoder( + args, task.target_dictionary, embed_tokens, no_encoder_attn=True + ) + return cls(decoder) + + @classmethod + def build_embedding(cls, args, dictionary, embed_dim, path=None): + embed_tokens = Embedding(len(dictionary), embed_dim, dictionary.pad()) + return embed_tokens + + +def base_lm_architecture(args): + # backward compatibility for older model checkpoints + if safe_hasattr(args, "no_tie_adaptive_proj"): + # previous models defined --no-tie-adaptive-proj, so use the existence of + # that option to determine if this is an "old" model checkpoint + args.no_decoder_final_norm = True # old models always set this to True + if args.no_tie_adaptive_proj is False: + args.tie_adaptive_proj = True + if safe_hasattr(args, "decoder_final_norm"): + args.no_decoder_final_norm = not args.decoder_final_norm + + args.dropout = safe_getattr(args, "dropout", 0.1) + args.attention_dropout = safe_getattr(args, "attention_dropout", 0.0) + + args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 512) + args.decoder_ffn_embed_dim = safe_getattr(args, "decoder_ffn_embed_dim", 2048) + args.decoder_layers = safe_getattr(args, "decoder_layers", 6) + args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 8) + args.adaptive_softmax_cutoff = safe_getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = safe_getattr(args, "adaptive_softmax_dropout", 0) + args.adaptive_softmax_factor = safe_getattr(args, "adaptive_softmax_factor", 4) + args.decoder_learned_pos = safe_getattr(args, "decoder_learned_pos", False) + args.activation_fn = safe_getattr(args, "activation_fn", "relu") + + args.decoder_layerdrop = safe_getattr(args, "decoder_layerdrop", 0) + args.decoder_layers_to_keep = safe_getattr(args, "decoder_layers_to_keep", None) + args.quant_noise_pq = safe_getattr(args, "quant_noise_pq", 0) + args.quant_noise_pq_block_size = safe_getattr(args, "quant_noise_pq_block_size", 8) + args.quant_noise_scalar = safe_getattr(args, "quant_noise_scalar", 0) + + args.base_layers = safe_getattr(args, "base_layers", 0) + args.base_sublayers = safe_getattr(args, "base_sublayers", 1) + args.base_shuffle = safe_getattr(args, "base_shuffle", False) + + args.add_bos_token = safe_getattr(args, "add_bos_token", False) + args.no_token_positional_embeddings = safe_getattr( + args, "no_token_positional_embeddings", False + ) + args.share_decoder_input_output_embed = safe_getattr( + args, "share_decoder_input_output_embed", False + ) + args.character_embeddings = safe_getattr(args, "character_embeddings", False) + + args.decoder_output_dim = safe_getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = safe_getattr( + args, "decoder_input_dim", args.decoder_embed_dim + ) + + # Model training is not stable without this + args.decoder_normalize_before = True + args.no_decoder_final_norm = safe_getattr(args, "no_decoder_final_norm", False) + + args.adaptive_input = safe_getattr(args, "adaptive_input", False) + args.adaptive_input_factor = safe_getattr(args, "adaptive_input_factor", 4) + args.adaptive_input_cutoff = safe_getattr(args, "adaptive_input_cutoff", None) + + args.tie_adaptive_weights = safe_getattr(args, "tie_adaptive_weights", False) + args.tie_adaptive_proj = safe_getattr(args, "tie_adaptive_proj", False) + + args.no_scale_embedding = safe_getattr(args, "no_scale_embedding", False) + args.layernorm_embedding = safe_getattr(args, "layernorm_embedding", False) + args.checkpoint_activations = safe_getattr(args, "checkpoint_activations", False) + args.offload_activations = safe_getattr(args, "offload_activations", False) + args.scale_fc = safe_getattr(args, "scale_fc", False) + args.scale_attn = safe_getattr(args, "scale_attn", False) + args.scale_heads = safe_getattr(args, "scale_heads", False) + args.scale_resids = safe_getattr(args, "scale_resids", False) + if args.offload_activations: + args.checkpoint_activations = True + + +@register_model_architecture("transformer_lm", "transformer_lm_big") +def transformer_lm_big(args): + args.decoder_layers = safe_getattr(args, "decoder_layers", 12) + args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 1024) + args.decoder_ffn_embed_dim = safe_getattr(args, "decoder_ffn_embed_dim", 4096) + args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 16) + base_lm_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_wiki103") +@register_model_architecture("transformer_lm", "transformer_lm_baevski_wiki103") +def transformer_lm_baevski_wiki103(args): + args.decoder_layers = safe_getattr(args, "decoder_layers", 16) + args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 8) + args.dropout = safe_getattr(args, "dropout", 0.3) + args.adaptive_input = safe_getattr(args, "adaptive_input", True) + args.tie_adaptive_weights = safe_getattr(args, "tie_adaptive_weights", True) + args.adaptive_input_cutoff = safe_getattr( + args, "adaptive_input_cutoff", "20000,60000" + ) + args.adaptive_softmax_cutoff = safe_getattr( + args, "adaptive_softmax_cutoff", "20000,60000" + ) + args.adaptive_softmax_dropout = safe_getattr(args, "adaptive_softmax_dropout", 0.2) + args.attention_dropout = safe_getattr(args, "attention_dropout", 0.1) + args.activation_dropout = safe_getattr(args, "activation_dropout", 0.1) + args.no_decoder_final_norm = safe_getattr(args, "no_decoder_final_norm", True) + args.tie_adaptive_proj = safe_getattr(args, "tie_adaptive_proj", True) + transformer_lm_big(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gbw") +@register_model_architecture("transformer_lm", "transformer_lm_baevski_gbw") +def transformer_lm_baevski_gbw(args): + args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 512) + args.dropout = safe_getattr(args, "dropout", 0.1) + args.attention_dropout = safe_getattr(args, "attention_dropout", 0.1) + args.no_decoder_final_norm = safe_getattr(args, "no_decoder_final_norm", True) + transformer_lm_big(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt") +def transformer_lm_gpt(args): + args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 768) + args.decoder_ffn_embed_dim = safe_getattr(args, "decoder_ffn_embed_dim", 3072) + args.decoder_layers = safe_getattr(args, "decoder_layers", 12) + args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 12) + args.dropout = safe_getattr(args, "dropout", 0.1) + args.attention_dropout = safe_getattr(args, "attention_dropout", 0.1) + args.activation_fn = safe_getattr(args, "activation_fn", "gelu") + base_lm_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt2_small") +def transformer_lm_gpt2_small(args): + args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 1024) + args.decoder_ffn_embed_dim = safe_getattr(args, "decoder_ffn_embed_dim", 4096) + args.decoder_layers = safe_getattr(args, "decoder_layers", 24) + args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 16) + args.dropout = safe_getattr(args, "dropout", 0.1) + args.attention_dropout = safe_getattr(args, "attention_dropout", 0.1) + args.activation_fn = safe_getattr(args, "activation_fn", "gelu") + base_lm_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt2_tiny") +def transformer_lm_gpt2_tiny(args): + args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 64) + args.decoder_ffn_embed_dim = safe_getattr(args, "decoder_ffn_embed_dim", 64) + args.decoder_layers = safe_getattr(args, "decoder_layers", 2) + args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 1) + args.dropout = safe_getattr(args, "dropout", 0.1) + args.attention_dropout = safe_getattr(args, "attention_dropout", 0.1) + args.activation_fn = safe_getattr(args, "activation_fn", "gelu") + base_lm_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt2_medium") +def transformer_lm_gpt2_medium(args): + args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 1280) + args.decoder_ffn_embed_dim = safe_getattr(args, "decoder_ffn_embed_dim", 5120) + args.decoder_layers = safe_getattr(args, "decoder_layers", 36) + args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 20) + args.dropout = safe_getattr(args, "dropout", 0.1) + args.attention_dropout = safe_getattr(args, "attention_dropout", 0.1) + args.activation_fn = safe_getattr(args, "activation_fn", "gelu") + base_lm_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt2_big") +def transformer_lm_gpt2_big(args): + args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 1600) + args.decoder_ffn_embed_dim = safe_getattr(args, "decoder_ffn_embed_dim", 6400) + args.decoder_layers = safe_getattr(args, "decoder_layers", 48) + args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 25) + args.dropout = safe_getattr(args, "dropout", 0.1) + args.attention_dropout = safe_getattr(args, "attention_dropout", 0.1) + args.activation_fn = safe_getattr(args, "activation_fn", "gelu") + base_lm_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt2_big_wide") +def transformer_lm_gpt2_big_wide(args): + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8192) + args.decoder_layers = getattr(args, "decoder_layers", 24) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_fn = getattr(args, "activation_fn", "gelu") + base_lm_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt2_bigger") +def transformer_lm_gpt2_bigger(args): + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8192) + args.decoder_layers = getattr(args, "decoder_layers", 48) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_fn = getattr(args, "activation_fn", "gelu") + base_lm_architecture(args) + + +def base_gpt3_architecture(args): + args.decoder_input_dim = args.decoder_embed_dim + args.decoder_output_dim = args.decoder_embed_dim + args.decoder_ffn_embed_dim = safe_getattr( + args, "decoder_ffn_embed_dim", args.decoder_embed_dim * 4 + ) + # GPT-3 used learned positional embeddings, rather than sinusoidal + args.decoder_learned_pos = safe_getattr(args, "decoder_learned_pos", True) + args.dropout = safe_getattr(args, "dropout", 0.0) + args.attention_dropout = safe_getattr(args, "attention_dropout", 0.0) + args.activation_fn = safe_getattr(args, "activation_fn", "gelu") + args.share_decoder_input_output_embed = True + base_lm_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_small") +def transformer_lm_gpt3_small(args): + # 125M params + args.decoder_layers = safe_getattr(args, "decoder_layers", 12) + args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 768) + args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 12) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_medium") +def transformer_lm_gpt3_medium(args): + # 350M params + args.decoder_layers = safe_getattr(args, "decoder_layers", 24) + args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 1024) + args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 16) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_large") +def transformer_lm_gpt3_large(args): + # 760M params + args.decoder_layers = safe_getattr(args, "decoder_layers", 24) + args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 1536) + args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 16) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_xl") +def transformer_lm_gpt3_xl(args): + # 1.3B params + args.decoder_layers = safe_getattr(args, "decoder_layers", 24) + args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 2048) + args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 32) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_2_7") +def transformer_lm_gpt3_2_7(args): + # 2.7B params + args.decoder_layers = safe_getattr(args, "decoder_layers", 32) + args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 2560) + args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 32) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_6_7") +def transformer_lm_gpt3_6_7(args): + # 6.7B params + args.decoder_layers = safe_getattr(args, "decoder_layers", 32) + args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 4096) + args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 32) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_13") +def transformer_lm_gpt3_13(args): + # 13B params + args.decoder_layers = safe_getattr(args, "decoder_layers", 40) + args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 5120) + args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 40) + base_gpt3_architecture(args) + + +@register_model_architecture("transformer_lm", "transformer_lm_gpt3_175") +def transformer_lm_gpt3_175(args): + # 175B params + args.decoder_layers = safe_getattr(args, "decoder_layers", 96) + args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 12288) + args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 96) + base_gpt3_architecture(args) diff --git a/fairseq/models/transformer_ulm.py b/fairseq/models/transformer_ulm.py new file mode 100644 index 0000000000000000000000000000000000000000..0fc9ae434873295b3904ea91a94b7332879d06a2 --- /dev/null +++ b/fairseq/models/transformer_ulm.py @@ -0,0 +1,408 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass, field +from fairseq.models.fairseq_decoder import FairseqDecoder +import numpy as np +from typing import Optional, Dict, Any, List +import torch +from torch import nn +from fairseq.data.data_utils import compute_mask_indices +from fairseq.dataclass import ChoiceEnum +from fairseq.models import ( + FairseqLanguageModel, + register_model, + register_model_architecture, +) +from fairseq.tasks.speech_ulm_task import SpeechUnitLanguageModelingTask +from fairseq.models.transformer import Embedding, TransformerDecoder, Linear +from fairseq.models.transformer_lm import TransformerLanguageModelConfig +from torch import Tensor + + +DEFAULT_MAX_TARGET_POSITIONS = 1024 +MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"]) + + +@dataclass +class SpeechUnitLanguageModelConfig(TransformerLanguageModelConfig): + mask_unit_seg_prob: float = field( + default=0.0, metadata={"help": "probability to mask a segment of unit sequence"} + ) + mask_unit_seg_leng: int = field( + default=5, metadata={"help": "length of unit segment mask"} + ) + mask_unit_seg_type: MASKING_DISTRIBUTION_CHOICES = field( + default="static", metadata={"help": "how to choose unit mask length"} + ) + + mask_dur_prob: float = field( + default=0.0, metadata={"help": "probability to mask entire duration sequence"} + ) + mask_dur_seg_prob: float = field( + default=0.0, + metadata={"help": "probability to mask a segment of duration sequence"}, + ) + mask_dur_seg_leng: int = field( + default=5, metadata={"help": "length of duration segment mask"} + ) + mask_dur_seg_type: MASKING_DISTRIBUTION_CHOICES = field( + default="static", metadata={"help": "how to choose duration mask length"} + ) + + mask_f0_prob: float = field( + default=0.0, metadata={"help": "probability to mask entire duration sequence"} + ) + mask_f0_seg_prob: float = field( + default=0.0, metadata={"help": "probability to mask a segment of f0 sequence"} + ) + mask_f0_seg_leng: int = field( + default=5, metadata={"help": "length of f0 segment mask"} + ) + mask_f0_seg_type: MASKING_DISTRIBUTION_CHOICES = field( + default="static", metadata={"help": "how to choose f0 mask length"} + ) + + +@register_model("transformer_ulm", dataclass=SpeechUnitLanguageModelConfig) +class TransformerUnitLanguageModel(FairseqLanguageModel): + def __init__( + self, + cfg: SpeechUnitLanguageModelConfig, + task: SpeechUnitLanguageModelingTask, + decoder: FairseqDecoder, + ): + super().__init__(decoder) + self.cfg = cfg + + self.channel_names = task.channel_names + self.channel_sizes = task.channel_sizes + + self.unit_mask_val = task.source_dictionary.unk() + self.dur_mask_val = ( + task.source_duration_dictionary.unk() if task.cfg.discrete_duration else 0 + ) + self.f0_mask_val = ( + task.source_f0_dictionary.unk() if task.cfg.discrete_f0 else 0 + ) + + self.ignore_duration_input = task.cfg.ignore_duration_input + self.ignore_f0_input = task.cfg.ignore_f0_input + + @classmethod + def build_model(cls, args, task): + base_ulm_architecture(args) + + if getattr(args, "max_target_positions", None) is None: + args.max_target_positions = getattr( + args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS + ) + + embed_tokens = Embedding( + len(task.source_dictionary), + args.decoder_input_dim, + padding_idx=task.source_dictionary.pad(), + ) + embed_duration = None + if task.cfg.discrete_duration: + embed_duration = Embedding( + len(task.source_duration_dictionary), + args.decoder_input_dim, + padding_idx=0, # duration uses 0 for padding + ) + embed_f0 = None + if task.cfg.discrete_f0: + embed_f0 = Embedding( + len(task.source_f0_dictionary), + args.decoder_input_dim, + padding_idx=task.source_f0_dictionary.pad(), + ) + + decoder = MultiStreamTransformerDecoder( + args, + task.target_dictionary, + embed_tokens, + [embed_duration, embed_f0], + no_encoder_attn=True, + channel_sizes=task.channel_sizes, + ) + + return cls(args, task, decoder) + + def apply_seg_dropout(self, inp, mask_prob, mask_leng, mask_type, mask_val): + B, T = inp.size() + if mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), None, mask_prob, mask_leng, mask_type # may mask padding + ) + mask_indices = torch.from_numpy(mask_indices).to(inp.device) + inp[mask_indices] = mask_val + else: + mask_indices = torch.zeros_like(inp).bool() + return inp, mask_indices + + def apply_seq_dropout(self, inp, mask_prob, mask_val): + B, T = inp.size() + if mask_prob > 0: + mask_indices = np.random.uniform(0, 1, (B,)) < mask_prob + mask_indices = ( + torch.from_numpy(mask_indices).to(inp.device).unsqueeze(1).expand(-1, T) + ) + inp[mask_indices] = mask_val + else: + mask_indices = torch.zeros_like(inp).bool() + return inp, mask_indices + + def apply_dropout(self, src_tokens, dur_src, f0_src): + src_tokens, unit_mask = self.apply_seg_dropout( + src_tokens, + self.cfg.mask_unit_seg_prob, + self.cfg.mask_unit_seg_leng, + self.cfg.mask_unit_seg_type, + self.unit_mask_val, + ) + + dur_src, dur_mask = self.apply_seq_dropout( + dur_src, self.cfg.mask_dur_prob, self.dur_mask_val + ) + dur_src, _dur_mask = self.apply_seg_dropout( + dur_src, + self.cfg.mask_dur_seg_prob, + self.cfg.mask_dur_seg_leng, + self.cfg.mask_dur_seg_type, + self.dur_mask_val, + ) + dur_mask = dur_mask.logical_or(_dur_mask) + + f0_src, f0_mask = self.apply_seq_dropout( + f0_src, self.cfg.mask_f0_prob, self.f0_mask_val + ) + f0_src, _f0_mask = self.apply_seg_dropout( + f0_src, + self.cfg.mask_f0_seg_prob, + self.cfg.mask_f0_seg_leng, + self.cfg.mask_f0_seg_type, + self.f0_mask_val, + ) + f0_mask = f0_mask.logical_or(_f0_mask) + + return src_tokens, unit_mask, dur_src, dur_mask, f0_src, f0_mask + + def forward( + self, + src_tokens: torch.Tensor, + dur_src: torch.Tensor, + f0_src: torch.Tensor, + src_lengths: Optional[Any] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + ): + if self.ignore_duration_input: + dur_src = torch.zeros_like(dur_src) + + if self.ignore_f0_input: + f0_src = torch.zeros_like(f0_src) + + if self.training: + ( + src_tokens, + unit_mask, + dur_src, + dur_mask, + f0_src, + f0_mask, + ) = self.apply_dropout(src_tokens, dur_src, f0_src) + else: + unit_masks = dur_mask = f0_mask = None + + prediction, _ = self.decoder( + prev_output_tokens=(src_tokens, dur_src, f0_src), + incremental_state=incremental_state, + src_lengths=src_lengths, + features_only=True, + ) + + result = dict(zip(self.channel_names, prediction)) + + return result + + +def base_ulm_architecture(args): + from .transformer_lm import base_lm_architecture + + base_lm_architecture(args) + + +@register_model_architecture("transformer_ulm", "transformer_ulm_big") +def transformer_ulm_big(args): + from .transformer_lm import transformer_lm_big + + transformer_lm_big(args) + base_ulm_architecture(args) + + +@register_model_architecture("transformer_ulm", "transformer_ulm_tiny") +def transformer_ulm_tiny(args): + from .transformer_lm import transformer_lm_gpt2_tiny + + transformer_lm_gpt2_tiny(args) + base_ulm_architecture(args) + + +class MultiStreamTransformerDecoder(TransformerDecoder): + def __init__( + self, + args, + dictionary, + embed_tokens, + embed_other_list, + no_encoder_attn, + channel_sizes, + ): + super().__init__( + args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn + ) + + # embed each channel and project if dimensions do not match + self.embed_other_list = torch.nn.ModuleList(embed_other_list) + self.proj_other_list = torch.nn.ModuleList() + dim = embed_tokens.embedding_dim + for embed_other in embed_other_list: + other_dim = 1 if embed_other is None else embed_other.embedding_dim + self.proj_other_list.append( + nn.Linear(other_dim, dim) if other_dim != dim else None + ) + + # tranformer output to prediction + self.channel_sizes = channel_sizes + self.project_out_dim = Linear( + embed_tokens.embedding_dim, sum(channel_sizes), bias=False + ) + + def extract_features_scriptable( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]], + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + if alignment_layer is None: + alignment_layer = self.num_layers - 1 + + # XXX: first multi-channel change start + prev_output_tokens, *other_channels = prev_output_tokens + # XXX: first multi-channel change end + + # embed positions + positions = None + if self.embed_positions is not None: + positions = self.embed_positions( + prev_output_tokens, incremental_state=incremental_state + ) + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + other_channels = [o[:, -1:] for o in other_channels] + if positions is not None: + positions = positions[:, -1:] + + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + + # XXX: second multi-channel change start + other_channels = [ + o.unsqueeze(-1).to(dtype=x.dtype) if emb is None else emb(o) + for o, emb in zip(other_channels, self.embed_other_list) + ] + other_channels = [ + o if proj_other is None else proj_other(o) + for o, proj_other in zip(other_channels, self.proj_other_list) + ] + for o in other_channels: + x = x + o + # XXX: second multi-channel change end + + if self.quant_noise is not None: + x = self.quant_noise(x) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if positions is not None: + x += positions + + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + + x = self.dropout_module(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + self_attn_padding_mask: Optional[Tensor] = None + if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any(): + self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) + + # decoder layers + attn: Optional[Tensor] = None + inner_states: List[Optional[Tensor]] = [x] + for idx, layer in enumerate(self.layers): + if incremental_state is None and not full_context_alignment: + self_attn_mask = self.buffered_future_mask(x) + else: + self_attn_mask = None + + x, layer_attn, _ = layer( + x, + encoder_out["encoder_out"][0] + if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0) + else None, + encoder_out["encoder_padding_mask"][0] + if ( + encoder_out is not None + and len(encoder_out["encoder_padding_mask"]) > 0 + ) + else None, + incremental_state, + self_attn_mask=self_attn_mask, + self_attn_padding_mask=self_attn_padding_mask, + need_attn=bool((idx == alignment_layer)), + need_head_weights=bool((idx == alignment_layer)), + ) + inner_states.append(x) + if layer_attn is not None and idx == alignment_layer: + attn = layer_attn.float().to(x) + + if attn is not None: + if alignment_heads is not None: + attn = attn[:alignment_heads] + + # average probabilities over heads + attn = attn.mean(dim=0) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + if self.project_out_dim is not None: + x = self.project_out_dim(x) + else: + assert False + + # XXX: the last change start + result = [] + start = 0 + for channel_size in self.channel_sizes: + end = start + channel_size + result.append(x[:, :, start:end]) + start = end + assert end == x.size(-1) + # XXX: the last change end + + return result, {"attn": [attn], "inner_states": inner_states} diff --git a/fairseq/models/wav2vec/__init__.py b/fairseq/models/wav2vec/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b756e4580b82b42c735d9e097c62f63cb7798f5e --- /dev/null +++ b/fairseq/models/wav2vec/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .wav2vec import * # noqa +from .wav2vec2 import * # noqa +from .wav2vec2_asr import * # noqa +from .wav2vec2_laser import * # noqa +from .wav2vec2_classification import * # noqa diff --git a/fairseq/models/wav2vec/__pycache__/__init__.cpython-310.pyc b/fairseq/models/wav2vec/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11a4a76be5983c8e23c1a783d63f9e619736674a Binary files /dev/null and b/fairseq/models/wav2vec/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/models/wav2vec/__pycache__/utils.cpython-310.pyc b/fairseq/models/wav2vec/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ea91e2abfa45b34eb65b9f9fa00a248fc0abaa4 Binary files /dev/null and b/fairseq/models/wav2vec/__pycache__/utils.cpython-310.pyc differ diff --git a/fairseq/models/wav2vec/__pycache__/wav2vec.cpython-310.pyc b/fairseq/models/wav2vec/__pycache__/wav2vec.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c20cb7730e12fb392c73cac6b873d1fcbcb0f99 Binary files /dev/null and b/fairseq/models/wav2vec/__pycache__/wav2vec.cpython-310.pyc differ diff --git a/fairseq/models/wav2vec/__pycache__/wav2vec2.cpython-310.pyc b/fairseq/models/wav2vec/__pycache__/wav2vec2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e47a4d62a80fdee86a56ec4773aed36150c999e Binary files /dev/null and b/fairseq/models/wav2vec/__pycache__/wav2vec2.cpython-310.pyc differ diff --git a/fairseq/models/wav2vec/__pycache__/wav2vec2_asr.cpython-310.pyc b/fairseq/models/wav2vec/__pycache__/wav2vec2_asr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..041f021931d3d224cc232dea5748e47d859721e9 Binary files /dev/null and b/fairseq/models/wav2vec/__pycache__/wav2vec2_asr.cpython-310.pyc differ diff --git a/fairseq/models/wav2vec/__pycache__/wav2vec2_classification.cpython-310.pyc b/fairseq/models/wav2vec/__pycache__/wav2vec2_classification.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b3f492c5c515ca1f2187b13132421ae2fd2d3f6 Binary files /dev/null and b/fairseq/models/wav2vec/__pycache__/wav2vec2_classification.cpython-310.pyc differ diff --git a/fairseq/models/wav2vec/__pycache__/wav2vec2_laser.cpython-310.pyc b/fairseq/models/wav2vec/__pycache__/wav2vec2_laser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07782f5eaf5dc666fe45584d3f22136260ba466f Binary files /dev/null and b/fairseq/models/wav2vec/__pycache__/wav2vec2_laser.cpython-310.pyc differ diff --git a/fairseq/models/wav2vec/utils.py b/fairseq/models/wav2vec/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dd52d86242ec5289b480cb085b50a68a713cc306 --- /dev/null +++ b/fairseq/models/wav2vec/utils.py @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import torch.nn.functional as F + + +def pad_to_multiple(x, multiple, dim=-1, value=0): + # Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41 + if x is None: + return None, 0 + tsz = x.size(dim) + m = tsz / multiple + remainder = math.ceil(m) * multiple - tsz + if m.is_integer(): + return x, 0 + pad_offset = (0,) * (-1 - dim) * 2 + + return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder diff --git a/fairseq/models/wav2vec/wav2vec.py b/fairseq/models/wav2vec/wav2vec.py new file mode 100644 index 0000000000000000000000000000000000000000..af6604da10f504baabff50bf14a6eb2214bffef3 --- /dev/null +++ b/fairseq/models/wav2vec/wav2vec.py @@ -0,0 +1,630 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +import logging +import math +from typing import Optional, Tuple +from omegaconf import II +import sys + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.models import BaseFairseqModel, register_model +from fairseq.modules import ( + Fp32GroupNorm, + Fp32LayerNorm, + GumbelVectorQuantizer, + KmeansVectorQuantizer, + TransposeLast, +) +from fairseq.tasks import FairseqTask +from fairseq.utils import buffered_arange + + +logger = logging.getLogger(__name__) + + +AGGREGATOR_CHOICES = ChoiceEnum(["cnn", "gru"]) +PROJECT_FEATURES_CHOICES = ChoiceEnum(["none", "same", "new"]) +ACTIVATION_CHOICES = ChoiceEnum(["relu", "gelu"]) +VQ_TYPE_CHOICES = ChoiceEnum(["none", "gumbel", "kmeans"]) + + +@dataclass +class Wav2VecConfig(FairseqDataclass): + prediction_steps: int = field( + default=12, metadata={"help": "number of steps ahead to predict"} + ) + sample_distance: Optional[int] = field( + default=None, + metadata={ + "help": "sample distance from target. does not work properly with cross-sampling" + }, + ) + cross_sample_negatives: int = field( + default=0, metadata={"help": "num of cross sampled negatives"} + ) + num_negatives: int = field( + default=10, metadata={"help": "num of sampled negatives"} + ) + conv_feature_layers: str = field( + default="[(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1), (512, 1, 1)]", + metadata={ + "help": "convolutional feature extraction layers [(dim, kernel_size, stride), ...]" + }, + ) + conv_aggregator_layers: str = field( + default="[(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)]", + metadata={ + "help": "convolutional aggregator layers [(dim, kernel_size, stride), ...]" + }, + ) + dropout: float = field( + default=0.0, metadata={"help": "dropout to apply within the model"} + ) + dropout_features: float = field( + default=0.0, metadata={"help": "dropout to apply to the features"} + ) + dropout_agg: float = field( + default=0.0, metadata={"help": "dropout to apply after aggregation step"} + ) + aggregator: AGGREGATOR_CHOICES = field( + default="cnn", metadata={"help": "type of aggregator to use"} + ) + gru_dim: int = field(default=512, metadata={"help": "GRU dimensionality"}) + no_conv_bias: bool = field( + default=False, metadata={"help": "if set, does not learn bias for conv layers"} + ) + agg_zero_pad: bool = field( + default=False, + metadata={"help": "if set, zero pads in aggregator instead of repl pad"}, + ) + skip_connections_feat: bool = field( + default=False, + metadata={"help": "if set, adds skip connections to the feature extractor"}, + ) + skip_connections_agg: bool = field( + default=True, + metadata={"help": "if set, adds skip connections to the aggregator"}, + ) + residual_scale: float = field( + default=0.5, metadata={"help": "scales residual by sqrt(value)"} + ) + log_compression: bool = field( + default=True, + metadata={"help": "if set, adds a log compression to feature extractor"}, + ) + balanced_classes: bool = field( + default=False, + metadata={"help": "if set, loss is scaled to balance for number of negatives"}, + ) + project_features: PROJECT_FEATURES_CHOICES = field( + default="none", + metadata={ + "help": "if not none, features are projected using the (same or new) aggregator" + }, + ) + non_affine_group_norm: bool = field( + default=False, metadata={"help": "if set, group norm is not affine"} + ) + offset: str = field( + default="auto", + metadata={ + "help": "if set to 'auto', it is computed automatically from the receptive field, else set to int value" + }, + ) + activation: ACTIVATION_CHOICES = field( + default="relu", + metadata={ + "help": "if set to 'auto', it is computed automatically from the receptive field, else set to int value" + }, + ) + vq_type: VQ_TYPE_CHOICES = field( + default="none", metadata={"help": "which type of quantizer to use"} + ) + vq_vars: int = field( + default=320, + metadata={"help": "project to this many vector quantized variables per group"}, + ) + vq_groups: int = field( + default=2, metadata={"help": "number of groups of latent variables"} + ) + vq_dim: int = field( + default=0, + metadata={ + "help": "uses this dimensionality for quantized vectors. 0 to use model dim // groups" + }, + ) + vq_depth: int = field( + default=1, metadata={"help": "number of layers for vq weight projection"} + ) + combine_groups: bool = field( + default=False, metadata={"help": "if set, variables are shared among groups"} + ) + vq_temp: Tuple[float, float, float] = field( + default=(2.0, 0.5, 0.999995), + metadata={ + "help": "temperature for latent variable sampling with gumbel softmax. should be a tuple of 3 values (start, end, decay)" + }, + ) + vq_gamma: float = field( + default=0.25, + metadata={"help": "gamma parameter for kmeans style vector quantization"}, + ) + infonce: bool = II("criterion.infonce") + + +@register_model("wav2vec", dataclass=Wav2VecConfig) +class Wav2VecModel(BaseFairseqModel): + @classmethod + def build_model(cls, cfg: Wav2VecConfig, task: FairseqTask): + """Build a new model instance.""" + + model = Wav2VecModel(cfg) + logger.info(model) + return model + + def __init__(self, cfg: Wav2VecConfig): + super().__init__() + + self.prediction_steps = cfg.prediction_steps + offset = cfg.offset + + if cfg.activation == "relu": + activation = nn.ReLU() + elif cfg.activation == "gelu": + activation = nn.GELU() + else: + raise Exception("unknown activation " + cfg.activation) + + feature_enc_layers = eval(cfg.conv_feature_layers) + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + log_compression=cfg.log_compression, + skip_connections=cfg.skip_connections_feat, + residual_scale=cfg.residual_scale, + non_affine_group_norm=cfg.non_affine_group_norm, + activation=activation, + ) + embed = feature_enc_layers[-1][0] + + self.vector_quantizer = None + if cfg.vq_type == "gumbel": + self.vector_quantizer = GumbelVectorQuantizer( + dim=embed, + num_vars=cfg.vq_vars, + temp=cfg.vq_temp, + groups=cfg.vq_groups, + combine_groups=cfg.combine_groups, + vq_dim=cfg.vq_dim if cfg.vq_dim > 0 else embed, + time_first=False, + activation=activation, + weight_proj_depth=cfg.vq_depth, + weight_proj_factor=2, + ) + elif cfg.vq_type == "kmeans": + self.vector_quantizer = KmeansVectorQuantizer( + dim=embed, + num_vars=cfg.vq_vars, + groups=cfg.vq_groups, + combine_groups=cfg.combine_groups, + vq_dim=cfg.vq_dim if cfg.vq_dim > 0 else embed, + time_first=False, + gamma=cfg.vq_gamma, + ) + else: + assert ( + cfg.vq_type == "none" or cfg.vq_type is None + ), "Unknown quantizer type" + + if cfg.offset == "auto": + jin = 0 + rin = 0 + for _, k, stride in feature_enc_layers: + if rin == 0: + rin = k + rin = rin + (k - 1) * jin + if jin == 0: + jin = stride + else: + jin *= stride + offset = math.ceil(rin / jin) + + offset = int(offset) + + def make_aggregator(): + if cfg.aggregator == "cnn": + agg_layers = eval(cfg.conv_aggregator_layers) + agg_dim = agg_layers[-1][0] + feature_aggregator = ConvAggegator( + conv_layers=agg_layers, + embed=embed, + dropout=cfg.dropout, + skip_connections=cfg.skip_connections_agg, + residual_scale=cfg.residual_scale, + non_affine_group_norm=cfg.non_affine_group_norm, + conv_bias=not cfg.no_conv_bias, + zero_pad=cfg.agg_zero_pad, + activation=activation, + ) + elif cfg.aggregator == "gru": + agg_dim = cfg.gru_dim + feature_aggregator = nn.Sequential( + TransposeLast(), + nn.GRU( + input_size=embed, + hidden_size=agg_dim, + num_layers=1, + dropout=cfg.dropout, + ), + TransposeLast(deconstruct_idx=0), + ) + else: + raise Exception("unknown aggregator type " + cfg.aggregator) + + return feature_aggregator, agg_dim + + self.feature_aggregator, agg_dim = make_aggregator() + + self.wav2vec_predictions = Wav2VecPredictionsModel( + in_dim=agg_dim, + out_dim=embed, + prediction_steps=cfg.prediction_steps, + n_negatives=cfg.num_negatives, + cross_sample_negatives=cfg.cross_sample_negatives, + sample_distance=cfg.sample_distance, + dropout=cfg.dropout, + offset=offset, + balanced_classes=cfg.balanced_classes, + infonce=cfg.infonce, + ) + + self.dropout_feats = nn.Dropout(p=cfg.dropout_features) + self.dropout_agg = nn.Dropout(p=cfg.dropout_agg) + + if cfg.project_features == "none": + self.project_features = None + elif cfg.project_features == "same": + self.project_features = self.feature_aggregator + elif cfg.project_features == "new": + self.project_features, _ = make_aggregator() + + def forward(self, source): + result = {} + + features = self.feature_extractor(source) + if self.vector_quantizer: + q_res = self.vector_quantizer(features) + features = q_res["x"] + for k in q_res.keys(): + if k != "x": + result[k] = q_res[k] + + x = self.dropout_feats(features) + x = self.feature_aggregator(x) + x = self.dropout_agg(x) + + if self.project_features is not None: + features = self.project_features(features) + x, targets = self.wav2vec_predictions(x, features) + result["cpc_logits"] = x + result["cpc_targets"] = targets + + return result + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + + def max_positions(self): + """Maximum length supported by the model.""" + return sys.maxsize + + def get_logits(self, net_output): + logits = net_output["cpc_logits"] + return logits + + def get_targets(self, sample, net_output): + t = net_output["cpc_targets"] + if isinstance(t, tuple): + t = t[0] + return t.contiguous() + + def get_target_weights(self, targets, net_output): + targets = net_output["cpc_targets"] + if isinstance(targets, tuple) and targets[-1] is not None: + return targets[-1] + return None + + def get_extra_losses(self, net_output): + loss = None + if "prob_perplexity" in net_output: + loss = net_output["num_vars"] - net_output["prob_perplexity"] + elif "kmeans_loss" in net_output: + loss = net_output["kmeans_loss"] + + return loss + + +def norm_block(is_layer_norm, dim, affine=True): + if is_layer_norm: + mod = nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=affine), + TransposeLast(), + ) + else: + mod = Fp32GroupNorm(1, dim, affine=affine) + + return mod + + +class ConvFeatureExtractionModel(nn.Module): + def __init__( + self, + conv_layers, + dropout, + log_compression, + skip_connections, + residual_scale, + non_affine_group_norm, + activation, + ): + super().__init__() + + def block(n_in, n_out, k, stride): + return nn.Sequential( + nn.Conv1d(n_in, n_out, k, stride=stride, bias=False), + nn.Dropout(p=dropout), + norm_block( + is_layer_norm=False, dim=n_out, affine=not non_affine_group_norm + ), + activation, + ) + + in_d = 1 + self.conv_layers = nn.ModuleList() + for dim, k, stride in conv_layers: + self.conv_layers.append(block(in_d, dim, k, stride)) + in_d = dim + + self.log_compression = log_compression + self.skip_connections = skip_connections + self.residual_scale = math.sqrt(residual_scale) + + def forward(self, x): + # BxT -> BxCxT + x = x.unsqueeze(1) + + for conv in self.conv_layers: + residual = x + x = conv(x) + if self.skip_connections and x.size(1) == residual.size(1): + tsz = x.size(2) + r_tsz = residual.size(2) + residual = residual[..., :: r_tsz // tsz][..., :tsz] + x = (x + residual) * self.residual_scale + + if self.log_compression: + x = x.abs() + x = x + 1 + x = x.log() + + return x + + +class ZeroPad1d(nn.Module): + def __init__(self, pad_left, pad_right): + super().__init__() + self.pad_left = pad_left + self.pad_right = pad_right + + def forward(self, x): + return F.pad(x, (self.pad_left, self.pad_right)) + + +class ConvAggegator(nn.Module): + def __init__( + self, + conv_layers, + embed, + dropout, + skip_connections, + residual_scale, + non_affine_group_norm, + conv_bias, + zero_pad, + activation, + ): + super().__init__() + + def block(n_in, n_out, k, stride): + # padding dims only really make sense for stride = 1 + ka = k // 2 + kb = ka - 1 if k % 2 == 0 else ka + + pad = ( + ZeroPad1d(ka + kb, 0) if zero_pad else nn.ReplicationPad1d((ka + kb, 0)) + ) + + return nn.Sequential( + pad, + nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias), + nn.Dropout(p=dropout), + norm_block(False, n_out, affine=not non_affine_group_norm), + activation, + ) + + in_d = embed + self.conv_layers = nn.ModuleList() + self.residual_proj = nn.ModuleList() + for dim, k, stride in conv_layers: + if in_d != dim and skip_connections: + self.residual_proj.append(nn.Conv1d(in_d, dim, 1, bias=False)) + else: + self.residual_proj.append(None) + + self.conv_layers.append(block(in_d, dim, k, stride)) + in_d = dim + self.conv_layers = nn.Sequential(*self.conv_layers) + self.skip_connections = skip_connections + self.residual_scale = math.sqrt(residual_scale) + + def forward(self, x): + for rproj, conv in zip(self.residual_proj, self.conv_layers): + residual = x + x = conv(x) + if self.skip_connections: + if rproj is not None: + residual = rproj(residual) + x = (x + residual) * self.residual_scale + return x + + +class Wav2VecPredictionsModel(nn.Module): + def __init__( + self, + in_dim, + out_dim, + prediction_steps, + n_negatives, + cross_sample_negatives, + sample_distance, + dropout, + offset, + balanced_classes, + infonce, + ): + super().__init__() + + self.n_negatives = n_negatives + self.cross_sample_negatives = cross_sample_negatives + self.sample_distance = sample_distance + self.project_to_steps = nn.ConvTranspose2d( + in_dim, out_dim, (1, prediction_steps) + ) + self.dropout = nn.Dropout(p=dropout) + self.offset = offset + self.balanced_classes = balanced_classes + self.infonce = infonce + + def sample_negatives(self, y): + bsz, fsz, tsz = y.shape + + y = y.transpose(0, 1) # BCT -> CBT + y = y.contiguous().view(fsz, -1) # CBT => C(BxT) + + cross_high = tsz * bsz + high = tsz if self.sample_distance is None else min(tsz, self.sample_distance) + assert high > 1 + + neg_idxs = torch.randint(low=0, high=high, size=(bsz, self.n_negatives * tsz)) + + with torch.no_grad(): + if self.n_negatives > 0: + tszs = ( + buffered_arange(tsz) + .unsqueeze(-1) + .expand(-1, self.n_negatives) + .flatten() + ) + + neg_idxs = torch.randint( + low=0, high=high - 1, size=(bsz, self.n_negatives * tsz) + ) + neg_idxs[neg_idxs >= tszs] += 1 + + if self.cross_sample_negatives > 0: + tszs = ( + buffered_arange(tsz) + .unsqueeze(-1) + .expand(-1, self.cross_sample_negatives) + .flatten() + ) + + cross_neg_idxs = torch.randint( + low=0, + high=cross_high - 1, + size=(bsz, self.cross_sample_negatives * tsz), + ) + cross_neg_idxs[cross_neg_idxs >= tszs] += 1 + + if self.n_negatives > 0: + for i in range(1, bsz): + neg_idxs[i] += i * high + else: + neg_idxs = cross_neg_idxs + + if self.cross_sample_negatives > 0 and self.n_negatives > 0: + neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) + + negs = y[..., neg_idxs.view(-1)] + negs = negs.view( + fsz, bsz, self.n_negatives + self.cross_sample_negatives, tsz + ).permute( + 2, 1, 0, 3 + ) # to NxBxCxT + + return negs + + def forward(self, x, y): + + x = x.unsqueeze(-1) + x = self.project_to_steps(x) # BxCxTxS + x = self.dropout(x) + + negatives = self.sample_negatives(y) + y = y.unsqueeze(0) + targets = torch.cat([y, negatives], dim=0) # Copies x B x C x T + + copies = targets.size(0) + bsz, dim, tsz, steps = x.shape + steps = min(steps, tsz - self.offset) + + predictions = x.new( + bsz * copies * (tsz - self.offset + 1) * steps + - ((steps + 1) * steps // 2) * copies * bsz + ) + if self.infonce: + labels = predictions.new_full( + (predictions.shape[0] // copies,), 0, dtype=torch.long + ) + else: + labels = torch.zeros_like(predictions) + weights = ( + torch.full_like(labels, 1 / self.n_negatives) + if self.balanced_classes and not self.infonce + else None + ) + + start = end = 0 + for i in range(steps): + offset = i + self.offset + end = start + (tsz - offset) * bsz * copies + if self.infonce: + predictions[start:end] = torch.einsum( + "bct,nbct->tbn", x[..., :-offset, i], targets[..., offset:] + ).flatten() + else: + pos_num = (end - start) // copies + predictions[start:end] = torch.einsum( + "bct,nbct->nbt", x[..., :-offset, i], targets[..., offset:] + ).flatten() + labels[start : start + pos_num] = 1.0 + if weights is not None: + weights[start : start + pos_num] = 1.0 + start = end + assert end == predictions.numel(), "{} != {}".format(end, predictions.numel()) + + if self.infonce: + predictions = predictions.view(-1, copies) + else: + if weights is not None: + labels = (labels, weights) + + return predictions, labels diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py new file mode 100644 index 0000000000000000000000000000000000000000..f8dc6a8f5db3e12ad0c74c18644e36a7f179aa65 --- /dev/null +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -0,0 +1,1492 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass, field +from typing import List, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq import utils +from fairseq.data.data_utils import compute_mask_indices +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.distributed import fsdp_wrap +from fairseq.models import BaseFairseqModel, register_model +from fairseq.distributed.fully_sharded_data_parallel import FullyShardedDataParallel +from fairseq.modules import ( + Fp32GroupNorm, + Fp32LayerNorm, + GradMultiply, + GumbelVectorQuantizer, + LayerNorm, + MultiheadAttention, + RelPositionalEncoding, + SamePad, + TransposeLast, +) +from fairseq.modules.checkpoint_activations import checkpoint_wrapper +from fairseq.modules.conformer_layer import ConformerWav2Vec2EncoderLayer +from fairseq.modules.transformer_sentence_encoder import init_bert_params +from fairseq.utils import buffered_arange, index_put, is_xla_tensor + +from .utils import pad_to_multiple + +EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"]) +MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"]) +LAYER_TYPE_CHOICES = ChoiceEnum(["transformer", "conformer", "trf_adp"]) + + +@dataclass +class Wav2Vec2Config(FairseqDataclass): + extractor_mode: EXTRACTOR_MODE_CHOICES = field( + default="default", + metadata={ + "help": "mode for feature extractor. default has a single group norm with d " + "groups in the first conv block, whereas layer_norm has layer norms in " + "every block (meant to use with normalize=True)" + }, + ) + encoder_layers: int = field( + default=12, metadata={"help": "num encoder layers in the transformer"} + ) + encoder_embed_dim: int = field( + default=768, metadata={"help": "encoder embedding dimension"} + ) + encoder_ffn_embed_dim: int = field( + default=3072, metadata={"help": "encoder embedding dimension for FFN"} + ) + encoder_attention_heads: int = field( + default=12, metadata={"help": "num encoder attention heads"} + ) + activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( + default="gelu", metadata={"help": "activation function to use"} + ) + layer_type: LAYER_TYPE_CHOICES = field( + default="transformer", metadata={"help": "layer type in encoder"} + ) + # dropouts + dropout: float = field( + default=0.1, metadata={"help": "dropout probability for the transformer"} + ) + attention_dropout: float = field( + default=0.1, metadata={"help": "dropout probability for attention weights"} + ) + activation_dropout: float = field( + default=0.0, metadata={"help": "dropout probability after activation in FFN"} + ) + encoder_layerdrop: float = field( + default=0.0, metadata={"help": "probability of dropping a tarnsformer layer"} + ) + dropout_input: float = field( + default=0.0, + metadata={"help": "dropout to apply to the input (after feat extr)"}, + ) + dropout_features: float = field( + default=0.0, + metadata={"help": "dropout to apply to the features (after feat extr)"}, + ) + + final_dim: int = field( + default=0, + metadata={ + "help": "project final representations and targets to this many dimensions." + "set to encoder_embed_dim is <= 0" + }, + ) + layer_norm_first: bool = field( + default=False, metadata={"help": "apply layernorm first in the transformer"} + ) + conv_feature_layers: str = field( + default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]", + metadata={ + "help": "string describing convolutional feature extraction layers in form of a python list that contains " + "[(dim, kernel_size, stride), ...]" + }, + ) + conv_bias: bool = field( + default=False, metadata={"help": "include bias in conv encoder"} + ) + logit_temp: float = field( + default=0.1, metadata={"help": "temperature to divide logits by"} + ) + quantize_targets: bool = field( + default=False, metadata={"help": "use quantized targets"} + ) + quantize_input: bool = field( + default=False, metadata={"help": "use quantized inputs"} + ) + same_quantizer: bool = field( + default=False, metadata={"help": "use same quantizer for inputs and targets"} + ) + target_glu: bool = field( + default=False, metadata={"help": "adds projection + glu to targets"} + ) + feature_grad_mult: float = field( + default=1.0, metadata={"help": "multiply feature extractor var grads by this"} + ) + quantizer_depth: int = field( + default=1, + metadata={"help": "number of quantizer layers"}, + ) + quantizer_factor: int = field( + default=3, + metadata={ + "help": "dimensionality increase for inner quantizer layers (if depth > 1)" + }, + ) + latent_vars: int = field( + default=320, + metadata={"help": "number of latent variables V in each group of the codebook"}, + ) + latent_groups: int = field( + default=2, + metadata={"help": "number of groups G of latent variables in the codebook"}, + ) + latent_dim: int = field( + default=0, + metadata={ + "help": "if > 0, uses this dimensionality for latent variables. " + "otherwise uses final_dim / latent_groups" + }, + ) + + # masking + mask_length: int = field(default=10, metadata={"help": "mask length"}) + mask_prob: float = field( + default=0.65, metadata={"help": "probability of replacing a token with mask"} + ) + mask_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", metadata={"help": "how to choose mask length"} + ) + mask_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument (used for more complex distributions), " + "see help in compute_mask_indices" + }, + ) + no_mask_overlap: bool = field( + default=False, metadata={"help": "whether to allow masks to overlap"} + ) + mask_min_space: int = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + require_same_masks: bool = field( + default=True, + metadata={ + "help": "whether to number of masked timesteps must be the same across all " + "examples in a batch" + }, + ) + mask_dropout: float = field( + default=0.0, + metadata={"help": "percent of masks to unmask for each sample"}, + ) + + # channel masking + mask_channel_length: int = field( + default=10, metadata={"help": "length of the mask for features (channels)"} + ) + mask_channel_prob: float = field( + default=0.0, metadata={"help": "probability of replacing a feature with 0"} + ) + mask_channel_before: bool = False + mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", + metadata={"help": "how to choose mask length for channel masking"}, + ) + mask_channel_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument (used for more complex distributions), " + "see help in compute_mask_indicesh" + }, + ) + no_mask_channel_overlap: bool = field( + default=False, metadata={"help": "whether to allow channel masks to overlap"} + ) + mask_channel_min_space: int = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + + # negative selection + num_negatives: int = field( + default=100, + metadata={"help": "number of negative examples from the same sample"}, + ) + negatives_from_everywhere: bool = field( + default=False, + metadata={"help": "sample negatives from everywhere, not just masked states"}, + ) + cross_sample_negatives: int = field( + default=0, metadata={"help": "number of negative examples from the any sample"} + ) + codebook_negatives: int = field( + default=0, metadata={"help": "number of negative examples codebook"} + ) + + # positional embeddings + conv_pos: int = field( + default=128, + metadata={"help": "number of filters for convolutional positional embeddings"}, + ) + conv_pos_groups: int = field( + default=16, + metadata={"help": "number of groups for convolutional positional embedding"}, + ) + pos_conv_depth: int = field( + default=1, + metadata={"help": "depth of positional encoder network"}, + ) + + latent_temp: Tuple[float, float, float] = field( + default=(2, 0.5, 0.999995), + metadata={ + "help": "temperature for latent variable sampling. " + "can be tuple of 3 values (start, end, decay)" + }, + ) + max_positions: int = field(default=100000, metadata={"help": "Max positions"}) + checkpoint_activations: bool = field( + default=False, + metadata={"help": "recompute activations and save memory for extra compute"}, + ) + + # FP16 optimization + required_seq_len_multiple: int = field( + default=2, + metadata={ + "help": "pad the input to encoder such that the sequence length is divisible by multiple" + }, + ) + crop_seq_to_multiple: int = field( + default=1, + metadata={ + "help": "crop convolutional feature extractor output such that the sequence length is divisible by multiple" + }, + ) + + # Conformer + depthwise_conv_kernel_size: int = field( + default=31, + metadata={ + "help": "depthwise-conv-kernel-size for convolution in conformer layer" + }, + ) + attn_type: str = field( + default="", + metadata={"help": "if espnet use ESPNET MHA"}, + ) + pos_enc_type: str = field( + default="abs", + metadata={"help": "Positional encoding type to use in conformer"}, + ) + fp16: bool = field(default=False, metadata={"help": "If fp16 is being used"}) + + # Adapter num + adp_num: int = field( + default=-1 + ) + adp_dim: int = field( + default=64 + ) + adp_act_fn: str = field( + default="relu" + ) + adp_trf_idx: str = field( + default="all", + ) + + +@register_model("wav2vec2", dataclass=Wav2Vec2Config) +class Wav2Vec2Model(BaseFairseqModel): + def __init__(self, cfg: Wav2Vec2Config): + super().__init__() + self.cfg = cfg + + feature_enc_layers = eval(cfg.conv_feature_layers) + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim and not cfg.quantize_input + else None + ) + + self.crop_seq_to_multiple = cfg.crop_seq_to_multiple + + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_before = cfg.mask_channel_before + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + + self.quantizer = None + self.input_quantizer = None + + self.n_negatives = cfg.num_negatives + self.cross_sample_negatives = cfg.cross_sample_negatives + self.codebook_negatives = cfg.codebook_negatives + self.negatives_from_everywhere = cfg.negatives_from_everywhere + + self.logit_temp = cfg.logit_temp + + final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim + + if cfg.quantize_targets: + vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else final_dim + self.quantizer = GumbelVectorQuantizer( + dim=self.embed, + num_vars=cfg.latent_vars, + temp=cfg.latent_temp, + groups=cfg.latent_groups, + combine_groups=False, + vq_dim=vq_dim, + time_first=True, + weight_proj_depth=cfg.quantizer_depth, + weight_proj_factor=cfg.quantizer_factor, + ) + self.project_q = nn.Linear(vq_dim, final_dim) + else: + self.project_q = nn.Linear(self.embed, final_dim) + + if cfg.quantize_input: + if cfg.same_quantizer and self.quantizer is not None: + vq_dim = final_dim + self.input_quantizer = self.quantizer + else: + vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else cfg.encoder_embed_dim + self.input_quantizer = GumbelVectorQuantizer( + dim=self.embed, + num_vars=cfg.latent_vars, + temp=cfg.latent_temp, + groups=cfg.latent_groups, + combine_groups=False, + vq_dim=vq_dim, + time_first=True, + weight_proj_depth=cfg.quantizer_depth, + weight_proj_factor=cfg.quantizer_factor, + ) + self.project_inp = nn.Linear(vq_dim, cfg.encoder_embed_dim) + + self.mask_emb = nn.Parameter( + torch.FloatTensor(cfg.encoder_embed_dim).uniform_() + ) + encoder_cls = TransformerEncoder + if cfg.layer_type == "conformer" and cfg.pos_enc_type in ["rel_pos", "rope"]: + encoder_cls = ConformerEncoder + + self.encoder = encoder_cls(cfg) + self.layer_norm = LayerNorm(self.embed) + + self.target_glu = None + if cfg.target_glu: + self.target_glu = nn.Sequential( + nn.Linear(final_dim, final_dim * 2), nn.GLU() + ) + + self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + return state_dict + + @classmethod + def build_model(cls, cfg: Wav2Vec2Config, task=None): + """Build a new model instance.""" + + return cls(cfg) + + def apply_mask( + self, + x, + padding_mask, + mask_indices=None, + mask_channel_indices=None, + ): + B, T, C = x.shape + + if self.mask_channel_prob > 0 and self.mask_channel_before: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + if self.mask_prob > 0: + if mask_indices is None: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + require_same_masks=self.cfg.require_same_masks, + mask_dropout=self.cfg.mask_dropout, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x = index_put(x, mask_indices, self.mask_emb) + else: + mask_indices = None + + if self.mask_channel_prob > 0 and not self.mask_channel_before: + if mask_channel_indices is None: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x = index_put(x, mask_channel_indices, 0) + + return x, mask_indices + + def sample_negatives(self, y, num, padding_count=None): + + if self.n_negatives == 0 and self.cross_sample_negatives == 0: + return y.new(0) + + bsz, tsz, fsz = y.shape + y = y.view(-1, fsz) # BTC => (BxT)C + + # FIXME: what happens if padding_count is specified? + cross_high = tsz * bsz + high = tsz - (padding_count or 0) + with torch.no_grad(): + assert high > 1, f"{bsz,tsz,fsz}" + + if self.n_negatives > 0: + tszs = ( + buffered_arange(num) + .unsqueeze(-1) + .expand(-1, self.n_negatives) + .flatten() + ) + + neg_idxs = torch.randint( + low=0, high=high - 1, size=(bsz, self.n_negatives * num) + ) + neg_idxs[neg_idxs >= tszs] += 1 + + if self.cross_sample_negatives > 0: + tszs = ( + buffered_arange(num) + .unsqueeze(-1) + .expand(-1, self.cross_sample_negatives) + .flatten() + ) + + cross_neg_idxs = torch.randint( + low=0, + high=cross_high - 1, + size=(bsz, self.cross_sample_negatives * num), + ) + cross_neg_idxs[cross_neg_idxs >= tszs] += 1 + + if self.n_negatives > 0: + neg_idxs = neg_idxs + (torch.arange(bsz).unsqueeze(1) * high) + else: + neg_idxs = cross_neg_idxs + + if self.cross_sample_negatives > 0 and self.n_negatives > 0: + neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) + + negs = y[neg_idxs.view(-1)] + negs = negs.view( + bsz, num, self.n_negatives + self.cross_sample_negatives, fsz + ).permute( + 2, 0, 1, 3 + ) # to NxBxTxC + return negs, neg_idxs + + def compute_preds(self, x, y, negatives): + + neg_is_pos = (y == negatives).all(-1) + y = y.unsqueeze(0) + targets = torch.cat([y, negatives], dim=0) + + logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1) + logits = logits / self.logit_temp + logits = logits.type_as(x) + + if is_xla_tensor(logits) or neg_is_pos.any(): + if not hasattr(self, "_inftensor"): + fillval = -float(2**30) + self._inftensor = ( + torch.tensor(fillval).to(x.device) + if is_xla_tensor(logits) + else float("-inf") + ) + logits[1:] = index_put(logits[1:], neg_is_pos, self._inftensor) + + return logits + + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + return torch.floor((input_length - kernel_size) / stride + 1) + + conv_cfg_list = eval(self.cfg.conv_feature_layers) + + for i in range(len(conv_cfg_list)): + input_lengths = _conv_out_length( + input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2] + ) + + return input_lengths.to(torch.long) + + def forward( + self, + source, + padding_mask=None, + mask=True, + features_only=False, + layer=None, + mask_indices=None, + mask_channel_indices=None, + padding_count=None, + corpus_key=None, + ): + + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + + features_pen = features.float().pow(2).mean() + + features = features.transpose(1, 2) + features = self.layer_norm(features) + unmasked_features = features.clone() + + if padding_mask is not None and padding_mask.any(): + input_lengths = (1 - padding_mask.long()).sum(-1) + # apply conv formula to get real output_lengths + output_lengths = self._get_feat_extract_output_lengths(input_lengths) + + padding_mask = torch.zeros( + features.shape[:2], dtype=features.dtype, device=features.device + ) + + # these two operations makes sure that all values + # before the output lengths indices are attended to + padding_mask[ + ( + torch.arange(padding_mask.shape[0], device=padding_mask.device), + output_lengths - 1, + ) + ] = 1 + padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool() + else: + padding_mask = None + + time_steps_to_drop = features.size(1) % self.crop_seq_to_multiple + if time_steps_to_drop != 0: + features = features[:, :-time_steps_to_drop] + unmasked_features = unmasked_features[:, :-time_steps_to_drop] + if padding_mask is not None: + padding_mask = padding_mask[:, :-time_steps_to_drop] + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + unmasked_features = self.dropout_features(unmasked_features) + + num_vars = None + code_ppl = None + prob_ppl = None + curr_temp = None + + if self.input_quantizer: + q = self.input_quantizer(features, produce_targets=False) + features = q["x"] + num_vars = q["num_vars"] + code_ppl = q["code_perplexity"] + prob_ppl = q["prob_perplexity"] + curr_temp = q["temp"] + features = self.project_inp(features) + + if mask: + x, mask_indices = self.apply_mask( + features, + padding_mask, + mask_indices=mask_indices, + mask_channel_indices=mask_channel_indices, + ) + if not is_xla_tensor(x) and mask_indices is not None: + # tpu-comment: reducing the size in a dynamic way causes + # too many recompilations on xla. + y = unmasked_features[mask_indices].view( + unmasked_features.size(0), -1, unmasked_features.size(-1) + ) + else: + y = unmasked_features + else: + x = features + y = unmasked_features + mask_indices = None + + x, layer_results = self.encoder( + x, padding_mask=padding_mask, layer=layer, corpus_key=corpus_key + ) + + if features_only: + return { + "x": x, + "padding_mask": padding_mask, + "features": unmasked_features, + "layer_results": layer_results, + } + + if self.quantizer: + if self.negatives_from_everywhere: + q = self.quantizer(unmasked_features, produce_targets=False) + y = q["x"] + num_vars = q["num_vars"] + code_ppl = q["code_perplexity"] + prob_ppl = q["prob_perplexity"] + curr_temp = q["temp"] + y = self.project_q(y) + + negs, _ = self.sample_negatives( + y, + mask_indices[0].sum(), + padding_count=padding_count, + ) + y = y[mask_indices].view(y.size(0), -1, y.size(-1)) + + else: + q = self.quantizer(y, produce_targets=False) + y = q["x"] + num_vars = q["num_vars"] + code_ppl = q["code_perplexity"] + prob_ppl = q["prob_perplexity"] + curr_temp = q["temp"] + + y = self.project_q(y) + + negs, _ = self.sample_negatives( + y, + y.size(1), + padding_count=padding_count, + ) + + if self.codebook_negatives > 0: + cb_negs = self.quantizer.sample_from_codebook( + y.size(0) * y.size(1), self.codebook_negatives + ) + cb_negs = cb_negs.view( + self.codebook_negatives, y.size(0), y.size(1), -1 + ) # order doesnt matter + cb_negs = self.project_q(cb_negs) + negs = torch.cat([negs, cb_negs], dim=0) + else: + y = self.project_q(y) + + if self.negatives_from_everywhere: + negs, _ = self.sample_negatives( + unmasked_features, + y.size(1), + padding_count=padding_count, + ) + negs = self.project_q(negs) + else: + negs, _ = self.sample_negatives( + y, + y.size(1), + padding_count=padding_count, + ) + + if not is_xla_tensor(x): + # tpu-comment: reducing the size in a dynamic way causes + # too many recompilations on xla. + x = x[mask_indices].view(x.size(0), -1, x.size(-1)) + + if self.target_glu: + y = self.target_glu(y) + negs = self.target_glu(negs) + + x = self.final_proj(x) + x = self.compute_preds(x, y, negs) + + result = { + "x": x, + "padding_mask": padding_mask, + "features_pen": features_pen, + } + + if prob_ppl is not None: + result["prob_perplexity"] = prob_ppl + result["code_perplexity"] = code_ppl + result["num_vars"] = num_vars + result["temp"] = curr_temp + + return result + + def quantize(self, x): + assert self.quantizer is not None + x = self.feature_extractor(x) + x = x.transpose(1, 2) + x = self.layer_norm(x) + return self.quantizer.forward_idx(x) + + def extract_features( + self, source, padding_mask, mask=False, layer=None, corpus_key=None + ): + res = self.forward( + source, + padding_mask, + mask=mask, + features_only=True, + layer=layer, + corpus_key=corpus_key, + ) + return res + + def get_logits(self, net_output): + logits = net_output["x"] + logits = logits.transpose(0, 2) + logits = logits.reshape(-1, logits.size(-1)) + return logits + + def get_targets(self, sample, net_output, expand_steps=True): + x = net_output["x"] + return x.new_zeros(x.size(1) * x.size(2), dtype=torch.long) + + def get_extra_losses(self, net_output): + pen = [] + + if "prob_perplexity" in net_output: + pen.append( + (net_output["num_vars"] - net_output["prob_perplexity"]) + / net_output["num_vars"] + ) + + if "features_pen" in net_output: + pen.append(net_output["features_pen"]) + + return pen + + def remove_pretraining_modules(self, last_layer=None): + self.quantizer = None + self.project_q = None + self.target_glu = None + self.final_proj = None + + if last_layer is not None: + self.encoder.layers = nn.ModuleList( + l for i, l in enumerate(self.encoder.layers) if i <= last_layer + ) + + +class ConvFeatureExtractionModel(nn.Module): + def __init__( + self, + conv_layers: List[Tuple[int, int, int]], + dropout: float = 0.0, + mode: str = "default", + conv_bias: bool = False, + ): + super().__init__() + + assert mode in {"default", "layer_norm"} + + def block( + n_in, + n_out, + k, + stride, + is_layer_norm=False, + is_group_norm=False, + conv_bias=False, + ): + def make_conv(): + conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) + nn.init.kaiming_normal_(conv.weight) + return conv + + assert ( + is_layer_norm and is_group_norm + ) == False, "layer norm and group norm are exclusive" + + if is_layer_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=True), + TransposeLast(), + ), + nn.GELU(), + ) + elif is_group_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + Fp32GroupNorm(dim, dim, affine=True), + nn.GELU(), + ) + else: + return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) + + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3, "invalid conv definition: " + str(cl) + (dim, k, stride) = cl + + self.conv_layers.append( + block( + in_d, + dim, + k, + stride, + is_layer_norm=mode == "layer_norm", + is_group_norm=mode == "default" and i == 0, + conv_bias=conv_bias, + ) + ) + in_d = dim + + def forward(self, x): + + # BxT -> BxCxT + x = x.unsqueeze(1) + + for conv in self.conv_layers: + x = conv(x) + + return x + + +def make_conv_pos(e, k, g, is_batch_norm=False): + pos_conv = nn.Conv1d( + e, + e, + kernel_size=k, + padding=k // 2, + groups=g, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (k * e)) + nn.init.normal_(pos_conv.weight, mean=0, std=std) + nn.init.constant_(pos_conv.bias, 0) + + if not is_batch_norm: + pos_conv = nn.utils.weight_norm(pos_conv, name="weight", dim=2) + pos_conv = nn.Sequential(pos_conv, SamePad(k), nn.GELU()) + else: + batch_norm = nn.BatchNorm1d(e) + pos_conv = nn.Sequential(batch_norm, pos_conv, SamePad(k), nn.GELU()) + + return pos_conv + + +class TransformerEncoder(nn.Module): + def build_encoder_layer(self, args: Wav2Vec2Config, **kwargs): + if args.layer_type == "transformer": + layer = TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + ) + elif args.layer_type == "conformer": + layer = ConformerWav2Vec2EncoderLayer( + embed_dim=self.embedding_dim, + ffn_embed_dim=args.encoder_ffn_embed_dim, + attention_heads=args.encoder_attention_heads, + dropout=args.dropout, + depthwise_conv_kernel_size=args.depthwise_conv_kernel_size, + activation_fn="swish", + attn_type=args.attn_type, + use_fp16=args.fp16, + pos_enc_type="abs", + ) + elif args.layer_type == "trf_adp": + use_adp = False + if args.adp_trf_idx == "all": + use_adp = True + else: + adp_trf_idx = list(range(*[int(g) for g in args.adp_trf_idx.split(":")])) + if kwargs.get("layer_idx", None) in adp_trf_idx: + use_adp = True + if use_adp: + layer = TransformerSentenceEncoderWithAdapterLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + adapter_num=args.adp_num, + adapter_dim=args.adp_dim, + adapter_act_fn=args.adp_act_fn, + ) + else: + layer = TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + ) + + layer = fsdp_wrap(layer) + if args.checkpoint_activations: + layer = checkpoint_wrapper(layer) + return layer + + def __init__(self, args: Wav2Vec2Config): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + self.required_seq_len_multiple = args.required_seq_len_multiple + + pos_conv_depth = getattr(args, "pos_conv_depth", 1) + if pos_conv_depth > 1: + num_layers = args.pos_conv_depth + k = max(3, args.conv_pos // num_layers) + + def make_conv_block(e, k, g, l): + return nn.Sequential( + *[ + nn.Sequential( + nn.Conv1d( + e, + e, + kernel_size=k, + padding=k // 2, + groups=g, + ), + SamePad(k), + TransposeLast(), + LayerNorm(e, elementwise_affine=False), + TransposeLast(), + nn.GELU(), + ) + for _ in range(l) + ] + ) + + self.pos_conv = make_conv_block( + self.embedding_dim, k, args.conv_pos_groups, num_layers + ) + + else: + self.pos_conv = make_conv_pos( + self.embedding_dim, + args.conv_pos, + args.conv_pos_groups, + is_batch_norm=args.conv_pos_batch_norm + if hasattr(args, "conv_pos_batch_norm") + else False, + ) + + self.layers = nn.ModuleList( + [self.build_encoder_layer(args, layer_idx=ii) for ii in range(args.encoder_layers)] + ) + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + def forward(self, x, padding_mask=None, layer=None, corpus_key=None): + x, layer_results = self.extract_features( + x, padding_mask, layer, corpus_key=corpus_key + ) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features( + self, + x, + padding_mask=None, + tgt_layer=None, + min_layer=0, + corpus_key=None, + ): + + if padding_mask is not None: + x = index_put(x, padding_mask, 0) + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x = x + x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + # pad to the sequence length dimension + x, pad_length = pad_to_multiple( + x, self.required_seq_len_multiple, dim=-2, value=0 + ) + if pad_length > 0 and padding_mask is None: + padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool) + padding_mask[:, -pad_length:] = True + else: + padding_mask, _ = pad_to_multiple( + padding_mask, self.required_seq_len_multiple, dim=-1, value=True + ) + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + r = None + + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() if self.layerdrop > 0 else 1 + if not self.training or (dropout_probability > self.layerdrop): + layer_check = layer + if isinstance(layer, FullyShardedDataParallel): + layer_check = layer.unwrapped_module + if (corpus_key is None) or ( + not isinstance(layer_check, ( + TransformerSentenceEncoderWithAdapterLayer, + ) + ) + ): + x, (z, lr) = layer( + x, self_attn_padding_mask=padding_mask, need_weights=False + ) + else: + x, (z, lr) = layer( + x, + self_attn_padding_mask=padding_mask, + need_weights=False, + corpus_key=corpus_key, + ) + if i >= min_layer: + layer_results.append((x, z, lr)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + # undo paddding + if pad_length > 0: + x = x[:, :-pad_length] + + def undo_pad(a, b, c): + return ( + a[:-pad_length], + b[:-pad_length] if b is not None else b, + c[:-pad_length], + ) + + layer_results = [undo_pad(*u) for u in layer_results] + + return x, layer_results + + def max_positions(self): + """Maximum output length supported by the encoder.""" + return self.args.max_positions + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + return state_dict + + +class ConformerEncoder(TransformerEncoder): + def build_encoder_layer(self, args): + layer = ConformerWav2Vec2EncoderLayer( + embed_dim=self.embedding_dim, + ffn_embed_dim=args.encoder_ffn_embed_dim, + attention_heads=args.encoder_attention_heads, + dropout=args.dropout, + depthwise_conv_kernel_size=args.depthwise_conv_kernel_size, + activation_fn="swish", + attn_type=args.attn_type, + pos_enc_type=args.pos_enc_type, + use_fp16=args.fp16, # only used for rope + ) + layer = fsdp_wrap(layer) + if args.checkpoint_activations: + layer = checkpoint_wrapper(layer) + return layer + + def __init__(self, args): + super().__init__(args) + self.args = args + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + self.pos_enc_type = args.pos_enc_type + max_source_positions = self.max_positions() + + if self.pos_enc_type == "rel_pos": + self.embed_positions = RelPositionalEncoding( + max_source_positions, self.embedding_dim + ) + elif self.pos_enc_type == "rope": + self.embed_positions = None + else: + raise Exception("Unsupported positional encoding type") + + self.layers = nn.ModuleList( + [self.build_encoder_layer(args) for _ in range(args.encoder_layers)] + ) + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + def extract_features(self, x, padding_mask=None, tgt_layer=None): + if padding_mask is not None: + x = index_put(x, padding_mask, 0) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + # B X T X C here + position_emb = None + if self.pos_enc_type == "rel_pos": + position_emb = self.embed_positions(x) + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + layer_results = [] + r = None + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, z = layer( + x, + self_attn_padding_mask=padding_mask, + need_weights=False, + position_emb=position_emb, + ) + if tgt_layer is not None: + layer_results.append((x, z)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, layer_results + + +class TransformerSentenceEncoderLayer(nn.Module): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: int = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + ) -> None: + + super().__init__() + # Initialize parameters + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + # Initialize blocks + self.activation_fn = utils.get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + att_args=None, + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer imlementation. + """ + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + attn_mask=self_attn_mask, + need_weights=False, + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + + layer_result = x + + x = self.dropout3(x) + x = residual + x + else: + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + ) + + x = self.dropout1(x) + x = residual + x + + x = self.self_attn_layer_norm(x) + + residual = x + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + + layer_result = x + + x = self.dropout3(x) + x = residual + x + x = self.final_layer_norm(x) + + return x, (attn, layer_result) + + +class AdapterFast(nn.Module): + def __init__(self, adapter_num, input_dim, hidden_dim, act_fn): + """ + Implements adapter modules directly with 3D tensor weight as parameters + and without using ModuleList orto speed up training throughput. + """ + super().__init__() + + self.adapter_num = adapter_num + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.W_a = nn.Parameter(torch.empty(adapter_num, hidden_dim, input_dim)) + self.W_b = nn.Parameter(torch.empty(adapter_num, input_dim, hidden_dim)) + self.b_a = nn.Parameter(torch.empty(adapter_num, hidden_dim)) + self.b_b = nn.Parameter(torch.empty(adapter_num, input_dim)) + + self.ln_W = nn.Parameter(torch.empty(adapter_num, input_dim)) + self.ln_b = nn.Parameter(torch.empty(adapter_num, input_dim)) + self.act_fn = nn.Identity() + if act_fn == "relu": + self.act_fn = nn.ReLU() + elif act_fn == "gelu": + self.act_fn = nn.GELU() + elif act_fn == "selu": + self.act_fn = nn.SELU() + else: + raise ValueError(f"unsupported {act_fn}") + + + self.input_dim = input_dim + self.reset_parameters() + + def reset_parameters(self): + for ii in range(self.adapter_num): + nn.init.kaiming_uniform_(self.W_a[ii], a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.W_b[ii], a=math.sqrt(5)) + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_a[ii]) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.b_a[ii], -bound, bound) + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_b[ii]) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.b_b[ii], -bound, bound) + + nn.init.ones_(self.ln_W) + nn.init.zeros_(self.ln_b) + + def forward(self, x, adapter_id): + ii = adapter_id + h = x + h = F.layer_norm(h, (self.input_dim, ), self.ln_W[ii], self.ln_b[ii]) + h = F.linear(h, self.W_a[ii], self.b_a[ii]) + h = self.act_fn(h) + h = F.linear(h, self.W_b[ii], self.b_b[ii]) + outputs = h + return outputs + + def extra_repr(self): + return ('adapter={}, input_dim={}, hidden_dim={}'.format(self.adapter_num, self.input_dim, self.hidden_dim)) + + + +class TransformerSentenceEncoderWithAdapterLayer(TransformerSentenceEncoderLayer): + """ + Implements a Transformer Encoder Layer with adapters used in BERT/XLM style pre-trained + models. An adapter module is added along with vanilla Transformer module. + """ + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: int = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + adapter_num=201, + adapter_dim=64, + adapter_act_fn="relu", + ) -> None: + + super().__init__( + embedding_dim=embedding_dim, + ffn_embedding_dim=ffn_embedding_dim, + num_attention_heads=num_attention_heads, + dropout=dropout, + attention_dropout=attention_dropout, + activation_dropout=activation_dropout, + activation_fn=activation_fn, + layer_norm_first=layer_norm_first, + + ) + + self.adapter_num = adapter_num + self.adapter_dim = adapter_dim + self.adapter_layer = AdapterFast(adapter_num, self.embedding_dim, self.adapter_dim, adapter_act_fn) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + att_args=None, + corpus_key=None, + ): + + x, (attn, layer_result) = super().forward( + x=x, + self_attn_mask=self_attn_mask, + self_attn_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + att_args=att_args, + ) + assert corpus_key is not None + assert len(set(corpus_key)) == 1, f"corpus_key items are not same {corpus_key}" + y = self.adapter_layer(x, corpus_key[0]) + x = x + y + return x, (attn, layer_result) diff --git a/fairseq/models/wav2vec/wav2vec2_asr.py b/fairseq/models/wav2vec/wav2vec2_asr.py new file mode 100644 index 0000000000000000000000000000000000000000..0403efebb9b4b48c4d8f518e127e7663395329e8 --- /dev/null +++ b/fairseq/models/wav2vec/wav2vec2_asr.py @@ -0,0 +1,878 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import copy +import logging +import math +import re +from argparse import Namespace +from dataclasses import dataclass, field +from typing import Any, Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from omegaconf import II, MISSING, open_dict + +from fairseq import checkpoint_utils, tasks, utils +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.models import ( + BaseFairseqModel, + FairseqEncoder, + FairseqEncoderDecoderModel, + FairseqIncrementalDecoder, + register_model, +) +from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES, LAYER_TYPE_CHOICES, AdapterFast +from fairseq.modules import LayerNorm, PositionalEmbedding, TransformerDecoderLayer +from fairseq.tasks import FairseqTask + +logger = logging.getLogger(__name__) + + +@dataclass +class Wav2Vec2AsrConfig(FairseqDataclass): + w2v_path: str = field( + default=MISSING, metadata={"help": "path to wav2vec 2.0 model"} + ) + no_pretrained_weights: bool = field( + default=False, metadata={"help": "if true, does not load pretrained weights"} + ) + dropout_input: float = field( + default=0.0, + metadata={"help": "dropout to apply to the input (after feat extr)"}, + ) + + final_dropout: float = field( + default=0.0, + metadata={"help": "dropout after transformer and before final projection"}, + ) + dropout: float = field( + default=0.0, metadata={"help": "dropout probability inside wav2vec 2.0 model"} + ) + attention_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability for attention weights inside wav2vec 2.0 model" + }, + ) + activation_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability after activation in FFN inside wav2vec 2.0 model" + }, + ) + + # masking + apply_mask: bool = field( + default=False, metadata={"help": "apply masking during fine-tuning"} + ) + mask_length: int = field( + default=10, metadata={"help": "repeat the mask indices multiple times"} + ) + mask_prob: float = field( + default=0.5, + metadata={ + "help": "probability of replacing a token with mask (normalized by length)" + }, + ) + mask_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", metadata={"help": "how to choose masks"} + ) + mask_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument (used for more complex distributions), " + "see help in compute_mask_indices" + }, + ) + no_mask_overlap: bool = field( + default=False, metadata={"help": "whether to allow masks to overlap"} + ) + mask_min_space: Optional[int] = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + require_same_masks: bool = field( + default=True, + metadata={ + "help": "whether to number of masked timesteps must be the same across all " + "examples in a batch" + }, + ) + mask_dropout: float = field( + default=0.0, + metadata={"help": "percent of masks to unmask for each sample"}, + ) + + # channel masking + mask_channel_length: int = field( + default=10, metadata={"help": "length of the mask for features (channels)"} + ) + mask_channel_prob: float = field( + default=0.0, metadata={"help": "probability of replacing a feature with 0"} + ) + mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", + metadata={"help": "how to choose mask length for channel masking"}, + ) + mask_channel_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument (used for more complex distributions), " + "see help in compute_mask_indicesh" + }, + ) + no_mask_channel_overlap: bool = field( + default=False, metadata={"help": "whether to allow channel masks to overlap"} + ) + freeze_finetune_updates: int = field( + default=0, metadata={"help": "dont finetune wav2vec for this many updates"} + ) + feature_grad_mult: float = field( + default=0.0, metadata={"help": "reset feature grad mult in wav2vec 2.0 to this"} + ) + layerdrop: float = field( + default=0.0, metadata={"help": "probability of dropping a layer in wav2vec 2.0"} + ) + drop_path: float = 0 + mask_channel_min_space: Optional[int] = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + mask_channel_before: bool = False + normalize: bool = II("task.normalize") + update_alibi: bool = True + data: str = II("task.data") + # this holds the loaded wav2vec args + w2v_args: Any = None + offload_activations: bool = field( + default=False, metadata={"help": "offload_activations"} + ) + min_params_to_wrap: int = field( + default=int(1e8), + metadata={ + "help": "minimum number of params for a layer to be wrapped with FSDP() when " + "training with --ddp-backend=fully_sharded. Smaller values will " + "improve memory efficiency, but may make torch.distributed " + "communication less efficient due to smaller input sizes. This option " + "is set to 0 (i.e., always wrap) when --checkpoint-activations or " + "--offload-activations are passed." + }, + ) + + checkpoint_activations: bool = field( + default=False, + metadata={"help": "recompute activations and save memory for extra compute"}, + ) + ddp_backend: str = II("distributed_training.ddp_backend") + + zero_mask: bool = False + load_ema: bool = False + + layer_decay: float = 1 + + + layer_type: LAYER_TYPE_CHOICES = field( + default="transformer", metadata={"help": "layer type in encoder"} + ) + # Adapter num + adp_num: int = field( + default=-1 + ) + adp_dim: int = field( + default=64 + ) + adp_act_fn: str = field( + default="relu" + ) + adp_trf_idx: str = field( + default="all", + ) + + freeze_regex: Optional[str] = field( + default=None, + ) + +@dataclass +class Wav2Vec2CtcConfig(Wav2Vec2AsrConfig): + blank_weight: float = 0 + blank_mode: str = "add" + + +@register_model("wav2vec_ctc", dataclass=Wav2Vec2CtcConfig) +class Wav2VecCtc(BaseFairseqModel): + def __init__(self, cfg: Wav2Vec2CtcConfig, w2v_encoder: BaseFairseqModel): + super().__init__() + self.cfg = cfg + self.w2v_encoder = w2v_encoder + self.blank_weight = cfg.blank_weight + self.blank_mode = cfg.blank_mode + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + @classmethod + def build_model(cls, cfg: Wav2Vec2CtcConfig, task: FairseqTask): + """Build a new model instance.""" + w2v_encoder = Wav2VecEncoder(cfg, len(task.target_dictionary)) + return cls(cfg, w2v_encoder) + + def get_logits(self, net_output, normalize=False): + logits = net_output["encoder_out"] + if self.blank_weight != 0: + if self.blank_mode == "add": + logits[..., 0] += self.blank_weight + elif self.blank_mode == "set": + logits[..., 0] = self.blank_weight + else: + raise Exception(f"invalid blank mode {self.blank_mode}") + + if net_output["padding_mask"] is not None and net_output["padding_mask"].any(): + number_of_classes = logits.size(-1) + masking_tensor = torch.ones( + number_of_classes, device=logits.device + ) * float("-inf") + masking_tensor[0] = 0 + + if logits.size(0) > net_output["padding_mask"].size(1): + net_output["padding_mask"] = F.pad( + net_output["padding_mask"], (1, 0), value=False + ) + + logits[net_output["padding_mask"].T] = masking_tensor.type_as(logits) + + if normalize: + logits = utils.log_softmax(logits.float(), dim=-1) + + return logits + + def get_normalized_probs(self, net_output, log_probs): + """Get normalized probabilities (or log probs) from a net's output.""" + + logits = self.get_logits(net_output) + + if log_probs: + return utils.log_softmax(logits.float(), dim=-1) + else: + return utils.softmax(logits.float(), dim=-1) + + def forward(self, **kwargs): + x = self.w2v_encoder(**kwargs) + return x + + +@dataclass +class Wav2Vec2Seq2SeqConfig(Wav2Vec2AsrConfig): + decoder_embed_dim: int = field( + default=768, metadata={"help": "decoder embedding dimension"} + ) + decoder_ffn_embed_dim: int = field( + default=3072, metadata={"help": "decoder embedding dimension for FFN"} + ) + decoder_layers: int = field(default=6, metadata={"help": "num of decoder layers"}) + decoder_layerdrop: float = field( + default=0.0, metadata={"help": "decoder layerdrop chance"} + ) + decoder_attention_heads: int = field( + default=4, metadata={"help": "num decoder attention heads"} + ) + decoder_learned_pos: bool = field( + default=False, + metadata={"help": "use learned positional embeddings in the decoder"}, + ) + decoder_normalize_before: bool = field( + default=False, metadata={"help": "apply layernorm before each decoder block"} + ) + no_token_positional_embeddings: bool = field( + default=False, + metadata={ + "help": "if set, disables positional embeddings (outside self attention)" + }, + ) + decoder_dropout: float = field( + default=0.0, metadata={"help": "dropout probability in the decoder"} + ) + decoder_attention_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability for attention weights inside the decoder" + }, + ) + decoder_activation_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability after activation in FFN inside the decoder" + }, + ) + max_target_positions: int = field( + default=2048, metadata={"help": "max target positions"} + ) + share_decoder_input_output_embed: bool = field( + default=False, metadata={"help": "share decoder input and output embeddings"} + ) + autoregressive: bool = II("task.autoregressive") + + +@register_model("wav2vec_seq2seq", dataclass=Wav2Vec2Seq2SeqConfig) +class Wav2Vec2Seq2SeqModel(FairseqEncoderDecoderModel): + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @classmethod + def build_model(cls, cfg: Wav2Vec2Seq2SeqConfig, task: FairseqTask): + """Build a new model instance.""" + + assert ( + cfg.autoregressive + ), "Please set task.autoregressive=true for seq2seq asr models" + + src_dict, tgt_dict = task.source_dictionary, task.target_dictionary + + def build_embedding(dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + emb = Embedding(num_embeddings, embed_dim, padding_idx) + return emb + + decoder_embed_tokens = build_embedding(tgt_dict, cfg.decoder_embed_dim) + + encoder = cls.build_encoder(cfg) + decoder = cls.build_decoder(cfg, tgt_dict, decoder_embed_tokens) + + return Wav2Vec2Seq2SeqModel(encoder, decoder) + + @classmethod + def build_encoder(cls, cfg: Wav2Vec2AsrConfig): + return Wav2VecEncoder(cfg) + + @classmethod + def build_decoder(cls, cfg: Wav2Vec2Seq2SeqConfig, tgt_dict, embed_tokens): + return TransformerDecoder(cfg, tgt_dict, embed_tokens) + + def forward(self, **kwargs): + encoder_out = self.encoder(**kwargs) + decoder_out = self.decoder(encoder_out=encoder_out, **kwargs) + return decoder_out + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + +class Wav2VecEncoder(FairseqEncoder): + def __init__(self, cfg: Wav2Vec2AsrConfig, output_size=None): + self.apply_mask = cfg.apply_mask + + arg_overrides = { + "dropout": cfg.dropout, + "activation_dropout": cfg.activation_dropout, + "dropout_input": cfg.dropout_input, + "attention_dropout": cfg.attention_dropout, + "mask_length": cfg.mask_length, + "mask_prob": cfg.mask_prob, + "require_same_masks": getattr(cfg, "require_same_masks", True), + "pct_holes": getattr(cfg, "mask_dropout", 0), + "mask_selection": cfg.mask_selection, + "mask_other": cfg.mask_other, + "no_mask_overlap": cfg.no_mask_overlap, + "mask_channel_length": cfg.mask_channel_length, + "mask_channel_prob": cfg.mask_channel_prob, + "mask_channel_before": cfg.mask_channel_before, + "mask_channel_selection": cfg.mask_channel_selection, + "mask_channel_other": cfg.mask_channel_other, + "no_mask_channel_overlap": cfg.no_mask_channel_overlap, + "encoder_layerdrop": cfg.layerdrop, + "feature_grad_mult": cfg.feature_grad_mult, + "checkpoint_activations": cfg.checkpoint_activations, + "offload_activations": cfg.offload_activations, + "min_params_to_wrap": cfg.min_params_to_wrap, + # d2v multi args + "encoder_dropout": cfg.dropout, + "drop_path": getattr(cfg, "drop_path", 0), + "mask_dropout": getattr(cfg, "mask_dropout", 0), + "zero_mask": getattr(cfg, "zero_mask", False), + "local_grad_mult": cfg.feature_grad_mult, + "layerdrop": cfg.layerdrop, + "prenet_layerdrop": cfg.layerdrop, + "prenet_dropout": cfg.dropout, + "post_mlp_drop": cfg.dropout, + "encoder_zero_mask": getattr(cfg, "zero_mask", False), + "inverse_mask": False, + "learned_alibi_scale": getattr(cfg, "update_alibi", True), + } + + if cfg.w2v_args is None: + state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides) + w2v_args = state.get("cfg", None) + if w2v_args is None: + w2v_args = convert_namespace_to_omegaconf(state["args"]) + w2v_args.criterion = None + w2v_args.lr_scheduler = None + + cfg.w2v_args = w2v_args + + logger.info(w2v_args) + + else: + state = None + w2v_args = cfg.w2v_args + if isinstance(w2v_args, Namespace): + cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args) + + self.is_d2v_multi = "data2vec_multi" in w2v_args.model.get("_name", None) + + if not self.is_d2v_multi: + model_normalized = w2v_args.task.get( + "normalize", w2v_args.model.get("normalize", False) + ) + assert cfg.normalize == model_normalized, ( + "Fine-tuning works best when data normalization is the same. " + "Please check that --normalize is set or unset for both pre-training and here" + ) + + with open_dict(w2v_args): + args_replacement = ["checkpoint_activations", "layer_type", + "adp_num", "adp_dim", + "adp_act_fn", "adp_trf_idx"] + for _args in args_replacement: + if hasattr(cfg, _args) and getattr(cfg, _args, None) is not None: + w2v_args.model[_args] = getattr(cfg, _args, None) + + if hasattr(cfg, "checkpoint_activations") and cfg.checkpoint_activations: + with open_dict(w2v_args): + w2v_args.model.checkpoint_activations = cfg.checkpoint_activations + + w2v_args.task.data = cfg.data + task = tasks.setup_task(w2v_args.task, from_checkpoint=True) + model = task.build_model(w2v_args.model, from_checkpoint=True) + model.remove_pretraining_modules() + d = w2v_args.model.encoder_embed_dim + else: + assert cfg.normalize + + if hasattr(w2v_args.task, "audio"): + w2v_args.task.audio.data = cfg.data + else: + w2v_args.task.data = cfg.data + task = tasks.setup_task(w2v_args.task, from_checkpoint=True) + + model = task.build_model(w2v_args.model, from_checkpoint=True) + + model.remove_pretraining_modules(modality="audio") + d = w2v_args.model.embed_dim + + if state is not None and not cfg.no_pretrained_weights: + if cfg.load_ema: + assert "_ema" in state["model"] + for k in state["model"]["_ema"]: + mk = "encoder." + k + assert mk in state["model"], mk + state["model"][mk] = state["model"]["_ema"][k] + self.load_model_weights(state, model, cfg) + + super().__init__(task.source_dictionary) + + self.w2v_model = model + + self.final_dropout = nn.Dropout(cfg.final_dropout) + self.freeze_finetune_updates = cfg.freeze_finetune_updates + self.num_updates = 0 + + targ_d = None + self.proj = None + + if output_size is not None: + targ_d = output_size + elif getattr(cfg, "decoder_embed_dim", d) != d: + targ_d = cfg.decoder_embed_dim + + if targ_d is not None: + self.proj = Linear(d, targ_d) + + if cfg.freeze_regex is not None: + self.freeze_regex(cfg.freeze_regex) + + layer_decay = getattr(cfg, "layer_decay", 1) + if layer_decay < 1: + mod_encs = list(model.modality_encoders.values()) + assert len(mod_encs) == 1, len(mod_encs) + blocks = list(mod_encs[0].context_encoder.blocks) + list(model.blocks) + num_layers = len(blocks) + 1 + layer_scales = list( + layer_decay ** (num_layers - i) for i in range(num_layers + 1) + ) + + for i, b in enumerate(blocks): + lid = i + 1 + if layer_scales[lid] == 1.0: + continue + + for n, p in b.named_parameters(): + optim_override = getattr(p, "optim_overrides", {}) + if "optimizer" not in optim_override: + optim_override["optimizer"] = {} + + optim_override["optimizer"]["lr_scale"] = layer_scales[lid] + p.optim_overrides = optim_override + + def freeze_regex(self, pattern): + unfrozen_names = [] + for name, param in self.named_parameters(): + if re.fullmatch(pattern, name) is not None: + param.requires_grad_(False) + else: + unfrozen_names.append(name) + + def load_model_weights(self, state, model, cfg): + if cfg.ddp_backend == "fully_sharded": + from fairseq.distributed import FullyShardedDataParallel + + for name, module in model.named_modules(): + if "encoder.layers" in name and len(name.split(".")) == 3: + # Only for layers, we do a special handling and load the weights one by one + # We dont load all weights together as that wont be memory efficient and may + # cause oom + new_dict = { + k.replace(name + ".", ""): v + for (k, v) in state["model"].items() + if name + "." in k + } + assert isinstance(module, FullyShardedDataParallel) + with module.summon_full_params(): + module.load_state_dict(new_dict, strict=True) + module._reset_lazy_init() + + # Once layers are loaded, filter them out and load everything else. + r = re.compile("encoder.layers.\d.") + filtered_list = list(filter(r.match, state["model"].keys())) + + new_big_dict = { + k: v for (k, v) in state["model"].items() if k not in filtered_list + } + + model.load_state_dict(new_big_dict, strict=False) + else: + to_delete = {"_ema", "target_proj", "decoder"} + for k in to_delete: + if k in state["model"]: + del state["model"][k] + + if hasattr(model, "modality_encoders"): + if "modality_encoders.AUDIO.encoder_mask" not in state["model"]: + model.modality_encoders["AUDIO"].encoder_mask = None + elif not cfg.zero_mask: + model.modality_encoders["AUDIO"].encoder_mask = None + del state["model"]["modality_encoders.AUDIO.encoder_mask"] + + for k in list(state["model"].keys()): + if k.startswith("modality_encoders.") and not k.startswith( + "modality_encoders.AUDIO" + ): + del state["model"][k] + + print(model) + model.load_state_dict(state["model"], strict=True) + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + super().set_num_updates(num_updates) + self.num_updates = num_updates + + def forward(self, source, padding_mask, **kwargs): + + w2v_args = { + "source": source, + "padding_mask": padding_mask, + "mask": self.apply_mask and self.training, + } + if "corpus_key" in kwargs: + w2v_args["corpus_key"] = kwargs["corpus_key"] + + if self.is_d2v_multi: + w2v_args["mode"] = "AUDIO" + + ft = self.freeze_finetune_updates <= self.num_updates + + with torch.no_grad() if not ft else contextlib.ExitStack(): + res = self.w2v_model.extract_features(**w2v_args) + + x = res["x"] + padding_mask = res["padding_mask"] + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + x = self.final_dropout(x) + + if self.proj: + x = self.proj(x) + + return { + "encoder_out": x, # T x B x C + "padding_mask": padding_mask, # B x T, + "layer_results": res["layer_results"], + } + + def forward_torchscript(self, net_input): + if torch.jit.is_scripting(): + return self.forward(net_input["source"], net_input["padding_mask"]) + else: + return self.forward_non_torchscript(net_input) + + def reorder_encoder_out(self, encoder_out, new_order): + if encoder_out["encoder_out"] is not None: + encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select( + 1, new_order + ) + if encoder_out["padding_mask"] is not None: + encoder_out["padding_mask"] = encoder_out["padding_mask"].index_select( + 0, new_order + ) + return encoder_out + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return None + + def upgrade_state_dict_named(self, state_dict, name): + return state_dict + + +class TransformerDecoder(FairseqIncrementalDecoder): + """ + Transformer decoder consisting of *args.decoder_layers* layers. Each layer + is a :class:`TransformerDecoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): decoding dictionary + embed_tokens (torch.nn.Embedding): output embedding + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__( + self, + cfg: Wav2Vec2Seq2SeqConfig, + dictionary, + embed_tokens, + no_encoder_attn=False, + ): + super().__init__(dictionary) + + self.dropout = cfg.decoder_dropout + self.share_input_output_embed = cfg.share_decoder_input_output_embed + + input_embed_dim = embed_tokens.embedding_dim + embed_dim = cfg.decoder_embed_dim + self.output_embed_dim = cfg.decoder_embed_dim + + self.layerdrop = cfg.decoder_layerdrop + + self.padding_idx = embed_tokens.padding_idx + self.max_target_positions = cfg.max_target_positions + + self.embed_tokens = embed_tokens + self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim + + self.project_in_dim = ( + Linear(input_embed_dim, embed_dim, bias=False) + if embed_dim != input_embed_dim + else None + ) + + self.embed_positions = ( + PositionalEmbedding( + cfg.max_target_positions, + embed_dim, + self.padding_idx, + learned=cfg.decoder_learned_pos, + ) + if not cfg.no_token_positional_embeddings + else None + ) + + # TODO: update this when transformer gets converted to dataclass configs + transformer_cfg = copy.deepcopy(cfg) + with open_dict(transformer_cfg): + transformer_cfg.dropout = transformer_cfg.decoder_dropout + transformer_cfg.attention_dropout = ( + transformer_cfg.decoder_attention_dropout + ) + transformer_cfg.activation_dropout = ( + transformer_cfg.decoder_activation_dropout + ) + + self.layers = nn.ModuleList([]) + self.layers.extend( + [ + TransformerDecoderLayer(transformer_cfg, no_encoder_attn) + for _ in range(transformer_cfg.decoder_layers) + ] + ) + + if not self.share_input_output_embed: + self.embed_out = nn.Parameter( + torch.Tensor(len(dictionary), self.output_embed_dim) + ) + nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim**-0.5) + + if transformer_cfg.decoder_normalize_before: + self.layer_norm = LayerNorm(embed_dim) + else: + self.layer_norm = None + + def forward( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (Tensor, optional): output from the encoder, used for + encoder-side attention + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + + if type(prev_output_tokens) == list: + max_len = max((len(x) for x in prev_output_tokens)) + tmp = torch.zeros( + [len(prev_output_tokens), max_len], device=prev_output_tokens[0].device + ) + for (i, p) in enumerate(prev_output_tokens): + tmp[i, : len(p)] = p + prev_output_tokens = tmp + + prev_output_tokens = prev_output_tokens.long() + x, extra = self.extract_features( + prev_output_tokens, encoder_out, incremental_state + ) + x = self.output_layer(x) + return x, extra + + def extract_features( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused + ): + """ + Similar to *forward* but only return features. + + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + + # embed positions + positions = ( + self.embed_positions( + prev_output_tokens, incremental_state=incremental_state + ) + if self.embed_positions is not None + else None + ) + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if positions is not None: + x += positions + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + attn = None + + inner_states = [x] + + # decoder layers + self_attn_padding_mask = None + if prev_output_tokens.eq(self.padding_idx).any(): + self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) + for layer in self.layers: + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, attn, _ = layer( + x, + encoder_out["encoder_out"] if encoder_out is not None else None, + encoder_out["padding_mask"] if encoder_out is not None else None, + incremental_state, + self_attn_mask=self.buffered_future_mask(x) + if incremental_state is None + else None, + self_attn_padding_mask=self_attn_padding_mask, + ) + inner_states.append(x) + + if self.layer_norm: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, {"attn": attn, "inner_states": inner_states} + + def output_layer(self, features, **kwargs): + """Project features to the vocabulary size.""" + # project back to size of vocabulary + if self.share_input_output_embed: + return F.linear(features, self.embed_tokens.weight) + else: + return F.linear(features, self.embed_out) + + def max_positions(self): + """Maximum output length supported by the decoder.""" + if self.embed_positions is None: + return self.max_target_positions + return min(self.max_target_positions, self.embed_positions.max_positions) + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + if ( + not hasattr(self, "_future_mask") + or self._future_mask is None + or self._future_mask.device != tensor.device + or self._future_mask.size(0) < dim + ): + self._future_mask = torch.triu( + utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 + ) + return self._future_mask[:dim, :dim] + + def upgrade_state_dict_named(self, state_dict, name): + return state_dict + + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5) + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +def Linear(in_features, out_features, bias=True): + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.0) + return m diff --git a/fairseq/models/wav2vec/wav2vec2_classification.py b/fairseq/models/wav2vec/wav2vec2_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..c9bbaab28e16e2a8eab153b16d7f003979934d03 --- /dev/null +++ b/fairseq/models/wav2vec/wav2vec2_classification.py @@ -0,0 +1,348 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import logging +from argparse import Namespace +from dataclasses import dataclass, field +from typing import Any, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from omegaconf import II, MISSING, open_dict + +from fairseq import checkpoint_utils, tasks, utils +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.models import BaseFairseqModel, FairseqEncoder, register_model +from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES, Wav2Vec2Config +from fairseq.models.wav2vec.wav2vec2_asr import Embedding, Linear, Wav2VecEncoder, Wav2Vec2AsrConfig +from fairseq.tasks import FairseqTask + +logging.basicConfig(level=logging.DEBUG) + + +@dataclass +class Wav2Vec2ClassificationConfig(Wav2Vec2AsrConfig): + latent_embed_dim: Optional[int] = field( + default=None, metadata={"help": "latent dim (encoder w2v -> latent -> class"} + ) + pooling: str = field( + default="first_token", + metadata={"help": "pooling layer choices"}, + ) + activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( + default="gelu", metadata={"help": "activation function to use"} + ) + + +@register_model("wav2vec_classification", dataclass=Wav2Vec2ClassificationConfig) +class Wav2VecClassification(BaseFairseqModel): + # TODO: Can be shared/merged with ASR model class as w2v_encoder params are common. + def __init__( + self, + cfg: Wav2Vec2ClassificationConfig, + w2v_encoder: BaseFairseqModel, + pooling_layer, + ): + super().__init__() + self.cfg = cfg + self.w2v_encoder = w2v_encoder + self.pooling_layer = pooling_layer + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + @classmethod + def build_model(cls, cfg: Wav2Vec2ClassificationConfig, task: FairseqTask): + """Build a new model instance.""" + w2v_encoder = Wav2VecEncoder(cfg, None) + pooling_layer = get_pooling_layer( + cfg, + w2v_encoder.w2v_model.encoder.layers[-1].embedding_dim, + len(task.target_dictionary), + len(w2v_encoder.w2v_model.encoder.layers), + ) + return cls(cfg, w2v_encoder, pooling_layer) + + def get_normalized_probs(self, net_output, log_probs): + """Get normalized probabilities (or log probs) from a net's output.""" + logits = net_output + + if log_probs: + return utils.log_softmax(logits.float(), dim=-1) + else: + return utils.softmax(logits.float(), dim=-1) + + def get_logits(self, net_output): + return net_output + + def forward(self, **kwargs): + encoder_out_dict = self.w2v_encoder(**kwargs) + w2v_encoder_out = encoder_out_dict["encoder_out"] # TxBxC + w2v_encoder_padding_mask = encoder_out_dict["padding_mask"] # BxT + # w2v_encoder_layer_results = encoder_out_dict["layer_results"] + return self.pooling_layer( + last_layer_feats=w2v_encoder_out, + padding_mask=w2v_encoder_padding_mask, + # all_layer_feats=w2v_encoder_layer_results, + ) + + # def forward_latent(self, **kwargs): + # encoder_out_dict = self.w2v_encoder(**kwargs) + # w2v_encoder_out = encoder_out_dict["encoder_out"] + # w2v_encoder_padding_mask = encoder_out_dict["encoder_padding_mask"] + # w2v_encoder_layer_results = encoder_out_dict["layer_results"] + # return self.pooling_layer.forward_latent( + # last_layer_feats=w2v_encoder_out, + # padding_mask=w2v_encoder_padding_mask, + # all_layer_feats=w2v_encoder_layer_results, + # ) + + +def get_pooling_layer( + cfg: Wav2Vec2ClassificationConfig, + encoder_embed_dim: int, + num_targets: int, + encoder_layers: int, +): + assert cfg.pooling == 'mean' + if cfg.pooling == "first_token": + return FirstToken(cfg, encoder_embed_dim, num_targets) + # elif cfg.pooling == "mean": + # return MeanPooling(cfg, encoder_embed_dim, num_targets) + elif cfg.pooling == "mean": + return MeanPoolingFast(cfg, encoder_embed_dim, num_targets) + elif cfg.pooling == "mean_amsoftmax": + return MeanPoolingFastAMSoftmax(cfg, encoder_embed_dim, num_targets) + elif cfg.pooling == "max": + return MaxPoolingFast(cfg, encoder_embed_dim, num_targets) + elif cfg.pooling == "elmo": + return LayerWeightedMeanPooling( + cfg, encoder_embed_dim, num_targets, encoder_layers + ) + else: + raise NotImplementedError(f"{cfg.pooling} has not been implemented yet.") + + +class Pooling(nn.Module): + def __init__( + self, + cfg: Wav2Vec2ClassificationConfig, + encoder_embed_dim: int, + num_targets: int, + ): + super().__init__() + self.projection = Linear(encoder_embed_dim, num_targets) + + def forward(self, last_layer_feats, **kwargs): + raise NotImplementedError() + + +class FirstToken(Pooling): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, last_layer_feats, **kwargs): + return self.projection(last_layer_feats[:, 0]) + + +# class MeanPooling(Pooling): +# def __init__( +# self, +# cfg: Wav2VecClassificationConfig, +# encoder_embed_dim: int, +# num_targets: int, +# **kwargs, +# ): +# super().__init__(cfg, encoder_embed_dim, num_targets) +# self.activation_fn = utils.get_activation_fn(cfg.activation_fn) +# self.linear = Linear(encoder_embed_dim, encoder_embed_dim) + +# def forward(self, last_layer_feats, padding_mask, **kwargs): +# # last_layer_feats: [BxTxD] +# # padding_mask: [BxT] +# last_layer_feats = self.linear(self.activation_fn(last_layer_feats)) +# input_lengths = (1 - padding_mask.long()).sum(-1) +# pooled_feature_list = [] +# for i in range(len(last_layer_feats)): +# length = input_lengths[i] +# pooled_feature = torch.mean(last_layer_feats[i][:length], dim=0) +# pooled_feature_list.append(pooled_feature) +# return self.projection(torch.stack(pooled_feature_list)) + + +def fn_mean(x, mask): + """ + Args: + x: TxBxD + mask: BxT + Return: + y: BxD + """ + if mask is not None: + mask = mask.t()[:, :, None] + return (x * mask).sum(0) / mask.sum(0) + else: + return x.sum(0) / x.shape[0] + + +class MeanPoolingFast(nn.Module): + def __init__( + self, + cfg: Wav2Vec2ClassificationConfig, + encoder_embed_dim: int, + num_targets: int, + **kwargs, + ): + super().__init__() + self.activation_fn = utils.get_activation_fn(cfg.activation_fn) + self.latent_embed_dim = ( + cfg.latent_embed_dim + if cfg.latent_embed_dim is not None + else encoder_embed_dim + ) + logging.debug(f"| {self.latent_embed_dim=}") + self.linear = Linear(encoder_embed_dim, self.latent_embed_dim) + self.projection = Linear(self.latent_embed_dim, num_targets) + + def forward(self, last_layer_feats, padding_mask, **kwargs): + """ + Arguments + features - [TxBxD] Acoustic feature with shape + padding_mask - [BxT] Padding Mask + """ + if padding_mask is not None: + feat_mask = (~padding_mask).to(last_layer_feats.dtype) + else: + feat_mask = None + feat = self.linear(last_layer_feats) + feat = fn_mean(feat, feat_mask) + feat = self.activation_fn(feat) + return self.projection(feat) + + def forward_latent(self, last_layer_feats, padding_mask, **kwargs): + """ + Arguments + features - [TxBxD] Acoustic feature with shape + padding_mask - [BxT] Padding Mask + """ + if padding_mask is not None: + feat_mask = (~padding_mask).to(last_layer_feats.dtype) + else: + feat_mask = None + feat = self.linear(last_layer_feats) + feat = fn_mean(feat, feat_mask) + return feat + + +class MeanPoolingFastAMSoftmax(MeanPoolingFast): + def __init__( + self, + cfg: Wav2Vec2ClassificationConfig, + encoder_embed_dim: int, + num_targets: int, + **kwargs, + ): + super().__init__(cfg, encoder_embed_dim, num_targets, **kwargs) + self.projection = Linear(self.latent_embed_dim, num_targets, bias=False) + nn.init.xavier_normal_(self.projection.weight, gain=1) + + def forward(self, last_layer_feats, padding_mask, **kwargs): + + """ + Arguments + features - [BxTxD] Acoustic feature with shape + padding_mask - [BxT] Padding Mask + """ + feat_mask = (~padding_mask).to(last_layer_feats.dtype) # T,B -> B,T + feat = self.linear(last_layer_feats) # B,T,D + feat = fn_mean(feat, feat_mask) # B,D + feat = self.activation_fn(feat) + # normalize feat + feat_norm = F.normalize(feat, p=2, dim=-1) # B,D + weight_norm = F.normalize(self.projection.weight.t(), p=2, dim=-1) # D,K + cos_fw = feat_norm @ weight_norm + return cos_fw + + +def fn_max(x, mask): + """ + Args: + x: TxBxD + mask: BxT + Return: + y: BxD + """ + mask = mask.t()[:, :, None].to(torch.bool) + return x.masked_fill(~mask, -1e-8).max(0)[0] + + +class MaxPoolingFast(Pooling): + def __init__( + self, + cfg: Wav2Vec2ClassificationConfig, + encoder_embed_dim: int, + num_targets: int, + **kwargs, + ): + super().__init__(cfg, encoder_embed_dim, num_targets) + self.activation_fn = utils.get_activation_fn(cfg.activation_fn) + self.linear = Linear(encoder_embed_dim, encoder_embed_dim) + + def forward(self, last_layer_feats, padding_mask, **kwargs): + + """ + Arguments + features - [TxBxD] Acoustic feature with shape + padding_mask - [BxT] Padding Mask + """ + feat_mask = (~padding_mask).to(last_layer_feats.dtype) + feat = self.linear(last_layer_feats) + feat = fn_max(feat, feat_mask) + feat = self.activation_fn(feat) + return self.projection(feat) + + +class LayerWeightedMeanPooling(MeanPoolingFast): + """Elmo-style weighted average representation.""" + + def __init__( + self, + cfg: Wav2Vec2ClassificationConfig, + encoder_embed_dim: int, + num_targets: int, + encoder_layers: int, + ): + super().__init__(cfg, encoder_embed_dim, num_targets) + self.num_layers = encoder_layers + self.weights = nn.Parameter(torch.ones(encoder_layers)) + + def forward(self, last_layer_feats, padding_mask, all_layer_feats): + # last_layer_feats: [BxTxD] + # padding_mask: [BxT] + if not self.training: + msg = ( + f"Number of layers in input features = {len(all_layer_feats)}." + f" Expected {self.num_layers} layers." + ) + assert len(all_layer_feats) == self.num_layers, msg + + # Stack up all layers and reshape to (num_layers, features) + all_layer_feats_stacked = torch.stack(all_layer_feats, dim=0) + num_layers, *original_feat_shape = all_layer_feats_stacked.shape + all_layer_feats_stacked_flat = all_layer_feats_stacked.view(num_layers, -1) + + # Weighted average + normalized_weights = F.softmax(self.weights, dim=-1) + weighted_avg_features = ( + normalized_weights.unsqueeze(-1) * all_layer_feats_stacked_flat + ).sum(dim=0) + weighted_avg_features = weighted_avg_features.view(*original_feat_shape) + + # Mean Pooling on weighted average features. + return super().forward(weighted_avg_features, padding_mask) \ No newline at end of file diff --git a/fairseq/models/wav2vec/wav2vec2_laser.py b/fairseq/models/wav2vec/wav2vec2_laser.py new file mode 100644 index 0000000000000000000000000000000000000000..ff89759d38cc1837270a8571d279f11cf1edfa75 --- /dev/null +++ b/fairseq/models/wav2vec/wav2vec2_laser.py @@ -0,0 +1,39 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.models import BaseFairseqModel, register_model +from fairseq.models.wav2vec.wav2vec2_asr import ( + Wav2Vec2CtcConfig, + Wav2VecCtc, + Wav2VecEncoder, +) +from fairseq.tasks import FairseqTask + + +@register_model("wav2vec2_laser", dataclass=Wav2Vec2CtcConfig) +class Wav2VecLaser(Wav2VecCtc): + def __init__(self, cfg: Wav2Vec2CtcConfig, w2v_encoder: BaseFairseqModel): + super().__init__(cfg, w2v_encoder) + self.num_updates = 0 + self.freeze_finetune_updates = cfg.freeze_finetune_updates + + @classmethod + def build_model(cls, cfg: Wav2Vec2CtcConfig, task: FairseqTask): + """Build a new model instance.""" + w2v_encoder = Wav2VecEncoder(cfg, 1024) + return cls(cfg, w2v_encoder) + + def forward(self, **kwargs): + output = super().forward(**kwargs) + x_out = output["encoder_out"] * 0.01 + out_pad_mask = output["padding_mask"] + # Set padded outputs to -inf so they are not selected by max-pooling + if out_pad_mask is not None and out_pad_mask.any(): + x_out = ( + x_out.float() + .masked_fill_(out_pad_mask.T.unsqueeze(-1), float("-inf")) + .type_as(x_out) + ) + return x_out.max(dim=0)[0] diff --git a/fairseq/models/xmod/__init__.py b/fairseq/models/xmod/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf7694920eb00bed27e17dac272611be1ab44f9 --- /dev/null +++ b/fairseq/models/xmod/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .model import * # noqa +from .transformer_layer_xmod import * # noqa diff --git a/fairseq/models/xmod/__pycache__/__init__.cpython-310.pyc b/fairseq/models/xmod/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83687d5b1825e70af3123f79ad76ec6c20e7f142 Binary files /dev/null and b/fairseq/models/xmod/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/models/xmod/__pycache__/hub_interface.cpython-310.pyc b/fairseq/models/xmod/__pycache__/hub_interface.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..948a623e9bbb978f1e3b0a57f8e0902cabb480f3 Binary files /dev/null and b/fairseq/models/xmod/__pycache__/hub_interface.cpython-310.pyc differ diff --git a/fairseq/models/xmod/__pycache__/model.cpython-310.pyc b/fairseq/models/xmod/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5d17e73a55689e8f7e63d539f05f69431453dc6 Binary files /dev/null and b/fairseq/models/xmod/__pycache__/model.cpython-310.pyc differ diff --git a/fairseq/models/xmod/__pycache__/transformer_layer_xmod.cpython-310.pyc b/fairseq/models/xmod/__pycache__/transformer_layer_xmod.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8725b447dd33a96df4fa47df755709b909df717 Binary files /dev/null and b/fairseq/models/xmod/__pycache__/transformer_layer_xmod.cpython-310.pyc differ diff --git a/fairseq/models/xmod/hub_interface.py b/fairseq/models/xmod/hub_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..909bb423cac449b7f48f11e47126730aeadee919 --- /dev/null +++ b/fairseq/models/xmod/hub_interface.py @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from fairseq.models.roberta.hub_interface import RobertaHubInterface +import torch +import torch.nn.functional as F + + +class XMODHubInterface(RobertaHubInterface): + def extract_features( + self, + tokens: torch.LongTensor, + return_all_hiddens: bool = False, + lang_id=None, + ) -> torch.Tensor: + if tokens.dim() == 1: + tokens = tokens.unsqueeze(0) + if tokens.size(-1) > self.model.max_positions(): + raise ValueError( + "tokens exceeds maximum length: {} > {}".format( + tokens.size(-1), self.model.max_positions() + ) + ) + features, extra = self.model( + tokens.to(device=self.device), + features_only=True, + return_all_hiddens=return_all_hiddens, + lang_id=lang_id, + ) + if return_all_hiddens: + # convert from T x B x C -> B x T x C + inner_states = extra["inner_states"] + return [inner_state.transpose(0, 1) for inner_state in inner_states] + else: + return features # just the last layer's features + + def predict( + self, + head: str, + tokens: torch.LongTensor, + return_logits: bool = False, + lang_id=None, + ): + features = self.extract_features(tokens.to(device=self.device), lang_id=lang_id) + logits = self.model.classification_heads[head](features) + if return_logits: + return logits + return F.log_softmax(logits, dim=-1) diff --git a/fairseq/models/xmod/model.py b/fairseq/models/xmod/model.py new file mode 100644 index 0000000000000000000000000000000000000000..fb6c7a8deac1656855e6cdca9d4cf156942e10b2 --- /dev/null +++ b/fairseq/models/xmod/model.py @@ -0,0 +1,742 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from ..roberta.model_xlmr import XLMRModel +from fairseq.models.xmod.transformer_layer_xmod import XMODTransformerEncoderLayerBase +from ..roberta.model import base_architecture, RobertaEncoder +from fairseq.models.transformer import TransformerEncoder +from fairseq.modules.transformer_sentence_encoder import init_bert_params +from typing import Optional +from fairseq.models.xmod.hub_interface import XMODHubInterface +import torch +from fairseq.distributed import fsdp_wrap +from fairseq.models import ( + register_model, + register_model_architecture, +) + +from fairseq.modules.checkpoint_activations import checkpoint_wrapper + +DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) + + +@register_model("xmod") +class XMODModel(XLMRModel): + @classmethod + def hub_models(cls): + return { + "xmod.base": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.81.1M.tar.gz", + "xmod.large.prenorm": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.large.prenorm.81.500k.tar.gz", + "xmod.base.13.125k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.13.125k.tar.gz", + "xmod.base.30.125k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.30.125k.tar.gz", + "xmod.base.30.195k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.30.195k.tar.gz", + "xmod.base.60.125k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.60.125k.tar.gz", + "xmod.base.60.265k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.60.265k.tar.gz", + "xmod.base.75.125k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.75.125k.tar.gz", + "xmod.base.75.269k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.75.269k.tar.gz", + } + + @classmethod + def from_pretrained( + cls, + model_name_or_path, + checkpoint_file="model.pt", + data_name_or_path=".", + bpe="sentencepiece", + **kwargs, + ): + from fairseq import hub_utils + + x = hub_utils.from_pretrained( + model_name_or_path, + checkpoint_file, + data_name_or_path, + archive_map=cls.hub_models(), + bpe=bpe, + load_checkpoint_heads=True, + **kwargs, + ) + return XMODHubInterface(x["args"], x["task"], x["models"][0]) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + from omegaconf import OmegaConf + + if OmegaConf.is_config(args): + OmegaConf.set_struct(args, False) + + # make sure all arguments are present + base_architecture(args) + + if not hasattr(args, "max_positions"): + if not hasattr(args, "tokens_per_sample"): + args.tokens_per_sample = task.max_positions() + args.max_positions = args.tokens_per_sample + + encoder = XMODEncoder(args, task.source_dictionary) + + if OmegaConf.is_config(args): + OmegaConf.set_struct(args, True) + + return cls(args, encoder) + + def forward( + self, + src_tokens, + features_only=False, + return_all_hiddens=False, + classification_head_name=None, + lang_id=None, + **kwargs, + ): + if classification_head_name is not None: + features_only = True + x, extra = self.encoder( + src_tokens, features_only, return_all_hiddens, lang_id=lang_id, **kwargs + ) + + if classification_head_name is not None: + x = self.classification_heads[classification_head_name](x) + return x, extra + + +class XMODEncoder(RobertaEncoder): + """XMOD encoder.""" + + def build_encoder(self, args, dictionary, embed_tokens): + encoder = XMODTransformerEncoder(args, dictionary, embed_tokens) + encoder.apply(init_bert_params) + return encoder + + def forward( + self, + src_tokens, + features_only=False, + return_all_hiddens=False, + masked_tokens=None, + lang_id=None, + **unused, + ): + """ + Args: + src_tokens (LongTensor): input tokens of shape `(batch, src_len)` + features_only (bool, optional): skip LM head and just return + features. If True, the output will be of shape + `(batch, src_len, embed_dim)`. + return_all_hiddens (bool, optional): also return all of the + intermediate hidden states (default: False). + + Returns: + tuple: + - the LM output of shape `(batch, src_len, vocab)` + - a dictionary of additional data, where 'inner_states' + is a list of hidden states. Note that the hidden + states have shape `(src_len, batch, vocab)`. + """ + x, extra = self.extract_features( + src_tokens, return_all_hiddens=return_all_hiddens, lang_id=lang_id + ) + if not features_only: + x = self.output_layer(x, masked_tokens=masked_tokens) + return x, extra + + def extract_features( + self, src_tokens, return_all_hiddens=False, lang_id=None, **kwargs + ): + encoder_out = self.sentence_encoder( + src_tokens, + return_all_hiddens=return_all_hiddens, + lang_id=lang_id, + token_embeddings=kwargs.get("token_embeddings", None), + ) + # T x B x C -> B x T x C + features = encoder_out["encoder_out"][0].transpose(0, 1) + inner_states = encoder_out["encoder_states"] if return_all_hiddens else None + return features, {"inner_states": inner_states} + + +class XMODTransformerEncoder(TransformerEncoder): + def build_encoder_layer(self, cfg): + layer = XMODTransformerEncoderLayerBase(cfg) + checkpoint = cfg.checkpoint_activations + if checkpoint: + offload_to_cpu = cfg.offload_activations + layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) + # if we are checkpointing, enforce that FSDP always wraps the + # checkpointed layer, regardless of layer size + min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 + layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) + return layer + + def forward( + self, + src_tokens, + src_lengths: Optional[torch.Tensor] = None, + return_all_hiddens: bool = False, + token_embeddings: Optional[torch.Tensor] = None, + lang_id=None, + ): + """ + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (torch.LongTensor): lengths of each source sentence of + shape `(batch)` + return_all_hiddens (bool, optional): also return all of the + intermediate hidden states (default: False). + token_embeddings (torch.Tensor, optional): precomputed embeddings + default `None` will recompute embeddings + + Returns: + dict: + - **encoder_out** (Tensor): the last encoder layer's output of + shape `(src_len, batch, embed_dim)` + - **encoder_padding_mask** (ByteTensor): the positions of + padding elements of shape `(batch, src_len)` + - **encoder_embedding** (Tensor): the (scaled) embedding lookup + of shape `(batch, src_len, embed_dim)` + - **encoder_states** (List[Tensor]): all intermediate + hidden states of shape `(src_len, batch, embed_dim)`. + Only populated if *return_all_hiddens* is True. + """ + return self.forward_scriptable( + src_tokens, + src_lengths, + return_all_hiddens, + token_embeddings, + lang_id=lang_id, + ) + # TorchScript doesn't support super() method so that the scriptable Subclass + # can't access the base class model in Torchscript. + # Current workaround is to add a helper function with different name and + # call the helper function from scriptable Subclass. + + def forward_scriptable( + self, + src_tokens, + src_lengths: Optional[torch.Tensor] = None, + return_all_hiddens: bool = False, + token_embeddings: Optional[torch.Tensor] = None, + lang_id=None, + ): + """ + Args: + src_tokens (LongTensor): tokens in the source language of shape + `(batch, src_len)` + src_lengths (torch.LongTensor): lengths of each source sentence of + shape `(batch)` + return_all_hiddens (bool, optional): also return all of the + intermediate hidden states (default: False). + token_embeddings (torch.Tensor, optional): precomputed embeddings + default `None` will recompute embeddings + + Returns: + dict: + - **encoder_out** (Tensor): the last encoder layer's output of + shape `(src_len, batch, embed_dim)` + - **encoder_padding_mask** (ByteTensor): the positions of + padding elements of shape `(batch, src_len)` + - **encoder_embedding** (Tensor): the (scaled) embedding lookup + of shape `(batch, src_len, embed_dim)` + - **encoder_states** (List[Tensor]): all intermediate + hidden states of shape `(src_len, batch, embed_dim)`. + Only populated if *return_all_hiddens* is True. + """ + # compute padding mask + encoder_padding_mask = src_tokens.eq(self.padding_idx) + has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any() + + x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) + + # account for padding while computing the representation + if has_pads: + x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + encoder_states = [] + + if return_all_hiddens: + encoder_states.append(x) + + # encoder layers + for layer in self.layers: + x = layer( + x, + encoder_padding_mask=encoder_padding_mask if has_pads else None, + lang_id=lang_id, + ) + if return_all_hiddens: + assert encoder_states is not None + encoder_states.append(x) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in + # `forward` so we use a dictionary instead. + # TorchScript does not support mixed values so the values are all lists. + # The empty list is equivalent to None. + src_lengths = ( + src_tokens.ne(self.padding_idx) + .sum(dim=1, dtype=torch.int32) + .reshape(-1, 1) + .contiguous() + ) + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [encoder_padding_mask], # B x T + "encoder_embedding": [encoder_embedding], # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], + "src_lengths": [src_lengths], + } + + +@register_model_architecture("xmod", "xmod_base_13") +def roberta_base_architecture(args): + args.ffn_modules = getattr(args, "ffn_modules", False) + args.adapter_modules = getattr(args, "adapter_modules", True) + args.adapter_layer_norm = getattr(args, "adapter_layer_norm", False) + args.adapter_reuse_layer_norm = getattr(args, "adapter_reuse_layer_norm", True) + args.ln_before_adapter = getattr(args, "ln_before_adapter", True) + args.languages = getattr( + args, + "languages", + [ + "ar_AR", + "en_XX", + "fi_FI", + "fr_XX", + "hi_IN", + "id_ID", + "ka_GE", + "ko_KR", + "ru_RU", + "sw_KE", + "ta_IN", + "th_TH", + "vi_VN", + ], + ) + base_architecture(args) + + +@register_model_architecture("xmod", "xmod_base_30") +def roberta_base_architecture(args): + args.ffn_modules = getattr(args, "ffn_modules", False) + args.adapter_modules = getattr(args, "adapter_modules", True) + args.adapter_layer_norm = getattr(args, "adapter_layer_norm", False) + args.adapter_reuse_layer_norm = getattr(args, "adapter_reuse_layer_norm", True) + args.ln_before_adapter = getattr(args, "ln_before_adapter", True) + args.languages = getattr( + args, + "languages", + [ + "ar_AR", + "cs_CZ", + "en_XX", + "eu_ES", + "fi_FI", + "fr_XX", + "hi_IN", + "hr_HR", + "hu_HU", + "hy_AM", + "id_ID", + "it_IT", + "ka_GE", + "ko_KR", + "lt_LT", + "ml_IN", + "mn_MN", + "ms_MY", + "pl_PL", + "ro_RO", + "ru_RU", + "si_LK", + "sk_SK", + "sq_AL", + "sv_SE", + "sw_KE", + "ta_IN", + "th_TH", + "tl_XX", + "vi_VN", + ], + ) + base_architecture(args) + + +@register_model_architecture("xmod", "xmod_base_60") +def roberta_base_architecture(args): + args.ffn_modules = getattr(args, "ffn_modules", False) + args.adapter_modules = getattr(args, "adapter_modules", True) + args.adapter_layer_norm = getattr(args, "adapter_layer_norm", False) + args.adapter_reuse_layer_norm = getattr(args, "adapter_reuse_layer_norm", True) + args.ln_before_adapter = getattr(args, "ln_before_adapter", True) + args.languages = getattr( + args, + "languages", + [ + "af_ZA", + "am_ET", + "ar_AR", + "be_BY", + "bn_IN", + "ca_ES", + "cs_CZ", + "cy_GB", + "da_DK", + "en_XX", + "eo_EO", + "et_EE", + "eu_ES", + "fa_IR", + "fi_FI", + "fr_XX", + "ga_IE", + "gl_ES", + "gu_IN", + "ha_NG", + "hi_IN", + "hr_HR", + "hu_HU", + "hy_AM", + "id_ID", + "is_IS", + "it_IT", + "ka_GE", + "ko_KR", + "ku_TR", + "la_VA", + "lt_LT", + "lv_LV", + "mk_MK", + "ml_IN", + "mn_MN", + "ms_MY", + "ne_NP", + "nl_XX", + "no_XX", + "pl_PL", + "ps_AF", + "pt_XX", + "ro_RO", + "ru_RU", + "sa_IN", + "sd_PK", + "si_LK", + "sk_SK", + "sl_SI", + "so_SO", + "sq_AL", + "sr_RS", + "sv_SE", + "sw_KE", + "ta_IN", + "te_IN", + "th_TH", + "tl_XX", + "vi_VN", + ], + ) + base_architecture(args) + + +@register_model_architecture("xmod", "xmod_base_75") +def roberta_base_architecture(args): + args.ffn_modules = getattr(args, "ffn_modules", False) + args.adapter_modules = getattr(args, "adapter_modules", True) + args.adapter_layer_norm = getattr(args, "adapter_layer_norm", False) + args.adapter_reuse_layer_norm = getattr(args, "adapter_reuse_layer_norm", True) + args.ln_before_adapter = getattr(args, "ln_before_adapter", True) + args.languages = getattr( + args, + "languages", + [ + "af_ZA", + "am_ET", + "ar_AR", + "as_IN", + "be_BY", + "bn_IN", + "br_FR", + "bs_BA", + "ca_ES", + "cs_CZ", + "cy_GB", + "da_DK", + "en_XX", + "eo_EO", + "et_EE", + "eu_ES", + "fa_IR", + "fi_FI", + "fr_XX", + "fy_NL", + "ga_IE", + "gd_GB", + "gl_ES", + "gu_IN", + "ha_NG", + "hi_IN", + "hr_HR", + "hu_HU", + "hy_AM", + "id_ID", + "is_IS", + "it_IT", + "jv_ID", + "ka_GE", + "kn_IN", + "ko_KR", + "ku_TR", + "la_VA", + "lt_LT", + "lv_LV", + "mg_MG", + "mk_MK", + "ml_IN", + "mn_MN", + "mr_IN", + "ms_MY", + "ne_NP", + "nl_XX", + "no_XX", + "om_KE", + "or_IN", + "pa_IN", + "pl_PL", + "ps_AF", + "pt_XX", + "ro_RO", + "ru_RU", + "sa_IN", + "sd_PK", + "si_LK", + "sk_SK", + "sl_SI", + "so_SO", + "sq_AL", + "sr_RS", + "su_ID", + "sv_SE", + "sw_KE", + "ta_IN", + "te_IN", + "th_TH", + "tl_XX", + "vi_VN", + "xh_ZA", + "yi_DE", + ], + ) + base_architecture(args) + + +@register_model_architecture("xmod", "xmod_base") +def roberta_base_architecture(args): + args.ffn_modules = getattr(args, "ffn_modules", False) + args.adapter_modules = getattr(args, "adapter_modules", True) + args.adapter_layer_norm = getattr(args, "adapter_layer_norm", False) + args.adapter_reuse_layer_norm = getattr(args, "adapter_reuse_layer_norm", True) + args.ln_before_adapter = getattr(args, "ln_before_adapter", True) + args.languages = getattr( + args, + "languages", + [ + "en_XX", + "id_ID", + "vi_VN", + "ru_RU", + "fa_IR", + "sv_SE", + "ja_XX", + "fr_XX", + "de_DE", + "ro_RO", + "ko_KR", + "hu_HU", + "es_XX", + "fi_FI", + "uk_UA", + "da_DK", + "pt_XX", + "no_XX", + "th_TH", + "pl_PL", + "bg_BG", + "nl_XX", + "zh_CN", + "he_IL", + "el_GR", + "it_IT", + "sk_SK", + "hr_HR", + "tr_TR", + "ar_AR", + "cs_CZ", + "lt_LT", + "hi_IN", + "zh_TW", + "ca_ES", + "ms_MY", + "sl_SI", + "lv_LV", + "ta_IN", + "bn_IN", + "et_EE", + "az_AZ", + "sq_AL", + "sr_RS", + "kk_KZ", + "ka_GE", + "tl_XX", + "ur_PK", + "is_IS", + "hy_AM", + "ml_IN", + "mk_MK", + "be_BY", + "la_VA", + "te_IN", + "eu_ES", + "gl_ES", + "mn_MN", + "kn_IN", + "ne_NP", + "sw_KE", + "si_LK", + "mr_IN", + "af_ZA", + "gu_IN", + "cy_GB", + "eo_EO", + "km_KH", + "ky_KG", + "uz_UZ", + "ps_AF", + "pa_IN", + "ga_IE", + "ha_NG", + "am_ET", + "lo_LA", + "ku_TR", + "so_SO", + "my_MM", + "or_IN", + "sa_IN", + ], + ) + base_architecture(args) + + +@register_model_architecture("xmod", "xmod_large_prenorm") +def roberta_base_architecture(args): + args.ffn_modules = getattr(args, "ffn_modules", False) + args.adapter_modules = getattr(args, "adapter_modules", True) + args.adapter_layer_norm = getattr(args, "adapter_layer_norm", True) + args.adapter_reuse_layer_norm = getattr(args, "adapter_reuse_layer_norm", False) + args.ln_before_adapter = getattr(args, "ln_before_adapter", False) + # args.bottleneck = getattr(args, "bottleneck", 8) + args.bottleneck = getattr(args, "bottleneck", 4) + args.languages = getattr( + args, + "languages", + [ + "en_XX", + "id_ID", + "vi_VN", + "ru_RU", + "fa_IR", + "sv_SE", + "ja_XX", + "fr_XX", + "de_DE", + "ro_RO", + "ko_KR", + "hu_HU", + "es_XX", + "fi_FI", + "uk_UA", + "da_DK", + "pt_XX", + "no_XX", + "th_TH", + "pl_PL", + "bg_BG", + "nl_XX", + "zh_CN", + "he_IL", + "el_GR", + "it_IT", + "sk_SK", + "hr_HR", + "tr_TR", + "ar_AR", + "cs_CZ", + "lt_LT", + "hi_IN", + "zh_TW", + "ca_ES", + "ms_MY", + "sl_SI", + "lv_LV", + "ta_IN", + "bn_IN", + "et_EE", + "az_AZ", + "sq_AL", + "sr_RS", + "kk_KZ", + "ka_GE", + "tl_XX", + "ur_PK", + "is_IS", + "hy_AM", + "ml_IN", + "mk_MK", + "be_BY", + "la_VA", + "te_IN", + "eu_ES", + "gl_ES", + "mn_MN", + "kn_IN", + "ne_NP", + "sw_KE", + "si_LK", + "mr_IN", + "af_ZA", + "gu_IN", + "cy_GB", + "eo_EO", + "km_KH", + "ky_KG", + "uz_UZ", + "ps_AF", + "pa_IN", + "ga_IE", + "ha_NG", + "am_ET", + "lo_LA", + "ku_TR", + "so_SO", + "my_MM", + "or_IN", + "sa_IN", + ], + ) + + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) + args.encoder_layers = getattr(args, "encoder_layers", 24) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + base_architecture(args) diff --git a/fairseq/models/xmod/transformer_layer_xmod.py b/fairseq/models/xmod/transformer_layer_xmod.py new file mode 100644 index 0000000000000000000000000000000000000000..47a91cdc2339f235efe0cecc85d6a7a260d19c6a --- /dev/null +++ b/fairseq/models/xmod/transformer_layer_xmod.py @@ -0,0 +1,179 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.modules.transformer_layer import TransformerEncoderLayer +from typing import Optional +import torch +import torch.nn as nn +from fairseq import utils +from fairseq.modules import LayerNorm +from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.quant_noise import quant_noise +from torch import Tensor + + +class Adapter(nn.Module): + def __init__(self, cfg, red_fac=2): + super(Adapter, self).__init__() + self.cfg = cfg + self.embed_dim = cfg.encoder_embed_dim + self.quant_noise = getattr(cfg, "quant_noise_pq", 0) + self.quant_noise_block_size = getattr(cfg, "quant_noise_pq_block_size", 8) or 8 + self.activation_fn = utils.get_activation_fn( + activation=getattr(cfg, "activation_fn", "relu") or "relu" + ) + self.fc1 = quant_noise( + nn.Linear(self.embed_dim, self.embed_dim // red_fac), + p=self.quant_noise, + block_size=self.quant_noise_block_size, + ) + self.fc2 = quant_noise( + nn.Linear(self.embed_dim // red_fac, self.embed_dim), + p=self.quant_noise, + block_size=self.quant_noise_block_size, + ) + activation_dropout_p = getattr(cfg, "activation_dropout", 0) or 0 + if activation_dropout_p == 0: + # for backwards compatibility with models that use cfg.relu_dropout + activation_dropout_p = getattr(cfg, "relu_dropout", 0) or 0 + self.activation_dropout_module = FairseqDropout( + float(activation_dropout_p), module_name=self.__class__.__name__ + ) + + def forward(self, x): + x = self.activation_fn(self.fc1(x)) + if not hasattr(self.cfg, "adapter_dropout") or self.cfg.adapter_dropout: + x = self.activation_dropout_module(x) + x = self.fc2(x) + return x + + +class XMODTransformerEncoderLayerBase(TransformerEncoderLayer): + """Encoder layer block. + + In the original paper each operation (multi-head attention or FFN) is + postprocessed with: `dropout -> add residual -> layernorm`. In the + tensor2tensor code they suggest that learning is more robust when + preprocessing each layer with layernorm and postprocessing with: + `dropout -> add residual`. We default to the approach in the paper, but the + tensor2tensor approach can be enabled by setting + *cfg.encoder.normalize_before* to ``True``. + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + + def __init__(self, cfg): + super().__init__(cfg) + if hasattr(cfg, "adapter_modules") and cfg.adapter_modules: + export = getattr(cfg, "export", False) + if cfg.adapter_layer_norm: + self.adapter_layer_norm = LayerNorm(self.embed_dim, export=export) + self.adapter_modules = nn.ModuleDict(dict()) + if hasattr(self.cfg, "bottleneck"): + bottleneck = self.cfg.bottleneck + else: + bottleneck = 2 + for language in cfg.languages: + self.adapter_modules[str(language)] = Adapter(cfg, red_fac=bottleneck) + + def lang_adapter(self, lang_id, x): + # If language adapters exist pass throught them + if hasattr(self.cfg, "adapter_modules") and self.cfg.adapter_modules: + if lang_id is None: + lang_id = ["en_XX"] * x.shape[1] + d_langs = [lang_id[0]] + lang_lengths = [1] + for lang in lang_id[1:]: + if lang == d_langs[-1]: + lang_lengths[-1] += 1 + else: + d_langs.append(lang) + lang_lengths.append(1) + + if ( + not hasattr(self.cfg, "ln_before_adapter") + or not self.cfg.ln_before_adapter + ): + residual = x + if self.cfg.adapter_layer_norm: + x = self.adapter_layer_norm(x) + elif self.cfg.adapter_reuse_layer_norm: + x = self.final_layer_norm(x) + if hasattr(self.cfg, "ln_before_adapter") and self.cfg.ln_before_adapter: + residual = x + + split_x = torch.split(x, lang_lengths, 1) + x_ = [] + for i, (lang, s_x) in enumerate(zip(d_langs, split_x)): + lang = lang.replace("_rom", "").replace("_zaw", "") + x_.append(self.adapter_modules[str(lang)](s_x)) + x = torch.cat(x_, 1) + + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + + return x + + def forward( + self, + x, + encoder_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor] = None, + lang_id: Optional[list] = None, + ): + """ + Args: + x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_padding_mask (ByteTensor): binary ByteTensor of shape + `(batch, seq_len)` where padding elements are indicated by ``1``. + attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`, + where `tgt_len` is the length of output and `src_len` is the + length of input, though here both are equal to `seq_len`. + `attn_mask[tgt_i, src_j] = 1` means that when calculating the + embedding for `tgt_i`, we exclude (mask out) `src_j`. This is + useful for strided self-attention. + + Returns: + encoded output of shape `(seq_len, batch, embed_dim)` + """ + # anything in original attn_mask = 1, becomes -1e8 + # anything in original attn_mask = 0, becomes 0 + # Note that we cannot use -inf here, because at some edge cases, + # the attention weight (before softmax) for some padded element in query + # will become -inf, which results in NaN in model parameters + if attn_mask is not None: + attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8) + + residual = x + if self.normalize_before: + x = self.self_attn_layer_norm(x) + x, _ = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=encoder_padding_mask, + need_weights=False, + attn_mask=attn_mask, + ) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.self_attn_layer_norm(x) + + residual = x + if self.normalize_before: + x = self.final_layer_norm(x) + x = self.activation_fn(self.fc1(x)) + x = self.activation_dropout_module(x) + x = self.fc2(x) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + + x = self.lang_adapter(lang_id, x) + + if not self.normalize_before: + x = self.final_layer_norm(x) + return x diff --git a/fairseq/modules/__init__.py b/fairseq/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dcfda9b82aa74819d99dc9305fcad437fd90456f --- /dev/null +++ b/fairseq/modules/__init__.py @@ -0,0 +1,106 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +from .adaptive_input import AdaptiveInput +from .adaptive_softmax import AdaptiveSoftmax +from .base_layer import BaseLayer +from .beamable_mm import BeamableMM +from .character_token_embedder import CharacterTokenEmbedder +from .conv_tbc import ConvTBC +from .cross_entropy import cross_entropy +from .downsampled_multihead_attention import DownsampledMultiHeadAttention +from .dynamic_convolution import DynamicConv, DynamicConv1dTBC, DynamicConv_scripatable +from .dynamic_crf_layer import DynamicCRF +from .ema_module import EMAModuleConfig, EMAModule +from .fairseq_dropout import FairseqDropout +from .fp32_batch_norm import Fp32BatchNorm +from .fp32_group_norm import Fp32GroupNorm +from .fp32_instance_norm import Fp32InstanceNorm +from .gelu import gelu, gelu_accurate +from .grad_multiply import GradMultiply +from .gumbel_vector_quantizer import GumbelVectorQuantizer +from .kmeans_vector_quantizer import KmeansVectorQuantizer +from .layer_drop import LayerDropModuleList +from .layer_norm import Fp32LayerNorm, LayerNorm +from .learned_positional_embedding import LearnedPositionalEmbedding +from .lightweight_convolution import LightweightConv, LightweightConv1dTBC +from .linearized_convolution import LinearizedConvolution +from .location_attention import LocationAttention +from .lstm_cell_with_zoneout import LSTMCellWithZoneOut +from .multihead_attention import MultiheadAttention +from .positional_embedding import PositionalEmbedding +from .same_pad import SamePad, SamePad2d +from .scalar_bias import ScalarBias +from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding +from .transformer_sentence_encoder_layer import TransformerSentenceEncoderLayer +from .transformer_sentence_encoder import TransformerSentenceEncoder +from .transpose_last import TransposeLast +from .unfold import unfold1d +from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer +from .vggblock import VGGBlock +from .espnet_multihead_attention import ( + ESPNETMultiHeadedAttention, + RelPositionMultiHeadedAttention, + RotaryPositionMultiHeadedAttention, +) +from .rotary_positional_embedding import RotaryPositionalEmbedding +from .positional_encoding import ( + RelPositionalEncoding, +) + +__all__ = [ + "AdaptiveInput", + "AdaptiveSoftmax", + "BaseLayer", + "BeamableMM", + "CharacterTokenEmbedder", + "ConvTBC", + "cross_entropy", + "DownsampledMultiHeadAttention", + "DynamicConv1dTBC", + "DynamicConv", + "DynamicConv_scripatable", + "DynamicCRF", + "EMAModule", + "EMAModuleConfig", + "FairseqDropout", + "Fp32BatchNorm", + "Fp32GroupNorm", + "Fp32LayerNorm", + "Fp32InstanceNorm", + "gelu", + "gelu_accurate", + "GradMultiply", + "GumbelVectorQuantizer", + "KmeansVectorQuantizer", + "LayerDropModuleList", + "LayerNorm", + "LearnedPositionalEmbedding", + "LightweightConv1dTBC", + "LightweightConv", + "LinearizedConvolution", + "LocationAttention", + "LSTMCellWithZoneOut", + "MultiheadAttention", + "PositionalEmbedding", + "SamePad", + "SamePad2d", + "ScalarBias", + "SinusoidalPositionalEmbedding", + "TransformerSentenceEncoderLayer", + "TransformerSentenceEncoder", + "TransformerDecoderLayer", + "TransformerEncoderLayer", + "TransposeLast", + "VGGBlock", + "unfold1d", + "ESPNETMultiheadedAttention", + "PositionalEmbedding", + "RelPositionMultiHeadedAttention", + "RelPositionalEncoding", + "RotaryPositionalEmbedding", + "RotaryPositionMultiHeadedAttention", +] diff --git a/fairseq/modules/__pycache__/__init__.cpython-310.pyc b/fairseq/modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f0d9c65093bd6f2136981cf257bfbb238d8ddda Binary files /dev/null and b/fairseq/modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/adaptive_input.cpython-310.pyc b/fairseq/modules/__pycache__/adaptive_input.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b812a5317bc17e77877ef60a912c9ad95863e35 Binary files /dev/null and b/fairseq/modules/__pycache__/adaptive_input.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/adaptive_softmax.cpython-310.pyc b/fairseq/modules/__pycache__/adaptive_softmax.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e28ef49b17b1b7f8cfbb48b24230808c1b570f4 Binary files /dev/null and b/fairseq/modules/__pycache__/adaptive_softmax.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/base_layer.cpython-310.pyc b/fairseq/modules/__pycache__/base_layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1cf7e108e2559318a215cda0fc0ce0295f84b1f Binary files /dev/null and b/fairseq/modules/__pycache__/base_layer.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/beamable_mm.cpython-310.pyc b/fairseq/modules/__pycache__/beamable_mm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e049b7f47fcd5c4fced4466dd26a1d6a39046afc Binary files /dev/null and b/fairseq/modules/__pycache__/beamable_mm.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/character_token_embedder.cpython-310.pyc b/fairseq/modules/__pycache__/character_token_embedder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c62d922d6be475a14972b9c5249be64067e18e7 Binary files /dev/null and b/fairseq/modules/__pycache__/character_token_embedder.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/checkpoint_activations.cpython-310.pyc b/fairseq/modules/__pycache__/checkpoint_activations.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a62611fa906b5905413787d8f22444269b8ec1ae Binary files /dev/null and b/fairseq/modules/__pycache__/checkpoint_activations.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/conformer_layer.cpython-310.pyc b/fairseq/modules/__pycache__/conformer_layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bda5641d3bc95f41aab2bd9afebfe35d28a1d136 Binary files /dev/null and b/fairseq/modules/__pycache__/conformer_layer.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/conv_tbc.cpython-310.pyc b/fairseq/modules/__pycache__/conv_tbc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fc4abeac2b285824161fb303f095526c7b73a4d Binary files /dev/null and b/fairseq/modules/__pycache__/conv_tbc.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/cross_entropy.cpython-310.pyc b/fairseq/modules/__pycache__/cross_entropy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d745e8903f883bc37e93d5b80f9a67f11248c8ca Binary files /dev/null and b/fairseq/modules/__pycache__/cross_entropy.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/downsampled_multihead_attention.cpython-310.pyc b/fairseq/modules/__pycache__/downsampled_multihead_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..600103771cfdb783d079ee1586a27cb1a203276e Binary files /dev/null and b/fairseq/modules/__pycache__/downsampled_multihead_attention.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/dynamic_convolution.cpython-310.pyc b/fairseq/modules/__pycache__/dynamic_convolution.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..363ab129e64656b4931927d6c16b08481d99c13d Binary files /dev/null and b/fairseq/modules/__pycache__/dynamic_convolution.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/dynamic_crf_layer.cpython-310.pyc b/fairseq/modules/__pycache__/dynamic_crf_layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..050893fa308c670b3c5c687234dbefe137c60a08 Binary files /dev/null and b/fairseq/modules/__pycache__/dynamic_crf_layer.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/ema_module.cpython-310.pyc b/fairseq/modules/__pycache__/ema_module.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b849c61b7b5f9cdd71709674e69f8f07ee129f81 Binary files /dev/null and b/fairseq/modules/__pycache__/ema_module.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/espnet_multihead_attention.cpython-310.pyc b/fairseq/modules/__pycache__/espnet_multihead_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdc37e6550fb4b87166d96513cc8b0c40a6e949f Binary files /dev/null and b/fairseq/modules/__pycache__/espnet_multihead_attention.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/fairseq_dropout.cpython-310.pyc b/fairseq/modules/__pycache__/fairseq_dropout.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e97f32e5a3d12b98b0b0c37653509be5796fe65d Binary files /dev/null and b/fairseq/modules/__pycache__/fairseq_dropout.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/fp32_batch_norm.cpython-310.pyc b/fairseq/modules/__pycache__/fp32_batch_norm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b08cffc18b26d50022e3a7063425f08d35c2f3ef Binary files /dev/null and b/fairseq/modules/__pycache__/fp32_batch_norm.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/fp32_group_norm.cpython-310.pyc b/fairseq/modules/__pycache__/fp32_group_norm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..684b941f0bfc39ad54fca2005db94ad2c13e4737 Binary files /dev/null and b/fairseq/modules/__pycache__/fp32_group_norm.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/fp32_instance_norm.cpython-310.pyc b/fairseq/modules/__pycache__/fp32_instance_norm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98b8361d8ecfef2a55ca93786cfc663c79aaee67 Binary files /dev/null and b/fairseq/modules/__pycache__/fp32_instance_norm.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/gelu.cpython-310.pyc b/fairseq/modules/__pycache__/gelu.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f32083f7960d693fbaafc41f923cf85828b131f Binary files /dev/null and b/fairseq/modules/__pycache__/gelu.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/grad_multiply.cpython-310.pyc b/fairseq/modules/__pycache__/grad_multiply.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9160d7e564e5f5242e994b54f59a78816cb441b6 Binary files /dev/null and b/fairseq/modules/__pycache__/grad_multiply.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/gumbel_vector_quantizer.cpython-310.pyc b/fairseq/modules/__pycache__/gumbel_vector_quantizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..949fc678863e69739b010de71704a6187a4c97fb Binary files /dev/null and b/fairseq/modules/__pycache__/gumbel_vector_quantizer.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/kmeans_vector_quantizer.cpython-310.pyc b/fairseq/modules/__pycache__/kmeans_vector_quantizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c644cf6783f9909dc637e4684051c027ec786bc1 Binary files /dev/null and b/fairseq/modules/__pycache__/kmeans_vector_quantizer.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/layer_drop.cpython-310.pyc b/fairseq/modules/__pycache__/layer_drop.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e55988e183129ddfb13a0032c06a1a9e2d53ffb Binary files /dev/null and b/fairseq/modules/__pycache__/layer_drop.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/layer_norm.cpython-310.pyc b/fairseq/modules/__pycache__/layer_norm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc19b65e347a539a6c56ecb0e7b20e9e1d0c159a Binary files /dev/null and b/fairseq/modules/__pycache__/layer_norm.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/learned_positional_embedding.cpython-310.pyc b/fairseq/modules/__pycache__/learned_positional_embedding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27b6fd8bf4b3d36ab4af0aa49cfd8447770780f5 Binary files /dev/null and b/fairseq/modules/__pycache__/learned_positional_embedding.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/lightweight_convolution.cpython-310.pyc b/fairseq/modules/__pycache__/lightweight_convolution.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..187c520aead3bfa4e344d7e3f25d6ae9af8e6317 Binary files /dev/null and b/fairseq/modules/__pycache__/lightweight_convolution.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/linearized_convolution.cpython-310.pyc b/fairseq/modules/__pycache__/linearized_convolution.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75b1be77e74c795a5d8fd125f4b9d1cc86ff6bf0 Binary files /dev/null and b/fairseq/modules/__pycache__/linearized_convolution.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/location_attention.cpython-310.pyc b/fairseq/modules/__pycache__/location_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81f83178773fc3ddf035b192291c4078a59e4034 Binary files /dev/null and b/fairseq/modules/__pycache__/location_attention.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/lstm_cell_with_zoneout.cpython-310.pyc b/fairseq/modules/__pycache__/lstm_cell_with_zoneout.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca52e9494f6c60ba70fe9d9dcee77e1c96f68859 Binary files /dev/null and b/fairseq/modules/__pycache__/lstm_cell_with_zoneout.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/multihead_attention.cpython-310.pyc b/fairseq/modules/__pycache__/multihead_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..834f7d445cbb1d398ad237beafdabcd9df68c12e Binary files /dev/null and b/fairseq/modules/__pycache__/multihead_attention.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/positional_embedding.cpython-310.pyc b/fairseq/modules/__pycache__/positional_embedding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1cb946bc8d03396469ad3c684e29616c963ac2a Binary files /dev/null and b/fairseq/modules/__pycache__/positional_embedding.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/positional_encoding.cpython-310.pyc b/fairseq/modules/__pycache__/positional_encoding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30e9d548ffe67fc51fdea81f616d0f9198c0f1fd Binary files /dev/null and b/fairseq/modules/__pycache__/positional_encoding.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/quant_noise.cpython-310.pyc b/fairseq/modules/__pycache__/quant_noise.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac0776df14df84929d0826e5759cf0c0d1c06919 Binary files /dev/null and b/fairseq/modules/__pycache__/quant_noise.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/rotary_positional_embedding.cpython-310.pyc b/fairseq/modules/__pycache__/rotary_positional_embedding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42605cfe25ca62d531093158148c401cc317ada0 Binary files /dev/null and b/fairseq/modules/__pycache__/rotary_positional_embedding.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/same_pad.cpython-310.pyc b/fairseq/modules/__pycache__/same_pad.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08b127905eaf10e7bcf683fa2d612c147afba892 Binary files /dev/null and b/fairseq/modules/__pycache__/same_pad.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/scalar_bias.cpython-310.pyc b/fairseq/modules/__pycache__/scalar_bias.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd0305593f17fa6f06d463aac886352e386bd7da Binary files /dev/null and b/fairseq/modules/__pycache__/scalar_bias.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/sinusoidal_positional_embedding.cpython-310.pyc b/fairseq/modules/__pycache__/sinusoidal_positional_embedding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..583e9f8677fe1f5e0696f3e3542061be814ec6c0 Binary files /dev/null and b/fairseq/modules/__pycache__/sinusoidal_positional_embedding.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/transformer_layer.cpython-310.pyc b/fairseq/modules/__pycache__/transformer_layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4054f21505e933f5e31f560b82cbfc0927197d6 Binary files /dev/null and b/fairseq/modules/__pycache__/transformer_layer.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/transformer_layer_aug.cpython-310.pyc b/fairseq/modules/__pycache__/transformer_layer_aug.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a3fc438c3176b2457268f9c0510b52b2a157816 Binary files /dev/null and b/fairseq/modules/__pycache__/transformer_layer_aug.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/transformer_sentence_encoder.cpython-310.pyc b/fairseq/modules/__pycache__/transformer_sentence_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d5fc3e3f9ee10323c8dca512cc869bf7b1cff03 Binary files /dev/null and b/fairseq/modules/__pycache__/transformer_sentence_encoder.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/transformer_sentence_encoder_layer.cpython-310.pyc b/fairseq/modules/__pycache__/transformer_sentence_encoder_layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..046e59b999e4a05747749f00168b6fccc93b48db Binary files /dev/null and b/fairseq/modules/__pycache__/transformer_sentence_encoder_layer.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/transpose_last.cpython-310.pyc b/fairseq/modules/__pycache__/transpose_last.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb38999fcad5f9b73817de04cf94607ed7ff07e2 Binary files /dev/null and b/fairseq/modules/__pycache__/transpose_last.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/unfold.cpython-310.pyc b/fairseq/modules/__pycache__/unfold.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41305b7692bfde89ef1cac80d1c0d06eb001d1c5 Binary files /dev/null and b/fairseq/modules/__pycache__/unfold.cpython-310.pyc differ diff --git a/fairseq/modules/__pycache__/vggblock.cpython-310.pyc b/fairseq/modules/__pycache__/vggblock.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c84960989422e87bde51390d9b64d32fb0d15d03 Binary files /dev/null and b/fairseq/modules/__pycache__/vggblock.cpython-310.pyc differ diff --git a/fairseq/modules/adaptive_input.py b/fairseq/modules/adaptive_input.py new file mode 100644 index 0000000000000000000000000000000000000000..01ac4accaccc3e4625ee2ca6a17cf88df1db0ca1 --- /dev/null +++ b/fairseq/modules/adaptive_input.py @@ -0,0 +1,81 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List + +import torch +from torch import nn + +from fairseq.modules.quant_noise import quant_noise + + +class AdaptiveInput(nn.Module): + def __init__( + self, + vocab_size: int, + padding_idx: int, + initial_dim: int, + factor: float, + output_dim: int, + cutoff: List[int], + q_noise: float = 0, + qn_block_size: int = 8, + ): + super().__init__() + + if vocab_size > cutoff[-1]: + cutoff = cutoff + [vocab_size] + else: + assert ( + vocab_size == cutoff[-1] + ), "cannot specify cutoff larger than vocab size" + + self.cutoff = cutoff + self.embedding_dim = output_dim + self.padding_idx = padding_idx + + self.embeddings = nn.ModuleList() + for i in range(len(self.cutoff)): + prev = self.cutoff[i - 1] if i > 0 else 0 + size = self.cutoff[i] - prev + dim = int(initial_dim // (factor**i)) + seq = nn.Sequential( + nn.Embedding(size, dim, self.padding_idx), + quant_noise( + nn.Linear(dim, output_dim, bias=False), q_noise, qn_block_size + ), + ) + + self.embeddings.append(seq) + self.padding_idx = None + self.padding_idx = padding_idx + + def init_weights(m): + if isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, mean=0, std=m.weight.shape[1] ** -0.5) + nn.init.constant_(m.weight[padding_idx], 0) + elif hasattr(m, "weight"): + nn.init.xavier_uniform_(m.weight) + + self.apply(init_weights) + + self.register_buffer("_float_tensor", torch.FloatTensor(1)) + + def weights_for_band(self, band: int): + return self.embeddings[band][0].weight, self.embeddings[band][1].weight + + def forward(self, input: torch.Tensor): + result = self._float_tensor.new(input.shape + (self.embedding_dim,)) + for i in range(len(self.cutoff)): + mask = input.lt(self.cutoff[i]) + if i > 0: + mask.mul_(input.ge(self.cutoff[i - 1])) + chunk_input = input[mask] - self.cutoff[i - 1] + else: + chunk_input = input[mask] + if mask.any(): + result[mask] = self.embeddings[i](chunk_input) + return result diff --git a/fairseq/modules/adaptive_softmax.py b/fairseq/modules/adaptive_softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..ae0c77ba0f6ee98501306d66cbc4a948b4ade0f7 --- /dev/null +++ b/fairseq/modules/adaptive_softmax.py @@ -0,0 +1,268 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import operator + +import torch +import torch.nn.functional as F +from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.quant_noise import quant_noise +from torch import nn + + +class TiedLinear(nn.Module): + def __init__(self, weight, transpose): + super().__init__() + self.weight = weight + self.transpose = transpose + + def forward(self, input): + return F.linear(input, self.weight.t() if self.transpose else self.weight) + + +class TiedHeadModule(nn.Module): + def __init__(self, weights, input_dim, num_classes, q_noise, qn_block_size): + super().__init__() + tied_emb, _ = weights + self.num_words, emb_dim = tied_emb.size() + + self.word_proj = quant_noise( + TiedLinear(tied_emb, transpose=False), q_noise, qn_block_size + ) + if input_dim != emb_dim: + self.word_proj = nn.Sequential( + quant_noise( + nn.Linear(input_dim, emb_dim, bias=False), q_noise, qn_block_size + ), + self.word_proj, + ) + + self.class_proj = quant_noise( + nn.Linear(input_dim, num_classes, bias=False), q_noise, qn_block_size + ) + self.out_dim = self.num_words + num_classes + + self.register_buffer("_float_tensor", torch.FloatTensor(1)) + + def forward(self, input): + inp_sz = functools.reduce(operator.mul, input.shape[:-1], 1) + out = self._float_tensor.new(inp_sz, self.out_dim) + out[:, : self.num_words] = self.word_proj(input.view(inp_sz, -1)) + out[:, self.num_words :] = self.class_proj(input.view(inp_sz, -1)) + return out + + +class AdaptiveSoftmax(nn.Module): + """ + This is an implementation of the efficient softmax approximation for + graphical processing units (GPU), described in the paper "Efficient softmax + approximation for GPUs" (http://arxiv.org/abs/1609.04309). + """ + + def __init__( + self, + vocab_size, + input_dim, + cutoff, + dropout, + factor=4.0, + adaptive_inputs=None, + tie_proj=False, + q_noise=0, + qn_block_size=8, + ): + super().__init__() + + if vocab_size > cutoff[-1]: + cutoff = cutoff + [vocab_size] + else: + assert ( + vocab_size == cutoff[-1] + ), "cannot specify cutoff larger than vocab size" + + output_dim = cutoff[0] + len(cutoff) - 1 + + self.vocab_size = vocab_size + self.cutoff = cutoff + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) + self.input_dim = input_dim + self.factor = factor + self.q_noise = q_noise + self.qn_block_size = qn_block_size + + self.lsm = nn.LogSoftmax(dim=1) + + if adaptive_inputs is not None: + self.head = TiedHeadModule( + adaptive_inputs.weights_for_band(0), + input_dim, + len(cutoff) - 1, + self.q_noise, + self.qn_block_size, + ) + else: + self.head = quant_noise( + nn.Linear(input_dim, output_dim, bias=False), + self.q_noise, + self.qn_block_size, + ) + + self._make_tail(adaptive_inputs, tie_proj) + + def init_weights(m): + if ( + hasattr(m, "weight") + and not isinstance(m, TiedLinear) + and not isinstance(m, TiedHeadModule) + ): + nn.init.xavier_uniform_(m.weight) + + self.apply(init_weights) + + self.register_buffer("version", torch.LongTensor([1])) + + def _make_tail(self, adaptive_inputs=None, tie_proj=False): + self.tail = nn.ModuleList() + for i in range(len(self.cutoff) - 1): + dim = int(self.input_dim // self.factor ** (i + 1)) + + tied_emb, tied_proj = ( + adaptive_inputs.weights_for_band(i + 1) + if adaptive_inputs is not None + else (None, None) + ) + + if tied_proj is not None: + if tie_proj: + proj = quant_noise( + TiedLinear(tied_proj, transpose=True), + self.q_noise, + self.qn_block_size, + ) + else: + proj = quant_noise( + nn.Linear(tied_proj.size(0), tied_proj.size(1), bias=False), + self.q_noise, + self.qn_block_size, + ) + else: + proj = quant_noise( + nn.Linear(self.input_dim, dim, bias=False), + self.q_noise, + self.qn_block_size, + ) + + if tied_emb is None: + out_proj = nn.Linear( + dim, self.cutoff[i + 1] - self.cutoff[i], bias=False + ) + else: + out_proj = TiedLinear(tied_emb, transpose=False) + + m = nn.Sequential( + proj, + nn.Dropout(self.dropout_module.p), + quant_noise(out_proj, self.q_noise, self.qn_block_size), + ) + + self.tail.append(m) + + def upgrade_state_dict_named(self, state_dict, name): + version_name = name + ".version" + if version_name not in state_dict: + raise Exception("This version of the model is no longer supported") + + def adapt_target(self, target): + """ + In order to be efficient, the AdaptiveSoftMax does not compute the + scores for all the word of the vocabulary for all the examples. It is + thus necessary to call the method adapt_target of the AdaptiveSoftMax + layer inside each forward pass. + """ + + target = target.view(-1) + new_target = [target.clone()] + target_idxs = [] + + for i in range(len(self.cutoff) - 1): + mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1])) + new_target[0][mask] = self.cutoff[0] + i + + if mask.any(): + target_idxs.append(mask.nonzero(as_tuple=False).squeeze(1)) + new_target.append(target[mask].add(-self.cutoff[i])) + else: + target_idxs.append(None) + new_target.append(None) + + return new_target, target_idxs + + def forward(self, input, target): + """ + Args: + input: (b x t x d) + target: (b x t) + Returns: + 2 lists: output for each cutoff section and new targets by cut off + """ + + input = input.contiguous().view(-1, input.size(-1)) + input = self.dropout_module(input) + + new_target, target_idxs = self.adapt_target(target) + output = [self.head(input)] + + for i in range(len(target_idxs)): + if target_idxs[i] is not None: + output.append(self.tail[i](input.index_select(0, target_idxs[i]))) + else: + output.append(None) + + return output, new_target + + def get_log_prob(self, input, target): + """ + Computes the log probabilities for all the words of the vocabulary, + given a 2D tensor of hidden vectors. + """ + + bsz, length, dim = input.size() + input = input.contiguous().view(-1, dim) + + if target is not None: + _, target_idxs = self.adapt_target(target) + else: + target_idxs = None + + head_y = self.head(input) + log_probs = head_y.new_zeros(input.size(0), self.vocab_size) + + head_sz = self.cutoff[0] + len(self.tail) + log_probs[:, :head_sz] = self.lsm(head_y) + tail_priors = log_probs[:, self.cutoff[0] : head_sz].clone() + + for i in range(len(self.tail)): + start = self.cutoff[i] + end = self.cutoff[i + 1] + + if target_idxs is None: + tail_out = log_probs[:, start:end] + tail_out.copy_(self.tail[i](input)) + log_probs[:, start:end] = self.lsm(tail_out).add_( + tail_priors[:, i, None] + ) + elif target_idxs[i] is not None: + idxs = target_idxs[i] + tail_out = log_probs[idxs, start:end] + tail_out.copy_(self.tail[i](input[idxs])) + log_probs[idxs, start:end] = self.lsm(tail_out).add_( + tail_priors[idxs, i, None] + ) + + log_probs = log_probs.view(bsz, length, -1) + return log_probs diff --git a/fairseq/modules/base_layer.py b/fairseq/modules/base_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..e823f7bae2c2e9b7e2277b0161cf26b453e7216c --- /dev/null +++ b/fairseq/modules/base_layer.py @@ -0,0 +1,170 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn +import torch +import sys +from fairseq import utils +from fairseq.distributed import utils as distributed_utils +from fairseq.modules.layer_norm import LayerNorm + + +class BaseLayer(nn.Module): + def __init__(self, args): + super().__init__() + self.num_workers = distributed_utils.get_data_parallel_world_size() + expert_centroids = torch.empty(self.num_workers, args.decoder_embed_dim) + torch.nn.init.orthogonal_(expert_centroids, gain=0.1) + self.register_parameter( + "expert_centroids", torch.nn.Parameter(expert_centroids) + ) + self.expert_network = nn.Sequential( + *([BaseSublayer(args) for _ in range(args.base_sublayers)]) + ) + self.expert_id = distributed_utils.get_data_parallel_rank() + self.shuffle = args.base_shuffle + self.cpp = self.load_assignment() + + # Add a special attribute to the expert parameters, so we know not to sync their gradients + for param in self.expert_network.parameters(): + param.expert = True + + def forward(self, input_features, *args, **kwargs): + features = input_features.reshape(-1, input_features.size(-1)) + is_training = input_features.requires_grad + + if self.shuffle and is_training: + # Send each token to a random worker, to break correlations within the batch + shuffle_sort = torch.randperm(features.size(0), device=features.device) + features = All2All.apply(features[shuffle_sort]) + + with torch.no_grad(): + # Compute similarity of each token to each expert, for routing + token_expert_affinities = features.matmul( + self.expert_centroids.transpose(0, 1) + ) + + # Compute which token goes to which expert + sort_by_expert, input_splits, output_splits = ( + self.balanced_assignment(token_expert_affinities) + if is_training + else self.greedy_assignment(token_expert_affinities) + ) + # Swap these tokens for the right ones for our expert + routed_features = All2All.apply( + features[sort_by_expert], output_splits, input_splits + ) + + if routed_features.size(0) > 0: + # Mix in the expert network based on how appropriate it is for these tokens + alpha = torch.sigmoid( + routed_features.mv(self.expert_centroids[self.expert_id]) + ).unsqueeze(1) + routed_features = ( + alpha * self.expert_network(routed_features) + + (1 - alpha) * routed_features + ) + # Return to original worker and ordering + result = All2All.apply(routed_features, input_splits, output_splits)[ + self.inverse_sort(sort_by_expert) + ] + + if self.shuffle and is_training: + # Undo shuffling + result = All2All.apply(result)[self.inverse_sort(shuffle_sort)] + + # Return additional Nones for compatibility with TransformerDecoderLayer + return result.view(input_features.size()), None, None + + def inverse_sort(self, order): + # Creates an index that undoes a sort: xs==xs[order][inverse_sort(order)] + return torch.empty_like(order).scatter_( + 0, order, torch.arange(0, order.size(0), device=order.device) + ) + + def balanced_assignment(self, scores): + ok = scores.isfinite() + if not ok.all(): + # NaNs here can break the assignment algorithm + scores[~ok] = scores[ok].min() + return self.cpp.balanced_assignment(scores), None, None + + # Assigns each token to the top k experts + def greedy_assignment(self, scores, k=1): + token_to_workers = torch.topk(scores, dim=1, k=k, largest=True).indices.view(-1) + token_to_workers, sort_ordering = torch.sort(token_to_workers) + worker2token = sort_ordering // k + + # Find how many tokens we're sending to each other worker (being careful for sending 0 tokens to some workers) + output_splits = torch.zeros( + (self.num_workers,), dtype=torch.long, device=scores.device + ) + workers, counts = torch.unique_consecutive(token_to_workers, return_counts=True) + output_splits[workers] = counts + # Tell other workers how many tokens to expect from us + input_splits = All2All.apply(output_splits) + return worker2token, input_splits.tolist(), output_splits.tolist() + + def load_assignment(self): + try: + from fairseq import libbase + + return libbase + + except ImportError as e: + sys.stderr.write( + "ERROR: missing libbase. run `python setup.py build_ext --inplace`\n" + ) + raise e + + +class BaseSublayer(nn.Module): + def __init__(self, args): + super().__init__() + self.activation_fn = utils.get_activation_fn( + activation=getattr(args, "activation_fn", "relu") or "relu" + ) + self.norm = LayerNorm(args.decoder_embed_dim, export=False) + self.ff1 = torch.nn.Linear(args.decoder_embed_dim, args.decoder_ffn_embed_dim) + self.ff2 = torch.nn.Linear(args.decoder_ffn_embed_dim, args.decoder_embed_dim) + self.ff2.weight.data.zero_() + + def forward(self, xs): + return xs + self.ff2(self.activation_fn(self.ff1(self.norm(xs)))) + + +# Wraps torch.distributed.all_to_all_single as a function that supports autograd +class All2All(torch.autograd.Function): + @staticmethod + def forward(ctx, xs, input_splits=None, output_splits=None): + ctx.input_splits = input_splits + ctx.output_splits = output_splits + + ys = ( + torch.empty_like(xs) + if output_splits is None + else xs.new_empty(size=[sum(output_splits)] + list(xs.size()[1:])) + ) + torch.distributed.all_to_all_single( + ys, xs, output_split_sizes=output_splits, input_split_sizes=input_splits + ) + return ys + + @staticmethod + def backward(ctx, grad_output): + result = ( + torch.empty_like(grad_output) + if ctx.input_splits is None + else grad_output.new_empty( + size=[sum(ctx.input_splits)] + list(grad_output.size()[1:]) + ) + ) + torch.distributed.all_to_all_single( + result, + grad_output, + output_split_sizes=ctx.input_splits, + input_split_sizes=ctx.output_splits, + ) + return result, None, None diff --git a/fairseq/modules/beamable_mm.py b/fairseq/modules/beamable_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..eff1a4607f600c71210e6b914985dc48731aae86 --- /dev/null +++ b/fairseq/modules/beamable_mm.py @@ -0,0 +1,49 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + + +class BeamableMM(nn.Module): + """This module provides an optimized MM for beam decoding with attention. + + It leverage the fact that the source-side of the input is replicated beam + times and the target-side of the input is of width one. This layer speeds up + inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)} + with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}. + """ + + def __init__(self, beam_size=None): + super(BeamableMM, self).__init__() + self.beam_size = beam_size + + def forward(self, input1, input2): + if ( + not self.training + and self.beam_size is not None # test mode + and input1.dim() == 3 # beam size is set + and input1.size(1) # only support batched input + == 1 # single time step update + ): + bsz, beam = input1.size(0), self.beam_size + + # bsz x 1 x nhu --> bsz/beam x beam x nhu + input1 = input1[:, 0, :].unfold(0, beam, beam).transpose(2, 1) + + # bsz x sz2 x nhu --> bsz/beam x sz2 x nhu + input2 = input2.unfold(0, beam, beam)[:, :, :, 0] + + # use non batched operation if bsz = beam + if input1.size(0) == 1: + output = torch.mm(input1[0, :, :], input2[0, :, :]) + else: + output = input1.bmm(input2) + return output.view(bsz, 1, -1) + else: + return input1.bmm(input2) + + def set_beam_size(self, beam_size): + self.beam_size = beam_size diff --git a/fairseq/modules/character_token_embedder.py b/fairseq/modules/character_token_embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..181221b61b9f76453b67e3b848b198620dce912c --- /dev/null +++ b/fairseq/modules/character_token_embedder.py @@ -0,0 +1,214 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import List, Tuple + +import torch +import torch.nn.functional as F +from fairseq.data import Dictionary +from torch import nn + + +CHAR_PAD_IDX = 0 +CHAR_EOS_IDX = 257 + + +logger = logging.getLogger(__name__) + + +class CharacterTokenEmbedder(torch.nn.Module): + def __init__( + self, + vocab: Dictionary, + filters: List[Tuple[int, int]], + char_embed_dim: int, + word_embed_dim: int, + highway_layers: int, + max_char_len: int = 50, + char_inputs: bool = False, + ): + super(CharacterTokenEmbedder, self).__init__() + + self.onnx_trace = False + self.embedding_dim = word_embed_dim + self.max_char_len = max_char_len + self.char_embeddings = nn.Embedding(257, char_embed_dim, padding_idx=0) + self.symbol_embeddings = nn.Parameter(torch.FloatTensor(2, word_embed_dim)) + self.eos_idx, self.unk_idx = 0, 1 + self.char_inputs = char_inputs + + self.convolutions = nn.ModuleList() + for width, out_c in filters: + self.convolutions.append( + nn.Conv1d(char_embed_dim, out_c, kernel_size=width) + ) + + last_dim = sum(f[1] for f in filters) + + self.highway = Highway(last_dim, highway_layers) if highway_layers > 0 else None + + self.projection = nn.Linear(last_dim, word_embed_dim) + + assert ( + vocab is not None or char_inputs + ), "vocab must be set if not using char inputs" + self.vocab = None + if vocab is not None: + self.set_vocab(vocab, max_char_len) + + self.reset_parameters() + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def set_vocab(self, vocab, max_char_len): + word_to_char = torch.LongTensor(len(vocab), max_char_len) + + truncated = 0 + for i in range(len(vocab)): + if i < vocab.nspecial: + char_idxs = [0] * max_char_len + else: + chars = vocab[i].encode() + # +1 for padding + char_idxs = [c + 1 for c in chars] + [0] * (max_char_len - len(chars)) + if len(char_idxs) > max_char_len: + truncated += 1 + char_idxs = char_idxs[:max_char_len] + word_to_char[i] = torch.LongTensor(char_idxs) + + if truncated > 0: + logger.info( + "truncated {} words longer than {} characters".format( + truncated, max_char_len + ) + ) + + self.vocab = vocab + self.word_to_char = word_to_char + + @property + def padding_idx(self): + return Dictionary().pad() if self.vocab is None else self.vocab.pad() + + def reset_parameters(self): + nn.init.xavier_normal_(self.char_embeddings.weight) + nn.init.xavier_normal_(self.symbol_embeddings) + nn.init.xavier_uniform_(self.projection.weight) + + nn.init.constant_( + self.char_embeddings.weight[self.char_embeddings.padding_idx], 0.0 + ) + nn.init.constant_(self.projection.bias, 0.0) + + def forward( + self, + input: torch.Tensor, + ): + if self.char_inputs: + chars = input.view(-1, self.max_char_len) + pads = chars[:, 0].eq(CHAR_PAD_IDX) + eos = chars[:, 0].eq(CHAR_EOS_IDX) + if eos.any(): + if self.onnx_trace: + chars = torch.where(eos.unsqueeze(1), chars.new_zeros(1), chars) + else: + chars[eos] = 0 + + unk = None + else: + flat_words = input.view(-1) + chars = self.word_to_char[flat_words.type_as(self.word_to_char)].type_as( + input + ) + pads = flat_words.eq(self.vocab.pad()) + eos = flat_words.eq(self.vocab.eos()) + unk = flat_words.eq(self.vocab.unk()) + + word_embs = self._convolve(chars) + if self.onnx_trace: + if pads.any(): + word_embs = torch.where( + pads.unsqueeze(1), word_embs.new_zeros(1), word_embs + ) + if eos.any(): + word_embs = torch.where( + eos.unsqueeze(1), self.symbol_embeddings[self.eos_idx], word_embs + ) + if unk is not None and unk.any(): + word_embs = torch.where( + unk.unsqueeze(1), self.symbol_embeddings[self.unk_idx], word_embs + ) + else: + if pads.any(): + word_embs[pads] = 0 + if eos.any(): + word_embs[eos] = self.symbol_embeddings[self.eos_idx] + if unk is not None and unk.any(): + word_embs[unk] = self.symbol_embeddings[self.unk_idx] + + return word_embs.view(input.size()[:2] + (-1,)) + + def _convolve( + self, + char_idxs: torch.Tensor, + ): + char_embs = self.char_embeddings(char_idxs) + char_embs = char_embs.transpose(1, 2) # BTC -> BCT + + conv_result = [] + + for conv in self.convolutions: + x = conv(char_embs) + x, _ = torch.max(x, -1) + x = F.relu(x) + conv_result.append(x) + + x = torch.cat(conv_result, dim=-1) + + if self.highway is not None: + x = self.highway(x) + x = self.projection(x) + + return x + + +class Highway(torch.nn.Module): + """ + A `Highway layer `_. + Adopted from the AllenNLP implementation. + """ + + def __init__(self, input_dim: int, num_layers: int = 1): + super(Highway, self).__init__() + self.input_dim = input_dim + self.layers = nn.ModuleList( + [nn.Linear(input_dim, input_dim * 2) for _ in range(num_layers)] + ) + self.activation = nn.ReLU() + + self.reset_parameters() + + def reset_parameters(self): + for layer in self.layers: + # As per comment in AllenNLP: + # We should bias the highway layer to just carry its input forward. We do that by + # setting the bias on `B(x)` to be positive, because that means `g` will be biased to + # be high, so we will carry the input forward. The bias on `B(x)` is the second half + # of the bias vector in each Linear layer. + nn.init.constant_(layer.bias[self.input_dim :], 1) + + nn.init.constant_(layer.bias[: self.input_dim], 0) + nn.init.xavier_normal_(layer.weight) + + def forward(self, x: torch.Tensor): + for layer in self.layers: + projection = layer(x) + proj_x, gate = projection.chunk(2, dim=-1) + proj_x = self.activation(proj_x) + gate = torch.sigmoid(gate) + x = gate * x + (gate.new_tensor([1]) - gate) * proj_x + return x diff --git a/fairseq/modules/checkpoint_activations.py b/fairseq/modules/checkpoint_activations.py new file mode 100644 index 0000000000000000000000000000000000000000..aa0b5929a39d5acec91038d7335276b4e5098aed --- /dev/null +++ b/fairseq/modules/checkpoint_activations.py @@ -0,0 +1,242 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import functools +from typing import Any, Dict, List, Tuple, Union + +import torch +import torch.utils.checkpoint as checkpoint +from fairseq import utils + + +def checkpoint_wrapper(m, offload_to_cpu=False): + """ + A friendlier wrapper for performing activation checkpointing. + + Compared to the PyTorch version, this version: + - wraps an nn.Module, so that all subsequent calls will use checkpointing + - handles keyword arguments in the forward + - handles non-Tensor outputs from the forward + + Usage:: + + checkpointed_module = checkpoint_wrapper(my_module, offload_to_cpu=True) + a, b = checkpointed_module(x, y=3, z=torch.Tensor([1])) + """ + # should I check whether original_forward has already been set? + assert not hasattr( + m, "precheckpoint_forward" + ), "checkpoint function has already been applied?" + m.precheckpoint_forward = m.forward + m.forward = functools.partial( + _checkpointed_forward, + m.precheckpoint_forward, # original_forward + offload_to_cpu, + ) + return m + + +def unwrap_checkpoint(m: torch.nn.Module): + """ + unwrap a module and its children from checkpoint_wrapper + """ + for module in m.modules(): + if hasattr(module, "precheckpoint_forward"): + module.forward = module.precheckpoint_forward + del module.precheckpoint_forward + if hasattr(module, "old_deepcopy_method"): + module.__deepcopy__ = module.old_deepcopy_method + del module.old_deepcopy_method + return m + + +def _checkpointed_forward(original_forward, offload_to_cpu, *args, **kwargs): + # Autograd Functions in PyTorch work best with positional args, since + # the backward must return gradients (or None) for every input argument. + # We can flatten keyword arguments to make this easier. + kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) + parent_ctx_dict = {"offload": offload_to_cpu} + output = CheckpointFunction.apply( + original_forward, parent_ctx_dict, kwarg_keys, *flat_args + ) + if isinstance(output, torch.Tensor): + return output + else: + packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"] + if packed_non_tensor_outputs: + output = unpack_non_tensors(output, packed_non_tensor_outputs) + return output + + +def pack_kwargs(*args, **kwargs) -> Tuple[List[str], List[Any]]: + """ + Usage:: + + kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4) + args, kwargs = unpack_kwargs(kwarg_keys, flat_args) + assert args == [1, 2] + assert kwargs == {"a": 3, "b": 4} + """ + kwarg_keys = [] + flat_args = list(args) + for k, v in kwargs.items(): + kwarg_keys.append(k) + flat_args.append(v) + return kwarg_keys, flat_args + + +def unpack_kwargs( + kwarg_keys: List[str], flat_args: List[Any] +) -> Tuple[List[Any], Dict[str, Any]]: + if len(kwarg_keys) == 0: + return flat_args, {} + args = flat_args[: -len(kwarg_keys)] + kwargs = {k: v for k, v in zip(kwarg_keys, flat_args[-len(kwarg_keys) :])} + return args, kwargs + + +def split_non_tensors( + mixed: Union[torch.Tensor, Tuple[Any]] +) -> Tuple[Tuple[torch.Tensor], Dict[str, List[Any]]]: + """ + Usage:: + + x = torch.Tensor([1]) + y = torch.Tensor([2]) + tensors, packed_non_tensors = split_non_tensors((x, y, None, 3)) + recon = unpack_non_tensors(tensors, packed_non_tensors) + assert recon == (x, y, None, 3) + """ + if isinstance(mixed, torch.Tensor): + return (mixed,), None + tensors = [] + packed_non_tensors = {"is_tensor": [], "objects": []} + for o in mixed: + if isinstance(o, torch.Tensor): + packed_non_tensors["is_tensor"].append(True) + tensors.append(o) + else: + packed_non_tensors["is_tensor"].append(False) + packed_non_tensors["objects"].append(o) + return tuple(tensors), packed_non_tensors + + +def unpack_non_tensors( + tensors: Tuple[torch.Tensor], + packed_non_tensors: Dict[str, List[Any]], +) -> Tuple[Any]: + if packed_non_tensors is None: + return tensors + assert isinstance(packed_non_tensors, dict) + mixed = [] + is_tensor_list = packed_non_tensors["is_tensor"] + objects = packed_non_tensors["objects"] + assert len(tensors) + len(objects) == len(is_tensor_list) + obj_i = tnsr_i = 0 + for is_tensor in is_tensor_list: + if is_tensor: + mixed.append(tensors[tnsr_i]) + tnsr_i += 1 + else: + mixed.append(objects[obj_i]) + obj_i += 1 + return tuple(mixed) + + +class CheckpointFunction(torch.autograd.Function): + """Similar to the torch version, but support non-Tensor outputs. + + The caller is expected to provide a dict (*parent_ctx_dict*) that will hold + the non-Tensor outputs. These should be combined with the Tensor *outputs* + by calling ``unpack_non_tensors``. + """ + + @staticmethod + def forward(ctx, run_function, parent_ctx_dict, kwarg_keys, *args): + if torch.is_grad_enabled(): # grad may be disabled, e.g., during validation + checkpoint.check_backward_validity(args) + + ctx.run_function = run_function + ctx.kwarg_keys = kwarg_keys + ctx.fwd_rng_state = utils.get_rng_state() + + tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args) + if parent_ctx_dict["offload"]: + ctx.fwd_device = tuple(x.device for x in tensor_inputs) + ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs) + tensor_inputs = tuple( + x.to(torch.device("cpu"), non_blocking=True) for x in tensor_inputs + ) + + else: + ctx.fwd_device, ctx.grad_requirements = None, None + + ctx.save_for_backward(*tensor_inputs) + ctx.packed_non_tensor_inputs = packed_non_tensor_inputs + + with torch.no_grad(): + unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args) + outputs = run_function(*unpacked_args, **unpacked_kwargs) + + if isinstance(outputs, torch.Tensor): + return outputs + else: + # Autograd Functions don't like non-Tensor outputs. We can split the + # non-Tensor and Tensor outputs, returning the former by reference + # through *parent_ctx_dict* and returning the latter directly. + outputs, packed_non_tensor_outputs = split_non_tensors(outputs) + parent_ctx_dict["packed_non_tensor_outputs"] = packed_non_tensor_outputs + return outputs + + @staticmethod + def backward(ctx, *args): + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError( + "Checkpointing is not compatible with .grad(), please use .backward() if possible" + ) + + tensor_inputs: Tuple = ctx.saved_tensors + tensor_inputs = checkpoint.detach_variable(tensor_inputs) + if ctx.fwd_device is not None: + tensor_inputs = [ + t.to(ctx.fwd_device[i], non_blocking=True) + for i, t in enumerate(tensor_inputs) + ] + for i, need_grad in enumerate(ctx.grad_requirements): + tensor_inputs[i].requires_grad = need_grad + inputs = unpack_non_tensors(tensor_inputs, ctx.packed_non_tensor_inputs) + + # Store the current states. + bwd_rng_state = utils.get_rng_state() + + # Set the states to what it used to be before the forward pass. + utils.set_rng_state(ctx.fwd_rng_state) + + with torch.enable_grad(): + unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs) + outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs) + tensor_outputs, _ = split_non_tensors(outputs) + # Set the states back to what it was at the start of this function. + utils.set_rng_state(bwd_rng_state) + + # Run backward() with only Tensors that require grad + outputs_with_grad = [] + args_with_grad = [] + for i in range(len(tensor_outputs)): + if tensor_outputs[i].requires_grad: + outputs_with_grad.append(tensor_outputs[i]) + args_with_grad.append(args[i]) + if len(outputs_with_grad) == 0: + raise RuntimeError( + "None of the outputs have requires_grad=True, " + "this checkpoint() is not necessary" + ) + + torch.autograd.backward(outputs_with_grad, args_with_grad) + + grads = tuple( + inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs + ) + return (None, None, None) + grads diff --git a/fairseq/modules/conformer_layer.py b/fairseq/modules/conformer_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..964af243ec8fb5e09e0b61da688a9657277d39c1 --- /dev/null +++ b/fairseq/modules/conformer_layer.py @@ -0,0 +1,301 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Optional + +import torch + +from fairseq.modules import ( + ESPNETMultiHeadedAttention, + LayerNorm, + MultiheadAttention, + RelPositionMultiHeadedAttention, + RotaryPositionMultiHeadedAttention, +) +from fairseq.utils import get_activation_fn + + +class ConvolutionModule(torch.nn.Module): + """Convolution block used in the conformer block""" + + def __init__( + self, + embed_dim, + channels, + depthwise_kernel_size, + dropout, + activation_fn="swish", + bias=False, + export=False, + ): + """ + Args: + embed_dim: Embedding dimension + channels: Number of channels in depthwise conv layers + depthwise_kernel_size: Depthwise conv layer kernel size + dropout: dropout value + activation_fn: Activation function to use after depthwise convolution kernel + bias: If bias should be added to conv layers + export: If layernorm should be exported to jit + """ + super(ConvolutionModule, self).__init__() + assert ( + depthwise_kernel_size - 1 + ) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" + self.layer_norm = LayerNorm(embed_dim, export=export) + self.pointwise_conv1 = torch.nn.Conv1d( + embed_dim, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.glu = torch.nn.GLU(dim=1) + self.depthwise_conv = torch.nn.Conv1d( + channels, + channels, + depthwise_kernel_size, + stride=1, + padding=(depthwise_kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + self.batch_norm = torch.nn.BatchNorm1d(channels) + self.activation = get_activation_fn(activation_fn)(channels) + self.pointwise_conv2 = torch.nn.Conv1d( + channels, + embed_dim, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.dropout = torch.nn.Dropout(dropout) + + def forward(self, x): + """ + Args: + x: Input of shape B X T X C + Returns: + Tensor of shape B X T X C + """ + x = self.layer_norm(x) + # exchange the temporal dimension and the feature dimension + x = x.transpose(1, 2) + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + x = self.glu(x) # (batch, channel, dim) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + x = self.batch_norm(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) + x = self.dropout(x) + return x.transpose(1, 2) + + +class FeedForwardModule(torch.nn.Module): + """Positionwise feed forward layer used in conformer""" + + def __init__( + self, + input_feat, + hidden_units, + dropout1, + dropout2, + activation_fn="swish", + bias=True, + ): + """ + Args: + input_feat: Input feature dimension + hidden_units: Hidden unit dimension + dropout1: dropout value for layer1 + dropout2: dropout value for layer2 + activation_fn: Name of activation function + bias: If linear layers should have bias + """ + + super(FeedForwardModule, self).__init__() + self.layer_norm = LayerNorm(input_feat) + self.w_1 = torch.nn.Linear(input_feat, hidden_units, bias=bias) + self.w_2 = torch.nn.Linear(hidden_units, input_feat, bias=bias) + self.dropout1 = torch.nn.Dropout(dropout1) + self.dropout2 = torch.nn.Dropout(dropout2) + self.activation = get_activation_fn(activation_fn)(hidden_units) + + def forward(self, x): + """ + Args: + x: Input Tensor of shape T X B X C + Returns: + Tensor of shape T X B X C + """ + x = self.layer_norm(x) + x = self.w_1(x) + x = self.activation(x) + x = self.dropout1(x) + x = self.w_2(x) + return self.dropout2(x) + + +class ConformerEncoderLayer(torch.nn.Module): + """Conformer block based on https://arxiv.org/abs/2005.08100. We currently don't support relative positional encoding in MHA""" + + def __init__( + self, + embed_dim, + ffn_embed_dim, + attention_heads, + dropout, + use_fp16, + depthwise_conv_kernel_size=31, + activation_fn="swish", + attn_type=None, + pos_enc_type="abs", + ): + """ + Args: + embed_dim: Input embedding dimension + ffn_embed_dim: FFN layer dimension + attention_heads: Number of attention heads in MHA + dropout: dropout value + depthwise_conv_kernel_size: Size of kernel in depthwise conv layer in convolution module + activation_fn: Activation function name to use in convulation block and feed forward block + attn_type: MHA implementation from ESPNET vs fairseq + pos_enc_type: Positional encoding type - abs, rope, rel_pos + """ + self.pos_enc_type = pos_enc_type + super(ConformerEncoderLayer, self).__init__() + + self.ffn1 = FeedForwardModule( + embed_dim, + ffn_embed_dim, + dropout, + dropout, + ) + + self.self_attn_layer_norm = LayerNorm(embed_dim, export=False) + self.self_attn_dropout = torch.nn.Dropout(dropout) + if attn_type == "espnet": + if self.pos_enc_type == "rel_pos": + self.self_attn = RelPositionMultiHeadedAttention( + embed_dim, + attention_heads, + dropout=dropout, + ) + elif self.pos_enc_type == "rope": + self.self_attn = RotaryPositionMultiHeadedAttention( + embed_dim, attention_heads, dropout=dropout, precision=use_fp16 + ) + elif self.pos_enc_type == "abs": + self.self_attn = ESPNETMultiHeadedAttention( + embed_dim, + attention_heads, + dropout=dropout, + ) + else: + raise Exception(f"Unsupported attention type {self.pos_enc_type}") + else: + # Default to fairseq MHA + self.self_attn = MultiheadAttention( + embed_dim, + attention_heads, + dropout=dropout, + ) + + self.conv_module = ConvolutionModule( + embed_dim=embed_dim, + channels=embed_dim, + depthwise_kernel_size=depthwise_conv_kernel_size, + dropout=dropout, + activation_fn=activation_fn, + ) + + self.ffn2 = FeedForwardModule( + embed_dim, + ffn_embed_dim, + dropout, + dropout, + activation_fn=activation_fn, + ) + self.final_layer_norm = LayerNorm(embed_dim, export=False) + + def forward( + self, + x, + encoder_padding_mask: Optional[torch.Tensor], + position_emb: Optional[torch.Tensor] = None, + ): + """ + Args: + x: Tensor of shape T X B X C + encoder_padding_mask: Optional mask tensor + positions: + Returns: + Tensor of shape T X B X C + """ + residual = x + x = self.ffn1(x) + x = x * 0.5 + residual + residual = x + x = self.self_attn_layer_norm(x) + if self.pos_enc_type == "rel_pos": + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=encoder_padding_mask, + pos_emb=position_emb, + need_weights=False, + ) + else: + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=encoder_padding_mask, + need_weights=False, + ) + x = self.self_attn_dropout(x) + x = x + residual + + residual = x + # TBC to BTC + x = x.transpose(0, 1) + x = self.conv_module(x) + # BTC to TBC + x = x.transpose(0, 1) + x = residual + x + + residual = x + x = self.ffn2(x) + + layer_result = x + + x = x * 0.5 + residual + + x = self.final_layer_norm(x) + return x, (attn, layer_result) + + +class ConformerWav2Vec2EncoderLayer(ConformerEncoderLayer): + """Encoder layer for Wav2vec2 encoder""" + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + att_args=None, + position_emb=None, + ): + return super().forward(x, self_attn_padding_mask, position_emb) diff --git a/fairseq/modules/conv_tbc.py b/fairseq/modules/conv_tbc.py new file mode 100644 index 0000000000000000000000000000000000000000..65e17ec94f7e595cb657b3d2daaa1052a95d0677 --- /dev/null +++ b/fairseq/modules/conv_tbc.py @@ -0,0 +1,53 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn.modules.utils import _single +from torch import Tensor + + +class ConvTBC(torch.nn.Module): + """1D convolution over an input of shape (time x batch x channel) + + The implementation uses gemm to perform the convolution. This implementation + is faster than cuDNN for small kernel sizes. + """ + + def __init__(self, in_channels, out_channels, kernel_size, padding=0): + super(ConvTBC, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _single(kernel_size) + self.padding = _single(padding) + + self.weight = torch.nn.Parameter( + torch.Tensor(self.kernel_size[0], in_channels, out_channels) + ) + self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) + + self.reset_parameters() + + def reset_parameters(self): + nn.init.xavier_normal_(self.weight) + nn.init.zeros_(self.bias) + + def conv_tbc(self, input: Tensor): + return torch.conv_tbc( + input.contiguous(), self.weight, self.bias, self.padding[0] + ) + + def forward(self, input: Tensor): + return self.conv_tbc(input) + + def __repr__(self): + s = ( + "{name}({in_channels}, {out_channels}, kernel_size={kernel_size}" + ", padding={padding}" + ) + if self.bias is None: + s += ", bias=False" + s += ")" + return s.format(name=self.__class__.__name__, **self.__dict__) diff --git a/fairseq/modules/cross_entropy.py b/fairseq/modules/cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..286c00eeccbc9d5596a793a20d9c7fdf3dd092f1 --- /dev/null +++ b/fairseq/modules/cross_entropy.py @@ -0,0 +1,59 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +import torch.nn.functional as F + +logger = logging.getLogger(__name__) + + +def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction="mean"): + lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) + return F.nll_loss( + lprobs, + target, + ignore_index=ignore_index, + reduction=reduction, + ) + + +try: + import xentropy_cuda + from apex.contrib import xentropy + + def cross_entropy(logits, target, ignore_index=-100, reduction="mean"): + if logits.device == torch.device("cpu"): + return _cross_entropy_pytorch(logits, target, ignore_index, reduction) + else: + if not getattr(cross_entropy, "_has_logged_once", False): + logger.info("using fused cross entropy") + cross_entropy._has_logged_once = True + + half_to_float = logits.dtype == torch.half + losses = xentropy.SoftmaxCrossEntropyLoss.apply( + logits, + target, + 0.0, + ignore_index, + half_to_float, + ) + if reduction == "sum": + return losses.sum() + elif reduction == "mean": + if ignore_index >= 0: + return losses.sum() / target.ne(ignore_index).sum() + else: + return losses.mean() + elif reduction == "none": + return losses + else: + raise NotImplementedError + +except ImportError: + + def cross_entropy(logits, target, ignore_index=-100, reduction="mean"): + return _cross_entropy_pytorch(logits, target, ignore_index, reduction) diff --git a/fairseq/modules/cuda_utils.cu b/fairseq/modules/cuda_utils.cu new file mode 100644 index 0000000000000000000000000000000000000000..924f852758ee654e462546274db4b5e7199a9c90 --- /dev/null +++ b/fairseq/modules/cuda_utils.cu @@ -0,0 +1,202 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +template +constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) { + return (a + b - 1) / b; +} + +template +__inline__ __device__ void zeroSharedMem(scalar_t* data) { + /* + Given an array of length FS + SB, zero out the first padding_l and last + (FS - padding_l) values in the array + */ + + int tid = threadIdx.x; + + if (FS < SB) { + // zero all if we have enough threads in a block to do all of them + if (tid < padding_l || tid > SB - FS + padding_l - 1) { + data[tid] = scalar_t(0.0); + } + } else { + // otherwise zero out one block at a time + const int numIterations = divUp(FS, SB); + for (int i = 0; i < numIterations; i++) { + int offset = i * SB; + if (tid + offset < padding_l) { + data[tid + offset] = scalar_t(0.0); + } else if (tid + offset < FS) { + data[SB + tid + offset] = scalar_t(0.0); + } + } + } +} + +template +__inline__ __device__ scalar_t warpReduce(scalar_t data) { + /* + Reduce an array within each warp. After processing all values in warp will + caontain the sum of all original values in that warp. + + data - pointer to data to reduce + */ + data += __shfl_xor_sync(SHFL_MASK, data, 16); + data += __shfl_xor_sync(SHFL_MASK, data, 8); + data += __shfl_xor_sync(SHFL_MASK, data, 4); + data += __shfl_xor_sync(SHFL_MASK, data, 2); + data += __shfl_xor_sync(SHFL_MASK, data, 1); + return data; +} + +template +__inline__ __device__ scalar_t blockReduce(scalar_t data) { + /* + Reduce an entire array on the block level. After processing, the + first value in the array will contain the reduced sum. + + data - pointer to data to reduce + */ + + static __shared__ scalar_t warpSum[32]; + const int tid = threadIdx.x; + int wid = tid / 32; + int lane = tid % 32; + + __syncthreads(); + + // reduce each warp then write to shared memory + scalar_t sum = warpReduce(data); + if (lane == 0) { + warpSum[wid] = sum; + } + + __syncthreads(); + + scalar_t v; + // perform final sum of partial warp sums + if (tid < blockDim.x / 32) { + v = warpSum[lane]; + } else { + v = scalar_t(0.0); + } + + if (wid == 0) { + v = warpReduce(v); + } + __syncthreads(); + + return v; +} + +void checkCudaStatus(cudaError_t status, int lineNumber = -1) { + if (status != cudaSuccess) { + std::cout << cudaGetErrorString(status) << " at line " << lineNumber + << std::endl; + std::cout << "Exiting" << std::endl; + exit(1); + } +} + +template +__device__ void load_input_to_shared( + const scalar_t* input, // global memory + int inputOffset, + int sequenceLength, + int iteration, + int numIterations, + bool no_prev, + scalar_t* output /* shared memory */) { + /* + Load a block size of input into shared memory with + right and left overhang of total size FS. If previously + loaded memory, overlap will be shifted over to reduce + global memory access + + input - pointer to start of channel sequence + inputOffset - how far in the sequence to start loading + sequenceLength - total length of sequence + iteration - which block of sequence we are loading + numIterations - total number of blocks to load + no_prev - whether to load the whole block if the previous block + wasn't loaded + output - shared memory to write input to + */ + + const int tid = threadIdx.x; + + // Load the left "overhang" of input + if (iteration > 0) { + if (padding_l < SB) { + // load all at once + if (tid < padding_l) { + output[tid] = + (no_prev) ? input[inputOffset - padding_l + tid] : output[tid + SB]; + } + } else { + // load in chunks of size SB + int numIterations = divUp(padding_l, SB); + for (int i = 0; i < numIterations; i++) { + int offset = i * SB; + if ((tid + offset) < padding_l) { + output[tid + offset] = (no_prev) + ? input[inputOffset - padding_l + tid + offset] + : output[tid + offset + SB]; + } + } + } + } + + // Load the right "overhang" of input + if (iteration < (numIterations - 1)) { + const int elementsLeft = sequenceLength - (iteration + 1) * SB; + + if ((FS - padding_l) < SB) { + // load all at once + if (tid < (FS - padding_l)) { + output[padding_l + SB + tid] = (tid < elementsLeft) + ? input[inputOffset + SB + tid] + : scalar_t(0.0); + } + } else { + // load in chunks of size SB + int numIterations = divUp(FS - padding_l, SB); + for (int i = 0; i < numIterations; i++) { + int offset = i * SB; + if ((tid + offset) < (FS - padding_l)) { + output[padding_l + SB + tid + offset] = + ((tid + offset) < elementsLeft) + ? input[inputOffset + SB + tid + offset] + : scalar_t(0.0); + } + } + } + } + + // We should also clear out the right "overhang" + if (iteration == (numIterations - 1)) { + if ((FS - padding_l) < SB) { + // clear out all at once + if (tid < (FS - padding_l)) { + output[padding_l + SB + tid] = scalar_t(0.0); + } + } else { + // clear in chunks of size SB + int numIterations = divUp(FS - padding_l, SB); + for (int i = 0; i < numIterations; i++) { + int offset = i * SB; + if ((tid + offset) < (FS - padding_l)) { + output[padding_l + SB + tid + offset] = scalar_t(0.0); + } + } + } + } + output[tid + padding_l] = ((inputOffset + tid) < sequenceLength) + ? input[inputOffset + tid] + : scalar_t(0.0); +} diff --git a/fairseq/modules/downsampled_multihead_attention.py b/fairseq/modules/downsampled_multihead_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..5e42942a9f6c8a396ff848d181dd8c7885ee97bc --- /dev/null +++ b/fairseq/modules/downsampled_multihead_attention.py @@ -0,0 +1,317 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.scalar_bias import scalar_bias + + +class SingleHeadAttention(nn.Module): + """ + Single-head attention that supports Gating and Downsampling + """ + + def __init__( + self, + out_channels, + embed_dim, + head_dim, + head_index, + dropout=0.0, + bias=True, + project_input=True, + gated=False, + downsample=False, + num_heads=1, + ): + super().__init__() + self.embed_dim = embed_dim + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) + self.head_index = head_index + self.head_dim = head_dim + self.project_input = project_input + self.gated = gated + self.downsample = downsample + self.num_heads = num_heads + self.projection = None + + k_layers = [] + v_layers = [] + if self.downsample: + k_layers.append(Downsample(self.head_index)) + v_layers.append(Downsample(self.head_index)) + out_proj_size = self.head_dim + else: + out_proj_size = self.head_dim * self.num_heads + if self.gated: + k_layers.append(GatedLinear(self.embed_dim, out_proj_size, bias=bias)) + self.in_proj_q = GatedLinear(self.embed_dim, out_proj_size, bias=bias) + v_layers.append(GatedLinear(self.embed_dim, out_proj_size, bias=bias)) + else: + k_layers.append(Linear(self.embed_dim, out_proj_size, bias=bias)) + self.in_proj_q = Linear(self.embed_dim, out_proj_size, bias=bias) + v_layers.append(Linear(self.embed_dim, out_proj_size, bias=bias)) + + self.in_proj_k = nn.Sequential(*k_layers) + self.in_proj_v = nn.Sequential(*v_layers) + + if self.downsample: + self.out_proj = Linear(out_proj_size, self.head_dim, bias=bias) + else: + self.out_proj = Linear(out_proj_size, out_channels, bias=bias) + + self.scaling = self.head_dim**-0.5 + + def forward( + self, + query, + key, + value, + mask_future_timesteps=False, + key_padding_mask=None, + use_scalar_bias=False, + ): + """Input shape: Time x Batch x Channel + Self-attention can be implemented by passing in the same arguments for + query, key and value. Future timesteps can be masked with the + `mask_future_timesteps` argument. Padding elements can be excluded from + the key by passing a binary ByteTensor (`key_padding_mask`) with shape: + batch x src_len, where padding elements are indicated by 1s. + """ + src_len, bsz, out_channels = key.size() + tgt_len = query.size(0) + assert list(query.size()) == [tgt_len, bsz, out_channels] + assert key.size() == value.size() + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.downsample: + size = bsz + else: + size = bsz * self.num_heads + + k = key + v = value + q = query + if self.project_input: + q = self.in_proj_q(q) + k = self.in_proj_k(k) + v = self.in_proj_v(v) + src_len = k.size()[0] + q *= self.scaling + + if not self.downsample: + q = q.view(tgt_len, size, self.head_dim) + k = k.view(src_len, size, self.head_dim) + v = v.view(src_len, size, self.head_dim) + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + if mask_future_timesteps: + assert ( + query.size() == key.size() + ), "mask_future_timesteps only applies to self-attention" + attn_weights *= torch.tril( + attn_weights.data.new([1]).expand(tgt_len, tgt_len).clone(), + diagonal=-1, + )[:, :: self.head_index + 1 if self.downsample else 1].unsqueeze(0) + attn_weights += torch.triu( + attn_weights.data.new([-math.inf]).expand(tgt_len, tgt_len).clone(), + diagonal=0, + )[:, :: self.head_index + 1 if self.downsample else 1].unsqueeze(0) + tgt_size = tgt_len + if use_scalar_bias: + attn_weights = scalar_bias(attn_weights, 2) + v = scalar_bias(v, 1) + tgt_size += 1 + + if key_padding_mask is not None: + # don't attend to padding symbols + if key_padding_mask.max() > 0: + if self.downsample: + attn_weights = attn_weights.view(bsz, 1, tgt_len, src_len) + else: + attn_weights = attn_weights.view( + size, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + -math.inf, + ) + attn_weights = attn_weights.view(size, tgt_len, src_len) + attn_weights = F.softmax(attn_weights, dim=-1) + attn_weights = self.dropout_module(attn_weights) + + attn = torch.bmm(attn_weights, v) + if self.downsample: + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.head_dim) + else: + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim) + + attn = self.out_proj(attn) + + return attn, attn_weights + + +class DownsampledMultiHeadAttention(nn.ModuleList): + """ + Multi-headed attention with Gating and Downsampling + """ + + def __init__( + self, + out_channels, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + project_input=True, + gated=False, + downsample=False, + ): + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.downsample = downsample + self.gated = gated + self.project_input = project_input + assert self.head_dim * num_heads == embed_dim + + if self.downsample: + attention_heads = [] + for index in range(self.num_heads): + attention_heads.append( + SingleHeadAttention( + out_channels, + self.embed_dim, + self.head_dim, + index, + dropout, + bias, + self.project_input, + self.gated, + self.downsample, + self.num_heads, + ) + ) + super().__init__(modules=attention_heads) + self.out_proj = Linear(embed_dim, out_channels, bias=bias) + else: + # either we have a list of attention heads, or just one attention head + # if not being downsampled, we can do the heads with one linear layer instead of separate ones + super().__init__() + self.attention_module = SingleHeadAttention( + out_channels, + self.embed_dim, + self.head_dim, + 1, + dropout, + bias, + self.project_input, + self.gated, + self.downsample, + self.num_heads, + ) + + def forward( + self, + query, + key, + value, + mask_future_timesteps=False, + key_padding_mask=None, + use_scalar_bias=False, + ): + src_len, bsz, embed_dim = key.size() + tgt_len = query.size(0) + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + assert key.size() == value.size() + + tgt_size = tgt_len + if use_scalar_bias: + tgt_size += 1 + + attn = [] + attn_weights = [] + if self.downsample: + for attention_head_number in range(self.num_heads): + # call the forward of each attention head + _attn, _attn_weight = self[attention_head_number]( + query, + key, + value, + mask_future_timesteps, + key_padding_mask, + use_scalar_bias, + ) + attn.append(_attn) + attn_weights.append(_attn_weight) + full_attn = torch.cat(attn, dim=2) + full_attn = self.out_proj(full_attn) + return full_attn, attn_weights[0].clone() + else: + _attn, _attn_weight = self.attention_module( + query, + key, + value, + mask_future_timesteps, + key_padding_mask, + use_scalar_bias, + ) + attn.append(_attn) + attn_weights.append(_attn_weight) + full_attn = torch.cat(attn, dim=2) + full_attn_weights = torch.cat(attn_weights) + full_attn_weights = full_attn_weights.view( + bsz, self.num_heads, tgt_size, src_len + ) + full_attn_weights = full_attn_weights.sum(dim=1) / self.num_heads + return full_attn, full_attn_weights + + +class Downsample(nn.Module): + """ + Selects every nth element, where n is the index + """ + + def __init__(self, index): + super().__init__() + self.index = index + + def forward(self, x): + return x[:: self.index + 1] + + +def Linear(in_features, out_features, dropout=0.0, bias=True): + """Weight-normalized Linear layer (input: B x T x C)""" + m = nn.Linear(in_features, out_features, bias=bias) + m.weight.data.normal_(mean=0, std=math.sqrt((1 - dropout) / in_features)) + m.bias.data.zero_() + return nn.utils.weight_norm(m) + + +def GatedLinear(in_features, out_features, dropout=0.0, bias=True): + """Weight-normalized Linear layer (input: B x T x C) with interspersed GLU units""" + return nn.Sequential( + Linear(in_features, out_features * 4, dropout, bias), + nn.GLU(), + Linear(out_features * 2, out_features * 2, dropout, bias), + nn.GLU(), + Linear(out_features, out_features, dropout, bias), + ) diff --git a/fairseq/modules/dynamic_convolution.py b/fairseq/modules/dynamic_convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..0ff02cd62a27dd34232cc7b5f1fccfe57fa13041 --- /dev/null +++ b/fairseq/modules/dynamic_convolution.py @@ -0,0 +1,526 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import utils +from fairseq.incremental_decoding_utils import ( + FairseqIncrementalState, + with_incremental_state, +) +from fairseq.modules.fairseq_dropout import FairseqDropout +from torch import Tensor + +from .unfold import unfold1d + + +def DynamicConv( + input_size, + kernel_size=1, + padding_l=None, + num_heads=1, + weight_dropout=0.0, + weight_softmax=False, + renorm_padding=False, + bias=False, + conv_bias=False, + query_size=None, + in_proj=False, +): + if torch.cuda.is_available(): + try: + from fairseq.modules.dynamicconv_layer import DynamicconvLayer + + return DynamicconvLayer( + input_size, + kernel_size=kernel_size, + padding_l=padding_l, + num_heads=num_heads, + weight_dropout=weight_dropout, + weight_softmax=weight_softmax, + renorm_padding=renorm_padding, + bias=bias, + conv_bias=conv_bias, + query_size=query_size, + ) + except ImportError as e: + print(e) + return DynamicConv1dTBC( + input_size, + kernel_size=kernel_size, + padding_l=padding_l, + num_heads=num_heads, + weight_dropout=weight_dropout, + weight_softmax=weight_softmax, + renorm_padding=renorm_padding, + bias=bias, + conv_bias=conv_bias, + query_size=query_size, + ) + + +def Linear(in_features, out_features, bias=True): + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.0) + return m + + +@with_incremental_state +class DynamicConv1dTBC(nn.Module): + """Dynamic lightweight convolution taking T x B x C inputs + Args: + input_size: # of channels of the input + kernel_size: convolution channels + padding_l: padding to the left when using "same" padding + num_heads: number of heads used. The weight is of shape (num_heads, 1, kernel_size) + weight_dropout: the drop rate of the DropConnect to drop the weight + weight_softmax: normalize the weight with softmax before the convolution + renorm_padding: re-normalize the filters to ignore the padded part (only the non-padding parts sum up to 1) + bias: use bias + conv_bias: bias of the convolution + query_size: specified when feeding a different input as the query + in_proj: project the input and generate the filter together + + Shape: + Input: TxBxC, i.e. (timesteps, batch_size, input_size) + Output: TxBxC, i.e. (timesteps, batch_size, input_size) + + Attributes: + weight: the learnable weights of the module of shape + `(num_heads, 1, kernel_size)` + bias: the learnable bias of the module of shape `(input_size)` + """ + + def __init__( + self, + input_size, + kernel_size=1, + padding_l=None, + num_heads=1, + weight_dropout=0.0, + weight_softmax=False, + renorm_padding=False, + bias=False, + conv_bias=False, + query_size=None, + in_proj=False, + ): + super().__init__() + self.input_size = input_size + self.query_size = input_size if query_size is None else query_size + self.kernel_size = kernel_size + self.padding_l = padding_l + self.num_heads = num_heads + self.weight_dropout_module = FairseqDropout( + weight_dropout, module_name=self.__class__.__name__ + ) + self.weight_softmax = weight_softmax + self.renorm_padding = renorm_padding + + if in_proj: + self.weight_linear = Linear( + self.input_size, self.input_size + num_heads * kernel_size * 1 + ) + else: + self.weight_linear = Linear( + self.query_size, num_heads * kernel_size * 1, bias=bias + ) + if conv_bias: + self.conv_bias = nn.Parameter(torch.Tensor(input_size)) + else: + self.conv_bias = None + self.reset_parameters() + + @property + def in_proj(self): + return ( + self.weight_linear.out_features + == self.input_size + self.num_heads * self.kernel_size + ) + + def reset_parameters(self): + self.weight_linear.reset_parameters() + if self.conv_bias is not None: + nn.init.constant_(self.conv_bias, 0.0) + + def forward(self, x, incremental_state=None, query=None, unfold=None): + """Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C + args: + x: Input of shape T x B x C, i.e. (timesteps, batch_size, input_size) + incremental_state: A dict to keep the state + unfold: unfold the input or not. If not, we use the matrix trick instead + query: use the specified query to predict the conv filters + """ + unfold = ( + x.size(0) > 512 if unfold is None else unfold + ) # use unfold mode as default for long sequence to save memory + unfold = unfold or (incremental_state is not None) + assert query is None or not self.in_proj + + if query is None: + query = x + if unfold: + output = self._forward_unfolded(x, incremental_state, query) + else: + output = self._forward_expanded(x, incremental_state, query) + + if self.conv_bias is not None: + output = output + self.conv_bias.view(1, 1, -1) + return output + + def _forward_unfolded(self, x, incremental_state, query): + """The conventional implementation of convolutions. + Unfolding the input by having a window shifting to the right.""" + T, B, C = x.size() + K, H = self.kernel_size, self.num_heads + R = C // H + assert R * H == C == self.input_size + + if self.in_proj: + proj = self.weight_linear(x) + x = proj.narrow(2, 0, self.input_size).contiguous() + weight = ( + proj.narrow(2, self.input_size, H * K).contiguous().view(T * B * H, -1) + ) + else: + weight = self.weight_linear(query).view(T * B * H, -1) + + # renorm_padding is only implemented in _forward_expanded + assert not self.renorm_padding or incremental_state is not None + + if incremental_state is not None: + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is None: + input_buffer = x.new() + x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) + if self.kernel_size > 1: + self._set_input_buffer( + incremental_state, x_unfold[:, :, :, -self.kernel_size + 1 :] + ) + x_unfold = x_unfold.view(T * B * H, R, -1) + else: + padding_l = self.padding_l + if K > T and padding_l == K - 1: + weight = weight.narrow(1, K - T, T) + K, padding_l = T, T - 1 + # unfold the input: T x B x C --> T' x B x C x K + x_unfold = unfold1d(x, K, padding_l, 0) + x_unfold = x_unfold.view(T * B * H, R, K) + + if self.weight_softmax and not self.renorm_padding: + weight = F.softmax(weight, dim=1) + weight = weight.narrow(1, 0, K) + + if incremental_state is not None: + weight = weight[:, -x_unfold.size(2) :] + K = weight.size(1) + + if self.weight_softmax and self.renorm_padding: + weight = F.softmax(weight, dim=1) + + weight = self.weight_dropout_module(weight, inplace=False) + + output = torch.bmm(x_unfold, weight.unsqueeze(2)) # T*B*H x R x 1 + output = output.view(T, B, C) + return output + + def _forward_expanded(self, x, incremental_stat, query): + """Turn the convolution filters into band matrices and do matrix multiplication. + This is faster when the sequence is short, but less memory efficient. + This is not used in the decoder during inference. + """ + T, B, C = x.size() + K, H = self.kernel_size, self.num_heads + R = C // H + assert R * H == C == self.input_size + if self.in_proj: + proj = self.weight_linear(x) + x = proj.narrow(2, 0, self.input_size).contiguous() + weight = ( + proj.narrow(2, self.input_size, H * K).contiguous().view(T * B * H, -1) + ) + else: + weight = self.weight_linear(query).view(T * B * H, -1) + + if not self.renorm_padding: + if self.weight_softmax: + weight = F.softmax(weight, dim=1) + weight = self.weight_dropout_module(weight, inplace=False) + weight = weight.narrow(1, 0, K).contiguous() + weight = weight.view(T, B * H, K).transpose(0, 1) + + x = x.view(T, B * H, R).transpose(0, 1) + if self.weight_softmax and self.renorm_padding: + # turn the convolution filters into band matrices + weight_expanded = weight.new(B * H, T, T + K - 1).fill_(float("-inf")) + weight_expanded.as_strided( + (B * H, T, K), (T * (T + K - 1), T + K, 1) + ).copy_(weight) + weight_expanded = weight_expanded.narrow(2, self.padding_l, T) + # normalize the weight over valid positions like self-attention + weight_expanded = F.softmax(weight_expanded, dim=2) + weight_expanded = self.weight_dropout_module(weight_expanded, inplace=False) + else: + P = self.padding_l + # For efficiency, we cut the kernel size and reduce the padding when the kernel is larger than the length + if K > T and P == K - 1: + weight = weight.narrow(2, K - T, T) + K, P = T, T - 1 + # turn the convolution filters into band matrices + weight_expanded = weight.new_zeros(B * H, T, T + K - 1, requires_grad=False) + weight_expanded.as_strided( + (B * H, T, K), (T * (T + K - 1), T + K, 1) + ).copy_(weight) + weight_expanded = weight_expanded.narrow(2, P, T) # B*H x T x T + output = torch.bmm(weight_expanded, x) + output = output.transpose(0, 1).contiguous().view(T, B, C) + return output + + def reorder_incremental_state(self, incremental_state, new_order): + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + input_buffer = input_buffer.index_select(1, new_order) + self._set_input_buffer(incremental_state, input_buffer) + + def _get_input_buffer(self, incremental_state): + return utils.get_incremental_state(self, incremental_state, "input_buffer") + + def _set_input_buffer(self, incremental_state, new_buffer): + return utils.set_incremental_state( + self, incremental_state, "input_buffer", new_buffer + ) + + def extra_repr(self): + s = "{}, kernel_size={}, padding_l={}, num_heads={}, weight_softmax={}, conv_bias={}, renorm_padding={}, in_proj={}".format( + self.input_size, + self.kernel_size, + self.padding_l, + self.num_heads, + self.weight_softmax, + self.conv_bias is not None, + self.renorm_padding, + self.in_proj, + ) + + if self.query_size != self.input_size: + s += ", query_size={}".format(self.query_size) + if self.weight_dropout_module.p > 0.0: + s += ", weight_dropout={}".format(self.weight_dropout_module.p) + return s + + +class DynamicConv_scripatable(nn.Module, FairseqIncrementalState): + """Dynamic lightweight convolution taking T x B x C inputs + Args: + input_size: # of channels of the input + kernel_size: convolution channels + padding_l: padding to the left when using "same" padding + num_heads: number of heads used. The weight is of shape (num_heads, 1, kernel_size) + weight_dropout: the drop rate of the DropConnect to drop the weight + weight_softmax: normalize the weight with softmax before the convolution + renorm_padding: re-normalize the filters to ignore the padded part (only the non-padding parts sum up to 1) + bias: use bias + conv_bias: bias of the convolution + query_size: specified when feeding a different input as the query + in_proj: project the input and generate the filter together + + Shape: + Input: TxBxC, i.e. (timesteps, batch_size, input_size) + Output: TxBxC, i.e. (timesteps, batch_size, input_size) + + Attributes: + weight: the learnable weights of the module of shape + `(num_heads, 1, kernel_size)` + bias: the learnable bias of the module of shape `(input_size)` + """ + + def __init__( + self, + input_size, + kernel_size=1, + padding_l=None, + num_heads=1, + weight_dropout=0.0, + weight_softmax=False, + renorm_padding=False, + bias=False, + conv_bias=False, + query_size=None, + in_proj=False, + ): + super().__init__() + self.input_size = input_size + self.query_size = input_size if query_size is None else query_size + self.kernel_size = kernel_size + self.padding_l = padding_l + self.num_heads = num_heads + self.weight_dropout_module = FairseqDropout( + weight_dropout, module_name=self.__class__.__name__ + ) + self.weight_softmax = weight_softmax + self.renorm_padding = renorm_padding + + if in_proj: + self.weight_linear = Linear( + self.input_size, self.input_size + num_heads * kernel_size * 1 + ) + else: + self.weight_linear = Linear( + self.query_size, num_heads * kernel_size * 1, bias=bias + ) + self.in_proj = ( + self.weight_linear.out_features + == self.input_size + self.num_heads * self.kernel_size + ) + self.has_conv_bias = conv_bias + self.conv_bias = nn.Parameter(torch.Tensor(input_size).view(1, 1, -1)) + self.init_incremental_state() + + self.reset_parameters() + + def reset_parameters(self): + self.weight_linear.reset_parameters() + if self.has_conv_bias: + nn.init.constant_(self.conv_bias, 0.0) + + def forward( + self, + x, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + query: Optional[Tensor] = None, + ): + """Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C + args: + x: Input of shape T x B x C, i.e. (timesteps, batch_size, input_size) + incremental_state: A dict to keep the state + unfold: unfold the input or not. If not, we use the matrix trick instead + query: use the specified query to predict the conv filters + """ + assert query is None or not self.in_proj + + if query is None: + query = x + + output = self._forward_unfolded(x, incremental_state, query) + + if self.has_conv_bias: + output = output + self.conv_bias + return output + + def _forward_unfolded( + self, + x, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + query, + ): + """The conventional implementation of convolutions. + Unfolding the input by having a window shifting to the right.""" + T, B, C = x.size() + K, H = self.kernel_size, self.num_heads + R = C // H + assert R * H == C == self.input_size + + TxBxH = T * B * H + + if self.in_proj: + proj = self.weight_linear(x) + x = proj.narrow(2, 0, self.input_size).contiguous() + weight = proj.narrow(2, self.input_size, H * K).contiguous().view(TxBxH, -1) + else: + weight = self.weight_linear(query).view(TxBxH, -1) + + # renorm_padding is only implemented in _forward_expanded + assert not self.renorm_padding or incremental_state is not None + + if incremental_state is not None: + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) + else: + x_unfold = x.unsqueeze(3).clone() + if self.kernel_size > 1: + self._set_input_buffer( + incremental_state, x_unfold[:, :, :, -self.kernel_size + 1 :] + ) + x_unfold = x_unfold.view(TxBxH, R, -1) + else: + padding_l = self.padding_l + if K > T and padding_l == K - 1: + weight = weight.narrow(1, K - T, T) + K, padding_l = T, T - 1 + # unfold the input: T x B x C --> T' x B x C x K + x_unfold = unfold1d(x, K, padding_l, 0.0) + x_unfold = x_unfold.view(TxBxH, R, K) + + if self.weight_softmax and not self.renorm_padding: + weight = F.softmax(weight, dim=1) + weight = weight.narrow(1, 0, K) + + if incremental_state is not None: + weight = weight[:, -(x_unfold.size(2)) :] + K = weight.size(1) + + if self.weight_softmax and self.renorm_padding: + weight = F.softmax(weight, dim=1) + + weight = self.weight_dropout_module(weight, inplace=False) + + output = torch.bmm(x_unfold, weight.unsqueeze(2)) # T x B x H x R x 1 + output = output.view(T, B, C) + return output + + def reorder_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + new_order: Tensor, + ): + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + input_buffer = input_buffer.index_select(1, new_order) + self._set_input_buffer(incremental_state, input_buffer) + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ): + result = self.get_incremental_state(incremental_state, "input_buffer") + if result is not None and "input_buffer" in result: + return result["input_buffer"] + else: + return None + + def _set_input_buffer( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + new_buffer: Optional[Tensor], + ): + result = self.set_incremental_state( + incremental_state, "input_buffer", {"input_buffer": new_buffer} + ) + if result is not None: + incremental_state = result + return incremental_state + + def extra_repr(self): + s = "{}, kernel_size={}, padding_l={}, num_heads={}, weight_softmax={}, conv_bias={}, renorm_padding={}, in_proj={}".format( # noqa + self.input_size, + self.kernel_size, + self.padding_l, + self.num_heads, + self.weight_softmax, + self.conv_bias is not None, + self.renorm_padding, + self.in_proj, + ) + + if self.query_size != self.input_size: + s += ", query_size={}".format(self.query_size) + if self.weight_dropout_module.p > 0.0: + s += ", weight_dropout={}".format(self.weight_dropout_module.p) + return s diff --git a/fairseq/modules/dynamic_crf_layer.py b/fairseq/modules/dynamic_crf_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..8fcc6b8d2672d2eacc6d01b9688bac44d5e1ce26 --- /dev/null +++ b/fairseq/modules/dynamic_crf_layer.py @@ -0,0 +1,189 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This file is to re-implemented the low-rank and beam approximation of CRF layer +Proposed by: + +Sun, Zhiqing, et al. +Fast Structured Decoding for Sequence Models +https://arxiv.org/abs/1910.11555 + +The CRF implementation is mainly borrowed from +https://github.com/kmkurn/pytorch-crf/blob/master/torchcrf/__init__.py + +""" + +import numpy as np +import torch +import torch.nn as nn + + +def logsumexp(x, dim=1): + return torch.logsumexp(x.float(), dim=dim).type_as(x) + + +class DynamicCRF(nn.Module): + """Dynamic CRF layer is used to approximate the traditional + Conditional Random Fields (CRF) + $P(y | x) = 1/Z(x) exp(sum_i s(y_i, x) + sum_i t(y_{i-1}, y_i, x))$ + + where in this function, we assume the emition scores (s) are given, + and the transition score is a |V| x |V| matrix $M$ + + in the following two aspects: + (1) it used a low-rank approximation for the transition matrix: + $M = E_1 E_2^T$ + (2) it used a beam to estimate the normalizing factor Z(x) + """ + + def __init__(self, num_embedding, low_rank=32, beam_size=64): + super().__init__() + + self.E1 = nn.Embedding(num_embedding, low_rank) + self.E2 = nn.Embedding(num_embedding, low_rank) + + self.vocb = num_embedding + self.rank = low_rank + self.beam = beam_size + + def extra_repr(self): + return "vocab_size={}, low_rank={}, beam_size={}".format( + self.vocb, self.rank, self.beam + ) + + def forward(self, emissions, targets, masks, beam=None): + """ + Compute the conditional log-likelihood of a sequence of target tokens given emission scores + + Args: + emissions (`~torch.Tensor`): Emission score are usually the unnormalized decoder output + ``(batch_size, seq_len, vocab_size)``. We assume batch-first + targets (`~torch.LongTensor`): Sequence of target token indices + ``(batch_size, seq_len) + masks (`~torch.ByteTensor`): Mask tensor with the same size as targets + + Returns: + `~torch.Tensor`: approximated log-likelihood + """ + numerator = self._compute_score(emissions, targets, masks) + denominator = self._compute_normalizer(emissions, targets, masks, beam) + return numerator - denominator + + def forward_decoder(self, emissions, masks=None, beam=None): + """ + Find the most likely output sequence using Viterbi algorithm. + + Args: + emissions (`~torch.Tensor`): Emission score are usually the unnormalized decoder output + ``(batch_size, seq_len, vocab_size)``. We assume batch-first + masks (`~torch.ByteTensor`): Mask tensor with the same size as targets + + Returns: + `~torch.LongTensor`: decoded sequence from the CRF model + """ + return self._viterbi_decode(emissions, masks, beam) + + def _compute_score(self, emissions, targets, masks=None): + batch_size, seq_len = targets.size() + emission_scores = emissions.gather(2, targets[:, :, None])[:, :, 0] # B x T + transition_scores = (self.E1(targets[:, :-1]) * self.E2(targets[:, 1:])).sum(2) + + scores = emission_scores + scores[:, 1:] += transition_scores + + if masks is not None: + scores = scores * masks.type_as(scores) + return scores.sum(-1) + + def _compute_normalizer(self, emissions, targets=None, masks=None, beam=None): + # HACK: we include "target" which is a hueristic for training + # HACK: we use a beam of tokens to approximate the normalizing factor (which is bad?) + + beam = beam if beam is not None else self.beam + batch_size, seq_len = emissions.size()[:2] + if targets is not None: + _emissions = emissions.scatter(2, targets[:, :, None], np.float("inf")) + beam_targets = _emissions.topk(beam, 2)[1] + beam_emission_scores = emissions.gather(2, beam_targets) + else: + beam_emission_scores, beam_targets = emissions.topk(beam, 2) + beam_transition_score1 = self.E1(beam_targets[:, :-1]) # B x (T-1) x K x D + beam_transition_score2 = self.E2(beam_targets[:, 1:]) # B x (T-1) x K x D + beam_transition_matrix = torch.bmm( + beam_transition_score1.view(-1, beam, self.rank), + beam_transition_score2.view(-1, beam, self.rank).transpose(1, 2), + ) + beam_transition_matrix = beam_transition_matrix.view(batch_size, -1, beam, beam) + + # compute the normalizer in the log-space + score = beam_emission_scores[:, 0] # B x K + for i in range(1, seq_len): + next_score = score[:, :, None] + beam_transition_matrix[:, i - 1] + next_score = logsumexp(next_score, dim=1) + beam_emission_scores[:, i] + + if masks is not None: + score = torch.where(masks[:, i : i + 1], next_score, score) + else: + score = next_score + + # Sum (log-sum-exp) over all possible tags + return logsumexp(score, dim=1) + + def _viterbi_decode(self, emissions, masks=None, beam=None): + # HACK: we use a beam of tokens to approximate the normalizing factor (which is bad?) + + beam = beam if beam is not None else self.beam + batch_size, seq_len = emissions.size()[:2] + beam_emission_scores, beam_targets = emissions.topk(beam, 2) + beam_transition_score1 = self.E1(beam_targets[:, :-1]) # B x (T-1) x K x D + beam_transition_score2 = self.E2(beam_targets[:, 1:]) # B x (T-1) x K x D + beam_transition_matrix = torch.bmm( + beam_transition_score1.view(-1, beam, self.rank), + beam_transition_score2.view(-1, beam, self.rank).transpose(1, 2), + ) + beam_transition_matrix = beam_transition_matrix.view(batch_size, -1, beam, beam) + + traj_tokens, traj_scores = [], [] + finalized_tokens, finalized_scores = [], [] + + # compute the normalizer in the log-space + score = beam_emission_scores[:, 0] # B x K + dummy = ( + torch.arange(beam, device=score.device).expand(*score.size()).contiguous() + ) + + for i in range(1, seq_len): + traj_scores.append(score) + _score = score[:, :, None] + beam_transition_matrix[:, i - 1] + _score, _index = _score.max(dim=1) + _score = _score + beam_emission_scores[:, i] + + if masks is not None: + score = torch.where(masks[:, i : i + 1], _score, score) + index = torch.where(masks[:, i : i + 1], _index, dummy) + else: + score, index = _score, _index + traj_tokens.append(index) + + # now running the back-tracing and find the best + best_score, best_index = score.max(dim=1) + finalized_tokens.append(best_index[:, None]) + finalized_scores.append(best_score[:, None]) + + for idx, scs in zip(reversed(traj_tokens), reversed(traj_scores)): + previous_index = finalized_tokens[-1] + finalized_tokens.append(idx.gather(1, previous_index)) + finalized_scores.append(scs.gather(1, previous_index)) + + finalized_tokens.reverse() + finalized_tokens = torch.cat(finalized_tokens, 1) + finalized_tokens = beam_targets.gather(2, finalized_tokens[:, :, None])[:, :, 0] + + finalized_scores.reverse() + finalized_scores = torch.cat(finalized_scores, 1) + finalized_scores[:, 1:] = finalized_scores[:, 1:] - finalized_scores[:, :-1] + + return finalized_scores, finalized_tokens diff --git a/fairseq/modules/dynamicconv_layer/__init__.py b/fairseq/modules/dynamicconv_layer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..22dc6f403d2a0ecdb1b9e7e69ed96bd560e93b2c --- /dev/null +++ b/fairseq/modules/dynamicconv_layer/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .dynamicconv_layer import DynamicconvLayer # noqa diff --git a/fairseq/modules/dynamicconv_layer/cuda_function_gen.py b/fairseq/modules/dynamicconv_layer/cuda_function_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..9304f99eb8169a614f39babc830c84cac80e080b --- /dev/null +++ b/fairseq/modules/dynamicconv_layer/cuda_function_gen.py @@ -0,0 +1,223 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +def gen_forward(): + + kernels = [3, 5, 7, 15, 31, 63, 127, 255] + blocks = [32, 64, 128, 256] + + head = """ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "dynamicconv_cuda.cuh" + +std::vector dynamicconv_cuda_forward(at::Tensor input, at::Tensor weight, int padding_l) { + + at::DeviceGuard g(input.device()); + const auto minibatch = input.size(0); + const auto numFeatures = input.size(1); + const auto sequenceLength = input.size(2); + + const auto numHeads = weight.size(1); + const auto filterSize = weight.size(2); + + const auto numFiltersInBlock = numFeatures / numHeads; + const dim3 blocks(minibatch, numFeatures); + + auto output = at::zeros_like(input); + auto stream = at::cuda::getCurrentCUDAStream(); +""" + + switch = """ + switch(filterSize) { +""" + + case_k = """ + case {k}: +""" + + main_block = """ + if (padding_l == {pad}) {{ + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "dynamicconv_forward", ([&] {{ + dynamicconv_forward_kernel<{k}, {b_size}, {pad}, scalar_t> + <<>>( + input.data(), + weight.data(), + minibatch, + sequenceLength, + numFeatures, + numFiltersInBlock, + numHeads, + output.data()); + }})); + }} else +""" + + bad_padding = """ + { + std::cout << "WARNING: Unsupported padding size - skipping forward pass" << std::endl; + } + break;\n +""" + + end = """ + default: + std::cout << "WARNING: Unsupported filter length passed - skipping forward pass" << std::endl; + } + + return {output}; +} +""" + + with open("dynamicconv_cuda_forward.cu", "w") as forward: + forward.write(head) + forward.write(switch) + for k in kernels: + b_size = 32 + for b in blocks: + if b > k: + b_size = b + break + forward.write(case_k.format(k=k)) + for pad in [k // 2, k - 1]: + forward.write(main_block.format(k=k, b_size=b_size, pad=pad)) + forward.write(bad_padding) + forward.write(end) + + +def gen_backward(): + + kernels = [3, 5, 7, 15, 31, 63, 127, 255] + thresh = [512, 512, 512, 512, 512, 380, 256, 256] + min_block = [64, 64, 64, 64, 64, 64, 128, 256] + seqs = [32 * x for x in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]] + + head = """ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "dynamicconv_cuda.cuh" + +std::vector dynamicconv_cuda_backward(at::Tensor gradOutput, int padding_l, at::Tensor input, at::Tensor weight) { + + at::DeviceGuard g(input.device()); + const auto minibatch = input.size(0); + const auto numFeatures = input.size(1); + const auto sequenceLength = input.size(2); + + const auto numHeads = weight.size(1); + const auto filterSize = weight.size(2); + + const auto numFiltersInBlock = numFeatures / numHeads; + auto numChunks = 1; + + auto gradInput = at::zeros_like(input); + auto gradWeight = at::zeros_like(weight); + auto stream = at::cuda::getCurrentCUDAStream(); + + dim3 blocks(minibatch, numHeads, numChunks); +""" + + sequence_if = """ + if (sequenceLength < {seq}) {{ + switch(filterSize) {{ +""" + + case_k = """ + case {k}: +""" + + chunks_reset = """ + numChunks = int(ceilf(sequenceLength/float({b_size}))); + blocks = dim3(minibatch, numHeads, numChunks); +""" + + main_block = """ + if (padding_l == {p}) {{ + AT_DISPATCH_FLOATING_TYPES_AND_HALF(gradOutput.scalar_type(), "dynamicconv_backward", ([&] {{ + dynamicconv_backward_kernel<{k}, {b_size}, {p}, scalar_t> + <<>>( + gradOutput.data(), + input.data(), + weight.data(), + minibatch, + sequenceLength, + numFeatures, + numFiltersInBlock, + numHeads, + gradWeight.data(), + gradInput.data()); + }})); + }} else +""" + + bad_padding = """ + { + std::cout << "WARNING: Unsupported padding size - skipping backward pass" << std::endl; + } + break;\n +""" + + bad_filter = """ + default: + std::cout << "WARNING: Unsupported filter length passed - skipping backward pass" << std::endl; + } +""" + + con_else = """ + } else +""" + + final_else = """ + { + switch(filterSize) { +""" + + last_return = """ + } + return {gradInput, gradWeight}; +} +""" + + with open("dynamicconv_cuda_backward.cu", "w") as backward: + backward.write(head) + for seq in seqs: + backward.write(sequence_if.format(seq=seq)) + for k, t, m in zip(kernels, thresh, min_block): + backward.write(case_k.format(k=k)) + if seq <= t: + b_size = seq + else: + b_size = m + backward.write(chunks_reset.format(b_size=b_size)) + for p in [k // 2, k - 1]: + backward.write(main_block.format(k=k, b_size=b_size, p=p)) + backward.write(bad_padding) + backward.write(bad_filter) + backward.write(con_else) + backward.write(final_else) + for k, m in zip(kernels, min_block): + backward.write(case_k.format(k=k)) + backward.write(chunks_reset.format(b_size=m)) + for p in [k // 2, k - 1]: + backward.write(main_block.format(k=k, b_size=m, p=p)) + backward.write(bad_padding) + backward.write(bad_filter) + backward.write(last_return) + + +if __name__ == "__main__": + gen_forward() + gen_backward() diff --git a/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cpp b/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..744c363e550231b8e0fbb94f998d46039daf5c00 --- /dev/null +++ b/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cpp @@ -0,0 +1,51 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +std::vector +dynamicconv_cuda_forward(at::Tensor input, at::Tensor filters, int padding_l); + +std::vector dynamicconv_cuda_backward( + at::Tensor gradOutput, + int padding_l, + at::Tensor input, + at::Tensor filters); + +#define CHECK_CUDA(x) \ + AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +std::vector +dynamicconv_forward(at::Tensor input, at::Tensor filters, int padding_l) { + CHECK_INPUT(input); + CHECK_INPUT(filters); + + return dynamicconv_cuda_forward(input, filters, padding_l); +} + +std::vector dynamicconv_backward( + at::Tensor gradOutput, + int padding_l, + at::Tensor input, + at::Tensor filters) { + CHECK_INPUT(gradOutput); + CHECK_INPUT(input); + CHECK_INPUT(filters); + + return dynamicconv_cuda_backward(gradOutput, padding_l, input, filters); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &dynamicconv_forward, "dynamicconv forward (CUDA)"); + m.def("backward", &dynamicconv_backward, "dynamicconv backward (CUDA)"); +} diff --git a/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cuh b/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cuh new file mode 100644 index 0000000000000000000000000000000000000000..44baf21bdd2d4a7a692ae6f7953a413ea6513268 --- /dev/null +++ b/fairseq/modules/dynamicconv_layer/dynamicconv_cuda.cuh @@ -0,0 +1,50 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#define SHFL_MASK 0xffffffff + +template +__global__ void dynamicconv_forward_kernel( + const scalar_t* input, + const scalar_t* weight, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + int numHeads, + scalar_t* output); + +template +__global__ void dynamicconv_backward_kernel( + const scalar_t* gradOutput, // B * C * T + const scalar_t* input, // B * C * T + const scalar_t* weight, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + int numHeads, + scalar_t* gradWeight, + scalar_t* gradInput); // B * H * k * T diff --git a/fairseq/modules/dynamicconv_layer/dynamicconv_cuda_kernel.cu b/fairseq/modules/dynamicconv_layer/dynamicconv_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..4630f1e9826aea61973f2e82feb57ac0a4390735 --- /dev/null +++ b/fairseq/modules/dynamicconv_layer/dynamicconv_cuda_kernel.cu @@ -0,0 +1,176 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "../cuda_utils.cu" +#include "dynamicconv_cuda.cuh" +#include "dynamicconv_cuda_backward.cu" +#include "dynamicconv_cuda_forward.cu" + +// FS is filter size and kernels are specialized for filter sizes +template +__global__ void dynamicconv_forward_kernel( + const scalar_t* input, + const scalar_t* weight, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + int numHeads, + scalar_t* output) { + assert(blockDim.x == SB); + + const int tid = threadIdx.x; + const int batchIdx = blockIdx.x; + const int featureIdx = blockIdx.y; + const int head = featureIdx / numFiltersInBlock; + + const int IOOffset = + batchIdx * numFeatures * sequenceLength + featureIdx * sequenceLength; + const scalar_t* inputFeature = &input[IOOffset]; + scalar_t* outputFeature = &output[IOOffset]; + + scalar_t filter[FS]; + + __shared__ scalar_t tempInput[SB + FS]; + zeroSharedMem(tempInput); + + const int numIterations = divUp(sequenceLength, SB); + + for (int i = 0; i < numIterations; ++i) { + __syncthreads(); + const int inputOffset = i * SB; + load_input_to_shared( + inputFeature, + inputOffset, + sequenceLength, + i, + numIterations, + false, + tempInput); + __syncthreads(); + if (inputOffset + tid < sequenceLength) { +#pragma unroll + for (int k = 0; k < FS; ++k) { + const int filterOffset = batchIdx * numHeads * FS * sequenceLength + + head * FS * sequenceLength + k * sequenceLength + i * SB + tid; + filter[k] = weight[filterOffset]; + } + + scalar_t out = scalar_t(0.0); +#pragma unroll + for (int k = 0; k < FS; ++k) { + out += filter[k] * tempInput[tid + k]; + } + + outputFeature[inputOffset + tid] = out; + } + } +} + +template +__global__ void dynamicconv_backward_kernel( + const scalar_t* gradOutput, // B * C * T + const scalar_t* input, // B * C * T + const scalar_t* weight, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + int numHeads, + scalar_t* gradWeight, + scalar_t* gradInput) { // B * H * k * T + + assert(blockDim.x == SB); + + // each block operates on a single batch and filter head + const int tid = threadIdx.x; + const int batchIdx = blockIdx.x; + const int headIdx = blockIdx.y; + const int chunkIdx = blockIdx.z; + + const int numChunks = divUp(sequenceLength, SB); + const int inputOffset = chunkIdx * SB; + + // initialize shared memory for output gradient and input + __shared__ scalar_t tempGradOutput[SB + FS]; + __shared__ scalar_t tempInput[SB + FS]; + const int padding = FS - padding_l - 1; + + zeroSharedMem(tempGradOutput); + zeroSharedMem(tempInput); + + // initialize local filter and weight gradient sum arrays + scalar_t tempGradSum[FS]; + scalar_t bfilter[FS]; + for (int k = 0; k < FS; ++k) { + tempGradSum[k] = scalar_t(0.0); + + int idxOffset = inputOffset + tid + k - padding; + if (idxOffset >= 0 && idxOffset < sequenceLength) { + int bfilterOffset = batchIdx * numHeads * FS * sequenceLength + + headIdx * FS * sequenceLength + (FS - k - 1) * sequenceLength + + idxOffset; + bfilter[k] = weight[bfilterOffset]; + } else { + bfilter[k] = scalar_t(0.0); + } + } + + // iterate over filter block + for (int featureIdx = 0; featureIdx < numFiltersInBlock; ++featureIdx) { + __syncthreads(); + + // load input and output gradient for this channel and chunk + const int IOOffset = batchIdx * numFeatures * sequenceLength + + (headIdx * numFiltersInBlock + featureIdx) * sequenceLength; + const scalar_t* inputFeature = &input[IOOffset]; + const scalar_t* gradOutputFeature = &gradOutput[IOOffset]; + scalar_t* gradInputFeature = &gradInput[IOOffset]; + + load_input_to_shared( + gradOutputFeature, + inputOffset, + sequenceLength, + chunkIdx, + numChunks, + true, + tempGradOutput); + load_input_to_shared( + inputFeature, + inputOffset, + sequenceLength, + chunkIdx, + numChunks, + true, + tempInput); + __syncthreads(); + + // sum input and weight gradients + scalar_t out = scalar_t(0.0); +#pragma unroll + for (int k = 0; k < FS; ++k) { + tempGradSum[k] += tempInput[tid + k] * tempGradOutput[tid + padding]; + out += bfilter[k] * tempGradOutput[tid + k]; + } + + if (inputOffset + tid < sequenceLength) { + gradInputFeature[inputOffset + tid] = out; + } + } + + const int gradOffset = + batchIdx * numHeads * FS * sequenceLength + headIdx * FS * sequenceLength; + scalar_t* gradWeightFeature = &gradWeight[gradOffset]; + + // write weight gradient + if (inputOffset + tid < sequenceLength) { + for (int k = 0; k < FS; ++k) { + const int outputOffset = k * sequenceLength + inputOffset + tid; + gradWeightFeature[outputOffset] = tempGradSum[k]; + } + } +} diff --git a/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py b/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..711ed03483f4089dbe91964a89021b49eeffbedc --- /dev/null +++ b/fairseq/modules/dynamicconv_layer/dynamicconv_layer.py @@ -0,0 +1,227 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import dynamicconv_cuda +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.incremental_decoding_utils import with_incremental_state +from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.unfold import unfold1d +from torch import nn +from torch.autograd import Function + + +class dynamicconvFunction(Function): + @staticmethod + def forward(ctx, x, weights, padding_l): + ctx.padding_l = padding_l + outputs = dynamicconv_cuda.forward(x, weights, padding_l) + variables = [x, weights] + ctx.save_for_backward(*variables) + return outputs[0] + + @staticmethod + def backward(ctx, grad_output): + outputs = dynamicconv_cuda.backward( + grad_output.contiguous(), ctx.padding_l, *ctx.saved_tensors + ) + grad_input, grad_weights = outputs + return grad_input, grad_weights, None + + +@with_incremental_state +class DynamicconvLayer(nn.Module): + def __init__( + self, + input_size, + kernel_size=1, + padding_l=None, + weight_softmax=False, + num_heads=1, + weight_dropout=0.0, + bias=False, + renorm_padding=False, + conv_bias=False, + query_size=None, + ): + + super(DynamicconvLayer, self).__init__() + self.input_size = input_size + self.query_size = input_size if query_size is None else query_size + self.kernel_size = kernel_size + self.padding_l = padding_l + self.num_heads = num_heads + self.weight_softmax = weight_softmax + self.weight_dropout_module = FairseqDropout( + weight_dropout, module_name=self.__class__.__name__ + ) + self.renorm_padding = renorm_padding + self.bias = bias + + self.weight_linear = nn.Linear(input_size, num_heads * kernel_size, bias) + if conv_bias: + self.conv_bias = nn.Parameter(torch.Tensor(input_size)) + else: + self.conv_bias = None + self.reset_parameters() + + def reset_parameters(self): + nn.init.xavier_uniform_(self.weight_linear.weight) + if self.conv_bias is not None: + nn.init.constant_(self.conv_bias, 0.0) + nn.init.constant_(self.weight_linaer.bias, 0.0) + + def forward(self, x, incremental_state=None, query=None, unfold=None): + + T, B, C = x.size() + K, H = self.kernel_size, self.num_heads + # R = C // H + + # during inference time, incremental BMM is faster + if incremental_state is not None: + unfold = ( + x.size(0) > 512 if unfold is None else unfold + ) # use unfold mode as default for long sequence to save memory + unfold = unfold or (incremental_state is not None) + assert query is None + + if query is None: + query = x + if unfold: + output = self._forward_unfolded(x, incremental_state, query) + else: + output = self._forward_expanded(x, incremental_state, query) + + if self.conv_bias is not None: + output = output + self.conv_bias.view(1, 1, -1) + + return output + + # during training time, use CUDA kernel + else: + weight = self.weight_linear(x).view(T, B, H, K) + if self.weight_softmax: + weight = F.softmax(weight, dim=-1) + if self.weight_dropout_module.p: + weight = self.weight_dropout_module(weight) + + weight = weight.permute(1, 2, 3, 0).contiguous() + self.filters = weight + x = x.permute(1, 2, 0).contiguous() + output = dynamicconvFunction.apply(x, weight, self.padding_l).permute( + 2, 0, 1 + ) + if self.conv_bias is not None: + output = output + self.conv_bias.view(1, 1, -1) + return output + + def reorder_incremental_state(self, incremental_state, new_order): + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + input_buffer = input_buffer.index_select(1, new_order) + self._set_input_buffer(incremental_state, input_buffer) + + def _get_input_buffer(self, incremental_state): + return utils.get_incremental_state(self, incremental_state, "input_buffer") + + def _set_input_buffer(self, incremental_state, new_buffer): + return utils.set_incremental_state( + self, incremental_state, "input_buffer", new_buffer + ) + + def _forward_unfolded(self, x, incremental_state, query): + """The conventional implementation of convolutions. + Unfolding the input by having a window shifting to the right.""" + T, B, C = x.size() + K, H = self.kernel_size, self.num_heads + R = C // H + assert R * H == C == self.input_size + + weight = self.weight_linear(query).view(T * B * H, -1) + + # renorm_padding is only implemented in _forward_expanded + assert not self.renorm_padding or incremental_state is not None + + if incremental_state is not None: + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is None: + input_buffer = x.new() + x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) + if self.kernel_size > 1: + self._set_input_buffer( + incremental_state, x_unfold[:, :, :, -self.kernel_size + 1 :] + ) + x_unfold = x_unfold.view(T * B * H, R, -1) + else: + padding_l = self.padding_l + if K > T and padding_l == K - 1: + weight = weight.narrow(1, K - T, T) + K, padding_l = T, T - 1 + # unfold the input: T x B x C --> T' x B x C x K + x_unfold = unfold1d(x, K, padding_l, 0) + x_unfold = x_unfold.view(T * B * H, R, K) + + if self.weight_softmax and not self.renorm_padding: + weight = F.softmax(weight, dim=1) + weight = weight.narrow(1, 0, K) + + if incremental_state is not None: + weight = weight[:, -x_unfold.size(2) :] + K = weight.size(1) + + if self.weight_softmax and self.renorm_padding: + weight = F.softmax(weight, dim=1) + + weight = self.weight_dropout_module(weight, inplace=False) + + output = torch.bmm(x_unfold, weight.unsqueeze(2)) # T*B*H x R x 1 + output = output.view(T, B, C) + return output + + def _forward_expanded(self, x, incremental_stat, query): + """Turn the convolution filters into band matrices and do matrix multiplication. + This is faster when the sequence is short, but less memory efficient. + This is not used in the decoder during inference. + """ + T, B, C = x.size() + K, H = self.kernel_size, self.num_heads + R = C // H + assert R * H == C == self.input_size + weight = self.weight_linear(query).view(T * B * H, -1) + + if not self.renorm_padding: + if self.weight_softmax: + weight = F.softmax(weight, dim=1) + weight = self.weight_dropout_module(weight, inplace=False) + weight = weight.narrow(1, 0, K).contiguous() + weight = weight.view(T, B * H, K).transpose(0, 1) + + x = x.view(T, B * H, R).transpose(0, 1) + if self.weight_softmax and self.renorm_padding: + # turn the convolution filters into band matrices + weight_expanded = weight.new(B * H, T, T + K - 1).fill_(float("-inf")) + weight_expanded.as_strided( + (B * H, T, K), (T * (T + K - 1), T + K, 1) + ).copy_(weight) + weight_expanded = weight_expanded.narrow(2, self.padding_l, T) + # normalize the weight over valid positions like self-attention + weight_expanded = F.softmax(weight_expanded, dim=2) + weight_expanded = self.weight_dropout_module(weight_expanded, inplace=False) + else: + P = self.padding_l + # For efficiency, we cut the kernel size and reduce the padding when the kernel is larger than the length + if K > T and P == K - 1: + weight = weight.narrow(2, K - T, T) + K, P = T, T - 1 + # turn the convolution filters into band matrices + weight_expanded = weight.new_zeros(B * H, T, T + K - 1, requires_grad=False) + weight_expanded.as_strided( + (B * H, T, K), (T * (T + K - 1), T + K, 1) + ).copy_(weight) + weight_expanded = weight_expanded.narrow(2, P, T) # B*H x T x T + output = torch.bmm(weight_expanded, x) + output = output.transpose(0, 1).contiguous().view(T, B, C) + return output diff --git a/fairseq/modules/dynamicconv_layer/dynamiconv_cpu.cpp b/fairseq/modules/dynamicconv_layer/dynamiconv_cpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d7e57c859085f98ec10960330ca763ae2764585a --- /dev/null +++ b/fairseq/modules/dynamicconv_layer/dynamiconv_cpu.cpp @@ -0,0 +1,29 @@ +#include +#include + +std::vector +dynamicconv_cpu_forward(float* input, float* filters, int padding_l); + +std::vector dynamicconv_cpu_backward( + float* gradOutput, + int padding_l, + float* input, + float* filters); + +std::vector +dynamicconv_forward(float* input, float* filters, int padding_l) { + return dynamicconv_cpu_forward(input, filters, padding_l); +} + +std::vector dynamicconv_backward( + float* gradOutput, + int padding_l, + float* input, + float* filters) { + return dynamicconv_cpu_backward(gradOutput, padding_l, input, filters); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &dynamicconv_forward, "dynamicconv forward (CPU)"); + m.def("backward", &dynamicconv_backward, "dynamicconv backward (CPU)"); +} diff --git a/fairseq/modules/dynamicconv_layer/setup.py b/fairseq/modules/dynamicconv_layer/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..6a21f7e2ee0840a3b251522275a0b32a856951d7 --- /dev/null +++ b/fairseq/modules/dynamicconv_layer/setup.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + + +setup( + name="dynamicconv_layer", + ext_modules=[ + CUDAExtension( + name="dynamicconv_cuda", + sources=[ + "dynamicconv_cuda.cpp", + "dynamicconv_cuda_kernel.cu", + ], + ), + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/fairseq/modules/ema_module.py b/fairseq/modules/ema_module.py new file mode 100644 index 0000000000000000000000000000000000000000..f0ece842d435597cba5ed287de1961e0929639e0 --- /dev/null +++ b/fairseq/modules/ema_module.py @@ -0,0 +1,215 @@ +#!/usr/bin/env python3 + +""" +Used for EMA tracking a given pytorch module. The user is responsible for calling step() +and setting the appropriate decay +""" + +import copy +from dataclasses import dataclass, field +import logging + +import torch + +from omegaconf import II +from fairseq.dataclass import FairseqDataclass + +try: + from amp_C import multi_tensor_l2norm + + multi_tensor_l2norm_available = True +except ImportError: + multi_tensor_l2norm_available = False + +logger = logging.getLogger(__name__) + + +@dataclass +class EMAModuleConfig(FairseqDataclass): + ema_decay: float = field( + default=0.9999, metadata={"help": "decay for exponential moving average model"} + ) + ema_fp32: bool = field( + default=False, + metadata={"help": "If true, store EMA model in fp32 even if model is in fp16"}, + ) + add_missing_params: bool = True + log_norms: bool = False + + +class EMAModule: + """Exponential Moving Average of Fairseq Models""" + + def __init__( + self, + model, + config: EMAModuleConfig, + copy_model=True, + device=None, + skip_keys=None, + ): + """ + @param model model to initialize the EMA with + @param config EMAConfig object with configuration like + ema_decay, ema_update_freq, ema_fp32 + @param device If provided, copy EMA to this device (e.g. gpu). + Otherwise EMA is in the same device as the model. + """ + + self.config = config + + if copy_model: + self.model = copy.deepcopy(model) + self.model.requires_grad_(False) + else: + self.model = model + + self.config = config + self.decay = config.ema_decay + self.skip_keys = skip_keys or set() + self.add_missing_params = config.add_missing_params + self.fp32_params = {} + + if device is not None: + logging.info(f"Copying EMA model to device {device}") + self.model = self.model.to(device=device) + + if self.config.ema_fp32: + self.build_fp32_params() + + self.log_norms = config.log_norms and multi_tensor_l2norm_available + self.logs = {} + + def build_fp32_params(self, state_dict=None): + """ + Store a copy of the EMA params in fp32. + If state dict is passed, the EMA params is copied from + the provided state dict. Otherwise, it is copied from the + current EMA model parameters. + """ + if not self.config.ema_fp32: + raise RuntimeError( + "build_fp32_params should not be called if ema_fp32=False. " + "Use ema_fp32=True if this is really intended." + ) + + if state_dict is None: + state_dict = self.model.state_dict() + + def _to_float(t): + return t.float() if torch.is_floating_point(t) else t + + for param_key in state_dict: + if param_key in self.fp32_params: + if param_key == "__sq_mom": + self.fp32_params[param_key] = state_dict[param_key] + else: + self.fp32_params[param_key].copy_(state_dict[param_key]) + else: + self.fp32_params[param_key] = _to_float(state_dict[param_key]) + if "__sq_mom" in self.fp32_params: + self.fp32_params["__sq_mom"][param_key] = torch.zeros_like( + self.fp32_params[param_key] + ) + + def restore(self, state_dict, build_fp32_params=False): + """Load data from a model spec into EMA model""" + self.model.load_state_dict(state_dict, strict=False) + if build_fp32_params: + self.build_fp32_params(state_dict) + + def set_decay(self, decay, weight_decay=None): + self.decay = decay + if weight_decay is not None: + self.weight_decay = weight_decay + + def get_decay(self): + return self.decay + + def _step_internal(self, new_model): + """One update of the EMA model based on new model weights""" + decay = self.decay + + ema_state_dict = {} + ema_params = ( + self.fp32_params if self.config.ema_fp32 else self.model.state_dict() + ) + + new_p = [] + ema_p = [] + + for key, param in new_model.named_parameters(): + if isinstance(param, dict): + continue + + if not self.add_missing_params and key not in ema_params: + continue + + try: + ema_param = ema_params[key] + except KeyError: + ema_param = ( + param.float().clone() if param.ndim == 1 else copy.deepcopy(param) + ) + ema_params[key] = ema_param + + if param.shape != ema_param.shape: + raise ValueError( + "incompatible tensor shapes between model param and ema param" + + "{} vs. {}".format(param.shape, ema_param.shape) + ) + + if "version" in key: + # Do not decay a model.version pytorch param + continue + + lr = 1 - decay + + if key in self.skip_keys or not param.requires_grad: + ema_params[key].copy_(param.to(dtype=ema_param.dtype).data) + ema_param = ema_params[key] + else: + if self.log_norms: + new_p.append(param) + ema_p.append(ema_param) + + ema_param.mul_(1 - lr) + ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=lr) + + ema_state_dict[key] = ema_param + + for key, param in new_model.named_buffers(): + ema_state_dict[key] = param + + if self.log_norms: + if "model_norm" in self.logs: + self.prev_model_norm = self.logs["model_norm"] + + chunk_size = 2048 * 32 + has_inf = torch.zeros( + (1, 1), dtype=torch.int, device=next(new_model.parameters()).device + ) + + new_norm = multi_tensor_l2norm(chunk_size, has_inf, [new_p], False) + old_norm = multi_tensor_l2norm(chunk_size, has_inf, [ema_p], False) + + self.logs["model_norm"] = new_norm[0] + self.logs["ema_norm"] = old_norm[0] + + self.restore(ema_state_dict, build_fp32_params=False) + + @torch.no_grad() + def step(self, new_model): + self._step_internal(new_model) + + def reverse(self, model): + """ + Load the model parameters from EMA model. + Useful for inference or fine-tuning from the EMA model. + """ + d = self.model.state_dict() + if "_ema" in d: + del d["_ema"] + + model.load_state_dict(d, strict=False) + return model diff --git a/fairseq/modules/espnet_multihead_attention.py b/fairseq/modules/espnet_multihead_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..82bc0d7b452b47b3135548f89c098272c8699c0a --- /dev/null +++ b/fairseq/modules/espnet_multihead_attention.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Multi-Head Attention layer definition.""" + +import math + +import torch +from torch import nn + +from fairseq.modules.rotary_positional_embedding import ( + RotaryPositionalEmbedding, + apply_rotary_pos_emb, +) + + +class ESPNETMultiHeadedAttention(nn.Module): + """Multi-Head Attention layer. + Args: + n_head: The number of heads. + n_feat: The number of features. + dropout: Dropout rate. + """ + + def __init__(self, n_feat, n_head, dropout): + """Construct an MultiHeadedAttention object.""" + super(ESPNETMultiHeadedAttention, self).__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.attn = None + self.dropout = nn.Dropout(p=dropout) + + def forward_qkv(self, query, key, value, **kwargs): + """Transform query, key and value. + Args: + query: Query tensor B X T1 X C + key: Key tensor B X T2 X C + value: Value tensor B X T2 X C + Returns: + torch.Tensor: Transformed query tensor B X n_head X T1 X d_k + torch.Tensor: Transformed key tensor B X n_head X T2 X d_k + torch.Tensor: Transformed value tensor B X n_head X T2 X d_k + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = q.transpose(1, 2) # (batch, head, time1, d_k) + k = k.transpose(1, 2) # (batch, head, time2, d_k) + v = v.transpose(1, 2) # (batch, head, time2, d_k) + return q, k, v + + def forward_attention(self, value, scores, mask): + """Compute attention context vector. + Args: + value: Transformed value B X n_head X T2 X d_k. + scores: Attention score B X n_head X T1 X T2 + mask: Mask T2 X B + Returns: + torch.Tensor: Transformed value B X T1 X d_model + weighted by the attention score B X T1 X T2 + """ + n_batch = value.size(0) + if mask is not None: + scores = scores.masked_fill( + mask.unsqueeze(1).unsqueeze(2).to(bool), + float("-inf"), # (batch, head, time1, time2) + ) + self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + else: + self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + p_attn = self.dropout(self.attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = ( + x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward(self, query, key, value, key_padding_mask=None, **kwargs): + """Compute scaled dot product attention. + Args: + query (torch.Tensor): Query tensor T X B X C + key (torch.Tensor): Key tensor T X B X C + value (torch.Tensor): Value tensor T X B X C + mask (torch.Tensor): Mask tensor T X B + Returns: + torch.Tensor: Output tensor T X B X D. + """ + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) + + q, k, v = self.forward_qkv(query, key, value) + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + scores = self.forward_attention(v, scores, key_padding_mask) + scores = scores.transpose(0, 1) + return scores, None + + +class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head: The number of heads. + n_feat: The number of features. + dropout: Dropout rate. + zero_triu: Whether to zero the upper triangular part of attention matrix. + """ + + def __init__(self, n_feat, n_head, dropout, zero_triu=False): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_feat, n_head, dropout) + self.zero_triu = zero_triu + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.zeros(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.zeros(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x): + """Compute relative positional encoding. + Args: + x: Input tensor B X n_head X T X 2T-1 + Returns: + torch.Tensor: Output tensor. + """ + zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x)[ + :, :, :, : x.size(-1) // 2 + 1 + ] # only keep the positions from 0 to time2 + + if self.zero_triu: + ones = torch.ones((x.size(2), x.size(3)), device=x.device) + x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] + + return x + + def forward(self, query, key, value, pos_emb, key_padding_mask=None, **kwargs): + """Compute scaled dot product attention. + Args: + query: Query tensor T X B X C + key: Key tensor T X B X C + value: Value tensor T X B X C + pos_emb: Positional embedding tensor B X 2T-1 X C + key_padding_mask: Mask tensor T X B + Returns: + torch.Tensor: Output tensor T X B X C. + """ + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) + pos_emb = pos_emb.transpose(0, 1) + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k + ) # (batch, head, time1, time2) + + scores = self.forward_attention(v, scores, key_padding_mask) + scores = scores.transpose(0, 1) + return scores, None + + +class RotaryPositionMultiHeadedAttention(ESPNETMultiHeadedAttention): + def __init__( + self, + n_feat, + n_head, + dropout, + precision, + rotary_emd_base=10000, + ): + """Construct an RotaryPositionMultiHeadedAttention object.""" + super().__init__(n_feat, n_head, dropout) + precision = torch.float + self.rotary_ndims = self.d_k # also try self.d_k//2 + if precision == "fp16": + precision = torch.half + + self.rotary_emb = RotaryPositionalEmbedding( + self.rotary_ndims, base=rotary_emd_base, precision=precision + ) + + def forward(self, query, key, value, key_padding_mask=None, **kwargs): + """Compute rotary position attention. + Args: + query: Query tensor T X B X C + key: Key tensor T X B X C + value: Value tensor T X B X C + key_padding_mask: Mask tensor T X B + Returns: + torch.Tensor: Output tensor T X B X D. + Notes: + Assumes self attn + """ + + T, B, C = value.size() + query = query.view(T, B, self.h, self.d_k) + key = key.view(T, B, self.h, self.d_k) + value = value.view(T, B, self.h, self.d_k) + cos, sin = self.rotary_emb(value, seq_len=T) + query, key = apply_rotary_pos_emb( + query, key, cos, sin, offset=0 + ) # offset is based on layer_past + + query = query.view(T, B, self.h * self.d_k) + key = key.view(T, B, self.h * self.d_k) + value = value.view(T, B, self.h * self.d_k) + + # TBD to BTD + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) + + q, k, v = self.forward_qkv(query, key, value) + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + scores = self.forward_attention(v, scores, key_padding_mask) + scores = scores.transpose(0, 1) + return scores, None diff --git a/fairseq/modules/fairseq_dropout.py b/fairseq/modules/fairseq_dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..3cddca77186f5ddd5cfb9c0ed6def9bafdf3bf1e --- /dev/null +++ b/fairseq/modules/fairseq_dropout.py @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import List, Optional + +import torch.nn as nn +import torch.nn.functional as F + + +logger = logging.getLogger(__name__) + + +class FairseqDropout(nn.Module): + def __init__(self, p, module_name=None): + super().__init__() + self.p = p + self.module_name = module_name + self.apply_during_inference = False + + def forward(self, x, inplace: bool = False): + if self.p > 0 and (self.training or self.apply_during_inference): + return F.dropout(x, p=self.p, training=True, inplace=inplace) + else: + return x + + def make_generation_fast_( + self, + name: str, + retain_dropout: bool = False, + retain_dropout_modules: Optional[List[str]] = None, + **kwargs + ): + if retain_dropout: + if retain_dropout_modules is not None and self.module_name is None: + logger.warning( + "Cannot enable dropout during inference for module {} " + "because module_name was not set".format(name) + ) + elif ( + retain_dropout_modules is None # if None, apply to all modules + or self.module_name in retain_dropout_modules + ): + logger.info( + "Enabling dropout during inference for module: {}".format(name) + ) + self.apply_during_inference = True + else: + logger.info("Disabling dropout for module: {}".format(name)) diff --git a/fairseq/modules/fp32_batch_norm.py b/fairseq/modules/fp32_batch_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..c560f338fdb79870c49aa70ec3f65fc9d6184989 --- /dev/null +++ b/fairseq/modules/fp32_batch_norm.py @@ -0,0 +1,44 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +batch norm done in fp32 (for fp16 training) +""" +import torch +import torch.nn as nn + + +class Fp32BatchNorm(nn.Module): + def __init__(self, sync=False, *args, **kwargs): + super().__init__() + + if sync: + from fairseq.distributed import utils + + if utils.get_global_world_size() == 1: + sync = False + + if sync: + self.bn = nn.SyncBatchNorm(*args, **kwargs) + else: + self.bn = nn.BatchNorm1d(*args, **kwargs) + + self.sync = sync + + def forward(self, input): + if self.bn.running_mean.dtype != torch.float: + if self.sync: + self.bn.running_mean = self.bn.running_mean.float() + self.bn.running_var = self.bn.running_var.float() + if self.bn.affine: + try: + self.bn.weight = self.bn.weight.float() + self.bn.bias = self.bn.bias.float() + except: + self.bn.float() + else: + self.bn.float() + + output = self.bn(input.float()) + return output.type_as(input) diff --git a/fairseq/modules/fp32_group_norm.py b/fairseq/modules/fp32_group_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..d03aac022e30c8c14a600062d1d86429504ba003 --- /dev/null +++ b/fairseq/modules/fp32_group_norm.py @@ -0,0 +1,25 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Layer norm done in fp32 (for fp16 training) +""" + +import torch.nn as nn +import torch.nn.functional as F + + +class Fp32GroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.group_norm( + input.float(), + self.num_groups, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) diff --git a/fairseq/modules/fp32_instance_norm.py b/fairseq/modules/fp32_instance_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..30a54496ded633ee005cf11cbdbd2cc2ebb29ca0 --- /dev/null +++ b/fairseq/modules/fp32_instance_norm.py @@ -0,0 +1,35 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Layer norm done in fp32 (for fp16 training) +""" + +import torch.nn as nn +import torch.nn.functional as F + + +class Fp32InstanceNorm(nn.InstanceNorm1d): + def __init__(self, *args, **kwargs): + self.transpose_last = "transpose_last" in kwargs and kwargs["transpose_last"] + if "transpose_last" in kwargs: + del kwargs["transpose_last"] + super().__init__(*args, **kwargs) + + def forward(self, input): + if self.transpose_last: + input = input.transpose(1, 2) + output = F.instance_norm( + input.float(), + running_mean=self.running_mean, + running_var=self.running_var, + weight=self.weight.float() if self.weight is not None else None, + bias=self.bias.float() if self.bias is not None else None, + use_input_stats=self.training or not self.track_running_stats, + momentum=self.momentum, + eps=self.eps, + ) + if self.transpose_last: + output = output.transpose(1, 2) + return output.type_as(input) diff --git a/fairseq/modules/gelu.py b/fairseq/modules/gelu.py new file mode 100644 index 0000000000000000000000000000000000000000..a2f1ecff4a3ae3de3eb7d327b9163c46b18a15ed --- /dev/null +++ b/fairseq/modules/gelu.py @@ -0,0 +1,25 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +See "Gaussian Error Linear Units (GELUs)" by Dan Hendrycks and Kevin Gimpel with +the corresponding GitHub repo: https://github.com/hendrycks/GELUs +""" + +import math + +import torch +import torch.nn as nn + + +def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return ( + 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + ) + + +def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x.float()).type_as(x) diff --git a/fairseq/modules/grad_multiply.py b/fairseq/modules/grad_multiply.py new file mode 100644 index 0000000000000000000000000000000000000000..08d15f55dfda9c61a1cf8641ea31424fe1d97f57 --- /dev/null +++ b/fairseq/modules/grad_multiply.py @@ -0,0 +1,18 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None diff --git a/fairseq/modules/gumbel_vector_quantizer.py b/fairseq/modules/gumbel_vector_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..867b019f676d72a51db8f8ea54e08fab2b535bfc --- /dev/null +++ b/fairseq/modules/gumbel_vector_quantizer.py @@ -0,0 +1,212 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class GumbelVectorQuantizer(nn.Module): + def __init__( + self, + dim, + num_vars, + temp, + groups, + combine_groups, + vq_dim, + time_first, + activation=nn.GELU(), + weight_proj_depth=1, + weight_proj_factor=1, + hard=True, + std=0, + ): + """Vector quantization using gumbel softmax + + Args: + dim: input dimension (channels) + num_vars: number of quantized vectors per group + temp: temperature for training. this should be a tuple of 3 elements: (start, stop, decay factor) + groups: number of groups for vector quantization + combine_groups: whether to use the vectors for all groups + vq_dim: dimensionality of the resulting quantized vector + time_first: if true, expect input in BxTxC format, otherwise in BxCxT + activation: what activation to use (should be a module). this is only used if weight_proj_depth is > 1 + weight_proj_depth: number of layers (with activation in between) to project input before computing logits + weight_proj_factor: this is used only if weight_proj_depth is > 1. scales the inner dimensionality of + projections by this factor + """ + super().__init__() + + self.groups = groups + self.combine_groups = combine_groups + self.input_dim = dim + self.num_vars = num_vars + self.time_first = time_first + self.hard = hard + + assert ( + vq_dim % groups == 0 + ), f"dim {vq_dim} must be divisible by groups {groups} for concatenation" + + var_dim = vq_dim // groups + num_groups = groups if not combine_groups else 1 + + self.vars = nn.Parameter(torch.FloatTensor(1, num_groups * num_vars, var_dim)) + if std == 0: + nn.init.uniform_(self.vars) + else: + nn.init.normal_(self.vars, mean=0, std=std) + + if weight_proj_depth > 1: + + def block(input_dim, output_dim): + return nn.Sequential(nn.Linear(input_dim, output_dim), activation) + + inner_dim = self.input_dim * weight_proj_factor + self.weight_proj = nn.Sequential( + *[ + block(self.input_dim if i == 0 else inner_dim, inner_dim) + for i in range(weight_proj_depth - 1) + ], + nn.Linear(inner_dim, groups * num_vars), + ) + else: + self.weight_proj = nn.Linear(self.input_dim, groups * num_vars) + nn.init.normal_(self.weight_proj.weight, mean=0, std=1) + nn.init.zeros_(self.weight_proj.bias) + + if isinstance(temp, str): + import ast + + temp = ast.literal_eval(temp) + assert len(temp) == 3, f"{temp}, {len(temp)}" + + self.max_temp, self.min_temp, self.temp_decay = temp + self.curr_temp = self.max_temp + self.codebook_indices = None + + def set_num_updates(self, num_updates): + self.curr_temp = max( + self.max_temp * self.temp_decay**num_updates, self.min_temp + ) + + def get_codebook_indices(self): + if self.codebook_indices is None: + from itertools import product + + p = [range(self.num_vars)] * self.groups + inds = list(product(*p)) + self.codebook_indices = torch.tensor( + inds, dtype=torch.long, device=self.vars.device + ).flatten() + + if not self.combine_groups: + self.codebook_indices = self.codebook_indices.view( + self.num_vars**self.groups, -1 + ) + for b in range(1, self.groups): + self.codebook_indices[:, b] += self.num_vars * b + self.codebook_indices = self.codebook_indices.flatten() + return self.codebook_indices + + def codebook(self): + indices = self.get_codebook_indices() + return ( + self.vars.squeeze(0) + .index_select(0, indices) + .view(self.num_vars**self.groups, -1) + ) + + def sample_from_codebook(self, b, n): + indices = self.get_codebook_indices() + indices = indices.view(-1, self.groups) + cb_size = indices.size(0) + assert ( + n < cb_size + ), f"sample size {n} is greater than size of codebook {cb_size}" + sample_idx = torch.randint(low=0, high=cb_size, size=(b * n,)) + indices = indices[sample_idx] + + z = self.vars.squeeze(0).index_select(0, indices.flatten()).view(b, n, -1) + return z + + def to_codebook_index(self, indices): + res = indices.new_full(indices.shape[:-1], 0) + for i in range(self.groups): + exponent = self.groups - i - 1 + res += indices[..., i] * (self.num_vars**exponent) + return res + + def forward_idx(self, x): + res = self.forward(x, produce_targets=True) + return res["x"], res["targets"] + + def forward(self, x, produce_targets=False): + + result = {"num_vars": self.num_vars * self.groups} + + if not self.time_first: + x = x.transpose(1, 2) + + bsz, tsz, fsz = x.shape + x = x.reshape(-1, fsz) + x = self.weight_proj(x) + x = x.view(bsz * tsz * self.groups, -1) + + with torch.no_grad(): + _, k = x.max(-1) + hard_x = ( + x.new_zeros(*x.shape) + .scatter_(-1, k.view(-1, 1), 1.0) + .view(bsz * tsz, self.groups, -1) + ) + hard_probs = torch.mean(hard_x.float(), dim=0) + result["code_perplexity"] = torch.exp( + -torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1) + ).sum() + + avg_probs = torch.softmax( + x.view(bsz * tsz, self.groups, -1).float(), dim=-1 + ).mean(dim=0) + result["prob_perplexity"] = torch.exp( + -torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1) + ).sum() + + result["temp"] = self.curr_temp + + if self.training: + x = F.gumbel_softmax(x.float(), tau=self.curr_temp, hard=self.hard).type_as( + x + ) + else: + x = hard_x + + x = x.view(bsz * tsz, -1) + + vars = self.vars + if self.combine_groups: + vars = vars.repeat(1, self.groups, 1) + + if produce_targets: + result["targets"] = ( + x.view(bsz * tsz * self.groups, -1) + .argmax(dim=-1) + .view(bsz, tsz, self.groups) + .detach() + ) + + x = x.unsqueeze(-1) * vars + x = x.view(bsz * tsz, self.groups, self.num_vars, -1) + x = x.sum(-2) + x = x.view(bsz, tsz, -1) + + if not self.time_first: + x = x.transpose(1, 2) # BTC -> BCT + + result["x"] = x + + return result diff --git a/fairseq/modules/kmeans_attention.py b/fairseq/modules/kmeans_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..0088d1ebdca0331cc51fa785bab9eefce939b782 --- /dev/null +++ b/fairseq/modules/kmeans_attention.py @@ -0,0 +1,744 @@ +import math +from functools import reduce, wraps +from inspect import isfunction +from operator import mul + +import torch +import torch.nn as nn +import torch.nn.functional as F +from aml.multimodal_video.utils.einops.lib import rearrange, repeat +from aml.multimodal_video.utils.einops.lib.layers.torch import Rearrange + +from fairseq.modules.local_attention import LocalAttention + +# constants + +TOKEN_SELF_ATTN_VALUE = -5e4 +KMEAN_INIT_ITERS = 10 + +# helper functions + + +def exists(val): + return val is not None + + +def identity(x, *args, **kwargs): + return x + + +def default(x, d): + if not exists(x): + return d if not isfunction(d) else d() + return x + + +def cast_tuple(x): + return x if isinstance(x, tuple) else (x,) + + +def cache_fn(f): + cache = None + + @wraps(f) + def cached_fn(*args, **kwargs): + nonlocal cache + if exists(cache): + return cache + cache = f(*args, **kwargs) + return cache + + return cached_fn + + +def to(t): + return {"device": t.device, "dtype": t.dtype} + + +def find_modules(nn_module, type): + return [module for module in nn_module.modules() if isinstance(module, type)] + + +def is_empty(t): + return t.nelement() == 0 + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +def batched_index_select(values, indices): + last_dim = values.shape[-1] + return values.gather(2, expand_dim(indices, -1, last_dim)) + + +def merge_dims(ind_from, ind_to, tensor): + shape = list(tensor.shape) + arr_slice = slice(ind_from, ind_to + 1) + shape[arr_slice] = [reduce(mul, shape[arr_slice])] + return tensor.reshape(*shape) + + +def expand_dim(t, dim, k): + t = t.unsqueeze(dim) + expand_shape = [-1] * len(t.shape) + expand_shape[dim] = k + return t.expand(*expand_shape) + + +def scatter_mean(src, t, index, dim, eps=1e-5): + numer = src.scatter_add(dim, index, t) + denom = src.scatter_add(dim, index, torch.ones_like(t)) + return numer / (denom + eps) + + +def split_at_index(dim, index, t): + pre_slices = (slice(None),) * dim + l = (*pre_slices, slice(None, index)) + r = (*pre_slices, slice(index, None)) + return t[l], t[r] + + +def reshape_dim(t, dim, split_dims): + shape = list(t.shape) + num_dims = len(shape) + dim = (dim + num_dims) % num_dims + shape[dim : dim + 1] = split_dims + return t.reshape(shape) + + +def ema(old, new, decay): + if not exists(old): + return new + return old * decay + new * (1 - decay) + + +def ema_inplace(moving_avg, new, decay): + if is_empty(moving_avg): + moving_avg.data.copy_(new) + return + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +# helper classes + + +def map_first_tuple_or_el(x, fn): + if isinstance(x, tuple): + return (fn(x[0]),) + x[1:] + return fn(x) + + +class Chunk(nn.Module): + def __init__(self, chunks, fn, along_dim=-1): + super().__init__() + self.dim = along_dim + self.chunks = chunks + self.fn = fn + + def forward(self, x, **kwargs): + if self.chunks <= 1: + return self.fn(x, **kwargs) + chunks = x.chunk(self.chunks, dim=self.dim) + return torch.cat([self.fn(c, **kwargs) for c in chunks], dim=self.dim) + + +class PreNorm(nn.ModuleList): + def __init__(self, norm_class, dim, fn): + super().__init__() + self.norm = norm_class(dim) + self.fn = fn + + def forward(self, x, **kwargs): + x = self.norm(x) + return self.fn(x, **kwargs) + + +class ReZero(nn.Module): + def __init__(self, fn): + super().__init__() + self.residual_weight = nn.Parameter(torch.zeros(1)) + self.fn = fn + + def forward(self, x, **kwargs): + x = self.fn(x, **kwargs) + return map_first_tuple_or_el(x, lambda t: t * self.residual_weight) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.g = nn.Parameter(torch.ones(1)) + self.eps = eps + + def forward(self, x): + def norm(t): + n = torch.norm(t, dim=-1, keepdim=True).clamp(min=self.eps) + return t / n * self.g + + return map_first_tuple_or_el(x, norm) + + +class ProjectInOut(nn.Module): + def __init__(self, fn, dim_in, dim_out, project_out=True): + super().__init__() + self.fn = fn + self.project_in = nn.Linear(dim_in, dim_out) + self.project_out = nn.Linear(dim_out, dim_in) if project_out else identity + + def forward(self, x, **kwargs): + x = self.project_in(x) + x, loss = self.fn(x, **kwargs) + x = self.project_out(x) + return x, loss + + +class MatrixMultiply(nn.Module): + def __init__(self, tensor, transpose=False): + super().__init__() + self.tensor = tensor + self.transpose = transpose + + def forward(self, x): + tensor = self.tensor + if self.transpose: + tensor = tensor.t() + return x @ tensor + + +# positional embeddings + + +class DepthWiseConv1d(nn.Module): + def __init__(self, dim_in, dim_out, kernel_size, stride=1, bias=True, causal=False): + super().__init__() + self.padding = ( + ((kernel_size - 1), 0) if causal else (kernel_size // 2, kernel_size // 2) + ) + + self.net = nn.Sequential( + nn.Conv1d( + dim_in, + dim_in, + kernel_size=kernel_size, + groups=dim_in, + stride=stride, + bias=bias, + ), + nn.Conv1d(dim_in, dim_out, 1, bias=bias), + ) + + def forward(self, x): + x = F.pad(x, self.padding, value=0.0) + return self.net(x) + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + position = torch.arange(0, max_seq_len, dtype=torch.float) + sinusoid_inp = torch.einsum("i,j->ij", position, inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + self.register_buffer("emb", emb) + + def forward(self, x): + return self.emb[None, : x.shape[1], :].to(x) + + +def rotate_every_two(x): + x = rearrange(x, "... (d j) -> ... d j", j=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d j -> ... (d j)") + + +def apply_rotary_pos_emb(q, k, sinu_pos): + sinu_pos = rearrange(sinu_pos, "() n (j d) -> n j d", j=2) + sin, cos = sinu_pos.unbind(dim=-2) + sin, cos = map(lambda t: repeat(t, "b n -> b (n j)", j=2), (sin, cos)) + q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k)) + return q, k + + +# kmeans related function and class + + +def update_kmeans_on_backwards(module): + module.kmean_modules = find_modules(module, Kmeans) + + def hook(_, grad_in, grad_out): + for m in module.kmean_modules: + m.update() + + return module.register_backward_hook(hook) + + +def similarity(x, means): + return torch.einsum("bhld,hcd->bhlc", x, means) + + +def dists_and_buckets(x, means): + dists = similarity(x, means) + _, buckets = torch.max(dists, dim=-1) + return dists, buckets + + +def batched_bincount(index, num_classes, dim=-1): + shape = list(index.shape) + shape[dim] = num_classes + out = index.new_zeros(shape) + out.scatter_add_(dim, index, torch.ones_like(index, dtype=index.dtype)) + return out + + +def kmeans_iter(x, means, buckets=None): + b, h, _, d, dtype, num_clusters = *x.shape, x.dtype, means.shape[1] + + if not exists(buckets): + _, buckets = dists_and_buckets(x, means) + + bins = batched_bincount(buckets, num_clusters).sum(0, keepdim=True) + zero_mask = bins.long() == 0 + + means_ = buckets.new_zeros(b, h, num_clusters, d, dtype=dtype) + means_.scatter_add_(-2, expand_dim(buckets, -1, d), x) + means_ = F.normalize(means_.sum(0, keepdim=True), dim=-1).type(dtype) + + means = torch.where(zero_mask.unsqueeze(-1), means, means_) + means = means.squeeze(0) + return means + + +def distribution(dists, window_size): + _, topk_indices = dists.topk(k=window_size, dim=-2) + indices = topk_indices.transpose(-2, -1) + return indices.reshape(*indices.size()[:2], -1) + + +class Kmeans(nn.Module): + def __init__( + self, num_heads, head_dim, num_clusters, ema_decay=0.999, commitment=1e-4 + ): + super().__init__() + self.commitment = commitment + self.ema_decay = ema_decay + + self.register_buffer("means", torch.randn(num_heads, num_clusters, head_dim)) + self.register_buffer("initted", torch.tensor(False)) + self.num_new_means = 0 + self.new_means = None + + @torch.no_grad() + def init(self, x): + if self.initted: + return + _, h, _, d, device, _ = *x.shape, x.device, x.dtype + + num_clusters = self.means.shape[1] + + means = x.transpose(0, 1).contiguous().view(h, -1, d) + num_samples = means.shape[1] + + if num_samples >= num_clusters: + indices = torch.randperm(num_samples, device=device)[:num_clusters] + else: + indices = torch.randint(0, num_samples, (num_clusters,), device=device) + + means = means[:, indices] + + for _ in range(KMEAN_INIT_ITERS): + means = kmeans_iter(x, means) + + self.num_new_means = 0 + self.means.data.copy_(means) + self.initted.data.copy_(torch.tensor(True)) + + @torch.no_grad() + def update(self, new_means=None): + new_means = default(new_means, self.new_means) + assert exists(new_means), "new kmeans has not been supplied" + ema_inplace(self.means, new_means, self.ema_decay) + + del self.new_means + self.new_means = None + self.num_new_means = 0 + + def forward(self, x, update_means=False): + self.init(x) + + b, dtype = x.shape[0], x.dtype + means = self.means.type(dtype) + x = F.normalize(x, 2, dim=-1).type(dtype) + + with torch.no_grad(): + dists, buckets = dists_and_buckets(x, means) + + routed_means = batched_index_select(expand_dim(means, 0, b), buckets) + loss = F.mse_loss(x, routed_means) * self.commitment + + if update_means: + with torch.no_grad(): + means = kmeans_iter(x, means, buckets) + self.new_means = ema( + self.new_means, means, self.num_new_means / (self.num_new_means + 1) + ) + self.num_new_means += 1 + + return dists, loss + + +# kmeans attention class + + +class KmeansAttention(nn.Module): + def __init__( + self, + num_clusters, + window_size, + num_heads, + head_dim, + causal=False, + dropout=0.0, + ema_decay=0.999, + commitment=1e-4, + context_window_size=None, + receives_context=False, + num_mem_kv=0, + shared_qk=False, + ): + super().__init__() + self.num_heads = num_heads + self.num_clusters = num_clusters + self.head_dim = head_dim + + self.window_size = window_size + self.context_window_size = default(context_window_size, window_size) + self.causal = causal + + self.shared_qk = shared_qk + self.receives_context = receives_context + self.kmeans = Kmeans(num_heads, head_dim, num_clusters, ema_decay, commitment) + self.dropout = nn.Dropout(dropout) + + self.num_mem_kv = max(num_mem_kv, 1 if causal and not shared_qk else 0) + self.mem_key = nn.Parameter( + torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim) + ) + self.mem_value = nn.Parameter( + torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim) + ) + + def forward(self, q, k, v, query_mask=None, key_mask=None, **kwargs): + b, h, t, d, kv_t, wsz, c_wsz, nc, device, dtype = ( + *q.shape, + k.shape[2], + self.window_size, + self.context_window_size, + self.num_clusters, + q.device, + q.dtype, + ) + is_reverse = kwargs.pop("_reverse", False) + + out = torch.zeros_like(q, dtype=dtype) + + update_kmeans = self.training and not is_reverse + + key_mask = ( + default(key_mask, query_mask) if not self.receives_context else key_mask + ) + kv_wsz = wsz if not self.receives_context else c_wsz + + wsz = min(wsz, t) + kv_wsz = min(kv_wsz, kv_t) + + if not self.shared_qk or self.receives_context: + dists, aux_loss = self.kmeans(torch.cat((q, k), dim=2), update_kmeans) + q_dists, k_dists = split_at_index(2, t, dists) + indices = distribution(q_dists, wsz) + kv_indices = distribution(k_dists, kv_wsz) + else: + dists, aux_loss = self.kmeans(q, update_kmeans) + k = F.normalize(k, dim=-1).to(q) + indices = distribution(dists, wsz) + kv_indices = indices + + q = batched_index_select(q, indices) + k = batched_index_select(k, kv_indices) + v = batched_index_select(v, kv_indices) + + reshape_with_window = lambda x: x.reshape(b, h, nc, -1, d) + q, k, v = map(reshape_with_window, (q, k, v)) + + m_k, m_v = map( + lambda x: expand_dim(x, 0, b).to(q), (self.mem_key, self.mem_value) + ) + k, v = map(lambda x: torch.cat(x, dim=3), ((m_k, k), (m_v, v))) + + dots = torch.einsum("bhnid,bhnjd->bhnij", q, k) * (d**-0.5) + + mask_value = max_neg_value(dots) + + if exists(query_mask) or exists(key_mask): + query_mask = default( + query_mask, lambda: torch.ones((b, t), device=device).bool() + ) + key_mask = default( + key_mask, lambda: torch.ones((b, kv_t), device=device).bool() + ) + + q_mask = expand_dim(query_mask, 1, h).gather(2, indices) + kv_mask = expand_dim(key_mask, 1, h).gather(2, kv_indices) + q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (q_mask, kv_mask)) + mask = q_mask[:, :, :, :, None] * kv_mask[:, :, :, None, :] + mask = F.pad(mask, (self.num_mem_kv, 0), value=1) + dots.masked_fill_(~mask, mask_value) + del mask + + if self.causal: + q_mask, kv_mask = map( + lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices) + ) + mask = q_mask[:, :, :, :, None] >= kv_mask[:, :, :, None, :] + mask = F.pad(mask, (self.num_mem_kv, 0), value=1) + dots.masked_fill_(~mask, mask_value) + del mask + + if self.shared_qk: + q_mask, kv_mask = map( + lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices) + ) + mask = q_mask[:, :, :, :, None] == kv_mask[:, :, :, None, :] + mask = F.pad(mask, (self.num_mem_kv, 0), value=0) + dots.masked_fill_(mask, TOKEN_SELF_ATTN_VALUE) + del mask + + dots = dots.softmax(dim=-1) + dots = self.dropout(dots) + + bo = torch.einsum("bhcij,bhcjd->bhcid", dots, v) + so = torch.reshape(bo, (b, h, -1, bo.shape[-1])).type(dtype) + out = scatter_mean(out, so, indices.unsqueeze(-1).expand_as(so), -2) + return out, aux_loss + + +# feedforward + + +class GELU_(nn.Module): + def forward(self, x): + return ( + 0.5 + * x + * ( + 1 + + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))) + ) + ) + + +GELU = nn.GELU if hasattr(nn, "GELU") else GELU_ + + +class FeedForward(nn.Module): + def __init__(self, dim, mult=4, dropout=0.0, activation=None, glu=False): + super().__init__() + activation = default(activation, GELU) + + self.glu = glu + self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1)) + self.act = activation() + self.dropout = nn.Dropout(dropout) + self.w2 = nn.Linear(dim * mult, dim) + + def forward(self, x, **kwargs): + if not self.glu: + x = self.w1(x) + x = self.act(x) + else: + x, v = self.w1(x).chunk(2, dim=-1) + x = self.act(x) * v + + x = self.dropout(x) + x = self.w2(x) + return x + + +# self attention + + +class SelfAttention(nn.Module): + def __init__( + self, + dim, + max_seq_len, + heads, + local_attn_heads, + window_size, + dim_head=None, + local_attn_window_size=None, + local_attn_radius_blocks=1, + causal=False, + attn_dropout=0.0, + dropout=0.0, + kmeans_ema_decay=0.999, + commitment_factor=1e-4, + receives_context=False, + context_window_size=None, + rel_pos_emb=True, + num_mem_kv=0, + shared_qk=False, + conv_query_kernel=9, + ): + super().__init__() + assert ( + dim_head or (dim % heads) == 0 + ), "hidden dimension must be divisible by number of heads" + assert ( + max_seq_len % window_size + ) == 0, "maximum sequence length must be divisible by the target window size" + assert ( + local_attn_heads <= heads + ), "number of local attention heads must be less than total heads" + assert not ( + receives_context and local_attn_heads > 0 + ), "local attention cannot be used for self attention with context" + assert not ( + receives_context and causal + ), "contextual attention layer cannot be causal" + + local_attn_window_size = default(local_attn_window_size, window_size) + context_window_size = default(context_window_size, window_size) + + self.shared_qk = shared_qk + self.receives_context = receives_context + self.heads = heads + self.local_attn_heads = local_attn_heads + self.global_attn_heads = heads - local_attn_heads + + self.causal = causal + self.window_size = window_size + + dim_head = default(dim_head, dim // heads) + dim_heads = dim_head * heads + self.dim_head = dim_head + + num_clusters = max_seq_len // window_size + + # local + + local_dim_heads = dim_head * self.local_attn_heads + + if self.local_attn_heads > 0: + rel_pos_emb_config = (dim_head, local_attn_heads) if rel_pos_emb else None + self.local_attn = LocalAttention( + dim=dim_head, + window_size=local_attn_window_size, + causal=causal, + dropout=attn_dropout, + rel_pos_emb_config=rel_pos_emb_config, + look_backward=local_attn_radius_blocks, + look_forward=0 if causal else local_attn_radius_blocks, + ) + self.local_to_qkv = nn.Linear(dim, 3 * local_dim_heads) + + # global + + global_dim_heads = dim_head * self.global_attn_heads + + if self.global_attn_heads > 0: + self.global_attn = KmeansAttention( + num_clusters, + window_size, + self.global_attn_heads, + dim_head, + causal=causal, + dropout=attn_dropout, + ema_decay=kmeans_ema_decay, + commitment=commitment_factor, + receives_context=receives_context, + num_mem_kv=num_mem_kv, + shared_qk=shared_qk, + ) + + self.to_q = nn.Sequential( + Rearrange("b n c -> b c n"), + DepthWiseConv1d(dim, global_dim_heads, conv_query_kernel, causal=causal), + Rearrange("b c n -> b n c"), + ) + + self.to_v = nn.Linear(dim, global_dim_heads, bias=False) + + if not self.shared_qk: + self.to_k = nn.Linear(dim, global_dim_heads, bias=False) + + # out + + self.to_out = nn.Linear(dim_heads, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward( + self, + query, + key, + value, + context=None, + key_padding_mask=None, + context_mask=None, + pos_emb=None, + **kwargs + ): + assert not ( + self.receives_context and not exists(context) + ), "context must be passed if self attention is set to receive context" + input_mask = key_padding_mask + x = query.transpose(0, 1) + b, t, _, h, dh = *x.shape, self.heads, self.dim_head + has_local, has_global = map( + lambda x: x > 0, (self.local_attn_heads, self.global_attn_heads) + ) + + split_heads = ( + lambda v: reshape_dim(v, -1, (-1, dh)).transpose(1, 2).contiguous() + ) + + if has_local: + local_qkv = self.local_to_qkv(x).chunk(3, dim=-1) + lq, lk, lv = map(split_heads, local_qkv) + + if has_global: + kv_input = x if not self.receives_context else context + + q, v = self.to_q(x), self.to_v(kv_input) + + if not self.shared_qk: + k = self.to_k(kv_input) + else: + k = self.to_q(kv_input) if self.receives_context else q + + q, k, v = map(split_heads, (q, k, v)) + + out = [] + total_loss = torch.tensor(0.0, requires_grad=True, **to(x)) + + if has_local: + local_out = self.local_attn(lq, lk, lv, input_mask=input_mask) + out.append(local_out) + + if has_global: + if not self.receives_context and exists(pos_emb): + q, k = apply_rotary_pos_emb(q, k, pos_emb) + + global_out, loss = self.global_attn( + q, k, v, query_mask=input_mask, key_mask=context_mask + ) + total_loss = total_loss + loss + + out.append(global_out) + + out = torch.cat(out, dim=1) + out = out.reshape(b, h, t, -1).transpose(1, 2).reshape(b, t, -1) + out = self.dropout(out.transpose(0, 1)) + # out = self.to_out(out) + return out, total_loss diff --git a/fairseq/modules/kmeans_vector_quantizer.py b/fairseq/modules/kmeans_vector_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..1015c389995d16e433d08303e3f37ce6f29653da --- /dev/null +++ b/fairseq/modules/kmeans_vector_quantizer.py @@ -0,0 +1,128 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from fairseq.modules import Fp32GroupNorm + + +class KmeansVectorQuantizer(nn.Module): + def __init__( + self, dim, num_vars, groups, combine_groups, vq_dim, time_first, gamma=0.25 + ): + """Vector quantization using straight pass-through estimator (i.e. kmeans) + + Args: + dim: input dimension (channels) + num_vars: number of quantized vectors per group + groups: number of groups for vector quantization + combine_groups: whether to use the vectors for all groups + vq_dim: dimensionality of the resulting quantized vector + time_first: if true, expect input in BxTxC format, otherwise in BxCxT + gamma: commitment loss coefficient + """ + super().__init__() + + self.groups = groups + self.combine_groups = combine_groups + self.input_dim = dim + self.num_vars = num_vars + self.vq_dim = vq_dim + self.time_first = time_first + + assert ( + vq_dim % groups == 0 + ), f"dim {vq_dim} must be divisible by groups {groups} for concatenation" + + self.var_dim = vq_dim // groups + num_groups = groups if not combine_groups else 1 + + self.embedding = nn.Parameter( + 0.01 * torch.randn(num_vars, num_groups, self.var_dim) + ) + self.projection = nn.Sequential( + nn.Conv1d(dim, dim, kernel_size=1, groups=groups, bias=False), + Fp32GroupNorm(groups, dim), + ) + self.gamma = gamma + self.mse_mean = nn.MSELoss(reduction="mean") + + def _pass_grad(self, x, y): + """Manually set gradient for backward pass. + for y = f(x), ensure that during the backward pass, + dL/dy = dL/dx regardless of f(x). + Returns: + y, with the gradient forced to be dL/dy = dL/dx. + """ + + return y.detach() + (x - x.detach()) + + @property + def expand_embedding(self): + if self.combine_groups: + return self.embedding.expand(self.num_vars, self.groups, self.var_dim) + return self.embedding + + def forward_idx(self, x): + res = self.forward(x, produce_targets=True) + return res["x"], res["targets"] + + def forward(self, x, produce_targets=False): + + result = {"num_vars": self.num_vars} + + if self.time_first: + x = x.transpose(1, 2) + + bsz, fsz, tsz = x.shape + + ze = self.projection(x) + ze_ = ze.view(bsz, self.groups, self.var_dim, tsz).permute(0, 3, 1, 2) + d = ( + (ze_.unsqueeze(0) - self.expand_embedding.unsqueeze(1).unsqueeze(1)) + .view(self.num_vars, bsz, tsz, self.groups, -1) + .norm(dim=-1, p=2) + ) + idx = d.argmin(dim=0) + zq = ( + torch.stack( + [ + self.expand_embedding[idx[..., group], group] + for group in range(self.groups) + ], + dim=-2, + ) + .view(bsz, tsz, self.groups * self.var_dim) + .permute(0, 2, 1) + ) + assert ze.shape == zq.shape, (ze.shape, zq.shape) + x = self._pass_grad(ze, zq) + + with torch.no_grad(): + hard_x = ( + idx.new_zeros(bsz * tsz * self.groups, self.num_vars) + .scatter_(-1, idx.view(-1, 1), 1.0) + .view(bsz * tsz, self.groups, -1) + ) + hard_probs = torch.mean(hard_x.float(), dim=0) + result["code_perplexity"] = torch.exp( + -torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1) + ).sum() + + if produce_targets: + result["targets"] = idx + + if self.time_first: + x = x.transpose(1, 2) # BCT -> BTC + result["x"] = x + + ze = ze.float() + zq = zq.float() + latent_loss = self.mse_mean(zq, ze.detach()) + commitment_loss = self.mse_mean(ze, zq.detach()) + + result["kmeans_loss"] = latent_loss + self.gamma * commitment_loss + + return result diff --git a/fairseq/modules/layer_drop.py b/fairseq/modules/layer_drop.py new file mode 100644 index 0000000000000000000000000000000000000000..8961d8bcbc492c40c6b30973234416ce5a414f5a --- /dev/null +++ b/fairseq/modules/layer_drop.py @@ -0,0 +1,44 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +LayerDrop as described in https://arxiv.org/abs/1909.11556. +""" + +import torch +import torch.nn as nn + + +class LayerDropModuleList(nn.ModuleList): + """ + A LayerDrop implementation based on :class:`torch.nn.ModuleList`. + + We refresh the choice of which layers to drop every time we iterate + over the LayerDropModuleList instance. During evaluation we always + iterate over all layers. + + Usage:: + + layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3]) + for layer in layers: # this might iterate over layers 1 and 3 + x = layer(x) + for layer in layers: # this might iterate over all layers + x = layer(x) + for layer in layers: # this might not iterate over any layers + x = layer(x) + + Args: + p (float): probability of dropping out each layer + modules (iterable, optional): an iterable of modules to add + """ + + def __init__(self, p, modules=None): + super().__init__(modules) + self.p = p + + def __iter__(self): + dropout_probs = torch.empty(len(self)).uniform_() + for i, m in enumerate(super().__iter__()): + if not self.training or (dropout_probs[i] > self.p): + yield m diff --git a/fairseq/modules/layer_norm.py b/fairseq/modules/layer_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..0b276ce02fc6bcb9619c9e8a0f7ec10cd28bc420 --- /dev/null +++ b/fairseq/modules/layer_norm.py @@ -0,0 +1,48 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +try: + from apex.normalization import FusedLayerNorm as _FusedLayerNorm + + has_fused_layernorm = True + + class FusedLayerNorm(_FusedLayerNorm): + @torch.jit.unused + def forward(self, x): + if not x.is_cuda: + return super().forward(x) + else: + with torch.cuda.device(x.device): + return super().forward(x) + +except ImportError: + has_fused_layernorm = False + + +def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): + if torch.jit.is_scripting() or torch.jit.is_tracing(): + export = True + if not export and torch.cuda.is_available() and has_fused_layernorm: + return FusedLayerNorm(normalized_shape, eps, elementwise_affine) + return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) + + +class Fp32LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.layer_norm( + input.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) diff --git a/fairseq/modules/learned_positional_embedding.py b/fairseq/modules/learned_positional_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..378d0f707183dd344dbb9288dda394b11053acf0 --- /dev/null +++ b/fairseq/modules/learned_positional_embedding.py @@ -0,0 +1,61 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import utils +from torch import Tensor + + +class LearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + Padding ids are ignored by either offsetting based on padding_idx + or by setting padding_idx to None and ensuring that the appropriate + position ids are passed to the forward function. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.onnx_trace = False + if self.padding_idx is not None: + self.max_positions = self.num_embeddings - self.padding_idx - 1 + else: + self.max_positions = self.num_embeddings + + def forward( + self, + input: Tensor, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + positions: Optional[Tensor] = None, + ): + """Input is expected to be of size [bsz x seqlen].""" + assert (positions is None) or ( + self.padding_idx is None + ), "If positions is pre-computed then padding_idx should not be set." + + if positions is None: + if incremental_state is not None: + # positions is the same for every token when decoding a single step + # Without the int() cast, it doesn't work in some cases when exporting to ONNX + positions = torch.zeros( + (1, 1), device=input.device, dtype=input.dtype + ).fill_(int(self.padding_idx + input.size(1))) + else: + positions = utils.make_positions( + input, self.padding_idx, onnx_trace=self.onnx_trace + ) + return F.embedding( + positions, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) diff --git a/fairseq/modules/lightconv_layer/__init__.py b/fairseq/modules/lightconv_layer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3b2a99c1227f827768911e5e22e79f6865ffbfd3 --- /dev/null +++ b/fairseq/modules/lightconv_layer/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .lightconv_layer import LightconvLayer # noqa diff --git a/fairseq/modules/lightconv_layer/cuda_function_gen.py b/fairseq/modules/lightconv_layer/cuda_function_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..a25433dd8edae2f0b52d7d0eeeb829cabc6b4b89 --- /dev/null +++ b/fairseq/modules/lightconv_layer/cuda_function_gen.py @@ -0,0 +1,289 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +def gen_forward(): + + kernels = [3, 5, 7, 15, 31, 63, 127, 255] + seqs = [32 * x for x in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]] + + head = """ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "lightconv_cuda.cuh" + +std::vector lightconv_cuda_forward(at::Tensor input, at::Tensor filters, int padding_l) { + + at::DeviceGuard g(input.device()); + const auto minibatch = input.size(0); + const auto numFeatures = input.size(1); + const auto sequenceLength = input.size(2); + + const auto numHeads = filters.size(0); + const auto filterSize = filters.size(1); + + const auto numFiltersInBlock = numFeatures / numHeads; + + const dim3 blocks(minibatch, numFeatures); + + auto output = at::zeros_like(input); + auto stream = at::cuda::getCurrentCUDAStream(); +""" + + sequence_if = """ + if (sequenceLength <= {seq}) {{ + switch(filterSize) {{ +""" + + case_k = """ + case {k}: +""" + + main_block = """ + if (padding_l == {pad}) {{ + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "lightconv_forward", ([&] {{ + lightconv_forward_kernel<{k}, {b_size}, {pad}, scalar_t> + <<>>( + input.data(), + filters.data(), + minibatch, + sequenceLength, + numFeatures, + numFiltersInBlock, + output.data()); + }})); + }} else +""" + + bad_padding = """ + { + std::cout << "WARNING: Unsupported padding size - skipping forward pass" << std::endl; + } + break; +""" + + bad_filter = """ + default: + std::cout << "WARNING: Unsupported filter length passed - skipping forward pass" << std::endl; + } +""" + + con_else = """ + } else +""" + + final_else = """ + { + switch(filterSize) { +""" + + final_return = """ + } + + return {output}; +} +""" + + with open("lightconv_cuda_forward.cu", "w") as forward: + forward.write(head) + for seq in seqs: + forward.write(sequence_if.format(seq=seq)) + for k in kernels: + forward.write(case_k.format(k=k)) + for pad in [k // 2, k - 1]: + forward.write(main_block.format(k=k, b_size=seq, pad=pad)) + forward.write(bad_padding) + forward.write(bad_filter) + forward.write(con_else) + + forward.write(final_else) + for k in kernels: + forward.write(case_k.format(k=k)) + for pad in [k // 2, k - 1]: + forward.write(main_block.format(k=k, b_size=seq, pad=pad)) + forward.write(bad_padding) + forward.write(bad_filter) + forward.write(final_return) + + +def gen_backward(): + + head = """ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "lightconv_cuda.cuh" + +std::vector lightconv_cuda_backward( + at::Tensor gradOutput, + int padding_l, + at::Tensor input, + at::Tensor filters) { + + // gradWrtInput + const int minibatch = input.size(0); + const int numFeatures = input.size(1); + const int sequenceLength = input.size(2); + + const int numHeads = filters.size(0); + const int filterSize = filters.size(1); + + const dim3 gradBlocks(minibatch, numFeatures); + const dim3 weightGradFirstpassShortBlocks(minibatch, numHeads); + const dim3 weightGradSecondpassBlocks(numHeads, filterSize); + + const int numFiltersInBlock = numFeatures / numHeads; + + auto gradInput = at::zeros_like(input); + auto gradFilters = at::zeros_like(filters); + + at::DeviceGuard g(input.device()); + auto stream = at::cuda::getCurrentCUDAStream(); + + switch(filterSize) { +""" + + sequence_if = """ + if (sequenceLength <= {seq}) {{ +""" + + case_k = """ + case {k}: +""" + + main_block = """ + if (padding_l == {p}) {{ + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "lightconv_backward", ([&] {{ + lightconv_grad_wrt_input_kernel<{k}, {b_size}, {p}, scalar_t> + <<>>( + gradOutput.data(), + filters.data(), + minibatch, + sequenceLength, + numFeatures, + numFiltersInBlock, + gradInput.data()); + +""" + + weight_grad_short = """ + at::Tensor tempSumGradFilters = at::zeros({{minibatch, numHeads, filterSize}}, input.options().dtype(at::kFloat)); + lightconv_grad_wrt_weights_firstpass_short_kernel<{k}, {b_size}, {p}, scalar_t> + <<>>( + input.data(), + gradOutput.data(), + minibatch, + sequenceLength, + numFeatures, + numFiltersInBlock, + numHeads, + tempSumGradFilters.data() + ); + + lightconv_grad_wrt_weights_secondpass_short_kernel<{k}, {b_size}, scalar_t> + <<>>( + tempSumGradFilters.data(), + minibatch, + numFiltersInBlock, + gradFilters.data() + ); + }})); + }} else +""" + + weight_grad = """ + at::Tensor tempSumGradFilters = at::zeros({{minibatch, numFeatures, filterSize}}, input.options().dtype(at::kFloat)); + lightconv_grad_wrt_weights_firstpass_kernel<{k}, {b_size}, {p}, scalar_t> + <<>>( + input.data(), + gradOutput.data(), + minibatch, + sequenceLength, + numFeatures, + numFiltersInBlock, + tempSumGradFilters.data() + ); + + lightconv_grad_wrt_weights_secondpass_kernel<{k}, {b_size}, scalar_t> + <<>>( + tempSumGradFilters.data(), + minibatch, + numFiltersInBlock, + gradFilters.data() + ); + }})); + }} else +""" + + bad_padding = """ + { + std::cout << "WARNING: Unsupported padding size - skipping backward pass" << std::endl; + } +""" + + breakout = """ + break; +""" + + bad_filter = """ + default: + std::cout << "WARNING: Unsupported filter length passed - skipping backward pass" << std::endl; +""" + + con_else = """ + } else +""" + + final_else = """ + { + switch(filterSize) { +""" + + last_return = """ + } + return {gradInput, gradFilters}; +} +""" + + kernels = [3, 5, 7, 15, 31, 63, 127, 255] + seqs = [32 * x for x in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]] + thresh = [32, 32, 64, 128, 256, -1, -1, -1] + max_mem = [-1, -1, -1, -1, -1, 192, 96, 64] + + with open("lightconv_cuda_backward.cu", "w") as backward: + backward.write(head) + for (k, t, mem) in zip(kernels, thresh, max_mem): + backward.write(case_k.format(k=k)) + for seq in seqs: + if (t == -1 or seq <= t) and (mem == -1 or seq < mem): + backward.write(sequence_if.format(seq=seq)) + for p in [k // 2, k - 1]: + backward.write(main_block.format(k=k, b_size=seq, p=p)) + backward.write(weight_grad_short.format(k=k, b_size=seq, p=p)) + backward.write(bad_padding) + else: + for p in [k // 2, k - 1]: + backward.write(main_block.format(k=k, b_size=32, p=p)) + backward.write(weight_grad.format(k=k, b_size=32, p=p)) + backward.write(bad_padding) + backward.write(breakout) + break + backward.write(con_else) + backward.write(bad_filter) + backward.write(last_return) + + +if __name__ == "__main__": + gen_forward() + gen_backward() diff --git a/fairseq/modules/lightconv_layer/lightconv_cuda.cpp b/fairseq/modules/lightconv_layer/lightconv_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ece47a8d908b93cec102743070c9057986d39d3f --- /dev/null +++ b/fairseq/modules/lightconv_layer/lightconv_cuda.cpp @@ -0,0 +1,51 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +std::vector +lightconv_cuda_forward(at::Tensor input, at::Tensor filters, int padding_l); + +std::vector lightconv_cuda_backward( + at::Tensor gradOutput, + int padding_l, + at::Tensor input, + at::Tensor filters); + +#define CHECK_CUDA(x) \ + AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +std::vector +lightconv_forward(at::Tensor input, at::Tensor filters, int padding_l) { + CHECK_INPUT(input); + CHECK_INPUT(filters); + + return lightconv_cuda_forward(input, filters, padding_l); +} + +std::vector lightconv_backward( + at::Tensor gradOutput, + int padding_l, + at::Tensor input, + at::Tensor filters) { + CHECK_INPUT(gradOutput); + CHECK_INPUT(input); + CHECK_INPUT(filters); + + return lightconv_cuda_backward(gradOutput, padding_l, input, filters); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &lightconv_forward, "lighconv forward (CUDA)"); + m.def("backward", &lightconv_backward, "lighconv backward (CUDA)"); +} diff --git a/fairseq/modules/lightconv_layer/lightconv_cuda.cuh b/fairseq/modules/lightconv_layer/lightconv_cuda.cuh new file mode 100644 index 0000000000000000000000000000000000000000..610ab399e9b201cd8b0fb87a91e09b6f7aab9803 --- /dev/null +++ b/fairseq/modules/lightconv_layer/lightconv_cuda.cuh @@ -0,0 +1,79 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +#define SHFL_MASK 0xffffffff + +template +__global__ void lightconv_forward_kernel( + const scalar_t* input, + const scalar_t* filters, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + scalar_t* output); + +template +__global__ void lightconv_grad_wrt_input_kernel( + const scalar_t* input, + const scalar_t* filters, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + scalar_t* output); + +template +__global__ void lightconv_grad_wrt_weights_firstpass_short_kernel( + const scalar_t* input, + const scalar_t* gradInput, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + int numHeads, + float* output); + +template +__global__ void lightconv_grad_wrt_weights_secondpass_short_kernel( + const float* input, + const int minibatch, + const int numFiltersInBlock, + scalar_t* output); + +template +__global__ void lightconv_grad_wrt_weights_firstpass_kernel( + const scalar_t* input, + const scalar_t* gradInput, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + float* output); + +template +__global__ void lightconv_grad_wrt_weights_secondpass_kernel( + const float* input, + const int minibatch, + const int numFiltersInBlock, + scalar_t* output); diff --git a/fairseq/modules/lightconv_layer/lightconv_cuda_kernel.cu b/fairseq/modules/lightconv_layer/lightconv_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..cdf31d5d2df2d3433c66e167f098e48a99f96db2 --- /dev/null +++ b/fairseq/modules/lightconv_layer/lightconv_cuda_kernel.cu @@ -0,0 +1,400 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "../cuda_utils.cu" +#include "lightconv_cuda.cuh" +#include "lightconv_cuda_backward.cu" +#include "lightconv_cuda_forward.cu" + +template +__global__ void lightconv_forward_kernel( + const scalar_t* input, + const scalar_t* filters, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + scalar_t* output) { + const int tid = threadIdx.x; + const int batchIdx = blockIdx.x; + const int featureIdx = blockIdx.y; + const int filterIdx = featureIdx / numFiltersInBlock; + + const int IOOffset = + numFeatures * sequenceLength * batchIdx + featureIdx * sequenceLength; + const scalar_t* inputFeature = &input[IOOffset]; + scalar_t* outputFeature = &output[IOOffset]; + const scalar_t* inputFilter = &filters[filterIdx * FS]; + + assert(blockDim.x == SB); + + scalar_t filter[FS]; +#pragma unroll + for (int i = 0; i < FS; ++i) { + filter[i] = inputFilter[i]; + } + + __shared__ scalar_t temp[SB + FS]; + zeroSharedMem(temp); + + const int numIterations = divUp(sequenceLength, SB); + + for (int i = 0; i < numIterations; ++i) { + // Read input into shared memory + const int inputOffset = i * SB; + + load_input_to_shared( + inputFeature, + inputOffset, + sequenceLength, + i, + numIterations, + (numIterations == 1), + temp); + + __syncthreads(); + + scalar_t out = 0; +#pragma unroll + for (int j = 0; j < FS; ++j) { + out += filter[j] * temp[tid + j]; + } + + // Write output + const int outputOffset = inputOffset; + if ((outputOffset + tid) < sequenceLength) { + outputFeature[outputOffset + tid] = out; + } + + __syncthreads(); + } +} + +template +__global__ void lightconv_grad_wrt_input_kernel( + const scalar_t* input, + const scalar_t* filters, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + scalar_t* output) { + // input grad kernel is similar to forward kernel + const int tid = threadIdx.x; + const int batchIdx = blockIdx.x; + const int featureIdx = blockIdx.y; + const int filterIdx = featureIdx / numFiltersInBlock; + + const int IOOffset = + numFeatures * sequenceLength * batchIdx + featureIdx * sequenceLength; + const scalar_t* inputFeature = &input[IOOffset]; + scalar_t* outputFeature = &output[IOOffset]; + const scalar_t* inputFilter = &filters[filterIdx * FS]; + + assert(blockDim.x == SB); + + scalar_t filter[FS]; + +// The only change is loading the filter in reverse +#pragma unroll + for (int i = 0; i < FS; ++i) { + filter[i] = inputFilter[FS - i - 1]; + } + + __shared__ scalar_t temp[SB + FS]; + const int padding = FS - padding_l - 1; + zeroSharedMem(temp); + + __syncthreads(); + + const int numIterations = divUp(sequenceLength, SB); + + for (int i = 0; i < numIterations; ++i) { + // Read input into shared memory + const int inputOffset = i * SB; + + load_input_to_shared( + inputFeature, + inputOffset, + sequenceLength, + i, + numIterations, + false, + temp); + + __syncthreads(); + + scalar_t out = 0; +#pragma unroll + for (int j = 0; j < FS; ++j) { + out += filter[j] * temp[tid + j]; + } + + // Write output + const int outputOffset = inputOffset; + if ((outputOffset + tid) < sequenceLength) { + outputFeature[outputOffset + tid] = out; + } + + __syncthreads(); + } +} + +// This is by far the most expensive kernel in terms of time taken. +// Can be 16x slower than the forward or grad_wrt_input when filter size is 31 +template +__global__ void lightconv_grad_wrt_weights_firstpass_short_kernel( + const scalar_t* input, + const scalar_t* gradInput, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + int numHeads, + float* output) { + const int tid = threadIdx.x; + const int batchIdx = blockIdx.x; + const int filterIdx = blockIdx.y; + + const int numIterations = divUp(sequenceLength, SB); + + float* tempOutputGradWeight = &output[filterIdx * FS * minibatch]; + + assert(blockDim.x == SB); + + __shared__ scalar_t tempInput[SB + FS]; + __shared__ scalar_t tempGradInput[SB + FS]; + + // local weight accumulation + float accumWeights[FS]; + + // Initialize memory + for (int i = 0; i < FS; ++i) { + accumWeights[i] = float(0.0); + } + + // loop over each sequence within filterblock + for (int idxInFilterBlock = 0; idxInFilterBlock < numFiltersInBlock; + ++idxInFilterBlock) { + const int featureOffset = batchIdx * numFeatures * sequenceLength + + (filterIdx * numFiltersInBlock + idxInFilterBlock) * sequenceLength; + const scalar_t* inputFeature = &input[featureOffset]; + const scalar_t* gradInputFeature = &gradInput[featureOffset]; + + zeroSharedMem(tempInput); + zeroSharedMem(tempGradInput); + __syncthreads(); + + for (int i = 0; i < numIterations; ++i) { + const int inputOffset = i * SB; + + load_input_to_shared( + inputFeature, + inputOffset, + sequenceLength, + i, + numIterations, + false, + tempInput); + load_input_to_shared( + gradInputFeature, + inputOffset, + sequenceLength, + i, + numIterations, + false, + tempGradInput); + + __syncthreads(); + + const int gradIndex = (FS / 2) + tid; + scalar_t tempGrad = tempGradInput[gradIndex]; + +#pragma unroll + for (int j = 0; j < FS; j++) { + const int inputIndex = tid + j; + accumWeights[j] += tempInput[inputIndex] * tempGrad; + } + + __syncthreads(); + } + } + + // Row-major sum + for (int filterWeightIdx = 0; filterWeightIdx < FS; ++filterWeightIdx) { + float temp; + if (tid < sequenceLength) { + temp = accumWeights[filterWeightIdx]; + } else { + temp = float(0.0); + } + + const int outputOffset = filterWeightIdx * minibatch + batchIdx; + + temp = blockReduce(temp); + + if (tid == 0) { + tempOutputGradWeight[outputOffset] = temp; + } + } +} + +template +__global__ void lightconv_grad_wrt_weights_secondpass_short_kernel( + const float* input, + const int minibatch, + const int numFiltersInBlock, + scalar_t* output) { + assert(blockDim.x == SB); + + const int tid = threadIdx.x; + + const int filterIdx = blockIdx.x; + const int filterWeightIdx = blockIdx.y; + + const int inputOffset = + filterIdx * FS * minibatch + filterWeightIdx * minibatch; + const float* tempInput = &input[inputOffset]; + + // read into shared memory for reduction + int readIndex = tid; + + float sum = 0.0; + while (readIndex < minibatch) { + sum += tempInput[readIndex]; + readIndex += SB; + } + + float temp = blockReduce(sum); + + if (tid == 0) { + output[blockIdx.x * FS + blockIdx.y] = temp; + } +} + +// This is by far the most expensive kernel in terms of time taken. +// Can be 16x slower than the forward or grad_wrt_input when filter size is 31 +template +__global__ void lightconv_grad_wrt_weights_firstpass_kernel( + const scalar_t* input, + const scalar_t* gradInput, + int minibatch, + int sequenceLength, + int numFeatures, + int numFiltersInBlock, + float* output) { + assert(blockDim.x == SB); + + const int tid = threadIdx.x; + const int batchIdx = blockIdx.x; + const int featureIdx = blockIdx.y; + const int filterIdx = featureIdx / numFiltersInBlock; + const int idxInFilterBlock = featureIdx % numFiltersInBlock; + + const int numIterations = divUp(sequenceLength, SB); + + float temp; + + __shared__ scalar_t tempInput[SB + FS]; + __shared__ scalar_t tempGradInput[SB + FS]; + zeroSharedMem(tempInput); + zeroSharedMem(tempGradInput); + __syncthreads(); + + float accumWeights[FS]; + + for (int i = 0; i < FS; ++i) { + accumWeights[i] = float(0.0); + } + + const int IOOffset = + batchIdx * numFeatures * sequenceLength + featureIdx * sequenceLength; + const scalar_t* inputFeature = &input[IOOffset]; + const scalar_t* gradInputFeature = &gradInput[IOOffset]; + float* tempOutputGradWeight = + &output[filterIdx * FS * minibatch * numFiltersInBlock]; + + for (int i = 0; i < numIterations; ++i) { + const int inputOffset = i * SB; + + load_input_to_shared( + inputFeature, + inputOffset, + sequenceLength, + i, + numIterations, + false, + tempInput); + load_input_to_shared( + gradInputFeature, + inputOffset, + sequenceLength, + i, + numIterations, + false, + tempGradInput); + __syncthreads(); + +#pragma unroll + for (int j = 0; j < FS; ++j) { + accumWeights[j] += tempInput[tid + j] * tempGradInput[tid + (FS / 2)]; + } + + __syncthreads(); + } + + // Row-major sum + for (int filterWeightIdx = 0; filterWeightIdx < FS; ++filterWeightIdx) { + // Write to shared memory before reduction + if (tid < sequenceLength) { + temp = accumWeights[filterWeightIdx]; + } else { + temp = float(0.0); + } + + temp = blockReduce(temp); + + const int outputOffset = filterWeightIdx * minibatch * numFiltersInBlock + + batchIdx * numFiltersInBlock + idxInFilterBlock; + + if (tid == 0) { + tempOutputGradWeight[outputOffset] = temp; + } + } +} + +template +__global__ void lightconv_grad_wrt_weights_secondpass_kernel( + const float* input, + const int minibatch, + const int numFiltersInBlock, + scalar_t* output) { + assert(blockDim.x == SB); + const int tid = threadIdx.x; + + // What is the id within a minibatch + const int filterIdx = blockIdx.x; + const int filterWeightIdx = blockIdx.y; + + const int inputOffset = filterIdx * FS * minibatch * numFiltersInBlock + + filterWeightIdx * minibatch * numFiltersInBlock; + const float* tempInput = &input[inputOffset]; + + int readIndex = tid; + + float sum = float(0.0); + while (readIndex < (minibatch * numFiltersInBlock)) { + sum += tempInput[readIndex]; + readIndex += SB; + } + + float temp = blockReduce(sum); + + if (tid == 0) { + output[blockIdx.x * FS + blockIdx.y] = temp; + } +} diff --git a/fairseq/modules/lightconv_layer/lightconv_layer.py b/fairseq/modules/lightconv_layer/lightconv_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..e7e597f4749c591b057d776aacec39b44d99c037 --- /dev/null +++ b/fairseq/modules/lightconv_layer/lightconv_layer.py @@ -0,0 +1,137 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import lightconv_cuda +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.incremental_decoding_utils import with_incremental_state +from fairseq.modules.fairseq_dropout import FairseqDropout +from torch import nn +from torch.autograd import Function + + +class lightconvFunction(Function): + @staticmethod + def forward(ctx, x, weights, padding_l): + ctx.padding_l = padding_l + outputs = lightconv_cuda.forward(x, weights, padding_l) + variables = [x, weights] + ctx.save_for_backward(*variables) + return outputs[0] + + @staticmethod + def backward(ctx, grad_output): + outputs = lightconv_cuda.backward( + grad_output.contiguous(), ctx.padding_l, *ctx.saved_tensors + ) + grad_input, grad_weights = outputs + return grad_input, grad_weights, None + + +@with_incremental_state +class LightconvLayer(nn.Module): + def __init__( + self, + input_size, + kernel_size=1, + padding_l=None, + weight_softmax=False, + num_heads=1, + weight_dropout=0.0, + bias=False, + ): + super(LightconvLayer, self).__init__() + self.input_size = input_size + self.kernel_size = kernel_size + self.padding_l = padding_l + self.num_heads = num_heads + self.weight_softmax = weight_softmax + self.weight_dropout_module = FairseqDropout( + weight_dropout, module_name=self.__class__.__name__ + ) + + self.weight = nn.Parameter(torch.Tensor(num_heads, kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(input_size)) + else: + self.bias = None + self.reset_parameters() + + def upgrade_state_dict_named(self, state_dict, name): + prefix = name + "." if name != "" else "" + for k, v in state_dict.items(): + if k.endswith(prefix + "weight"): + if v.dim() == 3 and v.size(1) == 1: + state_dict[k] = v.squeeze(1) + + def reset_parameters(self): + nn.init.xavier_uniform_(self.weight) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + + def forward(self, x, incremental_state=None): + + # during inference time, incremental BMM is faster + if incremental_state is not None: + T, B, C = x.size() + K, H = self.kernel_size, self.num_heads + R = C // H + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is None: + input_buffer = x.new() + x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) + if self.kernel_size > 1: + self._set_input_buffer( + incremental_state, x_unfold[:, :, :, -self.kernel_size + 1 :] + ) + x_unfold = x_unfold.view(T * B * H, R, -1) + + weight = self.weight + if self.weight_softmax: + weight = F.softmax(weight.float(), dim=1).type_as(weight) + + weight = weight[:, -x_unfold.size(2) :] + + K = weight.size(1) + + weight = ( + weight.view(1, H, K) + .expand(T * B, H, K) + .contiguous() + .view(T * B * H, K, 1) + ) + + weight = self.weight_dropout_module(weight) + output = torch.bmm(x_unfold, weight) # T*B*H x R x 1 + output = output.view(T, B, C) + return output + + # during training time, use CUDA kernel + else: + x = x.permute(1, 2, 0).contiguous() + weight = self.weight + if self.weight_softmax: + weight = F.softmax(self.weight, -1) + if self.weight_dropout_module.p: + weight = self.weight_dropout_module(weight) + return lightconvFunction.apply(x, weight, self.padding_l).permute(2, 0, 1) + + def reorder_incremental_state(self, incremental_state, new_order): + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + input_buffer = input_buffer.index_select(1, new_order) + self._set_input_buffer(incremental_state, input_buffer) + + def _get_input_buffer(self, incremental_state): + return utils.get_incremental_state(self, incremental_state, "input_buffer") + + def _set_input_buffer(self, incremental_state, new_buffer): + return utils.set_incremental_state( + self, incremental_state, "input_buffer", new_buffer + ) + + def half(self): + return self._apply(lambda t: t.half() if t.is_floating_point() else t) diff --git a/fairseq/modules/lightconv_layer/setup.py b/fairseq/modules/lightconv_layer/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..052635be79b466d0ad56cf5cf607bd10c2297ecf --- /dev/null +++ b/fairseq/modules/lightconv_layer/setup.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + + +setup( + name="lightconv_layer", + ext_modules=[ + CUDAExtension( + "lightconv_cuda", + [ + "lightconv_cuda.cpp", + "lightconv_cuda_kernel.cu", + ], + ), + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/fairseq/modules/lightweight_convolution.py b/fairseq/modules/lightweight_convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..ec11a9507951c9e8f3564753841dd9c74a4900e0 --- /dev/null +++ b/fairseq/modules/lightweight_convolution.py @@ -0,0 +1,310 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from fairseq import utils +from fairseq.incremental_decoding_utils import with_incremental_state +from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.unfold import unfold1d + + +def LightweightConv( + input_size, + kernel_size=1, + padding_l=None, + num_heads=1, + weight_dropout=0.0, + weight_softmax=False, + bias=False, +): + if torch.cuda.is_available(): + try: + from fairseq.modules.lightconv_layer import LightconvLayer + + return LightconvLayer( + input_size, + kernel_size=kernel_size, + padding_l=padding_l, + num_heads=num_heads, + weight_dropout=weight_dropout, + weight_softmax=weight_softmax, + bias=bias, + ) + except ImportError as e: + print(e) + return LightweightConv1dTBC( + input_size, + kernel_size=kernel_size, + padding_l=padding_l, + num_heads=num_heads, + weight_dropout=weight_dropout, + weight_softmax=weight_softmax, + bias=bias, + ) + + +class LightweightConv1d(nn.Module): + """Lightweight Convolution assuming the input is BxCxT + This is just an example that explains LightConv clearer than the TBC version. + We don't use this module in the model. + + Args: + input_size: # of channels of the input and output + kernel_size: convolution channels + padding: padding + num_heads: number of heads used. The weight is of shape + `(num_heads, 1, kernel_size)` + weight_softmax: normalize the weight with softmax before the convolution + + Shape: + Input: BxCxT, i.e. (batch_size, input_size, timesteps) + Output: BxCxT, i.e. (batch_size, input_size, timesteps) + + Attributes: + weight: the learnable weights of the module of shape + `(num_heads, 1, kernel_size)` + bias: the learnable bias of the module of shape `(input_size)` + """ + + def __init__( + self, + input_size, + kernel_size=1, + padding=0, + num_heads=1, + weight_softmax=False, + bias=False, + weight_dropout=0.0, + ): + super().__init__() + self.input_size = input_size + self.kernel_size = kernel_size + self.num_heads = num_heads + self.padding = padding + self.weight_softmax = weight_softmax + self.weight = nn.Parameter(torch.Tensor(num_heads, 1, kernel_size)) + + if bias: + self.bias = nn.Parameter(torch.Tensor(input_size)) + else: + self.bias = None + self.weight_dropout_module = FairseqDropout( + weight_dropout, module_name=self.__class__.__name__ + ) + self.reset_parameters() + + def reset_parameters(self): + nn.init.xavier_uniform_(self.weight) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + + def forward(self, input): + """ + input size: B x C x T + output size: B x C x T + """ + B, C, T = input.size() + H = self.num_heads + + weight = self.weight + if self.weight_softmax: + weight = F.softmax(weight, dim=-1) + + weight = self.weight_dropout_module(weight) + # Merge every C/H entries into the batch dimension (C = self.input_size) + # B x C x T -> (B * C/H) x H x T + # One can also expand the weight to C x 1 x K by a factor of C/H + # and do not reshape the input instead, which is slow though + input = input.view(-1, H, T) + output = F.conv1d(input, weight, padding=self.padding, groups=self.num_heads) + output = output.view(B, C, T) + if self.bias is not None: + output = output + self.bias.view(1, -1, 1) + + return output + + +@with_incremental_state +class LightweightConv1dTBC(nn.Module): + """Lightweight Convolution assuming the input is TxBxC + Args: + input_size: # of channels of the input + kernel_size: convolution channels + padding_l: padding to the left when using "same" padding + num_heads: number of heads used. The weight is of shape (num_heads, 1, kernel_size) + weight_dropout: the drop rate of the DropConnect to drop the weight + weight_softmax: normalize the weight with softmax before the convolution + bias: use bias + + Shape: + Input: TxBxC, i.e. (timesteps, batch_size, input_size) + Output: TxBxC, i.e. (timesteps, batch_size, input_size) + + Attributes: + weight: the learnable weights of the module of shape + `(num_heads, 1, kernel_size)` + bias: the learnable bias of the module of shape `(input_size)` + """ + + def __init__( + self, + input_size, + kernel_size=1, + padding_l=None, + num_heads=1, + weight_dropout=0.0, + weight_softmax=False, + bias=False, + ): + super().__init__() + self.input_size = input_size + self.kernel_size = kernel_size + self.padding_l = padding_l + self.num_heads = num_heads + self.weight_dropout_module = FairseqDropout( + weight_dropout, module_name=self.__class__.__name__ + ) + self.weight_softmax = weight_softmax + + self.weight = nn.Parameter(torch.Tensor(num_heads, 1, kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(input_size)) + else: + self.bias = None + + self.reset_parameters() + self.onnx_trace = False + + def reset_parameters(self): + nn.init.xavier_uniform_(self.weight) + if self.bias is not None: + nn.init.constant_(self.bias, 0.0) + + def forward(self, x, incremental_state=None, unfold=False): + """Assuming the input, x, of the shape T x B x C and producing an output in the shape T x B x C + args: + x: Input of shape T x B x C, i.e. (timesteps, batch_size, input_size) + incremental_state: A dict to keep the state + unfold: unfold the input or not. If not, we use the matrix trick instead + """ + unfold = unfold or (incremental_state is not None) + + if unfold: + output = self._forward_unfolded(x, incremental_state) + else: + output = self._forward_expanded(x, incremental_state) + + if self.bias is not None: + output = output + self.bias.view(1, 1, -1) + return output + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def _forward_unfolded(self, x, incremental_state): + """The conventional implementation of convolutions. + Unfolding the input by having a window shifting to the right.""" + T, B, C = x.size() + K, H = self.kernel_size, self.num_heads + R = C // H + assert R * H == C == self.input_size + + weight = self.weight.view(H, K) + if incremental_state is not None: + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is None: + input_buffer = x.new() + x_unfold = torch.cat([input_buffer, x.unsqueeze(3)], dim=3) + if self.kernel_size > 1: + self._set_input_buffer( + incremental_state, x_unfold[:, :, :, -self.kernel_size + 1 :] + ) + x_unfold = x_unfold.view(T * B * H, R, -1) + else: + # unfold the input: T x B x C --> T' x B x C x K + x_unfold = unfold1d(x, self.kernel_size, self.padding_l, 0) + x_unfold = x_unfold.view(T * B * H, R, K) + + if self.weight_softmax: + weight = utils.softmax(weight, dim=1, onnx_trace=self.onnx_trace).type_as( + weight + ) + + if incremental_state is not None: + weight = weight[:, -x_unfold.size(2) :] + K = weight.size(1) + + weight = ( + weight.view(1, H, K).expand(T * B, H, K).contiguous().view(T * B * H, K, 1) + ) + + weight = self.weight_dropout_module(weight) + output = torch.bmm(x_unfold, weight) # T*B*H x R x 1 + output = output.view(T, B, C) + return output + + def _forward_expanded(self, x, incremental_state): + """Turn the convolution filters into band matrices and do matrix multiplication. + This is faster when the sequence is short, but less memory efficient. + This is not used in the decoder during inference. + """ + T, B, C = x.size() + K, H = self.kernel_size, self.num_heads + R = C // H + assert R * H == C == self.input_size + + weight = self.weight.view(H, K) + if self.weight_softmax: + weight = utils.softmax(weight, dim=1, onnx_trace=self.onnx_trace).type_as( + weight + ) + weight = weight.view(1, H, K).expand(T * B, H, K).contiguous() + weight = weight.view(T, B * H, K).transpose(0, 1) + + x = x.view(T, B * H, R).transpose(0, 1) + P = self.padding_l + if K > T and P == K - 1: + weight = weight.narrow(2, K - T, T) + K, P = T, T - 1 + # turn the convolution filters into band matrices + weight_expanded = weight.new_zeros(B * H, T, T + K - 1, requires_grad=False) + weight_expanded.as_strided((B * H, T, K), (T * (T + K - 1), T + K, 1)).copy_( + weight + ) + weight_expanded = weight_expanded.narrow(2, P, T) + weight_expanded = self.weight_dropout_module(weight_expanded) + + output = torch.bmm(weight_expanded, x) + output = output.transpose(0, 1).contiguous().view(T, B, C) + return output + + def reorder_incremental_state(self, incremental_state, new_order): + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + input_buffer = input_buffer.index_select(1, new_order) + self._set_input_buffer(incremental_state, input_buffer) + + def _get_input_buffer(self, incremental_state): + return utils.get_incremental_state(self, incremental_state, "input_buffer") + + def _set_input_buffer(self, incremental_state, new_buffer): + return utils.set_incremental_state( + self, incremental_state, "input_buffer", new_buffer + ) + + def extra_repr(self): + s = "{}, kernel_size={}, padding_l={}, num_heads={}, weight_softmax={}, bias={}".format( + self.input_size, + self.kernel_size, + self.padding_l, + self.num_heads, + self.weight_softmax, + self.bias is not None, + ) + if self.weight_dropout_module.p > 0.0: + s += ", weight_dropout={}".format(self.weight_dropout_module.p) + return s diff --git a/fairseq/modules/linearized_convolution.py b/fairseq/modules/linearized_convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..1c7a9f09acd734eaee7a4c825eb99c4357f9cd05 --- /dev/null +++ b/fairseq/modules/linearized_convolution.py @@ -0,0 +1,125 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +from fairseq import utils +from fairseq.incremental_decoding_utils import with_incremental_state + +from .conv_tbc import ConvTBC + +from typing import Dict, Optional +from torch import Tensor + + +@with_incremental_state +class LinearizedConvolution(ConvTBC): + """An optimized version of nn.Conv1d. + + At training time, this module uses ConvTBC, which is an optimized version + of Conv1d. At inference time, it optimizes incremental generation (i.e., + one time step at a time) by replacing the convolutions with linear layers. + Note that the input order changes from training to inference. + """ + + def __init__(self, in_channels, out_channels, kernel_size, **kwargs): + super().__init__(in_channels, out_channels, kernel_size, **kwargs) + self._linearized_weight = None + self.register_backward_hook(self._clear_linearized_weight) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + state = ConvTBC.state_dict(self, destination, prefix, keep_vars=keep_vars) + # don't store redundant _linearized_weight in checkpoints + if prefix + "_linearized_weight" in state: + del state[prefix + "_linearized_weight"] + return state + + def upgrade_state_dict_named(self, state_dict, name): + prefix = name + "." if name != "" else "" + if prefix + "_linearized_weight" in state_dict: + del state_dict[prefix + "_linearized_weight"] + + @torch.jit.export + def forward( + self, + input, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + ): + """ + Args: + incremental_state: Used to buffer signal; if not None, then input is + expected to contain a single frame. If the input order changes + between time steps, call reorder_incremental_state. + Input: + Time x Batch x Channel during training + Batch x Time x Channel during inference + """ + if incremental_state is None: + output = self.conv_tbc(input) + if self.kernel_size[0] > 1 and self.padding[0] > 0: + # remove future timesteps added by padding + output = output[: -self.padding[0], :, :] + return output + + # reshape weight + weight = self._get_linearized_weight() + kw = self.kernel_size[0] + + bsz = input.size(0) # input: bsz x len x dim + if kw > 1: + input = input.data + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is None: + input_buffer = input.new(bsz, kw, input.size(2)).zero_() + self._set_input_buffer(incremental_state, input_buffer) + else: + # shift buffer + input_buffer[:, :-1, :] = input_buffer[:, 1:, :].clone() + # append next input + input_buffer[:, -1, :] = input[:, -1, :] + input = input_buffer + with torch.no_grad(): + output = F.linear(input.view(bsz, -1), weight, self.bias) + return output.view(bsz, 1, -1) + + @torch.jit.unused + def reorder_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + new_order, + ): + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + input_buffer = input_buffer.index_select(0, new_order) + self._set_input_buffer(incremental_state, input_buffer) + + @torch.jit.unused + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ): + return utils.get_incremental_state(self, incremental_state, "input_buffer") + + @torch.jit.unused + def _set_input_buffer( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + new_buffer, + ): + return utils.set_incremental_state( + self, incremental_state, "input_buffer", new_buffer + ) + + @torch.jit.unused + def _get_linearized_weight(self): + if self._linearized_weight is None: + kw = self.kernel_size[0] + weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous() + assert weight.size() == (self.out_channels, kw, self.in_channels) + return weight.view(self.out_channels, -1) + return self._linearized_weight + + @torch.jit.unused + def _clear_linearized_weight(self, *args): + self._linearized_weight = None diff --git a/fairseq/modules/location_attention.py b/fairseq/modules/location_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..dbbbfb9f2d4b64a2ccc0281cc1679d789ef8374f --- /dev/null +++ b/fairseq/modules/location_attention.py @@ -0,0 +1,83 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn +import torch +import torch.nn.functional as F + + +class LocationAttention(nn.Module): + """ + Attention-Based Models for Speech Recognition + https://arxiv.org/pdf/1506.07503.pdf + + :param int encoder_dim: # projection-units of encoder + :param int decoder_dim: # units of decoder + :param int attn_dim: attention dimension + :param int conv_dim: # channels of attention convolution + :param int conv_kernel_size: filter size of attention convolution + """ + + def __init__( + self, + attn_dim, + encoder_dim, + decoder_dim, + attn_state_kernel_size, + conv_dim, + conv_kernel_size, + scaling=2.0, + ): + super(LocationAttention, self).__init__() + self.attn_dim = attn_dim + self.decoder_dim = decoder_dim + self.scaling = scaling + self.proj_enc = nn.Linear(encoder_dim, attn_dim) + self.proj_dec = nn.Linear(decoder_dim, attn_dim, bias=False) + self.proj_attn = nn.Linear(conv_dim, attn_dim, bias=False) + self.conv = nn.Conv1d( + attn_state_kernel_size, + conv_dim, + 2 * conv_kernel_size + 1, + padding=conv_kernel_size, + bias=False, + ) + self.proj_out = nn.Sequential(nn.Tanh(), nn.Linear(attn_dim, 1)) + + self.proj_enc_out = None # cache + + def clear_cache(self): + self.proj_enc_out = None + + def forward(self, encoder_out, encoder_padding_mask, decoder_h, attn_state): + """ + :param torch.Tensor encoder_out: padded encoder hidden state B x T x D + :param torch.Tensor encoder_padding_mask: encoder padding mask + :param torch.Tensor decoder_h: decoder hidden state B x D + :param torch.Tensor attn_prev: previous attention weight B x K x T + :return: attention weighted encoder state (B, D) + :rtype: torch.Tensor + :return: previous attention weights (B x T) + :rtype: torch.Tensor + """ + bsz, seq_len, _ = encoder_out.size() + if self.proj_enc_out is None: + self.proj_enc_out = self.proj_enc(encoder_out) + + # B x K x T -> B x C x T + attn = self.conv(attn_state) + # B x C x T -> B x T x C -> B x T x D + attn = self.proj_attn(attn.transpose(1, 2)) + + if decoder_h is None: + decoder_h = encoder_out.new_zeros(bsz, self.decoder_dim) + dec_h = self.proj_dec(decoder_h).view(bsz, 1, self.attn_dim) + + out = self.proj_out(attn + self.proj_enc_out + dec_h).squeeze(2) + out.masked_fill_(encoder_padding_mask, -float("inf")) + + w = F.softmax(self.scaling * out, dim=1) + c = torch.sum(encoder_out * w.view(bsz, seq_len, 1), dim=1) + return c, w diff --git a/fairseq/modules/lstm_cell_with_zoneout.py b/fairseq/modules/lstm_cell_with_zoneout.py new file mode 100644 index 0000000000000000000000000000000000000000..273308951f7b3d1924675fde4a62359602a11bca --- /dev/null +++ b/fairseq/modules/lstm_cell_with_zoneout.py @@ -0,0 +1,37 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn + + +class LSTMCellWithZoneOut(nn.Module): + """ + Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations + https://arxiv.org/abs/1606.01305 + """ + + def __init__( + self, prob: float, input_size: int, hidden_size: int, bias: bool = True + ): + super(LSTMCellWithZoneOut, self).__init__() + self.lstm_cell = nn.LSTMCell(input_size, hidden_size, bias=bias) + self.prob = prob + if prob > 1.0 or prob < 0.0: + raise ValueError( + "zoneout probability must be in the range from " "0.0 to 1.0." + ) + + def zoneout(self, h, next_h, prob): + if isinstance(h, tuple): + return tuple([self.zoneout(h[i], next_h[i], prob) for i in range(len(h))]) + + if self.training: + mask = h.new_zeros(*h.size()).bernoulli_(prob) + return mask * h + (1 - mask) * next_h + + return prob * h + (1 - prob) * next_h + + def forward(self, x, h): + return self.zoneout(h, self.lstm_cell(x, h), self.prob) diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..262132dfe7575d99705048dd6c6092505eb6898e --- /dev/null +++ b/fairseq/modules/multihead_attention.py @@ -0,0 +1,910 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn import Parameter + +try: + from xformers.components.attention import build_attention + from xformers.components.attention.utils import maybe_merge_masks + + _xformers_available = True +except ImportError: + _xformers_available = False + +from fairseq import utils +from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.quant_noise import quant_noise +from fairseq.models.fairseq_incremental_decoder import FairseqIncrementalDecoder + + +# TODO: move this into xformers? +# TODO: uint8 input type should just output a bool +def _mask_for_xformers(mask: Tensor, to_dtype: Optional[torch.dtype] = None): + """ + call to pytorch multihead accepts three mask types: + - ByteTensor where non-zero means to mask + - FloatTensor which is an additive mask + - BoolTensor where True means to mask + xFormers currently accepts boolean and additive maks. For boolean masks + the values have opposite meaning. For a BoolTensor True mean to keep the value. + """ + float_types = [torch.float, torch.float16] + # If an input mask is a float it is an additive mask. Otherwise it is either uint8 or bool. + additive = mask.dtype in float_types + # If to_dype is not specified, keep same dtype as mask. + to_dtype = mask.dtype if to_dtype is None else to_dtype + to_additive = to_dtype in float_types + + if additive: + if to_additive: + return mask.to(to_dtype) + mask = mask < 0 + + if to_additive: + # return additive mask + new_mask = torch.zeros_like(mask, dtype=to_dtype) + new_mask = new_mask.masked_fill_(mask, -float("inf")) + return new_mask + + # In xFormers True is value to keep rather than value to mask + mask = ~mask.to(torch.bool) + mask = mask.to(to_dtype) + return mask + + +class MultiheadAttention(FairseqIncrementalDecoder): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + dictionary=None, + q_noise=0.0, + qn_block_size=8, + # TODO: pass in config rather than string. + # config defined in xformers.components.attention.AttentionConfig + xformers_att_config: Optional[str] = None, + xformers_blocksparse_layout: Optional[ + torch.Tensor + ] = None, # This should be part of the config + xformers_blocksparse_blocksize: Optional[ + int + ] = 16, # This should be part of the config + ): + super().__init__(dictionary) + + xformers_att_config = utils.eval_str_dict(xformers_att_config) + self.use_xformers = xformers_att_config is not None + if self.use_xformers and not _xformers_available: + raise ImportError("\n\n Please install xFormers.") + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) + + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim**-0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + self.k_proj = quant_noise( + nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + self.beam_size = 1 + self.reset_parameters() + + if self.use_xformers: + xformers_att_config["dropout"] = xformers_att_config.get("dropout", dropout) + xformers_att_config["num_heads"] = xformers_att_config.get( + "num_heads", num_heads + ) + + if xformers_blocksparse_layout is not None: + # Could be part of a single config passed only once + xformers_att_config["block_size"] = xformers_blocksparse_blocksize + xformers_att_config["layout"] = xformers_blocksparse_layout + xformers_att_config["name"] = "blocksparse" + + self.attention = build_attention(xformers_att_config) + + self.onnx_trace = False + self.skip_embed_dim_check = False + self.init_incremental_state() + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + + def _get_reserve_head_index(self, num_heads_to_keep: int): + k_proj_heads_norm = [] + q_proj_heads_norm = [] + v_proj_heads_norm = [] + + for i in range(self.num_heads): + start_idx = i * self.head_dim + end_idx = (i + 1) * self.head_dim + k_proj_heads_norm.append( + torch.sum( + torch.abs( + self.k_proj.weight[ + start_idx:end_idx, + ] + ) + ).tolist() + + torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist() + ) + q_proj_heads_norm.append( + torch.sum( + torch.abs( + self.q_proj.weight[ + start_idx:end_idx, + ] + ) + ).tolist() + + torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist() + ) + v_proj_heads_norm.append( + torch.sum( + torch.abs( + self.v_proj.weight[ + start_idx:end_idx, + ] + ) + ).tolist() + + torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist() + ) + + heads_norm = [] + for i in range(self.num_heads): + heads_norm.append( + k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i] + ) + + sorted_head_index = sorted( + range(self.num_heads), key=lambda k: heads_norm[k], reverse=True + ) + reserve_head_index = [] + for i in range(num_heads_to_keep): + start = sorted_head_index[i] * self.head_dim + end = (sorted_head_index[i] + 1) * self.head_dim + reserve_head_index.append((start, end)) + return reserve_head_index + + def _adaptive_prune_heads(self, reserve_head_index: List[Tuple[int, int]]): + new_q_weight = [] + new_q_bias = [] + new_k_weight = [] + new_k_bias = [] + new_v_weight = [] + new_v_bias = [] + new_out_proj_weight = [] + + for ele in reserve_head_index: + start_idx, end_idx = ele + new_q_weight.append( + self.q_proj.weight[ + start_idx:end_idx, + ] + ) + new_q_bias.append(self.q_proj.bias[start_idx:end_idx]) + + new_k_weight.append( + self.k_proj.weight[ + start_idx:end_idx, + ] + ) + + new_k_bias.append(self.k_proj.bias[start_idx:end_idx]) + + new_v_weight.append( + self.v_proj.weight[ + start_idx:end_idx, + ] + ) + new_v_bias.append(self.v_proj.bias[start_idx:end_idx]) + + new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx]) + + new_q_weight = torch.cat(new_q_weight).detach() + new_k_weight = torch.cat(new_k_weight).detach() + new_v_weight = torch.cat(new_v_weight).detach() + new_out_proj_weight = torch.cat(new_out_proj_weight, dim=-1).detach() + new_q_weight.requires_grad = True + new_k_weight.requires_grad = True + new_v_weight.requires_grad = True + new_out_proj_weight.requires_grad = True + + new_q_bias = torch.cat(new_q_bias).detach() + new_q_bias.requires_grad = True + + new_k_bias = torch.cat(new_k_bias).detach() + new_k_bias.requires_grad = True + + new_v_bias = torch.cat(new_v_bias).detach() + new_v_bias.requires_grad = True + + self.q_proj.weight = torch.nn.Parameter(new_q_weight) + self.q_proj.bias = torch.nn.Parameter(new_q_bias) + + self.k_proj.weight = torch.nn.Parameter(new_k_weight) + self.k_proj.bias = torch.nn.Parameter(new_k_bias) + + self.v_proj.weight = torch.nn.Parameter(new_v_weight) + self.v_proj.bias = torch.nn.Parameter(new_v_bias) + + self.out_proj.weight = torch.nn.Parameter(new_out_proj_weight) + + self.num_heads = len(reserve_head_index) + self.embed_dim = self.head_dim * self.num_heads + self.q_proj.out_features = self.embed_dim + self.k_proj.out_features = self.embed_dim + self.v_proj.out_features = self.embed_dim + + def _set_skip_embed_dim_check(self): + self.skip_embed_dim_check = True + + def _pad_masks( + self, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + ) -> Tuple[Optional[Tensor], Optional[Tensor]]: + if attn_mask is not None: + shape = attn_mask.size()[:-1] + torch.Size([1]) + attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(shape)], dim=-1) + if key_padding_mask is not None: + shape = key_padding_mask.size()[:-1] + torch.Size([1]) + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(shape), + ], + dim=-1, + ) + return key_padding_mask, attn_mask + + def _add_bias( + self, + k: Tensor, + v: Tensor, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + bsz: int, + ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + assert self.bias_k is not None + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + key_padding_mask, attn_mask = self._pad_masks( + key_padding_mask=key_padding_mask, attn_mask=attn_mask + ) + return k, v, key_padding_mask, attn_mask + + def _append_zero_attn( + self, + k: Tensor, + v: Tensor, + key_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor], + ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + zero_attn_shape = k.size()[:-2] + torch.Size([1]) + k.size()[-1:] + k = torch.cat( + [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=-2 + ) + v = torch.cat( + [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=-2 + ) + key_padding_mask, attn_mask = self._pad_masks( + key_padding_mask=key_padding_mask, attn_mask=attn_mask + ) + return k, v, key_padding_mask, attn_mask + + def _xformers_attn_forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + + tgt_len, bsz, embed_dim = query.size() + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == tgt_len + + if self.self_attention: + key = query + value = query + elif self.encoder_decoder_attention: + value = key + + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + + if self.bias_k is not None: + assert self.bias_v is not None + k, v, attn_mask, key_padding_mask = self._add_bias( + k, v, attn_mask, key_padding_mask, bsz + ) + + def fold_heads(x): + return ( + x.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + def split_heads(x): + return ( + x.contiguous() + .view(-1, bsz, self.num_heads, self.head_dim) + .transpose(0, 1) + .transpose(1, 2) + ) + + massage = split_heads if self.attention.requires_head_dimension else fold_heads + q = massage(q) + if k is not None: + k = massage(k) + if v is not None: + v = massage(v) + + if self.add_zero_attn: + k, v, key_padding_mask, attn_mask = self._append_zero_attn( + k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask + ) + + kwargs = {} + + if attn_mask is not None and self.attention.supports_attention_mask: + attn_mask = _mask_for_xformers(attn_mask, to_dtype=q.dtype) + kwargs["att_mask"] = attn_mask + + if key_padding_mask is not None: + key_padding_mask = _mask_for_xformers(key_padding_mask, to_dtype=torch.bool) + if not self.attention.requires_separate_masks: + attn_mask = maybe_merge_masks( + attn_mask, + key_padding_mask, + batch_size=bsz, + src_len=k.size(-2), + tgt_len=q.size(-2), + num_heads=self.num_heads, + ) + key_padding_mask = None + kwargs["att_mask"] = attn_mask + if self.attention.supports_key_padding_mask: + kwargs["key_padding_mask"] = key_padding_mask + + y = self.attention(q, k, v, **kwargs) + + y = ( + y.view(bsz, self.num_heads, tgt_len, self.head_dim) + .transpose(1, 2) + .flatten(start_dim=2, end_dim=3) + .transpose(0, 1) + ) + assert list(y.size()) == [tgt_len, bsz, embed_dim] + + # Dropout not needed because already applied in attention. + # It is applied to the attention weights before matmul with v. + y = self.out_proj(y) + + # TODO: support returning attention weights if needed. + return y, None + + def forward( + self, + query: Tensor, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + ) -> Tuple[Tensor, Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + if not self.skip_embed_dim_check: + assert ( + embed_dim == self.embed_dim + ), f"query dim {embed_dim} != {self.embed_dim}" + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert value is not None + assert src_len, key_bsz == value.shape[:2] + + if ( + not self.onnx_trace + and not is_tpu # don't use PyTorch version on TPUs + and incremental_state is None + and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting() + # The Multihead attention implemented in pytorch forces strong dimension check + # for input embedding dimention and K,Q,V projection dimension. + # Since pruning will break the dimension check and it is not easy to modify the pytorch API, + # it is preferred to bypass the pytorch MHA when we need to skip embed_dim_check + and not self.skip_embed_dim_check + ): + assert key is not None and value is not None + + if self.use_xformers: + return self._xformers_attn_forward( + query, key, value, key_padding_mask, need_weights, attn_mask + ) + + else: + return F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training or self.dropout_module.apply_during_inference, + key_padding_mask.bool() if key_padding_mask is not None else None, + need_weights, + attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + if self.beam_size > 1 and bsz == key.size(1): + # key is [T, bsz*beam_size, C], reduce to [T, bsz, C] + key = key.view(key.size(0), -1, self.beam_size, key.size(2))[ + :, :, 0, : + ] + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.view( + -1, self.beam_size, key_padding_mask.size(1) + )[:, 0, :] + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k, v, attn_mask, key_padding_mask = self._add_bias( + k, v, attn_mask, key_padding_mask, bsz + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + kv_bsz = bsz # need default value for scripting + if k is not None: + kv_bsz = k.size(1) + k = ( + k.contiguous() + .view(-1, kv_bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, kv_bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + kv_bsz = _prev_key.size(0) + prev_key = _prev_key.view(kv_bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + assert kv_bsz == _prev_value.size(0) + prev_value = _prev_value.view( + kv_bsz * self.num_heads, -1, self.head_dim + ) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=kv_bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(kv_bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view( + kv_bsz, self.num_heads, -1, self.head_dim + ) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == kv_bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k, v, key_padding_mask, attn_mask = self._append_zero_attn( + k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask + ) + + if self.encoder_decoder_attention and bsz != kv_bsz: + attn_weights = torch.einsum( + "bxhtd,bhsd->bxhts", + q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]), + k.view((kv_bsz, self.num_heads) + k.size()[1:]), + ) + attn_weights = attn_weights.reshape((-1,) + attn_weights.size()[-2:]) + else: + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + if self.onnx_trace: + attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.view( + kv_bsz, -1, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v + + attn_weights_float = utils.softmax( + attn_weights, dim=-1, onnx_trace=self.onnx_trace + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn: Optional[Tensor] = None + if self.encoder_decoder_attention and bsz != kv_bsz: + attn = torch.einsum( + "bxhts,bhsd->bxhtd", + attn_probs.view( + ( + kv_bsz, + -1, + self.num_heads, + ) + + attn_probs.size()[1:] + ), + v.view( + ( + kv_bsz, + self.num_heads, + ) + + v.size()[1:] + ), + ) + attn = attn.reshape((-1,) + attn.size()[-2:]) + else: + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + if self.onnx_trace and attn.size(1) == 1: + # when ONNX tracing a single decoder step (sequence length == 1) + # the transpose is a no-op copy before view, thus unnecessary + attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim) + else: + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + @torch.jit.export + def reorder_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + new_order: Tensor, + ): + """Reorder buffered internal state (for incremental generation).""" + input_buffer = self._get_input_buffer(incremental_state) + if input_buffer is not None: + for k in input_buffer.keys(): + input_buffer_k = input_buffer[k] + if input_buffer_k is not None: + if self.encoder_decoder_attention: + if input_buffer_k.size(0) * self.beam_size == new_order.size(0): + return incremental_state + elif self.beam_size > 1: + input_buffer[k] = input_buffer_k.index_select( + 0, + new_order.reshape(-1, self.beam_size)[:, 0] + // self.beam_size, + ) + else: + input_buffer[k] = input_buffer_k.index_select(0, new_order) + else: + input_buffer[k] = input_buffer_k.index_select(0, new_order) + incremental_state = self._set_input_buffer(incremental_state, input_buffer) + return incremental_state + + def set_beam_size(self, beam_size): + """Used for effiecient beamable enc-dec attention""" + self.beam_size = beam_size + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights + + def upgrade_state_dict_named(self, state_dict, name): + prefix = name + "." if name != "" else "" + items_to_add = {} + keys_to_remove = [] + for k in state_dict.keys(): + if k.endswith(prefix + "in_proj_weight"): + # in_proj_weight used to be q + k + v with same dimensions + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim] + items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim] + items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :] + + keys_to_remove.append(k) + + k_bias = prefix + "in_proj_bias" + if k_bias in state_dict.keys(): + dim = int(state_dict[k].shape[0] / 3) + items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim] + items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][ + dim : 2 * dim + ] + items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :] + + keys_to_remove.append(prefix + "in_proj_bias") + + for k in keys_to_remove: + del state_dict[k] + + for key, value in items_to_add.items(): + state_dict[key] = value diff --git a/fairseq/modules/positional_embedding.py b/fairseq/modules/positional_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..97cd474b512f4ab851b4a135398a4d52a2b64533 --- /dev/null +++ b/fairseq/modules/positional_embedding.py @@ -0,0 +1,35 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn + +from .learned_positional_embedding import LearnedPositionalEmbedding +from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding + + +def PositionalEmbedding( + num_embeddings: int, + embedding_dim: int, + padding_idx: int, + learned: bool = False, +): + if learned: + # if padding_idx is specified then offset the embedding ids by + # this index and adjust num_embeddings appropriately + # TODO: The right place for this offset would be inside + # LearnedPositionalEmbedding. Move this there for a cleaner implementation. + if padding_idx is not None: + num_embeddings = num_embeddings + padding_idx + 1 + m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5) + if padding_idx is not None: + nn.init.constant_(m.weight[padding_idx], 0) + else: + m = SinusoidalPositionalEmbedding( + embedding_dim, + padding_idx, + init_size=num_embeddings + padding_idx + 1, + ) + return m diff --git a/fairseq/modules/positional_encoding.py b/fairseq/modules/positional_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..67f635353918d7b485da3c31ec9a836dedec6b7b --- /dev/null +++ b/fairseq/modules/positional_encoding.py @@ -0,0 +1,129 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn +import math +import torch + + +class PositionalEncoding(nn.Module): + """Positional encoding. + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + reverse: Whether to reverse the input position. + """ + + def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False): + """Construct an PositionalEncoding object.""" + super(PositionalEncoding, self).__init__() + self.d_model = d_model + self.reverse = reverse + self.xscale = math.sqrt(self.d_model) + self.dropout = nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model) + if self.reverse: + position = torch.arange( + x.size(1) - 1, -1, -1.0, dtype=torch.float32 + ).unsqueeze(1) + else: + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor): + """Add positional encoding. + Args: + x (torch.Tensor): Input tensor B X T X C + Returns: + torch.Tensor: Encoded tensor B X T X C + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1)] + return self.dropout(x) + + +class RelPositionalEncoding(nn.Module): + """Relative positional encoding module (new implementation). + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + """ + + def __init__(self, max_len, d_model): + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i 0: + # given an empty cluster, find most populated cluster and split it into two + k = random.choice(list(empty_clusters)) + m = counts.most_common(1)[0][0] + e = torch.randn_like(self.centroids[m]) * self.eps + self.centroids[k] = self.centroids[m].clone() + self.centroids[k] += e + self.centroids[m] -= e + + # recompute assignments + distances = self.compute_distances() # (n_centroids x out_features) + self.assignments = torch.argmin(distances, dim=0) # (out_features) + + # check for empty clusters + counts = Counter(map(lambda x: x.item(), self.assignments)) + empty_clusters = set(range(self.n_centroids)) - set(counts.keys()) + + # increment tentatives + if tentatives == self.max_tentatives: + logging.info( + f"Could not resolve all empty clusters, {len(empty_clusters)} remaining" + ) + raise EmptyClusterResolveError + tentatives += 1 + + return n_empty_clusters + + def compute_distances(self): + """ + For every centroid m, computes + + ||M - m[None, :]||_2 + + Remarks: + - We rely on PyTorch's broadcasting to speed up computations + and reduce the memory overhead + - Without chunking, the sizes in the broadcasting are modified as: + (n_centroids x n_samples x out_features) -> (n_centroids x out_features) + - The broadcasting computation is automatically chunked so that + the tensors fit into the memory of the GPU + """ + + nb_centroids_chunks = 1 + + while True: + try: + return torch.cat( + [ + (self.W[None, :, :] - centroids_c[:, :, None]).norm(p=2, dim=1) + for centroids_c in self.centroids.chunk( + nb_centroids_chunks, dim=0 + ) + ], + dim=0, + ) + except RuntimeError: + nb_centroids_chunks *= 2 + + def assign(self): + """ + Assigns each column of W to its closest centroid, thus essentially + performing the E-step in train(). + + Remarks: + - The function must be called after train() or after loading + centroids using self.load(), otherwise it will return empty tensors + """ + + distances = self.compute_distances() # (n_centroids x out_features) + self.assignments = torch.argmin(distances, dim=0) # (out_features) + + def save(self, path, layer): + """ + Saves centroids and assignments. + + Args: + - path: folder used to save centroids and assignments + """ + + torch.save(self.centroids, os.path.join(path, "{}_centroids.pth".format(layer))) + torch.save( + self.assignments, os.path.join(path, "{}_assignments.pth".format(layer)) + ) + torch.save(self.objective, os.path.join(path, "{}_objective.pth".format(layer))) + + def load(self, path, layer): + """ + Loads centroids and assignments from a given path + + Args: + - path: folder use to load centroids and assignments + """ + + self.centroids = torch.load( + os.path.join(path, "{}_centroids.pth".format(layer)) + ) + self.assignments = torch.load( + os.path.join(path, "{}_assignments.pth".format(layer)) + ) + self.objective = torch.load( + os.path.join(path, "{}_objective.pth".format(layer)) + ) + + +class EmptyClusterResolveError(Exception): + pass diff --git a/fairseq/modules/quantization/pq/modules/__init__.py b/fairseq/modules/quantization/pq/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b67c8e8ad691aa01e9e10e904d69d94595387668 --- /dev/null +++ b/fairseq/modules/quantization/pq/modules/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .qconv import PQConv2d # NOQA +from .qemb import PQEmbedding # NOQA +from .qlinear import PQLinear # NOQA diff --git a/fairseq/modules/quantization/pq/modules/__pycache__/__init__.cpython-310.pyc b/fairseq/modules/quantization/pq/modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5c436bea8ce21e1fcd9130fb7b01ba46f23e20f Binary files /dev/null and b/fairseq/modules/quantization/pq/modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/modules/quantization/pq/modules/__pycache__/qconv.cpython-310.pyc b/fairseq/modules/quantization/pq/modules/__pycache__/qconv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c94b821a171ad6ffa7f0bb58262e4e971c15842 Binary files /dev/null and b/fairseq/modules/quantization/pq/modules/__pycache__/qconv.cpython-310.pyc differ diff --git a/fairseq/modules/quantization/pq/modules/__pycache__/qemb.cpython-310.pyc b/fairseq/modules/quantization/pq/modules/__pycache__/qemb.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b57a736527f40b5cb7d6ece983e891921119f765 Binary files /dev/null and b/fairseq/modules/quantization/pq/modules/__pycache__/qemb.cpython-310.pyc differ diff --git a/fairseq/modules/quantization/pq/modules/__pycache__/qlinear.cpython-310.pyc b/fairseq/modules/quantization/pq/modules/__pycache__/qlinear.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbc68855c0480c6ce3739e98b4d6cbbf777e044b Binary files /dev/null and b/fairseq/modules/quantization/pq/modules/__pycache__/qlinear.cpython-310.pyc differ diff --git a/fairseq/modules/quantization/pq/modules/qconv.py b/fairseq/modules/quantization/pq/modules/qconv.py new file mode 100644 index 0000000000000000000000000000000000000000..d15ec192e8cda6265a198e583a9bf7fb194dd129 --- /dev/null +++ b/fairseq/modules/quantization/pq/modules/qconv.py @@ -0,0 +1,115 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.utils import _pair + + +class PQConv2d(nn.Module): + """ + Quantized counterpart of nn.Conv2d module. Stores the centroid, the assignments + and the non-quantized biases. The full weight is re-instantiated at each forward + pass and autograd automatically computes the gradients with respect to the + centroids. + + Args: + - centroids: centroids of size n_centroids x block_size + - assignments: assignments of the centroids to the subvectors + of size self.out_channels x n_blocks + - bias: the non-quantized bias, must be either torch.Tensor or None + + Remarks: + - We refer the reader to the official documentation of the nn.Conv2d module + for the other arguments and the behavior of the module. + - Performance tests on GPU show that this implementation is 10% slower than + the non-quantized nn.Conv2d module for a standard training loop. + - During the backward, the gradients are averaged by cluster and not summed. + This explains the hook registered to the centroids. + """ + + def __init__( + self, + centroids, + assignments, + bias, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + padding_mode="zeros", + ): + super(PQConv2d, self).__init__() + self.block_size = centroids.size(1) + self.n_centroids = centroids.size(0) + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.padding_mode = padding_mode + # check compatibility + if in_channels // groups * np.prod(self.kernel_size) % self.block_size != 0: + raise ValueError("Wrong PQ sizes") + if len(assignments) % out_channels != 0: + raise ValueError("Wrong PQ sizes") + if in_channels % groups != 0: + raise ValueError("in_channels must be divisible by groups") + if out_channels % groups != 0: + raise ValueError("out_channels must be divisible by groups") + # define parameters + self.centroids = nn.Parameter(centroids, requires_grad=True) + self.register_buffer("assignments", assignments) + self.register_buffer("counts", torch.bincount(assignments).type_as(centroids)) + if bias is not None: + self.bias = nn.Parameter(bias) + else: + self.register_parameter("bias", None) + # register hook for averaging gradients per centroids instead of summing + self.centroids.register_hook(lambda x: x / self.counts[:, None]) + + @property + def weight(self): + return ( + self.centroids[self.assignments] + .reshape(-1, self.out_channels, self.block_size) + .permute(1, 0, 2) + .reshape( + self.out_channels, self.in_channels // self.groups, *self.kernel_size + ) + ) + + def forward(self, x): + return F.conv2d( + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + + def extra_repr(self): + s = "{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}" + if self.padding != (0,) * len(self.padding): + s += ", padding={padding}" + if self.dilation != (1,) * len(self.dilation): + s += ", dilation={dilation}" + if self.groups != 1: + s += ", groups={groups}" + if self.bias is None: + s += ", bias=False" + if self.padding_mode != "zeros": + s += ", padding_mode={padding_mode}" + s += ", n_centroids={n_centroids}, block_size={block_size}" + return s.format(**self.__dict__) diff --git a/fairseq/modules/quantization/pq/modules/qemb.py b/fairseq/modules/quantization/pq/modules/qemb.py new file mode 100644 index 0000000000000000000000000000000000000000..3a74ad3c4c7c9d3203d26e7885864ba578951bfe --- /dev/null +++ b/fairseq/modules/quantization/pq/modules/qemb.py @@ -0,0 +1,107 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class PQEmbedding(nn.Module): + """ + Quantized counterpart of nn.Embedding module. Stores the centroids and + the assignments. The full weight is re-instantiated at each forward + pass. + + Args: + - centroids: centroids of size n_centroids x block_size + - assignments: assignments of the centroids to the subvectors + of size self.out_features x n_blocks + - bias: the non-quantized bias + + Remarks: + - We refer the reader to the official documentation of the nn.Embedding module + for the other arguments and the behavior of the module + - Performance tests on GPU show that this implementation is 10% slower than + the non-quantized nn.Embedding module for a standard training loop. + """ + + def __init__( + self, + centroids, + assignments, + num_embeddings, + embedding_dim, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + ): + super(PQEmbedding, self).__init__() + self.block_size = centroids.size(1) + self.n_centroids = centroids.size(0) + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + if padding_idx > 0: + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" + elif padding_idx < 0: + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + self.sparse = sparse + # check compatibility + if self.embedding_dim % self.block_size != 0: + raise ValueError("Wrong PQ sizes") + if len(assignments) % self.num_embeddings != 0: + raise ValueError("Wrong PQ sizes") + # define parameters + self.centroids = nn.Parameter(centroids, requires_grad=True) + self.register_buffer("assignments", assignments) + self.register_buffer("counts", torch.bincount(assignments).type_as(centroids)) + + @property + def weight(self): + return ( + self.centroids[self.assignments] + .reshape(-1, self.num_embeddings, self.block_size) + .permute(1, 0, 2) + .flatten(1, 2) + ) + + def forward(self, input): + return F.embedding( + input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + + def extra_repr(self): + s = "{num_embeddings}, {embedding_dim}" + if self.padding_idx is not None: + s += ", padding_idx={padding_idx}" + if self.max_norm is not None: + s += ", max_norm={max_norm}" + if self.norm_type != 2: + s += ", norm_type={norm_type}" + if self.scale_grad_by_freq is not False: + s += ", scale_grad_by_freq={scale_grad_by_freq}" + if self.sparse is not False: + s += ", sparse=True" + s += ", n_centroids={n_centroids}, block_size={block_size}" + + return s.format(**self.__dict__) diff --git a/fairseq/modules/quantization/pq/modules/qlinear.py b/fairseq/modules/quantization/pq/modules/qlinear.py new file mode 100644 index 0000000000000000000000000000000000000000..9bdd25a8685bb7c7b32e1f02372aaeb26d8ba53a --- /dev/null +++ b/fairseq/modules/quantization/pq/modules/qlinear.py @@ -0,0 +1,71 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class PQLinear(nn.Module): + """ + Quantized counterpart of nn.Linear module. Stores the centroid, the assignments + and the non-quantized biases. The full weight is re-instantiated at each forward + pass. + + Args: + - centroids: centroids of size n_centroids x block_size + - assignments: assignments of the centroids to the subvectors + of size self.out_features x n_blocks + - bias: the non-quantized bias + + Remarks: + - We refer the reader to the official documentation of the nn.Linear module + for the other arguments and the behavior of the module + - Performance tests on GPU show that this implementation is 15% slower than + the non-quantized nn.Linear module for a standard training loop. + """ + + def __init__(self, centroids, assignments, bias, in_features, out_features): + super(PQLinear, self).__init__() + self.block_size = centroids.size(1) + self.n_centroids = centroids.size(0) + self.in_features = in_features + self.out_features = out_features + # check compatibility + if self.in_features % self.block_size != 0: + raise ValueError("Wrong PQ sizes") + if len(assignments) % self.out_features != 0: + raise ValueError("Wrong PQ sizes") + # define parameters + self.centroids = nn.Parameter(centroids, requires_grad=True) + self.register_buffer("assignments", assignments) + self.register_buffer("counts", torch.bincount(assignments).type_as(centroids)) + if bias is not None: + self.bias = nn.Parameter(bias) + else: + self.register_parameter("bias", None) + + @property + def weight(self): + return ( + self.centroids[self.assignments] + .reshape(-1, self.out_features, self.block_size) + .permute(1, 0, 2) + .flatten(1, 2) + ) + + def forward(self, x): + return F.linear( + x, + self.weight, + self.bias, + ) + + def extra_repr(self): + return f"in_features={self.in_features},\ + out_features={self.out_features},\ + n_centroids={self.n_centroids},\ + block_size={self.block_size},\ + bias={self.bias is not None}" diff --git a/fairseq/modules/quantization/pq/pq.py b/fairseq/modules/quantization/pq/pq.py new file mode 100644 index 0000000000000000000000000000000000000000..eddc2eb34602403f10979f54cd23a45bc2f104d5 --- /dev/null +++ b/fairseq/modules/quantization/pq/pq.py @@ -0,0 +1,128 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .em import EM, EmptyClusterResolveError + + +class PQ(EM): + """ + Quantizes the layer weights W with the standard Product Quantization + technique. This learns a codebook of codewords or centroids of size + block_size from W. For further reference on using PQ to quantize + neural networks, see "And the Bit Goes Down: Revisiting the Quantization + of Neural Networks", Stock et al., ICLR 2020. + + PQ is performed in two steps: + (1) The matrix W (weights or fully-connected or convolutional layer) + is reshaped to (block_size, -1). + - If W is fully-connected (2D), its columns are split into + blocks of size block_size. + - If W is convolutional (4D), its filters are split along the + spatial dimension. + (2) We apply the standard EM/k-means algorithm to the resulting reshaped matrix. + + Args: + - W: weight matrix to quantize of size (in_features x out_features) + - block_size: size of the blocks (subvectors) + - n_centroids: number of centroids + - n_iter: number of k-means iterations + - eps: for cluster reassignment when an empty cluster is found + - max_tentatives for cluster reassignment when an empty cluster is found + - verbose: print information after each iteration + + Remarks: + - block_size be compatible with the shape of W + """ + + def __init__( + self, + W, + block_size, + n_centroids=256, + n_iter=20, + eps=1e-6, + max_tentatives=30, + verbose=True, + ): + self.block_size = block_size + W_reshaped = self._reshape(W) + super(PQ, self).__init__( + W_reshaped, + n_centroids=n_centroids, + n_iter=n_iter, + eps=eps, + max_tentatives=max_tentatives, + verbose=verbose, + ) + + def _reshape(self, W): + """ + Reshapes the matrix W as expained in step (1). + """ + + # fully connected: by convention the weight has size out_features x in_features + if len(W.size()) == 2: + self.out_features, self.in_features = W.size() + assert ( + self.in_features % self.block_size == 0 + ), "Linear: n_blocks must be a multiple of in_features" + return ( + W.reshape(self.out_features, -1, self.block_size) + .permute(2, 1, 0) + .flatten(1, 2) + ) + + # convolutional: we reshape along the spatial dimension + elif len(W.size()) == 4: + self.out_channels, self.in_channels, self.k_h, self.k_w = W.size() + assert ( + self.in_channels * self.k_h * self.k_w + ) % self.block_size == 0, ( + "Conv2d: n_blocks must be a multiple of in_channels * k_h * k_w" + ) + return ( + W.reshape(self.out_channels, -1, self.block_size) + .permute(2, 1, 0) + .flatten(1, 2) + ) + # not implemented + else: + raise NotImplementedError(W.size()) + + def encode(self): + """ + Performs self.n_iter EM steps. + """ + + self.initialize_centroids() + for i in range(self.n_iter): + try: + self.step(i) + except EmptyClusterResolveError: + break + + def decode(self): + """ + Returns the encoded full weight matrix. Must be called after + the encode function. + """ + + # fully connected case + if "k_h" not in self.__dict__: + return ( + self.centroids[self.assignments] + .reshape(-1, self.out_features, self.block_size) + .permute(1, 0, 2) + .flatten(1, 2) + ) + + # convolutional case + else: + return ( + self.centroids[self.assignments] + .reshape(-1, self.out_channels, self.block_size) + .permute(1, 0, 2) + .reshape(self.out_channels, self.in_channels, self.k_h, self.k_w) + ) diff --git a/fairseq/modules/quantization/pq/utils.py b/fairseq/modules/quantization/pq/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eceeef8ba3eb5657374256caf371080ee7d97aa4 --- /dev/null +++ b/fairseq/modules/quantization/pq/utils.py @@ -0,0 +1,376 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import re +from operator import attrgetter, itemgetter +import torch +import numpy as np +import torch.distributed as dist +import torch.nn as nn + +from .modules import PQConv2d, PQEmbedding, PQLinear +from .pq import PQ + + +def quantize_model_( + model, + size_tracker, + layers_to_quantize, + block_sizes_config, + n_centroids_config, + step=0, + n_iter=15, + eps=1e-6, + max_tentatives=100, + remove_weights=False, + verbose=True, + state_dict=None, +): + """ + Quantize a model in-place by stages. All the targeted + layers are replaced by their quantized counterpart, + and the model is ready for the finetuning of the + centroids in a standard training loop (no modifications + required). Note that we do not quantize biases. + + Args: + - model: a nn.Module + - size_tracker: useful for tracking quatization statistics + - layers_to_quantize: a list containing regexps for + filtering the layers to quantize at each stage according + to their name (as in model.named_parameters()) + - block_sizes_config: dict like + { + 'Conv2d': ('kernel_size', {'(3, 3)': 9, '(1, 1)': 4}), + 'Linear': ('in_features', {'*': 8}) + } + For instance, all conv2d layers with kernel size 3x3 have + a block size of 9 and all Linear layers are quantized with + a block size of 8, irrespective of their size. + - n_centroids_config: dict like + { + 'Conv2d': ('kernel_size', {'*': 256}), + 'Linear': ('in_features', {'*': 256}) + } + For instance, all conv2d layers are quantized with 256 centroids + - step: the layers to quantize inplace corresponding + to layers_to_quantize[step] + """ + + quantized_layers = get_layers( + model, layers_to_quantize[step], remove_weights=remove_weights + ) + + for layer in quantized_layers: + + # book-keeping + is_master_process = (not dist.is_initialized()) or ( + dist.is_initialized() and dist.get_rank() == 0 + ) + verbose = verbose and is_master_process + + # get block size and centroids + module = attrgetter(layer)(model) + block_size = get_param(module, layer, block_sizes_config) + n_centroids = get_param(module, layer, n_centroids_config) + if verbose: + logging.info( + f"Quantizing layer {layer} with block size {block_size} and {n_centroids} centroids" + ) + + # quantize layer + weight = module.weight.data.clone() + is_bias = "bias" in [x[0] for x in module.named_parameters()] + bias = module.bias.data.clone() if is_bias else None + quantizer = PQ( + weight, + block_size, + n_centroids=n_centroids, + n_iter=n_iter, + eps=eps, + max_tentatives=max_tentatives, + verbose=verbose, + ) + + # quantization performed on all GPUs with same seed + quantizer.encode() + centroids = quantizer.centroids.contiguous() + assignments = quantizer.assignments.contiguous() + + # If n_iter = 0 and state_dict is provided, then + # we initialize random assignments and centroids to + # random values of the appropriate dimensions + # because the quantized model parameters will + # overwritten by the state_dict later on. + if n_iter == 0 and state_dict: + # Initialize random centroids of the correct size + centroids = torch.rand(centroids.size()) + centroids.cuda() + # Get counts and assignment keys from layer in loaded checkpoint. + counts_key = layer + "." + "counts" + assignment_key = layer + "." + "assignments" + # Get number of different bins to include. + counts = list(state_dict[counts_key].shape)[0] + print(layer) + print(state_dict[counts_key]) + print(counts) + # Initialize random assignments of the correct size + # with an appropriate number of bins. + num_assignments = list(state_dict[assignment_key].shape)[0] + num_extra = num_assignments - counts + print(num_assignments) + print(num_extra) + assignments_bins = torch.arange(counts) + assignments_rand = torch.randint(0, counts - 1, (num_extra,)) + assignments = torch.cat((assignments_bins, assignments_rand), 0) + # assignments = assignments.type(torch.IntTensor) + assignments.cuda() + print("assignments") + print(assignments) + + # broadcast results to make sure weights are up-to-date + if dist.is_initialized(): + dist.broadcast(centroids, 0) + dist.broadcast(assignments, 0) + + # instantiate the quantized counterpart + if isinstance(module, nn.Linear): + out_features, in_features = map( + lambda k: module.__dict__[k], ["out_features", "in_features"] + ) + quantized_module = PQLinear( + centroids, assignments, bias, in_features, out_features + ) + elif isinstance(module, nn.Embedding): + num_embeddings, embedding_dim = map( + lambda k: module.__dict__[k], ["num_embeddings", "embedding_dim"] + ) + quantized_module = PQEmbedding( + centroids, assignments, num_embeddings, embedding_dim + ) + elif isinstance(module, nn.Conv2d): + out_channels, in_channels, kernel_size = map( + lambda k: module.__dict__[k], + ["out_channels", "in_channels", "kernel_size"], + ) + stride, padding, dilation, groups, padding_mode = map( + lambda k: module.__dict__[k], + ["stride", "padding", "dilation", "groups", "padding_mode"], + ) + + quantized_module = PQConv2d( + centroids, + assignments, + bias, + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + padding_mode=padding_mode, + ) + else: + raise ValueError(f"Module {module} not yet supported for quantization") + + # replace layer by its quantized counterpart + attrsetter(layer)(model, quantized_module) + + # update statistics + size_tracker.update(weight, block_size, n_centroids) + + # return name of quantized layers + return quantized_layers + + +def get_layers(model, filter_regexp, remove_weights=False): + """ + Filters out the layers according to a regexp. Note that + we omit biases. + + Args: + - model: a nn.Module + - filter_regexp: a regexp to filter the layers to keep + according to their name in model.named_parameters(). + For instance, the regexp: + + down_layers\\.[123456]\\.(conv[12]|identity\\.conv)) + + is keeping blocks down_layers from 1 to 6, and inside + each block is keeping conv1, conv2 and identity.conv. + + Remarks: + - We add (module\\.)? at the beginning of the regexp to + account for the possible use of nn.parallel.DataParallel + """ + + # get all parameter names + all_layers = map(itemgetter(0), model.named_parameters()) + + # remove biases + all_layers = filter(lambda x: "bias" not in x, all_layers) + + # remove .weight in all other names (or .weight_orig is spectral norm) + all_layers = map(lambda x: x.replace(".weight_orig", ""), all_layers) + # remove weights indicates whether the weights extension should be removed, in addition to + # weight_orig and weight extension on names + if remove_weights: + all_layers = map(lambda x: x.replace(".weights", ""), all_layers) + all_layers = map(lambda x: x.replace(".weight", ""), all_layers) + + # return filtered layers + filter_regexp = "(module\\.)?" + "(" + filter_regexp + ")" + r = re.compile(filter_regexp) + + return list(filter(r.match, all_layers)) + + +def get_param(module, layer_name, param_config): + """ + Given a quantization configuration, get the right parameter + for the module to be quantized. + + Args: + - module: a nn.Module + - layer_name: the name of the layer + - param_config: a dict like + { + 'Conv2d': ('kernel_size', {'(3, 3)': 9, '(1, 1)': 4}), + 'Linear': ('in_features', {'*': 8}) + } + For instance, all conv2d layers with kernel size 3x3 have + a block size of 9 and all Linear layers are quantized with + a block size of 8, irrespective of their size. + + Remarks: + - if 'fuzzy_name' is passed as a parameter, layers whose layer_name + include 'fuzzy_name' will be assigned the given parameter. + In the following example, conv.expand layers will have a block + size of 9 while conv.reduce will have a block size of 4 and all + other layers will have a block size of 2. + { + 'Conv2d': ('fuzzy_name', {'expand': 9, 'reduce': 4, '*': 2}), + 'Linear': ('fuzzy_name', {'classifier': 8, 'projection': 4}) + } + + """ + + layer_type = module.__class__.__name__ + + if layer_type not in param_config: + raise KeyError(f"Layer type {layer_type} not in config for layer {module}") + + feature, params = param_config[module.__class__.__name__] + + if feature != "fuzzy_name": + feature_value = str(getattr(module, feature)) + if feature_value not in params: + if "*" in params: + feature_value = "*" + else: + raise KeyError( + f"{feature}={feature_value} not in config for layer {module}" + ) + else: + feature_values = [name for name in params if name in layer_name] + if len(feature_values) == 0: + if "*" in params: + feature_value = "*" + else: + raise KeyError(f"name={layer_name} not in config for {module}") + else: + feature_value = feature_values[0] + + return params[feature_value] + + +class SizeTracker(object): + """ + Class to keep track of the compressed network size with iPQ. + + Args: + - model: a nn.Module + + Remarks: + - The compressed size is the sum of three components + for each layer in the network: + (1) Storing the centroids given by iPQ in fp16 + (2) Storing the assignments of the blocks in int8 + (3) Storing all non-compressed elements such as biases + - This cost in only valid if we use 256 centroids (then + indexing can indeed by done with int8). + """ + + def __init__(self, model): + self.model = model + self.size_non_compressed_model = self.compute_size() + self.size_non_quantized = self.size_non_compressed_model + self.size_index = 0 + self.size_centroids = 0 + self.n_quantized_layers = 0 + + def compute_size(self): + """ + Computes the size of the model (in MB). + """ + + res = 0 + for _, p in self.model.named_parameters(): + res += p.numel() + return res * 4 / 1024 / 1024 + + def update(self, W, block_size, n_centroids): + """ + Updates the running statistics when quantizing a new layer. + """ + + # bits per weights + bits_per_weight = np.log2(n_centroids) / block_size + self.n_quantized_layers += 1 + + # size of indexing the subvectors of size block_size (in MB) + size_index_layer = bits_per_weight * W.numel() / 8 / 1024 / 1024 + self.size_index += size_index_layer + + # size of the centroids stored in float16 (in MB) + size_centroids_layer = n_centroids * block_size * 2 / 1024 / 1024 + self.size_centroids += size_centroids_layer + + # size of non-compressed layers, e.g. LayerNorms or biases (in MB) + size_uncompressed_layer = W.numel() * 4 / 1024 / 1024 + self.size_non_quantized -= size_uncompressed_layer + + def __repr__(self): + size_compressed = ( + self.size_index + self.size_centroids + self.size_non_quantized + ) + compression_ratio = self.size_non_compressed_model / size_compressed # NOQA + return ( + f"Non-compressed model size: {self.size_non_compressed_model:.2f} MB. " + f"After quantizing {self.n_quantized_layers} layers, size " + f"(indexing + centroids + other): {self.size_index:.2f} MB + " + f"{self.size_centroids:.2f} MB + {self.size_non_quantized:.2f} MB = " + f"{size_compressed:.2f} MB, compression ratio: {compression_ratio:.2f}x" + ) + + +def attrsetter(*items): + def resolve_attr(obj, attr): + attrs = attr.split(".") + head = attrs[:-1] + tail = attrs[-1] + + for name in head: + obj = getattr(obj, name) + return obj, tail + + def g(obj, val): + for attr in items: + resolved_obj, resolved_attr = resolve_attr(obj, attr) + setattr(resolved_obj, resolved_attr, val) + + return g diff --git a/fairseq/modules/quantization/quantization_options.py b/fairseq/modules/quantization/quantization_options.py new file mode 100644 index 0000000000000000000000000000000000000000..b46d682c0edaeaaf2a230e51d50da2a32d4bda98 --- /dev/null +++ b/fairseq/modules/quantization/quantization_options.py @@ -0,0 +1,44 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +def parse_config_yaml(yaml_data): + # Initialize to default options. + quantization_options = { + "n_centroids": { + "Linear": ["in_features", {"*": 256}], + "Embedding": ["embedding_dim", {"*": 256}], + }, + "block_sizes": { + "Linear": ["fuzzy_name", {"fc": 8, "attn": 4, "emb": 4}], + "Embedding": ["fuzzy_name", {"emb": 8}], + }, + "layers_to_quantize": [ + "decoder\\.layers\\.\\d+\\.fc[12]", + "decoder\\.embed_tokens\\.embeddings\\.[012]\\.[01]", + "decoder\\.layers\\.\\d+\\.self_attn\\.(k_proj|v_proj|q_proj|out_proj)", + ], + } + + if "n_centroids" in yaml_data: + quantization_options["n_centroids"] = { + layer: convert_yaml_to_tuple(layer_data) + for layer, layer_data in yaml_data["n_centroids"].items() + } + if "block_sizes" in yaml_data: + quantization_options["block_sizes"] = { + layer: convert_yaml_to_tuple(layer_data) + for layer, layer_data in yaml_data["block_sizes"].items() + } + if "layers_to_quantize" in yaml_data: + quantization_options["layers_to_quantize"] = yaml_data["layers_to_quantize"] + + return quantization_options + + +def convert_yaml_to_tuple(yaml_dictionary): + """Converts a yaml dictionary with two keys: `key` and `value` into a two + argument tuple of those values.""" + return (yaml_dictionary["key"], yaml_dictionary["value"]) diff --git a/fairseq/modules/quantization/scalar/__init__.py b/fairseq/modules/quantization/scalar/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..143834f3d036780eb6844c82f0c6f2d10cfe2f61 --- /dev/null +++ b/fairseq/modules/quantization/scalar/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .utils import quantize_model_ # NOQA diff --git a/fairseq/modules/quantization/scalar/__pycache__/__init__.cpython-310.pyc b/fairseq/modules/quantization/scalar/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7514868ee61aab08c2bc4a080ef8c83394ed70f2 Binary files /dev/null and b/fairseq/modules/quantization/scalar/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/modules/quantization/scalar/__pycache__/ops.cpython-310.pyc b/fairseq/modules/quantization/scalar/__pycache__/ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed2e31ac56ec453e15466ea8a8ea20dc770cc541 Binary files /dev/null and b/fairseq/modules/quantization/scalar/__pycache__/ops.cpython-310.pyc differ diff --git a/fairseq/modules/quantization/scalar/__pycache__/utils.cpython-310.pyc b/fairseq/modules/quantization/scalar/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f06f079ad058819f4340491e017a367aaa1827bc Binary files /dev/null and b/fairseq/modules/quantization/scalar/__pycache__/utils.cpython-310.pyc differ diff --git a/fairseq/modules/quantization/scalar/modules/__init__.py b/fairseq/modules/quantization/scalar/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8031d9cdb23f2bc72596f8bc9cfa4965f96e3e6c --- /dev/null +++ b/fairseq/modules/quantization/scalar/modules/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .qact import ActivationQuantizer # NOQA +from .qconv import IntConv2d # NOQA +from .qemb import IntEmbedding # NOQA +from .qlinear import IntLinear # NOQA diff --git a/fairseq/modules/quantization/scalar/modules/__pycache__/__init__.cpython-310.pyc b/fairseq/modules/quantization/scalar/modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0173400269a5943ad7d8611e7cf9296aa22ff50 Binary files /dev/null and b/fairseq/modules/quantization/scalar/modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/modules/quantization/scalar/modules/__pycache__/qact.cpython-310.pyc b/fairseq/modules/quantization/scalar/modules/__pycache__/qact.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..abded740986330b2df8162b689768750637101c9 Binary files /dev/null and b/fairseq/modules/quantization/scalar/modules/__pycache__/qact.cpython-310.pyc differ diff --git a/fairseq/modules/quantization/scalar/modules/__pycache__/qconv.cpython-310.pyc b/fairseq/modules/quantization/scalar/modules/__pycache__/qconv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16624f7b3a7e4978604c049bb820bfd69853c984 Binary files /dev/null and b/fairseq/modules/quantization/scalar/modules/__pycache__/qconv.cpython-310.pyc differ diff --git a/fairseq/modules/quantization/scalar/modules/__pycache__/qemb.cpython-310.pyc b/fairseq/modules/quantization/scalar/modules/__pycache__/qemb.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94d884ed6118cd9f4138b0dc19b9f5506aef11e1 Binary files /dev/null and b/fairseq/modules/quantization/scalar/modules/__pycache__/qemb.cpython-310.pyc differ diff --git a/fairseq/modules/quantization/scalar/modules/__pycache__/qlinear.cpython-310.pyc b/fairseq/modules/quantization/scalar/modules/__pycache__/qlinear.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7a79a71174fdae50163f08990c64b01f4ae7d78 Binary files /dev/null and b/fairseq/modules/quantization/scalar/modules/__pycache__/qlinear.cpython-310.pyc differ diff --git a/fairseq/modules/quantization/scalar/modules/qact.py b/fairseq/modules/quantization/scalar/modules/qact.py new file mode 100644 index 0000000000000000000000000000000000000000..b362c30dc7982e3717d1b764ae737ec1b24ea78e --- /dev/null +++ b/fairseq/modules/quantization/scalar/modules/qact.py @@ -0,0 +1,88 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from ..ops import emulate_int + + +class ActivationQuantizer: + """ + Fake scalar quantization of the activations using a forward hook. + + Args: + - module. a nn.Module for which we quantize the *post-activations* + - p: proportion of activations to quantize, set by default to 1 + - update_step: to recompute quantization parameters + - bits: number of bits for quantization + - method: choose among {"tensor", "histogram", "channel"} + - clamp_threshold: to prevent gradients overflow + + Remarks: + - Parameters scale and zero_point are recomputed every update_step + forward pass to reduce the overhead + - For the list of quantization methods and number of bits, see ops.py + - To remove the hook from the module, simply call self.handle.remove() + - At test time, the activations are fully quantized + - We use the straight-through estimator so that the gradients + back-propagate nicely in the network, this is implemented with + the detach() trick + - The activations are hard-clamped in [-clamp_threshold, clamp_threshold] + to prevent overflow during the backward pass + """ + + def __init__( + self, + module, + p=1, + update_step=1000, + bits=8, + method="histogram", + clamp_threshold=5, + ): + self.module = module + self.p = p + self.update_step = update_step + self.counter = 0 + self.bits = bits + self.method = method + self.clamp_threshold = clamp_threshold + self.handle = None + self.register_hook() + + def register_hook(self): + # forward hook + def quantize_hook(module, x, y): + + # update parameters every 1000 iterations + if self.counter % self.update_step == 0: + self.scale = None + self.zero_point = None + self.counter += 1 + + # train with QuantNoise and evaluate the fully quantized network + p = self.p if self.module.training else 1 + + # quantize activations + y_q, self.scale, self.zero_point = emulate_int( + y.detach(), + bits=self.bits, + method=self.method, + scale=self.scale, + zero_point=self.zero_point, + ) + + # mask to apply noise + mask = torch.zeros_like(y) + mask.bernoulli_(1 - p) + noise = (y_q - y).masked_fill(mask.bool(), 0) + + # using straight-through estimator (STE) + clamp_low = -self.scale * self.zero_point + clamp_high = self.scale * (2**self.bits - 1 - self.zero_point) + return torch.clamp(y, clamp_low.item(), clamp_high.item()) + noise.detach() + + # register hook + self.handle = self.module.register_forward_hook(quantize_hook) diff --git a/fairseq/modules/quantization/scalar/modules/qconv.py b/fairseq/modules/quantization/scalar/modules/qconv.py new file mode 100644 index 0000000000000000000000000000000000000000..29744744ecffaec6187888e343f6d6031660292f --- /dev/null +++ b/fairseq/modules/quantization/scalar/modules/qconv.py @@ -0,0 +1,149 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +from torch.nn.modules.conv import _ConvNd +from torch.nn.modules.utils import _pair + +from ..ops import emulate_int + + +class IntConv2d(_ConvNd): + """ + Quantized counterpart of the nn.Conv2d module that applies QuantNoise during training. + + Args: + - standard nn.Conv2d parameters + - p: amount of noise to inject (0 = no quantization, 1 = quantize all the weights) + - bits: number of bits + - method: choose among {"tensor", "histogram", "channel"} + - update_step: recompute scale and zero_point every update_steps iterations + + Remarks: + - We use the straight-thgourh estimator so that the gradients + back-propagate nicely in the network, this is implemented with + the detach() trick + - Parameters scale and zero_point are recomputed every update_step + forward pass to reduce the overhead + - At test time, the weights are fully quantized + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode="zeros", + p=0, + bits=8, + method="histogram", + update_step=1000, + ): + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + super(IntConv2d, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _pair(0), + groups, + bias, + padding_mode, + ) + + # quantization parameters + self.p = p + self.bits = bits + self.method = method + self.update_step = update_step + self.counter = 0 + + def _conv_forward(self, input, weight): + if self.padding_mode != "zeros": + return F.conv2d( + F.pad(input, self._padding_repeated_twice, mode=self.padding_mode), + weight, + self.bias, + self.stride, + _pair(0), + self.dilation, + self.groups, + ) + return F.conv2d( + input, + weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + + def forward(self, input): + # train with QuantNoise and evaluate the fully quantized network + p = self.p if self.training else 1 + + # update parameters every 100 iterations + if self.counter % self.update_step == 0: + self.scale = None + self.zero_point = None + self.counter += 1 + + # quantize weight + weight_quantized, self.scale, self.zero_point = emulate_int( + self.weight.detach(), + bits=self.bits, + method=self.method, + scale=self.scale, + zero_point=self.zero_point, + ) + + # mask to apply noise + mask = torch.zeros_like(self.weight) + mask.bernoulli_(1 - p) + noise = (weight_quantized - self.weight).masked_fill(mask.bool(), 0) + + # using straight-through estimator (STE) + clamp_low = -self.scale * self.zero_point + clamp_high = self.scale * (2**self.bits - 1 - self.zero_point) + weight = ( + torch.clamp(self.weight, clamp_low.item(), clamp_high.item()) + + noise.detach() + ) + + # return output + output = self._conv_forward(input, weight) + return output + + def extra_repr(self): + return ( + "in_channels={}, out_channels={}, kernel_size={}, stride={}, " + "padding={}, dilation={}, groups={}, bias={}, quant_noise={}, " + "bits={}, method={}".format( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + self.dilation, + self.groups, + self.bias is not None, + self.p, + self.bits, + self.method, + ) + ) diff --git a/fairseq/modules/quantization/scalar/modules/qemb.py b/fairseq/modules/quantization/scalar/modules/qemb.py new file mode 100644 index 0000000000000000000000000000000000000000..3b293ac31e9ab683c47e05b763c077ec181e64f7 --- /dev/null +++ b/fairseq/modules/quantization/scalar/modules/qemb.py @@ -0,0 +1,147 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..ops import emulate_int + + +class IntEmbedding(nn.Module): + """ + Quantized counterpart of the nn.Embedding module that applies QuantNoise during training. + + Args: + - num_embeddings: number of tokens + - embedding_dim: embedding dimension + - p: amount of noise to inject (0 = no quantization, 1 = quantize all the weights) + - bits: number of bits + - method: choose among {"tensor", "histogram", "channel"} + - update_step: recompute scale and zero_point every update_steps iterations + + Remarks: + - We use the straight-through estimator so that the gradients + back-propagate nicely in the network, this is implemented with + the detach() trick + - Parameters scale and zero_point are recomputed every update_step + forward pass to reduce the overhead + - At test time, the weights are fully quantized + """ + + def __init__( + self, + num_embeddings, + embedding_dim, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + p=0, + update_step=1000, + bits=8, + method="histogram", + ): + super(IntEmbedding, self).__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + if padding_idx > 0: + assert ( + padding_idx < self.num_embeddings + ), "Padding_idx must be within num_embeddings" + elif padding_idx < 0: + assert ( + padding_idx >= -self.num_embeddings + ), "Padding_idx must be within num_embeddings" + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + if _weight is None: + self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) + self.reset_parameters() + else: + assert list(_weight.shape) == [ + num_embeddings, + embedding_dim, + ], "Shape of weight does not match num_embeddings and embedding_dim" + self.weight = nn.Parameter(_weight) + self.sparse = sparse + + # quantization parameters + self.p = p + self.bits = bits + self.method = method + self.update_step = update_step + self.counter = 0 + + def reset_parameters(self): + nn.init.normal_(self.weight) + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input): + # train with QuantNoise and evaluate the fully quantized network + p = self.p if self.training else 1 + + # update parameters every 1000 iterations + if self.counter % self.update_step == 0: + self.scale = None + self.zero_point = None + self.counter += 1 + + # quantize weight + weight_quantized, self.scale, self.zero_point = emulate_int( + self.weight.detach(), + bits=self.bits, + method=self.method, + scale=self.scale, + zero_point=self.zero_point, + ) + + # mask to apply noise + mask = torch.zeros_like(self.weight) + mask.bernoulli_(1 - p) + noise = (weight_quantized - self.weight).masked_fill(mask.bool(), 0) + + # using straight-through estimator (STE) + clamp_low = -self.scale * self.zero_point + clamp_high = self.scale * (2**self.bits - 1 - self.zero_point) + weight = ( + torch.clamp(self.weight, clamp_low.item(), clamp_high.item()) + + noise.detach() + ) + + # return output + output = F.embedding( + input, + weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + return output + + def extra_repr(self): + s = "{num_embeddings}, {embedding_dim}" + if self.padding_idx is not None: + s += ", padding_idx={padding_idx}" + if self.max_norm is not None: + s += ", max_norm={max_norm}" + if self.norm_type != 2: + s += ", norm_type={norm_type}" + if self.scale_grad_by_freq is not False: + s += ", scale_grad_by_freq={scale_grad_by_freq}" + if self.sparse is not False: + s += ", sparse=True" + s += "quant_noise={p}, bits={bits}, method={method}" + return s.format(**self.__dict__) diff --git a/fairseq/modules/quantization/scalar/modules/qlinear.py b/fairseq/modules/quantization/scalar/modules/qlinear.py new file mode 100644 index 0000000000000000000000000000000000000000..78606a25b98b69b42eef72742037616b21116133 --- /dev/null +++ b/fairseq/modules/quantization/scalar/modules/qlinear.py @@ -0,0 +1,113 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..ops import emulate_int + + +class IntLinear(nn.Module): + """ + Quantized counterpart of the nn.Linear module that applies QuantNoise during training. + + Args: + - in_features: input features + - out_features: output features + - bias: bias or not + - p: amount of noise to inject (0 = no quantization, 1 = quantize all the weights) + - bits: number of bits + - method: choose among {"tensor", "histogram", "channel"} + - update_step: recompute scale and zero_point every update_steps iterations + + Remarks: + - We use the straight-through estimator so that the gradients + back-propagate nicely in the network, this is implemented with + the detach() trick. + - Parameters scale and zero_point are recomputed every update_step + forward pass to reduce the overhead + - At test time, the weights are fully quantized + """ + + def __init__( + self, + in_features, + out_features, + bias=True, + p=0, + update_step=3000, + bits=8, + method="histogram", + ): + super(IntLinear, self).__init__() + self.in_features = int(in_features) + self.out_features = int(out_features) + self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) + self.chosen_bias = bias + if self.chosen_bias: + self.bias = torch.nn.Parameter(torch.Tensor(out_features)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + # quantization parameters + self.p = p + self.bits = bits + self.method = method + self.update_step = update_step + self.counter = 0 + + def reset_parameters(self): + nn.init.xavier_uniform_(self.weight) + if self.chosen_bias: + nn.init.constant_(self.bias, 0.0) + return + + def forward(self, input): + # train with QuantNoise and evaluate the fully quantized network + p = self.p if self.training else 1 + + # update parameters every 100 iterations + if self.counter % self.update_step == 0: + self.scale = None + self.zero_point = None + self.counter += 1 + + # quantize weight + weight_quantized, self.scale, self.zero_point = emulate_int( + self.weight.detach(), + bits=self.bits, + method=self.method, + scale=self.scale, + zero_point=self.zero_point, + ) + + # mask to apply noise + mask = torch.zeros_like(self.weight) + mask.bernoulli_(1 - p) + noise = (weight_quantized - self.weight).masked_fill(mask.bool(), 0) + + # using straight-through estimator (STE) + clamp_low = -self.scale * self.zero_point + clamp_high = self.scale * (2**self.bits - 1 - self.zero_point) + weight = ( + torch.clamp(self.weight, clamp_low.item(), clamp_high.item()) + + noise.detach() + ) + + # return output + output = F.linear(input, weight, self.bias) + return output + + def extra_repr(self): + return "in_features={}, out_features={}, bias={}, quant_noise={}, bits={}, method={}".format( + self.in_features, + self.out_features, + self.bias is not None, + self.p, + self.bits, + self.method, + ) diff --git a/fairseq/modules/quantization/scalar/ops.py b/fairseq/modules/quantization/scalar/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e0f9a0c1f83112e554cbbea37aafa7eea27900db --- /dev/null +++ b/fairseq/modules/quantization/scalar/ops.py @@ -0,0 +1,59 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +try: + import torch.ao.quantization as quantization +except ImportError: + import torch.quantization as quantization + + +def emulate_int(w, bits, method, scale=None, zero_point=None): + q = globals()[f"emulate_int8_{method}"] + return q(w, scale=scale, zero_point=zero_point, bits=bits) + + +def quantize(w, scale, zero_point, bits=8): + # In the default behavior, max_val = 255. + max_val = 2**bits - 1 + return ( + torch.clamp(torch.round(w / scale + zero_point), 0, max_val) - zero_point + ) * scale + + +def emulate_int8_histogram(w, scale=None, zero_point=None, bits=8): + if scale is None: + obs = quantization.observer.HistogramObserver() + obs.to(device=w.device) + _ = obs(w.float()) + scale, zero_point = obs.calculate_qparams() + scale = scale.cuda().type_as(w) + zero_point = zero_point.cuda().type_as(w) + return quantize(w, scale, zero_point, bits=bits), scale, zero_point + + +def emulate_int8_channel(w, scale=None, zero_point=None, bits=8): + if scale is None: + obs = quantization.observer.PerChannelMinMaxObserver( + ch_axis=-1, qscheme=torch.per_channel_symmetric + ) + obs.to(device=w.device) + _ = obs(w) + scale, zero_point, ch_axis = obs.get_qparams() + scale = scale.cuda().type_as(w) + zero_point = zero_point.cuda().type_as(w) + return quantize(w, scale, zero_point, bits=bits), scale, zero_point + + +def emulate_int8_tensor(w, scale=None, zero_point=None, bits=8): + if scale is None: + obs = quantization.observer.MinMaxObserver() + obs.to(device=w.device) + _ = obs(w) + scale, zero_point = obs.calculate_qparams() + scale = scale.cuda().type_as(w) + zero_point = zero_point.cuda().type_as(w) + return quantize(w, scale, zero_point, bits=bits), scale, zero_point diff --git a/fairseq/modules/quantization/scalar/utils.py b/fairseq/modules/quantization/scalar/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d4b1cc255b931098bd63ef277dd4a609106326d5 --- /dev/null +++ b/fairseq/modules/quantization/scalar/utils.py @@ -0,0 +1,80 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from operator import attrgetter + +import torch.distributed as dist +import torch.nn as nn + +from ..pq.utils import attrsetter, get_layers +from .modules import ActivationQuantizer, IntConv2d, IntEmbedding, IntLinear + + +MAPPING = {nn.Linear: IntLinear, nn.Embedding: IntEmbedding, nn.Conv2d: IntConv2d} + + +def quantize_model_( + model, p=0.2, bits=8, update_step=3000, method="histogram", remove_weights=False +): + """ + Replaces all modules with their scalar quantized counterpart and + registers hooks to quantize the post-ativations of those modules. + + Args: + - model: a nn.Module + - p: amount of noise (0 for no noise, 1 to quantize all the weights/activations) + - bits: number of bits + - update_step: update quantization parameters every update_step steps + """ + # quantize all layers + # remove weights indicates whether the weights extension should be removed, in addition to + # weight_orig and weight extension on names + quantized_layers = get_layers(model, "(.*?)", remove_weights=remove_weights) + + for layer in quantized_layers: + + # book-keeping + is_master_process = (not dist.is_initialized()) or ( + dist.is_initialized() and dist.get_rank() == 0 + ) + + # recover module + module = attrgetter(layer)(model) + if is_master_process: + logging.info( + f"Quantizing layer {layer} with bits={bits} and QuantNoise={p}" + ) + + # quantization params + q_params = { + "p": p, + "update_step": update_step, + "bits": bits, + "method": method, + "counter": 0, + } + + # instantiate the quantized counterpart + if isinstance(module, tuple(MAPPING.keys())): + QuantizedModule = MAPPING[module.__class__] + quantized_module = QuantizedModule.__new__(QuantizedModule) + params = module.__dict__ + params.update(q_params) + quantized_module.__dict__.update(params) + + else: + if is_master_process: + logging.info(f"Module {module} not yet supported for quantization") + continue + + # activation quantization + a_q = ActivationQuantizer(quantized_module, p=0, bits=bits, method=method) + + # replace layer by its quantized counterpart + attrsetter(layer)(model, quantized_module) + + # return name of quantized layers + return quantized_layers diff --git a/fairseq/modules/rotary_positional_embedding.py b/fairseq/modules/rotary_positional_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..b74028b0117ae94ff8567f45f37e67ae1079ee72 --- /dev/null +++ b/fairseq/modules/rotary_positional_embedding.py @@ -0,0 +1,50 @@ +import torch + + +class RotaryPositionalEmbedding(torch.nn.Module): + def __init__(self, dim, base=10000, precision=torch.half): + """Rotary positional embedding + Reference : https://blog.eleuther.ai/rotary-embeddings/ + Paper: https://arxiv.org/pdf/2104.09864.pdf + Args: + dim: Dimension of embedding + base: Base value for exponential + precision: precision to use for numerical values + """ + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + self.seq_len_cached = 0 + self.cos_cached = torch.empty(self.seq_len_cached, 1, 1, dim) + self.sin_cached = torch.empty(self.seq_len_cached, 1, 1, dim) + self.precision = precision + + def forward(self, x, seq_len: int = 0): + """ + Args: + x: Input x with T X B X C + seq_len: Sequence length of input x + """ + if seq_len > self.seq_len_cached: + self.seq_len_cached = seq_len + t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.cos_cached = emb.cos().view(emb.size(0), 1, 1, emb.size(1)) + self.sin_cached = emb.sin().view(emb.size(0), 1, 1, emb.size(1)) + return self.cos_cached, self.sin_cached + +# rotary pos emb helpers: +def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat( + (-x2, x1), dim=x1.ndim - 1 + ) # dim=-1 triggers a bug in earlier torch versions + + +def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): + cos, sin = ( + cos[offset : q.shape[0] + offset, ...], + sin[offset : q.shape[0] + offset, ...], + ) + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) diff --git a/fairseq/modules/same_pad.py b/fairseq/modules/same_pad.py new file mode 100644 index 0000000000000000000000000000000000000000..a3ce4131c686fddeb86b4392b0c662c8b8100ff3 --- /dev/null +++ b/fairseq/modules/same_pad.py @@ -0,0 +1,33 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from torch import nn + + +class SamePad(nn.Module): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class SamePad2d(nn.Module): + def __init__(self, kernel_size): + super().__init__() + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + assert len(x.size()) == 4 + if self.remove > 0: + x = x[:, :, : -self.remove, : -self.remove] + return x diff --git a/fairseq/modules/scalar_bias.py b/fairseq/modules/scalar_bias.py new file mode 100644 index 0000000000000000000000000000000000000000..c96247c75914fabb8a2b7ff731bb82b588f72690 --- /dev/null +++ b/fairseq/modules/scalar_bias.py @@ -0,0 +1,31 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch + + +class ScalarBias(torch.autograd.Function): + """ + Adds a vector of scalars, used in self-attention mechanism to allow + the model to optionally attend to this vector instead of the past + """ + + @staticmethod + def forward(ctx, input, dim, bias_init): + size = list(input.size()) + size[dim] += 1 + output = input.new(*size).fill_(bias_init) + output.narrow(dim, 1, size[dim] - 1).copy_(input) + ctx.dim = dim + return output + + @staticmethod + def backward(ctx, grad): + return grad.narrow(ctx.dim, 1, grad.size(ctx.dim) - 1), None, None + + +def scalar_bias(input, dim, bias_init=0): + return ScalarBias.apply(input, dim, bias_init) diff --git a/fairseq/modules/sinusoidal_positional_embedding.py b/fairseq/modules/sinusoidal_positional_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..e7ecd0f2c8bc08cb9fafac8f8e54a3b560227770 --- /dev/null +++ b/fairseq/modules/sinusoidal_positional_embedding.py @@ -0,0 +1,111 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Any, Optional + +import torch +import torch.onnx.operators +from fairseq import utils +from torch import nn, Tensor + + +class SinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length. + + Padding symbols are ignored. + """ + + def __init__(self, embedding_dim, padding_idx, init_size=1024): + super().__init__() + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx if padding_idx is not None else 0 + self.register_buffer("weights", SinusoidalPositionalEmbedding.get_embedding( + init_size, embedding_dim, padding_idx + ), persistent=False) + self.max_positions = int(1e5) + self.onnx_trace = False + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + # Ignore some deprecated keys that were used in older versions + deprecated_keys = ["weights", "_float_tensor"] + for key in deprecated_keys: + if prefix + key in state_dict: + del state_dict[prefix + key] + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + @staticmethod + def get_embedding( + num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None + ): + """Build sinusoidal embeddings. + + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) + emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze( + 1 + ) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view( + num_embeddings, -1 + ) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + return emb + + def forward( + self, + input, + incremental_state: Optional[Any] = None, + timestep: Optional[Tensor] = None, + positions: Optional[Any] = None, + ): + """Input is expected to be of size [bsz x seqlen].""" + bspair = torch.onnx.operators.shape_as_tensor(input) + bsz, seq_len = bspair[0], bspair[1] + max_pos = self.padding_idx + 1 + seq_len + if max_pos > self.weights.size(0): + # expand embeddings if needed + self.weights = SinusoidalPositionalEmbedding.get_embedding( + max_pos, self.embedding_dim, self.padding_idx + ).to(self.weights) + + if incremental_state is not None: + # positions is the same for every token when decoding a single step + pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len + if self.onnx_trace: + return ( + self.weights.index_select(index=self.padding_idx + pos, dim=0) + .unsqueeze(1) + .repeat(bsz, 1, 1) + ) + return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) + + positions = utils.make_positions( + input, self.padding_idx, onnx_trace=self.onnx_trace + ) + if self.onnx_trace: + flat_embeddings = self.weights.detach().index_select(0, positions.view(-1)) + embedding_shape = torch.cat( + (bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long)) + ) + embeddings = torch.onnx.operators.reshape_from_tensor_shape( + flat_embeddings, embedding_shape + ) + return embeddings + return ( + self.weights.index_select(0, positions.view(-1)) + .view(bsz, seq_len, -1) + .detach() + ) diff --git a/fairseq/modules/sparse_multihead_attention.py b/fairseq/modules/sparse_multihead_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..3cbd9d6785886e319aab0601517e27df733b6f97 --- /dev/null +++ b/fairseq/modules/sparse_multihead_attention.py @@ -0,0 +1,140 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch + +from .multihead_attention import MultiheadAttention + + +class SparseMultiheadAttention(MultiheadAttention): + """Sparse Multi-Headed Attention. + + "Generating Long Sequences with Sparse Transformers". Implements + fixed factorized self attention, where l=stride and c=expressivity. + A(1) includes all words in the stride window and A(2) takes a summary of c + words from the end of each stride window. + If is_bidirectional=False, we do not include any words past the current word, + as in the paper. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + stride=32, + expressivity=8, + is_bidirectional=True, + ): + + super().__init__( + embed_dim, + num_heads, + kdim, + vdim, + dropout, + bias, + add_bias_kv, + add_zero_attn, + self_attention, + encoder_decoder_attention, + ) + + self.is_bidirectional = is_bidirectional + self.stride = stride + self.expressivity = expressivity + assert self.stride > 0 and self.stride >= self.expressivity + + # Used for Ai(2) calculations - beginning of [l-c, l] range + def compute_checkpoint(self, word_index): + if word_index % self.stride == 0 and word_index != 0: + checkpoint_index = word_index - self.expressivity + else: + checkpoint_index = ( + math.floor(word_index / self.stride) * self.stride + + self.stride + - self.expressivity + ) + return checkpoint_index + + # Computes Ai(2) + def compute_subset_summaries(self, absolute_max): + checkpoint_index = self.compute_checkpoint(0) + subset_two = set() + while checkpoint_index <= absolute_max - 1: + summary = set( + range( + checkpoint_index, + min(checkpoint_index + self.expressivity + 1, absolute_max), + ) + ) + subset_two = subset_two.union(summary) + checkpoint_index = self.compute_checkpoint(checkpoint_index + self.stride) + return subset_two + + # Sparse Transformer Fixed Attention Pattern: https://arxiv.org/pdf/1904.10509.pdf + def compute_fixed_attention_subset(self, word_index, tgt_len): + # +1s account for range function; [min, max) -> [min, max] + if not self.is_bidirectional: + absolute_max = word_index + 1 + else: + absolute_max = tgt_len + + # Subset 1 - whole window + rounded_index = ( + math.floor((word_index + self.stride) / self.stride) * self.stride + ) + if word_index % self.stride == 0 and word_index != 0: + subset_one = set( + range(word_index - self.stride, min(absolute_max, word_index + 1)) + ) + else: + subset_one = set( + range( + max(0, rounded_index - self.stride), + min(absolute_max, rounded_index + 1), + ) + ) + + # Subset 2 - summary per window + # If bidirectional, subset 2 is the same for every index + subset_two = set() + if not self.is_bidirectional: + subset_two = self.compute_subset_summaries(absolute_max) + + return subset_one.union(subset_two) + + # Compute sparse mask - if bidirectional, can pre-compute and store + def buffered_sparse_mask(self, tensor, tgt_len, src_len): + assert tgt_len > self.stride + sparse_mask = torch.empty((tgt_len, src_len)).float().fill_(float("-inf")) + + # If bidirectional, subset 2 is the same for every index + subset_summaries = set() + if self.is_bidirectional: + subset_summaries = self.compute_subset_summaries(tgt_len) + + for i in range(tgt_len): + fixed_attention_subset = self.compute_fixed_attention_subset(i, tgt_len) + fixed_attention_subset = fixed_attention_subset.union(subset_summaries) + included_word_indices = torch.LongTensor(list(fixed_attention_subset)) + sparse_mask[i].index_fill_(0, included_word_indices, 0) + return sparse_mask.type_as(tensor) + + def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz): + sparse_mask = self.buffered_sparse_mask(attn_weights, tgt_len, src_len) + sparse_mask = sparse_mask.unsqueeze(0).expand( + bsz * self.num_heads, tgt_len, src_len + ) + attn_weights += sparse_mask diff --git a/fairseq/modules/sparse_transformer_sentence_encoder.py b/fairseq/modules/sparse_transformer_sentence_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f41ec09327fe80b50d20674e7482794ce45c531c --- /dev/null +++ b/fairseq/modules/sparse_transformer_sentence_encoder.py @@ -0,0 +1,96 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn +from fairseq.modules import TransformerSentenceEncoder +from fairseq.modules.sparse_transformer_sentence_encoder_layer import ( + SparseTransformerSentenceEncoderLayer, +) + + +class SparseTransformerSentenceEncoder(TransformerSentenceEncoder): + """ + Sparse implementation of the TransformerSentenceEncoder + - see SparseMultiheadAttention + """ + + def __init__( + self, + padding_idx: int, + vocab_size: int, + num_encoder_layers: int = 6, + embedding_dim: int = 768, + ffn_embedding_dim: int = 3072, + num_attention_heads: int = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + max_seq_len: int = 256, + num_segments: int = 2, + use_position_embeddings: bool = True, + offset_positions_by_padding: bool = True, + encoder_normalize_before: bool = False, + apply_bert_init: bool = False, + activation_fn: str = "relu", + learned_pos_embedding: bool = True, + embed_scale: float = None, + freeze_embeddings: bool = False, + n_trans_layers_to_freeze: int = 0, + export: bool = False, + is_bidirectional: bool = True, + stride: int = 32, + expressivity: int = 8, + ) -> None: + + super().__init__( + padding_idx, + vocab_size, + num_encoder_layers, + embedding_dim, + ffn_embedding_dim, + num_attention_heads, + dropout, + attention_dropout, + activation_dropout, + max_seq_len, + num_segments, + use_position_embeddings, + offset_positions_by_padding, + encoder_normalize_before, + apply_bert_init, + activation_fn, + learned_pos_embedding, + embed_scale, + freeze_embeddings, + n_trans_layers_to_freeze, + export, + ) + + self.layers = nn.ModuleList( + [ + SparseTransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=ffn_embedding_dim, + num_attention_heads=num_attention_heads, + dropout=dropout, + attention_dropout=attention_dropout, + activation_dropout=activation_dropout, + activation_fn=activation_fn, + export=export, + is_bidirectional=is_bidirectional, + stride=stride, + expressivity=expressivity, + ) + for _ in range(num_encoder_layers) + ] + ) + + def freeze_module_params(m): + if m is not None: + for p in m.parameters(): + p.requires_grad = False + + for layer in range(n_trans_layers_to_freeze): + freeze_module_params(self.layers[layer]) diff --git a/fairseq/modules/sparse_transformer_sentence_encoder_layer.py b/fairseq/modules/sparse_transformer_sentence_encoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..d95da59c2471bfa858fd627605196d7f41f9ec12 --- /dev/null +++ b/fairseq/modules/sparse_transformer_sentence_encoder_layer.py @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.modules import TransformerSentenceEncoderLayer +from fairseq.modules.sparse_multihead_attention import SparseMultiheadAttention + + +class SparseTransformerSentenceEncoderLayer(TransformerSentenceEncoderLayer): + """ + Implements a Sprase Transformer Encoder Layer (see SparseMultiheadAttention) + """ + + def __init__( + self, + embedding_dim: int = 768, + ffn_embedding_dim: int = 3072, + num_attention_heads: int = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + export: bool = False, + is_bidirectional: bool = True, + stride: int = 32, + expressivity: int = 8, + ) -> None: + + super().__init__( + embedding_dim, + ffn_embedding_dim, + num_attention_heads, + dropout, + attention_dropout, + activation_dropout, + activation_fn, + export, + ) + + self.self_attn = SparseMultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + add_bias_kv=False, + add_zero_attn=False, + self_attention=True, + is_bidirectional=is_bidirectional, + stride=stride, + expressivity=expressivity, + ) diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..19e035dec53fcfc2ed832d9349c98cdd28824e94 --- /dev/null +++ b/fairseq/modules/transformer_layer.py @@ -0,0 +1,562 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +from torch import Tensor + +from fairseq import utils +from fairseq.models.transformer import TransformerConfig +from fairseq.modules import LayerNorm, MultiheadAttention +from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.quant_noise import quant_noise + + +class TransformerEncoderLayerBase(nn.Module): + """Encoder layer block. + + In the original paper each operation (multi-head attention or FFN) is + postprocessed with: `dropout -> add residual -> layernorm`. In the + tensor2tensor code they suggest that learning is more robust when + preprocessing each layer with layernorm and postprocessing with: + `dropout -> add residual`. We default to the approach in the paper, but the + tensor2tensor approach can be enabled by setting + *cfg.encoder.normalize_before* to ``True``. + + Args: + cfg (argparse.Namespace): parsed command-line arguments + """ + + def __init__(self, cfg, return_fc=False): + super().__init__() + self.cfg = cfg + self.return_fc = return_fc + self.embed_dim = cfg.encoder.embed_dim + self.quant_noise = cfg.quant_noise.pq + self.quant_noise_block_size = cfg.quant_noise.pq_block_size + self.self_attn = self.build_self_attention(self.embed_dim, cfg) + self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) + self.dropout_module = FairseqDropout( + cfg.dropout, module_name=self.__class__.__name__ + ) + self.activation_fn = utils.get_activation_fn(activation=cfg.activation_fn) + activation_dropout_p = cfg.activation_dropout + if activation_dropout_p == 0: + # for backwards compatibility with models that use cfg.relu_dropout + activation_dropout_p = cfg.relu_dropout or 0 + self.activation_dropout_module = FairseqDropout( + float(activation_dropout_p), module_name=self.__class__.__name__ + ) + self.normalize_before = cfg.encoder.normalize_before + self.fc1 = self.build_fc1( + self.embed_dim, + cfg.encoder.ffn_embed_dim, + self.quant_noise, + self.quant_noise_block_size, + ) + self.fc2 = self.build_fc2( + cfg.encoder.ffn_embed_dim, + self.embed_dim, + self.quant_noise, + self.quant_noise_block_size, + ) + + self.final_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) + + def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): + return quant_noise( + nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size + ) + + def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): + return quant_noise( + nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size + ) + + def _get_fc_rank(self, remove_num: int) -> List[int]: + f1_filter_param = [] + for i in range(self.fc1.out_features): + f1_filter_param.append( + torch.sum(torch.abs(self.fc1.weight[i])) + + torch.sum(torch.abs(self.fc2.weight[:, i])) + + torch.abs(self.fc1.bias[i]) + ) + return sorted( + range(len(f1_filter_param)), key=lambda k: f1_filter_param[k], reverse=False + )[0:remove_num] + + def _prune_fc_layer(self, remove_index: List[int]): + new_fc1_weight = [] + new_fc1_bias = [] + for i in range(self.fc1.out_features): + if i not in remove_index: + new_fc1_weight.append(self.fc1.weight[i]) + new_fc1_bias.append(self.fc1.bias[i]) + + new_fc1_weight = torch.stack(new_fc1_weight).detach() + new_fc1_weight.requires_grad = True + + new_fc1_bias = torch.stack(new_fc1_bias).detach() + new_fc1_bias.requires_grad = True + + self.fc1 = quant_noise( + nn.Linear(self.fc1.in_features, self.fc1.out_features - len(remove_index)), + p=self.quant_noise, + block_size=self.quant_noise_block_size, + ) + self.fc1.weight = torch.nn.Parameter(new_fc1_weight) + self.fc1.bias = torch.nn.Parameter(new_fc1_bias) + + new_fc2_weight = [] + new_fc2_bias = [] + for i in range(self.fc2.in_features): + if i not in remove_index: + new_fc2_weight.append(self.fc2.weight[:, i]) + new_fc2_bias = self.fc2.bias.detach() + + new_fc2_weight = torch.stack(new_fc2_weight, dim=-1).detach() + new_fc2_weight.requires_grad = True + + new_fc2_bias = self.fc2.bias.detach() + new_fc2_bias.requires_grad = True + + self.fc2 = quant_noise( + nn.Linear(self.fc2.in_features - len(remove_index), self.fc2.out_features), + p=self.quant_noise, + block_size=self.quant_noise_block_size, + ) + self.fc2.weight = torch.nn.Parameter(new_fc2_weight) + self.fc2.bias = torch.nn.Parameter(new_fc2_bias) + + def build_self_attention(self, embed_dim, cfg): + return MultiheadAttention( + embed_dim, + cfg.encoder.attention_heads, + dropout=cfg.attention_dropout, + self_attention=True, + q_noise=self.quant_noise, + qn_block_size=self.quant_noise_block_size, + xformers_att_config=cfg.encoder.xformers_att_config, + ) + + def residual_connection(self, x, residual): + return residual + x + + def upgrade_state_dict_named(self, state_dict, name): + """ + Rename layer norm states from `...layer_norms.0.weight` to + `...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to + `...final_layer_norm.weight` + """ + layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"} + for old, new in layer_norm_map.items(): + for m in ("weight", "bias"): + k = "{}.layer_norms.{}.{}".format(name, old, m) + if k in state_dict: + state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k] + del state_dict[k] + + def forward( + self, + x, + encoder_padding_mask: Optional[Tensor], + attn_mask: Optional[Tensor] = None, + ): + """ + Args: + x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_padding_mask (ByteTensor): binary ByteTensor of shape + `(batch, seq_len)` where padding elements are indicated by ``1``. + attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`, + where `tgt_len` is the length of output and `src_len` is the + length of input, though here both are equal to `seq_len`. + `attn_mask[tgt_i, src_j] = 1` means that when calculating the + embedding for `tgt_i`, we exclude (mask out) `src_j`. This is + useful for strided self-attention. + + Returns: + encoded output of shape `(seq_len, batch, embed_dim)` + """ + # anything in original attn_mask = 1, becomes -1e8 + # anything in original attn_mask = 0, becomes 0 + # Note that we cannot use -inf here, because at some edge cases, + # the attention weight (before softmax) for some padded element in query + # will become -inf, which results in NaN in model parameters + if attn_mask is not None: + attn_mask = attn_mask.masked_fill( + attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4 + ) + + residual = x + if self.normalize_before: + x = self.self_attn_layer_norm(x) + x, _ = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=encoder_padding_mask, + need_weights=False, + attn_mask=attn_mask, + ) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.self_attn_layer_norm(x) + + residual = x + if self.normalize_before: + x = self.final_layer_norm(x) + x = self.activation_fn(self.fc1(x)) + x = self.activation_dropout_module(x) + x = self.fc2(x) + + fc_result = x + + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.final_layer_norm(x) + + if self.return_fc and not torch.jit.is_scripting(): + return x, fc_result + return x + + +# backward compatible with the legacy argparse format +class TransformerEncoderLayer(TransformerEncoderLayerBase): + def __init__(self, args): + super().__init__(TransformerConfig.from_namespace(args)) + self.args = args + + def build_self_attention(self, embed_dim, args): + return super().build_self_attention( + embed_dim, TransformerConfig.from_namespace(args) + ) + + +class TransformerDecoderLayerBase(nn.Module): + """Decoder layer block. + + In the original paper each operation (multi-head attention, encoder + attention or FFN) is postprocessed with: `dropout -> add residual -> + layernorm`. In the tensor2tensor code they suggest that learning is more + robust when preprocessing each layer with layernorm and postprocessing with: + `dropout -> add residual`. We default to the approach in the paper, but the + tensor2tensor approach can be enabled by setting + *cfg.decoder.normalize_before* to ``True``. + + Args: + args (argparse.Namespace): parsed command-line arguments + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__( + self, cfg, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False + ): + super().__init__() + self.embed_dim = cfg.decoder.embed_dim + self.dropout_module = FairseqDropout( + cfg.dropout, module_name=self.__class__.__name__ + ) + self.quant_noise = cfg.quant_noise.pq + self.quant_noise_block_size = cfg.quant_noise.pq_block_size + + self.cross_self_attention = cfg.cross_self_attention + + self.self_attn = self.build_self_attention( + self.embed_dim, + cfg, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + ) + self.attn_ln = ( + LayerNorm(self.embed_dim) + if utils.safe_getattr(cfg, "scale_attn", False) + else None + ) + self.nh = self.self_attn.num_heads + self.head_dim = self.self_attn.head_dim + scale_heads = utils.safe_getattr(cfg, "scale_heads", False) + self.c_attn = ( + nn.Parameter(torch.ones((self.nh,)), requires_grad=True) + if scale_heads + else None + ) + + self.activation_fn = utils.get_activation_fn(activation=cfg.activation_fn) + activation_dropout_p = cfg.activation_dropout + if activation_dropout_p == 0: + # for backwards compatibility with models that use cfg.relu_dropout + activation_dropout_p = cfg.relu_dropout or 0 + self.activation_dropout_module = FairseqDropout( + float(activation_dropout_p), module_name=self.__class__.__name__ + ) + self.normalize_before = cfg.decoder.normalize_before + + self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) + + if no_encoder_attn: + self.encoder_attn = None + self.encoder_attn_layer_norm = None + else: + self.encoder_attn = self.build_encoder_attention(self.embed_dim, cfg) + self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) + + self.ffn_layernorm = ( + LayerNorm(cfg.decoder.ffn_embed_dim) + if utils.safe_getattr(cfg, "scale_fc", False) + else None + ) + self.w_resid = ( + nn.Parameter( + torch.ones( + self.embed_dim, + ), + requires_grad=True, + ) + if utils.safe_getattr(cfg, "scale_resids", False) + else None + ) + + self.fc1 = self.build_fc1( + self.embed_dim, + cfg.decoder.ffn_embed_dim, + self.quant_noise, + self.quant_noise_block_size, + ) + self.fc2 = self.build_fc2( + cfg.decoder.ffn_embed_dim, + self.embed_dim, + self.quant_noise, + self.quant_noise_block_size, + ) + + self.final_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) + self.need_attn = True + + self.onnx_trace = False + + def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): + return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) + + def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): + return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) + + def build_self_attention( + self, embed_dim, cfg, add_bias_kv=False, add_zero_attn=False + ): + return MultiheadAttention( + embed_dim, + cfg.decoder.attention_heads, + dropout=cfg.attention_dropout, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + self_attention=not cfg.cross_self_attention, + q_noise=self.quant_noise, + qn_block_size=self.quant_noise_block_size, + xformers_att_config=cfg.decoder.xformers_att_config, + ) + + def build_encoder_attention(self, embed_dim, cfg): + return MultiheadAttention( + embed_dim, + cfg.decoder.attention_heads, + kdim=cfg.encoder.embed_dim, + vdim=cfg.encoder.embed_dim, + dropout=cfg.attention_dropout, + encoder_decoder_attention=True, + q_noise=self.quant_noise, + qn_block_size=self.quant_noise_block_size, + xformers_att_config=cfg.encoder.xformers_att_config, + ) + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def residual_connection(self, x, residual): + return residual + x + + def forward( + self, + x, + encoder_out: Optional[torch.Tensor] = None, + encoder_padding_mask: Optional[torch.Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + prev_self_attn_state: Optional[List[torch.Tensor]] = None, + prev_attn_state: Optional[List[torch.Tensor]] = None, + self_attn_mask: Optional[torch.Tensor] = None, + self_attn_padding_mask: Optional[torch.Tensor] = None, + need_attn: bool = False, + need_head_weights: bool = False, + ): + """ + Args: + x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_padding_mask (ByteTensor, optional): binary + ByteTensor of shape `(batch, src_len)` where padding + elements are indicated by ``1``. + need_attn (bool, optional): return attention weights + need_head_weights (bool, optional): return attention weights + for each head (default: return average over heads). + + Returns: + encoded output of shape `(seq_len, batch, embed_dim)` + """ + if need_head_weights: + need_attn = True + + residual = x + if self.normalize_before: + x = self.self_attn_layer_norm(x) + if prev_self_attn_state is not None: + prev_key, prev_value = prev_self_attn_state[:2] + saved_state: Dict[str, Optional[Tensor]] = { + "prev_key": prev_key, + "prev_value": prev_value, + } + if len(prev_self_attn_state) >= 3: + saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] + assert incremental_state is not None + self.self_attn._set_input_buffer(incremental_state, saved_state) + _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state) + if self.cross_self_attention and not ( + incremental_state is not None + and _self_attn_input_buffer is not None + and "prev_key" in _self_attn_input_buffer + ): + if self_attn_mask is not None: + assert encoder_out is not None + self_attn_mask = torch.cat( + (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1 + ) + if self_attn_padding_mask is not None: + if encoder_padding_mask is None: + assert encoder_out is not None + encoder_padding_mask = self_attn_padding_mask.new_zeros( + encoder_out.size(1), encoder_out.size(0) + ) + self_attn_padding_mask = torch.cat( + (encoder_padding_mask, self_attn_padding_mask), dim=1 + ) + assert encoder_out is not None + y = torch.cat((encoder_out, x), dim=0) + else: + y = x + + x, attn = self.self_attn( + query=x, + key=y, + value=y, + key_padding_mask=self_attn_padding_mask, + incremental_state=incremental_state, + need_weights=False, + attn_mask=self_attn_mask, + ) + if self.c_attn is not None: + tgt_len, bsz = x.size(0), x.size(1) + x = x.view(tgt_len, bsz, self.nh, self.head_dim) + x = torch.einsum("tbhd,h->tbhd", x, self.c_attn) + x = x.reshape(tgt_len, bsz, self.embed_dim) + if self.attn_ln is not None: + x = self.attn_ln(x) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.self_attn_layer_norm(x) + + if self.encoder_attn is not None and encoder_out is not None: + residual = x + if self.normalize_before: + x = self.encoder_attn_layer_norm(x) + if prev_attn_state is not None: + prev_key, prev_value = prev_attn_state[:2] + saved_state: Dict[str, Optional[Tensor]] = { + "prev_key": prev_key, + "prev_value": prev_value, + } + if len(prev_attn_state) >= 3: + saved_state["prev_key_padding_mask"] = prev_attn_state[2] + assert incremental_state is not None + self.encoder_attn._set_input_buffer(incremental_state, saved_state) + + x, attn = self.encoder_attn( + query=x, + key=encoder_out, + value=encoder_out, + key_padding_mask=encoder_padding_mask, + incremental_state=incremental_state, + static_kv=True, + need_weights=need_attn or (not self.training and self.need_attn), + need_head_weights=need_head_weights, + ) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.encoder_attn_layer_norm(x) + + residual = x + if self.normalize_before: + x = self.final_layer_norm(x) + + x = self.activation_fn(self.fc1(x)) + x = self.activation_dropout_module(x) + if self.ffn_layernorm is not None: + x = self.ffn_layernorm(x) + x = self.fc2(x) + x = self.dropout_module(x) + if self.w_resid is not None: + residual = torch.mul(self.w_resid, residual) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.final_layer_norm(x) + if self.onnx_trace and incremental_state is not None: + saved_state = self.self_attn._get_input_buffer(incremental_state) + assert saved_state is not None + if self_attn_padding_mask is not None: + self_attn_state = [ + saved_state["prev_key"], + saved_state["prev_value"], + saved_state["prev_key_padding_mask"], + ] + else: + self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]] + return x, attn, self_attn_state + return x, attn, None + + def make_generation_fast_(self, need_attn: bool = False, **kwargs): + self.need_attn = need_attn + + +# backward compatible with the legacy argparse format +class TransformerDecoderLayer(TransformerDecoderLayerBase): + def __init__( + self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False + ): + super().__init__( + TransformerConfig.from_namespace(args), + no_encoder_attn=no_encoder_attn, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + ) + self.args = args + + def build_self_attention( + self, embed_dim, args, add_bias_kv=False, add_zero_attn=False + ): + return super().build_self_attention( + embed_dim, + TransformerConfig.from_namespace(args), + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + ) + + def build_encoder_attention(self, embed_dim, args): + return super().build_encoder_attention( + embed_dim, + TransformerConfig.from_namespace(args), + ) diff --git a/fairseq/modules/transformer_layer_aug.py b/fairseq/modules/transformer_layer_aug.py new file mode 100644 index 0000000000000000000000000000000000000000..7eb816978a0c44b9be94221d8a886aa96511a67f --- /dev/null +++ b/fairseq/modules/transformer_layer_aug.py @@ -0,0 +1,315 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Optional + +import torch +from numpy.random import uniform +from torch import Tensor + +from fairseq.modules import LayerNorm +from fairseq.modules.transformer_layer import TransformerDecoderLayerBase + + +class AugTransformerDecoderLayerBase(TransformerDecoderLayerBase): + """Decoder layer block augmented with an additional cross-attention. + + This decoder block is processed with the sequence of the following sub-modules. + self-attention -> cross-attention (first) -> cross-attention (second) -> FFN + + Args: + cfg (argparse.Namespace): parsed command-line arguments + encoder_attn_merge_type (str, optional): the way to combine outputs from + two cross-attention modules. If "sequential" is set, two cross-attention + modules are stacked sequentially. If "parallel" is set, they are processed + in parallel and combined before feeding it to FFN (default: sequential). + dropnet_ratio (float, optional): a probability to drop each cross-attention + module during training (default: 0.0). + """ + + def __init__( + self, + cfg, + add_bias_kv=False, + add_zero_attn=False, + encoder_attn_merge_type="sequential", + dropnet_ratio=0.0, + ): + super().__init__( + cfg, + no_encoder_attn=False, + add_bias_kv=add_bias_kv, + add_zero_attn=False, + ) + self.encoder_attn = self.build_encoder_attention(self.embed_dim, cfg) + self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) + self.encoder_attn2 = self.build_encoder_attention(self.embed_dim, cfg) + if encoder_attn_merge_type == "sequential": + self.encoder_attn_layer_norm2 = LayerNorm(self.embed_dim, export=cfg.export) + else: + self.encoder_attn_layer_norm2 = None + + self.encoder_attn_merge_type = encoder_attn_merge_type + self.dropnet_ratio = dropnet_ratio + + def forward( + self, + x, + encoder_out: Optional[torch.Tensor] = None, + encoder_padding_mask: Optional[torch.Tensor] = None, + encoder_out_aug: Optional[torch.Tensor] = None, + encoder_padding_mask2: Optional[torch.Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + prev_self_attn_state: Optional[List[torch.Tensor]] = None, + prev_attn_state: Optional[List[torch.Tensor]] = None, + self_attn_mask: Optional[torch.Tensor] = None, + self_attn_padding_mask: Optional[torch.Tensor] = None, + need_attn: bool = False, + need_head_weights: bool = False, + ): + """ + Args: + x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_padding_mask (ByteTensor, optional): binary + ByteTensor of shape `(batch, src_len)` where padding + elements are indicated by ``1``. + need_attn (bool, optional): return attention weights + need_head_weights (bool, optional): return attention weights + for each head (default: return average over heads). + + Returns: + encoded output of shape `(seq_len, batch, embed_dim)` + """ + if need_head_weights: + need_attn = True + + residual = x + if self.normalize_before: + x = self.self_attn_layer_norm(x) + if prev_self_attn_state is not None: + prev_key, prev_value = prev_self_attn_state[:2] + saved_state: Dict[str, Optional[Tensor]] = { + "prev_key": prev_key, + "prev_value": prev_value, + } + if len(prev_self_attn_state) >= 3: + saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] + assert incremental_state is not None + self.self_attn._set_input_buffer(incremental_state, saved_state) + _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state) + if self.cross_self_attention and not ( + incremental_state is not None + and _self_attn_input_buffer is not None + and "prev_key" in _self_attn_input_buffer + ): + if self_attn_mask is not None: + assert encoder_out is not None + self_attn_mask = torch.cat( + (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1 + ) + if self_attn_padding_mask is not None: + if encoder_padding_mask is None: + assert encoder_out is not None + encoder_padding_mask = self_attn_padding_mask.new_zeros( + encoder_out.size(1), encoder_out.size(0) + ) + self_attn_padding_mask = torch.cat( + (encoder_padding_mask, self_attn_padding_mask), dim=1 + ) + assert encoder_out is not None + y = torch.cat((encoder_out, x), dim=0) + else: + y = x + + x, attn = self.self_attn( + query=x, + key=y, + value=y, + key_padding_mask=self_attn_padding_mask, + incremental_state=incremental_state, + need_weights=False, + attn_mask=self_attn_mask, + ) + if self.c_attn is not None: + tgt_len, bsz = x.size(0), x.size(1) + x = x.view(tgt_len, bsz, self.nh, self.head_dim) + x = torch.einsum("tbhd,h->tbhd", x, self.c_attn) + x = x.reshape(tgt_len, bsz, self.embed_dim) + if self.attn_ln is not None: + x = self.attn_ln(x) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.self_attn_layer_norm(x) + + assert encoder_out is not None + assert encoder_out_aug is not None + + if self.encoder_attn_merge_type == "sequential": + ratios = self.get_dropnet_ratio() + + # first encoder attention + if ratios[0] > 0: + residual = x + if self.normalize_before: + x = self.encoder_attn_layer_norm(x) + if prev_attn_state is not None: + prev_key, prev_value = prev_attn_state[:2] + saved_state: Dict[str, Optional[Tensor]] = { + "prev_key": prev_key, + "prev_value": prev_value, + } + if len(prev_attn_state) >= 3: + saved_state["prev_key_padding_mask"] = prev_attn_state[2] + assert incremental_state is not None + self.encoder_attn._set_input_buffer(incremental_state, saved_state) + + x, attn = self.encoder_attn( + query=x, + key=encoder_out, + value=encoder_out, + key_padding_mask=encoder_padding_mask, + incremental_state=incremental_state, + static_kv=True, + need_weights=need_attn or (not self.training and self.need_attn), + need_head_weights=need_head_weights, + ) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.encoder_attn_layer_norm(x) + x = ratios[0] * x + + # second encoder attention + if ratios[1] > 0: + residual = x + if self.normalize_before: + x = self.encoder_attn_layer_norm2(x) + if prev_attn_state is not None: + prev_key, prev_value = prev_attn_state[:2] + saved_state: Dict[str, Optional[Tensor]] = { + "prev_key": prev_key, + "prev_value": prev_value, + } + if len(prev_attn_state) >= 3: + saved_state["prev_key_padding_mask"] = prev_attn_state[2] + assert incremental_state is not None + self.encoder_attn2._set_input_buffer(incremental_state, saved_state) + + x, attn2 = self.encoder_attn2( + query=x, + key=encoder_out_aug, + value=encoder_out_aug, + key_padding_mask=encoder_padding_mask2, + incremental_state=incremental_state, + static_kv=True, + need_weights=need_attn or (not self.training and self.need_attn), + need_head_weights=need_head_weights, + ) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.encoder_attn_layer_norm2(x) + x = ratios[1] * x + + elif self.encoder_attn_merge_type == "parallel": + residual = x + if self.normalize_before: + x = self.encoder_attn_layer_norm(x) + if prev_attn_state is not None: + prev_key, prev_value = prev_attn_state[:2] + saved_state: Dict[str, Optional[Tensor]] = { + "prev_key": prev_key, + "prev_value": prev_value, + } + if len(prev_attn_state) >= 3: + saved_state["prev_key_padding_mask"] = prev_attn_state[2] + assert incremental_state is not None + self.encoder_attn._set_input_buffer(incremental_state, saved_state) + + x1, attn = self.encoder_attn( + query=x, + key=encoder_out, + value=encoder_out, + key_padding_mask=encoder_padding_mask, + incremental_state=incremental_state, + static_kv=True, + need_weights=need_attn or (not self.training and self.need_attn), + need_head_weights=need_head_weights, + ) + x2, attn2 = self.encoder_attn2( + query=x, + key=encoder_out_aug, + value=encoder_out_aug, + key_padding_mask=encoder_padding_mask2, + incremental_state=incremental_state, + static_kv=True, + need_weights=need_attn or (not self.training and self.need_attn), + need_head_weights=need_head_weights, + ) + x1 = self.dropout_module(x1) + x2 = self.dropout_module(x2) + ratios = self.get_dropnet_ratio() + x = ratios[0] * x1 + ratios[1] * x2 + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.encoder_attn_layer_norm(x) + + else: + raise NotImplementedError(self.encoder_attn_merge_type) + + residual = x + if self.normalize_before: + x = self.final_layer_norm(x) + + x = self.activation_fn(self.fc1(x)) + x = self.activation_dropout_module(x) + if self.ffn_layernorm is not None: + x = self.ffn_layernorm(x) + x = self.fc2(x) + x = self.dropout_module(x) + if self.w_resid is not None: + residual = torch.mul(self.w_resid, residual) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.final_layer_norm(x) + if self.onnx_trace and incremental_state is not None: + saved_state = self.self_attn._get_input_buffer(incremental_state) + assert saved_state is not None + if self_attn_padding_mask is not None: + self_attn_state = [ + saved_state["prev_key"], + saved_state["prev_value"], + saved_state["prev_key_padding_mask"], + ] + else: + self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]] + return x, attn, attn2, self_attn_state + return x, attn, attn2, None + + def get_dropnet_ratio(self): + if self.encoder_attn_merge_type == "sequential": + if self.dropnet_ratio > 0: + frand = float(uniform(0, 1)) + if frand < self.dropnet_ratio and self.training: + return [2, 0] + elif frand > 1 - self.dropnet_ratio and self.training: + return [0, 2] + else: + return [1, 1] + else: + return [1, 1] + + elif self.encoder_attn_merge_type == "parallel": + if self.dropnet_ratio > 0: + frand = float(uniform(0, 1)) + if frand < self.dropnet_ratio and self.training: + return [1, 0] + elif frand > 1 - self.dropnet_ratio and self.training: + return [0, 1] + else: + return [0.5, 0.5] + else: + return [0.5, 0.5] diff --git a/fairseq/modules/transformer_sentence_encoder.py b/fairseq/modules/transformer_sentence_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2db91ad7b902386efa843f5ade0c0295758fa1 --- /dev/null +++ b/fairseq/modules/transformer_sentence_encoder.py @@ -0,0 +1,291 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from fairseq.modules import ( + FairseqDropout, + LayerDropModuleList, + LayerNorm, + MultiheadAttention, + PositionalEmbedding, + TransformerSentenceEncoderLayer, +) +from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ + + +def init_bert_params(module): + """ + Initialize the weights specific to the BERT Model. + This overrides the default initializations depending on the specified arguments. + 1. If normal_init_linear_weights is set then weights of linear + layer will be initialized using the normal distribution and + bais will be set to the specified value. + 2. If normal_init_embed_weights is set then weights of embedding + layer will be initialized using the normal distribution. + 3. If normal_init_proj_weights is set then weights of + in_project_weight for MultiHeadAttention initialized using + the normal distribution (to be validated). + """ + + def normal_(data): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) + + +class TransformerSentenceEncoder(nn.Module): + """ + Implementation for a Bi-directional Transformer based Sentence Encoder used + in BERT/XLM style pre-trained models. + + This first computes the token embedding using the token embedding matrix, + position embeddings (if specified) and segment embeddings + (if specified). After applying the specified number of + TransformerEncoderLayers, it outputs all the internal states of the + encoder as well as the final representation associated with the first + token (usually CLS token). + + Input: + - tokens: B x T matrix representing sentences + - segment_labels: B x T matrix representing segment label for tokens + + Output: + - a tuple of the following: + - a list of internal model states used to compute the + predictions where each tensor has shape T x B x C + - sentence representation associated with first input token + in format B x C. + """ + + def __init__( + self, + padding_idx: int, + vocab_size: int, + num_encoder_layers: int = 6, + embedding_dim: int = 768, + ffn_embedding_dim: int = 3072, + num_attention_heads: int = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + layerdrop: float = 0.0, + max_seq_len: int = 256, + num_segments: int = 2, + use_position_embeddings: bool = True, + offset_positions_by_padding: bool = True, + encoder_normalize_before: bool = False, + apply_bert_init: bool = False, + activation_fn: str = "relu", + learned_pos_embedding: bool = True, + embed_scale: float = None, + freeze_embeddings: bool = False, + n_trans_layers_to_freeze: int = 0, + export: bool = False, + traceable: bool = False, + q_noise: float = 0.0, + qn_block_size: int = 8, + ) -> None: + + super().__init__() + self.padding_idx = padding_idx + self.vocab_size = vocab_size + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) + self.layerdrop = layerdrop + self.max_seq_len = max_seq_len + self.embedding_dim = embedding_dim + self.num_segments = num_segments + self.use_position_embeddings = use_position_embeddings + self.apply_bert_init = apply_bert_init + self.learned_pos_embedding = learned_pos_embedding + self.traceable = traceable + + self.embed_tokens = self.build_embedding( + self.vocab_size, self.embedding_dim, self.padding_idx + ) + self.embed_scale = embed_scale + + if q_noise > 0: + self.quant_noise = apply_quant_noise_( + nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), + q_noise, + qn_block_size, + ) + else: + self.quant_noise = None + + self.segment_embeddings = ( + nn.Embedding(self.num_segments, self.embedding_dim, padding_idx=None) + if self.num_segments > 0 + else None + ) + + self.embed_positions = ( + PositionalEmbedding( + self.max_seq_len, + self.embedding_dim, + padding_idx=(self.padding_idx if offset_positions_by_padding else None), + learned=self.learned_pos_embedding, + ) + if self.use_position_embeddings + else None + ) + + if encoder_normalize_before: + self.emb_layer_norm = LayerNorm(self.embedding_dim, export=export) + else: + self.emb_layer_norm = None + + if self.layerdrop > 0.0: + self.layers = LayerDropModuleList(p=self.layerdrop) + else: + self.layers = nn.ModuleList([]) + self.layers.extend( + [ + self.build_transformer_sentence_encoder_layer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=ffn_embedding_dim, + num_attention_heads=num_attention_heads, + dropout=self.dropout_module.p, + attention_dropout=attention_dropout, + activation_dropout=activation_dropout, + activation_fn=activation_fn, + export=export, + q_noise=q_noise, + qn_block_size=qn_block_size, + ) + for _ in range(num_encoder_layers) + ] + ) + + # Apply initialization of model params after building the model + if self.apply_bert_init: + self.apply(init_bert_params) + + def freeze_module_params(m): + if m is not None: + for p in m.parameters(): + p.requires_grad = False + + if freeze_embeddings: + freeze_module_params(self.embed_tokens) + freeze_module_params(self.segment_embeddings) + freeze_module_params(self.embed_positions) + freeze_module_params(self.emb_layer_norm) + + for layer in range(n_trans_layers_to_freeze): + freeze_module_params(self.layers[layer]) + + def build_embedding(self, vocab_size, embedding_dim, padding_idx): + return nn.Embedding(vocab_size, embedding_dim, padding_idx) + + def build_transformer_sentence_encoder_layer( + self, + embedding_dim, + ffn_embedding_dim, + num_attention_heads, + dropout, + attention_dropout, + activation_dropout, + activation_fn, + export, + q_noise, + qn_block_size, + ): + return TransformerSentenceEncoderLayer( + embedding_dim=embedding_dim, + ffn_embedding_dim=ffn_embedding_dim, + num_attention_heads=num_attention_heads, + dropout=dropout, + attention_dropout=attention_dropout, + activation_dropout=activation_dropout, + activation_fn=activation_fn, + export=export, + q_noise=q_noise, + qn_block_size=qn_block_size, + ) + + def forward( + self, + tokens: torch.Tensor, + segment_labels: torch.Tensor = None, + last_state_only: bool = False, + positions: Optional[torch.Tensor] = None, + token_embeddings: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + is_tpu = tokens.device.type == "xla" + + # compute padding mask. This is needed for multi-head attention + padding_mask = tokens.eq(self.padding_idx) + if not self.traceable and not is_tpu and not padding_mask.any(): + padding_mask = None + + if token_embeddings is not None: + x = token_embeddings + else: + x = self.embed_tokens(tokens) + + if self.embed_scale is not None: + x = x * self.embed_scale + + if self.embed_positions is not None: + x = x + self.embed_positions(tokens, positions=positions) + + if self.segment_embeddings is not None and segment_labels is not None: + x = x + self.segment_embeddings(segment_labels) + + if self.quant_noise is not None: + x = self.quant_noise(x) + + if self.emb_layer_norm is not None: + x = self.emb_layer_norm(x) + + x = self.dropout_module(x) + + # account for padding while computing the representation + if padding_mask is not None: + x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + inner_states = [] + if not last_state_only: + inner_states.append(x) + + for layer in self.layers: + x, _ = layer( + x, self_attn_padding_mask=padding_mask, self_attn_mask=attn_mask + ) + if not last_state_only: + inner_states.append(x) + + sentence_rep = x[0, :, :] + + if last_state_only: + inner_states = [x] + + if self.traceable: + return torch.stack(inner_states), sentence_rep + else: + return inner_states, sentence_rep diff --git a/fairseq/modules/transformer_sentence_encoder_layer.py b/fairseq/modules/transformer_sentence_encoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..f869c4b2f8fb15f96a292e39bd293df7898a4fce --- /dev/null +++ b/fairseq/modules/transformer_sentence_encoder_layer.py @@ -0,0 +1,139 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +import torch +import torch.nn as nn +from fairseq import utils +from fairseq.modules import LayerNorm, MultiheadAttention +from fairseq.modules.fairseq_dropout import FairseqDropout +from fairseq.modules.quant_noise import quant_noise + + +class TransformerSentenceEncoderLayer(nn.Module): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embedding_dim: int = 768, + ffn_embedding_dim: int = 3072, + num_attention_heads: int = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + export: bool = False, + q_noise: float = 0.0, + qn_block_size: int = 8, + init_fn: Callable = None, + ) -> None: + super().__init__() + + if init_fn is not None: + init_fn() + + # Initialize parameters + self.embedding_dim = embedding_dim + self.num_attention_heads = num_attention_heads + self.attention_dropout = attention_dropout + self.q_noise = q_noise + self.qn_block_size = qn_block_size + + self.dropout_module = FairseqDropout( + dropout, module_name=self.__class__.__name__ + ) + self.activation_dropout_module = FairseqDropout( + activation_dropout, module_name=self.__class__.__name__ + ) + + # Initialize blocks + self.activation_fn = utils.get_activation_fn(activation_fn) + self.self_attn = self.build_self_attention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + q_noise=q_noise, + qn_block_size=qn_block_size, + ) + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embedding_dim, export=export) + + self.fc1 = self.build_fc1( + self.embedding_dim, + ffn_embedding_dim, + q_noise=q_noise, + qn_block_size=qn_block_size, + ) + self.fc2 = self.build_fc2( + ffn_embedding_dim, + self.embedding_dim, + q_noise=q_noise, + qn_block_size=qn_block_size, + ) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim, export=export) + + def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): + return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) + + def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): + return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) + + def build_self_attention( + self, + embed_dim, + num_attention_heads, + dropout, + self_attention, + q_noise, + qn_block_size, + ): + return MultiheadAttention( + embed_dim, + num_attention_heads, + dropout=dropout, + self_attention=True, + q_noise=q_noise, + qn_block_size=qn_block_size, + ) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: Optional[torch.Tensor] = None, + self_attn_padding_mask: Optional[torch.Tensor] = None, + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer implementation. + """ + residual = x + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + ) + x = self.dropout_module(x) + x = residual + x + x = self.self_attn_layer_norm(x) + + residual = x + x = self.activation_fn(self.fc1(x)) + x = self.activation_dropout_module(x) + x = self.fc2(x) + x = self.dropout_module(x) + x = residual + x + x = self.final_layer_norm(x) + return x, attn diff --git a/fairseq/modules/transpose_last.py b/fairseq/modules/transpose_last.py new file mode 100644 index 0000000000000000000000000000000000000000..d7cca9a4bbdb3f455217380f96a2f2d77eae8630 --- /dev/null +++ b/fairseq/modules/transpose_last.py @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +transpose last 2 dimensions of the input +""" + +import torch.nn as nn + + +class TransposeLast(nn.Module): + def __init__(self, deconstruct_idx=None, tranpose_dim=-2): + super().__init__() + self.deconstruct_idx = deconstruct_idx + self.tranpose_dim = tranpose_dim + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + return x.transpose(self.tranpose_dim, -1) diff --git a/fairseq/modules/unfold.py b/fairseq/modules/unfold.py new file mode 100644 index 0000000000000000000000000000000000000000..bbaafbd6bfe1d206348a3673da8cae8ef9daeca4 --- /dev/null +++ b/fairseq/modules/unfold.py @@ -0,0 +1,19 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn.functional as F + + +def unfold1d(x, kernel_size: int, padding_l: int, pad_value: float = 0): + """unfold T x B x C to T x B x C x K""" + if kernel_size > 1: + T, B, C = x.size() + x = F.pad( + x, (0, 0, 0, 0, padding_l, kernel_size - 1 - padding_l), value=pad_value + ) + x = x.as_strided((T, B, C, kernel_size), (B * C, C, 1, B * C)) + else: + x = x.unsqueeze(3) + return x diff --git a/fairseq/modules/vggblock.py b/fairseq/modules/vggblock.py new file mode 100644 index 0000000000000000000000000000000000000000..ee5ee19a34816c7350c21fba7c4907fec8ca7a61 --- /dev/null +++ b/fairseq/modules/vggblock.py @@ -0,0 +1,116 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +from collections.abc import Iterable +from itertools import repeat + +import torch +import torch.nn as nn + + +def _pair(v): + if isinstance(v, Iterable): + assert len(v) == 2, "len(v) != 2" + return v + return tuple(repeat(v, 2)) + + +def infer_conv_output_dim(conv_op, input_dim, sample_inchannel): + sample_seq_len = 200 + sample_bsz = 10 + x = torch.randn(sample_bsz, sample_inchannel, sample_seq_len, input_dim) + # N x C x H x W + # N: sample_bsz, C: sample_inchannel, H: sample_seq_len, W: input_dim + x = conv_op(x) + # N x C x H x W + x = x.transpose(1, 2) + # N x H x C x W + bsz, seq = x.size()[:2] + per_channel_dim = x.size()[3] + # bsz: N, seq: H, CxW the rest + return x.contiguous().view(bsz, seq, -1).size(-1), per_channel_dim + + +class VGGBlock(torch.nn.Module): + """ + VGG motibated cnn module https://arxiv.org/pdf/1409.1556.pdf + + Args: + in_channels: (int) number of input channels (typically 1) + out_channels: (int) number of output channels + conv_kernel_size: convolution channels + pooling_kernel_size: the size of the pooling window to take a max over + num_conv_layers: (int) number of convolution layers + input_dim: (int) input dimension + conv_stride: the stride of the convolving kernel. + Can be a single number or a tuple (sH, sW) Default: 1 + padding: implicit paddings on both sides of the input. + Can be a single number or a tuple (padH, padW). Default: None + layer_norm: (bool) if layer norm is going to be applied. Default: False + + Shape: + Input: BxCxTxfeat, i.e. (batch_size, input_size, timesteps, features) + Output: BxCxTxfeat, i.e. (batch_size, input_size, timesteps, features) + """ + + def __init__( + self, + in_channels, + out_channels, + conv_kernel_size, + pooling_kernel_size, + num_conv_layers, + input_dim, + conv_stride=1, + padding=None, + layer_norm=False, + ): + assert ( + input_dim is not None + ), "Need input_dim for LayerNorm and infer_conv_output_dim" + super(VGGBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.conv_kernel_size = _pair(conv_kernel_size) + self.pooling_kernel_size = _pair(pooling_kernel_size) + self.num_conv_layers = num_conv_layers + self.padding = ( + tuple(e // 2 for e in self.conv_kernel_size) + if padding is None + else _pair(padding) + ) + self.conv_stride = _pair(conv_stride) + + self.layers = nn.ModuleList() + for layer in range(num_conv_layers): + conv_op = nn.Conv2d( + in_channels if layer == 0 else out_channels, + out_channels, + self.conv_kernel_size, + stride=self.conv_stride, + padding=self.padding, + ) + self.layers.append(conv_op) + if layer_norm: + conv_output_dim, per_channel_dim = infer_conv_output_dim( + conv_op, input_dim, in_channels if layer == 0 else out_channels + ) + self.layers.append(nn.LayerNorm(per_channel_dim)) + input_dim = per_channel_dim + self.layers.append(nn.ReLU()) + + if self.pooling_kernel_size is not None: + pool_op = nn.MaxPool2d(kernel_size=self.pooling_kernel_size, ceil_mode=True) + self.layers.append(pool_op) + self.total_output_dim, self.output_dim = infer_conv_output_dim( + pool_op, input_dim, out_channels + ) + + def forward(self, x): + for i, _ in enumerate(self.layers): + x = self.layers[i](x) + return x diff --git a/fairseq/nan_detector.py b/fairseq/nan_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..bd0f9110731461f78e85196185303f1b5ea62c91 --- /dev/null +++ b/fairseq/nan_detector.py @@ -0,0 +1,108 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch + + +logger = logging.getLogger(__name__) + + +class NanDetector: + """ + Detects the first NaN or Inf in forward and/or backward pass and logs, together with the module name + """ + + def __init__(self, model, forward=True, backward=True): + self.bhooks = [] + self.fhooks = [] + self.forward = forward + self.backward = backward + self.named_parameters = list(model.named_parameters()) + self.reset() + + for name, mod in model.named_modules(): + mod.__module_name = name + self.add_hooks(mod) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + # Dump out all model gnorms to enable better debugging + norm = {} + gradients = {} + for name, param in self.named_parameters: + if param.grad is not None: + grad_norm = torch.norm(param.grad.data.float(), p=2) + norm[name] = param.norm().item() + if torch.isnan(grad_norm).any() or torch.isinf(grad_norm).any(): + gradients[name] = param.grad.data + if len(gradients) > 0: + logger.info("Detected nan/inf grad norm, dumping norms...") + logger.info(f"norms: {norm}") + logger.info(f"gradients: {gradients}") + + self.close() + + def add_hooks(self, module): + if self.forward: + self.fhooks.append(module.register_forward_hook(self.fhook_fn)) + if self.backward: + self.bhooks.append(module.register_backward_hook(self.bhook_fn)) + + def reset(self): + self.has_printed_f = False + self.has_printed_b = False + + def _detect(self, tensor, name, backward): + err = None + if ( + torch.is_floating_point(tensor) + # single value tensors (like the loss) will not provide much info + and tensor.numel() >= 2 + ): + with torch.no_grad(): + if torch.isnan(tensor).any(): + err = "NaN" + elif torch.isinf(tensor).any(): + err = "Inf" + if err is not None: + err = f"{err} detected in output of {name}, shape: {tensor.shape}, {'backward' if backward else 'forward'}" + return err + + def _apply(self, module, inp, x, backward): + if torch.is_tensor(x): + if isinstance(inp, tuple) and len(inp) > 0: + inp = inp[0] + err = self._detect(x, module.__module_name, backward) + if err is not None: + if torch.is_tensor(inp) and not backward: + err += ( + f" input max: {inp.max().item()}, input min: {inp.min().item()}" + ) + + has_printed_attr = "has_printed_b" if backward else "has_printed_f" + logger.warning(err) + setattr(self, has_printed_attr, True) + elif isinstance(x, dict): + for v in x.values(): + self._apply(module, inp, v, backward) + elif isinstance(x, list) or isinstance(x, tuple): + for v in x: + self._apply(module, inp, v, backward) + + def fhook_fn(self, module, inp, output): + if not self.has_printed_f: + self._apply(module, inp, output, backward=False) + + def bhook_fn(self, module, inp, output): + if not self.has_printed_b: + self._apply(module, inp, output, backward=True) + + def close(self): + for hook in self.fhooks + self.bhooks: + hook.remove() diff --git a/fairseq/ngram_repeat_block.py b/fairseq/ngram_repeat_block.py new file mode 100644 index 0000000000000000000000000000000000000000..4eb50303116671a47d03528fbe8a808647cbb116 --- /dev/null +++ b/fairseq/ngram_repeat_block.py @@ -0,0 +1,120 @@ +# Originally from Microsoft Corporation. +# Licensed under the MIT License. + +""" Wrapper for ngram_repeat_block cuda extension """ +import math +import warnings +from typing import List + +import torch +from torch import nn + +try: + from fairseq import ngram_repeat_block_cuda + + EXTENSION_BUILT = True +except ImportError: + EXTENSION_BUILT = False + + +def is_cuda_extension_usable() -> bool: + """Check whether ngram_repeat_block_cuda is built properly""" + if not EXTENSION_BUILT or not torch.cuda.is_available(): + return False + bsz = 2 + tokens = torch.tensor([[4, 4, 3, 2], [1, 2, 3, 4]], dtype=torch.long, device="cuda") + lprobs = torch.rand((8, 12), device="cuda") + try: + outputs = ngram_repeat_block_cuda.forward(tokens, lprobs, bsz, 3, 4, 3) + outputs = outputs + 4 # This line breaks if the extension is built incorrectly. + return True + except RuntimeError: + warnings.warn( + "NGramRepeatBlock extension must be rebuilt." + 'Run TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0" python setup.py build_ext --inplace' + ) + return False + + +class NGramRepeatBlock(nn.Module): + """Wrapper class for calling ngram_repeat_block cuda extension""" + + def __init__(self, no_repeat_ngram_size: int, use_extension: bool = True): + super().__init__() + self.use_extension = is_cuda_extension_usable() if use_extension else False + self.no_repeat_ngram_size = no_repeat_ngram_size + + def reset_parameters(self): + pass + + @torch.jit.unused + def call_cuda_extension( + self, + tokens, + lprobs, + bsz: int, + beam_size: int, + step: int, + ): + return ngram_repeat_block_cuda.forward( + tokens, lprobs, bsz, step, beam_size, self.no_repeat_ngram_size + ) + + def forward( + self, + tokens, + lprobs, + bsz: int, + beam_size: int, + step: int, + ): + """ + Args: + tokens(Tensor): Input tokens(Bsz*beam, seq_len) + lprobs(Tensor): likelihood probability, + Expected to be updated in place.(Bsz*beam, vocab_size) + bsz(int): batch size + step(int): current step + beam_size(int): beam size + no_repeat_ngram_size(int): Ngram size + """ + msg = f"expected {bsz *beam_size} got" + assert tokens.size(0) == bsz * beam_size, f"{msg} {tokens.size(0)}" + assert lprobs.size(0) == bsz * beam_size, f"{msg} {lprobs.size(0)}" + if self.use_extension: + return self.call_cuda_extension(tokens, lprobs, bsz, beam_size, step) + + else: + return self._no_repeat_ngram( + tokens, + lprobs, + bsz, + beam_size, + step, + ) + + def _no_repeat_ngram(self, tokens, lprobs, bsz: int, beam_size: int, step: int): + """For each hypothesis generate a list of previous ngrams and set associated lprobs to -inf""" + banned_tokens = [ + torch.jit.annotate(List[int], []) for bbsz_idx in range(bsz * beam_size) + ] + if step + 2 - self.no_repeat_ngram_size >= 0: + cpu_tokens: List[List[int]] = tokens.cpu().tolist() + check_start_pos = step + 2 - self.no_repeat_ngram_size + for bbsz_idx in range(bsz * beam_size): + ngram_to_check = cpu_tokens[bbsz_idx][ + -(self.no_repeat_ngram_size - 1) : + ] + for i in range(check_start_pos): + if ( + ngram_to_check + == cpu_tokens[bbsz_idx][i : i + self.no_repeat_ngram_size - 1] + ): + banned_tokens[bbsz_idx].append( + cpu_tokens[bbsz_idx][i + self.no_repeat_ngram_size - 1] + ) + for bbsz_idx in range(bsz * beam_size): + lprobs[bbsz_idx][ + torch.tensor(banned_tokens[bbsz_idx], dtype=torch.int64) + ] = torch.tensor(-math.inf).to(lprobs) + return lprobs diff --git a/fairseq/optim/__init__.py b/fairseq/optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..be783be896396ff659c0bd173a7acebb8a2d165d --- /dev/null +++ b/fairseq/optim/__init__.py @@ -0,0 +1,48 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import importlib +import os + +from fairseq import registry +from fairseq.optim.bmuf import FairseqBMUF # noqa +from fairseq.optim.fairseq_optimizer import ( # noqa + FairseqOptimizer, + LegacyFairseqOptimizer, +) +from fairseq.optim.amp_optimizer import AMPOptimizer +from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer +from fairseq.optim.shard import shard_ +from omegaconf import DictConfig + +__all__ = [ + "AMPOptimizer", + "FairseqOptimizer", + "FP16Optimizer", + "MemoryEfficientFP16Optimizer", + "shard_", +] + +( + _build_optimizer, + register_optimizer, + OPTIMIZER_REGISTRY, + OPTIMIZER_DATACLASS_REGISTRY, +) = registry.setup_registry("--optimizer", base_class=FairseqOptimizer, required=True) + + +def build_optimizer(cfg: DictConfig, params, *extra_args, **extra_kwargs): + if all(isinstance(p, dict) for p in params): + params = [t for p in params for t in p.values()] + params = list(filter(lambda p: p.requires_grad, params)) + return _build_optimizer(cfg, params, *extra_args, **extra_kwargs) + + +# automatically import any Python files in the optim/ directory +for file in sorted(os.listdir(os.path.dirname(__file__))): + if file.endswith(".py") and not file.startswith("_"): + file_name = file[: file.find(".py")] + importlib.import_module("fairseq.optim." + file_name) diff --git a/fairseq/optim/__pycache__/__init__.cpython-310.pyc b/fairseq/optim/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b5660e3aa0cece231cefedaea2d731d89370fb7 Binary files /dev/null and b/fairseq/optim/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/optim/__pycache__/adadelta.cpython-310.pyc b/fairseq/optim/__pycache__/adadelta.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d7b71b6ba56289da0618454acded3c9502d7dc9 Binary files /dev/null and b/fairseq/optim/__pycache__/adadelta.cpython-310.pyc differ diff --git a/fairseq/optim/__pycache__/adafactor.cpython-310.pyc b/fairseq/optim/__pycache__/adafactor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c6c603dc664d526857bc062fcdf6b5a7219067d Binary files /dev/null and b/fairseq/optim/__pycache__/adafactor.cpython-310.pyc differ diff --git a/fairseq/optim/__pycache__/adagrad.cpython-310.pyc b/fairseq/optim/__pycache__/adagrad.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd9fa74c11d6af89cb53824a27ba5087a1922498 Binary files /dev/null and b/fairseq/optim/__pycache__/adagrad.cpython-310.pyc differ diff --git a/fairseq/optim/__pycache__/adam.cpython-310.pyc b/fairseq/optim/__pycache__/adam.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6c806792aad11055c82e1c16b9b72e32676b2f5 Binary files /dev/null and b/fairseq/optim/__pycache__/adam.cpython-310.pyc differ diff --git a/fairseq/optim/__pycache__/adamax.cpython-310.pyc b/fairseq/optim/__pycache__/adamax.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0e08fc6e2b4580c0747080436ca3952dedbfde2 Binary files /dev/null and b/fairseq/optim/__pycache__/adamax.cpython-310.pyc differ diff --git a/fairseq/optim/__pycache__/amp_optimizer.cpython-310.pyc b/fairseq/optim/__pycache__/amp_optimizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb21e2493418592c85b1e6b104b7443d9fabac44 Binary files /dev/null and b/fairseq/optim/__pycache__/amp_optimizer.cpython-310.pyc differ diff --git a/fairseq/optim/__pycache__/bmuf.cpython-310.pyc b/fairseq/optim/__pycache__/bmuf.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d3b9b041a67b1fb16863a3e8538f8970d3a1884 Binary files /dev/null and b/fairseq/optim/__pycache__/bmuf.cpython-310.pyc differ diff --git a/fairseq/optim/__pycache__/composite.cpython-310.pyc b/fairseq/optim/__pycache__/composite.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..042b9ce18a4aa2f62dd781f5f8ec6beff34f53bb Binary files /dev/null and b/fairseq/optim/__pycache__/composite.cpython-310.pyc differ diff --git a/fairseq/optim/__pycache__/cpu_adam.cpython-310.pyc b/fairseq/optim/__pycache__/cpu_adam.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..188908db674b7acf76fe1204c996ef80d297c069 Binary files /dev/null and b/fairseq/optim/__pycache__/cpu_adam.cpython-310.pyc differ diff --git a/fairseq/optim/__pycache__/dynamic_loss_scaler.cpython-310.pyc b/fairseq/optim/__pycache__/dynamic_loss_scaler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fa98e420dee99e6b21c13bdbd0c635bbe08cc5f Binary files /dev/null and b/fairseq/optim/__pycache__/dynamic_loss_scaler.cpython-310.pyc differ diff --git a/fairseq/optim/__pycache__/fairseq_optimizer.cpython-310.pyc b/fairseq/optim/__pycache__/fairseq_optimizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a010573b26ccc73980a1eac2a9689c99443b57ff Binary files /dev/null and b/fairseq/optim/__pycache__/fairseq_optimizer.cpython-310.pyc differ diff --git a/fairseq/optim/__pycache__/fp16_optimizer.cpython-310.pyc b/fairseq/optim/__pycache__/fp16_optimizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27b75697e413aa9f0fcd680dedeebe2356f5c44a Binary files /dev/null and b/fairseq/optim/__pycache__/fp16_optimizer.cpython-310.pyc differ diff --git a/fairseq/optim/__pycache__/fused_adam.cpython-310.pyc b/fairseq/optim/__pycache__/fused_adam.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b8ce7113e8a32fc37e748375a743af4b9048c05 Binary files /dev/null and b/fairseq/optim/__pycache__/fused_adam.cpython-310.pyc differ diff --git a/fairseq/optim/__pycache__/fused_lamb.cpython-310.pyc b/fairseq/optim/__pycache__/fused_lamb.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc83e1b3642cf9b40bca1edb2d73c5df7a0179fe Binary files /dev/null and b/fairseq/optim/__pycache__/fused_lamb.cpython-310.pyc differ diff --git a/fairseq/optim/__pycache__/nag.cpython-310.pyc b/fairseq/optim/__pycache__/nag.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c0a8fd712e65534f0b876d7937a7ac55c188573 Binary files /dev/null and b/fairseq/optim/__pycache__/nag.cpython-310.pyc differ diff --git a/fairseq/optim/__pycache__/sgd.cpython-310.pyc b/fairseq/optim/__pycache__/sgd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84675bf9e444283d3dc0339262e892840628ddcd Binary files /dev/null and b/fairseq/optim/__pycache__/sgd.cpython-310.pyc differ diff --git a/fairseq/optim/__pycache__/shard.cpython-310.pyc b/fairseq/optim/__pycache__/shard.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec5930b9461740454c6420a0f75da02bb936e1cd Binary files /dev/null and b/fairseq/optim/__pycache__/shard.cpython-310.pyc differ diff --git a/fairseq/optim/adadelta.py b/fairseq/optim/adadelta.py new file mode 100644 index 0000000000000000000000000000000000000000..f1a21549770f0904a6a40a42ff7eb52811f1bfbe --- /dev/null +++ b/fairseq/optim/adadelta.py @@ -0,0 +1,47 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.optim + +from . import LegacyFairseqOptimizer, register_optimizer + + +@register_optimizer("adadelta") +class Adadelta(LegacyFairseqOptimizer): + def __init__(self, args, params): + super().__init__(args) + self._optimizer = torch.optim.Adadelta(params, **self.optimizer_config) + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--adadelta-rho', type=float, default=0.9, metavar='RHO', + help='coefficient used for computing a running average of squared gradients') + parser.add_argument('--adadelta-eps', type=float, default=1e-6, metavar='EPS', + help='term added to the denominator to improve numerical stability') + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') + parser.add_argument('--anneal-eps', action='store_true', help='flag to anneal eps') + # fmt: on + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + return { + "lr": self.args.lr[0], + "rho": self.args.adadelta_rho, + "eps": self.args.adadelta_eps, + "weight_decay": self.args.weight_decay, + } + + @property + def supports_flat_params(self): + return True diff --git a/fairseq/optim/adafactor.py b/fairseq/optim/adafactor.py new file mode 100644 index 0000000000000000000000000000000000000000..042ae926b0718d3ad66b28cddcc0f32293ad88a1 --- /dev/null +++ b/fairseq/optim/adafactor.py @@ -0,0 +1,268 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.optim + +from . import LegacyFairseqOptimizer, register_optimizer + + +@register_optimizer("adafactor") +class FairseqAdafactor(LegacyFairseqOptimizer): + def __init__(self, args, params): + super().__init__(args) + self._optimizer = Adafactor(params, **self.optimizer_config) + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--adafactor-eps', default='(1e-30, 1e-3)', metavar="E", + help='epsilons for Adafactor optimizer') + parser.add_argument('--clip-threshold', type=float, default=1.0, metavar="C", + help='threshold for clipping update root mean square') + parser.add_argument('--decay-rate', type=float, default=-0.8, metavar="D", + help='decay rate of the second moment estimator') + parser.add_argument('--beta1', type=float, default=None, metavar="B", + help='beta for first moment estimator. Optional') + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') + parser.add_argument('--scale-parameter', action='store_true', + help='scale learning rate by root mean square of parameter') + parser.add_argument('--relative-step', action='store_true', + help='set learning rate to inverse square root of timestep,' + 'otherwise use external learning rate') + parser.add_argument('--warmup-init', action='store_true', + help='use relative step for warm-up learning rate schedule') + # fmt: on + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + Note : Convergence issues empirically observed with fp16 on. + Might require search for appropriate configuration. + """ + return { + "lr": self.args.lr[0], + "eps": eval(self.args.adafactor_eps), + "clip_threshold": self.args.clip_threshold, + "decay_rate": self.args.decay_rate, + "beta1": self.args.beta1, + "weight_decay": self.args.weight_decay, + "scale_parameter": self.args.scale_parameter, # defaults to False + "relative_step": self.args.relative_step, # defaults to False + "warmup_init": self.args.warmup_init, + } + + +class Adafactor(torch.optim.Optimizer): + """Implements Adafactor algorithm. + + This implementation is based on: + `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost` + (see https://arxiv.org/abs/1804.04235) + + Note that this optimizer internally adjusts the learning rate + depending on the *scale_parameter*, *relative_step* and + *warmup_init* options. To use a manual (external) learning rate + schedule you should set `scale_parameter=False` and + `relative_step=False`. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): external learning rate (default: None) + eps (tuple[float, float]): regularization constans for square gradient + and parameter scale respectively (default: (1e-30, 1e-3)) + clip_threshold (float): threshold of root mean square of + final gradient update (default: 1.0) + decay_rate (float): coefficient used to compute running averages of square + gradient (default: -0.8) + beta1 (float): coefficient used for computing running averages of gradient + (default: None) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + scale_parameter (bool): if True, learning rate is scaled by root mean square of + parameter (default: True) + relative_step (bool): if True, time-dependent learning rate is computed + instead of external learning rate (default: True) + warmup_init (bool): time-dependent learning rate computation depends on + whether warm-up initialization is being used (default: False) + """ + + def __init__( + self, + params, + lr=None, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + scale_parameter=True, + relative_step=True, + warmup_init=False, + ): + if lr is not None and relative_step: + raise ValueError("Cannot combine manual lr and relative_step options") + if warmup_init and not relative_step: + raise ValueError("warmup_init requires relative_step=True") + + defaults = dict( + lr=lr, + eps=eps, + clip_threshold=clip_threshold, + decay_rate=decay_rate, + beta1=beta1, + weight_decay=weight_decay, + scale_parameter=scale_parameter, + relative_step=relative_step, + warmup_init=warmup_init, + ) + super(Adafactor, self).__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self): + return True + + @property + def supports_flat_params(self): + return False + + def _get_lr(self, param_group, param_state): + rel_step_sz = param_group["lr"] + if param_group["relative_step"]: + min_step = ( + 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 + ) + rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) + param_scale = 1.0 + if param_group["scale_parameter"]: + param_scale = max(param_group["eps"][1], param_state["RMS"]) + return param_scale * rel_step_sz + + def _get_options(self, param_group, param_shape): + factored = len(param_shape) >= 2 + use_first_moment = param_group["beta1"] is not None + return factored, use_first_moment + + def _rms(self, tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): + r_factor = ( + (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)) + .rsqrt_() + .unsqueeze(-1) + ) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = self._get_options(group, grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + if factored: + state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) + state["exp_avg_sq_col"] = torch.zeros( + grad_shape[:-2] + grad_shape[-1:] + ).to(grad) + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"].to(grad) + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p.data + if p.data.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state["step"] += 1 + state["RMS"] = self._rms(p_data_fp32) + group["lr"] = self._get_lr(group, state) + + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad**2) + group["eps"][0] + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + + exp_avg_sq_row.mul_(beta2t).add_( + update.mean(dim=-1), alpha=1.0 - beta2t + ) + exp_avg_sq_col.mul_(beta2t).add_( + update.mean(dim=-2), alpha=1.0 - beta2t + ) + + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_( + (self._rms(update) / group["clip_threshold"]).clamp_(min=1.0) + ) + update.mul_(group["lr"]) + + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=1 - group["beta1"]) + update = exp_avg + + if group["weight_decay"] != 0: + p_data_fp32.add_( + p_data_fp32, alpha=-group["weight_decay"] * group["lr"] + ) + + p_data_fp32.add_(-update) + + if p.data.dtype in {torch.float16, torch.bfloat16}: + p.data.copy_(p_data_fp32) + + return loss diff --git a/fairseq/optim/adagrad.py b/fairseq/optim/adagrad.py new file mode 100644 index 0000000000000000000000000000000000000000..4f539541c1c91d8c822f7ce624fa6eabf744f60e --- /dev/null +++ b/fairseq/optim/adagrad.py @@ -0,0 +1,40 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.optim + +from . import LegacyFairseqOptimizer, register_optimizer + + +@register_optimizer("adagrad") +class Adagrad(LegacyFairseqOptimizer): + def __init__(self, args, params): + super().__init__(args) + self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config) + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') + # fmt: on + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + return { + "lr": self.args.lr[0], + "weight_decay": self.args.weight_decay, + } + + @property + def supports_flat_params(self): + return False diff --git a/fairseq/optim/adam.py b/fairseq/optim/adam.py new file mode 100644 index 0000000000000000000000000000000000000000..678ec7c61763101d7ee62dbbaafc012886ff8a0e --- /dev/null +++ b/fairseq/optim/adam.py @@ -0,0 +1,239 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +from collections.abc import Collection +from dataclasses import dataclass, field +from typing import Any, List + +import torch +import torch.distributed as dist +import torch.optim +from fairseq.dataclass import FairseqDataclass +from fairseq.optim import FairseqOptimizer, register_optimizer +from fairseq.optim.fused_adam import get_fused_adam_class +from omegaconf import II, OmegaConf + + +logger = logging.getLogger(__name__) + + +@dataclass +class FairseqAdamConfig(FairseqDataclass): + adam_betas: Any = field( + default=(0.9, 0.999), metadata={"help": "betas for Adam optimizer"} + ) + adam_eps: float = field( + default=1e-8, metadata={"help": "epsilon for Adam optimizer"} + ) + weight_decay: float = field(default=0.0, metadata={"help": "weight decay"}) + use_old_adam: bool = field( + default=False, metadata={"help": "Use fairseq.optim.adam.Adam"} + ) + fp16_adam_stats: bool = field( + default=False, metadata={"help": "use FP16 stats (with automatic scaling)"} + ) + # TODO common vars below in parent + tpu: bool = II("common.tpu") + lr: List[float] = II("optimization.lr") + + +@register_optimizer("adam", dataclass=FairseqAdamConfig) +class FairseqAdam(FairseqOptimizer): + """Adam optimizer for fairseq. + + Important note: this optimizer corresponds to the "AdamW" variant of + Adam in its weight decay behavior. As such, it is most closely + analogous to torch.optim.AdamW from PyTorch. + """ + + def __init__(self, cfg: FairseqAdamConfig, params): + super().__init__(cfg) + fused_adam_cls = get_fused_adam_class() + use_fused_adam = ( + not getattr(cfg, "use_old_adam", False) + and fused_adam_cls is not None + and torch.cuda.is_available() + ) + if getattr(cfg, "tpu", False): + if self.cfg.fp16_adam_stats: + raise NotImplementedError("--fp16-adam-stats is only supported on GPU") + # on TPUs we use the Adam defined here, since it + # automatically casts gradients to FP32 + self._optimizer = Adam(params, **self.optimizer_config) + elif use_fused_adam: + logger.info("using FusedAdam") + self._optimizer = fused_adam_cls( + params, use_fp16_stats=self.cfg.fp16_adam_stats, **self.optimizer_config + ) + else: + if self.cfg.fp16_adam_stats: + raise NotImplementedError( + "--fp16-adam-stats is only supported with FusedAdamV1" + ) + self._optimizer = Adam(params, **self.optimizer_config) + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + return { + "lr": self.cfg.lr[0] + if isinstance(self.cfg.lr, Collection) + else self.cfg.lr, + "betas": eval(self.cfg.adam_betas) + if isinstance(self.cfg.adam_betas, str) + else OmegaConf.to_container(self.cfg.adam_betas), + "eps": self.cfg.adam_eps, + "weight_decay": self.cfg.weight_decay, + } + + def average_params(self): + """Reduce Params is only used during BMUF distributed training.""" + state_dict = self.optimizer.state_dict() + total_gpus = float(dist.get_world_size()) + + for _, value in state_dict["state"].items(): + value["exp_avg"] /= total_gpus + value["exp_avg_sq"] /= total_gpus + dist.all_reduce(value["exp_avg"], op=dist.ReduceOp.SUM) + dist.all_reduce(value["exp_avg_sq"], op=dist.ReduceOp.SUM) + + +class Adam(torch.optim.Optimizer): + r"""Implements Adam algorithm. + + This implementation is modified from torch.optim.Adam based on: + `Fixed Weight Decay Regularization in Adam` + (see https://arxiv.org/abs/1711.05101) + + It has been proposed in `Adam: A Method for Stochastic Optimization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + ): + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad + ) + super(Adam, self).__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self): + return True + + @property + def supports_flat_params(self): + return True + + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + amsgrad = group.get("amsgrad", False) + + p_data_fp32 = p.data + if p.data.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p_data_fp32) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32) + else: + state["exp_avg"] = state["exp_avg"].to(p_data_fp32) + state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32) + if amsgrad: + state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to( + p_data_fp32 + ) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + if amsgrad: + max_exp_avg_sq = state["max_exp_avg_sq"] + beta1, beta2 = group["betas"] + + state["step"] += 1 + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = max_exp_avg_sq.sqrt().add_(group["eps"]) + else: + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 + + if group["weight_decay"] != 0: + p_data_fp32.add_( + p_data_fp32, alpha=-group["weight_decay"] * group["lr"] + ) + + p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size) + + if p.data.dtype in {torch.float16, torch.bfloat16}: + p.data.copy_(p_data_fp32) + + return loss diff --git a/fairseq/optim/adamax.py b/fairseq/optim/adamax.py new file mode 100644 index 0000000000000000000000000000000000000000..98ff8ad7ad6c12ab5efc53ca76db2f1663be7906 --- /dev/null +++ b/fairseq/optim/adamax.py @@ -0,0 +1,172 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.optim + +from . import LegacyFairseqOptimizer, register_optimizer + + +@register_optimizer("adamax") +class FairseqAdamax(LegacyFairseqOptimizer): + def __init__(self, args, params): + super().__init__(args) + self._optimizer = Adamax(params, **self.optimizer_config) + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--adamax-betas', default='(0.9, 0.999)', metavar='B', + help='betas for Adam optimizer') + parser.add_argument('--adamax-eps', type=float, default=1e-8, metavar='D', + help='epsilon for Adam optimizer') + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') + parser.add_argument('--no-bias-correction', default=False, action='store_true', + help='disable bias correction') + # fmt: on + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + return { + "lr": self.args.lr[0], + "betas": eval(self.args.adamax_betas), + "eps": self.args.adamax_eps, + "weight_decay": self.args.weight_decay, + "bias_correction": not self.args.no_bias_correction, + } + + +class Adamax(torch.optim.Optimizer): + """Implements Adamax algorithm (a variant of Adam based on infinity norm). + + It has been proposed in `Adam: A Method for Stochastic Optimization`__. + + Compared to the version in PyTorch, this version implements a fix for weight decay. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 2e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + bias_correction (bool, optional): enable bias correction (default: True) + + __ https://arxiv.org/abs/1412.6980 + """ + + def __init__( + self, + params, + lr=2e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + bias_correction=True, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + bias_correction=bias_correction, + ) + super(Adamax, self).__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self): + return True + + @property + def supports_flat_params(self): + return True + + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError("Adamax does not support sparse gradients") + + p_data_fp32 = p.data + if p.data.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p_data_fp32) + state["exp_inf"] = torch.zeros_like(p_data_fp32) + else: + state["exp_avg"] = state["exp_avg"].to(p_data_fp32) + state["exp_inf"] = state["exp_inf"].to(p_data_fp32) + + exp_avg, exp_inf = state["exp_avg"], state["exp_inf"] + beta1, beta2 = group["betas"] + eps = group["eps"] + + state["step"] += 1 + + # Update biased first moment estimate. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + # Update the exponentially weighted infinity norm. + torch.max( + exp_inf.mul_(beta2), + grad.abs_(), + out=exp_inf, + ) + + step_size = group["lr"] + if group["bias_correction"]: + bias_correction = 1 - beta1 ** state["step"] + step_size /= bias_correction + + if group["weight_decay"] != 0: + p_data_fp32.add_( + p_data_fp32, alpha=-group["weight_decay"] * group["lr"] + ) + + p_data_fp32.addcdiv_(exp_avg, exp_inf.add(eps), value=-step_size) + + if p.data.dtype in {torch.float16, torch.bfloat16}: + p.data.copy_(p_data_fp32) + + return loss diff --git a/fairseq/optim/amp_optimizer.py b/fairseq/optim/amp_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..cfe57d07f9f7bc87a7ebb86b71ddb739e3496cf7 --- /dev/null +++ b/fairseq/optim/amp_optimizer.py @@ -0,0 +1,106 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +from fairseq import optim +from omegaconf import DictConfig + +logger = logging.getLogger(__name__) + + +class AMPOptimizer(optim.FairseqOptimizer): + """ + Wrap an *optimizer* to support AMP (automatic mixed precision) training. + """ + + def __init__(self, cfg: DictConfig, params, fp32_optimizer, **kwargs): + super().__init__(cfg.optimizer) + self.fp32_optimizer = fp32_optimizer + amp_kwargs = {"init_scale": cfg.common.fp16_init_scale} + if getattr(cfg.common, "amp_scale_window", None) is not None: + amp_kwargs["growth_interval"] = cfg.common.amp_init_scale + self._grad_scaler = torch.cuda.amp.GradScaler(**amp_kwargs) + self.min_loss_scale = cfg.common.min_loss_scale + + @classmethod + def build_optimizer(cls, cfg: DictConfig, params, **kwargs): + """ + Args: + cfg (omegaconf.DictConfig): fairseq args + params (iterable): iterable of parameters to optimize + """ + fp32_optimizer = optim.build_optimizer(cfg.optimizer, params) + return cls(cfg, params, fp32_optimizer, **kwargs) + + def backward(self, loss): + """Computes the sum of gradients of the given tensor w.r.t. graph leaves. + + Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this + function additionally dynamically scales the loss to avoid gradient + underflow. + """ + self._grad_scaler.scale(loss).backward() + + def step(self): + self.scaler.step(self.fp32_optimizer) + self.scaler.update() + + def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): + """Clips gradient norm.""" + self.scaler.unscale_(self.optimizer) + grad_norm = self.fp32_optimizer.clip_grad_norm(max_norm, aggregate_norm_fn) + if not torch.isfinite(grad_norm).all(): + new_loss_scale = self.next_loss_scale + if new_loss_scale <= self.min_loss_scale: + raise FloatingPointError( + ( + "AMP: Minimum loss scale reached ({}). Your loss is probably exploding. " + "Try restarting training or use fp32. {}" + ).format(self.min_loss_scale, new_loss_scale) + ) + else: + logger.info( + "AMP: overflow detected, setting scale to " f"to {new_loss_scale}" + ) + return grad_norm + + @property + def scaler(self): + return self._grad_scaler + + @property + def next_loss_scale(self): + return self.scaler.get_scale() * self.scaler.get_backoff_factor() + + @property + def optimizer(self): + return self.fp32_optimizer.optimizer + + @optimizer.setter + def optimizer(self, optimizer): + self.fp32_optimizer.optimizer = optimizer + + @property + def lr_scheduler(self): + return getattr(self.fp32_optimizer, "lr_scheduler", None) + + @property + def optimizer_config(self): + return self.fp32_optimizer.optimizer_config + + def get_lr(self): + return self.fp32_optimizer.get_lr() + + def set_lr(self, lr): + self.fp32_optimizer.set_lr(lr) + + def all_reduce_grads(self, module): + self.fp32_optimizer.all_reduce_grads(module) + + @property + def supports_flat_params(self): + return self.fp32_optimizer.supports_flat_params diff --git a/fairseq/optim/bmuf.py b/fairseq/optim/bmuf.py new file mode 100644 index 0000000000000000000000000000000000000000..d6d0e04e86eb894efe59e13a78843d01ca9e651d --- /dev/null +++ b/fairseq/optim/bmuf.py @@ -0,0 +1,200 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +import torch +import torch.distributed as dist +from fairseq.dataclass.configs import FairseqBMUFConfig +from fairseq.dataclass.utils import gen_parser_from_dataclass +from fairseq.optim.fairseq_optimizer import FairseqOptimizer + + +class FairseqBMUF(FairseqOptimizer): + """ + Implements incremental block distributed data parallelism similar to + https://ieeexplore.ieee.org/document/7472805 + + Paper title: Scalable training of deep learning machines by incremental + block training with intra-block parallel optimization and blockwise + model-update filtering + """ + + def __init__(self, cfg: FairseqBMUFConfig, optimizer): + super().__init__(cfg) + self._optimizer = optimizer + self._num_updates = 0 + self.sync_iter = cfg.global_sync_iter + self.block_momentum = cfg.block_momentum + self.block_lr = cfg.block_lr + self._reset_local_data() + self.warmup_iteration = cfg.warmup_iterations + self.use_nbm = cfg.use_nbm + self.initial_state = self._optimizer.state_dict() + self.average_sync = self.cfg.average_sync + self.world_size = self.cfg.distributed_world_size + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + gen_parser_from_dataclass(parser, FairseqBMUFConfig()) + + @property + def optimizer(self): + return self._optimizer.optimizer + + @property + def optimizer_config(self): + return self._optimizer.optimizer_config + + def get_lr(self): + return self._optimizer.get_lr() + + def set_lr(self, lr): + self._optimizer.set_lr(lr) + + def state_dict(self): + return self._optimizer.state_dict() + + def load_state_dict(self, state_dict, optimizer_overrides=None): + self._optimizer.load_state_dict(state_dict, optimizer_overrides) + self.initial_state = self._optimizer.state_dict() + + def multiply_grads(self, c): + """Multiplies grads by a constant *c*.""" + self._optimizer.multiply_grads(c) + + def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): + """Clips gradient norm.""" + return self._optimizer.clip_grad_norm(max_norm, aggregate_norm_fn) + + def average_params(self): + self._optimizer.average_params() + + def _block_sync(self): + if self.world_size <= 1: + return + # Update the global model using local models from all GPUs + # (Step-1) Calculate grad between previously synced model and + # currrent local model + if self.block_momentum != 0: + self._calc_grad() + + # (Step-2) Average gradient from all GPUs + self._avg_grad_from_all_gpus() + + # (Step-3) Calculate global momentum and update the global model + if self.block_momentum != 0: + self._update_global_model() + + # (Step-4) Average local optimizer params + if self.average_sync: + self.average_params() + + def _is_warmup_end(self): + # Check whether train iterations is equal to warmup iter + if self.get_num_updates() == self.warmup_iteration: + return True + return False + + def _is_bmuf_iter(self): + # Check whether train iterations is equal to bmuf sync iter + if (self.get_num_updates() > self.warmup_iteration) and ( + self.get_num_updates() % self.sync_iter == 0 + ): + return True + return False + + def _warmup_sync(self, root_rank=0): + if self.world_size <= 1: + return + # Broadcast the local model to all gpus + for param in self.params: + dist.broadcast(param.data, src=root_rank) + + # Update local optimizer state + if self.average_sync: + self._optimizer.average_params() + else: + self._optimizer.load_state_dict(self.initial_state) + + self._reset_local_data() + + def step(self, closure=None): + """Performs a single optimization step.""" + self._optimizer.step(closure) + self.set_num_updates(self.get_num_updates() + 1) + if self._is_warmup_end(): + self._warmup_sync() + elif self._is_bmuf_iter(): + self._block_sync() + + def zero_grad(self): + """Clears the gradients of all optimized parameters.""" + self._optimizer.zero_grad() + + def get_num_updates(self): + """Get the number of parameters updates.""" + return self._num_updates + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + self._num_updates = num_updates + + @torch.no_grad() + def _reset_local_data(self): + # (Step-0) Initialize global momentum parameters and store global copy on each gpu + self.global_params = [torch.zeros_like(p.data) for p in self.params] + self.smoothed_grads = [p.data.new_zeros(p.data.size()) for p in self.params] + self.grads = [p.data.new_zeros(p.data.size()) for p in self.params] + + # saving the global model locally for calculating gradient during bmuf sync + for param, global_param in zip(self.params, self.global_params): + global_param.copy_(param.data) + + @torch.no_grad() + def _calc_grad(self): + # global_params is basically the global copy from the previously finished + # synchronisation. param.data is local parameter after block_sync_freq + # for the local gpu. so grad is difference between previously synced + # model and currrent local model. + for index, (param, global_param) in enumerate( + zip(self.params, self.global_params) + ): + self.grads[index] = global_param - param.data + + def _avg_grad_from_all_gpus(self): + for index, param in enumerate(self.params): + sync_para = param.data if self.block_momentum == 0 else self.grads[index] + sync_para /= float(dist.get_world_size()) + dist.all_reduce(sync_para, op=dist.ReduceOp.SUM) + + @torch.no_grad() + def _update_global_model(self): + for index, (param, global_param, smoothed_grad, grad) in enumerate( + zip( + self.params, + self.global_params, + self.smoothed_grads, + # all gpus would share the same value of smoothed_grad, since it is + # always computed on synchronized gradients. + self.grads, + ) + ): + # global_param is basically last syncrhornized parameter. though + # smoothed_grad is local, all processes will have same value of + # smoothed_grad and hence param is globally synchronized copy. + # smoothed_grad(t) = BM * smoothed_grad(t-1) + BM_lr * grad(t) + smoothed_grad = self.block_momentum * smoothed_grad + self.block_lr * grad + param.data.copy_(global_param - smoothed_grad) + + # A Nesterov momentum here is to do a partial weight update before + # calculating the gradient + if self.use_nbm: + param.data.copy_(param.data - self.block_momentum * smoothed_grad) + + # backup for the next synchronization. + self.smoothed_grads[index] = smoothed_grad + global_param.copy_(param.data) diff --git a/fairseq/optim/composite.py b/fairseq/optim/composite.py new file mode 100644 index 0000000000000000000000000000000000000000..1ef0114ed63dcfbdf2898a257f38453ad2c69748 --- /dev/null +++ b/fairseq/optim/composite.py @@ -0,0 +1,273 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Dict, Any, List, Optional + +import torch.optim +from fairseq.dataclass import FairseqDataclass +from fairseq.optim import FairseqOptimizer, register_optimizer, _build_optimizer +from fairseq.optim.lr_scheduler import FairseqLRScheduler, build_lr_scheduler +from omegaconf import II, open_dict +import copy + + +logger = logging.getLogger(__name__) + + +@dataclass +class OptimizerAndSchedulerConfig(FairseqDataclass): + optimizer: Any = None + lr_scheduler: Optional[Any] = None + lr: List = II("optimization.lr") + lr_float: Optional[ + float + ] = None # this makes it easier to sweep on learning rate with auto sweepers + + +@dataclass +class CompositeOptimizerConfig(FairseqDataclass): + groups: Dict[str, Any] = field( + default_factory=lambda: {}, + metadata={ + "help": "optimizer name -> optimizer OptimizerAndSchedulerConfig. " + "Configures a different optimizer and (optionally) lr scheduler for each parameter group" + }, + ) + dynamic_groups: bool = field( + default=False, + metadata={ + "help": "create groups dynamically based on parameters, if set to False, all parameters needs to have group_names" + }, + ) + + +@register_optimizer("composite", dataclass=CompositeOptimizerConfig) +class FairseqCompositeOptimizer(FairseqOptimizer): + + optimizers: Dict[str, FairseqOptimizer] = {} + lr_schedulers: Dict[str, FairseqLRScheduler] = {} + lr_scheduler: FairseqLRScheduler = None + _optimizer: torch.optim.Optimizer + + def __init__(self, cfg: CompositeOptimizerConfig, params): + super().__init__(cfg) + + assert ( + len(params) > 1 + ), "Composite optimizer only works when there are multiple parameter groups (try fp16_no_flatten_grads: true)" + + def dict_hash(dictionary: Dict[str, Any]) -> str: + import hashlib + import json + + dhash = hashlib.md5() + encoded = json.dumps(dictionary, sort_keys=True).encode() + dhash.update(encoded) + return dhash.hexdigest() + + groupped_params = defaultdict(list) + overrides = defaultdict(dict) + if not cfg.dynamic_groups: + for p in params: + group = getattr(p, "param_group", "default") + override_config = getattr(p, "optim_overrides", None) + if override_config is not None and bool(override_config): + overrides[group] = override_config + else: + assert ( + override_config == None or override_config == overrides[group] + ), f"For group {group}, different overrides found {override_config} v/s {overrides[group]}" + groupped_params[group].append(p) + + for p, params in groupped_params.items(): + override_config = getattr(params[0], "optim_overrides", None) + if override_config is not None: + for pp in params[1:]: + assert override_config == getattr( + pp, "optim_overrides", None + ), f" {str(override_config)} != {str(getattr(pp, 'optim_overrides', None))}" + else: + for p in params: + group = getattr(p, "param_group", "default") + override_config = getattr(p, "optim_overrides", None) + if override_config is not None: + override_config["group_name"] = group + group_name = dict_hash(override_config) + overrides[group_name] = override_config + else: + group_name = group + groupped_params[group_name].append(p) + + self.optimizers_config = {} + for group, group_params in groupped_params.items(): + p_group = group + if group in overrides and "group_name" in overrides[group]: + p_group = overrides[group]["group_name"] + if group in cfg.groups: + group_cfg = cfg.groups[group] + optimizer_config = copy.deepcopy(group_cfg.optimizer) + scheduler_config = copy.deepcopy(group_cfg.lr_scheduler) + explicit_group_present = True + else: + group_cfg = cfg.groups[p_group] + optimizer_config = copy.deepcopy(group_cfg.optimizer) + scheduler_config = copy.deepcopy(group_cfg.lr_scheduler) + explicit_group_present = False + + if getattr(group_cfg, "lr_float", None) is not None: + with open_dict(optimizer_config): + optimizer_config.lr = [group_cfg.lr_float] + + if group in overrides and "optimizer" in overrides[group]: + with open_dict(optimizer_config): + if "lr_scale" in overrides[group]["optimizer"]: + lr_scale = overrides[group]["optimizer"]["lr_scale"] + optimizer_config.lr = [ + lr * lr_scale for lr in optimizer_config.lr + ] + + if explicit_group_present: + logger.info( + f"For group:{group}, config as well as override present for lr" + ) + + if ( + "weight_decay_scale" in overrides[group]["optimizer"] + and "optimizer_config" in optimizer_config + ): + weight_decay_scale = overrides[group]["optimizer"][ + "weight_decay_scale" + ] + optimizer_config.weight_decay = ( + optimizer_config.weight_decay * weight_decay_scale + ) + if explicit_group_present: + logger.info( + f"For group:{group}, config as well as override present for weight_decay" + ) + + with open_dict(scheduler_config): + scheduler_config.lr = optimizer_config.lr + self.optimizers[group] = _build_optimizer(optimizer_config, group_params) + self.optimizers_config[group] = optimizer_config + if scheduler_config is not None: + self.lr_schedulers[group] = build_lr_scheduler( + scheduler_config, self.optimizers[group] + ) + logger.info("Optimizers for different groups are as below") + for group in self.optimizers_config.keys(): + logger.info(f"Group : {group}:{self.optimizers_config[group]}") + if len(self.lr_schedulers) > 0: + assert len(self.lr_schedulers) == len(self.optimizers), ( + f"Please provide an lr scheduler for each optimizer to use pass_through scheduler. " + f"Optimizers: {self.optimizers}; Lr scheds: {self.lr_schedulers}" + ) + self.lr_scheduler = CompositeLRScheduler(self.lr_schedulers) + + self._optimizer = CompositeOptimizer(self.optimizers) + + @property + def supports_groups(self): + return True + + @property + def param_groups(self): + for opt in self.optimizers.values(): + for group in opt.param_groups: + yield group + + def get_lr(self): + """Return the current learning rate.""" + k = ( + "default" + if "default" in self.optimizers + else next(iter(self.optimizers.keys())) + ) + return self.optimizers[k].param_groups[0]["lr"] + + def state_dict(self): + """Return the LR scheduler state dict.""" + return {k: s.state_dict() for k, s in self.optimizers.items()} + + def load_state_dict(self, state_dict, optimizer_overrides=None): + """Load an LR scheduler state dict.""" + for k, state in state_dict.items(): + if k not in self.optimizers: + # skip extra keys like "loss_scale" added by fp16 optimizer + continue + + overrides = ( + optimizer_overrides[k] + if isinstance(optimizer_overrides, dict) and k in optimizer_overrides + else None + ) + self.optimizers[k].load_state_dict(state, optimizer_overrides=overrides) + + +class CompositeOptimizer(torch.optim.Optimizer): + def __init__(self, optimizers: Dict[str, FairseqOptimizer]): + self.optimizers = optimizers + + @property + def supports_memory_efficient_fp16(self): + return all(o.supports_memory_efficient_fp16 for o in self.optimizers.values()) + + @property + def supports_flat_params(self): + return all(o.supports_flat_params for o in self.optimizers.values()) + + def step(self, closure=None, groups=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for k, opt in self.optimizers.items(): + if groups is None or k in groups: + opt.step() + + return loss + + def zero_grad(self): + for opt in self.optimizers.values(): + opt.zero_grad() + + +class CompositeLRScheduler(FairseqLRScheduler): + def __init__(self, lr_schedulers): + super().__init__(None, None) + + self.lr_schedulers = lr_schedulers + + def state_dict(self): + """Return the LR scheduler state dict.""" + return {k: s.state_dict() for k, s in self.lr_schedulers.items()} + + def load_state_dict(self, state_dict): + """Load an LR scheduler state dict.""" + for k, state in state_dict.items(): + self.lr_schedulers[k].load_state_dict(state) + + def step_begin_epoch(self, epoch): + """Update the learning rate at the beginning of the given epoch.""" + for s in self.lr_schedulers.values(): + s.step_begin_epoch(epoch) + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + for s in self.lr_schedulers.values(): + s.step(epoch) + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + return {k: s.step_update(num_updates) for k, s in self.lr_schedulers.items()} diff --git a/fairseq/optim/cpu_adam.py b/fairseq/optim/cpu_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..b218934e717cacab98e04b8832a24352fdd1de1a --- /dev/null +++ b/fairseq/optim/cpu_adam.py @@ -0,0 +1,210 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +from collections.abc import Collection +from dataclasses import dataclass, field +from typing import List + +import torch +from fairseq.dataclass import FairseqDataclass +from fairseq.optim import FairseqOptimizer, register_optimizer +from omegaconf import II, DictConfig + + +try: + import deepspeed + + has_deepspeed = True +except ImportError as e: + has_deepspeed = False + + +def _get_cpu_adam(): + try: + from deepspeed.ops.op_builder import CPUAdamBuilder + + return CPUAdamBuilder().load() + except ImportError: + # fbcode + from deepspeed.ops.adam import DeepSpeedCPUAdam as ds_opt_adam + + return ds_opt_adam + + +@dataclass +class FairseqCPUAdamConfig(FairseqDataclass): + adam_betas: str = field( + default="(0.9, 0.999)", metadata={"help": "betas for Adam optimizer"} + ) + adam_eps: float = field( + default=1e-8, metadata={"help": "epsilon for Adam optimizer"} + ) + weight_decay: float = field(default=0.0, metadata={"help": "weight decay"}) + fp16_adam_stats: bool = field( + default=False, metadata={"help": "use FP16 stats (with automatic scaling)"} + ) + # TODO common vars below in parent + lr: List[float] = II("optimization.lr") + + +@register_optimizer("cpu_adam", dataclass=FairseqCPUAdamConfig) +class FairseqCPUAdam(FairseqOptimizer): + """Adam optimizer for fairseq, optimized for CPU tensors. + + Important note: this optimizer corresponds to the "AdamW" variant of + Adam in its weight decay behavior. As such, it is most closely + analogous to torch.optim.AdamW from PyTorch. + """ + + def __init__(self, cfg: DictConfig, params): + super().__init__(cfg) + self._optimizer = CPUAdam(params, **self.optimizer_config) + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + return { + "lr": self.cfg.lr[0] + if isinstance(self.cfg.lr, Collection) + else self.cfg.lr, + "betas": eval(self.cfg.adam_betas), + "eps": self.cfg.adam_eps, + "weight_decay": self.cfg.weight_decay, + "use_fp16_stats": self.cfg.fp16_adam_stats, + } + + +class CPUAdam(torch.optim.Optimizer): + + optimizer_id = 0 + + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + use_fp16_stats=False, + ): + defaults = { + "lr": lr, + "bias_correction": bias_correction, + "betas": betas, + "eps": eps, + "weight_decay": weight_decay, + } + super().__init__(params, defaults) + + self.use_fp16_stats = use_fp16_stats + self.FLOAT16_MAX = 65504.0 + + if not has_deepspeed: + raise ImportError("Please install DeepSpeed: pip install deepspeed") + + self.opt_id = CPUAdam.optimizer_id + CPUAdam.optimizer_id = CPUAdam.optimizer_id + 1 + + self.ds_opt_adam = _get_cpu_adam() + adamw_mode = True + self.ds_opt_adam.create_adam( + self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode + ) + + @property + def supports_memory_efficient_fp16(self): + return True + + @property + def supports_flat_params(self): + return True + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + torch.cuda.synchronize() + + for group_id, group in enumerate(self.param_groups): + for param_id, p in enumerate(group["params"]): + if p.grad is None: + continue + + state = self.state[p] + if len(state) == 0: + state["step"] = 0 + dtype = torch.float16 if self.use_fp16_stats else p.data.dtype + # gradient momentums + state["exp_avg"] = torch.zeros_like( + p.data, dtype=dtype, device="cpu" + ) + # gradient variances + state["exp_avg_sq"] = torch.zeros_like( + p.data, dtype=dtype, device="cpu" + ) + if self.use_fp16_stats: + assert torch.is_floating_point(p.data) + state["exp_avg_scale"] = 1.0 + state["exp_avg_sq_scale"] = 1.0 + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + p_data_bak = p.data # backup of the original data pointer + + p.data = p.data.to(dtype=torch.float32, device="cpu") + p.grad.data = p.grad.data.to(dtype=torch.float32, device="cpu") + + if self.use_fp16_stats: + exp_avg = exp_avg.float() * state["exp_avg_scale"] + exp_avg_sq = exp_avg_sq.float() * state["exp_avg_sq_scale"] + + state["step"] += 1 + beta1, beta2 = group["betas"] + + self.ds_opt_adam.adam_update( + self.opt_id, + state["step"], + group["lr"], + beta1, + beta2, + group["eps"], + group["weight_decay"], + group["bias_correction"], + p.data, + p.grad.data, + exp_avg, + exp_avg_sq, + ) + + if p_data_bak.data_ptr() != p.data.data_ptr(): + p_data_bak.copy_(p.data) + p.data = p_data_bak + + if self.use_fp16_stats: + + def inf_norm(t): + return torch.norm(t, float("inf")) + + # from github.com/openai/jukebox/blob/master/jukebox/utils/fp16.py + state["exp_avg_scale"], state["exp_avg_sq_scale"] = ( + 1e-8 + inf_norm(exp_avg) / self.FLOAT16_MAX, + 1e-8 + inf_norm(exp_avg_sq) / self.FLOAT16_MAX, + ) + state["exp_avg"], state["exp_avg_sq"] = ( + (exp_avg / state["exp_avg_scale"]).half(), + (exp_avg_sq / state["exp_avg_sq_scale"]).half(), + ) + + return loss diff --git a/fairseq/optim/dynamic_loss_scaler.py b/fairseq/optim/dynamic_loss_scaler.py new file mode 100644 index 0000000000000000000000000000000000000000..60c47b8db0a4d907be966bf39fa24effa2c825db --- /dev/null +++ b/fairseq/optim/dynamic_loss_scaler.py @@ -0,0 +1,70 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +class DynamicLossScaler(object): + def __init__( + self, + init_scale=2.0**15, + scale_factor=2.0, + scale_window=2000, + tolerance=0.0, + threshold=None, + min_loss_scale=1e-4, + ): + self.loss_scale = init_scale + self.scale_factor = scale_factor + self.scale_window = scale_window + self.tolerance = tolerance + self.threshold = threshold + self._iter = 0 + self._last_overflow_iter = -1 + self._last_rescale_iter = -1 + self._overflows_since_rescale = 0 + self.min_loss_scale = min_loss_scale + + def scale(self, outputs): + return self.loss_scale * outputs + + def update(self): + if (self._iter - self._last_overflow_iter) % self.scale_window == 0: + self.loss_scale *= self.scale_factor + self._last_rescale_iter = self._iter + self._iter += 1 + + def _decrease_loss_scale(self): + self.loss_scale /= self.scale_factor + if self.threshold is not None: + self.loss_scale = max(self.loss_scale, self.threshold) + + def check_overflow(self, grad_norm): + # detect inf and nan + if grad_norm == float("inf") or grad_norm != grad_norm: + # overflow has occured + prev_scale = self.loss_scale + iter_since_rescale = self._iter - self._last_rescale_iter + + self._last_overflow_iter = self._iter + self._overflows_since_rescale += 1 + pct_overflow = self._overflows_since_rescale / float(iter_since_rescale) + if pct_overflow >= self.tolerance: + self._decrease_loss_scale() + self._last_rescale_iter = self._iter + self._overflows_since_rescale = 0 + + if self.loss_scale <= self.min_loss_scale: + # Use FloatingPointError as an uncommon error that parent + # functions can safely catch to stop training. + self.loss_scale = prev_scale + raise FloatingPointError( + ( + "Minimum loss scale reached ({}). Your loss is probably exploding. " + "Try lowering the learning rate, using gradient clipping or " + "increasing the batch size." + ).format(self.min_loss_scale) + ) + + self._iter += 1 + raise OverflowError("setting loss scale to: " + str(self.loss_scale)) diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..73c7c695ee5ce71e1dd44b4deb0990f57858cb38 --- /dev/null +++ b/fairseq/optim/fairseq_optimizer.py @@ -0,0 +1,187 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from fairseq import utils +from fairseq.dataclass.utils import gen_parser_from_dataclass +from collections import defaultdict + + +class FairseqOptimizer(object): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + + @classmethod + def add_args(cls, parser): + """Add optimizer-specific arguments to the parser.""" + dc = getattr(cls, "__dataclass", None) + if dc is not None: + gen_parser_from_dataclass(parser, dc()) + + @property + def optimizer(self): + """Return a torch.optim.optimizer.Optimizer instance.""" + if not hasattr(self, "_optimizer"): + raise NotImplementedError + if not isinstance(self._optimizer, torch.optim.Optimizer): + raise ValueError("_optimizer must be an instance of torch.optim.Optimizer") + return self._optimizer + + @optimizer.setter + def optimizer(self, optimizer): + """Reset optimizer instance.""" + if not hasattr(self, "_optimizer"): + raise NotImplementedError + if not isinstance(self._optimizer, torch.optim.Optimizer): + raise ValueError("_optimizer must be an instance of torch.optim.Optimizer") + self._optimizer = optimizer + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + raise NotImplementedError + + @property + def params(self): + """Return an iterable of the parameters held by the optimizer.""" + for param_group in self.param_groups: + for p in param_group["params"]: + yield p + + @property + def param_groups(self): + return self.optimizer.param_groups + + def __getstate__(self): + return self._optimizer.__getstate__() + + def get_lr(self): + """Return the current learning rate.""" + return self.param_groups[0]["lr"] + + def set_lr(self, lr): + """Set the learning rate.""" + for param_group in self.param_groups: + param_group["lr"] = lr + + def state_dict(self): + """Return the optimizer's state dict.""" + return self.optimizer.state_dict() + + def load_state_dict(self, state_dict, optimizer_overrides=None): + """Load an optimizer state dict. + + In general we should prefer the configuration of the existing optimizer + instance (e.g., learning rate) over that found in the state_dict. This + allows us to resume training from a checkpoint using a new set of + optimizer args. + """ + self.optimizer.load_state_dict(state_dict) + + if optimizer_overrides is not None and len(optimizer_overrides) > 0: + # override learning rate, momentum, etc. with latest values + for group in self.param_groups: + group.update(optimizer_overrides) + + def backward(self, loss): + """Computes the sum of gradients of the given tensor w.r.t. graph leaves.""" + loss.backward() + + def all_reduce_grads(self, module): + """Manually all-reduce gradients (if required).""" + if hasattr(module, "all_reduce_grads"): + module.all_reduce_grads() + + def multiply_grads(self, c): + """Multiplies grads by a constant *c*.""" + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) + for p in self.params: + if p.grad is not None: + if p.grad.is_sparse: + p.grad.data.mul_(c.to(p.grad.device) if torch.is_tensor(c) else c) + else: + per_device_and_dtype_grads[p.grad.device][p.grad.dtype].append( + p.grad.data + ) + for device, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + torch._foreach_mul_(grads, c.to(device) if torch.is_tensor(c) else c) + + def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): + """Clips gradient norm.""" + return utils.clip_grad_norm_(self.params, max_norm, aggregate_norm_fn) + + def step(self, closure=None, scale=1.0, groups=None): + """Performs a single optimization step.""" + if self.supports_step_with_scale: + if self.supports_groups: + self.optimizer.step(closure, scale=scale, groups=groups) + else: + self.optimizer.step(closure, scale=scale) + else: + if scale != 1.0: + self.multiply_grads(1.0 / scale) + if self.supports_groups: + self.optimizer.step(closure, groups=groups) + else: + self.optimizer.step(closure) + + def zero_grad(self): + """Clears the gradients of all optimized parameters.""" + for p in self.params: + p.grad = None + self.optimizer.zero_grad() + + @property + def supports_memory_efficient_fp16(self): + if hasattr(self.optimizer, "supports_memory_efficient_fp16"): + return self.optimizer.supports_memory_efficient_fp16 + return False + + @property + def supports_step_with_scale(self): + if hasattr(self.optimizer, "supports_step_with_scale"): + return self.optimizer.supports_step_with_scale + return False + + @property + def supports_groups(self): + if hasattr(self.optimizer, "supports_groups"): + return self.optimizer.supports_groups + return False + + @property + def supports_flat_params(self): + """ + Whether the optimizer supports collapsing of the model + parameters/gradients into a single contiguous Tensor. + """ + if hasattr(self.optimizer, "supports_flat_params"): + return self.optimizer.supports_flat_params + return False + + def average_params(self): + pass + + def broadcast_global_state_dict(self, state_dict): + """ + Broadcasts a global state dict to all ranks. + Useful for optimizers that shard state between ranks. + """ + if hasattr(self.optimizer, "broadcast_global_state_dict"): + return self.optimizer.broadcast_global_state_dict(state_dict) + else: + return state_dict + + +class LegacyFairseqOptimizer(FairseqOptimizer): + def __init__(self, args): + self.args = args diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..6a4da342cab8526ca09893cfd1fbee82be955faa --- /dev/null +++ b/fairseq/optim/fp16_optimizer.py @@ -0,0 +1,558 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections import defaultdict +from itertools import chain + +import torch +from omegaconf import DictConfig + +from fairseq import optim + +from .dynamic_loss_scaler import DynamicLossScaler + + +class _FP16OptimizerMixin(object): + def __init__(self, *args, **kwargs): + # forward __init__ call to the next class in mro(method resolution order) + super().__init__(*args, **kwargs) + self._multiply_factor = 1.0 + + @property + def has_flat_params(self): + return torch.is_tensor(self.fp32_params) or ( + isinstance(self.fp32_params, dict) + and all(torch.is_tensor(t) for t in self.fp32_params.values()) + ) + + @classmethod + def build_fp32_params(cls, args, params, flatten=True): + # create FP32 copy of parameters and grads + if flatten: + is_pipeline_parallel = getattr( + args, "pipeline_model_parallel", False + ) and getattr(args, "distributed_no_spawn", False) + total_param_size = sum(p.data.numel() for p in params) + devices = [torch.cuda.current_device()] + if is_pipeline_parallel: + devices = list(set(args.pipeline_devices)) + fp32_params = {} + for device in devices: + if is_pipeline_parallel: + device_param_size = sum( + p.data.numel() for p in params if p.device.index == device + ) + device_params = [p for p in params if p.device.index == device] + else: + device_param_size = total_param_size + device_params = params + fp32_params[device] = ( + device_params[0].new(0).float().new(device_param_size) + ) + offset = 0 + for p in device_params: + numel = p.data.numel() + fp32_params[device][offset : offset + numel].copy_(p.data.view(-1)) + offset += numel + fp32_params[device] = torch.nn.Parameter(fp32_params[device]) + fp32_params[device].grad = fp32_params[device].data.new( + device_param_size + ) + return fp32_params + else: + fp32_params = [] + for p in params: + p32 = torch.nn.Parameter(p.data.float()) + if hasattr(p, "expert"): + p32.expert = True + elif hasattr(p, "base_expert"): + p32.base_expert = True + p32.grad = torch.zeros_like(p32.data) + if hasattr(p, "param_group"): + p32.param_group = p.param_group + if hasattr(p, "optim_overrides"): + p32.optim_overrides = p.optim_overrides + fp32_params.append(p32) + return fp32_params + + def state_dict(self): + """Return the optimizer's state dict.""" + state_dict = self.fp32_optimizer.state_dict() + if self.scaler is not None: + state_dict["loss_scale"] = self.scaler.loss_scale + return state_dict + + def load_state_dict(self, state_dict, optimizer_overrides=None): + """Load an optimizer state dict. + + In general we should prefer the configuration of the existing optimizer + instance (e.g., learning rate) over that found in the state_dict. This + allows us to resume training from a checkpoint using a new set of + optimizer args. + """ + if "loss_scale" in state_dict and self.scaler is not None: + self.scaler.loss_scale = state_dict["loss_scale"] + self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides) + + def backward(self, loss): + """Computes the sum of gradients of the given tensor w.r.t. graph leaves. + + Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this + function additionally dynamically scales the loss to avoid gradient + underflow. + """ + if self.scaler is not None: + loss = self.scaler.scale(loss) + loss.backward() + self._needs_sync = True + + def _sync_fp16_grads_to_fp32(self): + if self._needs_sync: + # copy FP16 grads to FP32 + if self.has_flat_params: + devices = list(self.fp32_params.keys()) + device_params_dict = defaultdict(list) + for p in self.fp16_params: + if p.requires_grad: + device_params_dict[p.device.index].append(p) + for device in devices: + device_params = device_params_dict[device] + offset = 0 + for p in device_params: + grad_data = ( + p.grad.data + if p.grad is not None + else p.data.new_zeros(p.data.shape) + ) + numel = grad_data.numel() + self.fp32_params[device].grad.data[ + offset : offset + numel + ].copy_(grad_data.view(-1)) + offset += numel + else: + for p, p32 in zip(self.fp16_params, self.fp32_params): + if not p.requires_grad: + continue + if p.grad is not None: + if p32.grad is None: + p32.grad = p.grad.data.float() + else: + p32.grad.data.copy_(p.grad.data) + else: + p32.grad = torch.zeros_like(p.data, dtype=torch.float) + + self._needs_sync = False + + def _sync_fp32_params_to_fp16(self): + # copy FP32 params back into FP16 model + if self.has_flat_params: + devices = list(self.fp32_params.keys()) + device_params_dict = defaultdict(list) + for p in self.fp16_params: + device_params_dict[p.device.index].append(p) + for device in devices: + device_params = device_params_dict[device] + offset = 0 + for p in device_params: + numel = p.data.numel() + p.data.copy_( + self.fp32_params[device] + .data[offset : offset + numel] + .view_as(p.data) + ) + offset += numel + else: + for p, p32 in zip(self.fp16_params, self.fp32_params): + if not p.requires_grad: + continue + p.data.copy_(p32.data) + + def _unscale_grads(self): + self._sync_fp16_grads_to_fp32() + if ( + # Skip the multiplication if it's a no-op (i.e., if _multiply_factor + # is 1.0). At the same time, we want to avoid the device-to-host + # transfer by comparing it to 1.0. Since _multiply_factor starts as + # a Python float, we roughly assume that if it's a tensor then it's + # probably not =1.0 anymore and we do the multiplication. Otherwise + # we can safely check the value without a D2H transfer. + torch.is_tensor(self._multiply_factor) + or self._multiply_factor != 1.0 + ): + self.fp32_optimizer.multiply_grads(self._multiply_factor) + self._multiply_factor = 1.0 + + def multiply_grads(self, c): + """Multiplies grads by a constant ``c``.""" + self._multiply_factor *= c + + def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): + """Clips gradient norm and updates dynamic loss scaler.""" + self._sync_fp16_grads_to_fp32() + + grad_norm = self._multiply_factor * self.fp32_optimizer.clip_grad_norm( + 0, aggregate_norm_fn + ) + + if torch.is_tensor(self._multiply_factor): + self._multiply_factor = self._multiply_factor.to(grad_norm.device) + + if self.scaler is not None: + if grad_norm > max_norm > 0.0: + self._multiply_factor *= max_norm / grad_norm + + self.scaler.check_overflow(grad_norm) + elif max_norm > 0.0: + clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1) + self._multiply_factor *= clip_coef + + return grad_norm + + def step(self, closure=None, groups=None): + """Performs a single optimization step.""" + self._sync_fp16_grads_to_fp32() + + if getattr(self, "supports_step_with_scale", False): + self.fp32_optimizer.step( + closure, scale=(1.0 / self._multiply_factor), groups=groups + ) + else: + self._unscale_grads() + self.fp32_optimizer.step(closure, groups=groups) + + if self.scaler is not None: + self.scaler.update() + + self._sync_fp32_params_to_fp16() + + def zero_grad(self): + """Clears the gradients of all optimized parameters.""" + for p in self.fp16_params: + p.grad = None + if self.has_flat_params: + if torch.is_tensor(self.fp32_params): + self.fp32_params.grad.zero_() + elif isinstance(self.fp32_params, dict): + for fp32_params in self.fp32_params.values(): + fp32_params.grad.zero_() + else: + raise RuntimeError("self.fp32_params must be a tensor or dict") + else: + for p32 in self.fp32_params: + if p32.grad is not None: + p32.grad.zero_() + self._needs_sync = False + + if self.scaler is not None: + self._multiply_factor = 1.0 / float(self.scaler.loss_scale) + + +class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer): + """ + Wrap an *optimizer* to support FP16 (mixed precision) training. + """ + + def __init__(self, cfg: DictConfig, params, fp32_optimizer, fp32_params, **kwargs): + super().__init__(cfg.optimizer) + self.fp16_params = params + self.fp32_optimizer = fp32_optimizer + self.fp32_params = fp32_params + + if getattr(cfg.common, "fp16_scale_window", None) is None: + if len(cfg.optimization.update_freq) > 1: + raise ValueError( + "--fp16-scale-window must be given explicitly when using a " + "custom --update-freq schedule" + ) + data_parallel_size = int( + cfg.distributed_training.distributed_world_size + / cfg.common.model_parallel_size + ) + scale_window = int( + 2**14 / data_parallel_size / cfg.optimization.update_freq[0] + ) + else: + scale_window = cfg.common.fp16_scale_window + + if not getattr(cfg.common, "bf16", False): + self.scaler = DynamicLossScaler( + init_scale=cfg.common.fp16_init_scale, + scale_window=scale_window, + tolerance=cfg.common.fp16_scale_tolerance, + threshold=cfg.common.threshold_loss_scale, + min_loss_scale=cfg.common.min_loss_scale, + ) + else: + # disable loss scaling for bfloat16 + self.scaler = None + + @classmethod + def build_optimizer(cls, cfg: DictConfig, params, **kwargs): + """ + Args: + cfg (omegaconf.DictConfig): fairseq args + params (iterable): iterable of parameters to optimize + """ + flatten = not getattr(cfg.common, "fp16_no_flatten_grads", False) + if getattr(cfg.common, "bf16", False): + flatten = False # mixed precision is faster on TPUs without flat grads + fp32_params = cls.build_fp32_params(cfg.optimizer, params, flatten=flatten) + if flatten: + fp32_optimizer = optim.build_optimizer(cfg.optimizer, [fp32_params]) + else: + fp32_optimizer = optim.build_optimizer(cfg.optimizer, fp32_params) + if flatten and not fp32_optimizer.supports_flat_params: + raise RuntimeError( + f"chosen optimizer {fp32_optimizer.__class__.__name__} does not support flat params, please set --fp16-no-flatten-grads" + ) + return cls(cfg, params, fp32_optimizer, fp32_params, **kwargs) + + @property + def optimizer(self): + return self.fp32_optimizer.optimizer + + @optimizer.setter + def optimizer(self, optimizer): + self.fp32_optimizer.optimizer = optimizer + + @property + def lr_scheduler(self): + return getattr(self.fp32_optimizer, "lr_scheduler", None) + + @property + def optimizer_config(self): + return self.fp32_optimizer.optimizer_config + + def get_lr(self): + return self.fp32_optimizer.get_lr() + + def set_lr(self, lr): + self.fp32_optimizer.set_lr(lr) + + def all_reduce_grads(self, module): + self.fp32_optimizer.all_reduce_grads(module) + + @property + def supports_flat_params(self): + return self.fp32_optimizer.supports_flat_params + + +class _MemoryEfficientFP16OptimizerMixin(object): + def __init__(self, *args, **kwargs): + # forward __init__ call to the next class in MRO (method resolution order) + super().__init__(*args, **kwargs) + self._multiply_factor = 1.0 + + @property + def has_flat_params(self): + return False + + def state_dict(self): + """Return the optimizer's state dict.""" + state_dict = self.wrapped_optimizer.state_dict() + if self.scaler is not None: + state_dict["loss_scale"] = self.scaler.loss_scale + return state_dict + + def load_state_dict(self, state_dict, optimizer_overrides=None): + """Load an optimizer state dict. + + In general we should prefer the configuration of the existing optimizer + instance (e.g., learning rate) over that found in the state_dict. This + allows us to resume training from a checkpoint using a new set of + optimizer args. + """ + if "loss_scale" in state_dict and self.scaler is not None: + self.scaler.loss_scale = state_dict["loss_scale"] + + self.wrapped_optimizer.load_state_dict(state_dict, optimizer_overrides) + + # Hack: PyTorch automatically casts the optimizer state to match the + # type of the current parameters. But with --memory-efficient-fp16 the + # params are FP16 while the optimizer state is FP32 and we don't want + # to cast. A workaround is to manually copy back the original state + # after the optimizer has been loaded. + if not getattr(self.optimizer, "disable_mem_eff_fp16_loading_hack", False): + groups = self.optimizer.param_groups + saved_groups = state_dict["param_groups"] + id_map = { + old_id: p + for old_id, p in zip( + chain(*(g["params"] for g in saved_groups)), + chain(*(g["params"] for g in groups)), + ) + } + for k, v in state_dict["state"].items(): + if k in id_map: + param = id_map[k] + self.optimizer.state[param] = v + + def backward(self, loss): + """Computes the sum of gradients of the given tensor w.r.t. graph leaves. + + Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this + function additionally dynamically scales the loss to avoid gradient + underflow. + """ + if self.scaler is not None: + loss = self.scaler.scale(loss) + loss.backward() + + def _unscale_grads(self): + if ( + # Skip the multiplication if it's a no-op (i.e., if _multiply_factor + # is 1.0). At the same time, we want to avoid the device-to-host + # transfer by comparing it to 1.0. Since _multiply_factor starts as + # a Python float, we roughly assume that if it's a tensor then it's + # probably not =1.0 anymore and we do the multiplication. Otherwise + # we can safely check the value without a D2H transfer. + torch.is_tensor(self._multiply_factor) + or self._multiply_factor != 1.0 + ): + self.wrapped_optimizer.multiply_grads(self._multiply_factor) + self._multiply_factor = 1.0 + + def multiply_grads(self, c): + """Multiplies grads by a constant *c*.""" + self._multiply_factor *= c + + def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): + """Clips gradient norm and updates dynamic loss scaler.""" + max_norm = float(max_norm) + grad_norm = self._multiply_factor * self.wrapped_optimizer.clip_grad_norm( + 0, aggregate_norm_fn + ) + + if self.scaler is not None: + grad_norm_cpu = float(grad_norm) + if grad_norm_cpu > max_norm > 0.0: + self._multiply_factor *= max_norm / grad_norm_cpu + + # detect overflow and adjust loss scale + self.scaler.check_overflow(grad_norm_cpu) + elif max_norm > 0.0: + clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1) + self._multiply_factor *= clip_coef + + return grad_norm + + def step(self, closure=None, groups=None): + """Performs a single optimization step.""" + if getattr(self, "supports_step_with_scale", False): + # NOTE(msb) optimizer divides by scale factor + self.wrapped_optimizer.step( + closure, scale=(1.0 / self._multiply_factor), groups=groups + ) + else: + self._unscale_grads() + self.wrapped_optimizer.step(closure, groups=groups) + + if self.scaler is not None: + self.scaler.update() + + def zero_grad(self): + """Clears the gradients of all optimized parameters.""" + self.wrapped_optimizer.zero_grad() + if self.scaler is not None: + self._multiply_factor = 1.0 / float(self.scaler.loss_scale) + else: + self._multiply_factor = 1.0 + + @property + def supports_flat_params(self): + return self.wrapped_optimizer.supports_flat_params + + +class MemoryEfficientFP16Optimizer( + _MemoryEfficientFP16OptimizerMixin, optim.FairseqOptimizer +): + """ + Wrap an *optimizer* to support FP16 (mixed precision) training. + + Compared to :class:`fairseq.optim.FP16Optimizer`, this version does not + maintain an FP32 copy of the model. We instead expect the optimizer to + convert the gradients to FP32 internally and sync the results back to the + FP16 model params. This significantly reduces memory usage but slightly + increases the time spent in the optimizer. + + Since this wrapper depends on specific functionality in the wrapped + optimizer (i.e., on-the-fly conversion of grads to FP32), only certain + optimizers can be wrapped. This is determined by the + *supports_memory_efficient_fp16* property. + """ + + def __init__( + self, cfg: DictConfig, params, optimizer, allow_unsupported=False, **kwargs + ): + if not allow_unsupported and not optimizer.supports_memory_efficient_fp16: + raise ValueError( + "Unsupported optimizer: {}".format(optimizer.__class__.__name__) + ) + + super().__init__(getattr(cfg, "optimizer", None)) + self.wrapped_optimizer = optimizer + + if getattr(cfg.common, "fp16_scale_window", None) is None: + if len(cfg.optimization.update_freq) > 1: + raise ValueError( + "--fp16-scale-window must be given explicitly when using a " + "custom --update-freq schedule" + ) + data_parallel_size = int( + cfg.distributed_training.distributed_world_size + / cfg.common.model_parallel_size + ) + scale_window = int( + 2**14 / data_parallel_size / cfg.optimization.update_freq[0] + ) + else: + scale_window = cfg.common.fp16_scale_window + + if not getattr(cfg.common, "bf16", False): + self.scaler = DynamicLossScaler( + init_scale=cfg.common.fp16_init_scale, + scale_window=scale_window, + tolerance=cfg.common.fp16_scale_tolerance, + threshold=cfg.common.threshold_loss_scale, + min_loss_scale=cfg.common.min_loss_scale, + ) + else: + # disable loss scaling for bfloat16 + self.scaler = None + + @classmethod + def build_optimizer(cls, cfg: DictConfig, params, **kwargs): + """ + Args: + args (argparse.Namespace): fairseq args + params (iterable): iterable of parameters to optimize + """ + fp16_optimizer = optim.build_optimizer(cfg.optimizer, params) + return cls(cfg, params, fp16_optimizer, **kwargs) + + @property + def optimizer(self): + return self.wrapped_optimizer.optimizer + + @optimizer.setter + def optimizer(self, optimizer): + self.wrapped_optimizer.optimizer = optimizer + + @property + def optimizer_config(self): + return self.wrapped_optimizer.optimizer_config + + @property + def lr_scheduler(self): + return getattr(self.wrapped_optimizer, "lr_scheduler", None) + + def get_lr(self): + return self.wrapped_optimizer.get_lr() + + def set_lr(self, lr): + self.wrapped_optimizer.set_lr(lr) + + def all_reduce_grads(self, module): + self.wrapped_optimizer.all_reduce_grads(module) diff --git a/fairseq/optim/fused_adam.py b/fairseq/optim/fused_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..39a2a83694755a3a6d79f797027ad7590fadea5b --- /dev/null +++ b/fairseq/optim/fused_adam.py @@ -0,0 +1,389 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import types + +import torch + + +def get_fused_adam_class(): + """ + Look for the FusedAdam optimizer from apex. We first try to load the + "contrib" interface, which is a bit faster than the main interface, + but is technically deprecated. + """ + try: + # The "deprecated" interface in recent versions of apex is a bit + # faster than the main interface, since we don't use the apex + # optimizer. This can be installed by passing the + # `--deprecated_fused_adam` option when building apex. + global fused_adam_cuda + import importlib + + fused_adam_cuda = importlib.import_module("fused_adam_cuda") + return FusedAdamV1 + except ImportError: + try: + # fallback to the newer interface + from apex.multi_tensor_apply import multi_tensor_applier + from apex.optimizers import FusedAdam as _FusedAdam # noqa + + if multi_tensor_applier.available: + return FusedAdamV2 + except ImportError: + pass + return None + + +class FusedAdamV1(torch.optim.Optimizer): + """ + Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via + ``python setup.py install --cuda_ext --cpp_ext``. + + It has been proposed in `Adam: A Method for Stochastic Optimization`_. + + Compared to the original version in Apex, the fairseq version casts grads + and params to FP32 internally to support ``--memory-efficient-fp16``. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) NOT SUPPORTED in FusedAdam! + eps_inside_sqrt (boolean, optional): in the 'update parameters' step, + adds eps to the bias-corrected second moment estimate before + evaluating square root instead of adding it to the square root of + second moment estimate as in the original paper. (default: False) + .. _Adam: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + eps_inside_sqrt=False, + weight_decay=0.0, + max_grad_norm=0.0, + amsgrad=False, + use_fp16_stats=False, + ): + global fused_adam_cuda + import importlib + + fused_adam_cuda = importlib.import_module("fused_adam_cuda") + + if amsgrad: + raise RuntimeError("FusedAdam does not support the AMSGrad variant.") + defaults = { + "lr": lr, + "bias_correction": bias_correction, + "betas": betas, + "eps": eps, + "weight_decay": weight_decay, + "max_grad_norm": max_grad_norm, + } + super().__init__(params, defaults) + self.eps_mode = 0 if eps_inside_sqrt else 1 + + self.use_fp16_stats = use_fp16_stats + self.FLOAT16_MAX = 65504.0 + + @property + def supports_memory_efficient_fp16(self): + return True + + @property + def supports_flat_params(self): + return True + + @property + def supports_step_with_scale(self): + return True + + def step(self, closure=None, grads=None, scale=1.0, grad_norms=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + grads (list of tensors, optional): weight gradient to use for the + optimizer update. If gradients have type torch.half, parameters + are expected to be in type torch.float. (default: None) + output params (list of tensors, optional): A reduced precision copy + of the updated weights written out in addition to the regular + updated weights. Have to be of same type as gradients. (default: None) + scale (float, optional): factor to divide gradient tensor values + by before applying to weights. (default: 1) + """ + loss = None + if closure is not None: + loss = closure() + + if grads is None: + grads_group = [None] * len(self.param_groups) + # backward compatibility + # assuming a list/generator of parameter means single group + elif isinstance(grads, types.GeneratorType): + grads_group = [grads] + elif type(grads[0]) != list: + grads_group = [grads] + else: + grads_group = grads + + if grad_norms is None: + grad_norms = [None] * len(self.param_groups) + + for group, grads_this_group, grad_norm in zip( + self.param_groups, grads_group, grad_norms + ): + if grads_this_group is None: + grads_this_group = [None] * len(group["params"]) + + # compute combined scale factor for this group + combined_scale = scale + if group.get("max_grad_norm", 0) > 0: + # norm is in fact norm*scale + clip = ((grad_norm / scale) + 1e-6) / group["max_grad_norm"] + if clip > 1: + combined_scale = clip * scale + + bias_correction = 1 if group.get("bias_correction", 1) else 0 + + for p, grad in zip(group["params"], grads_this_group): + # note: p.grad should not ever be set for correct + # operation of mixed precision optimizer that sometimes + # sends None gradients + if p.grad is None and grad is None: + continue + if grad is None: + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError( + "FusedAdam does not support sparse gradients, " + "please consider SparseAdam instead" + ) + + if p.device.type == "cpu": + p_data_fp32 = p.data.cuda(non_blocking=True).float() + out_p = torch.tensor([], dtype=torch.float) + else: + p_data_fp32 = p.data.float() + out_p = p.data + + state = self.state[p] + + # State initialization + dtype = torch.float16 if self.use_fp16_stats else p_data_fp32.dtype + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p_data_fp32, dtype=dtype) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p_data_fp32, dtype=dtype) + if self.use_fp16_stats: + state["exp_avg_scale"] = 1.0 + state["exp_avg_sq_scale"] = 1.0 + else: + device = p_data_fp32.device + state["exp_avg"] = state["exp_avg"].to(device, dtype) + state["exp_avg_sq"] = state["exp_avg_sq"].to(device, dtype) + + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + if self.use_fp16_stats: + assert exp_avg.dtype == torch.float16 + exp_avg = exp_avg.float() * state["exp_avg_scale"] + exp_avg_sq = exp_avg_sq.float() * state["exp_avg_sq_scale"] + beta1, beta2 = group["betas"] + + if "step" not in state: + state["step"] = group["step"] + + state["step"] += 1 + + with torch.cuda.device(p_data_fp32.device): + fused_adam_cuda.adam( + p_data_fp32, + out_p, + exp_avg, + exp_avg_sq, + grad, + group["lr"], + beta1, + beta2, + group["eps"], + combined_scale, + state["step"], + self.eps_mode, + bias_correction, + group["weight_decay"], + ) + + if p.device.type == "cpu": + p.data.copy_(p_data_fp32, non_blocking=True) + + if self.use_fp16_stats: + + def inf_norm(t): + return torch.norm(t, float("inf")) + + # from github.com/openai/jukebox/blob/master/jukebox/utils/fp16.py + state["exp_avg_scale"], state["exp_avg_sq_scale"] = ( + 1e-8 + inf_norm(exp_avg) / self.FLOAT16_MAX, + 1e-8 + inf_norm(exp_avg_sq) / self.FLOAT16_MAX, + ) + state["exp_avg"], state["exp_avg_sq"] = ( + (exp_avg / state["exp_avg_scale"]).half(), + (exp_avg_sq / state["exp_avg_sq_scale"]).half(), + ) + + return loss + + +try: + from apex.multi_tensor_apply import multi_tensor_applier + from apex.optimizers import FusedAdam + + class FusedAdamV2(FusedAdam): + """ + Compared to the original version in Apex, the fairseq version casts grads + and params to FP32 internally to support ``--memory-efficient-fp16``. + """ + + def __init__(self, *args, use_fp16_stats=False, **kwargs): + if use_fp16_stats: + raise NotImplementedError( + "--fp16-adam-stats is only supported with FusedAdamV1" + ) + super().__init__(*args, **kwargs) + if not hasattr(self, "multi_tensor_adam"): + raise Exception( + "Apex installation is outdated. Please install an updated version of apex." + ) + + @property + def supports_memory_efficient_fp16(self): + return True + + @property + def supports_flat_params(self): + return True + + def step( + self, + closure=None, + grads=None, + output_params=None, + scale=None, + grad_norms=None, + ): + """Performs a single optimization step.""" + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + bias_correction = 1 if group["bias_correction"] else 0 + beta1, beta2 = group["betas"] + + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or pass list into kernel + if "step" in group: + group["step"] += 1 + else: + group["step"] = 1 + + # create lists for multi-tensor apply + g_16, p_16, orig_p_16, m_16, v_16 = [], [], [], [], [] + g_32, p_32, m_32, v_32 = [], [], [], [] + + for p in group["params"]: + if p.grad is None: + continue + if p.grad.data.is_sparse: + raise RuntimeError( + "FusedAdam does not support sparse gradients, " + "please consider SparseAdam instead" + ) + + state = self.state[p] + # State initialization + if len(state) == 0: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data, dtype=torch.float) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p.data, dtype=torch.float + ) + else: + state["exp_avg"] = state["exp_avg"].to( + device=p.data.device, dtype=torch.float + ) + state["exp_avg_sq"] = state["exp_avg_sq"].to( + device=p.data.device, dtype=torch.float + ) + + if p.dtype == torch.float16: + g_16.append(p.grad.data.float()) + p_16.append(p.data.float()) + orig_p_16.append(p.data) + m_16.append(state["exp_avg"]) + v_16.append(state["exp_avg_sq"]) + elif p.dtype == torch.float32: + g_32.append(p.grad.data) + p_32.append(p.data) + m_32.append(state["exp_avg"]) + v_32.append(state["exp_avg_sq"]) + else: + raise RuntimeError("FusedAdam only support fp16 and fp32.") + + with torch.cuda.device(p.device): + if len(g_16) > 0: + multi_tensor_applier( + self.multi_tensor_adam, + self._dummy_overflow_buf, + [g_16, p_16, m_16, v_16], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adam_w_mode, + bias_correction, + group["weight_decay"], + ) + for orig_p, p in zip(orig_p_16, p_16): + orig_p.copy_(p.data) + if len(g_32) > 0: + multi_tensor_applier( + self.multi_tensor_adam, + self._dummy_overflow_buf, + [g_32, p_32, m_32, v_32], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self.adam_w_mode, + bias_correction, + group["weight_decay"], + ) + + return loss + +except ImportError: + pass diff --git a/fairseq/optim/fused_lamb.py b/fairseq/optim/fused_lamb.py new file mode 100644 index 0000000000000000000000000000000000000000..f4f2bdb0c6c65f7758509b6d4d2f2c48cb6e8b4f --- /dev/null +++ b/fairseq/optim/fused_lamb.py @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.optim import LegacyFairseqOptimizer, register_optimizer + + +@register_optimizer("lamb") +class FairseqLAMB(LegacyFairseqOptimizer): + """LAMB optimizer.""" + + def __init__(self, args, params): + super().__init__(args) + try: + from apex.optimizers import FusedLAMB + + self._optimizer = FusedLAMB(params, **self.optimizer_config) + except ImportError: + raise ImportError("Please install apex to use LAMB optimizer") + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--lamb-betas', default='(0.9, 0.999)', metavar='B', + help='betas for LAMB optimizer') + parser.add_argument('--lamb-eps', type=float, default=1e-8, metavar='D', + help='epsilon for LAMB optimizer') + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') + # fmt: on + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + return { + "lr": self.args.lr[0], + "betas": eval(self.args.lamb_betas), + "eps": self.args.lamb_eps, + "weight_decay": self.args.weight_decay, + } + + @property + def supports_flat_params(self): + return False diff --git a/fairseq/optim/lr_scheduler/__init__.py b/fairseq/optim/lr_scheduler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5b3dbc023aa4a6f7bfb8403b8204d71ca432f79c --- /dev/null +++ b/fairseq/optim/lr_scheduler/__init__.py @@ -0,0 +1,36 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import importlib +import os + +from fairseq import registry +from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import ( # noqa + FairseqLRScheduler, + LegacyFairseqLRScheduler, +) +from omegaconf import DictConfig + + +( + build_lr_scheduler_, + register_lr_scheduler, + LR_SCHEDULER_REGISTRY, + LR_SCHEDULER_DATACLASS_REGISTRY, +) = registry.setup_registry( + "--lr-scheduler", base_class=FairseqLRScheduler, default="fixed" +) + + +def build_lr_scheduler(cfg: DictConfig, optimizer): + return build_lr_scheduler_(cfg, optimizer) + + +# automatically import any Python files in the optim/lr_scheduler/ directory +for file in sorted(os.listdir(os.path.dirname(__file__))): + if file.endswith(".py") and not file.startswith("_"): + file_name = file[: file.find(".py")] + importlib.import_module("fairseq.optim.lr_scheduler." + file_name) diff --git a/fairseq/optim/lr_scheduler/__pycache__/__init__.cpython-310.pyc b/fairseq/optim/lr_scheduler/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a162d5cba486cb661f6c299d601bb53c27c1c219 Binary files /dev/null and b/fairseq/optim/lr_scheduler/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/optim/lr_scheduler/__pycache__/cosine_lr_scheduler.cpython-310.pyc b/fairseq/optim/lr_scheduler/__pycache__/cosine_lr_scheduler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c69b8053e03381df6ec5bd2f3e43b5dbfb4b776 Binary files /dev/null and b/fairseq/optim/lr_scheduler/__pycache__/cosine_lr_scheduler.cpython-310.pyc differ diff --git a/fairseq/optim/lr_scheduler/__pycache__/fairseq_lr_scheduler.cpython-310.pyc b/fairseq/optim/lr_scheduler/__pycache__/fairseq_lr_scheduler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9e949a1addc6c7bb10a9209aac14602e5a653bd Binary files /dev/null and b/fairseq/optim/lr_scheduler/__pycache__/fairseq_lr_scheduler.cpython-310.pyc differ diff --git a/fairseq/optim/lr_scheduler/__pycache__/fixed_schedule.cpython-310.pyc b/fairseq/optim/lr_scheduler/__pycache__/fixed_schedule.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..abe202fb50b668888313338b6e24c1cc78d9709e Binary files /dev/null and b/fairseq/optim/lr_scheduler/__pycache__/fixed_schedule.cpython-310.pyc differ diff --git a/fairseq/optim/lr_scheduler/__pycache__/inverse_square_root_schedule.cpython-310.pyc b/fairseq/optim/lr_scheduler/__pycache__/inverse_square_root_schedule.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7defa840388b78b494e3d7b00643d9c64353e2d1 Binary files /dev/null and b/fairseq/optim/lr_scheduler/__pycache__/inverse_square_root_schedule.cpython-310.pyc differ diff --git a/fairseq/optim/lr_scheduler/__pycache__/manual_lr_scheduler.cpython-310.pyc b/fairseq/optim/lr_scheduler/__pycache__/manual_lr_scheduler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c71bb7d96b13477ed39bea7f0f18550e866c56f8 Binary files /dev/null and b/fairseq/optim/lr_scheduler/__pycache__/manual_lr_scheduler.cpython-310.pyc differ diff --git a/fairseq/optim/lr_scheduler/__pycache__/pass_through.cpython-310.pyc b/fairseq/optim/lr_scheduler/__pycache__/pass_through.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68ac53b1a8a6eb068ca7830299c7d65a6232c22b Binary files /dev/null and b/fairseq/optim/lr_scheduler/__pycache__/pass_through.cpython-310.pyc differ diff --git a/fairseq/optim/lr_scheduler/__pycache__/polynomial_decay_schedule.cpython-310.pyc b/fairseq/optim/lr_scheduler/__pycache__/polynomial_decay_schedule.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a0ea73e02c80d447a5921a5abaec5689613a56d Binary files /dev/null and b/fairseq/optim/lr_scheduler/__pycache__/polynomial_decay_schedule.cpython-310.pyc differ diff --git a/fairseq/optim/lr_scheduler/__pycache__/reduce_lr_on_plateau.cpython-310.pyc b/fairseq/optim/lr_scheduler/__pycache__/reduce_lr_on_plateau.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02ecdce1306cbb9f9078c1cb1efffe40c7f2433a Binary files /dev/null and b/fairseq/optim/lr_scheduler/__pycache__/reduce_lr_on_plateau.cpython-310.pyc differ diff --git a/fairseq/optim/lr_scheduler/__pycache__/step_lr_scheduler.cpython-310.pyc b/fairseq/optim/lr_scheduler/__pycache__/step_lr_scheduler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0ae78fbe91307857f8f2e2e6e240d16625f8900 Binary files /dev/null and b/fairseq/optim/lr_scheduler/__pycache__/step_lr_scheduler.cpython-310.pyc differ diff --git a/fairseq/optim/lr_scheduler/__pycache__/tri_stage_lr_scheduler.cpython-310.pyc b/fairseq/optim/lr_scheduler/__pycache__/tri_stage_lr_scheduler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85dbe96d6345142a461884cadff56dfd12050c20 Binary files /dev/null and b/fairseq/optim/lr_scheduler/__pycache__/tri_stage_lr_scheduler.cpython-310.pyc differ diff --git a/fairseq/optim/lr_scheduler/__pycache__/triangular_lr_scheduler.cpython-310.pyc b/fairseq/optim/lr_scheduler/__pycache__/triangular_lr_scheduler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e19a6e6a579de2ac119e26ad1717d800529a020 Binary files /dev/null and b/fairseq/optim/lr_scheduler/__pycache__/triangular_lr_scheduler.cpython-310.pyc differ diff --git a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..5fcaea25d493c3f33bab6b9aa65d50450ca446e3 --- /dev/null +++ b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py @@ -0,0 +1,146 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from collections.abc import Collection +from dataclasses import dataclass, field +from typing import List + +from omegaconf import II + +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler + + +@dataclass +class CosineLRScheduleConfig(FairseqDataclass): + warmup_updates: int = field( + default=0, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + warmup_init_lr: float = field( + default=-1, + metadata={ + "help": "initial learning rate during warmup phase; default is cfg.lr" + }, + ) + lr: List[float] = field( + default=II("optimization.lr"), + metadata={"help": "max learning rate, must be more than cfg.min_lr"}, + ) + min_lr: float = field(default=0.0, metadata={"help": "min learning rate"}) + t_mult: float = field( + default=1.0, metadata={"help": "factor to grow the length of each period"} + ) + lr_period_updates: float = field( + default=-1, metadata={"help": "initial number of updates per period"} + ) + lr_shrink: float = field( + default=0.1, metadata={"help": "shrink factor for annealing"} + ) + # This is not required, but is for convenience in inferring lr_period_updates + max_update: int = II("optimization.max_update") + + +@register_lr_scheduler("cosine", dataclass=CosineLRScheduleConfig) +class CosineLRSchedule(FairseqLRScheduler): + """Assign LR based on a cyclical schedule that follows the cosine function. + + See https://arxiv.org/pdf/1608.03983.pdf for details. + + We also support a warmup phase where we linearly increase the learning rate + from some initial learning rate (``--warmup-init-lr``) until the configured + max learning rate (``--lr``). + + During warmup:: + + lrs = torch.linspace(cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates) + lr = lrs[update_num] + + After warmup:: + + lr = cfg.min_lr + 0.5*(cfg.lr - cfg.min_lr)*(1 + cos(t_curr / t_i)) + + where ``t_curr`` is current percentage of updates within the current period + range and ``t_i`` is the current period range, which is scaled by ``t_mul`` + after every iteration. + """ + + def __init__(self, cfg: CosineLRScheduleConfig, fairseq_optimizer): + super().__init__(cfg, fairseq_optimizer) + if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1: + raise ValueError( + "Cannot use a fixed learning rate schedule with cosine." + f" Consider --lr-scheduler=fixed instead. ({cfg.lr})" + ) + + self.max_lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr + if self.max_lr < cfg.min_lr: + cfg.min_lr = self.max_lr + + warmup_end_lr = self.max_lr + if cfg.warmup_init_lr < 0: + cfg.warmup_init_lr = cfg.min_lr + + self.t_mult = cfg.t_mult + self.period = cfg.lr_period_updates + + if self.period <= 0: + assert ( + cfg.max_update > 0 + ), "Either --max_update or --lr-period-updates must be set" + self.period = cfg.max_update - cfg.warmup_updates + + if cfg.warmup_updates > 0: + # linearly warmup for the first cfg.warmup_updates + self.lr_step = (warmup_end_lr - cfg.warmup_init_lr) / cfg.warmup_updates + else: + self.lr_step = 1 + + self.warmup_updates = cfg.warmup_updates + self.lr_shrink = cfg.lr_shrink + + # initial learning rate + self.lr = cfg.warmup_init_lr + self.optimizer.set_lr(self.lr) + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + super().step(epoch, val_loss) + # we don't change the learning rate at epoch boundaries + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + if num_updates < self.cfg.warmup_updates: + self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step + else: + curr_updates = num_updates - self.cfg.warmup_updates + if self.t_mult != 1: + i = math.floor( + math.log( + 1 - curr_updates / self.period * (1 - self.t_mult), self.t_mult + ) + ) + t_i = self.t_mult**i * self.period + t_curr = ( + curr_updates + - (1 - self.t_mult**i) / (1 - self.t_mult) * self.period + ) + else: + i = math.floor(curr_updates / self.period) + t_i = self.period + t_curr = curr_updates - (self.period * i) + + lr_shrink = self.lr_shrink**i + min_lr = self.cfg.min_lr * lr_shrink + max_lr = self.max_lr * lr_shrink + + self.lr = min_lr + 0.5 * (max_lr - min_lr) * ( + 1 + math.cos(math.pi * t_curr / t_i) + ) + + self.optimizer.set_lr(self.lr) + return self.lr diff --git a/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py b/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..6c12fa56b825e81bcc3fc7a97d206777418260ef --- /dev/null +++ b/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py @@ -0,0 +1,59 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from argparse import Namespace + +from fairseq.dataclass.utils import gen_parser_from_dataclass +from fairseq.optim import FairseqOptimizer + + +class FairseqLRScheduler(object): + def __init__(self, cfg, optimizer): + super().__init__() + if optimizer is not None and not isinstance(optimizer, FairseqOptimizer): + raise ValueError("optimizer must be an instance of FairseqOptimizer") + self.cfg = cfg + self.optimizer = optimizer + self.best = None + + @classmethod + def add_args(cls, parser): + """Add arguments to the parser for this LR scheduler.""" + dc = getattr(cls, "__dataclass", None) + if dc is not None: + gen_parser_from_dataclass(parser, dc()) + + def state_dict(self): + """Return the LR scheduler state dict.""" + return {"best": self.best} + + def load_state_dict(self, state_dict): + """Load an LR scheduler state dict.""" + self.best = state_dict["best"] + + def step_begin_epoch(self, epoch): + """Update the learning rate at the beginning of the given epoch.""" + pass + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + if val_loss is not None: + if self.best is None: + self.best = val_loss + else: + self.best = min(self.best, val_loss) + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + return self.optimizer.get_lr() + + +class LegacyFairseqLRScheduler(FairseqLRScheduler): + def __init__(self, args: Namespace, optimizer): + if not isinstance(optimizer, FairseqOptimizer): + raise ValueError("optimizer must be an instance of FairseqOptimizer") + self.args = args + self.optimizer = optimizer + self.best = None diff --git a/fairseq/optim/lr_scheduler/fixed_schedule.py b/fairseq/optim/lr_scheduler/fixed_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..d0e7e14b7e72b1151f7d7f19094430bbab64f8f0 --- /dev/null +++ b/fairseq/optim/lr_scheduler/fixed_schedule.py @@ -0,0 +1,76 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +from typing import Optional, List +from omegaconf import II + +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler + + +@dataclass +class FixedLRScheduleConfig(FairseqDataclass): + force_anneal: Optional[int] = field( + default=None, + metadata={"help": "force annealing at specified epoch"}, + ) + lr_shrink: float = field( + default=0.1, + metadata={"help": "shrink factor for annealing, lr_new = (lr * lr_shrink)"}, + ) + warmup_updates: int = field( + default=0, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + lr: List[float] = II("optimization.lr") + + +@register_lr_scheduler("fixed", dataclass=FixedLRScheduleConfig) +class FixedLRSchedule(FairseqLRScheduler): + """Decay the LR on a fixed schedule.""" + + def __init__(self, cfg: FixedLRScheduleConfig, optimizer): + super().__init__(cfg, optimizer) + + self.lr = cfg.lr[0] + if cfg.warmup_updates > 0: + self.warmup_factor = 1.0 / cfg.warmup_updates + else: + self.warmup_factor = 1 + + def state_dict(self): + return {"lr": self.lr} + + def load_state_dict(self, state_dict): + if "lr" in state_dict: + self.lr = state_dict["lr"] + + def get_next_lr(self, epoch): + lrs = self.cfg.lr + if self.cfg.force_anneal is None or epoch < self.cfg.force_anneal: + # use fixed LR schedule + next_lr = lrs[min(epoch - 1, len(lrs) - 1)] + else: + # annneal based on lr_shrink + next_lr = lrs[-1] * self.cfg.lr_shrink ** ( + epoch + 1 - self.cfg.force_anneal + ) + return next_lr + + def step_begin_epoch(self, epoch): + """Update the learning rate at the beginning of the given epoch.""" + self.lr = self.get_next_lr(epoch) + self.optimizer.set_lr(self.warmup_factor * self.lr) + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + if self.cfg.warmup_updates > 0 and num_updates < self.cfg.warmup_updates: + self.warmup_factor = (num_updates + 1) / float(self.cfg.warmup_updates) + self.optimizer.set_lr(self.warmup_factor * self.lr) + else: + self.optimizer.set_lr(self.lr) + return self.optimizer.get_lr() diff --git a/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..987c905a23d50342dd7e809e0eddb5a6df2ebe90 --- /dev/null +++ b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py @@ -0,0 +1,85 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Collection +from dataclasses import dataclass, field +from typing import List + +from omegaconf import II + +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler + + +@dataclass +class InverseSquareRootLRScheduleConfig(FairseqDataclass): + warmup_updates: int = field( + default=4000, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + warmup_init_lr: float = field( + default=-1, + metadata={ + "help": "initial learning rate during warmup phase; default is cfg.lr" + }, + ) + lr: List[float] = II("optimization.lr") + + +@register_lr_scheduler("inverse_sqrt", dataclass=InverseSquareRootLRScheduleConfig) +class InverseSquareRootSchedule(FairseqLRScheduler): + """Decay the LR based on the inverse square root of the update number. + + We also support a warmup phase where we linearly increase the learning rate + from some initial learning rate (``--warmup-init-lr``) until the configured + learning rate (``--lr``). Thereafter we decay proportional to the number of + updates, with a decay factor set to align with the configured learning rate. + + During warmup:: + + lrs = torch.linspace(cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates) + lr = lrs[update_num] + + After warmup:: + + decay_factor = cfg.lr * sqrt(cfg.warmup_updates) + lr = decay_factor / sqrt(update_num) + """ + + def __init__(self, cfg: InverseSquareRootLRScheduleConfig, optimizer): + super().__init__(cfg, optimizer) + if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1: + raise ValueError( + "Cannot use a fixed learning rate schedule with inverse_sqrt." + " Consider --lr-scheduler=fixed instead." + ) + warmup_end_lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr + if cfg.warmup_init_lr < 0: + cfg.warmup_init_lr = 0 if cfg.warmup_updates > 0 else warmup_end_lr + + # linearly warmup for the first cfg.warmup_updates + self.lr_step = (warmup_end_lr - cfg.warmup_init_lr) / cfg.warmup_updates + + # then, decay prop. to the inverse square root of the update number + self.decay_factor = warmup_end_lr * cfg.warmup_updates**0.5 + + # initial learning rate + self.lr = cfg.warmup_init_lr + self.optimizer.set_lr(self.lr) + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + super().step(epoch, val_loss) + # we don't change the learning rate at epoch boundaries + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + if num_updates < self.cfg.warmup_updates: + self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step + else: + self.lr = self.decay_factor * num_updates**-0.5 + self.optimizer.set_lr(self.lr) + return self.lr diff --git a/fairseq/optim/lr_scheduler/manual_lr_scheduler.py b/fairseq/optim/lr_scheduler/manual_lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..57edc256fd98e1671502ba18def54eae82a4a3ab --- /dev/null +++ b/fairseq/optim/lr_scheduler/manual_lr_scheduler.py @@ -0,0 +1,121 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import LegacyFairseqLRScheduler, register_lr_scheduler +import logging +import ast + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + + +@register_lr_scheduler("manual") +class ManualSchedule(LegacyFairseqLRScheduler): + """Decay the LR on a manual schedule.""" + + def __init__(self, args, optimizer): + super().__init__(args, optimizer) + + self.epoch2lr = self.parse_manuallr_args(args.epoch2lr) + self.update2lr = self.parse_manuallr_args(args.update2lr) + logger.info("@@@ ManualSchedule epoch2lr={}".format(self.epoch2lr)) + logger.info("@@@ ManualSchedule update2lr={}".format(self.update2lr)) + + if 1 in self.epoch2lr: + self.lr = self.epoch2lr[1] + elif 1 in self.update2lr: + self.lr = self.update2lr[1] + else: + self.lr = args.lr[0] + self.optimizer.set_lr(self.lr) # Set the beginning of the epoch. + + def parse_manuallr_args(self, lr_args_str): + lr_dict = ast.literal_eval(lr_args_str.replace(" ", "")) + if not isinstance(lr_dict, dict): + raise ValueError("epoch2lr/update2lr must be abel to evaluated to a dict") + + lr_args = {} + logger.info("@@@ after parsing input dictionary lr_dict = {}".format(lr_dict)) + for key, val in lr_dict.items(): + if "," in key: + for k in key.split(","): + lr_args[int(k)] = float(val) + elif "-" in key: + s = int(key.split("-")[0]) + e = int(key.split("-")[1]) + for k in range(s, e + 1, 1): + lr_args[k] = float(val) + else: + lr_args[int(key)] = float(val) + + return lr_args + + @staticmethod + def add_args(parser): + """Add arguments to the parser for this LR scheduler.""" + # fmt: off + parser.add_argument( + "--epoch2lr", + type=str, + metavar="DICT", + default="{}", + help="a dictionary used to set lr for each epoch manually", + ) + parser.add_argument( + "--update2lr", + type=str, + metavar="DICT", + default="{}", + help="a dictionary used to set lr for each update manually", + ) + # fmt: on + + def state_dict(self): + return {"lr": self.lr} + + def load_state_dict(self, state_dict): + if "lr" in state_dict: + self.lr = state_dict["lr"] + + def get_next_lr(self, epoch): + manual_keys = [k for k in self.epoch2lr if k <= epoch] + if manual_keys: + manual_lr = self.epoch2lr[max(manual_keys)] + else: + logger.warning( + "@@@ epoch={} does not exist in manual lr input. epoch2lr={}...".format( + epoch, + list(self.epoch2lr.items())[ + : min(10, len(self.epoch2lr.keys()) - 1) + ], + ) + ) + manual_lr = self.optimizer.get_lr() + return manual_lr + + def step_begin_epoch(self, epoch): + """Update the learning rate at the beginning of the given epoch.""" + self.lr = self.get_next_lr(epoch) + self.optimizer.set_lr(self.lr) + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + manual_keys = [k for k in self.update2lr if k <= num_updates] + if manual_keys: + manual_lr = self.update2lr[max(manual_keys)] + else: + logger.warning( + "epoch={} does not exist in manual lr input update2lr={}...".format( + num_updates, + list(self.update2lr.items())[ + : min(10, len(self.update2lr.keys()) - 1) + ], + ) + ) + manual_lr = self.optimizer.get_lr() + + self.optimizer.set_lr(manual_lr) + return self.optimizer.get_lr() diff --git a/fairseq/optim/lr_scheduler/pass_through.py b/fairseq/optim/lr_scheduler/pass_through.py new file mode 100644 index 0000000000000000000000000000000000000000..2f93db328c1de9b268e8ee1c0c1cad558fd089aa --- /dev/null +++ b/fairseq/optim/lr_scheduler/pass_through.py @@ -0,0 +1,39 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler + + +@dataclass +class PassThroughScheduleConfig(FairseqDataclass): + pass + + +@register_lr_scheduler("pass_through", dataclass=PassThroughScheduleConfig) +class PassThroughScheduleSchedule(FairseqLRScheduler): + """Delegate lr scheduling to the optimizer.""" + + def __init__(self, cfg: PassThroughScheduleConfig, optimizer): + super().__init__(cfg, optimizer) + assert ( + hasattr(optimizer, "lr_scheduler") and optimizer.lr_scheduler is not None + ), "Pass-through schedule can only be used with optimizers with their own schedulers" + + def state_dict(self): + return self.optimizer.lr_scheduler.state_dict() + + def load_state_dict(self, state_dict): + self.optimizer.lr_scheduler.load_state_dict(state_dict) + + def step_begin_epoch(self, epoch): + """Update the learning rate at the beginning of the given epoch.""" + return self.optimizer.lr_scheduler.step_begin_epoch(epoch) + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + return self.optimizer.lr_scheduler.step_update(num_updates) diff --git a/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py b/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..b8109a7c1e79cd057c355504d07bac5615c02ea9 --- /dev/null +++ b/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py @@ -0,0 +1,89 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +from typing import Optional, List +from omegaconf import II + +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler + + +@dataclass +class PolynomialDecayLRScheduleConfig(FairseqDataclass): + warmup_updates: int = field( + default=0, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + force_anneal: Optional[int] = field( + default=None, + metadata={"help": "force annealing at specified epoch"}, + ) + end_learning_rate: float = field( + default=0.0, + metadata={"help": "learning rate to decay to"}, + ) + power: float = field( + default=1.0, + metadata={"help": "decay exponent"}, + ) + total_num_update: float = field( + default=II("optimization.max_update"), + metadata={"help": "total number of updates over which to decay learning rate"}, + ) + lr: List[float] = II("optimization.lr") + + +@register_lr_scheduler("polynomial_decay", dataclass=PolynomialDecayLRScheduleConfig) +class PolynomialDecayLRSchedule(FairseqLRScheduler): + """Decay the LR on a fixed schedule.""" + + def __init__(self, cfg: PolynomialDecayLRScheduleConfig, optimizer): + super().__init__(cfg, optimizer) + + assert cfg.total_num_update > 0 + + self.lr = cfg.lr[0] + if cfg.warmup_updates > 0: + self.warmup_factor = 1.0 / cfg.warmup_updates + else: + self.warmup_factor = 1 + self.end_learning_rate = cfg.end_learning_rate + self.total_num_update = cfg.total_num_update + self.power = cfg.power + self.optimizer.set_lr(self.warmup_factor * self.lr) + + def get_next_lr(self, epoch): + lrs = self.cfg.lr + if self.cfg.force_anneal is None or epoch < self.cfg.force_anneal: + # use fixed LR schedule + next_lr = lrs[min(epoch, len(lrs) - 1)] + else: + # annneal based on lr_shrink + next_lr = self.optimizer.get_lr() + return next_lr + + def step_begin_epoch(self, epoch): + """Update the learning rate at the beginning of the given epoch.""" + self.lr = self.get_next_lr(epoch) + self.optimizer.set_lr(self.warmup_factor * self.lr) + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + if self.cfg.warmup_updates > 0 and num_updates <= self.cfg.warmup_updates: + self.warmup_factor = num_updates / float(self.cfg.warmup_updates) + lr = self.warmup_factor * self.lr + elif num_updates >= self.total_num_update: + lr = self.end_learning_rate + else: + warmup = self.cfg.warmup_updates + lr_range = self.lr - self.end_learning_rate + pct_remaining = 1 - (num_updates - warmup) / ( + self.total_num_update - warmup + ) + lr = lr_range * pct_remaining ** (self.power) + self.end_learning_rate + self.optimizer.set_lr(lr) + return self.optimizer.get_lr() diff --git a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py new file mode 100644 index 0000000000000000000000000000000000000000..5ee9c1be4a59ad3d072412827ab4e9b62dc7434e --- /dev/null +++ b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py @@ -0,0 +1,143 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +from typing import List + +import torch.optim.lr_scheduler +from omegaconf import II + +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler + + +@dataclass +class ReduceLROnPlateauLRScheduleConfig(FairseqDataclass): + lr_shrink: float = field( + default=0.1, metadata={"help": "shrink factor for annealing"} + ) + lr_threshold: float = field( + default=1e-4, + metadata={ + "help": ( + "threshold for measuring the new optimum, to only focus on " + "significant changes" + ) + }, + ) + lr_patience: int = field( + default=0, + metadata={ + "help": ( + "number of epochs with no improvement after which learning rate will " + "be reduced" + ) + }, + ) + warmup_updates: int = field( + default=0, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + warmup_init_lr: float = field( + default=-1, + metadata={ + "help": "initial learning rate during warmup phase; default is cfg.lr" + }, + ) + lr: List[float] = II("optimization.lr") + maximize_best_checkpoint_metric: bool = II( + "checkpoint.maximize_best_checkpoint_metric" + ) + + +@register_lr_scheduler( + "reduce_lr_on_plateau", dataclass=ReduceLROnPlateauLRScheduleConfig +) +class ReduceLROnPlateauLRSchedule(FairseqLRScheduler): + """ + Decay the LR by a factor every time the validation loss plateaus. + Also comes with optional warmup phase, where we linearly increase + the learning rate from some initial learning rate + (``--warmup-init-lr``) until the configured learning rate + (``--lr``). Thereafter the lr is adjusted according to original + reduce_on_plateau scheme. + + During warmup:: + + lrs = torch.linspace( + cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates + ) + lr = lrs[update_num] + """ + + def __init__(self, cfg: ReduceLROnPlateauLRScheduleConfig, optimizer): + super().__init__(cfg, optimizer) + if len(cfg.lr) > 1: + raise ValueError( + "Cannot use a fixed learning rate schedule with reduce_lr_on_plateau." + " Consider --lr-scheduler=fixed instead." + ) + self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + self.optimizer.optimizer, + patience=cfg.lr_patience, + factor=cfg.lr_shrink, + mode="max" if cfg.maximize_best_checkpoint_metric else "min", + threshold=cfg.lr_threshold, + ) + warmup_end_lr = cfg.lr[0] + # if no warm up, sets initial lr to be cfg.lr[0] + if cfg.warmup_init_lr < 0: + cfg.warmup_init_lr = 0 if cfg.warmup_updates > 0 else warmup_end_lr + + # linearly warmup for the first cfg.warmup_updates + if cfg.warmup_updates > 0: + self.lr_step = (warmup_end_lr - cfg.warmup_init_lr) / cfg.warmup_updates + + # this flag is either set from arg when no warm up, or set by + # step_update() when warmup finishes + self.warmup_end = True if cfg.warmup_updates <= 0 else False + + # initial learning rate + # this self.lr is used only during init and/or warm up period + self.lr = warmup_end_lr if self.warmup_end else cfg.warmup_init_lr + self.optimizer.set_lr(self.lr) + + def state_dict(self): + """Return the LR scheduler state dict.""" + return { + "best": self.lr_scheduler.best, + "last_epoch": self.lr_scheduler.last_epoch, + } + + def load_state_dict(self, state_dict): + """Load an LR scheduler state dict.""" + self.lr_scheduler.best = state_dict["best"] + if "last_epoch" in state_dict: + self.lr_scheduler.last_epoch = state_dict["last_epoch"] + + def step(self, epoch, val_loss=None): + """ + Update the learning rate at the end of the given epoch if warmup + finishes otherwise no update of lr on epoch boundaries + """ + if val_loss is not None and self.warmup_end is True: + self.lr_scheduler.step(val_loss) + else: + self.lr_scheduler.last_epoch = epoch + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """ + Update the learning rate after each update.""" + # if there is warmup + if self.cfg.warmup_updates > 0: + if num_updates <= self.cfg.warmup_updates: + self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step + self.optimizer.set_lr(self.lr) + else: + if self.warmup_end is False: + self.warmup_end = True + # else do nothing + return self.optimizer.get_lr() diff --git a/fairseq/optim/lr_scheduler/step_lr_scheduler.py b/fairseq/optim/lr_scheduler/step_lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..db99d4eee84475d7eaef625e4c72c71de847db42 --- /dev/null +++ b/fairseq/optim/lr_scheduler/step_lr_scheduler.py @@ -0,0 +1,85 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Collection +from dataclasses import dataclass, field +from typing import List + +from omegaconf import II + +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler + + +@dataclass +class StepLRScheduleConfig(FairseqDataclass): + warmup_updates: int = field( + default=0, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + warmup_init_lr: float = field( + default=-1, + metadata={ + "help": "initial learning rate during warmup phase; default is cfg.lr" + }, + ) + lr: List[float] = field( + default=II("optimization.lr"), + metadata={"help": "max learning rate, must be more than cfg.min_lr"}, + ) + min_lr: float = field(default=0.0, metadata={"help": "min learning rate"}) + lr_deacy_period: int = field(default=25000, metadata={"help": "decay period"}) + lr_decay: float = field(default=0.5, metadata={"help": "decay factor"}) + + +@register_lr_scheduler("step", dataclass=StepLRScheduleConfig) +class StepLRSchedule(FairseqLRScheduler): + """Decay learning rate every k updates by a fixed factor""" + + def __init__(self, cfg: StepLRScheduleConfig, fairseq_optimizer): + super().__init__(cfg, fairseq_optimizer) + self.max_lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr + self.min_lr = cfg.min_lr + self.lr_deacy_period = cfg.lr_deacy_period + self.lr_decay = cfg.lr_decay + self.warmup_updates = cfg.warmup_updates + self.warmup_init_lr = ( + cfg.warmup_init_lr if cfg.warmup_init_lr >= 0 else self.min_lr + ) + + assert self.lr_deacy_period > 0 + assert self.lr_decay <= 1 + assert self.min_lr >= 0 + assert self.max_lr > self.min_lr + + if cfg.warmup_updates > 0: + # linearly warmup for the first cfg.warmup_updates + self.warmup_lr_step = ( + self.max_lr - self.warmup_init_lr + ) / self.warmup_updates + else: + self.warmup_lr_step = 1 + + # initial learning rate + self.lr = self.warmup_init_lr + self.optimizer.set_lr(self.lr) + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + super().step(epoch, val_loss) + # we don't change the learning rate at epoch boundaries + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + if num_updates < self.cfg.warmup_updates: + self.lr = self.warmup_init_lr + num_updates * self.warmup_lr_step + else: + curr_updates = num_updates - self.cfg.warmup_updates + lr_mult = self.lr_decay ** (curr_updates // self.lr_deacy_period) + self.lr = max(self.max_lr * lr_mult, self.min_lr) + + self.optimizer.set_lr(self.lr) + return self.lr diff --git a/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py b/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..4d5547c39b14f62acbd4f4b9ab3abfb3009c0e6d --- /dev/null +++ b/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py @@ -0,0 +1,175 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass, field +from typing import Optional, List, Tuple +from omegaconf import II + +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler + + +@dataclass +class TriStageLRScheduleConfig(FairseqDataclass): + warmup_steps: int = field( + default=0, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + hold_steps: int = field( + default=0, + metadata={"help": "steps in hold stage"}, + ) + decay_steps: int = field( + default=0, + metadata={"help": "steps in decay stages"}, + ) + phase_ratio: Optional[Tuple[float, float, float]] = field( + default=None, + metadata={ + "help": ( + "if set, automatically sets warmup/hold/decay steps to the ratio " + "specified here from max_updates. the ratios must add up to 1.0" + ) + }, + ) + init_lr_scale: float = field( + default=0.01, + metadata={"help": "initial learning rate scale during warmup phase"}, + ) + final_lr_scale: float = field( + default=0.01, + metadata={"help": "final learning rate scale"}, + ) + max_update: float = II("optimization.max_update") + lr: List[float] = II("optimization.lr") + + +@register_lr_scheduler("tri_stage", dataclass=TriStageLRScheduleConfig) +class TriStageLRSchedule(FairseqLRScheduler): + """Tristage learning rate schedulr + + Implement the learning rate scheduler in https://arxiv.org/pdf/1904.08779.pdf + + Similar to inverse_squre_root scheduler, but tri_stage learning rate employs + three stages LR scheduling: + + - warmup stage, starting from `lr` * `init_lr_scale`, linearly + increased to `lr` in `warmup_steps` iterations + + - hold stage, after `warmup_steps`, keep the LR as `lr` for `hold_steps` + iterations + + - decay stage, after hold stage, decay LR exponetially to + `lr` * `final_lr_scale` in `decay_steps`; + after that LR is keep as `final_lr_scale` * `lr` + + During warmup:: + + init_lr = cfg.init_lr_scale * cfg.lr + lrs = torch.linspace(init_lr, cfg.lr, cfg.warmup_steps) + lr = lrs[update_num] + + During hold:: + + lr = cfg.lr + + During decay:: + + decay_factor = - math.log(cfg.final_lr_scale) / cfg.decay_steps + lr = cfg.lr * exp(- (update_num - warmup_steps - decay_steps) * decay_factor) + + After that:: + + lr = cfg.lr * cfg.final_lr_scale + """ + + def __init__(self, cfg: TriStageLRScheduleConfig, optimizer): + super().__init__(cfg, optimizer) + if len(cfg.lr) > 1: + raise ValueError( + "Cannot use a fixed learning rate schedule with tri-stage lr." + " Consider --lr-scheduler=fixed instead." + ) + + # calculate LR at each point + self.peak_lr = cfg.lr[0] + self.init_lr = cfg.init_lr_scale * cfg.lr[0] + self.final_lr = cfg.final_lr_scale * cfg.lr[0] + + if cfg.phase_ratio is not None: + assert cfg.max_update > 0 + assert sum(cfg.phase_ratio) == 1, "phase ratios must add up to 1" + self.warmup_steps = int(cfg.max_update * cfg.phase_ratio[0]) + self.hold_steps = int(cfg.max_update * cfg.phase_ratio[1]) + self.decay_steps = int(cfg.max_update * cfg.phase_ratio[2]) + else: + self.warmup_steps = cfg.warmup_steps + self.hold_steps = cfg.hold_steps + self.decay_steps = cfg.decay_steps + + assert ( + self.warmup_steps + self.hold_steps + self.decay_steps > 0 + ), "please specify steps or phase_ratio" + + self.warmup_rate = ( + (self.peak_lr - self.init_lr) / self.warmup_steps + if self.warmup_steps != 0 + else 0 + ) + self.decay_factor = -math.log(cfg.final_lr_scale) / self.decay_steps + + # initial learning rate + self.lr = self.init_lr + self.optimizer.set_lr(self.lr) + + def _decide_stage(self, update_step): + """ + return stage, and the corresponding steps within the current stage + """ + if update_step < self.warmup_steps: + # warmup state + return 0, update_step + + offset = self.warmup_steps + + if update_step < offset + self.hold_steps: + # hold stage + return 1, update_step - offset + + offset += self.hold_steps + + if update_step <= offset + self.decay_steps: + # decay stage + return 2, update_step - offset + + offset += self.decay_steps + + # still here ? constant lr stage + return 3, update_step - offset + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + super().step(epoch, val_loss) + # we don't change the learning rate at epoch boundaries + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + stage, steps_in_stage = self._decide_stage(num_updates) + if stage == 0: + self.lr = self.init_lr + self.warmup_rate * steps_in_stage + elif stage == 1: + self.lr = self.peak_lr + elif stage == 2: + self.lr = self.peak_lr * math.exp(-self.decay_factor * steps_in_stage) + elif stage == 3: + self.lr = self.final_lr + else: + raise ValueError("Undefined stage") + + self.optimizer.set_lr(self.lr) + + return self.lr diff --git a/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py b/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..2a32bd10f213e4d8408e529ac20fd4cac7a91704 --- /dev/null +++ b/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py @@ -0,0 +1,83 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass, field +from typing import List + +from omegaconf import II + +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler + + +@dataclass +class TriangularLRScheduleConfig(FairseqDataclass): + max_lr: float = field( + default="???", metadata={"help": "max learning rate, must be more than cfg.lr"} + ) + lr_period_updates: float = field( + default=5000, + metadata={"help": "initial number of updates per period (cycle length)"}, + ) + lr_shrink: float = field( + default=0.1, metadata={"help": "shrink factor for annealing"} + ) + shrink_min: bool = field( + default=False, metadata={"help": "if set, also shrinks min lr"} + ) + lr: List[float] = II("optimization.lr") + + +@register_lr_scheduler("triangular", dataclass=TriangularLRScheduleConfig) +class TriangularLRSchedule(FairseqLRScheduler): + """Assign LR based on a triangular cyclical schedule. + + See https://arxiv.org/pdf/1506.01186.pdf for details. + """ + + def __init__(self, cfg: TriangularLRScheduleConfig, optimizer): + super().__init__(cfg, optimizer) + if len(cfg.lr) > 1: + raise ValueError( + "Cannot use a fixed learning rate schedule with triangular." + " Consider --lr-scheduler=fixed instead." + ) + + lr = cfg.lr[0] + + assert cfg.max_lr > lr, "max_lr must be more than lr" + self.min_lr = lr + self.max_lr = cfg.max_lr + self.stepsize = cfg.lr_period_updates // 2 + self.lr_shrink = cfg.lr_shrink + self.shrink_min = cfg.shrink_min + + # initial learning rate + self.lr = self.min_lr + self.optimizer.set_lr(self.lr) + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + super().step(epoch, val_loss) + # we don't change the learning rate at epoch boundaries + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + cycle = math.floor(num_updates / (2 * self.stepsize)) + + lr_shrink = self.lr_shrink**cycle + max_lr = self.max_lr * lr_shrink + if self.shrink_min: + min_lr = self.min_lr * lr_shrink + else: + min_lr = self.min_lr + + x = abs(num_updates / self.stepsize - 2 * (cycle + 1) + 1) + self.lr = min_lr + (max_lr - min_lr) * max(0, (1 - x)) + + self.optimizer.set_lr(self.lr) + return self.lr diff --git a/fairseq/optim/nag.py b/fairseq/optim/nag.py new file mode 100644 index 0000000000000000000000000000000000000000..c30a6c0fb1e8d5dc7edd5b53ba15a6acd46ecbff --- /dev/null +++ b/fairseq/optim/nag.py @@ -0,0 +1,111 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Collection +from dataclasses import dataclass, field +from typing import List + +import torch +from fairseq.dataclass import FairseqDataclass +from omegaconf import II, DictConfig +from torch.optim.optimizer import Optimizer, required + +from . import FairseqOptimizer, register_optimizer + + +@dataclass +class FairseqNAGConfig(FairseqDataclass): + momentum: float = field(default=0.99, metadata={"help": "momentum factor"}) + weight_decay: float = field(default=0.0, metadata={"help": "weight decay"}) + # TODO common vars in parent class + lr: List[float] = II("optimization.lr") + + +@register_optimizer("nag", dataclass=FairseqNAGConfig) +class FairseqNAG(FairseqOptimizer): + def __init__(self, cfg: DictConfig, params): + super().__init__(cfg) + self._optimizer = NAG(params, **self.optimizer_config) + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + return { + "lr": self.cfg.lr[0] + if isinstance(self.cfg.lr, Collection) + else self.cfg.lr, + "momentum": self.cfg.momentum, + "weight_decay": self.cfg.weight_decay, + } + + +class NAG(Optimizer): + def __init__(self, params, lr=required, momentum=0, weight_decay=0): + defaults = dict(lr=lr, lr_old=lr, momentum=momentum, weight_decay=weight_decay) + super(NAG, self).__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self): + return True + + @property + def supports_flat_params(self): + return True + + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + weight_decay = group["weight_decay"] + momentum = group["momentum"] + lr = group["lr"] + lr_old = group.get("lr_old", lr) + lr_correct = lr / lr_old if lr_old > 0 else lr + + for p in group["params"]: + if p.grad is None: + continue + + p_data_fp32 = p.data + if p_data_fp32.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + d_p = p.grad.data.float() + param_state = self.state[p] + if "momentum_buffer" not in param_state: + param_state["momentum_buffer"] = torch.zeros_like(d_p) + else: + param_state["momentum_buffer"] = param_state["momentum_buffer"].to( + d_p + ) + + buf = param_state["momentum_buffer"] + + if weight_decay != 0: + p_data_fp32.mul_(1 - lr * weight_decay) + p_data_fp32.add_(buf, alpha=momentum * momentum * lr_correct) + p_data_fp32.add_(d_p, alpha=-(1 + momentum) * lr) + + buf.mul_(momentum * lr_correct).add_(d_p, alpha=-lr) + + if p.data.dtype in {torch.float16, torch.bfloat16}: + p.data.copy_(p_data_fp32) + + group["lr_old"] = lr + + return loss diff --git a/fairseq/optim/sgd.py b/fairseq/optim/sgd.py new file mode 100644 index 0000000000000000000000000000000000000000..8e34fb99a18fff12ab76be5894a84cbbb2f48176 --- /dev/null +++ b/fairseq/optim/sgd.py @@ -0,0 +1,43 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.optim + +from . import LegacyFairseqOptimizer, register_optimizer + + +@register_optimizer("sgd") +class SGD(LegacyFairseqOptimizer): + def __init__(self, args, params): + super().__init__(args) + self._optimizer = torch.optim.SGD(params, **self.optimizer_config) + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--momentum', default=0.0, type=float, metavar='M', + help='momentum factor') + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') + # fmt: on + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + return { + "lr": self.args.lr[0], + "momentum": self.args.momentum, + "weight_decay": self.args.weight_decay, + } + + @property + def supports_flat_params(self): + return True diff --git a/fairseq/optim/shard.py b/fairseq/optim/shard.py new file mode 100644 index 0000000000000000000000000000000000000000..9d7f2eb9e5de6086fe2435d432bde7521ebb8155 --- /dev/null +++ b/fairseq/optim/shard.py @@ -0,0 +1,58 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict + +from fairseq.distributed import utils + + +try: + from fairscale.optim import OSS + + _has_fairscale = True +except ImportError: + _has_fairscale = False + + +def shard_(optimizer, group): + if not _has_fairscale: + raise ImportError( + "\n\nPlease install the fairscale package:" "\n\n pip install fairscale" + ) + + class FairseqOSS(OSS): + @property + def disable_mem_eff_fp16_loading_hack(self): + return True + + def __getattr__(self, name): + if name.startswith("supports") and hasattr(self.optim, name): + return getattr(self.optim, name) + raise AttributeError( + "'FairseqOSS' object has no attribute {0!r}".format(name) + ) + + def broadcast_global_state_dict( + self, state_dict: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Broadcasts the entire state_dict to all other ranks + each rank is responsible to load their own partition of data + """ + return utils.broadcast_object( + state_dict, + src_rank=0, + group=self.group, + ) + + torch_optimizer = optimizer.optimizer + optim_cls = type(torch_optimizer) + + optimizer.optimizer = FairseqOSS( + torch_optimizer.param_groups, + optim_cls, + group=group, + **optimizer.optimizer_config + ) diff --git a/fairseq/options.py b/fairseq/options.py new file mode 100644 index 0000000000000000000000000000000000000000..920591635a05aa4aca728321a47fed6f3c28e504 --- /dev/null +++ b/fairseq/options.py @@ -0,0 +1,413 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from pathlib import Path +from typing import Callable, List, Optional, Union + +import torch +from fairseq import utils +from fairseq.data.indexed_dataset import get_available_dataset_impl +from fairseq.dataclass.configs import ( + CheckpointConfig, + CommonConfig, + CommonEvalConfig, + DatasetConfig, + DistributedTrainingConfig, + EvalLMConfig, + GenerationConfig, + InteractiveConfig, + OptimizationConfig, + EMAConfig, +) +from fairseq.dataclass.utils import gen_parser_from_dataclass + +# this import is for backward compatibility +from fairseq.utils import csv_str_list, eval_bool, eval_str_dict, eval_str_list # noqa + + +def get_preprocessing_parser(default_task="translation"): + parser = get_parser("Preprocessing", default_task) + add_preprocess_args(parser) + return parser + + +def get_training_parser(default_task="translation"): + parser = get_parser("Trainer", default_task) + add_dataset_args(parser, train=True) + add_distributed_training_args(parser) + add_model_args(parser) + add_optimization_args(parser) + add_checkpoint_args(parser) + add_ema_args(parser) + return parser + + +def get_generation_parser(interactive=False, default_task="translation"): + parser = get_parser("Generation", default_task) + add_dataset_args(parser, gen=True) + add_distributed_training_args(parser, default_world_size=1) + add_generation_args(parser) + add_checkpoint_args(parser) + if interactive: + add_interactive_args(parser) + return parser + + +def get_speech_generation_parser(default_task="text_to_speech"): + parser = get_parser("Speech Generation", default_task) + add_dataset_args(parser, gen=True) + add_distributed_training_args(parser, default_world_size=1) + add_speech_generation_args(parser) + return parser + + +def get_interactive_generation_parser(default_task="translation"): + return get_generation_parser(interactive=True, default_task=default_task) + + +def get_eval_lm_parser(default_task="language_modeling"): + parser = get_parser("Evaluate Language Model", default_task) + add_dataset_args(parser, gen=True) + add_distributed_training_args(parser, default_world_size=1) + add_eval_lm_args(parser) + return parser + + +def get_validation_parser(default_task=None): + parser = get_parser("Validation", default_task) + add_dataset_args(parser, train=True) + add_distributed_training_args(parser, default_world_size=1) + group = parser.add_argument_group("Evaluation") + gen_parser_from_dataclass(group, CommonEvalConfig()) + return parser + + +def parse_args_and_arch( + parser: argparse.ArgumentParser, + input_args: List[str] = None, + parse_known: bool = False, + suppress_defaults: bool = False, + modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None, +): + """ + Args: + parser (ArgumentParser): the parser + input_args (List[str]): strings to parse, defaults to sys.argv + parse_known (bool): only parse known arguments, similar to + `ArgumentParser.parse_known_args` + suppress_defaults (bool): parse while ignoring all default values + modify_parser (Optional[Callable[[ArgumentParser], None]]): + function to modify the parser, e.g., to set default values + """ + if suppress_defaults: + # Parse args without any default values. This requires us to parse + # twice, once to identify all the necessary task/model args, and a second + # time with all defaults set to None. + args = parse_args_and_arch( + parser, + input_args=input_args, + parse_known=parse_known, + suppress_defaults=False, + ) + suppressed_parser = argparse.ArgumentParser(add_help=False, parents=[parser]) + suppressed_parser.set_defaults(**{k: None for k, v in vars(args).items()}) + args = suppressed_parser.parse_args(input_args) + return argparse.Namespace( + **{k: v for k, v in vars(args).items() if v is not None} + ) + + from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY, MODEL_REGISTRY + + # Before creating the true parser, we need to import optional user module + # in order to eagerly import custom tasks, optimizers, architectures, etc. + usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) + usr_parser.add_argument("--user-dir", default=None) + usr_args, _ = usr_parser.parse_known_args(input_args) + utils.import_user_module(usr_args) + + if modify_parser is not None: + modify_parser(parser) + + # The parser doesn't know about model/criterion/optimizer-specific args, so + # we parse twice. First we parse the model/criterion/optimizer, then we + # parse a second time after adding the *-specific arguments. + # If input_args is given, we will parse those args instead of sys.argv. + args, _ = parser.parse_known_args(input_args) + + # Add model-specific args to parser. + if hasattr(args, "arch"): + model_specific_group = parser.add_argument_group( + "Model-specific configuration", + # Only include attributes which are explicitly given as command-line + # arguments or which have default values. + argument_default=argparse.SUPPRESS, + ) + if args.arch in ARCH_MODEL_REGISTRY: + ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group) + elif args.arch in MODEL_REGISTRY: + MODEL_REGISTRY[args.arch].add_args(model_specific_group) + else: + raise RuntimeError() + + if hasattr(args, "task"): + from fairseq.tasks import TASK_REGISTRY + + TASK_REGISTRY[args.task].add_args(parser) + if getattr(args, "use_bmuf", False): + # hack to support extra args for block distributed data parallelism + from fairseq.optim.bmuf import FairseqBMUF + + FairseqBMUF.add_args(parser) + + # Add *-specific args to parser. + from fairseq.registry import REGISTRIES + + for registry_name, REGISTRY in REGISTRIES.items(): + choice = getattr(args, registry_name, None) + if choice is not None: + cls = REGISTRY["registry"][choice] + if hasattr(cls, "add_args"): + cls.add_args(parser) + elif hasattr(cls, "__dataclass"): + gen_parser_from_dataclass(parser, cls.__dataclass()) + + # Modify the parser a second time, since defaults may have been reset + if modify_parser is not None: + modify_parser(parser) + + # Parse a second time. + if parse_known: + args, extra = parser.parse_known_args(input_args) + else: + args = parser.parse_args(input_args) + extra = None + # Post-process args. + if ( + hasattr(args, "batch_size_valid") and args.batch_size_valid is None + ) or not hasattr(args, "batch_size_valid"): + args.batch_size_valid = args.batch_size + if hasattr(args, "max_tokens_valid") and args.max_tokens_valid is None: + args.max_tokens_valid = args.max_tokens + if getattr(args, "memory_efficient_fp16", False): + args.fp16 = True + if getattr(args, "memory_efficient_bf16", False): + args.bf16 = True + args.tpu = getattr(args, "tpu", False) + args.bf16 = getattr(args, "bf16", False) + if args.bf16: + args.tpu = True + if args.tpu and args.fp16: + raise ValueError("Cannot combine --fp16 and --tpu, use --bf16 on TPUs") + + if getattr(args, "seed", None) is None: + args.seed = 1 # default seed for training + args.no_seed_provided = True + else: + args.no_seed_provided = False + + if getattr(args, "update_epoch_batch_itr", None) is None: + if hasattr(args, "grouped_shuffling"): + args.update_epoch_batch_itr = args.grouped_shuffling + else: + args.grouped_shuffling = False + args.update_epoch_batch_itr = False + + # Apply architecture configuration. + if hasattr(args, "arch") and args.arch in ARCH_CONFIG_REGISTRY: + ARCH_CONFIG_REGISTRY[args.arch](args) + + if parse_known: + return args, extra + else: + return args + + +def get_parser(desc, default_task="translation"): + # Before creating the true parser, we need to import optional user module + # in order to eagerly import custom tasks, optimizers, architectures, etc. + usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) + usr_parser.add_argument("--user-dir", default=None) + usr_args, _ = usr_parser.parse_known_args() + utils.import_user_module(usr_args) + + parser = argparse.ArgumentParser(allow_abbrev=False) + gen_parser_from_dataclass(parser, CommonConfig()) + + from fairseq.registry import REGISTRIES + + for registry_name, REGISTRY in REGISTRIES.items(): + parser.add_argument( + "--" + registry_name.replace("_", "-"), + default=REGISTRY["default"], + choices=REGISTRY["registry"].keys(), + ) + + # Task definitions can be found under fairseq/tasks/ + from fairseq.tasks import TASK_REGISTRY + + parser.add_argument( + "--task", + metavar="TASK", + default=default_task, + choices=TASK_REGISTRY.keys(), + help="task", + ) + # fmt: on + return parser + + +def add_preprocess_args(parser): + group = parser.add_argument_group("Preprocessing") + # fmt: off + group.add_argument("-s", "--source-lang", default=None, metavar="SRC", + help="source language") + group.add_argument("-t", "--target-lang", default=None, metavar="TARGET", + help="target language") + group.add_argument("--trainpref", metavar="FP", default=None, + help="train file prefix (also used to build dictionaries)") + group.add_argument("--validpref", metavar="FP", default=None, + help="comma separated, valid file prefixes " + "(words missing from train set are replaced with )") + group.add_argument("--testpref", metavar="FP", default=None, + help="comma separated, test file prefixes " + "(words missing from train set are replaced with )") + group.add_argument("--align-suffix", metavar="FP", default=None, + help="alignment file suffix") + group.add_argument("--destdir", metavar="DIR", default="data-bin", + help="destination dir") + group.add_argument("--thresholdtgt", metavar="N", default=0, type=int, + help="map words appearing less than threshold times to unknown") + group.add_argument("--thresholdsrc", metavar="N", default=0, type=int, + help="map words appearing less than threshold times to unknown") + group.add_argument("--tgtdict", metavar="FP", + help="reuse given target dictionary") + group.add_argument("--srcdict", metavar="FP", + help="reuse given source dictionary") + group.add_argument("--nwordstgt", metavar="N", default=-1, type=int, + help="number of target words to retain") + group.add_argument("--nwordssrc", metavar="N", default=-1, type=int, + help="number of source words to retain") + group.add_argument("--alignfile", metavar="ALIGN", default=None, + help="an alignment file (optional)") + parser.add_argument('--dataset-impl', metavar='FORMAT', default='mmap', + choices=get_available_dataset_impl(), + help='output dataset implementation') + group.add_argument("--joined-dictionary", action="store_true", + help="Generate joined dictionary") + group.add_argument("--only-source", action="store_true", + help="Only process the source language") + group.add_argument("--padding-factor", metavar="N", default=8, type=int, + help="Pad dictionary size to be multiple of N") + group.add_argument("--workers", metavar="N", default=1, type=int, + help="number of parallel workers") + group.add_argument("--dict-only", action='store_true', + help="if true, only builds a dictionary and then exits") + # fmt: on + return parser + + +def add_dataset_args(parser, train=False, gen=False): + group = parser.add_argument_group("dataset_data_loading") + gen_parser_from_dataclass(group, DatasetConfig()) + # fmt: on + return group + + +def add_distributed_training_args(parser, default_world_size=None): + group = parser.add_argument_group("distributed_training") + if default_world_size is None: + default_world_size = max(1, torch.cuda.device_count()) + gen_parser_from_dataclass( + group, DistributedTrainingConfig(distributed_world_size=default_world_size) + ) + return group + + +def add_optimization_args(parser): + group = parser.add_argument_group("optimization") + # fmt: off + gen_parser_from_dataclass(group, OptimizationConfig()) + # fmt: on + return group + + +def add_checkpoint_args(parser): + group = parser.add_argument_group("checkpoint") + # fmt: off + gen_parser_from_dataclass(group, CheckpointConfig()) + # fmt: on + return group + + +def add_common_eval_args(group): + gen_parser_from_dataclass(group, CommonEvalConfig()) + + +def add_eval_lm_args(parser): + group = parser.add_argument_group("LM Evaluation") + add_common_eval_args(group) + gen_parser_from_dataclass(group, EvalLMConfig()) + + +def add_generation_args(parser): + group = parser.add_argument_group("Generation") + add_common_eval_args(group) + gen_parser_from_dataclass(group, GenerationConfig()) + return group + + +def add_speech_generation_args(parser): + group = parser.add_argument_group("Speech Generation") + add_common_eval_args(group) # NOTE: remove_bpe is not needed + # fmt: off + group.add_argument('--eos_prob_threshold', default=0.5, type=float, + help='terminate when eos probability exceeds this') + # fmt: on + return group + + +def add_interactive_args(parser): + group = parser.add_argument_group("Interactive") + gen_parser_from_dataclass(group, InteractiveConfig()) + + +def add_model_args(parser): + group = parser.add_argument_group("Model configuration") + # fmt: off + + # Model definitions can be found under fairseq/models/ + # + # The model architecture can be specified in several ways. + # In increasing order of priority: + # 1) model defaults (lowest priority) + # 2) --arch argument + # 3) --encoder/decoder-* arguments (highest priority) + from fairseq.models import ARCH_MODEL_REGISTRY + group.add_argument('--arch', '-a', metavar='ARCH', + choices=ARCH_MODEL_REGISTRY.keys(), + help='model architecture') + # fmt: on + return group + + +def get_args( + data: Union[str, Path], + task: str = "translation", + arch: str = "transformer", + **overrides +): + parser = get_training_parser(task) + args = parse_args_and_arch(parser, [str(data), "--task", task, "--arch", arch]) + + for k, v in overrides.items(): + setattr(args, k, v) + + return args + + +def add_ema_args(parser): + group = parser.add_argument_group("EMA configuration") + gen_parser_from_dataclass(group, EMAConfig()) diff --git a/fairseq/pdb.py b/fairseq/pdb.py new file mode 100644 index 0000000000000000000000000000000000000000..1ba6ef0d336b30717cfdde94e1b838cfe2bfeb20 --- /dev/null +++ b/fairseq/pdb.py @@ -0,0 +1,47 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import multiprocessing +import os +import pdb +import sys + + +__all__ = ["set_trace"] + + +_stdin = [None] +_stdin_lock = multiprocessing.Lock() +try: + _stdin_fd = sys.stdin.fileno() +except Exception: + _stdin_fd = None + + +class MultiprocessingPdb(pdb.Pdb): + """A Pdb wrapper that works in a multiprocessing environment. + + Usage: `from fairseq import pdb; pdb.set_trace()` + """ + + def __init__(self): + pdb.Pdb.__init__(self, nosigint=True) + + def _cmdloop(self): + stdin_bak = sys.stdin + with _stdin_lock: + try: + if _stdin_fd is not None: + if not _stdin[0]: + _stdin[0] = os.fdopen(_stdin_fd) + sys.stdin = _stdin[0] + self.cmdloop() + finally: + sys.stdin = stdin_bak + + +def set_trace(): + pdb = MultiprocessingPdb() + pdb.set_trace(sys._getframe().f_back) diff --git a/fairseq/quantization_utils.py b/fairseq/quantization_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..11fc414c852b199b80a569bf024272535929abcc --- /dev/null +++ b/fairseq/quantization_utils.py @@ -0,0 +1,143 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from fairseq.modules.quantization import pq, quantization_options, scalar +from omegaconf import DictConfig + + +logger = logging.getLogger(__name__) + + +def quantize_model_scalar(model, model_cfg: DictConfig): + quant_noise_scalar = getattr(model_cfg, "quant_noise_scalar", 0) or 0 + if quant_noise_scalar > 0: + # quantize_model edits the model in place + scalar.quantize_model_(model, p=quant_noise_scalar, bits=8, update_step=1000) + return model + + +class Quantizer(object): + def __init__(self, config_path, max_epoch, max_update): + try: + import yaml + except ImportError: + raise ImportError("Please install yaml with: pip install yaml") + + # parse config + if config_path: + with open(config_path) as config_file: + config = quantization_options.parse_config_yaml( + yaml.safe_load(config_file) + ) + else: + config = quantization_options.parse_config_yaml({}) + + self.n_centroids_config = config["n_centroids"] + self.block_sizes_config = config["block_sizes"] + self.layers_to_quantize = config["layers_to_quantize"] + + # We assume that training will run for a fixed number of epochs + # (or updates) and that we should train for equal durations + # between iterations of PQ. + num_iterations = len(self.layers_to_quantize) + if max_epoch > 0: + assert max_epoch % num_iterations == 0, ( + "for iterative PQ, --max-epoch (={}) must be evenly divisible by " + "len(layers_to_quantize) (={})".format(max_epoch, num_iterations) + ) + self.epoch_schedule = max_epoch // num_iterations + else: + self.epoch_schedule = None + if max_update > 0: + assert max_update % num_iterations == 0, ( + "for iterative PQ, --max-update (={}) must be evenly divisible by " + "len(layers_to_quantize) (={})".format(max_update, num_iterations) + ) + self.update_schedule = max_update // num_iterations + else: + self.update_schedule = None + assert (self.epoch_schedule is not None) ^ ( + self.update_schedule is not None + ), "for iterative PQ, cannot specify both --max-update and --max-epoch" + + # 0 is a special value for quantization step, which will force + # the first call to begin_epoch() to call step() + self.quantization_step = 0 + + def set_trainer(self, trainer): + self.trainer = trainer + self.size_tracker = pq.SizeTracker(self.trainer.get_model()) + + def step(self): + """Move to the next stage of quantization.""" + if self.quantization_step >= len(self.layers_to_quantize): + # Maybe we just finished the last training step or we loaded + # a checkpoint for an iterative PQ model which previously + # finished training. Either way, don't quantize again. + return + + logger.info( + "quantizing model (step={}; layers_to_quantize[step]={})".format( + self.quantization_step, self.layers_to_quantize[self.quantization_step] + ) + ) + quantized_layers = pq.quantize_model_( + self.trainer.get_model(), + self.size_tracker, + self.layers_to_quantize, + self.block_sizes_config, + self.n_centroids_config, + step=self.quantization_step, + ) + logger.info("quantized layers: {}".format(quantized_layers)) + logger.info(self.size_tracker) + + self.quantization_step += 1 + + # reintialize the Trainer since model parameters have changed + self.trainer.reinitialize() + + def begin_epoch(self, epoch): + """Called at the beginning of each epoch (epochs start at 1).""" + if ( + ( + self.epoch_schedule is not None + and epoch > 0 + and (epoch - 1) % self.epoch_schedule == 0 + ) + # we always step once in the beginning, even if using + # update-based quantization + or self.quantization_step == 0 + ): + self.step() + + def step_update(self, num_updates): + """Called at the end of each step.""" + if ( + self.update_schedule is not None + and num_updates > 0 + and num_updates % self.update_schedule == 0 + ): + self.step() + + def state_dict(self): + return { + "n_centroids_config": self.n_centroids_config, + "block_sizes_config": self.block_sizes_config, + "layers_to_quantize": self.layers_to_quantize, + "epoch_schedule": self.epoch_schedule, + "update_schedule": self.update_schedule, + "quantization_step": self.quantization_step, + } + + def load_state_dict(self, state_dict): + self.n_centroids_config = state_dict["n_centroids_config"] + self.block_sizes_config = state_dict["block_sizes_config"] + self.layers_to_quantize = state_dict["layers_to_quantize"] + self.epoch_schedule = state_dict["epoch_schedule"] + self.update_schedule = state_dict["update_schedule"] + self.quantization_step = state_dict["quantization_step"] diff --git a/fairseq/registry.py b/fairseq/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..904ffcd60253c069f466a0b7ba0aaa2136c78c82 --- /dev/null +++ b/fairseq/registry.py @@ -0,0 +1,104 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from argparse import Namespace + +from typing import Union +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import merge_with_parent +from hydra.core.config_store import ConfigStore +from omegaconf import DictConfig + +REGISTRIES = {} + + +def setup_registry(registry_name: str, base_class=None, default=None, required=False): + assert registry_name.startswith("--") + registry_name = registry_name[2:].replace("-", "_") + + REGISTRY = {} + REGISTRY_CLASS_NAMES = set() + DATACLASS_REGISTRY = {} + + # maintain a registry of all registries + if registry_name in REGISTRIES: + return # registry already exists + REGISTRIES[registry_name] = { + "registry": REGISTRY, + "default": default, + "dataclass_registry": DATACLASS_REGISTRY, + } + + def build_x(cfg: Union[DictConfig, str, Namespace], *extra_args, **extra_kwargs): + if isinstance(cfg, DictConfig): + choice = cfg._name + + if choice and choice in DATACLASS_REGISTRY: + from_checkpoint = extra_kwargs.get("from_checkpoint", False) + dc = DATACLASS_REGISTRY[choice] + cfg = merge_with_parent(dc(), cfg, remove_missing=from_checkpoint) + elif isinstance(cfg, str): + choice = cfg + if choice in DATACLASS_REGISTRY: + cfg = DATACLASS_REGISTRY[choice]() + else: + choice = getattr(cfg, registry_name, None) + if choice in DATACLASS_REGISTRY: + cfg = DATACLASS_REGISTRY[choice].from_namespace(cfg) + + if choice is None: + if required: + raise ValueError("{} is required!".format(registry_name)) + return None + + cls = REGISTRY[choice] + if hasattr(cls, "build_" + registry_name): + builder = getattr(cls, "build_" + registry_name) + else: + builder = cls + + if "from_checkpoint" in extra_kwargs: + del extra_kwargs["from_checkpoint"] + + return builder(cfg, *extra_args, **extra_kwargs) + + def register_x(name, dataclass=None): + def register_x_cls(cls): + if name in REGISTRY: + raise ValueError( + "Cannot register duplicate {} ({})".format(registry_name, name) + ) + if cls.__name__ in REGISTRY_CLASS_NAMES: + raise ValueError( + "Cannot register {} with duplicate class name ({})".format( + registry_name, cls.__name__ + ) + ) + if base_class is not None and not issubclass(cls, base_class): + raise ValueError( + "{} must extend {}".format(cls.__name__, base_class.__name__) + ) + + if dataclass is not None and not issubclass(dataclass, FairseqDataclass): + raise ValueError( + "Dataclass {} must extend FairseqDataclass".format(dataclass) + ) + + cls.__dataclass = dataclass + if cls.__dataclass is not None: + DATACLASS_REGISTRY[name] = cls.__dataclass + + cs = ConfigStore.instance() + node = dataclass() + node._name = name + cs.store(name=name, group=registry_name, node=node, provider="fairseq") + + REGISTRY[name] = cls + + return cls + + return register_x_cls + + return build_x, register_x, REGISTRY, DATACLASS_REGISTRY diff --git a/fairseq/scoring/__init__.py b/fairseq/scoring/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..58f2f563e493327394dff1265030d18f0814b5a2 --- /dev/null +++ b/fairseq/scoring/__init__.py @@ -0,0 +1,55 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import importlib +import os +from abc import ABC, abstractmethod + +from fairseq import registry +from omegaconf import DictConfig + + +class BaseScorer(ABC): + def __init__(self, cfg): + self.cfg = cfg + self.ref = [] + self.pred = [] + + def add_string(self, ref, pred): + self.ref.append(ref) + self.pred.append(pred) + + @abstractmethod + def score(self) -> float: + pass + + @abstractmethod + def result_string(self) -> str: + pass + + +_build_scorer, register_scorer, SCORER_REGISTRY, _ = registry.setup_registry( + "--scoring", default="bleu" +) + + +def build_scorer(choice, tgt_dict): + _choice = choice._name if isinstance(choice, DictConfig) else choice + + if _choice == "bleu": + from fairseq.scoring import bleu + + return bleu.Scorer( + bleu.BleuConfig(pad=tgt_dict.pad(), eos=tgt_dict.eos(), unk=tgt_dict.unk()) + ) + return _build_scorer(choice) + + +# automatically import any Python files in the current directory +for file in sorted(os.listdir(os.path.dirname(__file__))): + if file.endswith(".py") and not file.startswith("_"): + module = file[: file.find(".py")] + importlib.import_module("fairseq.scoring." + module) diff --git a/fairseq/scoring/__pycache__/__init__.cpython-310.pyc b/fairseq/scoring/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19f35cbccea6e0f6d4167a00c74d4209440e0b14 Binary files /dev/null and b/fairseq/scoring/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/scoring/__pycache__/bertscore.cpython-310.pyc b/fairseq/scoring/__pycache__/bertscore.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fb626e8f0cf14537f0a2c68cbfb154be5d364e0 Binary files /dev/null and b/fairseq/scoring/__pycache__/bertscore.cpython-310.pyc differ diff --git a/fairseq/scoring/__pycache__/bleu.cpython-310.pyc b/fairseq/scoring/__pycache__/bleu.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1522512db11f334c8fe590b0740458a0d7029fcd Binary files /dev/null and b/fairseq/scoring/__pycache__/bleu.cpython-310.pyc differ diff --git a/fairseq/scoring/__pycache__/chrf.cpython-310.pyc b/fairseq/scoring/__pycache__/chrf.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ac35929bacd4d98bf3195d3afef3c9709bb0df5 Binary files /dev/null and b/fairseq/scoring/__pycache__/chrf.cpython-310.pyc differ diff --git a/fairseq/scoring/__pycache__/meteor.cpython-310.pyc b/fairseq/scoring/__pycache__/meteor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f45481d9490e0fb56947c037dd0c74d37a96637b Binary files /dev/null and b/fairseq/scoring/__pycache__/meteor.cpython-310.pyc differ diff --git a/fairseq/scoring/__pycache__/tokenizer.cpython-310.pyc b/fairseq/scoring/__pycache__/tokenizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ef6bebb57fa8174049c68780f374a32f4af5f61 Binary files /dev/null and b/fairseq/scoring/__pycache__/tokenizer.cpython-310.pyc differ diff --git a/fairseq/scoring/__pycache__/wer.cpython-310.pyc b/fairseq/scoring/__pycache__/wer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad6572b878282327a0ea6fa82f3b88134a6fea36 Binary files /dev/null and b/fairseq/scoring/__pycache__/wer.cpython-310.pyc differ diff --git a/fairseq/scoring/bertscore.py b/fairseq/scoring/bertscore.py new file mode 100644 index 0000000000000000000000000000000000000000..6d5a8450d34a7f76ba08748f61f5d84fcf5570ee --- /dev/null +++ b/fairseq/scoring/bertscore.py @@ -0,0 +1,44 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +import numpy as np + +from fairseq.dataclass import FairseqDataclass +from fairseq.scoring import BaseScorer, register_scorer + + +@dataclass +class BertScoreScorerConfig(FairseqDataclass): + bert_score_lang: str = field(default="en", metadata={"help": "BERTScore language"}) + + +@register_scorer("bert_score", dataclass=BertScoreScorerConfig) +class BertScoreScorer(BaseScorer): + def __init__(self, cfg): + super(BertScoreScorer, self).__init__(cfg) + try: + import bert_score as _bert_score + except ImportError: + raise ImportError("Please install BERTScore: pip install bert-score") + + self.cfg = cfg + self._bert_score = _bert_score + self.scores = None + + def add_string(self, ref, pred): + self.ref.append(ref) + self.pred.append(pred) + + def score(self, order=4): + _, _, self.scores = self._bert_score.score( + self.pred, self.ref, lang=self.cfg.bert_score_lang + ) + self.scores = self.scores.numpy() + return np.mean(self.scores) + + def result_string(self, order=4): + return f"BERTScore: {self.score():.4f}" diff --git a/fairseq/scoring/bleu.py b/fairseq/scoring/bleu.py new file mode 100644 index 0000000000000000000000000000000000000000..e55bd2f393e279d318e08d680c41cc20e8d49931 --- /dev/null +++ b/fairseq/scoring/bleu.py @@ -0,0 +1,168 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import ctypes +import math +import sys +from dataclasses import dataclass, field + +import torch +from fairseq.dataclass import FairseqDataclass +from fairseq.scoring import BaseScorer, register_scorer +from fairseq.scoring.tokenizer import EvaluationTokenizer + + +class BleuStat(ctypes.Structure): + _fields_ = [ + ("reflen", ctypes.c_size_t), + ("predlen", ctypes.c_size_t), + ("match1", ctypes.c_size_t), + ("count1", ctypes.c_size_t), + ("match2", ctypes.c_size_t), + ("count2", ctypes.c_size_t), + ("match3", ctypes.c_size_t), + ("count3", ctypes.c_size_t), + ("match4", ctypes.c_size_t), + ("count4", ctypes.c_size_t), + ] + + +@dataclass +class SacrebleuConfig(FairseqDataclass): + sacrebleu_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field( + default="13a", metadata={"help": "tokenizer"} + ) + sacrebleu_lowercase: bool = field( + default=False, metadata={"help": "apply lowercasing"} + ) + sacrebleu_char_level: bool = field( + default=False, metadata={"help": "evaluate at character level"} + ) + + +@register_scorer("sacrebleu", dataclass=SacrebleuConfig) +class SacrebleuScorer(BaseScorer): + def __init__(self, cfg): + super(SacrebleuScorer, self).__init__(cfg) + import sacrebleu + + self.sacrebleu = sacrebleu + self.tokenizer = EvaluationTokenizer( + tokenizer_type=cfg.sacrebleu_tokenizer, + lowercase=cfg.sacrebleu_lowercase, + character_tokenization=cfg.sacrebleu_char_level, + ) + + def add_string(self, ref, pred): + self.ref.append(self.tokenizer.tokenize(ref)) + self.pred.append(self.tokenizer.tokenize(pred)) + + def _score(self, order=4): + if order != 4: + raise NotImplementedError + # tokenization and lowercasing are performed by self.tokenizer instead. + return self.sacrebleu.corpus_bleu(self.pred, [self.ref], tokenize="none") + + def score(self, order=4): + return self._score(order).score + + def result_string(self, order=4): + return self._score(order).format() + + +@dataclass +class BleuConfig(FairseqDataclass): + pad: int = field(default=1, metadata={"help": "padding index"}) + eos: int = field(default=2, metadata={"help": "eos index"}) + unk: int = field(default=3, metadata={"help": "unk index"}) + + +@register_scorer("bleu", dataclass=BleuConfig) +class Scorer(object): + def __init__(self, cfg): + self.stat = BleuStat() + self.pad = cfg.pad + self.eos = cfg.eos + self.unk = cfg.unk + + try: + from fairseq import libbleu + except ImportError as e: + sys.stderr.write( + "ERROR: missing libbleu.so. run `pip install --editable .`\n" + ) + raise e + + self.C = ctypes.cdll.LoadLibrary(libbleu.__file__) + + self.reset() + + def reset(self, one_init=False): + if one_init: + self.C.bleu_one_init(ctypes.byref(self.stat)) + else: + self.C.bleu_zero_init(ctypes.byref(self.stat)) + + def add(self, ref, pred): + if not isinstance(ref, torch.IntTensor): + raise TypeError("ref must be a torch.IntTensor (got {})".format(type(ref))) + if not isinstance(pred, torch.IntTensor): + raise TypeError("pred must be a torch.IntTensor(got {})".format(type(pred))) + + # don't match unknown words + rref = ref.clone() + assert not rref.lt(0).any() + rref[rref.eq(self.unk)] = -999 + + rref = rref.contiguous().view(-1) + pred = pred.contiguous().view(-1) + + self.C.bleu_add( + ctypes.byref(self.stat), + ctypes.c_size_t(rref.size(0)), + ctypes.c_void_p(rref.data_ptr()), + ctypes.c_size_t(pred.size(0)), + ctypes.c_void_p(pred.data_ptr()), + ctypes.c_int(self.pad), + ctypes.c_int(self.eos), + ) + + def score(self, order=4): + psum = sum( + math.log(p) if p > 0 else float("-Inf") for p in self.precision()[:order] + ) + return self.brevity() * math.exp(psum / order) * 100 + + def precision(self): + def ratio(a, b): + return a / b if b > 0 else 0 + + return [ + ratio(self.stat.match1, self.stat.count1), + ratio(self.stat.match2, self.stat.count2), + ratio(self.stat.match3, self.stat.count3), + ratio(self.stat.match4, self.stat.count4), + ] + + def brevity(self): + r = self.stat.reflen / self.stat.predlen + return min(1, math.exp(1 - r)) + + def result_string(self, order=4): + assert order <= 4, "BLEU scores for order > 4 aren't supported" + fmt = "BLEU{} = {:2.2f}, {:2.1f}" + for _ in range(1, order): + fmt += "/{:2.1f}" + fmt += " (BP={:.3f}, ratio={:.3f}, syslen={}, reflen={})" + bleup = [p * 100 for p in self.precision()[:order]] + return fmt.format( + order, + self.score(order=order), + *bleup, + self.brevity(), + self.stat.predlen / self.stat.reflen, + self.stat.predlen, + self.stat.reflen + ) diff --git a/fairseq/scoring/chrf.py b/fairseq/scoring/chrf.py new file mode 100644 index 0000000000000000000000000000000000000000..5df5a1c011243fe2e836c38a5f8459aeb824f0e7 --- /dev/null +++ b/fairseq/scoring/chrf.py @@ -0,0 +1,36 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass + +from fairseq.dataclass import FairseqDataclass +from fairseq.scoring import BaseScorer, register_scorer + + +@dataclass +class ChrFScorerConfig(FairseqDataclass): + pass + + +@register_scorer("chrf", dataclass=ChrFScorerConfig) +class ChrFScorer(BaseScorer): + def __init__(self, args): + super(ChrFScorer, self).__init__(args) + import sacrebleu + + self.sacrebleu = sacrebleu + + def add_string(self, ref, pred): + self.ref.append(ref) + self.pred.append(pred) + + def score(self, order=4): + return self.result_string(order).score + + def result_string(self, order=4): + if order != 4: + raise NotImplementedError + return self.sacrebleu.corpus_chrf(self.pred, [self.ref]).format() diff --git a/fairseq/scoring/meteor.py b/fairseq/scoring/meteor.py new file mode 100644 index 0000000000000000000000000000000000000000..32719956fe32cc25ab071aa7ff6abf84c6533fe8 --- /dev/null +++ b/fairseq/scoring/meteor.py @@ -0,0 +1,42 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +from dataclasses import dataclass + +from fairseq.dataclass import FairseqDataclass +from fairseq.scoring import BaseScorer, register_scorer + + +@dataclass +class MeteorScorerConfig(FairseqDataclass): + pass + + +@register_scorer("meteor", dataclass=MeteorScorerConfig) +class MeteorScorer(BaseScorer): + def __init__(self, args): + super(MeteorScorer, self).__init__(args) + try: + import nltk + except ImportError: + raise ImportError("Please install nltk to use METEOR scorer") + + self.nltk = nltk + self.scores = [] + + def add_string(self, ref, pred): + self.ref.append(ref) + self.pred.append(pred) + + def score(self, order=4): + self.scores = [ + self.nltk.translate.meteor_score.single_meteor_score(r, p) + for r, p in zip(self.ref, self.pred) + ] + return np.mean(self.scores) + + def result_string(self, order=4): + return f"METEOR: {self.score():.4f}" diff --git a/fairseq/scoring/tokenizer.py b/fairseq/scoring/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..b0cedd5099b9ddc7278205e1a4f2aa18359a14bf --- /dev/null +++ b/fairseq/scoring/tokenizer.py @@ -0,0 +1,80 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unicodedata + +import sacrebleu as sb + +from fairseq.dataclass import ChoiceEnum + +SACREBLEU_V2_ABOVE = int(sb.__version__[0]) >= 2 + + +class EvaluationTokenizer(object): + """A generic evaluation-time tokenizer, which leverages built-in tokenizers + in sacreBLEU (https://github.com/mjpost/sacrebleu). It additionally provides + lowercasing, punctuation removal and character tokenization, which are + applied after sacreBLEU tokenization. + + Args: + tokenizer_type (str): the type of sacreBLEU tokenizer to apply. + lowercase (bool): lowercase the text. + punctuation_removal (bool): remove punctuation (based on unicode + category) from text. + character_tokenization (bool): tokenize the text to characters. + """ + + SPACE = chr(32) + SPACE_ESCAPE = chr(9601) + _ALL_TOKENIZER_TYPES = ( + sb.BLEU.TOKENIZERS + if SACREBLEU_V2_ABOVE + else ["none", "13a", "intl", "zh", "ja-mecab"] + ) + ALL_TOKENIZER_TYPES = ChoiceEnum(_ALL_TOKENIZER_TYPES) + + def __init__( + self, + tokenizer_type: str = "13a", + lowercase: bool = False, + punctuation_removal: bool = False, + character_tokenization: bool = False, + ): + + assert ( + tokenizer_type in self._ALL_TOKENIZER_TYPES + ), f"{tokenizer_type}, {self._ALL_TOKENIZER_TYPES}" + self.lowercase = lowercase + self.punctuation_removal = punctuation_removal + self.character_tokenization = character_tokenization + if SACREBLEU_V2_ABOVE: + self.tokenizer = sb.BLEU(tokenize=str(tokenizer_type)).tokenizer + else: + self.tokenizer = sb.tokenizers.TOKENIZERS[tokenizer_type]() + + @classmethod + def remove_punctuation(cls, sent: str): + """Remove punctuation based on Unicode category.""" + return cls.SPACE.join( + t + for t in sent.split(cls.SPACE) + if not all(unicodedata.category(c)[0] == "P" for c in t) + ) + + def tokenize(self, sent: str): + tokenized = self.tokenizer(sent) + + if self.punctuation_removal: + tokenized = self.remove_punctuation(tokenized) + + if self.character_tokenization: + tokenized = self.SPACE.join( + list(tokenized.replace(self.SPACE, self.SPACE_ESCAPE)) + ) + + if self.lowercase: + tokenized = tokenized.lower() + + return tokenized diff --git a/fairseq/scoring/wer.py b/fairseq/scoring/wer.py new file mode 100644 index 0000000000000000000000000000000000000000..633dc47c247691c4c9e36cbdbab7d7cb74b38452 --- /dev/null +++ b/fairseq/scoring/wer.py @@ -0,0 +1,58 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +from fairseq.dataclass import FairseqDataclass +from fairseq.scoring import BaseScorer, register_scorer +from fairseq.scoring.tokenizer import EvaluationTokenizer + + +@dataclass +class WerScorerConfig(FairseqDataclass): + wer_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field( + default="none", metadata={"help": "sacreBLEU tokenizer to use for evaluation"} + ) + wer_remove_punct: bool = field( + default=False, metadata={"help": "remove punctuation"} + ) + wer_char_level: bool = field( + default=False, metadata={"help": "evaluate at character level"} + ) + wer_lowercase: bool = field(default=False, metadata={"help": "lowercasing"}) + + +@register_scorer("wer", dataclass=WerScorerConfig) +class WerScorer(BaseScorer): + def __init__(self, cfg): + super().__init__(cfg) + self.reset() + try: + import editdistance as ed + except ImportError: + raise ImportError("Please install editdistance to use WER scorer") + self.ed = ed + self.tokenizer = EvaluationTokenizer( + tokenizer_type=self.cfg.wer_tokenizer, + lowercase=self.cfg.wer_lowercase, + punctuation_removal=self.cfg.wer_remove_punct, + character_tokenization=self.cfg.wer_char_level, + ) + + def reset(self): + self.distance = 0 + self.ref_length = 0 + + def add_string(self, ref, pred): + ref_items = self.tokenizer.tokenize(ref).split() + pred_items = self.tokenizer.tokenize(pred).split() + self.distance += self.ed.eval(ref_items, pred_items) + self.ref_length += len(ref_items) + + def result_string(self): + return f"WER: {self.score():.2f}" + + def score(self): + return 100.0 * self.distance / self.ref_length if self.ref_length > 0 else 0 diff --git a/fairseq/search.py b/fairseq/search.py new file mode 100644 index 0000000000000000000000000000000000000000..c7378bbb514342cd3f9f56c8514d0fa5cb351316 --- /dev/null +++ b/fairseq/search.py @@ -0,0 +1,892 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from typing import List, Optional + +import torch +import torch.nn as nn +from fairseq.token_generation_constraints import ( + ConstraintState, + OrderedConstraintState, + UnorderedConstraintState, +) +from torch import Tensor + + +class Search(nn.Module): + def __init__(self, tgt_dict): + super().__init__() + self.pad = tgt_dict.pad() + self.unk = tgt_dict.unk() + self.eos = tgt_dict.eos() + self.vocab_size = len(tgt_dict) + self.src_lengths = torch.tensor(-1) + self.supports_constraints = False + self.stop_on_max_len = False + + def step( + self, step, lprobs, scores, prev_output_tokens=None, original_batch_idxs=None + ): + """Take a single search step. + + Args: + step: the current search step, starting at 0 + lprobs: (bsz x input_beam_size x vocab_size) + the model's log-probabilities over the vocabulary at the current step + scores: (bsz x input_beam_size x step) + the historical model scores of each hypothesis up to this point + prev_output_tokens: (bsz x step) + the previously generated oputput tokens + original_batch_idxs: (bsz) + the tensor with the batch indices, in the range [0, bsz) + this is useful in case there has been applied a re-ordering + and we need to know the orignal indices + + Return: A tuple of (scores, indices, beams) where: + scores: (bsz x output_beam_size) + the scores of the chosen elements; output_beam_size can be + larger than input_beam_size, e.g., we may return + 2*input_beam_size to account for EOS + indices: (bsz x output_beam_size) + the indices of the chosen elements + beams: (bsz x output_beam_size) + the hypothesis ids of the chosen elements, in the range [0, input_beam_size) + """ + raise NotImplementedError + + @torch.jit.export + def set_src_lengths(self, src_lengths): + self.src_lengths = src_lengths + + @torch.jit.export + def init_constraints(self, batch_constraints: Optional[Tensor], beam_size: int): + """Initialize constraint states for constrained decoding (if supported). + + Args: + batch_constraints: (torch.Tensor, optional) + the list of constraints, in packed form + beam_size: (int) + the beam size + Returns: + *encoder_out* rearranged according to *new_order* + """ + pass + + def prune_sentences(self, batch_idxs: Tensor): + """ + Removes constraint states for completed sentences (if supported). + This is called from sequence_generator._generate() when sentences are + deleted from the batch. + + Args: + batch_idxs: Indices of *sentences* whose constraint state should be *kept*. + """ + pass + + def update_constraints(self, active_hypos: Tensor): + """ + Updates the constraint states by selecting the beam items that are retained. + This is called at each time step of sequence_generator._generate() when + the set of 2 * {beam_size} candidate hypotheses are reduced to the beam size. + + Args: + active_hypos: (batch size, beam size) + list of integers denoting, for each sentence, which beam candidate items + should be kept. + """ + pass + + +class BeamSearch(Search): + def __init__(self, tgt_dict): + super().__init__(tgt_dict) + self.constraint_states = None + + @torch.jit.export + def step( + self, + step: int, + lprobs, + scores: Optional[Tensor], + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + candidate_multiple: int = 2, + ): + bsz, beam_size, vocab_size = lprobs.size() + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam + lprobs = lprobs[:, ::beam_size, :].contiguous() + else: + # make probs contain cumulative scores for each hypothesis + assert scores is not None + lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) + + top_prediction = torch.topk( + lprobs.view(bsz, -1), + k=min( + # Take the best `candidate_muliple`(default 2) x beam_size predictions. We'll choose the first + # beam_size of these which don't predict eos to continue with. + candidate_multiple * beam_size, + lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad + ), + ) + scores_buf = top_prediction[0] + indices_buf = top_prediction[1] + # Project back into relative indices and beams + beams_buf = torch.div(indices_buf, vocab_size, rounding_mode="trunc") + indices_buf = indices_buf.fmod(vocab_size) + + # At this point, beams_buf and indices_buf are single-dim and contain relative indices + return scores_buf, indices_buf, beams_buf + + +class PrefixConstrainedBeamSearch(Search): + def __init__(self, tgt_dict, prefix_allowed_tokens_fn): + super().__init__(tgt_dict) + self.prefix_allowed_tokens_fn = prefix_allowed_tokens_fn + self.stop_on_max_len = True + + @torch.jit.export + def apply_mask(self, x, prev_output_tokens, original_batch_idxs): + beam_size = x.shape[0] // original_batch_idxs.shape[0] + original_batch_idxs = ( + original_batch_idxs.unsqueeze(-1).repeat((1, beam_size)).flatten().tolist() + ) + + mask = torch.full_like(x, -math.inf) + for sent_i, (sent, batch_i) in enumerate( + zip(prev_output_tokens, original_batch_idxs) + ): + mask[sent_i, :, self.prefix_allowed_tokens_fn(batch_i, sent)] = 0 + + return mask + + @torch.jit.export + def step( + self, + step: int, + lprobs: Tensor, + scores: Tensor, + prev_output_tokens: Tensor, + original_batch_idxs: Tensor, + ): + bsz, beam_size, vocab_size = lprobs.size() + + lprobs += self.apply_mask( + lprobs.view(bsz * beam_size, 1, vocab_size), + prev_output_tokens, + original_batch_idxs, + ).view(bsz, beam_size, vocab_size) + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam + lprobs = lprobs[:, ::beam_size, :].contiguous() + else: + # make probs contain cumulative scores for each hypothesis + assert scores is not None + lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) + + top_prediction = torch.topk( + lprobs.view(bsz, -1), + k=min( + # Take the best beam_size predictions. We'll choose the first + # beam_size of these which don't predict eos to continue with. + beam_size, + lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad + ), + ) + scores_buf = top_prediction[0] + indices_buf = top_prediction[1] + beams_buf = indices_buf // vocab_size + indices_buf = indices_buf.fmod(vocab_size) + return scores_buf, indices_buf, beams_buf + + +class LexicallyConstrainedBeamSearch(Search): + """Implements lexically constrained beam search as described in + + Fast Lexically Constrained Decoding with Dynamic Beam + Allocation for Neural Machine Translation. Post & Vilar, + NAACL 2018. https://www.aclweb.org/anthology/N18-1119/ + + and + + Improved Lexically Constrained Decoding for Translation and + Monolingual Rewriting. Hu et al, NAACL + 2019. https://www.aclweb.org/anthology/N19-1090/ + + This is accomplished by maintaining, for each beam hypothesis, a + ConstraintState object (see constraints.py) that tracks which + constraints have been generated and using this information to + shape the beam for each input sentence. + """ + + def __init__(self, tgt_dict, representation): + super().__init__(tgt_dict) + self.representation = representation + self.vocab_size = len(tgt_dict) + self.num_cands = 0 + self.supports_constraints = True + + @torch.jit.export + def init_constraints(self, batch_constraints: Optional[Tensor], beam_size: int): + self.constraint_states = [] + for constraint_tensor in batch_constraints: + if self.representation == "ordered": + constraint_state = OrderedConstraintState.create(constraint_tensor) + elif self.representation == "unordered": + constraint_state = UnorderedConstraintState.create(constraint_tensor) + + self.constraint_states.append([constraint_state for i in range(beam_size)]) + + @torch.jit.export + def prune_sentences(self, batch_idxs: Tensor): + self.constraint_states = [ + self.constraint_states[i] for i in batch_idxs.tolist() + ] + + @torch.jit.export + def update_constraints(self, active_hypos: Tensor): + if self.constraint_states: + batch_size = active_hypos.size(0) + for sentid in range(batch_size): + self.constraint_states[sentid] = [ + self.constraint_states[sentid][i] for i in active_hypos[sentid] + ] + + @torch.jit.export + def step( + self, + step: int, + lprobs: Tensor, + scores: Optional[Tensor], + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + """ + A constrained step builds a large candidates list from the following: + - the top 2 * {beam_size} items over the whole beam + - for each item in the beam + - the top {each_k} (default 1) + - all next constraints + We then compute the constrained state of each beam item, and assign + stripe codes: 0 to the best in each bank, 1 to the 2nd-best, and so + on. We then sort by (stripe, score), and truncate the list at + 2 * beam size. + + Args: + step: the decoder step + lprobs: (batch size, beam size, target vocab) + the target-vocab distributions for each item in the beam. + Retrun: A tuple of (scores, indices, beams, constraints) where: + scores: (batch, output beam size) + the scores of the chosen elements + indices: (batch, output beam size) + the target vocab indices of the chosen elements + beams: (batch, output beam size) + the 0-indexed hypothesis ids of the chosen elements + constraints: (batch, output beam size) + the new constraint states + """ + each_k = 1 + device = lprobs.device + + batch_size, beam_size, vocab_size = lprobs.size() + + self.num_cands = min( + # Just take the k-best. We'll get another k from the 1-best from each + # row, plus more from the constraints + beam_size * 2, + lprobs.view(batch_size, -1).size(1) - 1, # -1 so we never select pad + ) + + # STEP 0: Preliminary. Prevent EOS for unfinished hyps across all batch items + constraint_states = self.constraint_states + if constraint_states and step > 0: + not_finished_indices = [] + for sentno, sent_constraints in enumerate(constraint_states): + for beamno, state in enumerate(sent_constraints): + index = sentno * beam_size + beamno + if not state.finished: + not_finished_indices.append(index) + not_finished_indices = torch.tensor(not_finished_indices) + if not_finished_indices.numel() > 0: + lprobs.view(batch_size * beam_size, -1)[ + not_finished_indices, self.eos + ] = -math.inf + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam entry for each batch item + lprobs = lprobs[:, ::beam_size, :].contiguous() + else: + # make probs contain cumulative scores for each hypothesis + assert scores is not None + lprobs = lprobs + scores[:, :, step - 1].unsqueeze(-1) + + top_prediction = torch.topk( + lprobs.view(batch_size, -1), + self.num_cands, + ) + scores_buf, indices_buf = top_prediction + # Project back into relative indices and beams + beams_buf = indices_buf // vocab_size + indices_buf = indices_buf.fmod(vocab_size) + + # Short circuit if there are no constraints in this batch + if not constraint_states: + return scores_buf, indices_buf, beams_buf + + # STEP 1: get top-1 from each hypothesis across all sentences in the batch + if step > 0: + top_scores, top_indices = torch.topk( + lprobs.view(batch_size * beam_size, -1), + k=each_k, + dim=1, + ) + top_scores = top_scores.view(batch_size, -1) + top_indices = top_indices.view(batch_size, -1) + scores_buf = torch.cat((scores_buf, top_scores), dim=1) + indices_buf = torch.cat((indices_buf, top_indices), dim=1) + new_beams = torch.arange(0, beam_size, device=device).repeat(batch_size, 1) + beams_buf = torch.cat((beams_buf, new_beams), dim=1) + + # Now, process sentences in the batch one by one. + new_scores_buf = torch.zeros((batch_size, 2 * beam_size), device=device) + new_indices_buf = torch.zeros((batch_size, 2 * beam_size), device=device).long() + new_beams_buf = torch.zeros((batch_size, 2 * beam_size), device=device).long() + for sentno, states in enumerate(constraint_states): + scores, indices, beams, new_states = self.step_sentence( + step, + sentno, + lprobs[sentno], + constraint_states[sentno], + beams_buf[sentno].clone(), + indices_buf[sentno].clone(), + scores_buf[sentno].clone(), + ) + new_scores_buf[sentno] = scores + new_indices_buf[sentno] = indices + new_beams_buf[sentno] = beams + self.constraint_states[sentno] = new_states + + return new_scores_buf, new_indices_buf, new_beams_buf + + @torch.jit.export + def step_sentence( + self, + step: int, + sentno: int, + lprobs: Tensor, + constraint_states: List[List[ConstraintState]], + beams_buf: Tensor, + indices_buf: Tensor, + scores_buf: Tensor, + ): + """Does per-sentence processing. Adds all constraints for each + hypothesis to the list of candidates; then removes duplicates, + sorts, and dynamically stripes across the banks. All tensor inputs + are collapsed to those pertaining to a single input sentence. + """ + device = lprobs.device + + # STEP 2: Add all constraints for each beam item + for beamno, state in enumerate(constraint_states): + next_tokens = torch.tensor(list(state.next_tokens()), device=device).long() + if next_tokens.numel() != 0: + indices_buf = torch.cat((indices_buf, next_tokens)) + next_beams = ( + torch.tensor(beamno, device=device) + .repeat(next_tokens.size(0)) + .long() + ) + beams_buf = torch.cat((beams_buf, next_beams)) + next_values = lprobs[beamno].take(next_tokens.view(-1)) + scores_buf = torch.cat((scores_buf, next_values)) + + # At the 0th time step, there is just one beam item + if step == 0: + break + + # STEP 3: Compute the "bank" for each candidate. This is the + # number of constraints it's generated. We need this so that + # we can do round-robin allocation of the beam across these + # banks. If C is the number of constraints, we select the best + # item in bank C, then the best in bank C-1, etc, followed by + # the 2nd-best in bank C, the 2nd-best in bank C-1, etc, and so + # on, until the maximum beam size. We accomplish this by + # creating a sort key and striping across the banks. + + # Compute the new states for all candidates + cands_size = indices_buf.size(0) + constraint_states = [ + constraint_states[beams_buf[i]].advance(indices_buf[i]) + for i in range(cands_size) + ] + + banks = torch.tensor([state.bank for state in constraint_states], device=device) + + # STEP 4: Sort + num_constraint_tokens = len(state.tokens) + + # Sort by keys (bank, score) (i.e., sort banks together, and scores + # within banks). AFAIK pytorch doesn't support either stable sort or + # multi-key sorting, so we have to hack this. + MAX_SCORE = -100 + sort_key = (num_constraint_tokens - banks) * MAX_SCORE + scores_buf + sort_values, sort_indices = sort_key.sort(dim=0, descending=True) + scores_buf = scores_buf[sort_indices] + indices_buf = indices_buf[sort_indices] + beams_buf = beams_buf[sort_indices] + banks = banks[sort_indices] + + # Sort the constraints to follow suit + constraint_states = [constraint_states[i] for i in sort_indices] + + # STEP 5: Remove duplicates. The topk calls (overall and + # per-row) plus the per-row generation of constraints will + # produce duplicates. Here we remove them. + + def roll(t): + """Rolls a 1d tensor left by 1. + + [0, 1, 2, 3, 4] becomes [4, 0, 1, 2, 3] + """ + return torch.cat((t[-1].unsqueeze(0), t[0:-1]), dim=0) + + # We map candidates (beam, token_id) to a single dimension. + # This is then shifted by 1. We can then easily identify + # duplicates and create a mask that identifies unique + # extensions. + uniques_mask = beams_buf * (self.vocab_size + 1) + indices_buf + uniques_mask = roll(uniques_mask) != uniques_mask + + # Use the mask to pare down the data structures + scores_buf = torch.masked_select(scores_buf, uniques_mask) + indices_buf = torch.masked_select(indices_buf, uniques_mask) + beams_buf = torch.masked_select(beams_buf, uniques_mask) + banks = torch.masked_select(banks, uniques_mask) + i = 1 + for mask in uniques_mask[1:]: + if not mask: + constraint_states.pop(i) + i += mask + + # STEP 6: Assign IDs round-robin across banks, sort, and + # truncate. Now that the candidates are sorted by (bank, + # score) and uniqed, we dynamically allocate the {beam_size} + # beam by striping across the candidates. These stripes will + # be used as sort keys to do round-robin selection. This is + # accomplished in a single pass with offsets. Sorting by + # highest-banks (furthest-along hypotheses) first ensures + # progress through the constraints. + # + # e.g., BANKS: 3 3 3 2 2 2 2 1 1 1 0 0 + # OLD STRIPES: 0 1 2 0 1 2 3 0 1 2 0 1 + # NEW STRIPES: 0 1+4 2+8 0+1 1+5 2+9 3+11 0+2 1+6 2+10 0+3 1+7 + # = 0 5 10 1 6 11 13 2 7 12 3 8 + # + # Sorting by this then gives the following banks: + # + # 3 2 1 0 3 2 1 0 3 2 1 2 + # + # We'll take the top {beam_size} of these. + stripe_offsets = [offset * (len(banks) + 1) for offset in range(len(banks) + 1)] + stripes = torch.zeros_like(banks) + cur_bank_count = -1 + cur_bank = banks[0] + for i, bank in enumerate(banks): + if bank != cur_bank: + cur_bank_count = 0 + cur_bank = bank + else: + cur_bank_count += 1 + stripes[i] = num_constraint_tokens - bank + stripe_offsets[cur_bank_count] + + # STEP 7: Sort by the stripes values + sort_values, sort_indices = stripes.sort(dim=0) + scores_buf = scores_buf[sort_indices] + indices_buf = indices_buf[sort_indices] + beams_buf = beams_buf[sort_indices] + constraint_states = [constraint_states[i] for i in sort_indices] + + # STEP 8: Truncate to the candidates size! + scores_buf = scores_buf[: self.num_cands] + indices_buf = indices_buf[: self.num_cands] + beams_buf = beams_buf[: self.num_cands] + + return scores_buf, indices_buf, beams_buf, constraint_states + + +class LengthConstrainedBeamSearch(Search): + def __init__(self, tgt_dict, min_len_a, min_len_b, max_len_a, max_len_b): + super().__init__(tgt_dict) + self.min_len_a = min_len_a + self.min_len_b = min_len_b + self.max_len_a = max_len_a + self.max_len_b = max_len_b + self.beam = BeamSearch(tgt_dict) + self.needs_src_lengths = True + + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + min_lens = self.min_len_a * self.src_lengths + self.min_len_b + max_lens = self.max_len_a * self.src_lengths + self.max_len_b + lprobs[step < min_lens, :, self.eos] = -math.inf + lprobs[step >= max_lens, :, self.eos] = 0 + return self.beam.step(step, lprobs, scores) + + +class DiverseBeamSearch(Search): + """Diverse Beam Search. + + See "Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence + Models" for details. + + We implement cumulative diversity penalty here as default, optionally provide Hamming diversity described + in the original paper, and a way to interpolate between the two through diversity_discount. + + Take the example below for illustration of cumulative diversity implemented. + A) I like dogs. + B) I like ____. + C) There are ___. + And we are at step=2, trying to fill in the blank: + + Hamming diversity: + Penalty for B from A is 1 for "dogs" and 0 for any other words like "cats". + Penalty for C from A is 1 for "dogs" and 0 for any other words like "cats". + + Cumulative diversity (default): + Penalty for B from A is 3 for "dogs" and 0 for any other words like "cats". + Penalty for C from A is 1 for "dogs" and 0 for any other words like "cats". + B and C differ because B matches with A for "I" and "like" at respective steps incurring 2 cumulative penalty. + + Using divesrity_discount to interpolate between the two: + if diverstiy_discount = 0.5, then + Penalty for B from A is 1.75 (1 + 0.5 + 0.25) for "dogs" and 0 for any other words like "cats". + Penalty for C from A is 1 for "dogs" and 0 for any other words like "cats". + "I" and "like" matched for B and A at step 0 and 1 respectively. Since "I" is two steps away and "like" is one step away, they are discounted by (0.5)^2 and 0.5 respectively. + When diversity_discount = 0, we recover Hammning diversity and when diversity_discount = 1, we recover cumulative diversity. + + NB: During beam search for each diversity group, `candidate_mutiple` is set to 1 rather than BeamSearch default(2). + This is to ensure we have final `beam_size` candidates so that no diversity groups would be dropped during final token selection in sequence generation. + For full backwards compatibility, use diversity_discount=0 and candidate_multiple=2. + + """ + + def __init__( + self, + tgt_dict, + num_groups, + diversity_strength, + diversity_discount=1.0, + candidate_multiple=1, + ): + super().__init__(tgt_dict) + self.num_groups = num_groups + self.diversity_strength = -diversity_strength + self.beam = BeamSearch(tgt_dict) + self.diversity_discount = diversity_discount + self.candidate_multiple = candidate_multiple + + # Float tensor to keep track of overlap between groups. + # Each token shared at the same step between two groups is counted as one. + # Then token counts are discounted by `diversity_discount` for every next timestep. + # Once initialized, dimension is batch_size * num_groups * num_groups. + self.group_overlap = torch.empty(0) + + @torch.jit.export + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + bsz, beam_size, vocab_size = lprobs.size() + if beam_size % self.num_groups != 0: + raise ValueError( + "DiverseBeamSearch requires --beam to be divisible by the number of groups" + ) + + # initialize diversity penalty + diversity_buf = torch.zeros(lprobs[:, 0, :].size()).to(lprobs) + + scores_G, beams_G = [], [] + + # pre-allocating tensor for indices for all groups + indices_G_stacked = torch.empty( + bsz, + int(beam_size / self.num_groups) * self.candidate_multiple, + self.num_groups, + dtype=torch.long, + device=lprobs.device, + ) + + for g in range(self.num_groups): + lprobs_g = lprobs[:, g :: self.num_groups, :] + scores_g = scores[:, g :: self.num_groups, :] if step > 0 else None + + diversity_buf.zero_() + # apply diversity penalty + if g > 0: + indices_ = indices_G_stacked[:, :, :g] + if step > 0: + penalty_val = 1 + self.group_overlap[original_batch_idxs, g, :g] + penalty_val = penalty_val.unsqueeze(1) + else: + penalty_val = torch.ones(bsz, 1, 1) + diversity_buf.scatter_add_( + 1, + indices_.reshape(bsz, -1), + penalty_val.expand(indices_.size()) + .reshape(bsz, -1) + .to(diversity_buf), + ) + + lprobs_g = torch.add( + lprobs_g, + other=diversity_buf.unsqueeze(1), + alpha=self.diversity_strength, + ) + else: + lprobs_g = lprobs_g.contiguous() + + scores_buf, indices_buf, beams_buf = self.beam.step( + step, lprobs_g, scores_g, candidate_multiple=self.candidate_multiple + ) + beams_buf.mul_(self.num_groups).add_(g) + + scores_G.append(scores_buf.clone()) + beams_G.append(beams_buf.clone()) + + indices_G_stacked[:, :, g] = indices_buf + + # interleave results from different groups + scores_buf = torch.stack(scores_G, dim=2).view(bsz, -1) + indices_buf = indices_G_stacked.view(bsz, -1) + beams_buf = torch.stack(beams_G, dim=2).view(bsz, -1) + # find num of overlapped tokens for each group pair + # then discount it for next timestamp + overlap = self.diversity_discount * torch.sum( + indices_G_stacked.unsqueeze(2).eq(indices_G_stacked.unsqueeze(3)), dim=1 + ) + if step == 0: + self.group_overlap = overlap + else: + self.group_overlap[original_batch_idxs] = ( + self.group_overlap[original_batch_idxs] * self.diversity_discount + + overlap + ) + + return scores_buf, indices_buf, beams_buf + + +class Sampling(Search): + sampling_topk: int + sampling_topp: float + + def __init__(self, tgt_dict, sampling_topk=-1, sampling_topp=-1.0): + super().__init__(tgt_dict) + self.sampling_topk = sampling_topk + self.sampling_topp = sampling_topp + + def _sample_topp(self, lprobs): + """Sample among the smallest set of elements whose cumulative probability mass exceeds p. + + See `"The Curious Case of Neural Text Degeneration" + (Holtzman et al., 2019) `_. + + Args: + lprobs: (bsz x input_beam_size x vocab_size) + the model's log-probabilities over the vocabulary at the current step + + Return: A tuple of (trimed_probs, truncated_indices) where: + trimed_probs: (bsz x input_beam_size x ?) + the model's probabilities over the elements selected to sample from. The + width of the third dimension is determined by top-P. + truncated_indices: (bsz x input_beam_size x ?) + the indices of the chosen elements. + """ + probs = lprobs.exp_() + + # sort the last dimension (vocab dimension) in descending order + sorted_probs, sorted_indices = probs.sort(descending=True) + + # compute a mask to indicate the words to be included in the top-P set. + cumsum_probs = sorted_probs.cumsum(dim=2) + mask = cumsum_probs.lt(self.sampling_topp) + + # note that mask was computed by 'lt'. One more word needs to be included + # so that the cumulative probability mass can exceed p. + cumsum_mask = mask.cumsum(dim=2) + last_included = cumsum_mask[:, :, -1:] + last_included.clamp_(0, mask.size()[2] - 1) + mask = mask.scatter_(2, last_included, 1) + + # truncate unnecessary dims. + max_dim = last_included.max() + truncated_mask = mask[:, :, : max_dim + 1] + truncated_probs = sorted_probs[:, :, : max_dim + 1] + truncated_indices = sorted_indices[:, :, : max_dim + 1] + + # trim the words that are not in top-P by setting their probabilities + # to 0, so that they would not be sampled later. + trim_mask = ~truncated_mask + trimed_probs = truncated_probs.masked_fill_(trim_mask, 0) + return trimed_probs, truncated_indices + + @torch.jit.export + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + bsz, beam_size, vocab_size = lprobs.size() + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam + lprobs = lprobs[:, ::beam_size, :].contiguous() + + if self.sampling_topp > 0: + # only sample from the smallest set of words whose cumulative probability mass exceeds p + probs, top_indices = self._sample_topp(lprobs) + elif self.sampling_topk > 0: + # only sample from top-k candidates + lprobs, top_indices = lprobs.topk(self.sampling_topk) + probs = lprobs.exp_() + else: + probs = lprobs.exp_() + + # dummy data to be consistent with true branch for type check + top_indices = torch.empty(0).to(probs) + # sample + if step == 0: + indices_buf = torch.multinomial( + probs.view(bsz, -1), + beam_size, + replacement=True, + ).view(bsz, beam_size) + else: + indices_buf = torch.multinomial( + probs.view(bsz * beam_size, -1), + 1, + replacement=True, + ).view(bsz, beam_size) + + if step == 0: + # expand to beam size + probs = probs.expand(bsz, beam_size, -1) + + # gather scores + scores_buf = torch.gather(probs, dim=2, index=indices_buf.unsqueeze(-1)) + scores_buf = scores_buf.log_().view(bsz, -1) + + # remap indices if using top-k or top-P sampling + if self.sampling_topk > 0 or self.sampling_topp > 0: + indices_buf = torch.gather( + top_indices.expand(bsz, beam_size, -1), + dim=2, + index=indices_buf.unsqueeze(-1), + ).squeeze(2) + + if step == 0: + beams_buf = indices_buf.new_zeros(bsz, beam_size) + else: + beams_buf = torch.arange(0, beam_size).to(indices_buf).repeat(bsz, 1) + # make scores cumulative + scores_buf.add_( + torch.gather(scores[:, :, step - 1], dim=1, index=beams_buf) + ) + + return scores_buf, indices_buf, beams_buf + + +class DiverseSiblingsSearch(Search): + """ + Beam search with diverse siblings. + + See "A Simple, Fast Diverse Decoding Algorithm for Neural Generation" for details. + https://arxiv.org/abs/1611.08562 + + 1/ Calculate hypotheses for each beam + 2/ Intra-sibling ordering + 3/ Rewrite scores + 4/ Choose top K hypotheses + + if diversity_rate == 0 is equivalent to BeamSearch + """ + + def __init__(self, tgt_dict, diversity_rate): + super().__init__(tgt_dict) + self.diversity_rate = diversity_rate + self.beam = BeamSearch(tgt_dict) + + def step( + self, + step: int, + lprobs, + scores, + prev_output_tokens: Optional[Tensor] = None, + original_batch_idxs: Optional[Tensor] = None, + ): + bsz, beam_size, vocab_size = lprobs.size() + k = min( + # Take the best 2 x beam_size predictions. We'll choose the first + # beam_size of these which don't predict eos to continue with. + beam_size * 2, + lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad + ) + s_list: List[Tensor] + i_list: List[Tensor] + s_list = [torch.empty(0).to(lprobs) for i in range(beam_size)] + i_list = [torch.LongTensor().to(device=lprobs.device) for i in range(beam_size)] + sibling_score = torch.arange(1, k + 1).to(lprobs) * self.diversity_rate + + if step == 0: + return self.beam.step(step, lprobs, scores) + lprobs.add_(scores[:, :, step - 1].unsqueeze(-1)) + + # 1/ Calculate hypotheses for each beam + for i in range(beam_size): + torch.topk(lprobs[:, i, :].view(bsz, -1), k, out=(s_list[i], i_list[i])) + i_list[i].fmod_(vocab_size) + + # 2/ Intra-sibling ordering by default from topk + 3/ Rewrite scores + s_list[i].sub_(sibling_score) + + # 4/ Choose top K hypotheses + indices = torch.stack(i_list, dim=1).view(bsz, -1) + + final_scores = torch.empty(0).to(lprobs) + final_indices = torch.LongTensor().to(device=lprobs.device) + final_beams = torch.LongTensor().to(device=lprobs.device) + (final_scores, final_indices) = torch.topk( + torch.stack(s_list, dim=1).view(bsz, -1), + k, + ) + + final_beams = final_indices // k + + for i in range(bsz): + final_indices[i] = indices[i][final_indices[i]] + + return final_scores, final_indices, final_beams diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..78db504e6ce75ac31980e71cfdbf436e07739025 --- /dev/null +++ b/fairseq/sequence_generator.py @@ -0,0 +1,1020 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import sys +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +from torch import Tensor + +from fairseq import search, utils +from fairseq.data import data_utils +from fairseq.models import FairseqIncrementalDecoder +from fairseq.ngram_repeat_block import NGramRepeatBlock + + +class SequenceGenerator(nn.Module): + def __init__( + self, + models, + tgt_dict, + beam_size=1, + max_len_a=0, + max_len_b=200, + max_len=0, + min_len=1, + normalize_scores=True, + len_penalty=1.0, + unk_penalty=0.0, + temperature=1.0, + match_source_len=False, + no_repeat_ngram_size=0, + search_strategy=None, + eos=None, + symbols_to_strip_from_output=None, + lm_model=None, + lm_weight=1.0, + tokens_to_suppress=(), + ): + """Generates translations of a given source sentence. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models, + currently support fairseq.models.TransformerModel for scripting + beam_size (int, optional): beam width (default: 1) + max_len_a/b (int, optional): generate sequences of maximum length + ax + b, where x is the source length + max_len (int, optional): the maximum length of the generated output + (not including end-of-sentence) + min_len (int, optional): the minimum length of the generated output + (not including end-of-sentence) + normalize_scores (bool, optional): normalize scores by the length + of the output (default: True) + len_penalty (float, optional): length penalty, where <1.0 favors + shorter, >1.0 favors longer sentences (default: 1.0) + unk_penalty (float, optional): unknown word penalty, where <0 + produces more unks, >0 produces fewer (default: 0.0) + temperature (float, optional): temperature, where values + >1.0 produce more uniform samples and values <1.0 produce + sharper samples (default: 1.0) + match_source_len (bool, optional): outputs should match the source + length (default: False) + """ + super().__init__() + if isinstance(models, EnsembleModel): + self.model = models + else: + self.model = EnsembleModel(models) + self.tgt_dict = tgt_dict + self.pad = tgt_dict.pad() + self.unk = tgt_dict.unk() + self.eos = tgt_dict.eos() if eos is None else eos + self.symbols_to_strip_from_output = ( + symbols_to_strip_from_output.union({self.eos}) + if symbols_to_strip_from_output is not None + else {self.eos} + ) + + self.token_indices_to_suppress: Optional[Tensor] = None + token_indices_to_suppress = [] + for token_string in tokens_to_suppress: + token_index = tgt_dict.index(token_string) + assert token_index != self.unk + token_indices_to_suppress.append(token_index) + if len(token_indices_to_suppress) > 0: + self.token_indices_to_suppress = torch.Tensor( + token_indices_to_suppress + ).long() + + self.vocab_size = len(tgt_dict) + self.beam_size = beam_size + # the max beam size is the dictionary size - 1, since we never select pad + self.beam_size = min(beam_size, self.vocab_size - 1) + self.model.set_decoder_beam_size(self.beam_size) + self.max_len_a = max_len_a + self.max_len_b = max_len_b + self.min_len = min_len + self.max_len = max_len or self.model.max_decoder_positions() + + self.normalize_scores = normalize_scores + self.len_penalty = len_penalty + self.unk_penalty = unk_penalty + self.temperature = temperature + self.match_source_len = match_source_len + + if no_repeat_ngram_size > 0: + self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size) + else: + self.repeat_ngram_blocker = None + + assert temperature > 0, "--temperature must be greater than 0" + + self.search = ( + search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy + ) + # We only need to set src_lengths in LengthConstrainedBeamSearch. + # As a module attribute, setting it would break in multithread + # settings when the model is shared. + self.should_set_src_lengths = ( + hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths + ) + + self.model.eval() + + self.lm_model = lm_model + self.lm_weight = lm_weight + if self.lm_model is not None: + self.lm_model.eval() + + def cuda(self): + self.model.cuda() + return self + + @torch.no_grad() + def forward( + self, + sample: Dict[str, Dict[str, Tensor]], + prefix_tokens: Optional[Tensor] = None, + bos_token: Optional[int] = None, + ): + """Generate a batch of translations. + + Args: + sample (dict): batch + prefix_tokens (torch.LongTensor, optional): force decoder to begin + with these tokens + bos_token (int, optional): beginning of sentence token + (default: self.eos) + """ + return self._generate(sample, prefix_tokens, bos_token=bos_token) + + # TODO(myleott): unused, deprecate after pytorch-translate migration + def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None): + """Iterate over a batched dataset and yield individual translations. + Args: + cuda (bool, optional): use GPU for generation + timer (StopwatchMeter, optional): time generations + """ + for sample in data_itr: + s = utils.move_to_cuda(sample) if cuda else sample + if "net_input" not in s: + continue + input = s["net_input"] + # model.forward normally channels prev_output_tokens into the decoder + # separately, but SequenceGenerator directly calls model.encoder + encoder_input = { + k: v for k, v in input.items() if k != "prev_output_tokens" + } + if timer is not None: + timer.start() + with torch.no_grad(): + hypos = self.generate(encoder_input) + if timer is not None: + timer.stop(sum(len(h[0]["tokens"]) for h in hypos)) + for i, id in enumerate(s["id"].data): + # remove padding + src = utils.strip_pad(input["src_tokens"].data[i, :], self.pad) + ref = ( + utils.strip_pad(s["target"].data[i, :], self.pad) + if s["target"] is not None + else None + ) + yield id, src, ref, hypos[i] + + @torch.no_grad() + def generate( + self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs + ) -> List[List[Dict[str, Tensor]]]: + """Generate translations. Match the api of other fairseq generators. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models + sample (dict): batch + prefix_tokens (torch.LongTensor, optional): force decoder to begin + with these tokens + constraints (torch.LongTensor, optional): force decoder to include + the list of constraints + bos_token (int, optional): beginning of sentence token + (default: self.eos) + """ + return self._generate(sample, **kwargs) + + def _generate( + self, + sample: Dict[str, Dict[str, Tensor]], + prefix_tokens: Optional[Tensor] = None, + constraints: Optional[Tensor] = None, + bos_token: Optional[int] = None, + ): + incremental_states = torch.jit.annotate( + List[Dict[str, Dict[str, Optional[Tensor]]]], + [ + torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) + for i in range(self.model.models_size) + ], + ) + net_input = sample["net_input"] + + if "src_tokens" in net_input: + src_tokens = net_input["src_tokens"] + # length of the source text being the character length except EndOfSentence and pad + # if src_lengths exists in net_input (speech_to_text dataset case), then use it + if "src_lengths" in net_input: + src_lengths = net_input["src_lengths"] + else: + src_lengths = ( + (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)) + .long() + .sum(dim=1) + ) + elif "source" in net_input: + src_tokens = net_input["source"] + src_lengths = ( + net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1) + if net_input["padding_mask"] is not None + else torch.tensor(src_tokens.size(-1)).to(src_tokens) + ) + elif "features" in net_input: + src_tokens = net_input["features"] + src_lengths = ( + net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1) + if net_input["padding_mask"] is not None + else torch.tensor(src_tokens.size(-1)).to(src_tokens) + ) + else: + raise Exception( + "expected src_tokens or source in net input. input keys: " + + str(net_input.keys()) + ) + + # bsz: total number of sentences in beam + # Note that src_tokens may have more than 2 dimensions (i.e. audio features) + bsz, src_len = src_tokens.size()[:2] + beam_size = self.beam_size + + if constraints is not None and not self.search.supports_constraints: + raise NotImplementedError( + "Target-side constraints were provided, but search method doesn't support them" + ) + + # Initialize constraints, when active + self.search.init_constraints(constraints, beam_size) + + max_len: int = -1 + if self.match_source_len: + max_len = src_lengths.max().item() + else: + max_len = min( + int(self.max_len_a * src_len + self.max_len_b), + self.max_len - 1, + ) + assert ( + self.min_len <= max_len + ), "min_len cannot be larger than max_len, please adjust these!" + # compute the encoder output for each beam + with torch.autograd.profiler.record_function("EnsembleModel: forward_encoder"): + encoder_outs = self.model.forward_encoder(net_input) + + # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores + new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) + new_order = new_order.to(src_tokens.device).long() + encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order) + # ensure encoder_outs is a List. + assert encoder_outs is not None + + # initialize buffers + scores = ( + torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float() + ) # +1 for eos; pad is never chosen for scoring + tokens = ( + torch.zeros(bsz * beam_size, max_len + 2) + .to(src_tokens) + .long() + .fill_(self.pad) + ) # +2 for eos and pad + tokens[:, 0] = self.eos if bos_token is None else bos_token + attn: Optional[Tensor] = None + + # A list that indicates candidates that should be ignored. + # For example, suppose we're sampling and have already finalized 2/5 + # samples. Then cands_to_ignore would mark 2 positions as being ignored, + # so that we only finalize the remaining 3 samples. + cands_to_ignore = ( + torch.zeros(bsz, beam_size).to(src_tokens).eq(-1) + ) # forward and backward-compatible False mask + + # list of completed sentences + finalized = torch.jit.annotate( + List[List[Dict[str, Tensor]]], + [torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)], + ) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step + + # a boolean array indicating if the sentence at the index is finished or not + finished = [False for i in range(bsz)] + num_remaining_sent = bsz # number of sentences remaining + + # number of candidate hypos per step + cand_size = 2 * beam_size # 2 x beam size in case half are EOS + + # offset arrays for converting between different indexing schemes + bbsz_offsets = ( + (torch.arange(0, bsz) * beam_size) + .unsqueeze(1) + .type_as(tokens) + .to(src_tokens.device) + ) + cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_tokens.device) + + reorder_state: Optional[Tensor] = None + batch_idxs: Optional[Tensor] = None + + original_batch_idxs: Optional[Tensor] = None + if "id" in sample and isinstance(sample["id"], Tensor): + original_batch_idxs = sample["id"] + else: + original_batch_idxs = torch.arange(0, bsz).type_as(tokens) + + for step in range(max_len + 1): # one extra step for EOS marker + # reorder decoder internal states based on the prev choice of beams + if reorder_state is not None: + if batch_idxs is not None: + # update beam indices to take into account removed sentences + corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as( + batch_idxs + ) + reorder_state.view(-1, beam_size).add_( + corr.unsqueeze(-1) * beam_size + ) + original_batch_idxs = original_batch_idxs[batch_idxs] + self.model.reorder_incremental_state(incremental_states, reorder_state) + encoder_outs = self.model.reorder_encoder_out( + encoder_outs, reorder_state + ) + with torch.autograd.profiler.record_function( + "EnsembleModel: forward_decoder" + ): + lprobs, avg_attn_scores = self.model.forward_decoder( + tokens[:, : step + 1], + encoder_outs, + incremental_states, + self.temperature, + ) + + if self.lm_model is not None: + lm_out = self.lm_model(tokens[:, : step + 1]) + probs = self.lm_model.get_normalized_probs( + lm_out, log_probs=True, sample=None + ) + probs = probs[:, -1, :] * self.lm_weight + lprobs += probs + + lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs) + + lprobs[:, self.pad] = -math.inf # never select pad + lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty + + # handle max length constraint + if step >= max_len: + lprobs[:, : self.eos] = -math.inf + lprobs[:, self.eos + 1 :] = -math.inf + + # handle prefix tokens (possibly with different lengths) + if ( + prefix_tokens is not None + and step < prefix_tokens.size(1) + and step < max_len + ): + lprobs, tokens, scores = self._prefix_tokens( + step, lprobs, scores, tokens, prefix_tokens, beam_size + ) + else: + if step < self.min_len: + # minimum length constraint (does not apply if using prefix_tokens) + lprobs[:, self.eos] = -math.inf + + if self.token_indices_to_suppress is not None: + lprobs[:, self.token_indices_to_suppress] = -math.inf + + # Record attention scores, only support avg_attn_scores is a Tensor + if avg_attn_scores is not None: + if attn is None: + attn = torch.empty( + bsz * beam_size, avg_attn_scores.size(1), max_len + 2 + ).to(scores) + attn[:, :, step + 1].copy_(avg_attn_scores) + + scores = scores.type_as(lprobs) + eos_bbsz_idx = torch.empty(0).to( + tokens + ) # indices of hypothesis ending with eos (finished sentences) + eos_scores = torch.empty(0).to( + scores + ) # scores of hypothesis ending with eos (finished sentences) + + if self.should_set_src_lengths: + self.search.set_src_lengths(src_lengths) + + if self.repeat_ngram_blocker is not None: + lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step) + + # Shape: (batch, cand_size) + cand_scores, cand_indices, cand_beams = self.search.step( + step, + lprobs.view(bsz, -1, self.vocab_size), + scores.view(bsz, beam_size, -1)[:, :, :step], + tokens[:, : step + 1], + original_batch_idxs, + ) + + # cand_bbsz_idx contains beam indices for the top candidate + # hypotheses, with a range of values: [0, bsz*beam_size), + # and dimensions: [bsz, cand_size] + cand_bbsz_idx = cand_beams.add(bbsz_offsets) + + # finalize hypotheses that end in eos + # Shape of eos_mask: (batch size, beam size) + eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf) + eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask) + + # only consider eos when it's among the top beam_size indices + # Now we know what beam item(s) to finish + # Shape: 1d list of absolute-numbered + eos_bbsz_idx = torch.masked_select( + cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size] + ) + + finalized_sents: List[int] = [] + if eos_bbsz_idx.numel() > 0: + eos_scores = torch.masked_select( + cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size] + ) + + finalized_sents = self.finalize_hypos( + step, + eos_bbsz_idx, + eos_scores, + tokens, + scores, + finalized, + finished, + beam_size, + attn, + src_lengths, + max_len, + ) + num_remaining_sent -= len(finalized_sents) + + assert num_remaining_sent >= 0 + if num_remaining_sent == 0: + break + if self.search.stop_on_max_len and step >= max_len: + break + assert step < max_len, f"{step} < {max_len}" + + # Remove finalized sentences (ones for which {beam_size} + # finished hypotheses have been generated) from the batch. + if len(finalized_sents) > 0: + new_bsz = bsz - len(finalized_sents) + + # construct batch_idxs which holds indices of batches to keep for the next pass + batch_mask = torch.ones( + bsz, dtype=torch.bool, device=cand_indices.device + ) + batch_mask[finalized_sents] = False + # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it + batch_idxs = torch.arange( + bsz, device=cand_indices.device + ).masked_select(batch_mask) + + # Choose the subset of the hypothesized constraints that will continue + self.search.prune_sentences(batch_idxs) + + eos_mask = eos_mask[batch_idxs] + cand_beams = cand_beams[batch_idxs] + bbsz_offsets.resize_(new_bsz, 1) + cand_bbsz_idx = cand_beams.add(bbsz_offsets) + cand_scores = cand_scores[batch_idxs] + cand_indices = cand_indices[batch_idxs] + + if prefix_tokens is not None: + prefix_tokens = prefix_tokens[batch_idxs] + src_lengths = src_lengths[batch_idxs] + cands_to_ignore = cands_to_ignore[batch_idxs] + + scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) + tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) + if attn is not None: + attn = attn.view(bsz, -1)[batch_idxs].view( + new_bsz * beam_size, attn.size(1), -1 + ) + bsz = new_bsz + else: + batch_idxs = None + + # Set active_mask so that values > cand_size indicate eos hypos + # and values < cand_size indicate candidate active hypos. + # After, the min values per row are the top candidate active hypos + + # Rewrite the operator since the element wise or is not supported in torchscript. + + eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size])) + active_mask = torch.add( + eos_mask.type_as(cand_offsets) * cand_size, + cand_offsets[: eos_mask.size(1)], + ) + + # get the top beam_size active hypotheses, which are just + # the hypos with the smallest values in active_mask. + # {active_hypos} indicates which {beam_size} hypotheses + # from the list of {2 * beam_size} candidates were + # selected. Shapes: (batch size, beam size) + new_cands_to_ignore, active_hypos = torch.topk( + active_mask, k=beam_size, dim=1, largest=False + ) + + # update cands_to_ignore to ignore any finalized hypos. + cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size] + # Make sure there is at least one active item for each sentence in the batch. + assert (~cands_to_ignore).any(dim=1).all() + + # update cands_to_ignore to ignore any finalized hypos + + # {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam + # can be selected more than once). + active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos) + active_scores = torch.gather(cand_scores, dim=1, index=active_hypos) + + active_bbsz_idx = active_bbsz_idx.view(-1) + active_scores = active_scores.view(-1) + + # copy tokens and scores for active hypotheses + + # Set the tokens for each beam (can select the same row more than once) + tokens[:, : step + 1] = torch.index_select( + tokens[:, : step + 1], dim=0, index=active_bbsz_idx + ) + # Select the next token for each of them + tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather( + cand_indices, dim=1, index=active_hypos + ) + if step > 0: + scores[:, :step] = torch.index_select( + scores[:, :step], dim=0, index=active_bbsz_idx + ) + scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather( + cand_scores, dim=1, index=active_hypos + ) + + # Update constraints based on which candidates were selected for the next beam + self.search.update_constraints(active_hypos) + + # copy attention for active hypotheses + if attn is not None: + attn[:, :, : step + 2] = torch.index_select( + attn[:, :, : step + 2], dim=0, index=active_bbsz_idx + ) + + # reorder incremental state in decoder + reorder_state = active_bbsz_idx + + # sort by score descending + for sent in range(len(finalized)): + scores = torch.tensor( + [float(elem["score"].item()) for elem in finalized[sent]] + ) + _, sorted_scores_indices = torch.sort(scores, descending=True) + finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices] + finalized[sent] = torch.jit.annotate( + List[Dict[str, Tensor]], finalized[sent] + ) + return finalized + + def _prefix_tokens( + self, step: int, lprobs, scores, tokens, prefix_tokens, beam_size: int + ): + """Handle prefix tokens""" + prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1) + prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1)) + prefix_mask = prefix_toks.ne(self.pad) + lprobs[prefix_mask] = torch.tensor(-math.inf).to(lprobs) + lprobs[prefix_mask] = lprobs[prefix_mask].scatter( + -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask] + ) + # if prefix includes eos, then we should make sure tokens and + # scores are the same across all beams + eos_mask = prefix_toks.eq(self.eos) + if eos_mask.any(): + # validate that the first beam matches the prefix + first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[ + :, 0, 1 : step + 1 + ] + eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] + target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] + assert (first_beam == target_prefix).all() + + # copy tokens, scores and lprobs from the first beam to all beams + tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, beam_size) + scores = self.replicate_first_beam(scores, eos_mask_batch_dim, beam_size) + lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, beam_size) + return lprobs, tokens, scores + + def replicate_first_beam(self, tensor, mask, beam_size: int): + tensor = tensor.view(-1, beam_size, tensor.size(-1)) + tensor[mask] = tensor[mask][:, :1, :] + return tensor.view(-1, tensor.size(-1)) + + def finalize_hypos( + self, + step: int, + bbsz_idx, + eos_scores, + tokens, + scores, + finalized: List[List[Dict[str, Tensor]]], + finished: List[bool], + beam_size: int, + attn: Optional[Tensor], + src_lengths, + max_len: int, + ): + """Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly. + A sentence is finalized when {beam_size} finished items have been collected for it. + + Returns number of sentences (not beam items) being finalized. + These will be removed from the batch and not processed further. + Args: + bbsz_idx (Tensor): + """ + assert bbsz_idx.numel() == eos_scores.numel() + + # clone relevant token and attention tensors. + # tokens is (batch * beam, max_len). So the index_select + # gets the newly EOS rows, then selects cols 1..{step + 2} + tokens_clone = tokens.index_select(0, bbsz_idx)[ + :, 1 : step + 2 + ] # skip the first index, which is EOS + + tokens_clone[:, step] = self.eos + attn_clone = ( + attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2] + if attn is not None + else None + ) + + # compute scores per token position + pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1] + pos_scores[:, step] = eos_scores + # convert from cumulative to per-position scores + pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] + + # normalize sentence-level scores + if self.normalize_scores: + eos_scores /= (step + 1) ** self.len_penalty + + # cum_unfin records which sentences in the batch are finished. + # It helps match indexing between (a) the original sentences + # in the batch and (b) the current, possibly-reduced set of + # sentences. + cum_unfin: List[int] = [] + prev = 0 + for f in finished: + if f: + prev += 1 + else: + cum_unfin.append(prev) + cum_fin_tensor = torch.tensor(cum_unfin, dtype=torch.int).to(bbsz_idx) + + unfin_idx = torch.div(bbsz_idx, beam_size, rounding_mode="trunc") + sent = unfin_idx + torch.index_select(cum_fin_tensor, 0, unfin_idx) + + # Create a set of "{sent}{unfin_idx}", where + # "unfin_idx" is the index in the current (possibly reduced) + # list of sentences, and "sent" is the index in the original, + # unreduced batch + # For every finished beam item + # sentence index in the current (possibly reduced) batch + seen = (sent << 32) + unfin_idx + unique_seen: List[int] = torch.unique(seen).tolist() + + if self.match_source_len: + condition = step > torch.index_select(src_lengths, 0, unfin_idx) + eos_scores = torch.where(condition, torch.tensor(-math.inf), eos_scores) + sent_list: List[int] = sent.tolist() + for i in range(bbsz_idx.size()[0]): + # An input sentence (among those in a batch) is finished when + # beam_size hypotheses have been collected for it + if len(finalized[sent_list[i]]) < beam_size: + if attn_clone is not None: + # remove padding tokens from attn scores + hypo_attn = attn_clone[i] + else: + hypo_attn = torch.empty(0) + + finalized[sent_list[i]].append( + { + "tokens": tokens_clone[i], + "score": eos_scores[i], + "attention": hypo_attn, # src_len x tgt_len + "alignment": torch.empty(0), + "positional_scores": pos_scores[i], + } + ) + + newly_finished: List[int] = [] + for unique_s in unique_seen: + # check termination conditions for this sentence + unique_sent: int = unique_s >> 32 + unique_unfin_idx: int = unique_s - (unique_sent << 32) + + if not finished[unique_sent] and self.is_finished( + step, unique_unfin_idx, max_len, len(finalized[unique_sent]), beam_size + ): + finished[unique_sent] = True + newly_finished.append(unique_unfin_idx) + + return newly_finished + + def is_finished( + self, + step: int, + unfin_idx: int, + max_len: int, + finalized_sent_len: int, + beam_size: int, + ): + """ + Check whether decoding for a sentence is finished, which + occurs when the list of finalized sentences has reached the + beam size, or when we reach the maximum length. + """ + assert finalized_sent_len <= beam_size + if finalized_sent_len == beam_size or step == max_len: + return True + return False + + +class EnsembleModel(nn.Module): + """A wrapper around an ensemble of models.""" + + def __init__(self, models): + super().__init__() + self.models_size = len(models) + # method '__len__' is not supported in ModuleList for torch script + self.single_model = models[0] + self.models = nn.ModuleList(models) + + self.has_incremental: bool = False + if all( + hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder) + for m in models + ): + self.has_incremental = True + + def forward(self): + pass + + def has_encoder(self): + return hasattr(self.single_model, "encoder") + + def has_incremental_states(self): + return self.has_incremental + + def max_decoder_positions(self): + return min( + [ + m.max_decoder_positions() + for m in self.models + if hasattr(m, "max_decoder_positions") + ] + + [sys.maxsize] + ) + + def set_decoder_beam_size(self, beam_size): + """Set beam size for efficient beamable enc-dec attention.""" + if beam_size > 1: + for model in self.models: + if hasattr(model, "set_beam_size"): + model.set_beam_size(beam_size) + + @torch.jit.export + def forward_encoder(self, net_input: Dict[str, Tensor]): + if not self.has_encoder(): + return None + return [model.encoder.forward_torchscript(net_input) for model in self.models] + + @torch.jit.export + def forward_decoder( + self, + tokens, + encoder_outs: List[Dict[str, List[Tensor]]], + incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], + temperature: float = 1.0, + ): + log_probs = [] + avg_attn: Optional[Tensor] = None + encoder_out: Optional[Dict[str, List[Tensor]]] = None + for i, model in enumerate(self.models): + if self.has_encoder(): + encoder_out = encoder_outs[i] + # decode each model + if self.has_incremental_states(): + decoder_out = model.decoder.forward( + tokens, + encoder_out=encoder_out, + incremental_state=incremental_states[i], + ) + else: + if hasattr(model, "decoder"): + decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out) + else: + decoder_out = model.forward(tokens) + + attn: Optional[Tensor] = None + decoder_len = len(decoder_out) + if decoder_len > 1 and decoder_out[1] is not None: + if isinstance(decoder_out[1], Tensor): + attn = decoder_out[1] + else: + attn_holder = decoder_out[1]["attn"] + if isinstance(attn_holder, Tensor): + attn = attn_holder + elif attn_holder is not None: + attn = attn_holder[0] + if attn is not None: + attn = attn[:, -1, :] + + decoder_out_tuple = ( + decoder_out[0][:, -1:, :].div_(temperature), + None if decoder_len <= 1 else decoder_out[1], + ) + probs = model.get_normalized_probs( + decoder_out_tuple, log_probs=True, sample=None + ) + probs = probs[:, -1, :] + if self.models_size == 1: + return probs, attn + + log_probs.append(probs) + if attn is not None: + if avg_attn is None: + avg_attn = attn + else: + avg_attn.add_(attn) + + avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log( + self.models_size + ) + + if avg_attn is not None: + avg_attn.div_(self.models_size) + return avg_probs, avg_attn + + @torch.jit.export + def reorder_encoder_out( + self, encoder_outs: Optional[List[Dict[str, List[Tensor]]]], new_order + ): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + new_outs: List[Dict[str, List[Tensor]]] = [] + if not self.has_encoder(): + return new_outs + for i, model in enumerate(self.models): + assert encoder_outs is not None + new_outs.append( + model.encoder.reorder_encoder_out(encoder_outs[i], new_order) + ) + return new_outs + + @torch.jit.export + def reorder_incremental_state( + self, + incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], + new_order, + ): + if not self.has_incremental_states(): + return + for i, model in enumerate(self.models): + model.decoder.reorder_incremental_state_scripting( + incremental_states[i], new_order + ) + + +class SequenceGeneratorWithAlignment(SequenceGenerator): + def __init__( + self, models, tgt_dict, left_pad_target=False, print_alignment="hard", **kwargs + ): + """Generates translations of a given source sentence. + + Produces alignments following "Jointly Learning to Align and + Translate with Transformer Models" (Garg et al., EMNLP 2019). + + Args: + left_pad_target (bool, optional): Whether or not the + hypothesis should be left padded or not when they are + teacher forced for generating alignments. + """ + super().__init__(EnsembleModelWithAlignment(models), tgt_dict, **kwargs) + self.left_pad_target = left_pad_target + + if print_alignment == "hard": + self.extract_alignment = utils.extract_hard_alignment + elif print_alignment == "soft": + self.extract_alignment = utils.extract_soft_alignment + + @torch.no_grad() + def generate(self, models, sample, **kwargs): + finalized = super()._generate(sample, **kwargs) + + src_tokens = sample["net_input"]["src_tokens"] + bsz = src_tokens.shape[0] + beam_size = self.beam_size + ( + src_tokens, + src_lengths, + prev_output_tokens, + tgt_tokens, + ) = self._prepare_batch_for_alignment(sample, finalized) + if any(getattr(m, "full_context_alignment", False) for m in self.model.models): + attn = self.model.forward_align(src_tokens, src_lengths, prev_output_tokens) + else: + attn = [ + finalized[i // beam_size][i % beam_size]["attention"].transpose(1, 0) + for i in range(bsz * beam_size) + ] + + if src_tokens.device != "cpu": + src_tokens = src_tokens.to("cpu") + tgt_tokens = tgt_tokens.to("cpu") + attn = [i.to("cpu") for i in attn] + + # Process the attn matrix to extract hard alignments. + for i in range(bsz * beam_size): + alignment = self.extract_alignment( + attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos + ) + finalized[i // beam_size][i % beam_size]["alignment"] = alignment + return finalized + + def _prepare_batch_for_alignment(self, sample, hypothesis): + src_tokens = sample["net_input"]["src_tokens"] + bsz = src_tokens.shape[0] + src_tokens = ( + src_tokens[:, None, :] + .expand(-1, self.beam_size, -1) + .contiguous() + .view(bsz * self.beam_size, -1) + ) + src_lengths = sample["net_input"]["src_lengths"] + src_lengths = ( + src_lengths[:, None] + .expand(-1, self.beam_size) + .contiguous() + .view(bsz * self.beam_size) + ) + prev_output_tokens = data_utils.collate_tokens( + [beam["tokens"] for example in hypothesis for beam in example], + self.pad, + self.eos, + self.left_pad_target, + move_eos_to_beginning=True, + ) + tgt_tokens = data_utils.collate_tokens( + [beam["tokens"] for example in hypothesis for beam in example], + self.pad, + self.eos, + self.left_pad_target, + move_eos_to_beginning=False, + ) + return src_tokens, src_lengths, prev_output_tokens, tgt_tokens + + +class EnsembleModelWithAlignment(EnsembleModel): + """A wrapper around an ensemble of models.""" + + def __init__(self, models): + super().__init__(models) + + def forward_align(self, src_tokens, src_lengths, prev_output_tokens): + avg_attn = None + for model in self.models: + decoder_out = model(src_tokens, src_lengths, prev_output_tokens) + attn = decoder_out[1]["attn"][0] + if avg_attn is None: + avg_attn = attn + else: + avg_attn.add_(attn) + if len(self.models) > 1: + avg_attn.div_(len(self.models)) + return avg_attn diff --git a/fairseq/sequence_scorer.py b/fairseq/sequence_scorer.py new file mode 100644 index 0000000000000000000000000000000000000000..411d4df4445ef8dd3f1907ad56f9de6943d1fed8 --- /dev/null +++ b/fairseq/sequence_scorer.py @@ -0,0 +1,153 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import sys + +import torch +from fairseq import utils + + +class SequenceScorer(object): + """Scores the target for a given source sentence.""" + + def __init__( + self, + tgt_dict, + softmax_batch=None, + compute_alignment=False, + eos=None, + symbols_to_strip_from_output=None, + ): + self.pad = tgt_dict.pad() + self.eos = tgt_dict.eos() if eos is None else eos + self.softmax_batch = softmax_batch or sys.maxsize + assert self.softmax_batch > 0 + self.compute_alignment = compute_alignment + self.symbols_to_strip_from_output = ( + symbols_to_strip_from_output.union({self.eos}) + if symbols_to_strip_from_output is not None + else {self.eos} + ) + + @torch.no_grad() + def generate(self, models, sample, **kwargs): + """Score a batch of translations.""" + net_input = sample["net_input"] + + def batch_for_softmax(dec_out, target): + # assumes decoder_out[0] is the only thing needed (may not be correct for future models!) + first, rest = dec_out[0], dec_out[1:] + bsz, tsz, dim = first.shape + if bsz * tsz < self.softmax_batch: + yield dec_out, target, True + else: + flat = first.contiguous().view(1, -1, dim) + flat_tgt = target.contiguous().view(flat.shape[:-1]) + s = 0 + while s < flat.size(1): + e = s + self.softmax_batch + yield (flat[:, s:e],) + rest, flat_tgt[:, s:e], False + s = e + + def gather_target_probs(probs, target): + probs = probs.gather( + dim=2, + index=target.unsqueeze(-1), + ) + return probs + + orig_target = sample["target"] + + # compute scores for each model in the ensemble + avg_probs = None + avg_attn = None + for model in models: + model.eval() + decoder_out = model(**net_input) + attn = decoder_out[1] if len(decoder_out) > 1 else None + if type(attn) is dict: + attn = attn.get("attn", None) + + batched = batch_for_softmax(decoder_out, orig_target) + probs, idx = None, 0 + for bd, tgt, is_single in batched: + sample["target"] = tgt + curr_prob = model.get_normalized_probs( + bd, log_probs=len(models) == 1, sample=sample + ).data + if is_single: + probs = gather_target_probs(curr_prob, orig_target) + else: + if probs is None: + probs = curr_prob.new(orig_target.numel()) + step = curr_prob.size(0) * curr_prob.size(1) + end = step + idx + tgt_probs = gather_target_probs( + curr_prob.view(tgt.shape + (curr_prob.size(-1),)), tgt + ) + probs[idx:end] = tgt_probs.view(-1) + idx = end + sample["target"] = orig_target + + probs = probs.view(sample["target"].shape) + + if avg_probs is None: + avg_probs = probs + else: + avg_probs.add_(probs) + if attn is not None: + if torch.is_tensor(attn): + attn = attn.data + else: + attn = attn[0] + if avg_attn is None: + avg_attn = attn + else: + avg_attn.add_(attn) + if len(models) > 1: + avg_probs.div_(len(models)) + avg_probs.log_() + if avg_attn is not None: + avg_attn.div_(len(models)) + + bsz = avg_probs.size(0) + hypos = [] + start_idxs = sample["start_indices"] if "start_indices" in sample else [0] * bsz + for i in range(bsz): + # remove padding from ref + ref = ( + utils.strip_pad(sample["target"][i, start_idxs[i] :], self.pad) + if sample["target"] is not None + else None + ) + tgt_len = ref.numel() + avg_probs_i = avg_probs[i][start_idxs[i] : start_idxs[i] + tgt_len] + score_i = avg_probs_i.sum() / tgt_len + if avg_attn is not None: + avg_attn_i = avg_attn[i] + if self.compute_alignment: + alignment = utils.extract_hard_alignment( + avg_attn_i, + sample["net_input"]["src_tokens"][i], + sample["target"][i], + self.pad, + self.eos, + ) + else: + alignment = None + else: + avg_attn_i = alignment = None + hypos.append( + [ + { + "tokens": ref, + "score": score_i, + "attention": avg_attn_i, + "alignment": alignment, + "positional_scores": avg_probs_i, + } + ] + ) + return hypos diff --git a/fairseq/speech_generator.py b/fairseq/speech_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..f2cc8b5e86377d74515e477313eee3864a01d812 --- /dev/null +++ b/fairseq/speech_generator.py @@ -0,0 +1,427 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig + + +class SpeechGenerator(object): + def __init__(self, model, vocoder, data_cfg: S2TDataConfig): + self.model = model + self.vocoder = vocoder + stats_npz_path = data_cfg.global_cmvn_stats_npz + self.gcmvn_stats = None + if stats_npz_path is not None: + self.gcmvn_stats = np.load(stats_npz_path) + + def gcmvn_denormalize(self, x): + # x: B x T x C + if self.gcmvn_stats is None: + return x + mean = torch.from_numpy(self.gcmvn_stats["mean"]).to(x) + std = torch.from_numpy(self.gcmvn_stats["std"]).to(x) + assert len(x.shape) == 3 and mean.shape[0] == std.shape[0] == x.shape[2] + x = x * std.view(1, 1, -1).expand_as(x) + return x + mean.view(1, 1, -1).expand_as(x) + + def get_waveform(self, feat): + # T x C -> T + return None if self.vocoder is None else self.vocoder(feat).squeeze(0) + + +class AutoRegressiveSpeechGenerator(SpeechGenerator): + def __init__( + self, + model, + vocoder, + data_cfg, + max_iter: int = 6000, + eos_prob_threshold: float = 0.5, + ): + super().__init__(model, vocoder, data_cfg) + self.max_iter = max_iter + self.eos_prob_threshold = eos_prob_threshold + + @torch.no_grad() + def generate(self, model, sample, has_targ=False, **kwargs): + model.eval() + + src_tokens = sample["net_input"]["src_tokens"] + src_lengths = sample["net_input"]["src_lengths"] + bsz, src_len = src_tokens.size()[:2] + n_frames_per_step = model.decoder.n_frames_per_step + out_dim = model.decoder.out_dim + raw_dim = out_dim // n_frames_per_step + + # initialize + encoder_out = model.forward_encoder( + src_tokens, src_lengths, speaker=sample["speaker"] + ) + incremental_state = {} + feat, attn, eos_prob = [], [], [] + finished = src_tokens.new_zeros((bsz,)).bool() + out_lens = src_lengths.new_zeros((bsz,)).long().fill_(self.max_iter) + + prev_feat_out = encoder_out["encoder_out"][0].new_zeros(bsz, 1, out_dim) + for step in range(self.max_iter): + cur_out_lens = out_lens.clone() + cur_out_lens.masked_fill_(cur_out_lens.eq(self.max_iter), step + 1) + _, cur_eos_out, cur_extra = model.forward_decoder( + prev_feat_out, + encoder_out=encoder_out, + incremental_state=incremental_state, + target_lengths=cur_out_lens, + speaker=sample["speaker"], + **kwargs, + ) + cur_eos_prob = torch.sigmoid(cur_eos_out).squeeze(2) + feat.append(cur_extra["feature_out"]) + attn.append(cur_extra["attn"]) + eos_prob.append(cur_eos_prob) + + cur_finished = cur_eos_prob.squeeze(1) > self.eos_prob_threshold + out_lens.masked_fill_((~finished) & cur_finished, step + 1) + finished = finished | cur_finished + if finished.sum().item() == bsz: + break + prev_feat_out = cur_extra["feature_out"] + + feat = torch.cat(feat, dim=1) + feat = model.decoder.postnet(feat) + feat + eos_prob = torch.cat(eos_prob, dim=1) + attn = torch.cat(attn, dim=2) + alignment = attn.max(dim=1)[1] + + feat = feat.reshape(bsz, -1, raw_dim) + feat = self.gcmvn_denormalize(feat) + + eos_prob = eos_prob.repeat_interleave(n_frames_per_step, dim=1) + attn = attn.repeat_interleave(n_frames_per_step, dim=2) + alignment = alignment.repeat_interleave(n_frames_per_step, dim=1) + out_lens = out_lens * n_frames_per_step + + finalized = [ + { + "feature": feat[b, :out_len], + "eos_prob": eos_prob[b, :out_len], + "attn": attn[b, :, :out_len], + "alignment": alignment[b, :out_len], + "waveform": self.get_waveform(feat[b, :out_len]), + } + for b, out_len in zip(range(bsz), out_lens) + ] + + if has_targ: + assert sample["target"].size(-1) == out_dim + tgt_feats = sample["target"].view(bsz, -1, raw_dim) + tgt_feats = self.gcmvn_denormalize(tgt_feats) + tgt_lens = sample["target_lengths"] * n_frames_per_step + for b, (f, l) in enumerate(zip(tgt_feats, tgt_lens)): + finalized[b]["targ_feature"] = f[:l] + finalized[b]["targ_waveform"] = self.get_waveform(f[:l]) + return finalized + + +class MultiDecoderSpeechGenerator(SpeechGenerator): + def __init__( + self, + models, + args, + vocoder, + data_cfg, + tgt_dict_mt, + max_iter: int = 6000, + eos_prob_threshold: float = 0.5, + eos_mt=None, + symbols_to_strip_from_output=None, + ): + super().__init__(models[0], vocoder, data_cfg) + self.max_iter = max_iter + self.eos_prob_threshold = eos_prob_threshold + + self.tgt_dict_mt = tgt_dict_mt + self.eos_mt = eos_mt + + from examples.speech_to_speech.unity.sequence_generator import SequenceGenerator + from fairseq import search + + self.text_generator = SequenceGenerator( + models, + tgt_dict_mt, + beam_size=max(1, getattr(args, "beam", 5)), + max_len_a=getattr(args, "max_len_a", 0), + max_len_b=getattr(args, "max_len_b", 200), + min_len=getattr(args, "min_len", 1), + normalize_scores=(not getattr(args, "unnormalized", False)), + len_penalty=getattr(args, "lenpen", 1), + unk_penalty=getattr(args, "unkpen", 0), + temperature=getattr(args, "temperature", 1.0), + match_source_len=getattr(args, "match_source_len", False), + no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), + search_strategy=search.BeamSearch(tgt_dict_mt), + eos=eos_mt, + symbols_to_strip_from_output=symbols_to_strip_from_output, + ) + + @torch.no_grad() + def generate(self, model, sample, has_targ=False, **kwargs): + model.eval() + + src_tokens = sample["net_input"]["src_tokens"] + src_lengths = sample["net_input"]["src_lengths"] + bsz, src_len = src_tokens.size()[:2] + n_frames_per_step = model.decoder.n_frames_per_step + out_dim = model.decoder.out_dim + raw_dim = out_dim // n_frames_per_step + + # initialize + encoder_out = model.forward_encoder( + src_tokens, src_lengths, speaker=sample["speaker"] + ) + + prefix_tokens = None + constraints = None + bos_token = None + + mt_decoder = getattr(model, f"{model.mt_task_name}_decoder") + + # 1. MT decoder + finalized_mt = self.text_generator.generate_decoder( + [encoder_out], + src_tokens, + src_lengths, + sample, + prefix_tokens, + constraints, + bos_token, + aux_task_name=model.mt_task_name, + ) + + # extract decoder output corresponding to the best hypothesis + max_tgt_len = max([len(hypo[0]["tokens"]) for hypo in finalized_mt]) + prev_output_tokens_mt = ( + src_tokens.new_zeros(src_tokens.shape[0], max_tgt_len) + .fill_(mt_decoder.padding_idx) + .int() + ) # B x T + for i, hypo in enumerate(finalized_mt): + i_beam = 0 + tmp = hypo[i_beam]["tokens"].int() # hyp + eos + prev_output_tokens_mt[i, 0] = self.text_generator.eos + if tmp[-1] == self.text_generator.eos: + tmp = tmp[:-1] + prev_output_tokens_mt[i, 1 : len(tmp) + 1] = tmp + + text = "".join([self.tgt_dict_mt[c] for c in tmp]) + text = text.replace("_", " ") + text = text.replace("▁", " ") + text = text.replace("", " ") + text = text.replace("", "") + text = text.replace("", "") + if len(text) > 0 and text[0] == " ": + text = text[1:] + sample_id = sample["id"].tolist()[i] + print("{} (None-{})".format(text, sample_id)) + + mt_decoder_out = mt_decoder( + prev_output_tokens_mt, + encoder_out=encoder_out, + features_only=True, + ) + x = mt_decoder_out[0].transpose(0, 1) + + mt_decoder_padding_mask = None + if prev_output_tokens_mt.eq(mt_decoder.padding_idx).any(): + mt_decoder_padding_mask = prev_output_tokens_mt.eq(mt_decoder.padding_idx) + + # 2. TTS encoder + if getattr(model, "synthesizer_encoder", None) is not None: + synthesizer_encoder_out = model.synthesizer_encoder( + x, + mt_decoder_padding_mask, + ) + else: + synthesizer_encoder_out = { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [mt_decoder_padding_mask] + if mt_decoder_padding_mask is not None + else [], # B x T + "encoder_embedding": [], + "encoder_states": [], + "src_tokens": [], + "src_lengths": [], + } + + # 3. TTS decoder + incremental_state = {} + feat, attn, eos_prob = [], [], [] + finished = src_tokens.new_zeros((bsz,)).bool() + out_lens = src_lengths.new_zeros((bsz,)).long().fill_(self.max_iter) + + prev_feat_out = encoder_out["encoder_out"][0].new_zeros(bsz, 1, out_dim) + for step in range(self.max_iter): + cur_out_lens = out_lens.clone() + cur_out_lens.masked_fill_(cur_out_lens.eq(self.max_iter), step + 1) + _, cur_eos_out, cur_extra = model.forward_decoder( + prev_feat_out, + encoder_out=synthesizer_encoder_out, + incremental_state=incremental_state, + target_lengths=cur_out_lens, + speaker=sample["speaker"], + **kwargs, + ) + cur_eos_prob = torch.sigmoid(cur_eos_out).squeeze(2) + feat.append(cur_extra["feature_out"]) + attn.append(cur_extra["attn"]) + eos_prob.append(cur_eos_prob) + + cur_finished = cur_eos_prob.squeeze(1) > self.eos_prob_threshold + out_lens.masked_fill_((~finished) & cur_finished, step + 1) + finished = finished | cur_finished + if finished.sum().item() == bsz: + break + prev_feat_out = cur_extra["feature_out"] + + feat = torch.cat(feat, dim=1) + feat = model.decoder.postnet(feat) + feat + eos_prob = torch.cat(eos_prob, dim=1) + attn = torch.cat(attn, dim=2) + alignment = attn.max(dim=1)[1] + + feat = feat.reshape(bsz, -1, raw_dim) + feat = self.gcmvn_denormalize(feat) + + eos_prob = eos_prob.repeat_interleave(n_frames_per_step, dim=1) + attn = attn.repeat_interleave(n_frames_per_step, dim=2) + alignment = alignment.repeat_interleave(n_frames_per_step, dim=1) + out_lens = out_lens * n_frames_per_step + + finalized = [ + { + "feature": feat[b, :out_len], + "eos_prob": eos_prob[b, :out_len], + "attn": attn[b, :, :out_len], + "alignment": alignment[b, :out_len], + "waveform": self.get_waveform(feat[b, :out_len]), + } + for b, out_len in zip(range(bsz), out_lens) + ] + + if has_targ: + assert sample["target"].size(-1) == out_dim + tgt_feats = sample["target"].view(bsz, -1, raw_dim) + tgt_feats = self.gcmvn_denormalize(tgt_feats) + tgt_lens = sample["target_lengths"] * n_frames_per_step + for b, (f, l) in enumerate(zip(tgt_feats, tgt_lens)): + finalized[b]["targ_feature"] = f[:l] + finalized[b]["targ_waveform"] = self.get_waveform(f[:l]) + return finalized + + +class NonAutoregressiveSpeechGenerator(SpeechGenerator): + @torch.no_grad() + def generate(self, model, sample, has_targ=False, **kwargs): + model.eval() + + bsz, max_src_len = sample["net_input"]["src_tokens"].size() + n_frames_per_step = model.encoder.n_frames_per_step + out_dim = model.encoder.out_dim + raw_dim = out_dim // n_frames_per_step + + feat, feat_post, out_lens, log_dur_out, _, _ = model( + src_tokens=sample["net_input"]["src_tokens"], + src_lengths=sample["net_input"]["src_lengths"], + prev_output_tokens=sample["net_input"]["prev_output_tokens"], + incremental_state=None, + target_lengths=sample["target_lengths"], + speaker=sample["speaker"], + ) + if feat_post is not None: + feat = feat_post + + feat = feat.view(bsz, -1, raw_dim) + feat = self.gcmvn_denormalize(feat) + + dur_out = torch.clamp(torch.round(torch.exp(log_dur_out) - 1).long(), min=0) + + def get_dur_plot_data(d): + r = [] + for i, dd in enumerate(d): + r += [i + 1] * dd.item() + return r + + out_lens = out_lens * n_frames_per_step + finalized = [ + { + "feature": feat[b, :l] if l > 0 else feat.new_zeros([1, raw_dim]), + "waveform": self.get_waveform( + feat[b, :l] if l > 0 else feat.new_zeros([1, raw_dim]) + ), + "attn": feat.new_tensor(get_dur_plot_data(dur_out[b])), + } + for b, l in zip(range(bsz), out_lens) + ] + + if has_targ: + tgt_feats = sample["target"].view(bsz, -1, raw_dim) + tgt_feats = self.gcmvn_denormalize(tgt_feats) + tgt_lens = sample["target_lengths"] * n_frames_per_step + for b, (f, l) in enumerate(zip(tgt_feats, tgt_lens)): + finalized[b]["targ_feature"] = f[:l] + finalized[b]["targ_waveform"] = self.get_waveform(f[:l]) + return finalized + + +class TeacherForcingAutoRegressiveSpeechGenerator(AutoRegressiveSpeechGenerator): + @torch.no_grad() + def generate(self, model, sample, has_targ=False, **kwargs): + model.eval() + + src_tokens = sample["net_input"]["src_tokens"] + src_lens = sample["net_input"]["src_lengths"] + prev_out_tokens = sample["net_input"]["prev_output_tokens"] + tgt_lens = sample["target_lengths"] + n_frames_per_step = model.decoder.n_frames_per_step + raw_dim = model.decoder.out_dim // n_frames_per_step + bsz = src_tokens.shape[0] + + feat, eos_prob, extra = model( + src_tokens, + src_lens, + prev_out_tokens, + incremental_state=None, + target_lengths=tgt_lens, + speaker=sample["speaker"], + ) + + attn = extra["attn"] # B x T_s x T_t + alignment = attn.max(dim=1)[1] + feat = feat.reshape(bsz, -1, raw_dim) + feat = self.gcmvn_denormalize(feat) + eos_prob = eos_prob.repeat_interleave(n_frames_per_step, dim=1) + attn = attn.repeat_interleave(n_frames_per_step, dim=2) + alignment = alignment.repeat_interleave(n_frames_per_step, dim=1) + tgt_lens = sample["target_lengths"] * n_frames_per_step + + finalized = [ + { + "feature": feat[b, :tgt_len], + "eos_prob": eos_prob[b, :tgt_len], + "attn": attn[b, :, :tgt_len], + "alignment": alignment[b, :tgt_len], + "waveform": self.get_waveform(feat[b, :tgt_len]), + } + for b, tgt_len in zip(range(bsz), tgt_lens) + ] + + if has_targ: + tgt_feats = sample["target"].view(bsz, -1, raw_dim) + tgt_feats = self.gcmvn_denormalize(tgt_feats) + for b, (f, l) in enumerate(zip(tgt_feats, tgt_lens)): + finalized[b]["targ_feature"] = f[:l] + finalized[b]["targ_waveform"] = self.get_waveform(f[:l]) + return finalized diff --git a/fairseq/tasks/__init__.py b/fairseq/tasks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6da1f001f095ad60557440ae39068f469b08ce1e --- /dev/null +++ b/fairseq/tasks/__init__.py @@ -0,0 +1,138 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import argparse +import importlib +import os + +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import merge_with_parent +from hydra.core.config_store import ConfigStore + +from .fairseq_task import FairseqTask, LegacyFairseqTask # noqa + + +# register dataclass +TASK_DATACLASS_REGISTRY = {} +TASK_REGISTRY = {} +TASK_CLASS_NAMES = set() + + +def setup_task(cfg: FairseqDataclass, **kwargs): + task = None + task_name = getattr(cfg, "task", None) + + if isinstance(task_name, str): + # legacy tasks + task = TASK_REGISTRY[task_name] + if task_name in TASK_DATACLASS_REGISTRY: + dc = TASK_DATACLASS_REGISTRY[task_name] + cfg = dc.from_namespace(cfg) + else: + task_name = getattr(cfg, "_name", None) + + if task_name and task_name in TASK_DATACLASS_REGISTRY: + remove_missing = "from_checkpoint" in kwargs and kwargs["from_checkpoint"] + dc = TASK_DATACLASS_REGISTRY[task_name] + cfg = merge_with_parent(dc(), cfg, remove_missing=remove_missing) + task = TASK_REGISTRY[task_name] + + assert ( + task is not None + ), f"Could not infer task type from {cfg}. Available argparse tasks: {TASK_REGISTRY.keys()}. Available hydra tasks: {TASK_DATACLASS_REGISTRY.keys()}" + + return task.setup_task(cfg, **kwargs) + + +def register_task(name, dataclass=None): + """ + New tasks can be added to fairseq with the + :func:`~fairseq.tasks.register_task` function decorator. + + For example:: + + @register_task('classification') + class ClassificationTask(FairseqTask): + (...) + + .. note:: + + All Tasks must implement the :class:`~fairseq.tasks.FairseqTask` + interface. + + Args: + name (str): the name of the task + """ + + def register_task_cls(cls): + if name in TASK_REGISTRY: + return TASK_REGISTRY[name] + + if not issubclass(cls, FairseqTask): + raise ValueError( + "Task ({}: {}) must extend FairseqTask".format(name, cls.__name__) + ) + if cls.__name__ in TASK_CLASS_NAMES: + raise ValueError( + "Cannot register task with duplicate class name ({})".format( + cls.__name__ + ) + ) + TASK_REGISTRY[name] = cls + TASK_CLASS_NAMES.add(cls.__name__) + + if dataclass is not None and not issubclass(dataclass, FairseqDataclass): + raise ValueError( + "Dataclass {} must extend FairseqDataclass".format(dataclass) + ) + + cls.__dataclass = dataclass + if dataclass is not None: + TASK_DATACLASS_REGISTRY[name] = dataclass + + cs = ConfigStore.instance() + node = dataclass() + node._name = name + cs.store(name=name, group="task", node=node, provider="fairseq") + + return cls + + return register_task_cls + + +def get_task(name): + return TASK_REGISTRY[name] + + +def import_tasks(tasks_dir, namespace): + for file in os.listdir(tasks_dir): + path = os.path.join(tasks_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + task_name = file[: file.find(".py")] if file.endswith(".py") else file + importlib.import_module(namespace + "." + task_name) + + # expose `task_parser` for sphinx + if task_name in TASK_REGISTRY: + parser = argparse.ArgumentParser(add_help=False) + group_task = parser.add_argument_group("Task name") + # fmt: off + group_task.add_argument('--task', metavar=task_name, + help='Enable this task with: ``--task=' + task_name + '``') + # fmt: on + group_args = parser.add_argument_group( + "Additional command-line arguments" + ) + TASK_REGISTRY[task_name].add_args(group_args) + globals()[task_name + "_parser"] = parser + + +# automatically import any Python files in the tasks/ directory +tasks_dir = os.path.dirname(__file__) +import_tasks(tasks_dir, "fairseq.tasks") diff --git a/fairseq/tasks/__pycache__/__init__.cpython-310.pyc b/fairseq/tasks/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f835631b5870af85a9575772902e7d16c56a66f Binary files /dev/null and b/fairseq/tasks/__pycache__/__init__.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/audio_classification.cpython-310.pyc b/fairseq/tasks/__pycache__/audio_classification.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae6b1c1243e106c76ba36b872f0a150cab755663 Binary files /dev/null and b/fairseq/tasks/__pycache__/audio_classification.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/audio_finetuning.cpython-310.pyc b/fairseq/tasks/__pycache__/audio_finetuning.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd0391c0b18aa3121d3e8cfe37339b72c406b645 Binary files /dev/null and b/fairseq/tasks/__pycache__/audio_finetuning.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/audio_pretraining.cpython-310.pyc b/fairseq/tasks/__pycache__/audio_pretraining.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64eb93a61d64c1665e810f363b55bc69eb171e3e Binary files /dev/null and b/fairseq/tasks/__pycache__/audio_pretraining.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/cross_lingual_lm.cpython-310.pyc b/fairseq/tasks/__pycache__/cross_lingual_lm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b4149c3a66543fe4051828fb02f85f8104fe898 Binary files /dev/null and b/fairseq/tasks/__pycache__/cross_lingual_lm.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/denoising.cpython-310.pyc b/fairseq/tasks/__pycache__/denoising.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90212d97f14ed9f7c799c01eac1a830274e9b0fb Binary files /dev/null and b/fairseq/tasks/__pycache__/denoising.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/fairseq_task.cpython-310.pyc b/fairseq/tasks/__pycache__/fairseq_task.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef92148cab4499be6afc09fb4014f7fadf907e5c Binary files /dev/null and b/fairseq/tasks/__pycache__/fairseq_task.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/frm_text_to_speech.cpython-310.pyc b/fairseq/tasks/__pycache__/frm_text_to_speech.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cea47a92045c28e1628694d895f84f9b158d8391 Binary files /dev/null and b/fairseq/tasks/__pycache__/frm_text_to_speech.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/hubert_pretraining.cpython-310.pyc b/fairseq/tasks/__pycache__/hubert_pretraining.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae7c2f328bb3dff79a48a66437b22065116ee88c Binary files /dev/null and b/fairseq/tasks/__pycache__/hubert_pretraining.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/language_modeling.cpython-310.pyc b/fairseq/tasks/__pycache__/language_modeling.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d05c180dbfbcd8f99f8f3b183994ca029e09bf5 Binary files /dev/null and b/fairseq/tasks/__pycache__/language_modeling.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/legacy_masked_lm.cpython-310.pyc b/fairseq/tasks/__pycache__/legacy_masked_lm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fb1d20fe6838b9f548420d037fa2009b19e25ae Binary files /dev/null and b/fairseq/tasks/__pycache__/legacy_masked_lm.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/masked_lm.cpython-310.pyc b/fairseq/tasks/__pycache__/masked_lm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae5b060f116da6293caf3b1f131be420a8230a6b Binary files /dev/null and b/fairseq/tasks/__pycache__/masked_lm.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/multilingual_denoising.cpython-310.pyc b/fairseq/tasks/__pycache__/multilingual_denoising.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de01779ae6398ee32730c2b84b80f85537b3b72e Binary files /dev/null and b/fairseq/tasks/__pycache__/multilingual_denoising.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/multilingual_language_modeling.cpython-310.pyc b/fairseq/tasks/__pycache__/multilingual_language_modeling.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa3bab1f5cdded39aee7187901aa8a3ac2f0d14b Binary files /dev/null and b/fairseq/tasks/__pycache__/multilingual_language_modeling.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/multilingual_masked_lm.cpython-310.pyc b/fairseq/tasks/__pycache__/multilingual_masked_lm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e508c28cc6b1ce9cfc233cdbdde28bcafd9b91e3 Binary files /dev/null and b/fairseq/tasks/__pycache__/multilingual_masked_lm.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/multilingual_translation.cpython-310.pyc b/fairseq/tasks/__pycache__/multilingual_translation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b3225a35407f7f5f9ad51143c48d96a663ed06b Binary files /dev/null and b/fairseq/tasks/__pycache__/multilingual_translation.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/nlu_finetuning.cpython-310.pyc b/fairseq/tasks/__pycache__/nlu_finetuning.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0440b5e02a37896b53cc3ccee396e350a629dad Binary files /dev/null and b/fairseq/tasks/__pycache__/nlu_finetuning.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/online_backtranslation.cpython-310.pyc b/fairseq/tasks/__pycache__/online_backtranslation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a40cd6b378fcde88f0b8cf0e9feea0e9d62b3d4 Binary files /dev/null and b/fairseq/tasks/__pycache__/online_backtranslation.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/semisupervised_translation.cpython-310.pyc b/fairseq/tasks/__pycache__/semisupervised_translation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c3c13ca2fc9c55535daadc6b1cb79e4326117ce Binary files /dev/null and b/fairseq/tasks/__pycache__/semisupervised_translation.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/sentence_prediction.cpython-310.pyc b/fairseq/tasks/__pycache__/sentence_prediction.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5661dbbdb794471fb3aeb61844e1dbe098e0162c Binary files /dev/null and b/fairseq/tasks/__pycache__/sentence_prediction.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/sentence_prediction_adapters.cpython-310.pyc b/fairseq/tasks/__pycache__/sentence_prediction_adapters.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f92cd3821b3c48faa582bed7b1a8b2f45b12b4f6 Binary files /dev/null and b/fairseq/tasks/__pycache__/sentence_prediction_adapters.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/sentence_ranking.cpython-310.pyc b/fairseq/tasks/__pycache__/sentence_ranking.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b562a40c9124b692d473e0c68e37404ee5cc79b0 Binary files /dev/null and b/fairseq/tasks/__pycache__/sentence_ranking.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/simultaneous_translation.cpython-310.pyc b/fairseq/tasks/__pycache__/simultaneous_translation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..951dfa3e9655f4c3c0a72c7a16e94d89c38d7322 Binary files /dev/null and b/fairseq/tasks/__pycache__/simultaneous_translation.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/span_masked_lm.cpython-310.pyc b/fairseq/tasks/__pycache__/span_masked_lm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61b16f2fe0fbf3f8ca49407e32cc3eaf30867eba Binary files /dev/null and b/fairseq/tasks/__pycache__/span_masked_lm.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/speech_dlm_task.cpython-310.pyc b/fairseq/tasks/__pycache__/speech_dlm_task.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..831e788b8bce3ba7a4dbfeed43f8c935385fb126 Binary files /dev/null and b/fairseq/tasks/__pycache__/speech_dlm_task.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/speech_to_speech.cpython-310.pyc b/fairseq/tasks/__pycache__/speech_to_speech.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2699172740ba4ad20e52bdf7e438bec9bf739320 Binary files /dev/null and b/fairseq/tasks/__pycache__/speech_to_speech.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/speech_to_text.cpython-310.pyc b/fairseq/tasks/__pycache__/speech_to_text.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f126d5b9722e2fb103bcca7d26a6f406126317b4 Binary files /dev/null and b/fairseq/tasks/__pycache__/speech_to_text.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/speech_ulm_task.cpython-310.pyc b/fairseq/tasks/__pycache__/speech_ulm_task.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e35de844891fba0ad0c2bfb599b864dec7173f47 Binary files /dev/null and b/fairseq/tasks/__pycache__/speech_ulm_task.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/text_to_speech.cpython-310.pyc b/fairseq/tasks/__pycache__/text_to_speech.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49aace989f76dccc9f0368fb973ef6f3db355239 Binary files /dev/null and b/fairseq/tasks/__pycache__/text_to_speech.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/translation.cpython-310.pyc b/fairseq/tasks/__pycache__/translation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61fc502429e35ec9f4f9bf77673853e9bb16aa72 Binary files /dev/null and b/fairseq/tasks/__pycache__/translation.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/translation_from_pretrained_bart.cpython-310.pyc b/fairseq/tasks/__pycache__/translation_from_pretrained_bart.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e95cd4d63b44e10ba978365cb8a718694c110d16 Binary files /dev/null and b/fairseq/tasks/__pycache__/translation_from_pretrained_bart.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/translation_from_pretrained_xlm.cpython-310.pyc b/fairseq/tasks/__pycache__/translation_from_pretrained_xlm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9543760c4cfa751dacb6705f3755b61f07cdd391 Binary files /dev/null and b/fairseq/tasks/__pycache__/translation_from_pretrained_xlm.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/translation_lev.cpython-310.pyc b/fairseq/tasks/__pycache__/translation_lev.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0550d2f9f6d8a54156be799e473592615ba38a18 Binary files /dev/null and b/fairseq/tasks/__pycache__/translation_lev.cpython-310.pyc differ diff --git a/fairseq/tasks/__pycache__/translation_multi_simple_epoch.cpython-310.pyc b/fairseq/tasks/__pycache__/translation_multi_simple_epoch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc3f307fe858ae556c099c0294990de619c2d8b8 Binary files /dev/null and b/fairseq/tasks/__pycache__/translation_multi_simple_epoch.cpython-310.pyc differ diff --git a/fairseq/tasks/audio_classification.py b/fairseq/tasks/audio_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..4c21d23b696f7bf358145db32943074cf1b01626 --- /dev/null +++ b/fairseq/tasks/audio_classification.py @@ -0,0 +1,269 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +from collections import OrderedDict +import itertools +import logging +import os +import sys +from dataclasses import dataclass, field +from typing import Optional + +import numpy as np +import torch +from omegaconf import II, MISSING +from sklearn import metrics as sklearn_metrics + +from fairseq.data import AddTargetDataset, Dictionary, FileAudioDataset +from fairseq.data.multi_corpus_dataset import MultiCorpusDataset +from fairseq.data.text_compressor import TextCompressionLevel, TextCompressor +from fairseq.dataclass import FairseqDataclass +from fairseq.tasks.audio_pretraining import AudioPretrainingConfig, AudioPretrainingTask +from fairseq.tasks.audio_finetuning import label_len_fn, LabelEncoder + +from .. import utils +from ..logging import metrics +from . import FairseqTask, register_task + +logger = logging.getLogger(__name__) + +@dataclass +class AudioClassificationConfig(AudioPretrainingConfig): + target_dictionary: Optional[str] = field( + default=None, metadata={"help": "override default dictionary location"} + ) + + +@register_task("audio_classification", dataclass=AudioClassificationConfig) +class AudioClassificationTask(AudioPretrainingTask): + """Task for audio classification tasks.""" + + cfg: AudioClassificationConfig + + def __init__( + self, + cfg: AudioClassificationConfig, + ): + super().__init__(cfg) + self.state.add_factory("target_dictionary", self.load_target_dictionary) + logging.info(f"=== Number of labels = {len(self.target_dictionary)}") + + def load_target_dictionary(self): + if self.cfg.labels: + target_dictionary = self.cfg.data + if self.cfg.target_dictionary: # override dict + target_dictionary = self.cfg.target_dictionary + dict_path = os.path.join(target_dictionary, f"dict.{self.cfg.labels}.txt") + logger.info("Using dict_path : {}".format(dict_path)) + return Dictionary.load(dict_path, add_special_symbols=False) + return None + + def load_dataset( + self, split: str, task_cfg: AudioClassificationConfig = None, **kwargs + ): + super().load_dataset(split, task_cfg, **kwargs) + task_cfg = task_cfg or self.cfg + assert task_cfg.labels is not None + text_compression_level = getattr( + TextCompressionLevel, str(self.cfg.text_compression_level) + ) + data_path = self.cfg.data + if task_cfg.multi_corpus_keys is None: + label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") + skipped_indices = getattr(self.datasets[split], "skipped_indices", set()) + text_compressor = TextCompressor(level=text_compression_level) + with open(label_path, "r") as f: + labels = [ + text_compressor.compress(l) + for i, l in enumerate(f) + if i not in skipped_indices + ] + + assert len(labels) == len(self.datasets[split]), ( + f"labels length ({len(labels)}) and dataset length " + f"({len(self.datasets[split])}) do not match" + ) + + process_label = LabelEncoder(self.target_dictionary) + + self.datasets[split] = AddTargetDataset( + self.datasets[split], + labels, + pad=self.target_dictionary.pad(), + eos=self.target_dictionary.eos(), + batch_targets=True, + process_label=process_label, + label_len_fn=label_len_fn, + add_to_input=False, + # text_compression_level=text_compression_level, + ) + else: + target_dataset_map = OrderedDict() + + multi_corpus_keys = [ + k.strip() for k in task_cfg.multi_corpus_keys.split(",") + ] + corpus_idx_map = {k: idx for idx, k in enumerate(multi_corpus_keys)} + + data_keys = [k.split(":") for k in split.split(",")] + + multi_corpus_sampling_weights = [ + float(val.strip()) + for val in task_cfg.multi_corpus_sampling_weights.split(",") + ] + data_weights = [] + for key, file_name in data_keys: + k = key.strip() + label_path = os.path.join( + data_path, f"{file_name.strip()}.{task_cfg.labels}" + ) + skipped_indices = getattr( + self.dataset_map[split][k], "skipped_indices", set() + ) + text_compressor = TextCompressor(level=text_compression_level) + with open(label_path, "r") as f: + labels = [ + text_compressor.compress(l) + for i, l in enumerate(f) + if i not in skipped_indices + ] + + assert len(labels) == len(self.dataset_map[split][k]), ( + f"labels length ({len(labels)}) and dataset length " + f"({len(self.dataset_map[split][k])}) do not match" + ) + + process_label = LabelEncoder(self.target_dictionary) + + # TODO: Remove duplication of code from the if block above + target_dataset_map[k] = AddTargetDataset( + self.dataset_map[split][k], + labels, + pad=self.target_dictionary.pad(), + eos=self.target_dictionary.eos(), + batch_targets=True, + process_label=process_label, + label_len_fn=label_len_fn, + add_to_input=False, + # text_compression_level=text_compression_level, + ) + + data_weights.append(multi_corpus_sampling_weights[corpus_idx_map[k]]) + + if len(target_dataset_map) == 1: + self.datasets[split] = list(target_dataset_map.values())[0] + else: + self.datasets[split] = MultiCorpusDataset( + target_dataset_map, + distribution=data_weights, + seed=0, + sort_indices=True, + ) + + @property + def source_dictionary(self): + return None + + @property + def target_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return self.state.target_dictionary + + def train_step(self, sample, model, *args, **kwargs): + sample["target"] = sample["target"].to(dtype=torch.long) + loss, sample_size, logging_output = super().train_step( + sample, model, *args, **kwargs + ) + self._log_metrics(sample, model, logging_output) + return loss, sample_size, logging_output + + def valid_step(self, sample, model, criterion): + sample["target"] = sample["target"].to(dtype=torch.long) + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) + self._log_metrics(sample, model, logging_output) + return loss, sample_size, logging_output + + def _log_metrics(self, sample, model, logging_output): + metrics = self._inference_with_metrics( + sample, + model, + ) + """ + logging_output["_precision"] = metrics["precision"] + logging_output["_recall"] = metrics["recall"] + logging_output["_f1"] = metrics["f1"] + logging_output["_eer"] = metrics["eer"] + logging_output["_accuracy"] = metrics["accuracy"] + """ + logging_output["_correct"] = metrics["correct"] + logging_output["_total"] = metrics["total"] + + def _inference_with_metrics(self, sample, model): + def _compute_eer(target_list, lprobs): + # from scipy.optimize import brentq + # from scipy.interpolate import interp1d + + y_one_hot = np.eye(len(self.state.target_dictionary))[target_list] + fpr, tpr, thresholds = sklearn_metrics.roc_curve( + y_one_hot.ravel(), lprobs.ravel() + ) + # Revisit the interpolation approach. + # eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0) + + fnr = 1 - tpr + eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] + + return eer + + with torch.no_grad(): + net_output = model(**sample["net_input"]) + lprobs = ( + model.get_normalized_probs(net_output, log_probs=True).cpu().detach() + ) + target_list = sample["target"][:, 0].detach().cpu() + predicted_list = torch.argmax(lprobs, 1).detach().cpu() # B,C->B + + metrics = { + "correct": torch.sum(target_list == predicted_list).item(), + "total": len(target_list), + } + return metrics + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + + zero = torch.scalar_tensor(0.0) + correct, total = 0, 0 + for log in logging_outputs: + correct += log.get("_correct", zero) + total += log.get("_total", zero) + metrics.log_scalar("_correct", correct) + metrics.log_scalar("_total", total) + + if total > 0: + def _fn_accuracy(meters): + if meters["_total"].sum > 0: + return utils.item(meters["_correct"].sum / meters["_total"].sum) + return float("nan") + + metrics.log_derived("accuracy", _fn_accuracy) + """ + prec_sum, recall_sum, f1_sum, acc_sum, eer_sum = 0.0, 0.0, 0.0, 0.0, 0.0 + for log in logging_outputs: + prec_sum += log.get("_precision", zero).item() + recall_sum += log.get("_recall", zero).item() + f1_sum += log.get("_f1", zero).item() + acc_sum += log.get("_accuracy", zero).item() + eer_sum += log.get("_eer", zero).item() + + metrics.log_scalar("avg_precision", prec_sum / len(logging_outputs)) + metrics.log_scalar("avg_recall", recall_sum / len(logging_outputs)) + metrics.log_scalar("avg_f1", f1_sum / len(logging_outputs)) + metrics.log_scalar("avg_accuracy", acc_sum / len(logging_outputs)) + metrics.log_scalar("avg_eer", eer_sum / len(logging_outputs)) + """ \ No newline at end of file diff --git a/fairseq/tasks/audio_finetuning.py b/fairseq/tasks/audio_finetuning.py new file mode 100644 index 0000000000000000000000000000000000000000..d79553cb86e830b12ed1b1affcb537cc4aec3630 --- /dev/null +++ b/fairseq/tasks/audio_finetuning.py @@ -0,0 +1,404 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import logging +import os +from fairseq.data.multi_corpus_dataset import MultiCorpusDataset +import torch +import json + +from argparse import Namespace +from dataclasses import dataclass, field +from typing import Optional, Any, OrderedDict + +from fairseq.data import AddTargetDataset, Dictionary, encoders +from fairseq.tasks.audio_pretraining import AudioPretrainingTask, AudioPretrainingConfig +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.configs import GenerationConfig +from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel + +from . import register_task +from .. import utils +from ..logging import metrics + + +logger = logging.getLogger(__name__) + + +class LabelEncoder(object): + def __init__(self, dictionary): + self.dictionary = dictionary + + def __call__(self, label): + return self.dictionary.encode_line( + label, append_eos=False, add_if_not_exist=False + ) + + +def label_len_fn(label): + return len(label.split(" ")) + + +@dataclass +class AudioFinetuningConfig(AudioPretrainingConfig): + # Options for reporting WER metrics during validation. Only applicable to + # Seq2Seq models during fine-tuning + eval_wer: bool = field( + default=False, metadata={"help": "compute WER for Seq2Seq models"} + ) + eval_wer_config: GenerationConfig = field( + default_factory=lambda: GenerationConfig(), + metadata={"help": "beam search config for evaluating wer during training"}, + ) + eval_wer_tokenizer: Any = field( + default=None, + metadata={"help": "tokenizer config for evaluating wer during training"}, + ) + eval_wer_post_process: str = field( + default="letter", + metadata={ + "help": "remove BPE tokens before scoring (can be sentencepiece, letter, and more)" + }, + ) + eval_bleu: bool = field( + default=False, metadata={"help": "evaluation with BLEU scores"} + ) + eval_bleu_detok: Optional[str] = field( + default=None, + metadata={ + "help": "detokenize before computing BLEU (e.g., 'moses'); " + "required if using --eval-bleu; use 'space' to disable " + "detokenization; see fairseq.data.encoders for other options" + }, + ) + eval_bleu_detok_args: str = field( + default="{}", metadata={"help": "args for building the tokenizer, if needed"} + ) + eval_tokenized_bleu: bool = field( + default=False, metadata={"help": "compute tokenized BLEU instead of sacrebleu"} + ) + eval_bleu_remove_bpe: Optional[str] = field( + default=None, metadata={"help": "remove BPE before computing BLEU"} + ) + eval_bleu_args: str = field( + default="{}", + metadata={ + "help": "generation args for BLUE scoring, e.g., " + '\'{"beam": 4, "lenpen": 0.6}\'' + }, + ) + eval_bleu_print_samples: bool = field( + default=False, metadata={"help": "print sample generations during validation"} + ) + autoregressive: bool = field( + default=False, + metadata={ + "help": "required for autoregressive decoders (like seq2seq models); " + "adds 'prev_output_tokens' to input and appends eos to target" + }, + ) + rebuild_batches: bool = True + target_dictionary: Optional[str] = field( + default=None, + metadata={ + "help": "override default dictionary location" + } + ) + +@register_task("audio_finetuning", dataclass=AudioFinetuningConfig) +class AudioFinetuningTask(AudioPretrainingTask): + """ """ + + cfg: AudioFinetuningConfig + + def __init__( + self, + cfg: AudioFinetuningConfig, + ): + super().__init__(cfg) + self.blank_symbol = "" + + self.state.add_factory("target_dictionary", self.load_target_dictionary) + + def load_target_dictionary(self): + if self.cfg.labels: + target_dictionary = self.cfg.data + if self.cfg.target_dictionary: # override dict + target_dictionary = self.cfg.target_dictionary + dict_path = os.path.join(target_dictionary, f"dict.{self.cfg.labels}.txt") + logger.info('Using dict_path : {}'.format(dict_path)) + return Dictionary.load(dict_path) + return None + + def load_dataset( + self, split: str, task_cfg: AudioFinetuningConfig = None, **kwargs + ): + super().load_dataset(split, task_cfg, **kwargs) + + task_cfg = task_cfg or self.cfg + assert task_cfg.labels is not None + text_compression_level = getattr( + TextCompressionLevel, str(self.cfg.text_compression_level) + ) + data_path = self.cfg.data + if task_cfg.multi_corpus_keys is None: + label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") + skipped_indices = getattr(self.datasets[split], "skipped_indices", set()) + text_compressor = TextCompressor(level=text_compression_level) + with open(label_path, "r") as f: + labels = [ + text_compressor.compress(l) + for i, l in enumerate(f) + if i not in skipped_indices + ] + + assert len(labels) == len(self.datasets[split]), ( + f"labels length ({len(labels)}) and dataset length " + f"({len(self.datasets[split])}) do not match" + ) + + process_label = LabelEncoder(self.target_dictionary) + + self.datasets[split] = AddTargetDataset( + self.datasets[split], + labels, + pad=self.target_dictionary.pad(), + eos=self.target_dictionary.eos(), + batch_targets=True, + process_label=process_label, + label_len_fn=label_len_fn, + add_to_input=task_cfg.get("autoregressive", False), + text_compression_level=text_compression_level, + ) + else: + + target_dataset_map = OrderedDict() + + multi_corpus_keys = [k.strip() for k in task_cfg.multi_corpus_keys.split(",")] + corpus_idx_map = {k: idx for idx, k in enumerate(multi_corpus_keys)} + + data_keys = [k.split(":") for k in split.split(",")] + + multi_corpus_sampling_weights = [float(val.strip()) for val in task_cfg.multi_corpus_sampling_weights.split(",")] + data_weights = [] + for key, file_name in data_keys: + k = key.strip() + label_path = os.path.join(data_path, f"{file_name.strip()}.{task_cfg.labels}") + skipped_indices = getattr(self.dataset_map[split][k], "skipped_indices", set()) + text_compressor = TextCompressor(level=text_compression_level) + with open(label_path, "r") as f: + labels = [ + text_compressor.compress(l) + for i, l in enumerate(f) + if i not in skipped_indices + ] + + assert len(labels) == len(self.dataset_map[split][k]), ( + f"labels length ({len(labels)}) and dataset length " + f"({len(self.dataset_map[split][k])}) do not match" + ) + + process_label = LabelEncoder(self.target_dictionary) + + # TODO: Remove duplication of code from the if block above + target_dataset_map[k] = AddTargetDataset( + self.dataset_map[split][k], + labels, + pad=self.target_dictionary.pad(), + eos=self.target_dictionary.eos(), + batch_targets=True, + process_label=process_label, + label_len_fn=label_len_fn, + add_to_input=task_cfg.get("autoregressive", False), + text_compression_level=text_compression_level, + ) + + data_weights.append(multi_corpus_sampling_weights[corpus_idx_map[k]]) + + if len(target_dataset_map) == 1: + self.datasets[split] = list(target_dataset_map.values())[0] + else: + self.datasets[split] = MultiCorpusDataset(target_dataset_map, distribution=data_weights, seed=0, sort_indices=True) + + @property + def target_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return self.state.target_dictionary + + def valid_step(self, sample, model, criterion): + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) + if self.cfg.eval_wer and self.cfg.autoregressive: + metrics = self._inference_with_wer(self.sequence_generator, sample, model) + logging_output["_num_char_errors"] = metrics["num_char_errors"] + logging_output["_num_chars"] = metrics["num_chars"] + logging_output["_num_word_errors"] = metrics["num_word_errors"] + logging_output["_num_words"] = metrics["num_words"] + if self.cfg.eval_bleu and self.cfg.autoregressive: + metrics = self._inference_with_bleu(self.sequence_generator, sample, model) + logging_output["_bleu_sys_len"] = metrics.sys_len + logging_output["_bleu_ref_len"] = metrics.ref_len + # we split counts into separate entries so that they can be + # summed efficiently across workers using fast-stat-sync + assert len(metrics.counts) == 4 + for i in range(4): + logging_output[f"_bleu_counts_{i}"] = metrics.counts[i] + logging_output[f"_bleu_totals_{i}"] = metrics.totals[i] + return loss, sample_size, logging_output + + def build_model(self, model_cfg: FairseqDataclass, from_checkpoint=False): + model = super().build_model(model_cfg, from_checkpoint) + + if self.cfg.eval_wer and self.cfg.autoregressive: + self.sequence_generator = self.build_generator( + [model], + self.cfg.eval_wer_config, + ) + if self.cfg.eval_wer_tokenizer: + self.tokenizer = encoders.build_tokenizer(self.cfg.eval_wer_tokenizer) + else: + self.tokenizer = None + if self.cfg.eval_bleu and self.cfg.autoregressive: + assert self.cfg.eval_bleu_detok is not None, ( + "--eval-bleu-detok is required if using --eval-bleu; " + "try --eval-bleu-detok=moses (or --eval-bleu-detok=space " + "to disable detokenization, e.g., when using sentencepiece)" + ) + detok_args = json.loads(self.cfg.eval_bleu_detok_args) + self.tokenizer = encoders.build_tokenizer( + Namespace(tokenizer=self.cfg.eval_bleu_detok, **detok_args) + ) + gen_args = json.loads(self.cfg.eval_bleu_args) + gen_args = Namespace(**gen_args) + self.sequence_generator = self.build_generator([model], gen_args) + + return model + + def _inference_with_wer(self, generator, sample, model): + import editdistance + + def decode(toks): + s = self.target_dictionary.string( + toks.int().cpu(), + self.cfg.eval_wer_post_process, + escape_unk=True, + ) + if self.tokenizer: + s = self.tokenizer.decode(s) + return s + + num_word_errors, num_char_errors = 0, 0 + num_chars, num_words = 0, 0 + gen_out = self.inference_step(generator, [model], sample, None) + for i in range(len(gen_out)): + hyp = decode(gen_out[i][0]["tokens"]) + ref = decode( + utils.strip_pad(sample["target"][i], self.target_dictionary.pad()), + ) + num_char_errors += editdistance.eval(hyp, ref) + num_chars += len(ref) + hyp_words = hyp.split() + ref_words = ref.split() + num_word_errors += editdistance.eval(hyp_words, ref_words) + num_words += len(ref_words) + + return { + "num_char_errors": num_char_errors, + "num_chars": num_chars, + "num_word_errors": num_word_errors, + "num_words": num_words, + } + + def _inference_with_bleu(self, generator, sample, model): + import sacrebleu + + def decode(toks, is_ref): + s = self.target_dictionary.string( + toks.int().cpu(), + self.cfg.eval_bleu_remove_bpe, + # The default unknown string in fairseq is ``, but + # this is tokenized by sacrebleu as `< unk >`, inflating + # BLEU scores. Instead, we use a somewhat more verbose + # alternative that is unlikely to appear in the real + # reference, but doesn't get split into multiple tokens. + unk_string=("UNKNOWNTOKENINREF" if is_ref else "UNKNOWNTOKENINHYP"), + ) + if self.tokenizer: + s = self.tokenizer.decode(s) + return s + + gen_out = self.inference_step(generator, [model], sample) + hyps, refs = [], [] + for i in range(len(gen_out)): + hyps.append(decode(gen_out[i][0]["tokens"], is_ref=False)) + refs.append( + decode( + utils.strip_pad(sample["target"][i], self.target_dictionary.pad()), + is_ref=True, # don't count as matches to the hypo + ) + ) + if self.cfg.eval_bleu_print_samples: + logger.info("H-{} {}".format(sample["id"][0], hyps[0])) + logger.info("T-{} {}".format(sample["id"][0], refs[0])) + + eval_tokenization = "none" if self.cfg.eval_tokenized_bleu else "13a" + return sacrebleu.corpus_bleu(hyps, [refs], tokenize=eval_tokenization) + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + + if self.cfg.eval_wer: + zero = torch.scalar_tensor(0.0) + num_char_errors = sum( + log.get("_num_char_errors", zero) for log in logging_outputs + ) + num_chars = sum(log.get("_num_chars", zero) for log in logging_outputs) + num_word_errors = sum( + log.get("_num_word_errors", zero) for log in logging_outputs + ) + num_words = sum(log.get("_num_words", zero) for log in logging_outputs) + metrics.log_scalar("_num_char_errors", num_char_errors) + metrics.log_scalar("_num_chars", num_chars) + metrics.log_scalar("_num_word_errors", num_word_errors) + metrics.log_scalar("_num_words", num_words) + if num_chars > 0: + metrics.log_derived( + "uer", + lambda meters: meters["_num_char_errors"].sum + * 100.0 + / meters["_num_chars"].sum + if meters["_num_chars"].sum > 0 + else float("nan"), + ) + if num_words > 0: + metrics.log_derived( + "wer", + lambda meters: meters["_num_word_errors"].sum + * 100.0 + / meters["_num_words"].sum + if meters["_num_words"].sum > 0 + else float("nan"), + ) + if self.cfg.eval_bleu: + len_keys = ["_bleu_sys_len", "_bleu_ref_len"] + count_keys = [f"_bleu_counts_{i}" for i in range(4)] + total_keys = [f"_bleu_totals_{i}" for i in range(4)] + for k in len_keys + count_keys + total_keys: + metrics.log_scalar(k, sum(log.get(k, 0) for log in logging_outputs)) + + import sacrebleu + + metrics.log_derived( + "bleu", + lambda meters: sacrebleu.compute_bleu( + correct=[meters[k].sum for k in count_keys], + total=[meters[k].sum for k in total_keys], + sys_len=meters["_bleu_sys_len"].sum, + ref_len=meters["_bleu_ref_len"].sum, + smooth_method="exp", + ).score, + ) diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py new file mode 100644 index 0000000000000000000000000000000000000000..3e91303b6964627426f8b4791cdd2fc4f8676e3e --- /dev/null +++ b/fairseq/tasks/audio_pretraining.py @@ -0,0 +1,253 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import logging +import os +import sys + +from argparse import Namespace +from dataclasses import dataclass, field +from typing import Optional, OrderedDict +from fairseq.data.multi_corpus_dataset import MultiCorpusDataset +from omegaconf import MISSING, II, OmegaConf + +from fairseq.data import BinarizedAudioDataset, FileAudioDataset, SubsampleDataset +from fairseq.dataclass import FairseqDataclass, ChoiceEnum +from fairseq.data.text_compressor import TextCompressionLevel + +from . import FairseqTask, register_task + + +logger = logging.getLogger(__name__) + + +@dataclass +class AudioMaskingConfig: + feature_encoder_spec: str = II("model.modalities.audio.feature_encoder_spec") + mask_prob: float = II("model.modalities.audio.mask_prob") + mask_prob_adjust: float = II("model.modalities.audio.mask_prob_adjust") + mask_length: int = II("model.modalities.audio.mask_length") + inverse_mask: bool = II("model.modalities.audio.inverse_mask") + mask_dropout: float = II("model.modalities.audio.mask_dropout") + clone_batch: int = II("model.clone_batch") + expand_adjacent: bool = False + non_overlapping: bool = False + + +@dataclass +class AudioPretrainingConfig(FairseqDataclass): + data: str = field(default=MISSING, metadata={"help": "path to data directory"}) + labels: Optional[str] = field( + default=None, + metadata={"help": "extension of the label file to load, used for fine-tuning"}, + ) + multi_corpus_keys: Optional[str] = field( + default=None, + metadata={"help": "Comma separated names for loading multi corpus datasets"}) + multi_corpus_sampling_weights: Optional[str] = field( + default=None, + metadata={"help": "Comma separated string of sampling weights corresponding to the multi_corpus_keys"}) + binarized_dataset: bool = field( + default=False, + metadata={ + "help": "if true, loads binarized dataset (useful for very large datasets). " + "See examples/wav2vec/scripts/binarize_manifest.sh" + }, + ) + sample_rate: int = field( + default=16_000, + metadata={ + "help": "target sample rate. audio files will be up/down sampled to this rate" + }, + ) + normalize: bool = field( + default=False, + metadata={"help": "if set, normalizes input to have 0 mean and unit variance"}, + ) + enable_padding: bool = field( + default=False, metadata={"help": "pad shorter samples instead of cropping"} + ) + max_sample_size: Optional[int] = field( + default=None, metadata={"help": "max sample size to crop to for batching"} + ) + min_sample_size: Optional[int] = field( + default=None, metadata={"help": "min sample size to skip small examples"} + ) + num_batch_buckets: int = field( + default=0, + metadata={"help": "number of buckets"}, + ) + tpu: bool = II("common.tpu") + text_compression_level: ChoiceEnum([x.name for x in TextCompressionLevel]) = field( + default="none", + metadata={ + "help": "compression level for texts (e.g. audio filenames, " + "target texts): none/low/high (default: none). " + }, + ) + + rebuild_batches: bool = True + precompute_mask_config: Optional[AudioMaskingConfig] = None + + post_save_script: Optional[str] = None + + subsample: float = 1 + seed: int = II("common.seed") + + +@register_task("audio_pretraining", dataclass=AudioPretrainingConfig) +class AudioPretrainingTask(FairseqTask): + """ """ + + cfg: AudioPretrainingConfig + + @classmethod + def setup_task(cls, cfg: AudioPretrainingConfig, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + cfg (AudioPretrainingConfig): configuration of this task + """ + + return cls(cfg) + + def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): + data_path = self.cfg.data + task_cfg = task_cfg or self.cfg + + # upgrade old task + if isinstance(task_cfg, Namespace): + if not hasattr(task_cfg, "autoregressive"): + task_cfg.autoregressive = not task_cfg.criterion == "ctc" + + text_compression_level = getattr( + TextCompressionLevel, str(self.cfg.text_compression_level) + ) + + compute_mask = getattr(task_cfg, "precompute_mask_config", None) is not None + mask_args = {} + if compute_mask: + mask_args = task_cfg.precompute_mask_config + + if getattr(task_cfg, "binarized_dataset", False): + self.datasets[split] = BinarizedAudioDataset( + data_path, + split=split, + sample_rate=task_cfg.get("sample_rate", self.cfg.sample_rate), + max_sample_size=self.cfg.max_sample_size, + min_sample_size=self.cfg.min_sample_size, + pad=task_cfg.labels is not None or task_cfg.enable_padding, + normalize=task_cfg.normalize, + num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu), + compute_mask=compute_mask, + **mask_args, + ) + else: + if task_cfg.multi_corpus_keys is None: + manifest_path = os.path.join(data_path, "{}.tsv".format(split)) + + self.datasets[split] = FileAudioDataset( + manifest_path=manifest_path, + sample_rate=task_cfg.get("sample_rate", self.cfg.sample_rate), + max_sample_size=self.cfg.max_sample_size, + min_sample_size=self.cfg.min_sample_size, + pad=task_cfg.labels is not None or task_cfg.enable_padding, + normalize=task_cfg.normalize, + num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu), + text_compression_level=text_compression_level, + compute_mask=compute_mask, + **mask_args, + ) + else: + dataset_map = OrderedDict() + self.dataset_map = {} + multi_corpus_keys = [k.strip() for k in task_cfg.multi_corpus_keys.split(",")] + corpus_idx_map = {k: idx for idx, k in enumerate(multi_corpus_keys)} + data_keys = [k.split(":") for k in split.split(",")] + + multi_corpus_sampling_weights = [float(val.strip()) for val in task_cfg.multi_corpus_sampling_weights.split(",")] + data_weights = [] + + for key, file_name in data_keys: + + k = key.strip() + manifest_path = os.path.join(data_path, "{}.tsv".format(file_name.strip())) + + # TODO: Remove duplication of code from the if block above + dataset_map[k] = FileAudioDataset( + manifest_path=manifest_path, + sample_rate=task_cfg.get("sample_rate", self.cfg.sample_rate), + max_sample_size=self.cfg.max_sample_size, + min_sample_size=self.cfg.min_sample_size, + pad=task_cfg.labels is not None or task_cfg.enable_padding, + normalize=task_cfg.normalize, + num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu), + text_compression_level=text_compression_level, + compute_mask=compute_mask, + corpus_key=corpus_idx_map[k], + **mask_args, + ) + + data_weights.append(multi_corpus_sampling_weights[corpus_idx_map[k]]) + + self.dataset_map[split] = dataset_map + + if len(dataset_map) == 1: + self.datasets[split] = list(dataset_map.values())[0] + else: + self.datasets[split] = MultiCorpusDataset(dataset_map, distribution=data_weights, seed=0, sort_indices=True) + + if getattr(task_cfg, "subsample", 1) < 1: + self.datasets[split] = SubsampleDataset( + self.datasets[split], + task_cfg.subsample, + shuffle=True, + seed=task_cfg.seed, + ) + + if self.cfg.tpu and task_cfg.inferred_w2v_config.mask_channel_prob == 0.0: + logger.info( + "Pretraining on TPUs may suffer convergence " + "issues when training with `mask_channel_prob` value of " + "0. You may want to set this to a low value close to 0." + ) + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return sys.maxsize, sys.maxsize + + def build_model(self, model_cfg: FairseqDataclass, from_checkpoint=False): + model = super().build_model(model_cfg, from_checkpoint) + + actualized_cfg = getattr(model, "cfg", None) + if actualized_cfg is not None: + # if "w2v_args" in actualized_cfg: + if hasattr(actualized_cfg, "w2v_args"): + model_cfg.w2v_args = actualized_cfg.w2v_args + + return model + + def post_save(self, cp_path, num_updates): + if self.cfg.post_save_script is not None: + logger.info(f"launching {self.cfg.post_save_script}") + import os.path as osp + from fairseq.file_io import PathManager + + eval_cp_path = osp.join( + osp.dirname(cp_path), f"checkpoint_eval_{num_updates}.pt" + ) + + print(cp_path, eval_cp_path, osp.dirname(cp_path)) + + assert PathManager.copy( + cp_path, eval_cp_path, overwrite=True + ), f"Failed to copy {cp_path} to {eval_cp_path}" + + import subprocess + import shlex + + subprocess.call(shlex.split(f"{self.cfg.post_save_script} {eval_cp_path}")) diff --git a/fairseq/tasks/cross_lingual_lm.py b/fairseq/tasks/cross_lingual_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..8f8fe7e2de181e41bd0e6a2bf96948ee78de5ae8 --- /dev/null +++ b/fairseq/tasks/cross_lingual_lm.py @@ -0,0 +1,191 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import logging +import os +from collections import OrderedDict + +import numpy as np +from fairseq import tokenizer, utils +from fairseq.data import ConcatDataset, Dictionary, TokenBlockDataset, data_utils +from fairseq.data.legacy.masked_lm_dataset import MaskedLMDataset +from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary +from fairseq.data.multi_corpus_sampled_dataset import MultiCorpusSampledDataset +from fairseq.tasks import LegacyFairseqTask, register_task + + +logger = logging.getLogger(__name__) + + +@register_task("cross_lingual_lm") +class CrossLingualLMTask(LegacyFairseqTask): + """ + Task for training cross-lingual language models. + + For more details look at: https://arxiv.org/pdf/1901.07291.pdf + + Args: + dictionary (Dictionary): the dictionary for the input of the task + """ + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument( + "data", + help="colon separated path to data directories list, \ + will be iterated upon during epochs in round-robin manner", + ) + parser.add_argument( + "--tokens-per-sample", + default=512, + type=int, + help="max number of total tokens over all segments" " per sample", + ) + parser.add_argument( + "--monolingual-langs", + default="en", + type=str, + help="comma separated list of languages for which we" + " want to train XLM on", + ) + parser.add_argument( + "--shuffle", + action="store_true", + help="shuffle each monolingual dataset while" " training", + ) + + def __init__(self, args, dictionary): + super().__init__(args) + self.dictionary = dictionary + self.seed = args.seed + self.distributed_world_size = args.distributed_world_size + self.langs2id = self._lang_to_id(args.monolingual_langs) + + def _lang_to_id(self, languages: str): + """ + Build a map from languages to ids. These ids are used as segment labels + for cross-lingual LM training. + """ + lang2id = {} + langs = [l.strip() for l in languages.split(",")] + for id, lang in enumerate(langs): + lang2id[lang] = id + return lang2id + + @classmethod + def load_dictionary(cls, filename): + return MaskedLMDictionary.load(filename) + + @classmethod + def build_dictionary( + cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8 + ): + d = MaskedLMDictionary() + for filename in filenames: + Dictionary.add_file_to_dictionary( + filename, d, tokenizer.tokenize_line, workers + ) + d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor) + return d + + @property + def target_dictionary(self): + return self.dictionary + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task.""" + dictionary = MaskedLMDictionary.load(os.path.join(args.data, "dict.txt")) + logger.info("dictionary: {} types".format(len(dictionary))) + return cls(args, dictionary) + + def _load_single_lang_dataset(self, split, epoch): + loaded_datasets = [] + + paths = utils.split_paths(self.args.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + + for k in itertools.count(): + split_k = split + (str(k) if k > 0 else "") + path = os.path.join(data_path, split_k) + + ds = data_utils.load_indexed_dataset( + path, self.dictionary, self.args.dataset_impl + ) + if ds is None: + if k > 0: + break + else: + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, data_path) + ) + + # Since we append each block with the classification_token, + # we need to effectively create blocks of length + # tokens_per_sample-1 + loaded_datasets.append( + TokenBlockDataset( + ds, + ds.sizes, + self.args.tokens_per_sample - 1, + pad=self.dictionary.pad(), + eos=self.dictionary.eos(), + ) + ) + + logger.info( + "{} {} {} examples".format(data_path, split_k, len(loaded_datasets[-1])) + ) + + if len(loaded_datasets) == 1: + dataset = loaded_datasets[0] + sizes = dataset.sizes + else: + dataset = ConcatDataset(loaded_datasets) + sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) + + return dataset, sizes + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + dataset_map = OrderedDict() + + for lang in self.langs2id.keys(): + # Datasets are expected to be in "split.lang" format (Eg: train.en) + language_split = "{}.{}".format(split, lang) + + block_dataset, sizes = self._load_single_lang_dataset( + split=language_split, epoch=epoch + ) + + dataset_map[lang] = MaskedLMDataset( + dataset=block_dataset, + sizes=sizes, + vocab=self.dictionary, + pad_idx=self.dictionary.pad(), + mask_idx=self.dictionary.mask(), + classif_token_idx=self.dictionary.eos(), + sep_token_idx=self.dictionary.eos(), + shuffle=getattr(self.args, "shuffle", False), + has_pairs=False, + segment_id=self.langs2id[lang], + seed=self.seed, + ) + + self.datasets[split] = MultiCorpusSampledDataset(dataset_map) + logger.info( + "{} {} {} examples".format( + utils.split_paths(self.args.data)[epoch - 1], + split, + len(self.datasets[split]), + ) + ) diff --git a/fairseq/tasks/denoising.py b/fairseq/tasks/denoising.py new file mode 100644 index 0000000000000000000000000000000000000000..57b824d5812535d2a2f83292bad4f707e14ec618 --- /dev/null +++ b/fairseq/tasks/denoising.py @@ -0,0 +1,296 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from dataclasses import dataclass, field +from typing import Any, Optional + +import numpy as np +from omegaconf import II, MISSING + +from fairseq import utils +from fairseq.data import ( + AppendTokenDataset, + DenoisingDataset, + Dictionary, + IdDataset, + NestedDictionaryDataset, + NumelDataset, + PadDataset, + PrependTokenDataset, + StripTokenDataset, + TokenBlockDataset, + data_utils, +) +from fairseq.data.encoders.utils import get_whole_word_mask +from fairseq.data.shorten_dataset import maybe_shorten_dataset +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.tasks import FairseqTask, register_task + +from ..data.indexed_dataset import get_available_dataset_impl + +logger = logging.getLogger(__name__) + +SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"]) +SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"]) +MASK_LENGTH_CHOICES = ChoiceEnum(["subword", "word", "span-poisson"]) + + +@dataclass +class DenoisingConfig(FairseqDataclass): + data: str = field( + default=MISSING, + metadata={"help": "path to data directory"}, + ) + bpe: Optional[str] = field( + default=None, + metadata={"help": "TODO"}, + ) + tokens_per_sample: int = field( + default=512, + metadata={ + "help": "max number of total tokens over all segments " + "per sample for dataset" + }, + ) + sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field( + default="complete_doc", + metadata={ + "help": 'If omitted or "none", fills each sample with tokens-per-sample ' + 'tokens. If set to "complete", splits samples only at the end ' + "of sentence, but may include multiple sentences per sample. " + '"complete_doc" is similar but respects doc boundaries. ' + 'If set to "eos", includes only one sentence per sample.' + }, + ) + replace_length: int = field( + default=0, + metadata={"help": "TODO, should only allow -1, 0 and 1"}, + ) + mask: float = field( + default=0.0, + metadata={"help": "fraction of words/subwords that will be masked"}, + ) + mask_random: float = field( + default=0.0, + metadata={"help": "instead of using [MASK], use random token this often"}, + ) + insert: float = field( + default=0.0, + metadata={"help": "insert this percentage of additional random tokens"}, + ) + permute: float = field( + default=0.0, + metadata={"help": "take this proportion of subwords and permute them"}, + ) + rotate: float = field( + default=0.5, + metadata={"help": "rotate this proportion of inputs"}, + ) + poisson_lambda: float = field( + default=3.0, + metadata={"help": "randomly shuffle sentences for this proportion of inputs"}, + ) + shuffle_instance: float = field( + default=0.0, + metadata={"help": "shuffle this proportion of sentences in all inputs"}, + ) + mask_length: MASK_LENGTH_CHOICES = field( + default="subword", + metadata={"help": "mask length to choose"}, + ) + permute_sentences: int = field( + default=-1, + metadata={ + "help": "when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)" + }, + ) + seed: int = II("common.seed") + shorten_method: SHORTEN_METHOD_CHOICES = field( + default="none", + metadata={ + "help": "if not none, shorten sequences that exceed --tokens-per-sample" + }, + ) + shorten_data_split_list: str = field( + default="", + metadata={ + "help": "comma-separated list of dataset splits to apply shortening to, " + 'e.g., "train,valid" (default: all dataset splits)' + }, + ) + max_source_positions: int = field( + default=1024, + metadata={"help": "max number of tokens in the source sequence"}, + ) + max_target_positions: int = field( + default=1024, + metadata={"help": "max number of tokens in the target sequence"}, + ) + dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II( + "dataset.dataset_impl" + ) + + +@register_task("denoising", dataclass=DenoisingConfig) +class DenoisingTask(FairseqTask): + """ + Denoising task for applying sequence to sequence denoising. (ie. BART) + """ + + cfg: DenoisingConfig + + def __init__(self, cfg, dictionary): + super().__init__(cfg) + self.dictionary = dictionary + + # add mask token + self.mask_idx = self.dictionary.add_symbol("") + + @classmethod + def setup_task(cls, cfg: DenoisingConfig, **kwargs): + """Setup the task.""" + paths = utils.split_paths(cfg.data) + assert len(paths) > 0 + dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) + logger.info("dictionary: {} types".format(len(dictionary))) + if not hasattr(cfg, "shuffle_instance"): + cfg.shuffle_instance = False + return cls(cfg, dictionary) + + def _load_dataset_split(self, split, epoch, combine): + paths = utils.split_paths(self.cfg.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + split_path = os.path.join(data_path, split) + + dataset = data_utils.load_indexed_dataset( + split_path, + self.dictionary, + self.cfg.dataset_impl, + combine=combine, + ) + if dataset is None: + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, split_path) + ) + + dataset = StripTokenDataset(dataset, self.dictionary.eos()) + + dataset = maybe_shorten_dataset( + dataset, + split, + self.cfg.shorten_data_split_list, + self.cfg.shorten_method, + self.cfg.tokens_per_sample, + self.cfg.seed, + ) + + # create continuous blocks of tokens + dataset = TokenBlockDataset( + dataset, + dataset.sizes, + self.cfg.tokens_per_sample - 2, + # one less for and one for + pad=self.dictionary.pad(), + eos=self.dictionary.eos(), + break_mode=self.cfg.sample_break_mode, + document_sep_len=0, + ) + logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) + + # prepend beginning-of-sentence token (, equiv. to [CLS] in BERT) + dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) + dataset = AppendTokenDataset(dataset, self.source_dictionary.eos()) + return dataset + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + dataset = self._load_dataset_split(split, epoch, combine) + + mask_whole_words = ( + get_whole_word_mask(self.cfg.bpe, self.source_dictionary) + if self.cfg.mask_length != "subword" + else None + ) + + self.datasets[split] = DenoisingDataset( + dataset, + dataset.sizes, + self.dictionary, + self.mask_idx, + mask_whole_words, + shuffle=self.cfg.shuffle_instance, + seed=self.cfg.seed, + mask=self.cfg.mask, + mask_random=self.cfg.mask_random, + insert=self.cfg.insert, + rotate=self.cfg.rotate, + permute_sentences=self.cfg.permute_sentences, + bpe=self.cfg.bpe, + replace_length=self.cfg.replace_length, + mask_length=self.cfg.mask_length, + poisson_lambda=self.cfg.poisson_lambda, + ) + logger.info( + "Split: {0}, Loaded {1} samples of denoising_dataset".format( + split, + len(self.datasets[split]), + ) + ) + + def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): + """ + Generate batches for inference. We assume that the input begins with a + bos symbol (``) and ends with an eos symbol (``). + """ + pad = self.source_dictionary.pad() + eos = self.source_dictionary.eos() + src_dataset = TokenBlockDataset( + src_tokens, + src_lengths, + block_size=self.cfg.tokens_per_sample - 2, # for and + pad=pad, + eos=eos, + break_mode=self.cfg.sample_break_mode, + document_sep_len=0, + ) + prev_output_tokens = PrependTokenDataset( + StripTokenDataset(src_dataset, eos), eos + ) + src_dataset = PadDataset(src_dataset, pad_idx=pad, left_pad=False) + return NestedDictionaryDataset( + { + "id": IdDataset(), + "net_input": { + "src_tokens": src_dataset, + "src_lengths": NumelDataset(src_dataset, reduce=False), + "prev_output_tokens": PadDataset( + prev_output_tokens, pad_idx=pad, left_pad=False + ), + }, + "target": src_dataset, + }, + sizes=[np.array(src_lengths)], + ) + + def max_positions(self): + """Return the max sentence length allowed by the task.""" + return (self.cfg.max_source_positions, self.cfg.max_target_positions) + + @property + def source_dictionary(self): + """Return the source :class:`~fairseq.data.Dictionary`.""" + return self.dictionary + + @property + def target_dictionary(self): + """Return the target :class:`~fairseq.data.Dictionary`.""" + return self.dictionary diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py new file mode 100644 index 0000000000000000000000000000000000000000..e39d1d684882b3269e1fb97ebb6db61ef1a97df9 --- /dev/null +++ b/fairseq/tasks/fairseq_task.py @@ -0,0 +1,708 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import warnings +from argparse import Namespace +from typing import Any, Callable, Dict, List + +import torch +from fairseq import search, tokenizer, utils +from fairseq.logging import metrics +from fairseq.data import Dictionary, FairseqDataset, data_utils, encoders, iterators +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import gen_parser_from_dataclass +from fairseq.optim.amp_optimizer import AMPOptimizer +from omegaconf import DictConfig + + +logger = logging.getLogger(__name__) + + +class StatefulContainer(object): + def __init__(self): + self._state = dict() + self._factories = dict() + + def add_factory(self, name, factory: Callable[[], Any]): + self._factories[name] = factory + + def merge_state_dict(self, state_dict: Dict[str, Any]): + self._state.update(state_dict) + + @property + def state_dict(self) -> Dict[str, Any]: + return self._state + + def __getattr__(self, name): + if name not in self._state and name in self._factories: + self._state[name] = self._factories[name]() + + if name in self._state: + return self._state[name] + + raise AttributeError(f"Task state has no factory for attribute {name}") + + +class FairseqTask(object): + """ + Tasks store dictionaries and provide helpers for loading/iterating over + Datasets, initializing the Model/Criterion and calculating the loss. + + Tasks have limited statefulness. In particular, state that needs to be + saved to/loaded from checkpoints needs to be stored in the `self.state` + :class:`StatefulContainer` object. For example:: + + self.state.add_factory("dictionary", self.load_dictionary) + print(self.state.dictionary) # calls self.load_dictionary() + + This is necessary so that when loading checkpoints, we can properly + recreate the task state after initializing the task instance. + """ + + @classmethod + def add_args(cls, parser): + """Add task-specific arguments to the parser.""" + dc = getattr(cls, "__dataclass", None) + if dc is not None: + gen_parser_from_dataclass(parser, dc()) + + @staticmethod + def logging_outputs_can_be_summed(criterion) -> bool: + """ + Whether the logging outputs returned by `train_step` and `valid_step` can + be summed across workers prior to calling `aggregate_logging_outputs`. + Setting this to True will improves distributed training speed. + """ + return criterion.logging_outputs_can_be_summed() + + def __init__(self, cfg: FairseqDataclass, **kwargs): + self.cfg = cfg + self.datasets = dict() + self.dataset_to_epoch_iter = dict() + self.state = StatefulContainer() + + @classmethod + def load_dictionary(cls, filename): + """Load the dictionary from the filename + + Args: + filename (str): the filename + """ + return Dictionary.load(filename) + + @classmethod + def build_dictionary( + cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8 + ): + """Build the dictionary + + Args: + filenames (list): list of filenames + workers (int): number of concurrent workers + threshold (int): defines the minimum word count + nwords (int): defines the total number of words in the final dictionary, + including special symbols + padding_factor (int): can be used to pad the dictionary size to be a + multiple of 8, which is important on some hardware (e.g., Nvidia + Tensor Cores). + """ + d = Dictionary() + for filename in filenames: + Dictionary.add_file_to_dictionary( + filename, d, tokenizer.tokenize_line, workers + ) + d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor) + return d + + @classmethod + def setup_task(cls, cfg: DictConfig, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + cfg (omegaconf.DictConfig): parsed command-line arguments + """ + return cls(cfg, **kwargs) + + def has_sharded_data(self, split): + return os.pathsep in getattr(self.cfg, "data", "") + + def load_dataset( + self, + split: str, + combine: bool = False, + task_cfg: FairseqDataclass = None, + **kwargs, + ): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + combine (bool): combines a split segmented into pieces into one dataset + task_cfg (FairseqDataclass): optional task configuration stored in the checkpoint that can be used + to load datasets + """ + raise NotImplementedError + + def dataset(self, split): + """ + Return a loaded dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + + Returns: + a :class:`~fairseq.data.FairseqDataset` corresponding to *split* + """ + from fairseq.data import FairseqDataset + + if split not in self.datasets: + raise KeyError("Dataset not loaded: " + split) + if not isinstance(self.datasets[split], FairseqDataset): + raise TypeError("Datasets are expected to be of type FairseqDataset") + return self.datasets[split] + + def filter_indices_by_size( + self, indices, dataset, max_positions=None, ignore_invalid_inputs=False + ): + """ + Filter examples that are too large + + Args: + indices (np.array): original array of sample indices + dataset (~fairseq.data.FairseqDataset): dataset to batch + max_positions (optional): max sentence length supported by the + model (default: None). + ignore_invalid_inputs (bool, optional): don't raise Exception for + sentences that are too long (default: False). + Returns: + np.array: array of filtered sample indices + """ + indices, ignored = dataset.filter_indices_by_size(indices, max_positions) + if len(ignored) > 0: + if not ignore_invalid_inputs: + raise Exception( + ( + "Size of sample #{} is invalid (={}) since max_positions={}, " + "skip this example with --skip-invalid-size-inputs-valid-test" + ).format(ignored[0], dataset.size(ignored[0]), max_positions) + ) + logger.warning( + ( + "{:,} samples have invalid sizes and will be skipped, " + "max_positions={}, first few sample ids={}" + ).format(len(ignored), max_positions, ignored[:10]) + ) + return indices + + def can_reuse_epoch_itr(self, dataset): + # We can reuse the epoch iterator across epochs as long as the dataset + # hasn't disabled it. We default to ``False`` here, although in practice + # this will be ``True`` for most datasets that inherit from + # ``FairseqDataset`` due to the base implementation there. + return getattr(dataset, "can_reuse_epoch_itr_across_epochs", False) + + def get_batch_iterator( + self, + dataset, + max_tokens=None, + max_sentences=None, + max_positions=None, + ignore_invalid_inputs=False, + required_batch_size_multiple=1, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=1, + data_buffer_size=0, + disable_iterator_cache=False, + skip_remainder_batch=False, + grouped_shuffling=False, + update_epoch_batch_itr=False, + ): + """ + Get an iterator that yields batches of data from the given dataset. + + Args: + dataset (~fairseq.data.FairseqDataset): dataset to batch + max_tokens (int, optional): max number of tokens in each batch + (default: None). + max_sentences (int, optional): max number of sentences in each + batch (default: None). + max_positions (optional): max sentence length supported by the + model (default: None). + ignore_invalid_inputs (bool, optional): don't raise Exception for + sentences that are too long (default: False). + required_batch_size_multiple (int, optional): require batch size to + be a multiple of N (default: 1). + seed (int, optional): seed for random number generator for + reproducibility (default: 1). + num_shards (int, optional): shard the data iterator into N + shards (default: 1). + shard_id (int, optional): which shard of the data iterator to + return (default: 0). + num_workers (int, optional): how many subprocesses to use for data + loading. 0 means the data will be loaded in the main process + (default: 0). + epoch (int, optional): the epoch to start the iterator from + (default: 1). + data_buffer_size (int, optional): number of batches to + preload (default: 0). + disable_iterator_cache (bool, optional): don't cache the + EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`) + (default: False). + skip_remainder_batch (bool, optional): if set, discard the last + batch in each training epoch, as the last batch is often smaller than + local_batch_size * distributed_word_size (default: ``True``). + grouped_shuffling (bool, optional): group batches with each groups + containing num_shards batches and shuffle groups. Reduces difference + between sequence lengths among workers for batches sorted by length. + update_epoch_batch_itr (bool optional): if true then donot use the cached + batch iterator for the epoch + + Returns: + ~fairseq.iterators.EpochBatchIterator: a batched iterator over the + given dataset split + """ + can_reuse_epoch_itr = ( + not disable_iterator_cache + and not update_epoch_batch_itr + and self.can_reuse_epoch_itr(dataset) + ) + logger.info(f"can_reuse_epoch_itr = {can_reuse_epoch_itr}") + if can_reuse_epoch_itr and dataset in self.dataset_to_epoch_iter: + logger.debug("reusing EpochBatchIterator for epoch {}".format(epoch)) + return self.dataset_to_epoch_iter[dataset] + + assert isinstance(dataset, FairseqDataset) + + # initialize the dataset with the correct starting epoch + dataset.set_epoch(epoch) + + def make_batches(dataset, epoch): + logger.info(f"creating new batches for epoch {epoch}") + + # get indices ordered by example size + with data_utils.numpy_seed(seed + epoch): + indices = dataset.ordered_indices() + + # filter examples that are too large + if max_positions is not None: + indices = self.filter_indices_by_size( + indices, dataset, max_positions, ignore_invalid_inputs + ) + + # create mini-batches with given size constraints + batches = dataset.batch_by_size( + indices, + max_tokens=max_tokens, + max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + ) + return batches + + reuse_dataloader = getattr(self.cfg, "reuse_dataloader", True) + persistent_workers = getattr(self.cfg, "persistent_workers", True) + rebuild_batches = getattr(self.cfg, "rebuild_batches", False) + logger.info(f"reuse_dataloader = {reuse_dataloader}") + logger.info(f"rebuild_batches = {rebuild_batches}") + + if rebuild_batches: + logger.info("batches will be rebuilt for each epoch") + batch_sampler = make_batches + else: + batch_sampler = make_batches(dataset, epoch) + + # return a reusable, sharded iterator + epoch_iter = iterators.EpochBatchIterator( + dataset=dataset, + collate_fn=dataset.collater, + batch_sampler=batch_sampler, + seed=seed, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + epoch=epoch, + buffer_size=data_buffer_size, + skip_remainder_batch=skip_remainder_batch, + grouped_shuffling=grouped_shuffling, + reuse_dataloader=reuse_dataloader, + persistent_workers=persistent_workers, + ) + + if can_reuse_epoch_itr: + self.dataset_to_epoch_iter[dataset] = epoch_iter + + return epoch_iter + + def build_model(self, cfg: FairseqDataclass, from_checkpoint=False): + """ + Build the :class:`~fairseq.models.BaseFairseqModel` instance for this + task. + + Args: + cfg (FairseqDataclass): configuration object + + Returns: + a :class:`~fairseq.models.BaseFairseqModel` instance + """ + from fairseq import models, quantization_utils + + model = models.build_model(cfg, self, from_checkpoint) + model = quantization_utils.quantize_model_scalar(model, cfg) + return model + + def build_criterion(self, cfg: DictConfig, from_checkpoint=False): + """ + Build the :class:`~fairseq.criterions.FairseqCriterion` instance for + this task. + + Args: + cfg (omegaconf.DictConfig): configration object + + Returns: + a :class:`~fairseq.criterions.FairseqCriterion` instance + """ + from fairseq import criterions + + return criterions.build_criterion(cfg, self, from_checkpoint=from_checkpoint) + + def build_generator( + self, + models, + args, + seq_gen_cls=None, + extra_gen_cls_kwargs=None, + prefix_allowed_tokens_fn=None, + ): + """ + Build a :class:`~fairseq.SequenceGenerator` instance for this + task. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models + args (fairseq.dataclass.configs.GenerationConfig): + configuration object (dataclass) for generation + extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass + through to SequenceGenerator + prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]): + If provided, this function constrains the beam search to + allowed tokens only at each step. The provided function + should take 2 arguments: the batch ID (`batch_id: int`) + and a unidimensional tensor of token ids (`inputs_ids: + torch.Tensor`). It has to return a `List[int]` with the + allowed tokens for the next generation step conditioned + on the previously generated tokens (`inputs_ids`) and + the batch ID (`batch_id`). This argument is useful for + constrained generation conditioned on the prefix, as + described in "Autoregressive Entity Retrieval" + (https://arxiv.org/abs/2010.00904) and + https://github.com/facebookresearch/GENRE. + """ + if getattr(args, "score_reference", False): + from fairseq.sequence_scorer import SequenceScorer + + return SequenceScorer( + self.target_dictionary, + compute_alignment=getattr(args, "print_alignment", False), + ) + + from fairseq.sequence_generator import ( + SequenceGenerator, + SequenceGeneratorWithAlignment, + ) + + # Choose search strategy. Defaults to Beam Search. + sampling = getattr(args, "sampling", False) + sampling_topk = getattr(args, "sampling_topk", -1) + sampling_topp = getattr(args, "sampling_topp", -1.0) + diverse_beam_groups = getattr(args, "diverse_beam_groups", -1) + diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5) + match_source_len = getattr(args, "match_source_len", False) + diversity_rate = getattr(args, "diversity_rate", -1) + constrained = getattr(args, "constraints", False) + if prefix_allowed_tokens_fn is None: + prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None) + if ( + sum( + int(cond) + for cond in [ + sampling, + diverse_beam_groups > 0, + match_source_len, + diversity_rate > 0, + ] + ) + > 1 + ): + raise ValueError("Provided Search parameters are mutually exclusive.") + assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling" + assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling" + + if sampling: + search_strategy = search.Sampling( + self.target_dictionary, sampling_topk, sampling_topp + ) + elif diverse_beam_groups > 0: + search_strategy = search.DiverseBeamSearch( + self.target_dictionary, diverse_beam_groups, diverse_beam_strength + ) + elif match_source_len: + # this is useful for tagging applications where the output + # length should match the input length, so we hardcode the + # length constraints for simplicity + search_strategy = search.LengthConstrainedBeamSearch( + self.target_dictionary, + min_len_a=1, + min_len_b=0, + max_len_a=1, + max_len_b=0, + ) + elif diversity_rate > -1: + search_strategy = search.DiverseSiblingsSearch( + self.target_dictionary, diversity_rate + ) + elif constrained: + search_strategy = search.LexicallyConstrainedBeamSearch( + self.target_dictionary, args.constraints + ) + elif prefix_allowed_tokens_fn: + search_strategy = search.PrefixConstrainedBeamSearch( + self.target_dictionary, prefix_allowed_tokens_fn + ) + else: + search_strategy = search.BeamSearch(self.target_dictionary) + + extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} + if seq_gen_cls is None: + if getattr(args, "print_alignment", False): + seq_gen_cls = SequenceGeneratorWithAlignment + extra_gen_cls_kwargs["print_alignment"] = args.print_alignment + else: + seq_gen_cls = SequenceGenerator + + return seq_gen_cls( + models, + self.target_dictionary, + beam_size=getattr(args, "beam", 5), + max_len_a=getattr(args, "max_len_a", 0), + max_len_b=getattr(args, "max_len_b", 200), + min_len=getattr(args, "min_len", 1), + normalize_scores=(not getattr(args, "unnormalized", False)), + len_penalty=getattr(args, "lenpen", 1), + unk_penalty=getattr(args, "unkpen", 0), + temperature=getattr(args, "temperature", 1.0), + match_source_len=getattr(args, "match_source_len", False), + no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), + search_strategy=search_strategy, + **extra_gen_cls_kwargs, + ) + + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): + """ + Do forward and backward, and return the loss as computed by *criterion* + for the given *model* and *sample*. + + Args: + sample (dict): the mini-batch. The format is defined by the + :class:`~fairseq.data.FairseqDataset`. + model (~fairseq.models.BaseFairseqModel): the model + criterion (~fairseq.criterions.FairseqCriterion): the criterion + optimizer (~fairseq.optim.FairseqOptimizer): the optimizer + update_num (int): the current update + ignore_grad (bool): multiply loss by 0 if this is set to True + + Returns: + tuple: + - the loss + - the sample size, which is used as the denominator for the + gradient + - logging outputs to display while training + """ + model.train() + model.set_num_updates(update_num) + with torch.autograd.profiler.record_function("forward"): + with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))): + loss, sample_size, logging_output = criterion(model, sample) + if ignore_grad: + loss *= 0 + with torch.autograd.profiler.record_function("backward"): + optimizer.backward(loss) + return loss, sample_size, logging_output + + def valid_step(self, sample, model, criterion): + model.eval() + with torch.no_grad(): + loss, sample_size, logging_output = criterion(model, sample) + return loss, sample_size, logging_output + + def optimizer_step(self, optimizer, model, update_num): + optimizer.step() + + def build_dataset_for_inference( + self, src_tokens: List[torch.Tensor], src_lengths: List[int], **kwargs + ) -> torch.utils.data.Dataset: + raise NotImplementedError + + def inference_step( + self, generator, models, sample, prefix_tokens=None, constraints=None + ): + with torch.no_grad(): + return generator.generate( + models, sample, prefix_tokens=prefix_tokens, constraints=constraints + ) + + def begin_epoch(self, epoch, model): + """Hook function called before the start of each epoch.""" + pass + + def begin_valid_epoch(self, epoch, model): + """Hook function called before the start of each validation epoch.""" + pass + + def aggregate_logging_outputs(self, logging_outputs, criterion): + """[deprecated] Aggregate logging outputs from data parallel training.""" + utils.deprecation_warning( + "The aggregate_logging_outputs API is deprecated. " + "Please use the reduce_metrics API instead." + ) + with metrics.aggregate() as agg: + self.reduce_metrics(logging_outputs, criterion) + return agg.get_smoothed_values() + + def reduce_metrics(self, logging_outputs, criterion): + """Aggregate logging outputs from data parallel training.""" + # backward compatibility for tasks that override aggregate_logging_outputs + base_func = FairseqTask.aggregate_logging_outputs + self_func = getattr(self, "aggregate_logging_outputs").__func__ + if self_func is not base_func: + utils.deprecation_warning( + "Tasks should implement the reduce_metrics API. " + "Falling back to deprecated aggregate_logging_outputs API." + ) + agg_logging_outputs = self.aggregate_logging_outputs( + logging_outputs, criterion + ) + for k, v in agg_logging_outputs.items(): + metrics.log_scalar(k, v) + return + + if not any("ntokens" in log for log in logging_outputs): + warnings.warn( + "ntokens not found in Criterion logging outputs, cannot log wpb or wps" + ) + else: + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + metrics.log_scalar("wpb", ntokens, priority=180, round=1) + metrics.log_speed("wps", ntokens, priority=90, round=1) + + if not any("nsentences" in log for log in logging_outputs): + warnings.warn( + "nsentences not found in Criterion logging outputs, cannot log bsz" + ) + else: + nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) + metrics.log_scalar("bsz", nsentences, priority=190, round=1) + + criterion.__class__.reduce_metrics(logging_outputs) + + def state_dict(self): + if self.state is not None: + return self.state.state_dict + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]): + if self.state is not None: + self.state.merge_state_dict(state_dict) + + def max_positions(self): + """Return the max input length allowed by the task.""" + return None + + @property + def source_dictionary(self): + """Return the source :class:`~fairseq.data.Dictionary` (if applicable + for this task).""" + return None + + @property + def target_dictionary(self): + """Return the target :class:`~fairseq.data.Dictionary` (if applicable + for this task).""" + return None + + def build_tokenizer(self, args): + """Build the pre-tokenizer for this task.""" + return encoders.build_tokenizer(args) + + def build_bpe(self, args): + """Build the tokenizer for this task.""" + return encoders.build_bpe(args) + + def get_interactive_tokens_and_lengths(self, lines, encode_fn): + tokens = [ + self.source_dictionary.encode_line( + encode_fn(src_str), add_if_not_exist=False + ).long() + for src_str in lines + ] + lengths = [t.numel() for t in tokens] + return tokens, lengths + + +class LegacyFairseqTask(FairseqTask): + def __init__(self, args: Namespace): + super().__init__(None) + self.args = args + self.datasets = {} + self.dataset_to_epoch_iter = {} + + @classmethod + def setup_task(cls, args: Namespace, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + return cls(args, **kwargs) + + def has_sharded_data(self, split): + return os.pathsep in getattr(self.args, "data", "") + + def build_model(self, args: Namespace, from_checkpoint=False): + """ + Build the :class:`~fairseq.models.BaseFairseqModel` instance for this + task. + + Args: + args (argparse.Namespace): parsed command-line arguments + + Returns: + a :class:`~fairseq.models.BaseFairseqModel` instance + """ + from fairseq import models, quantization_utils + + model = models.build_model(args, self, from_checkpoint) + model = quantization_utils.quantize_model_scalar(model, args) + return model + + def build_criterion(self, args: Namespace): + """ + Build the :class:`~fairseq.criterions.FairseqCriterion` instance for + this task. + + Args: + args (argparse.Namespace): parsed command-line arguments + + Returns: + a :class:`~fairseq.criterions.FairseqCriterion` instance + """ + from fairseq import criterions + + return criterions.build_criterion(args, self) diff --git a/fairseq/tasks/frm_text_to_speech.py b/fairseq/tasks/frm_text_to_speech.py new file mode 100644 index 0000000000000000000000000000000000000000..667f5f8ee45694a992a44a8c975e98f34c5b207a --- /dev/null +++ b/fairseq/tasks/frm_text_to_speech.py @@ -0,0 +1,55 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from fairseq.data.audio.frm_text_to_speech_dataset import FrmTextToSpeechDatasetCreator +from fairseq.tasks import register_task +from fairseq.tasks.text_to_speech import TextToSpeechTask + + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, +) +logger = logging.getLogger(__name__) + + +@register_task("frm_text_to_speech") +class FrmTextToSpeechTask(TextToSpeechTask): + @staticmethod + def add_args(parser): + TextToSpeechTask.add_args(parser) + parser.add_argument("--do_chunk", action="store_true", help="train on chunks") + parser.add_argument("--chunk_bound", default=-1, type=int) + parser.add_argument("--chunk_init", default=50, type=int) + parser.add_argument("--chunk_incr", default=5, type=int) + parser.add_argument("--add_eos", action="store_true") + parser.add_argument("--dedup", action="store_true") + parser.add_argument("--ref_fpu", default=-1, type=float) + + def load_dataset(self, split, **unused_kwargs): + is_train_split = split.startswith("train") + pre_tokenizer = self.build_tokenizer(self.args) + bpe_tokenizer = self.build_bpe(self.args) + self.datasets[split] = FrmTextToSpeechDatasetCreator.from_tsv( + self.args.data, + self.data_cfg, + split, + self.src_dict, + pre_tokenizer, + bpe_tokenizer, + is_train_split=is_train_split, + n_frames_per_step=self.args.n_frames_per_step, + speaker_to_id=self.speaker_to_id, + do_chunk=self.args.do_chunk, + chunk_bound=self.args.chunk_bound, + chunk_init=self.args.chunk_init, + chunk_incr=self.args.chunk_incr, + add_eos=self.args.add_eos, + dedup=self.args.dedup, + ref_fpu=self.args.ref_fpu, + ) diff --git a/fairseq/tasks/hubert_pretraining.py b/fairseq/tasks/hubert_pretraining.py new file mode 100644 index 0000000000000000000000000000000000000000..1a3605f14dec1a9d1893cc03891991235b78199c --- /dev/null +++ b/fairseq/tasks/hubert_pretraining.py @@ -0,0 +1,191 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import logging +import os +import sys +from typing import Dict, List, Optional, Tuple + +import numpy as np + +from dataclasses import dataclass, field +from fairseq.data import Dictionary, HubertDataset +from fairseq.dataclass.configs import FairseqDataclass +from fairseq.tasks import register_task +from fairseq.tasks.fairseq_task import FairseqTask +from omegaconf import MISSING + +logger = logging.getLogger(__name__) + + +class LabelEncoder(object): + def __init__(self, dictionary: Dictionary) -> None: + self.dictionary = dictionary + + def __call__(self, label: str) -> List[str]: + return self.dictionary.encode_line( + label, + append_eos=False, + add_if_not_exist=False, + ) + + +@dataclass +class HubertPretrainingConfig(FairseqDataclass): + data: str = field(default=MISSING, metadata={"help": "path to data directory"}) + fine_tuning: bool = field( + default=False, metadata={"help": "set to true if fine-tuning Hubert"} + ) + labels: List[str] = field( + default_factory=lambda: ["ltr"], + metadata={ + "help": ( + "extension of the label files to load, frame-level labels for" + " pre-training, and sequence-level label for fine-tuning" + ) + }, + ) + label_dir: Optional[str] = field( + default=None, + metadata={ + "help": "if set, looks for labels in this directory instead", + }, + ) + label_rate: float = field( + default=-1.0, + metadata={"help": "label frame rate. -1.0 for sequence label"}, + ) + sample_rate: int = field( + default=16_000, + metadata={ + "help": "target sample rate. audio files will be up/down " + "sampled to this rate" + }, + ) + normalize: bool = field( + default=False, + metadata={"help": "if set, normalizes input to have 0 mean and unit variance"}, + ) + enable_padding: bool = field( + default=False, + metadata={"help": "pad shorter samples instead of cropping"}, + ) + max_keep_size: Optional[int] = field( + default=None, + metadata={"help": "exclude sample longer than this"}, + ) + max_sample_size: Optional[int] = field( + default=None, + metadata={"help": "max sample size to crop to for batching"}, + ) + min_sample_size: Optional[int] = field( + default=None, + metadata={"help": "min sample size to crop to for batching"}, + ) + single_target: Optional[bool] = field( + default=False, + metadata={ + "help": "if set, AddTargetDatasets outputs same keys " "as AddTargetDataset" + }, + ) + random_crop: Optional[bool] = field( + default=True, + metadata={"help": "always crop from the beginning if false"}, + ) + pad_audio: Optional[bool] = field( + default=False, + metadata={"help": "pad audio to the longest one in the batch if true"}, + ) + + +@register_task("hubert_pretraining", dataclass=HubertPretrainingConfig) +class HubertPretrainingTask(FairseqTask): + + cfg: HubertPretrainingConfig + + def __init__( + self, + cfg: HubertPretrainingConfig, + ) -> None: + super().__init__(cfg) + + logger.info(f"current directory is {os.getcwd()}") + logger.info(f"HubertPretrainingTask Config {cfg}") + + self.cfg = cfg + self.fine_tuning = cfg.fine_tuning + + if cfg.fine_tuning: + self.state.add_factory("target_dictionary", self.load_dictionaries) + else: + self.state.add_factory("dictionaries", self.load_dictionaries) + + self.blank_symbol = "" + + @property + def source_dictionary(self) -> Optional[Dictionary]: + return None + + @property + def target_dictionary(self) -> Optional[Dictionary]: + return self.state.target_dictionary + + @property + def dictionaries(self) -> List[Dictionary]: + return self.state.dictionaries + + @classmethod + def setup_task( + cls, cfg: HubertPretrainingConfig, **kwargs + ) -> "HubertPretrainingTask": + return cls(cfg) + + def load_dictionaries(self): + label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir + dictionaries = [ + Dictionary.load(f"{label_dir}/dict.{label}.txt") + for label in self.cfg.labels + ] + return dictionaries[0] if self.cfg.fine_tuning else dictionaries + + def get_label_dir(self) -> str: + if self.cfg.label_dir is None: + return self.cfg.data + return self.cfg.label_dir + + def load_dataset(self, split: str, **kwargs) -> None: + manifest = f"{self.cfg.data}/{split}.tsv" + dicts = [self.target_dictionary] if self.cfg.fine_tuning else self.dictionaries + pad_list = [dict.pad() for dict in dicts] + eos_list = [dict.eos() for dict in dicts] + procs = [LabelEncoder(dict) for dict in dicts] + paths = [f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels] + + # hubert v1: pad_audio=True, random_crop=False; + self.datasets[split] = HubertDataset( + manifest, + sample_rate=self.cfg.sample_rate, + label_paths=paths, + label_rates=self.cfg.label_rate, + pad_list=pad_list, + eos_list=eos_list, + label_processors=procs, + max_keep_sample_size=self.cfg.max_keep_size, + min_keep_sample_size=self.cfg.min_sample_size, + max_sample_size=self.cfg.max_sample_size, + pad_audio=self.cfg.pad_audio, + normalize=self.cfg.normalize, + store_labels=False, + random_crop=self.cfg.random_crop, + single_target=self.cfg.single_target, + ) + + def max_positions(self) -> Tuple[int, int]: + return (sys.maxsize, sys.maxsize) + + def filter_indices_by_size(self, indices: np.array, *args, **kwargs) -> np.array: + return indices diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py new file mode 100644 index 0000000000000000000000000000000000000000..44d5324b3d9bbc131a95ab4763fccdd8884729f6 --- /dev/null +++ b/fairseq/tasks/language_modeling.py @@ -0,0 +1,383 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from dataclasses import dataclass, field +from typing import Optional + +import numpy as np +import torch +from fairseq import utils +from fairseq.data import ( + AppendTokenDataset, + Dictionary, + IdDataset, + LMContextWindowDataset, + MonolingualDataset, + NestedDictionaryDataset, + NumelDataset, + PadDataset, + PrependTokenDataset, + StripTokenDataset, + TokenBlockDataset, + TruncatedDictionary, + data_utils, +) +from fairseq.data.indexed_dataset import get_available_dataset_impl +from fairseq.data.shorten_dataset import maybe_shorten_dataset +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.tasks import LegacyFairseqTask, register_task +from omegaconf import II + + +SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"]) +SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"]) +logger = logging.getLogger(__name__) + + +@dataclass +class LanguageModelingConfig(FairseqDataclass): + data: Optional[str] = field( + default=None, metadata={"help": "path to data directory"} + ) + sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field( + default="none", + metadata={ + "help": 'If omitted or "none", fills each sample with tokens-per-sample ' + 'tokens. If set to "complete", splits samples only at the end ' + "of sentence, but may include multiple sentences per sample. " + '"complete_doc" is similar but respects doc boundaries. ' + 'If set to "eos", includes only one sentence per sample.' + }, + ) + tokens_per_sample: int = field( + default=1024, + metadata={"help": "max number of tokens per sample for LM dataset"}, + ) + output_dictionary_size: int = field( + default=-1, metadata={"help": "limit the size of output dictionary"} + ) + self_target: bool = field(default=False, metadata={"help": "include self target"}) + future_target: bool = field( + default=False, metadata={"help": "include future target"} + ) + past_target: bool = field(default=False, metadata={"help": "include past target"}) + add_bos_token: bool = field( + default=False, metadata={"help": "prepend beginning of sentence token ()"} + ) + max_target_positions: Optional[int] = field( + default=None, metadata={"help": "max number of tokens in the target sequence"} + ) + shorten_method: SHORTEN_METHOD_CHOICES = field( + default="none", + metadata={ + "help": "if not none, shorten sequences that exceed --tokens-per-sample" + }, + ) + shorten_data_split_list: str = field( + default="", + metadata={ + "help": "comma-separated list of dataset splits to apply shortening to, " + 'e.g., "train,valid" (default: all dataset splits)' + }, + ) + pad_to_fixed_length: Optional[bool] = field( + default=False, + metadata={"help": "pad to fixed length"}, + ) + pad_to_fixed_bsz: Optional[bool] = field( + default=False, + metadata={"help": "boolean to pad to fixed batch size"}, + ) + + # TODO common vars below add to parent + seed: int = II("common.seed") + batch_size: Optional[int] = II("dataset.batch_size") + batch_size_valid: Optional[int] = II("dataset.batch_size_valid") + dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II( + "dataset.dataset_impl" + ) + data_buffer_size: int = II("dataset.data_buffer_size") + tpu: bool = II("common.tpu") + use_plasma_view: bool = II("common.use_plasma_view") + plasma_path: str = II("common.plasma_path") + + +@register_task("language_modeling", dataclass=LanguageModelingConfig) +class LanguageModelingTask(LegacyFairseqTask): + """ + Train a language model. + + Args: + dictionary (~fairseq.data.Dictionary): the dictionary for the input of + the language model + output_dictionary (~fairseq.data.Dictionary): the dictionary for the + output of the language model. In most cases it will be the same as + *dictionary*, but could possibly be a more limited version of the + dictionary (if ``--output-dictionary-size`` is used). + targets (List[str]): list of the target types that the language model + should predict. Can be one of "self", "future", and "past". + Defaults to "future". + + .. note:: + + The language modeling task is compatible with :mod:`fairseq-train`, + :mod:`fairseq-generate`, :mod:`fairseq-interactive` and + :mod:`fairseq-eval-lm`. + + The language modeling task provides the following additional command-line + arguments: + + .. argparse:: + :ref: fairseq.tasks.language_modeling_parser + :prog: + """ + + def __init__(self, args, dictionary, output_dictionary=None, targets=None): + super().__init__(args) + self.dictionary = dictionary + self.output_dictionary = output_dictionary or dictionary + + if targets is None: + targets = ["future"] + self.targets = targets + + @classmethod + def setup_dictionary(cls, args, **kwargs): + dictionary = None + output_dictionary = None + if args.data: + paths = utils.split_paths(args.data) + assert len(paths) > 0 + dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) + logger.info("dictionary: {} types".format(len(dictionary))) + output_dictionary = dictionary + if args.output_dictionary_size >= 0: + output_dictionary = TruncatedDictionary( + dictionary, args.output_dictionary_size + ) + return (dictionary, output_dictionary) + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + dictionary, output_dictionary = cls.setup_dictionary(args, **kwargs) + + # upgrade old checkpoints + if getattr(args, "exclude_self_target", False): + args.self_target = False + + targets = [] + if getattr(args, "self_target", False): + targets.append("self") + if getattr(args, "future_target", False): + targets.append("future") + if getattr(args, "past_target", False): + targets.append("past") + if len(targets) == 0: + # standard language modeling + targets = ["future"] + + return cls(args, dictionary, output_dictionary, targets=targets) + + def build_model(self, args, from_checkpoint=False): + model = super().build_model(args, from_checkpoint) + for target in self.targets: + if target not in model.supported_targets: + raise ValueError( + "Unsupported language modeling target: {}".format(target) + ) + + return model + + def load_dataset( + self, split: str, epoch=1, combine=False, **kwargs + ) -> MonolingualDataset: + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, valid1, test) + """ + paths = utils.split_paths(self.args.data) + assert len(paths) > 0 + + data_path = paths[(epoch - 1) % len(paths)] + split_path = os.path.join(data_path, split) + + # each process has its own copy of the raw data (likely to be an np.memmap) + dataset = data_utils.load_indexed_dataset( + split_path, self.dictionary, self.args.dataset_impl, combine=combine + ) + if dataset is None: + raise FileNotFoundError(f"Dataset not found: {split} ({split_path})") + + dataset = maybe_shorten_dataset( + dataset, + split, + self.args.shorten_data_split_list, + self.args.shorten_method, + self.args.tokens_per_sample, + self.args.seed, + ) + dataset = TokenBlockDataset( + dataset, + dataset.sizes, + self.args.tokens_per_sample, + pad=self.dictionary.pad(), + eos=self.dictionary.eos(), + break_mode=self.args.sample_break_mode, + include_targets=True, + use_plasma_view=self.args.use_plasma_view, + split_path=split_path, + plasma_path=self.args.plasma_path, + ) + + add_eos_for_other_targets = ( + self.args.sample_break_mode is not None + and self.args.sample_break_mode != "none" + ) + fixed_pad_length = None + if self.args.pad_to_fixed_length: + fixed_pad_length = self.args.tokens_per_sample + + pad_to_bsz = None + if self.args.pad_to_fixed_bsz: + pad_to_bsz = ( + self.args.batch_size_valid if "valid" in split else self.args.batch_size + ) + + self.datasets[split] = MonolingualDataset( + dataset=dataset, + sizes=dataset.sizes, + src_vocab=self.dictionary, + tgt_vocab=self.output_dictionary, + add_eos_for_other_targets=add_eos_for_other_targets, + shuffle=True, + targets=self.targets, + add_bos_token=self.args.add_bos_token, + fixed_pad_length=fixed_pad_length, + pad_to_bsz=pad_to_bsz, + ) + + def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): + """ + Generate batches for inference. We prepend an eos token to src_tokens + (or bos if `--add-bos-token` is set) and we append a to target. + This is convenient both for generation with a prefix and LM scoring. + """ + dataset = StripTokenDataset( + TokenBlockDataset( + src_tokens, + src_lengths, + block_size=None, # ignored for "eos" break mode + pad=self.source_dictionary.pad(), + eos=self.source_dictionary.eos(), + break_mode="eos", + ), + # remove eos from (end of) target sequence + self.source_dictionary.eos(), + ) + src_dataset = PrependTokenDataset( + dataset, + token=( + self.source_dictionary.bos() + if getattr(self.args, "add_bos_token", False) + else self.source_dictionary.eos() + ), + ) + tgt_dataset = AppendTokenDataset(dataset, token=self.source_dictionary.pad()) + return NestedDictionaryDataset( + { + "id": IdDataset(), + "net_input": { + "src_tokens": PadDataset( + src_dataset, + pad_idx=self.source_dictionary.pad(), + left_pad=False, + ), + "src_lengths": NumelDataset(src_dataset, reduce=False), + }, + "target": PadDataset( + tgt_dataset, pad_idx=self.source_dictionary.pad(), left_pad=False + ), + }, + sizes=[np.array(src_lengths)], + ) + + def inference_step( + self, generator, models, sample, prefix_tokens=None, constraints=None + ): + with torch.no_grad(): + # Generation will always be conditioned on bos_token + if getattr(self.args, "add_bos_token", False): + bos_token = self.source_dictionary.bos() + else: + bos_token = self.source_dictionary.eos() + + if constraints is not None: + raise NotImplementedError( + "Constrained decoding with the language_modeling task is not supported" + ) + + # SequenceGenerator doesn't use src_tokens directly, we need to + # pass the `prefix_tokens` argument instead + if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement(): + prefix_tokens = sample["net_input"]["src_tokens"] + if prefix_tokens[:, 0].eq(bos_token).all(): + prefix_tokens = prefix_tokens[:, 1:] + + return generator.generate( + models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token + ) + + def eval_lm_dataloader( + self, + dataset, + max_tokens: Optional[int] = 36000, + batch_size: Optional[int] = None, + max_positions: Optional[int] = None, + num_shards: int = 1, + shard_id: int = 0, + num_workers: int = 1, + data_buffer_size: int = 10, + # ensures that every evaluated token has access to a context of at least + # this size, if possible + context_window: int = 0, + ): + if context_window > 0: + dataset = LMContextWindowDataset( + dataset=dataset, + tokens_per_sample=self.args.tokens_per_sample, + context_window=context_window, + pad_idx=self.source_dictionary.pad(), + ) + return self.get_batch_iterator( + dataset=dataset, + max_tokens=max_tokens, + max_sentences=batch_size, + max_positions=max_positions, + ignore_invalid_inputs=True, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + data_buffer_size=data_buffer_size, + ).next_epoch_itr(shuffle=False) + + @property + def source_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return self.dictionary + + @property + def target_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return self.output_dictionary diff --git a/fairseq/tasks/legacy_masked_lm.py b/fairseq/tasks/legacy_masked_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..975497654926b64fff6c4960f54c4e6932e7fce1 --- /dev/null +++ b/fairseq/tasks/legacy_masked_lm.py @@ -0,0 +1,152 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import logging +import os + +import numpy as np +from fairseq import tokenizer, utils +from fairseq.data import ConcatDataset, Dictionary, data_utils, indexed_dataset +from fairseq.data.legacy.block_pair_dataset import BlockPairDataset +from fairseq.data.legacy.masked_lm_dataset import MaskedLMDataset +from fairseq.data.legacy.masked_lm_dictionary import BertDictionary +from fairseq.tasks import LegacyFairseqTask, register_task + + +logger = logging.getLogger(__name__) + + +@register_task("legacy_masked_lm") +class LegacyMaskedLMTask(LegacyFairseqTask): + """ + Task for training Masked LM (BERT) model. + Args: + dictionary (Dictionary): the dictionary for the input of the task + """ + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument( + "data", + help="colon separated path to data directories list, \ + will be iterated upon during epochs in round-robin manner", + ) + parser.add_argument( + "--tokens-per-sample", + default=512, + type=int, + help="max number of total tokens over all segments" + " per sample for BERT dataset", + ) + parser.add_argument( + "--break-mode", default="doc", type=str, help="mode for breaking sentence" + ) + parser.add_argument("--shuffle-dataset", action="store_true", default=False) + + def __init__(self, args, dictionary): + super().__init__(args) + self.dictionary = dictionary + self.seed = args.seed + + @classmethod + def load_dictionary(cls, filename): + return BertDictionary.load(filename) + + @classmethod + def build_dictionary( + cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8 + ): + d = BertDictionary() + for filename in filenames: + Dictionary.add_file_to_dictionary( + filename, d, tokenizer.tokenize_line, workers + ) + d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor) + return d + + @property + def target_dictionary(self): + return self.dictionary + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task.""" + paths = utils.split_paths(args.data) + assert len(paths) > 0 + dictionary = BertDictionary.load(os.path.join(paths[0], "dict.txt")) + logger.info("dictionary: {} types".format(len(dictionary))) + + return cls(args, dictionary) + + def load_dataset(self, split, epoch=1, combine=False): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + loaded_datasets = [] + + paths = utils.split_paths(self.args.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + logger.info("data_path", data_path) + + for k in itertools.count(): + split_k = split + (str(k) if k > 0 else "") + path = os.path.join(data_path, split_k) + ds = indexed_dataset.make_dataset( + path, + impl=self.args.dataset_impl, + fix_lua_indexing=True, + dictionary=self.dictionary, + ) + + if ds is None: + if k > 0: + break + else: + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, data_path) + ) + + with data_utils.numpy_seed(self.seed + k): + loaded_datasets.append( + BlockPairDataset( + ds, + self.dictionary, + ds.sizes, + self.args.tokens_per_sample, + break_mode=self.args.break_mode, + doc_break_size=1, + ) + ) + + logger.info( + "{} {} {} examples".format(data_path, split_k, len(loaded_datasets[-1])) + ) + + if not combine: + break + + if len(loaded_datasets) == 1: + dataset = loaded_datasets[0] + sizes = dataset.sizes + else: + dataset = ConcatDataset(loaded_datasets) + sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) + + self.datasets[split] = MaskedLMDataset( + dataset=dataset, + sizes=sizes, + vocab=self.dictionary, + pad_idx=self.dictionary.pad(), + mask_idx=self.dictionary.mask(), + classif_token_idx=self.dictionary.cls(), + sep_token_idx=self.dictionary.sep(), + shuffle=self.args.shuffle_dataset, + seed=self.seed, + ) diff --git a/fairseq/tasks/masked_lm.py b/fairseq/tasks/masked_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..b064907a50d69f925e424a5045e71608de881f13 --- /dev/null +++ b/fairseq/tasks/masked_lm.py @@ -0,0 +1,327 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from dataclasses import dataclass, field + +import numpy as np +from omegaconf import II, MISSING, OmegaConf + +from fairseq import utils +from fairseq.data import ( + Dictionary, + IdDataset, + MaskTokensDataset, + NestedDictionaryDataset, + NumelDataset, + NumSamplesDataset, + PrependTokenDataset, + RightPadDataset, + RightPaddingMaskDataset, + SortDataset, + TokenBlockDataset, + data_utils, +) +from fairseq.data.encoders.utils import get_whole_word_mask +from fairseq.data.shorten_dataset import maybe_shorten_dataset +from fairseq.dataclass import FairseqDataclass +from fairseq.tasks import FairseqTask, register_task + +from .language_modeling import SAMPLE_BREAK_MODE_CHOICES, SHORTEN_METHOD_CHOICES + +logger = logging.getLogger(__name__) + + +@dataclass +class MaskedLMConfig(FairseqDataclass): + data: str = field( + default=MISSING, + metadata={ + "help": "colon separated path to data directories list, \ + will be iterated upon during epochs in round-robin manner" + }, + ) + sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field( + default="none", + metadata={ + "help": 'If omitted or "none", fills each sample with tokens-per-sample ' + 'tokens. If set to "complete", splits samples only at the end ' + "of sentence, but may include multiple sentences per sample. " + '"complete_doc" is similar but respects doc boundaries. ' + 'If set to "eos", includes only one sentence per sample.' + }, + ) + tokens_per_sample: int = field( + default=1024, + metadata={"help": "max number of tokens per sample for LM dataset"}, + ) + mask_prob: float = field( + default=0.15, + metadata={"help": "probability of replacing a token with mask"}, + ) + leave_unmasked_prob: float = field( + default=0.1, + metadata={"help": "probability that a masked token is unmasked"}, + ) + random_token_prob: float = field( + default=0.1, + metadata={"help": "probability of replacing a token with a random token"}, + ) + freq_weighted_replacement: bool = field( + default=False, + metadata={"help": "sample random replacement words based on word frequencies"}, + ) + mask_whole_words: bool = field( + default=False, + metadata={"help": "mask whole words; you may also want to set --bpe"}, + ) + mask_multiple_length: int = field( + default=1, + metadata={"help": "repeat the mask indices multiple times"}, + ) + mask_stdev: float = field( + default=0.0, + metadata={"help": "stdev of the mask length"}, + ) + shorten_method: SHORTEN_METHOD_CHOICES = field( + default="none", + metadata={ + "help": "if not none, shorten sequences that exceed --tokens-per-sample" + }, + ) + shorten_data_split_list: str = field( + default="", + metadata={ + "help": "comma-separated list of dataset splits to apply shortening to, " + 'e.g., "train,valid" (default: all dataset splits)' + }, + ) + seed: int = II("common.seed") + + include_target_tokens: bool = field( + default=False, + metadata={ + "help": "include target tokens in model input. this is used for data2vec" + }, + ) + include_index: bool = field( + default=True, + metadata={"help": "include index in model input. this is used for data2vec"}, + ) + skip_masking: bool = field( + default=False, + metadata={"help": "skip masking at dataset"}, + ) + # subsample_train: float = field( + # default=1, + # metadata={"help": "shorten training set for debugging"}, + # ) + d2v2_multi: bool = field( + default=False, + metadata={"help": "prepare dataset for data2vec_multi"}, + ) + + +@register_task("masked_lm", dataclass=MaskedLMConfig) +class MaskedLMTask(FairseqTask): + + cfg: MaskedLMConfig + + """Task for training masked language models (e.g., BERT, RoBERTa).""" + + def __init__(self, cfg: MaskedLMConfig, dictionary=None): + super().__init__(cfg) + self.dictionary = dictionary or self.load_dict(cfg) + + # add mask token + self.mask_idx = self.dictionary.add_symbol("") + + @classmethod + def setup_task(cls, cfg: MaskedLMConfig, **kwargs): + dictionary = cls.load_dict(cfg) + return cls(cfg, dictionary) + + @classmethod + def load_dict(cls, cfg): + paths = utils.split_paths(cfg.data) + assert len(paths) > 0 + dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) + logger.info("dictionary: {} types".format(len(dictionary))) + return dictionary + + def _load_dataset_split(self, split, epoch, combine): + paths = utils.split_paths(self.cfg.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + split_path = os.path.join(data_path, split) + + dataset = data_utils.load_indexed_dataset( + split_path, + self.source_dictionary, + combine=combine, + ) + if dataset is None: + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, split_path) + ) + + dataset = maybe_shorten_dataset( + dataset, + split, + self.cfg.shorten_data_split_list, + self.cfg.shorten_method, + self.cfg.tokens_per_sample, + self.cfg.seed, + ) + + # create continuous blocks of tokens + dataset = TokenBlockDataset( + dataset, + dataset.sizes, + self.cfg.tokens_per_sample - 1, # one less for + pad=self.source_dictionary.pad(), + eos=self.source_dictionary.eos(), + break_mode=self.cfg.sample_break_mode, + ) + logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) + + # prepend beginning-of-sentence token (, equiv. to [CLS] in BERT) + return PrependTokenDataset(dataset, self.source_dictionary.bos()) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + dataset = self._load_dataset_split(split, epoch, combine) + + # create masked input and targets + mask_whole_words = ( + get_whole_word_mask(self.args, self.source_dictionary) + if self.cfg.mask_whole_words + else None + ) + + src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( + dataset, + self.source_dictionary, + pad_idx=self.source_dictionary.pad(), + mask_idx=self.mask_idx, + seed=self.cfg.seed, + mask_prob=self.cfg.mask_prob, + leave_unmasked_prob=self.cfg.leave_unmasked_prob, + random_token_prob=self.cfg.random_token_prob, + freq_weighted_replacement=self.cfg.freq_weighted_replacement, + mask_whole_words=mask_whole_words, + mask_multiple_length=self.cfg.mask_multiple_length, + mask_stdev=self.cfg.mask_stdev, + skip_masking=self.cfg.skip_masking, + ) + + with data_utils.numpy_seed(self.cfg.seed): + shuffle = np.random.permutation(len(src_dataset)) + + target_dataset = RightPadDataset( + tgt_dataset, + pad_idx=self.source_dictionary.pad(), + ) + + if self.cfg.d2v2_multi: + dataset = self._d2v2_multi_dataset(src_dataset) + else: + dataset = self._regular_dataset(src_dataset, target_dataset) + + self.datasets[split] = SortDataset( + dataset, sort_order=[shuffle, src_dataset.sizes] + ) + + def _regular_dataset(self, src_dataset, target_dataset): + input_dict = { + "src_tokens": RightPadDataset( + src_dataset, + pad_idx=self.source_dictionary.pad(), + ), + "src_lengths": NumelDataset(src_dataset, reduce=False), + } + if self.cfg.include_target_tokens: + input_dict["target_tokens"] = target_dataset + if self.cfg.include_index: + input_dict["src_id"] = IdDataset() + + dataset = NestedDictionaryDataset( + { + "id": IdDataset(), + "net_input": input_dict, + "target": target_dataset, + "nsentences": NumSamplesDataset(), + "ntokens": NumelDataset(src_dataset, reduce=True), + }, + sizes=[src_dataset.sizes], + ) + return dataset + + def _d2v2_multi_dataset(self, src_dataset): + input_dict = { + "source": RightPadDataset( + src_dataset, + pad_idx=self.source_dictionary.pad(), + ), + "id": IdDataset(), + "padding_mask": RightPaddingMaskDataset(src_dataset), + } + + dataset = NestedDictionaryDataset( + { + "id": IdDataset(), + "net_input": input_dict, + "nsentences": NumSamplesDataset(), + "ntokens": NumelDataset(src_dataset, reduce=True), + }, + sizes=[src_dataset.sizes], + ) + return dataset + + def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True): + src_dataset = RightPadDataset( + TokenBlockDataset( + src_tokens, + src_lengths, + self.cfg.tokens_per_sample - 1, # one less for + pad=self.source_dictionary.pad(), + eos=self.source_dictionary.eos(), + break_mode="eos", + ), + pad_idx=self.source_dictionary.pad(), + ) + src_dataset = PrependTokenDataset(src_dataset, self.source_dictionary.bos()) + src_dataset = NestedDictionaryDataset( + { + "id": IdDataset(), + "net_input": { + "src_tokens": src_dataset, + "src_lengths": NumelDataset(src_dataset, reduce=False), + }, + }, + sizes=src_lengths, + ) + if sort: + src_dataset = SortDataset(src_dataset, sort_order=[src_lengths]) + return src_dataset + + @property + def source_dictionary(self): + return self.dictionary + + @property + def target_dictionary(self): + return self.dictionary + + def begin_epoch(self, epoch, model): + model.set_epoch(epoch) + + def max_positions(self): + return self.cfg.tokens_per_sample diff --git a/fairseq/tasks/multilingual_denoising.py b/fairseq/tasks/multilingual_denoising.py new file mode 100644 index 0000000000000000000000000000000000000000..cb5ee345542940fedc899bdaa2947994b0b59554 --- /dev/null +++ b/fairseq/tasks/multilingual_denoising.py @@ -0,0 +1,268 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import logging +import os +from dataclasses import dataclass, field +from typing import Optional + +import numpy as np +from omegaconf import II + +from fairseq.data import ( + AppendTokenDataset, + ConcatDataset, + DenoisingDataset, + Dictionary, + PrependTokenDataset, + ResamplingDataset, + SortDataset, + TokenBlockDataset, + data_utils, +) +from fairseq.data.encoders.utils import get_whole_word_mask +from fairseq.tasks import register_task + +from .denoising import DenoisingConfig, DenoisingTask + +logger = logging.getLogger(__name__) + + +@dataclass +class MultilingualDenoisingConfig(DenoisingConfig): + multilang_sampling_alpha: float = field( + default=1.0, + metadata={"help": "smoothing alpha for sample ratios across multiple datasets"}, + ) + add_lang_token: bool = field( + default=False, + metadata={"help": ""}, + ) + langs: Optional[str] = field( + default=None, + metadata={"help": "language ids we are considering"}, + ) + no_whole_word_mask_langs: str = field( + default="", + metadata={ + "help": "languages without spacing between words don't support whole word masking" + }, + ) + train_subset: str = II("common.train_subset") + valid_subset: str = II("common.valid_subset") + + +@register_task("multilingual_denoising", dataclass=MultilingualDenoisingConfig) +class MultilingualDenoisingTask(DenoisingTask): + + cfg: MultilingualDenoisingConfig + + @classmethod + def setup_task(cls, cfg: MultilingualDenoisingConfig, **kwargs): + """Setup the task.""" + paths = cfg.data.split(":") + assert len(paths) > 0 + dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) + + data_path = paths[0] + if cfg.langs is None: + languages = sorted( + [ + name + for name in os.listdir(data_path) + if os.path.isdir(os.path.join(data_path, name)) + ] + ) + else: + languages = cfg.langs.split(",") + + if cfg.add_lang_token: + for lang in languages: + dictionary.add_symbol("[{}]".format(lang)) + + logger.info("dictionary: {} types".format(len(dictionary))) + if not hasattr(cfg, "shuffle_instance"): + cfg.shuffle_instance = False + return cls(cfg, dictionary) + + def __init__(self, cfg: MultilingualDenoisingConfig, dictionary): + super().__init__(cfg, dictionary) + self.dictionary = dictionary + + # add mask token + self.mask_idx = self.dictionary.add_symbol("") + self.cfg = cfg + + def _get_sample_prob(self, dataset_lens): + """ + Get smoothed sampling probability by languages. This helps low resource + languages by upsampling them. + """ + prob = dataset_lens / dataset_lens.sum() + smoothed_prob = prob**self.cfg.multilang_sampling_alpha + smoothed_prob = smoothed_prob / smoothed_prob.sum() + return smoothed_prob + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + paths = self.cfg.data.split(":") + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + split_path = os.path.join(data_path, split) + + if self.cfg.langs is None: + languages = sorted( + [ + name + for name in os.listdir(data_path) + if os.path.isdir(os.path.join(data_path, name)) + ] + ) + else: + languages = self.cfg.langs.split(",") + for name in languages: + p = os.path.join(data_path, name) + assert os.path.exists(p), "data not found: {}".format(p) + + logger.info("Training on {0} languages: {1}".format(len(languages), languages)) + logger.info( + "Language to id mapping: ", {lang: id for id, lang in enumerate(languages)} + ) + + mask_whole_words = get_whole_word_mask(self.cfg.bpe, self.dictionary) + language_without_segmentations = self.cfg.no_whole_word_mask_langs.split(",") + lang_datasets = [] + for language in languages: + split_path = os.path.join(data_path, language, split) + + dataset = data_utils.load_indexed_dataset( + split_path, + self.source_dictionary, + self.cfg.dataset_impl, + combine=combine, + ) + if dataset is None: + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, split_path) + ) + + end_token = ( + self.source_dictionary.index("[{}]".format(language)) + if self.cfg.add_lang_token + else self.source_dictionary.eos() + ) + + # create continuous blocks of tokens + dataset = TokenBlockDataset( + dataset, + dataset.sizes, + self.cfg.tokens_per_sample - 2, # one less for + pad=self.source_dictionary.pad(), + eos=end_token, + break_mode=self.cfg.sample_break_mode, + ) + logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) + + # prepend beginning-of-sentence token (, equiv. to [CLS] in BERT) + dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) + dataset = AppendTokenDataset(dataset, end_token) + + lang_mask_whole_words = ( + mask_whole_words + if language not in language_without_segmentations + else None + ) + lang_dataset = DenoisingDataset( + dataset, + dataset.sizes, + self.dictionary, + self.mask_idx, + lang_mask_whole_words, + shuffle=self.cfg.shuffle_instance, + seed=self.cfg.seed, + mask=self.cfg.mask, + mask_random=self.cfg.mask_random, + insert=self.cfg.insert, + rotate=self.cfg.rotate, + permute_sentences=self.cfg.permute_sentences, + bpe=self.cfg.bpe, + replace_length=self.cfg.replace_length, + mask_length=self.cfg.mask_length, + poisson_lambda=self.cfg.poisson_lambda, + eos=None + if not self.cfg.add_lang_token + else self.source_dictionary.index("[{}]".format(language)), + ) + lang_datasets.append(lang_dataset) + + dataset_lengths = np.array( + [len(d) for d in lang_datasets], + dtype=float, + ) + logger.info( + "loaded total {} blocks for all languages".format( + int(dataset_lengths.sum()), + ) + ) + if split == self.cfg.train_subset: + # For train subset, additionally up or down sample languages. + sample_probs = self._get_sample_prob(dataset_lengths) + logger.info( + "Sample probability by language: {}".format( + { + lang: "{0:.4f}".format(sample_probs[id]) + for id, lang in enumerate(languages) + } + ) + ) + size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths + logger.info( + "Up/Down Sampling ratio by language: {}".format( + { + lang: "{0:.2f}".format(size_ratio[id]) + for id, lang in enumerate(languages) + } + ) + ) + + resampled_lang_datasets = [ + ResamplingDataset( + lang_datasets[i], + size_ratio=size_ratio[i], + seed=self.cfg.seed, + epoch=epoch, + replace=size_ratio[i] >= 1.0, + ) + for i, d in enumerate(lang_datasets) + ] + dataset = ConcatDataset( + resampled_lang_datasets, + ) + else: + dataset = ConcatDataset(lang_datasets) + lang_splits = [split] + for lang_id, lang_dataset in enumerate(lang_datasets): + split_name = split + "_" + languages[lang_id] + lang_splits.append(split_name) + self.datasets[split_name] = lang_dataset + + if split in self.cfg.valid_subset: + self.cfg.valid_subset = self.cfg.valid_subset.replace( + split, ",".join(lang_splits) + ) + + with data_utils.numpy_seed(self.cfg.seed + epoch): + shuffle = np.random.permutation(len(dataset)) + + self.datasets[split] = SortDataset( + dataset, + sort_order=[ + shuffle, + dataset.sizes, + ], + ) diff --git a/fairseq/tasks/multilingual_language_modeling.py b/fairseq/tasks/multilingual_language_modeling.py new file mode 100644 index 0000000000000000000000000000000000000000..8fd5e5954d6692e5772dbda63f57f5c9669a575c --- /dev/null +++ b/fairseq/tasks/multilingual_language_modeling.py @@ -0,0 +1,627 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from dataclasses import dataclass, field +from typing import Optional + +import numpy as np +import torch +from omegaconf import II + +from fairseq import utils +from fairseq.data import ( + AppendTokenDataset, + ConcatDataset, + Dictionary, + IdDataset, + LMContextWindowDataset, + MonolingualDataset, + NestedDictionaryDataset, + NumelDataset, + PadDataset, + PrependTokenDataset, + ResamplingDataset, + SortDataset, + StripTokenDataset, + TokenBlockDataset, + TruncatedDictionary, + data_utils, +) +from fairseq.data.indexed_dataset import get_available_dataset_impl +from fairseq.data.shorten_dataset import maybe_shorten_dataset +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.tasks import LegacyFairseqTask, register_task + +SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"]) +SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"]) +logger = logging.getLogger(__name__) + + +def lang_token(lang): + return f"<{lang}>" + + +@dataclass +class MultilingualLanguageModelingConfig(FairseqDataclass): + # TODO common var add to parent + data: Optional[str] = field( + default=None, metadata={"help": "path to data directory"} + ) + sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field( + default="none", + metadata={ + "help": 'If omitted or "none", fills each sample with tokens-per-sample ' + 'tokens. If set to "complete", splits samples only at the end ' + "of sentence, but may include multiple sentences per sample. " + '"complete_doc" is similar but respects doc boundaries. ' + 'If set to "eos", includes only one sentence per sample.' + }, + ) + tokens_per_sample: int = field( + default=1024, + metadata={"help": "max number of tokens per sample for LM dataset"}, + ) + output_dictionary_size: int = field( + default=-1, metadata={"help": "limit the size of output dictionary"} + ) + self_target: bool = field(default=False, metadata={"help": "include self target"}) + future_target: bool = field( + default=False, metadata={"help": "include future target"} + ) + past_target: bool = field(default=False, metadata={"help": "include past target"}) + add_bos_token: bool = field( + default=False, metadata={"help": "prepend lang id token "} + ) + max_source_positions: Optional[int] = field( + default=None, metadata={"help": "max number of tokens in the source sequence"} + ) + max_target_positions: Optional[int] = field( + default=None, metadata={"help": "max number of tokens in the target sequence"} + ) + pad_to_fixed_length: Optional[bool] = field( + default=False, metadata={"help": "pad to fixed length"} + ) + pad_to_fixed_bsz: Optional[bool] = field( + default=False, metadata={"help": "boolean to pad to fixed batch size"} + ) + + multilang_sampling_alpha: Optional[float] = field( + default=1.0, + metadata={ + "help": "smoothing alpha for sample rations across multiple datasets" + }, + ) + + shorten_method: SHORTEN_METHOD_CHOICES = field( + default="none", + metadata={ + "help": "if not none, shorten sequences that exceed --tokens-per-sample" + }, + ) + shorten_data_split_list: str = field( + default="", + metadata={ + "help": "comma-separated list of dataset splits to apply shortening to, " + 'e.g., "train,valid" (default: all dataset splits)' + }, + ) + + langs: str = field( + default="", + metadata={ + "help": "comma-separated list of languages (default: all directories in data path)" + }, + ) + baseline_model_langs: str = field( + default="", + metadata={ + "help": "comma-separated list of languages in the baseline model (default: none)" + }, + ) + # TODO: legacy parameter kept for compatibility + baseline_model: str = field( + default="", + metadata={"help": "path to the baseline model (default: none)"}, + ) + + lang_to_offline_shard_ratio: str = field( + default="", + metadata={ + "help": "absolute path of tsv file location to indicate lang to offline shard ratio.", + }, + ) + # TODO common vars below add to parent + seed: int = II("common.seed") + dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II( + "dataset.dataset_impl" + ) + data_buffer_size: int = II("dataset.data_buffer_size") + tpu: bool = II("common.tpu") + batch_size: Optional[int] = II("dataset.batch_size") + batch_size_valid: Optional[int] = II("dataset.batch_size_valid") + train_subset: str = II("common.train_subset") + valid_subset: str = II("common.valid_subset") + + +@register_task( + "multilingual_language_modeling", dataclass=MultilingualLanguageModelingConfig +) +class MultilingualLanguageModelingTask(LegacyFairseqTask): + """ + Train a language model. + + Args: + dictionary (~fairseq.data.Dictionary): the dictionary for the input of + the language model + output_dictionary (~fairseq.data.Dictionary): the dictionary for the + output of the language model. In most cases it will be the same as + *dictionary*, but could possibly be a more limited version of the + dictionary (if ``--output-dictionary-size`` is used). + targets (List[str]): list of the target types that the language model + should predict. Can be one of "self", "future", and "past". + Defaults to "future". + + .. note:: + + The language modeling task is compatible with :mod:`fairseq-train`, + :mod:`fairseq-generate`, :mod:`fairseq-interactive` and + :mod:`fairseq-eval-lm`. + + The language modeling task provides the following additional command-line + arguments: + + .. argparse:: + :ref: fairseq.tasks.language_modeling_parser + :prog: + """ + + def __init__(self, args, dictionary, output_dictionary=None, targets=None): + super().__init__(args) + self.dictionary = dictionary + self.output_dictionary = output_dictionary or dictionary + + if targets is None: + targets = ["future"] + self.targets = targets + + @staticmethod + def _get_langs(args, epoch=1): + paths = utils.split_paths(args.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + + languages = sorted( + name + for name in os.listdir(data_path) + if os.path.isdir(os.path.join(data_path, name)) + ) + if args.langs: + keep_langs = set(args.langs.split(",")) + languages = [lang for lang in languages if lang in keep_langs] + assert len(languages) == len(keep_langs) + + return languages, data_path + + @classmethod + def setup_dictionary(cls, args, **kwargs): + dictionary = None + output_dictionary = None + if args.data: + paths = utils.split_paths(args.data) + assert len(paths) > 0 + dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) + if args.add_bos_token: + languages, _ = cls._get_langs(args) + logger.info("----------------") + for lang in languages: + dictionary.add_symbol(lang_token(lang)) + logger.info(f"add language token: {lang_token(lang)}") + logger.info("----------------") + + logger.info("dictionary: {} types".format(len(dictionary))) + output_dictionary = dictionary + if args.output_dictionary_size >= 0: + output_dictionary = TruncatedDictionary( + dictionary, args.output_dictionary_size + ) + return (dictionary, output_dictionary) + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + dictionary, output_dictionary = cls.setup_dictionary(args, **kwargs) + + # upgrade old checkpoints + if hasattr(args, "exclude_self_target"): + args.self_target = not args.exclude_self_target + + targets = [] + if getattr(args, "self_target", False): + targets.append("self") + if getattr(args, "future_target", False): + targets.append("future") + if getattr(args, "past_target", False): + targets.append("past") + if len(targets) == 0: + # standard language modeling + targets = ["future"] + + return cls(args, dictionary, output_dictionary, targets=targets) + + def build_model(self, args, from_checkpoint=False): + model = super().build_model(args, from_checkpoint) + for target in self.targets: + if target not in model.supported_targets: + raise ValueError( + f"Unsupported language modeling target: {target} not in {model.supported_targets}" + ) + + return model + + def _get_sample_prob(self, dataset_lens): + """ + Get smoothed sampling porbability by languages. This helps low resource + languages by upsampling them. + """ + prob = dataset_lens / dataset_lens.sum() + smoothed_prob = prob**self.args.multilang_sampling_alpha + smoothed_prob = smoothed_prob / smoothed_prob.sum() + return smoothed_prob + + def load_dataset(self, split: str, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + languages, data_path = MultilingualLanguageModelingTask._get_langs( + self.args, epoch + ) + lang_to_offline_shard_ratio = None + if self.args.lang_to_offline_shard_ratio != "": + lang_to_offline_shard_ratio = {} + assert os.path.exists( + self.args.lang_to_offline_shard_ratio + ), "provided offline shard ratio file doesn't exist: {0}".format( + self.args.lang_to_offline_shard_ratio + ) + with open(self.args.lang_to_offline_shard_ratio) as fin: + for line in fin: + lang, ratio = line.strip().split("\t") + ratio = float(ratio) + lang_to_offline_shard_ratio[lang] = ratio + + logger.info( + "Found offline sharded ratio: %s", + lang_to_offline_shard_ratio, + ) + + if split == self.args.train_subset: + logger.info( + "Training on {0} languages: {1}".format(len(languages), languages) + ) + else: + logger.info( + "Evaluating on {0} languages: {1}".format(len(languages), languages) + ) + + tokens_per_sample = self.args.tokens_per_sample - int(self.args.add_bos_token) + + fixed_pad_length = None + if self.args.pad_to_fixed_length: + fixed_pad_length = self.args.tokens_per_sample + + pad_to_bsz = None + if self.args.pad_to_fixed_bsz: + pad_to_bsz = ( + self.args.batch_size_valid if "valid" in split else self.args.batch_size + ) + + lang_datasets = [] + for lang_id, language in enumerate(languages): + split_path = os.path.join(data_path, language, split) + dataset = data_utils.load_indexed_dataset( + split_path, self.dictionary, self.args.dataset_impl, combine=combine + ) + # print('len(dataset) =', len(dataset)) + if dataset is None: + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, split_path) + ) + + dataset = maybe_shorten_dataset( + dataset, + split, + self.args.shorten_data_split_list, + self.args.shorten_method, + tokens_per_sample, + self.args.seed, + ) + + dataset = TokenBlockDataset( + dataset, + dataset.sizes, + tokens_per_sample, + pad=self.dictionary.pad(), + eos=self.dictionary.eos(), + break_mode=self.args.sample_break_mode, + include_targets=True, + ) + + add_eos_for_other_targets = ( + self.args.sample_break_mode is not None + and self.args.sample_break_mode != "none" + ) + src_lang_idx, tgt_lang_idx = None, None + if self.args.add_bos_token: + src_lang_idx = self.dictionary.index(lang_token(language)) + tgt_lang_idx = self.output_dictionary.index(lang_token(language)) + + lang_datasets.append( + MonolingualDataset( + dataset=dataset, + sizes=dataset.sizes, + src_vocab=self.dictionary, + tgt_vocab=self.output_dictionary, + add_eos_for_other_targets=add_eos_for_other_targets, + shuffle=True, + targets=self.targets, + fixed_pad_length=fixed_pad_length, + pad_to_bsz=pad_to_bsz, + add_bos_token=self.args.add_bos_token, + src_lang_idx=src_lang_idx, + tgt_lang_idx=tgt_lang_idx, + ) + ) + + dataset_lengths = np.array( + [len(d) for d in lang_datasets], + dtype=float, + ) + logger.info( + "loaded total {} blocks for all languages".format( + dataset_lengths.sum(), + ) + ) + if split == self.args.train_subset: + dataset_lengths_ratio_multiplier = np.ones(len(dataset_lengths)) + if lang_to_offline_shard_ratio is not None: + dataset_lengths_ratio_multiplier = [] + for lang in languages: + assert ( + lang in lang_to_offline_shard_ratio + ), "Lang: {0} missing in offline shard ratio file: {1}".format( + lang, + self.args.lang_to_offline_shard_ratio, + ) + dataset_lengths_ratio_multiplier.append( + lang_to_offline_shard_ratio[lang] + ) + dataset_lengths_ratio_multiplier = np.array( + dataset_lengths_ratio_multiplier + ) + true_dataset_lengths = ( + dataset_lengths * dataset_lengths_ratio_multiplier + ) + else: + true_dataset_lengths = dataset_lengths + # For train subset, additionally up or down sample languages. + sample_probs = self._get_sample_prob(true_dataset_lengths) + + logger.info( + "Sample probability by language: %s", + { + lang: "{0:.4f}".format(sample_probs[id]) + for id, lang in enumerate(languages) + }, + ) + size_ratio = (sample_probs * true_dataset_lengths.sum()) / dataset_lengths + # TODO: add an option for shrinking all size ratios to below 1 + # if self.args.multilang_sampling_alpha != 1: + # size_ratio /= size_ratio.max() + + # Fix numeric errors in size ratio computation + # 0.999999999999999999 -> 1 + # 1.000000000000000002 -> 1 + for i in range(len(size_ratio)): + size_ratio[i] = round(size_ratio[i], 8) + + logger.info( + "Up/Down Sampling ratio by language: %s", + { + lang: "{0:.2f}".format(size_ratio[id]) + for id, lang in enumerate(languages) + }, + ) + logger.info( + "Actual dataset size by language: %s", + { + lang: "{0:.2f}".format(len(lang_datasets[id])) + for id, lang in enumerate(languages) + }, + ) + resampled_lang_datasets = [ + ResamplingDataset( + lang_datasets[i], + size_ratio=size_ratio[i], + seed=self.args.seed, + epoch=epoch, + replace=size_ratio[i] > 1.0, + ) + for i, d in enumerate(lang_datasets) + ] + logger.info( + "Resampled dataset size by language: %s", + { + lang: "{0:.2f}".format(len(resampled_lang_datasets[id])) + for id, lang in enumerate(languages) + }, + ) + dataset = ConcatDataset(resampled_lang_datasets) + else: + dataset = ConcatDataset(lang_datasets) + lang_splits = [split] + for lang_id, lang_dataset in enumerate(lang_datasets): + split_name = split + "_" + languages[lang_id] + lang_splits.append(split_name) + self.datasets[split_name] = lang_dataset + + # [TODO]: This is hacky for now to print validation ppl for each + # language individually. Maybe need task API changes to allow it + # in more generic ways. + if split in self.args.valid_subset: + self.args.valid_subset = self.args.valid_subset.replace( + split, ",".join(lang_splits) + ) + + with data_utils.numpy_seed(self.args.seed + epoch): + shuffle = np.random.permutation(len(dataset)) + + self.datasets[split] = SortDataset( + dataset, + sort_order=[ + shuffle, + dataset.sizes, + ], + ) + + def build_dataset_for_inference( + self, src_tokens, src_lengths, language="en_XX", **kwargs + ): + """ + Generate batches for inference. We prepend an eos token to src_tokens + (or bos if `--add-bos-token` is set) and we append a to target. + This is convenient both for generation with a prefix and LM scoring. + """ + dataset = StripTokenDataset( + TokenBlockDataset( + src_tokens, + src_lengths, + block_size=None, # ignored for "eos" break mode + pad=self.source_dictionary.pad(), + eos=self.source_dictionary.eos(), + break_mode="eos", + ), + # remove eos from (end of) target sequence + self.source_dictionary.eos(), + ) + + src_lang_idx = self.dictionary.index(lang_token(language)) + src_dataset = PrependTokenDataset( + dataset, + token=( + (src_lang_idx or self.source_dictionary.bos()) + if getattr(self.args, "add_bos_token", False) + else self.source_dictionary.eos() + ), + ) + + max_seq_len = max(src_lengths) + 1 + tgt_dataset = AppendTokenDataset(dataset, token=self.source_dictionary.pad()) + return NestedDictionaryDataset( + { + "id": IdDataset(), + "net_input": { + "src_tokens": PadDataset( + src_dataset, + pad_idx=self.source_dictionary.pad(), + left_pad=False, + pad_length=max_seq_len, + ), + "src_lengths": NumelDataset(src_dataset, reduce=False), + }, + "target": PadDataset( + tgt_dataset, + pad_idx=self.source_dictionary.pad(), + left_pad=False, + pad_length=max_seq_len, + ), + }, + sizes=[np.array(src_lengths)], + ) + + @torch.no_grad() + def inference_step( + self, + generator, + models, + sample, + language="en_XX", + prefix_tokens=None, + constraints=None, + ): + # Generation will always be conditioned on bos_token + if getattr(self.args, "add_bos_token", False): + src_lang_idx = self.dictionary.index(lang_token(language)) + bos_token = src_lang_idx or self.source_dictionary.bos() + else: + bos_token = self.source_dictionary.eos() + + if constraints is not None: + raise NotImplementedError( + "Constrained decoding with the language_modeling task is not supported" + ) + + # SequenceGenerator doesn't use src_tokens directly, we need to + # pass the `prefix_tokens` argument instead + if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement(): + prefix_tokens = sample["net_input"]["src_tokens"] + if prefix_tokens[:, 0].eq(bos_token).all(): + prefix_tokens = prefix_tokens[:, 1:] + + return generator.generate( + models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token + ) + + def eval_lm_dataloader( + self, + dataset, + max_tokens: Optional[int] = 36000, + batch_size: Optional[int] = None, + max_positions: Optional[int] = None, + num_shards: int = 1, + shard_id: int = 0, + num_workers: int = 1, + data_buffer_size: int = 10, + # ensures that every evaluated token has access to a context of at least + # this size, if possible + context_window: int = 0, + ): + if context_window > 0: + dataset = LMContextWindowDataset( + dataset=dataset, + tokens_per_sample=self.args.tokens_per_sample, + context_window=context_window, + pad_idx=self.source_dictionary.pad(), + ) + return self.get_batch_iterator( + dataset=dataset, + max_tokens=max_tokens, + max_sentences=batch_size, + max_positions=max_positions, + ignore_invalid_inputs=True, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + data_buffer_size=data_buffer_size, + ) + + @property + def source_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return self.dictionary + + @property + def target_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return self.output_dictionary diff --git a/fairseq/tasks/multilingual_masked_lm.py b/fairseq/tasks/multilingual_masked_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..156d085aa4e65292d28f0d5c8e3fefd0e7d60e17 --- /dev/null +++ b/fairseq/tasks/multilingual_masked_lm.py @@ -0,0 +1,338 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os + +import numpy as np +import torch + +from fairseq import utils +from fairseq.data import ( + ConcatDataset, + Dictionary, + IdDataset, + MaskTokensDataset, + NestedDictionaryDataset, + NumelDataset, + NumSamplesDataset, + PadDataset, + PrependTokenDataset, + RawLabelDataset, + ResamplingDataset, + SortDataset, + TokenBlockDataset, + data_utils, + encoders, +) +from fairseq.tasks import LegacyFairseqTask, register_task + +logger = logging.getLogger(__name__) + + +@register_task("multilingual_masked_lm") +class MultiLingualMaskedLMTask(LegacyFairseqTask): + """Task for training masked language models (e.g., BERT, RoBERTa).""" + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument( + "data", + help="colon separated path to data directories list, \ + will be iterated upon during epochs in round-robin manner", + ) + parser.add_argument( + "--sample-break-mode", + default="complete", + choices=["none", "complete", "complete_doc", "eos"], + help='If omitted or "none", fills each sample with tokens-per-sample ' + 'tokens. If set to "complete", splits samples only at the end ' + "of sentence, but may include multiple sentences per sample. " + '"complete_doc" is similar but respects doc boundaries. ' + 'If set to "eos", includes only one sentence per sample.', + ) + parser.add_argument( + "--tokens-per-sample", + default=512, + type=int, + help="max number of total tokens over all segments " + "per sample for BERT dataset", + ) + parser.add_argument( + "--mask-prob", + default=0.15, + type=float, + help="probability of replacing a token with mask", + ) + parser.add_argument( + "--leave-unmasked-prob", + default=0.1, + type=float, + help="probability that a masked token is unmasked", + ) + parser.add_argument( + "--random-token-prob", + default=0.1, + type=float, + help="probability of replacing a token with a random token", + ) + parser.add_argument( + "--freq-weighted-replacement", + action="store_true", + help="sample random replacement words based on word frequencies", + ) + parser.add_argument( + "--mask-whole-words", + default=False, + action="store_true", + help="mask whole words; you may also want to set --bpe", + ) + parser.add_argument( + "--multilang-sampling-alpha", + type=float, + default=1.0, + help="smoothing alpha for sample rations across multiple datasets", + ) + + def __init__(self, args, dictionary): + super().__init__(args) + self.dictionary = dictionary + self.seed = args.seed + + # add mask token + self.mask_idx = dictionary.add_symbol("") + + @classmethod + def setup_task(cls, args, **kwargs): + paths = utils.split_paths(args.data) + assert len(paths) > 0 + dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) + logger.info("dictionary: {} types".format(len(dictionary))) + return cls(args, dictionary) + + def _get_whole_word_mask(self): + # create masked input and targets + if self.args.mask_whole_words: + bpe = encoders.build_bpe(self.args) + if bpe is not None: + + def is_beginning_of_word(i): + if i < self.source_dictionary.nspecial: + # special elements are always considered beginnings + return True + tok = self.source_dictionary[i] + if tok.startswith("madeupword"): + return True + try: + return bpe.is_beginning_of_word(tok) + except ValueError: + return True + + mask_whole_words = torch.ByteTensor( + list(map(is_beginning_of_word, range(len(self.source_dictionary)))) + ) + else: + mask_whole_words = None + return mask_whole_words + + def _get_sample_prob(self, dataset_lens): + """ + Get smoothed sampling porbability by languages. This helps low resource + languages by upsampling them. + """ + prob = dataset_lens / dataset_lens.sum() + smoothed_prob = prob**self.args.multilang_sampling_alpha + smoothed_prob = smoothed_prob / smoothed_prob.sum() + return smoothed_prob + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + paths = utils.split_paths(self.args.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + + languages = sorted( + name + for name in os.listdir(data_path) + if os.path.isdir(os.path.join(data_path, name)) + ) + + logger.info("Training on {0} languages: {1}".format(len(languages), languages)) + logger.info( + "Language to id mapping: ", {lang: id for id, lang in enumerate(languages)} + ) + + mask_whole_words = self._get_whole_word_mask() + lang_datasets = [] + for lang_id, language in enumerate(languages): + split_path = os.path.join(data_path, language, split) + + dataset = data_utils.load_indexed_dataset( + split_path, + self.source_dictionary, + self.args.dataset_impl, + combine=combine, + ) + if dataset is None: + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, split_path) + ) + + # create continuous blocks of tokens + dataset = TokenBlockDataset( + dataset, + dataset.sizes, + self.args.tokens_per_sample - 1, # one less for + pad=self.source_dictionary.pad(), + eos=self.source_dictionary.eos(), + break_mode=self.args.sample_break_mode, + ) + logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) + + # prepend beginning-of-sentence token (, equiv. to [CLS] in BERT) + dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) + + src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( + dataset, + self.source_dictionary, + pad_idx=self.source_dictionary.pad(), + mask_idx=self.mask_idx, + seed=self.args.seed, + mask_prob=self.args.mask_prob, + leave_unmasked_prob=self.args.leave_unmasked_prob, + random_token_prob=self.args.random_token_prob, + freq_weighted_replacement=self.args.freq_weighted_replacement, + mask_whole_words=mask_whole_words, + ) + + lang_dataset = NestedDictionaryDataset( + { + "net_input": { + "src_tokens": PadDataset( + src_dataset, + pad_idx=self.source_dictionary.pad(), + left_pad=False, + ), + "src_lengths": NumelDataset(src_dataset, reduce=False), + }, + "target": PadDataset( + tgt_dataset, + pad_idx=self.source_dictionary.pad(), + left_pad=False, + ), + "nsentences": NumSamplesDataset(), + "ntokens": NumelDataset(src_dataset, reduce=True), + "lang_id": RawLabelDataset([lang_id] * src_dataset.sizes.shape[0]), + }, + sizes=[src_dataset.sizes], + ) + lang_datasets.append(lang_dataset) + + dataset_lengths = np.array( + [len(d) for d in lang_datasets], + dtype=float, + ) + logger.info( + "loaded total {} blocks for all languages".format( + dataset_lengths.sum(), + ) + ) + if split == self.args.train_subset: + # For train subset, additionally up or down sample languages. + sample_probs = self._get_sample_prob(dataset_lengths) + logger.info( + "Sample probability by language: ", + { + lang: "{0:.4f}".format(sample_probs[id]) + for id, lang in enumerate(languages) + }, + ) + size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths + logger.info( + "Up/Down Sampling ratio by language: ", + { + lang: "{0:.2f}".format(size_ratio[id]) + for id, lang in enumerate(languages) + }, + ) + + resampled_lang_datasets = [ + ResamplingDataset( + lang_datasets[i], + size_ratio=size_ratio[i], + seed=self.args.seed, + epoch=epoch, + replace=size_ratio[i] >= 1.0, + ) + for i, d in enumerate(lang_datasets) + ] + dataset = ConcatDataset(resampled_lang_datasets) + else: + dataset = ConcatDataset(lang_datasets) + lang_splits = [split] + for lang_id, lang_dataset in enumerate(lang_datasets): + split_name = split + "_" + languages[lang_id] + lang_splits.append(split_name) + self.datasets[split_name] = lang_dataset + + # [TODO]: This is hacky for now to print validation ppl for each + # language individually. Maybe need task API changes to allow it + # in more generic ways. + if split in self.args.valid_subset: + self.args.valid_subset = self.args.valid_subset.replace( + split, ",".join(lang_splits) + ) + + with data_utils.numpy_seed(self.args.seed + epoch): + shuffle = np.random.permutation(len(dataset)) + + self.datasets[split] = SortDataset( + dataset, + sort_order=[ + shuffle, + dataset.sizes, + ], + ) + + def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True): + src_dataset = PadDataset( + TokenBlockDataset( + src_tokens, + src_lengths, + self.args.tokens_per_sample - 1, # one less for + pad=self.source_dictionary.pad(), + eos=self.source_dictionary.eos(), + break_mode="eos", + ), + pad_idx=self.source_dictionary.pad(), + left_pad=False, + ) + src_dataset = PrependTokenDataset(src_dataset, self.source_dictionary.bos()) + src_dataset = NestedDictionaryDataset( + { + "id": IdDataset(), + "net_input": { + "src_tokens": src_dataset, + "src_lengths": NumelDataset(src_dataset, reduce=False), + }, + }, + sizes=src_lengths, + ) + if sort: + src_dataset = SortDataset(src_dataset, sort_order=[src_lengths]) + return src_dataset + + @property + def source_dictionary(self): + return self.dictionary + + @property + def target_dictionary(self): + return self.dictionary diff --git a/fairseq/tasks/multilingual_translation.py b/fairseq/tasks/multilingual_translation.py new file mode 100644 index 0000000000000000000000000000000000000000..cef7656691aeb8042e42d24eccd328a85fe32e27 --- /dev/null +++ b/fairseq/tasks/multilingual_translation.py @@ -0,0 +1,463 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import logging +import os +from collections import OrderedDict +from argparse import ArgumentError + +import torch +from fairseq import options, utils +from fairseq.logging import metrics +from fairseq.data import ( + Dictionary, + LanguagePairDataset, + RoundRobinZipDatasets, + TransformEosLangPairDataset, +) +from fairseq.models import FairseqMultiModel +from fairseq.tasks.translation import load_langpair_dataset + +from . import LegacyFairseqTask, register_task + + +logger = logging.getLogger(__name__) + + +def _lang_token(lang: str): + return "__{}__".format(lang) + + +def _lang_token_index(dic: Dictionary, lang: str): + """Return language token index.""" + idx = dic.index(_lang_token(lang)) + assert idx != dic.unk_index, "cannot find language token for lang {}".format(lang) + return idx + + +@register_task("multilingual_translation") +class MultilingualTranslationTask(LegacyFairseqTask): + """A task for training multiple translation models simultaneously. + + We iterate round-robin over batches from multiple language pairs, ordered + according to the `--lang-pairs` argument. + + The training loop is roughly: + + for i in range(len(epoch)): + for lang_pair in args.lang_pairs: + batch = next_batch_for_lang_pair(lang_pair) + loss = criterion(model_for_lang_pair(lang_pair), batch) + loss.backward() + optimizer.step() + + In practice, `next_batch_for_lang_pair` is abstracted in a FairseqDataset + (e.g., `RoundRobinZipDatasets`) and `model_for_lang_pair` is a model that + implements the `FairseqMultiModel` interface. + + During inference it is required to specify a single `--source-lang` and + `--target-lang`, which indicates the inference langauge direction. + `--lang-pairs`, `--encoder-langtok`, `--decoder-langtok` have to be set to + the same value as training. + """ + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + # fmt: off + parser.add_argument('data', metavar='DIR', help='path to data directory') + parser.add_argument('--lang-pairs', default=None, metavar='PAIRS', + help='comma-separated list of language pairs (in training order): en-de,en-fr,de-fr') + parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', + help='source language (only needed for inference)') + parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', + help='target language (only needed for inference)') + parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL', + help='pad the source on the left (default: True)') + parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL', + help='pad the target on the left (default: False)') + try: + parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N', + help='max number of tokens in the source sequence') + parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', + help='max number of tokens in the target sequence') + except ArgumentError: + # this might have already been defined. Once we transition this to hydra it should be fine to add it here. + pass + parser.add_argument('--upsample-primary', default=1, type=int, + help='amount to upsample primary dataset') + parser.add_argument('--encoder-langtok', default=None, type=str, choices=['src', 'tgt'], + metavar='SRCTGT', + help='replace beginning-of-sentence in source sentence with source or target ' + 'language token. (src/tgt)') + parser.add_argument('--decoder-langtok', action='store_true', + help='replace beginning-of-sentence in target sentence with target language token') + # fmt: on + + def __init__(self, args, dicts, training): + super().__init__(args) + self.dicts = dicts + self.training = training + if training: + self.lang_pairs = args.lang_pairs + else: + self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)] + # eval_lang_pairs for multilingual translation is usually all of the + # lang_pairs. However for other multitask settings or when we want to + # optimize for certain languages we want to use a different subset. Thus + # the eval_lang_pairs class variable is provided for classes that extend + # this class. + self.eval_lang_pairs = self.lang_pairs + # model_lang_pairs will be used to build encoder-decoder model pairs in + # models.build_model(). This allows multitask type of sub-class can + # build models other than the input lang_pairs + self.model_lang_pairs = self.lang_pairs + self.langs = list(dicts.keys()) + + @classmethod + def setup_task(cls, args, **kwargs): + dicts, training = cls.prepare(args, **kwargs) + return cls(args, dicts, training) + + @classmethod + def update_args(cls, args): + args.left_pad_source = utils.eval_bool(args.left_pad_source) + args.left_pad_target = utils.eval_bool(args.left_pad_target) + + if args.lang_pairs is None: + raise ValueError( + "--lang-pairs is required. List all the language pairs in the training objective." + ) + if isinstance(args.lang_pairs, str): + args.lang_pairs = args.lang_pairs.split(",") + + @classmethod + def prepare(cls, args, **kargs): + cls.update_args(args) + sorted_langs = sorted( + list({x for lang_pair in args.lang_pairs for x in lang_pair.split("-")}) + ) + if args.source_lang is not None or args.target_lang is not None: + training = False + else: + training = True + + # load dictionaries + dicts = OrderedDict() + for lang in sorted_langs: + paths = utils.split_paths(args.data) + assert len(paths) > 0 + dicts[lang] = cls.load_dictionary( + os.path.join(paths[0], "dict.{}.txt".format(lang)) + ) + if len(dicts) > 0: + assert dicts[lang].pad() == dicts[sorted_langs[0]].pad() + assert dicts[lang].eos() == dicts[sorted_langs[0]].eos() + assert dicts[lang].unk() == dicts[sorted_langs[0]].unk() + if args.encoder_langtok is not None or args.decoder_langtok: + for lang_to_add in sorted_langs: + dicts[lang].add_symbol(_lang_token(lang_to_add)) + logger.info("[{}] dictionary: {} types".format(lang, len(dicts[lang]))) + return dicts, training + + def get_encoder_langtok(self, src_lang, tgt_lang): + if self.args.encoder_langtok is None: + return self.dicts[src_lang].eos() + if self.args.encoder_langtok == "src": + return _lang_token_index(self.dicts[src_lang], src_lang) + else: + return _lang_token_index(self.dicts[src_lang], tgt_lang) + + def get_decoder_langtok(self, tgt_lang): + if not self.args.decoder_langtok: + return self.dicts[tgt_lang].eos() + return _lang_token_index(self.dicts[tgt_lang], tgt_lang) + + def alter_dataset_langtok( + self, + lang_pair_dataset, + src_eos=None, + src_lang=None, + tgt_eos=None, + tgt_lang=None, + ): + if self.args.encoder_langtok is None and not self.args.decoder_langtok: + return lang_pair_dataset + + new_src_eos = None + if ( + self.args.encoder_langtok is not None + and src_eos is not None + and src_lang is not None + and tgt_lang is not None + ): + new_src_eos = self.get_encoder_langtok(src_lang, tgt_lang) + else: + src_eos = None + + new_tgt_bos = None + if self.args.decoder_langtok and tgt_eos is not None and tgt_lang is not None: + new_tgt_bos = self.get_decoder_langtok(tgt_lang) + else: + tgt_eos = None + + return TransformEosLangPairDataset( + lang_pair_dataset, + src_eos=src_eos, + new_src_eos=new_src_eos, + tgt_bos=tgt_eos, + new_tgt_bos=new_tgt_bos, + ) + + def load_dataset(self, split, epoch=1, **kwargs): + """Load a dataset split.""" + paths = utils.split_paths(self.args.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + + def language_pair_dataset(lang_pair): + src, tgt = lang_pair.split("-") + langpair_dataset = load_langpair_dataset( + data_path, + split, + src, + self.dicts[src], + tgt, + self.dicts[tgt], + combine=True, + dataset_impl=self.args.dataset_impl, + upsample_primary=self.args.upsample_primary, + left_pad_source=self.args.left_pad_source, + left_pad_target=self.args.left_pad_target, + max_source_positions=self.args.max_source_positions, + max_target_positions=self.args.max_target_positions, + ) + return self.alter_dataset_langtok( + langpair_dataset, + src_eos=self.dicts[src].eos(), + src_lang=src, + tgt_eos=self.dicts[tgt].eos(), + tgt_lang=tgt, + ) + + self.datasets[split] = RoundRobinZipDatasets( + OrderedDict( + [ + (lang_pair, language_pair_dataset(lang_pair)) + for lang_pair in self.lang_pairs + ] + ), + eval_key=None + if self.training + else "%s-%s" % (self.args.source_lang, self.args.target_lang), + ) + + def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): + if constraints is not None: + raise NotImplementedError( + "Constrained decoding with the multilingual_translation task is not supported" + ) + + lang_pair = "%s-%s" % (self.args.source_lang, self.args.target_lang) + return RoundRobinZipDatasets( + OrderedDict( + [ + ( + lang_pair, + self.alter_dataset_langtok( + LanguagePairDataset( + src_tokens, src_lengths, self.source_dictionary + ), + src_eos=self.source_dictionary.eos(), + src_lang=self.args.source_lang, + tgt_eos=self.target_dictionary.eos(), + tgt_lang=self.args.target_lang, + ), + ) + ] + ), + eval_key=lang_pair, + ) + + def build_model(self, args, from_checkpoint=False): + def check_args(): + messages = [] + if ( + len(set(self.args.lang_pairs).symmetric_difference(args.lang_pairs)) + != 0 + ): + messages.append( + "--lang-pairs should include all the language pairs {}.".format( + args.lang_pairs + ) + ) + if self.args.encoder_langtok != args.encoder_langtok: + messages.append( + "--encoder-langtok should be {}.".format(args.encoder_langtok) + ) + if self.args.decoder_langtok != args.decoder_langtok: + messages.append( + "--decoder-langtok should {} be set.".format( + "" if args.decoder_langtok else "not" + ) + ) + + if len(messages) > 0: + raise ValueError(" ".join(messages)) + + # Update args -> the fact that the constructor here + # changes the args object doesn't mean you get the same one here + self.update_args(args) + + # Check if task args are consistant with model args + check_args() + + from fairseq import models + + model = models.build_model(args, self, from_checkpoint) + if not isinstance(model, FairseqMultiModel): + raise ValueError( + "MultilingualTranslationTask requires a FairseqMultiModel architecture" + ) + return model + + def _per_lang_pair_train_loss( + self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad + ): + loss, sample_size, logging_output = criterion( + model.models[lang_pair], sample[lang_pair] + ) + if ignore_grad: + loss *= 0 + optimizer.backward(loss) + return loss, sample_size, logging_output + + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): + model.train() + from collections import defaultdict + + agg_loss, agg_sample_size, agg_logging_output = 0.0, 0.0, defaultdict(float) + curr_lang_pairs = [ + lang_pair + for lang_pair in self.model_lang_pairs + if sample[lang_pair] is not None and len(sample[lang_pair]) != 0 + ] + + for idx, lang_pair in enumerate(curr_lang_pairs): + + def maybe_no_sync(): + if ( + self.args.distributed_world_size > 1 + and hasattr(model, "no_sync") + and idx < len(curr_lang_pairs) - 1 + ): + return model.no_sync() + else: + return contextlib.ExitStack() # dummy contextmanager + + with maybe_no_sync(): + loss, sample_size, logging_output = self._per_lang_pair_train_loss( + lang_pair, + model, + update_num, + criterion, + sample, + optimizer, + ignore_grad, + ) + agg_loss += loss.detach().item() + # TODO make summing of the sample sizes configurable + agg_sample_size += sample_size + for k in logging_output: + agg_logging_output[k] += logging_output[k] + agg_logging_output[f"{lang_pair}:{k}"] += logging_output[k] + return agg_loss, agg_sample_size, agg_logging_output + + def _per_lang_pair_valid_loss(self, lang_pair, model, criterion, sample): + return criterion(model.models[lang_pair], sample[lang_pair]) + + def valid_step(self, sample, model, criterion): + model.eval() + with torch.no_grad(): + from collections import defaultdict + + agg_loss, agg_sample_size, agg_logging_output = 0.0, 0.0, defaultdict(float) + for lang_pair in self.eval_lang_pairs: + if ( + lang_pair not in sample + or sample[lang_pair] is None + or len(sample[lang_pair]) == 0 + ): + continue + loss, sample_size, logging_output = self._per_lang_pair_valid_loss( + lang_pair, model, criterion, sample + ) + agg_loss += loss.data.item() + # TODO make summing of the sample sizes configurable + agg_sample_size += sample_size + for k in logging_output: + agg_logging_output[k] += logging_output[k] + agg_logging_output[f"{lang_pair}:{k}"] += logging_output[k] + return agg_loss, agg_sample_size, agg_logging_output + + def inference_step( + self, generator, models, sample, prefix_tokens=None, constraints=None + ): + with torch.no_grad(): + if self.args.decoder_langtok: + bos_token = _lang_token_index( + self.target_dictionary, self.args.target_lang + ) + else: + bos_token = self.target_dictionary.eos() + return generator.generate( + models, + sample, + prefix_tokens=prefix_tokens, + constraints=constraints, + bos_token=bos_token, + ) + + def reduce_metrics(self, logging_outputs, criterion): + with metrics.aggregate(): + # pass 'sample_size', 'nsentences', 'ntokens' stats to fairseq_task + super().reduce_metrics(logging_outputs, criterion) + for k in ["sample_size", "nsentences", "ntokens"]: + metrics.log_scalar(k, sum(l[k] for l in logging_outputs)) + + @property + def source_dictionary(self): + if self.training: + return next(iter(self.dicts.values())) + else: + return self.dicts[self.args.source_lang] + + @property + def target_dictionary(self): + if self.training: + return next(iter(self.dicts.values())) + else: + return self.dicts[self.args.target_lang] + + def max_positions(self): + """Return the max sentence length allowed by the task.""" + if len(self.datasets.values()) == 0: + return { + "%s-%s" + % (self.args.source_lang, self.args.target_lang): ( + self.args.max_source_positions, + self.args.max_target_positions, + ) + } + return OrderedDict( + [ + (key, (self.args.max_source_positions, self.args.max_target_positions)) + for split in self.datasets.keys() + for key in self.datasets[split].datasets.keys() + ] + ) diff --git a/fairseq/tasks/nlu_finetuning.py b/fairseq/tasks/nlu_finetuning.py new file mode 100644 index 0000000000000000000000000000000000000000..a335021335a417aaaf6e6a3b3a02f525ed933a46 --- /dev/null +++ b/fairseq/tasks/nlu_finetuning.py @@ -0,0 +1,477 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import logging +import os +import torch +import json + +from argparse import Namespace +from dataclasses import dataclass, field +from typing import Optional, Any + +from fairseq.data import AddTargetDataset, Dictionary, encoders +from fairseq.tasks.audio_pretraining import AudioPretrainingTask, AudioPretrainingConfig +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.configs import GenerationConfig +from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel + +from . import register_task +from .. import utils +from ..logging import metrics + + +logger = logging.getLogger(__name__) + + +class LabelEncoder(object): + def __init__(self, dictionary): + self.dictionary = dictionary + + def __call__(self, label): + return self.dictionary.encode_line( + label, append_eos=False, add_if_not_exist=False + ) + + +def label_len_fn(label): + return len(label.split(" ")) + + +@dataclass +class NLUFinetuningConfig(AudioPretrainingConfig): + # Options for reporting WER metrics during validation. Only applicable to + # Seq2Seq models during fine-tuning + eval_wer: bool = field( + default=False, metadata={"help": "compute WER for Seq2Seq models"} + ) + eval_wer_parse: bool = field( + default=False, metadata={"help": "compute WER for Seq2Seq models"} + ) + eval_wer_config: GenerationConfig = field( + default_factory=lambda: GenerationConfig(), + metadata={"help": "beam search config for evaluating wer during training"}, + ) + eval_wer_tokenizer: Any = field( + default=None, + metadata={"help": "tokenizer config for evaluating wer during training"}, + ) + eval_wer_post_process: str = field( + default="letter", + metadata={ + "help": "remove BPE tokens before scoring (can be sentencepiece, letter, and more)" + }, + ) + eval_bleu: bool = field( + default=False, metadata={"help": "evaluation with BLEU scores"} + ) + eval_bleu_detok: Optional[str] = field( + default=None, + metadata={ + "help": "detokenize before computing BLEU (e.g., 'moses'); " + "required if using --eval-bleu; use 'space' to disable " + "detokenization; see fairseq.data.encoders for other options" + }, + ) + eval_bleu_detok_args: str = field( + default="{}", metadata={"help": "args for building the tokenizer, if needed"} + ) + eval_tokenized_bleu: bool = field( + default=False, metadata={"help": "compute tokenized BLEU instead of sacrebleu"} + ) + eval_bleu_remove_bpe: Optional[str] = field( + default=None, metadata={"help": "remove BPE before computing BLEU"} + ) + eval_bleu_args: str = field( + default="{}", + metadata={ + "help": "generation args for BLUE scoring, e.g., " + '\'{"beam": 4, "lenpen": 0.6}\'' + }, + ) + eval_bleu_print_samples: bool = field( + default=False, metadata={"help": "print sample generations during validation"} + ) + autoregressive: bool = field( + default=False, + metadata={ + "help": "required for autoregressive decoders (like seq2seq models); " + "adds 'prev_output_tokens' to input and appends eos to target" + }, + ) + + +@register_task("nlu_finetuning", dataclass=NLUFinetuningConfig) +class NLUFinetuningTask(AudioPretrainingTask): + """ """ + + cfg: NLUFinetuningConfig + + def __init__( + self, + cfg: NLUFinetuningConfig, + ): + super().__init__(cfg) + self.blank_symbol = "" + + self.state.add_factory("target_dictionary", self.load_target_dictionary) + + def load_target_dictionary(self): + if self.cfg.labels: + dict_path = os.path.join(self.cfg.data, f"dict.{self.cfg.labels}.txt") + return Dictionary.load(dict_path) + return None + + def load_dataset(self, split: str, task_cfg: NLUFinetuningConfig = None, **kwargs): + super().load_dataset(split, task_cfg, **kwargs) + + task_cfg = task_cfg or self.cfg + assert task_cfg.labels is not None + text_compression_level = getattr( + TextCompressionLevel, str(self.cfg.text_compression_level) + ) + data_path = self.cfg.data + label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") + skipped_indices = getattr(self.datasets[split], "skipped_indices", set()) + text_compressor = TextCompressor(level=text_compression_level) + with open(label_path, "r") as f: + labels = [ + text_compressor.compress(l) + for i, l in enumerate(f) + if i not in skipped_indices + ] + + assert len(labels) == len(self.datasets[split]), ( + f"labels length ({len(labels)}) and dataset length " + f"({len(self.datasets[split])}) do not match" + ) + + process_label = LabelEncoder(self.target_dictionary) + + self.datasets[split] = AddTargetDataset( + self.datasets[split], + labels, + pad=self.target_dictionary.pad(), + eos=self.target_dictionary.eos(), + batch_targets=True, + process_label=process_label, + label_len_fn=label_len_fn, + add_to_input=task_cfg.get("autoregressive", False), + text_compression_level=text_compression_level, + ) + + @property + def target_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return self.state.target_dictionary + + def valid_step(self, sample, model, criterion): + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) + if self.cfg.eval_wer_parse and self.cfg.autoregressive: + metrics = self._inference_with_wer_parse( + self.sequence_generator, sample, model + ) + logging_output["_num_char_errors"] = metrics["num_char_errors"] + logging_output["_num_chars"] = metrics["num_chars"] + logging_output["_num_word_errors"] = metrics["num_word_errors"] + logging_output["_num_words"] = metrics["num_words"] + logging_output["_num_em_errors"] = metrics["num_em_errors"] + logging_output["_num_ems"] = metrics["num_ems"] + logging_output["_num_tree_errors"] = metrics["num_tree_errors"] + logging_output["_num_trees"] = metrics["num_trees"] + if self.cfg.eval_wer and self.cfg.autoregressive: + metrics = self._inference_with_wer(self.sequence_generator, sample, model) + logging_output["_num_char_errors"] = metrics["num_char_errors"] + logging_output["_num_chars"] = metrics["num_chars"] + logging_output["_num_word_errors"] = metrics["num_word_errors"] + logging_output["_num_words"] = metrics["num_words"] + if self.cfg.eval_bleu and self.cfg.autoregressive: + metrics = self._inference_with_bleu(self.sequence_generator, sample, model) + logging_output["_bleu_sys_len"] = metrics.sys_len + logging_output["_bleu_ref_len"] = metrics.ref_len + # we split counts into separate entries so that they can be + # summed efficiently across workers using fast-stat-sync + assert len(metrics.counts) == 4 + for i in range(4): + logging_output[f"_bleu_counts_{i}"] = metrics.counts[i] + logging_output[f"_bleu_totals_{i}"] = metrics.totals[i] + return loss, sample_size, logging_output + + def build_model(self, model_cfg: FairseqDataclass): + model = super().build_model(model_cfg) + + if (self.cfg.eval_wer or self.cfg.eval_wer_parse) and self.cfg.autoregressive: + self.sequence_generator = self.build_generator( + [model], + self.cfg.eval_wer_config, + ) + if self.cfg.eval_wer_tokenizer: + self.tokenizer = encoders.build_tokenizer(self.cfg.eval_wer_tokenizer) + else: + self.tokenizer = None + if self.cfg.eval_bleu and self.cfg.autoregressive: + assert self.cfg.eval_bleu_detok is not None, ( + "--eval-bleu-detok is required if using --eval-bleu; " + "try --eval-bleu-detok=moses (or --eval-bleu-detok=space " + "to disable detokenization, e.g., when using sentencepiece)" + ) + detok_args = json.loads(self.cfg.eval_bleu_detok_args) + self.tokenizer = encoders.build_tokenizer( + Namespace(tokenizer=self.cfg.eval_bleu_detok, **detok_args) + ) + gen_args = json.loads(self.cfg.eval_bleu_args) + gen_args = Namespace(**gen_args) + self.sequence_generator = self.build_generator([model], gen_args) + + return model + + def _inference_with_wer_parse(self, generator, sample, model): + import editdistance + + def decode(toks): + s = self.target_dictionary.string( + toks.int().cpu(), + self.cfg.eval_wer_post_process, + escape_unk=True, + ) + if self.tokenizer: + s = self.tokenizer.decode(s) + return s + + def decode_to_list(toks): + def token_string(i): + if i == self.target_dictionary.unk(): + return self.target_dictionary.unk_string(False) + else: + return self.target_dictionary[i] + + return [token_string(i) for i in toks] + + def is_ont_token(token): + return "[" in token or "]" in token + + def post_process(l): + o = [] + for w in l: + if w == self.target_dictionary.eos_word or w == "|": + continue + if w == "_": + o.append(" ") + else: + o.append(w) + if is_ont_token(w): + o.append(" ") + return o + + num_word_errors, num_char_errors = 0, 0 + num_chars, num_words = 0, 0 + num_em_errors, num_ems = 0, 0 + num_tree_errors, num_trees = 0, 0 + gen_out = self.inference_step(generator, [model], sample, None) + for i in range(len(gen_out)): + hyp_tokens = gen_out[i][0]["tokens"] + # hyp = decode(hyp_tokens) + ref_tokens = utils.strip_pad( + sample["target"][i], self.target_dictionary.pad() + ) + # ref = decode(ref_tokens) + hyp_list = decode_to_list(hyp_tokens) + ref_list = decode_to_list(ref_tokens) + + hyp_list = post_process(hyp_list) + ref_list = post_process(ref_list) + + hyp = "".join(hyp_list).strip() + ref = "".join(ref_list).strip() + num_chars += len(ref) + num_char_errors += editdistance.eval(hyp, ref) + hyp_words = hyp.split() + ref_words = ref.split() + hyp_tree = [word for word in hyp_list if ("[" in word or "]" in word)] + ref_tree = [word for word in ref_list if ("[" in word or "]" in word)] + # num_word_errors += editdistance.eval(hyp_words, ref_words) + hyp_before = decode(hyp_tokens).split() + ref_before = decode(ref_tokens).split() + + num_word_errors += editdistance.eval(hyp_before, ref_before) + num_words += len(ref_before) + if hyp != ref: + num_em_errors += 1 + if hyp_tree != ref_tree: + num_tree_errors += 1 + num_ems += 1 + num_trees += 1 + + return { + "num_char_errors": num_char_errors, + "num_chars": num_chars, + "num_word_errors": num_word_errors, + "num_words": num_words, + "num_ems": num_ems, + "num_em_errors": num_em_errors, + "num_trees": num_trees, + "num_tree_errors": num_tree_errors, + } + + def _inference_with_wer(self, generator, sample, model): + import editdistance + + def decode(toks): + s = self.target_dictionary.string( + toks.int().cpu(), + self.cfg.eval_wer_post_process, + escape_unk=True, + ) + if self.tokenizer: + s = self.tokenizer.decode(s) + return s + + num_word_errors, num_char_errors = 0, 0 + num_chars, num_words = 0, 0 + gen_out = self.inference_step(generator, [model], sample, None) + for i in range(len(gen_out)): + hyp = decode(gen_out[i][0]["tokens"]) + ref = decode( + utils.strip_pad(sample["target"][i], self.target_dictionary.pad()), + ) + num_char_errors += editdistance.eval(hyp, ref) + num_chars += len(ref) + hyp_words = hyp.split() + ref_words = ref.split() + num_word_errors += editdistance.eval(hyp_words, ref_words) + num_words += len(ref_words) + + return { + "num_char_errors": num_char_errors, + "num_chars": num_chars, + "num_word_errors": num_word_errors, + "num_words": num_words, + } + + def _inference_with_bleu(self, generator, sample, model): + import sacrebleu + + def decode(toks, is_ref): + s = self.target_dictionary.string( + toks.int().cpu(), + self.cfg.eval_bleu_remove_bpe, + # The default unknown string in fairseq is ``, but + # this is tokenized by sacrebleu as `< unk >`, inflating + # BLEU scores. Instead, we use a somewhat more verbose + # alternative that is unlikely to appear in the real + # reference, but doesn't get split into multiple tokens. + unk_string=("UNKNOWNTOKENINREF" if is_ref else "UNKNOWNTOKENINHYP"), + ) + if self.tokenizer: + s = self.tokenizer.decode(s) + return s + + gen_out = self.inference_step(generator, [model], sample) + hyps, refs = [], [] + for i in range(len(gen_out)): + hyps.append(decode(gen_out[i][0]["tokens"], is_ref=False)) + refs.append( + decode( + utils.strip_pad(sample["target"][i], self.target_dictionary.pad()), + is_ref=True, # don't count as matches to the hypo + ) + ) + if self.cfg.eval_bleu_print_samples: + logger.info("H-{} {}".format(sample["id"][0], hyps[0])) + logger.info("T-{} {}".format(sample["id"][0], refs[0])) + + eval_tokenization = "none" if self.cfg.eval_tokenized_bleu else "13a" + return sacrebleu.corpus_bleu(hyps, [refs], tokenize=eval_tokenization) + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + + if self.cfg.eval_wer or self.cfg.eval_wer_parse: + zero = torch.scalar_tensor(0.0) + num_char_errors = sum( + log.get("_num_char_errors", zero) for log in logging_outputs + ) + num_chars = sum(log.get("_num_chars", zero) for log in logging_outputs) + num_word_errors = sum( + log.get("_num_word_errors", zero) for log in logging_outputs + ) + num_words = sum(log.get("_num_words", zero) for log in logging_outputs) + metrics.log_scalar("_num_char_errors", num_char_errors) + metrics.log_scalar("_num_chars", num_chars) + metrics.log_scalar("_num_word_errors", num_word_errors) + metrics.log_scalar("_num_words", num_words) + if num_chars > 0: + metrics.log_derived( + "uer", + lambda meters: meters["_num_char_errors"].sum + * 100.0 + / meters["_num_chars"].sum + if meters["_num_chars"].sum > 0 + else float("nan"), + ) + if num_words > 0: + metrics.log_derived( + "wer", + lambda meters: meters["_num_word_errors"].sum + * 100.0 + / meters["_num_words"].sum + if meters["_num_words"].sum > 0 + else float("nan"), + ) + if self.cfg.eval_wer_parse: + num_em_errors = sum( + log.get("_num_em_errors", zero) for log in logging_outputs + ) + num_ems = sum(log.get("_num_ems", zero) for log in logging_outputs) + metrics.log_scalar("_num_em_errors", num_em_errors) + metrics.log_scalar("_num_ems", num_ems) + num_tree_errors = sum( + log.get("_num_tree_errors", zero) for log in logging_outputs + ) + num_trees = sum(log.get("_num_trees", zero) for log in logging_outputs) + metrics.log_scalar("_num_tree_errors", num_tree_errors) + metrics.log_scalar("_num_trees", num_trees) + + if num_ems > 0: + metrics.log_derived( + "em_error", + lambda meters: meters["_num_em_errors"].sum + * 100.0 + / meters["_num_ems"].sum + if meters["_num_ems"].sum > 0 + else float("nan"), + ) + if num_trees > 0: + metrics.log_derived( + "tree_error", + lambda meters: meters["_num_tree_errors"].sum + * 100.0 + / meters["_num_trees"].sum + if meters["_num_trees"].sum > 0 + else float("nan"), + ) + + if self.cfg.eval_bleu: + len_keys = ["_bleu_sys_len", "_bleu_ref_len"] + count_keys = [f"_bleu_counts_{i}" for i in range(4)] + total_keys = [f"_bleu_totals_{i}" for i in range(4)] + for k in len_keys + count_keys + total_keys: + metrics.log_scalar(k, sum(log.get(k, 0) for log in logging_outputs)) + + import sacrebleu + + metrics.log_derived( + "bleu", + lambda meters: sacrebleu.compute_bleu( + correct=[meters[k].sum for k in count_keys], + total=[meters[k].sum for k in total_keys], + sys_len=meters["_bleu_sys_len"].sum, + ref_len=meters["_bleu_ref_len"].sum, + smooth_method="exp", + ).score, + ) diff --git a/fairseq/tasks/online_backtranslation.py b/fairseq/tasks/online_backtranslation.py new file mode 100644 index 0000000000000000000000000000000000000000..da24fe8981cd6a2b6f953b3e6646082c1758b0b5 --- /dev/null +++ b/fairseq/tasks/online_backtranslation.py @@ -0,0 +1,683 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import json +import logging +import math +import os +from argparse import Namespace +from collections import OrderedDict, defaultdict +from pathlib import Path +from typing import Dict, Sequence, Tuple +from argparse import ArgumentError + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import fairseq +from fairseq import options, utils +from fairseq.logging import metrics +from fairseq.data import ( + FairseqDataset, + LanguagePairDataset, + NoisingDataset, + PrependTokenDataset, + RoundRobinZipDatasets, + TransformEosLangPairDataset, + data_utils, + encoders, +) +from fairseq.sequence_generator import SequenceGenerator +from fairseq.tasks import register_task +from fairseq.tasks.translation import TranslationTask, load_langpair_dataset + +logger = logging.getLogger(__name__) + + +class PiecewiseLinearFn: + """Piecewise linear function. Can be configured with a string.""" + + def __init__(self, pieces: Sequence[Tuple[int, float]]): + assert pieces == sorted( + pieces + ), f"PiecewiseLinearFn configuration should be sorted, received: {pieces}" + + self.pieces = pieces + + def __call__(self, x: int) -> float: + for i, (x_a, y_a) in enumerate(self.pieces[:-1]): + x_b, y_b = self.pieces[i + 1] + if x_a <= x <= x_b: + return y_a + (x - x_a) * (y_b - y_a) / (x_b - x_a) + + return self.pieces[-1][1] + + @staticmethod + def from_string(configuration: str) -> "PiecewiseLinearFn": + """ + Parse the configuration of lambda coefficient (for scheduling). + x = "3" # lambda will be a constant equal to x + x = "0:1,1000:0" # lambda will start from 1 and linearly decrease + # to 0 during the first 1000 iterations + x = "0:0,1000:0,2000:1" # lambda will be equal to 0 for the first 1000 + # iterations, then will linearly increase to 1 until iteration 2000 + """ + if isinstance(configuration, float): + return PiecewiseLinearFn([(0, configuration)]) + + try: + parts = configuration.split(",") + if len(parts) == 1: + v = float(configuration) + return PiecewiseLinearFn([(0, v)]) + + split = [s.split(":") for s in parts] + pieces = [(int(t), float(v)) for t, v in split] + return PiecewiseLinearFn(pieces) + except Exception: + raise ValueError( + f"Invalid PiecewiseLinearFn configuration: {configuration!r}" + ) + + @staticmethod + def one() -> "PiecewiseLinearFn": + return PiecewiseLinearFn([(0, 1.0)]) + + +@register_task("online_backtranslation") +class OnlineBackTranslationTask(TranslationTask): + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + # fmt: off + # Generic translation args + parser.add_argument('data', help='colon separated path to data directories list, \ + will be iterated upon during epochs in round-robin manner; \ + however, valid and test data are always in the first directory to \ + avoid the need for repeating them in all directories') + parser.add_argument('--mono-langs', metavar='MONO_LANGS', + help='monolingual languages for training') + parser.add_argument('--valid-lang-pairs', default=None, metavar='VALID_LANG_PAIRS', + help='language pairs for validation') + parser.add_argument('--load-alignments', action='store_true', + help='load the binarized alignments') + parser.add_argument('--left-pad-source', default='False', type=str, metavar='BOOL', + help='pad the source on the left') + parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL', + help='pad the target on the left') + parser.add_argument('--upsample-primary', default=1, type=int, + help='amount to upsample primary dataset') + try: + parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N', + help='max number of tokens in the source sequence') + parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', + help='max number of tokens in the target sequence') + except ArgumentError: + # this might have already been defined. Once we transition this to hydra it should be fine to add it here. + pass + parser.add_argument('--truncate-source', action='store_true', default=False, + help='truncate source to max-source-positions') + parser.add_argument('--num-batch-buckets', default=0, type=int, metavar='N', + help='if >0, then bucket source and target lengths into N ' + 'buckets and pad accordingly; this is useful on TPUs ' + 'to minimize the number of compilations') + + # Denoising args + parser.add_argument('--max-word-shuffle-distance', default=3.0, type=float, metavar='N', + help='maximum word shuffle distance for denoising autoencoding data generation') + parser.add_argument('--word-dropout-prob', default=0.1, type=float, metavar='N', + help='word dropout probability for denoising autoencoding data generation') + parser.add_argument('--word-blanking-prob', default=0.2, type=float, metavar='N', + help='word blanking probability for denoising autoencoding data generation') + + # Backtranslation args + parser.add_argument('--lambda-bt', default="1.0", type=str, metavar='N', + help='back-translation weight') + parser.add_argument('--lambda-dae', default="1.0", type=str, metavar='N', + help='denoising auto-encoder weight') + + # Evaluation args + parser.add_argument('--generate-one-by-one', action='store_true', + help='generate one sentence at a time for backtranslation') + + parser.add_argument('--eval-bleu', action='store_true', + help='evaluation with BLEU scores') + parser.add_argument('--eval-bleu-detok', type=str, default="space", + help='detokenize before computing BLEU (e.g., "moses"); ' + 'required if using --eval-bleu; use "space" to ' + 'disable detokenization; see fairseq.data.encoders ' + 'for other options') + parser.add_argument('--eval-bleu-detok-args', type=str, metavar='JSON', + help='args for building the tokenizer, if needed') + parser.add_argument('--eval-tokenized-bleu', action='store_true', default=False, + help='compute tokenized BLEU instead of sacrebleu') + parser.add_argument('--eval-bleu-remove-bpe', nargs='?', const='@@ ', default=None, + help='remove BPE before computing BLEU') + parser.add_argument('--eval-bleu-args', type=str, metavar='JSON', + help='generation args for BLUE scoring, ' + 'e.g., \'{"beam": 4, "lenpen": 0.6}\'') + parser.add_argument('--eval-bleu-print-samples', action='store_true', + help='print sample generations during validation') + # fmt: on + + def __init__(self, args, common_dict, mono_langs, valid_lang_pairs): + super().__init__(args, common_dict, common_dict) + self.common_dict = common_dict + self.mono_langs = mono_langs + self.valid_lang_pairs = valid_lang_pairs + + self.SHOW_SAMPLES_INTERVAL = 1000 + # Start by showing samples + self._show_samples_ctr = self.SHOW_SAMPLES_INTERVAL + self.SHOW_SAMPLES_NUMBER = 5 + self.lambda_bt = PiecewiseLinearFn.from_string(args.lambda_bt) + self.lambda_dae = PiecewiseLinearFn.from_string(args.lambda_dae) + + self.args = args + self.data = utils.split_paths(self.args.data) + if len(self.data) == 1: + shards = list(Path(self.data[0]).glob("shard*")) + if len(shards) > 0: + # keep this as strings, since it can also be a manifold path + old_data = self.data + self.data = [str(shard) for shard in shards] + logging.warning(f"Expanded data directory {old_data} to {self.data}") + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + args.left_pad_source = options.eval_bool(args.left_pad_source) + args.left_pad_target = options.eval_bool(args.left_pad_target) + + paths = utils.split_paths(args.data) + assert len(paths) > 0 + assert args.mono_langs is not None + + mono_langs = args.mono_langs.split(",") + valid_lang_pairs = args.valid_lang_pairs.split(",") + + # load dictionary + dict_path = os.path.join(paths[0], "dict.txt") + common_dict = cls.load_dictionary(dict_path) + + return cls(args, common_dict, mono_langs, valid_lang_pairs) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs) -> FairseqDataset: + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + if split == "train": + data_path = self.data[(epoch - 1) % len(self.data)] + dataset = self.load_train_dataset(data_path) + else: + # valid/test should always be the same. + dataset = self.load_translation_dataset(split, self.data[0]) + + self.datasets[split] = dataset + return dataset + + def load_train_dataset(self, data_path: str) -> FairseqDataset: + """The training dataset is made of backtranslation dataset and denoising dataset.""" + data = [] + for lang in self.mono_langs: + train_path = os.path.join(data_path, lang, "train") + # TODO: could we do the BT using denoise sample ? + # this would half the data loading work + data.append((f"{lang}-BT", self.load_bt_dataset(train_path, lang))) + data.append( + (f"{lang}-DENOISE", self.load_denoise_dataset(train_path, lang)) + ) + + return RoundRobinZipDatasets(OrderedDict(data)) + + def _langpair_dataset( + self, src: FairseqDataset, tgt: FairseqDataset + ) -> LanguagePairDataset: + return LanguagePairDataset( + src, + src.sizes, + self.dictionary, + tgt=tgt, + tgt_sizes=tgt.sizes, + tgt_dict=self.dictionary, + left_pad_source=self.args.left_pad_source, + left_pad_target=self.args.left_pad_target, + # TODO: should we shuffle ? we are already sorting batch by sizes so ? + # shuffle=True, + ) + + def _prepend_lang_bos_to_target( + self, dataset: LanguagePairDataset, lang: str + ) -> LanguagePairDataset: + bos = _lang_token_index(self.dictionary, lang) + return TransformEosLangPairDataset( + dataset, + src_eos=self.dictionary.eos(), + new_src_eos=self.dictionary.eos(), + tgt_bos=self.dictionary.eos(), + new_tgt_bos=bos, + ) + + def load_bt_dataset(self, data_path: str, lang: str) -> FairseqDataset: + """The BT dataset is generated with (tgt, tgt) pairs. + The actual translation to a (generated_src, tgt) pair + is done on the fly during training. + """ + mono_dataset = data_utils.load_indexed_dataset( + data_path, self.common_dict, self.args.dataset_impl + ) + assert mono_dataset is not None, f"No dataset found for {lang}" + + mono_dataset_src = PrependTokenDataset( + mono_dataset, _lang_token_index(self.dictionary, lang) + ) + + mono_dataset_bt = self._langpair_dataset(mono_dataset_src, mono_dataset) + logger.info( + f"mono_lang = {lang} " + f"lang token index = {_lang_token_index(self.dictionary, lang)} " + f"lang token = {_lang_token(lang)}" + ) + + mono_dataset_bt = self._prepend_lang_bos_to_target(mono_dataset_bt, lang) + return mono_dataset_bt + + def load_denoise_dataset(self, data_path: str, lang: str) -> FairseqDataset: + """Classic denoising dataset""" + dataset = data_utils.load_indexed_dataset( + data_path, self.common_dict, self.args.dataset_impl + ) + noisy_dataset = NoisingDataset( + dataset, + self.dictionary, + seed=1, + max_word_shuffle_distance=self.args.max_word_shuffle_distance, + word_dropout_prob=self.args.word_dropout_prob, + word_blanking_prob=self.args.word_blanking_prob, + ) + noisy_dataset = PrependTokenDataset( + noisy_dataset, _lang_token_index(self.dictionary, lang) + ) + + clean_dataset = data_utils.load_indexed_dataset( + data_path, self.common_dict, self.args.dataset_impl + ) + denoising_dataset = self._langpair_dataset(noisy_dataset, clean_dataset) + denoising_dataset = self._prepend_lang_bos_to_target(denoising_dataset, lang) + return denoising_dataset + + def load_translation_dataset( + self, split: str, data_path: str, combine: bool = False + ): + # only judging with one language pair for the moment, + # since ConcatDataset doesn't work as expected + assert len(self.valid_lang_pairs) == 1, "For now..." + valid_lang_pair = self.valid_lang_pairs[0] + src, tgt = valid_lang_pair.split("-") + + # use the same function than TranslationTask + src_tgt_dt = load_langpair_dataset( + data_path, + split, + src, + self.common_dict, + tgt, + self.common_dict, + combine=combine, + dataset_impl=self.args.dataset_impl, + upsample_primary=self.args.upsample_primary, + left_pad_source=self.args.left_pad_source, + left_pad_target=self.args.left_pad_target, + max_source_positions=self.args.max_source_positions, + max_target_positions=self.args.max_target_positions, + load_alignments=self.args.load_alignments, + truncate_source=self.args.truncate_source, + num_buckets=self.args.num_batch_buckets, + shuffle=(split != "test"), + prepend_bos_src=_lang_token_index(self.dictionary, src), + ) + + src_tgt_eos_dt = self._prepend_lang_bos_to_target(src_tgt_dt, tgt) + src_tgt_eos_dt.args = self.args + return src_tgt_eos_dt + + def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): + raise NotImplementedError + + def build_model(self, args, from_checkpoint=False): + # torch.autograd.set_detect_anomaly(True) + model = super().build_model(args, from_checkpoint) + + add_secial_tokens_to_dict_and_model(self.common_dict, model, self.mono_langs) + + self.sequence_generators = {} + for mono_lang in self.mono_langs: + self.sequence_generators[mono_lang] = SequenceGenerator( + [model], + tgt_dict=self.dictionary, + beam_size=1, + max_len_a=1.3, + max_len_b=5, + min_len=5, + # keep 1 to be able to prepend bos + max_len=model.max_decoder_positions() - 1, + ) + + if getattr(args, "eval_bleu", False): + assert getattr(args, "eval_bleu_detok", None) is not None, ( + "--eval-bleu-detok is required if using --eval-bleu; " + "try --eval-bleu-detok=moses (or --eval-bleu-detok=space " + "to disable detokenization, e.g., when using sentencepiece)" + ) + detok_args = json.loads(getattr(args, "eval_bleu_detok_args", "{}") or "{}") + self.tokenizer = encoders.build_tokenizer( + Namespace( + tokenizer=getattr(args, "eval_bleu_detok", None), **detok_args + ) + ) + + gen_args = json.loads(getattr(args, "eval_bleu_args", "{}") or "{}") + self.bleu_sequence_generator = self.build_generator( + [model], Namespace(**gen_args) + ) + + return model + + def max_positions(self): + """Return the max sentence length allowed by the task.""" + return (self.args.max_source_positions, self.args.max_target_positions) + + @property + def dictionary(self): + """Return the source :class:`~fairseq.data.Dictionary`.""" + return self.common_dict + + def display_samples_once_in_a_while(self, smp, mono_lang, other_lang): + self._show_samples_ctr += 1 + if self._show_samples_ctr < self.SHOW_SAMPLES_INTERVAL: + return + self._show_samples_ctr = 0 + + ln = smp["net_input"]["src_tokens"].shape[0] + + logger.info( + f"(r:{self.args.distributed_rank}) : " + f"{other_lang} ---> {mono_lang} " + f"({other_lang} was generated by back-translation.) {ln} samples" + ) + + for i in range(min(ln, self.SHOW_SAMPLES_NUMBER)): + src_tokens = smp["net_input"]["src_tokens"][i] + tgt_tokens = smp["target"][i] + + src_str = self.dictionary.string(src_tokens, "sentencepiece") + tgt_str = self.dictionary.string(tgt_tokens, "sentencepiece") + logger.info( + f"\n{i}\t\t[{other_lang} generated] {src_str}\n" + f"\t\t[{mono_lang} original ] {tgt_str}\n" + f"\t\t[ src tokens] {src_tokens}\n" + ) + + def backtranslate_sample(self, smp, orig_lang, other_lang) -> None: + """ + * WARNING: smp is modified in place. + * At the start of this function, `smp` has the same input and target: + |--------------------------------------------------------| + | smp['net_input']['src_tokens'] | smp['target'] | + | (from data) __en__ hello world | __en__ hello world | + |--------------------------------------------------------| + + * We call generator.generate(smp, bos_token = token("ro")), + and copy the result as input + * At the end, `smp` has the translation to other language. + |--------------------------------------------------------| + | smp['net_input']['src_tokens'] | smp['target'] | + | (generated) __ro__ salut lume | __en__ hello world | + |--------------------------------------------------------| + + """ + bos_token = _lang_token_index(self.dictionary, other_lang) + generated = self.sequence_generators[orig_lang].generate( + models=[], sample=smp, bos_token=bos_token + ) + + max_lngth = max([gn[0]["tokens"].size(0) for gn in generated]) + net_input = smp["net_input"] + n_src_tokens = torch.empty( + size=(len(generated), max_lngth + 1), dtype=net_input["src_tokens"].dtype + ) + n_src_lengths = torch.empty( + len(generated), dtype=net_input["src_lengths"].dtype + ) + + for i, gn in enumerate(generated): + tokens = gn[0]["tokens"] + tokens_size = tokens.size(0) + padding_needed = max_lngth - tokens_size + tokens = torch.cat([tokens.new([bos_token]), tokens]) + tokens = F.pad(tokens, (0, padding_needed), value=self.dictionary.pad()) + n_src_tokens[i] = tokens + n_src_lengths[i] = tokens_size + 1 + + device = net_input["src_tokens"].device + # This seems to be important + del net_input["src_tokens"] + del net_input["src_lengths"] + net_input["src_tokens"] = n_src_tokens.to(device) + net_input["src_lengths"] = n_src_lengths.to(device) + + def generate(self, smp, model): + model.eval() + orig_lang = ( + self.dictionary[smp["net_input"]["src_tokens"][0][0]] + .replace(" ", "") + .replace("_", "") + ) + bos_token = smp["net_input"]["prev_output_tokens"][0][0] + with torch.no_grad(): + generated = self.sequence_generators[orig_lang].generate( + models=[model], sample=smp, bos_token=bos_token + ) + return generated + + def get_other_lang(self, lang): + # TODO: allow more complex mapping + if lang != self.mono_langs[0]: + return self.mono_langs[0] + if len(self.mono_langs) == 2: + return self.mono_langs[1] + return self.mono_langs[np.random.randint(1, len(self.mono_langs))] + + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): + + model.train() + model.set_num_updates(update_num) + + agg_loss, agg_sample_size = 0.0, 0.0 + agg_logging_output: Dict[str, float] = defaultdict(float) + + dataset_keys = self.datasets["train"].datasets.keys() + + weights = { + "BT": self.lambda_bt(update_num), + "DENOISE": self.lambda_dae(update_num), + } + log_keys = {"BT": "bt_", "DENOISE": "dae_"} + + for dataset_key in dataset_keys: + smp = sample[dataset_key] + mono_lang, task_subtype = dataset_key.split("-") + if weights[task_subtype] == 0: + continue + + if task_subtype == "BT": + with torch.autograd.profiler.record_function("backtranslation"): + model.eval() + # TODO: Could we translate to several language at once ? + # this would allow to share encoder_out and maximize GPU usage. + other_lang = self.get_other_lang(mono_lang) + self.backtranslate_sample(smp, mono_lang, other_lang) + self.display_samples_once_in_a_while(smp, mono_lang, other_lang) + model.train() + + # Like in FairseqTask.train_step + with torch.autograd.profiler.record_function("forward"): + loss, sample_size, logging_output = criterion(model, smp) + loss *= weights[task_subtype] + if ignore_grad: + loss *= 0 + with torch.autograd.profiler.record_function("backward"): + optimizer.backward(loss) + + agg_loss += loss.item() + agg_sample_size += sample_size + for k in logging_output: + agg_logging_output[log_keys[task_subtype] + k] += logging_output[k] + agg_logging_output[k] += logging_output[k] + + return agg_loss, agg_sample_size, agg_logging_output + + def get_bos_token_from_sample(self, sample): + net_input = sample["net_input"] + source_lang_token_id = torch.unique(net_input["src_tokens"][:, 0]).item() + source_lang_token = self.dictionary[source_lang_token_id].replace("_", "") + target_lang_token_id = _lang_token_index( + self.dictionary, self.get_other_lang(source_lang_token) + ) + + return target_lang_token_id + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + bt_sample_size = sum(x.get("bt_sample_size", 0) for x in logging_outputs) + if bt_sample_size: + bt_loss_sum = sum(x.get("bt_loss", 0) for x in logging_outputs) + bt_loss_sum *= 1 / bt_sample_size / math.log(2) + metrics.log_scalar("bt_loss", bt_loss_sum, bt_sample_size, round=3) + + bt_nll_loss_sum = sum(x.get("bt_nll_loss", 0) for x in logging_outputs) + bt_ntokens = sum(x.get("bt_ntokens", 0) for x in logging_outputs) + bt_nll_loss_sum *= 1 / bt_ntokens / math.log(2) + metrics.log_scalar("bt_nll_loss", bt_nll_loss_sum, bt_ntokens, round=3) + metrics.log_derived( + "bt_ppl", lambda meters: utils.get_perplexity(meters["bt_nll_loss"].avg) + ) + + dae_sample_size = sum(x.get("dae_sample_size", 0) for x in logging_outputs) + if dae_sample_size: + dae_loss_sum = sum(x.get("dae_loss", 0) for x in logging_outputs) + dae_loss_sum *= 1 / dae_sample_size / math.log(2) + metrics.log_scalar("dae_loss", dae_loss_sum, dae_sample_size, round=3) + + dae_nll_loss_sum = sum(x.get("dae_nll_loss", 0) for x in logging_outputs) + dae_ntokens = sum(x.get("dae_ntokens", 0) for x in logging_outputs) + dae_nll_loss_sum *= 1 / dae_ntokens / math.log(2) + metrics.log_scalar("dae_nll_loss", dae_nll_loss_sum, dae_ntokens, round=3) + metrics.log_derived( + "dae_ppl", + lambda meters: utils.get_perplexity(meters["dae_nll_loss"].avg), + ) + + +@torch.no_grad() +def extend_embedding( + emb: nn.Module, new_vocab_size: int, copy_from_token_id: int +) -> None: + old_emb_data = emb.weight.data + (old_vocab_size, dim) = old_emb_data.shape + assert new_vocab_size >= old_vocab_size + + if new_vocab_size > old_vocab_size: + emb.weight.data = torch.zeros((new_vocab_size, dim)) + emb.weight.data[:old_vocab_size, :] = old_emb_data + # initialize new embeddings + emb.weight.data[old_vocab_size:, :] = old_emb_data[copy_from_token_id] + if hasattr(emb, "num_embeddings"): + emb.num_embeddings = new_vocab_size + if hasattr(emb, "out_features"): + emb.out_features = new_vocab_size + + if getattr(emb, "bias", None) is None: + return + + # Fix the bias. + # Bias shape can be different from the previous vocab size + # if the weight matrix was shared and alread extended but not the bias. + (old_vocab_size,) = emb.bias.shape + assert new_vocab_size >= old_vocab_size + if new_vocab_size > old_vocab_size: + old_bias = emb.bias.data + new_bias = torch.zeros( + (new_vocab_size,), dtype=old_bias.dtype, device=old_bias.device + ) + new_bias[:old_vocab_size] = old_bias + emb.bias.data = new_bias + + +def add_secial_tokens_to_dict_and_model( + dictionary: "fairseq.data.Dictionary", + model: nn.Module, + mono_langs: Sequence[str], +) -> None: + embs = model.encoder.embed_tokens + vocab_size, embedding_dim = embs.weight.shape + + # The model may or may not have a '' embedding yet + assert ( + len(dictionary) <= vocab_size <= len(dictionary) + 1 + ), f"Dictionary len ({len(dictionary)}) doesn't match embs shape ({embs.weight.shape})" + # TODO: we should reuse the pretrained model dict which already has + dictionary.add_symbol("") + + for lang in mono_langs: + lang_token = _lang_token(lang) + dictionary.add_symbol(lang_token) + logger.info( + f"dictionary: {len(dictionary)} -> {vocab_size} tokens " + f"after adding {len(mono_langs)} lang tokens." + ) + + if len(dictionary) <= vocab_size: + return + + extend_embedding(embs, len(dictionary), dictionary.bos()) + dec_embs = model.decoder.embed_tokens + extend_embedding(dec_embs, len(dictionary), dictionary.bos()) + lm_head = model.decoder.output_projection + extend_embedding(lm_head, len(dictionary), dictionary.bos()) + assert lm_head.weight.shape == (len(dictionary), embedding_dim) + + +def _lang_token(lang: str) -> str: + return f"__{lang}__" + + +def _lang_token_index(dictionary, lang: str) -> int: + return dictionary.index(_lang_token(lang)) + + +@contextlib.contextmanager +def assert_weights_have_changed(model: nn.Module): + def checksum(model: nn.Module) -> float: + return sum(p.sum().item() for p in model.parameters()) + + initial_checksum = checksum(model) + yield model + final_checksum = checksum(model) + logger.info( + f"initial_checksum={initial_checksum} -> final_checksum={final_checksum}" + ) + assert initial_checksum != final_checksum, "Model hasn't changed !" diff --git a/fairseq/tasks/semisupervised_translation.py b/fairseq/tasks/semisupervised_translation.py new file mode 100644 index 0000000000000000000000000000000000000000..432b8a52ca122bca3e3f24a1fd493da33614e742 --- /dev/null +++ b/fairseq/tasks/semisupervised_translation.py @@ -0,0 +1,485 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from collections import OrderedDict + +from fairseq import utils +from fairseq.data import ( + BacktranslationDataset, + IndexedCachedDataset, + IndexedDataset, + IndexedRawTextDataset, + LanguagePairDataset, + NoisingDataset, + RoundRobinZipDatasets, + data_utils, + indexed_dataset, +) +from fairseq.models import FairseqMultiModel +from fairseq.sequence_generator import SequenceGenerator + +from . import register_task +from .multilingual_translation import MultilingualTranslationTask + + +logger = logging.getLogger(__name__) + + +def _get_bt_dataset_key(lang_pair): + return "bt:" + lang_pair + + +def _get_denoising_dataset_key(lang_pair): + return "denoising:" + lang_pair + + +# ported from UnsupervisedMT +def parse_lambda_config(x): + """ + Parse the configuration of lambda coefficient (for scheduling). + x = "3" # lambda will be a constant equal to x + x = "0:1,1000:0" # lambda will start from 1 and linearly decrease + # to 0 during the first 1000 iterations + x = "0:0,1000:0,2000:1" # lambda will be equal to 0 for the first 1000 + # iterations, then will linearly increase to 1 until iteration 2000 + """ + split = x.split(",") + if len(split) == 1: + return float(x), None + else: + split = [s.split(os.pathsep) for s in split] + assert all(len(s) == 2 for s in split) + assert all(k.isdigit() for k, _ in split) + assert all( + int(split[i][0]) < int(split[i + 1][0]) for i in range(len(split) - 1) + ) + return float(split[0][1]), [(int(k), float(v)) for k, v in split] + + +@register_task("semisupervised_translation") +class SemisupervisedTranslationTask(MultilingualTranslationTask): + """A task for training multiple translation models simultaneously. + + We iterate round-robin over batches from multiple language pairs, ordered + according to the `--lang-pairs` argument. + + The training loop is roughly: + + for i in range(len(epoch)): + for lang_pair in args.lang_pairs: + batch = next_batch_for_lang_pair(lang_pair) + loss = criterion(model_for_lang_pair(lang_pair), batch) + loss.backward() + optimizer.step() + + In practice, `next_batch_for_lang_pair` is abstracted in a FairseqDataset + (e.g., `RoundRobinZipDatasets`) and `model_for_lang_pair` is a model that + implements the `FairseqMultiModel` interface. + + During inference it is required to specify a single `--source-lang` and + `--target-lang`, instead of `--lang-pairs`. + """ + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + # fmt: off + MultilingualTranslationTask.add_args(parser) + parser.add_argument('--lambda-parallel-config', default="1.0", type=str, metavar='CONFIG', + help='cross-entropy reconstruction coefficient (parallel data). ' + 'use fixed weight during training if set to floating point number. ' + 'use piecewise linear function over number of updates to schedule the ' + 'weight with the format: w0:step0,w1:step1,...') + parser.add_argument('--lambda-denoising-config', default="0.0", type=str, metavar='CONFIG', + help='Cross-entropy reconstruction coefficient (denoising autoencoding)' + 'use fixed weight during training if set to floating point number. ' + 'use piecewise linear function over number of updates to schedule the ' + 'weight with the format: w0:step0,w1:step1,...') + parser.add_argument('--lambda-otf-bt-config', default="0.0", type=str, metavar='CONFIG', + help='cross-entropy reconstruction coefficient (on-the-fly back-translation parallel data)' + 'use fixed weight during training if set to floating point number. ' + 'use piecewise linear function over number of updates to schedule the ' + 'weight with the format: w0:step0,w1:step1,...') + parser.add_argument('--bt-max-len-a', default=1.1, type=float, metavar='N', + help='generate back-translated sequences of maximum length ax + b, where x is the ' + 'source length') + parser.add_argument('--bt-max-len-b', default=10.0, type=float, metavar='N', + help='generate back-translated sequences of maximum length ax + b, where x is the ' + 'source length') + parser.add_argument('--bt-beam-size', default=1, type=int, metavar='N', + help='beam size used in beam search of online back-translation') + parser.add_argument('--max-word-shuffle-distance', default=3.0, type=float, metavar='N', + help='maximum word shuffle distance for denoising autoencoding data generation') + parser.add_argument('--word-dropout-prob', default=0.1, type=float, metavar='N', + help='word dropout probability for denoising autoencoding data generation') + parser.add_argument('--word-blanking-prob', default=0.2, type=float, metavar='N', + help='word blanking probability for denoising autoencoding data generation') + # fmt: on + + def __init__(self, args, dicts, training): + super().__init__(args, dicts, training) + self.lambda_parallel, self.lambda_parallel_steps = parse_lambda_config( + args.lambda_parallel_config + ) + self.lambda_otf_bt, self.lambda_otf_bt_steps = parse_lambda_config( + args.lambda_otf_bt_config + ) + self.lambda_denoising, self.lambda_denoising_steps = parse_lambda_config( + args.lambda_denoising_config + ) + if self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None: + denoising_lang_pairs = [ + "%s-%s" % (tgt, tgt) + for tgt in {lang_pair.split("-")[1] for lang_pair in args.lang_pairs} + ] + self.model_lang_pairs = self.model_lang_pairs + denoising_lang_pairs + self.backtranslate_datasets = {} + self.backtranslators = {} + + @classmethod + def setup_task(cls, args, **kwargs): + dicts, training = MultilingualTranslationTask.prepare(args, **kwargs) + return cls(args, dicts, training) + + def load_dataset(self, split, epoch=1, **kwargs): + """Load a dataset split.""" + paths = utils.split_paths(self.args.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + + def split_exists(split, src, tgt, lang): + if src is not None: + filename = os.path.join( + data_path, "{}.{}-{}.{}".format(split, src, tgt, lang) + ) + else: + filename = os.path.join( + data_path, "{}.{}-None.{}".format(split, src, tgt) + ) + return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl) + + def load_indexed_dataset(path, dictionary): + return data_utils.load_indexed_dataset( + path, dictionary, self.args.dataset_impl + ) + + # load parallel datasets + src_datasets, tgt_datasets = {}, {} + if ( + self.lambda_parallel > 0.0 + or self.lambda_parallel_steps is not None + or not split.startswith("train") + ): + for lang_pair in self.lang_pairs: + src, tgt = lang_pair.split("-") + if split_exists(split, src, tgt, src): + prefix = os.path.join( + data_path, "{}.{}-{}.".format(split, src, tgt) + ) + elif split_exists(split, tgt, src, src): + prefix = os.path.join( + data_path, "{}.{}-{}.".format(split, tgt, src) + ) + else: + continue + src_datasets[lang_pair] = load_indexed_dataset( + prefix + src, self.dicts[src] + ) + tgt_datasets[lang_pair] = load_indexed_dataset( + prefix + tgt, self.dicts[tgt] + ) + logger.info( + "parallel-{} {} {} examples".format( + data_path, split, len(src_datasets[lang_pair]) + ) + ) + if len(src_datasets) == 0: + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, data_path) + ) + + # back translation datasets + backtranslate_datasets = {} + if ( + self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None + ) and split.startswith("train"): + for lang_pair in self.lang_pairs: + src, tgt = lang_pair.split("-") + if not split_exists(split, tgt, None, tgt): + raise FileNotFoundError( + "Dataset not found: backtranslation {} ({})".format( + split, data_path + ) + ) + filename = os.path.join( + data_path, "{}.{}-None.{}".format(split, tgt, tgt) + ) + dataset = load_indexed_dataset(filename, self.dicts[tgt]) + lang_pair_dataset_tgt = LanguagePairDataset( + dataset, + dataset.sizes, + self.dicts[tgt], + left_pad_source=self.args.left_pad_source, + left_pad_target=self.args.left_pad_target, + ) + lang_pair_dataset = LanguagePairDataset( + dataset, + dataset.sizes, + src_dict=self.dicts[src], + tgt=dataset, + tgt_sizes=dataset.sizes, + tgt_dict=self.dicts[tgt], + left_pad_source=self.args.left_pad_source, + left_pad_target=self.args.left_pad_target, + ) + backtranslate_datasets[lang_pair] = BacktranslationDataset( + tgt_dataset=self.alter_dataset_langtok( + lang_pair_dataset_tgt, + src_eos=self.dicts[tgt].eos(), + src_lang=tgt, + tgt_lang=src, + ), + backtranslation_fn=self.backtranslators[lang_pair], + src_dict=self.dicts[src], + tgt_dict=self.dicts[tgt], + output_collater=self.alter_dataset_langtok( + lang_pair_dataset=lang_pair_dataset, + src_eos=self.dicts[src].eos(), + src_lang=src, + tgt_eos=self.dicts[tgt].eos(), + tgt_lang=tgt, + ).collater, + ) + logger.info( + "backtranslate-{}: {} {} {} examples".format( + tgt, + data_path, + split, + len(backtranslate_datasets[lang_pair]), + ) + ) + self.backtranslate_datasets[lang_pair] = backtranslate_datasets[ + lang_pair + ] + + # denoising autoencoder + noising_datasets = {} + if ( + self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None + ) and split.startswith("train"): + for lang_pair in self.lang_pairs: + _, tgt = lang_pair.split("-") + if not split_exists(split, tgt, None, tgt): + continue + filename = os.path.join( + data_path, "{}.{}-None.{}".format(split, tgt, tgt) + ) + tgt_dataset1 = load_indexed_dataset(filename, self.dicts[tgt]) + tgt_dataset2 = load_indexed_dataset(filename, self.dicts[tgt]) + noising_dataset = NoisingDataset( + tgt_dataset1, + self.dicts[tgt], + seed=1, + max_word_shuffle_distance=self.args.max_word_shuffle_distance, + word_dropout_prob=self.args.word_dropout_prob, + word_blanking_prob=self.args.word_blanking_prob, + ) + noising_datasets[lang_pair] = self.alter_dataset_langtok( + LanguagePairDataset( + noising_dataset, + tgt_dataset1.sizes, + self.dicts[tgt], + tgt_dataset2, + tgt_dataset2.sizes, + self.dicts[tgt], + left_pad_source=self.args.left_pad_source, + left_pad_target=self.args.left_pad_target, + ), + src_eos=self.dicts[tgt].eos(), + src_lang=tgt, + tgt_eos=self.dicts[tgt].eos(), + tgt_lang=tgt, + ) + logger.info( + "denoising-{}: {} {} {} examples".format( + tgt, + data_path, + split, + len(noising_datasets[lang_pair]), + ) + ) + + def language_pair_dataset(lang_pair): + src, tgt = lang_pair.split("-") + src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair] + return self.alter_dataset_langtok( + LanguagePairDataset( + src_dataset, + src_dataset.sizes, + self.dicts[src], + tgt_dataset, + tgt_dataset.sizes, + self.dicts[tgt], + left_pad_source=self.args.left_pad_source, + left_pad_target=self.args.left_pad_target, + ), + self.dicts[src].eos(), + src, + self.dicts[tgt].eos(), + tgt, + ) + + self.datasets[split] = RoundRobinZipDatasets( + OrderedDict( + [ + (lang_pair, language_pair_dataset(lang_pair)) + for lang_pair in src_datasets.keys() + ] + + [ + (_get_bt_dataset_key(lang_pair), dataset) + for lang_pair, dataset in backtranslate_datasets.items() + ] + + [ + (_get_denoising_dataset_key(lang_pair), dataset) + for lang_pair, dataset in noising_datasets.items() + ] + ), + eval_key=None + if self.training + else "%s-%s" % (self.args.source_lang, self.args.target_lang), + ) + + def build_model(self, args, from_checkpoint=False): + from fairseq import models + + model = models.build_model(args, self, from_checkpoint) + if not isinstance(model, FairseqMultiModel): + raise ValueError( + "SemisupervisedTranslationTask requires a FairseqMultiModel architecture" + ) + + # create SequenceGenerator for each model that has backtranslation dependency on it + self.sequence_generators = {} + if ( + self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None + ) and self.training: + for lang_pair in self.lang_pairs: + src, tgt = lang_pair.split("-") + key = "{}-{}".format(tgt, src) + self.sequence_generators[key] = SequenceGenerator( + [model.models[key]], + tgt_dict=self.dicts[src], + beam_size=args.bt_beam_size, + max_len_a=args.bt_max_len_a, + max_len_b=args.bt_max_len_b, + ) + decoder_lang_tok_idx = self.get_decoder_langtok(src) + + def backtranslate_fn( + sample, + model=model.models[key], + bos_token=decoder_lang_tok_idx, + sequence_generator=self.sequence_generators[key], + ): + return sequence_generator.generate( + [model], + sample, + bos_token=bos_token, + ) + + self.backtranslators[lang_pair] = backtranslate_fn + + return model + + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): + model.train() + + if update_num > 0: + self.update_step(update_num) + + agg_loss, agg_sample_size, agg_logging_output = 0.0, 0.0, {} + + def forward_backward(model, samples, logging_output_key, weight): + nonlocal agg_loss, agg_sample_size, agg_logging_output + if samples is None or len(samples) == 0: + return + loss, sample_size, logging_output = criterion(model, samples) + if ignore_grad: + loss *= 0 + else: + loss *= weight + optimizer.backward(loss) + agg_loss += loss.detach().item() + # TODO make summing of the sample sizes configurable + agg_sample_size += sample_size + for k in logging_output: + agg_logging_output[k] += logging_output[k] + agg_logging_output[logging_output_key] += logging_output[k] + + if self.lambda_parallel > 0.0: + for lang_pair in self.lang_pairs: + forward_backward( + model.models[lang_pair], + sample[lang_pair], + lang_pair, + self.lambda_parallel, + ) + + if self.lambda_otf_bt > 0.0: + for lang_pair in self.lang_pairs: + sample_key = _get_bt_dataset_key(lang_pair) + forward_backward( + model.models[lang_pair], + sample[sample_key], + sample_key, + self.lambda_otf_bt, + ) + + if self.lambda_denoising > 0.0: + for lang_pair in self.lang_pairs: + _, tgt = lang_pair.split("-") + sample_key = _get_denoising_dataset_key(lang_pair) + forward_backward( + model.models["{0}-{0}".format(tgt)], + sample[sample_key], + sample_key, + self.lambda_denoising, + ) + + return agg_loss, agg_sample_size, agg_logging_output + + def update_step(self, num_updates): + def lambda_step_func(config, n_iter): + """ + Update a lambda value according to its schedule configuration. + """ + ranges = [ + i + for i in range(len(config) - 1) + if config[i][0] <= n_iter < config[i + 1][0] + ] + if len(ranges) == 0: + assert n_iter >= config[-1][0] + return config[-1][1] + assert len(ranges) == 1 + i = ranges[0] + x_a, y_a = config[i] + x_b, y_b = config[i + 1] + return y_a + (n_iter - x_a) * float(y_b - y_a) / float(x_b - x_a) + + if self.lambda_parallel_steps is not None: + self.lambda_parallel = lambda_step_func( + self.lambda_parallel_steps, num_updates + ) + if self.lambda_denoising_steps is not None: + self.lambda_denoising = lambda_step_func( + self.lambda_denoising_steps, num_updates + ) + if self.lambda_otf_bt_steps is not None: + self.lambda_otf_bt = lambda_step_func(self.lambda_otf_bt_steps, num_updates) diff --git a/fairseq/tasks/sentence_prediction.py b/fairseq/tasks/sentence_prediction.py new file mode 100644 index 0000000000000000000000000000000000000000..de80addaf20e902a04f251bb6d0e3712fc5439d9 --- /dev/null +++ b/fairseq/tasks/sentence_prediction.py @@ -0,0 +1,303 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os + +import contextlib +from dataclasses import dataclass, field +from typing import Optional +from omegaconf import MISSING, II, open_dict, OmegaConf + +import numpy as np +from fairseq.data import ( + ConcatSentencesDataset, + Dictionary, + IdDataset, + NestedDictionaryDataset, + NumelDataset, + NumSamplesDataset, + OffsetTokensDataset, + PrependTokenDataset, + RawLabelDataset, + RightPadDataset, + RightPaddingMaskDataset, + RollDataset, + SortDataset, + StripTokenDataset, + data_utils, +) +from fairseq.data.shorten_dataset import maybe_shorten_dataset +from fairseq.tasks import FairseqDataclass, FairseqTask, register_task +from fairseq.dataclass import ChoiceEnum + + +logger = logging.getLogger(__name__) +SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"]) + + +@dataclass +class SentencePredictionConfig(FairseqDataclass): + data: str = field(default=MISSING, metadata={"help": "path to data directory"}) + num_classes: int = field( + default=-1, + metadata={"help": "number of classes or regression targets"}, + ) + init_token: Optional[int] = field( + default=None, + metadata={"help": "add token at the beginning of each batch item"}, + ) + separator_token: Optional[int] = field( + default=None, + metadata={"help": "add separator token between inputs"}, + ) + no_shuffle: bool = field( + default=False, + ) + shorten_method: SHORTEN_METHOD_CHOICES = field( + default="none", + metadata={ + "help": "if not none, shorten sequences that exceed tokens_per_sample" + }, + ) + shorten_data_split_list: str = field( + default="", + metadata={ + "help": "comma-separated list of dataset splits to apply shortening to, " + 'e.g., "train,valid" (default: all dataset splits)' + }, + ) + add_prev_output_tokens: bool = field( + default=False, + metadata={ + "help": "add prev_output_tokens to sample, used for encoder-decoder arch" + }, + ) + max_positions: int = field( + default=512, + metadata={"help": "max tokens per example"}, + ) + + regression_target: bool = II("criterion.regression_target") + classification_head_name: str = II("criterion.classification_head_name") + seed: int = II("common.seed") + + d2v2_multi: bool = field( + default=False, + metadata={"help": "prepare dataset for data2vec_multi"}, + ) + + +@register_task("sentence_prediction", dataclass=SentencePredictionConfig) +class SentencePredictionTask(FairseqTask): + """ + Sentence (or sentence pair) prediction (classification or regression) task. + + Args: + dictionary (Dictionary): the dictionary for the input of the task + """ + + def __init__(self, cfg, data_dictionary, label_dictionary): + super().__init__(cfg) + self.dictionary = data_dictionary + self._label_dictionary = label_dictionary + + @classmethod + def load_dictionary(cls, filename): + """Load the dictionary from the filename + + Args: + filename (str): the filename + """ + dictionary = Dictionary.load(filename) + dictionary.add_symbol("") + return dictionary + + @classmethod + def setup_task(cls, cfg, **kwargs): + assert cfg.num_classes > 0, "Must set task.num_classes" + + # load data dictionary + data_dict = cls.load_dictionary( + os.path.join(cfg.data, "input0", "dict.txt"), + ) + logger.info("[input] dictionary: {} types".format(len(data_dict))) + + # load label dictionary + if not cfg.regression_target: + label_dict = cls.load_dictionary( + os.path.join(cfg.data, "label", "dict.txt"), + ) + logger.info("[label] dictionary: {} types".format(len(label_dict))) + else: + label_dict = data_dict + return cls(cfg, data_dict, label_dict) + + def load_dataset(self, split, combine=False, **kwargs): + """Load a given dataset split (e.g., train, valid, test).""" + + def get_path(key, split): + return os.path.join(self.cfg.data, key, split) + + def make_dataset(key, dictionary): + split_path = get_path(key, split) + + try: + dataset = data_utils.load_indexed_dataset( + split_path, + dictionary, + combine=combine, + ) + except Exception as e: + if "StorageException: [404] Path not found" in str(e): + logger.warning(f"dataset {e} not found") + dataset = None + else: + raise e + return dataset + + input0 = make_dataset("input0", self.source_dictionary) + assert input0 is not None, "could not find dataset: {}".format( + get_path("input0", split) + ) + input1 = make_dataset("input1", self.source_dictionary) + + if self.cfg.init_token is not None: + input0 = PrependTokenDataset(input0, self.cfg.init_token) + + if input1 is None: + src_tokens = input0 + else: + if self.cfg.separator_token is not None: + input1 = PrependTokenDataset(input1, self.cfg.separator_token) + + src_tokens = ConcatSentencesDataset(input0, input1) + + with data_utils.numpy_seed(self.cfg.seed): + shuffle = np.random.permutation(len(src_tokens)) + + src_tokens = maybe_shorten_dataset( + src_tokens, + split, + self.cfg.shorten_data_split_list, + self.cfg.shorten_method, + self.max_positions(), + self.cfg.seed, + ) + + if self.cfg.d2v2_multi: + net_input = { + "source": RightPadDataset( + src_tokens, + pad_idx=self.source_dictionary.pad(), + ), + "id": IdDataset(), + "padding_mask": RightPaddingMaskDataset(src_tokens), + } + else: + net_input = { + "src_tokens": RightPadDataset( + src_tokens, + pad_idx=self.source_dictionary.pad(), + ), + "src_lengths": NumelDataset(src_tokens, reduce=False), + } + if self.cfg.add_prev_output_tokens: + prev_tokens_dataset = RightPadDataset( + RollDataset(src_tokens, 1), + pad_idx=self.dictionary.pad(), + ) + net_input.update( + prev_output_tokens=prev_tokens_dataset, + ) + + dataset = { + "id": IdDataset(), + "net_input": net_input, + "nsentences": NumSamplesDataset(), + "ntokens": NumelDataset(src_tokens, reduce=True), + } + + if not self.cfg.regression_target: + label_dataset = make_dataset("label", self.label_dictionary) + if label_dataset is not None: + dataset.update( + target=OffsetTokensDataset( + StripTokenDataset( + label_dataset, + id_to_strip=self.label_dictionary.eos(), + ), + offset=-self.label_dictionary.nspecial, + ) + ) + else: + label_path = "{0}.label".format(get_path("label", split)) + if os.path.exists(label_path): + + def parse_regression_target(i, line): + values = line.split() + assert ( + len(values) == self.cfg.num_classes + ), f'expected num_classes={self.cfg.num_classes} regression target values on line {i}, found: "{line}"' + return [float(x) for x in values] + + with open(label_path) as h: + dataset.update( + target=RawLabelDataset( + [ + parse_regression_target(i, line.strip()) + for i, line in enumerate(h.readlines()) + ] + ) + ) + + nested_dataset = NestedDictionaryDataset( + dataset, + sizes=[src_tokens.sizes], + ) + + if self.cfg.no_shuffle: + dataset = nested_dataset + else: + dataset = SortDataset( + nested_dataset, + # shuffle + sort_order=[shuffle], + ) + + logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset))) + + self.datasets[split] = dataset + return self.datasets[split] + + def build_model(self, cfg, from_checkpoint=False): + from fairseq import models + + with open_dict(cfg) if OmegaConf.is_config(cfg) else contextlib.ExitStack(): + cfg.max_positions = self.cfg.max_positions + + model = models.build_model(cfg, self, from_checkpoint) + + model.register_classification_head( + self.cfg.classification_head_name, + num_classes=self.cfg.num_classes, + ) + + return model + + def max_positions(self): + return self.cfg.max_positions + + @property + def source_dictionary(self): + return self.dictionary + + @property + def target_dictionary(self): + return self.dictionary + + @property + def label_dictionary(self): + return self._label_dictionary diff --git a/fairseq/tasks/sentence_prediction_adapters.py b/fairseq/tasks/sentence_prediction_adapters.py new file mode 100644 index 0000000000000000000000000000000000000000..afe556962621ba509d6a784d33c40e1c5406a6fb --- /dev/null +++ b/fairseq/tasks/sentence_prediction_adapters.py @@ -0,0 +1,56 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import contextlib +from omegaconf import open_dict, OmegaConf + +from fairseq.tasks import register_task +from fairseq.tasks.sentence_prediction import ( + SentencePredictionTask, + SentencePredictionConfig, +) + + +logger = logging.getLogger(__name__) + + +@register_task("sentence_prediction_adapters", dataclass=SentencePredictionConfig) +class SentencePredictionAdapterTask(SentencePredictionTask): + def build_model(self, cfg): + from fairseq import models + + with open_dict(cfg) if OmegaConf.is_config(cfg) else contextlib.ExitStack(): + cfg.max_positions = self.cfg.max_positions + + model = models.build_model(cfg, self) + + model.register_classification_head( + self.cfg.classification_head_name, + num_classes=self.cfg.num_classes, + ) + + logger.info("Freezing Embedding Parameters") + for parameter in model.encoder.sentence_encoder.embed_positions.parameters(): + parameter.requires_grad = False + for ( + parameter + ) in model.encoder.sentence_encoder.layernorm_embedding.parameters(): + parameter.requires_grad = False + for parameter in model.encoder.sentence_encoder.embed_tokens.parameters(): + parameter.requires_grad = False + + logger.info("Freezing Adapters") + for k, v in model.encoder.sentence_encoder.layers._modules.items(): + logger.info("Freezing Adapters in Layer " + str(k)) + if hasattr(v, "adapter_layer_norm"): + logger.info("Freezing Adapter LN") + for parameter in v.adapter_layer_norm.parameters(): + parameter.requires_grad = False + for parameter in v.adapter_modules.parameters(): + parameter.requires_grad = False + + return model diff --git a/fairseq/tasks/sentence_ranking.py b/fairseq/tasks/sentence_ranking.py new file mode 100644 index 0000000000000000000000000000000000000000..57f63aab6725922d1b07b5cf67c45a44356f454f --- /dev/null +++ b/fairseq/tasks/sentence_ranking.py @@ -0,0 +1,219 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os + +import numpy as np +from fairseq import utils +from fairseq.data import ( + ConcatSentencesDataset, + Dictionary, + IdDataset, + NestedDictionaryDataset, + NumelDataset, + NumSamplesDataset, + PrependTokenDataset, + RawLabelDataset, + RightPadDataset, + SortDataset, + TruncateDataset, + data_utils, +) +from fairseq.data.shorten_dataset import maybe_shorten_dataset +from fairseq.tasks import LegacyFairseqTask, register_task + + +logger = logging.getLogger(__name__) + + +@register_task("sentence_ranking") +class SentenceRankingTask(LegacyFairseqTask): + """ + Ranking task on multiple sentences. + + Args: + dictionary (Dictionary): the dictionary for the input of the task + """ + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument("data", metavar="FILE", help="file prefix for data") + parser.add_argument( + "--num-classes", type=int, help="number of sentences to be ranked" + ) + parser.add_argument( + "--init-token", + type=int, + help="add token at the beginning of each batch item", + ) + parser.add_argument( + "--separator-token", type=int, help="add separator token between inputs" + ) + parser.add_argument("--no-shuffle", action="store_true") + parser.add_argument( + "--shorten-method", + default="none", + choices=["none", "truncate", "random_crop"], + help="if not none, shorten sequences that exceed --tokens-per-sample", + ) + parser.add_argument( + "--shorten-data-split-list", + default="", + help="comma-separated list of dataset splits to apply shortening to, " + 'e.g., "train,valid" (default: all dataset splits)', + ) + parser.add_argument( + "--max-option-length", type=int, help="max length for each option" + ) + + def __init__(self, args, dictionary): + super().__init__(args) + self.dictionary = dictionary + + @classmethod + def load_dictionary(cls, args, filename, source=True): + """Load the dictionary from the filename + + Args: + filename (str): the filename + """ + dictionary = Dictionary.load(filename) + dictionary.add_symbol("") + return dictionary + + @classmethod + def setup_task(cls, args, **kwargs): + assert ( + args.criterion == "sentence_ranking" + ), "Must set --criterion=sentence_ranking" + + # load data dictionary + data_dict = cls.load_dictionary( + args, + os.path.join(args.data, "input0", "dict.txt"), + source=True, + ) + logger.info("[input] dictionary: {} types".format(len(data_dict))) + return SentenceRankingTask(args, data_dict) + + def load_dataset(self, split, combine=False, **kwargs): + """Load a given dataset split (e.g., train, valid, test).""" + + def get_path(type, split): + return os.path.join(self.args.data, type, split) + + def make_dataset(type, dictionary): + split_path = get_path(type, split) + + dataset = data_utils.load_indexed_dataset( + split_path, + self.source_dictionary, + self.args.dataset_impl, + combine=combine, + ) + return dataset + + input0 = make_dataset("input0", self.source_dictionary) + input_options = [ + make_dataset("input{idx}".format(idx=idx + 1), self.source_dictionary) + for idx in range(self.args.num_classes) + ] + + if self.args.separator_token is not None: + input0 = PrependTokenDataset(input0, self.args.separator_token) + + src_tokens = [] + for input_option in input_options: + if self.args.init_token is not None: + input_option = PrependTokenDataset(input_option, self.args.init_token) + if self.args.max_option_length is not None: + input_option = TruncateDataset( + input_option, self.args.max_option_length + ) + src_token = ConcatSentencesDataset(input_option, input0) + src_token = maybe_shorten_dataset( + src_token, + split, + self.args.shorten_data_split_list, + self.args.shorten_method, + self.args.max_positions, + self.args.seed, + ) + src_tokens.append(src_token) + + with data_utils.numpy_seed(self.args.seed): + shuffle = np.random.permutation(len(src_tokens[0])) + + dataset = { + "id": IdDataset(), + "nsentences": NumSamplesDataset(), + "ntokens": NumelDataset(src_tokens[0], reduce=True), + } + + for src_token_idx in range(len(src_tokens)): + dataset.update( + { + "net_input{idx}".format(idx=src_token_idx + 1): { + "src_tokens": RightPadDataset( + src_tokens[src_token_idx], + pad_idx=self.source_dictionary.pad(), + ), + "src_lengths": NumelDataset( + src_tokens[src_token_idx], reduce=False + ), + } + } + ) + + label_path = "{}.label".format(get_path("label", split)) + if os.path.exists(label_path): + with open(label_path) as h: + dataset.update( + target=RawLabelDataset([int(x.strip()) for x in h.readlines()]) + ) + + nested_dataset = NestedDictionaryDataset( + dataset, + sizes=[np.maximum.reduce([src_token.sizes for src_token in src_tokens])], + ) + + if self.args.no_shuffle: + dataset = nested_dataset + else: + dataset = SortDataset( + nested_dataset, + # shuffle + sort_order=[shuffle], + ) + + logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset))) + + self.datasets[split] = dataset + return self.datasets[split] + + def build_model(self, args, from_checkpoint=False): + from fairseq import models + + model = models.build_model(args, self, from_checkpoint) + + model.register_classification_head( + getattr(args, "ranking_head_name", "sentence_classification_head"), + num_classes=1, + ) + + return model + + def max_positions(self): + return self.args.max_positions + + @property + def source_dictionary(self): + return self.dictionary + + @property + def target_dictionary(self): + return self.dictionary diff --git a/fairseq/tasks/simultaneous_translation.py b/fairseq/tasks/simultaneous_translation.py new file mode 100644 index 0000000000000000000000000000000000000000..9576b26801dcc9c1433b0e8632926117c0d50aea --- /dev/null +++ b/fairseq/tasks/simultaneous_translation.py @@ -0,0 +1,41 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from fairseq.tasks import register_task +from fairseq.tasks.speech_to_text import SpeechToTextTask +from fairseq.tasks.translation import TranslationTask, TranslationConfig + +try: + import examples.simultaneous_translation # noqa + + import_successful = True +except BaseException: + import_successful = False + + +logger = logging.getLogger(__name__) + + +def check_import(flag): + if not flag: + raise ImportError( + "'examples.simultaneous_translation' is not correctly imported. " + "Please considering `pip install -e $FAIRSEQ_DIR`." + ) + + +@register_task("simul_speech_to_text") +class SimulSpeechToTextTask(SpeechToTextTask): + def __init__(self, args, tgt_dict): + check_import(import_successful) + super().__init__(args, tgt_dict) + + +@register_task("simul_text_to_text", dataclass=TranslationConfig) +class SimulTextToTextTask(TranslationTask): + def __init__(self, cfg, src_dict, tgt_dict): + check_import(import_successful) + super().__init__(cfg, src_dict, tgt_dict) diff --git a/fairseq/tasks/span_masked_lm.py b/fairseq/tasks/span_masked_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..d746aa154c815b339451e04131642dc9d419bb2a --- /dev/null +++ b/fairseq/tasks/span_masked_lm.py @@ -0,0 +1,243 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from dataclasses import dataclass, field +from typing import Optional + +import numpy as np +from omegaconf import II, MISSING + +from fairseq import utils +from fairseq.data import ( + AppendTokenDataset, + Dictionary, + IdDataset, + NestedDictionaryDataset, + NumelDataset, + PadDataset, + PrependTokenDataset, + StripTokenDataset, + TokenBlockDataset, + data_utils, +) +from fairseq.data.shorten_dataset import maybe_shorten_dataset +from fairseq.data.span_mask_tokens_dataset import SpanMaskedTokensDataset +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.tasks import FairseqTask, register_task + +from ..data.indexed_dataset import get_available_dataset_impl + +logger = logging.getLogger(__name__) + +SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"]) +SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"]) + + +@dataclass +class SpanMaskedLMConfig(FairseqDataclass): + shuffle: bool = field( + default=False, + ) + noise_density: float = field( + default=0.15, + metadata={"help": "What fraction of the tokens to select as noise"}, + ) + mean_noise_span_length: float = field( + default=3, + metadata={"help": "Mean noise span length, must be >= 1"}, + ) + data: str = field( + default=MISSING, + metadata={ + "help": "colon separated path to data directories list, " + "will be iterated upon during epochs in round-robin manner" + }, + ) + sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field( + default="none", + metadata={ + "help": 'If omitted or "none", fills each sample with tokens-per-sample ' + 'tokens. If set to "complete", splits samples only at the end ' + "of sentence, but may include multiple sentences per sample. " + '"complete_doc" is similar but respects doc boundaries. ' + 'If set to "eos", includes only one sentence per sample.' + }, + ) + tokens_per_sample: int = field( + default=1024, + metadata={"help": "max number of tokens per sample for LM dataset"}, + ) + shorten_method: SHORTEN_METHOD_CHOICES = field( + default="none", + metadata={ + "help": "if not none, shorten sequences that exceed --tokens-per-sample" + }, + ) + shorten_data_split_list: str = field( + default="", + metadata={ + "help": "comma-separated list of dataset splits to apply shortening to, " + 'e.g., "train,valid" (default: all dataset splits)' + }, + ) + seed: int = II("common.seed") + dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II( + "dataset.dataset_impl" + ) + max_source_positions: int = field( + default=1024, metadata={"help": "max number of tokens in the source sequence"} + ) + max_target_positions: int = field( + default=1024, metadata={"help": "max number of tokens in the target sequence"} + ) + include_target_tokens: bool = field( + default=False, + metadata={ + "help": "include target tokens in model input. this is used for data2vec" + }, + ) + + +@register_task("span_masked_lm", dataclass=SpanMaskedLMConfig) +class SpanMaskedLMTask(FairseqTask): + """ + Span masked language modeling task. (ie. T5) + """ + + cfg: SpanMaskedLMConfig + + def __init__(self, cfg, dictionary): + super().__init__(cfg) + self.dictionary = dictionary + + @classmethod + def setup_task(cls, cfg: SpanMaskedLMConfig, **kwargs): + """Setup the task.""" + paths = utils.split_paths(cfg.data) + assert len(paths) > 0 + dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) + logger.info("dictionary: {} types".format(len(dictionary))) + if not hasattr(cfg, "shuffle"): + cfg.shuffle = False + return cls(cfg, dictionary) + + def _load_dataset_split(self, split, epoch, combine): + paths = utils.split_paths(self.cfg.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + split_path = os.path.join(data_path, split) + + dataset = data_utils.load_indexed_dataset( + split_path, + self.dictionary, + self.cfg.dataset_impl, + combine=combine, + ) + if dataset is None: + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, split_path) + ) + + dataset = StripTokenDataset(dataset, self.dictionary.eos()) + + dataset = maybe_shorten_dataset( + dataset, + split, + self.cfg.shorten_data_split_list, + self.cfg.shorten_method, + self.cfg.tokens_per_sample, + self.cfg.seed, + ) + + # create continuous blocks of tokens + dataset = TokenBlockDataset( + dataset, + dataset.sizes, + self.cfg.tokens_per_sample - 2, # one less for and one for + pad=self.dictionary.pad(), + eos=self.dictionary.eos(), + break_mode=self.cfg.sample_break_mode, + document_sep_len=0, + ) + logger.info("loaded {} blocks from: {}".format(len(dataset), split_path)) + + # prepend beginning-of-sentence token (, equiv. to [CLS] in BERT) + dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) + dataset = AppendTokenDataset(dataset, self.source_dictionary.eos()) + return dataset + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + dataset = self._load_dataset_split(split, epoch, combine) + + self.datasets[split] = SpanMaskedTokensDataset( + dataset, + self.dictionary, + noise_density=self.cfg.noise_density, + mean_noise_span_length=self.cfg.mean_noise_span_length, + shuffle=self.cfg.shuffle, + seed=self.cfg.seed, + ) + logger.info( + "Split: {0}, Loaded {1} samples of span_masked_tokens_dataset".format( + split, + len(self.datasets[split]), + ) + ) + + def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): + """ + Generate batches for inference. We assume that the input begins with a + bos symbol (``) and ends with an eos symbol (``). + """ + pad = self.source_dictionary.pad() + eos = self.source_dictionary.eos() + src_dataset = TokenBlockDataset( + src_tokens, + src_lengths, + block_size=self.cfg.tokens_per_sample - 2, # for and + pad=pad, + eos=eos, + break_mode=self.cfg.sample_break_mode, + document_sep_len=0, + ) + prev_output_tokens = PrependTokenDataset( + StripTokenDataset(src_dataset, eos), eos + ) + src_dataset = PadDataset(src_dataset, pad_idx=pad, left_pad=False) + return NestedDictionaryDataset( + { + "id": IdDataset(), + "net_input": { + "src_tokens": src_dataset, + "src_lengths": NumelDataset(src_dataset, reduce=False), + "prev_output_tokens": PadDataset( + prev_output_tokens, pad_idx=pad, left_pad=False + ), + }, + "target": src_dataset, + }, + sizes=[np.array(src_lengths)], + ) + + def max_positions(self): + """Return the max sentence length allowed by the task.""" + return (self.cfg.max_source_positions, self.cfg.max_target_positions) + + @property + def source_dictionary(self): + """Return the source :class:`~fairseq.data.Dictionary`.""" + return self.dictionary + + @property + def target_dictionary(self): + """Return the target :class:`~fairseq.data.Dictionary`.""" + return self.dictionary diff --git a/fairseq/tasks/speech_dlm_task.py b/fairseq/tasks/speech_dlm_task.py new file mode 100644 index 0000000000000000000000000000000000000000..340732b928122356ab2183d050b504da1773e91a --- /dev/null +++ b/fairseq/tasks/speech_dlm_task.py @@ -0,0 +1,561 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from dataclasses import dataclass, field +from typing import Optional +from collections import OrderedDict + +import numpy as np +import torch +from fairseq import utils +from fairseq.data import ( + AppendTokenDataset, + Dictionary, + IdDataset, + LMContextWindowDataset, + MonolingualDataset, + NestedDictionaryDataset, + NumelDataset, + PadDataset, + PrependTokenDataset, + SpeechDLMDataset, + StripTokenDataset, + TokenBlockDataset, + TruncatedDictionary, + data_utils, +) +from fairseq.data.indexed_dataset import get_available_dataset_impl +from fairseq.data.shorten_dataset import maybe_shorten_dataset +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.tasks import LegacyFairseqTask, register_task +from omegaconf import II + + +SAMPLE_BREAK_MODE_CHOICES = ChoiceEnum(["none", "complete", "complete_doc", "eos"]) +SHORTEN_METHOD_CHOICES = ChoiceEnum(["none", "truncate", "random_crop"]) +logger = logging.getLogger(__name__) + + +@dataclass +class SpeechDLMConfig(FairseqDataclass): + data: Optional[str] = field( + default=None, metadata={"help": "path to data directory"} + ) + channels: Optional[str] = field( + default=None, + metadata={ + "help": 'comma-separated list of channels to load e.g., "unitA,unitB"' + "(default: load all possible channels in the data path)" + }, + ) + channel_weights: Optional[str] = field( + default=None, + metadata={ + "help": "comma-separated list of weights for different losses" + "(default: None, which means all losses are treated equally)" + }, + ) + sample_break_mode: SAMPLE_BREAK_MODE_CHOICES = field( + default="none", + metadata={ + "help": 'If omitted or "none", fills each sample with tokens-per-sample ' + 'tokens. If set to "complete", splits samples only at the end ' + "of sentence, but may include multiple sentences per sample. " + '"complete_doc" is similar but respects doc boundaries. ' + 'If set to "eos", includes only one sentence per sample.' + }, + ) + tokens_per_sample: int = field( + default=1024, + metadata={"help": "max number of tokens per sample for LM dataset"}, + ) + output_dictionary_size: int = field( + default=-1, metadata={"help": "limit the size of output dictionary"} + ) + # str type is a workaround to put **default=True** here + next_unit_prediction: str = field( + default="False", + metadata={ + "help": "Perform Next Unit Prediction, expected str input ('True' or 'False')" + }, + ) + edge_unit_prediction: str = field( + default="True", + metadata={ + "help": "Perform Edge Unit Prediction, expected str input ('True' or 'False')" + }, + ) + duration_prediction: str = field( + default="True", + metadata={ + "help": "Perform Duration Prediction, expected str input ('True' or 'False')" + }, + ) + delayed_duration_target: str = field( + default="True", + metadata={ + "help": "Perform Delayed Duration Prediction, expected str input ('True' or 'False')" + "(default: 'True')" + }, + ) + max_target_durations: Optional[int] = field( + default=256, + metadata={"help": "max duration considered (cut off to this value)"}, + ) + add_bos_token: bool = field( + default=False, metadata={"help": "prepend beginning of sentence token ()"} + ) + max_target_positions: Optional[int] = field( + default=None, metadata={"help": "max number of tokens in the target sequence"} + ) + shorten_method: SHORTEN_METHOD_CHOICES = field( + default="none", + metadata={ + "help": "if not none, shorten sequences that exceed --tokens-per-sample" + }, + ) + shorten_data_split_list: str = field( + default="", + metadata={ + "help": "comma-separated list of dataset splits to apply shortening to, " + 'e.g., "train,valid" (default: all dataset splits)' + }, + ) + # TODO common vars below add to parent + seed: int = II("common.seed") + dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II( + "dataset.dataset_impl" + ) + data_buffer_size: int = II("dataset.data_buffer_size") + tpu: bool = II("common.tpu") + + +@register_task("speech_dlm_task", dataclass=SpeechDLMConfig) +class SpeechDLMTask(LegacyFairseqTask): + """Task for the SpeechDLM model as described in the paper: + https://arxiv.org/pdf/2203.16502.pdf + + It create a multi-channel dataset (SpeechDLMDataset) from multiple + dictionaries. + + Args: + dictionaries (Dict[str, ~fairseq.data.Dictionary]): the dictionaries for + each input channel of the SpeechDLM model + output_dictionaries (Dict[str, ~fairseq.data.Dictionary]): the dictionaries + for the output of each channel of the SpeechDLM model. In most cases it + will be the same as *dictionaries*. + targets (List[str]): list of the target types that the SpeechDLM model + should predict. Can be one of "next", "edge", "duration". + Defaults to "next". + + .. note:: + + The SpeechDLM task is only compatible with + :mod:`fairseq-train` and :mod:`fairseq-validate`. + To generate new samples, please refer to example codes + at examples/textless_nlp/dgslm . + """ + + def __init__(self, args, dicts, output_dicts=None, targets=None): + super().__init__(args) + self.dicts = dicts + self.output_dicts = output_dicts or dicts + + if targets is None: + targets = ["next"] + self.targets = targets + + self.channels = list(dicts.keys()) + + if args.channel_weights is not None: + self.channel_weights = [float(w) for w in args.channel_weights.split(",")] + else: + self.channel_weights = [1.0 for _ in self.channels] + assert len(self.channel_weights) == len( + self.channels + ), "number of channel_weights must be the same as number of channels" + + assert str(args.next_unit_prediction).lower() in [ + "true", + "false", + ], f"Expected to be a string of boolean, found {args.next_unit_prediction}" + assert str(args.edge_unit_prediction).lower() in [ + "true", + "false", + ], f"Expected to be a string of boolean, found {args.edge_unit_prediction}" + assert str(args.duration_prediction).lower() in [ + "true", + "false", + ], f"Expected to be a string of boolean, found {args.duration_prediction}" + assert str(args.delayed_duration_target).lower() in [ + "true", + "false", + ], f"Expected to be a string of boolean, found {args.delayed_duration_target}" + self.next_unit_prediction = bool( + str(args.next_unit_prediction).lower() == "true" + ) + self.edge_unit_prediction = bool( + str(args.edge_unit_prediction).lower() == "true" + ) + self.duration_prediction = bool(str(args.duration_prediction).lower() == "true") + self.delayed_duration_target = bool( + str(args.delayed_duration_target).lower() == "true" + ) + + self.max_target_durations = args.max_target_durations + + @classmethod + def setup_dictionary(cls, args, **kwargs): + """The dictionaries will be a dict over channel keys and values of type + ~fairseq.data.Dictionary. + """ + paths = utils.split_paths(args.data) + assert len(paths) > 0 + data_path = paths[0] + + dicts = None + output_dicts = None + if args.channels is None: + sorted_channels = sorted( + name[5:-4] + for name in os.listdir(data_path) + if name[:5] == "dict." and name[-4:] == ".txt" + ) + else: + sorted_channels = sorted(args.channels.split(",")) + logger.info("channels: {}".format(sorted_channels)) + # load dictionaries + dicts = OrderedDict() + output_dicts = OrderedDict() + for channel in sorted_channels: + dictionary = Dictionary.load( + os.path.join(data_path, "dict.{}.txt".format(channel)) + ) + logger.info("[{}] dictionary: {} types".format(channel, len(dictionary))) + output_dictionary = dictionary + if args.output_dictionary_size >= 0: + output_dictionary = TruncatedDictionary( + dictionary, args.output_dictionary_size + ) + dicts[channel] = dictionary + output_dicts[channel] = output_dictionary + if len(dicts) > 0: + assert dicts[channel].pad() == dicts[sorted_channels[0]].pad() + assert dicts[channel].bos() == dicts[sorted_channels[0]].bos() + assert dicts[channel].eos() == dicts[sorted_channels[0]].eos() + assert dicts[channel].unk() == dicts[sorted_channels[0]].unk() + return (dicts, output_dicts) + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + dicts, output_dicts = cls.setup_dictionary(args, **kwargs) + + targets = [] + if str(getattr(args, "next_unit_prediction", "false")).lower() == "true": + targets.append("next") + if str(getattr(args, "edge_unit_prediction", "false")).lower() == "true": + targets.append("edge") + if str(getattr(args, "duration_prediction", "false")).lower() == "true": + targets.append("duration") + if len(targets) == 0: + # standard language modeling + targets = ["next"] + + return cls(args, dicts, output_dicts, targets=targets) + + def build_model(self, args): + model = super().build_model(args) + for target in self.targets: + if target not in model.supported_targets: + raise ValueError("Unsupported SpeechDLM target: {}".format(target)) + return model + + def load_dataset( + self, split: str, epoch=1, combine=False, **kwargs + ) -> SpeechDLMDataset: + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + paths = utils.split_paths(self.args.data) + assert len(paths) > 0 + + data_path = paths[(epoch - 1) % len(paths)] + + channel_datasets = {} + for channel in self.channels: + split_path = os.path.join(data_path, split + "." + channel) + dictionary = self.dicts[channel] + output_dictionary = self.output_dicts[channel] + + dataset = data_utils.load_indexed_dataset( + split_path, dictionary, self.args.dataset_impl, combine=combine + ) + + if dataset is None: + raise FileNotFoundError( + "[{}] Dataset not found: {} ({})".format(channel, split, split_path) + ) + + dataset = maybe_shorten_dataset( + dataset, + split, + self.args.shorten_data_split_list, + self.args.shorten_method, + self.args.tokens_per_sample, + self.args.seed, + ) + + dataset = TokenBlockDataset( + dataset, + dataset.sizes, + self.args.tokens_per_sample, + pad=dictionary.pad(), + eos=dictionary.eos(), + break_mode=self.args.sample_break_mode, + include_targets=True, + ) + + add_eos_for_other_targets = ( + self.args.sample_break_mode is not None + and self.args.sample_break_mode != "none" + ) + + channel_datasets[channel] = MonolingualDataset( + dataset=dataset, + sizes=dataset.sizes, + src_vocab=dictionary, + tgt_vocab=output_dictionary, + add_eos_for_other_targets=add_eos_for_other_targets, + shuffle=False, + targets=["future"], + add_bos_token=self.args.add_bos_token, + ) + + self.datasets[split] = SpeechDLMDataset( + datasets=channel_datasets, + targets=self.targets, + max_target_durations=self.max_target_durations, + shuffle=True, + ) + + def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): + """ + Generate batches for inference. We prepend an eos token to src_tokens + (or bos if `--add-bos-token` is set) and we append a to target. + This is convenient both for generation with a prefix and LM scoring. + """ + src_datasets = {} + tgt_datasets = {} + for channel in src_tokens[0]: + dataset = StripTokenDataset( + TokenBlockDataset( + [src_tokens[i][channel] for i in range(len(src_tokens))], + src_lengths, + block_size=None, # ignored for "eos" break mode + pad=self.source_dictionaries[channel].pad(), + eos=self.source_dictionaries[channel].eos(), + break_mode="eos", + ), + # remove eos from (end of) target sequence + self.source_dictionaries[channel].eos(), + ) + src_dataset = PrependTokenDataset( + dataset, + token=( + self.source_dictionaries[channel].bos() + if getattr(self.args, "add_bos_token", False) + else self.source_dictionaries[channel].eos() + ), + ) + tgt_dataset = AppendTokenDataset( + dataset, token=self.source_dictionaries[channel].pad() + ) + + src_datasets[channel] = src_dataset + tgt_datasets[channel] = tgt_dataset + + return NestedDictionaryDataset( + { + "id": IdDataset(), + "net_input": { + "src_tokens": OrderedDict( + [ + ( + channel, + PadDataset( + src_datasets[channel], + pad_idx=self.source_dictionaries[channel].pad(), + left_pad=False, + ), + ) + for channel in src_datasets + ] + ), + "src_lengths": NumelDataset( + next(iter(src_datasets.values())), reduce=False + ), + }, + "target": OrderedDict( + [ + ( + channel, + PadDataset( + tgt_datasets[channel], + pad_idx=self.source_dictionaries[channel].pad(), + left_pad=False, + ), + ) + for channel in tgt_datasets + ] + ), + }, + sizes=[np.array(src_lengths)], + ) + + def inference_step( + self, generator, models, sample, prefix_tokens=None, constraints=None + ): + with torch.no_grad(): + # Generation will always be conditioned on bos_token + if getattr(self.args, "add_bos_token", False): + bos_token = self.source_dictionary.bos() + else: + bos_token = self.source_dictionary.eos() + + if constraints is not None: + raise NotImplementedError( + "Constrained decoding with the SpeechDLM task is not supported" + ) + # SequenceGenerator doesn't use src_tokens directly, we need to + # pass the `prefix_tokens` argument instead + if prefix_tokens is None: + prefix_tokens = {} + for channel in sample["net_input"]["src_tokens"]: + if sample["net_input"]["src_tokens"][channel].nelement(): + prefix_tokens_channel = sample["net_input"]["src_tokens"][ + channel + ] + if prefix_tokens_channel[:, 0].eq(bos_token).all(): + prefix_tokens_channel = prefix_tokens_channel[:, 1:] + prefix_tokens[channel] = prefix_tokens_channel + else: + prefix_tokens = None + break + return generator.generate( + models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token + ) + + def eval_lm_dataloader( + self, + dataset, + max_tokens: Optional[int] = 36000, + batch_size: Optional[int] = None, + max_positions: Optional[int] = None, + num_shards: int = 1, + shard_id: int = 0, + num_workers: int = 1, + data_buffer_size: int = 10, + # ensures that every evaluated token has access to a context of at least + # this size, if possible + context_window: int = 0, + ): + if context_window > 0: + dataset = LMContextWindowDataset( + dataset=dataset, + tokens_per_sample=self.args.tokens_per_sample, + context_window=context_window, + pad_idx=self.source_dictionary.pad(), + ) + return self.get_batch_iterator( + dataset=dataset, + max_tokens=max_tokens, + max_sentences=batch_size, + max_positions=max_positions, + ignore_invalid_inputs=True, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + data_buffer_size=data_buffer_size, + ).next_epoch_itr(shuffle=False) + + @property + def source_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return self.dicts[self.channels[0]] + + @property + def target_dictionary(self): + """Return the :class:`~fairseq.data.Dictionary` for the language + model.""" + return self.output_dicts[self.channels[0]] + + @property + def source_dictionaries(self): + """Return the dict of :class:`~fairseq.data.Dictionary` for the + multichannel language model.""" + return self.dicts + + @property + def target_dictionaries(self): + """Return the dict of :class:`~fairseq.data.Dictionary` for the + multichannel language model.""" + return self.output_dicts + + def build_generator(self, models, args, extra_gen_cls_kwargs=None): + + from fairseq.models.speech_dlm.sequence_generator import ( + multichannel_search, + MultichannelSequenceGenerator, + ) + + # Choose search strategy. Defaults to Beam Search. + sampling = getattr(args, "sampling", False) + sampling_topk = getattr(args, "sampling_topk", -1) + sampling_topp = getattr(args, "sampling_topp", -1.0) + assert ( + sampling_topk < 0 or sampling + ), "--sampling-topk requires sampling (not beam search)" + assert ( + sampling_topp < 0 or sampling + ), "--sampling-topp requires sampling (not beam search)" + + if sampling: + search_strategy = multichannel_search.ContiguousMultichannelSampling( + self.target_dictionaries, sampling_topk, sampling_topp + ) + else: + search_strategy = multichannel_search.ContiguousMultichannelBeamSearch( + self.target_dictionaries + ) + + extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} + + return MultichannelSequenceGenerator( + models, + self.target_dictionaries, + beam_size=getattr(args, "beam", 5), + max_len_a=getattr(args, "max_len_a", 0), + max_len_b=getattr(args, "max_len_b", 500), + min_len=getattr(args, "min_len", 1), + normalize_scores=(not getattr(args, "unnormalized", False)), + len_penalty=getattr(args, "lenpen", 1), + unk_penalty=getattr(args, "unkpen", 0), + temperature=getattr(args, "temperature", 1.0), + match_source_len=getattr(args, "match_source_len", False), + no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), + search_strategy=search_strategy, + duration_temperature=getattr(args, "duration_temperature", 1.0), + **extra_gen_cls_kwargs, + ) diff --git a/fairseq/tasks/speech_to_speech.py b/fairseq/tasks/speech_to_speech.py new file mode 100644 index 0000000000000000000000000000000000000000..5aaaa95a90cce0ae3d4bc8cf9b79b312dd342b3f --- /dev/null +++ b/fairseq/tasks/speech_to_speech.py @@ -0,0 +1,597 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json +import logging +import math +from argparse import Namespace +from pathlib import Path +from typing import List + +import torch +import torch.nn as nn + +from fairseq import utils +from fairseq.data import Dictionary +from fairseq.data.audio.data_cfg import MultitaskConfig, S2SDataConfig +from fairseq.data.audio.speech_to_speech_dataset import SpeechToSpeechDatasetCreator +from fairseq.data.audio.speech_to_text_dataset import ( + SpeechToTextDataset, + TextTargetMultitaskData, +) +from fairseq.tasks import LegacyFairseqTask, register_task +from fairseq.tasks.speech_to_text import DummyMultiTask +from fairseq.tasks.text_to_speech import batch_mel_cepstral_distortion + +logger = logging.getLogger(__name__) + + +class StackUnitSequenceGenerator(nn.Module): + def __init__(self, tgt_dict, vocab_size): + super().__init__() + self.pad = tgt_dict.pad() + self.eos = tgt_dict.eos() + self.unk = tgt_dict.unk() + self.offset = len(tgt_dict) - vocab_size + self.vocab_size = vocab_size + + def pack_units(self, input: torch.Tensor, n_frames_per_step) -> torch.Tensor: + if n_frames_per_step <= 1: + return input + + bsz, _, n = input.shape + assert n == n_frames_per_step + + scale = [ + pow(self.vocab_size, n_frames_per_step - 1 - i) + for i in range(n_frames_per_step) + ] + scale = torch.LongTensor(scale).squeeze(0).to(input.device) + mask = input >= self.offset + res = ((input - self.offset) * scale * mask).sum(dim=2) + self.offset + return res + + @torch.no_grad() + def generate(self, models, sample, **kwargs): + # currently only support viterbi search for stacked units + model = models[0] + model.eval() + + max_len = model.max_decoder_positions() + # TODO: incorporate max_len_a and max_len_b + + src_tokens = sample["net_input"]["src_tokens"] + src_lengths = sample["net_input"]["src_lengths"] + bsz, src_len, _ = src_tokens.size() + n_frames_per_step = model.decoder.n_frames_per_step + + # initialize + encoder_out = model.forward_encoder( + src_tokens, src_lengths, speaker=sample["speaker"] + ) + incremental_state = {} + pred_out, attn, scores = [], [], [] + finished = src_tokens.new_zeros((bsz,)).bool() + + prev_output_tokens = src_lengths.new_zeros((bsz, 1)).long().fill_(self.eos) + for _ in range(max_len): + cur_out, cur_extra = model.forward_decoder( + prev_output_tokens, + encoder_out=encoder_out, + incremental_state=incremental_state, + ) + + lprobs = model.get_normalized_probs([cur_out], log_probs=True) + # never select pad, unk + lprobs[:, :, self.pad] = -math.inf + lprobs[:, :, self.unk] = -math.inf + + cur_pred_lprob, cur_pred_out = torch.max(lprobs, dim=2) + scores.append(cur_pred_lprob) + pred_out.append(cur_pred_out) + + prev_output_tokens = torch.cat( + ( + prev_output_tokens, + self.pack_units( + cur_pred_out.view(bsz, 1, n_frames_per_step), n_frames_per_step + ), + ), + dim=1, + ) + + attn.append(cur_extra["attn"][0]) + + cur_finished = torch.any(cur_pred_out.squeeze(1) == self.eos, dim=1) + finished = finished | cur_finished + if finished.sum().item() == bsz: + break + + pred_out = torch.cat(pred_out, dim=1).view(bsz, -1) + attn = torch.cat(attn, dim=2) + alignment = attn.max(dim=1)[1] + attn = attn.repeat_interleave(n_frames_per_step, dim=2) + alignment = alignment.repeat_interleave(n_frames_per_step, dim=1) + scores = torch.cat(scores, dim=1) + eos_idx = (pred_out == self.eos).nonzero(as_tuple=True) + out_lens = src_lengths.new_zeros((bsz,)).long().fill_(max_len) + for b, l in zip(eos_idx[0], eos_idx[1]): + out_lens[b] = min(l, out_lens[b]) + + hypos = [ + [ + { + "tokens": pred_out[b, :out_len], + "attn": attn[b, :, :out_len], + "alignment": alignment[b, :out_len], + "positional_scores": scores[b, :out_len], + "score": utils.item(scores[b, :out_len].sum().data), + } + ] + for b, out_len in zip(range(bsz), out_lens) + ] + + return hypos + + +@register_task("speech_to_speech") +class SpeechToSpeechTask(LegacyFairseqTask): + @classmethod + def add_args(cls, parser): + parser.add_argument("data", help="manifest root path") + parser.add_argument( + "--config-yaml", + type=str, + default="config.yaml", + help="Configuration YAML filename (under manifest root)", + ) + parser.add_argument( + "--multitask-config-yaml", + type=str, + default=None, + help="Configuration YAML filename for the multitasks (under manifest root)", + ) + parser.add_argument( + "--max-source-positions", + default=6000, + type=int, + metavar="N", + help="max number of tokens in the source sequence", + ) + parser.add_argument( + "--max-target-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the target sequence", + ) + parser.add_argument( + "--target-is-code", + action="store_true", + help="set if target is discrete unit instead of spectrogram", + ) + parser.add_argument( + "--target-code-size", type=int, default=None, help="# discrete units" + ) + parser.add_argument( + "--n-frames-per-step", + type=int, + default=1, + help="# stacked frames, use 0 for reduced discrete unit sequence", + ) + parser.add_argument("--eval-inference", action="store_true") + parser.add_argument( + "--eval-args", + type=str, + default="{}", + help='generation args for speech-to-unit model , e.g., \'{"beam": 5, "max_len_a": 1}\', as JSON string', + ) + parser.add_argument("--eos-prob-threshold", type=float, default=0.5) + parser.add_argument( + "--mcd-normalize-type", + type=str, + default="targ", + choices=["targ", "pred", "path"], + ) + parser.add_argument( + "--vocoder", + type=str, + default="griffin_lim", + choices=["griffin_lim", "hifigan", "code_hifigan"], + ) + parser.add_argument("--spec-bwd-max-iter", type=int, default=8) + parser.add_argument( + "--infer-target-lang", + type=str, + default="", + help="target language for inference", + ) + + def __init__(self, args, tgt_dict, infer_tgt_lang_id=None): + super().__init__(args) + self.tgt_dict = tgt_dict + self.data_cfg = S2SDataConfig(Path(args.data) / args.config_yaml) + + self.multitask_tasks = {} + self.tgt_dict_mt = None + self.eos_token_mt = None + if getattr(args, "multitask_config_yaml", None) is not None: + multitask_cfg = MultitaskConfig( + Path(args.data) / args.multitask_config_yaml + ) + first_pass_task_idx = multitask_cfg.first_pass_decoder_task_index + for i, (task_name, task_config) in enumerate( + multitask_cfg.get_all_tasks().items() + ): + task_obj = DummyMultiTask( + task_config, + task_config.tgt_dict, + first_pass=i == first_pass_task_idx, + ) + self.multitask_tasks[task_name] = task_obj + if task_obj.is_first_pass_decoder: + self.tgt_dict_mt = task_obj.target_dictionary + if task_config.prepend_bos_and_append_tgt_lang_tag: + self.eos_token_mt = task_config.eos_token + assert not isinstance(self.eos_token_mt, List) + + if not self.eos_token_mt: + raise Warning( + "Please provide eos_token in --multitask-config-yaml to replace eos in sequence generator" + ) + + self._infer_tgt_lang_id = infer_tgt_lang_id + + @classmethod + def setup_task(cls, args, **kwargs): + data_cfg = data_cfg = S2SDataConfig(Path(args.data) / args.config_yaml) + tgt_dict = None + infer_tgt_lang_id = None + if args.target_is_code: + if data_cfg.prepend_tgt_lang_tag_as_bos: + # dictionary with language tags + dict_path = Path(args.data) / data_cfg.vocab_filename + if not dict_path.is_file(): + raise FileNotFoundError( + f"Dict has to be provided when setting prepend_tgt_lang_tag_as_bos: true, but dict not found: {dict_path}" + ) + tgt_dict = Dictionary.load(dict_path.as_posix()) + + # target langauge for inference + if args.infer_target_lang != "": + tgt_lang_tag = SpeechToTextDataset.LANG_TAG_TEMPLATE.format( + args.infer_target_lang + ) + infer_tgt_lang_id = tgt_dict.index(tgt_lang_tag) + assert infer_tgt_lang_id != tgt_dict.unk() + else: + assert args.target_code_size is not None + + tgt_dict = Dictionary() + for i in range(args.target_code_size): + tgt_dict.add_symbol(str(i)) + logger.info(f"dictionary size: " f"{len(tgt_dict):,}") + + if getattr(args, "train_subset", None) is not None: + if not all(s.startswith("train") for s in args.train_subset.split(",")): + raise ValueError('Train splits should be named like "train*".') + + assert args.n_frames_per_step >= 1 + assert ( + not args.eval_inference + or (args.target_is_code and args.vocoder == "code_hifigan") + or (not args.target_is_code and args.vocoder != "code_hifigan") + ) + + return cls(args, tgt_dict, infer_tgt_lang_id=infer_tgt_lang_id) + + def build_criterion(self, args): + from fairseq import criterions + + if len(self.multitask_tasks) > 0: + if self.args.target_is_code and not args._name.startswith("speech_to_unit"): + raise ValueError( + "set --criterion speech_to_unit for speech-to-unit loss with multitask" + ) + elif not self.args.target_is_code and not args._name.startswith( + "speech_to_spectrogram" + ): + raise ValueError( + "set --criterion speech_to_spectrogram for speech-to-spectrogram loss with multitask" + ) + + return criterions.build_criterion(args, self) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + self.datasets[split] = SpeechToSpeechDatasetCreator.from_tsv( + root=self.args.data, + data_cfg=self.data_cfg, + splits=split, + is_train_split=split.startswith("train"), + epoch=epoch, + seed=self.args.seed, + target_is_code=self.args.target_is_code, + tgt_dict=self.target_dictionary, + n_frames_per_step=self.args.n_frames_per_step, + multitask=self.multitask_tasks, + ) + + @property + def target_dictionary(self): + return self.tgt_dict + + @property + def target_dictionary_mt(self): + return self.tgt_dict_mt + + @property + def source_dictionary(self): + return None + + def max_positions(self): + return self.args.max_source_positions, self.args.max_target_positions + + def build_model(self, args, from_checkpoint=False): + args.input_feat_per_channel = self.data_cfg.input_feat_per_channel + args.input_channels = self.data_cfg.input_transformed_channels + args.target_speaker_embed = self.data_cfg.target_speaker_embed is not None + args.n_frames_per_step = self.args.n_frames_per_step + + model = super().build_model(args, from_checkpoint) + + if len(self.multitask_tasks) > 0: + from fairseq.models.speech_to_speech.s2s_transformer import ( + S2STransformerMultitaskModelBase, + ) + + assert isinstance(model, S2STransformerMultitaskModelBase) + + if self.args.eval_inference: + self.eval_gen_args = json.loads(self.args.eval_args) + self.generator = self.build_generator( + [model], Namespace(**self.eval_gen_args) + ) + + return model + + def build_generator_dual_decoder( + self, + models, + args, + extra_gen_cls_kwargs=None, + ): + from examples.speech_to_speech.unity.sequence_generator_multi_decoder import ( + MultiDecoderSequenceGenerator, + ) + + return MultiDecoderSequenceGenerator( + models, + self.target_dictionary, + self.target_dictionary_mt, + beam_size=max(1, getattr(args, "beam", 1)), + beam_size_mt=max(1, getattr(args, "beam_mt", 1)), + max_len_a=getattr(args, "max_len_a", 0), + max_len_b=getattr(args, "max_len_b", 200), + max_len_a_mt=getattr(args, "max_len_a_mt", 0), + max_len_b_mt=getattr(args, "max_len_b_mt", 200), + min_len=getattr(args, "min_len", 1), + normalize_scores=(not getattr(args, "unnormalized", False)), + len_penalty=getattr(args, "lenpen", 1), + unk_penalty=getattr(args, "unkpen", 0), + temperature=getattr(args, "temperature", 1.0), + match_source_len=getattr(args, "match_source_len", False), + no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), + **extra_gen_cls_kwargs, + ) + + def build_generator( + self, + models, + args, + seq_gen_cls=None, + extra_gen_cls_kwargs=None, + ): + + if not self.args.target_is_code or self.args.eval_inference: + from fairseq.models.text_to_speech.vocoder import get_vocoder + + self.vocoder = get_vocoder(self.args, self.data_cfg) + self.vocoder = ( + self.vocoder.cuda() + if torch.cuda.is_available() and not self.args.cpu + else self.vocoder.cpu() + ) + + has_dual_decoder = getattr(models[0], "mt_task_name", None) is not None + + if self.args.target_is_code: + if self.args.n_frames_per_step == 1: + if has_dual_decoder: + seq_generator = self.build_generator_dual_decoder( + models, + args, + extra_gen_cls_kwargs=extra_gen_cls_kwargs, + ) + else: + seq_generator = super().build_generator( + models, + args, + seq_gen_cls=None, + extra_gen_cls_kwargs=extra_gen_cls_kwargs, + ) + else: + assert ( + getattr(args, "beam", 1) == 1 and getattr(args, "nbest", 1) == 1 + ), "only support viterbi search for stacked units" + seq_generator = StackUnitSequenceGenerator( + self.tgt_dict, + self.args.target_code_size, + ) + else: + if has_dual_decoder: + if getattr(args, "teacher_forcing", False): + raise NotImplementedError + else: + from fairseq.speech_generator import MultiDecoderSpeechGenerator + + generator = MultiDecoderSpeechGenerator + + lang_token_ids_aux = { + i + for s, i in self.tgt_dict_mt.indices.items() + if TextTargetMultitaskData.is_lang_tag(s) + } + + if extra_gen_cls_kwargs is None: + extra_gen_cls_kwargs = {} + extra_gen_cls_kwargs[ + "symbols_to_strip_from_output" + ] = lang_token_ids_aux + + eos_id_mt = ( + self.tgt_dict_mt.index(self.eos_token_mt) + if self.eos_token_mt + else None + ) + assert eos_id_mt != self.tgt_dict_mt.unk() + extra_gen_cls_kwargs["eos_mt"] = eos_id_mt + + seq_generator = generator( + models, + args, + self.vocoder, + self.data_cfg, + self.target_dictionary_mt, + max_iter=self.args.max_target_positions, + eos_prob_threshold=self.args.eos_prob_threshold, + **extra_gen_cls_kwargs, + ) + else: + if getattr(args, "teacher_forcing", False): + from fairseq.speech_generator import ( + TeacherForcingAutoRegressiveSpeechGenerator, + ) + + generator = TeacherForcingAutoRegressiveSpeechGenerator + logger.info("Teacher forcing mode for generation") + else: + from fairseq.speech_generator import AutoRegressiveSpeechGenerator + + generator = AutoRegressiveSpeechGenerator + + seq_generator = generator( + models[0], + self.vocoder, + self.data_cfg, + max_iter=self.args.max_target_positions, + eos_prob_threshold=self.args.eos_prob_threshold, + ) + + return seq_generator + + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): + for task_name, task_obj in self.multitask_tasks.items(): + criterion.set_multitask_loss_weight( + task_name, task_obj.args.get_loss_weight(update_num) + ) + if task_name in model.multitask_decoders: + model.multitask_decoders[task_name].train() + + loss, sample_size, logging_output = super().train_step( + sample, model, criterion, optimizer, update_num, ignore_grad + ) + return loss, sample_size, logging_output + + def valid_step(self, sample, model, criterion): + for task_name in self.multitask_tasks.keys(): + if task_name in model.multitask_decoders: + model.multitask_decoders[task_name].eval() + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) + + if self.args.eval_inference: + hypos, inference_losses = self.valid_step_with_inference( + sample, model, self.generator + ) + for k, v in inference_losses.items(): + assert k not in logging_output + logging_output[k] = v + + return loss, sample_size, logging_output + + def valid_step_with_inference(self, sample, model, generator): + if self.args.target_is_code: + hypos = generator.generate([model], sample) + tgt_lens = ( + sample["target_lengths"] - 1 + ) * self.args.n_frames_per_step # strip + for b, (f, l) in enumerate(zip(sample["target"], tgt_lens)): + hypos[b][0]["targ_waveform"] = self.vocoder( + {"code": f[:l] - 4}, # remove , , , + dur_prediction=self.eval_gen_args.get("dur_prediction", False), + ) + if len(hypos[b][0]["tokens"]) > 0: + hypos[b][0]["waveform"] = self.vocoder( + {"code": hypos[b][0]["tokens"] - 4}, + dur_prediction=self.eval_gen_args.get("dur_prediction", False), + ) + else: + hypos[b][0]["waveform"] = torch.flip( + hypos[b][0]["targ_waveform"], dims=[0] + ) + else: + hypos = [ + [hypo] for hypo in generator.generate(model, sample, has_targ=True) + ] + + losses = { + "mcd_loss": 0.0, + "targ_frames": 0.0, + "pred_frames": 0.0, + "path_frames": 0.0, + "nins": 0.0, + "ndel": 0.0, + } + rets = batch_mel_cepstral_distortion( + [hypo[0]["targ_waveform"] for hypo in hypos], + [hypo[0]["waveform"] for hypo in hypos], + self.data_cfg.output_sample_rate, + normalize_type=None, + ) + for d, extra in rets: + pathmap = extra[-1] + losses["mcd_loss"] += d.item() + losses["targ_frames"] += pathmap.size(0) + losses["pred_frames"] += pathmap.size(1) + losses["path_frames"] += pathmap.sum().item() + losses["nins"] += (pathmap.sum(dim=1) - 1).sum().item() + losses["ndel"] += (pathmap.sum(dim=0) - 1).sum().item() + losses["norm_frames"] = losses[ + f"{getattr(self.args, 'mcd_normalize_type', 'targ')}_frames" + ] + + return hypos, losses + + def inference_step( + self, generator, models, sample, prefix_tokens=None, constraints=None + ): + with torch.no_grad(): + if self._infer_tgt_lang_id is not None: + return generator.generate( + models, + sample, + prefix_tokens=prefix_tokens, + constraints=constraints, + bos_token=self._infer_tgt_lang_id, + ) + else: + return super().inference_step( + generator, + models, + sample, + prefix_tokens=prefix_tokens, + constraints=constraints, + ) diff --git a/fairseq/tasks/speech_to_text.py b/fairseq/tasks/speech_to_text.py new file mode 100644 index 0000000000000000000000000000000000000000..884082112a6763e0edb126e7ee21c83e40127823 --- /dev/null +++ b/fairseq/tasks/speech_to_text.py @@ -0,0 +1,350 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from argparse import Namespace +from pathlib import Path +from typing import List + +from fairseq.data import Dictionary, encoders +from fairseq.data.audio.audio_utils import get_features_or_waveform +from fairseq.data.audio.data_cfg import MultitaskConfig +from fairseq.data.audio.speech_to_text_dataset import ( + S2TDataConfig, + SpeechToTextDataset, + SpeechToTextDatasetCreator, + TextTargetMultitaskData, +) +from fairseq.tasks import LegacyFairseqTask, register_task + +logger = logging.getLogger(__name__) + + +@register_task("speech_to_text") +class SpeechToTextTask(LegacyFairseqTask): + @classmethod + def add_args(cls, parser): + parser.add_argument("data", help="manifest root path") + parser.add_argument( + "--config-yaml", + type=str, + default="config.yaml", + help="Configuration YAML filename (under manifest root)", + ) + parser.add_argument( + "--multitask-config-yaml", + type=str, + default=None, + help="Configuration YAML filename for the multitasks (under manifest root)", + ) + parser.add_argument( + "--max-source-positions", + default=6000, + type=int, + metavar="N", + help="max number of tokens in the source sequence", + ) + parser.add_argument( + "--max-target-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the target sequence", + ) + + def __init__(self, args, tgt_dict): + super().__init__(args) + self.tgt_dict = tgt_dict + self.data_cfg = S2TDataConfig(Path(args.data) / args.config_yaml) + self.speaker_to_id = self._get_speaker_to_id() + if ( + self.data_cfg.prepend_tgt_lang_tag + and self.data_cfg.prepend_bos_and_append_tgt_lang_tag + ): + raise ValueError( + "Please set only one of the two options to avoid adding target token multiple times" + ) + + self.multitask_tasks = {} + self.tgt_dict_mt = None + self.eos_token_mt = None + if getattr(args, "multitask_config_yaml", None) is not None: + multitask_cfg = MultitaskConfig( + Path(args.data) / args.multitask_config_yaml + ) + first_pass_task_idx = multitask_cfg.first_pass_decoder_task_index + for i, (task_name, task_config) in enumerate( + multitask_cfg.get_all_tasks().items() + ): + task_obj = DummyMultiTask( + task_config, + task_config.tgt_dict, + first_pass=i == first_pass_task_idx, + ) + self.multitask_tasks[task_name] = task_obj + if task_obj.is_first_pass_decoder: + self.tgt_dict_mt = task_obj.target_dictionary + if task_config.prepend_bos_and_append_tgt_lang_tag: + self.eos_token_mt = task_config.eos_token + assert not isinstance(self.eos_token_mt, List) + + if not self.eos_token_mt: + raise Warning( + "Please provide eos_token in --multitask-config-yaml to replace eos in sequence generator" + ) + + def _get_speaker_to_id(self): + speaker_to_id = None + speaker_set_filename = self.data_cfg.config.get("speaker_set_filename") + if speaker_set_filename is not None: + speaker_set_path = Path(self.args.data) / speaker_set_filename + with open(speaker_set_path) as f: + speaker_to_id = {r.strip(): i for i, r in enumerate(f)} + return speaker_to_id + + @classmethod + def setup_task(cls, args, **kwargs): + data_cfg = S2TDataConfig(Path(args.data) / args.config_yaml) + dict_path = Path(args.data) / data_cfg.vocab_filename + if not dict_path.is_file(): + raise FileNotFoundError(f"Dict not found: {dict_path.as_posix()}") + tgt_dict = Dictionary.load(dict_path.as_posix()) + logger.info( + f"dictionary size ({data_cfg.vocab_filename}): " f"{len(tgt_dict):,}" + ) + + if getattr(args, "train_subset", None) is not None: + if not all(s.startswith("train") for s in args.train_subset.split(",")): + raise ValueError('Train splits should be named like "train*".') + return cls(args, tgt_dict) + + def build_criterion(self, args): + from fairseq import criterions + + if self.data_cfg.prepend_tgt_lang_tag and args.ignore_prefix_size != 1: + raise ValueError( + 'Please set "--ignore-prefix-size 1" since ' + "target language ID token is prepended as BOS." + ) + return criterions.build_criterion(args, self) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + is_train_split = split.startswith("train") + pre_tokenizer = self.build_tokenizer(self.args) + bpe_tokenizer = self.build_bpe(self.args) + self.datasets[split] = SpeechToTextDatasetCreator.from_tsv( + root=self.args.data, + cfg=self.data_cfg, + splits=split, + tgt_dict=self.tgt_dict, + pre_tokenizer=pre_tokenizer, + bpe_tokenizer=bpe_tokenizer, + is_train_split=is_train_split, + epoch=epoch, + seed=self.args.seed, + speaker_to_id=self.speaker_to_id, + multitask=self.multitask_tasks, + ) + + @property + def target_dictionary(self): + return self.tgt_dict + + @property + def target_dictionary_mt(self): + return self.tgt_dict_mt + + @property + def source_dictionary(self): + return None + + def max_positions(self): + return self.args.max_source_positions, self.args.max_target_positions + + def build_model(self, args, from_checkpoint=False): + args.input_feat_per_channel = self.data_cfg.input_feat_per_channel + args.input_channels = self.data_cfg.input_channels + args.speaker_to_id = self.speaker_to_id + return super(SpeechToTextTask, self).build_model(args, from_checkpoint) + + def build_generator_dual_decoder( + self, + models, + args, + extra_gen_cls_kwargs, + ): + from examples.speech_to_speech.unity.sequence_generator_multi_decoder import ( + MultiDecoderSequenceGenerator, + ) + + lang_token_ids_aux = { + i + for s, i in self.tgt_dict_mt.indices.items() + if TextTargetMultitaskData.is_lang_tag(s) + } + + extra_gen_cls_kwargs["symbols_to_strip_from_output"].update(lang_token_ids_aux) + + eos_id_mt = ( + self.tgt_dict_mt.index(self.eos_token_mt) if self.eos_token_mt else None + ) + assert eos_id_mt != self.tgt_dict_mt.unk() + extra_gen_cls_kwargs["eos_mt"] = eos_id_mt + + return MultiDecoderSequenceGenerator( + models, + self.target_dictionary, + self.target_dictionary_mt, + beam_size=max(1, getattr(args, "beam", 1)), + beam_size_mt=max(1, getattr(args, "beam_mt", 1)), + max_len_a=getattr(args, "max_len_a", 0), + max_len_b=getattr(args, "max_len_b", 200), + max_len_a_mt=getattr(args, "max_len_a_mt", 0), + max_len_b_mt=getattr(args, "max_len_b_mt", 0), + min_len=getattr(args, "min_len", 1), + normalize_scores=(not getattr(args, "unnormalized", False)), + len_penalty=getattr(args, "lenpen", 1), + len_penalty_mt=getattr(args, "lenpen_mt", 1), + unk_penalty=getattr(args, "unkpen", 0), + temperature=getattr(args, "temperature", 1.0), + match_source_len=getattr(args, "match_source_len", False), + no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), + **extra_gen_cls_kwargs, + ) + + def build_generator( + self, + models, + args, + seq_gen_cls=None, + extra_gen_cls_kwargs=None, + ): + if self.data_cfg.prepend_tgt_lang_tag and args.prefix_size != 1: + raise ValueError( + 'Please set "--prefix-size 1" since ' + "target language ID token is prepended as BOS." + ) + lang_token_ids = { + i + for s, i in self.tgt_dict.indices.items() + if SpeechToTextDataset.is_lang_tag(s) + } + + if extra_gen_cls_kwargs is None: + extra_gen_cls_kwargs = {} + extra_gen_cls_kwargs["symbols_to_strip_from_output"] = lang_token_ids + + eos_token = ( + args.eos_token + if "eos_token" in args and args.eos_token is not None + else self.data_cfg.config.get("eos_token", None) + ) + + if self.data_cfg.prepend_bos_and_append_tgt_lang_tag and not eos_token: + raise Warning( + "Please provide --eos_token to replace eos in sequence generator" + ) + + eos_id = self.tgt_dict.index(eos_token) if eos_token else None + extra_gen_cls_kwargs["eos"] = eos_id + + has_dual_decoder = getattr(models[0], "mt_task_name", None) is not None + + if has_dual_decoder: + return self.build_generator_dual_decoder( + models, + args, + extra_gen_cls_kwargs=extra_gen_cls_kwargs, + ) + else: + return super().build_generator( + models, + args, + seq_gen_cls=None, + extra_gen_cls_kwargs=extra_gen_cls_kwargs, + ) + + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): + for task_name, task_obj in self.multitask_tasks.items(): + criterion.set_multitask_loss_weight( + task_name, task_obj.args.get_loss_weight(update_num) + ) + if task_name in model.multitask_decoders: + model.multitask_decoders[task_name].train() + + loss, sample_size, logging_output = super().train_step( + sample, model, criterion, optimizer, update_num, ignore_grad + ) + return loss, sample_size, logging_output + + def valid_step(self, sample, model, criterion): + for task_name, task_obj in self.multitask_tasks.items(): + if task_name in model.multitask_decoders: + model.multitask_decoders[task_name].eval() + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) + + return loss, sample_size, logging_output + + def build_tokenizer(self, args): + logger.info(f"pre-tokenizer: {self.data_cfg.pre_tokenizer}") + return encoders.build_tokenizer(Namespace(**self.data_cfg.pre_tokenizer)) + + def build_bpe(self, args): + logger.info(f"tokenizer: {self.data_cfg.bpe_tokenizer}") + return encoders.build_bpe(Namespace(**self.data_cfg.bpe_tokenizer)) + + def get_interactive_tokens_and_lengths(self, lines, encode_fn): + n_frames = [get_features_or_waveform(p).shape[0] for p in lines] + return lines, n_frames + + def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): + return SpeechToTextDataset( + "interactive", False, self.data_cfg, src_tokens, src_lengths + ) + + +class DummyMultiTask(LegacyFairseqTask): + def __init__(self, args, tgt_dict, first_pass=False): + super().__init__(args) + self.tgt_dict = tgt_dict + self.first_pass = first_pass + + @property + def target_dictionary(self): + return self.tgt_dict + + @property + def is_first_pass_decoder(self): + return self.first_pass + + def inference_step( + self, generator, models, sample, prefix_tokens=None, constraints=None + ): + if self.args.decoder_type == "ctc": + model = models[0] # only support single model + encoder_out = model(**sample) + if hasattr(model, "get_logits"): + emissions = model.get_logits( + encoder_out + ) # no need to normalize emissions + else: + emissions = model.get_normalized_probs(encoder_out, log_probs=True) + return generator.decode( + emissions.transpose(0, 1).float().cpu().contiguous() + ) + else: + raise NotImplementedError("only ctc decoder is supported at the moment") + + def build_generator( + self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None + ): + if self.args.decoder_type == "ctc": + from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder + + return W2lViterbiDecoder(args, self.tgt_dict) + else: + raise NotImplementedError("only ctc decoder is supported at the moment") diff --git a/fairseq/tasks/speech_ulm_task.py b/fairseq/tasks/speech_ulm_task.py new file mode 100644 index 0000000000000000000000000000000000000000..b9d3019d5049f4f0ba1673d8b3e77c99db951887 --- /dev/null +++ b/fairseq/tasks/speech_ulm_task.py @@ -0,0 +1,224 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import logging +import sys +import torch +from dataclasses import dataclass, field +from typing import List, Optional, Tuple + +from fairseq.data import Dictionary +from fairseq.data.codedataset import ExpressiveCodeDataConfig, CodeDataset +from fairseq.dataclass.configs import FairseqDataclass +from fairseq.tasks import register_task +from fairseq.tasks.fairseq_task import FairseqTask +from omegaconf import MISSING, DictConfig + + +logger = logging.getLogger(__name__) + + +class UnitDictionary(Dictionary): + """ + A fixed-sized Dictionary that operates on integer-valued tokens + wth a trivial (identity) token <-> id mapping. + Special symbols (bos, eos, ...) have ids above n_units. + """ + + def __init__( + self, + *, # begin keyword-only arguments + n_units, + bos="", + pad="", + eos="", + unk="", + extra_special_symbols=None, + clip=False, + ): + self.n_units = n_units + self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos + self.clip = clip + + self.symbols = [] + self.count = [] + self.indices = {} + for i in range(n_units): + self.add_symbol(str(i)) + + self.bos_index = self.add_symbol(bos) + self.pad_index = self.add_symbol(pad) + self.eos_index = self.add_symbol(eos) + self.unk_index = self.add_symbol(unk) + + if extra_special_symbols: + for s in extra_special_symbols: + self.add_symbol(s) + self.nspecial = len(self.symbols) + + def encode_line(self, line, append_eos=True, prepend_bos=False) -> torch.IntTensor: + words = [int(x) for x in line.split()] + if self.clip: + words = [min(self.n_units - 1, word) for word in words] + if prepend_bos: + words = [self.bos_index] + words + if append_eos: + words.append(self.eos_index) + ids = torch.IntTensor(words) + return ids + + +@dataclass +class SpeechUnitModelingConfig(FairseqDataclass): + data: str = field(default=MISSING, metadata={"help": "Path to data config.json"}) + max_token_duration: int = field( + default=20, metadata={"help": "all token durations are capped to this value"} + ) + tokens_per_sample: int = field( + default=1024, metadata={"help": "tokens in a sample"} + ) + max_target_positions: int = field( + default=1024, metadata={"help": "max target positions"} + ) + + # duration modeling + ignore_duration_input: bool = field( + default=False, metadata={"help": "whether token durations should be zeroed out"} + ) + discrete_duration: bool = field( + default=False, metadata={"help": "treat duration as discrete variable"} + ) + # F0 modeling + ignore_f0_input: bool = field( + default=False, metadata={"help": "whether F0 should be zeroed out"} + ) + discrete_f0: bool = field( + default=False, metadata={"help": "load quantized f0. get bin from config"} + ) + log_f0: bool = field( + default=False, metadata={"help": "whether f0 should be modeled in log space"} + ) + normalize_f0_mean: bool = field( + default=False, metadata={"help": "whether normalize f0 by speaker mean"} + ) + normalize_f0_std: bool = field( + default=False, metadata={"help": "whether normalize f0 by speaker stddev"} + ) + interpolate_f0: bool = field( + default=False, + metadata={"help": "whether interpolate f0 for non-voiced segments"}, + ) + + # input/output streams + stream_shifts: str = field( + default="0,0", + metadata={ + "help": ( + "comma-separated integer list denoting right-shift for " + "duration and pitch streams" + ) + }, + ) + + +@register_task("speech_unit_modeling", dataclass=SpeechUnitModelingConfig) +class SpeechUnitLanguageModelingTask(FairseqTask): + def __init__(self, cfg: SpeechUnitModelingConfig) -> None: + super().__init__(cfg) + assert not self.cfg.normalize_f0_std or self.cfg.normalize_f0_mean + + self.data_config = ExpressiveCodeDataConfig(cfg.data) + self._source_dictionary = self._target_dictionary = UnitDictionary( + n_units=self.data_config.n_units + ) + self._source_duration_dictionary = self._target_duration_dictionary = ( + UnitDictionary(n_units=self.cfg.max_token_duration + 1, clip=True) + if self.cfg.discrete_duration + else None + ) + self._source_f0_dictionary = self._target_f0_dictionary = ( + UnitDictionary(n_units=self.data_config.f0_vq_n_units) + if self.cfg.discrete_f0 + else None + ) + + self._channel_names = ["token", "duration", "f0"] + self._channel_sizes = [ + len(self.target_dictionary), + len(self.target_duration_dictionary) if self.cfg.discrete_duration else 1, + len(self.target_f0_dictionary) if self.cfg.discrete_f0 else 1, + ] + + @property + def source_dictionary(self) -> Optional[Dictionary]: + return self._source_dictionary + + @property + def source_duration_dictionary(self) -> Optional[Dictionary]: + return self._source_duration_dictionary + + @property + def source_f0_dictionary(self) -> Optional[Dictionary]: + return self._source_f0_dictionary + + @property + def channel_names(self) -> List[str]: + return self._channel_names + + @property + def channel_sizes(self) -> List[int]: + return self._channel_sizes + + @property + def dictionary(self) -> Optional[Dictionary]: + return self._source_dictionary + + @property + def target_dictionary(self) -> Optional[Dictionary]: + return self._target_dictionary + + @property + def target_duration_dictionary(self) -> Optional[Dictionary]: + return self._target_duration_dictionary + + @property + def target_f0_dictionary(self) -> Optional[Dictionary]: + return self._target_f0_dictionary + + @property + def dictionaries(self) -> List[Dictionary]: + return [self._dictionaries[l] for l in self.cfg.labels] + + @classmethod + def setup_task( + cls, cfg: SpeechUnitModelingConfig, **kwargs + ) -> "SpeechUnitLanguageModelingTask": + return cls(cfg) + + def load_dataset(self, split: str, **kwargs) -> None: + self.datasets[split] = CodeDataset( + manifest=self.data_config.manifests[split], + dictionary=self.source_dictionary, + dur_dictionary=self.source_duration_dictionary, + f0_dictionary=self.source_f0_dictionary, + config=self.data_config, + discrete_dur=self.cfg.discrete_duration, + discrete_f0=self.cfg.discrete_f0, + log_f0=self.cfg.log_f0, + normalize_f0_mean=self.cfg.normalize_f0_mean, + normalize_f0_std=self.cfg.normalize_f0_std, + interpolate_f0=self.cfg.interpolate_f0, + shifts=self.cfg.stream_shifts, + ) + + def max_positions(self) -> Tuple[int, int]: + return (sys.maxsize, sys.maxsize) + + def build_criterion(self, cfg: DictConfig): + import fairseq.criterions + + return fairseq.criterions.build_criterion(cfg, self) diff --git a/fairseq/tasks/text_to_speech.py b/fairseq/tasks/text_to_speech.py new file mode 100644 index 0000000000000000000000000000000000000000..82e7e6643af719a9bc0b4d5ef446365b8ef7e8fb --- /dev/null +++ b/fairseq/tasks/text_to_speech.py @@ -0,0 +1,501 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import os.path as op + +import torch +import torch.nn.functional as F +import numpy as np + +from fairseq.data.audio.text_to_speech_dataset import TextToSpeechDatasetCreator +from fairseq.tasks import register_task +from fairseq.tasks.speech_to_text import SpeechToTextTask +from fairseq.speech_generator import ( + AutoRegressiveSpeechGenerator, + NonAutoregressiveSpeechGenerator, + TeacherForcingAutoRegressiveSpeechGenerator, +) + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, +) +logger = logging.getLogger(__name__) + + +try: + from tensorboardX import SummaryWriter +except ImportError: + logger.info("Please install tensorboardX: pip install tensorboardX") + SummaryWriter = None + + +@register_task("text_to_speech") +class TextToSpeechTask(SpeechToTextTask): + @staticmethod + def add_args(parser): + parser.add_argument("data", help="manifest root path") + parser.add_argument( + "--config-yaml", + type=str, + default="config.yaml", + help="Configuration YAML filename (under manifest root)", + ) + parser.add_argument( + "--max-source-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the source sequence", + ) + parser.add_argument( + "--max-target-positions", + default=1200, + type=int, + metavar="N", + help="max number of tokens in the target sequence", + ) + parser.add_argument("--n-frames-per-step", type=int, default=1) + parser.add_argument("--eos-prob-threshold", type=float, default=0.5) + parser.add_argument("--eval-inference", action="store_true") + parser.add_argument("--eval-tb-nsample", type=int, default=8) + parser.add_argument("--vocoder", type=str, default="griffin_lim") + parser.add_argument("--spec-bwd-max-iter", type=int, default=8) + + def __init__(self, args, src_dict): + super().__init__(args, src_dict) + self.src_dict = src_dict + self.sr = self.data_cfg.config.get("features").get("sample_rate") + + self.tensorboard_writer = None + self.tensorboard_dir = "" + if args.tensorboard_logdir and SummaryWriter is not None: + self.tensorboard_dir = os.path.join(args.tensorboard_logdir, "valid_extra") + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + is_train_split = split.startswith("train") + pre_tokenizer = self.build_tokenizer(self.args) + bpe_tokenizer = self.build_bpe(self.args) + self.datasets[split] = TextToSpeechDatasetCreator.from_tsv( + self.args.data, + self.data_cfg, + split, + self.src_dict, + pre_tokenizer, + bpe_tokenizer, + is_train_split=is_train_split, + epoch=epoch, + seed=self.args.seed, + n_frames_per_step=self.args.n_frames_per_step, + speaker_to_id=self.speaker_to_id, + ) + + @property + def target_dictionary(self): + return None + + @property + def source_dictionary(self): + return self.src_dict + + def get_speaker_embeddings_path(self): + speaker_emb_path = None + if self.data_cfg.config.get("speaker_emb_filename") is not None: + speaker_emb_path = op.join( + self.args.data, self.data_cfg.config.get("speaker_emb_filename") + ) + return speaker_emb_path + + @classmethod + def get_speaker_embeddings(cls, args): + embed_speaker = None + if args.speaker_to_id is not None: + if args.speaker_emb_path is None: + embed_speaker = torch.nn.Embedding( + len(args.speaker_to_id), args.speaker_embed_dim + ) + else: + speaker_emb_mat = np.load(args.speaker_emb_path) + assert speaker_emb_mat.shape[1] == args.speaker_embed_dim + embed_speaker = torch.nn.Embedding.from_pretrained( + torch.from_numpy(speaker_emb_mat), + freeze=True, + ) + logger.info( + f"load speaker embeddings from {args.speaker_emb_path}. " + f"train embedding? {embed_speaker.weight.requires_grad}\n" + f"embeddings:\n{speaker_emb_mat}" + ) + return embed_speaker + + def build_model(self, cfg, from_checkpoint=False): + cfg.pitch_min = self.data_cfg.config["features"].get("pitch_min", None) + cfg.pitch_max = self.data_cfg.config["features"].get("pitch_max", None) + cfg.energy_min = self.data_cfg.config["features"].get("energy_min", None) + cfg.energy_max = self.data_cfg.config["features"].get("energy_max", None) + cfg.speaker_emb_path = self.get_speaker_embeddings_path() + model = super().build_model(cfg, from_checkpoint) + self.generator = None + if getattr(cfg, "eval_inference", False): + self.generator = self.build_generator([model], cfg) + return model + + def build_generator(self, models, cfg, vocoder=None, **unused): + if vocoder is None: + vocoder = self.build_default_vocoder() + model = models[0] + if getattr(model, "NON_AUTOREGRESSIVE", False): + return NonAutoregressiveSpeechGenerator(model, vocoder, self.data_cfg) + else: + generator = AutoRegressiveSpeechGenerator + if getattr(cfg, "teacher_forcing", False): + generator = TeacherForcingAutoRegressiveSpeechGenerator + logger.info("Teacher forcing mode for generation") + return generator( + model, + vocoder, + self.data_cfg, + max_iter=self.args.max_target_positions, + eos_prob_threshold=self.args.eos_prob_threshold, + ) + + def build_default_vocoder(self): + from fairseq.models.text_to_speech.vocoder import get_vocoder + + vocoder = get_vocoder(self.args, self.data_cfg) + if torch.cuda.is_available() and not self.args.cpu: + vocoder = vocoder.cuda() + else: + vocoder = vocoder.cpu() + return vocoder + + def valid_step(self, sample, model, criterion): + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) + + if getattr(self.args, "eval_inference", False): + hypos, inference_losses = self.valid_step_with_inference( + sample, model, self.generator + ) + for k, v in inference_losses.items(): + assert k not in logging_output + logging_output[k] = v + + picked_id = 0 + if self.tensorboard_dir and (sample["id"] == picked_id).any(): + self.log_tensorboard( + sample, + hypos[: self.args.eval_tb_nsample], + model._num_updates, + is_na_model=getattr(model, "NON_AUTOREGRESSIVE", False), + ) + return loss, sample_size, logging_output + + def valid_step_with_inference(self, sample, model, generator): + hypos = generator.generate(model, sample, has_targ=True) + + losses = { + "mcd_loss": 0.0, + "targ_frames": 0.0, + "pred_frames": 0.0, + "nins": 0.0, + "ndel": 0.0, + } + rets = batch_mel_cepstral_distortion( + [hypo["targ_waveform"] for hypo in hypos], + [hypo["waveform"] for hypo in hypos], + self.sr, + normalize_type=None, + ) + for d, extra in rets: + pathmap = extra[-1] + losses["mcd_loss"] += d.item() + losses["targ_frames"] += pathmap.size(0) + losses["pred_frames"] += pathmap.size(1) + losses["nins"] += (pathmap.sum(dim=1) - 1).sum().item() + losses["ndel"] += (pathmap.sum(dim=0) - 1).sum().item() + + return hypos, losses + + def log_tensorboard(self, sample, hypos, num_updates, is_na_model=False): + if self.tensorboard_writer is None: + self.tensorboard_writer = SummaryWriter(self.tensorboard_dir) + tb_writer = self.tensorboard_writer + for b in range(len(hypos)): + idx = sample["id"][b] + text = sample["src_texts"][b] + targ = hypos[b]["targ_feature"] + pred = hypos[b]["feature"] + attn = hypos[b]["attn"] + + if is_na_model: + data = plot_tts_output( + [targ.transpose(0, 1), pred.transpose(0, 1)], + [f"target (idx={idx})", "output"], + attn, + "alignment", + ret_np=True, + suptitle=text, + ) + else: + eos_prob = hypos[b]["eos_prob"] + data = plot_tts_output( + [targ.transpose(0, 1), pred.transpose(0, 1), attn], + [f"target (idx={idx})", "output", "alignment"], + eos_prob, + "eos prob", + ret_np=True, + suptitle=text, + ) + + tb_writer.add_image( + f"inference_sample_{b}", data, num_updates, dataformats="HWC" + ) + + if hypos[b]["waveform"] is not None: + targ_wave = hypos[b]["targ_waveform"].detach().cpu().float() + pred_wave = hypos[b]["waveform"].detach().cpu().float() + tb_writer.add_audio( + f"inference_targ_{b}", targ_wave, num_updates, sample_rate=self.sr + ) + tb_writer.add_audio( + f"inference_pred_{b}", pred_wave, num_updates, sample_rate=self.sr + ) + + +def save_figure_to_numpy(fig): + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + return data + + +DEFAULT_V_MIN = np.log(1e-5) + + +def plot_tts_output( + data_2d, + title_2d, + data_1d, + title_1d, + figsize=(24, 4), + v_min=DEFAULT_V_MIN, + v_max=3, + ret_np=False, + suptitle="", +): + try: + import matplotlib.pyplot as plt + from mpl_toolkits.axes_grid1 import make_axes_locatable + except ImportError: + raise ImportError("Please install Matplotlib: pip install matplotlib") + + data_2d = [ + x.detach().cpu().float().numpy() if isinstance(x, torch.Tensor) else x + for x in data_2d + ] + fig, axes = plt.subplots(1, len(data_2d) + 1, figsize=figsize) + if suptitle: + fig.suptitle(suptitle[:400]) # capped at 400 chars + axes = [axes] if len(data_2d) == 0 else axes + for ax, x, name in zip(axes, data_2d, title_2d): + ax.set_title(name) + divider = make_axes_locatable(ax) + cax = divider.append_axes("right", size="5%", pad=0.05) + im = ax.imshow( + x, + origin="lower", + aspect="auto", + vmin=max(x.min(), v_min), + vmax=min(x.max(), v_max), + ) + fig.colorbar(im, cax=cax, orientation="vertical") + + if isinstance(data_1d, torch.Tensor): + data_1d = data_1d.detach().cpu().numpy() + axes[-1].plot(data_1d) + axes[-1].set_title(title_1d) + plt.tight_layout() + + if ret_np: + fig.canvas.draw() + data = save_figure_to_numpy(fig) + plt.close(fig) + return data + + +def antidiag_indices(offset, min_i=0, max_i=None, min_j=0, max_j=None): + """ + for a (3, 4) matrix with min_i=1, max_i=3, min_j=1, max_j=4, outputs + + offset=2 (1, 1), + offset=3 (2, 1), (1, 2) + offset=4 (2, 2), (1, 3) + offset=5 (2, 3) + + constraints: + i + j = offset + min_j <= j < max_j + min_i <= offset - j < max_i + """ + if max_i is None: + max_i = offset + 1 + if max_j is None: + max_j = offset + 1 + min_j = max(min_j, offset - max_i + 1, 0) + max_j = min(max_j, offset - min_i + 1, offset + 1) + j = torch.arange(min_j, max_j) + i = offset - j + return torch.stack([i, j]) + + +def batch_dynamic_time_warping(distance, shapes=None): + """full batched DTW without any constraints + + distance: (batchsize, max_M, max_N) matrix + shapes: (batchsize,) vector specifying (M, N) for each entry + """ + # ptr: 0=left, 1=up-left, 2=up + ptr2dij = {0: (0, -1), 1: (-1, -1), 2: (-1, 0)} + + bsz, m, n = distance.size() + cumdist = torch.zeros_like(distance) + backptr = torch.zeros_like(distance).type(torch.int32) - 1 + + # initialize + cumdist[:, 0, :] = distance[:, 0, :].cumsum(dim=-1) + cumdist[:, :, 0] = distance[:, :, 0].cumsum(dim=-1) + backptr[:, 0, :] = 0 + backptr[:, :, 0] = 2 + + # DP with optimized anti-diagonal parallelization, O(M+N) steps + for offset in range(2, m + n - 1): + ind = antidiag_indices(offset, 1, m, 1, n) + c = torch.stack( + [ + cumdist[:, ind[0], ind[1] - 1], + cumdist[:, ind[0] - 1, ind[1] - 1], + cumdist[:, ind[0] - 1, ind[1]], + ], + dim=2, + ) + v, b = c.min(axis=-1) + backptr[:, ind[0], ind[1]] = b.int() + cumdist[:, ind[0], ind[1]] = v + distance[:, ind[0], ind[1]] + + # backtrace + pathmap = torch.zeros_like(backptr) + for b in range(bsz): + i = m - 1 if shapes is None else (shapes[b][0] - 1).item() + j = n - 1 if shapes is None else (shapes[b][1] - 1).item() + dtwpath = [(i, j)] + while (i != 0 or j != 0) and len(dtwpath) < 10000: + assert i >= 0 and j >= 0 + di, dj = ptr2dij[backptr[b, i, j].item()] + i, j = i + di, j + dj + dtwpath.append((i, j)) + dtwpath = dtwpath[::-1] + indices = torch.from_numpy(np.array(dtwpath)) + pathmap[b, indices[:, 0], indices[:, 1]] = 1 + + return cumdist, backptr, pathmap + + +def compute_l2_dist(x1, x2): + """compute an (m, n) L2 distance matrix from (m, d) and (n, d) matrices""" + return torch.cdist(x1.unsqueeze(0), x2.unsqueeze(0), p=2).squeeze(0).pow(2) + + +def compute_rms_dist(x1, x2): + l2_dist = compute_l2_dist(x1, x2) + return (l2_dist / x1.size(1)).pow(0.5) + + +def get_divisor(pathmap, normalize_type): + if normalize_type is None: + return 1 + elif normalize_type == "len1": + return pathmap.size(0) + elif normalize_type == "len2": + return pathmap.size(1) + elif normalize_type == "path": + return pathmap.sum().item() + else: + raise ValueError(f"normalize_type {normalize_type} not supported") + + +def batch_compute_distortion(y1, y2, sr, feat_fn, dist_fn, normalize_type): + d, s, x1, x2 = [], [], [], [] + for cur_y1, cur_y2 in zip(y1, y2): + assert cur_y1.ndim == 1 and cur_y2.ndim == 1 + cur_x1 = feat_fn(cur_y1) + cur_x2 = feat_fn(cur_y2) + x1.append(cur_x1) + x2.append(cur_x2) + + cur_d = dist_fn(cur_x1, cur_x2) + d.append(cur_d) + s.append(d[-1].size()) + max_m = max(ss[0] for ss in s) + max_n = max(ss[1] for ss in s) + d = torch.stack( + [F.pad(dd, (0, max_n - dd.size(1), 0, max_m - dd.size(0))) for dd in d] + ) + s = torch.LongTensor(s).to(d.device) + cumdists, backptrs, pathmaps = batch_dynamic_time_warping(d, s) + + rets = [] + itr = zip(s, x1, x2, d, cumdists, backptrs, pathmaps) + for (m, n), cur_x1, cur_x2, dist, cumdist, backptr, pathmap in itr: + cumdist = cumdist[:m, :n] + backptr = backptr[:m, :n] + pathmap = pathmap[:m, :n] + divisor = get_divisor(pathmap, normalize_type) + + distortion = cumdist[-1, -1] / divisor + ret = distortion, (cur_x1, cur_x2, dist, cumdist, backptr, pathmap) + rets.append(ret) + return rets + + +def batch_mel_cepstral_distortion(y1, y2, sr, normalize_type="path", mfcc_fn=None): + """ + https://arxiv.org/pdf/2011.03568.pdf + + The root mean squared error computed on 13-dimensional MFCC using DTW for + alignment. MFCC features are computed from an 80-channel log-mel + spectrogram using a 50ms Hann window and hop of 12.5ms. + + y1: list of waveforms + y2: list of waveforms + sr: sampling rate + """ + + try: + import torchaudio + except ImportError: + raise ImportError("Please install torchaudio: pip install torchaudio") + + if mfcc_fn is None or mfcc_fn.sample_rate != sr: + melkwargs = { + "n_fft": int(0.05 * sr), + "win_length": int(0.05 * sr), + "hop_length": int(0.0125 * sr), + "f_min": 20, + "n_mels": 80, + "window_fn": torch.hann_window, + } + mfcc_fn = torchaudio.transforms.MFCC( + sr, n_mfcc=13, log_mels=True, melkwargs=melkwargs + ).to(y1[0].device) + return batch_compute_distortion( + y1, + y2, + sr, + lambda y: mfcc_fn(y).transpose(-1, -2), + compute_rms_dist, + normalize_type, + ) diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py new file mode 100644 index 0000000000000000000000000000000000000000..6897ebe116a28c03b1632deff04d6d570398d2f0 --- /dev/null +++ b/fairseq/tasks/translation.py @@ -0,0 +1,498 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +import itertools +import json +import logging +import os +from typing import Optional +from argparse import Namespace +from omegaconf import II + +import numpy as np +from fairseq import utils +from fairseq.logging import metrics +from fairseq.data import ( + AppendTokenDataset, + ConcatDataset, + LanguagePairDataset, + PrependTokenDataset, + StripTokenDataset, + TruncateDataset, + data_utils, + encoders, + indexed_dataset, +) +from fairseq.data.indexed_dataset import get_available_dataset_impl +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.tasks import FairseqTask, register_task + + +EVAL_BLEU_ORDER = 4 + + +logger = logging.getLogger(__name__) + + +def load_langpair_dataset( + data_path, + split, + src, + src_dict, + tgt, + tgt_dict, + combine, + dataset_impl, + upsample_primary, + left_pad_source, + left_pad_target, + max_source_positions, + max_target_positions, + prepend_bos=False, + load_alignments=False, + truncate_source=False, + append_source_id=False, + num_buckets=0, + shuffle=True, + pad_to_multiple=1, + prepend_bos_src=None, +): + def split_exists(split, src, tgt, lang, data_path): + filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) + return indexed_dataset.dataset_exists(filename, impl=dataset_impl) + + src_datasets = [] + tgt_datasets = [] + + for k in itertools.count(): + split_k = split + (str(k) if k > 0 else "") + + # infer langcode + if split_exists(split_k, src, tgt, src, data_path): + prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt)) + elif split_exists(split_k, tgt, src, src, data_path): + prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src)) + else: + if k > 0: + break + else: + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, data_path) + ) + + src_dataset = data_utils.load_indexed_dataset( + prefix + src, src_dict, dataset_impl + ) + if truncate_source: + src_dataset = AppendTokenDataset( + TruncateDataset( + StripTokenDataset(src_dataset, src_dict.eos()), + max_source_positions - 1, + ), + src_dict.eos(), + ) + src_datasets.append(src_dataset) + + tgt_dataset = data_utils.load_indexed_dataset( + prefix + tgt, tgt_dict, dataset_impl + ) + if tgt_dataset is not None: + tgt_datasets.append(tgt_dataset) + + logger.info( + "{} {} {}-{} {} examples".format( + data_path, split_k, src, tgt, len(src_datasets[-1]) + ) + ) + + if not combine: + break + + assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 + + if len(src_datasets) == 1: + src_dataset = src_datasets[0] + tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None + else: + sample_ratios = [1] * len(src_datasets) + sample_ratios[0] = upsample_primary + src_dataset = ConcatDataset(src_datasets, sample_ratios) + if len(tgt_datasets) > 0: + tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) + else: + tgt_dataset = None + + if prepend_bos: + assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index") + src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) + if tgt_dataset is not None: + tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) + elif prepend_bos_src is not None: + logger.info(f"prepending src bos: {prepend_bos_src}") + src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src) + + eos = None + if append_source_id: + src_dataset = AppendTokenDataset( + src_dataset, src_dict.index("[{}]".format(src)) + ) + if tgt_dataset is not None: + tgt_dataset = AppendTokenDataset( + tgt_dataset, tgt_dict.index("[{}]".format(tgt)) + ) + eos = tgt_dict.index("[{}]".format(tgt)) + + align_dataset = None + if load_alignments: + align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt)) + if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): + align_dataset = data_utils.load_indexed_dataset( + align_path, None, dataset_impl + ) + + tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None + return LanguagePairDataset( + src_dataset, + src_dataset.sizes, + src_dict, + tgt_dataset, + tgt_dataset_sizes, + tgt_dict, + left_pad_source=left_pad_source, + left_pad_target=left_pad_target, + align_dataset=align_dataset, + eos=eos, + num_buckets=num_buckets, + shuffle=shuffle, + pad_to_multiple=pad_to_multiple, + ) + + +@dataclass +class TranslationConfig(FairseqDataclass): + data: Optional[str] = field( + default=None, + metadata={ + "help": "colon separated path to data directories list, will be iterated upon during epochs " + "in round-robin manner; however, valid and test data are always in the first directory " + "to avoid the need for repeating them in all directories" + }, + ) + source_lang: Optional[str] = field( + default=None, + metadata={ + "help": "source language", + "argparse_alias": "-s", + }, + ) + target_lang: Optional[str] = field( + default=None, + metadata={ + "help": "target language", + "argparse_alias": "-t", + }, + ) + load_alignments: bool = field( + default=False, metadata={"help": "load the binarized alignments"} + ) + left_pad_source: bool = field( + default=True, metadata={"help": "pad the source on the left"} + ) + left_pad_target: bool = field( + default=False, metadata={"help": "pad the target on the left"} + ) + max_source_positions: int = field( + default=1024, metadata={"help": "max number of tokens in the source sequence"} + ) + max_target_positions: int = field( + default=1024, metadata={"help": "max number of tokens in the target sequence"} + ) + upsample_primary: int = field( + default=-1, metadata={"help": "the amount of upsample primary dataset"} + ) + truncate_source: bool = field( + default=False, metadata={"help": "truncate source to max-source-positions"} + ) + num_batch_buckets: int = field( + default=0, + metadata={ + "help": "if >0, then bucket source and target lengths into " + "N buckets and pad accordingly; this is useful on TPUs to minimize the number of compilations" + }, + ) + train_subset: str = II("dataset.train_subset") + dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II( + "dataset.dataset_impl" + ) + required_seq_len_multiple: int = II("dataset.required_seq_len_multiple") + + # options for reporting BLEU during validation + eval_bleu: bool = field( + default=False, metadata={"help": "evaluation with BLEU scores"} + ) + eval_bleu_args: Optional[str] = field( + default="{}", + metadata={ + "help": 'generation args for BLUE scoring, e.g., \'{"beam": 4, "lenpen": 0.6}\', as JSON string' + }, + ) + eval_bleu_detok: str = field( + default="space", + metadata={ + "help": "detokenize before computing BLEU (e.g., 'moses'); required if using --eval-bleu; " + "use 'space' to disable detokenization; see fairseq.data.encoders for other options" + }, + ) + eval_bleu_detok_args: Optional[str] = field( + default="{}", + metadata={"help": "args for building the tokenizer, if needed, as JSON string"}, + ) + eval_tokenized_bleu: bool = field( + default=False, metadata={"help": "compute tokenized BLEU instead of sacrebleu"} + ) + eval_bleu_remove_bpe: Optional[str] = field( + default=None, + metadata={ + "help": "remove BPE before computing BLEU", + "argparse_const": "@@ ", + }, + ) + eval_bleu_print_samples: bool = field( + default=False, metadata={"help": "print sample generations during validation"} + ) + + +@register_task("translation", dataclass=TranslationConfig) +class TranslationTask(FairseqTask): + """ + Translate from one (source) language to another (target) language. + + Args: + src_dict (~fairseq.data.Dictionary): dictionary for the source language + tgt_dict (~fairseq.data.Dictionary): dictionary for the target language + + .. note:: + + The translation task is compatible with :mod:`fairseq-train`, + :mod:`fairseq-generate` and :mod:`fairseq-interactive`. + """ + + cfg: TranslationConfig + + def __init__(self, cfg: TranslationConfig, src_dict, tgt_dict): + super().__init__(cfg) + self.src_dict = src_dict + self.tgt_dict = tgt_dict + + @classmethod + def setup_task(cls, cfg: TranslationConfig, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + + paths = utils.split_paths(cfg.data) + assert len(paths) > 0 + # find language pair automatically + if cfg.source_lang is None or cfg.target_lang is None: + cfg.source_lang, cfg.target_lang = data_utils.infer_language_pair(paths[0]) + if cfg.source_lang is None or cfg.target_lang is None: + raise Exception( + "Could not infer language pair, please provide it explicitly" + ) + + # load dictionaries + src_dict = cls.load_dictionary( + os.path.join(paths[0], "dict.{}.txt".format(cfg.source_lang)) + ) + tgt_dict = cls.load_dictionary( + os.path.join(paths[0], "dict.{}.txt".format(cfg.target_lang)) + ) + assert src_dict.pad() == tgt_dict.pad() + assert src_dict.eos() == tgt_dict.eos() + assert src_dict.unk() == tgt_dict.unk() + logger.info("[{}] dictionary: {} types".format(cfg.source_lang, len(src_dict))) + logger.info("[{}] dictionary: {} types".format(cfg.target_lang, len(tgt_dict))) + + return cls(cfg, src_dict, tgt_dict) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + paths = utils.split_paths(self.cfg.data) + assert len(paths) > 0 + if split != self.cfg.train_subset: + # if not training data set, use the first shard for valid and test + paths = paths[:1] + data_path = paths[(epoch - 1) % len(paths)] + + # infer langcode + src, tgt = self.cfg.source_lang, self.cfg.target_lang + + self.datasets[split] = load_langpair_dataset( + data_path, + split, + src, + self.src_dict, + tgt, + self.tgt_dict, + combine=combine, + dataset_impl=self.cfg.dataset_impl, + upsample_primary=self.cfg.upsample_primary, + left_pad_source=self.cfg.left_pad_source, + left_pad_target=self.cfg.left_pad_target, + max_source_positions=self.cfg.max_source_positions, + max_target_positions=self.cfg.max_target_positions, + load_alignments=self.cfg.load_alignments, + truncate_source=self.cfg.truncate_source, + num_buckets=self.cfg.num_batch_buckets, + shuffle=(split != "test"), + pad_to_multiple=self.cfg.required_seq_len_multiple, + ) + + def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): + return LanguagePairDataset( + src_tokens, + src_lengths, + self.source_dictionary, + tgt_dict=self.target_dictionary, + constraints=constraints, + ) + + def build_model(self, cfg, from_checkpoint=False): + model = super().build_model(cfg, from_checkpoint) + if self.cfg.eval_bleu: + detok_args = json.loads(self.cfg.eval_bleu_detok_args) + self.tokenizer = encoders.build_tokenizer( + Namespace(tokenizer=self.cfg.eval_bleu_detok, **detok_args) + ) + + gen_args = json.loads(self.cfg.eval_bleu_args) + self.sequence_generator = self.build_generator( + [model], Namespace(**gen_args) + ) + return model + + def valid_step(self, sample, model, criterion): + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) + if self.cfg.eval_bleu: + bleu = self._inference_with_bleu(self.sequence_generator, sample, model) + logging_output["_bleu_sys_len"] = bleu.sys_len + logging_output["_bleu_ref_len"] = bleu.ref_len + # we split counts into separate entries so that they can be + # summed efficiently across workers using fast-stat-sync + assert len(bleu.counts) == EVAL_BLEU_ORDER + for i in range(EVAL_BLEU_ORDER): + logging_output["_bleu_counts_" + str(i)] = bleu.counts[i] + logging_output["_bleu_totals_" + str(i)] = bleu.totals[i] + return loss, sample_size, logging_output + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + if self.cfg.eval_bleu: + + def sum_logs(key): + import torch + + result = sum(log.get(key, 0) for log in logging_outputs) + if torch.is_tensor(result): + result = result.cpu() + return result + + counts, totals = [], [] + for i in range(EVAL_BLEU_ORDER): + counts.append(sum_logs("_bleu_counts_" + str(i))) + totals.append(sum_logs("_bleu_totals_" + str(i))) + + if max(totals) > 0: + # log counts as numpy arrays -- log_scalar will sum them correctly + metrics.log_scalar("_bleu_counts", np.array(counts)) + metrics.log_scalar("_bleu_totals", np.array(totals)) + metrics.log_scalar("_bleu_sys_len", sum_logs("_bleu_sys_len")) + metrics.log_scalar("_bleu_ref_len", sum_logs("_bleu_ref_len")) + + def compute_bleu(meters): + import inspect + + try: + from sacrebleu.metrics import BLEU + + comp_bleu = BLEU.compute_bleu + except ImportError: + # compatibility API for sacrebleu 1.x + import sacrebleu + + comp_bleu = sacrebleu.compute_bleu + + fn_sig = inspect.getfullargspec(comp_bleu)[0] + if "smooth_method" in fn_sig: + smooth = {"smooth_method": "exp"} + else: + smooth = {"smooth": "exp"} + bleu = comp_bleu( + correct=meters["_bleu_counts"].sum, + total=meters["_bleu_totals"].sum, + sys_len=int(meters["_bleu_sys_len"].sum), + ref_len=int(meters["_bleu_ref_len"].sum), + **smooth, + ) + return round(bleu.score, 2) + + metrics.log_derived("bleu", compute_bleu) + + def max_positions(self): + """Return the max sentence length allowed by the task.""" + return (self.cfg.max_source_positions, self.cfg.max_target_positions) + + @property + def source_dictionary(self): + """Return the source :class:`~fairseq.data.Dictionary`.""" + return self.src_dict + + @property + def target_dictionary(self): + """Return the target :class:`~fairseq.data.Dictionary`.""" + return self.tgt_dict + + def _inference_with_bleu(self, generator, sample, model): + import sacrebleu + + def decode(toks, escape_unk=False): + s = self.tgt_dict.string( + toks.int().cpu(), + self.cfg.eval_bleu_remove_bpe, + # The default unknown string in fairseq is ``, but + # this is tokenized by sacrebleu as `< unk >`, inflating + # BLEU scores. Instead, we use a somewhat more verbose + # alternative that is unlikely to appear in the real + # reference, but doesn't get split into multiple tokens. + unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"), + ) + if self.tokenizer: + s = self.tokenizer.decode(s) + return s + + gen_out = self.inference_step(generator, [model], sample, prefix_tokens=None) + hyps, refs = [], [] + for i in range(len(gen_out)): + hyps.append(decode(gen_out[i][0]["tokens"])) + refs.append( + decode( + utils.strip_pad(sample["target"][i], self.tgt_dict.pad()), + escape_unk=True, # don't count as matches to the hypo + ) + ) + if self.cfg.eval_bleu_print_samples: + logger.info("example hypothesis: " + hyps[0]) + logger.info("example reference: " + refs[0]) + if self.cfg.eval_tokenized_bleu: + return sacrebleu.corpus_bleu(hyps, [refs], tokenize="none") + else: + return sacrebleu.corpus_bleu(hyps, [refs]) diff --git a/fairseq/tasks/translation_from_pretrained_bart.py b/fairseq/tasks/translation_from_pretrained_bart.py new file mode 100644 index 0000000000000000000000000000000000000000..0fd7a5b29f0e34699b5d5ef7574bc39b8c6052c9 --- /dev/null +++ b/fairseq/tasks/translation_from_pretrained_bart.py @@ -0,0 +1,132 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from fairseq import utils +from fairseq.data import LanguagePairDataset + +from . import register_task +from .translation import TranslationTask, load_langpair_dataset + + +@register_task("translation_from_pretrained_bart") +class TranslationFromPretrainedBARTTask(TranslationTask): + """ + Translate from source language to target language with a model initialized with a multilingual pretrain. + + Args: + src_dict (~fairseq.data.Dictionary): dictionary for the source language + tgt_dict (~fairseq.data.Dictionary): dictionary for the target language + + .. note:: + + The translation task is compatible with :mod:`fairseq-train`, + :mod:`fairseq-generate` and :mod:`fairseq-interactive`. + + The translation task provides the following additional command-line + arguments: + + .. argparse:: + :ref: fairseq.tasks.translation_parser + :prog: + """ + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + # fmt: off + TranslationTask.add_args(parser) + parser.add_argument('--langs', type=str, metavar='LANG', + help='comma-separated list of monolingual language, ' + 'for example, "en,de,fr". These should match the ' + 'langs from pretraining (and be in the same order). ' + 'You should always add all pretraining language idx ' + 'during finetuning.') + parser.add_argument('--prepend-bos', action='store_true', + help='prepend bos token to each sentence, which matches ' + 'mBART pretraining') + # fmt: on + + def __init__(self, args, src_dict, tgt_dict): + super().__init__(args, src_dict, tgt_dict) + self.langs = args.langs.split(",") + for d in [src_dict, tgt_dict]: + for l in self.langs: + d.add_symbol("[{}]".format(l)) + d.add_symbol("") + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + paths = utils.split_paths(self.args.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + + # infer langcode + src, tgt = self.args.source_lang, self.args.target_lang + + self.datasets[split] = load_langpair_dataset( + data_path, + split, + src, + self.src_dict, + tgt, + self.tgt_dict, + combine=combine, + dataset_impl=self.args.dataset_impl, + upsample_primary=self.args.upsample_primary, + left_pad_source=self.args.left_pad_source, + left_pad_target=self.args.left_pad_target, + max_source_positions=getattr(self.args, "max_source_positions", 1024), + max_target_positions=getattr(self.args, "max_target_positions", 1024), + load_alignments=self.args.load_alignments, + prepend_bos=getattr(self.args, "prepend_bos", False), + append_source_id=True, + ) + + def build_generator(self, models, args, **unused): + if getattr(args, "score_reference", False): + from fairseq.sequence_scorer import SequenceScorer + + return SequenceScorer( + self.target_dictionary, + eos=self.tgt_dict.index("[{}]".format(self.args.target_lang)), + ) + else: + from fairseq.sequence_generator import SequenceGenerator + + return SequenceGenerator( + models, + self.target_dictionary, + beam_size=getattr(args, "beam", 5), + max_len_a=getattr(args, "max_len_a", 0), + max_len_b=getattr(args, "max_len_b", 200), + min_len=getattr(args, "min_len", 1), + normalize_scores=(not getattr(args, "unnormalized", False)), + len_penalty=getattr(args, "lenpen", 1), + unk_penalty=getattr(args, "unkpen", 0), + temperature=getattr(args, "temperature", 1.0), + match_source_len=getattr(args, "match_source_len", False), + no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), + eos=self.tgt_dict.index("[{}]".format(self.args.target_lang)), + ) + + def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): + src_lang_id = self.source_dictionary.index("[{}]".format(self.args.source_lang)) + source_tokens = [] + for s_t in src_tokens: + s_t = torch.cat([s_t, s_t.new(1).fill_(src_lang_id)]) + source_tokens.append(s_t) + dataset = LanguagePairDataset( + source_tokens, + src_lengths, + self.source_dictionary, + tgt_dict=self.target_dictionary, + constraints=constraints, + ) + return dataset diff --git a/fairseq/tasks/translation_from_pretrained_xlm.py b/fairseq/tasks/translation_from_pretrained_xlm.py new file mode 100644 index 0000000000000000000000000000000000000000..a05f2891524a8b23482e206c1742c3b816b77afb --- /dev/null +++ b/fairseq/tasks/translation_from_pretrained_xlm.py @@ -0,0 +1,39 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from fairseq.data.legacy.masked_lm_dictionary import MaskedLMDictionary +from fairseq.tasks.translation import TranslationConfig, TranslationTask + +from . import register_task + + +@dataclass +class TranslationFromPretrainedXLMConfig(TranslationConfig): + pass + + +@register_task( + "translation_from_pretrained_xlm", dataclass=TranslationFromPretrainedXLMConfig +) +class TranslationFromPretrainedXLMTask(TranslationTask): + """ + Same as TranslationTask except use the MaskedLMDictionary class so that + we can load data that was binarized with the MaskedLMDictionary class. + + This task should be used for the entire training pipeline when we want to + train an NMT model from a pretrained XLM checkpoint: binarizing NMT data, + training NMT with the pretrained XLM checkpoint, and subsequent evaluation + of that trained model. + """ + + @classmethod + def load_dictionary(cls, filename): + """Load the masked LM dictionary from the filename + + Args: + filename (str): the filename + """ + return MaskedLMDictionary.load(filename) diff --git a/fairseq/tasks/translation_lev.py b/fairseq/tasks/translation_lev.py new file mode 100644 index 0000000000000000000000000000000000000000..b45fecd1f40ae43829ef43633a04dcbfd77a4136 --- /dev/null +++ b/fairseq/tasks/translation_lev.py @@ -0,0 +1,195 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +import torch +from fairseq import utils +from fairseq.data import LanguagePairDataset +from fairseq.dataclass import ChoiceEnum +from fairseq.tasks import register_task +from fairseq.tasks.translation import ( + TranslationConfig, + TranslationTask, + load_langpair_dataset, +) +from fairseq.utils import new_arange + + +NOISE_CHOICES = ChoiceEnum(["random_delete", "random_mask", "no_noise", "full_mask"]) + + +@dataclass +class TranslationLevenshteinConfig(TranslationConfig): + noise: NOISE_CHOICES = field( + default="random_delete", + metadata={"help": "type of noise"}, + ) + + +@register_task("translation_lev", dataclass=TranslationLevenshteinConfig) +class TranslationLevenshteinTask(TranslationTask): + """ + Translation (Sequence Generation) task for Levenshtein Transformer + See `"Levenshtein Transformer" `_. + """ + + cfg: TranslationLevenshteinConfig + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + paths = utils.split_paths(self.cfg.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + + # infer langcode + src, tgt = self.cfg.source_lang, self.cfg.target_lang + + self.datasets[split] = load_langpair_dataset( + data_path, + split, + src, + self.src_dict, + tgt, + self.tgt_dict, + combine=combine, + dataset_impl=self.cfg.dataset_impl, + upsample_primary=self.cfg.upsample_primary, + left_pad_source=self.cfg.left_pad_source, + left_pad_target=self.cfg.left_pad_target, + max_source_positions=self.cfg.max_source_positions, + max_target_positions=self.cfg.max_target_positions, + prepend_bos=True, + ) + + def inject_noise(self, target_tokens): + def _random_delete(target_tokens): + pad = self.tgt_dict.pad() + bos = self.tgt_dict.bos() + eos = self.tgt_dict.eos() + + max_len = target_tokens.size(1) + target_mask = target_tokens.eq(pad) + target_score = target_tokens.clone().float().uniform_() + target_score.masked_fill_( + target_tokens.eq(bos) | target_tokens.eq(eos), 0.0 + ) + target_score.masked_fill_(target_mask, 1) + target_score, target_rank = target_score.sort(1) + target_length = target_mask.size(1) - target_mask.float().sum( + 1, keepdim=True + ) + + # do not delete and (we assign 0 score for them) + target_cutoff = ( + 2 + + ( + (target_length - 2) + * target_score.new_zeros(target_score.size(0), 1).uniform_() + ).long() + ) + target_cutoff = target_score.sort(1)[1] >= target_cutoff + + prev_target_tokens = ( + target_tokens.gather(1, target_rank) + .masked_fill_(target_cutoff, pad) + .gather(1, target_rank.masked_fill_(target_cutoff, max_len).sort(1)[1]) + ) + prev_target_tokens = prev_target_tokens[ + :, : prev_target_tokens.ne(pad).sum(1).max() + ] + + return prev_target_tokens + + def _random_mask(target_tokens): + pad = self.tgt_dict.pad() + bos = self.tgt_dict.bos() + eos = self.tgt_dict.eos() + unk = self.tgt_dict.unk() + + target_masks = ( + target_tokens.ne(pad) & target_tokens.ne(bos) & target_tokens.ne(eos) + ) + target_score = target_tokens.clone().float().uniform_() + target_score.masked_fill_(~target_masks, 2.0) + target_length = target_masks.sum(1).float() + target_length = target_length * target_length.clone().uniform_() + target_length = target_length + 1 # make sure to mask at least one token. + + _, target_rank = target_score.sort(1) + target_cutoff = new_arange(target_rank) < target_length[:, None].long() + prev_target_tokens = target_tokens.masked_fill( + target_cutoff.scatter(1, target_rank, target_cutoff), unk + ) + return prev_target_tokens + + def _full_mask(target_tokens): + pad = self.tgt_dict.pad() + bos = self.tgt_dict.bos() + eos = self.tgt_dict.eos() + unk = self.tgt_dict.unk() + + target_mask = ( + target_tokens.eq(bos) | target_tokens.eq(eos) | target_tokens.eq(pad) + ) + return target_tokens.masked_fill(~target_mask, unk) + + if self.cfg.noise == "random_delete": + return _random_delete(target_tokens) + elif self.cfg.noise == "random_mask": + return _random_mask(target_tokens) + elif self.cfg.noise == "full_mask": + return _full_mask(target_tokens) + elif self.cfg.noise == "no_noise": + return target_tokens + else: + raise NotImplementedError + + def build_generator(self, models, args, **unused): + # add models input to match the API for SequenceGenerator + from fairseq.iterative_refinement_generator import IterativeRefinementGenerator + + return IterativeRefinementGenerator( + self.target_dictionary, + eos_penalty=getattr(args, "iter_decode_eos_penalty", 0.0), + max_iter=getattr(args, "iter_decode_max_iter", 10), + beam_size=getattr(args, "iter_decode_with_beam", 1), + reranking=getattr(args, "iter_decode_with_external_reranker", False), + decoding_format=getattr(args, "decoding_format", None), + adaptive=not getattr(args, "iter_decode_force_max_iter", False), + retain_history=getattr(args, "retain_iter_history", False), + ) + + def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): + if constraints is not None: + # Though see Susanto et al. (ACL 2020): https://www.aclweb.org/anthology/2020.acl-main.325/ + raise NotImplementedError( + "Constrained decoding with the translation_lev task is not supported" + ) + + return LanguagePairDataset( + src_tokens, src_lengths, self.source_dictionary, append_bos=True + ) + + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): + model.train() + sample["prev_target"] = self.inject_noise(sample["target"]) + loss, sample_size, logging_output = criterion(model, sample) + if ignore_grad: + loss *= 0 + optimizer.backward(loss) + return loss, sample_size, logging_output + + def valid_step(self, sample, model, criterion): + model.eval() + with torch.no_grad(): + sample["prev_target"] = self.inject_noise(sample["target"]) + loss, sample_size, logging_output = criterion(model, sample) + return loss, sample_size, logging_output diff --git a/fairseq/tasks/translation_multi_simple_epoch.py b/fairseq/tasks/translation_multi_simple_epoch.py new file mode 100644 index 0000000000000000000000000000000000000000..5db36a7c79ab291319339f5df15b70234154eda2 --- /dev/null +++ b/fairseq/tasks/translation_multi_simple_epoch.py @@ -0,0 +1,441 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import datetime +import logging +import time + +import torch +from fairseq.data import ( + FairseqDataset, + LanguagePairDataset, + ListDataset, + data_utils, + iterators, +) +from fairseq.data.multilingual.multilingual_data_manager import ( + MultilingualDatasetManager, +) +from fairseq.data.multilingual.sampling_method import SamplingMethod +from fairseq.tasks import LegacyFairseqTask, register_task +from fairseq.utils import FileContentsAction + + +### +def get_time_gap(s, e): + return ( + datetime.datetime.fromtimestamp(e) - datetime.datetime.fromtimestamp(s) + ).__str__() + + +### + + +logger = logging.getLogger(__name__) + + +@register_task("translation_multi_simple_epoch") +class TranslationMultiSimpleEpochTask(LegacyFairseqTask): + """ + Translate from one (source) language to another (target) language. + + Args: + langs (List[str]): a list of languages that are being supported + dicts (Dict[str, fairseq.data.Dictionary]): mapping from supported languages to their dictionaries + training (bool): whether the task should be configured for training or not + + .. note:: + + The translation task is compatible with :mod:`fairseq-train`, + :mod:`fairseq-generate` and :mod:`fairseq-interactive`. + + The translation task provides the following additional command-line + arguments: + + .. argparse:: + :ref: fairseq.tasks.translation_parser + :prog: + """ + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + # fmt: off + parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', + help='inference source language') + parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', + help='inference target language') + parser.add_argument('--lang-pairs', default=None, metavar='PAIRS', + help='comma-separated list of language pairs (in training order): en-de,en-fr,de-fr', + action=FileContentsAction) + parser.add_argument('--keep-inference-langtok', action='store_true', + help='keep language tokens in inference output (e.g. for analysis or debugging)') + + SamplingMethod.add_arguments(parser) + MultilingualDatasetManager.add_args(parser) + # fmt: on + + def __init__(self, args, langs, dicts, training): + super().__init__(args) + self.langs = langs + self.dicts = dicts + self.training = training + if training: + self.lang_pairs = args.lang_pairs + else: + self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)] + # eval_lang_pairs for multilingual translation is usually all of the + # lang_pairs. However for other multitask settings or when we want to + # optimize for certain languages we want to use a different subset. Thus + # the eval_lang_pairs class variable is provided for classes that extend + # this class. + self.eval_lang_pairs = self.lang_pairs + # model_lang_pairs will be used to build encoder-decoder model pairs in + # models.build_model(). This allows multitask type of sub-class can + # build models other than the input lang_pairs + self.model_lang_pairs = self.lang_pairs + self.source_langs = [d.split("-")[0] for d in self.lang_pairs] + self.target_langs = [d.split("-")[1] for d in self.lang_pairs] + self.check_dicts(self.dicts, self.source_langs, self.target_langs) + + self.sampling_method = SamplingMethod.build_sampler(args, self) + self.data_manager = MultilingualDatasetManager.setup_data_manager( + args, self.lang_pairs, langs, dicts, self.sampling_method + ) + + def check_dicts(self, dicts, source_langs, target_langs): + if self.args.source_dict is not None or self.args.target_dict is not None: + # no need to check whether the source side and target side are sharing dictionaries + return + src_dict = dicts[source_langs[0]] + tgt_dict = dicts[target_langs[0]] + for src_lang in source_langs: + assert ( + src_dict == dicts[src_lang] + ), "Diffrent dictionary are specified for different source languages; " + "TranslationMultiSimpleEpochTask only supports one shared dictionary across all source languages" + for tgt_lang in target_langs: + assert ( + tgt_dict == dicts[tgt_lang] + ), "Diffrent dictionary are specified for different target languages; " + "TranslationMultiSimpleEpochTask only supports one shared dictionary across all target languages" + + @classmethod + def setup_task(cls, args, **kwargs): + langs, dicts, training = MultilingualDatasetManager.prepare( + cls.load_dictionary, args, **kwargs + ) + return cls(args, langs, dicts, training) + + def has_sharded_data(self, split): + return self.data_manager.has_sharded_data(split) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + if split in self.datasets: + dataset = self.datasets[split] + if self.has_sharded_data(split): + if self.args.virtual_epoch_size is not None: + if dataset.load_next_shard: + shard_epoch = dataset.shard_epoch + else: + # no need to load next shard so skip loading + # also this avoid always loading from beginning of the data + return + else: + shard_epoch = epoch + else: + # estimate the shard epoch from virtual data size and virtual epoch size + shard_epoch = self.data_manager.estimate_global_pass_epoch(epoch) + logger.info(f"loading data for {split} epoch={epoch}/{shard_epoch}") + logger.info(f"mem usage: {data_utils.get_mem_usage()}") + if split in self.datasets: + del self.datasets[split] + logger.info("old dataset deleted manually") + logger.info(f"mem usage: {data_utils.get_mem_usage()}") + self.datasets[split] = self.data_manager.load_dataset( + split, + self.training, + epoch=epoch, + combine=combine, + shard_epoch=shard_epoch, + **kwargs, + ) + + def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): + if constraints is not None: + raise NotImplementedError( + "Constrained decoding with the multilingual_translation task is not supported" + ) + + src_data = ListDataset(src_tokens, src_lengths) + dataset = LanguagePairDataset(src_data, src_lengths, self.source_dictionary) + src_langtok_spec, tgt_langtok_spec = self.args.langtoks["main"] + if self.args.lang_tok_replacing_bos_eos: + dataset = self.data_manager.alter_dataset_langtok( + dataset, + src_eos=self.source_dictionary.eos(), + src_lang=self.args.source_lang, + tgt_eos=self.target_dictionary.eos(), + tgt_lang=self.args.target_lang, + src_langtok_spec=src_langtok_spec, + tgt_langtok_spec=tgt_langtok_spec, + ) + else: + dataset.src = self.data_manager.src_dataset_tranform_func( + self.args.source_lang, + self.args.target_lang, + dataset=dataset.src, + spec=src_langtok_spec, + ) + return dataset + + def build_generator( + self, + models, + args, + seq_gen_cls=None, + extra_gen_cls_kwargs=None, + ): + if not getattr(args, "keep_inference_langtok", False): + _, tgt_langtok_spec = self.args.langtoks["main"] + if tgt_langtok_spec: + tgt_lang_tok = self.data_manager.get_decoder_langtok( + self.args.target_lang, tgt_langtok_spec + ) + extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} + extra_gen_cls_kwargs["symbols_to_strip_from_output"] = {tgt_lang_tok} + + return super().build_generator( + models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs + ) + + def build_model(self, args, from_checkpoint=False): + return super().build_model(args, from_checkpoint) + + def valid_step(self, sample, model, criterion): + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) + return loss, sample_size, logging_output + + def inference_step( + self, generator, models, sample, prefix_tokens=None, constraints=None + ): + with torch.no_grad(): + _, tgt_langtok_spec = self.args.langtoks["main"] + if not self.args.lang_tok_replacing_bos_eos: + if prefix_tokens is None and tgt_langtok_spec: + tgt_lang_tok = self.data_manager.get_decoder_langtok( + self.args.target_lang, tgt_langtok_spec + ) + src_tokens = sample["net_input"]["src_tokens"] + bsz = src_tokens.size(0) + prefix_tokens = ( + torch.LongTensor([[tgt_lang_tok]]).expand(bsz, 1).to(src_tokens) + ) + return generator.generate( + models, + sample, + prefix_tokens=prefix_tokens, + constraints=constraints, + ) + else: + return generator.generate( + models, + sample, + prefix_tokens=prefix_tokens, + bos_token=self.data_manager.get_decoder_langtok( + self.args.target_lang, tgt_langtok_spec + ) + if tgt_langtok_spec + else self.target_dictionary.eos(), + ) + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + + def max_positions(self): + """Return the max sentence length allowed by the task.""" + return (self.args.max_source_positions, self.args.max_target_positions) + + @property + def source_dictionary(self): + return self.data_manager.get_source_dictionary(self.source_langs[0]) + + @property + def target_dictionary(self): + return self.data_manager.get_target_dictionary(self.target_langs[0]) + + def create_batch_sampler_func( + self, + max_positions, + ignore_invalid_inputs, + max_tokens, + max_sentences, + required_batch_size_multiple=1, + seed=1, + ): + def construct_batch_sampler(dataset, epoch): + splits = [ + s for s, _ in self.datasets.items() if self.datasets[s] == dataset + ] + split = splits[0] if len(splits) > 0 else None + # NEW implementation + if epoch is not None: + # initialize the dataset with the correct starting epoch + dataset.set_epoch(epoch) + + # get indices ordered by example size + start_time = time.time() + logger.info(f"start batch sampler: mem usage: {data_utils.get_mem_usage()}") + + with data_utils.numpy_seed(seed): + indices = dataset.ordered_indices() + logger.info( + f"[{split}] @batch_sampler order indices time: {get_time_gap(start_time, time.time())}" + ) + logger.info(f"mem usage: {data_utils.get_mem_usage()}") + + # filter examples that are too large + if max_positions is not None: + my_time = time.time() + indices = self.filter_indices_by_size( + indices, dataset, max_positions, ignore_invalid_inputs + ) + logger.info( + f"[{split}] @batch_sampler filter_by_size time: {get_time_gap(my_time, time.time())}" + ) + logger.info(f"mem usage: {data_utils.get_mem_usage()}") + + # create mini-batches with given size constraints + my_time = time.time() + batch_sampler = dataset.batch_by_size( + indices, + max_tokens=max_tokens, + max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + ) + + logger.info( + f"[{split}] @batch_sampler batch_by_size time: {get_time_gap(my_time, time.time())}" + ) + logger.info( + f"[{split}] per epoch batch_sampler set-up time: {get_time_gap(start_time, time.time())}" + ) + logger.info(f"mem usage: {data_utils.get_mem_usage()}") + + return batch_sampler + + return construct_batch_sampler + + # we need to override get_batch_iterator because we want to reset the epoch iterator each time + def get_batch_iterator( + self, + dataset, + max_tokens=None, + max_sentences=None, + max_positions=None, + ignore_invalid_inputs=False, + required_batch_size_multiple=1, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=1, + data_buffer_size=0, + disable_iterator_cache=False, + skip_remainder_batch=False, + grouped_shuffling=False, + update_epoch_batch_itr=False, + ): + """ + Get an iterator that yields batches of data from the given dataset. + + Args: + dataset (~fairseq.data.FairseqDataset): dataset to batch + max_tokens (int, optional): max number of tokens in each batch + (default: None). + max_sentences (int, optional): max number of sentences in each + batch (default: None). + max_positions (optional): max sentence length supported by the + model (default: None). + ignore_invalid_inputs (bool, optional): don't raise Exception for + sentences that are too long (default: False). + required_batch_size_multiple (int, optional): require batch size to + be a multiple of N (default: 1). + seed (int, optional): seed for random number generator for + reproducibility (default: 1). + num_shards (int, optional): shard the data iterator into N + shards (default: 1). + shard_id (int, optional): which shard of the data iterator to + return (default: 0). + num_workers (int, optional): how many subprocesses to use for data + loading. 0 means the data will be loaded in the main process + (default: 0). + epoch (int, optional): the epoch to start the iterator from + (default: 0). + data_buffer_size (int, optional): number of batches to + preload (default: 0). + disable_iterator_cache (bool, optional): don't cache the + EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`) + (default: False). + grouped_shuffling (bool, optional): group batches with each groups + containing num_shards batches and shuffle groups. Reduces difference + between sequence lengths among workers for batches sorted by length. + update_epoch_batch_itr (bool optional): if true then donot use the cached + batch iterator for the epoch + + Returns: + ~fairseq.iterators.EpochBatchIterator: a batched iterator over the + given dataset split + """ + # initialize the dataset with the correct starting epoch + assert isinstance(dataset, FairseqDataset) + if dataset in self.dataset_to_epoch_iter: + return self.dataset_to_epoch_iter[dataset] + if self.args.sampling_method == "RoundRobin": + batch_iter = super().get_batch_iterator( + dataset, + max_tokens=max_tokens, + max_sentences=max_sentences, + max_positions=max_positions, + ignore_invalid_inputs=ignore_invalid_inputs, + required_batch_size_multiple=required_batch_size_multiple, + seed=seed, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + epoch=epoch, + data_buffer_size=data_buffer_size, + disable_iterator_cache=disable_iterator_cache, + skip_remainder_batch=skip_remainder_batch, + update_epoch_batch_itr=update_epoch_batch_itr, + ) + self.dataset_to_epoch_iter[dataset] = batch_iter + return batch_iter + + construct_batch_sampler = self.create_batch_sampler_func( + max_positions, + ignore_invalid_inputs, + max_tokens, + max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + seed=seed, + ) + + epoch_iter = iterators.EpochBatchIterator( + dataset=dataset, + collate_fn=dataset.collater, + batch_sampler=construct_batch_sampler, + seed=seed, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + epoch=epoch, + ) + return epoch_iter diff --git a/fairseq/token_generation_constraints.py b/fairseq/token_generation_constraints.py new file mode 100644 index 0000000000000000000000000000000000000000..e708dc51bcb0ffb7b411496239c74d5e6f3c2448 --- /dev/null +++ b/fairseq/token_generation_constraints.py @@ -0,0 +1,506 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Implements tracking of constraints for a beam item. + +A list of constraints is given as a list of one or more token +sequences, each of length at least one token. For example, for an input sentence + +> Die maschinelle Übersetzung ist schwer zu kontrollieren. + +We could have the constraints: +* to influence +* hard + +There are two implementations: +* OrderedConstraintState: Tracks progress through an ordered list of multitoken constraints. +* UnorderedConstraintState: Tracks progress through an unordered list of multitoken constraints. + +The difference is that in the first, the constraints are assumed to be +in order; the algorithm will permit zero or more tokens between them. +In the second, the constraints are not ordered, so many orderings will +be explored. + +The same sequence can be present any number of times, and will appear +that many times in the output. +""" + +from collections import Counter +from typing import List, Optional, Set, Tuple + +import torch + + +class ConstraintState: + def __init__(self): + pass + + +def pack_constraints(batch_constraints: List[List[torch.Tensor]]) -> torch.Tensor: + """Takes a list of list of constraints in tensor form (a list of + tensor constraints for each sentence) and transforms it into a + packed Tensor. For example, here is a batch of size 3 with 3, 0, + and 1 constraints: + + [ [ [3 1 2], [3], [4 5 6 7], ] + [], + [ [1 8 9 10 1 4 11 12], ] + ] + + Its corresponding packed structure is: + + [ [ 3 3 1 2 0 3 0 4 5 6 7 0], + [ 0 0 0 0 0 0 0 0 0 0 0 0], + [ 1 1 8 9 10 1 4 11 12 0 0 0] ] + + The packed tensor has shape (batch size, maxlen), where + maxlen is defined below. Each row contains concatenated + constraint tokens for that sentence, with 0 appended after + each constraint. The first item in each row is the number + of constraints for that sentence. So maxlen is the maximum + of + + (number of constraints) + (sum length of constraints) + 1. + + across all sentences in the batch. + """ + # The maximum word length of concatenated constraints for any sentence + max_constraints_len = 1 + for sentence_constraints in batch_constraints: + if len(sentence_constraints): + # number of constraints, plus sum of constrain lens, plus a zero after each + constraints_len = ( + 1 + + sum([c.size(0) for c in sentence_constraints]) + + len(sentence_constraints) + ) + max_constraints_len = max(max_constraints_len, constraints_len) + + batch_size = len(batch_constraints) + constraints_tensor = torch.zeros((batch_size, max_constraints_len)).long() + for i, sentence_constraints in enumerate(batch_constraints): + constraints_tensor[i, 0] = len(sentence_constraints) + offset = 1 + for j, constraint in enumerate(sentence_constraints): + this_len = constraint.size(0) + constraints_tensor[i, offset : offset + this_len] = constraint + offset += this_len + 1 + + return constraints_tensor.long() + + +def unpack_constraints(constraint_tensor: torch.Tensor) -> List[torch.Tensor]: + """ + Transforms *one row* of a packed constraint tensor (e.g., for one + sentence in the batch) into a list of constraint tensors. + """ + constraint_list = [] + num_constraints = constraint_tensor[0] + constraints = constraint_tensor.tolist() + offset = 1 + for i in range(num_constraints): + where = constraints.index(0, offset) + constraint_list.append(constraint_tensor[offset:where]) + offset = where + 1 + + return constraint_list + + +class ConstraintNode: + """ + Represents a node in a trie managing unordered constraints. + """ + + def __init__(self, token: int = None, parent=None): + # The token associate with this node (None for the root) + self.token = int(token) if token is not None else None + # The parent (None at the root) + self.parent = parent + # Whether this node is a completed constraint + self.terminal = 0 + # List of child nodes + self.children = {} + + # The cumulative number of constraints from this point in the + # trie forward + self.num_constraints = 0 + + @property + def id(self): + return self.token + + def __str__(self): + term = self.terminal != 0 + return f"[{self.token}].{term}#{self.num_constraints}" + + def __getitem__(self, key: int): + return self.children.get(key, None) + + def next_tokens(self) -> Set[int]: + """The set of child labels.""" + return set(self.children.keys()) + + @staticmethod + def create(constraints: List[List[int]]): + root = ConstraintNode() + for sequence in constraints: + root.add_sequence(sequence) + + return root + + @staticmethod + def print_graph(node: "ConstraintNode"): + if len(node.children) == 0: + return str(node) + else: + s = f"({node}" + for child in node.children.values(): + s += " " + ConstraintNode.print_graph(child) + s += ")" + return s + + def token_counts(self) -> Counter: + """Returns a counter of the number of times each token is used + in a constraint. + """ + token_counts = Counter() + kids = list(self.children.values()) + while len(kids) > 0: + kid = kids.pop() + token_counts[kid.id] += kid.num_constraints + kids += list(kid.children.values()) + + return token_counts + + def tokens(self) -> Set[int]: + """Returns the set of tokens in constraints.""" + return set(self.token_counts().keys()) + + def add_sequence(self, sequence: List[int]): + """Adds a constraint, represented as a list of integers, to + the trie.""" + assert len(sequence) > 0 + + token = int(sequence[0]) + if token not in self.children: + self.children[token] = ConstraintNode(token, parent=self) + + node = self.children[token] + if len(sequence) == 1: + node.terminal += 1 + node.num_constraints += 1 + parent = node.parent + while parent is not None: + parent.num_constraints += 1 + parent = parent.parent + else: + node.add_sequence(sequence[1:]) + + +class UnorderedConstraintState(ConstraintState): + """ + Records progress through the set of constraints for each item in the beam + using a trie. + """ + + def __init__(self, node: ConstraintNode, copy_from: "ConstraintState" = None): + self.node = node + + if copy_from is None: + # The root node + self.root = node + # The set of states in the graph that have been completed + self.completed = Counter() + # The... + self.generated = Counter() + # The list of tokens we need to generate + self.needed_tokens = self.root.tokens() + else: + self.completed = Counter(copy_from.completed) + self.generated = Counter(copy_from.generated) + self.root = copy_from.root + + # Mark the node as generated + if self.node != self.root: + self.generated[node] += 1 + + @staticmethod + def create(constraint_tensor: torch.Tensor): + constraint_list = unpack_constraints(constraint_tensor) + constraint_trie_root = ConstraintNode.create(constraint_list) + return UnorderedConstraintState(constraint_trie_root) + + def __str__(self): + gen_str = ",".join([str(node) for node in self.generated]) + return f"{self.name}/{self.bank}({gen_str})x{self.num_completed}" + + def __copy__(self): + copied_state = UnorderedConstraintState(self.node, copy_from=self) + return copied_state + + def copy(self): + return self.__copy__() + + @property + def name(self): + if self.node.id is None: + return "ROOT" + else: + return str(self.node.id) + + @property + def is_root(self): + return self.node == self.root + + @property + def bank(self): + return sum(self.generated.values()) + + @property + def num_completed(self): + """The number of constraints (not constraint tokens) that are completed. + In addition to the already-completed states, we need to account for the + current state, which might get marked as completed when another token + is generated. + """ + in_final = self.node.terminal and self.completed[self.node] < self.node.terminal + return sum(self.completed.values()) + in_final + + @property + def finished(self): + return self.root.num_constraints - self.num_completed == 0 + + @property + def token_counts(self): + return self.root.token_counts() + + @property + def tokens(self): + return self.root.tokens() + + @property + def num_constraint_tokens(self): + return sum(self.token_counts.values()) + + def next_tokens(self) -> Set[int]: + """Returns the list of tokens that could come next. + These are (a) all tokens extending the root state and, for + non-root states, additionally all tokens extending the current + state.""" + + if self.node != self.root: + return self.root.next_tokens().union(self.node.next_tokens()) + else: + return self.root.next_tokens() + + def advance(self, token: int): + """Reads in a token and advances the state. Here's how it works. + + We can advance to the next state if: + - there is a matching child + - its path isn't blocked + + A path is blocked when all constraints that are descendants of + that node have already been generated, in the current state. + + If we are not able to advance from the current state, we "fall + off the graph" and return to the root state. There, we again + try to advance, checking the same criteria. + + In any case, when falling off the graph, we need to do some + bookkeeping. We: + - check whether any constraints were met (all prefixes of + current state) + - if one is found, mark it as completed + - adjust visited nodes accordingly + """ + token = int(token) + + next_state = None + child = self.node[token] + if child is not None and self.generated[child] < child.num_constraints: + next_state = UnorderedConstraintState(child, copy_from=self) + + def rewind(): + """If we're mid-trie and an "illegal" token is chosen next, we need + to reset our state to the root state. However, along the way, we need + to check whether a prefix of the current trie state represents a state + we could mark as completed. + """ + node = self.node + while node != self.root: + if node.terminal and self.completed[node] < node.terminal: + next_state.completed[node] += 1 + return + + next_state.generated[node] -= 1 + node = node.parent + + # Fall off the graph, check the root + if next_state is None and token in self.root.next_tokens(): + child = self.root[token] + # We can only traverse this edge if it's not saturated + if self.generated[child] < child.num_constraints: + next_state = UnorderedConstraintState(child, copy_from=self) + else: + next_state = UnorderedConstraintState(self.root, copy_from=self) + + # Rewind + rewind() + + elif next_state is None: + next_state = UnorderedConstraintState(self.root, copy_from=self) + # Rewind + rewind() + + return next_state + + +class ConstraintSequence: + def __init__(self, sequences: List[List[int]]): + """Represents a set of possibly multitoken constraints by + concatenating them and internally recording the end points. + """ + self.sequences = [] + self.endpoints = [] + self.num_tokens = 0 + self.tokens = set() + for sequence in sequences: + for token in sequence: + self.tokens.add(token) + self.num_tokens += len(sequence) + self.endpoints += [False for x in range(len(sequence) - 1)] + [True] + self.sequences += sequence + + def __getitem__(self, key: int): + return self.sequences[key] + + def __len__(self): + return len(self.sequences) + + def __str__(self): + return str(self.sequences) + + +class OrderedConstraintState(ConstraintState): + """ + Records progress through the set of linear nonbranching constraints with gaps. + """ + + def __init__(self, sequence: ConstraintSequence, state: int = -1): + self.sequence = sequence + self.state = state + + @staticmethod + def create(constraint_tensor: torch.Tensor): + constraint_list = unpack_constraints(constraint_tensor) + return OrderedConstraintState(ConstraintSequence(constraint_list), -1) + + def __str__(self): + return f"{self.state}/{self.bank}x{self.num_completed}" + + def __copy__(self): + return OrderedConstraintState(self.sequence, self.state) + + def copy(self): + return self.__copy__() + + @property + def num_completed(self): + if self.state == -1: + return 0 + count = len( + list(filter(lambda x: x, self.sequence.endpoints[0 : self.state + 1])) + ) + return count + + @property + def is_root(self): + return self.state == -1 + + @property + def name(self): + if self.state == -1: + return "ROOT" + else: + return str(self.sequence[self.state]) + + @property + def bank(self) -> int: + return self.state + 1 + + @property + def finished(self): + return self.state + 1 == len(self.sequence) + + @property + def token_counts(self): + return self.sequence.token_counts() + + @property + def tokens(self): + return self.sequence.tokens + + @property + def num_constraint_tokens(self): + return sum(self.token_counts.values()) + + def next_tokens(self) -> Set[int]: + """Returns the list of tokens that could come next. + These are (a) all tokens extending the root state and, for + non-root states, additionally all tokens extending the current + state.""" + + tokens = set() + if self.state > 0: + tokens.add(self.sequence[0]) + if not self.finished: + tokens.add(self.sequence[self.state + 1]) + return tokens + + def advance(self, token: int): + """Reads in a token and advances the state. Here's how it works. + + We can advance to the next state if: + - there is a matching child + - its path isn't blocked + + A path is blocked when all constraints that are descendants of + that node have already been generated, in the current state. + + If we are not able to advance from the current state, we "fall + off the graph" and return to the root state. There, we again + try to advance, checking the same criteria. + + In any case, when falling off the graph, we need to do some + bookkeeping. We: + - check whether any constraints were met (all prefixes of + current state) + - if one is found, mark it as completed + - adjust visited nodes accordingly + """ + token = int(token) + # print(f"{self} ADVANCE({token}) {self.sequence} -> ", end="") + + if self.finished: + # Accept anything + next_state = self.copy() + + elif self.sequence[self.state + 1] == token: + # Advance to the next token + next_state = OrderedConstraintState(self.sequence, self.state + 1) + + elif self.sequence.endpoints[self.state]: + # Accept anything between constraints (*) + next_state = self.copy() + + elif token == self.sequence[0]: + # Start over having generated the first token + next_state = OrderedConstraintState(self.sequence, 0) + else: + # Start over from the root + next_state = OrderedConstraintState(self.sequence, -1) + + return next_state diff --git a/fairseq/tokenizer.py b/fairseq/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..42131f7b1d334020c3b48a6e44d4139f7c62ad28 --- /dev/null +++ b/fairseq/tokenizer.py @@ -0,0 +1,15 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import re + + +SPACE_NORMALIZER = re.compile(r"\s+") + + +def tokenize_line(line): + line = SPACE_NORMALIZER.sub(" ", line) + line = line.strip() + return line.split() diff --git a/fairseq/trainer.py b/fairseq/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..16b1b9169738269147a615e1ca52036205e74421 --- /dev/null +++ b/fairseq/trainer.py @@ -0,0 +1,1622 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Train a network across multiple GPUs. +""" + +import contextlib +import logging +import os +import sys +import time +from argparse import Namespace +from itertools import chain +from typing import Any, Dict, List + +import torch +from omegaconf import OmegaConf + +from fairseq import checkpoint_utils, models, optim, utils +from fairseq.dataclass.configs import FairseqConfig +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.distributed import utils as distributed_utils +from fairseq.file_io import PathManager +from fairseq.logging import meters, metrics +from fairseq.models.ema import build_ema +from fairseq.nan_detector import NanDetector +from fairseq.optim import lr_scheduler +from fairseq.utils import safe_hasattr + +logger = logging.getLogger(__name__) + + +class Trainer(object): + """Main class for data parallel training. + + This class supports synchronous distributed data parallel training, + where multiple workers each have a full model replica and gradients + are accumulated across workers before each update. We use + :class:`~torch.nn.parallel.DistributedDataParallel` to handle + communication of the gradients across workers. + """ + + def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): + + if isinstance(cfg, Namespace): + logger.warning( + "argparse.Namespace configuration is deprecated! Automatically converting to OmegaConf" + ) + cfg = convert_namespace_to_omegaconf(cfg) + + self.cfg = cfg + self.task = task + + # catalog shared parameters + shared_params = _catalog_shared_params(model) + self.tpu = cfg.common.tpu + self.cuda = torch.cuda.is_available() and not cfg.common.cpu and not self.tpu + if self.cuda: + self.device = torch.device("cuda") + elif self.tpu: + self.device = utils.get_tpu_device() + else: + self.device = torch.device("cpu") + + if self.is_fsdp: + import fairscale + + if self.cfg.common.bf16: + raise ValueError( + "FullyShardedDataParallel is not compatible with --bf16 or " + "--memory-efficient-bf16" + ) + if self.cfg.distributed_training.zero_sharding != "none": + raise ValueError( + "FullyShardedDataParallel is not compatible with --zero-sharding " + "option (it's already built in)" + ) + if ( + max(self.cfg.optimization.update_freq) > 1 + and fairscale.__version__ < "0.4.0" + ): + raise RuntimeError( + "Please update to fairscale 0.4.0 or newer when combining " + "--update-freq with FullyShardedDataParallel" + ) + else: + if ( + hasattr(self.cfg.distributed_training, "cpu_offload") + and self.cfg.distributed_training.cpu_offload + ): + raise ValueError("--cpu-offload requires --ddp-backend=fully_sharded") + + # copy model and criterion to current device/dtype + self._criterion = criterion + self._model = model + if not self.is_fsdp: + if cfg.common.fp16: + assert not cfg.common.amp, "Cannot use fp16 and AMP together" + self._criterion = self._criterion.half() + self._model = self._model.half() + elif cfg.common.bf16: + self._criterion = self._criterion.to(dtype=torch.bfloat16) + self._model = self._model.to(dtype=torch.bfloat16) + elif cfg.common.amp: + self._amp_retries = 0 + if ( + not cfg.distributed_training.pipeline_model_parallel + # the DistributedFairseqModel wrapper will handle moving to device, + # so only handle cases which don't use the wrapper + and not self.use_distributed_wrapper + ): + self._criterion = self._criterion.to(device=self.device) + self._model = self._model.to(device=self.device) + self.pipeline_model_parallel = cfg.distributed_training.pipeline_model_parallel + self.last_device = None + if self.cuda and self.pipeline_model_parallel: + self.last_device = torch.device( + cfg.distributed_training.pipeline_devices[-1] + ) + + # check that shared parameters are preserved after device transfer + for shared_param in shared_params: + ref = _get_module_by_path(self._model, shared_param[0]) + for path in shared_param[1:]: + logger.info( + "detected shared parameter: {} <- {}".format(shared_param[0], path) + ) + _set_module_by_path(self._model, path, ref) + + self._dummy_batch = None # indicates we don't have a dummy batch at first + self._lr_scheduler = None + self._num_updates = 0 + self._num_xla_compiles = 0 # for TPUs + self._optim_history = None + self._optimizer = None + self._warn_once = set() + self._wrapped_criterion = None + self._wrapped_model = None + self._ema = None + + # TODO(myleott): support tpu + if self.cuda and self.data_parallel_world_size > 1: + self._grad_norm_buf = torch.cuda.DoubleTensor(self.data_parallel_world_size) + else: + self._grad_norm_buf = None + + self.quantizer = quantizer + if self.quantizer is not None: + self.quantizer.set_trainer(self) + + # get detailed cuda environment + if self.cuda: + self.cuda_env = utils.CudaEnvironment() + if self.data_parallel_world_size > 1: + self.cuda_env_arr = distributed_utils.all_gather_list( + self.cuda_env, group=distributed_utils.get_global_group() + ) + else: + self.cuda_env_arr = [self.cuda_env] + if self.data_parallel_rank == 0: + utils.CudaEnvironment.pretty_print_cuda_env_list(self.cuda_env_arr) + else: + self.cuda_env = None + self.cuda_env_arr = None + + metrics.log_start_time("wall", priority=790, round=0) + + self._start_time = time.time() + self._previous_training_time = 0 + self._cumulative_training_time = None + + def reinitialize(self): + """Reinitialize the Trainer, typically after model params change.""" + self._lr_scheduler = None + self._optimizer = None + self._wrapped_criterion = None + self._wrapped_model = None + + @property + def data_parallel_world_size(self): + if self.cfg.distributed_training.distributed_world_size == 1: + return 1 + return distributed_utils.get_data_parallel_world_size() + + @property + def data_parallel_process_group(self): + return distributed_utils.get_data_parallel_group() + + @property + def data_parallel_rank(self): + if self.cfg.distributed_training.distributed_world_size == 1: + return 0 + return distributed_utils.get_data_parallel_rank() + + @property + def is_data_parallel_master(self): + # NOTE: this returns true for all model parallel replicas with data + # parallel rank 0 + return self.data_parallel_rank == 0 + + @property + def use_distributed_wrapper(self) -> bool: + return ( + self.data_parallel_world_size > 1 and not self.cfg.optimization.use_bmuf + ) or (self.is_fsdp and self.cfg.distributed_training.cpu_offload) + + @property + def should_save_checkpoint_on_current_rank(self) -> bool: + """Indicates whether to save checkpoints on the current DDP rank.""" + if ( + self.is_fsdp and self.cfg.distributed_training.use_sharded_state + ) or getattr(self.cfg.model, "base_layers", 0) > 0: + return True + else: + return self.is_data_parallel_master + + @property + def always_call_state_dict_during_save_checkpoint(self) -> bool: + if self.is_fsdp and not self.cfg.distributed_training.use_sharded_state: + # FSDP calls communication collective when consolidating checkpoints + return True + else: + return False + + @property + def checkpoint_suffix(self) -> str: + """Suffix to add to the checkpoint file name.""" + if self.is_fsdp and self.cfg.distributed_training.use_sharded_state: + return self.cfg.checkpoint.checkpoint_suffix + "-shard{0}".format( + self.data_parallel_rank + ) + else: + return self.cfg.checkpoint.checkpoint_suffix or "" + + @property + def criterion(self): + if self._wrapped_criterion is None: + if utils.has_parameters(self._criterion) and self.use_distributed_wrapper: + self._wrapped_criterion = models.DistributedFairseqModel( + self.cfg.distributed_training, + self._criterion, + process_group=self.data_parallel_process_group, + device=self.device, + ) + else: + self._wrapped_criterion = self._criterion + return self._wrapped_criterion + + @property + def model(self): + if self._wrapped_model is None: + if self.use_distributed_wrapper: + self._wrapped_model = models.DistributedFairseqModel( + self.cfg.distributed_training, + self._model, + process_group=self.data_parallel_process_group, + device=self.device, + ) + else: + self._wrapped_model = self._model + return self._wrapped_model + + @property + def ema(self): + if self._ema is None: + self._build_ema() + return self._ema + + def _build_ema(self): + if self.cfg.ema.store_ema: + self._ema = build_ema(self._model, self.cfg.ema, self.device) + logger.info("Exponential Moving Average Shadow Model is initialized.") + + @property + def optimizer(self): + if self._optimizer is None: + self._build_optimizer() + return self._optimizer + + @property + def lr_scheduler(self): + if self._lr_scheduler is None: + self._build_optimizer() # this will initialize self._lr_scheduler + return self._lr_scheduler + + def _build_optimizer(self): + + if ( + self.cfg.optimization.debug_param_names + and self.cfg.common.fp16_no_flatten_grads + ): + params = [] + self.param_names = [] + + for n, p in chain( + self.model.named_parameters(), self.criterion.named_parameters() + ): + if p.requires_grad: + params.append(p) + self.param_names.append(n) + else: + params = list( + filter( + lambda p: p.requires_grad, + chain(self.model.parameters(), self.criterion.parameters()), + ) + ) + + if self.is_fsdp and self.cfg.common.fp16: + # FullyShardedDataParallel always uses MemoryEfficientFP16 wrapper, + # mostly for the grad scaling. But if we don't have the + # --memory-efficient-fp16 flag set, then we're effectively doing + # regular --fp16 and can allow the use of optimizers that would + # otherwise be unsupported by MemoryEfficientFP16Optimizer. + allow_unsupported = not self.cfg.common.memory_efficient_fp16 + self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer( + self.cfg, params, allow_unsupported=allow_unsupported + ) + elif self.cfg.common.fp16 or self.cfg.common.bf16 or self.cfg.common.amp: + if self.cuda and torch.cuda.get_device_capability(0)[0] < 7: + logger.info( + "NOTE: your device does NOT support faster training with --fp16 or --amp, " + "please switch to FP32 which is likely to be faster" + ) + if ( + self.cfg.common.memory_efficient_fp16 + or self.cfg.common.memory_efficient_bf16 + ): + self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer( + self.cfg, params + ) + elif self.cfg.common.amp: + self._optimizer = optim.AMPOptimizer.build_optimizer(self.cfg, params) + else: + self._optimizer = optim.FP16Optimizer.build_optimizer(self.cfg, params) + else: + if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7: + logger.info( + "NOTE: your device may support faster training with --fp16 or --amp" + ) + self._optimizer = optim.build_optimizer(self.cfg.optimizer, params) + + if self.is_fsdp: + assert ( + not self.cfg.optimization.use_bmuf + ), "--ddp-backend=fully_sharded is not compatible with BMUF" + assert self._optimizer.supports_flat_params, ( + "--ddp-backend=fully_sharded is only compatible with pointwise " + "optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.). " + "However, the sharding will result in slightly different results when " + "using non-pointwise optimizers (e.g., Adagrad, Adafactor, LAMB)" + ) + + if self.cfg.optimization.use_bmuf: + self._optimizer = optim.FairseqBMUF( + self.cfg.bmuf, + self._optimizer, + ) + + if self.cfg.distributed_training.zero_sharding == "os": + if ( + self.cfg.common.fp16 + and not self.cfg.common.memory_efficient_fp16 + and not self.cfg.common.memory_efficient_bf16 + ) and not self.cfg.common.fp16_no_flatten_grads: + raise ValueError( + "ZeRO is incomptabile with fp16 and flattened grads. " + "Please use --fp16-no-flatten-grads" + ) + else: + optim.shard_(self._optimizer, self.data_parallel_process_group) + + # We should initialize the learning rate scheduler immediately after + # building the optimizer, so that the initial learning rate is set. + self._lr_scheduler = lr_scheduler.build_lr_scheduler( + self.cfg.lr_scheduler, + self.optimizer, + ) + self._lr_scheduler.step_update(0) + + @property + def is_fsdp(self): + return self.cfg.distributed_training.ddp_backend == "fully_sharded" + + def consolidate_optimizer(self): + """For OSS, we need to consolidate the state dict.""" + if self.cfg.checkpoint.no_save_optimizer_state: + return + self._gathered_optim_state = None + if hasattr(self.optimizer.optimizer, "consolidate_state_dict"): + self.optimizer.optimizer.consolidate_state_dict() + elif self.is_fsdp and not self.model.use_sharded_state: + st = self.model.gather_full_optim_state_dict( + self.optimizer + ) # only returns on rank 0 + self._gathered_optim_state = st + + def state_dict(self): + state_dict = { + "args": None, # legacy + "cfg": ( + OmegaConf.to_container(self.cfg, resolve=True, enum_to_str=True) + if OmegaConf.is_config(self.cfg) + else self.cfg + ), + "model": self.model.state_dict(), + "criterion": ( + self.criterion.state_dict() + if utils.has_parameters(self.criterion) + else None + ), + "optimizer_history": (self._optim_history or []) + + [ + { + "criterion_name": self.get_criterion().__class__.__name__, + "optimizer_name": self.optimizer.__class__.__name__, + "lr_scheduler_state": self.lr_scheduler.state_dict(), + "num_updates": self.get_num_updates(), + } + ], + "task_state": self.task.state_dict() if self.task is not None else {}, + "extra_state": { + "metrics": metrics.state_dict(), + "previous_training_time": self.cumulative_training_time(), + }, + } + if self.cfg.ema.store_ema: + # Save EMA model state as extra state + state_dict["extra_state"]["ema"] = self.ema.get_model().state_dict() + if self.cfg.ema.ema_fp32: + # Save EMA params in fp32 + state_dict["extra_state"]["ema_fp32_params"] = self.ema.fp32_params + if not self.cfg.checkpoint.no_save_optimizer_state: + if self._gathered_optim_state is not None: + state_dict["last_optimizer_state"] = self._gathered_optim_state + self._gathered_optim_state = None + else: + state_dict["last_optimizer_state"] = self.optimizer.state_dict() + if self.is_fsdp: + # save meta data for recombining checkpoint upon loading + state_dict["fsdp_metadata"] = self.model.local_metadata_dict() + return state_dict + + def save_checkpoint(self, filename, extra_state): + """Save all training state in a checkpoint file.""" + if self.should_save_checkpoint_on_current_rank: + + logger.info(f"Saving checkpoint to {os.path.abspath(filename)}") + # call state_dict on all ranks in case it needs internal communication + state_dict = utils.move_to_cpu(self.state_dict()) + state_dict["extra_state"].update(extra_state) + + checkpoint_utils.torch_persistent_save( + state_dict, + filename, + async_write=self.cfg.checkpoint.write_checkpoints_asynchronously, + ) + logger.info(f"Finished saving checkpoint to {os.path.abspath(filename)}") + return os.path.abspath(filename) + return None + + def load_checkpoint( + self, + filename, + reset_optimizer=False, + reset_lr_scheduler=False, + optimizer_overrides=None, + reset_meters=False, + ): + """ + Load all training state from a checkpoint file. + rank = 0 will load the checkpoint, and then broadcast it to all + other ranks. + """ + extra_state, self._optim_history, last_optim_state = None, [], None + + logger.info(f"Preparing to load checkpoint {filename}") + is_distributed = self.data_parallel_world_size > 1 + bexists = PathManager.isfile(filename) + if bexists: + load_on_all_ranks = ( + self.cfg.checkpoint.load_checkpoint_on_all_dp_ranks + # TPUs don't support broadcast yet, so load checkpoints + # on every worker for now + or self.tpu + # FSDP requires loading checkpoint shards on all ranks + or (self.is_fsdp and self.cfg.distributed_training.use_sharded_state) + or getattr(self.cfg.model, "base_layers", 0) > 0 + ) + + if load_on_all_ranks or self.data_parallel_rank == 0: + state = checkpoint_utils.load_checkpoint_to_cpu( + filename, load_on_all_ranks=load_on_all_ranks + ) + last_optim_state = state.get("last_optimizer_state", None) + + # If doing zero_sharding, do not broadcast global optimizer + # state. Later we will broadcast sharded states to each rank + # to avoid memory from exploding. + if ( + not load_on_all_ranks + and self.cfg.distributed_training.zero_sharding == "os" + and "last_optimizer_state" in state + and is_distributed + ): + state["last_optimizer_state"] = "SHARDED" + else: + last_optim_state = None + state = None + + if is_distributed and not load_on_all_ranks: + state = distributed_utils.broadcast_object( + state, + src_rank=0, + group=self.data_parallel_process_group, + dist_device=self.device, + ) + if self.data_parallel_rank > 0: + last_optim_state = state.get("last_optimizer_state", None) + + # load model parameters + try: + if ( + "optimizer_history" in state + and len(state["optimizer_history"]) > 0 + and "num_updates" in state["optimizer_history"][-1] + ): + self.model.set_num_updates( + state["optimizer_history"][-1]["num_updates"] + ) + + # this is the code related to AdaPrune + # In short, it removes redundant heads in multi-head attention module based on heads importance provided + # For more info, please refer to the paper: https://openreview.net/forum?id=_CMSV7FTzGI + # The idea of prune in mha can be summarized as + # Fine tune model (e.g. roberta encoder) on a certain datasets with regularization + # After the model is trained. User could use get_reserve_head_index and _adaptive_prune_heads functions to get the top X heads with most importance. + # Then user uses the rank to prune a new roberta encoder and save the pruned ckpt manually. + # User will fine tune the the new roberta encoder via the ckpt saved above + # To get rid of registering different pruned version of Roberta, I use the argument --mha-heads-to-keep to prune the Roberta model into a pruned version which matches the pruned ckpt. + if ( + safe_hasattr(self.model, "args") + and safe_hasattr(self.model.args, "mha_heads_to_keep") + and self.model.args.mha_heads_to_keep != -1 + ): + logger.info( + f"Prune model: keep {self.model.args.mha_heads_to_keep} heads for each multihead attention module" + ) + for layer in self.model.encoder.sentence_encoder.layers: + reserve_head_index = layer.self_attn._get_reserve_head_index( + num_heads_to_keep=self.model.args.mha_heads_to_keep + ) + layer.self_attn._adaptive_prune_heads( + reserve_head_index=reserve_head_index + ) + layer.self_attn._set_skip_embed_dim_check() + logger.info(self.model) + # this is the code related to AdaPrune + # In short, it removes redundant units in feedforward layer in each transformer layer based on importance + # For more info, please refer to the paper: https://openreview.net/forum?id=_CMSV7FTzGI + # The idea of prune in ffn can be summarized as + # Fine tune model (e.g. roberta encoder) on a certain datasets with regularization + # After the model is trained. User could use _get_fc_rank and _prune_fc_layer functions to get the top X units with most importance. + # Then user uses the rank to prune a new roberta encoder and save the pruned ckpt manually. + # User will fine tune the the new roberta encoder via the ckpt saved above + # To get rid of registering different pruned version of Roberta, I use the argument --ffn-blocks-to-remove to prune the Roberta model into a pruned version which matches the pruned ckpt. + if ( + safe_hasattr(self.model, "args") + and safe_hasattr(self.model.args, "ffn_blocks_to_remove") + and self.model.args.ffn_blocks_to_remove != -1 + ): + logger.info( + f"Prune model: remove {self.model.args.ffn_blocks_to_remove} ffn blocks for each transformer layer" + ) + for layer in self.model.encoder.sentence_encoder.layers: + remove_index = layer._get_fc_rank( + remove_num=self.model.args.ffn_blocks_to_remove + ) + layer._prune_fc_layer(remove_index=remove_index) + logger.info(self.model) + + self.model.load_state_dict( + state["model"], strict=True, model_cfg=self.cfg.model + ) + # save memory for later steps + del state["model"] + if utils.has_parameters(self.get_criterion()): + self.get_criterion().load_state_dict( + state["criterion"], strict=True + ) + del state["criterion"] + + except Exception: + raise Exception( + "Cannot load model parameters from checkpoint {}; " + "please ensure that the architectures match.".format(filename) + ) + extra_state = state["extra_state"] + self._optim_history = state["optimizer_history"] + + if last_optim_state is not None and not reset_optimizer: + # rebuild optimizer after loading model, since params may have changed + self._build_optimizer() + + # only reload optimizer and lr_scheduler if they match + last_optim = self._optim_history[-1] + assert ( + last_optim["criterion_name"] == self.get_criterion().__class__.__name__ + ), f"Criterion does not match; please reset the optimizer (--reset-optimizer). {last_optim['criterion_name']} vs {self.get_criterion().__class__.__name__}" + assert ( + last_optim["optimizer_name"] == self.optimizer.__class__.__name__ + ), f"Optimizer does not match; please reset the optimizer (--reset-optimizer). {last_optim['optimizer_name']} vs {self.optimizer.__class__.__name__}" + + if not reset_lr_scheduler: + self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"]) + + if self.is_fsdp and not self.model.use_sharded_state: + # if use_sharded_state, the last_optim_state is already sharded, skip this + last_optim_state = self.model.get_shard_from_optim_state_dict( + last_optim_state + ) + elif not load_on_all_ranks and is_distributed: + last_optim_state = self.optimizer.broadcast_global_state_dict( + last_optim_state + ) + + self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) + + self.set_num_updates(last_optim["num_updates"]) + + if extra_state is not None: + itr_state = extra_state["train_iterator"] + epoch = itr_state["epoch"] + + if "previous_training_time" in extra_state: + self._previous_training_time = extra_state["previous_training_time"] + self._start_time = time.time() + + self.lr_step(epoch) + + if ( + itr_state.get("version", 1) >= 2 + and itr_state["iterations_in_epoch"] == 0 + ): + # reset meters at start of epoch + reset_meters = True + + if "metrics" in extra_state and not reset_meters: + metrics.load_state_dict(extra_state["metrics"]) + + # reset TimeMeters, since their start times don't make sense anymore + for meter in metrics.get_meters("default"): + if isinstance(meter, meters.TimeMeter): + meter.reset() + + if self.cfg.ema.store_ema: + if "ema" not in extra_state: + logger.warn( + "EMA not found in checkpoint. But store_ema is True. " + "EMA is re-initialized from checkpoint." + ) + self.ema.restore( + state["model"], build_fp32_params=self.cfg.ema.ema_fp32 + ) + else: + logger.info("Loading EMA from checkpoint") + self.ema.restore(extra_state["ema"], build_fp32_params=False) + + if self.cfg.ema.ema_fp32: + if "ema_fp32_params" in extra_state: + logger.info("Loading EMA fp32 params from checkpoint") + self.ema.build_fp32_params(extra_state["ema_fp32_params"]) + else: + logger.info( + "Building EMA fp32 params from EMA model in checkpoint" + ) + self.ema.build_fp32_params() + + logger.info( + "Loaded checkpoint {} (epoch {} @ {} updates)".format( + filename, epoch, self.get_num_updates() + ) + ) + + else: + logger.info("No existing checkpoint found {}".format(filename)) + + return extra_state + + def get_train_iterator( + self, + epoch, + combine=True, + load_dataset=True, + data_selector=None, + shard_batch_itr=True, + disable_iterator_cache=False, + ): + """Return an EpochBatchIterator over the training set for a given epoch.""" + if load_dataset: + logger.info("loading train data for epoch {}".format(epoch)) + self.task.load_dataset( + self.cfg.dataset.train_subset, + epoch=epoch, + combine=combine, + data_selector=data_selector, + tpu=self.tpu, + ) + batch_iterator = self.task.get_batch_iterator( + dataset=self.task.dataset(self.cfg.dataset.train_subset), + max_tokens=self.cfg.dataset.max_tokens, + max_sentences=self.cfg.dataset.batch_size, + max_positions=utils.resolve_max_positions( + self.task.max_positions(), + self.model.max_positions(), + self.cfg.dataset.max_tokens, + ), + ignore_invalid_inputs=True, + required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple, + seed=(self.cfg.common.seed + epoch) + if self.cfg.dataset.update_ordered_indices_seed + else self.cfg.common.seed, + num_shards=self.data_parallel_world_size if shard_batch_itr else 1, + shard_id=self.data_parallel_rank if shard_batch_itr else 0, + num_workers=self.cfg.dataset.num_workers, + epoch=epoch, + data_buffer_size=self.cfg.dataset.data_buffer_size, + disable_iterator_cache=disable_iterator_cache, + skip_remainder_batch=self.cfg.optimization.skip_remainder_batch, + grouped_shuffling=self.cfg.dataset.grouped_shuffling, + update_epoch_batch_itr=self.cfg.dataset.update_epoch_batch_itr, + ) + self.reset_dummy_batch(batch_iterator.first_batch) + return batch_iterator + + def get_valid_iterator( + self, + subset, + disable_iterator_cache=False, + ): + """Return an EpochBatchIterator over given validation subset for a given epoch.""" + batch_iterator = self.task.get_batch_iterator( + dataset=self.task.dataset(subset), + max_tokens=self.cfg.dataset.max_tokens_valid, + max_sentences=self.cfg.dataset.batch_size_valid, + max_positions=utils.resolve_max_positions( + self.task.max_positions(), + self.model.max_positions(), + ), + ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple, + seed=self.cfg.common.seed, + num_shards=self.data_parallel_world_size, + shard_id=self.data_parallel_rank, + num_workers=self.cfg.dataset.num_workers, + # always pass a fixed "epoch" to keep validation data consistent + # across training epochs + epoch=1, + data_buffer_size=self.cfg.dataset.data_buffer_size, + disable_iterator_cache=disable_iterator_cache, + skip_remainder_batch=False, + ) + self.reset_dummy_batch(batch_iterator.first_batch) + return batch_iterator + + def begin_epoch(self, epoch): + """Called at the beginning of each epoch.""" + logger.info("begin training epoch {}".format(epoch)) + + self.lr_step_begin_epoch(epoch) + + if self.quantizer is not None: + self.quantizer.begin_epoch(epoch) + + # task specific setup per epoch + self.task.begin_epoch(epoch, self.get_model()) + + if self.tpu: + import torch_xla.core.xla_model as xm + + xm.rendezvous("begin_epoch") # wait for all workers + xm.mark_step() + + def begin_valid_epoch(self, epoch): + """Called at the beginning of each validation epoch.""" + + # task specific setup per validation epoch + self.task.begin_valid_epoch(epoch, self.get_model()) + + def reset_dummy_batch(self, batch): + self._dummy_batch = batch + + @metrics.aggregate("train") + def train_step(self, samples, raise_oom=False): + """Do forward, backward and parameter update.""" + self._set_seed() + self.model.train() + self.criterion.train() + self.zero_grad() + + metrics.log_start_time("train_wall", priority=800, round=0) + + # If EMA is enabled through store_ema=True + # and task.uses_ema is True, pass the EMA model as a keyword + # argument to the task. + extra_kwargs = {} + if self.cfg.ema.store_ema and getattr(self.task, "uses_ema", False): + extra_kwargs["ema_model"] = self.ema.get_model() + + has_oom = False + + # forward and backward pass + logging_outputs, sample_size, ooms = [], 0, 0 + for i, sample in enumerate(samples): # delayed update loop + sample, is_dummy_batch = self._prepare_sample(sample) + + def maybe_no_sync(): + """ + Whenever *samples* contains more than one mini-batch, we + want to accumulate gradients locally and only call + all-reduce in the last backwards pass. + """ + if ( + self.data_parallel_world_size > 1 + and hasattr(self.model, "no_sync") + and i < len(samples) - 1 + # The no_sync context manager results in increased memory + # usage with FSDP, since full-size gradients will be + # accumulated on each GPU. It's typically a better tradeoff + # to do the extra communication with FSDP. + and not self.is_fsdp + ): + return self.model.no_sync() + else: + return contextlib.ExitStack() # dummy contextmanager + + try: + with maybe_no_sync(): + # forward and backward + loss, sample_size_i, logging_output = self.task.train_step( + sample=sample, + model=self.model, + criterion=self.criterion, + optimizer=self.optimizer, + update_num=self.get_num_updates(), + ignore_grad=is_dummy_batch, + **extra_kwargs, + ) + del loss + + logging_outputs.append(logging_output) + sample_size += sample_size_i + + # emptying the CUDA cache after the first step can + # reduce the chance of OOM + if self.cuda and self.get_num_updates() == 0: + torch.cuda.empty_cache() + except RuntimeError as e: + if "out of memory" in str(e): + self._log_oom(e) + has_oom = True + if raise_oom: + raise e + else: + raise e + except Exception: + self.consolidate_optimizer() + self.save_checkpoint( + os.path.join(self.cfg.checkpoint.save_dir, "crash.pt"), {} + ) + raise + + if has_oom: + logger.warning( + "attempting to recover from OOM in forward/backward pass" + ) + ooms += 1 + self.zero_grad() + if self.cuda: + torch.cuda.empty_cache() + + if self.cfg.distributed_training.distributed_world_size == 1: + return None + + if self.tpu and i < len(samples) - 1: + # tpu-comment: every XLA operation before marking step is + # appended to the IR graph, and processing too many batches + # before marking step can lead to OOM errors. + # To handle gradient accumulation use case, we explicitly + # mark step here for every forward pass without a backward pass + self._xla_markstep_and_send_to_cpu() + + if is_dummy_batch: + if torch.is_tensor(sample_size): + sample_size.zero_() + else: + sample_size *= 0.0 + + if torch.is_tensor(sample_size): + sample_size = sample_size.float() + else: + sample_size = float(sample_size) + + # gather logging outputs from all replicas + if self._sync_stats(): + train_time = self._local_cumulative_training_time() + ( + logging_outputs, + ( + sample_size, + ooms, + total_train_time, + ), + ) = self._aggregate_logging_outputs( + logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch + ) + self._cumulative_training_time = ( + total_train_time / self.data_parallel_world_size + ) + + overflow = False + try: + with torch.autograd.profiler.record_function("reduce-grads"): + # reduce gradients across workers + self.optimizer.all_reduce_grads(self.model) + if utils.has_parameters(self.criterion): + self.optimizer.all_reduce_grads(self.criterion) + + with torch.autograd.profiler.record_function("multiply-grads"): + # multiply gradients by (data_parallel_size / sample_size) since + # DDP normalizes by the number of data parallel workers for + # improved fp16 precision. + # Thus we get (sum_of_gradients / sample_size) at the end. + # In case of fp16, this step also undoes loss scaling. + # (Debugging note: Some optimizers perform this scaling on the + # fly, so inspecting model.parameters() or optimizer.params may + # still show the original, unscaled gradients.) + numer = ( + self.data_parallel_world_size + if not self.cfg.optimization.use_bmuf or self._sync_stats() + else 1 + ) + self.optimizer.multiply_grads(numer / (sample_size or 1.0)) + # Note: (sample_size or 1.0) handles the case of a zero gradient, in a + # way that avoids CPU/device transfers in case sample_size is a GPU or + # TPU object. The assumption is that the gradient itself is also 0. + + with torch.autograd.profiler.record_function("clip-grads"): + # clip grads + grad_norm = self.clip_grad_norm(self.cfg.optimization.clip_norm) + + # check that grad norms are consistent across workers + # on tpu check tensor is slow + if not self.tpu: + if ( + not self.cfg.optimization.use_bmuf + and self.cfg.distributed_training.ddp_backend != "slowmo" + ): + self._check_grad_norms(grad_norm) + if not torch.isfinite(grad_norm).all(): + # in case of AMP, if gradients are Nan/Inf then + # optimizer step is still required + if self.cfg.common.amp: + overflow = True + else: + # check local gradnorm single GPU case, trigger NanDetector + raise FloatingPointError("gradients are Nan/Inf") + + with torch.autograd.profiler.record_function("optimizer"): + # take an optimization step + self.task.optimizer_step( + self.optimizer, model=self.model, update_num=self.get_num_updates() + ) + if self.cfg.common.amp and overflow: + if self._amp_retries == self.cfg.common.amp_batch_retries: + logger.info("AMP: skipping this batch.") + self._amp_retries = 0 + else: + self._amp_retries += 1 + return self.train_step( + samples, raise_oom + ) # recursion to feed in same batch + + except FloatingPointError: + + self.consolidate_optimizer() + self.save_checkpoint( + os.path.join(self.cfg.checkpoint.save_dir, "crash.pt"), {} + ) + + # re-run the forward and backward pass with hooks attached to print + # out where it fails + self.zero_grad() + with NanDetector(self.get_model()): + for _, sample in enumerate(samples): + sample, _ = self._prepare_sample(sample) + self.task.train_step( + sample, + self.model, + self.criterion, + self.optimizer, + self.get_num_updates(), + ignore_grad=False, + **extra_kwargs, + ) + raise + except OverflowError as e: + overflow = True + logger.info( + f"NOTE: gradient overflow detected, ignoring gradient, {str(e)}" + ) + + if hasattr(self, "param_names") and hasattr( + self.optimizer, "fp32_optimizer" + ): + for p, n in zip(self.optimizer.fp32_optimizer.params, self.param_names): + if torch.isinf(p.grad).any() or torch.isnan(p.grad).any(): + logger.info(f"overflow in param {n}") + + grad_norm = torch.tensor(0.0).cuda() + self.zero_grad() + except RuntimeError as e: + if "out of memory" in str(e): + self._log_oom(e) + logger.error("OOM during optimization, irrecoverable") + raise e + + # Some distributed wrappers (e.g., SlowMo) need access to the optimizer + # after the step + if hasattr(self.model, "perform_slowmo"): + self.model.perform_slowmo( + self.optimizer.optimizer, getattr(self.optimizer, "fp32_params", None) + ) + + logging_output = None + if not overflow or self.cfg.distributed_training.ddp_backend == "slowmo": + self.set_num_updates(self.get_num_updates() + 1) + + if self.cfg.ema.store_ema: + # Step EMA forward with new model. + self.ema.step( + self.get_model(), + self.get_num_updates(), + ) + metrics.log_scalar( + "ema_decay", + self.ema.get_decay(), + priority=10000, + round=5, + weight=0, + ) + + if self.tpu: + import torch_xla.core.xla_model as xm + + # mark step on TPUs + self._xla_markstep_and_send_to_cpu() + + # only log stats every log_interval steps + # this causes wps to be misreported when log_interval > 1 + logging_output = {} + if self.get_num_updates() % self.cfg.common.log_interval == 0: + # log memory usage + mem_info = xm.get_memory_info(self.device) + gb_free = mem_info["kb_free"] / 1024 / 1024 + gb_total = mem_info["kb_total"] / 1024 / 1024 + metrics.log_scalar( + "gb_free", gb_free, priority=1500, round=1, weight=0 + ) + metrics.log_scalar( + "gb_total", gb_total, priority=1600, round=1, weight=0 + ) + logging_outputs = self._xla_markstep_and_send_to_cpu( + logging_outputs + ) + logging_output = self._reduce_and_log_stats( + logging_outputs, sample_size, grad_norm + ) + + # log whenever there's an XLA compilation, since these + # slow down training and may indicate opportunities for + # optimization + self._check_xla_compilation() + else: + if self.cuda and self.cuda_env is not None: + # log minimum free memory over the iteration + gb_used = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024 + torch.cuda.reset_peak_memory_stats() + gb_free = self.cuda_env.total_memory_in_GB - gb_used + metrics.log_scalar( + "gb_free", gb_free, priority=1500, round=1, weight=0 + ) + + # log stats + logging_output = self._reduce_and_log_stats( + logging_outputs, sample_size, grad_norm + ) + + # clear CUDA cache to reduce memory fragmentation + if ( + self.cuda + and self.cfg.common.empty_cache_freq > 0 + and ( + (self.get_num_updates() + self.cfg.common.empty_cache_freq - 1) + % self.cfg.common.empty_cache_freq + ) + == 0 + ): + torch.cuda.empty_cache() + + if self.cfg.common.fp16 or self.cfg.common.amp: + metrics.log_scalar( + "loss_scale", + ( + self.optimizer.scaler.loss_scale + if self.cfg.common.fp16 + else self.optimizer.scaler.get_scale() + ), + priority=700, + round=4, + weight=0, + ) + + metrics.log_stop_time("train_wall") + return logging_output + + @metrics.aggregate("valid") + def valid_step(self, sample, raise_oom=False): + """Do forward pass in evaluation mode.""" + if self.tpu: + import torch_xla.core.xla_model as xm + + xm.rendezvous("valid_step") # wait for all workers + + # If EMA is enabled through store_ema=True + # and task.uses_ema is True, pass the EMA model as a keyword + # argument to the task. + extra_kwargs = {} + if self.cfg.ema.store_ema and getattr(self.task, "uses_ema", False): + extra_kwargs["ema_model"] = self.ema.get_model() + + with torch.no_grad(): + self.model.eval() + self.criterion.eval() + + sample, is_dummy_batch = self._prepare_sample(sample) + + try: + _loss, sample_size, logging_output = self.task.valid_step( + sample, self.model, self.criterion, **extra_kwargs + ) + except RuntimeError as e: + if "out of memory" in str(e): + self._log_oom(e) + if not raise_oom: + logger.warning( + "ran out of memory in validation step, retrying batch" + ) + for p in self.model.parameters(): + if p.grad is not None: + p.grad = None # free some memory + if self.cuda: + torch.cuda.empty_cache() + return self.valid_step(sample, raise_oom=True) + raise e + + logging_outputs = [logging_output] + if is_dummy_batch: + if torch.is_tensor(sample_size): + sample_size.zero_() + else: + sample_size *= 0.0 + + # gather logging outputs from all replicas + if self.data_parallel_world_size > 1: + logging_outputs, (sample_size,) = self._aggregate_logging_outputs( + logging_outputs, + sample_size, + ignore=is_dummy_batch, + ) + + # log validation stats + if self.tpu: + logging_outputs = self._xla_markstep_and_send_to_cpu(logging_outputs) + logging_output = self._reduce_and_log_stats(logging_outputs, sample_size) + + return logging_output + + def zero_grad(self): + self.optimizer.zero_grad() + + def lr_step_begin_epoch(self, epoch): + """Adjust the learning rate at the beginning of the epoch.""" + self.lr_scheduler.step_begin_epoch(epoch) + # prefer updating the LR based on the number of steps + return self.lr_step_update() + + def lr_step(self, epoch, val_loss=None): + """Adjust the learning rate at the end of the epoch.""" + self.lr_scheduler.step(epoch, val_loss) + # prefer updating the LR based on the number of steps + return self.lr_step_update() + + def lr_step_update(self): + """Update the learning rate after each update.""" + new_lr = self.lr_scheduler.step_update(self.get_num_updates()) + if isinstance(new_lr, dict): + for k, v in new_lr.items(): + metrics.log_scalar(f"lr_{k}", v, weight=0, priority=300) + new_lr = new_lr.get("default", next(iter(new_lr.values()))) + else: + metrics.log_scalar("lr", new_lr, weight=0, priority=300) + return new_lr + + def get_lr(self): + """Get the current learning rate.""" + return self.optimizer.get_lr() + + def get_model(self): + """Get the (non-wrapped) model instance.""" + return self._model + + def get_criterion(self): + """Get the (non-wrapped) criterion instance.""" + return self._criterion + + def get_meter(self, name): + """[deprecated] Get a specific meter by name.""" + from fairseq import meters + + if "get_meter" not in self._warn_once: + self._warn_once.add("get_meter") + utils.deprecation_warning( + "Trainer.get_meter is deprecated. Please use fairseq.metrics instead." + ) + + train_meters = metrics.get_meters("train") + if train_meters is None: + train_meters = {} + + if name == "train_loss" and "loss" in train_meters: + return train_meters["loss"] + elif name == "train_nll_loss": + # support for legacy train.py, which assumed this meter is + # always initialized + m = train_meters.get("nll_loss", None) + return m or meters.AverageMeter() + elif name == "wall": + # support for legacy train.py, which assumed this meter is + # always initialized + m = metrics.get_meter("default", "wall") + return m or meters.TimeMeter() + elif name == "wps": + m = metrics.get_meter("train", "wps") + return m or meters.TimeMeter() + elif name in {"valid_loss", "valid_nll_loss"}: + # support for legacy train.py, which assumed these meters + # are always initialized + k = name[len("valid_") :] + m = metrics.get_meter("valid", k) + return m or meters.AverageMeter() + elif name == "oom": + return meters.AverageMeter() + elif name in train_meters: + return train_meters[name] + return None + + def get_num_updates(self): + """Get the number of parameters updates.""" + return self._num_updates + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + self._num_updates = num_updates + self.lr_step_update() + if self.quantizer: + self.quantizer.step_update(self._num_updates) + metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200) + + def clip_grad_norm(self, clip_norm): + def agg_norm_fn(total_norm): + total_norm = total_norm.cuda().float() ** 2 + total_norm = distributed_utils.all_reduce( + total_norm, group=self.data_parallel_process_group + ) + return total_norm**0.5 + + should_agg_norm = self.is_fsdp and ( + self.data_parallel_process_group is not None + or torch.distributed.is_initialized() + ) + return self.optimizer.clip_grad_norm( + clip_norm, aggregate_norm_fn=agg_norm_fn if should_agg_norm else None + ) + + def cumulative_training_time(self): + if self._cumulative_training_time is None: + # single GPU + return self._local_cumulative_training_time() + else: + return self._cumulative_training_time + + def _local_cumulative_training_time(self): + """Aggregate training time in seconds.""" + return time.time() - self._start_time + self._previous_training_time + + def _fp_convert_sample(self, sample): + def apply_half(t): + if t.dtype is torch.float32: + return t.to(dtype=torch.half) + return t + + def apply_bfloat16(t): + if t.dtype is torch.float32: + return t.to(dtype=torch.bfloat16) + return t + + if self.cfg.common.fp16: + sample = utils.apply_to_sample(apply_half, sample) + + if self.cfg.common.bf16: + sample = utils.apply_to_sample(apply_bfloat16, sample) + + return sample + + def _prepare_sample(self, sample, is_dummy=False): + if sample == "DUMMY": + raise Exception( + "Trying to use an uninitialized 'dummy' batch. This usually indicates " + "that the total number of batches is smaller than the number of " + "participating GPUs. Try reducing the batch size or using fewer GPUs." + ) + + if sample is None or len(sample) == 0: + assert ( + self._dummy_batch is not None and len(self._dummy_batch) > 0 + ), "Invalid dummy batch: {}".format(self._dummy_batch) + sample, _ = self._prepare_sample(self._dummy_batch, is_dummy=True) + return sample, True + + # Given that PCIe/NVLink bandwidth is significantly smaller than DRAM bandwidth + # it makes sense to do the format conversion on the CPU and then transfer + # a smaller buffer to the device. This also saves GPU memory capacity. + + if self.cfg.common.on_cpu_convert_precision: + sample = self._fp_convert_sample(sample) + + if self.cuda: + if self.pipeline_model_parallel: + if "target" in sample: + sample["target"] = utils.move_to_cuda( + sample["target"], device=self.last_device + ) + else: + sample = utils.move_to_cuda(sample) + elif self.tpu and is_dummy: + # the dummy batch may not be on the appropriate device + sample = utils.move_to_cuda(sample, device=self.device) + + if not self.cfg.common.on_cpu_convert_precision: + sample = self._fp_convert_sample(sample) + + if self._dummy_batch == "DUMMY": + self._dummy_batch = sample + + return sample, False + + def _set_seed(self): + # Set seed based on args.seed and the update number so that we get + # reproducible results when resuming from checkpoints + seed = self.cfg.common.seed + self.get_num_updates() + utils.set_torch_seed(seed) + + def _sync_stats(self): + # Return True if it's using multiple GPUs and DDP or multiple GPUs with + # BMUF and it's a bmuf sync with warmup iterations completed before. + if self.data_parallel_world_size == 1: + return False + elif self.cfg.optimization.use_bmuf: + return ( + self.get_num_updates() + 1 + ) % self.cfg.bmuf.global_sync_iter == 0 and ( + self.get_num_updates() + 1 + ) > self.cfg.bmuf.warmup_iterations + else: + return True + + def _log_oom(self, exc): + msg = "OOM: Ran out of memory with exception: {}".format(exc) + logger.warning(msg) + if torch.cuda.is_available() and hasattr(torch.cuda, "memory_summary"): + for device_idx in range(torch.cuda.device_count()): + logger.warning(torch.cuda.memory_summary(device=device_idx)) + sys.stderr.flush() + + def _aggregate_logging_outputs( + self, + logging_outputs: List[Dict[str, Any]], + *extra_stats_to_sum, + ignore=False, + ): + if self.task.__class__.logging_outputs_can_be_summed(self.get_criterion()): + return self._fast_stat_sync_sum( + logging_outputs, *extra_stats_to_sum, ignore=ignore + ) + else: + return self._all_gather_list_sync( + logging_outputs, *extra_stats_to_sum, ignore=ignore + ) + + def _all_gather_list_sync( + self, + logging_outputs: List[Dict[str, Any]], + *extra_stats_to_sum, + ignore=False, + ): + """ + Sync logging outputs across workers. all_gather_list_sync is + suitable when logging outputs are complex types. + """ + if self.tpu: + raise NotImplementedError + if ignore: + logging_outputs = [] + results = list( + zip( + *distributed_utils.all_gather_list( + [logging_outputs] + list(extra_stats_to_sum), + max_size=getattr(self.cfg.common, "all_gather_list_size", 16384), + group=self.data_parallel_process_group, + ) + ) + ) + logging_outputs, extra_stats_to_sum = results[0], results[1:] + logging_outputs = list(chain.from_iterable(logging_outputs)) + extra_stats_to_sum = [sum(s) for s in extra_stats_to_sum] + return logging_outputs, extra_stats_to_sum + + def _fast_stat_sync_sum( + self, + logging_outputs: List[Dict[str, Any]], + *extra_stats_to_sum, + ignore=False, + ): + """ + Sync logging outputs across workers. fast_stat_sync_sum is + faster than all_gather_list_sync, but is only suitable when + logging outputs are scalars and can be summed. Note that + *logging_outputs* cannot contain any nested dicts/lists. + """ + data = {} + for i, stat in enumerate(extra_stats_to_sum): + data["extra_stats_" + str(i)] = stat + if len(logging_outputs) > 0: + log_keys = list(logging_outputs[0].keys()) + for k in log_keys: + if not ignore: + v = sum(log[k] for log in logging_outputs if k in log) + else: + v = logging_outputs[0][k] + v = torch.zeros_like(v) if torch.is_tensor(v) else 0 + data["logging_outputs_" + k] = v + else: + log_keys = None + + data = distributed_utils.all_reduce_dict( + data, device=self.device, group=self.data_parallel_process_group + ) + + extra_stats_to_sum = [ + data["extra_stats_" + str(i)] for i in range(len(extra_stats_to_sum)) + ] + if log_keys is not None: + logging_outputs = [{k: data["logging_outputs_" + k] for k in log_keys}] + else: + logging_outputs = [] + return logging_outputs, extra_stats_to_sum + + def _check_grad_norms(self, grad_norm): + """Check that grad norms are consistent across workers.""" + if self._grad_norm_buf is not None: + self._grad_norm_buf.zero_() + self._grad_norm_buf[self.data_parallel_rank] = grad_norm + distributed_utils.all_reduce( + self._grad_norm_buf, group=self.data_parallel_process_group + ) + + def is_consistent(tensor): + max_abs_diff = torch.max(torch.abs(tensor - tensor[0])) + return ( + ( + torch.isfinite(tensor).all() + and (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all() + ) + or (self.cfg.common.amp and not torch.isfinite(tensor).all()) + # in case of amp non-finite grads are fine + ) + + if not is_consistent(self._grad_norm_buf): + pretty_detail = "\n".join( + "rank {:3d} = {:.8f}".format(r, n) + for r, n in enumerate(self._grad_norm_buf.tolist()) + ) + error_detail = "grad_norm across the workers:\n{}\n".format( + pretty_detail + ) + # use FloatingPointError to trigger NanDetector + raise FloatingPointError( + "Fatal error: gradients are inconsistent between workers. " + "Try --ddp-backend=legacy_ddp. " + "Or are you mixing up different generation of GPUs in training?" + + "\n" + + "-" * 80 + + "\n{}\n".format(error_detail) + + "-" * 80 + ) + + def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None): + if grad_norm is not None and ( + not torch.is_tensor(grad_norm) or torch.isfinite(grad_norm) + ): + metrics.log_speed("ups", 1.0, priority=100, round=2) + metrics.log_scalar("gnorm", grad_norm, priority=400, round=3) + if self.cfg.optimization.clip_norm > 0: + metrics.log_scalar( + "clip", + torch.where( + grad_norm > self.cfg.optimization.clip_norm, + grad_norm.new_tensor(100), + grad_norm.new_tensor(0), + ), + priority=500, + round=1, + ) + + with metrics.aggregate() as agg: + if logging_outputs is not None: + self.task.reduce_metrics(logging_outputs, self.get_criterion()) + del logging_outputs + + # extra warning for criterions that don't properly log a loss value + if "loss" not in agg: + if "loss" not in self._warn_once: + self._warn_once.add("loss") + logger.warning( + "Criterion.reduce_metrics did not log a 'loss' value, " + "which may break some functionality" + ) + metrics.log_scalar("loss", -1) + + # support legacy interface + if self.tpu: + logging_output = {} + else: + logging_output = agg.get_smoothed_values() + logging_output["sample_size"] = sample_size + for key_to_delete in ["ppl", "wps", "wpb", "bsz"]: + if key_to_delete in logging_output: + del logging_output[key_to_delete] + return logging_output + + def _check_xla_compilation(self): + import torch_xla.debug.metrics as met + + compile_stats = met.metric_data("CompileTime") + if compile_stats is None: + return + num_xla_compiles = compile_stats[0] + if num_xla_compiles > self._num_xla_compiles: + logger.warning( + "XLA compilation detected on device #{}; too many of these can lead " + "to slow training, but we expect a few in the beginning".format( + self.cfg.distributed_training.distributed_rank + ) + ) + self._num_xla_compiles = num_xla_compiles + + def _xla_markstep_and_send_to_cpu(self, data=None): + import torch_xla.core.xla_model as xm + + xm.mark_step() + if data is not None: + from fairseq.utils import xla_device_to_cpu + + return xla_device_to_cpu(data) + + +def _catalog_shared_params(module, memo=None, prefix=""): + if memo is None: + first_call = True + memo = {} + else: + first_call = False + for name, param in module._parameters.items(): + param_prefix = prefix + ("." if prefix else "") + name + if param not in memo: + memo[param] = [] + memo[param].append(param_prefix) + for name, m in module._modules.items(): + if m is None: + continue + submodule_prefix = prefix + ("." if prefix else "") + name + _catalog_shared_params(m, memo, submodule_prefix) + if first_call: + return [x for x in memo.values() if len(x) > 1] + + +def _get_module_by_path(module, path): + path = path.split(".") + for name in path: + module = getattr(module, name) + return module + + +def _set_module_by_path(module, path, value): + path = path.split(".") + for name in path[:-1]: + module = getattr(module, name) + setattr(module, path[-1], value) diff --git a/fairseq/utils.py b/fairseq/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4d4b35052305e7b60aa958d6d9b88a7ce0201045 --- /dev/null +++ b/fairseq/utils.py @@ -0,0 +1,951 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import collections +import contextlib +import copy +import importlib +import logging +import os +import sys +import warnings +from itertools import accumulate +from typing import TYPE_CHECKING, Callable, Dict, List, Optional + +import torch +import torch.nn.functional as F +from torch import Tensor + +if TYPE_CHECKING: + from fairseq.modules.multihead_attention import MultiheadAttention + +try: + from amp_C import multi_tensor_l2norm + + multi_tensor_l2norm_available = True +except ImportError: + multi_tensor_l2norm_available = False + +try: + import torch_xla.core.xla_model as xm +except ImportError: + xm = None + + +logger = logging.getLogger(__name__) + + +MANIFOLD_PATH_SEP = "|" + + +class FileContentsAction(argparse.Action): + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super(FileContentsAction, self).__init__(option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + from fairseq.file_io import PathManager + + if PathManager.isfile(values): + with PathManager.open(values) as f: + argument = f.read().strip() + else: + argument = values + setattr(namespace, self.dest, argument) + + +def split_paths(paths: str, separator=os.pathsep) -> List[str]: + return ( + paths.split(separator) if "://" not in paths else paths.split(MANIFOLD_PATH_SEP) + ) + + +def load_ensemble_for_inference(filenames, task, model_arg_overrides=None): + from fairseq import checkpoint_utils + + deprecation_warning( + "utils.load_ensemble_for_inference is deprecated. " + "Please use checkpoint_utils.load_model_ensemble instead." + ) + return checkpoint_utils.load_model_ensemble( + filenames, arg_overrides=model_arg_overrides, task=task + ) + + +def apply_to_sample(f, sample): + if hasattr(sample, "__len__") and len(sample) == 0: + return {} + + def _apply(x): + if torch.is_tensor(x): + return f(x) + elif isinstance(x, collections.OrderedDict): + # OrderedDict has attributes that needs to be preserved + od = collections.OrderedDict( + (key, _apply(value)) for key, value in x.items() + ) + od.__dict__ = x.__dict__ + return od + elif isinstance(x, dict): + return {key: _apply(value) for key, value in x.items()} + elif isinstance(x, list): + return [_apply(x) for x in x] + elif isinstance(x, tuple): + return tuple(_apply(x) for x in x) + elif isinstance(x, set): + return {_apply(x) for x in x} + else: + return x + + return _apply(sample) + + +def move_to_cuda(sample, device=None): + device = device or torch.cuda.current_device() + + def _move_to_cuda(tensor): + # non_blocking is ignored if tensor is not pinned, so we can always set + # to True (see github.com/PyTorchLightning/pytorch-lightning/issues/620) + return tensor.to(device=device, non_blocking=True) + + return apply_to_sample(_move_to_cuda, sample) + + +def move_to_cpu(sample): + def _move_to_cpu(tensor): + # PyTorch has poor support for half tensors (float16) on CPU. + # Move any such tensors to float32. + if tensor.dtype in {torch.bfloat16, torch.float16}: + tensor = tensor.to(dtype=torch.float32) + return tensor.cpu() + + return apply_to_sample(_move_to_cpu, sample) + + +def move_to_tpu(sample): + + import torch_xla.core.xla_model as xm + + device = xm.xla_device() + + def _move_to_tpu(tensor): + return tensor.to(device) + + return apply_to_sample(_move_to_tpu, sample) + + +def get_incremental_state( + module: "MultiheadAttention", + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, +) -> Optional[Dict[str, Optional[Tensor]]]: + """Helper for getting incremental state for an nn.Module.""" + return module.get_incremental_state(incremental_state, key) + + +def set_incremental_state( + module: "MultiheadAttention", + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + value: Dict[str, Optional[Tensor]], +) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: + """Helper for setting incremental state for an nn.Module.""" + if incremental_state is not None: + result = module.set_incremental_state(incremental_state, key, value) + if result is not None: + incremental_state = result + return incremental_state + + +def load_align_dict(replace_unk): + if replace_unk is None: + align_dict = None + elif isinstance(replace_unk, str) and len(replace_unk) > 0: + # Load alignment dictionary for unknown word replacement if it was passed as an argument. + align_dict = {} + with open(replace_unk, "r") as f: + for line in f: + cols = line.split() + align_dict[cols[0]] = cols[1] + else: + # No alignment dictionary provided but we still want to perform unknown word replacement by copying the + # original source word. + align_dict = {} + return align_dict + + +def print_embed_overlap(embed_dict, vocab_dict): + embed_keys = set(embed_dict.keys()) + vocab_keys = set(vocab_dict.symbols) + overlap = len(embed_keys & vocab_keys) + logger.info("found {}/{} types in embedding file".format(overlap, len(vocab_dict))) + + +def parse_embedding(embed_path): + """Parse embedding text file into a dictionary of word and embedding tensors. + + The first line can have vocabulary size and dimension. The following lines + should contain word and embedding separated by spaces. + + Example: + 2 5 + the -0.0230 -0.0264 0.0287 0.0171 0.1403 + at -0.0395 -0.1286 0.0275 0.0254 -0.0932 + """ + embed_dict = {} + with open(embed_path) as f_embed: + next(f_embed) # skip header + for line in f_embed: + pieces = line.rstrip().split(" ") + embed_dict[pieces[0]] = torch.Tensor( + [float(weight) for weight in pieces[1:]] + ) + return embed_dict + + +def load_embedding(embed_dict, vocab, embedding): + for idx in range(len(vocab)): + token = vocab[idx] + if token in embed_dict: + embedding.weight.data[idx] = embed_dict[token] + return embedding + + +def replace_unk(hypo_str, src_str, alignment, align_dict, unk): + from fairseq import tokenizer + + # Tokens are strings here + hypo_tokens = tokenizer.tokenize_line(hypo_str) + # TODO: Very rare cases where the replacement is '' should be handled gracefully + src_tokens = tokenizer.tokenize_line(src_str) + [""] + for i, ht in enumerate(hypo_tokens): + if ht == unk: + src_token = src_tokens[alignment[i]] + # Either take the corresponding value in the aligned dictionary or just copy the original value. + hypo_tokens[i] = align_dict.get(src_token, src_token) + return " ".join(hypo_tokens) + + +def post_process_prediction( + hypo_tokens, + src_str, + alignment, + align_dict, + tgt_dict, + remove_bpe=None, + extra_symbols_to_ignore=None, +): + hypo_str = tgt_dict.string( + hypo_tokens, remove_bpe, extra_symbols_to_ignore=extra_symbols_to_ignore + ) + if align_dict is not None: + hypo_str = replace_unk( + hypo_str, src_str, alignment, align_dict, tgt_dict.unk_string() + ) + if align_dict is not None or remove_bpe is not None: + # Convert back to tokens for evaluating with unk replacement or without BPE + # Note that the dictionary can be modified inside the method. + hypo_tokens = tgt_dict.encode_line(hypo_str, add_if_not_exist=True) + return hypo_tokens, hypo_str, alignment + + +def make_positions(tensor, padding_idx: int, onnx_trace: bool = False): + """Replace non-padding symbols with their position numbers. + + Position numbers begin at padding_idx+1. Padding symbols are ignored. + """ + # The series of casts and type-conversions here are carefully + # balanced to both work with ONNX export and XLA. In particular XLA + # prefers ints, cumsum defaults to output longs, and ONNX doesn't know + # how to handle the dtype kwarg in cumsum. + mask = tensor.ne(padding_idx).int() + return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx + + +def strip_pad(tensor, pad): + return tensor[tensor.ne(pad)] + + +def buffered_arange(max, device="cpu"): + if not hasattr(buffered_arange, "buf"): + buffered_arange.buf = torch.LongTensor().to(device) + if max > buffered_arange.buf.numel(): + buffered_arange.buf.resize_(max) + torch.arange(max, out=buffered_arange.buf) + return buffered_arange.buf[:max] + + +def convert_padding_direction( + src_tokens, padding_idx, right_to_left: bool = False, left_to_right: bool = False +): + assert right_to_left ^ left_to_right + pad_mask = src_tokens.eq(padding_idx) + if not pad_mask.any(): + # no padding, return early + return src_tokens + if left_to_right and not pad_mask[:, 0].any(): + # already right padded + return src_tokens + if right_to_left and not pad_mask[:, -1].any(): + # already left padded + return src_tokens + max_len = src_tokens.size(1) + buffered = torch.empty(0).long() + if max_len > 0: + torch.arange(max_len, out=buffered) + range = buffered.type_as(src_tokens).expand_as(src_tokens) + num_pads = pad_mask.long().sum(dim=1, keepdim=True) + if right_to_left: + index = torch.remainder(range - num_pads, max_len) + else: + index = torch.remainder(range + num_pads, max_len) + return src_tokens.gather(1, index) + + +def item(tensor): + # tpu-comment: making this a no-op for xla devices. + if torch.is_tensor(tensor) and tensor.device.type == "xla": + return tensor.detach() + if hasattr(tensor, "item"): + return tensor.item() + if hasattr(tensor, "__getitem__"): + return tensor[0] + return tensor + + +def multi_tensor_total_norm(grads, chunk_size=2048 * 32) -> torch.Tensor: + per_device_grads = {} + norms = [] + for grad in grads: + device = grad.device + cur_device_grads = per_device_grads.get(device) + if cur_device_grads is None: + cur_device_grads = [] + per_device_grads[device] = cur_device_grads + cur_device_grads.append(grad) + for device in per_device_grads.keys(): + cur_device_grads = per_device_grads[device] + if device.type == "cuda": + # TODO(msb) return has_inf + has_inf = torch.zeros((1, 1), dtype=torch.int, device=device) + with torch.cuda.device(device): + norm = multi_tensor_l2norm( + chunk_size, has_inf, [cur_device_grads], False + ) + norms.append(norm[0].to(torch.cuda.current_device())) + else: + norms += [torch.norm(g, p=2, dtype=torch.float32) for g in cur_device_grads] + total_norm = torch.norm(torch.stack(norms)) + return total_norm + + +@torch.no_grad() +def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor: + def grad_exists(p): + return p is not None and getattr(p, "grad", None) is not None + + if isinstance(params, torch.Tensor): + params = [params] + params = list(params) + grads = [ + p.grad.detach() for p in params if grad_exists(p) and not hasattr(p, "expert") + ] + expert_grads = [ + p.grad.detach() for p in params if grad_exists(p) and hasattr(p, "expert") + ] + + if len(grads) == 0: + if len(params) > 0: + return params[0].new_tensor(0.0) + else: + return torch.tensor(0.0) + + if len(grads) == 1: + total_norm = torch.norm(grads[0], p=2, dtype=torch.float32) + else: + if multi_tensor_l2norm_available: + total_norm = multi_tensor_total_norm(grads) + else: + if torch.cuda.is_available(): + warnings.warn( + "amp_C fused kernels unavailable, disabling multi_tensor_l2norm; " + "you may get better performance by installing NVIDIA's apex library" + ) + device = torch.cuda.current_device() + elif grads[0].device.type == "xla": + device = grads[0].device + else: + device = torch.device("cpu") + total_norm = torch.norm( + torch.stack( + [torch.norm(g, p=2, dtype=torch.float32).to(device) for g in grads] + ) + ) + + if aggregate_norm_fn is not None: + total_norm = aggregate_norm_fn(total_norm) + + if max_norm > 0: + max_norm = float(max_norm) + clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1) + torch._foreach_mul_(grads + expert_grads, clip_coef) + + return total_norm + + +def fill_with_neg_inf(t): + """FP16-compatible function that fills a tensor with -inf.""" + return t.float().fill_(float("-inf")).type_as(t) + + +def _match_types(arg1, arg2): + """Convert the numerical argument to the same type as the other argument""" + + def upgrade(arg_number, arg_structure): + if isinstance(arg_structure, tuple): + return tuple([arg_number] * len(arg_structure)) + elif isinstance(arg_structure, dict): + arg = copy.deepcopy(arg_structure) + for k in arg: + arg[k] = upgrade(arg_number, arg_structure[k]) + return arg + else: + return arg_number + + if isinstance(arg1, float) or isinstance(arg1, int): + return upgrade(arg1, arg2), arg2 + elif isinstance(arg2, float) or isinstance(arg2, int): + return arg1, upgrade(arg2, arg1) + + return arg1, arg2 + + +def resolve_max_positions(*args): + """Resolve max position constraints from multiple sources.""" + + def map_value_update(d1, d2): + updated_value = copy.deepcopy(d1) + for key in d2: + if key not in updated_value: + updated_value[key] = d2[key] + else: + updated_value[key] = min(d1[key], d2[key]) + return updated_value + + def nullsafe_min(l): + minim = None + for item in l: + if minim is None: + minim = item + elif item is not None and item < minim: + minim = item + return minim + + max_positions = None + for arg in args: + if max_positions is None: + max_positions = arg + elif arg is not None: + max_positions, arg = _match_types(max_positions, arg) + if isinstance(arg, float) or isinstance(arg, int): + max_positions = min(max_positions, arg) + elif isinstance(arg, dict): + max_positions = map_value_update(max_positions, arg) + else: + max_positions = tuple(map(nullsafe_min, zip(max_positions, arg))) + + return max_positions + + +def import_user_module(args): + module_path = getattr(args, "user_dir", None) + if module_path is not None: + module_path = os.path.abspath(args.user_dir) + if not os.path.exists(module_path) and not os.path.isfile( + os.path.dirname(module_path) + ): + fairseq_rel_path = os.path.join(os.path.dirname(__file__), args.user_dir) + if os.path.exists(fairseq_rel_path): + module_path = fairseq_rel_path + else: + fairseq_rel_path = os.path.join( + os.path.dirname(__file__), "..", args.user_dir + ) + if os.path.exists(fairseq_rel_path): + module_path = fairseq_rel_path + else: + raise FileNotFoundError(module_path) + + # ensure that user modules are only imported once + import_user_module.memo = getattr(import_user_module, "memo", set()) + if module_path not in import_user_module.memo: + import_user_module.memo.add(module_path) + + module_parent, module_name = os.path.split(module_path) + if module_name not in sys.modules: + sys.path.insert(0, module_parent) + importlib.import_module(module_name) + + tasks_path = os.path.join(module_path, "tasks") + if os.path.exists(tasks_path): + from fairseq.tasks import import_tasks + + import_tasks(tasks_path, f"{module_name}.tasks") + + models_path = os.path.join(module_path, "models") + if os.path.exists(models_path): + from fairseq.models import import_models + + import_models(models_path, f"{module_name}.models") + elif module_path in sys.modules[module_name].__path__: + logger.info(f"--user-dir={module_path} has already been imported.") + else: + raise ImportError( + "Failed to import --user-dir={} because the corresponding module name " + "({}) is not globally unique. Please rename the directory to " + "something unique and try again.".format(module_path, module_name) + ) + + +def softmax(x, dim: int, onnx_trace: bool = False): + if onnx_trace: + return F.softmax(x.float(), dim=dim) + else: + return F.softmax(x, dim=dim, dtype=torch.float32) + + +def log_softmax(x, dim: int, onnx_trace: bool = False): + if onnx_trace: + return F.log_softmax(x.float(), dim=dim) + else: + return F.log_softmax(x, dim=dim, dtype=torch.float32) + + +def get_perplexity(loss, round=2, base=2): + from fairseq.logging.meters import safe_round + + if loss is None: + return 0.0 + try: + return safe_round(base**loss, round) + except OverflowError: + return float("inf") + + +def deprecation_warning(message, stacklevel=3): + # don't use DeprecationWarning, since it's ignored by default + warnings.warn(message, stacklevel=stacklevel) + + +def relu_squared(x: torch.Tensor): + return F.relu(x).pow(2) + + +def get_activation_fn(activation: str) -> Callable: + """Returns the activation function corresponding to `activation`""" + from fairseq.modules import gelu, gelu_accurate + + if activation == "relu": + return F.relu + elif activation == "relu_squared": + return relu_squared + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + deprecation_warning( + "--activation-fn=gelu_fast has been renamed to gelu_accurate" + ) + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "swish": + return torch.nn.SiLU + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +def get_available_activation_fns() -> List: + return [ + "relu", + "gelu", + "gelu_fast", # deprecated + "gelu_accurate", + "tanh", + "linear", + ] + + +@contextlib.contextmanager +def model_eval(model): + is_training = model.training + model.eval() + yield + model.train(is_training) + + +def has_parameters(module): + try: + next(module.parameters()) + return True + except StopIteration: + return False + + +def get_rng_state(): + state = {"torch_rng_state": torch.get_rng_state()} + if xm is not None: + state["xla_rng_state"] = xm.get_rng_state() + if torch.cuda.is_available(): + state["cuda_rng_state"] = torch.cuda.get_rng_state() + return state + + +def set_rng_state(state): + torch.set_rng_state(state["torch_rng_state"]) + if xm is not None: + xm.set_rng_state(state["xla_rng_state"]) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(state["cuda_rng_state"]) + + +class set_torch_seed(object): + def __init__(self, seed): + assert isinstance(seed, int) + self.rng_state = get_rng_state() + + torch.manual_seed(seed) + if xm is not None: + xm.set_rng_state(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + def __enter__(self): + return self + + def __exit__(self, *exc): + set_rng_state(self.rng_state) + + +def parse_alignment(line): + """ + Parses a single line from the alingment file. + + Args: + line (str): String containing the alignment of the format: + - - .. + -. All indices are 0 indexed. + + Returns: + torch.IntTensor: packed alignments of shape (2 * m). + """ + alignments = line.strip().split() + parsed_alignment = torch.IntTensor(2 * len(alignments)) + for idx, alignment in enumerate(alignments): + src_idx, tgt_idx = alignment.split("-") + parsed_alignment[2 * idx] = int(src_idx) + parsed_alignment[2 * idx + 1] = int(tgt_idx) + return parsed_alignment + + +def get_token_to_word_mapping(tokens, exclude_list): + n = len(tokens) + word_start = [int(token not in exclude_list) for token in tokens] + word_idx = list(accumulate(word_start)) + token_to_word = {i: word_idx[i] for i in range(n)} + return token_to_word + + +def extract_hard_alignment(attn, src_sent, tgt_sent, pad, eos): + tgt_valid = ( + ((tgt_sent != pad) & (tgt_sent != eos)).nonzero(as_tuple=False).squeeze(dim=-1) + ) + src_invalid = ( + ((src_sent == pad) | (src_sent == eos)).nonzero(as_tuple=False).squeeze(dim=-1) + ) + src_token_to_word = get_token_to_word_mapping(src_sent, [eos, pad]) + tgt_token_to_word = get_token_to_word_mapping(tgt_sent, [eos, pad]) + alignment = [] + if len(tgt_valid) != 0 and len(src_invalid) < len(src_sent): + attn_valid = attn[tgt_valid] + attn_valid[:, src_invalid] = float("-inf") + _, src_indices = attn_valid.max(dim=1) + for tgt_idx, src_idx in zip(tgt_valid, src_indices): + alignment.append( + ( + src_token_to_word[src_idx.item()] - 1, + tgt_token_to_word[tgt_idx.item()] - 1, + ) + ) + return alignment + + +def extract_soft_alignment(attn, src_sent, tgt_sent, pad, eos): + tgt_valid = ((tgt_sent != pad)).nonzero(as_tuple=False) + src_valid = ((src_sent != pad)).nonzero(as_tuple=False).squeeze(dim=-1) + alignment = [] + if len(tgt_valid) != 0 and len(src_valid) != 0: + attn_valid = attn[tgt_valid, src_valid] + alignment = [ + ["{:.6f}".format(p) for p in src_probs.tolist()] for src_probs in attn_valid + ] + return alignment + + +def new_arange(x, *size): + """ + Return a Tensor of `size` filled with a range function on the device of x. + If size is empty, using the size of the variable x. + """ + if len(size) == 0: + size = x.size() + return torch.arange(size[-1], device=x.device).expand(*size).contiguous() + + +def get_tpu_device(): + return xm.xla_device() + + +def tpu_data_loader(itr): + import torch_xla.core.xla_model as xm + import torch_xla.distributed.parallel_loader as pl + + from fairseq.data import iterators + + xm.rendezvous("tpu_data_loader") # wait for all workers + xm.mark_step() + device = xm.xla_device() + return iterators.CountingIterator( + pl.ParallelLoader(itr, [device]).per_device_loader(device), + start=getattr(itr, "n", 0), + total=len(itr), + ) + + +def is_xla_tensor(tensor): + return torch.is_tensor(tensor) and tensor.device.type == "xla" + + +def index_put(tensor, indices, value): + if is_xla_tensor(tensor): + for _ in range(indices.dim(), tensor.dim()): + indices = indices.unsqueeze(-1) + if indices.size(-1) < tensor.size(-1): + indices = indices.expand_as(tensor) + tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices) + else: + tensor[indices] = value + return tensor + + +def xla_device_to_cpu(dat): + import torch_xla.core.xla_model as xm + + return xm._maybe_convert_to_cpu(dat) + + +class CudaEnvironment(object): + def __init__(self): + cur_device = torch.cuda.current_device() + prop = torch.cuda.get_device_properties("cuda:{}".format(cur_device)) + self.name = prop.name + self.major = prop.major + self.minor = prop.minor + self.total_memory_in_GB = prop.total_memory / 1024 / 1024 / 1024 + + @staticmethod + def pretty_print_cuda_env_list(cuda_env_list): + """ + Given a list of CudaEnviorments, pretty print them + """ + num_workers = len(cuda_env_list) + center = "CUDA enviroments for all {} workers".format(num_workers) + banner_len = 40 - len(center) // 2 + first_line = "*" * banner_len + center + "*" * banner_len + logger.info(first_line) + for r, env in enumerate(cuda_env_list): + logger.info( + "rank {:3d}: ".format(r) + + "capabilities = {:2d}.{:<2d} ; ".format(env.major, env.minor) + + "total memory = {:.3f} GB ; ".format(env.total_memory_in_GB) + + "name = {:40s}".format(env.name) + ) + logger.info(first_line) + + +def csv_str_list(x): + return x.split(",") + + +def eval_str_list(x, type=float): + if x is None: + return None + if isinstance(x, str): + x = eval(x) + try: + return list(map(type, x)) + except TypeError: + return [type(x)] + + +def eval_str_dict(x, type=dict): + if x is None: + return None + if isinstance(x, str): + x = eval(x) + return x + + +def eval_bool(x, default=False): + if x is None: + return default + try: + return bool(eval(x)) + except TypeError: + return default + + +def reset_logging(): + root = logging.getLogger() + for handler in root.handlers: + root.removeHandler(handler) + root.setLevel(os.environ.get("LOGLEVEL", "INFO").upper()) + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter( + logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + ) + root.addHandler(handler) + + +def safe_getattr(obj, k, default=None): + """Returns obj[k] if it exists and is not None, otherwise returns default.""" + from omegaconf import OmegaConf + + if OmegaConf.is_config(obj): + return obj[k] if k in obj and obj[k] is not None else default + + return getattr(obj, k, default) + + +def safe_hasattr(obj, k): + """Returns True if the given key exists and is not None.""" + return getattr(obj, k, None) is not None + + +def hotreload_function(name=None): + """ + Decorator to function to enable hot-reload for debugging. + It allows you to debug a function without having reloading all heavy models, dataset loading and + preprocessing, allow faster debugging. + If you want to change model or dataset loading, consider relaunching your code + ----------------------------------- + This will run the decorated function func: + if func run successful: + It will pause, allow user to edit code, and prompt user to: + Press enter to re-run the function with updated code + Type "done" to finish the function, return output + Type "disable" to stop pausing this function and let code continue without pause + Ctril + C to terminal + if func raise error: + it will prompt user to + 1. Edit code, and press enter to retry + 2. Ctrl + C to terminate + 3. Type "raise" to raise that exception + * Requirements: + 0. Fairseq was installed with `pip install --editable .` + 1. pip install jurigged[develoop] + 2. set environment HOTRELOAD_PAUSE=1 CUDA_LAUNCH_BLOCKING=1 + 3. Run on only 1 GPU (no distributed) + * How to use: + 1. in python, import and decorate the top-level function to be re-run after code edits: + ```python + from fairseq.utils import hotreload_function + .... + @hotreload_function("train_step") + def train_step(self, sample ....): + .... + .... + ``` + 2. in bash run scripts: + ```bash + watch_dir=/fairseq-py/fairseq/tasks # directory to watch for file changes + export CUDA_VISIBLE_DEVICES=0 # single-gpu + HOTRELOAD_PAUSE=1 CUDA_LAUNCH_BLOCKING=1 python -m jurigged -w ${watch_dir} --poll 2 -v train.py ...... + ``` + * NOTE: + 1. -w ${watch_dir} specify all the files to be watched for changes + once functions, class, ... code are changed, all instances in the process will get updated (hot-reload) + * Limitation: + * Currently distributed debugging not working + * Need to launch train.py locally (cannot submit jobs) + """ + try: + import jurigged + except ImportError as e: + logger.warning("Please install jurigged: pip install jurigged[develoop]") + raise e + from fairseq.distributed import utils as distributed_utils + import traceback + + def hotreload_decorator(func): + assert callable(func), f"not callable: {func}" + jname = name or func.__name__ + logger.info(f"jurigged-hotreload:Apply jurigged on {jname}:{func.__name__}") + HOTRELOAD_PAUSE = bool(os.environ.get("HOTRELOAD_PAUSE", 0)) + cublk = bool(os.environ.get("CUDA_LAUNCH_BLOCKING", 0)) + prefix = f"HOTRELOAD:{jname}:[cublk={cublk}]" + hot_reload_state = {"disable": False} + + def func_wrapper(*args, **kwargs): + if not HOTRELOAD_PAUSE or hot_reload_state["disable"]: + return func(*args, **kwargs) + world_size = distributed_utils.get_global_world_size() + assert ( + world_size <= 1 + ), f"HOTRELOAD_PAUSE:{jname} currently cannot do distributed training" + success = False + while not success: + try: + output = func(*args, **kwargs) + # success = True + end_action = input( + f"{prefix}: PAUSE, you may edit code now. Enter to re-run, ctrl+C to terminate, " + f'type "done" to continue (function still being watched), or type "disable" to stop pausing this function :' + ) + if end_action.strip().lower() in ["disable", "done"]: + success = True + else: + logger.warning( + f"{prefix}: action={end_action} function will re-run now." + ) + except Exception as e: + action = input( + f"{prefix}:ERROR: \n{traceback.format_exc()}\n" + f'Edit code to try again: enter to continue, ctrl+C to terminate, or type "raise" to raise the exception: ' + ) + if action.strip().lower() == "raise": + raise e + + if end_action.strip().lower() == "disable": + logger.warning( + f"{prefix}: Stop pausing {jname}. The function is still being watched and newly editted code will take effect " + f"if the {jname} is called again later." + f' "unset HOTRELOAD_PAUSE" before relaunch to disable hotreload and' + f" remove @hotreload_function decorator in the code." + ) + hot_reload_state["disable"] = True + return output + + return func_wrapper + + return hotreload_decorator diff --git a/fairseq/version.py b/fairseq/version.py new file mode 100644 index 0000000000000000000000000000000000000000..76da4a9882c63454f7f915ed547854f52ae38e8f --- /dev/null +++ b/fairseq/version.py @@ -0,0 +1 @@ +__version__ = "0.12.2" diff --git a/fairseq/version.txt b/fairseq/version.txt new file mode 100644 index 0000000000000000000000000000000000000000..26acbf080be051b441bc144e358859396d9133cc --- /dev/null +++ b/fairseq/version.txt @@ -0,0 +1 @@ +0.12.2 diff --git a/requirements.txt b/requirements.txt index 37dc80bf465ea056589a343d12e46c72efb9623f..2d9ff3cce7ce8df6caa68df822720b3b56472763 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -fairseq @ git+https://github.com/pytorch/fairseq.git@main absl-py==2.3.0 accelerate==1.2.1 alias-free-torch==0.0.6 @@ -33,7 +32,7 @@ GitPython==3.1.44 grpcio==1.73.0 h5py==3.10.0 huggingface-hub==0.30.2 -hydra-core>=1.0.7,<1.1 +hydra-core==1.3.2 hypothesis==6.70.0 imageio==2.37.0 importlib_metadata==8.5.0 @@ -71,7 +70,7 @@ nvidia-ml-py==12.575.51 nvidia-nccl-cu12==2.21.5 nvidia-nvjitlink-cu12==12.4.127 nvidia-nvtx-cu12==12.4.127 -omegaconf>=2.0.6,<2.2 +omegaconf==2.3.0 packaging==23.2 pandas==2.2.0 pathlib==1.0.1