add whisper mlx backend
Browse files- whisper_online.py +60 -1
whisper_online.py
CHANGED
|
@@ -156,6 +156,63 @@ class FasterWhisperASR(ASRBase):
|
|
| 156 |
def set_translate_task(self):
|
| 157 |
self.transcribe_kargs["task"] = "translate"
|
| 158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
class OpenaiApiASR(ASRBase):
|
| 161 |
"""Uses OpenAI's Whisper API for audio transcription."""
|
|
@@ -660,7 +717,7 @@ def add_shared_args(parser):
|
|
| 660 |
parser.add_argument('--model_dir', type=str, default=None, help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.")
|
| 661 |
parser.add_argument('--lan', '--language', type=str, default='auto', help="Source language code, e.g. en,de,cs, or 'auto' for language detection.")
|
| 662 |
parser.add_argument('--task', type=str, default='transcribe', choices=["transcribe","translate"],help="Transcribe or translate.")
|
| 663 |
-
parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped", "openai-api"],help='Load only this backend for Whisper processing.')
|
| 664 |
parser.add_argument('--vac', action="store_true", default=False, help='Use VAC = voice activity controller. Recommended. Requires torch.')
|
| 665 |
parser.add_argument('--vac-chunk-size', type=float, default=0.04, help='VAC sample size in seconds.')
|
| 666 |
parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.')
|
|
@@ -679,6 +736,8 @@ def asr_factory(args, logfile=sys.stderr):
|
|
| 679 |
else:
|
| 680 |
if backend == "faster-whisper":
|
| 681 |
asr_cls = FasterWhisperASR
|
|
|
|
|
|
|
| 682 |
else:
|
| 683 |
asr_cls = WhisperTimestampedASR
|
| 684 |
|
|
|
|
| 156 |
def set_translate_task(self):
|
| 157 |
self.transcribe_kargs["task"] = "translate"
|
| 158 |
|
| 159 |
+
class MLXWhisper(ASRBase):
|
| 160 |
+
"""
|
| 161 |
+
Uses MPX Whisper library as the backend, optimized for Apple Silicon.
|
| 162 |
+
Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc
|
| 163 |
+
Significantly faster than faster-whisper (without CUDA) on Apple M1. Model used by default: mlx-community/whisper-large-v3-mlx
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
sep = " "
|
| 167 |
+
|
| 168 |
+
def load_model(self, modelsize=None, model_dir=None):
|
| 169 |
+
from mlx_whisper import transcribe
|
| 170 |
+
|
| 171 |
+
if model_dir is not None:
|
| 172 |
+
logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.")
|
| 173 |
+
model_size_or_path = model_dir
|
| 174 |
+
elif modelsize is not None:
|
| 175 |
+
logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so make sure you use a mlx-compatible model.")
|
| 176 |
+
model_size_or_path = modelsize
|
| 177 |
+
elif modelsize == None:
|
| 178 |
+
logger.debug("No model size or path specified. Using mlx-community/whisper-large-v3-mlx.")
|
| 179 |
+
model_size_or_path = "mlx-community/whisper-large-v3-mlx"
|
| 180 |
+
|
| 181 |
+
self.model_size_or_path = model_size_or_path
|
| 182 |
+
return transcribe
|
| 183 |
+
|
| 184 |
+
def transcribe(self, audio, init_prompt=""):
|
| 185 |
+
segments = self.model(
|
| 186 |
+
audio,
|
| 187 |
+
language=self.original_language,
|
| 188 |
+
initial_prompt=init_prompt,
|
| 189 |
+
word_timestamps=True,
|
| 190 |
+
condition_on_previous_text=True,
|
| 191 |
+
path_or_hf_repo=self.model_size_or_path,
|
| 192 |
+
**self.transcribe_kargs
|
| 193 |
+
)
|
| 194 |
+
return segments.get("segments", [])
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def ts_words(self, segments):
|
| 198 |
+
"""
|
| 199 |
+
Extract timestamped words from transcription segments and skips words with high no-speech probability.
|
| 200 |
+
"""
|
| 201 |
+
return [
|
| 202 |
+
(word["start"], word["end"], word["word"])
|
| 203 |
+
for segment in segments
|
| 204 |
+
for word in segment.get("words", [])
|
| 205 |
+
if segment.get("no_speech_prob", 0) <= 0.9
|
| 206 |
+
]
|
| 207 |
+
|
| 208 |
+
def segments_end_ts(self, res):
|
| 209 |
+
return [s['end'] for s in res]
|
| 210 |
+
|
| 211 |
+
def use_vad(self):
|
| 212 |
+
self.transcribe_kargs["vad_filter"] = True
|
| 213 |
+
|
| 214 |
+
def set_translate_task(self):
|
| 215 |
+
self.transcribe_kargs["task"] = "translate"
|
| 216 |
|
| 217 |
class OpenaiApiASR(ASRBase):
|
| 218 |
"""Uses OpenAI's Whisper API for audio transcription."""
|
|
|
|
| 717 |
parser.add_argument('--model_dir', type=str, default=None, help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.")
|
| 718 |
parser.add_argument('--lan', '--language', type=str, default='auto', help="Source language code, e.g. en,de,cs, or 'auto' for language detection.")
|
| 719 |
parser.add_argument('--task', type=str, default='transcribe', choices=["transcribe","translate"],help="Transcribe or translate.")
|
| 720 |
+
parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],help='Load only this backend for Whisper processing.')
|
| 721 |
parser.add_argument('--vac', action="store_true", default=False, help='Use VAC = voice activity controller. Recommended. Requires torch.')
|
| 722 |
parser.add_argument('--vac-chunk-size', type=float, default=0.04, help='VAC sample size in seconds.')
|
| 723 |
parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.')
|
|
|
|
| 736 |
else:
|
| 737 |
if backend == "faster-whisper":
|
| 738 |
asr_cls = FasterWhisperASR
|
| 739 |
+
elif backend == "mlx-whisper":
|
| 740 |
+
asr_cls = MLXWhisper
|
| 741 |
else:
|
| 742 |
asr_cls = WhisperTimestampedASR
|
| 743 |
|