Spaces:
Running
Running
| from tiktoken import Encoding | |
| from utils.log_util import logger | |
| def decode(self, tokens, errors="replace", skip_special_tokens=False): | |
| """ | |
| 默认的decode,可能会报错,详见 decode_test.py | |
| skip_special_tokens 是为了兼容 hf_tokenizer | |
| errors: | |
| decoded bytes are not guaranteed to be valid UTF-8. | |
| "strict" Raise UnicodeError | |
| "ignore" Ignore and continue | |
| "replace" Replace with replacement character | |
| "backslashreplace" Replace with backslashed escape sequence | |
| "xmlcharrefreplace" Replace with XML character reference | |
| "namereplace" | |
| """ | |
| try: | |
| decode_str = self._core_bpe.decode_bytes(tokens).decode("utf-8", errors=errors) | |
| except Exception as e: # 捕捉不到 PyO3PanicException | |
| logger.error(f"{e} for {tokens} -> return 'null'") | |
| decode_str = "null" | |
| except: | |
| logger.error(f"unknown exception for {tokens} -> return 'null'") | |
| decode_str = "null" | |
| return decode_str | |
| def convert_ids_to_tokens(self, tokens, skip_special_tokens=False): | |
| """ | |
| 为什么没有这个方法? | |
| """ | |
| try: | |
| return self.decode_tokens_bytes(tokens) | |
| except Exception as e: # 捕捉不到 PyO3PanicException | |
| # 什么要返回None?见zh_util.py | |
| # 16个空闲id, 100256 100261-100275 | |
| logger.error(f"{e} for {tokens} -> return None") | |
| return [None for _ in tokens] | |
| except: | |
| logger.error(f"unknown exception for {tokens} -> return None") | |
| return [None for _ in tokens] | |
| def get_vocab(self, token_type="str"): | |
| """Returns vocab as a dict | |
| :param token_type: ["str", "byte"] | |
| :return: | |
| """ | |
| vocab = {} | |
| key_error_list = [] | |
| unicode_decode_error_list = [] | |
| for i in range(self.vocab_size): | |
| try: | |
| token_byte = self.convert_ids_to_tokens([i])[0] | |
| if token_byte is None: | |
| continue | |
| # token_str = token_byte.decode("utf-8") | |
| vocab[token_byte] = i | |
| except UnicodeDecodeError: # 773 UnicodeDecodeError | |
| unicode_decode_error_list.append((i, str(token_byte))) | |
| vocab[token_byte] = i | |
| # vocab.update(self.added_tokens_encoder) | |
| logger.info(f"{self.name} {len(key_error_list)} KeyError: {key_error_list}") | |
| logger.info(f"{self.name} {len(unicode_decode_error_list)} UnicodeDecodeError: {unicode_decode_error_list[:5]}") | |
| return vocab | |
| def vocab_size(self): | |
| """Returns vocab size without special tokens""" | |
| return len(self._mergeable_ranks) | |
| def encode(self, *args, **kwargs): | |
| """ | |
| add_special_token 是为了兼容 hf_tokenizer | |
| """ | |
| kwargs.pop("add_special_tokens", None) | |
| kwargs["allowed_special"] = "all" | |
| return self._encode(*args, **kwargs) | |
| def __len__(self): | |
| return self.n_vocab | |
| # tiktoken patch | |
| Encoding._encode = Encoding.encode | |
| Encoding.encode = encode | |
| Encoding.decode = decode | |
| Encoding.convert_ids_to_tokens = convert_ids_to_tokens | |
| Encoding.get_vocab = get_vocab | |
| Encoding.vocab_size = vocab_size | |
| Encoding.__len__ = __len__ | |