Spaces:
Runtime error
Runtime error
Commit
·
06b4245
1
Parent(s):
994c238
add chinese models
Browse files
model.py
CHANGED
|
@@ -192,7 +192,9 @@ def get_vad() -> sherpa_onnx.VoiceActivityDetector:
|
|
| 192 |
|
| 193 |
@lru_cache(maxsize=10)
|
| 194 |
def get_pretrained_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer:
|
| 195 |
-
if repo_id in
|
|
|
|
|
|
|
| 196 |
return english_models[repo_id](repo_id)
|
| 197 |
elif repo_id in chinese_english_mixed_models:
|
| 198 |
return chinese_english_mixed_models[repo_id](repo_id)
|
|
@@ -202,6 +204,49 @@ def get_pretrained_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer:
|
|
| 202 |
raise ValueError(f"Unsupported repo_id: {repo_id}")
|
| 203 |
|
| 204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
english_models = {
|
| 206 |
"whisper-tiny.en": _get_whisper_model,
|
| 207 |
"whisper-base.en": _get_whisper_model,
|
|
@@ -218,6 +263,7 @@ russian_models = {
|
|
| 218 |
}
|
| 219 |
|
| 220 |
language_to_models = {
|
|
|
|
| 221 |
"English": list(english_models.keys()),
|
| 222 |
"Chinese+English": list(chinese_english_mixed_models.keys()),
|
| 223 |
"Russian": list(russian_models.keys()),
|
|
|
|
| 192 |
|
| 193 |
@lru_cache(maxsize=10)
|
| 194 |
def get_pretrained_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer:
|
| 195 |
+
if repo_id in chinese_models:
|
| 196 |
+
return chinese_models[repo_id](repo_id)
|
| 197 |
+
elif repo_id in english_models:
|
| 198 |
return english_models[repo_id](repo_id)
|
| 199 |
elif repo_id in chinese_english_mixed_models:
|
| 200 |
return chinese_english_mixed_models[repo_id](repo_id)
|
|
|
|
| 204 |
raise ValueError(f"Unsupported repo_id: {repo_id}")
|
| 205 |
|
| 206 |
|
| 207 |
+
def _get_wenetspeech_pre_trained_model(repo_id):
|
| 208 |
+
assert repo_id in (
|
| 209 |
+
"csukuangfj/sherpa-onnx-conformer-zh-stateless2-2023-05-23",
|
| 210 |
+
), repo_id
|
| 211 |
+
|
| 212 |
+
encoder_model = _get_nn_model_filename(
|
| 213 |
+
repo_id=repo_id,
|
| 214 |
+
filename="encoder-epoch-99-avg-1.onnx",
|
| 215 |
+
subfolder=".",
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
decoder_model = _get_nn_model_filename(
|
| 219 |
+
repo_id=repo_id,
|
| 220 |
+
filename="decoder-epoch-99-avg-1.onnx",
|
| 221 |
+
subfolder=".",
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
joiner_model = _get_nn_model_filename(
|
| 225 |
+
repo_id=repo_id,
|
| 226 |
+
filename="joiner-epoch-99-avg-1.onnx",
|
| 227 |
+
subfolder=".",
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
tokens = _get_token_filename(repo_id=repo_id, subfolder=".")
|
| 231 |
+
|
| 232 |
+
recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
|
| 233 |
+
tokens=tokens,
|
| 234 |
+
encoder=encoder_model,
|
| 235 |
+
decoder=decoder_model,
|
| 236 |
+
joiner=joiner_model,
|
| 237 |
+
num_threads=2,
|
| 238 |
+
sample_rate=16000,
|
| 239 |
+
feature_dim=80,
|
| 240 |
+
decoding_method="greedy_search",
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
return recognizer
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
chinese_models = {
|
| 247 |
+
"csukuangfj/sherpa-onnx-conformer-zh-stateless2-2023-05-23": _get_wenetspeech_pre_trained_model, # noqa
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
english_models = {
|
| 251 |
"whisper-tiny.en": _get_whisper_model,
|
| 252 |
"whisper-base.en": _get_whisper_model,
|
|
|
|
| 263 |
}
|
| 264 |
|
| 265 |
language_to_models = {
|
| 266 |
+
"Chinese": list(chinese_models),
|
| 267 |
"English": list(english_models.keys()),
|
| 268 |
"Chinese+English": list(chinese_english_mixed_models.keys()),
|
| 269 |
"Russian": list(russian_models.keys()),
|