Enhance diarization logic to improve speaker attribution : corrects several bugs
Browse files- whisper_fastapi_online_server.py +19 -10
whisper_fastapi_online_server.py
CHANGED
|
@@ -208,6 +208,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 208 |
"beg": transcription.start,
|
| 209 |
"end": transcription.end,
|
| 210 |
"text": transcription.text,
|
|
|
|
| 211 |
})
|
| 212 |
full_transcription += transcription.text if transcription else ""
|
| 213 |
buffer = online.get_buffer()
|
|
@@ -218,23 +219,32 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 218 |
"beg": time() - beg_loop,
|
| 219 |
"end": time() - beg_loop + 1,
|
| 220 |
"text": '',
|
|
|
|
| 221 |
})
|
| 222 |
sleep(1)
|
| 223 |
buffer = ''
|
| 224 |
|
| 225 |
if args.diarization:
|
| 226 |
await diarization.diarize(pcm_array)
|
| 227 |
-
diarization.assign_speakers_to_chunks(chunk_history)
|
| 228 |
|
| 229 |
|
| 230 |
-
current_speaker =
|
| 231 |
lines = []
|
| 232 |
last_end_diarized = 0
|
|
|
|
| 233 |
for ind, ch in enumerate(chunk_history):
|
| 234 |
-
speaker = ch.get("speaker"
|
| 235 |
-
if
|
| 236 |
-
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
lines.append(
|
| 239 |
{
|
| 240 |
"speaker": speaker,
|
|
@@ -245,12 +255,11 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 245 |
}
|
| 246 |
)
|
| 247 |
current_speaker = speaker
|
| 248 |
-
|
| 249 |
lines[-1]["text"] += ch['text']
|
| 250 |
lines[-1]["end"] = format_time(ch['end'])
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
response = {"lines": lines, "buffer": buffer}
|
| 255 |
await websocket.send_json(response)
|
| 256 |
|
|
|
|
| 208 |
"beg": transcription.start,
|
| 209 |
"end": transcription.end,
|
| 210 |
"text": transcription.text,
|
| 211 |
+
"speaker": -1
|
| 212 |
})
|
| 213 |
full_transcription += transcription.text if transcription else ""
|
| 214 |
buffer = online.get_buffer()
|
|
|
|
| 219 |
"beg": time() - beg_loop,
|
| 220 |
"end": time() - beg_loop + 1,
|
| 221 |
"text": '',
|
| 222 |
+
"speaker": -1
|
| 223 |
})
|
| 224 |
sleep(1)
|
| 225 |
buffer = ''
|
| 226 |
|
| 227 |
if args.diarization:
|
| 228 |
await diarization.diarize(pcm_array)
|
| 229 |
+
end_attributed_speaker = diarization.assign_speakers_to_chunks(chunk_history)
|
| 230 |
|
| 231 |
|
| 232 |
+
current_speaker = -10
|
| 233 |
lines = []
|
| 234 |
last_end_diarized = 0
|
| 235 |
+
previous_speaker = -1
|
| 236 |
for ind, ch in enumerate(chunk_history):
|
| 237 |
+
speaker = ch.get("speaker")
|
| 238 |
+
if args.diarization:
|
| 239 |
+
if speaker == -1 or speaker == 0:
|
| 240 |
+
if ch['end'] < end_attributed_speaker:
|
| 241 |
+
speaker = previous_speaker
|
| 242 |
+
else:
|
| 243 |
+
speaker = 0
|
| 244 |
+
else:
|
| 245 |
+
last_end_diarized = max(ch['end'], last_end_diarized)
|
| 246 |
+
|
| 247 |
+
if speaker != current_speaker:
|
| 248 |
lines.append(
|
| 249 |
{
|
| 250 |
"speaker": speaker,
|
|
|
|
| 255 |
}
|
| 256 |
)
|
| 257 |
current_speaker = speaker
|
| 258 |
+
else:
|
| 259 |
lines[-1]["text"] += ch['text']
|
| 260 |
lines[-1]["end"] = format_time(ch['end'])
|
| 261 |
+
lines[-1]["diff"] = round(ch['end'] - last_end_diarized, 2)
|
| 262 |
+
|
|
|
|
| 263 |
response = {"lines": lines, "buffer": buffer}
|
| 264 |
await websocket.send_json(response)
|
| 265 |
|