|
|
import json |
|
|
import os |
|
|
import shutil |
|
|
import logging |
|
|
|
|
|
import regex as re |
|
|
import nltk |
|
|
import wget |
|
|
import torch |
|
|
import torchaudio |
|
|
import faster_whisper |
|
|
from speechbrain.inference.separation import SepformerSeparation |
|
|
|
|
|
from omegaconf import OmegaConf |
|
|
|
|
|
|
|
|
from ctc_forced_aligner import ( |
|
|
generate_emissions, |
|
|
get_alignments, |
|
|
get_spans, |
|
|
load_alignment_model, |
|
|
postprocess_results, |
|
|
preprocess_text, |
|
|
) |
|
|
|
|
|
|
|
|
from nemo.collections.asr.models.msdd_models import NeuralDiarizer |
|
|
|
|
|
punct_model_langs = [ |
|
|
"en", |
|
|
"fr", |
|
|
"de", |
|
|
"es", |
|
|
"it", |
|
|
"nl", |
|
|
"pt", |
|
|
"bg", |
|
|
"pl", |
|
|
"cs", |
|
|
"sk", |
|
|
"sl", |
|
|
] |
|
|
|
|
|
LANGUAGES = { |
|
|
"en": "english", |
|
|
"zh": "chinese", |
|
|
"de": "german", |
|
|
"es": "spanish", |
|
|
"ru": "russian", |
|
|
"ko": "korean", |
|
|
"fr": "french", |
|
|
"ja": "japanese", |
|
|
"pt": "portuguese", |
|
|
"tr": "turkish", |
|
|
"pl": "polish", |
|
|
"ca": "catalan", |
|
|
"nl": "dutch", |
|
|
"ar": "arabic", |
|
|
"sv": "swedish", |
|
|
"it": "italian", |
|
|
"id": "indonesian", |
|
|
"hi": "hindi", |
|
|
"fi": "finnish", |
|
|
"vi": "vietnamese", |
|
|
"he": "hebrew", |
|
|
"uk": "ukrainian", |
|
|
"el": "greek", |
|
|
"ms": "malay", |
|
|
"cs": "czech", |
|
|
"ro": "romanian", |
|
|
"da": "danish", |
|
|
"hu": "hungarian", |
|
|
"ta": "tamil", |
|
|
"no": "norwegian", |
|
|
"th": "thai", |
|
|
"ur": "urdu", |
|
|
"hr": "croatian", |
|
|
"bg": "bulgarian", |
|
|
"lt": "lithuanian", |
|
|
"la": "latin", |
|
|
"mi": "maori", |
|
|
"ml": "malayalam", |
|
|
"cy": "welsh", |
|
|
"sk": "slovak", |
|
|
"te": "telugu", |
|
|
"fa": "persian", |
|
|
"lv": "latvian", |
|
|
"bn": "bengali", |
|
|
"sr": "serbian", |
|
|
"az": "azerbaijani", |
|
|
"sl": "slovenian", |
|
|
"kn": "kannada", |
|
|
"et": "estonian", |
|
|
"mk": "macedonian", |
|
|
"br": "breton", |
|
|
"eu": "basque", |
|
|
"is": "icelandic", |
|
|
"hy": "armenian", |
|
|
"ne": "nepali", |
|
|
"mn": "mongolian", |
|
|
"bs": "bosnian", |
|
|
"kk": "kazakh", |
|
|
"sq": "albanian", |
|
|
"sw": "swahili", |
|
|
"gl": "galician", |
|
|
"mr": "marathi", |
|
|
"pa": "punjabi", |
|
|
"si": "sinhala", |
|
|
"km": "khmer", |
|
|
"sn": "shona", |
|
|
"yo": "yoruba", |
|
|
"so": "somali", |
|
|
"af": "afrikaans", |
|
|
"oc": "occitan", |
|
|
"ka": "georgian", |
|
|
"be": "belarusian", |
|
|
"tg": "tajik", |
|
|
"sd": "sindhi", |
|
|
"gu": "gujarati", |
|
|
"am": "amharic", |
|
|
"yi": "yiddish", |
|
|
"lo": "lao", |
|
|
"uz": "uzbek", |
|
|
"fo": "faroese", |
|
|
"ht": "haitian creole", |
|
|
"ps": "pashto", |
|
|
"tk": "turkmen", |
|
|
"nn": "nynorsk", |
|
|
"mt": "maltese", |
|
|
"sa": "sanskrit", |
|
|
"lb": "luxembourgish", |
|
|
"my": "myanmar", |
|
|
"bo": "tibetan", |
|
|
"tl": "tagalog", |
|
|
"mg": "malagasy", |
|
|
"as": "assamese", |
|
|
"tt": "tatar", |
|
|
"haw": "hawaiian", |
|
|
"ln": "lingala", |
|
|
"ha": "hausa", |
|
|
"ba": "bashkir", |
|
|
"jw": "javanese", |
|
|
"su": "sundanese", |
|
|
"yue": "cantonese", |
|
|
} |
|
|
|
|
|
|
|
|
TO_LANGUAGE_CODE = { |
|
|
**{language: code for code, language in LANGUAGES.items()}, |
|
|
"burmese": "my", |
|
|
"valencian": "ca", |
|
|
"flemish": "nl", |
|
|
"haitian": "ht", |
|
|
"letzeburgesch": "lb", |
|
|
"pushto": "ps", |
|
|
"panjabi": "pa", |
|
|
"moldavian": "ro", |
|
|
"moldovan": "ro", |
|
|
"sinhalese": "si", |
|
|
"castilian": "es", |
|
|
} |
|
|
|
|
|
whisper_langs = sorted(LANGUAGES.keys()) + sorted( |
|
|
[k.title() for k in TO_LANGUAGE_CODE.keys()] |
|
|
) |
|
|
|
|
|
langs_to_iso = { |
|
|
"af": "afr", |
|
|
"am": "amh", |
|
|
"ar": "ara", |
|
|
"as": "asm", |
|
|
"az": "aze", |
|
|
"ba": "bak", |
|
|
"be": "bel", |
|
|
"bg": "bul", |
|
|
"bn": "ben", |
|
|
"bo": "tib", |
|
|
"br": "bre", |
|
|
"bs": "bos", |
|
|
"ca": "cat", |
|
|
"cs": "cze", |
|
|
"cy": "wel", |
|
|
"da": "dan", |
|
|
"de": "ger", |
|
|
"el": "gre", |
|
|
"en": "eng", |
|
|
"es": "spa", |
|
|
"et": "est", |
|
|
"eu": "baq", |
|
|
"fa": "per", |
|
|
"fi": "fin", |
|
|
"fo": "fao", |
|
|
"fr": "fre", |
|
|
"gl": "glg", |
|
|
"gu": "guj", |
|
|
"ha": "hau", |
|
|
"haw": "haw", |
|
|
"he": "heb", |
|
|
"hi": "hin", |
|
|
"hr": "hrv", |
|
|
"ht": "hat", |
|
|
"hu": "hun", |
|
|
"hy": "arm", |
|
|
"id": "ind", |
|
|
"is": "ice", |
|
|
"it": "ita", |
|
|
"ja": "jpn", |
|
|
"jw": "jav", |
|
|
"ka": "geo", |
|
|
"kk": "kaz", |
|
|
"km": "khm", |
|
|
"kn": "kan", |
|
|
"ko": "kor", |
|
|
"la": "lat", |
|
|
"lb": "ltz", |
|
|
"ln": "lin", |
|
|
"lo": "lao", |
|
|
"lt": "lit", |
|
|
"lv": "lav", |
|
|
"mg": "mlg", |
|
|
"mi": "mao", |
|
|
"mk": "mac", |
|
|
"ml": "mal", |
|
|
"mn": "mon", |
|
|
"mr": "mar", |
|
|
"ms": "may", |
|
|
"mt": "mlt", |
|
|
"my": "bur", |
|
|
"ne": "nep", |
|
|
"nl": "dut", |
|
|
"nn": "nno", |
|
|
"no": "nor", |
|
|
"oc": "oci", |
|
|
"pa": "pan", |
|
|
"pl": "pol", |
|
|
"ps": "pus", |
|
|
"pt": "por", |
|
|
"ro": "rum", |
|
|
"ru": "rus", |
|
|
"sa": "san", |
|
|
"sd": "snd", |
|
|
"si": "sin", |
|
|
"sk": "slo", |
|
|
"sl": "slv", |
|
|
"sn": "sna", |
|
|
"so": "som", |
|
|
"sq": "alb", |
|
|
"sr": "srp", |
|
|
"su": "sun", |
|
|
"sv": "swe", |
|
|
"sw": "swa", |
|
|
"ta": "tam", |
|
|
"te": "tel", |
|
|
"tg": "tgk", |
|
|
"th": "tha", |
|
|
"tk": "tuk", |
|
|
"tl": "tgl", |
|
|
"tr": "tur", |
|
|
"tt": "tat", |
|
|
"uk": "ukr", |
|
|
"ur": "urd", |
|
|
"uz": "uzb", |
|
|
"vi": "vie", |
|
|
"yi": "yid", |
|
|
"yo": "yor", |
|
|
"yue": "yue", |
|
|
"zh": "chi", |
|
|
} |
|
|
|
|
|
|
|
|
def separate_speakers(audio_path, num_speakers, temp_dir, device): |
|
|
""" |
|
|
Separate speakers using SpeechBrain SepFormer models |
|
|
""" |
|
|
|
|
|
original_audio_path = os.path.normpath(os.path.abspath(audio_path)) |
|
|
if not os.path.exists(original_audio_path): |
|
|
raise FileNotFoundError(f"Audio file not found: {original_audio_path}") |
|
|
|
|
|
print(f"Original audio file path: {original_audio_path}") |
|
|
print(f"File exists: {os.path.exists(original_audio_path)}") |
|
|
print(f"File size: {os.path.getsize(original_audio_path) if os.path.exists(original_audio_path) else 'N/A'} bytes") |
|
|
|
|
|
|
|
|
|
|
|
simple_audio_name = "input_audio.wav" |
|
|
simple_audio_path = os.path.join(temp_dir, simple_audio_name) |
|
|
simple_audio_path = os.path.normpath(simple_audio_path) |
|
|
|
|
|
|
|
|
print(f"Creating temp directory: {temp_dir}") |
|
|
os.makedirs(temp_dir, exist_ok=True) |
|
|
print(f"Temp directory exists: {os.path.exists(temp_dir)}") |
|
|
|
|
|
|
|
|
import shutil |
|
|
shutil.copy2(original_audio_path, simple_audio_path) |
|
|
|
|
|
print(f"Copied audio to simpler path: {simple_audio_path}") |
|
|
print(f"Simple path exists: {os.path.exists(simple_audio_path)}") |
|
|
|
|
|
|
|
|
audio_path_for_speechbrain = simple_audio_path.replace('\\', '/') |
|
|
audio_path_relative = os.path.relpath(simple_audio_path) |
|
|
print(f"Path for SpeechBrain (forward slashes): {audio_path_for_speechbrain}") |
|
|
print(f"Path for SpeechBrain (relative): {audio_path_relative}") |
|
|
|
|
|
|
|
|
audio_path = simple_audio_path |
|
|
|
|
|
print(f"Separating {num_speakers} speakers from audio...") |
|
|
|
|
|
|
|
|
original_waveform, original_sample_rate = torchaudio.load(audio_path) |
|
|
print(f"Original audio sample rate: {original_sample_rate}Hz") |
|
|
|
|
|
|
|
|
if num_speakers == 2: |
|
|
model_name = "speechbrain/sepformer-libri2mix" |
|
|
fallback_model = "speechbrain/sepformer-wsj02mix" |
|
|
elif num_speakers == 3: |
|
|
model_name = "speechbrain/sepformer-wsj03mix" |
|
|
fallback_model = "speechbrain/sepformer-libri3mix" |
|
|
else: |
|
|
raise ValueError("Only 2 or 3 speakers are supported for separation") |
|
|
|
|
|
models_to_try = [model_name, fallback_model] if num_speakers in [2, 3] else [model_name] |
|
|
|
|
|
for model_attempt, current_model in enumerate(models_to_try): |
|
|
try: |
|
|
print(f"Trying separation model: {current_model}") |
|
|
|
|
|
|
|
|
separator = SepformerSeparation.from_hparams( |
|
|
source=current_model, |
|
|
savedir=os.path.join(temp_dir, "sepformer_models"), |
|
|
run_opts={"device": device} |
|
|
) |
|
|
|
|
|
|
|
|
paths_to_try = [ |
|
|
audio_path_for_speechbrain, |
|
|
audio_path_relative, |
|
|
simple_audio_path |
|
|
] |
|
|
|
|
|
est_sources = None |
|
|
for path_attempt, path_to_try in enumerate(paths_to_try): |
|
|
try: |
|
|
print(f"Attempt {path_attempt + 1}: Calling SpeechBrain with path: {path_to_try}") |
|
|
est_sources = separator.separate_file(path=path_to_try) |
|
|
print(f"Success with path format {path_attempt + 1}") |
|
|
break |
|
|
except Exception as path_error: |
|
|
print(f"Path attempt {path_attempt + 1} failed: {path_error}") |
|
|
if path_attempt == len(paths_to_try) - 1: |
|
|
|
|
|
raise |
|
|
|
|
|
print(f"Separated sources shape: {est_sources.shape}") |
|
|
|
|
|
|
|
|
if len(est_sources.shape) == 3: |
|
|
actual_num_sources = est_sources.shape[2] |
|
|
elif len(est_sources.shape) == 2: |
|
|
actual_num_sources = 1 |
|
|
est_sources = est_sources.unsqueeze(2) |
|
|
else: |
|
|
raise ValueError(f"Unexpected tensor shape: {est_sources.shape}") |
|
|
|
|
|
print(f"Expected {num_speakers} speakers, got {actual_num_sources} separated sources") |
|
|
|
|
|
if actual_num_sources < num_speakers: |
|
|
logging.warning(f"Only {actual_num_sources} sources separated, but {num_speakers} were requested.") |
|
|
if actual_num_sources == 0: |
|
|
raise ValueError("No sources were separated") |
|
|
num_speakers_to_process = actual_num_sources |
|
|
else: |
|
|
num_speakers_to_process = num_speakers |
|
|
|
|
|
|
|
|
separated_files = [] |
|
|
|
|
|
|
|
|
target_sample_rate = original_sample_rate |
|
|
processing_sample_rate = 16000 |
|
|
|
|
|
for i in range(num_speakers_to_process): |
|
|
|
|
|
separated_path_original = os.path.join(temp_dir, f"separated_speaker_{i+1}_original.wav") |
|
|
separated_path_processing = os.path.join(temp_dir, f"separated_speaker_{i+1}.wav") |
|
|
|
|
|
|
|
|
if actual_num_sources == 1: |
|
|
source_audio = est_sources[:, :, 0].cpu() |
|
|
else: |
|
|
source_audio = est_sources[:, :, i].cpu() |
|
|
|
|
|
|
|
|
if source_audio.dim() == 1: |
|
|
source_audio = source_audio.unsqueeze(0) |
|
|
elif source_audio.dim() > 2: |
|
|
source_audio = source_audio[0:1, :] |
|
|
|
|
|
|
|
|
|
|
|
separation_sample_rate = 8000 |
|
|
|
|
|
|
|
|
if target_sample_rate != separation_sample_rate: |
|
|
resampler_to_original = torchaudio.transforms.Resample( |
|
|
orig_freq=separation_sample_rate, |
|
|
new_freq=target_sample_rate |
|
|
) |
|
|
source_audio_original = resampler_to_original(source_audio) |
|
|
else: |
|
|
source_audio_original = source_audio |
|
|
|
|
|
torchaudio.save(separated_path_original, source_audio_original, target_sample_rate) |
|
|
print(f"Saved original quality audio for speaker {i+1}: {separated_path_original} ({target_sample_rate}Hz)") |
|
|
|
|
|
|
|
|
if processing_sample_rate != separation_sample_rate: |
|
|
resampler_to_processing = torchaudio.transforms.Resample( |
|
|
orig_freq=separation_sample_rate, |
|
|
new_freq=processing_sample_rate |
|
|
) |
|
|
source_audio_processing = resampler_to_processing(source_audio) |
|
|
else: |
|
|
source_audio_processing = source_audio |
|
|
|
|
|
torchaudio.save(separated_path_processing, source_audio_processing, processing_sample_rate) |
|
|
print(f"Saved processing audio for speaker {i+1}: {separated_path_processing} ({processing_sample_rate}Hz)") |
|
|
|
|
|
|
|
|
separated_files.append(separated_path_processing) |
|
|
|
|
|
if len(separated_files) == 0: |
|
|
raise ValueError("No audio sources could be separated") |
|
|
|
|
|
|
|
|
while len(separated_files) < num_speakers: |
|
|
last_file = separated_files[-1] |
|
|
new_speaker_id = len(separated_files) + 1 |
|
|
duplicated_path_processing = os.path.join(temp_dir, f"separated_speaker_{new_speaker_id}.wav") |
|
|
duplicated_path_original = os.path.join(temp_dir, f"separated_speaker_{new_speaker_id}_original.wav") |
|
|
|
|
|
|
|
|
waveform, sample_rate = torchaudio.load(last_file) |
|
|
torchaudio.save(duplicated_path_processing, waveform, sample_rate) |
|
|
|
|
|
|
|
|
last_original_file = last_file.replace(".wav", "_original.wav") |
|
|
if os.path.exists(last_original_file): |
|
|
waveform_orig, sample_rate_orig = torchaudio.load(last_original_file) |
|
|
torchaudio.save(duplicated_path_original, waveform_orig, sample_rate_orig) |
|
|
|
|
|
separated_files.append(duplicated_path_processing) |
|
|
print(f"Duplicated speaker audio for speaker {new_speaker_id}") |
|
|
|
|
|
return separated_files |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Model {current_model} failed: {e}") |
|
|
if model_attempt < len(models_to_try) - 1: |
|
|
print(f"Trying next model...") |
|
|
continue |
|
|
else: |
|
|
|
|
|
error_msg = f""" |
|
|
╔══════════════════════════════════════════════════════════════════════════════╗ |
|
|
║ SPEAKER SEPARATION FAILED ║ |
|
|
╚══════════════════════════════════════════════════════════════════════════════╝ |
|
|
|
|
|
❌ All SpeechBrain separation models failed to load or process your audio file. |
|
|
|
|
|
🔍 Possible causes: |
|
|
• Audio file path contains special characters or spaces |
|
|
• Audio file format is not supported by SpeechBrain models |
|
|
• File permissions issue |
|
|
• Audio file is corrupted or too short |
|
|
|
|
|
💡 Solutions: |
|
|
1. Try copying your audio file to the current directory with a simple name: |
|
|
copy "{audio_path}" "./audio_simple.wav" |
|
|
|
|
|
2. Use forward slashes in the path: |
|
|
python diarize1.py -a "C:/path/to/your/audio.wav" --whisper-model large-v2 |
|
|
|
|
|
3. Run WITHOUT speaker separation (standard diarization mode): |
|
|
python diarize1.py -a "{audio_path}" --whisper-model large-v2 --batch-size 4 |
|
|
|
|
|
4. Check that your audio file is a valid WAV/MP3 file that can be played normally |
|
|
|
|
|
Cannot continue with speaker separation mode. Please fix the issue or use standard diarization mode. |
|
|
""" |
|
|
|
|
|
try: |
|
|
print(error_msg) |
|
|
except UnicodeEncodeError: |
|
|
print("Speaker separation failed for all models. Please check the logs for details.") |
|
|
logging.error("Speaker separation failed for all models. Terminating separation mode.") |
|
|
raise RuntimeError( |
|
|
f"Speaker separation failed: All SpeechBrain models could not process the audio file '{audio_path}'. " |
|
|
f"Check file path, format, and permissions. Consider using standard diarization mode instead (remove --num-speakers argument)." |
|
|
) |
|
|
|
|
|
|
|
|
def process_separated_audio(audio_path, speaker_id, args, language, temp_dir): |
|
|
""" |
|
|
Process a single separated audio file - TRANSCRIPTION ONLY, NO VAD/DIARIZATION |
|
|
""" |
|
|
print(f"Processing separated audio for Speaker {speaker_id}") |
|
|
|
|
|
mtypes = {"cpu": "int8", "cuda": "float16"} |
|
|
|
|
|
|
|
|
if args.stemming: |
|
|
vocal_target_dir = os.path.join(temp_dir, f"speaker_{speaker_id}_demucs") |
|
|
return_code = os.system( |
|
|
f'python -m demucs.separate -n htdemucs --two-stems=vocals "{audio_path}" -o "{vocal_target_dir}" --device "{args.device}"' |
|
|
) |
|
|
|
|
|
if return_code != 0: |
|
|
logging.warning(f"Demucs failed for speaker {speaker_id}, using original separated audio") |
|
|
vocal_target = audio_path |
|
|
else: |
|
|
vocal_target = os.path.join( |
|
|
vocal_target_dir, |
|
|
"htdemucs", |
|
|
os.path.splitext(os.path.basename(audio_path))[0], |
|
|
"vocals.wav", |
|
|
) |
|
|
else: |
|
|
vocal_target = audio_path |
|
|
|
|
|
|
|
|
whisper_model = faster_whisper.WhisperModel( |
|
|
args.model_name, device=args.device, compute_type=mtypes[args.device] |
|
|
) |
|
|
whisper_pipeline = faster_whisper.BatchedInferencePipeline(whisper_model) |
|
|
audio_waveform = faster_whisper.decode_audio(vocal_target) |
|
|
suppress_tokens = ( |
|
|
find_numeral_symbol_tokens(whisper_model.hf_tokenizer) |
|
|
if args.suppress_numerals |
|
|
else [-1] |
|
|
) |
|
|
|
|
|
if args.batch_size > 0: |
|
|
transcript_segments, info = whisper_pipeline.transcribe( |
|
|
audio_waveform, |
|
|
language, |
|
|
suppress_tokens=suppress_tokens, |
|
|
batch_size=args.batch_size, |
|
|
) |
|
|
else: |
|
|
transcript_segments, info = whisper_model.transcribe( |
|
|
audio_waveform, |
|
|
language, |
|
|
suppress_tokens=suppress_tokens, |
|
|
vad_filter=True, |
|
|
) |
|
|
|
|
|
full_transcript = "".join(segment.text for segment in transcript_segments) |
|
|
|
|
|
|
|
|
segments_list = list(transcript_segments) |
|
|
|
|
|
|
|
|
del whisper_model, whisper_pipeline |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
try: |
|
|
alignment_model, alignment_tokenizer = load_alignment_model( |
|
|
args.device, |
|
|
dtype=torch.float16 if args.device == "cuda" else torch.float32, |
|
|
) |
|
|
|
|
|
emissions, stride = generate_emissions( |
|
|
alignment_model, |
|
|
torch.from_numpy(audio_waveform) |
|
|
.to(alignment_model.dtype) |
|
|
.to(alignment_model.device), |
|
|
batch_size=args.batch_size, |
|
|
) |
|
|
|
|
|
del alignment_model |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
tokens_starred, text_starred = preprocess_text( |
|
|
full_transcript, |
|
|
romanize=True, |
|
|
language=langs_to_iso[info.language], |
|
|
) |
|
|
|
|
|
segments, scores, blank_token = get_alignments( |
|
|
emissions, |
|
|
tokens_starred, |
|
|
alignment_tokenizer, |
|
|
) |
|
|
|
|
|
spans = get_spans(tokens_starred, segments, blank_token) |
|
|
word_timestamps = postprocess_results(text_starred, spans, stride, scores) |
|
|
|
|
|
print(f"Forced alignment completed for speaker {speaker_id}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Forced alignment failed for speaker {speaker_id}: {e}") |
|
|
print("Using Whisper segment timestamps instead") |
|
|
|
|
|
|
|
|
word_timestamps = [] |
|
|
for segment in segments_list: |
|
|
|
|
|
words = segment.text.strip().split() |
|
|
if words: |
|
|
segment_duration = segment.end - segment.start |
|
|
word_duration = segment_duration / len(words) |
|
|
|
|
|
for i, word in enumerate(words): |
|
|
word_start = segment.start + (i * word_duration) |
|
|
word_end = word_start + word_duration |
|
|
word_timestamps.append({ |
|
|
"text": word, |
|
|
"start": word_start, |
|
|
"end": word_end |
|
|
}) |
|
|
|
|
|
|
|
|
if info.language in punct_model_langs: |
|
|
try: |
|
|
from deepmultilingualpunctuation import PunctuationModel |
|
|
punct_model = PunctuationModel(model="kredor/punctuate-all") |
|
|
|
|
|
words_list = [wt["text"] for wt in word_timestamps] |
|
|
labled_words = punct_model.predict(words_list, chunk_size=230) |
|
|
|
|
|
ending_puncts = ".?!" |
|
|
model_puncts = ".,;:!?" |
|
|
is_acronym = lambda x: re.fullmatch(r"\b(?:[a-zA-Z]\.){2,}", x) |
|
|
|
|
|
for word_dict, labeled_tuple in zip(word_timestamps, labled_words): |
|
|
word = word_dict["text"] |
|
|
if ( |
|
|
word |
|
|
and labeled_tuple[1] in ending_puncts |
|
|
and (word[-1] not in model_puncts or is_acronym(word)) |
|
|
): |
|
|
word += labeled_tuple[1] |
|
|
if word.endswith(".."): |
|
|
word = word.rstrip(".") |
|
|
word_dict["text"] = word |
|
|
|
|
|
print(f"Punctuation restoration completed for speaker {speaker_id}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Punctuation restoration failed for speaker {speaker_id}: {e}") |
|
|
|
|
|
|
|
|
|
|
|
sentences = [] |
|
|
current_sentence = { |
|
|
"speaker": f"Speaker {speaker_id}", |
|
|
"start_time": 0, |
|
|
"end_time": 0, |
|
|
"text": "" |
|
|
} |
|
|
|
|
|
sentence_endings = ".?!" |
|
|
|
|
|
for i, word_data in enumerate(word_timestamps): |
|
|
word = word_data["text"] |
|
|
start_ms = int(word_data["start"] * 1000) |
|
|
end_ms = int(word_data["end"] * 1000) |
|
|
|
|
|
if i == 0: |
|
|
current_sentence["start_time"] = start_ms |
|
|
|
|
|
current_sentence["end_time"] = end_ms |
|
|
current_sentence["text"] += word + " " |
|
|
|
|
|
|
|
|
if any(word.endswith(punct) for punct in sentence_endings): |
|
|
sentences.append(current_sentence.copy()) |
|
|
current_sentence = { |
|
|
"speaker": f"Speaker {speaker_id}", |
|
|
"start_time": end_ms, |
|
|
"end_time": end_ms, |
|
|
"text": "" |
|
|
} |
|
|
|
|
|
|
|
|
if current_sentence["text"].strip(): |
|
|
sentences.append(current_sentence) |
|
|
|
|
|
|
|
|
if not sentences: |
|
|
audio_duration_ms = int(len(audio_waveform) / 16000 * 1000) |
|
|
sentences = [{ |
|
|
"speaker": f"Speaker {speaker_id}", |
|
|
"start_time": 0, |
|
|
"end_time": audio_duration_ms, |
|
|
"text": full_transcript.strip() |
|
|
}] |
|
|
|
|
|
print(f"Created {len(sentences)} sentences for speaker {speaker_id}") |
|
|
|
|
|
return { |
|
|
"speaker_id": speaker_id, |
|
|
"sentences": sentences, |
|
|
"language": info.language |
|
|
} |
|
|
|
|
|
|
|
|
def create_config(output_dir): |
|
|
DOMAIN_TYPE = "telephonic" |
|
|
CONFIG_LOCAL_DIRECTORY = "nemo_msdd_configs" |
|
|
CONFIG_FILE_NAME = f"diar_infer_{DOMAIN_TYPE}.yaml" |
|
|
MODEL_CONFIG_PATH = os.path.join(CONFIG_LOCAL_DIRECTORY, CONFIG_FILE_NAME) |
|
|
if not os.path.exists(MODEL_CONFIG_PATH): |
|
|
os.makedirs(CONFIG_LOCAL_DIRECTORY, exist_ok=True) |
|
|
CONFIG_URL = f"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/{CONFIG_FILE_NAME}" |
|
|
MODEL_CONFIG_PATH = wget.download(CONFIG_URL, MODEL_CONFIG_PATH) |
|
|
|
|
|
config = OmegaConf.load(MODEL_CONFIG_PATH) |
|
|
|
|
|
data_dir = os.path.join(output_dir, "data") |
|
|
os.makedirs(data_dir, exist_ok=True) |
|
|
|
|
|
meta = { |
|
|
"audio_filepath": os.path.join(output_dir, "mono_file.wav"), |
|
|
"offset": 0, |
|
|
"duration": None, |
|
|
"label": "infer", |
|
|
"text": "-", |
|
|
"rttm_filepath": None, |
|
|
"uem_filepath": None, |
|
|
} |
|
|
with open(os.path.join(data_dir, "input_manifest.json"), "w") as fp: |
|
|
json.dump(meta, fp) |
|
|
fp.write("\n") |
|
|
|
|
|
pretrained_vad = "vad_multilingual_marblenet" |
|
|
pretrained_speaker_model = "titanet_large" |
|
|
config.num_workers = 0 |
|
|
config.diarizer.manifest_filepath = os.path.join(data_dir, "input_manifest.json") |
|
|
config.diarizer.out_dir = ( |
|
|
output_dir |
|
|
) |
|
|
|
|
|
config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model |
|
|
config.diarizer.oracle_vad = ( |
|
|
False |
|
|
) |
|
|
config.diarizer.clustering.parameters.oracle_num_speakers = False |
|
|
|
|
|
|
|
|
config.diarizer.vad.model_path = pretrained_vad |
|
|
config.diarizer.vad.parameters.onset = 0.8 |
|
|
config.diarizer.vad.parameters.offset = 0.6 |
|
|
config.diarizer.vad.parameters.pad_offset = -0.05 |
|
|
config.diarizer.msdd_model.model_path = ( |
|
|
"diar_msdd_telephonic" |
|
|
) |
|
|
|
|
|
return config |
|
|
|
|
|
|
|
|
def get_word_ts_anchor(s, e, option="start"): |
|
|
if option == "end": |
|
|
return e |
|
|
elif option == "mid": |
|
|
return (s + e) / 2 |
|
|
return s |
|
|
|
|
|
|
|
|
def get_words_speaker_mapping(wrd_ts, spk_ts, word_anchor_option="start"): |
|
|
s, e, sp = spk_ts[0] |
|
|
wrd_pos, turn_idx = 0, 0 |
|
|
wrd_spk_mapping = [] |
|
|
for wrd_dict in wrd_ts: |
|
|
ws, we, wrd = ( |
|
|
int(wrd_dict["start"] * 1000), |
|
|
int(wrd_dict["end"] * 1000), |
|
|
wrd_dict["text"], |
|
|
) |
|
|
wrd_pos = get_word_ts_anchor(ws, we, word_anchor_option) |
|
|
while wrd_pos > float(e): |
|
|
turn_idx += 1 |
|
|
turn_idx = min(turn_idx, len(spk_ts) - 1) |
|
|
s, e, sp = spk_ts[turn_idx] |
|
|
if turn_idx == len(spk_ts) - 1: |
|
|
e = get_word_ts_anchor(ws, we, option="end") |
|
|
wrd_spk_mapping.append( |
|
|
{"word": wrd, "start_time": ws, "end_time": we, "speaker": sp} |
|
|
) |
|
|
return wrd_spk_mapping |
|
|
|
|
|
|
|
|
sentence_ending_punctuations = ".?!" |
|
|
|
|
|
|
|
|
def get_first_word_idx_of_sentence(word_idx, word_list, speaker_list, max_words): |
|
|
is_word_sentence_end = ( |
|
|
lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations |
|
|
) |
|
|
left_idx = word_idx |
|
|
while ( |
|
|
left_idx > 0 |
|
|
and word_idx - left_idx < max_words |
|
|
and speaker_list[left_idx - 1] == speaker_list[left_idx] |
|
|
and not is_word_sentence_end(left_idx - 1) |
|
|
): |
|
|
left_idx -= 1 |
|
|
|
|
|
return left_idx if left_idx == 0 or is_word_sentence_end(left_idx - 1) else -1 |
|
|
|
|
|
|
|
|
def get_last_word_idx_of_sentence(word_idx, word_list, max_words): |
|
|
is_word_sentence_end = ( |
|
|
lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations |
|
|
) |
|
|
right_idx = word_idx |
|
|
while ( |
|
|
right_idx < len(word_list) - 1 |
|
|
and right_idx - word_idx < max_words |
|
|
and not is_word_sentence_end(right_idx) |
|
|
): |
|
|
right_idx += 1 |
|
|
|
|
|
return ( |
|
|
right_idx |
|
|
if right_idx == len(word_list) - 1 or is_word_sentence_end(right_idx) |
|
|
else -1 |
|
|
) |
|
|
|
|
|
|
|
|
def get_realigned_ws_mapping_with_punctuation( |
|
|
word_speaker_mapping, max_words_in_sentence=50 |
|
|
): |
|
|
is_word_sentence_end = ( |
|
|
lambda x: x >= 0 |
|
|
and word_speaker_mapping[x]["word"][-1] in sentence_ending_punctuations |
|
|
) |
|
|
wsp_len = len(word_speaker_mapping) |
|
|
|
|
|
words_list, speaker_list = [], [] |
|
|
for k, line_dict in enumerate(word_speaker_mapping): |
|
|
word, speaker = line_dict["word"], line_dict["speaker"] |
|
|
words_list.append(word) |
|
|
speaker_list.append(speaker) |
|
|
|
|
|
k = 0 |
|
|
while k < len(word_speaker_mapping): |
|
|
line_dict = word_speaker_mapping[k] |
|
|
if ( |
|
|
k < wsp_len - 1 |
|
|
and speaker_list[k] != speaker_list[k + 1] |
|
|
and not is_word_sentence_end(k) |
|
|
): |
|
|
left_idx = get_first_word_idx_of_sentence( |
|
|
k, words_list, speaker_list, max_words_in_sentence |
|
|
) |
|
|
right_idx = ( |
|
|
get_last_word_idx_of_sentence( |
|
|
k, words_list, max_words_in_sentence - k + left_idx - 1 |
|
|
) |
|
|
if left_idx > -1 |
|
|
else -1 |
|
|
) |
|
|
if min(left_idx, right_idx) == -1: |
|
|
k += 1 |
|
|
continue |
|
|
|
|
|
spk_labels = speaker_list[left_idx : right_idx + 1] |
|
|
mod_speaker = max(set(spk_labels), key=spk_labels.count) |
|
|
if spk_labels.count(mod_speaker) < len(spk_labels) // 2: |
|
|
k += 1 |
|
|
continue |
|
|
|
|
|
speaker_list[left_idx : right_idx + 1] = [mod_speaker] * ( |
|
|
right_idx - left_idx + 1 |
|
|
) |
|
|
k = right_idx |
|
|
|
|
|
k += 1 |
|
|
|
|
|
k, realigned_list = 0, [] |
|
|
while k < len(word_speaker_mapping): |
|
|
line_dict = word_speaker_mapping[k].copy() |
|
|
line_dict["speaker"] = speaker_list[k] |
|
|
realigned_list.append(line_dict) |
|
|
k += 1 |
|
|
|
|
|
return realigned_list |
|
|
|
|
|
|
|
|
def get_sentences_speaker_mapping(word_speaker_mapping, spk_ts): |
|
|
sentence_checker = nltk.tokenize.PunktSentenceTokenizer().text_contains_sentbreak |
|
|
s, e, spk = spk_ts[0] |
|
|
prev_spk = spk |
|
|
|
|
|
snts = [] |
|
|
snt = {"speaker": f"Speaker {spk}", "start_time": s, "end_time": e, "text": ""} |
|
|
|
|
|
for wrd_dict in word_speaker_mapping: |
|
|
wrd, spk = wrd_dict["word"], wrd_dict["speaker"] |
|
|
s, e = wrd_dict["start_time"], wrd_dict["end_time"] |
|
|
if spk != prev_spk or sentence_checker(snt["text"] + " " + wrd): |
|
|
snts.append(snt) |
|
|
snt = { |
|
|
"speaker": f"Speaker {spk}", |
|
|
"start_time": s, |
|
|
"end_time": e, |
|
|
"text": "", |
|
|
} |
|
|
else: |
|
|
snt["end_time"] = e |
|
|
snt["text"] += wrd + " " |
|
|
prev_spk = spk |
|
|
|
|
|
snts.append(snt) |
|
|
return snts |
|
|
|
|
|
|
|
|
def get_speaker_aware_transcript(sentences_speaker_mapping, f): |
|
|
previous_speaker = sentences_speaker_mapping[0]["speaker"] |
|
|
f.write(f"{previous_speaker}: ") |
|
|
|
|
|
for sentence_dict in sentences_speaker_mapping: |
|
|
speaker = sentence_dict["speaker"] |
|
|
sentence = sentence_dict["text"] |
|
|
|
|
|
|
|
|
if speaker != previous_speaker: |
|
|
f.write(f"\n\n{speaker}: ") |
|
|
previous_speaker = speaker |
|
|
|
|
|
|
|
|
f.write(sentence + " ") |
|
|
|
|
|
|
|
|
def format_timestamp( |
|
|
milliseconds: float, always_include_hours: bool = False, decimal_marker: str = "." |
|
|
): |
|
|
assert milliseconds >= 0, "non-negative timestamp expected" |
|
|
|
|
|
hours = milliseconds // 3_600_000 |
|
|
milliseconds -= hours * 3_600_000 |
|
|
|
|
|
minutes = milliseconds // 60_000 |
|
|
milliseconds -= minutes * 60_000 |
|
|
|
|
|
seconds = milliseconds // 1_000 |
|
|
milliseconds -= seconds * 1_000 |
|
|
|
|
|
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" |
|
|
return ( |
|
|
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" |
|
|
) |
|
|
|
|
|
|
|
|
def write_srt(transcript, file): |
|
|
""" |
|
|
Write a transcript to a file in SRT format. |
|
|
""" |
|
|
for i, segment in enumerate(transcript, start=1): |
|
|
|
|
|
print( |
|
|
f"{i}\n" |
|
|
f"{format_timestamp(segment['start_time'], always_include_hours=True, decimal_marker=',')} --> " |
|
|
f"{format_timestamp(segment['end_time'], always_include_hours=True, decimal_marker=',')}\n" |
|
|
f"{segment['speaker']}: {segment['text'].strip().replace('-->', '->')}\n", |
|
|
file=file, |
|
|
flush=True, |
|
|
) |
|
|
|
|
|
|
|
|
def find_numeral_symbol_tokens(tokenizer): |
|
|
numeral_symbol_tokens = [ |
|
|
-1, |
|
|
] |
|
|
for token, token_id in tokenizer.get_vocab().items(): |
|
|
has_numeral_symbol = any(c in "0123456789%$£" for c in token) |
|
|
if has_numeral_symbol: |
|
|
numeral_symbol_tokens.append(token_id) |
|
|
return numeral_symbol_tokens |
|
|
|
|
|
|
|
|
def _get_next_start_timestamp(word_timestamps, current_word_index, final_timestamp): |
|
|
|
|
|
if current_word_index == len(word_timestamps) - 1: |
|
|
return word_timestamps[current_word_index]["start"] |
|
|
|
|
|
next_word_index = current_word_index + 1 |
|
|
while current_word_index < len(word_timestamps) - 1: |
|
|
if word_timestamps[next_word_index].get("start") is None: |
|
|
|
|
|
|
|
|
word_timestamps[current_word_index]["word"] += ( |
|
|
" " + word_timestamps[next_word_index]["word"] |
|
|
) |
|
|
|
|
|
word_timestamps[next_word_index]["word"] = None |
|
|
next_word_index += 1 |
|
|
if next_word_index == len(word_timestamps): |
|
|
return final_timestamp |
|
|
|
|
|
else: |
|
|
return word_timestamps[next_word_index]["start"] |
|
|
|
|
|
|
|
|
def filter_missing_timestamps( |
|
|
word_timestamps, initial_timestamp=0, final_timestamp=None |
|
|
): |
|
|
|
|
|
if word_timestamps[0].get("start") is None: |
|
|
word_timestamps[0]["start"] = ( |
|
|
initial_timestamp if initial_timestamp is not None else 0 |
|
|
) |
|
|
word_timestamps[0]["end"] = _get_next_start_timestamp( |
|
|
word_timestamps, 0, final_timestamp |
|
|
) |
|
|
|
|
|
result = [ |
|
|
word_timestamps[0], |
|
|
] |
|
|
|
|
|
for i, ws in enumerate(word_timestamps[1:], start=1): |
|
|
|
|
|
|
|
|
if ws.get("start") is None and ws.get("word") is not None: |
|
|
ws["start"] = word_timestamps[i - 1]["end"] |
|
|
ws["end"] = _get_next_start_timestamp(word_timestamps, i, final_timestamp) |
|
|
|
|
|
if ws["word"] is not None: |
|
|
result.append(ws) |
|
|
return result |
|
|
|
|
|
|
|
|
def cleanup(path: str): |
|
|
"""path could either be relative or absolute.""" |
|
|
|
|
|
if os.path.isfile(path) or os.path.islink(path): |
|
|
|
|
|
os.remove(path) |
|
|
elif os.path.isdir(path): |
|
|
|
|
|
shutil.rmtree(path) |
|
|
else: |
|
|
raise ValueError(f"Path {path} is not a file or dir.") |
|
|
|
|
|
|
|
|
def process_language_arg(language: str, model_name: str): |
|
|
""" |
|
|
Process the language argument to make sure it's valid |
|
|
and convert language names to language codes. |
|
|
""" |
|
|
if language is not None: |
|
|
language = language.lower() |
|
|
if language not in LANGUAGES: |
|
|
if language in TO_LANGUAGE_CODE: |
|
|
language = TO_LANGUAGE_CODE[language] |
|
|
else: |
|
|
raise ValueError(f"Unsupported language: {language}") |
|
|
|
|
|
if model_name.endswith(".en") and language != "en": |
|
|
raise ValueError( |
|
|
f"{model_name} is an English-only model but choosen language is '{language}'" |
|
|
) |
|
|
|
|
|
return language |