Spaces:
Runtime error
Runtime error
| # Originally from Microsoft Corporation. | |
| # Licensed under the MIT License. | |
| """ Wrapper for ngram_repeat_block cuda extension """ | |
| import math | |
| import warnings | |
| from typing import List | |
| import torch | |
| from torch import nn | |
| try: | |
| from fairseq import ngram_repeat_block_cuda | |
| EXTENSION_BUILT = True | |
| except ImportError: | |
| EXTENSION_BUILT = False | |
| def is_cuda_extension_usable() -> bool: | |
| """Check whether ngram_repeat_block_cuda is built properly""" | |
| if not EXTENSION_BUILT or not torch.cuda.is_available(): | |
| return False | |
| bsz = 2 | |
| tokens = torch.tensor([[4, 4, 3, 2], [1, 2, 3, 4]], dtype=torch.long, device="cuda") | |
| lprobs = torch.rand((8, 12), device="cuda") | |
| try: | |
| outputs = ngram_repeat_block_cuda.forward(tokens, lprobs, bsz, 3, 4, 3) | |
| outputs = outputs + 4 # This line breaks if the extension is built incorrectly. | |
| return True | |
| except RuntimeError: | |
| warnings.warn( | |
| "NGramRepeatBlock extension must be rebuilt." | |
| 'Run TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0" python setup.py build_ext --inplace' | |
| ) | |
| return False | |
| class NGramRepeatBlock(nn.Module): | |
| """Wrapper class for calling ngram_repeat_block cuda extension""" | |
| def __init__(self, no_repeat_ngram_size: int, use_extension: bool = True): | |
| super().__init__() | |
| self.use_extension = is_cuda_extension_usable() if use_extension else False | |
| self.no_repeat_ngram_size = no_repeat_ngram_size | |
| def reset_parameters(self): | |
| pass | |
| def call_cuda_extension( | |
| self, | |
| tokens, | |
| lprobs, | |
| bsz: int, | |
| beam_size: int, | |
| step: int, | |
| ): | |
| return ngram_repeat_block_cuda.forward( | |
| tokens, lprobs, bsz, step, beam_size, self.no_repeat_ngram_size | |
| ) | |
| def forward( | |
| self, | |
| tokens, | |
| lprobs, | |
| bsz: int, | |
| beam_size: int, | |
| step: int, | |
| ): | |
| """ | |
| Args: | |
| tokens(Tensor): Input tokens(Bsz*beam, seq_len) | |
| lprobs(Tensor): likelihood probability, | |
| Expected to be updated in place.(Bsz*beam, vocab_size) | |
| bsz(int): batch size | |
| step(int): current step | |
| beam_size(int): beam size | |
| no_repeat_ngram_size(int): Ngram size | |
| """ | |
| msg = f"expected {bsz *beam_size} got" | |
| assert tokens.size(0) == bsz * beam_size, f"{msg} {tokens.size(0)}" | |
| assert lprobs.size(0) == bsz * beam_size, f"{msg} {lprobs.size(0)}" | |
| if self.use_extension: | |
| return self.call_cuda_extension(tokens, lprobs, bsz, beam_size, step) | |
| else: | |
| return self._no_repeat_ngram( | |
| tokens, | |
| lprobs, | |
| bsz, | |
| beam_size, | |
| step, | |
| ) | |
| def _no_repeat_ngram(self, tokens, lprobs, bsz: int, beam_size: int, step: int): | |
| """For each hypothesis generate a list of previous ngrams and set associated lprobs to -inf""" | |
| banned_tokens = [ | |
| torch.jit.annotate(List[int], []) for bbsz_idx in range(bsz * beam_size) | |
| ] | |
| if step + 2 - self.no_repeat_ngram_size >= 0: | |
| cpu_tokens: List[List[int]] = tokens.cpu().tolist() | |
| check_start_pos = step + 2 - self.no_repeat_ngram_size | |
| for bbsz_idx in range(bsz * beam_size): | |
| ngram_to_check = cpu_tokens[bbsz_idx][ | |
| -(self.no_repeat_ngram_size - 1) : | |
| ] | |
| for i in range(check_start_pos): | |
| if ( | |
| ngram_to_check | |
| == cpu_tokens[bbsz_idx][i : i + self.no_repeat_ngram_size - 1] | |
| ): | |
| banned_tokens[bbsz_idx].append( | |
| cpu_tokens[bbsz_idx][i + self.no_repeat_ngram_size - 1] | |
| ) | |
| for bbsz_idx in range(bsz * beam_size): | |
| lprobs[bbsz_idx][ | |
| torch.tensor(banned_tokens[bbsz_idx], dtype=torch.int64) | |
| ] = torch.tensor(-math.inf).to(lprobs) | |
| return lprobs | |