add translate_model_name function
Browse files- whisper_online.py +51 -7
whisper_online.py
CHANGED
|
@@ -160,27 +160,71 @@ class MLXWhisper(ASRBase):
|
|
| 160 |
"""
|
| 161 |
Uses MPX Whisper library as the backend, optimized for Apple Silicon.
|
| 162 |
Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc
|
| 163 |
-
Significantly faster than faster-whisper (without CUDA) on Apple M1.
|
| 164 |
"""
|
| 165 |
|
| 166 |
sep = " "
|
| 167 |
|
| 168 |
-
def load_model(self, modelsize=None, model_dir=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
from mlx_whisper import transcribe
|
| 170 |
|
| 171 |
if model_dir is not None:
|
| 172 |
logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.")
|
| 173 |
model_size_or_path = model_dir
|
| 174 |
elif modelsize is not None:
|
| 175 |
-
|
| 176 |
-
model_size_or_path
|
| 177 |
-
elif modelsize == None:
|
| 178 |
-
logger.debug("No model size or path specified. Using mlx-community/whisper-large-v3-mlx.")
|
| 179 |
-
model_size_or_path = "mlx-community/whisper-large-v3-mlx"
|
| 180 |
|
| 181 |
self.model_size_or_path = model_size_or_path
|
| 182 |
return transcribe
|
| 183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
def transcribe(self, audio, init_prompt=""):
|
| 185 |
segments = self.model(
|
| 186 |
audio,
|
|
|
|
| 160 |
"""
|
| 161 |
Uses MPX Whisper library as the backend, optimized for Apple Silicon.
|
| 162 |
Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc
|
| 163 |
+
Significantly faster than faster-whisper (without CUDA) on Apple M1.
|
| 164 |
"""
|
| 165 |
|
| 166 |
sep = " "
|
| 167 |
|
| 168 |
+
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
| 169 |
+
"""
|
| 170 |
+
Loads the MLX-compatible Whisper model.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
modelsize (str, optional): The size or name of the Whisper model to load.
|
| 174 |
+
If provided, it will be translated to an MLX-compatible model path using the `translate_model_name` method.
|
| 175 |
+
Example: "large-v3-turbo" -> "mlx-community/whisper-large-v3-turbo".
|
| 176 |
+
cache_dir (str, optional): Path to the directory for caching models.
|
| 177 |
+
**Note**: This is not supported by MLX Whisper and will be ignored.
|
| 178 |
+
model_dir (str, optional): Direct path to a custom model directory.
|
| 179 |
+
If specified, it overrides the `modelsize` parameter.
|
| 180 |
+
"""
|
| 181 |
from mlx_whisper import transcribe
|
| 182 |
|
| 183 |
if model_dir is not None:
|
| 184 |
logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.")
|
| 185 |
model_size_or_path = model_dir
|
| 186 |
elif modelsize is not None:
|
| 187 |
+
model_size_or_path = self.translate_model_name(modelsize)
|
| 188 |
+
logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.")
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
self.model_size_or_path = model_size_or_path
|
| 191 |
return transcribe
|
| 192 |
|
| 193 |
+
def translate_model_name(self, model_name):
|
| 194 |
+
"""
|
| 195 |
+
Translates a given model name to its corresponding MLX-compatible model path.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
model_name (str): The name of the model to translate.
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
str: The MLX-compatible model path.
|
| 202 |
+
"""
|
| 203 |
+
# Dictionary mapping model names to MLX-compatible paths
|
| 204 |
+
model_mapping = {
|
| 205 |
+
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
|
| 206 |
+
"tiny": "mlx-community/whisper-tiny-mlx",
|
| 207 |
+
"base.en": "mlx-community/whisper-base.en-mlx",
|
| 208 |
+
"base": "mlx-community/whisper-base-mlx",
|
| 209 |
+
"small.en": "mlx-community/whisper-small.en-mlx",
|
| 210 |
+
"small": "mlx-community/whisper-small-mlx",
|
| 211 |
+
"medium.en": "mlx-community/whisper-medium.en-mlx",
|
| 212 |
+
"medium": "mlx-community/whisper-medium-mlx",
|
| 213 |
+
"large-v1": "mlx-community/whisper-large-v1-mlx",
|
| 214 |
+
"large-v2": "mlx-community/whisper-large-v2-mlx",
|
| 215 |
+
"large-v3": "mlx-community/whisper-large-v3-mlx",
|
| 216 |
+
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
|
| 217 |
+
"large": "mlx-community/whisper-large-mlx"
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
# Retrieve the corresponding MLX model path
|
| 221 |
+
mlx_model_path = model_mapping.get(model_name)
|
| 222 |
+
|
| 223 |
+
if mlx_model_path:
|
| 224 |
+
return mlx_model_path
|
| 225 |
+
else:
|
| 226 |
+
raise ValueError(f"Model name '{model_name}' is not recognized or not supported.")
|
| 227 |
+
|
| 228 |
def transcribe(self, audio, init_prompt=""):
|
| 229 |
segments = self.model(
|
| 230 |
audio,
|