voice-denoising / train_with_hf_datasets.py
grgsaliba's picture
Upload train_with_hf_datasets.py with huggingface_hub
2a37d6d verified
"""
Train DTLN model using Hugging Face datasets
Uses real speech and noise datasets for production-quality training
"""
import tensorflow as tf
import tensorflow_model_optimization as tfmot
import numpy as np
import soundfile as sf
import librosa
from pathlib import Path
import argparse
from dtln_ethos_u55 import DTLN_Ethos_U55
import os
from datasets import load_dataset, Audio
from tqdm import tqdm
class HuggingFaceAudioDataGenerator(tf.keras.utils.Sequence):
"""
Data generator using Hugging Face datasets
Loads clean speech and noise from HF Hub
"""
def __init__(
self,
clean_dataset_name="librispeech_asr",
noise_dataset_name="dns-challenge/dns-challenge-4",
clean_split="train.clean.100",
noise_split="train",
batch_size=16,
samples_per_epoch=1000,
frame_len=512,
frame_shift=128,
sampling_rate=16000,
snr_range=(0, 20),
shuffle=True,
cache_dir=None
):
"""
Args:
clean_dataset_name: HF dataset for clean speech (default: LibriSpeech)
noise_dataset_name: HF dataset for noise (default: DNS Challenge)
clean_split: Split to use from clean dataset
noise_split: Split to use from noise dataset
batch_size: Batch size for training
samples_per_epoch: Number of samples per epoch
frame_len: Frame length in samples
frame_shift: Frame shift in samples
sampling_rate: Target sampling rate
snr_range: Range of SNR for mixing (min, max) in dB
shuffle: Whether to shuffle data each epoch
cache_dir: Directory to cache datasets
"""
print(f"\n{'='*60}")
print("Initializing Hugging Face Dataset Generator")
print(f"{'='*60}")
self.batch_size = batch_size
self.samples_per_epoch = samples_per_epoch
self.frame_len = frame_len
self.frame_shift = frame_shift
self.sampling_rate = sampling_rate
self.snr_range = snr_range
self.shuffle = shuffle
# Segment length for training (1 second)
self.segment_len = sampling_rate
# Load datasets
print(f"\n1. Loading clean speech dataset: {clean_dataset_name}")
print(f" Split: {clean_split}")
try:
self.clean_dataset = load_dataset(
clean_dataset_name,
split=clean_split,
streaming=True, # Stream for large datasets
cache_dir=cache_dir
)
# Cast audio to correct sampling rate
self.clean_dataset = self.clean_dataset.cast_column(
"audio",
Audio(sampling_rate=sampling_rate)
)
print(f" ✓ Clean speech dataset loaded")
except Exception as e:
print(f" ⚠ Error loading clean dataset: {e}")
print(f" Using fallback: common_voice")
self.clean_dataset = load_dataset(
"mozilla-foundation/common_voice_11_0",
"en",
split="train",
streaming=True,
cache_dir=cache_dir
)
self.clean_dataset = self.clean_dataset.cast_column(
"audio",
Audio(sampling_rate=sampling_rate)
)
print(f"\n2. Loading noise dataset: {noise_dataset_name}")
print(f" Split: {noise_split}")
try:
self.noise_dataset = load_dataset(
noise_dataset_name,
split=noise_split,
streaming=True,
cache_dir=cache_dir
)
self.noise_dataset = self.noise_dataset.cast_column(
"audio",
Audio(sampling_rate=sampling_rate)
)
print(f" ✓ Noise dataset loaded")
except Exception as e:
print(f" ⚠ Error loading noise dataset: {e}")
print(f" Using synthetic noise instead")
self.noise_dataset = None
# Create iterators
self.clean_iter = iter(self.clean_dataset)
if self.noise_dataset:
self.noise_iter = iter(self.noise_dataset)
self.on_epoch_end()
print(f"\n{'='*60}")
print(f"Dataset Generator Ready")
print(f" Batch size: {batch_size}")
print(f" Samples per epoch: {samples_per_epoch}")
print(f" Batches per epoch: {len(self)}")
print(f"{'='*60}\n")
def __len__(self):
"""Return number of batches per epoch"""
return self.samples_per_epoch // self.batch_size
def __getitem__(self, index):
"""Generate one batch of data"""
batch_clean = []
batch_noisy = []
for _ in range(self.batch_size):
try:
# Get clean audio
clean_audio = self._load_next_clean_audio()
# Get noise
if self.noise_dataset:
noise_audio = self._load_next_noise_audio()
else:
# Generate synthetic noise
noise_audio = np.random.randn(self.segment_len).astype(np.float32) * 0.1
# Mix clean and noise at random SNR
noisy_audio = self._mix_audio(clean_audio, noise_audio)
batch_clean.append(clean_audio)
batch_noisy.append(noisy_audio)
except Exception as e:
# If error, use white noise as fallback
print(f" Warning: Error loading sample: {e}")
noise = np.random.randn(self.segment_len).astype(np.float32) * 0.01
batch_clean.append(noise)
batch_noisy.append(noise)
return np.array(batch_noisy), np.array(batch_clean)
def on_epoch_end(self):
"""Reset iterators at epoch end"""
pass
def _load_next_clean_audio(self):
"""Load next clean audio sample"""
try:
sample = next(self.clean_iter)
audio = sample['audio']['array']
except StopIteration:
# Restart iterator
self.clean_iter = iter(self.clean_dataset)
sample = next(self.clean_iter)
audio = sample['audio']['array']
return self._preprocess_audio(audio)
def _load_next_noise_audio(self):
"""Load next noise sample"""
try:
sample = next(self.noise_iter)
if 'audio' in sample:
audio = sample['audio']['array']
elif 'noise' in sample:
audio = sample['noise']['array']
else:
# Fallback to white noise
audio = np.random.randn(self.segment_len).astype(np.float32) * 0.1
except StopIteration:
# Restart iterator
self.noise_iter = iter(self.noise_dataset)
sample = next(self.noise_iter)
audio = sample['audio']['array']
return self._preprocess_audio(audio)
def _preprocess_audio(self, audio):
"""Preprocess audio to target length and format"""
# Convert to float32
audio = audio.astype(np.float32)
# Trim or pad to segment length
if len(audio) > self.segment_len:
start = np.random.randint(0, len(audio) - self.segment_len)
audio = audio[start:start + self.segment_len]
else:
audio = np.pad(audio, (0, self.segment_len - len(audio)))
# Normalize
max_val = np.max(np.abs(audio))
if max_val > 1e-8:
audio = audio / max_val
return audio
def _mix_audio(self, clean, noise):
"""Mix clean audio with noise at random SNR"""
snr = np.random.uniform(*self.snr_range)
# Calculate noise power
clean_power = np.mean(clean ** 2)
noise_power = np.mean(noise ** 2)
# Calculate noise scaling factor
if noise_power > 1e-8:
snr_linear = 10 ** (snr / 10)
noise_scale = np.sqrt(clean_power / (snr_linear * noise_power))
else:
noise_scale = 0.1
# Mix
noisy = clean + noise_scale * noise
# Normalize to prevent clipping
max_val = np.max(np.abs(noisy))
if max_val > 1e-8:
noisy = noisy / max_val * 0.95
return noisy.astype(np.float32)
def create_loss_function():
"""
Create custom loss function combining time and frequency domain losses
"""
def combined_loss(y_true, y_pred):
# Time domain MSE
time_loss = tf.reduce_mean(tf.square(y_true - y_pred))
# Frequency domain loss (STFT-based)
stft_true = tf.signal.stft(
y_true,
frame_length=512,
frame_step=128
)
stft_pred = tf.signal.stft(
y_pred,
frame_length=512,
frame_step=128
)
mag_true = tf.abs(stft_true)
mag_pred = tf.abs(stft_pred)
freq_loss = tf.reduce_mean(tf.square(mag_true - mag_pred))
# Combined loss (weighted)
return 0.7 * time_loss + 0.3 * freq_loss
return combined_loss
def train_model(
clean_dataset="librispeech_asr",
noise_dataset="dns-challenge/dns-challenge-4",
clean_split="train.clean.100",
noise_split="train",
output_dir='./models',
epochs=50,
batch_size=16,
samples_per_epoch=1000,
lstm_units=128,
learning_rate=0.001,
use_qat=True,
cache_dir=None
):
"""
Main training function using HF datasets
Args:
clean_dataset: HF dataset name for clean speech
noise_dataset: HF dataset name for noise
clean_split: Split for clean dataset
noise_split: Split for noise dataset
output_dir: Directory to save models
epochs: Number of training epochs
batch_size: Training batch size
samples_per_epoch: Samples per epoch
lstm_units: Number of LSTM units
learning_rate: Learning rate for Adam optimizer
use_qat: Whether to use quantization-aware training
cache_dir: Directory to cache datasets
"""
# Create output directory
os.makedirs(output_dir, exist_ok=True)
print("="*60)
print("Training DTLN with Hugging Face Datasets")
print("="*60)
# Create model
print("\n1. Building DTLN model...")
dtln = DTLN_Ethos_U55(
frame_len=512,
frame_shift=128,
lstm_units=lstm_units,
sampling_rate=16000
)
model = dtln.build_model()
print(" ✓ Model built")
print(f" Parameters: {model.count_params():,}")
# Apply QAT if requested
if use_qat:
print("\n2. Applying Quantization-Aware Training...")
quantize_model = tfmot.quantization.keras.quantize_model
model = quantize_model(model)
print(" ✓ QAT applied (INT8 optimized)")
# Compile model
print("\n3. Compiling model...")
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
loss=create_loss_function(),
metrics=['mae']
)
print(" ✓ Model compiled")
# Create data generator
print("\n4. Creating Hugging Face data generator...")
train_generator = HuggingFaceAudioDataGenerator(
clean_dataset_name=clean_dataset,
noise_dataset_name=noise_dataset,
clean_split=clean_split,
noise_split=noise_split,
batch_size=batch_size,
samples_per_epoch=samples_per_epoch,
frame_len=512,
frame_shift=128,
sampling_rate=16000,
snr_range=(0, 20),
shuffle=True,
cache_dir=cache_dir
)
# Callbacks
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join(output_dir, 'best_model.h5'),
monitor='loss',
save_best_only=True,
verbose=1
),
tf.keras.callbacks.ReduceLROnPlateau(
monitor='loss',
factor=0.5,
patience=5,
min_lr=1e-6,
verbose=1
),
tf.keras.callbacks.EarlyStopping(
monitor='loss',
patience=10,
restore_best_weights=True,
verbose=1
),
tf.keras.callbacks.TensorBoard(
log_dir=os.path.join(output_dir, 'logs'),
histogram_freq=1
)
]
# Train
print("\n5. Starting training...")
print("="*60)
history = model.fit(
train_generator,
epochs=epochs,
callbacks=callbacks,
verbose=1
)
# Save final model
final_model_path = os.path.join(output_dir, 'dtln_final.h5')
model.save(final_model_path)
print(f"\n{'='*60}")
print("✓ Training Complete!")
print(f"{'='*60}")
print(f"Final loss: {history.history['loss'][-1]:.4f}")
print(f"Best loss: {min(history.history['loss']):.4f}")
print(f"Model saved to: {final_model_path}")
print(f"{'='*60}\n")
return model, history
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Train DTLN model using Hugging Face datasets'
)
# Dataset arguments
parser.add_argument(
'--clean-dataset',
type=str,
default='librispeech_asr',
help='HF dataset for clean speech (default: librispeech_asr)'
)
parser.add_argument(
'--noise-dataset',
type=str,
default='dns-challenge/dns-challenge-4',
help='HF dataset for noise (default: dns-challenge)'
)
parser.add_argument(
'--clean-split',
type=str,
default='train.clean.100',
help='Split for clean dataset'
)
parser.add_argument(
'--noise-split',
type=str,
default='train',
help='Split for noise dataset'
)
# Training arguments
parser.add_argument(
'--output-dir',
type=str,
default='./models',
help='Output directory for models'
)
parser.add_argument(
'--epochs',
type=int,
default=50,
help='Number of training epochs'
)
parser.add_argument(
'--batch-size',
type=int,
default=16,
help='Training batch size'
)
parser.add_argument(
'--samples-per-epoch',
type=int,
default=1000,
help='Number of samples per epoch'
)
parser.add_argument(
'--lstm-units',
type=int,
default=128,
help='Number of LSTM units'
)
parser.add_argument(
'--learning-rate',
type=float,
default=0.001,
help='Learning rate'
)
parser.add_argument(
'--no-qat',
action='store_true',
help='Disable quantization-aware training'
)
parser.add_argument(
'--cache-dir',
type=str,
default=None,
help='Directory to cache HF datasets'
)
args = parser.parse_args()
# Train model
model, history = train_model(
clean_dataset=args.clean_dataset,
noise_dataset=args.noise_dataset,
clean_split=args.clean_split,
noise_split=args.noise_split,
output_dir=args.output_dir,
epochs=args.epochs,
batch_size=args.batch_size,
samples_per_epoch=args.samples_per_epoch,
lstm_units=args.lstm_units,
learning_rate=args.learning_rate,
use_qat=not args.no_qat,
cache_dir=args.cache_dir
)