Spaces:
Sleeping
Sleeping
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| from enum import Enum | |
| from functools import lru_cache | |
| import os | |
| from pathlib import Path | |
| import huggingface_hub | |
| import sherpa | |
| import sherpa_onnx | |
| class EnumDecodingMethod(Enum): | |
| greedy_search = "greedy_search" | |
| modified_beam_search = "modified_beam_search" | |
| model_map = { | |
| "Chinese": [ | |
| { | |
| "repo_id": "csukuangfj/wenet-chinese-model", | |
| "nn_model_file": "final.zip", | |
| "tokens_file": "units.txt", | |
| "sub_folder": ".", | |
| "loader": "load_sherpa_offline_recognizer", | |
| }, | |
| { | |
| "repo_id": "csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28", | |
| "nn_model_file": "model.int8.onnx", | |
| "tokens_file": "tokens.txt", | |
| "sub_folder": ".", | |
| "loader": "load_sherpa_offline_recognizer_from_paraformer", | |
| } | |
| ] | |
| } | |
| def download_model(repo_id: str, | |
| nn_model_file: str, | |
| tokens_file: str, | |
| sub_folder: str, | |
| local_model_dir: str, | |
| ): | |
| nn_model_file = huggingface_hub.hf_hub_download( | |
| repo_id=repo_id, | |
| filename=nn_model_file, | |
| subfolder=sub_folder, | |
| local_dir=local_model_dir, | |
| ) | |
| tokens_file = huggingface_hub.hf_hub_download( | |
| repo_id=repo_id, | |
| filename=tokens_file, | |
| subfolder=sub_folder, | |
| local_dir=local_model_dir, | |
| ) | |
| return nn_model_file, tokens_file | |
| def load_sherpa_offline_recognizer(nn_model_file: str, | |
| tokens_file: str, | |
| sample_rate: int = 16000, | |
| num_active_paths: int = 2, | |
| decoding_method: str = "greedy_search", | |
| num_mel_bins: int = 80, | |
| frame_dither: int = 0, | |
| ): | |
| feat_config = sherpa.FeatureConfig(normalize_samples=False) | |
| feat_config.fbank_opts.frame_opts.samp_freq = sample_rate | |
| feat_config.fbank_opts.mel_opts.num_bins = num_mel_bins | |
| feat_config.fbank_opts.frame_opts.dither = frame_dither | |
| config = sherpa.OfflineRecognizerConfig( | |
| nn_model=nn_model_file, | |
| tokens=tokens_file, | |
| use_gpu=False, | |
| feat_config=feat_config, | |
| decoding_method=decoding_method, | |
| num_active_paths=num_active_paths, | |
| ) | |
| recognizer = sherpa.OfflineRecognizer(config) | |
| return recognizer | |
| def load_sherpa_offline_recognizer_from_paraformer(nn_model_file: str, | |
| tokens_file: str, | |
| sample_rate: int = 16000, | |
| decoding_method: str = "greedy_search", | |
| feature_dim: int = 80, | |
| num_threads: int = 2, | |
| ): | |
| recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( | |
| paraformer=nn_model_file, | |
| tokens=tokens_file, | |
| num_threads=num_threads, | |
| sample_rate=sample_rate, | |
| feature_dim=feature_dim, | |
| decoding_method=decoding_method, | |
| debug=False, | |
| ) | |
| return recognizer | |
| def load_recognizer(repo_id: str, | |
| nn_model_file: str, | |
| tokens_file: str, | |
| sub_folder: str, | |
| local_model_dir: Path, | |
| loader: str, | |
| decoding_method: str = "greedy_search", | |
| num_active_paths: int = 4, | |
| ): | |
| if not os.path.exists(local_model_dir): | |
| download_model( | |
| repo_id=repo_id, | |
| nn_model_file=nn_model_file, | |
| tokens_file=tokens_file, | |
| sub_folder=sub_folder, | |
| local_model_dir=local_model_dir.as_posix(), | |
| ) | |
| nn_model_file = (local_model_dir / nn_model_file).as_posix() | |
| tokens_file = (local_model_dir / tokens_file).as_posix() | |
| if loader == "load_sherpa_offline_recognizer": | |
| recognizer = load_sherpa_offline_recognizer( | |
| nn_model_file=nn_model_file, | |
| tokens_file=tokens_file, | |
| decoding_method=decoding_method, | |
| num_active_paths=num_active_paths, | |
| ) | |
| elif loader == "load_sherpa_offline_recognizer_from_paraformer": | |
| recognizer = load_sherpa_offline_recognizer_from_paraformer( | |
| nn_model_file=nn_model_file, | |
| tokens_file=tokens_file, | |
| decoding_method=decoding_method, | |
| ) | |
| else: | |
| raise NotImplementedError("loader not support: {}".format(loader)) | |
| return recognizer | |
| if __name__ == "__main__": | |
| pass | |