qfuxa commited on
Commit
3ac7218
·
1 Parent(s): 6b5070a

adapt backend for the new classes

Browse files
Files changed (1) hide show
  1. src/whisper_streaming/backends.py +80 -158
src/whisper_streaming/backends.py CHANGED
@@ -1,45 +1,47 @@
1
  import sys
2
  import logging
3
-
4
  import io
5
  import soundfile as sf
6
  import math
7
  import torch
 
 
 
8
 
9
  logger = logging.getLogger(__name__)
10
 
11
  class ASRBase:
12
  sep = " " # join transcribe words with this character (" " for whisper_timestamped,
13
- # "" for faster-whisper because it emits the spaces when neeeded)
14
 
15
- def __init__(
16
- self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr
17
- ):
18
  self.logfile = logfile
19
-
20
  self.transcribe_kargs = {}
21
  if lan == "auto":
22
  self.original_language = None
23
  else:
24
  self.original_language = lan
25
-
26
  self.model = self.load_model(modelsize, cache_dir, model_dir)
27
 
28
- def load_model(self, modelsize, cache_dir):
29
- raise NotImplemented("must be implemented in the child class")
 
 
 
 
 
 
 
30
 
31
  def transcribe(self, audio, init_prompt=""):
32
- raise NotImplemented("must be implemented in the child class")
33
 
34
  def use_vad(self):
35
- raise NotImplemented("must be implemented in the child class")
36
 
37
 
38
  class WhisperTimestampedASR(ASRBase):
39
- """Uses whisper_timestamped library as the backend. Initially, we tested the code on this backend. It worked, but slower than faster-whisper.
40
- On the other hand, the installation for GPU could be easier.
41
- """
42
-
43
  sep = " "
44
 
45
  def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
@@ -64,17 +66,19 @@ class WhisperTimestampedASR(ASRBase):
64
  )
65
  return result
66
 
67
- def ts_words(self, r):
68
- # return: transcribe result object to [(beg,end,"word1"), ...]
69
- o = []
70
- for s in r["segments"]:
71
- for w in s["words"]:
72
- t = (w["start"], w["end"], w["text"])
73
- o.append(t)
74
- return o
 
 
75
 
76
- def segments_end_ts(self, res):
77
- return [s["end"] for s in res["segments"]]
78
 
79
  def use_vad(self):
80
  self.transcribe_kargs["vad"] = True
@@ -84,24 +88,20 @@ class WhisperTimestampedASR(ASRBase):
84
 
85
 
86
  class FasterWhisperASR(ASRBase):
87
- """Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version."""
88
-
89
  sep = ""
90
 
91
  def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
92
  from faster_whisper import WhisperModel
93
 
94
- # logging.getLogger("faster_whisper").setLevel(logger.level)
95
  if model_dir is not None:
96
- logger.debug(
97
- f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used."
98
- )
99
  model_size_or_path = model_dir
100
  elif modelsize is not None:
101
  model_size_or_path = modelsize
102
  else:
103
- raise ValueError("modelsize or model_dir parameter must be set")
104
-
105
  device = "cuda" if torch.cuda.is_available() else "cpu"
106
  compute_type = "float16" if device == "cuda" else "float32"
107
 
@@ -111,19 +111,9 @@ class FasterWhisperASR(ASRBase):
111
  compute_type=compute_type,
112
  download_root=cache_dir,
113
  )
114
-
115
- # or run on GPU with INT8
116
- # tested: the transcripts were different, probably worse than with FP16, and it was slightly (appx 20%) slower
117
- # model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
118
-
119
- # or run on CPU with INT8
120
- # tested: works, but slow, appx 10-times than cuda FP16
121
- # model = WhisperModel(modelsize, device="cpu", compute_type="int8") #, download_root="faster-disk-cache-dir/")
122
  return model
123
 
124
- def transcribe(self, audio, init_prompt=""):
125
-
126
- # tested: beam_size=5 is faster and better than 1 (on one 200 second document from En ESIC, min chunk 0.01)
127
  segments, info = self.model.transcribe(
128
  audio,
129
  language=self.original_language,
@@ -133,24 +123,20 @@ class FasterWhisperASR(ASRBase):
133
  condition_on_previous_text=True,
134
  **self.transcribe_kargs,
135
  )
136
- # print(info) # info contains language detection result
137
-
138
  return list(segments)
139
 
140
- def ts_words(self, segments):
141
- o = []
142
  for segment in segments:
 
 
143
  for word in segment.words:
144
- if segment.no_speech_prob > 0.9:
145
- continue
146
- # not stripping the spaces -- should not be merged with them!
147
- w = word.word
148
- t = (word.start, word.end, w)
149
- o.append(t)
150
- return o
151
 
152
- def segments_end_ts(self, res):
153
- return [s.end for s in res]
154
 
155
  def use_vad(self):
156
  self.transcribe_kargs["vad_filter"] = True
@@ -161,60 +147,29 @@ class FasterWhisperASR(ASRBase):
161
 
162
  class MLXWhisper(ASRBase):
163
  """
164
- Uses MPX Whisper library as the backend, optimized for Apple Silicon.
165
- Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc
166
- Significantly faster than faster-whisper (without CUDA) on Apple M1.
167
  """
168
-
169
- sep = "" # In my experience in french it should also be no space.
170
 
171
  def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
172
- """
173
- Loads the MLX-compatible Whisper model.
174
-
175
- Args:
176
- modelsize (str, optional): The size or name of the Whisper model to load.
177
- If provided, it will be translated to an MLX-compatible model path using the `translate_model_name` method.
178
- Example: "large-v3-turbo" -> "mlx-community/whisper-large-v3-turbo".
179
- cache_dir (str, optional): Path to the directory for caching models.
180
- **Note**: This is not supported by MLX Whisper and will be ignored.
181
- model_dir (str, optional): Direct path to a custom model directory.
182
- If specified, it overrides the `modelsize` parameter.
183
- """
184
  from mlx_whisper.transcribe import ModelHolder, transcribe
185
  import mlx.core as mx
186
 
187
  if model_dir is not None:
188
- logger.debug(
189
- f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used."
190
- )
191
  model_size_or_path = model_dir
192
  elif modelsize is not None:
193
  model_size_or_path = self.translate_model_name(modelsize)
194
- logger.debug(
195
- f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used."
196
- )
197
 
198
  self.model_size_or_path = model_size_or_path
199
-
200
- # In mlx_whisper.transcribe, dtype is defined as:
201
- # dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
202
- # Since we do not use decode_options in self.transcribe, we will set dtype to mx.float16
203
- dtype = mx.float16
204
  ModelHolder.get_model(model_size_or_path, dtype)
205
  return transcribe
206
 
207
  def translate_model_name(self, model_name):
208
- """
209
- Translates a given model name to its corresponding MLX-compatible model path.
210
-
211
- Args:
212
- model_name (str): The name of the model to translate.
213
-
214
- Returns:
215
- str: The MLX-compatible model path.
216
- """
217
- # Dictionary mapping model names to MLX-compatible paths
218
  model_mapping = {
219
  "tiny.en": "mlx-community/whisper-tiny.en-mlx",
220
  "tiny": "mlx-community/whisper-tiny-mlx",
@@ -230,16 +185,11 @@ class MLXWhisper(ASRBase):
230
  "large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
231
  "large": "mlx-community/whisper-large-mlx",
232
  }
233
-
234
- # Retrieve the corresponding MLX model path
235
  mlx_model_path = model_mapping.get(model_name)
236
-
237
  if mlx_model_path:
238
  return mlx_model_path
239
  else:
240
- raise ValueError(
241
- f"Model name '{model_name}' is not recognized or not supported."
242
- )
243
 
244
  def transcribe(self, audio, init_prompt=""):
245
  if self.transcribe_kargs:
@@ -254,18 +204,17 @@ class MLXWhisper(ASRBase):
254
  )
255
  return segments.get("segments", [])
256
 
257
- def ts_words(self, segments):
258
- """
259
- Extract timestamped words from transcription segments and skips words with high no-speech probability.
260
- """
261
- return [
262
- (word["start"], word["end"], word["word"])
263
- for segment in segments
264
- for word in segment.get("words", [])
265
- if segment.get("no_speech_prob", 0) <= 0.9
266
- ]
267
-
268
- def segments_end_ts(self, res):
269
  return [s["end"] for s in res]
270
 
271
  def use_vad(self):
@@ -276,68 +225,50 @@ class MLXWhisper(ASRBase):
276
 
277
 
278
  class OpenaiApiASR(ASRBase):
279
- """Uses OpenAI's Whisper API for audio transcription."""
280
-
281
  def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
282
  self.logfile = logfile
283
-
284
  self.modelname = "whisper-1"
285
- self.original_language = (
286
- None if lan == "auto" else lan
287
- ) # ISO-639-1 language code
288
  self.response_format = "verbose_json"
289
  self.temperature = temperature
290
-
291
  self.load_model()
292
-
293
  self.use_vad_opt = False
294
-
295
- # reset the task in set_translate_task
296
  self.task = "transcribe"
297
 
298
  def load_model(self, *args, **kwargs):
299
  from openai import OpenAI
300
-
301
  self.client = OpenAI()
 
302
 
303
- self.transcribed_seconds = (
304
- 0 # for logging how many seconds were processed by API, to know the cost
305
- )
306
-
307
- def ts_words(self, segments):
308
  no_speech_segments = []
309
  if self.use_vad_opt:
310
  for segment in segments.segments:
311
- # TODO: threshold can be set from outside
312
  if segment["no_speech_prob"] > 0.8:
313
- no_speech_segments.append(
314
- (segment.get("start"), segment.get("end"))
315
- )
316
-
317
- o = []
318
  for word in segments.words:
319
  start = word.start
320
  end = word.end
321
  if any(s[0] <= start <= s[1] for s in no_speech_segments):
322
- # print("Skipping word", word.get("word"), "because it's in a no-speech segment")
323
  continue
324
- o.append((start, end, word.word))
325
- return o
326
 
327
- def segments_end_ts(self, res):
328
  return [s.end for s in res.words]
329
 
330
  def transcribe(self, audio_data, prompt=None, *args, **kwargs):
331
- # Write the audio data to a buffer
332
  buffer = io.BytesIO()
333
  buffer.name = "temp.wav"
334
  sf.write(buffer, audio_data, samplerate=16000, format="WAV", subtype="PCM_16")
335
- buffer.seek(0) # Reset buffer's position to the beginning
336
-
337
- self.transcribed_seconds += math.ceil(
338
- len(audio_data) / 16000
339
- ) # it rounds up to the whole seconds
340
-
341
  params = {
342
  "model": self.modelname,
343
  "file": buffer,
@@ -349,22 +280,13 @@ class OpenaiApiASR(ASRBase):
349
  params["language"] = self.original_language
350
  if prompt:
351
  params["prompt"] = prompt
352
-
353
- if self.task == "translate":
354
- proc = self.client.audio.translations
355
- else:
356
- proc = self.client.audio.transcriptions
357
-
358
- # Process transcription/translation
359
  transcript = proc.create(**params)
360
- logger.debug(
361
- f"OpenAI API processed accumulated {self.transcribed_seconds} seconds"
362
- )
363
-
364
  return transcript
365
 
366
  def use_vad(self):
367
  self.use_vad_opt = True
368
 
369
  def set_translate_task(self):
370
- self.task = "translate"
 
1
  import sys
2
  import logging
 
3
  import io
4
  import soundfile as sf
5
  import math
6
  import torch
7
+ from typing import List
8
+ import numpy as np
9
+ from src.whisper_streaming.asr_token import ASRToken
10
 
11
  logger = logging.getLogger(__name__)
12
 
13
  class ASRBase:
14
  sep = " " # join transcribe words with this character (" " for whisper_timestamped,
15
+ # "" for faster-whisper because it emits the spaces when needed)
16
 
17
+ def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
 
 
18
  self.logfile = logfile
 
19
  self.transcribe_kargs = {}
20
  if lan == "auto":
21
  self.original_language = None
22
  else:
23
  self.original_language = lan
 
24
  self.model = self.load_model(modelsize, cache_dir, model_dir)
25
 
26
+ def with_offset(self, offset: float) -> ASRToken:
27
+ # This method is kept for compatibility (typically you will use ASRToken.with_offset)
28
+ return ASRToken(self.start + offset, self.end + offset, self.text)
29
+
30
+ def __repr__(self):
31
+ return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})"
32
+
33
+ def load_model(self, modelsize, cache_dir, model_dir):
34
+ raise NotImplementedError("must be implemented in the child class")
35
 
36
  def transcribe(self, audio, init_prompt=""):
37
+ raise NotImplementedError("must be implemented in the child class")
38
 
39
  def use_vad(self):
40
+ raise NotImplementedError("must be implemented in the child class")
41
 
42
 
43
  class WhisperTimestampedASR(ASRBase):
44
+ """Uses whisper_timestamped as the backend."""
 
 
 
45
  sep = " "
46
 
47
  def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
 
66
  )
67
  return result
68
 
69
+ def ts_words(self, r) -> List[ASRToken]:
70
+ """
71
+ Converts the whisper_timestamped result to a list of ASRToken objects.
72
+ """
73
+ tokens = []
74
+ for segment in r["segments"]:
75
+ for word in segment["words"]:
76
+ token = ASRToken(word["start"], word["end"], word["text"])
77
+ tokens.append(token)
78
+ return tokens
79
 
80
+ def segments_end_ts(self, res) -> List[float]:
81
+ return [segment["end"] for segment in res["segments"]]
82
 
83
  def use_vad(self):
84
  self.transcribe_kargs["vad"] = True
 
88
 
89
 
90
  class FasterWhisperASR(ASRBase):
91
+ """Uses faster-whisper as the backend."""
 
92
  sep = ""
93
 
94
  def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
95
  from faster_whisper import WhisperModel
96
 
 
97
  if model_dir is not None:
98
+ logger.debug(f"Loading whisper model from model_dir {model_dir}. "
99
+ f"modelsize and cache_dir parameters are not used.")
 
100
  model_size_or_path = model_dir
101
  elif modelsize is not None:
102
  model_size_or_path = modelsize
103
  else:
104
+ raise ValueError("Either modelsize or model_dir must be set")
 
105
  device = "cuda" if torch.cuda.is_available() else "cpu"
106
  compute_type = "float16" if device == "cuda" else "float32"
107
 
 
111
  compute_type=compute_type,
112
  download_root=cache_dir,
113
  )
 
 
 
 
 
 
 
 
114
  return model
115
 
116
+ def transcribe(self, audio: np.ndarray, init_prompt: str = "") -> list:
 
 
117
  segments, info = self.model.transcribe(
118
  audio,
119
  language=self.original_language,
 
123
  condition_on_previous_text=True,
124
  **self.transcribe_kargs,
125
  )
 
 
126
  return list(segments)
127
 
128
+ def ts_words(self, segments) -> List[ASRToken]:
129
+ tokens = []
130
  for segment in segments:
131
+ if segment.no_speech_prob > 0.9:
132
+ continue
133
  for word in segment.words:
134
+ token = ASRToken(word.start, word.end, word.word)
135
+ tokens.append(token)
136
+ return tokens
 
 
 
 
137
 
138
+ def segments_end_ts(self, segments) -> List[float]:
139
+ return [segment.end for segment in segments]
140
 
141
  def use_vad(self):
142
  self.transcribe_kargs["vad_filter"] = True
 
147
 
148
  class MLXWhisper(ASRBase):
149
  """
150
+ Uses MLX Whisper optimized for Apple Silicon.
 
 
151
  """
152
+ sep = ""
 
153
 
154
  def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
 
 
 
 
 
 
 
 
 
 
 
 
155
  from mlx_whisper.transcribe import ModelHolder, transcribe
156
  import mlx.core as mx
157
 
158
  if model_dir is not None:
159
+ logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.")
 
 
160
  model_size_or_path = model_dir
161
  elif modelsize is not None:
162
  model_size_or_path = self.translate_model_name(modelsize)
163
+ logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.")
164
+ else:
165
+ raise ValueError("Either modelsize or model_dir must be set")
166
 
167
  self.model_size_or_path = model_size_or_path
168
+ dtype = mx.float16
 
 
 
 
169
  ModelHolder.get_model(model_size_or_path, dtype)
170
  return transcribe
171
 
172
  def translate_model_name(self, model_name):
 
 
 
 
 
 
 
 
 
 
173
  model_mapping = {
174
  "tiny.en": "mlx-community/whisper-tiny.en-mlx",
175
  "tiny": "mlx-community/whisper-tiny-mlx",
 
185
  "large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
186
  "large": "mlx-community/whisper-large-mlx",
187
  }
 
 
188
  mlx_model_path = model_mapping.get(model_name)
 
189
  if mlx_model_path:
190
  return mlx_model_path
191
  else:
192
+ raise ValueError(f"Model name '{model_name}' is not recognized or not supported.")
 
 
193
 
194
  def transcribe(self, audio, init_prompt=""):
195
  if self.transcribe_kargs:
 
204
  )
205
  return segments.get("segments", [])
206
 
207
+ def ts_words(self, segments) -> List[ASRToken]:
208
+ tokens = []
209
+ for segment in segments:
210
+ if segment.get("no_speech_prob", 0) > 0.9:
211
+ continue
212
+ for word in segment.get("words", []):
213
+ token = ASRToken(word["start"], word["end"], word["word"])
214
+ tokens.append(token)
215
+ return tokens
216
+
217
+ def segments_end_ts(self, res) -> List[float]:
 
218
  return [s["end"] for s in res]
219
 
220
  def use_vad(self):
 
225
 
226
 
227
  class OpenaiApiASR(ASRBase):
228
+ """Uses OpenAI's Whisper API for transcription."""
 
229
  def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
230
  self.logfile = logfile
 
231
  self.modelname = "whisper-1"
232
+ self.original_language = None if lan == "auto" else lan
 
 
233
  self.response_format = "verbose_json"
234
  self.temperature = temperature
 
235
  self.load_model()
 
236
  self.use_vad_opt = False
 
 
237
  self.task = "transcribe"
238
 
239
  def load_model(self, *args, **kwargs):
240
  from openai import OpenAI
 
241
  self.client = OpenAI()
242
+ self.transcribed_seconds = 0
243
 
244
+ def ts_words(self, segments) -> List[ASRToken]:
245
+ """
246
+ Converts OpenAI API response words into ASRToken objects while
247
+ optionally skipping words that fall into no-speech segments.
248
+ """
249
  no_speech_segments = []
250
  if self.use_vad_opt:
251
  for segment in segments.segments:
 
252
  if segment["no_speech_prob"] > 0.8:
253
+ no_speech_segments.append((segment.get("start"), segment.get("end")))
254
+ tokens = []
 
 
 
255
  for word in segments.words:
256
  start = word.start
257
  end = word.end
258
  if any(s[0] <= start <= s[1] for s in no_speech_segments):
 
259
  continue
260
+ tokens.append(ASRToken(start, end, word.word))
261
+ return tokens
262
 
263
+ def segments_end_ts(self, res) -> List[float]:
264
  return [s.end for s in res.words]
265
 
266
  def transcribe(self, audio_data, prompt=None, *args, **kwargs):
 
267
  buffer = io.BytesIO()
268
  buffer.name = "temp.wav"
269
  sf.write(buffer, audio_data, samplerate=16000, format="WAV", subtype="PCM_16")
270
+ buffer.seek(0)
271
+ self.transcribed_seconds += math.ceil(len(audio_data) / 16000)
 
 
 
 
272
  params = {
273
  "model": self.modelname,
274
  "file": buffer,
 
280
  params["language"] = self.original_language
281
  if prompt:
282
  params["prompt"] = prompt
283
+ proc = self.client.audio.translations if self.task == "translate" else self.client.audio.transcriptions
 
 
 
 
 
 
284
  transcript = proc.create(**params)
285
+ logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
 
 
 
286
  return transcript
287
 
288
  def use_vad(self):
289
  self.use_vad_opt = True
290
 
291
  def set_translate_task(self):
292
+ self.task = "translate"