voice-denoising / train_dtln.py
grgsaliba's picture
Upload train_dtln.py with huggingface_hub
e5d5706 verified
raw
history blame
12.5 kB
"""
Training script for DTLN model with Quantization-Aware Training (QAT)
Optimized for deployment on Alif E7 Ethos-U55 NPU
"""
import tensorflow as tf
import tensorflow_model_optimization as tfmot
import numpy as np
import soundfile as sf
import librosa
from pathlib import Path
import argparse
from dtln_ethos_u55 import DTLN_Ethos_U55
import os
class AudioDataGenerator(tf.keras.utils.Sequence):
"""
Data generator for training audio denoising models
Loads clean and noisy audio pairs
"""
def __init__(
self,
clean_audio_dir,
noise_audio_dir,
batch_size=16,
frame_len=512,
frame_shift=128,
sampling_rate=16000,
snr_range=(0, 20),
shuffle=True
):
"""
Args:
clean_audio_dir: Directory containing clean speech files
noise_audio_dir: Directory containing noise files
batch_size: Batch size for training
frame_len: Frame length in samples
frame_shift: Frame shift in samples
sampling_rate: Target sampling rate
snr_range: Range of SNR for mixing (min, max) in dB
shuffle: Whether to shuffle data each epoch
"""
self.clean_files = list(Path(clean_audio_dir).glob('**/*.wav'))
self.noise_files = list(Path(noise_audio_dir).glob('**/*.wav'))
self.batch_size = batch_size
self.frame_len = frame_len
self.frame_shift = frame_shift
self.sampling_rate = sampling_rate
self.snr_range = snr_range
self.shuffle = shuffle
# Segment length for training (1 second)
self.segment_len = sampling_rate
self.on_epoch_end()
def __len__(self):
"""Return number of batches per epoch"""
return len(self.clean_files) // self.batch_size
def __getitem__(self, index):
"""Generate one batch of data"""
# Select files for this batch
batch_indices = self.indices[
index * self.batch_size:(index + 1) * self.batch_size
]
batch_clean = []
batch_noisy = []
for idx in batch_indices:
clean_audio = self._load_audio(self.clean_files[idx])
noise_audio = self._load_random_noise()
# Mix clean and noise at random SNR
noisy_audio = self._mix_audio(clean_audio, noise_audio)
batch_clean.append(clean_audio)
batch_noisy.append(noisy_audio)
return np.array(batch_noisy), np.array(batch_clean)
def on_epoch_end(self):
"""Update indices after each epoch"""
self.indices = np.arange(len(self.clean_files))
if self.shuffle:
np.random.shuffle(self.indices)
def _load_audio(self, file_path):
"""Load and preprocess audio file"""
audio, sr = sf.read(file_path)
# Resample if needed
if sr != self.sampling_rate:
audio = librosa.resample(
audio,
orig_sr=sr,
target_sr=self.sampling_rate
)
# Convert to mono if stereo
if len(audio.shape) > 1:
audio = np.mean(audio, axis=1)
# Trim or pad to segment length
if len(audio) > self.segment_len:
start = np.random.randint(0, len(audio) - self.segment_len)
audio = audio[start:start + self.segment_len]
else:
audio = np.pad(audio, (0, self.segment_len - len(audio)))
# Normalize
audio = audio / (np.max(np.abs(audio)) + 1e-8)
return audio.astype(np.float32)
def _load_random_noise(self):
"""Load random noise file"""
noise_file = np.random.choice(self.noise_files)
return self._load_audio(noise_file)
def _mix_audio(self, clean, noise):
"""Mix clean audio with noise at random SNR"""
snr = np.random.uniform(*self.snr_range)
# Calculate noise power
clean_power = np.mean(clean ** 2)
noise_power = np.mean(noise ** 2)
# Calculate noise scaling factor
snr_linear = 10 ** (snr / 10)
noise_scale = np.sqrt(clean_power / (snr_linear * noise_power + 1e-8))
# Mix
noisy = clean + noise_scale * noise
# Normalize to prevent clipping
noisy = noisy / (np.max(np.abs(noisy)) + 1e-8) * 0.95
return noisy.astype(np.float32)
def apply_quantization_aware_training(model):
"""
Apply quantization-aware training for 8-bit deployment
Args:
model: Keras model to quantize
Returns:
Quantization-aware model
"""
# Quantize the entire model
quantize_model = tfmot.quantization.keras.quantize_model
# Use default quantization config
q_aware_model = quantize_model(model)
return q_aware_model
def create_loss_function():
"""
Create custom loss function combining time and frequency domain losses
"""
def combined_loss(y_true, y_pred):
# Time domain MSE
time_loss = tf.reduce_mean(tf.square(y_true - y_pred))
# Frequency domain loss (STFT-based)
stft_true = tf.signal.stft(
y_true,
frame_length=512,
frame_step=128
)
stft_pred = tf.signal.stft(
y_pred,
frame_length=512,
frame_step=128
)
mag_true = tf.abs(stft_true)
mag_pred = tf.abs(stft_pred)
freq_loss = tf.reduce_mean(tf.square(mag_true - mag_pred))
# Combined loss (weighted)
return 0.7 * time_loss + 0.3 * freq_loss
return combined_loss
def train_model(
clean_dir,
noise_dir,
output_dir='./models',
epochs=50,
batch_size=16,
lstm_units=128,
learning_rate=0.001,
use_qat=True
):
"""
Main training function
Args:
clean_dir: Directory with clean speech
noise_dir: Directory with noise files
output_dir: Directory to save models
epochs: Number of training epochs
batch_size: Training batch size
lstm_units: Number of LSTM units
learning_rate: Learning rate for Adam optimizer
use_qat: Whether to use quantization-aware training
"""
# Create output directory
os.makedirs(output_dir, exist_ok=True)
print("="*60)
print("Training DTLN for Alif E7 Ethos-U55")
print("="*60)
# Create model
print("\n1. Building model...")
dtln = DTLN_Ethos_U55(
frame_len=512,
frame_shift=128,
lstm_units=lstm_units,
sampling_rate=16000
)
model = dtln.build_model()
model.summary()
# Apply QAT if requested
if use_qat:
print("\n2. Applying Quantization-Aware Training...")
model = apply_quantization_aware_training(model)
print(" ✓ QAT applied")
# Compile model
print("\n3. Compiling model...")
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
loss=create_loss_function(),
metrics=['mae']
)
print(" ✓ Model compiled")
# Create data generators
print("\n4. Creating data generators...")
train_generator = AudioDataGenerator(
clean_audio_dir=clean_dir,
noise_audio_dir=noise_dir,
batch_size=batch_size,
frame_len=512,
frame_shift=128,
sampling_rate=16000,
snr_range=(0, 20),
shuffle=True
)
print(f" ✓ Training samples: {len(train_generator) * batch_size}")
# Callbacks
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join(output_dir, 'best_model.h5'),
monitor='loss',
save_best_only=True,
verbose=1
),
tf.keras.callbacks.ReduceLROnPlateau(
monitor='loss',
factor=0.5,
patience=5,
min_lr=1e-6,
verbose=1
),
tf.keras.callbacks.EarlyStopping(
monitor='loss',
patience=10,
restore_best_weights=True,
verbose=1
),
tf.keras.callbacks.TensorBoard(
log_dir=os.path.join(output_dir, 'logs'),
histogram_freq=1
)
]
# Train
print("\n5. Starting training...")
print("="*60)
history = model.fit(
train_generator,
epochs=epochs,
callbacks=callbacks,
verbose=1
)
# Save final model
final_model_path = os.path.join(
output_dir,
'dtln_ethos_u55_final.h5'
)
model.save(final_model_path)
print(f"\n✓ Training complete! Model saved to {final_model_path}")
return model, history
def train_with_pretrained_dtln(
pretrained_weights_path,
clean_dir,
noise_dir,
output_dir='./models',
epochs=20,
batch_size=16
):
"""
Fine-tune from pre-trained DTLN weights
Args:
pretrained_weights_path: Path to pretrained DTLN weights
clean_dir: Directory with clean speech
noise_dir: Directory with noise files
output_dir: Output directory
epochs: Number of fine-tuning epochs
batch_size: Training batch size
"""
print("Fine-tuning from pretrained DTLN weights...")
# Build model
dtln = DTLN_Ethos_U55(lstm_units=128)
model = dtln.build_model()
# Load pretrained weights (if architecture matches)
try:
model.load_weights(pretrained_weights_path, by_name=True)
print("✓ Pretrained weights loaded")
except:
print("⚠ Could not load pretrained weights, training from scratch")
# Continue training
return train_model(
clean_dir=clean_dir,
noise_dir=noise_dir,
output_dir=output_dir,
epochs=epochs,
batch_size=batch_size,
use_qat=True
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Train DTLN model for Alif E7 Ethos-U55'
)
parser.add_argument(
'--clean-dir',
type=str,
required=True,
help='Directory containing clean speech files'
)
parser.add_argument(
'--noise-dir',
type=str,
required=True,
help='Directory containing noise files'
)
parser.add_argument(
'--output-dir',
type=str,
default='./models',
help='Output directory for models'
)
parser.add_argument(
'--epochs',
type=int,
default=50,
help='Number of training epochs'
)
parser.add_argument(
'--batch-size',
type=int,
default=16,
help='Training batch size'
)
parser.add_argument(
'--lstm-units',
type=int,
default=128,
help='Number of LSTM units'
)
parser.add_argument(
'--learning-rate',
type=float,
default=0.001,
help='Learning rate'
)
parser.add_argument(
'--no-qat',
action='store_true',
help='Disable quantization-aware training'
)
parser.add_argument(
'--pretrained',
type=str,
default=None,
help='Path to pretrained weights for fine-tuning'
)
args = parser.parse_args()
# Train model
if args.pretrained:
model, history = train_with_pretrained_dtln(
pretrained_weights_path=args.pretrained,
clean_dir=args.clean_dir,
noise_dir=args.noise_dir,
output_dir=args.output_dir,
epochs=args.epochs,
batch_size=args.batch_size
)
else:
model, history = train_model(
clean_dir=args.clean_dir,
noise_dir=args.noise_dir,
output_dir=args.output_dir,
epochs=args.epochs,
batch_size=args.batch_size,
lstm_units=args.lstm_units,
learning_rate=args.learning_rate,
use_qat=not args.no_qat
)
print("\n" + "="*60)
print("Training Summary:")
print(f" Final loss: {history.history['loss'][-1]:.4f}")
print(f" Best loss: {min(history.history['loss']):.4f}")
print(f" Model saved to: {args.output_dir}")
print("="*60)