Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # This module is modified from [Whisper](https://github.com/openai/whisper.git). | |
| # ## Citations | |
| # ```bibtex | |
| # @inproceedings{openai-whisper, | |
| # author = {Alec Radford and | |
| # Jong Wook Kim and | |
| # Tao Xu and | |
| # Greg Brockman and | |
| # Christine McLeavey and | |
| # Ilya Sutskever}, | |
| # title = {Robust Speech Recognition via Large-Scale Weak Supervision}, | |
| # booktitle = {{ICML}}, | |
| # series = {Proceedings of Machine Learning Research}, | |
| # volume = {202}, | |
| # pages = {28492--28518}, | |
| # publisher = {{PMLR}}, | |
| # year = {2023} | |
| # } | |
| # ``` | |
| # | |
| import hashlib | |
| import io | |
| import os | |
| import urllib | |
| import warnings | |
| from typing import List, Optional, Union | |
| import torch | |
| from tqdm import tqdm | |
| from .audio import load_audio, log_mel_spectrogram, pad_or_trim | |
| from .decoding import DecodingOptions, DecodingResult, decode, detect_language | |
| from .model import Whisper, ModelDimensions | |
| from .transcribe import transcribe | |
| from .version import __version__ | |
| _MODELS = { | |
| "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", | |
| "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", | |
| "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", | |
| "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", | |
| "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", | |
| "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", | |
| "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", | |
| "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", | |
| "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", | |
| "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", | |
| "large": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", | |
| } | |
| def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: | |
| os.makedirs(root, exist_ok=True) | |
| expected_sha256 = url.split("/")[-2] | |
| download_target = os.path.join(root, os.path.basename(url)) | |
| if os.path.exists(download_target) and not os.path.isfile(download_target): | |
| raise RuntimeError(f"{download_target} exists and is not a regular file") | |
| if os.path.isfile(download_target): | |
| with open(download_target, "rb") as f: | |
| model_bytes = f.read() | |
| if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: | |
| return model_bytes if in_memory else download_target | |
| else: | |
| warnings.warn( | |
| f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" | |
| ) | |
| with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: | |
| with tqdm( | |
| total=int(source.info().get("Content-Length")), | |
| ncols=80, | |
| unit="iB", | |
| unit_scale=True, | |
| unit_divisor=1024, | |
| ) as loop: | |
| while True: | |
| buffer = source.read(8192) | |
| if not buffer: | |
| break | |
| output.write(buffer) | |
| loop.update(len(buffer)) | |
| model_bytes = open(download_target, "rb").read() | |
| if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: | |
| raise RuntimeError( | |
| "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." | |
| ) | |
| return model_bytes if in_memory else download_target | |
| def available_models() -> List[str]: | |
| """Returns the names of available models""" | |
| return list(_MODELS.keys()) | |
| def load_model( | |
| name: str, | |
| device: Optional[Union[str, torch.device]] = None, | |
| download_root: str = None, | |
| in_memory: bool = False, | |
| checkpoint_file=None, | |
| ) -> Whisper: | |
| """ | |
| Load a Whisper ASR model | |
| Parameters | |
| ---------- | |
| name : str | |
| one of the official model names listed by `whisper.available_models()`, or | |
| path to a model checkpoint containing the model dimensions and the model state_dict. | |
| device : Union[str, torch.device] | |
| the PyTorch device to put the model into | |
| download_root: str | |
| path to download the model files; by default, it uses "~/.cache/whisper" | |
| in_memory: bool | |
| whether to preload the model weights into host memory | |
| Returns | |
| ------- | |
| model : Whisper | |
| The Whisper ASR model instance | |
| """ | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if download_root is None: | |
| download_root = os.getenv( | |
| "XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache", "whisper") | |
| ) | |
| if not os.path.exists(checkpoint_file): | |
| if name in _MODELS: | |
| checkpoint_file = _download(_MODELS[name], download_root, in_memory) | |
| elif os.path.isfile(name): | |
| checkpoint_file = open(name, "rb").read() if in_memory else name | |
| else: | |
| raise RuntimeError( | |
| f"Model {name} not found; available models = {available_models()}" | |
| ) | |
| else: | |
| checkpoint_file = ( | |
| open(checkpoint_file, "rb").read() if in_memory else checkpoint_file | |
| ) | |
| with ( | |
| io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") | |
| ) as fp: | |
| checkpoint = torch.load(fp, map_location=device) | |
| del checkpoint_file | |
| dims = ModelDimensions(**checkpoint["dims"]) | |
| model = Whisper(dims) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| return model.to(device) | |