NOI_3_ZIP / utils.py
hynt's picture
Update utils.py
cb5ba27
raw
history blame
6.94 kB
from pydub import AudioSegment, silence
import tempfile
import hashlib
import matplotlib.pylab as plt
import librosa
from transformers import pipeline
import re
_ref_audio_cache = {}
asr_pipe = None
def chunk_text(text, max_chars=135):
# print(text)
# Bước 1: Tách câu theo dấu ". "
sentences = [s.strip() for s in text.split('. ') if s.strip()]
# Ghép câu ngắn hơn 4 từ với câu liền kề
i = 0
while i < len(sentences):
if len(sentences[i].split()) < 4:
if i == 0 and i + 1 < len(sentences):
# Ghép với câu sau
sentences[i + 1] = sentences[i] + ', ' + sentences[i + 1]
del sentences[i]
else:
if i - 1 >= 0:
# Ghép với câu trước
sentences[i - 1] = sentences[i - 1] + ', ' + sentences[i]
del sentences[i]
i -= 1
else:
i += 1
# print(sentences)
# Bước 2: Tách phần quá dài trong câu theo dấu ", "
final_sentences = []
for sentence in sentences:
parts = [p.strip() for p in sentence.split(', ')]
buffer = []
for part in parts:
buffer.append(part)
total_words = sum(len(p.split()) for p in buffer)
if total_words > 20:
# Tách câu ra
long_part = ', '.join(buffer)
final_sentences.append(long_part)
buffer = []
if buffer:
final_sentences.append(', '.join(buffer))
# print(final_sentences)
if len(final_sentences[-1].split()) < 4 and len(final_sentences) >= 2:
final_sentences[-2] = final_sentences[-2] + ", " + final_sentences[-1]
final_sentences = final_sentences[0:-1]
# print(final_sentences)
return final_sentences
def initialize_asr_pipeline(device="cuda", dtype=None):
if dtype is None:
dtype = (
torch.float16
if "cuda" in device
and torch.cuda.get_device_properties(device).major >= 6
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
else torch.float32
)
global asr_pipe
asr_pipe = pipeline(
"automatic-speech-recognition",
model="vinai/PhoWhisper-medium",
torch_dtype=dtype,
device=device,
)
# transcribe
def transcribe(ref_audio, language=None):
global asr_pipe
if asr_pipe is None:
initialize_asr_pipeline(device="cuda")
return asr_pipe(
ref_audio,
chunk_length_s=30,
batch_size=128,
generate_kwargs={"task": "transcribe", "language": language} if language else {"task": "transcribe"},
return_timestamps=False,
)["text"].strip()
def caculate_spec(audio):
# Compute spectrogram (Short-Time Fourier Transform)
stft = librosa.stft(audio, n_fft=512, hop_length=256, win_length=512)
spectrogram = np.abs(stft)
# Convert to dB
spectrogram_db = librosa.amplitude_to_db(spectrogram, ref=np.max)
return spectrogram_db
def save_spectrogram(audio, path):
spectrogram = caculate_spec(audio)
plt.figure(figsize=(12, 4))
plt.imshow(spectrogram, origin="lower", aspect="auto")
plt.colorbar()
plt.savefig(path)
plt.close()
def remove_silence_edges(audio, silence_threshold=-42):
# Remove silence from the start
non_silent_start_idx = silence.detect_leading_silence(audio, silence_threshold=silence_threshold)
audio = audio[non_silent_start_idx:]
# Remove silence from the end
non_silent_end_duration = audio.duration_seconds
for ms in reversed(audio):
if ms.dBFS > silence_threshold:
break
non_silent_end_duration -= 0.001
trimmed_audio = audio[: int(non_silent_end_duration * 1000)]
return trimmed_audio
def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print, device="cuda"):
show_info("Converting audio...")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
aseg = AudioSegment.from_file(ref_audio_orig)
if clip_short:
# 1. try to find long silence for clipping
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
show_info("Audio is over 15s, clipping short. (1)")
break
non_silent_wave += non_silent_seg
# 2. try to find short silence for clipping if 1. failed
if len(non_silent_wave) > 15000:
non_silent_segs = silence.split_on_silence(
aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
)
non_silent_wave = AudioSegment.silent(duration=0)
for non_silent_seg in non_silent_segs:
if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
show_info("Audio is over 15s, clipping short. (2)")
break
non_silent_wave += non_silent_seg
aseg = non_silent_wave
# 3. if no proper silence found for clipping
if len(aseg) > 15000:
aseg = aseg[:15000]
show_info("Audio is over 15s, clipping short. (3)")
aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
aseg.export(f.name, format="wav")
ref_audio = f.name
# Compute a hash of the reference audio file
with open(ref_audio, "rb") as audio_file:
audio_data = audio_file.read()
audio_hash = hashlib.md5(audio_data).hexdigest()
if not ref_text.strip():
global _ref_audio_cache
if audio_hash in _ref_audio_cache:
# Use cached asr transcription
show_info("Using cached reference text...")
ref_text = _ref_audio_cache[audio_hash]
else:
show_info("No reference text provided, transcribing reference audio...")
ref_text = transcribe(ref_audio)
# Cache the transcribed text (not caching custom ref_text, enabling users to do manual tweak)
_ref_audio_cache[audio_hash] = ref_text
else:
show_info("Using custom reference text...")
# Ensure ref_text ends with a proper sentence-ending punctuation
if not ref_text.endswith(". ") and not ref_text.endswith("。"):
if ref_text.endswith("."):
ref_text += " "
else:
ref_text += ". "
print("\nref_text ", ref_text)
return ref_audio, ref_text