Diarization : Uses a rx observer instead of diart attach_hooks method
Browse files- diarization/diarization_online.py +98 -35
- timed_objects.py +1 -0
- whisper_fastapi_online_server.py +8 -8
diarization/diarization_online.py
CHANGED
|
@@ -2,16 +2,79 @@ import asyncio
|
|
| 2 |
import re
|
| 3 |
import threading
|
| 4 |
import numpy as np
|
|
|
|
|
|
|
| 5 |
|
| 6 |
from diart import SpeakerDiarization
|
| 7 |
from diart.inference import StreamingInference
|
| 8 |
from diart.sources import AudioSource
|
| 9 |
from timed_objects import SpeakerSegment
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
def extract_number(s: str) -> int:
|
| 12 |
m = re.search(r'\d+', s)
|
| 13 |
return int(m.group()) if m else None
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
class WebSocketAudioSource(AudioSource):
|
| 17 |
"""
|
|
@@ -34,57 +97,57 @@ class WebSocketAudioSource(AudioSource):
|
|
| 34 |
|
| 35 |
def push_audio(self, chunk: np.ndarray):
|
| 36 |
if not self._closed:
|
| 37 |
-
|
|
|
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
class DiartDiarization:
|
| 41 |
-
def __init__(self, sample_rate: int):
|
| 42 |
-
self.
|
| 43 |
-
self.
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
self.inference = StreamingInference(
|
| 48 |
pipeline=self.pipeline,
|
| 49 |
source=self.source,
|
| 50 |
do_plot=False,
|
| 51 |
show_progress=False,
|
| 52 |
)
|
| 53 |
-
|
| 54 |
-
self.inference.attach_hooks(self._diar_hook)
|
| 55 |
asyncio.get_event_loop().run_in_executor(None, self.inference)
|
| 56 |
|
| 57 |
-
def _diar_hook(self, result):
|
| 58 |
-
annotation, audio = result
|
| 59 |
-
if annotation._labels:
|
| 60 |
-
for speaker, label in annotation._labels.items():
|
| 61 |
-
start = label.segments_boundaries_[0]
|
| 62 |
-
end = label.segments_boundaries_[-1]
|
| 63 |
-
if end > self.processed_time:
|
| 64 |
-
self.processed_time = end
|
| 65 |
-
asyncio.create_task(self.speakers_queue.put(SpeakerSegment(
|
| 66 |
-
speaker=speaker,
|
| 67 |
-
start=start,
|
| 68 |
-
end=end,
|
| 69 |
-
)))
|
| 70 |
-
else:
|
| 71 |
-
dur = audio.extent.end
|
| 72 |
-
if dur > self.processed_time:
|
| 73 |
-
self.processed_time = dur
|
| 74 |
-
|
| 75 |
async def diarize(self, pcm_array: np.ndarray):
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
def close(self):
|
| 82 |
-
|
|
|
|
|
|
|
| 83 |
|
| 84 |
-
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) ->
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
for token in tokens:
|
| 86 |
-
for segment in
|
| 87 |
if not (segment.end <= token.start or segment.start >= token.end):
|
| 88 |
token.speaker = extract_number(segment.speaker) + 1
|
| 89 |
end_attributed_speaker = max(token.end, end_attributed_speaker)
|
| 90 |
-
return end_attributed_speaker
|
|
|
|
| 2 |
import re
|
| 3 |
import threading
|
| 4 |
import numpy as np
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
|
| 8 |
from diart import SpeakerDiarization
|
| 9 |
from diart.inference import StreamingInference
|
| 10 |
from diart.sources import AudioSource
|
| 11 |
from timed_objects import SpeakerSegment
|
| 12 |
+
from diart.sources import MicrophoneAudioSource
|
| 13 |
+
from rx.core import Observer
|
| 14 |
+
from typing import Tuple, Any, List
|
| 15 |
+
from pyannote.core import Annotation
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
|
| 19 |
def extract_number(s: str) -> int:
|
| 20 |
m = re.search(r'\d+', s)
|
| 21 |
return int(m.group()) if m else None
|
| 22 |
|
| 23 |
+
class DiarizationObserver(Observer):
|
| 24 |
+
"""Observer that logs all data emitted by the diarization pipeline and stores speaker segments."""
|
| 25 |
+
|
| 26 |
+
def __init__(self):
|
| 27 |
+
self.speaker_segments = []
|
| 28 |
+
self.processed_time = 0
|
| 29 |
+
self.segment_lock = threading.Lock()
|
| 30 |
+
|
| 31 |
+
def on_next(self, value: Tuple[Annotation, Any]):
|
| 32 |
+
annotation, audio = value
|
| 33 |
+
|
| 34 |
+
logger.debug("\n--- New Diarization Result ---")
|
| 35 |
+
|
| 36 |
+
duration = audio.extent.end - audio.extent.start
|
| 37 |
+
logger.debug(f"Audio segment: {audio.extent.start:.2f}s - {audio.extent.end:.2f}s (duration: {duration:.2f}s)")
|
| 38 |
+
logger.debug(f"Audio shape: {audio.data.shape}")
|
| 39 |
+
|
| 40 |
+
with self.segment_lock:
|
| 41 |
+
if audio.extent.end > self.processed_time:
|
| 42 |
+
self.processed_time = audio.extent.end
|
| 43 |
+
if annotation and len(annotation._labels) > 0:
|
| 44 |
+
logger.debug("\nSpeaker segments:")
|
| 45 |
+
for speaker, label in annotation._labels.items():
|
| 46 |
+
for start, end in zip(label.segments_boundaries_[:-1], label.segments_boundaries_[1:]):
|
| 47 |
+
print(f" {speaker}: {start:.2f}s-{end:.2f}s")
|
| 48 |
+
self.speaker_segments.append(SpeakerSegment(
|
| 49 |
+
speaker=speaker,
|
| 50 |
+
start=start,
|
| 51 |
+
end=end
|
| 52 |
+
))
|
| 53 |
+
else:
|
| 54 |
+
logger.debug("\nNo speakers detected in this segment")
|
| 55 |
+
|
| 56 |
+
def get_segments(self) -> List[SpeakerSegment]:
|
| 57 |
+
"""Get a copy of the current speaker segments."""
|
| 58 |
+
with self.segment_lock:
|
| 59 |
+
return self.speaker_segments.copy()
|
| 60 |
+
|
| 61 |
+
def clear_old_segments(self, older_than: float = 30.0):
|
| 62 |
+
"""Clear segments older than the specified time."""
|
| 63 |
+
with self.segment_lock:
|
| 64 |
+
current_time = self.processed_time
|
| 65 |
+
self.speaker_segments = [
|
| 66 |
+
segment for segment in self.speaker_segments
|
| 67 |
+
if current_time - segment.end < older_than
|
| 68 |
+
]
|
| 69 |
+
|
| 70 |
+
def on_error(self, error):
|
| 71 |
+
"""Handle an error in the stream."""
|
| 72 |
+
logger.debug(f"Error in diarization stream: {error}")
|
| 73 |
+
|
| 74 |
+
def on_completed(self):
|
| 75 |
+
"""Handle the completion of the stream."""
|
| 76 |
+
logger.debug("Diarization stream completed")
|
| 77 |
+
|
| 78 |
|
| 79 |
class WebSocketAudioSource(AudioSource):
|
| 80 |
"""
|
|
|
|
| 97 |
|
| 98 |
def push_audio(self, chunk: np.ndarray):
|
| 99 |
if not self._closed:
|
| 100 |
+
new_audio = np.expand_dims(chunk, axis=0)
|
| 101 |
+
logger.debug('Add new chunk with shape:', new_audio.shape)
|
| 102 |
+
self.stream.on_next(new_audio)
|
| 103 |
|
| 104 |
|
| 105 |
class DiartDiarization:
|
| 106 |
+
def __init__(self, sample_rate: int, use_microphone: bool = False):
|
| 107 |
+
self.pipeline = SpeakerDiarization()
|
| 108 |
+
self.observer = DiarizationObserver()
|
| 109 |
+
|
| 110 |
+
if use_microphone:
|
| 111 |
+
self.source = MicrophoneAudioSource()
|
| 112 |
+
self.custom_source = None
|
| 113 |
+
else:
|
| 114 |
+
self.custom_source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate)
|
| 115 |
+
self.source = self.custom_source
|
| 116 |
+
|
| 117 |
self.inference = StreamingInference(
|
| 118 |
pipeline=self.pipeline,
|
| 119 |
source=self.source,
|
| 120 |
do_plot=False,
|
| 121 |
show_progress=False,
|
| 122 |
)
|
| 123 |
+
self.inference.attach_observers(self.observer)
|
|
|
|
| 124 |
asyncio.get_event_loop().run_in_executor(None, self.inference)
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
async def diarize(self, pcm_array: np.ndarray):
|
| 127 |
+
"""
|
| 128 |
+
Process audio data for diarization.
|
| 129 |
+
Only used when working with WebSocketAudioSource.
|
| 130 |
+
"""
|
| 131 |
+
if self.custom_source:
|
| 132 |
+
self.custom_source.push_audio(pcm_array)
|
| 133 |
+
self.observer.clear_old_segments()
|
| 134 |
+
return self.observer.get_segments()
|
| 135 |
|
| 136 |
def close(self):
|
| 137 |
+
"""Close the audio source."""
|
| 138 |
+
if self.custom_source:
|
| 139 |
+
self.custom_source.close()
|
| 140 |
|
| 141 |
+
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> float:
|
| 142 |
+
"""
|
| 143 |
+
Assign speakers to tokens based on timing overlap with speaker segments.
|
| 144 |
+
Uses the segments collected by the observer.
|
| 145 |
+
"""
|
| 146 |
+
segments = self.observer.get_segments()
|
| 147 |
+
|
| 148 |
for token in tokens:
|
| 149 |
+
for segment in segments:
|
| 150 |
if not (segment.end <= token.start or segment.start >= token.end):
|
| 151 |
token.speaker = extract_number(segment.speaker) + 1
|
| 152 |
end_attributed_speaker = max(token.end, end_attributed_speaker)
|
| 153 |
+
return end_attributed_speaker
|
timed_objects.py
CHANGED
|
@@ -8,6 +8,7 @@ class TimedText:
|
|
| 8 |
text: Optional[str] = ''
|
| 9 |
speaker: Optional[int] = -1
|
| 10 |
probability: Optional[float] = None
|
|
|
|
| 11 |
|
| 12 |
@dataclass
|
| 13 |
class ASRToken(TimedText):
|
|
|
|
| 8 |
text: Optional[str] = ''
|
| 9 |
speaker: Optional[int] = -1
|
| 10 |
probability: Optional[float] = None
|
| 11 |
+
is_dummy: Optional[bool] = False
|
| 12 |
|
| 13 |
@dataclass
|
| 14 |
class ASRToken(TimedText):
|
whisper_fastapi_online_server.py
CHANGED
|
@@ -49,7 +49,7 @@ parser.add_argument(
|
|
| 49 |
parser.add_argument(
|
| 50 |
"--confidence-validation",
|
| 51 |
type=bool,
|
| 52 |
-
default=
|
| 53 |
help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.",
|
| 54 |
)
|
| 55 |
|
|
@@ -110,9 +110,10 @@ class SharedState:
|
|
| 110 |
current_time = time() - self.beg_loop
|
| 111 |
dummy_token = ASRToken(
|
| 112 |
start=current_time,
|
| 113 |
-
end=current_time +
|
| 114 |
-
text="",
|
| 115 |
-
speaker=-1
|
|
|
|
| 116 |
)
|
| 117 |
self.tokens.append(dummy_token)
|
| 118 |
|
|
@@ -275,14 +276,13 @@ async def results_formatter(shared_state, websocket):
|
|
| 275 |
sep = state["sep"]
|
| 276 |
|
| 277 |
# If diarization is enabled but no transcription, add dummy tokens periodically
|
| 278 |
-
if not tokens and not args.transcription and args.diarization:
|
| 279 |
await shared_state.add_dummy_token()
|
| 280 |
-
|
| 281 |
state = await shared_state.get_current_state()
|
| 282 |
tokens = state["tokens"]
|
| 283 |
-
|
| 284 |
# Process tokens to create response
|
| 285 |
-
previous_speaker = -
|
| 286 |
lines = []
|
| 287 |
last_end_diarized = 0
|
| 288 |
undiarized_text = []
|
|
|
|
| 49 |
parser.add_argument(
|
| 50 |
"--confidence-validation",
|
| 51 |
type=bool,
|
| 52 |
+
default=False,
|
| 53 |
help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.",
|
| 54 |
)
|
| 55 |
|
|
|
|
| 110 |
current_time = time() - self.beg_loop
|
| 111 |
dummy_token = ASRToken(
|
| 112 |
start=current_time,
|
| 113 |
+
end=current_time + 1,
|
| 114 |
+
text=".",
|
| 115 |
+
speaker=-1,
|
| 116 |
+
is_dummy=True
|
| 117 |
)
|
| 118 |
self.tokens.append(dummy_token)
|
| 119 |
|
|
|
|
| 276 |
sep = state["sep"]
|
| 277 |
|
| 278 |
# If diarization is enabled but no transcription, add dummy tokens periodically
|
| 279 |
+
if (not tokens or tokens[-1].is_dummy) and not args.transcription and args.diarization:
|
| 280 |
await shared_state.add_dummy_token()
|
| 281 |
+
sleep(0.5)
|
| 282 |
state = await shared_state.get_current_state()
|
| 283 |
tokens = state["tokens"]
|
|
|
|
| 284 |
# Process tokens to create response
|
| 285 |
+
previous_speaker = -1
|
| 286 |
lines = []
|
| 287 |
last_end_diarized = 0
|
| 288 |
undiarized_text = []
|