Refactor DiartDiarization initialization and streamline WebSocket audio processing
Browse files- audio.py +82 -54
- diarization/diarization_online.py +1 -1
- whisper_fastapi_online_server.py +6 -74
audio.py
CHANGED
|
@@ -1,25 +1,15 @@
|
|
| 1 |
-
import io
|
| 2 |
-
import argparse
|
| 3 |
import asyncio
|
| 4 |
import numpy as np
|
| 5 |
import ffmpeg
|
| 6 |
from time import time, sleep
|
| 7 |
-
from contextlib import asynccontextmanager
|
| 8 |
|
| 9 |
-
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
| 10 |
-
from fastapi.responses import HTMLResponse
|
| 11 |
-
from fastapi.middleware.cors import CORSMiddleware
|
| 12 |
-
|
| 13 |
-
from whisper_streaming_custom.whisper_online import backend_factory, online_factory, add_shared_args, warmup_asr
|
| 14 |
-
from timed_objects import ASRToken
|
| 15 |
|
|
|
|
| 16 |
import math
|
| 17 |
import logging
|
| 18 |
-
from datetime import timedelta
|
| 19 |
import traceback
|
| 20 |
from state import SharedState
|
| 21 |
from formatters import format_time
|
| 22 |
-
from parse_args import parse_args
|
| 23 |
|
| 24 |
|
| 25 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
@@ -27,7 +17,6 @@ logging.getLogger().setLevel(logging.WARNING)
|
|
| 27 |
logger = logging.getLogger(__name__)
|
| 28 |
logger.setLevel(logging.DEBUG)
|
| 29 |
|
| 30 |
-
|
| 31 |
class AudioProcessor:
|
| 32 |
|
| 33 |
def __init__(self, args, asr, tokenizer):
|
|
@@ -38,9 +27,22 @@ class AudioProcessor:
|
|
| 38 |
self.bytes_per_sample = 2
|
| 39 |
self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
|
| 40 |
self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
|
|
|
|
|
|
|
| 41 |
self.shared_state = SharedState()
|
| 42 |
self.asr = asr
|
| 43 |
self.tokenizer = tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
def convert_pcm_to_float(self, pcm_buffer):
|
| 46 |
"""
|
|
@@ -70,26 +72,17 @@ class AudioProcessor:
|
|
| 70 |
)
|
| 71 |
return process
|
| 72 |
|
| 73 |
-
async def restart_ffmpeg(self
|
| 74 |
-
if ffmpeg_process:
|
| 75 |
try:
|
| 76 |
-
ffmpeg_process.kill()
|
| 77 |
-
await asyncio.get_event_loop().run_in_executor(None, ffmpeg_process.wait)
|
| 78 |
except Exception as e:
|
| 79 |
logger.warning(f"Error killing FFmpeg process: {e}")
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
if self.args.transcription:
|
| 84 |
-
online = online_factory(self.args, self.asr, self.tokenizer)
|
| 85 |
-
|
| 86 |
-
await self.shared_state.reset()
|
| 87 |
-
logger.info("FFmpeg process started.")
|
| 88 |
-
return ffmpeg_process, online, pcm_buffer
|
| 89 |
-
|
| 90 |
-
|
| 91 |
|
| 92 |
-
async def ffmpeg_stdout_reader(self
|
| 93 |
loop = asyncio.get_event_loop()
|
| 94 |
beg = time()
|
| 95 |
|
|
@@ -103,36 +96,36 @@ class AudioProcessor:
|
|
| 103 |
try:
|
| 104 |
chunk = await asyncio.wait_for(
|
| 105 |
loop.run_in_executor(
|
| 106 |
-
None, ffmpeg_process.stdout.read, ffmpeg_buffer_from_duration
|
| 107 |
),
|
| 108 |
timeout=15.0
|
| 109 |
)
|
| 110 |
except asyncio.TimeoutError:
|
| 111 |
logger.warning("FFmpeg read timeout. Restarting...")
|
| 112 |
-
|
| 113 |
beg = time()
|
| 114 |
continue # Skip processing and read from new process
|
| 115 |
|
| 116 |
if not chunk:
|
| 117 |
logger.info("FFmpeg stdout closed.")
|
| 118 |
break
|
| 119 |
-
pcm_buffer.extend(chunk)
|
| 120 |
|
| 121 |
-
if self.args.diarization and diarization_queue:
|
| 122 |
-
await diarization_queue.put(self.convert_pcm_to_float(pcm_buffer).copy())
|
| 123 |
|
| 124 |
-
if len(pcm_buffer) >= self.bytes_per_sec:
|
| 125 |
-
if len(pcm_buffer) > self.max_bytes_per_sec:
|
| 126 |
logger.warning(
|
| 127 |
-
f"""Audio buffer is too large: {len(pcm_buffer) / self.bytes_per_sec:.2f} seconds.
|
| 128 |
The model probably struggles to keep up. Consider using a smaller model.
|
| 129 |
""")
|
| 130 |
|
| 131 |
-
pcm_array = self.convert_pcm_to_float(pcm_buffer[:self.max_bytes_per_sec])
|
| 132 |
-
pcm_buffer = pcm_buffer[self.max_bytes_per_sec:]
|
| 133 |
|
| 134 |
-
if self.args.transcription and transcription_queue:
|
| 135 |
-
await transcription_queue.put(pcm_array.copy())
|
| 136 |
|
| 137 |
|
| 138 |
if not self.args.transcription and not self.args.diarization:
|
|
@@ -144,27 +137,24 @@ class AudioProcessor:
|
|
| 144 |
break
|
| 145 |
logger.info("Exiting ffmpeg_stdout_reader...")
|
| 146 |
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
async def transcription_processor(self, pcm_queue, online):
|
| 151 |
full_transcription = ""
|
| 152 |
-
sep = online.asr.sep
|
| 153 |
|
| 154 |
while True:
|
| 155 |
try:
|
| 156 |
-
pcm_array = await
|
| 157 |
|
| 158 |
-
logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.")
|
| 159 |
|
| 160 |
# Process transcription
|
| 161 |
-
online.insert_audio_chunk(pcm_array)
|
| 162 |
-
new_tokens = online.process_iter()
|
| 163 |
|
| 164 |
if new_tokens:
|
| 165 |
full_transcription += sep.join([t.text for t in new_tokens])
|
| 166 |
|
| 167 |
-
_buffer = online.get_buffer()
|
| 168 |
buffer = _buffer.text
|
| 169 |
end_buffer = _buffer.end if _buffer.end else (new_tokens[-1].end if new_tokens else 0)
|
| 170 |
|
|
@@ -178,14 +168,15 @@ class AudioProcessor:
|
|
| 178 |
logger.warning(f"Exception in transcription_processor: {e}")
|
| 179 |
logger.warning(f"Traceback: {traceback.format_exc()}")
|
| 180 |
finally:
|
| 181 |
-
|
| 182 |
|
| 183 |
-
|
|
|
|
| 184 |
buffer_diarization = ""
|
| 185 |
|
| 186 |
while True:
|
| 187 |
try:
|
| 188 |
-
pcm_array = await
|
| 189 |
|
| 190 |
# Process diarization
|
| 191 |
await diarization_obj.diarize(pcm_array)
|
|
@@ -205,7 +196,7 @@ class AudioProcessor:
|
|
| 205 |
logger.warning(f"Exception in diarization_processor: {e}")
|
| 206 |
logger.warning(f"Traceback: {traceback.format_exc()}")
|
| 207 |
finally:
|
| 208 |
-
|
| 209 |
|
| 210 |
async def results_formatter(self, websocket):
|
| 211 |
while True:
|
|
@@ -304,3 +295,40 @@ class AudioProcessor:
|
|
| 304 |
logger.warning(f"Exception in results_formatter: {e}")
|
| 305 |
logger.warning(f"Traceback: {traceback.format_exc()}")
|
| 306 |
await asyncio.sleep(0.5) # Back off on error
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import numpy as np
|
| 3 |
import ffmpeg
|
| 4 |
from time import time, sleep
|
|
|
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
from whisper_streaming_custom.whisper_online import online_factory
|
| 8 |
import math
|
| 9 |
import logging
|
|
|
|
| 10 |
import traceback
|
| 11 |
from state import SharedState
|
| 12 |
from formatters import format_time
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
|
|
| 17 |
logger = logging.getLogger(__name__)
|
| 18 |
logger.setLevel(logging.DEBUG)
|
| 19 |
|
|
|
|
| 20 |
class AudioProcessor:
|
| 21 |
|
| 22 |
def __init__(self, args, asr, tokenizer):
|
|
|
|
| 27 |
self.bytes_per_sample = 2
|
| 28 |
self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
|
| 29 |
self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
|
| 30 |
+
|
| 31 |
+
|
| 32 |
self.shared_state = SharedState()
|
| 33 |
self.asr = asr
|
| 34 |
self.tokenizer = tokenizer
|
| 35 |
+
|
| 36 |
+
self.ffmpeg_process = self.start_ffmpeg_decoder()
|
| 37 |
+
|
| 38 |
+
self.transcription_queue = asyncio.Queue() if self.args.transcription else None
|
| 39 |
+
self.diarization_queue = asyncio.Queue() if self.args.diarization else None
|
| 40 |
+
|
| 41 |
+
self.pcm_buffer = bytearray()
|
| 42 |
+
if self.args.transcription:
|
| 43 |
+
self.online = online_factory(self.args, self.asr, self.tokenizer)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
|
| 47 |
def convert_pcm_to_float(self, pcm_buffer):
|
| 48 |
"""
|
|
|
|
| 72 |
)
|
| 73 |
return process
|
| 74 |
|
| 75 |
+
async def restart_ffmpeg(self):
|
| 76 |
+
if self.ffmpeg_process:
|
| 77 |
try:
|
| 78 |
+
self.ffmpeg_process.kill()
|
| 79 |
+
await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait)
|
| 80 |
except Exception as e:
|
| 81 |
logger.warning(f"Error killing FFmpeg process: {e}")
|
| 82 |
+
self.ffmpeg_process = await self.start_ffmpeg_decoder()
|
| 83 |
+
self.pcm_buffer = bytearray()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
+
async def ffmpeg_stdout_reader(self):
|
| 86 |
loop = asyncio.get_event_loop()
|
| 87 |
beg = time()
|
| 88 |
|
|
|
|
| 96 |
try:
|
| 97 |
chunk = await asyncio.wait_for(
|
| 98 |
loop.run_in_executor(
|
| 99 |
+
None, self.ffmpeg_process.stdout.read, ffmpeg_buffer_from_duration
|
| 100 |
),
|
| 101 |
timeout=15.0
|
| 102 |
)
|
| 103 |
except asyncio.TimeoutError:
|
| 104 |
logger.warning("FFmpeg read timeout. Restarting...")
|
| 105 |
+
await self.restart_ffmpeg()
|
| 106 |
beg = time()
|
| 107 |
continue # Skip processing and read from new process
|
| 108 |
|
| 109 |
if not chunk:
|
| 110 |
logger.info("FFmpeg stdout closed.")
|
| 111 |
break
|
| 112 |
+
self.pcm_buffer.extend(chunk)
|
| 113 |
|
| 114 |
+
if self.args.diarization and self.diarization_queue:
|
| 115 |
+
await self.diarization_queue.put(self.convert_pcm_to_float(self.pcm_buffer).copy())
|
| 116 |
|
| 117 |
+
if len(self.pcm_buffer) >= self.bytes_per_sec:
|
| 118 |
+
if len(self.pcm_buffer) > self.max_bytes_per_sec:
|
| 119 |
logger.warning(
|
| 120 |
+
f"""Audio buffer is too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f} seconds.
|
| 121 |
The model probably struggles to keep up. Consider using a smaller model.
|
| 122 |
""")
|
| 123 |
|
| 124 |
+
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec])
|
| 125 |
+
self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:]
|
| 126 |
|
| 127 |
+
if self.args.transcription and self.transcription_queue:
|
| 128 |
+
await self.transcription_queue.put(pcm_array.copy())
|
| 129 |
|
| 130 |
|
| 131 |
if not self.args.transcription and not self.args.diarization:
|
|
|
|
| 137 |
break
|
| 138 |
logger.info("Exiting ffmpeg_stdout_reader...")
|
| 139 |
|
| 140 |
+
async def transcription_processor(self):
|
|
|
|
|
|
|
|
|
|
| 141 |
full_transcription = ""
|
| 142 |
+
sep = self.online.asr.sep
|
| 143 |
|
| 144 |
while True:
|
| 145 |
try:
|
| 146 |
+
pcm_array = await self.transcription_queue.get()
|
| 147 |
|
| 148 |
+
logger.info(f"{len(self.online.audio_buffer) / self.online.SAMPLING_RATE} seconds of audio will be processed by the model.")
|
| 149 |
|
| 150 |
# Process transcription
|
| 151 |
+
self.online.insert_audio_chunk(pcm_array)
|
| 152 |
+
new_tokens = self.online.process_iter()
|
| 153 |
|
| 154 |
if new_tokens:
|
| 155 |
full_transcription += sep.join([t.text for t in new_tokens])
|
| 156 |
|
| 157 |
+
_buffer = self.online.get_buffer()
|
| 158 |
buffer = _buffer.text
|
| 159 |
end_buffer = _buffer.end if _buffer.end else (new_tokens[-1].end if new_tokens else 0)
|
| 160 |
|
|
|
|
| 168 |
logger.warning(f"Exception in transcription_processor: {e}")
|
| 169 |
logger.warning(f"Traceback: {traceback.format_exc()}")
|
| 170 |
finally:
|
| 171 |
+
self.transcription_queue.task_done()
|
| 172 |
|
| 173 |
+
|
| 174 |
+
async def diarization_processor(self, diarization_obj):
|
| 175 |
buffer_diarization = ""
|
| 176 |
|
| 177 |
while True:
|
| 178 |
try:
|
| 179 |
+
pcm_array = await self.diarization_queue.get()
|
| 180 |
|
| 181 |
# Process diarization
|
| 182 |
await diarization_obj.diarize(pcm_array)
|
|
|
|
| 196 |
logger.warning(f"Exception in diarization_processor: {e}")
|
| 197 |
logger.warning(f"Traceback: {traceback.format_exc()}")
|
| 198 |
finally:
|
| 199 |
+
self.diarization_queue.task_done()
|
| 200 |
|
| 201 |
async def results_formatter(self, websocket):
|
| 202 |
while True:
|
|
|
|
| 295 |
logger.warning(f"Exception in results_formatter: {e}")
|
| 296 |
logger.warning(f"Traceback: {traceback.format_exc()}")
|
| 297 |
await asyncio.sleep(0.5) # Back off on error
|
| 298 |
+
|
| 299 |
+
async def create_tasks(self, websocket, diarization):
|
| 300 |
+
tasks = []
|
| 301 |
+
if self.args.transcription and self.online:
|
| 302 |
+
tasks.append(asyncio.create_task(self.transcription_processor()))
|
| 303 |
+
if self.args.diarization and diarization:
|
| 304 |
+
tasks.append(asyncio.create_task(self.diarization_processor(diarization)))
|
| 305 |
+
formatter_task = asyncio.create_task(self.results_formatter(websocket))
|
| 306 |
+
tasks.append(formatter_task)
|
| 307 |
+
stdout_reader_task = asyncio.create_task(self.ffmpeg_stdout_reader())
|
| 308 |
+
tasks.append(stdout_reader_task)
|
| 309 |
+
self.tasks = tasks
|
| 310 |
+
self.diarization = diarization
|
| 311 |
+
|
| 312 |
+
async def cleanup(self):
|
| 313 |
+
for task in self.tasks:
|
| 314 |
+
task.cancel()
|
| 315 |
+
try:
|
| 316 |
+
await asyncio.gather(*self.tasks, return_exceptions=True)
|
| 317 |
+
self.ffmpeg_process.stdin.close()
|
| 318 |
+
self.ffmpeg_process.wait()
|
| 319 |
+
except Exception as e:
|
| 320 |
+
logger.warning(f"Error during cleanup: {e}")
|
| 321 |
+
if self.args.diarization and self.diarization:
|
| 322 |
+
self.diarization.close()
|
| 323 |
+
|
| 324 |
+
async def process_audio(self, message):
|
| 325 |
+
try:
|
| 326 |
+
self.ffmpeg_process.stdin.write(message)
|
| 327 |
+
self.ffmpeg_process.stdin.flush()
|
| 328 |
+
except (BrokenPipeError, AttributeError) as e:
|
| 329 |
+
logger.warning(f"Error writing to FFmpeg: {e}. Restarting...")
|
| 330 |
+
await self.restart_ffmpeg()
|
| 331 |
+
self.ffmpeg_process.stdin.write(message)
|
| 332 |
+
self.ffmpeg_process.stdin.flush()
|
| 333 |
+
|
| 334 |
+
|
diarization/diarization_online.py
CHANGED
|
@@ -103,7 +103,7 @@ class WebSocketAudioSource(AudioSource):
|
|
| 103 |
|
| 104 |
|
| 105 |
class DiartDiarization:
|
| 106 |
-
def __init__(self, sample_rate: int, config : SpeakerDiarizationConfig = None, use_microphone: bool = False):
|
| 107 |
self.pipeline = SpeakerDiarization(config=config)
|
| 108 |
self.observer = DiarizationObserver()
|
| 109 |
|
|
|
|
| 103 |
|
| 104 |
|
| 105 |
class DiartDiarization:
|
| 106 |
+
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False):
|
| 107 |
self.pipeline = SpeakerDiarization(config=config)
|
| 108 |
self.observer = DiarizationObserver()
|
| 109 |
|
whisper_fastapi_online_server.py
CHANGED
|
@@ -1,24 +1,11 @@
|
|
| 1 |
-
import io
|
| 2 |
-
import argparse
|
| 3 |
-
import asyncio
|
| 4 |
-
import numpy as np
|
| 5 |
-
import ffmpeg
|
| 6 |
-
from time import time, sleep
|
| 7 |
from contextlib import asynccontextmanager
|
| 8 |
|
| 9 |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
| 10 |
from fastapi.responses import HTMLResponse
|
| 11 |
from fastapi.middleware.cors import CORSMiddleware
|
| 12 |
|
| 13 |
-
from whisper_streaming_custom.whisper_online import backend_factory,
|
| 14 |
-
from timed_objects import ASRToken
|
| 15 |
-
|
| 16 |
-
import math
|
| 17 |
import logging
|
| 18 |
-
from datetime import timedelta
|
| 19 |
-
import traceback
|
| 20 |
-
from state import SharedState
|
| 21 |
-
from formatters import format_time
|
| 22 |
from parse_args import parse_args
|
| 23 |
from audio import AudioProcessor
|
| 24 |
|
|
@@ -27,19 +14,8 @@ logging.getLogger().setLevel(logging.WARNING)
|
|
| 27 |
logger = logging.getLogger(__name__)
|
| 28 |
logger.setLevel(logging.DEBUG)
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
| 32 |
args = parse_args()
|
| 33 |
|
| 34 |
-
SAMPLE_RATE = 16000
|
| 35 |
-
# CHANNELS = 1
|
| 36 |
-
# SAMPLES_PER_SEC = int(SAMPLE_RATE * args.min_chunk_size)
|
| 37 |
-
# BYTES_PER_SAMPLE = 2 # s16le = 2 bytes per sample
|
| 38 |
-
# BYTES_PER_SEC = SAMPLES_PER_SEC * BYTES_PER_SAMPLE
|
| 39 |
-
# MAX_BYTES_PER_SEC = 32000 * 5 # 5 seconds of audio at 32 kHz
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
##### LOAD APP #####
|
| 43 |
|
| 44 |
@asynccontextmanager
|
| 45 |
async def lifespan(app: FastAPI):
|
|
@@ -52,7 +28,7 @@ async def lifespan(app: FastAPI):
|
|
| 52 |
|
| 53 |
if args.diarization:
|
| 54 |
from diarization.diarization_online import DiartDiarization
|
| 55 |
-
diarization = DiartDiarization(
|
| 56 |
else :
|
| 57 |
diarization = None
|
| 58 |
yield
|
|
@@ -75,66 +51,22 @@ with open("web/live_transcription.html", "r", encoding="utf-8") as f:
|
|
| 75 |
async def get():
|
| 76 |
return HTMLResponse(html)
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
@app.websocket("/asr")
|
| 86 |
async def websocket_endpoint(websocket: WebSocket):
|
| 87 |
audio_processor = AudioProcessor(args, asr, tokenizer)
|
| 88 |
|
| 89 |
await websocket.accept()
|
| 90 |
logger.info("WebSocket connection opened.")
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
pcm_buffer = bytearray()
|
| 94 |
-
|
| 95 |
-
transcription_queue = asyncio.Queue() if args.transcription else None
|
| 96 |
-
diarization_queue = asyncio.Queue() if args.diarization else None
|
| 97 |
-
|
| 98 |
-
online = None
|
| 99 |
-
|
| 100 |
-
ffmpeg_process, online, pcm_buffer = await audio_processor.restart_ffmpeg(ffmpeg_process, online, pcm_buffer)
|
| 101 |
-
tasks = []
|
| 102 |
-
if args.transcription and online:
|
| 103 |
-
tasks.append(asyncio.create_task(
|
| 104 |
-
audio_processor.transcription_processor(transcription_queue, online)))
|
| 105 |
-
if args.diarization and diarization:
|
| 106 |
-
tasks.append(asyncio.create_task(
|
| 107 |
-
audio_processor.diarization_processor(diarization_queue, diarization)))
|
| 108 |
-
formatter_task = asyncio.create_task(audio_processor.results_formatter(websocket))
|
| 109 |
-
tasks.append(formatter_task)
|
| 110 |
-
stdout_reader_task = asyncio.create_task(audio_processor.ffmpeg_stdout_reader(ffmpeg_process, pcm_buffer, diarization_queue, transcription_queue))
|
| 111 |
-
tasks.append(stdout_reader_task)
|
| 112 |
-
|
| 113 |
try:
|
| 114 |
while True:
|
| 115 |
-
# Receive incoming WebM audio chunks from the client
|
| 116 |
message = await websocket.receive_bytes()
|
| 117 |
-
|
| 118 |
-
ffmpeg_process.stdin.write(message)
|
| 119 |
-
ffmpeg_process.stdin.flush()
|
| 120 |
-
except (BrokenPipeError, AttributeError) as e:
|
| 121 |
-
logger.warning(f"Error writing to FFmpeg: {e}. Restarting...")
|
| 122 |
-
ffmpeg_process, online, pcm_buffer = await audio_processor.restart_ffmpeg(ffmpeg_process, online, pcm_buffer)
|
| 123 |
-
ffmpeg_process.stdin.write(message)
|
| 124 |
-
ffmpeg_process.stdin.flush()
|
| 125 |
except WebSocketDisconnect:
|
| 126 |
logger.warning("WebSocket disconnected.")
|
| 127 |
finally:
|
| 128 |
-
|
| 129 |
-
task.cancel()
|
| 130 |
-
try:
|
| 131 |
-
await asyncio.gather(*tasks, return_exceptions=True)
|
| 132 |
-
ffmpeg_process.stdin.close()
|
| 133 |
-
ffmpeg_process.wait()
|
| 134 |
-
except Exception as e:
|
| 135 |
-
logger.warning(f"Error during cleanup: {e}")
|
| 136 |
-
if args.diarization and diarization:
|
| 137 |
-
diarization.close()
|
| 138 |
logger.info("WebSocket endpoint cleaned up.")
|
| 139 |
|
| 140 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from contextlib import asynccontextmanager
|
| 2 |
|
| 3 |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
| 4 |
from fastapi.responses import HTMLResponse
|
| 5 |
from fastapi.middleware.cors import CORSMiddleware
|
| 6 |
|
| 7 |
+
from whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
|
|
|
|
|
|
|
|
|
|
| 8 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
from parse_args import parse_args
|
| 10 |
from audio import AudioProcessor
|
| 11 |
|
|
|
|
| 14 |
logger = logging.getLogger(__name__)
|
| 15 |
logger.setLevel(logging.DEBUG)
|
| 16 |
|
|
|
|
|
|
|
| 17 |
args = parse_args()
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
@asynccontextmanager
|
| 21 |
async def lifespan(app: FastAPI):
|
|
|
|
| 28 |
|
| 29 |
if args.diarization:
|
| 30 |
from diarization.diarization_online import DiartDiarization
|
| 31 |
+
diarization = DiartDiarization()
|
| 32 |
else :
|
| 33 |
diarization = None
|
| 34 |
yield
|
|
|
|
| 51 |
async def get():
|
| 52 |
return HTMLResponse(html)
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
@app.websocket("/asr")
|
| 55 |
async def websocket_endpoint(websocket: WebSocket):
|
| 56 |
audio_processor = AudioProcessor(args, asr, tokenizer)
|
| 57 |
|
| 58 |
await websocket.accept()
|
| 59 |
logger.info("WebSocket connection opened.")
|
| 60 |
+
|
| 61 |
+
await audio_processor.create_tasks(websocket, diarization)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
try:
|
| 63 |
while True:
|
|
|
|
| 64 |
message = await websocket.receive_bytes()
|
| 65 |
+
audio_processor.process_audio(message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
except WebSocketDisconnect:
|
| 67 |
logger.warning("WebSocket disconnected.")
|
| 68 |
finally:
|
| 69 |
+
audio_processor.cleanup()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
logger.info("WebSocket endpoint cleaned up.")
|
| 71 |
|
| 72 |
if __name__ == "__main__":
|