diarization now works at word - not chunk - level!
Browse files
src/diarization/diarization_online.py
CHANGED
|
@@ -81,11 +81,10 @@ class DiartDiarization:
|
|
| 81 |
def close(self):
|
| 82 |
self.source.close()
|
| 83 |
|
| 84 |
-
def
|
| 85 |
-
|
| 86 |
-
for chunk in chunks:
|
| 87 |
for segment in self.segment_speakers:
|
| 88 |
-
if not (segment["end"] <=
|
| 89 |
-
|
| 90 |
-
end_attributed_speaker =
|
| 91 |
return end_attributed_speaker
|
|
|
|
| 81 |
def close(self):
|
| 82 |
self.source.close()
|
| 83 |
|
| 84 |
+
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> list:
|
| 85 |
+
for token in tokens:
|
|
|
|
| 86 |
for segment in self.segment_speakers:
|
| 87 |
+
if not (segment["end"] <= token.start or segment["beg"] >= token.end):
|
| 88 |
+
token.speaker = extract_number(segment["speaker"]) + 1
|
| 89 |
+
end_attributed_speaker = max(token.end, end_attributed_speaker)
|
| 90 |
return end_attributed_speaker
|
src/whisper_streaming/online_asr.py
CHANGED
|
@@ -202,7 +202,7 @@ class OnlineASRProcessor:
|
|
| 202 |
logger.debug(
|
| 203 |
f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
|
| 204 |
)
|
| 205 |
-
return
|
| 206 |
|
| 207 |
def chunk_completed_sentence(self):
|
| 208 |
"""
|
|
|
|
| 202 |
logger.debug(
|
| 203 |
f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
|
| 204 |
)
|
| 205 |
+
return committed_tokens
|
| 206 |
|
| 207 |
def chunk_completed_sentence(self):
|
| 208 |
"""
|
src/whisper_streaming/timed_objects.py
CHANGED
|
@@ -5,7 +5,8 @@ from typing import Optional
|
|
| 5 |
class TimedText:
|
| 6 |
start: Optional[float]
|
| 7 |
end: Optional[float]
|
| 8 |
-
text: str
|
|
|
|
| 9 |
|
| 10 |
@dataclass
|
| 11 |
class ASRToken(TimedText):
|
|
|
|
| 5 |
class TimedText:
|
| 6 |
start: Optional[float]
|
| 7 |
end: Optional[float]
|
| 8 |
+
text: Optional[str] = ''
|
| 9 |
+
speaker: Optional[int] = -1
|
| 10 |
|
| 11 |
@dataclass
|
| 12 |
class ASRToken(TimedText):
|
whisper_fastapi_online_server.py
CHANGED
|
@@ -11,6 +11,7 @@ from fastapi.responses import HTMLResponse
|
|
| 11 |
from fastapi.middleware.cors import CORSMiddleware
|
| 12 |
|
| 13 |
from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args
|
|
|
|
| 14 |
|
| 15 |
import math
|
| 16 |
import logging
|
|
@@ -47,7 +48,7 @@ parser.add_argument(
|
|
| 47 |
parser.add_argument(
|
| 48 |
"--diarization",
|
| 49 |
type=bool,
|
| 50 |
-
default=
|
| 51 |
help="Whether to enable speaker diarization.",
|
| 52 |
)
|
| 53 |
|
|
@@ -157,7 +158,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 157 |
full_transcription = ""
|
| 158 |
beg = time()
|
| 159 |
beg_loop = time()
|
| 160 |
-
|
|
|
|
|
|
|
| 161 |
|
| 162 |
while True:
|
| 163 |
try:
|
|
@@ -177,7 +180,6 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 177 |
logger.warning("FFmpeg read timeout. Restarting...")
|
| 178 |
await restart_ffmpeg()
|
| 179 |
full_transcription = ""
|
| 180 |
-
chunk_history = []
|
| 181 |
beg = time()
|
| 182 |
continue # Skip processing and read from new process
|
| 183 |
|
|
@@ -202,63 +204,53 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 202 |
if args.transcription:
|
| 203 |
logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.")
|
| 204 |
online.insert_audio_chunk(pcm_array)
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
"beg": transcription.start,
|
| 209 |
-
"end": transcription.end,
|
| 210 |
-
"text": transcription.text,
|
| 211 |
-
"speaker": -1
|
| 212 |
-
})
|
| 213 |
-
full_transcription += transcription.text if transcription else ""
|
| 214 |
buffer = online.get_buffer()
|
| 215 |
if buffer in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed
|
| 216 |
buffer = ""
|
| 217 |
else:
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
})
|
| 224 |
-
sleep(1)
|
| 225 |
buffer = ''
|
| 226 |
|
| 227 |
if args.diarization:
|
| 228 |
await diarization.diarize(pcm_array)
|
| 229 |
-
end_attributed_speaker = diarization.
|
| 230 |
-
|
| 231 |
|
| 232 |
-
|
| 233 |
lines = []
|
| 234 |
last_end_diarized = 0
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
speaker = ch.get("speaker")
|
| 238 |
if args.diarization:
|
| 239 |
if speaker == -1 or speaker == 0:
|
| 240 |
-
if
|
| 241 |
speaker = previous_speaker
|
| 242 |
else:
|
| 243 |
speaker = 0
|
| 244 |
else:
|
| 245 |
-
last_end_diarized = max(
|
| 246 |
|
| 247 |
-
if speaker !=
|
| 248 |
lines.append(
|
| 249 |
{
|
| 250 |
"speaker": speaker,
|
| 251 |
-
"text":
|
| 252 |
-
"beg": format_time(
|
| 253 |
-
"end": format_time(
|
| 254 |
-
"diff": round(
|
| 255 |
}
|
| 256 |
)
|
| 257 |
-
|
| 258 |
else:
|
| 259 |
-
lines[-1]["text"] +=
|
| 260 |
-
lines[-1]["end"] = format_time(
|
| 261 |
-
lines[-1]["diff"] = round(
|
| 262 |
|
| 263 |
response = {"lines": lines, "buffer": buffer}
|
| 264 |
await websocket.send_json(response)
|
|
|
|
| 11 |
from fastapi.middleware.cors import CORSMiddleware
|
| 12 |
|
| 13 |
from src.whisper_streaming.whisper_online import backend_factory, online_factory, add_shared_args
|
| 14 |
+
from src.whisper_streaming.timed_objects import ASRToken
|
| 15 |
|
| 16 |
import math
|
| 17 |
import logging
|
|
|
|
| 48 |
parser.add_argument(
|
| 49 |
"--diarization",
|
| 50 |
type=bool,
|
| 51 |
+
default=True,
|
| 52 |
help="Whether to enable speaker diarization.",
|
| 53 |
)
|
| 54 |
|
|
|
|
| 158 |
full_transcription = ""
|
| 159 |
beg = time()
|
| 160 |
beg_loop = time()
|
| 161 |
+
tokens = []
|
| 162 |
+
end_attributed_speaker = 0
|
| 163 |
+
sep = online.asr.sep
|
| 164 |
|
| 165 |
while True:
|
| 166 |
try:
|
|
|
|
| 180 |
logger.warning("FFmpeg read timeout. Restarting...")
|
| 181 |
await restart_ffmpeg()
|
| 182 |
full_transcription = ""
|
|
|
|
| 183 |
beg = time()
|
| 184 |
continue # Skip processing and read from new process
|
| 185 |
|
|
|
|
| 204 |
if args.transcription:
|
| 205 |
logger.info(f"{len(online.audio_buffer) / online.SAMPLING_RATE} seconds of audio will be processed by the model.")
|
| 206 |
online.insert_audio_chunk(pcm_array)
|
| 207 |
+
new_tokens = online.process_iter()
|
| 208 |
+
tokens.extend(new_tokens)
|
| 209 |
+
full_transcription += sep.join([t.text for t in new_tokens])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
buffer = online.get_buffer()
|
| 211 |
if buffer in full_transcription: # With VAC, the buffer is not updated until the next chunk is processed
|
| 212 |
buffer = ""
|
| 213 |
else:
|
| 214 |
+
tokens.append(
|
| 215 |
+
ASRToken(
|
| 216 |
+
start = time() - beg_loop,
|
| 217 |
+
end = time() - beg_loop + 0.5))
|
| 218 |
+
sleep(0.5)
|
|
|
|
|
|
|
| 219 |
buffer = ''
|
| 220 |
|
| 221 |
if args.diarization:
|
| 222 |
await diarization.diarize(pcm_array)
|
| 223 |
+
end_attributed_speaker = diarization.assign_speakers_to_tokens(end_attributed_speaker, tokens)
|
|
|
|
| 224 |
|
| 225 |
+
previous_speaker = -10
|
| 226 |
lines = []
|
| 227 |
last_end_diarized = 0
|
| 228 |
+
for token in tokens:
|
| 229 |
+
speaker = token.speaker
|
|
|
|
| 230 |
if args.diarization:
|
| 231 |
if speaker == -1 or speaker == 0:
|
| 232 |
+
if token.end < end_attributed_speaker:
|
| 233 |
speaker = previous_speaker
|
| 234 |
else:
|
| 235 |
speaker = 0
|
| 236 |
else:
|
| 237 |
+
last_end_diarized = max(token.end, last_end_diarized)
|
| 238 |
|
| 239 |
+
if speaker != previous_speaker:
|
| 240 |
lines.append(
|
| 241 |
{
|
| 242 |
"speaker": speaker,
|
| 243 |
+
"text": token.text,
|
| 244 |
+
"beg": format_time(token.start),
|
| 245 |
+
"end": format_time(token.end),
|
| 246 |
+
"diff": round(token.end - last_end_diarized, 2)
|
| 247 |
}
|
| 248 |
)
|
| 249 |
+
previous_speaker = speaker
|
| 250 |
else:
|
| 251 |
+
lines[-1]["text"] += sep + token.text
|
| 252 |
+
lines[-1]["end"] = format_time(token.end)
|
| 253 |
+
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
|
| 254 |
|
| 255 |
response = {"lines": lines, "buffer": buffer}
|
| 256 |
await websocket.send_json(response)
|