Spaces:
Sleeping
Sleeping
| """ | |
| Training script for DTLN model with Quantization-Aware Training (QAT) | |
| Optimized for deployment on Alif E7 Ethos-U55 NPU | |
| """ | |
| import tensorflow as tf | |
| import tensorflow_model_optimization as tfmot | |
| import numpy as np | |
| import soundfile as sf | |
| import librosa | |
| from pathlib import Path | |
| import argparse | |
| from dtln_ethos_u55 import DTLN_Ethos_U55 | |
| import os | |
| class AudioDataGenerator(tf.keras.utils.Sequence): | |
| """ | |
| Data generator for training audio denoising models | |
| Loads clean and noisy audio pairs | |
| """ | |
| def __init__( | |
| self, | |
| clean_audio_dir, | |
| noise_audio_dir, | |
| batch_size=16, | |
| frame_len=512, | |
| frame_shift=128, | |
| sampling_rate=16000, | |
| snr_range=(0, 20), | |
| shuffle=True | |
| ): | |
| """ | |
| Args: | |
| clean_audio_dir: Directory containing clean speech files | |
| noise_audio_dir: Directory containing noise files | |
| batch_size: Batch size for training | |
| frame_len: Frame length in samples | |
| frame_shift: Frame shift in samples | |
| sampling_rate: Target sampling rate | |
| snr_range: Range of SNR for mixing (min, max) in dB | |
| shuffle: Whether to shuffle data each epoch | |
| """ | |
| self.clean_files = list(Path(clean_audio_dir).glob('**/*.wav')) | |
| self.noise_files = list(Path(noise_audio_dir).glob('**/*.wav')) | |
| self.batch_size = batch_size | |
| self.frame_len = frame_len | |
| self.frame_shift = frame_shift | |
| self.sampling_rate = sampling_rate | |
| self.snr_range = snr_range | |
| self.shuffle = shuffle | |
| # Segment length for training (1 second) | |
| self.segment_len = sampling_rate | |
| self.on_epoch_end() | |
| def __len__(self): | |
| """Return number of batches per epoch""" | |
| return len(self.clean_files) // self.batch_size | |
| def __getitem__(self, index): | |
| """Generate one batch of data""" | |
| # Select files for this batch | |
| batch_indices = self.indices[ | |
| index * self.batch_size:(index + 1) * self.batch_size | |
| ] | |
| batch_clean = [] | |
| batch_noisy = [] | |
| for idx in batch_indices: | |
| clean_audio = self._load_audio(self.clean_files[idx]) | |
| noise_audio = self._load_random_noise() | |
| # Mix clean and noise at random SNR | |
| noisy_audio = self._mix_audio(clean_audio, noise_audio) | |
| batch_clean.append(clean_audio) | |
| batch_noisy.append(noisy_audio) | |
| return np.array(batch_noisy), np.array(batch_clean) | |
| def on_epoch_end(self): | |
| """Update indices after each epoch""" | |
| self.indices = np.arange(len(self.clean_files)) | |
| if self.shuffle: | |
| np.random.shuffle(self.indices) | |
| def _load_audio(self, file_path): | |
| """Load and preprocess audio file""" | |
| audio, sr = sf.read(file_path) | |
| # Resample if needed | |
| if sr != self.sampling_rate: | |
| audio = librosa.resample( | |
| audio, | |
| orig_sr=sr, | |
| target_sr=self.sampling_rate | |
| ) | |
| # Convert to mono if stereo | |
| if len(audio.shape) > 1: | |
| audio = np.mean(audio, axis=1) | |
| # Trim or pad to segment length | |
| if len(audio) > self.segment_len: | |
| start = np.random.randint(0, len(audio) - self.segment_len) | |
| audio = audio[start:start + self.segment_len] | |
| else: | |
| audio = np.pad(audio, (0, self.segment_len - len(audio))) | |
| # Normalize | |
| audio = audio / (np.max(np.abs(audio)) + 1e-8) | |
| return audio.astype(np.float32) | |
| def _load_random_noise(self): | |
| """Load random noise file""" | |
| noise_file = np.random.choice(self.noise_files) | |
| return self._load_audio(noise_file) | |
| def _mix_audio(self, clean, noise): | |
| """Mix clean audio with noise at random SNR""" | |
| snr = np.random.uniform(*self.snr_range) | |
| # Calculate noise power | |
| clean_power = np.mean(clean ** 2) | |
| noise_power = np.mean(noise ** 2) | |
| # Calculate noise scaling factor | |
| snr_linear = 10 ** (snr / 10) | |
| noise_scale = np.sqrt(clean_power / (snr_linear * noise_power + 1e-8)) | |
| # Mix | |
| noisy = clean + noise_scale * noise | |
| # Normalize to prevent clipping | |
| noisy = noisy / (np.max(np.abs(noisy)) + 1e-8) * 0.95 | |
| return noisy.astype(np.float32) | |
| def apply_quantization_aware_training(model): | |
| """ | |
| Apply quantization-aware training for 8-bit deployment | |
| Args: | |
| model: Keras model to quantize | |
| Returns: | |
| Quantization-aware model | |
| """ | |
| # Quantize the entire model | |
| quantize_model = tfmot.quantization.keras.quantize_model | |
| # Use default quantization config | |
| q_aware_model = quantize_model(model) | |
| return q_aware_model | |
| def create_loss_function(): | |
| """ | |
| Create custom loss function combining time and frequency domain losses | |
| """ | |
| def combined_loss(y_true, y_pred): | |
| # Time domain MSE | |
| time_loss = tf.reduce_mean(tf.square(y_true - y_pred)) | |
| # Frequency domain loss (STFT-based) | |
| stft_true = tf.signal.stft( | |
| y_true, | |
| frame_length=512, | |
| frame_step=128 | |
| ) | |
| stft_pred = tf.signal.stft( | |
| y_pred, | |
| frame_length=512, | |
| frame_step=128 | |
| ) | |
| mag_true = tf.abs(stft_true) | |
| mag_pred = tf.abs(stft_pred) | |
| freq_loss = tf.reduce_mean(tf.square(mag_true - mag_pred)) | |
| # Combined loss (weighted) | |
| return 0.7 * time_loss + 0.3 * freq_loss | |
| return combined_loss | |
| def train_model( | |
| clean_dir, | |
| noise_dir, | |
| output_dir='./models', | |
| epochs=50, | |
| batch_size=16, | |
| lstm_units=128, | |
| learning_rate=0.001, | |
| use_qat=True | |
| ): | |
| """ | |
| Main training function | |
| Args: | |
| clean_dir: Directory with clean speech | |
| noise_dir: Directory with noise files | |
| output_dir: Directory to save models | |
| epochs: Number of training epochs | |
| batch_size: Training batch size | |
| lstm_units: Number of LSTM units | |
| learning_rate: Learning rate for Adam optimizer | |
| use_qat: Whether to use quantization-aware training | |
| """ | |
| # Create output directory | |
| os.makedirs(output_dir, exist_ok=True) | |
| print("="*60) | |
| print("Training DTLN for Alif E7 Ethos-U55") | |
| print("="*60) | |
| # Create model | |
| print("\n1. Building model...") | |
| dtln = DTLN_Ethos_U55( | |
| frame_len=512, | |
| frame_shift=128, | |
| lstm_units=lstm_units, | |
| sampling_rate=16000 | |
| ) | |
| model = dtln.build_model() | |
| model.summary() | |
| # Apply QAT if requested | |
| if use_qat: | |
| print("\n2. Applying Quantization-Aware Training...") | |
| model = apply_quantization_aware_training(model) | |
| print(" ✓ QAT applied") | |
| # Compile model | |
| print("\n3. Compiling model...") | |
| model.compile( | |
| optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), | |
| loss=create_loss_function(), | |
| metrics=['mae'] | |
| ) | |
| print(" ✓ Model compiled") | |
| # Create data generators | |
| print("\n4. Creating data generators...") | |
| train_generator = AudioDataGenerator( | |
| clean_audio_dir=clean_dir, | |
| noise_audio_dir=noise_dir, | |
| batch_size=batch_size, | |
| frame_len=512, | |
| frame_shift=128, | |
| sampling_rate=16000, | |
| snr_range=(0, 20), | |
| shuffle=True | |
| ) | |
| print(f" ✓ Training samples: {len(train_generator) * batch_size}") | |
| # Callbacks | |
| callbacks = [ | |
| tf.keras.callbacks.ModelCheckpoint( | |
| filepath=os.path.join(output_dir, 'best_model.h5'), | |
| monitor='loss', | |
| save_best_only=True, | |
| verbose=1 | |
| ), | |
| tf.keras.callbacks.ReduceLROnPlateau( | |
| monitor='loss', | |
| factor=0.5, | |
| patience=5, | |
| min_lr=1e-6, | |
| verbose=1 | |
| ), | |
| tf.keras.callbacks.EarlyStopping( | |
| monitor='loss', | |
| patience=10, | |
| restore_best_weights=True, | |
| verbose=1 | |
| ), | |
| tf.keras.callbacks.TensorBoard( | |
| log_dir=os.path.join(output_dir, 'logs'), | |
| histogram_freq=1 | |
| ) | |
| ] | |
| # Train | |
| print("\n5. Starting training...") | |
| print("="*60) | |
| history = model.fit( | |
| train_generator, | |
| epochs=epochs, | |
| callbacks=callbacks, | |
| verbose=1 | |
| ) | |
| # Save final model | |
| final_model_path = os.path.join( | |
| output_dir, | |
| 'dtln_ethos_u55_final.h5' | |
| ) | |
| model.save(final_model_path) | |
| print(f"\n✓ Training complete! Model saved to {final_model_path}") | |
| return model, history | |
| def train_with_pretrained_dtln( | |
| pretrained_weights_path, | |
| clean_dir, | |
| noise_dir, | |
| output_dir='./models', | |
| epochs=20, | |
| batch_size=16 | |
| ): | |
| """ | |
| Fine-tune from pre-trained DTLN weights | |
| Args: | |
| pretrained_weights_path: Path to pretrained DTLN weights | |
| clean_dir: Directory with clean speech | |
| noise_dir: Directory with noise files | |
| output_dir: Output directory | |
| epochs: Number of fine-tuning epochs | |
| batch_size: Training batch size | |
| """ | |
| print("Fine-tuning from pretrained DTLN weights...") | |
| # Build model | |
| dtln = DTLN_Ethos_U55(lstm_units=128) | |
| model = dtln.build_model() | |
| # Load pretrained weights (if architecture matches) | |
| try: | |
| model.load_weights(pretrained_weights_path, by_name=True) | |
| print("✓ Pretrained weights loaded") | |
| except: | |
| print("⚠ Could not load pretrained weights, training from scratch") | |
| # Continue training | |
| return train_model( | |
| clean_dir=clean_dir, | |
| noise_dir=noise_dir, | |
| output_dir=output_dir, | |
| epochs=epochs, | |
| batch_size=batch_size, | |
| use_qat=True | |
| ) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description='Train DTLN model for Alif E7 Ethos-U55' | |
| ) | |
| parser.add_argument( | |
| '--clean-dir', | |
| type=str, | |
| required=True, | |
| help='Directory containing clean speech files' | |
| ) | |
| parser.add_argument( | |
| '--noise-dir', | |
| type=str, | |
| required=True, | |
| help='Directory containing noise files' | |
| ) | |
| parser.add_argument( | |
| '--output-dir', | |
| type=str, | |
| default='./models', | |
| help='Output directory for models' | |
| ) | |
| parser.add_argument( | |
| '--epochs', | |
| type=int, | |
| default=50, | |
| help='Number of training epochs' | |
| ) | |
| parser.add_argument( | |
| '--batch-size', | |
| type=int, | |
| default=16, | |
| help='Training batch size' | |
| ) | |
| parser.add_argument( | |
| '--lstm-units', | |
| type=int, | |
| default=128, | |
| help='Number of LSTM units' | |
| ) | |
| parser.add_argument( | |
| '--learning-rate', | |
| type=float, | |
| default=0.001, | |
| help='Learning rate' | |
| ) | |
| parser.add_argument( | |
| '--no-qat', | |
| action='store_true', | |
| help='Disable quantization-aware training' | |
| ) | |
| parser.add_argument( | |
| '--pretrained', | |
| type=str, | |
| default=None, | |
| help='Path to pretrained weights for fine-tuning' | |
| ) | |
| args = parser.parse_args() | |
| # Train model | |
| if args.pretrained: | |
| model, history = train_with_pretrained_dtln( | |
| pretrained_weights_path=args.pretrained, | |
| clean_dir=args.clean_dir, | |
| noise_dir=args.noise_dir, | |
| output_dir=args.output_dir, | |
| epochs=args.epochs, | |
| batch_size=args.batch_size | |
| ) | |
| else: | |
| model, history = train_model( | |
| clean_dir=args.clean_dir, | |
| noise_dir=args.noise_dir, | |
| output_dir=args.output_dir, | |
| epochs=args.epochs, | |
| batch_size=args.batch_size, | |
| lstm_units=args.lstm_units, | |
| learning_rate=args.learning_rate, | |
| use_qat=not args.no_qat | |
| ) | |
| print("\n" + "="*60) | |
| print("Training Summary:") | |
| print(f" Final loss: {history.history['loss'][-1]:.4f}") | |
| print(f" Best loss: {min(history.history['loss']):.4f}") | |
| print(f" Model saved to: {args.output_dir}") | |
| print("="*60) | |