warning when transcribe_kargs are used with MLX Whisper
Browse files- whisper_online.py +16 -7
whisper_online.py
CHANGED
|
@@ -201,7 +201,8 @@ class MLXWhisper(ASRBase):
|
|
| 201 |
model_dir (str, optional): Direct path to a custom model directory.
|
| 202 |
If specified, it overrides the `modelsize` parameter.
|
| 203 |
"""
|
| 204 |
-
from mlx_whisper import transcribe
|
|
|
|
| 205 |
|
| 206 |
if model_dir is not None:
|
| 207 |
logger.debug(
|
|
@@ -215,6 +216,12 @@ class MLXWhisper(ASRBase):
|
|
| 215 |
)
|
| 216 |
|
| 217 |
self.model_size_or_path = model_size_or_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
return transcribe
|
| 219 |
|
| 220 |
def translate_model_name(self, model_name):
|
|
@@ -255,6 +262,8 @@ class MLXWhisper(ASRBase):
|
|
| 255 |
)
|
| 256 |
|
| 257 |
def transcribe(self, audio, init_prompt=""):
|
|
|
|
|
|
|
| 258 |
segments = self.model(
|
| 259 |
audio,
|
| 260 |
language=self.original_language,
|
|
@@ -262,7 +271,6 @@ class MLXWhisper(ASRBase):
|
|
| 262 |
word_timestamps=True,
|
| 263 |
condition_on_previous_text=True,
|
| 264 |
path_or_hf_repo=self.model_size_or_path,
|
| 265 |
-
**self.transcribe_kargs,
|
| 266 |
)
|
| 267 |
return segments.get("segments", [])
|
| 268 |
|
|
@@ -844,7 +852,7 @@ def add_shared_args(parser):
|
|
| 844 |
parser.add_argument(
|
| 845 |
"--model",
|
| 846 |
type=str,
|
| 847 |
-
default="
|
| 848 |
choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo".split(
|
| 849 |
","
|
| 850 |
),
|
|
@@ -879,14 +887,14 @@ def add_shared_args(parser):
|
|
| 879 |
parser.add_argument(
|
| 880 |
"--backend",
|
| 881 |
type=str,
|
| 882 |
-
default="
|
| 883 |
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
|
| 884 |
help="Load only this backend for Whisper processing.",
|
| 885 |
)
|
| 886 |
parser.add_argument(
|
| 887 |
"--vac",
|
| 888 |
action="store_true",
|
| 889 |
-
default=
|
| 890 |
help="Use VAC = voice activity controller. Recommended. Requires torch.",
|
| 891 |
)
|
| 892 |
parser.add_argument(
|
|
@@ -895,7 +903,7 @@ def add_shared_args(parser):
|
|
| 895 |
parser.add_argument(
|
| 896 |
"--vad",
|
| 897 |
action="store_true",
|
| 898 |
-
default=
|
| 899 |
help="Use VAD = voice activity detection, with the default parameters.",
|
| 900 |
)
|
| 901 |
parser.add_argument(
|
|
@@ -1006,8 +1014,9 @@ if __name__ == "__main__":
|
|
| 1006 |
|
| 1007 |
parser = argparse.ArgumentParser()
|
| 1008 |
parser.add_argument(
|
| 1009 |
-
"audio_path",
|
| 1010 |
type=str,
|
|
|
|
| 1011 |
help="Filename of 16kHz mono channel wav, on which live streaming is simulated.",
|
| 1012 |
)
|
| 1013 |
add_shared_args(parser)
|
|
|
|
| 201 |
model_dir (str, optional): Direct path to a custom model directory.
|
| 202 |
If specified, it overrides the `modelsize` parameter.
|
| 203 |
"""
|
| 204 |
+
from mlx_whisper.transcribe import ModelHolder, transcribe
|
| 205 |
+
import mlx.core as mx
|
| 206 |
|
| 207 |
if model_dir is not None:
|
| 208 |
logger.debug(
|
|
|
|
| 216 |
)
|
| 217 |
|
| 218 |
self.model_size_or_path = model_size_or_path
|
| 219 |
+
|
| 220 |
+
# In mlx_whisper.transcribe, dtype is defined as:
|
| 221 |
+
# dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
|
| 222 |
+
# Since we do not use decode_options in self.transcribe, we will set dtype to mx.float16
|
| 223 |
+
dtype = mx.float16
|
| 224 |
+
ModelHolder.get_model(model_size_or_path, dtype)
|
| 225 |
return transcribe
|
| 226 |
|
| 227 |
def translate_model_name(self, model_name):
|
|
|
|
| 262 |
)
|
| 263 |
|
| 264 |
def transcribe(self, audio, init_prompt=""):
|
| 265 |
+
if self.transcribe_kargs:
|
| 266 |
+
logger.warning("Transcribe kwargs (vad, task) are not compatible with MLX Whisper and will be ignored.")
|
| 267 |
segments = self.model(
|
| 268 |
audio,
|
| 269 |
language=self.original_language,
|
|
|
|
| 271 |
word_timestamps=True,
|
| 272 |
condition_on_previous_text=True,
|
| 273 |
path_or_hf_repo=self.model_size_or_path,
|
|
|
|
| 274 |
)
|
| 275 |
return segments.get("segments", [])
|
| 276 |
|
|
|
|
| 852 |
parser.add_argument(
|
| 853 |
"--model",
|
| 854 |
type=str,
|
| 855 |
+
default="tiny",
|
| 856 |
choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo".split(
|
| 857 |
","
|
| 858 |
),
|
|
|
|
| 887 |
parser.add_argument(
|
| 888 |
"--backend",
|
| 889 |
type=str,
|
| 890 |
+
default="mlx-whisper",
|
| 891 |
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
|
| 892 |
help="Load only this backend for Whisper processing.",
|
| 893 |
)
|
| 894 |
parser.add_argument(
|
| 895 |
"--vac",
|
| 896 |
action="store_true",
|
| 897 |
+
default=True,
|
| 898 |
help="Use VAC = voice activity controller. Recommended. Requires torch.",
|
| 899 |
)
|
| 900 |
parser.add_argument(
|
|
|
|
| 903 |
parser.add_argument(
|
| 904 |
"--vad",
|
| 905 |
action="store_true",
|
| 906 |
+
default=True,
|
| 907 |
help="Use VAD = voice activity detection, with the default parameters.",
|
| 908 |
)
|
| 909 |
parser.add_argument(
|
|
|
|
| 1014 |
|
| 1015 |
parser = argparse.ArgumentParser()
|
| 1016 |
parser.add_argument(
|
| 1017 |
+
"--audio_path",
|
| 1018 |
type=str,
|
| 1019 |
+
default='samples_jfk.wav',
|
| 1020 |
help="Filename of 16kHz mono channel wav, on which live streaming is simulated.",
|
| 1021 |
)
|
| 1022 |
add_shared_args(parser)
|