qfuxa commited on
Commit
de779de
·
1 Parent(s): b3a32cd

clean diart audiosource class

Browse files
Files changed (1) hide show
  1. src/diarization/diarization_online.py +60 -99
src/diarization/diarization_online.py CHANGED
@@ -1,26 +1,27 @@
 
 
 
 
 
1
  from diart import SpeakerDiarization
2
  from diart.inference import StreamingInference
3
  from diart.sources import AudioSource
4
- from rx.subject import Subject
5
- import threading
6
- import numpy as np
7
- import asyncio
8
- import re
9
 
10
- def extract_number(s):
11
- match = re.search(r'\d+', s)
12
- return int(match.group()) if match else None
 
 
13
 
14
  class WebSocketAudioSource(AudioSource):
15
  """
16
- Simple custom AudioSource that blocks in read()
17
- until close() is called.
18
- push_audio() is used to inject new PCM chunks.
19
  """
20
  def __init__(self, uri: str = "websocket", sample_rate: int = 16000):
21
  super().__init__(uri, sample_rate)
22
- self._close_event = threading.Event()
23
  self._closed = False
 
24
 
25
  def read(self):
26
  self._close_event.wait()
@@ -32,99 +33,59 @@ class WebSocketAudioSource(AudioSource):
32
  self._close_event.set()
33
 
34
  def push_audio(self, chunk: np.ndarray):
35
- chunk = np.expand_dims(chunk, axis=0)
36
  if not self._closed:
37
- self.stream.on_next(chunk)
38
-
39
 
40
- def create_pipeline(SAMPLE_RATE):
41
- diar_pipeline = SpeakerDiarization()
42
- ws_source = WebSocketAudioSource(uri="websocket_source", sample_rate=SAMPLE_RATE)
43
- inference = StreamingInference(
44
- pipeline=diar_pipeline,
45
- source=ws_source,
46
- do_plot=False,
47
- show_progress=False,
48
- )
49
- return inference, ws_source
50
 
51
-
52
- def init_diart(SAMPLE_RATE, diar_instance):
53
- diar_pipeline = SpeakerDiarization()
54
- ws_source = WebSocketAudioSource(uri="websocket_source", sample_rate=SAMPLE_RATE)
55
- inference = StreamingInference(
56
- pipeline=diar_pipeline,
57
- source=ws_source,
58
- do_plot=False,
59
- show_progress=False,
60
- )
61
-
62
- l_speakers_queue = asyncio.Queue()
63
-
64
- def diar_hook(result):
65
- """
66
- Hook called each time Diart processes a chunk.
67
- result is (annotation, audio).
68
- For each detected speaker segment, push its info to the queue and update processed_time.
69
- """
70
  annotation, audio = result
71
  if annotation._labels:
72
- for speaker in annotation._labels:
73
- segments_beg = annotation._labels[speaker].segments_boundaries_[0]
74
- segments_end = annotation._labels[speaker].segments_boundaries_[-1]
75
- if segments_end > diar_instance.processed_time:
76
- diar_instance.processed_time = segments_end
77
- asyncio.create_task(
78
- l_speakers_queue.put({"speaker": speaker, "beg": segments_beg, "end": segments_end})
79
- )
 
 
80
  else:
81
- audio_duration = audio.extent.end
82
- if audio_duration > diar_instance.processed_time:
83
- diar_instance.processed_time = audio_duration
84
 
85
- inference.attach_hooks(diar_hook)
86
- loop = asyncio.get_event_loop()
87
- diar_future = loop.run_in_executor(None, inference)
88
- return inference, l_speakers_queue, ws_source
89
-
90
- class DiartDiarization:
91
- def __init__(self, SAMPLE_RATE):
92
- self.processed_time = 0
93
- self.inference, self.l_speakers_queue, self.ws_source = init_diart(SAMPLE_RATE, self)
94
- self.segment_speakers = []
95
-
96
- async def diarize(self, pcm_array):
97
- self.ws_source.push_audio(pcm_array)
98
- self.segment_speakers = []
99
- while not self.l_speakers_queue.empty():
100
- self.segment_speakers.append(await self.l_speakers_queue.get())
101
 
102
  def close(self):
103
- self.ws_source.close()
104
-
105
- def assign_speakers_to_chunks(self, chunks):
106
- """
107
- For each chunk (a dict with keys "beg" and "end"), assign a speaker label.
108
-
109
- - If a chunk overlaps with a detected speaker segment, assign that label.
110
- - If the chunk's end time is within the processed time and no speaker was assigned,
111
- mark it as "No speaker".
112
- - If the chunk's time hasn't been fully processed yet, leave it (or mark as "Processing").
113
- """
114
- for ch in chunks:
115
- ch["speaker"] = ch.get("speaker", -1)
116
-
117
- for segment in self.segment_speakers:
118
- seg_beg = segment["beg"]
119
- seg_end = segment["end"]
120
- speaker = segment["speaker"]
121
- for ch in chunks:
122
- if seg_end <= ch["beg"] or seg_beg >= ch["end"]:
123
- continue
124
- ch["speaker"] = extract_number(speaker) + 1
125
- if self.processed_time > 0:
126
- for ch in chunks:
127
- if ch["end"] <= self.processed_time and ch["speaker"] == -1:
128
- ch["speaker"] = -2
129
-
130
- return chunks
 
1
+ import asyncio
2
+ import re
3
+ import threading
4
+ import numpy as np
5
+
6
  from diart import SpeakerDiarization
7
  from diart.inference import StreamingInference
8
  from diart.sources import AudioSource
 
 
 
 
 
9
 
10
+
11
+ def extract_number(s: str) -> int:
12
+ m = re.search(r'\d+', s)
13
+ return int(m.group()) if m else None
14
+
15
 
16
  class WebSocketAudioSource(AudioSource):
17
  """
18
+ Custom AudioSource that blocks in read() until close() is called.
19
+ Use push_audio() to inject PCM chunks.
 
20
  """
21
  def __init__(self, uri: str = "websocket", sample_rate: int = 16000):
22
  super().__init__(uri, sample_rate)
 
23
  self._closed = False
24
+ self._close_event = threading.Event()
25
 
26
  def read(self):
27
  self._close_event.wait()
 
33
  self._close_event.set()
34
 
35
  def push_audio(self, chunk: np.ndarray):
 
36
  if not self._closed:
37
+ self.stream.on_next(np.expand_dims(chunk, axis=0))
 
38
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ class DiartDiarization:
41
+ def __init__(self, sample_rate: int):
42
+ self.processed_time = 0
43
+ self.segment_speakers = []
44
+ self.speakers_queue = asyncio.Queue()
45
+ self.pipeline = SpeakerDiarization()
46
+ self.source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate)
47
+ self.inference = StreamingInference(
48
+ pipeline=self.pipeline,
49
+ source=self.source,
50
+ do_plot=False,
51
+ show_progress=False,
52
+ )
53
+ # Attache la fonction hook et démarre l'inférence en arrière-plan.
54
+ self.inference.attach_hooks(self._diar_hook)
55
+ asyncio.get_event_loop().run_in_executor(None, self.inference)
56
+
57
+ def _diar_hook(self, result):
 
58
  annotation, audio = result
59
  if annotation._labels:
60
+ for speaker, label in annotation._labels.items():
61
+ beg = label.segments_boundaries_[0]
62
+ end = label.segments_boundaries_[-1]
63
+ if end > self.processed_time:
64
+ self.processed_time = end
65
+ asyncio.create_task(self.speakers_queue.put({
66
+ "speaker": speaker,
67
+ "beg": beg,
68
+ "end": end
69
+ }))
70
  else:
71
+ dur = audio.extent.end
72
+ if dur > self.processed_time:
73
+ self.processed_time = dur
74
 
75
+ async def diarize(self, pcm_array: np.ndarray):
76
+ self.source.push_audio(pcm_array)
77
+ self.segment_speakers.clear()
78
+ while not self.speakers_queue.empty():
79
+ self.segment_speakers.append(await self.speakers_queue.get())
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  def close(self):
82
+ self.source.close()
83
+
84
+ def assign_speakers_to_chunks(self, chunks: list) -> list:
85
+ end_attributed_speaker = 0
86
+ for chunk in chunks:
87
+ for segment in self.segment_speakers:
88
+ if not (segment["end"] <= chunk["beg"] or segment["beg"] >= chunk["end"]):
89
+ chunk["speaker"] = extract_number(segment["speaker"]) + 1
90
+ end_attributed_speaker = chunk["end"]
91
+ return end_attributed_speaker