Spaces:
Running
Running
| """ | |
| TODO: | |
| 1. add more language | |
| 2. check space count of bert | |
| 3. add token_impl | |
| 4. | |
| """ | |
| import os | |
| import json | |
| import numpy as np | |
| import pandas as pd | |
| from collections import Counter, defaultdict | |
| from vocab import tokenizer_factory | |
| from typing import Optional, Union, Literal | |
| from utils.log_util import logger | |
| from utils.text_util import contains_digit, get_space_count | |
| from utils.lang_util import detect_language_by_unicode, language_ranges | |
| CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| default_columns = ["digit", "zh"] | |
| def _to_unicode(text): | |
| return ''.join(r'\u{:04X}'.format(ord(chr)) for chr in text) | |
| def _to_unicode_decimal(text): | |
| return [ord(chr) for chr in text] | |
| def _get_coding_length(tokenizer, vocab, filter=None): | |
| """ | |
| oov character may be tokenized into more than one token. | |
| """ | |
| all_length = [] | |
| for word in vocab: | |
| if len(word) > 1: | |
| continue | |
| if filter is not None and filter(word): | |
| continue | |
| try: | |
| tokens = tokenizer.encode(word) | |
| except Exception as e: | |
| print(e) | |
| all_length.append(len(tokens)) | |
| # if len(tokens.ids) > 1: | |
| # if len(tokens) > 3: | |
| # print(word, tokens) | |
| dist_length = Counter(all_length) | |
| mean_length = round(sum(all_length) / len(all_length), 2) | |
| return dist_length, mean_length | |
| cache = {} | |
| def _dist(token_lens): | |
| """ | |
| :param token_lens: | |
| :return: min,median,max of token_lens | |
| """ | |
| if not token_lens: | |
| return "-" | |
| return f"{min(token_lens)},{round(np.median(token_lens))},{max(token_lens)}" | |
| def _decode_bytes_to_str(token: bytes) -> str: | |
| try: | |
| token = token.decode("utf-8", errors="strict") | |
| except: | |
| try: | |
| # for single byte, such as b'\xa1' | |
| token = token.decode('latin-1') | |
| except: | |
| logger.warning(f"token {token} decode failed") | |
| token = token.decode("utf-8", errors="ignore") | |
| return token | |
| def iter_vocab( | |
| tokenizer_name: str, | |
| from_cache: bool = True, | |
| cache_dir: str = "stats", | |
| ) -> Union[pd.DataFrame, dict]: | |
| """ | |
| :param tokenizer_name: | |
| :param from_cache: | |
| :param cache_dir: | |
| :return: | |
| """ | |
| tokenizer_config = tokenizer_factory.get_tokenizer_config(tokenizer_name) | |
| cache_dir = os.path.join(CURRENT_DIR, cache_dir) | |
| os.makedirs(cache_dir, exist_ok=True) | |
| # load from cache | |
| cache_path = os.path.join(cache_dir, "character_stats.json") | |
| if not cache and os.path.exists(cache_path): | |
| with open(cache_path, "r", encoding="utf-8") as f_tmp: | |
| cache.update(json.load(f_tmp)) | |
| if from_cache and tokenizer_name in cache: | |
| # logger.info(f"load {tokenizer_config.name_or_path} from cache") | |
| return cache[tokenizer_name] | |
| tokenizer = tokenizer_factory.get_tokenizer(tokenizer_name) | |
| tokens_by_lang = {lang[1]: [] for lang in language_ranges.keys()} | |
| digit_tokens = [] | |
| space_tokens = [] | |
| byte_tokens = [] | |
| buffer = [] | |
| for token_id in range(tokenizer.vocab_size): | |
| # for token_id in tokenizer.get_vocab(): | |
| # for token_id in range(len(tokenizer)): | |
| decode_str = tokenizer.decode([token_id], skip_special_tokens=False) | |
| token = tokenizer.convert_ids_to_tokens([token_id], skip_special_tokens=False)[0] | |
| tags = [] | |
| if token is None: # 有些词典有空的id(不连续) | |
| continue | |
| if isinstance(token, bytes): # convert bytes to string | |
| token = _decode_bytes_to_str(token) | |
| if hasattr(tokenizer, "sp_model"): # 基于 sentencepiece 包 | |
| if tokenizer.sp_model.is_byte(token_id): | |
| tags.append("is_byte") | |
| byte_tokens.append(token) | |
| language_tags = detect_language_by_unicode(decode_str) | |
| for language in language_tags: | |
| tokens_by_lang[language[1]].append(decode_str) | |
| if contains_digit(decode_str): | |
| tags.append("digit") | |
| digit_tokens.append(decode_str) | |
| space_count = get_space_count(decode_str) | |
| if space_count > 0: | |
| space_tokens.append(decode_str) | |
| buffer.append(json.dumps( | |
| { | |
| "id": token_id, | |
| "token": token, | |
| "token_decode": decode_str, | |
| "token_dumps": json.dumps(token), # unicode: | |
| # https://en.wikipedia.org/wiki/List_of_Unicode_characters | |
| "token_unicode": _to_unicode(token), | |
| "token_unicode_decimal": _to_unicode_decimal(token), # 十进制的 | |
| # "token_utf8_bytes": "", | |
| "token_len": len(decode_str), | |
| }, | |
| ensure_ascii=False) + "\n") | |
| result = { | |
| "tokenizer": tokenizer_factory.get_name_with_hyperlink(tokenizer_name), | |
| "organization": tokenizer_config.org, | |
| # "impl": str(tokenizer.__class__), | |
| # "vocab_size-": tokenizer.vocab_size, # vocab_size_without_added_token | |
| "vocab_size": len(tokenizer), | |
| # "中文汉字编码长度均值": mean_length, # 不用统计,因为字典包含中文字符多,一般就意味着 中文汉字编码长度短。 | |
| # "中文汉字编码长度分布": json.dumps(dist_length), | |
| "num(digit)": len(digit_tokens), | |
| "len(digit)": _dist([len(token) for token in digit_tokens]), | |
| "num(space)": len(space_tokens), | |
| "len(space)": _dist([len(token) for token in space_tokens]), | |
| # "num(byte)": len(byte_tokens) | |
| } | |
| for lang, tokens in tokens_by_lang.items(): | |
| result[f"num({lang})"] = len(tokens) | |
| result["len(" + lang + ")"] = _dist([len(token) for token in tokens]) | |
| out_path = os.path.join(cache_dir, f"iter_vocab/{tokenizer_name.replace('/', '_')}.vocab.jsonl") | |
| with open(out_path, "w", encoding="utf-8") as f_out: | |
| for line in buffer: | |
| f_out.write(line) | |
| len_before = len(cache) | |
| cache[tokenizer_name] = result | |
| len_after = len(cache) | |
| logger.info(f"saving {tokenizer_name} to memory and file cache: {len_before}->{len_after}") | |
| with open(cache_path, "w", encoding="utf-8") as f_out: | |
| f_out.write(json.dumps(cache, ensure_ascii=False, indent=2)) | |
| return result | |
| def to_dataframe(stats, columns): | |
| table = [] | |
| for stat in stats.values(): | |
| filtered_stat = {} | |
| for k, v in stat.items(): | |
| if not k.startswith("num") and not k.startswith("len"): | |
| filtered_stat[k] = v | |
| if any(column in k for column in columns): | |
| k = k.replace("ja-kana", "kana") | |
| filtered_stat[k] = v | |
| table.append(filtered_stat) | |
| df = pd.DataFrame(table) | |
| return df | |
| def get_character_table( | |
| tokenizer_filter: Optional[str] = None, | |
| columns: Optional[list] = None, | |
| return_type: Optional[Literal["dict", "dataframe"]] = "dataframe" | |
| ) -> Union[pd.DataFrame, dict]: | |
| """ | |
| """ | |
| logger.info(f"columns: {columns}, tokenizer_filter: {tokenizer_filter}") | |
| stats = {} | |
| if columns is None: | |
| columns = default_columns | |
| if tokenizer_filter is not None: | |
| tokenizer_names = [tokenizer_config.name_or_path for tokenizer_config in tokenizer_factory.all_tokenizer_configs | |
| if tokenizer_filter.lower() in tokenizer_config.name_or_path.lower()] | |
| else: | |
| tokenizer_names = tokenizer_factory.all_tokenizer_names | |
| for tokenizer_name in tokenizer_names: | |
| stat = iter_vocab(tokenizer_name) | |
| stats[tokenizer_name] = stat | |
| if return_type == "dataframe": | |
| stats = to_dataframe(stats, columns) | |
| return stats | |
| if __name__ == "__main__": | |
| # aa = get_character_table(tokenizer_filter="baichuan") | |
| # iter_vocab("openai/gpt-4o", from_cache=False) | |
| # iter_vocab("openai/gpt-oss-20b", from_cache=False) | |
| iter_vocab("NousResearch/Llama-2-7b-chat-hf", from_cache=False) | |
| # df = get_character_table() | |
| # logger.info(f"\n{df.to_markdown(index=False)}") | |