""" Training script for DTLN model with Quantization-Aware Training (QAT) Optimized for deployment on Alif E7 Ethos-U55 NPU """ import tensorflow as tf import tensorflow_model_optimization as tfmot import numpy as np import soundfile as sf import librosa from pathlib import Path import argparse from dtln_ethos_u55 import DTLN_Ethos_U55 import os class AudioDataGenerator(tf.keras.utils.Sequence): """ Data generator for training audio denoising models Loads clean and noisy audio pairs """ def __init__( self, clean_audio_dir, noise_audio_dir, batch_size=16, frame_len=512, frame_shift=128, sampling_rate=16000, snr_range=(0, 20), shuffle=True ): """ Args: clean_audio_dir: Directory containing clean speech files noise_audio_dir: Directory containing noise files batch_size: Batch size for training frame_len: Frame length in samples frame_shift: Frame shift in samples sampling_rate: Target sampling rate snr_range: Range of SNR for mixing (min, max) in dB shuffle: Whether to shuffle data each epoch """ self.clean_files = list(Path(clean_audio_dir).glob('**/*.wav')) self.noise_files = list(Path(noise_audio_dir).glob('**/*.wav')) self.batch_size = batch_size self.frame_len = frame_len self.frame_shift = frame_shift self.sampling_rate = sampling_rate self.snr_range = snr_range self.shuffle = shuffle # Segment length for training (1 second) self.segment_len = sampling_rate self.on_epoch_end() def __len__(self): """Return number of batches per epoch""" return len(self.clean_files) // self.batch_size def __getitem__(self, index): """Generate one batch of data""" # Select files for this batch batch_indices = self.indices[ index * self.batch_size:(index + 1) * self.batch_size ] batch_clean = [] batch_noisy = [] for idx in batch_indices: clean_audio = self._load_audio(self.clean_files[idx]) noise_audio = self._load_random_noise() # Mix clean and noise at random SNR noisy_audio = self._mix_audio(clean_audio, noise_audio) batch_clean.append(clean_audio) batch_noisy.append(noisy_audio) return np.array(batch_noisy), np.array(batch_clean) def on_epoch_end(self): """Update indices after each epoch""" self.indices = np.arange(len(self.clean_files)) if self.shuffle: np.random.shuffle(self.indices) def _load_audio(self, file_path): """Load and preprocess audio file""" audio, sr = sf.read(file_path) # Resample if needed if sr != self.sampling_rate: audio = librosa.resample( audio, orig_sr=sr, target_sr=self.sampling_rate ) # Convert to mono if stereo if len(audio.shape) > 1: audio = np.mean(audio, axis=1) # Trim or pad to segment length if len(audio) > self.segment_len: start = np.random.randint(0, len(audio) - self.segment_len) audio = audio[start:start + self.segment_len] else: audio = np.pad(audio, (0, self.segment_len - len(audio))) # Normalize audio = audio / (np.max(np.abs(audio)) + 1e-8) return audio.astype(np.float32) def _load_random_noise(self): """Load random noise file""" noise_file = np.random.choice(self.noise_files) return self._load_audio(noise_file) def _mix_audio(self, clean, noise): """Mix clean audio with noise at random SNR""" snr = np.random.uniform(*self.snr_range) # Calculate noise power clean_power = np.mean(clean ** 2) noise_power = np.mean(noise ** 2) # Calculate noise scaling factor snr_linear = 10 ** (snr / 10) noise_scale = np.sqrt(clean_power / (snr_linear * noise_power + 1e-8)) # Mix noisy = clean + noise_scale * noise # Normalize to prevent clipping noisy = noisy / (np.max(np.abs(noisy)) + 1e-8) * 0.95 return noisy.astype(np.float32) def apply_quantization_aware_training(model): """ Apply quantization-aware training for 8-bit deployment Args: model: Keras model to quantize Returns: Quantization-aware model """ # Quantize the entire model quantize_model = tfmot.quantization.keras.quantize_model # Use default quantization config q_aware_model = quantize_model(model) return q_aware_model def create_loss_function(): """ Create custom loss function combining time and frequency domain losses """ def combined_loss(y_true, y_pred): # Time domain MSE time_loss = tf.reduce_mean(tf.square(y_true - y_pred)) # Frequency domain loss (STFT-based) stft_true = tf.signal.stft( y_true, frame_length=512, frame_step=128 ) stft_pred = tf.signal.stft( y_pred, frame_length=512, frame_step=128 ) mag_true = tf.abs(stft_true) mag_pred = tf.abs(stft_pred) freq_loss = tf.reduce_mean(tf.square(mag_true - mag_pred)) # Combined loss (weighted) return 0.7 * time_loss + 0.3 * freq_loss return combined_loss def train_model( clean_dir, noise_dir, output_dir='./models', epochs=50, batch_size=16, lstm_units=128, learning_rate=0.001, use_qat=True ): """ Main training function Args: clean_dir: Directory with clean speech noise_dir: Directory with noise files output_dir: Directory to save models epochs: Number of training epochs batch_size: Training batch size lstm_units: Number of LSTM units learning_rate: Learning rate for Adam optimizer use_qat: Whether to use quantization-aware training """ # Create output directory os.makedirs(output_dir, exist_ok=True) print("="*60) print("Training DTLN for Alif E7 Ethos-U55") print("="*60) # Create model print("\n1. Building model...") dtln = DTLN_Ethos_U55( frame_len=512, frame_shift=128, lstm_units=lstm_units, sampling_rate=16000 ) model = dtln.build_model() model.summary() # Apply QAT if requested if use_qat: print("\n2. Applying Quantization-Aware Training...") model = apply_quantization_aware_training(model) print(" ✓ QAT applied") # Compile model print("\n3. Compiling model...") model.compile( optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), loss=create_loss_function(), metrics=['mae'] ) print(" ✓ Model compiled") # Create data generators print("\n4. Creating data generators...") train_generator = AudioDataGenerator( clean_audio_dir=clean_dir, noise_audio_dir=noise_dir, batch_size=batch_size, frame_len=512, frame_shift=128, sampling_rate=16000, snr_range=(0, 20), shuffle=True ) print(f" ✓ Training samples: {len(train_generator) * batch_size}") # Callbacks callbacks = [ tf.keras.callbacks.ModelCheckpoint( filepath=os.path.join(output_dir, 'best_model.h5'), monitor='loss', save_best_only=True, verbose=1 ), tf.keras.callbacks.ReduceLROnPlateau( monitor='loss', factor=0.5, patience=5, min_lr=1e-6, verbose=1 ), tf.keras.callbacks.EarlyStopping( monitor='loss', patience=10, restore_best_weights=True, verbose=1 ), tf.keras.callbacks.TensorBoard( log_dir=os.path.join(output_dir, 'logs'), histogram_freq=1 ) ] # Train print("\n5. Starting training...") print("="*60) history = model.fit( train_generator, epochs=epochs, callbacks=callbacks, verbose=1 ) # Save final model final_model_path = os.path.join( output_dir, 'dtln_ethos_u55_final.h5' ) model.save(final_model_path) print(f"\n✓ Training complete! Model saved to {final_model_path}") return model, history def train_with_pretrained_dtln( pretrained_weights_path, clean_dir, noise_dir, output_dir='./models', epochs=20, batch_size=16 ): """ Fine-tune from pre-trained DTLN weights Args: pretrained_weights_path: Path to pretrained DTLN weights clean_dir: Directory with clean speech noise_dir: Directory with noise files output_dir: Output directory epochs: Number of fine-tuning epochs batch_size: Training batch size """ print("Fine-tuning from pretrained DTLN weights...") # Build model dtln = DTLN_Ethos_U55(lstm_units=128) model = dtln.build_model() # Load pretrained weights (if architecture matches) try: model.load_weights(pretrained_weights_path, by_name=True) print("✓ Pretrained weights loaded") except: print("⚠ Could not load pretrained weights, training from scratch") # Continue training return train_model( clean_dir=clean_dir, noise_dir=noise_dir, output_dir=output_dir, epochs=epochs, batch_size=batch_size, use_qat=True ) if __name__ == "__main__": parser = argparse.ArgumentParser( description='Train DTLN model for Alif E7 Ethos-U55' ) parser.add_argument( '--clean-dir', type=str, required=True, help='Directory containing clean speech files' ) parser.add_argument( '--noise-dir', type=str, required=True, help='Directory containing noise files' ) parser.add_argument( '--output-dir', type=str, default='./models', help='Output directory for models' ) parser.add_argument( '--epochs', type=int, default=50, help='Number of training epochs' ) parser.add_argument( '--batch-size', type=int, default=16, help='Training batch size' ) parser.add_argument( '--lstm-units', type=int, default=128, help='Number of LSTM units' ) parser.add_argument( '--learning-rate', type=float, default=0.001, help='Learning rate' ) parser.add_argument( '--no-qat', action='store_true', help='Disable quantization-aware training' ) parser.add_argument( '--pretrained', type=str, default=None, help='Path to pretrained weights for fine-tuning' ) args = parser.parse_args() # Train model if args.pretrained: model, history = train_with_pretrained_dtln( pretrained_weights_path=args.pretrained, clean_dir=args.clean_dir, noise_dir=args.noise_dir, output_dir=args.output_dir, epochs=args.epochs, batch_size=args.batch_size ) else: model, history = train_model( clean_dir=args.clean_dir, noise_dir=args.noise_dir, output_dir=args.output_dir, epochs=args.epochs, batch_size=args.batch_size, lstm_units=args.lstm_units, learning_rate=args.learning_rate, use_qat=not args.no_qat ) print("\n" + "="*60) print("Training Summary:") print(f" Final loss: {history.history['loss'][-1]:.4f}") print(f" Best loss: {min(history.history['loss']):.4f}") print(f" Model saved to: {args.output_dir}") print("="*60)