clean diart audiosource class
Browse files
src/diarization/diarization_online.py
CHANGED
|
@@ -1,26 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from diart import SpeakerDiarization
|
| 2 |
from diart.inference import StreamingInference
|
| 3 |
from diart.sources import AudioSource
|
| 4 |
-
from rx.subject import Subject
|
| 5 |
-
import threading
|
| 6 |
-
import numpy as np
|
| 7 |
-
import asyncio
|
| 8 |
-
import re
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
| 13 |
|
| 14 |
class WebSocketAudioSource(AudioSource):
|
| 15 |
"""
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
push_audio() is used to inject new PCM chunks.
|
| 19 |
"""
|
| 20 |
def __init__(self, uri: str = "websocket", sample_rate: int = 16000):
|
| 21 |
super().__init__(uri, sample_rate)
|
| 22 |
-
self._close_event = threading.Event()
|
| 23 |
self._closed = False
|
|
|
|
| 24 |
|
| 25 |
def read(self):
|
| 26 |
self._close_event.wait()
|
|
@@ -32,99 +33,59 @@ class WebSocketAudioSource(AudioSource):
|
|
| 32 |
self._close_event.set()
|
| 33 |
|
| 34 |
def push_audio(self, chunk: np.ndarray):
|
| 35 |
-
chunk = np.expand_dims(chunk, axis=0)
|
| 36 |
if not self._closed:
|
| 37 |
-
self.stream.on_next(chunk)
|
| 38 |
-
|
| 39 |
|
| 40 |
-
def create_pipeline(SAMPLE_RATE):
|
| 41 |
-
diar_pipeline = SpeakerDiarization()
|
| 42 |
-
ws_source = WebSocketAudioSource(uri="websocket_source", sample_rate=SAMPLE_RATE)
|
| 43 |
-
inference = StreamingInference(
|
| 44 |
-
pipeline=diar_pipeline,
|
| 45 |
-
source=ws_source,
|
| 46 |
-
do_plot=False,
|
| 47 |
-
show_progress=False,
|
| 48 |
-
)
|
| 49 |
-
return inference, ws_source
|
| 50 |
|
| 51 |
-
|
| 52 |
-
def
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
pipeline=
|
| 57 |
-
source=
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
"""
|
| 70 |
annotation, audio = result
|
| 71 |
if annotation._labels:
|
| 72 |
-
for speaker in annotation._labels:
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
if
|
| 76 |
-
|
| 77 |
-
asyncio.create_task(
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
| 80 |
else:
|
| 81 |
-
|
| 82 |
-
if
|
| 83 |
-
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
class DiartDiarization:
|
| 91 |
-
def __init__(self, SAMPLE_RATE):
|
| 92 |
-
self.processed_time = 0
|
| 93 |
-
self.inference, self.l_speakers_queue, self.ws_source = init_diart(SAMPLE_RATE, self)
|
| 94 |
-
self.segment_speakers = []
|
| 95 |
-
|
| 96 |
-
async def diarize(self, pcm_array):
|
| 97 |
-
self.ws_source.push_audio(pcm_array)
|
| 98 |
-
self.segment_speakers = []
|
| 99 |
-
while not self.l_speakers_queue.empty():
|
| 100 |
-
self.segment_speakers.append(await self.l_speakers_queue.get())
|
| 101 |
|
| 102 |
def close(self):
|
| 103 |
-
self.
|
| 104 |
-
|
| 105 |
-
def assign_speakers_to_chunks(self, chunks):
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
"""
|
| 114 |
-
for ch in chunks:
|
| 115 |
-
ch["speaker"] = ch.get("speaker", -1)
|
| 116 |
-
|
| 117 |
-
for segment in self.segment_speakers:
|
| 118 |
-
seg_beg = segment["beg"]
|
| 119 |
-
seg_end = segment["end"]
|
| 120 |
-
speaker = segment["speaker"]
|
| 121 |
-
for ch in chunks:
|
| 122 |
-
if seg_end <= ch["beg"] or seg_beg >= ch["end"]:
|
| 123 |
-
continue
|
| 124 |
-
ch["speaker"] = extract_number(speaker) + 1
|
| 125 |
-
if self.processed_time > 0:
|
| 126 |
-
for ch in chunks:
|
| 127 |
-
if ch["end"] <= self.processed_time and ch["speaker"] == -1:
|
| 128 |
-
ch["speaker"] = -2
|
| 129 |
-
|
| 130 |
-
return chunks
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import re
|
| 3 |
+
import threading
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
from diart import SpeakerDiarization
|
| 7 |
from diart.inference import StreamingInference
|
| 8 |
from diart.sources import AudioSource
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
+
|
| 11 |
+
def extract_number(s: str) -> int:
|
| 12 |
+
m = re.search(r'\d+', s)
|
| 13 |
+
return int(m.group()) if m else None
|
| 14 |
+
|
| 15 |
|
| 16 |
class WebSocketAudioSource(AudioSource):
|
| 17 |
"""
|
| 18 |
+
Custom AudioSource that blocks in read() until close() is called.
|
| 19 |
+
Use push_audio() to inject PCM chunks.
|
|
|
|
| 20 |
"""
|
| 21 |
def __init__(self, uri: str = "websocket", sample_rate: int = 16000):
|
| 22 |
super().__init__(uri, sample_rate)
|
|
|
|
| 23 |
self._closed = False
|
| 24 |
+
self._close_event = threading.Event()
|
| 25 |
|
| 26 |
def read(self):
|
| 27 |
self._close_event.wait()
|
|
|
|
| 33 |
self._close_event.set()
|
| 34 |
|
| 35 |
def push_audio(self, chunk: np.ndarray):
|
|
|
|
| 36 |
if not self._closed:
|
| 37 |
+
self.stream.on_next(np.expand_dims(chunk, axis=0))
|
|
|
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
+
class DiartDiarization:
|
| 41 |
+
def __init__(self, sample_rate: int):
|
| 42 |
+
self.processed_time = 0
|
| 43 |
+
self.segment_speakers = []
|
| 44 |
+
self.speakers_queue = asyncio.Queue()
|
| 45 |
+
self.pipeline = SpeakerDiarization()
|
| 46 |
+
self.source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate)
|
| 47 |
+
self.inference = StreamingInference(
|
| 48 |
+
pipeline=self.pipeline,
|
| 49 |
+
source=self.source,
|
| 50 |
+
do_plot=False,
|
| 51 |
+
show_progress=False,
|
| 52 |
+
)
|
| 53 |
+
# Attache la fonction hook et démarre l'inférence en arrière-plan.
|
| 54 |
+
self.inference.attach_hooks(self._diar_hook)
|
| 55 |
+
asyncio.get_event_loop().run_in_executor(None, self.inference)
|
| 56 |
+
|
| 57 |
+
def _diar_hook(self, result):
|
|
|
|
| 58 |
annotation, audio = result
|
| 59 |
if annotation._labels:
|
| 60 |
+
for speaker, label in annotation._labels.items():
|
| 61 |
+
beg = label.segments_boundaries_[0]
|
| 62 |
+
end = label.segments_boundaries_[-1]
|
| 63 |
+
if end > self.processed_time:
|
| 64 |
+
self.processed_time = end
|
| 65 |
+
asyncio.create_task(self.speakers_queue.put({
|
| 66 |
+
"speaker": speaker,
|
| 67 |
+
"beg": beg,
|
| 68 |
+
"end": end
|
| 69 |
+
}))
|
| 70 |
else:
|
| 71 |
+
dur = audio.extent.end
|
| 72 |
+
if dur > self.processed_time:
|
| 73 |
+
self.processed_time = dur
|
| 74 |
|
| 75 |
+
async def diarize(self, pcm_array: np.ndarray):
|
| 76 |
+
self.source.push_audio(pcm_array)
|
| 77 |
+
self.segment_speakers.clear()
|
| 78 |
+
while not self.speakers_queue.empty():
|
| 79 |
+
self.segment_speakers.append(await self.speakers_queue.get())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
def close(self):
|
| 82 |
+
self.source.close()
|
| 83 |
+
|
| 84 |
+
def assign_speakers_to_chunks(self, chunks: list) -> list:
|
| 85 |
+
end_attributed_speaker = 0
|
| 86 |
+
for chunk in chunks:
|
| 87 |
+
for segment in self.segment_speakers:
|
| 88 |
+
if not (segment["end"] <= chunk["beg"] or segment["beg"] >= chunk["end"]):
|
| 89 |
+
chunk["speaker"] = extract_number(segment["speaker"]) + 1
|
| 90 |
+
end_attributed_speaker = chunk["end"]
|
| 91 |
+
return end_attributed_speaker
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|