soheillotfi's picture
deploy demo
201f32f
import json
import os
import shutil
import logging
import regex as re
import nltk
import wget
import torch
import torchaudio
import faster_whisper
from speechbrain.inference.separation import SepformerSeparation
from omegaconf import OmegaConf
# Import for forced alignment
from ctc_forced_aligner import (
generate_emissions,
get_alignments,
get_spans,
load_alignment_model,
postprocess_results,
preprocess_text,
)
# Import for diarization
from nemo.collections.asr.models.msdd_models import NeuralDiarizer
punct_model_langs = [
"en",
"fr",
"de",
"es",
"it",
"nl",
"pt",
"bg",
"pl",
"cs",
"sk",
"sl",
]
LANGUAGES = {
"en": "english",
"zh": "chinese",
"de": "german",
"es": "spanish",
"ru": "russian",
"ko": "korean",
"fr": "french",
"ja": "japanese",
"pt": "portuguese",
"tr": "turkish",
"pl": "polish",
"ca": "catalan",
"nl": "dutch",
"ar": "arabic",
"sv": "swedish",
"it": "italian",
"id": "indonesian",
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"he": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
"cs": "czech",
"ro": "romanian",
"da": "danish",
"hu": "hungarian",
"ta": "tamil",
"no": "norwegian",
"th": "thai",
"ur": "urdu",
"hr": "croatian",
"bg": "bulgarian",
"lt": "lithuanian",
"la": "latin",
"mi": "maori",
"ml": "malayalam",
"cy": "welsh",
"sk": "slovak",
"te": "telugu",
"fa": "persian",
"lv": "latvian",
"bn": "bengali",
"sr": "serbian",
"az": "azerbaijani",
"sl": "slovenian",
"kn": "kannada",
"et": "estonian",
"mk": "macedonian",
"br": "breton",
"eu": "basque",
"is": "icelandic",
"hy": "armenian",
"ne": "nepali",
"mn": "mongolian",
"bs": "bosnian",
"kk": "kazakh",
"sq": "albanian",
"sw": "swahili",
"gl": "galician",
"mr": "marathi",
"pa": "punjabi",
"si": "sinhala",
"km": "khmer",
"sn": "shona",
"yo": "yoruba",
"so": "somali",
"af": "afrikaans",
"oc": "occitan",
"ka": "georgian",
"be": "belarusian",
"tg": "tajik",
"sd": "sindhi",
"gu": "gujarati",
"am": "amharic",
"yi": "yiddish",
"lo": "lao",
"uz": "uzbek",
"fo": "faroese",
"ht": "haitian creole",
"ps": "pashto",
"tk": "turkmen",
"nn": "nynorsk",
"mt": "maltese",
"sa": "sanskrit",
"lb": "luxembourgish",
"my": "myanmar",
"bo": "tibetan",
"tl": "tagalog",
"mg": "malagasy",
"as": "assamese",
"tt": "tatar",
"haw": "hawaiian",
"ln": "lingala",
"ha": "hausa",
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
"yue": "cantonese",
}
# language code lookup by name, with a few language aliases
TO_LANGUAGE_CODE = {
**{language: code for code, language in LANGUAGES.items()},
"burmese": "my",
"valencian": "ca",
"flemish": "nl",
"haitian": "ht",
"letzeburgesch": "lb",
"pushto": "ps",
"panjabi": "pa",
"moldavian": "ro",
"moldovan": "ro",
"sinhalese": "si",
"castilian": "es",
}
whisper_langs = sorted(LANGUAGES.keys()) + sorted(
[k.title() for k in TO_LANGUAGE_CODE.keys()]
)
langs_to_iso = {
"af": "afr",
"am": "amh",
"ar": "ara",
"as": "asm",
"az": "aze",
"ba": "bak",
"be": "bel",
"bg": "bul",
"bn": "ben",
"bo": "tib",
"br": "bre",
"bs": "bos",
"ca": "cat",
"cs": "cze",
"cy": "wel",
"da": "dan",
"de": "ger",
"el": "gre",
"en": "eng",
"es": "spa",
"et": "est",
"eu": "baq",
"fa": "per",
"fi": "fin",
"fo": "fao",
"fr": "fre",
"gl": "glg",
"gu": "guj",
"ha": "hau",
"haw": "haw",
"he": "heb",
"hi": "hin",
"hr": "hrv",
"ht": "hat",
"hu": "hun",
"hy": "arm",
"id": "ind",
"is": "ice",
"it": "ita",
"ja": "jpn",
"jw": "jav",
"ka": "geo",
"kk": "kaz",
"km": "khm",
"kn": "kan",
"ko": "kor",
"la": "lat",
"lb": "ltz",
"ln": "lin",
"lo": "lao",
"lt": "lit",
"lv": "lav",
"mg": "mlg",
"mi": "mao",
"mk": "mac",
"ml": "mal",
"mn": "mon",
"mr": "mar",
"ms": "may",
"mt": "mlt",
"my": "bur",
"ne": "nep",
"nl": "dut",
"nn": "nno",
"no": "nor",
"oc": "oci",
"pa": "pan",
"pl": "pol",
"ps": "pus",
"pt": "por",
"ro": "rum",
"ru": "rus",
"sa": "san",
"sd": "snd",
"si": "sin",
"sk": "slo",
"sl": "slv",
"sn": "sna",
"so": "som",
"sq": "alb",
"sr": "srp",
"su": "sun",
"sv": "swe",
"sw": "swa",
"ta": "tam",
"te": "tel",
"tg": "tgk",
"th": "tha",
"tk": "tuk",
"tl": "tgl",
"tr": "tur",
"tt": "tat",
"uk": "ukr",
"ur": "urd",
"uz": "uzb",
"vi": "vie",
"yi": "yid",
"yo": "yor",
"yue": "yue",
"zh": "chi",
}
def separate_speakers(audio_path, num_speakers, temp_dir, device):
"""
Separate speakers using SpeechBrain SepFormer models
"""
# Validate and normalize the audio path
original_audio_path = os.path.normpath(os.path.abspath(audio_path))
if not os.path.exists(original_audio_path):
raise FileNotFoundError(f"Audio file not found: {original_audio_path}")
print(f"Original audio file path: {original_audio_path}")
print(f"File exists: {os.path.exists(original_audio_path)}")
print(f"File size: {os.path.getsize(original_audio_path) if os.path.exists(original_audio_path) else 'N/A'} bytes")
# Copy file to a simpler path to avoid Windows path issues with SpeechBrain
# Use the temp_dir which should be in the working directory
simple_audio_name = "input_audio.wav"
simple_audio_path = os.path.join(temp_dir, simple_audio_name)
simple_audio_path = os.path.normpath(simple_audio_path)
# Ensure the temp directory exists
print(f"Creating temp directory: {temp_dir}")
os.makedirs(temp_dir, exist_ok=True)
print(f"Temp directory exists: {os.path.exists(temp_dir)}")
# Copy the file to avoid path corruption issues
import shutil
shutil.copy2(original_audio_path, simple_audio_path)
print(f"Copied audio to simpler path: {simple_audio_path}")
print(f"Simple path exists: {os.path.exists(simple_audio_path)}")
# Try different path formats for SpeechBrain compatibility
audio_path_for_speechbrain = simple_audio_path.replace('\\', '/')
audio_path_relative = os.path.relpath(simple_audio_path)
print(f"Path for SpeechBrain (forward slashes): {audio_path_for_speechbrain}")
print(f"Path for SpeechBrain (relative): {audio_path_relative}")
# Use the simpler path for processing
audio_path = simple_audio_path
print(f"Separating {num_speakers} speakers from audio...")
# First, get the original audio sample rate
original_waveform, original_sample_rate = torchaudio.load(audio_path)
print(f"Original audio sample rate: {original_sample_rate}Hz")
# Choose the appropriate model based on number of speakers
if num_speakers == 2:
model_name = "speechbrain/sepformer-libri2mix"
fallback_model = "speechbrain/sepformer-wsj02mix"
elif num_speakers == 3:
model_name = "speechbrain/sepformer-wsj03mix"
fallback_model = "speechbrain/sepformer-libri3mix"
else:
raise ValueError("Only 2 or 3 speakers are supported for separation")
models_to_try = [model_name, fallback_model] if num_speakers in [2, 3] else [model_name]
for model_attempt, current_model in enumerate(models_to_try):
try:
print(f"Trying separation model: {current_model}")
# Load the separation model
separator = SepformerSeparation.from_hparams(
source=current_model,
savedir=os.path.join(temp_dir, "sepformer_models"),
run_opts={"device": device}
)
# Separate the audio - try different path formats
paths_to_try = [
audio_path_for_speechbrain, # Forward slashes
audio_path_relative, # Relative path
simple_audio_path # Original normalized path
]
est_sources = None
for path_attempt, path_to_try in enumerate(paths_to_try):
try:
print(f"Attempt {path_attempt + 1}: Calling SpeechBrain with path: {path_to_try}")
est_sources = separator.separate_file(path=path_to_try)
print(f"Success with path format {path_attempt + 1}")
break
except Exception as path_error:
print(f"Path attempt {path_attempt + 1} failed: {path_error}")
if path_attempt == len(paths_to_try) - 1:
# This was the last attempt, re-raise the error
raise
print(f"Separated sources shape: {est_sources.shape}")
# Check if we have the expected number of sources
if len(est_sources.shape) == 3:
actual_num_sources = est_sources.shape[2]
elif len(est_sources.shape) == 2:
actual_num_sources = 1
est_sources = est_sources.unsqueeze(2)
else:
raise ValueError(f"Unexpected tensor shape: {est_sources.shape}")
print(f"Expected {num_speakers} speakers, got {actual_num_sources} separated sources")
if actual_num_sources < num_speakers:
logging.warning(f"Only {actual_num_sources} sources separated, but {num_speakers} were requested.")
if actual_num_sources == 0:
raise ValueError("No sources were separated")
num_speakers_to_process = actual_num_sources
else:
num_speakers_to_process = num_speakers
# Save separated audio files with proper sample rate handling
separated_files = []
# Create resampler if needed
target_sample_rate = original_sample_rate # Keep original sample rate for playback
processing_sample_rate = 16000 # Use 16kHz for processing pipeline
for i in range(num_speakers_to_process):
# Paths for both versions
separated_path_original = os.path.join(temp_dir, f"separated_speaker_{i+1}_original.wav")
separated_path_processing = os.path.join(temp_dir, f"separated_speaker_{i+1}.wav")
# Extract the source audio
if actual_num_sources == 1:
source_audio = est_sources[:, :, 0].cpu()
else:
source_audio = est_sources[:, :, i].cpu()
# Ensure the audio is in the right format
if source_audio.dim() == 1:
source_audio = source_audio.unsqueeze(0)
elif source_audio.dim() > 2:
source_audio = source_audio[0:1, :]
# The separated audio is at the original sample rate from SpeechBrain
# SpeechBrain typically works at 8kHz for these models, but let's detect
separation_sample_rate = 8000 # SpeechBrain Sepformer models typically use 8kHz
# Save original quality version for listening
if target_sample_rate != separation_sample_rate:
resampler_to_original = torchaudio.transforms.Resample(
orig_freq=separation_sample_rate,
new_freq=target_sample_rate
)
source_audio_original = resampler_to_original(source_audio)
else:
source_audio_original = source_audio
torchaudio.save(separated_path_original, source_audio_original, target_sample_rate)
print(f"Saved original quality audio for speaker {i+1}: {separated_path_original} ({target_sample_rate}Hz)")
# Save processing version at 16kHz for transcription pipeline
if processing_sample_rate != separation_sample_rate:
resampler_to_processing = torchaudio.transforms.Resample(
orig_freq=separation_sample_rate,
new_freq=processing_sample_rate
)
source_audio_processing = resampler_to_processing(source_audio)
else:
source_audio_processing = source_audio
torchaudio.save(separated_path_processing, source_audio_processing, processing_sample_rate)
print(f"Saved processing audio for speaker {i+1}: {separated_path_processing} ({processing_sample_rate}Hz)")
# Use the processing version for the pipeline
separated_files.append(separated_path_processing)
if len(separated_files) == 0:
raise ValueError("No audio sources could be separated")
# If we have fewer speakers than requested, duplicate the last one
while len(separated_files) < num_speakers:
last_file = separated_files[-1]
new_speaker_id = len(separated_files) + 1
duplicated_path_processing = os.path.join(temp_dir, f"separated_speaker_{new_speaker_id}.wav")
duplicated_path_original = os.path.join(temp_dir, f"separated_speaker_{new_speaker_id}_original.wav")
# Copy the last separated files
waveform, sample_rate = torchaudio.load(last_file)
torchaudio.save(duplicated_path_processing, waveform, sample_rate)
# Also copy the original quality version
last_original_file = last_file.replace(".wav", "_original.wav")
if os.path.exists(last_original_file):
waveform_orig, sample_rate_orig = torchaudio.load(last_original_file)
torchaudio.save(duplicated_path_original, waveform_orig, sample_rate_orig)
separated_files.append(duplicated_path_processing)
print(f"Duplicated speaker audio for speaker {new_speaker_id}")
return separated_files
except Exception as e:
logging.error(f"Model {current_model} failed: {e}")
if model_attempt < len(models_to_try) - 1:
print(f"Trying next model...")
continue
else:
# All models failed - stop execution and provide helpful error message
error_msg = f"""
╔══════════════════════════════════════════════════════════════════════════════╗
║ SPEAKER SEPARATION FAILED ║
╚══════════════════════════════════════════════════════════════════════════════╝
❌ All SpeechBrain separation models failed to load or process your audio file.
🔍 Possible causes:
• Audio file path contains special characters or spaces
• Audio file format is not supported by SpeechBrain models
• File permissions issue
• Audio file is corrupted or too short
💡 Solutions:
1. Try copying your audio file to the current directory with a simple name:
copy "{audio_path}" "./audio_simple.wav"
2. Use forward slashes in the path:
python diarize1.py -a "C:/path/to/your/audio.wav" --whisper-model large-v2
3. Run WITHOUT speaker separation (standard diarization mode):
python diarize1.py -a "{audio_path}" --whisper-model large-v2 --batch-size 4
4. Check that your audio file is a valid WAV/MP3 file that can be played normally
Cannot continue with speaker separation mode. Please fix the issue or use standard diarization mode.
"""
# Use safe printing to avoid Unicode encoding errors
try:
print(error_msg)
except UnicodeEncodeError:
print("Speaker separation failed for all models. Please check the logs for details.")
logging.error("Speaker separation failed for all models. Terminating separation mode.")
raise RuntimeError(
f"Speaker separation failed: All SpeechBrain models could not process the audio file '{audio_path}'. "
f"Check file path, format, and permissions. Consider using standard diarization mode instead (remove --num-speakers argument)."
)
def process_separated_audio(audio_path, speaker_id, args, language, temp_dir):
"""
Process a single separated audio file - TRANSCRIPTION ONLY, NO VAD/DIARIZATION
"""
print(f"Processing separated audio for Speaker {speaker_id}")
mtypes = {"cpu": "int8", "cuda": "float16"}
# Apply Demucs if enabled
if args.stemming:
vocal_target_dir = os.path.join(temp_dir, f"speaker_{speaker_id}_demucs")
return_code = os.system(
f'python -m demucs.separate -n htdemucs --two-stems=vocals "{audio_path}" -o "{vocal_target_dir}" --device "{args.device}"'
)
if return_code != 0:
logging.warning(f"Demucs failed for speaker {speaker_id}, using original separated audio")
vocal_target = audio_path
else:
vocal_target = os.path.join(
vocal_target_dir,
"htdemucs",
os.path.splitext(os.path.basename(audio_path))[0],
"vocals.wav",
)
else:
vocal_target = audio_path
# Transcribe the audio file
whisper_model = faster_whisper.WhisperModel(
args.model_name, device=args.device, compute_type=mtypes[args.device]
)
whisper_pipeline = faster_whisper.BatchedInferencePipeline(whisper_model)
audio_waveform = faster_whisper.decode_audio(vocal_target)
suppress_tokens = (
find_numeral_symbol_tokens(whisper_model.hf_tokenizer)
if args.suppress_numerals
else [-1]
)
if args.batch_size > 0:
transcript_segments, info = whisper_pipeline.transcribe(
audio_waveform,
language,
suppress_tokens=suppress_tokens,
batch_size=args.batch_size,
)
else:
transcript_segments, info = whisper_model.transcribe(
audio_waveform,
language,
suppress_tokens=suppress_tokens,
vad_filter=True,
)
full_transcript = "".join(segment.text for segment in transcript_segments)
# Convert transcript segments to list for easier processing
segments_list = list(transcript_segments)
# Clear GPU VRAM
del whisper_model, whisper_pipeline
torch.cuda.empty_cache()
# Forced Alignment (optional, for word-level timestamps)
try:
alignment_model, alignment_tokenizer = load_alignment_model(
args.device,
dtype=torch.float16 if args.device == "cuda" else torch.float32,
)
emissions, stride = generate_emissions(
alignment_model,
torch.from_numpy(audio_waveform)
.to(alignment_model.dtype)
.to(alignment_model.device),
batch_size=args.batch_size,
)
del alignment_model
torch.cuda.empty_cache()
tokens_starred, text_starred = preprocess_text(
full_transcript,
romanize=True,
language=langs_to_iso[info.language],
)
segments, scores, blank_token = get_alignments(
emissions,
tokens_starred,
alignment_tokenizer,
)
spans = get_spans(tokens_starred, segments, blank_token)
word_timestamps = postprocess_results(text_starred, spans, stride, scores)
print(f"Forced alignment completed for speaker {speaker_id}")
except Exception as e:
print(f"Forced alignment failed for speaker {speaker_id}: {e}")
print("Using Whisper segment timestamps instead")
# Fallback to Whisper's segment-level timestamps
word_timestamps = []
for segment in segments_list:
# Create a simple word-level timestamp from segment
words = segment.text.strip().split()
if words:
segment_duration = segment.end - segment.start
word_duration = segment_duration / len(words)
for i, word in enumerate(words):
word_start = segment.start + (i * word_duration)
word_end = word_start + word_duration
word_timestamps.append({
"text": word,
"start": word_start,
"end": word_end
})
# Add punctuation if available
if info.language in punct_model_langs:
try:
from deepmultilingualpunctuation import PunctuationModel
punct_model = PunctuationModel(model="kredor/punctuate-all")
words_list = [wt["text"] for wt in word_timestamps]
labled_words = punct_model.predict(words_list, chunk_size=230)
ending_puncts = ".?!"
model_puncts = ".,;:!?"
is_acronym = lambda x: re.fullmatch(r"\b(?:[a-zA-Z]\.){2,}", x)
for word_dict, labeled_tuple in zip(word_timestamps, labled_words):
word = word_dict["text"]
if (
word
and labeled_tuple[1] in ending_puncts
and (word[-1] not in model_puncts or is_acronym(word))
):
word += labeled_tuple[1]
if word.endswith(".."):
word = word.rstrip(".")
word_dict["text"] = word
print(f"Punctuation restoration completed for speaker {speaker_id}")
except Exception as e:
print(f"Punctuation restoration failed for speaker {speaker_id}: {e}")
# Create simple sentences from word timestamps
# Since we skipped VAD/diarization, all words belong to this speaker
sentences = []
current_sentence = {
"speaker": f"Speaker {speaker_id}",
"start_time": 0,
"end_time": 0,
"text": ""
}
sentence_endings = ".?!"
for i, word_data in enumerate(word_timestamps):
word = word_data["text"]
start_ms = int(word_data["start"] * 1000)
end_ms = int(word_data["end"] * 1000)
if i == 0:
current_sentence["start_time"] = start_ms
current_sentence["end_time"] = end_ms
current_sentence["text"] += word + " "
# Check if this word ends a sentence
if any(word.endswith(punct) for punct in sentence_endings):
sentences.append(current_sentence.copy())
current_sentence = {
"speaker": f"Speaker {speaker_id}",
"start_time": end_ms,
"end_time": end_ms,
"text": ""
}
# Add the last sentence if it has content
if current_sentence["text"].strip():
sentences.append(current_sentence)
# If no sentences were created, create one from the full transcript
if not sentences:
audio_duration_ms = int(len(audio_waveform) / 16000 * 1000)
sentences = [{
"speaker": f"Speaker {speaker_id}",
"start_time": 0,
"end_time": audio_duration_ms,
"text": full_transcript.strip()
}]
print(f"Created {len(sentences)} sentences for speaker {speaker_id}")
return {
"speaker_id": speaker_id,
"sentences": sentences,
"language": info.language
}
def create_config(output_dir):
DOMAIN_TYPE = "telephonic"
CONFIG_LOCAL_DIRECTORY = "nemo_msdd_configs"
CONFIG_FILE_NAME = f"diar_infer_{DOMAIN_TYPE}.yaml"
MODEL_CONFIG_PATH = os.path.join(CONFIG_LOCAL_DIRECTORY, CONFIG_FILE_NAME)
if not os.path.exists(MODEL_CONFIG_PATH):
os.makedirs(CONFIG_LOCAL_DIRECTORY, exist_ok=True)
CONFIG_URL = f"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/{CONFIG_FILE_NAME}"
MODEL_CONFIG_PATH = wget.download(CONFIG_URL, MODEL_CONFIG_PATH)
config = OmegaConf.load(MODEL_CONFIG_PATH)
data_dir = os.path.join(output_dir, "data")
os.makedirs(data_dir, exist_ok=True)
meta = {
"audio_filepath": os.path.join(output_dir, "mono_file.wav"),
"offset": 0,
"duration": None,
"label": "infer",
"text": "-",
"rttm_filepath": None,
"uem_filepath": None,
}
with open(os.path.join(data_dir, "input_manifest.json"), "w") as fp:
json.dump(meta, fp)
fp.write("\n")
pretrained_vad = "vad_multilingual_marblenet"
pretrained_speaker_model = "titanet_large"
config.num_workers = 0
config.diarizer.manifest_filepath = os.path.join(data_dir, "input_manifest.json")
config.diarizer.out_dir = (
output_dir # Directory to store intermediate files and prediction outputs
)
config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model
config.diarizer.oracle_vad = (
False # compute VAD provided with model_path to vad config
)
config.diarizer.clustering.parameters.oracle_num_speakers = False
# Here, we use our in-house pretrained NeMo VAD model
config.diarizer.vad.model_path = pretrained_vad
config.diarizer.vad.parameters.onset = 0.8
config.diarizer.vad.parameters.offset = 0.6
config.diarizer.vad.parameters.pad_offset = -0.05
config.diarizer.msdd_model.model_path = (
"diar_msdd_telephonic" # Telephonic speaker diarization model
)
return config
def get_word_ts_anchor(s, e, option="start"):
if option == "end":
return e
elif option == "mid":
return (s + e) / 2
return s
def get_words_speaker_mapping(wrd_ts, spk_ts, word_anchor_option="start"):
s, e, sp = spk_ts[0]
wrd_pos, turn_idx = 0, 0
wrd_spk_mapping = []
for wrd_dict in wrd_ts:
ws, we, wrd = (
int(wrd_dict["start"] * 1000),
int(wrd_dict["end"] * 1000),
wrd_dict["text"],
)
wrd_pos = get_word_ts_anchor(ws, we, word_anchor_option)
while wrd_pos > float(e):
turn_idx += 1
turn_idx = min(turn_idx, len(spk_ts) - 1)
s, e, sp = spk_ts[turn_idx]
if turn_idx == len(spk_ts) - 1:
e = get_word_ts_anchor(ws, we, option="end")
wrd_spk_mapping.append(
{"word": wrd, "start_time": ws, "end_time": we, "speaker": sp}
)
return wrd_spk_mapping
sentence_ending_punctuations = ".?!"
def get_first_word_idx_of_sentence(word_idx, word_list, speaker_list, max_words):
is_word_sentence_end = (
lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations
)
left_idx = word_idx
while (
left_idx > 0
and word_idx - left_idx < max_words
and speaker_list[left_idx - 1] == speaker_list[left_idx]
and not is_word_sentence_end(left_idx - 1)
):
left_idx -= 1
return left_idx if left_idx == 0 or is_word_sentence_end(left_idx - 1) else -1
def get_last_word_idx_of_sentence(word_idx, word_list, max_words):
is_word_sentence_end = (
lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations
)
right_idx = word_idx
while (
right_idx < len(word_list) - 1
and right_idx - word_idx < max_words
and not is_word_sentence_end(right_idx)
):
right_idx += 1
return (
right_idx
if right_idx == len(word_list) - 1 or is_word_sentence_end(right_idx)
else -1
)
def get_realigned_ws_mapping_with_punctuation(
word_speaker_mapping, max_words_in_sentence=50
):
is_word_sentence_end = (
lambda x: x >= 0
and word_speaker_mapping[x]["word"][-1] in sentence_ending_punctuations
)
wsp_len = len(word_speaker_mapping)
words_list, speaker_list = [], []
for k, line_dict in enumerate(word_speaker_mapping):
word, speaker = line_dict["word"], line_dict["speaker"]
words_list.append(word)
speaker_list.append(speaker)
k = 0
while k < len(word_speaker_mapping):
line_dict = word_speaker_mapping[k]
if (
k < wsp_len - 1
and speaker_list[k] != speaker_list[k + 1]
and not is_word_sentence_end(k)
):
left_idx = get_first_word_idx_of_sentence(
k, words_list, speaker_list, max_words_in_sentence
)
right_idx = (
get_last_word_idx_of_sentence(
k, words_list, max_words_in_sentence - k + left_idx - 1
)
if left_idx > -1
else -1
)
if min(left_idx, right_idx) == -1:
k += 1
continue
spk_labels = speaker_list[left_idx : right_idx + 1]
mod_speaker = max(set(spk_labels), key=spk_labels.count)
if spk_labels.count(mod_speaker) < len(spk_labels) // 2:
k += 1
continue
speaker_list[left_idx : right_idx + 1] = [mod_speaker] * (
right_idx - left_idx + 1
)
k = right_idx
k += 1
k, realigned_list = 0, []
while k < len(word_speaker_mapping):
line_dict = word_speaker_mapping[k].copy()
line_dict["speaker"] = speaker_list[k]
realigned_list.append(line_dict)
k += 1
return realigned_list
def get_sentences_speaker_mapping(word_speaker_mapping, spk_ts):
sentence_checker = nltk.tokenize.PunktSentenceTokenizer().text_contains_sentbreak
s, e, spk = spk_ts[0]
prev_spk = spk
snts = []
snt = {"speaker": f"Speaker {spk}", "start_time": s, "end_time": e, "text": ""}
for wrd_dict in word_speaker_mapping:
wrd, spk = wrd_dict["word"], wrd_dict["speaker"]
s, e = wrd_dict["start_time"], wrd_dict["end_time"]
if spk != prev_spk or sentence_checker(snt["text"] + " " + wrd):
snts.append(snt)
snt = {
"speaker": f"Speaker {spk}",
"start_time": s,
"end_time": e,
"text": "",
}
else:
snt["end_time"] = e
snt["text"] += wrd + " "
prev_spk = spk
snts.append(snt)
return snts
def get_speaker_aware_transcript(sentences_speaker_mapping, f):
previous_speaker = sentences_speaker_mapping[0]["speaker"]
f.write(f"{previous_speaker}: ")
for sentence_dict in sentences_speaker_mapping:
speaker = sentence_dict["speaker"]
sentence = sentence_dict["text"]
# If this speaker doesn't match the previous one, start a new paragraph
if speaker != previous_speaker:
f.write(f"\n\n{speaker}: ")
previous_speaker = speaker
# No matter what, write the current sentence
f.write(sentence + " ")
def format_timestamp(
milliseconds: float, always_include_hours: bool = False, decimal_marker: str = "."
):
assert milliseconds >= 0, "non-negative timestamp expected"
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return (
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
)
def write_srt(transcript, file):
"""
Write a transcript to a file in SRT format.
"""
for i, segment in enumerate(transcript, start=1):
# write srt lines
print(
f"{i}\n"
f"{format_timestamp(segment['start_time'], always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment['end_time'], always_include_hours=True, decimal_marker=',')}\n"
f"{segment['speaker']}: {segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
def find_numeral_symbol_tokens(tokenizer):
numeral_symbol_tokens = [
-1,
]
for token, token_id in tokenizer.get_vocab().items():
has_numeral_symbol = any(c in "0123456789%$£" for c in token)
if has_numeral_symbol:
numeral_symbol_tokens.append(token_id)
return numeral_symbol_tokens
def _get_next_start_timestamp(word_timestamps, current_word_index, final_timestamp):
# if current word is the last word
if current_word_index == len(word_timestamps) - 1:
return word_timestamps[current_word_index]["start"]
next_word_index = current_word_index + 1
while current_word_index < len(word_timestamps) - 1:
if word_timestamps[next_word_index].get("start") is None:
# if next word doesn't have a start timestamp
# merge it with the current word and delete it
word_timestamps[current_word_index]["word"] += (
" " + word_timestamps[next_word_index]["word"]
)
word_timestamps[next_word_index]["word"] = None
next_word_index += 1
if next_word_index == len(word_timestamps):
return final_timestamp
else:
return word_timestamps[next_word_index]["start"]
def filter_missing_timestamps(
word_timestamps, initial_timestamp=0, final_timestamp=None
):
# handle the first and last word
if word_timestamps[0].get("start") is None:
word_timestamps[0]["start"] = (
initial_timestamp if initial_timestamp is not None else 0
)
word_timestamps[0]["end"] = _get_next_start_timestamp(
word_timestamps, 0, final_timestamp
)
result = [
word_timestamps[0],
]
for i, ws in enumerate(word_timestamps[1:], start=1):
# if ws doesn't have a start and end
# use the previous end as start and next start as end
if ws.get("start") is None and ws.get("word") is not None:
ws["start"] = word_timestamps[i - 1]["end"]
ws["end"] = _get_next_start_timestamp(word_timestamps, i, final_timestamp)
if ws["word"] is not None:
result.append(ws)
return result
def cleanup(path: str):
"""path could either be relative or absolute."""
# check if file or directory exists
if os.path.isfile(path) or os.path.islink(path):
# remove file
os.remove(path)
elif os.path.isdir(path):
# remove directory and all its content
shutil.rmtree(path)
else:
raise ValueError(f"Path {path} is not a file or dir.")
def process_language_arg(language: str, model_name: str):
"""
Process the language argument to make sure it's valid
and convert language names to language codes.
"""
if language is not None:
language = language.lower()
if language not in LANGUAGES:
if language in TO_LANGUAGE_CODE:
language = TO_LANGUAGE_CODE[language]
else:
raise ValueError(f"Unsupported language: {language}")
if model_name.endswith(".en") and language != "en":
raise ValueError(
f"{model_name} is an English-only model but choosen language is '{language}'"
)
return language