Spaces:
Configuration error
Configuration error
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| import pathlib | |
| import string | |
| import time | |
| from multiprocessing import Pool, Value, Lock | |
| from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor | |
| import torch | |
| import whisper | |
| processed_files_count = Value("i", 0) # count of processed files | |
| lock = Lock() # lock for the count | |
| def preprocess_text(text): | |
| """Preprocess text after ASR""" | |
| return text.lower().translate(str.maketrans("", "", string.punctuation)) | |
| def transcribe_audio(model, processor, audio_file, device): | |
| """Transcribe audio file""" | |
| audio = whisper.load_audio(audio_file) # load from path | |
| audio = whisper.pad_or_trim(audio) # default 30 seconds | |
| inputs = whisper.log_mel_spectrogram(audio).to( | |
| device=device | |
| ) # convert to spectrogram | |
| inputs = inputs.unsqueeze(0).type(torch.cuda.HalfTensor) # add batch dimension | |
| outputs = model.generate( | |
| inputs=inputs, max_new_tokens=128 | |
| ) # generate transcription | |
| transcription = processor.batch_decode(outputs, skip_special_tokens=True)[ | |
| 0 | |
| ] # decode | |
| transcription_processed = preprocess_text(transcription) # preprocess | |
| return transcription_processed | |
| def write_transcription(audio_file, transcription): | |
| """Write transcription to txt file""" | |
| txt_file = audio_file.with_suffix(".txt") | |
| with open(txt_file, "w") as file: | |
| file.write(transcription) | |
| def init_whisper(model_id, device): | |
| """Initialize whisper model and processor""" | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| print(f"Loading model {model_id}") # model_id = "distil-whisper/distil-large-v2" | |
| distil_model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=False | |
| ) | |
| distil_model = distil_model.to(device) | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| return distil_model, processor | |
| def asr_wav_files(file_list, gpu_id, total_files, model_id): | |
| """Transcribe wav files in a list""" | |
| device = f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu" | |
| whisper_model, processor = init_whisper(model_id, device) | |
| print(f"Processing on {device} starts") | |
| start_time = time.time() | |
| for audio_file in file_list: | |
| try: | |
| transcription = transcribe_audio( | |
| whisper_model, processor, audio_file, device | |
| ) | |
| write_transcription(audio_file, transcription) | |
| with lock: | |
| processed_files_count.value += 1 | |
| if processed_files_count.value % 5 == 0: | |
| current_time = time.time() | |
| avg_time_per_file = (current_time - start_time) / ( | |
| processed_files_count.value | |
| ) | |
| remaining_files = total_files - processed_files_count.value | |
| estimated_time_remaining = avg_time_per_file * remaining_files | |
| remaining_time_formatted = time.strftime( | |
| "%H:%M:%S", time.gmtime(estimated_time_remaining) | |
| ) | |
| print( | |
| f"Processed {processed_files_count.value}/{total_files} files, time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}, Estimated time remaining: {remaining_time_formatted}" | |
| ) | |
| except Exception as e: | |
| print(f"Error processing file {audio_file}: {e}") | |
| def asr_main(input_dir, num_gpus, model_id): | |
| """Transcribe wav files in a directory""" | |
| num_processes = min(num_gpus, os.cpu_count()) | |
| print(f"Using {num_processes} GPUs for transcription") | |
| wav_files = list(pathlib.Path(input_dir).rglob("*.wav")) | |
| total_files = len(wav_files) | |
| print(f"Found {total_files} wav files in {input_dir}") | |
| files_per_process = len(wav_files) // num_processes | |
| print(f"Processing {files_per_process} files per process") | |
| with Pool(num_processes) as p: | |
| p.starmap( | |
| asr_wav_files, | |
| [ | |
| ( | |
| wav_files[i * files_per_process : (i + 1) * files_per_process], | |
| i % num_gpus, | |
| total_files, | |
| model_id, | |
| ) | |
| for i in range(num_processes) | |
| ], | |
| ) | |
| print("Done!") | |
| if __name__ == "__main__": | |
| input_dir = "/path/to/output/directory" | |
| num_gpus = 2 | |
| model_id = "distil-whisper/distil-large-v2" | |
| asr_main(input_dir, num_gpus, model_id) | |