Spaces:
Build error
Build error
Commit
·
09ae8c3
1
Parent(s):
1420134
add punctuations
Browse files- app.py +28 -1
- model.py +15 -0
- requirements.txt +1 -1
app.py
CHANGED
|
@@ -32,7 +32,13 @@ import torch
|
|
| 32 |
import torchaudio
|
| 33 |
|
| 34 |
from examples import examples
|
| 35 |
-
from model import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
languages = list(language_to_models.keys())
|
| 38 |
|
|
@@ -65,6 +71,7 @@ def process_url(
|
|
| 65 |
repo_id: str,
|
| 66 |
decoding_method: str,
|
| 67 |
num_active_paths: int,
|
|
|
|
| 68 |
url: str,
|
| 69 |
):
|
| 70 |
logging.info(f"Processing URL: {url}")
|
|
@@ -78,6 +85,7 @@ def process_url(
|
|
| 78 |
repo_id=repo_id,
|
| 79 |
decoding_method=decoding_method,
|
| 80 |
num_active_paths=num_active_paths,
|
|
|
|
| 81 |
)
|
| 82 |
except Exception as e:
|
| 83 |
logging.info(str(e))
|
|
@@ -89,6 +97,7 @@ def process_uploaded_file(
|
|
| 89 |
repo_id: str,
|
| 90 |
decoding_method: str,
|
| 91 |
num_active_paths: int,
|
|
|
|
| 92 |
in_filename: str,
|
| 93 |
):
|
| 94 |
if in_filename is None or in_filename == "":
|
|
@@ -106,6 +115,7 @@ def process_uploaded_file(
|
|
| 106 |
repo_id=repo_id,
|
| 107 |
decoding_method=decoding_method,
|
| 108 |
num_active_paths=num_active_paths,
|
|
|
|
| 109 |
)
|
| 110 |
except Exception as e:
|
| 111 |
logging.info(str(e))
|
|
@@ -117,6 +127,7 @@ def process_microphone(
|
|
| 117 |
repo_id: str,
|
| 118 |
decoding_method: str,
|
| 119 |
num_active_paths: int,
|
|
|
|
| 120 |
in_filename: str,
|
| 121 |
):
|
| 122 |
if in_filename is None or in_filename == "":
|
|
@@ -135,6 +146,7 @@ def process_microphone(
|
|
| 135 |
repo_id=repo_id,
|
| 136 |
decoding_method=decoding_method,
|
| 137 |
num_active_paths=num_active_paths,
|
|
|
|
| 138 |
)
|
| 139 |
except Exception as e:
|
| 140 |
logging.info(str(e))
|
|
@@ -147,6 +159,7 @@ def process(
|
|
| 147 |
repo_id: str,
|
| 148 |
decoding_method: str,
|
| 149 |
num_active_paths: int,
|
|
|
|
| 150 |
in_filename: str,
|
| 151 |
):
|
| 152 |
logging.info(f"language: {language}")
|
|
@@ -170,6 +183,9 @@ def process(
|
|
| 170 |
)
|
| 171 |
|
| 172 |
text = decode(recognizer, filename)
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
|
| 175 |
end = time.time()
|
|
@@ -277,6 +293,12 @@ with demo:
|
|
| 277 |
label="Number of active paths for modified_beam_search",
|
| 278 |
)
|
| 279 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
with gr.Tabs():
|
| 281 |
with gr.TabItem("Upload from disk"):
|
| 282 |
uploaded_file = gr.Audio(
|
|
@@ -295,6 +317,7 @@ with demo:
|
|
| 295 |
model_dropdown,
|
| 296 |
decoding_method_radio,
|
| 297 |
num_active_paths_slider,
|
|
|
|
| 298 |
uploaded_file,
|
| 299 |
],
|
| 300 |
outputs=[uploaded_output, uploaded_html_info],
|
|
@@ -319,6 +342,7 @@ with demo:
|
|
| 319 |
model_dropdown,
|
| 320 |
decoding_method_radio,
|
| 321 |
num_active_paths_slider,
|
|
|
|
| 322 |
microphone,
|
| 323 |
],
|
| 324 |
outputs=[recorded_output, recorded_html_info],
|
|
@@ -344,6 +368,7 @@ with demo:
|
|
| 344 |
model_dropdown,
|
| 345 |
decoding_method_radio,
|
| 346 |
num_active_paths_slider,
|
|
|
|
| 347 |
uploaded_file,
|
| 348 |
],
|
| 349 |
outputs=[uploaded_output, uploaded_html_info],
|
|
@@ -356,6 +381,7 @@ with demo:
|
|
| 356 |
model_dropdown,
|
| 357 |
decoding_method_radio,
|
| 358 |
num_active_paths_slider,
|
|
|
|
| 359 |
microphone,
|
| 360 |
],
|
| 361 |
outputs=[recorded_output, recorded_html_info],
|
|
@@ -368,6 +394,7 @@ with demo:
|
|
| 368 |
model_dropdown,
|
| 369 |
decoding_method_radio,
|
| 370 |
num_active_paths_slider,
|
|
|
|
| 371 |
url_textbox,
|
| 372 |
],
|
| 373 |
outputs=[url_output, url_html_info],
|
|
|
|
| 32 |
import torchaudio
|
| 33 |
|
| 34 |
from examples import examples
|
| 35 |
+
from model import (
|
| 36 |
+
decode,
|
| 37 |
+
get_pretrained_model,
|
| 38 |
+
get_punct_model,
|
| 39 |
+
language_to_models,
|
| 40 |
+
sample_rate,
|
| 41 |
+
)
|
| 42 |
|
| 43 |
languages = list(language_to_models.keys())
|
| 44 |
|
|
|
|
| 71 |
repo_id: str,
|
| 72 |
decoding_method: str,
|
| 73 |
num_active_paths: int,
|
| 74 |
+
add_punct: str,
|
| 75 |
url: str,
|
| 76 |
):
|
| 77 |
logging.info(f"Processing URL: {url}")
|
|
|
|
| 85 |
repo_id=repo_id,
|
| 86 |
decoding_method=decoding_method,
|
| 87 |
num_active_paths=num_active_paths,
|
| 88 |
+
add_punct=add_punct,
|
| 89 |
)
|
| 90 |
except Exception as e:
|
| 91 |
logging.info(str(e))
|
|
|
|
| 97 |
repo_id: str,
|
| 98 |
decoding_method: str,
|
| 99 |
num_active_paths: int,
|
| 100 |
+
add_punct: str,
|
| 101 |
in_filename: str,
|
| 102 |
):
|
| 103 |
if in_filename is None or in_filename == "":
|
|
|
|
| 115 |
repo_id=repo_id,
|
| 116 |
decoding_method=decoding_method,
|
| 117 |
num_active_paths=num_active_paths,
|
| 118 |
+
add_punct=add_punct,
|
| 119 |
)
|
| 120 |
except Exception as e:
|
| 121 |
logging.info(str(e))
|
|
|
|
| 127 |
repo_id: str,
|
| 128 |
decoding_method: str,
|
| 129 |
num_active_paths: int,
|
| 130 |
+
add_punct: str,
|
| 131 |
in_filename: str,
|
| 132 |
):
|
| 133 |
if in_filename is None or in_filename == "":
|
|
|
|
| 146 |
repo_id=repo_id,
|
| 147 |
decoding_method=decoding_method,
|
| 148 |
num_active_paths=num_active_paths,
|
| 149 |
+
add_punct=add_punct,
|
| 150 |
)
|
| 151 |
except Exception as e:
|
| 152 |
logging.info(str(e))
|
|
|
|
| 159 |
repo_id: str,
|
| 160 |
decoding_method: str,
|
| 161 |
num_active_paths: int,
|
| 162 |
+
add_punct: str,
|
| 163 |
in_filename: str,
|
| 164 |
):
|
| 165 |
logging.info(f"language: {language}")
|
|
|
|
| 183 |
)
|
| 184 |
|
| 185 |
text = decode(recognizer, filename)
|
| 186 |
+
if add_punct == "Yes":
|
| 187 |
+
punct = get_punct_model()
|
| 188 |
+
text = punct.add_punctuation(text)
|
| 189 |
|
| 190 |
date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
|
| 191 |
end = time.time()
|
|
|
|
| 293 |
label="Number of active paths for modified_beam_search",
|
| 294 |
)
|
| 295 |
|
| 296 |
+
punct_radio = gr.Radio(
|
| 297 |
+
label="Whether to add punctuation (Only for Chinese and English)",
|
| 298 |
+
choices=["Yes", "No"],
|
| 299 |
+
value="Yes",
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
with gr.Tabs():
|
| 303 |
with gr.TabItem("Upload from disk"):
|
| 304 |
uploaded_file = gr.Audio(
|
|
|
|
| 317 |
model_dropdown,
|
| 318 |
decoding_method_radio,
|
| 319 |
num_active_paths_slider,
|
| 320 |
+
punct_radio,
|
| 321 |
uploaded_file,
|
| 322 |
],
|
| 323 |
outputs=[uploaded_output, uploaded_html_info],
|
|
|
|
| 342 |
model_dropdown,
|
| 343 |
decoding_method_radio,
|
| 344 |
num_active_paths_slider,
|
| 345 |
+
punct_radio,
|
| 346 |
microphone,
|
| 347 |
],
|
| 348 |
outputs=[recorded_output, recorded_html_info],
|
|
|
|
| 368 |
model_dropdown,
|
| 369 |
decoding_method_radio,
|
| 370 |
num_active_paths_slider,
|
| 371 |
+
punct_radio,
|
| 372 |
uploaded_file,
|
| 373 |
],
|
| 374 |
outputs=[uploaded_output, uploaded_html_info],
|
|
|
|
| 381 |
model_dropdown,
|
| 382 |
decoding_method_radio,
|
| 383 |
num_active_paths_slider,
|
| 384 |
+
punct_radio,
|
| 385 |
microphone,
|
| 386 |
],
|
| 387 |
outputs=[recorded_output, recorded_html_info],
|
|
|
|
| 394 |
model_dropdown,
|
| 395 |
decoding_method_radio,
|
| 396 |
num_active_paths_slider,
|
| 397 |
+
punct_radio,
|
| 398 |
url_textbox,
|
| 399 |
],
|
| 400 |
outputs=[url_output, url_html_info],
|
model.py
CHANGED
|
@@ -1182,6 +1182,21 @@ def _get_aishell_pre_trained_model(
|
|
| 1182 |
return recognizer
|
| 1183 |
|
| 1184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1185 |
def _get_multi_zh_hans_pre_trained_model(
|
| 1186 |
repo_id: str,
|
| 1187 |
decoding_method: str,
|
|
|
|
| 1182 |
return recognizer
|
| 1183 |
|
| 1184 |
|
| 1185 |
+
@lru_cache(maxsize=2)
|
| 1186 |
+
def get_punct_model() -> sherpa_onnx.OfflinePunctuation:
|
| 1187 |
+
model = _get_nn_model_filename(
|
| 1188 |
+
repo_id="csukuangfj/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12",
|
| 1189 |
+
filename="model.onnx",
|
| 1190 |
+
subfolder=".",
|
| 1191 |
+
)
|
| 1192 |
+
config = sherpa_onnx.OfflinePunctuationConfig(
|
| 1193 |
+
model=sherpa_onnx.OfflinePunctuationModelConfig(ct_transformer=model),
|
| 1194 |
+
)
|
| 1195 |
+
|
| 1196 |
+
punct = sherpa_onnx.OfflinePunctuation(config)
|
| 1197 |
+
return punct
|
| 1198 |
+
|
| 1199 |
+
|
| 1200 |
def _get_multi_zh_hans_pre_trained_model(
|
| 1201 |
repo_id: str,
|
| 1202 |
decoding_method: str,
|
requirements.txt
CHANGED
|
@@ -9,4 +9,4 @@ sentencepiece>=0.1.96
|
|
| 9 |
numpy
|
| 10 |
|
| 11 |
huggingface_hub
|
| 12 |
-
sherpa-onnx
|
|
|
|
| 9 |
numpy
|
| 10 |
|
| 11 |
huggingface_hub
|
| 12 |
+
sherpa-onnx>=1.9.19
|