import json import os import shutil import logging import regex as re import nltk import wget import torch import torchaudio import faster_whisper from speechbrain.inference.separation import SepformerSeparation from omegaconf import OmegaConf # Import for forced alignment from ctc_forced_aligner import ( generate_emissions, get_alignments, get_spans, load_alignment_model, postprocess_results, preprocess_text, ) # Import for diarization from nemo.collections.asr.models.msdd_models import NeuralDiarizer punct_model_langs = [ "en", "fr", "de", "es", "it", "nl", "pt", "bg", "pl", "cs", "sk", "sl", ] LANGUAGES = { "en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian", "ko": "korean", "fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish", "pl": "polish", "ca": "catalan", "nl": "dutch", "ar": "arabic", "sv": "swedish", "it": "italian", "id": "indonesian", "hi": "hindi", "fi": "finnish", "vi": "vietnamese", "he": "hebrew", "uk": "ukrainian", "el": "greek", "ms": "malay", "cs": "czech", "ro": "romanian", "da": "danish", "hu": "hungarian", "ta": "tamil", "no": "norwegian", "th": "thai", "ur": "urdu", "hr": "croatian", "bg": "bulgarian", "lt": "lithuanian", "la": "latin", "mi": "maori", "ml": "malayalam", "cy": "welsh", "sk": "slovak", "te": "telugu", "fa": "persian", "lv": "latvian", "bn": "bengali", "sr": "serbian", "az": "azerbaijani", "sl": "slovenian", "kn": "kannada", "et": "estonian", "mk": "macedonian", "br": "breton", "eu": "basque", "is": "icelandic", "hy": "armenian", "ne": "nepali", "mn": "mongolian", "bs": "bosnian", "kk": "kazakh", "sq": "albanian", "sw": "swahili", "gl": "galician", "mr": "marathi", "pa": "punjabi", "si": "sinhala", "km": "khmer", "sn": "shona", "yo": "yoruba", "so": "somali", "af": "afrikaans", "oc": "occitan", "ka": "georgian", "be": "belarusian", "tg": "tajik", "sd": "sindhi", "gu": "gujarati", "am": "amharic", "yi": "yiddish", "lo": "lao", "uz": "uzbek", "fo": "faroese", "ht": "haitian creole", "ps": "pashto", "tk": "turkmen", "nn": "nynorsk", "mt": "maltese", "sa": "sanskrit", "lb": "luxembourgish", "my": "myanmar", "bo": "tibetan", "tl": "tagalog", "mg": "malagasy", "as": "assamese", "tt": "tatar", "haw": "hawaiian", "ln": "lingala", "ha": "hausa", "ba": "bashkir", "jw": "javanese", "su": "sundanese", "yue": "cantonese", } # language code lookup by name, with a few language aliases TO_LANGUAGE_CODE = { **{language: code for code, language in LANGUAGES.items()}, "burmese": "my", "valencian": "ca", "flemish": "nl", "haitian": "ht", "letzeburgesch": "lb", "pushto": "ps", "panjabi": "pa", "moldavian": "ro", "moldovan": "ro", "sinhalese": "si", "castilian": "es", } whisper_langs = sorted(LANGUAGES.keys()) + sorted( [k.title() for k in TO_LANGUAGE_CODE.keys()] ) langs_to_iso = { "af": "afr", "am": "amh", "ar": "ara", "as": "asm", "az": "aze", "ba": "bak", "be": "bel", "bg": "bul", "bn": "ben", "bo": "tib", "br": "bre", "bs": "bos", "ca": "cat", "cs": "cze", "cy": "wel", "da": "dan", "de": "ger", "el": "gre", "en": "eng", "es": "spa", "et": "est", "eu": "baq", "fa": "per", "fi": "fin", "fo": "fao", "fr": "fre", "gl": "glg", "gu": "guj", "ha": "hau", "haw": "haw", "he": "heb", "hi": "hin", "hr": "hrv", "ht": "hat", "hu": "hun", "hy": "arm", "id": "ind", "is": "ice", "it": "ita", "ja": "jpn", "jw": "jav", "ka": "geo", "kk": "kaz", "km": "khm", "kn": "kan", "ko": "kor", "la": "lat", "lb": "ltz", "ln": "lin", "lo": "lao", "lt": "lit", "lv": "lav", "mg": "mlg", "mi": "mao", "mk": "mac", "ml": "mal", "mn": "mon", "mr": "mar", "ms": "may", "mt": "mlt", "my": "bur", "ne": "nep", "nl": "dut", "nn": "nno", "no": "nor", "oc": "oci", "pa": "pan", "pl": "pol", "ps": "pus", "pt": "por", "ro": "rum", "ru": "rus", "sa": "san", "sd": "snd", "si": "sin", "sk": "slo", "sl": "slv", "sn": "sna", "so": "som", "sq": "alb", "sr": "srp", "su": "sun", "sv": "swe", "sw": "swa", "ta": "tam", "te": "tel", "tg": "tgk", "th": "tha", "tk": "tuk", "tl": "tgl", "tr": "tur", "tt": "tat", "uk": "ukr", "ur": "urd", "uz": "uzb", "vi": "vie", "yi": "yid", "yo": "yor", "yue": "yue", "zh": "chi", } def separate_speakers(audio_path, num_speakers, temp_dir, device): """ Separate speakers using SpeechBrain SepFormer models """ # Validate and normalize the audio path original_audio_path = os.path.normpath(os.path.abspath(audio_path)) if not os.path.exists(original_audio_path): raise FileNotFoundError(f"Audio file not found: {original_audio_path}") print(f"Original audio file path: {original_audio_path}") print(f"File exists: {os.path.exists(original_audio_path)}") print(f"File size: {os.path.getsize(original_audio_path) if os.path.exists(original_audio_path) else 'N/A'} bytes") # Copy file to a simpler path to avoid Windows path issues with SpeechBrain # Use the temp_dir which should be in the working directory simple_audio_name = "input_audio.wav" simple_audio_path = os.path.join(temp_dir, simple_audio_name) simple_audio_path = os.path.normpath(simple_audio_path) # Ensure the temp directory exists print(f"Creating temp directory: {temp_dir}") os.makedirs(temp_dir, exist_ok=True) print(f"Temp directory exists: {os.path.exists(temp_dir)}") # Copy the file to avoid path corruption issues import shutil shutil.copy2(original_audio_path, simple_audio_path) print(f"Copied audio to simpler path: {simple_audio_path}") print(f"Simple path exists: {os.path.exists(simple_audio_path)}") # Try different path formats for SpeechBrain compatibility audio_path_for_speechbrain = simple_audio_path.replace('\\', '/') audio_path_relative = os.path.relpath(simple_audio_path) print(f"Path for SpeechBrain (forward slashes): {audio_path_for_speechbrain}") print(f"Path for SpeechBrain (relative): {audio_path_relative}") # Use the simpler path for processing audio_path = simple_audio_path print(f"Separating {num_speakers} speakers from audio...") # First, get the original audio sample rate original_waveform, original_sample_rate = torchaudio.load(audio_path) print(f"Original audio sample rate: {original_sample_rate}Hz") # Choose the appropriate model based on number of speakers if num_speakers == 2: model_name = "speechbrain/sepformer-libri2mix" fallback_model = "speechbrain/sepformer-wsj02mix" elif num_speakers == 3: model_name = "speechbrain/sepformer-wsj03mix" fallback_model = "speechbrain/sepformer-libri3mix" else: raise ValueError("Only 2 or 3 speakers are supported for separation") models_to_try = [model_name, fallback_model] if num_speakers in [2, 3] else [model_name] for model_attempt, current_model in enumerate(models_to_try): try: print(f"Trying separation model: {current_model}") # Load the separation model separator = SepformerSeparation.from_hparams( source=current_model, savedir=os.path.join(temp_dir, "sepformer_models"), run_opts={"device": device} ) # Separate the audio - try different path formats paths_to_try = [ audio_path_for_speechbrain, # Forward slashes audio_path_relative, # Relative path simple_audio_path # Original normalized path ] est_sources = None for path_attempt, path_to_try in enumerate(paths_to_try): try: print(f"Attempt {path_attempt + 1}: Calling SpeechBrain with path: {path_to_try}") est_sources = separator.separate_file(path=path_to_try) print(f"Success with path format {path_attempt + 1}") break except Exception as path_error: print(f"Path attempt {path_attempt + 1} failed: {path_error}") if path_attempt == len(paths_to_try) - 1: # This was the last attempt, re-raise the error raise print(f"Separated sources shape: {est_sources.shape}") # Check if we have the expected number of sources if len(est_sources.shape) == 3: actual_num_sources = est_sources.shape[2] elif len(est_sources.shape) == 2: actual_num_sources = 1 est_sources = est_sources.unsqueeze(2) else: raise ValueError(f"Unexpected tensor shape: {est_sources.shape}") print(f"Expected {num_speakers} speakers, got {actual_num_sources} separated sources") if actual_num_sources < num_speakers: logging.warning(f"Only {actual_num_sources} sources separated, but {num_speakers} were requested.") if actual_num_sources == 0: raise ValueError("No sources were separated") num_speakers_to_process = actual_num_sources else: num_speakers_to_process = num_speakers # Save separated audio files with proper sample rate handling separated_files = [] # Create resampler if needed target_sample_rate = original_sample_rate # Keep original sample rate for playback processing_sample_rate = 16000 # Use 16kHz for processing pipeline for i in range(num_speakers_to_process): # Paths for both versions separated_path_original = os.path.join(temp_dir, f"separated_speaker_{i+1}_original.wav") separated_path_processing = os.path.join(temp_dir, f"separated_speaker_{i+1}.wav") # Extract the source audio if actual_num_sources == 1: source_audio = est_sources[:, :, 0].cpu() else: source_audio = est_sources[:, :, i].cpu() # Ensure the audio is in the right format if source_audio.dim() == 1: source_audio = source_audio.unsqueeze(0) elif source_audio.dim() > 2: source_audio = source_audio[0:1, :] # The separated audio is at the original sample rate from SpeechBrain # SpeechBrain typically works at 8kHz for these models, but let's detect separation_sample_rate = 8000 # SpeechBrain Sepformer models typically use 8kHz # Save original quality version for listening if target_sample_rate != separation_sample_rate: resampler_to_original = torchaudio.transforms.Resample( orig_freq=separation_sample_rate, new_freq=target_sample_rate ) source_audio_original = resampler_to_original(source_audio) else: source_audio_original = source_audio torchaudio.save(separated_path_original, source_audio_original, target_sample_rate) print(f"Saved original quality audio for speaker {i+1}: {separated_path_original} ({target_sample_rate}Hz)") # Save processing version at 16kHz for transcription pipeline if processing_sample_rate != separation_sample_rate: resampler_to_processing = torchaudio.transforms.Resample( orig_freq=separation_sample_rate, new_freq=processing_sample_rate ) source_audio_processing = resampler_to_processing(source_audio) else: source_audio_processing = source_audio torchaudio.save(separated_path_processing, source_audio_processing, processing_sample_rate) print(f"Saved processing audio for speaker {i+1}: {separated_path_processing} ({processing_sample_rate}Hz)") # Use the processing version for the pipeline separated_files.append(separated_path_processing) if len(separated_files) == 0: raise ValueError("No audio sources could be separated") # If we have fewer speakers than requested, duplicate the last one while len(separated_files) < num_speakers: last_file = separated_files[-1] new_speaker_id = len(separated_files) + 1 duplicated_path_processing = os.path.join(temp_dir, f"separated_speaker_{new_speaker_id}.wav") duplicated_path_original = os.path.join(temp_dir, f"separated_speaker_{new_speaker_id}_original.wav") # Copy the last separated files waveform, sample_rate = torchaudio.load(last_file) torchaudio.save(duplicated_path_processing, waveform, sample_rate) # Also copy the original quality version last_original_file = last_file.replace(".wav", "_original.wav") if os.path.exists(last_original_file): waveform_orig, sample_rate_orig = torchaudio.load(last_original_file) torchaudio.save(duplicated_path_original, waveform_orig, sample_rate_orig) separated_files.append(duplicated_path_processing) print(f"Duplicated speaker audio for speaker {new_speaker_id}") return separated_files except Exception as e: logging.error(f"Model {current_model} failed: {e}") if model_attempt < len(models_to_try) - 1: print(f"Trying next model...") continue else: # All models failed - stop execution and provide helpful error message error_msg = f""" ╔══════════════════════════════════════════════════════════════════════════════╗ ║ SPEAKER SEPARATION FAILED ║ ╚══════════════════════════════════════════════════════════════════════════════╝ ❌ All SpeechBrain separation models failed to load or process your audio file. 🔍 Possible causes: • Audio file path contains special characters or spaces • Audio file format is not supported by SpeechBrain models • File permissions issue • Audio file is corrupted or too short 💡 Solutions: 1. Try copying your audio file to the current directory with a simple name: copy "{audio_path}" "./audio_simple.wav" 2. Use forward slashes in the path: python diarize1.py -a "C:/path/to/your/audio.wav" --whisper-model large-v2 3. Run WITHOUT speaker separation (standard diarization mode): python diarize1.py -a "{audio_path}" --whisper-model large-v2 --batch-size 4 4. Check that your audio file is a valid WAV/MP3 file that can be played normally Cannot continue with speaker separation mode. Please fix the issue or use standard diarization mode. """ # Use safe printing to avoid Unicode encoding errors try: print(error_msg) except UnicodeEncodeError: print("Speaker separation failed for all models. Please check the logs for details.") logging.error("Speaker separation failed for all models. Terminating separation mode.") raise RuntimeError( f"Speaker separation failed: All SpeechBrain models could not process the audio file '{audio_path}'. " f"Check file path, format, and permissions. Consider using standard diarization mode instead (remove --num-speakers argument)." ) def process_separated_audio(audio_path, speaker_id, args, language, temp_dir): """ Process a single separated audio file - TRANSCRIPTION ONLY, NO VAD/DIARIZATION """ print(f"Processing separated audio for Speaker {speaker_id}") mtypes = {"cpu": "int8", "cuda": "float16"} # Apply Demucs if enabled if args.stemming: vocal_target_dir = os.path.join(temp_dir, f"speaker_{speaker_id}_demucs") return_code = os.system( f'python -m demucs.separate -n htdemucs --two-stems=vocals "{audio_path}" -o "{vocal_target_dir}" --device "{args.device}"' ) if return_code != 0: logging.warning(f"Demucs failed for speaker {speaker_id}, using original separated audio") vocal_target = audio_path else: vocal_target = os.path.join( vocal_target_dir, "htdemucs", os.path.splitext(os.path.basename(audio_path))[0], "vocals.wav", ) else: vocal_target = audio_path # Transcribe the audio file whisper_model = faster_whisper.WhisperModel( args.model_name, device=args.device, compute_type=mtypes[args.device] ) whisper_pipeline = faster_whisper.BatchedInferencePipeline(whisper_model) audio_waveform = faster_whisper.decode_audio(vocal_target) suppress_tokens = ( find_numeral_symbol_tokens(whisper_model.hf_tokenizer) if args.suppress_numerals else [-1] ) if args.batch_size > 0: transcript_segments, info = whisper_pipeline.transcribe( audio_waveform, language, suppress_tokens=suppress_tokens, batch_size=args.batch_size, ) else: transcript_segments, info = whisper_model.transcribe( audio_waveform, language, suppress_tokens=suppress_tokens, vad_filter=True, ) full_transcript = "".join(segment.text for segment in transcript_segments) # Convert transcript segments to list for easier processing segments_list = list(transcript_segments) # Clear GPU VRAM del whisper_model, whisper_pipeline torch.cuda.empty_cache() # Forced Alignment (optional, for word-level timestamps) try: alignment_model, alignment_tokenizer = load_alignment_model( args.device, dtype=torch.float16 if args.device == "cuda" else torch.float32, ) emissions, stride = generate_emissions( alignment_model, torch.from_numpy(audio_waveform) .to(alignment_model.dtype) .to(alignment_model.device), batch_size=args.batch_size, ) del alignment_model torch.cuda.empty_cache() tokens_starred, text_starred = preprocess_text( full_transcript, romanize=True, language=langs_to_iso[info.language], ) segments, scores, blank_token = get_alignments( emissions, tokens_starred, alignment_tokenizer, ) spans = get_spans(tokens_starred, segments, blank_token) word_timestamps = postprocess_results(text_starred, spans, stride, scores) print(f"Forced alignment completed for speaker {speaker_id}") except Exception as e: print(f"Forced alignment failed for speaker {speaker_id}: {e}") print("Using Whisper segment timestamps instead") # Fallback to Whisper's segment-level timestamps word_timestamps = [] for segment in segments_list: # Create a simple word-level timestamp from segment words = segment.text.strip().split() if words: segment_duration = segment.end - segment.start word_duration = segment_duration / len(words) for i, word in enumerate(words): word_start = segment.start + (i * word_duration) word_end = word_start + word_duration word_timestamps.append({ "text": word, "start": word_start, "end": word_end }) # Add punctuation if available if info.language in punct_model_langs: try: from deepmultilingualpunctuation import PunctuationModel punct_model = PunctuationModel(model="kredor/punctuate-all") words_list = [wt["text"] for wt in word_timestamps] labled_words = punct_model.predict(words_list, chunk_size=230) ending_puncts = ".?!" model_puncts = ".,;:!?" is_acronym = lambda x: re.fullmatch(r"\b(?:[a-zA-Z]\.){2,}", x) for word_dict, labeled_tuple in zip(word_timestamps, labled_words): word = word_dict["text"] if ( word and labeled_tuple[1] in ending_puncts and (word[-1] not in model_puncts or is_acronym(word)) ): word += labeled_tuple[1] if word.endswith(".."): word = word.rstrip(".") word_dict["text"] = word print(f"Punctuation restoration completed for speaker {speaker_id}") except Exception as e: print(f"Punctuation restoration failed for speaker {speaker_id}: {e}") # Create simple sentences from word timestamps # Since we skipped VAD/diarization, all words belong to this speaker sentences = [] current_sentence = { "speaker": f"Speaker {speaker_id}", "start_time": 0, "end_time": 0, "text": "" } sentence_endings = ".?!" for i, word_data in enumerate(word_timestamps): word = word_data["text"] start_ms = int(word_data["start"] * 1000) end_ms = int(word_data["end"] * 1000) if i == 0: current_sentence["start_time"] = start_ms current_sentence["end_time"] = end_ms current_sentence["text"] += word + " " # Check if this word ends a sentence if any(word.endswith(punct) for punct in sentence_endings): sentences.append(current_sentence.copy()) current_sentence = { "speaker": f"Speaker {speaker_id}", "start_time": end_ms, "end_time": end_ms, "text": "" } # Add the last sentence if it has content if current_sentence["text"].strip(): sentences.append(current_sentence) # If no sentences were created, create one from the full transcript if not sentences: audio_duration_ms = int(len(audio_waveform) / 16000 * 1000) sentences = [{ "speaker": f"Speaker {speaker_id}", "start_time": 0, "end_time": audio_duration_ms, "text": full_transcript.strip() }] print(f"Created {len(sentences)} sentences for speaker {speaker_id}") return { "speaker_id": speaker_id, "sentences": sentences, "language": info.language } def create_config(output_dir): DOMAIN_TYPE = "telephonic" CONFIG_LOCAL_DIRECTORY = "nemo_msdd_configs" CONFIG_FILE_NAME = f"diar_infer_{DOMAIN_TYPE}.yaml" MODEL_CONFIG_PATH = os.path.join(CONFIG_LOCAL_DIRECTORY, CONFIG_FILE_NAME) if not os.path.exists(MODEL_CONFIG_PATH): os.makedirs(CONFIG_LOCAL_DIRECTORY, exist_ok=True) CONFIG_URL = f"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/{CONFIG_FILE_NAME}" MODEL_CONFIG_PATH = wget.download(CONFIG_URL, MODEL_CONFIG_PATH) config = OmegaConf.load(MODEL_CONFIG_PATH) data_dir = os.path.join(output_dir, "data") os.makedirs(data_dir, exist_ok=True) meta = { "audio_filepath": os.path.join(output_dir, "mono_file.wav"), "offset": 0, "duration": None, "label": "infer", "text": "-", "rttm_filepath": None, "uem_filepath": None, } with open(os.path.join(data_dir, "input_manifest.json"), "w") as fp: json.dump(meta, fp) fp.write("\n") pretrained_vad = "vad_multilingual_marblenet" pretrained_speaker_model = "titanet_large" config.num_workers = 0 config.diarizer.manifest_filepath = os.path.join(data_dir, "input_manifest.json") config.diarizer.out_dir = ( output_dir # Directory to store intermediate files and prediction outputs ) config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model config.diarizer.oracle_vad = ( False # compute VAD provided with model_path to vad config ) config.diarizer.clustering.parameters.oracle_num_speakers = False # Here, we use our in-house pretrained NeMo VAD model config.diarizer.vad.model_path = pretrained_vad config.diarizer.vad.parameters.onset = 0.8 config.diarizer.vad.parameters.offset = 0.6 config.diarizer.vad.parameters.pad_offset = -0.05 config.diarizer.msdd_model.model_path = ( "diar_msdd_telephonic" # Telephonic speaker diarization model ) return config def get_word_ts_anchor(s, e, option="start"): if option == "end": return e elif option == "mid": return (s + e) / 2 return s def get_words_speaker_mapping(wrd_ts, spk_ts, word_anchor_option="start"): s, e, sp = spk_ts[0] wrd_pos, turn_idx = 0, 0 wrd_spk_mapping = [] for wrd_dict in wrd_ts: ws, we, wrd = ( int(wrd_dict["start"] * 1000), int(wrd_dict["end"] * 1000), wrd_dict["text"], ) wrd_pos = get_word_ts_anchor(ws, we, word_anchor_option) while wrd_pos > float(e): turn_idx += 1 turn_idx = min(turn_idx, len(spk_ts) - 1) s, e, sp = spk_ts[turn_idx] if turn_idx == len(spk_ts) - 1: e = get_word_ts_anchor(ws, we, option="end") wrd_spk_mapping.append( {"word": wrd, "start_time": ws, "end_time": we, "speaker": sp} ) return wrd_spk_mapping sentence_ending_punctuations = ".?!" def get_first_word_idx_of_sentence(word_idx, word_list, speaker_list, max_words): is_word_sentence_end = ( lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations ) left_idx = word_idx while ( left_idx > 0 and word_idx - left_idx < max_words and speaker_list[left_idx - 1] == speaker_list[left_idx] and not is_word_sentence_end(left_idx - 1) ): left_idx -= 1 return left_idx if left_idx == 0 or is_word_sentence_end(left_idx - 1) else -1 def get_last_word_idx_of_sentence(word_idx, word_list, max_words): is_word_sentence_end = ( lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations ) right_idx = word_idx while ( right_idx < len(word_list) - 1 and right_idx - word_idx < max_words and not is_word_sentence_end(right_idx) ): right_idx += 1 return ( right_idx if right_idx == len(word_list) - 1 or is_word_sentence_end(right_idx) else -1 ) def get_realigned_ws_mapping_with_punctuation( word_speaker_mapping, max_words_in_sentence=50 ): is_word_sentence_end = ( lambda x: x >= 0 and word_speaker_mapping[x]["word"][-1] in sentence_ending_punctuations ) wsp_len = len(word_speaker_mapping) words_list, speaker_list = [], [] for k, line_dict in enumerate(word_speaker_mapping): word, speaker = line_dict["word"], line_dict["speaker"] words_list.append(word) speaker_list.append(speaker) k = 0 while k < len(word_speaker_mapping): line_dict = word_speaker_mapping[k] if ( k < wsp_len - 1 and speaker_list[k] != speaker_list[k + 1] and not is_word_sentence_end(k) ): left_idx = get_first_word_idx_of_sentence( k, words_list, speaker_list, max_words_in_sentence ) right_idx = ( get_last_word_idx_of_sentence( k, words_list, max_words_in_sentence - k + left_idx - 1 ) if left_idx > -1 else -1 ) if min(left_idx, right_idx) == -1: k += 1 continue spk_labels = speaker_list[left_idx : right_idx + 1] mod_speaker = max(set(spk_labels), key=spk_labels.count) if spk_labels.count(mod_speaker) < len(spk_labels) // 2: k += 1 continue speaker_list[left_idx : right_idx + 1] = [mod_speaker] * ( right_idx - left_idx + 1 ) k = right_idx k += 1 k, realigned_list = 0, [] while k < len(word_speaker_mapping): line_dict = word_speaker_mapping[k].copy() line_dict["speaker"] = speaker_list[k] realigned_list.append(line_dict) k += 1 return realigned_list def get_sentences_speaker_mapping(word_speaker_mapping, spk_ts): sentence_checker = nltk.tokenize.PunktSentenceTokenizer().text_contains_sentbreak s, e, spk = spk_ts[0] prev_spk = spk snts = [] snt = {"speaker": f"Speaker {spk}", "start_time": s, "end_time": e, "text": ""} for wrd_dict in word_speaker_mapping: wrd, spk = wrd_dict["word"], wrd_dict["speaker"] s, e = wrd_dict["start_time"], wrd_dict["end_time"] if spk != prev_spk or sentence_checker(snt["text"] + " " + wrd): snts.append(snt) snt = { "speaker": f"Speaker {spk}", "start_time": s, "end_time": e, "text": "", } else: snt["end_time"] = e snt["text"] += wrd + " " prev_spk = spk snts.append(snt) return snts def get_speaker_aware_transcript(sentences_speaker_mapping, f): previous_speaker = sentences_speaker_mapping[0]["speaker"] f.write(f"{previous_speaker}: ") for sentence_dict in sentences_speaker_mapping: speaker = sentence_dict["speaker"] sentence = sentence_dict["text"] # If this speaker doesn't match the previous one, start a new paragraph if speaker != previous_speaker: f.write(f"\n\n{speaker}: ") previous_speaker = speaker # No matter what, write the current sentence f.write(sentence + " ") def format_timestamp( milliseconds: float, always_include_hours: bool = False, decimal_marker: str = "." ): assert milliseconds >= 0, "non-negative timestamp expected" hours = milliseconds // 3_600_000 milliseconds -= hours * 3_600_000 minutes = milliseconds // 60_000 milliseconds -= minutes * 60_000 seconds = milliseconds // 1_000 milliseconds -= seconds * 1_000 hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" return ( f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" ) def write_srt(transcript, file): """ Write a transcript to a file in SRT format. """ for i, segment in enumerate(transcript, start=1): # write srt lines print( f"{i}\n" f"{format_timestamp(segment['start_time'], always_include_hours=True, decimal_marker=',')} --> " f"{format_timestamp(segment['end_time'], always_include_hours=True, decimal_marker=',')}\n" f"{segment['speaker']}: {segment['text'].strip().replace('-->', '->')}\n", file=file, flush=True, ) def find_numeral_symbol_tokens(tokenizer): numeral_symbol_tokens = [ -1, ] for token, token_id in tokenizer.get_vocab().items(): has_numeral_symbol = any(c in "0123456789%$£" for c in token) if has_numeral_symbol: numeral_symbol_tokens.append(token_id) return numeral_symbol_tokens def _get_next_start_timestamp(word_timestamps, current_word_index, final_timestamp): # if current word is the last word if current_word_index == len(word_timestamps) - 1: return word_timestamps[current_word_index]["start"] next_word_index = current_word_index + 1 while current_word_index < len(word_timestamps) - 1: if word_timestamps[next_word_index].get("start") is None: # if next word doesn't have a start timestamp # merge it with the current word and delete it word_timestamps[current_word_index]["word"] += ( " " + word_timestamps[next_word_index]["word"] ) word_timestamps[next_word_index]["word"] = None next_word_index += 1 if next_word_index == len(word_timestamps): return final_timestamp else: return word_timestamps[next_word_index]["start"] def filter_missing_timestamps( word_timestamps, initial_timestamp=0, final_timestamp=None ): # handle the first and last word if word_timestamps[0].get("start") is None: word_timestamps[0]["start"] = ( initial_timestamp if initial_timestamp is not None else 0 ) word_timestamps[0]["end"] = _get_next_start_timestamp( word_timestamps, 0, final_timestamp ) result = [ word_timestamps[0], ] for i, ws in enumerate(word_timestamps[1:], start=1): # if ws doesn't have a start and end # use the previous end as start and next start as end if ws.get("start") is None and ws.get("word") is not None: ws["start"] = word_timestamps[i - 1]["end"] ws["end"] = _get_next_start_timestamp(word_timestamps, i, final_timestamp) if ws["word"] is not None: result.append(ws) return result def cleanup(path: str): """path could either be relative or absolute.""" # check if file or directory exists if os.path.isfile(path) or os.path.islink(path): # remove file os.remove(path) elif os.path.isdir(path): # remove directory and all its content shutil.rmtree(path) else: raise ValueError(f"Path {path} is not a file or dir.") def process_language_arg(language: str, model_name: str): """ Process the language argument to make sure it's valid and convert language names to language codes. """ if language is not None: language = language.lower() if language not in LANGUAGES: if language in TO_LANGUAGE_CODE: language = TO_LANGUAGE_CODE[language] else: raise ValueError(f"Unsupported language: {language}") if model_name.endswith(".en") and language != "en": raise ValueError( f"{model_name} is an English-only model but choosen language is '{language}'" ) return language