Spaces:
Sleeping
Sleeping
| """ | |
| Train DTLN model using Hugging Face datasets | |
| Uses real speech and noise datasets for production-quality training | |
| """ | |
| import tensorflow as tf | |
| import tensorflow_model_optimization as tfmot | |
| import numpy as np | |
| import soundfile as sf | |
| import librosa | |
| from pathlib import Path | |
| import argparse | |
| from dtln_ethos_u55 import DTLN_Ethos_U55 | |
| import os | |
| from datasets import load_dataset, Audio | |
| from tqdm import tqdm | |
| class HuggingFaceAudioDataGenerator(tf.keras.utils.Sequence): | |
| """ | |
| Data generator using Hugging Face datasets | |
| Loads clean speech and noise from HF Hub | |
| """ | |
| def __init__( | |
| self, | |
| clean_dataset_name="librispeech_asr", | |
| noise_dataset_name="dns-challenge/dns-challenge-4", | |
| clean_split="train.clean.100", | |
| noise_split="train", | |
| batch_size=16, | |
| samples_per_epoch=1000, | |
| frame_len=512, | |
| frame_shift=128, | |
| sampling_rate=16000, | |
| snr_range=(0, 20), | |
| shuffle=True, | |
| cache_dir=None | |
| ): | |
| """ | |
| Args: | |
| clean_dataset_name: HF dataset for clean speech (default: LibriSpeech) | |
| noise_dataset_name: HF dataset for noise (default: DNS Challenge) | |
| clean_split: Split to use from clean dataset | |
| noise_split: Split to use from noise dataset | |
| batch_size: Batch size for training | |
| samples_per_epoch: Number of samples per epoch | |
| frame_len: Frame length in samples | |
| frame_shift: Frame shift in samples | |
| sampling_rate: Target sampling rate | |
| snr_range: Range of SNR for mixing (min, max) in dB | |
| shuffle: Whether to shuffle data each epoch | |
| cache_dir: Directory to cache datasets | |
| """ | |
| print(f"\n{'='*60}") | |
| print("Initializing Hugging Face Dataset Generator") | |
| print(f"{'='*60}") | |
| self.batch_size = batch_size | |
| self.samples_per_epoch = samples_per_epoch | |
| self.frame_len = frame_len | |
| self.frame_shift = frame_shift | |
| self.sampling_rate = sampling_rate | |
| self.snr_range = snr_range | |
| self.shuffle = shuffle | |
| # Segment length for training (1 second) | |
| self.segment_len = sampling_rate | |
| # Load datasets | |
| print(f"\n1. Loading clean speech dataset: {clean_dataset_name}") | |
| print(f" Split: {clean_split}") | |
| try: | |
| self.clean_dataset = load_dataset( | |
| clean_dataset_name, | |
| split=clean_split, | |
| streaming=True, # Stream for large datasets | |
| cache_dir=cache_dir | |
| ) | |
| # Cast audio to correct sampling rate | |
| self.clean_dataset = self.clean_dataset.cast_column( | |
| "audio", | |
| Audio(sampling_rate=sampling_rate) | |
| ) | |
| print(f" ✓ Clean speech dataset loaded") | |
| except Exception as e: | |
| print(f" ⚠ Error loading clean dataset: {e}") | |
| print(f" Using fallback: common_voice") | |
| self.clean_dataset = load_dataset( | |
| "mozilla-foundation/common_voice_11_0", | |
| "en", | |
| split="train", | |
| streaming=True, | |
| cache_dir=cache_dir | |
| ) | |
| self.clean_dataset = self.clean_dataset.cast_column( | |
| "audio", | |
| Audio(sampling_rate=sampling_rate) | |
| ) | |
| print(f"\n2. Loading noise dataset: {noise_dataset_name}") | |
| print(f" Split: {noise_split}") | |
| try: | |
| self.noise_dataset = load_dataset( | |
| noise_dataset_name, | |
| split=noise_split, | |
| streaming=True, | |
| cache_dir=cache_dir | |
| ) | |
| self.noise_dataset = self.noise_dataset.cast_column( | |
| "audio", | |
| Audio(sampling_rate=sampling_rate) | |
| ) | |
| print(f" ✓ Noise dataset loaded") | |
| except Exception as e: | |
| print(f" ⚠ Error loading noise dataset: {e}") | |
| print(f" Using synthetic noise instead") | |
| self.noise_dataset = None | |
| # Create iterators | |
| self.clean_iter = iter(self.clean_dataset) | |
| if self.noise_dataset: | |
| self.noise_iter = iter(self.noise_dataset) | |
| self.on_epoch_end() | |
| print(f"\n{'='*60}") | |
| print(f"Dataset Generator Ready") | |
| print(f" Batch size: {batch_size}") | |
| print(f" Samples per epoch: {samples_per_epoch}") | |
| print(f" Batches per epoch: {len(self)}") | |
| print(f"{'='*60}\n") | |
| def __len__(self): | |
| """Return number of batches per epoch""" | |
| return self.samples_per_epoch // self.batch_size | |
| def __getitem__(self, index): | |
| """Generate one batch of data""" | |
| batch_clean = [] | |
| batch_noisy = [] | |
| for _ in range(self.batch_size): | |
| try: | |
| # Get clean audio | |
| clean_audio = self._load_next_clean_audio() | |
| # Get noise | |
| if self.noise_dataset: | |
| noise_audio = self._load_next_noise_audio() | |
| else: | |
| # Generate synthetic noise | |
| noise_audio = np.random.randn(self.segment_len).astype(np.float32) * 0.1 | |
| # Mix clean and noise at random SNR | |
| noisy_audio = self._mix_audio(clean_audio, noise_audio) | |
| batch_clean.append(clean_audio) | |
| batch_noisy.append(noisy_audio) | |
| except Exception as e: | |
| # If error, use white noise as fallback | |
| print(f" Warning: Error loading sample: {e}") | |
| noise = np.random.randn(self.segment_len).astype(np.float32) * 0.01 | |
| batch_clean.append(noise) | |
| batch_noisy.append(noise) | |
| return np.array(batch_noisy), np.array(batch_clean) | |
| def on_epoch_end(self): | |
| """Reset iterators at epoch end""" | |
| pass | |
| def _load_next_clean_audio(self): | |
| """Load next clean audio sample""" | |
| try: | |
| sample = next(self.clean_iter) | |
| audio = sample['audio']['array'] | |
| except StopIteration: | |
| # Restart iterator | |
| self.clean_iter = iter(self.clean_dataset) | |
| sample = next(self.clean_iter) | |
| audio = sample['audio']['array'] | |
| return self._preprocess_audio(audio) | |
| def _load_next_noise_audio(self): | |
| """Load next noise sample""" | |
| try: | |
| sample = next(self.noise_iter) | |
| if 'audio' in sample: | |
| audio = sample['audio']['array'] | |
| elif 'noise' in sample: | |
| audio = sample['noise']['array'] | |
| else: | |
| # Fallback to white noise | |
| audio = np.random.randn(self.segment_len).astype(np.float32) * 0.1 | |
| except StopIteration: | |
| # Restart iterator | |
| self.noise_iter = iter(self.noise_dataset) | |
| sample = next(self.noise_iter) | |
| audio = sample['audio']['array'] | |
| return self._preprocess_audio(audio) | |
| def _preprocess_audio(self, audio): | |
| """Preprocess audio to target length and format""" | |
| # Convert to float32 | |
| audio = audio.astype(np.float32) | |
| # Trim or pad to segment length | |
| if len(audio) > self.segment_len: | |
| start = np.random.randint(0, len(audio) - self.segment_len) | |
| audio = audio[start:start + self.segment_len] | |
| else: | |
| audio = np.pad(audio, (0, self.segment_len - len(audio))) | |
| # Normalize | |
| max_val = np.max(np.abs(audio)) | |
| if max_val > 1e-8: | |
| audio = audio / max_val | |
| return audio | |
| def _mix_audio(self, clean, noise): | |
| """Mix clean audio with noise at random SNR""" | |
| snr = np.random.uniform(*self.snr_range) | |
| # Calculate noise power | |
| clean_power = np.mean(clean ** 2) | |
| noise_power = np.mean(noise ** 2) | |
| # Calculate noise scaling factor | |
| if noise_power > 1e-8: | |
| snr_linear = 10 ** (snr / 10) | |
| noise_scale = np.sqrt(clean_power / (snr_linear * noise_power)) | |
| else: | |
| noise_scale = 0.1 | |
| # Mix | |
| noisy = clean + noise_scale * noise | |
| # Normalize to prevent clipping | |
| max_val = np.max(np.abs(noisy)) | |
| if max_val > 1e-8: | |
| noisy = noisy / max_val * 0.95 | |
| return noisy.astype(np.float32) | |
| def create_loss_function(): | |
| """ | |
| Create custom loss function combining time and frequency domain losses | |
| """ | |
| def combined_loss(y_true, y_pred): | |
| # Time domain MSE | |
| time_loss = tf.reduce_mean(tf.square(y_true - y_pred)) | |
| # Frequency domain loss (STFT-based) | |
| stft_true = tf.signal.stft( | |
| y_true, | |
| frame_length=512, | |
| frame_step=128 | |
| ) | |
| stft_pred = tf.signal.stft( | |
| y_pred, | |
| frame_length=512, | |
| frame_step=128 | |
| ) | |
| mag_true = tf.abs(stft_true) | |
| mag_pred = tf.abs(stft_pred) | |
| freq_loss = tf.reduce_mean(tf.square(mag_true - mag_pred)) | |
| # Combined loss (weighted) | |
| return 0.7 * time_loss + 0.3 * freq_loss | |
| return combined_loss | |
| def train_model( | |
| clean_dataset="librispeech_asr", | |
| noise_dataset="dns-challenge/dns-challenge-4", | |
| clean_split="train.clean.100", | |
| noise_split="train", | |
| output_dir='./models', | |
| epochs=50, | |
| batch_size=16, | |
| samples_per_epoch=1000, | |
| lstm_units=128, | |
| learning_rate=0.001, | |
| use_qat=True, | |
| cache_dir=None | |
| ): | |
| """ | |
| Main training function using HF datasets | |
| Args: | |
| clean_dataset: HF dataset name for clean speech | |
| noise_dataset: HF dataset name for noise | |
| clean_split: Split for clean dataset | |
| noise_split: Split for noise dataset | |
| output_dir: Directory to save models | |
| epochs: Number of training epochs | |
| batch_size: Training batch size | |
| samples_per_epoch: Samples per epoch | |
| lstm_units: Number of LSTM units | |
| learning_rate: Learning rate for Adam optimizer | |
| use_qat: Whether to use quantization-aware training | |
| cache_dir: Directory to cache datasets | |
| """ | |
| # Create output directory | |
| os.makedirs(output_dir, exist_ok=True) | |
| print("="*60) | |
| print("Training DTLN with Hugging Face Datasets") | |
| print("="*60) | |
| # Create model | |
| print("\n1. Building DTLN model...") | |
| dtln = DTLN_Ethos_U55( | |
| frame_len=512, | |
| frame_shift=128, | |
| lstm_units=lstm_units, | |
| sampling_rate=16000 | |
| ) | |
| model = dtln.build_model() | |
| print(" ✓ Model built") | |
| print(f" Parameters: {model.count_params():,}") | |
| # Apply QAT if requested | |
| if use_qat: | |
| print("\n2. Applying Quantization-Aware Training...") | |
| quantize_model = tfmot.quantization.keras.quantize_model | |
| model = quantize_model(model) | |
| print(" ✓ QAT applied (INT8 optimized)") | |
| # Compile model | |
| print("\n3. Compiling model...") | |
| model.compile( | |
| optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), | |
| loss=create_loss_function(), | |
| metrics=['mae'] | |
| ) | |
| print(" ✓ Model compiled") | |
| # Create data generator | |
| print("\n4. Creating Hugging Face data generator...") | |
| train_generator = HuggingFaceAudioDataGenerator( | |
| clean_dataset_name=clean_dataset, | |
| noise_dataset_name=noise_dataset, | |
| clean_split=clean_split, | |
| noise_split=noise_split, | |
| batch_size=batch_size, | |
| samples_per_epoch=samples_per_epoch, | |
| frame_len=512, | |
| frame_shift=128, | |
| sampling_rate=16000, | |
| snr_range=(0, 20), | |
| shuffle=True, | |
| cache_dir=cache_dir | |
| ) | |
| # Callbacks | |
| callbacks = [ | |
| tf.keras.callbacks.ModelCheckpoint( | |
| filepath=os.path.join(output_dir, 'best_model.h5'), | |
| monitor='loss', | |
| save_best_only=True, | |
| verbose=1 | |
| ), | |
| tf.keras.callbacks.ReduceLROnPlateau( | |
| monitor='loss', | |
| factor=0.5, | |
| patience=5, | |
| min_lr=1e-6, | |
| verbose=1 | |
| ), | |
| tf.keras.callbacks.EarlyStopping( | |
| monitor='loss', | |
| patience=10, | |
| restore_best_weights=True, | |
| verbose=1 | |
| ), | |
| tf.keras.callbacks.TensorBoard( | |
| log_dir=os.path.join(output_dir, 'logs'), | |
| histogram_freq=1 | |
| ) | |
| ] | |
| # Train | |
| print("\n5. Starting training...") | |
| print("="*60) | |
| history = model.fit( | |
| train_generator, | |
| epochs=epochs, | |
| callbacks=callbacks, | |
| verbose=1 | |
| ) | |
| # Save final model | |
| final_model_path = os.path.join(output_dir, 'dtln_final.h5') | |
| model.save(final_model_path) | |
| print(f"\n{'='*60}") | |
| print("✓ Training Complete!") | |
| print(f"{'='*60}") | |
| print(f"Final loss: {history.history['loss'][-1]:.4f}") | |
| print(f"Best loss: {min(history.history['loss']):.4f}") | |
| print(f"Model saved to: {final_model_path}") | |
| print(f"{'='*60}\n") | |
| return model, history | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description='Train DTLN model using Hugging Face datasets' | |
| ) | |
| # Dataset arguments | |
| parser.add_argument( | |
| '--clean-dataset', | |
| type=str, | |
| default='librispeech_asr', | |
| help='HF dataset for clean speech (default: librispeech_asr)' | |
| ) | |
| parser.add_argument( | |
| '--noise-dataset', | |
| type=str, | |
| default='dns-challenge/dns-challenge-4', | |
| help='HF dataset for noise (default: dns-challenge)' | |
| ) | |
| parser.add_argument( | |
| '--clean-split', | |
| type=str, | |
| default='train.clean.100', | |
| help='Split for clean dataset' | |
| ) | |
| parser.add_argument( | |
| '--noise-split', | |
| type=str, | |
| default='train', | |
| help='Split for noise dataset' | |
| ) | |
| # Training arguments | |
| parser.add_argument( | |
| '--output-dir', | |
| type=str, | |
| default='./models', | |
| help='Output directory for models' | |
| ) | |
| parser.add_argument( | |
| '--epochs', | |
| type=int, | |
| default=50, | |
| help='Number of training epochs' | |
| ) | |
| parser.add_argument( | |
| '--batch-size', | |
| type=int, | |
| default=16, | |
| help='Training batch size' | |
| ) | |
| parser.add_argument( | |
| '--samples-per-epoch', | |
| type=int, | |
| default=1000, | |
| help='Number of samples per epoch' | |
| ) | |
| parser.add_argument( | |
| '--lstm-units', | |
| type=int, | |
| default=128, | |
| help='Number of LSTM units' | |
| ) | |
| parser.add_argument( | |
| '--learning-rate', | |
| type=float, | |
| default=0.001, | |
| help='Learning rate' | |
| ) | |
| parser.add_argument( | |
| '--no-qat', | |
| action='store_true', | |
| help='Disable quantization-aware training' | |
| ) | |
| parser.add_argument( | |
| '--cache-dir', | |
| type=str, | |
| default=None, | |
| help='Directory to cache HF datasets' | |
| ) | |
| args = parser.parse_args() | |
| # Train model | |
| model, history = train_model( | |
| clean_dataset=args.clean_dataset, | |
| noise_dataset=args.noise_dataset, | |
| clean_split=args.clean_split, | |
| noise_split=args.noise_split, | |
| output_dir=args.output_dir, | |
| epochs=args.epochs, | |
| batch_size=args.batch_size, | |
| samples_per_epoch=args.samples_per_epoch, | |
| lstm_units=args.lstm_units, | |
| learning_rate=args.learning_rate, | |
| use_qat=not args.no_qat, | |
| cache_dir=args.cache_dir | |
| ) | |