Spaces:
Runtime error
Runtime error
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| This script serves three goals: | |
| (1) Demonstrate how to use NeMo Models outside of PytorchLightning | |
| (2) Shows example of batch ASR inference | |
| (3) Serves as CI test for pre-trained checkpoint | |
| python speech_to_text_buffered_infer_ctc.py \ | |
| model_path=null \ | |
| pretrained_name=null \ | |
| audio_dir="<remove or path to folder of audio files>" \ | |
| dataset_manifest="<remove or path to manifest>" \ | |
| output_filename="<remove or specify output filename>" \ | |
| total_buffer_in_secs=4.0 \ | |
| chunk_len_in_secs=1.6 \ | |
| model_stride=4 \ | |
| batch_size=32 | |
| # NOTE: | |
| You can use `DEBUG=1 python speech_to_text_buffered_infer_ctc.py ...` to print out the | |
| predictions of the model, and ground-truth text if presents in manifest. | |
| """ | |
| import contextlib | |
| import copy | |
| import glob | |
| import math | |
| import os | |
| from dataclasses import dataclass, is_dataclass | |
| from typing import Optional | |
| import torch | |
| from omegaconf import OmegaConf | |
| from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchASR | |
| from nemo.collections.asr.parts.utils.transcribe_utils import ( | |
| compute_output_filename, | |
| get_buffered_pred_feat, | |
| setup_model, | |
| write_transcription, | |
| ) | |
| from nemo.core.config import hydra_runner | |
| from nemo.utils import logging | |
| can_gpu = torch.cuda.is_available() | |
| class TranscriptionConfig: | |
| # Required configs | |
| model_path: Optional[str] = None # Path to a .nemo file | |
| pretrained_name: Optional[str] = None # Name of a pretrained model | |
| audio_dir: Optional[str] = None # Path to a directory which contains audio files | |
| dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest | |
| # General configs | |
| output_filename: Optional[str] = None | |
| batch_size: int = 32 | |
| num_workers: int = 0 | |
| append_pred: bool = False # Sets mode of work, if True it will add new field transcriptions. | |
| pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one. | |
| # Chunked configs | |
| chunk_len_in_secs: float = 1.6 # Chunk length in seconds | |
| total_buffer_in_secs: float = 4.0 # Length of buffer (chunk + left and right padding) in seconds | |
| model_stride: int = 8 # Model downsampling factor, 8 for Citrinet models and 4 for Conformer models", | |
| # Set `cuda` to int to define CUDA device. If 'None', will look for CUDA | |
| # device anyway, and do inference on CPU only if CUDA device is not found. | |
| # If `cuda` is a negative number, inference will be on CPU only. | |
| cuda: Optional[int] = None | |
| amp: bool = False | |
| audio_type: str = "wav" | |
| # Recompute model transcription, even if the output folder exists with scores. | |
| overwrite_transcripts: bool = True | |
| def main(cfg: TranscriptionConfig) -> TranscriptionConfig: | |
| logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') | |
| torch.set_grad_enabled(False) | |
| if is_dataclass(cfg): | |
| cfg = OmegaConf.structured(cfg) | |
| if cfg.model_path is None and cfg.pretrained_name is None: | |
| raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!") | |
| if cfg.audio_dir is None and cfg.dataset_manifest is None: | |
| raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!") | |
| filepaths = None | |
| manifest = cfg.dataset_manifest | |
| if cfg.audio_dir is not None: | |
| filepaths = list(glob.glob(os.path.join(cfg.audio_dir, f"**/*.{cfg.audio_type}"), recursive=True)) | |
| manifest = None # ignore dataset_manifest if audio_dir and dataset_manifest both presents | |
| # setup GPU | |
| if cfg.cuda is None: | |
| if torch.cuda.is_available(): | |
| device = [0] # use 0th CUDA device | |
| accelerator = 'gpu' | |
| else: | |
| device = 1 | |
| accelerator = 'cpu' | |
| else: | |
| device = [cfg.cuda] | |
| accelerator = 'gpu' | |
| map_location = torch.device('cuda:{}'.format(device[0]) if accelerator == 'gpu' else 'cpu') | |
| logging.info(f"Inference will be done on device : {device}") | |
| asr_model, model_name = setup_model(cfg, map_location) | |
| model_cfg = copy.deepcopy(asr_model._cfg) | |
| OmegaConf.set_struct(model_cfg.preprocessor, False) | |
| # some changes for streaming scenario | |
| model_cfg.preprocessor.dither = 0.0 | |
| model_cfg.preprocessor.pad_to = 0 | |
| if model_cfg.preprocessor.normalize != "per_feature": | |
| logging.error("Only EncDecCTCModelBPE models trained with per_feature normalization are supported currently") | |
| # Disable config overwriting | |
| OmegaConf.set_struct(model_cfg.preprocessor, True) | |
| # setup AMP (optional) | |
| if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): | |
| logging.info("AMP enabled!\n") | |
| autocast = torch.cuda.amp.autocast | |
| else: | |
| def autocast(): | |
| yield | |
| # Compute output filename | |
| cfg = compute_output_filename(cfg, model_name) | |
| # if transcripts should not be overwritten, and already exists, skip re-transcription step and return | |
| if not cfg.overwrite_transcripts and os.path.exists(cfg.output_filename): | |
| logging.info( | |
| f"Previous transcripts found at {cfg.output_filename}, and flag `overwrite_transcripts`" | |
| f"is {cfg.overwrite_transcripts}. Returning without re-transcribing text." | |
| ) | |
| return cfg | |
| asr_model.eval() | |
| asr_model = asr_model.to(asr_model.device) | |
| feature_stride = model_cfg.preprocessor['window_stride'] | |
| model_stride_in_secs = feature_stride * cfg.model_stride | |
| total_buffer = cfg.total_buffer_in_secs | |
| chunk_len = float(cfg.chunk_len_in_secs) | |
| tokens_per_chunk = math.ceil(chunk_len / model_stride_in_secs) | |
| mid_delay = math.ceil((chunk_len + (total_buffer - chunk_len) / 2) / model_stride_in_secs) | |
| logging.info(f"tokens_per_chunk is {tokens_per_chunk}, mid_delay is {mid_delay}") | |
| frame_asr = FrameBatchASR( | |
| asr_model=asr_model, frame_len=chunk_len, total_buffer=cfg.total_buffer_in_secs, batch_size=cfg.batch_size, | |
| ) | |
| hyps = get_buffered_pred_feat( | |
| frame_asr, | |
| chunk_len, | |
| tokens_per_chunk, | |
| mid_delay, | |
| model_cfg.preprocessor, | |
| model_stride_in_secs, | |
| asr_model.device, | |
| manifest, | |
| filepaths, | |
| ) | |
| output_filename = write_transcription(hyps, cfg, model_name, filepaths=filepaths, compute_langs=False) | |
| logging.info(f"Finished writing predictions to {output_filename}!") | |
| return cfg | |
| if __name__ == '__main__': | |
| main() # noqa pylint: disable=no-value-for-parameter | |