Tijs Zwinkels
commited on
Commit
·
bccbb15
1
Parent(s):
006de3e
Move creation of OnlineASRProcessor inside the factory method
Browse filesPreventing more code duplication between whisper_online.py and whisper_online_server.py
- whisper_online.py +18 -17
- whisper_online_server.py +1 -16
whisper_online.py
CHANGED
|
@@ -551,7 +551,7 @@ def add_shared_args(parser):
|
|
| 551 |
|
| 552 |
def asr_factory(args, logfile=sys.stderr):
|
| 553 |
"""
|
| 554 |
-
Creates and configures an ASR instance based on the specified backend and arguments.
|
| 555 |
"""
|
| 556 |
backend = args.backend
|
| 557 |
if backend == "openai-api":
|
|
@@ -576,8 +576,23 @@ def asr_factory(args, logfile=sys.stderr):
|
|
| 576 |
print("Setting VAD filter", file=logfile)
|
| 577 |
asr.use_vad()
|
| 578 |
|
| 579 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 580 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 581 |
## main:
|
| 582 |
|
| 583 |
if __name__ == "__main__":
|
|
@@ -605,22 +620,8 @@ if __name__ == "__main__":
|
|
| 605 |
duration = len(load_audio(audio_path))/SAMPLING_RATE
|
| 606 |
print("Audio duration is: %2.2f seconds" % duration, file=logfile)
|
| 607 |
|
| 608 |
-
asr = asr_factory(args, logfile=logfile)
|
| 609 |
-
language = args.lan
|
| 610 |
-
if args.task == "translate":
|
| 611 |
-
asr.set_translate_task()
|
| 612 |
-
tgt_language = "en" # Whisper translates into English
|
| 613 |
-
else:
|
| 614 |
-
tgt_language = language # Whisper transcribes in this language
|
| 615 |
-
|
| 616 |
-
|
| 617 |
min_chunk = args.min_chunk_size
|
| 618 |
-
if args.buffer_trimming == "sentence":
|
| 619 |
-
tokenizer = create_tokenizer(tgt_language)
|
| 620 |
-
else:
|
| 621 |
-
tokenizer = None
|
| 622 |
-
online = OnlineASRProcessor(asr,tokenizer,logfile=logfile,buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec))
|
| 623 |
-
|
| 624 |
|
| 625 |
# load the audio into the LRU cache before we start the timer
|
| 626 |
a = load_audio_chunk(audio_path,0,1)
|
|
|
|
| 551 |
|
| 552 |
def asr_factory(args, logfile=sys.stderr):
|
| 553 |
"""
|
| 554 |
+
Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
|
| 555 |
"""
|
| 556 |
backend = args.backend
|
| 557 |
if backend == "openai-api":
|
|
|
|
| 576 |
print("Setting VAD filter", file=logfile)
|
| 577 |
asr.use_vad()
|
| 578 |
|
| 579 |
+
language = args.lan
|
| 580 |
+
if args.task == "translate":
|
| 581 |
+
asr.set_translate_task()
|
| 582 |
+
tgt_language = "en" # Whisper translates into English
|
| 583 |
+
else:
|
| 584 |
+
tgt_language = language # Whisper transcribes in this language
|
| 585 |
+
|
| 586 |
+
# Create the tokenizer
|
| 587 |
+
if args.buffer_trimming == "sentence":
|
| 588 |
+
tokenizer = create_tokenizer(tgt_language)
|
| 589 |
+
else:
|
| 590 |
+
tokenizer = None
|
| 591 |
|
| 592 |
+
# Create the OnlineASRProcessor
|
| 593 |
+
online = OnlineASRProcessor(asr,tokenizer,logfile=logfile,buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec))
|
| 594 |
+
|
| 595 |
+
return asr, online
|
| 596 |
## main:
|
| 597 |
|
| 598 |
if __name__ == "__main__":
|
|
|
|
| 620 |
duration = len(load_audio(audio_path))/SAMPLING_RATE
|
| 621 |
print("Audio duration is: %2.2f seconds" % duration, file=logfile)
|
| 622 |
|
| 623 |
+
asr, online = asr_factory(args, logfile=logfile)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 624 |
min_chunk = args.min_chunk_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 625 |
|
| 626 |
# load the audio into the LRU cache before we start the timer
|
| 627 |
a = load_audio_chunk(audio_path,0,1)
|
whisper_online_server.py
CHANGED
|
@@ -23,24 +23,9 @@ SAMPLING_RATE = 16000
|
|
| 23 |
|
| 24 |
size = args.model
|
| 25 |
language = args.lan
|
| 26 |
-
|
| 27 |
-
asr = asr_factory(args)
|
| 28 |
-
if args.task == "translate":
|
| 29 |
-
asr.set_translate_task()
|
| 30 |
-
tgt_language = "en"
|
| 31 |
-
else:
|
| 32 |
-
tgt_language = language
|
| 33 |
-
|
| 34 |
min_chunk = args.min_chunk_size
|
| 35 |
|
| 36 |
-
if args.buffer_trimming == "sentence":
|
| 37 |
-
tokenizer = create_tokenizer(tgt_language)
|
| 38 |
-
else:
|
| 39 |
-
tokenizer = None
|
| 40 |
-
online = OnlineASRProcessor(asr,tokenizer,buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec))
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
demo_audio_path = "cs-maji-2.16k.wav"
|
| 45 |
if os.path.exists(demo_audio_path):
|
| 46 |
# load the audio into the LRU cache before we start the timer
|
|
|
|
| 23 |
|
| 24 |
size = args.model
|
| 25 |
language = args.lan
|
| 26 |
+
asr, online = asr_factory(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
min_chunk = args.min_chunk_size
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
demo_audio_path = "cs-maji-2.16k.wav"
|
| 30 |
if os.path.exists(demo_audio_path):
|
| 31 |
# load the audio into the LRU cache before we start the timer
|