""" Train DTLN model using Hugging Face datasets Uses real speech and noise datasets for production-quality training """ import tensorflow as tf import tensorflow_model_optimization as tfmot import numpy as np import soundfile as sf import librosa from pathlib import Path import argparse from dtln_ethos_u55 import DTLN_Ethos_U55 import os from datasets import load_dataset, Audio from tqdm import tqdm class HuggingFaceAudioDataGenerator(tf.keras.utils.Sequence): """ Data generator using Hugging Face datasets Loads clean speech and noise from HF Hub """ def __init__( self, clean_dataset_name="librispeech_asr", noise_dataset_name="dns-challenge/dns-challenge-4", clean_split="train.clean.100", noise_split="train", batch_size=16, samples_per_epoch=1000, frame_len=512, frame_shift=128, sampling_rate=16000, snr_range=(0, 20), shuffle=True, cache_dir=None ): """ Args: clean_dataset_name: HF dataset for clean speech (default: LibriSpeech) noise_dataset_name: HF dataset for noise (default: DNS Challenge) clean_split: Split to use from clean dataset noise_split: Split to use from noise dataset batch_size: Batch size for training samples_per_epoch: Number of samples per epoch frame_len: Frame length in samples frame_shift: Frame shift in samples sampling_rate: Target sampling rate snr_range: Range of SNR for mixing (min, max) in dB shuffle: Whether to shuffle data each epoch cache_dir: Directory to cache datasets """ print(f"\n{'='*60}") print("Initializing Hugging Face Dataset Generator") print(f"{'='*60}") self.batch_size = batch_size self.samples_per_epoch = samples_per_epoch self.frame_len = frame_len self.frame_shift = frame_shift self.sampling_rate = sampling_rate self.snr_range = snr_range self.shuffle = shuffle # Segment length for training (1 second) self.segment_len = sampling_rate # Load datasets print(f"\n1. Loading clean speech dataset: {clean_dataset_name}") print(f" Split: {clean_split}") try: self.clean_dataset = load_dataset( clean_dataset_name, split=clean_split, streaming=True, # Stream for large datasets cache_dir=cache_dir ) # Cast audio to correct sampling rate self.clean_dataset = self.clean_dataset.cast_column( "audio", Audio(sampling_rate=sampling_rate) ) print(f" ✓ Clean speech dataset loaded") except Exception as e: print(f" ⚠ Error loading clean dataset: {e}") print(f" Using fallback: common_voice") self.clean_dataset = load_dataset( "mozilla-foundation/common_voice_11_0", "en", split="train", streaming=True, cache_dir=cache_dir ) self.clean_dataset = self.clean_dataset.cast_column( "audio", Audio(sampling_rate=sampling_rate) ) print(f"\n2. Loading noise dataset: {noise_dataset_name}") print(f" Split: {noise_split}") try: self.noise_dataset = load_dataset( noise_dataset_name, split=noise_split, streaming=True, cache_dir=cache_dir ) self.noise_dataset = self.noise_dataset.cast_column( "audio", Audio(sampling_rate=sampling_rate) ) print(f" ✓ Noise dataset loaded") except Exception as e: print(f" ⚠ Error loading noise dataset: {e}") print(f" Using synthetic noise instead") self.noise_dataset = None # Create iterators self.clean_iter = iter(self.clean_dataset) if self.noise_dataset: self.noise_iter = iter(self.noise_dataset) self.on_epoch_end() print(f"\n{'='*60}") print(f"Dataset Generator Ready") print(f" Batch size: {batch_size}") print(f" Samples per epoch: {samples_per_epoch}") print(f" Batches per epoch: {len(self)}") print(f"{'='*60}\n") def __len__(self): """Return number of batches per epoch""" return self.samples_per_epoch // self.batch_size def __getitem__(self, index): """Generate one batch of data""" batch_clean = [] batch_noisy = [] for _ in range(self.batch_size): try: # Get clean audio clean_audio = self._load_next_clean_audio() # Get noise if self.noise_dataset: noise_audio = self._load_next_noise_audio() else: # Generate synthetic noise noise_audio = np.random.randn(self.segment_len).astype(np.float32) * 0.1 # Mix clean and noise at random SNR noisy_audio = self._mix_audio(clean_audio, noise_audio) batch_clean.append(clean_audio) batch_noisy.append(noisy_audio) except Exception as e: # If error, use white noise as fallback print(f" Warning: Error loading sample: {e}") noise = np.random.randn(self.segment_len).astype(np.float32) * 0.01 batch_clean.append(noise) batch_noisy.append(noise) return np.array(batch_noisy), np.array(batch_clean) def on_epoch_end(self): """Reset iterators at epoch end""" pass def _load_next_clean_audio(self): """Load next clean audio sample""" try: sample = next(self.clean_iter) audio = sample['audio']['array'] except StopIteration: # Restart iterator self.clean_iter = iter(self.clean_dataset) sample = next(self.clean_iter) audio = sample['audio']['array'] return self._preprocess_audio(audio) def _load_next_noise_audio(self): """Load next noise sample""" try: sample = next(self.noise_iter) if 'audio' in sample: audio = sample['audio']['array'] elif 'noise' in sample: audio = sample['noise']['array'] else: # Fallback to white noise audio = np.random.randn(self.segment_len).astype(np.float32) * 0.1 except StopIteration: # Restart iterator self.noise_iter = iter(self.noise_dataset) sample = next(self.noise_iter) audio = sample['audio']['array'] return self._preprocess_audio(audio) def _preprocess_audio(self, audio): """Preprocess audio to target length and format""" # Convert to float32 audio = audio.astype(np.float32) # Trim or pad to segment length if len(audio) > self.segment_len: start = np.random.randint(0, len(audio) - self.segment_len) audio = audio[start:start + self.segment_len] else: audio = np.pad(audio, (0, self.segment_len - len(audio))) # Normalize max_val = np.max(np.abs(audio)) if max_val > 1e-8: audio = audio / max_val return audio def _mix_audio(self, clean, noise): """Mix clean audio with noise at random SNR""" snr = np.random.uniform(*self.snr_range) # Calculate noise power clean_power = np.mean(clean ** 2) noise_power = np.mean(noise ** 2) # Calculate noise scaling factor if noise_power > 1e-8: snr_linear = 10 ** (snr / 10) noise_scale = np.sqrt(clean_power / (snr_linear * noise_power)) else: noise_scale = 0.1 # Mix noisy = clean + noise_scale * noise # Normalize to prevent clipping max_val = np.max(np.abs(noisy)) if max_val > 1e-8: noisy = noisy / max_val * 0.95 return noisy.astype(np.float32) def create_loss_function(): """ Create custom loss function combining time and frequency domain losses """ def combined_loss(y_true, y_pred): # Time domain MSE time_loss = tf.reduce_mean(tf.square(y_true - y_pred)) # Frequency domain loss (STFT-based) stft_true = tf.signal.stft( y_true, frame_length=512, frame_step=128 ) stft_pred = tf.signal.stft( y_pred, frame_length=512, frame_step=128 ) mag_true = tf.abs(stft_true) mag_pred = tf.abs(stft_pred) freq_loss = tf.reduce_mean(tf.square(mag_true - mag_pred)) # Combined loss (weighted) return 0.7 * time_loss + 0.3 * freq_loss return combined_loss def train_model( clean_dataset="librispeech_asr", noise_dataset="dns-challenge/dns-challenge-4", clean_split="train.clean.100", noise_split="train", output_dir='./models', epochs=50, batch_size=16, samples_per_epoch=1000, lstm_units=128, learning_rate=0.001, use_qat=True, cache_dir=None ): """ Main training function using HF datasets Args: clean_dataset: HF dataset name for clean speech noise_dataset: HF dataset name for noise clean_split: Split for clean dataset noise_split: Split for noise dataset output_dir: Directory to save models epochs: Number of training epochs batch_size: Training batch size samples_per_epoch: Samples per epoch lstm_units: Number of LSTM units learning_rate: Learning rate for Adam optimizer use_qat: Whether to use quantization-aware training cache_dir: Directory to cache datasets """ # Create output directory os.makedirs(output_dir, exist_ok=True) print("="*60) print("Training DTLN with Hugging Face Datasets") print("="*60) # Create model print("\n1. Building DTLN model...") dtln = DTLN_Ethos_U55( frame_len=512, frame_shift=128, lstm_units=lstm_units, sampling_rate=16000 ) model = dtln.build_model() print(" ✓ Model built") print(f" Parameters: {model.count_params():,}") # Apply QAT if requested if use_qat: print("\n2. Applying Quantization-Aware Training...") quantize_model = tfmot.quantization.keras.quantize_model model = quantize_model(model) print(" ✓ QAT applied (INT8 optimized)") # Compile model print("\n3. Compiling model...") model.compile( optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), loss=create_loss_function(), metrics=['mae'] ) print(" ✓ Model compiled") # Create data generator print("\n4. Creating Hugging Face data generator...") train_generator = HuggingFaceAudioDataGenerator( clean_dataset_name=clean_dataset, noise_dataset_name=noise_dataset, clean_split=clean_split, noise_split=noise_split, batch_size=batch_size, samples_per_epoch=samples_per_epoch, frame_len=512, frame_shift=128, sampling_rate=16000, snr_range=(0, 20), shuffle=True, cache_dir=cache_dir ) # Callbacks callbacks = [ tf.keras.callbacks.ModelCheckpoint( filepath=os.path.join(output_dir, 'best_model.h5'), monitor='loss', save_best_only=True, verbose=1 ), tf.keras.callbacks.ReduceLROnPlateau( monitor='loss', factor=0.5, patience=5, min_lr=1e-6, verbose=1 ), tf.keras.callbacks.EarlyStopping( monitor='loss', patience=10, restore_best_weights=True, verbose=1 ), tf.keras.callbacks.TensorBoard( log_dir=os.path.join(output_dir, 'logs'), histogram_freq=1 ) ] # Train print("\n5. Starting training...") print("="*60) history = model.fit( train_generator, epochs=epochs, callbacks=callbacks, verbose=1 ) # Save final model final_model_path = os.path.join(output_dir, 'dtln_final.h5') model.save(final_model_path) print(f"\n{'='*60}") print("✓ Training Complete!") print(f"{'='*60}") print(f"Final loss: {history.history['loss'][-1]:.4f}") print(f"Best loss: {min(history.history['loss']):.4f}") print(f"Model saved to: {final_model_path}") print(f"{'='*60}\n") return model, history if __name__ == "__main__": parser = argparse.ArgumentParser( description='Train DTLN model using Hugging Face datasets' ) # Dataset arguments parser.add_argument( '--clean-dataset', type=str, default='librispeech_asr', help='HF dataset for clean speech (default: librispeech_asr)' ) parser.add_argument( '--noise-dataset', type=str, default='dns-challenge/dns-challenge-4', help='HF dataset for noise (default: dns-challenge)' ) parser.add_argument( '--clean-split', type=str, default='train.clean.100', help='Split for clean dataset' ) parser.add_argument( '--noise-split', type=str, default='train', help='Split for noise dataset' ) # Training arguments parser.add_argument( '--output-dir', type=str, default='./models', help='Output directory for models' ) parser.add_argument( '--epochs', type=int, default=50, help='Number of training epochs' ) parser.add_argument( '--batch-size', type=int, default=16, help='Training batch size' ) parser.add_argument( '--samples-per-epoch', type=int, default=1000, help='Number of samples per epoch' ) parser.add_argument( '--lstm-units', type=int, default=128, help='Number of LSTM units' ) parser.add_argument( '--learning-rate', type=float, default=0.001, help='Learning rate' ) parser.add_argument( '--no-qat', action='store_true', help='Disable quantization-aware training' ) parser.add_argument( '--cache-dir', type=str, default=None, help='Directory to cache HF datasets' ) args = parser.parse_args() # Train model model, history = train_model( clean_dataset=args.clean_dataset, noise_dataset=args.noise_dataset, clean_split=args.clean_split, noise_split=args.noise_split, output_dir=args.output_dir, epochs=args.epochs, batch_size=args.batch_size, samples_per_epoch=args.samples_per_epoch, lstm_units=args.lstm_units, learning_rate=args.learning_rate, use_qat=not args.no_qat, cache_dir=args.cache_dir )