qfuxa commited on
Commit
b81e509
·
1 Parent(s): 9076fea

all text-related classes now share a common TimedText base class

Browse files
src/whisper_streaming/backends.py CHANGED
@@ -6,7 +6,7 @@ import math
6
  import torch
7
  from typing import List
8
  import numpy as np
9
- from src.whisper_streaming.asr_token import ASRToken
10
 
11
  logger = logging.getLogger(__name__)
12
 
 
6
  import torch
7
  from typing import List
8
  import numpy as np
9
+ from src.whisper_streaming.timed_objects import ASRToken
10
 
11
  logger = logging.getLogger(__name__)
12
 
src/whisper_streaming/online_asr.py CHANGED
@@ -2,37 +2,10 @@ import sys
2
  import numpy as np
3
  import logging
4
  from typing import List, Tuple, Optional
5
- from src.whisper_streaming.asr_token import ASRToken
6
 
7
  logger = logging.getLogger(__name__)
8
 
9
- class Sentence:
10
- """
11
- A sentence assembled from tokens.
12
- """
13
- def __init__(self, start: float, end: float, text: str):
14
- self.start = start
15
- self.end = end
16
- self.text = text
17
-
18
- def __repr__(self):
19
- return f"Sentence(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})"
20
-
21
- class Transcript:
22
- """
23
- A transcript that bundles a start time, an end time, and a concatenated text.
24
- """
25
- def __init__(self, start: Optional[float], end: Optional[float], text: str):
26
- self.start = start
27
- self.end = end
28
- self.text = text
29
-
30
- def __iter__(self):
31
- return iter((self.start, self.end, self.text))
32
-
33
- def __repr__(self):
34
- return f"Transcript(start={self.start}, end={self.end}, text={self.text!r})"
35
-
36
 
37
  class HypothesisBuffer:
38
  """
@@ -111,10 +84,6 @@ class HypothesisBuffer:
111
  while self.committed_in_buffer and self.committed_in_buffer[0].end <= time:
112
  self.committed_in_buffer.pop(0)
113
 
114
- def complete(self) -> List[ASRToken]:
115
- """Return any remaining tokens (i.e. the current buffer)."""
116
- return self.buffer
117
-
118
 
119
  class OnlineASRProcessor:
120
  """
@@ -211,7 +180,7 @@ class OnlineASRProcessor:
211
  self.committed.extend(committed_tokens)
212
  completed = self.concatenate_tokens(committed_tokens)
213
  logger.debug(f">>>> COMPLETE NOW: {completed.text}")
214
- incomp = self.concatenate_tokens(self.transcript_buffer.complete())
215
  logger.debug(f"INCOMPLETE: {incomp.text}")
216
 
217
  if committed_tokens and self.buffer_trimming_way == "sentence":
@@ -318,7 +287,7 @@ class OnlineASRProcessor:
318
  """
319
  Flush the remaining transcript when processing ends.
320
  """
321
- remaining_tokens = self.transcript_buffer.complete()
322
  final_transcript = self.concatenate_tokens(remaining_tokens)
323
  logger.debug(f"Final non-committed transcript: {final_transcript}")
324
  self.buffer_time_offset += len(self.audio_buffer) / self.SAMPLING_RATE
 
2
  import numpy as np
3
  import logging
4
  from typing import List, Tuple, Optional
5
+ from src.whisper_streaming.timed_objects import ASRToken, Sentence, Transcript
6
 
7
  logger = logging.getLogger(__name__)
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  class HypothesisBuffer:
11
  """
 
84
  while self.committed_in_buffer and self.committed_in_buffer[0].end <= time:
85
  self.committed_in_buffer.pop(0)
86
 
 
 
 
 
87
 
88
  class OnlineASRProcessor:
89
  """
 
180
  self.committed.extend(committed_tokens)
181
  completed = self.concatenate_tokens(committed_tokens)
182
  logger.debug(f">>>> COMPLETE NOW: {completed.text}")
183
+ incomp = self.concatenate_tokens(self.transcript_buffer.buffer)
184
  logger.debug(f"INCOMPLETE: {incomp.text}")
185
 
186
  if committed_tokens and self.buffer_trimming_way == "sentence":
 
287
  """
288
  Flush the remaining transcript when processing ends.
289
  """
290
+ remaining_tokens = self.transcript_buffer.buffer
291
  final_transcript = self.concatenate_tokens(remaining_tokens)
292
  logger.debug(f"Final non-committed transcript: {final_transcript}")
293
  self.buffer_time_offset += len(self.audio_buffer) / self.SAMPLING_RATE
src/whisper_streaming/{asr_token.py → timed_objects.py} RENAMED
@@ -1,15 +1,22 @@
1
- class ASRToken:
2
- """
3
- A token (word) from the ASR system with start/end times and text.
4
- """
5
- def __init__(self, start: float, end: float, text: str):
6
- self.start = start
7
- self.end = end
8
- self.text = text
9
 
 
 
 
 
 
 
 
 
10
  def with_offset(self, offset: float) -> "ASRToken":
11
  """Return a new token with the time offset added."""
12
  return ASRToken(self.start + offset, self.end + offset, self.text)
13
 
14
- def __repr__(self):
15
- return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})"
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
 
 
 
 
 
 
3
 
4
+ @dataclass
5
+ class TimedText:
6
+ start: Optional[float]
7
+ end: Optional[float]
8
+ text: str
9
+
10
+ @dataclass
11
+ class ASRToken(TimedText):
12
  def with_offset(self, offset: float) -> "ASRToken":
13
  """Return a new token with the time offset added."""
14
  return ASRToken(self.start + offset, self.end + offset, self.text)
15
 
16
+ @dataclass
17
+ class Sentence(TimedText):
18
+ pass
19
+
20
+ @dataclass
21
+ class Transcript(TimedText):
22
+ pass
whisper_fastapi_online_server.py CHANGED
@@ -201,17 +201,17 @@ async def websocket_endpoint(websocket: WebSocket):
201
  )
202
  pcm_buffer = bytearray()
203
  online.insert_audio_chunk(pcm_array)
204
- beg_trans, end_trans, trans = online.process_iter()
205
 
206
- if trans:
207
  chunk_history.append({
208
- "beg": beg_trans,
209
- "end": end_trans,
210
- "text": trans,
211
  "speaker": "0"
212
  })
213
 
214
- full_transcription += trans
215
  if args.vac:
216
  transcript = online.online.concatenate_tokens(online.online.transcript_buffer.buffer)
217
  else:
 
201
  )
202
  pcm_buffer = bytearray()
203
  online.insert_audio_chunk(pcm_array)
204
+ transcription = online.process_iter()
205
 
206
+ if transcription:
207
  chunk_history.append({
208
+ "beg": transcription.start,
209
+ "end": transcription.end,
210
+ "text": transcription.text,
211
  "speaker": "0"
212
  })
213
 
214
+ full_transcription += transcription.text
215
  if args.vac:
216
  transcript = online.online.concatenate_tokens(online.online.transcript_buffer.buffer)
217
  else: