Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| # Copyright 2025 Xiaomi Corp. (authors: Han Zhu | |
| # Wei Kang) | |
| # | |
| # See ../../../../LICENSE for clarification regarding multiple authors | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Calculate WER with Whisper-large-v3 or Paraformer models, | |
| following Seed-TTS https://github.com/BytedanceSpeech/seed-tts-eval | |
| """ | |
| import argparse | |
| import os | |
| import string | |
| import numpy as np | |
| import scipy | |
| import soundfile as sf | |
| import torch | |
| import zhconv | |
| from funasr import AutoModel | |
| from jiwer import compute_measures | |
| from tqdm import tqdm | |
| from transformers import WhisperForConditionalGeneration, WhisperProcessor | |
| from zhon.hanzi import punctuation | |
| def get_parser(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--wav-path", type=str, help="path of the speech directory") | |
| parser.add_argument( | |
| "--decode-path", | |
| type=str, | |
| default=None, | |
| help="path of the output file of WER information", | |
| ) | |
| parser.add_argument( | |
| "--model-path", | |
| type=str, | |
| default=None, | |
| help="path of the local whisper and paraformer model, " | |
| "e.g., whisper: model/huggingface/whisper-large-v3/, " | |
| "paraformer: model/huggingface/paraformer-zh/", | |
| ) | |
| parser.add_argument( | |
| "--test-list", | |
| type=str, | |
| default="test.tsv", | |
| help="path of the transcript tsv file, where the first column " | |
| "is the wav name and the last column is the transcript", | |
| ) | |
| parser.add_argument("--lang", type=str, help="decoded language, zh or en") | |
| return parser | |
| def load_en_model(model_path): | |
| if model_path is None: | |
| model_path = "openai/whisper-large-v3" | |
| processor = WhisperProcessor.from_pretrained(model_path) | |
| model = WhisperForConditionalGeneration.from_pretrained(model_path) | |
| return processor, model | |
| def load_zh_model(model_path): | |
| if model_path is None: | |
| model_path = "paraformer-zh" | |
| model = AutoModel(model=model_path) | |
| return model | |
| def process_one(hypo, truth, lang): | |
| punctuation_all = punctuation + string.punctuation | |
| for x in punctuation_all: | |
| if x == "'": | |
| continue | |
| truth = truth.replace(x, "") | |
| hypo = hypo.replace(x, "") | |
| truth = truth.replace(" ", " ") | |
| hypo = hypo.replace(" ", " ") | |
| if lang == "zh": | |
| truth = " ".join([x for x in truth]) | |
| hypo = " ".join([x for x in hypo]) | |
| elif lang == "en": | |
| truth = truth.lower() | |
| hypo = hypo.lower() | |
| else: | |
| raise NotImplementedError | |
| measures = compute_measures(truth, hypo) | |
| word_num = len(truth.split(" ")) | |
| wer = measures["wer"] | |
| subs = measures["substitutions"] | |
| dele = measures["deletions"] | |
| inse = measures["insertions"] | |
| return (truth, hypo, wer, subs, dele, inse, word_num) | |
| def main(test_list, wav_path, model_path, decode_path, lang, device): | |
| if lang == "en": | |
| processor, model = load_en_model(model_path) | |
| model.to(device) | |
| elif lang == "zh": | |
| model = load_zh_model(model_path) | |
| params = [] | |
| for line in open(test_list).readlines(): | |
| line = line.strip() | |
| items = line.split("\t") | |
| wav_name, text_ref = items[0], items[-1] | |
| file_path = os.path.join(wav_path, wav_name + ".wav") | |
| assert os.path.exists(file_path), f"{file_path}" | |
| params.append((file_path, text_ref)) | |
| wers = [] | |
| inses = [] | |
| deles = [] | |
| subses = [] | |
| word_nums = 0 | |
| if decode_path: | |
| decode_dir = os.path.dirname(decode_path) | |
| if not os.path.exists(decode_dir): | |
| os.makedirs(decode_dir) | |
| fout = open(decode_path, "w") | |
| for wav_path, text_ref in tqdm(params): | |
| if lang == "en": | |
| wav, sr = sf.read(wav_path) | |
| if sr != 16000: | |
| wav = scipy.signal.resample(wav, int(len(wav) * 16000 / sr)) | |
| input_features = processor( | |
| wav, sampling_rate=16000, return_tensors="pt" | |
| ).input_features | |
| input_features = input_features.to(device) | |
| forced_decoder_ids = processor.get_decoder_prompt_ids( | |
| language="english", task="transcribe" | |
| ) | |
| predicted_ids = model.generate( | |
| input_features, forced_decoder_ids=forced_decoder_ids | |
| ) | |
| transcription = processor.batch_decode( | |
| predicted_ids, skip_special_tokens=True | |
| )[0] | |
| elif lang == "zh": | |
| res = model.generate(input=wav_path, batch_size_s=300, disable_pbar=True) | |
| transcription = res[0]["text"] | |
| transcription = zhconv.convert(transcription, "zh-cn") | |
| truth, hypo, wer, subs, dele, inse, word_num = process_one( | |
| transcription, text_ref, lang | |
| ) | |
| if decode_path: | |
| fout.write(f"{wav_path}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n") | |
| wers.append(float(wer)) | |
| inses.append(float(inse)) | |
| deles.append(float(dele)) | |
| subses.append(float(subs)) | |
| word_nums += word_num | |
| wer_avg = round(np.mean(wers) * 100, 3) | |
| wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 3) | |
| subs = round(np.mean(subses) * 100, 3) | |
| dele = round(np.mean(deles) * 100, 3) | |
| inse = round(np.mean(inses) * 100, 3) | |
| print(f"Seed-TTS WER: {wer_avg}%\n") | |
| print(f"WER: {wer}%\n") | |
| if decode_path: | |
| fout.write(f"SeedTTS WER: {wer_avg}%\n") | |
| fout.write(f"WER: {wer}%\n") | |
| fout.flush() | |
| if __name__ == "__main__": | |
| parser = get_parser() | |
| args = parser.parse_args() | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda", 0) | |
| else: | |
| device = torch.device("cpu") | |
| main( | |
| args.test_list, | |
| args.wav_path, | |
| args.model_path, | |
| args.decode_path, | |
| args.lang, | |
| device, | |
| ) | |