Spaces:
Sleeping
Sleeping
| """ | |
| DTLN Model Optimized for Alif E7 Ethos-U55 NPU | |
| Lightweight voice denoising with 8-bit quantization support | |
| """ | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| from tensorflow.keras.layers import ( | |
| LSTM, Dense, Input, Multiply, Lambda, Concatenate | |
| ) | |
| from tensorflow.keras.models import Model | |
| import numpy as np | |
| class DTLN_Ethos_U55: | |
| """ | |
| Dual-signal Transformation LSTM Network optimized for Ethos-U55 | |
| Key optimizations: | |
| - Reduced parameters for <1MB model size | |
| - 8-bit quantization aware architecture | |
| - Stateful inference for real-time processing | |
| - Memory-efficient for DTCM constraints | |
| """ | |
| def __init__( | |
| self, | |
| frame_len=512, | |
| frame_shift=128, | |
| lstm_units=128, | |
| sampling_rate=16000, | |
| use_stft=True | |
| ): | |
| """ | |
| Args: | |
| frame_len: Length of audio frame (default 512 = 32ms @ 16kHz) | |
| frame_shift: Frame hop size (default 128 = 8ms @ 16kHz) | |
| lstm_units: Number of LSTM units (reduced to 128 for NPU) | |
| sampling_rate: Audio sampling rate in Hz | |
| use_stft: Use STFT domain processing | |
| """ | |
| self.frame_len = frame_len | |
| self.frame_shift = frame_shift | |
| self.lstm_units = lstm_units | |
| self.sampling_rate = sampling_rate | |
| self.use_stft = use_stft | |
| # Frequency bins for STFT | |
| self.freq_bins = frame_len // 2 + 1 | |
| def stft_layer(self, x): | |
| """Custom STFT layer using tf.signal""" | |
| stft = tf.signal.stft( | |
| x, | |
| frame_length=self.frame_len, | |
| frame_step=self.frame_shift, | |
| fft_length=self.frame_len | |
| ) | |
| mag = tf.abs(stft) | |
| phase = tf.math.angle(stft) | |
| return mag, phase | |
| def istft_layer(self, mag, phase): | |
| """Custom inverse STFT layer""" | |
| complex_spec = tf.cast(mag, tf.complex64) * tf.exp( | |
| 1j * tf.cast(phase, tf.complex64) | |
| ) | |
| signal = tf.signal.inverse_stft( | |
| complex_spec, | |
| frame_length=self.frame_len, | |
| frame_step=self.frame_shift, | |
| fft_length=self.frame_len | |
| ) | |
| return signal | |
| def build_model(self, training=True): | |
| """ | |
| Build the full DTLN model for training | |
| Args: | |
| training: If True, builds training model. If False, builds inference model. | |
| Returns: | |
| Keras Model | |
| """ | |
| # Input: raw waveform | |
| input_audio = Input(shape=(None,), name='input_audio') | |
| # Reshape for processing | |
| audio_reshaped = Lambda( | |
| lambda x: tf.expand_dims(x, -1) | |
| )(input_audio) | |
| # STFT transformation | |
| mag, phase = Lambda( | |
| lambda x: self.stft_layer(x), | |
| name='stft' | |
| )(input_audio) | |
| # === First Processing Stage === | |
| # Process magnitude spectrum | |
| lstm_1 = LSTM( | |
| self.lstm_units, | |
| return_sequences=True, | |
| name='lstm_1' | |
| )(mag) | |
| # Estimate magnitude mask | |
| mask_1 = Dense( | |
| self.freq_bins, | |
| activation='sigmoid', | |
| name='mask_1' | |
| )(lstm_1) | |
| # Apply mask | |
| enhanced_mag_1 = Multiply(name='apply_mask_1')([mag, mask_1]) | |
| # === Second Processing Stage === | |
| lstm_2 = LSTM( | |
| self.lstm_units, | |
| return_sequences=True, | |
| name='lstm_2' | |
| )(enhanced_mag_1) | |
| # Second mask estimation | |
| mask_2 = Dense( | |
| self.freq_bins, | |
| activation='sigmoid', | |
| name='mask_2' | |
| )(lstm_2) | |
| # Apply second mask | |
| enhanced_mag = Multiply(name='apply_mask_2')([enhanced_mag_1, mask_2]) | |
| # Inverse STFT | |
| enhanced_audio = Lambda( | |
| lambda x: self.istft_layer(x[0], x[1]), | |
| name='istft' | |
| )([enhanced_mag, phase]) | |
| # Build model | |
| model = Model( | |
| inputs=input_audio, | |
| outputs=enhanced_audio, | |
| name='DTLN_Ethos_U55' | |
| ) | |
| return model | |
| def build_stateful_model(self, batch_size=1): | |
| """ | |
| Build stateful model for frame-by-frame inference | |
| This is more memory efficient for real-time processing | |
| Returns: | |
| Two models (stage1, stage2) for sequential processing | |
| """ | |
| # === Stage 1 Model === | |
| # Inputs | |
| mag_input = Input( | |
| batch_shape=(batch_size, 1, self.freq_bins), | |
| name='magnitude_input' | |
| ) | |
| state_h_1 = Input( | |
| batch_shape=(batch_size, self.lstm_units), | |
| name='lstm_1_state_h' | |
| ) | |
| state_c_1 = Input( | |
| batch_shape=(batch_size, self.lstm_units), | |
| name='lstm_1_state_c' | |
| ) | |
| # LSTM with state | |
| lstm_1 = LSTM( | |
| self.lstm_units, | |
| return_sequences=True, | |
| return_state=True, | |
| stateful=False, | |
| name='lstm_1' | |
| ) | |
| lstm_out_1, state_h_1_out, state_c_1_out = lstm_1( | |
| mag_input, | |
| initial_state=[state_h_1, state_c_1] | |
| ) | |
| # Mask estimation | |
| mask_1 = Dense( | |
| self.freq_bins, | |
| activation='sigmoid', | |
| name='mask_1' | |
| )(lstm_out_1) | |
| # Apply mask | |
| enhanced_mag_1 = Multiply()([mag_input, mask_1]) | |
| model_1 = Model( | |
| inputs=[mag_input, state_h_1, state_c_1], | |
| outputs=[enhanced_mag_1, state_h_1_out, state_c_1_out], | |
| name='DTLN_Stage1' | |
| ) | |
| # === Stage 2 Model === | |
| mag_input_2 = Input( | |
| batch_shape=(batch_size, 1, self.freq_bins), | |
| name='enhanced_magnitude_input' | |
| ) | |
| state_h_2 = Input( | |
| batch_shape=(batch_size, self.lstm_units), | |
| name='lstm_2_state_h' | |
| ) | |
| state_c_2 = Input( | |
| batch_shape=(batch_size, self.lstm_units), | |
| name='lstm_2_state_c' | |
| ) | |
| # LSTM with state | |
| lstm_2 = LSTM( | |
| self.lstm_units, | |
| return_sequences=True, | |
| return_state=True, | |
| stateful=False, | |
| name='lstm_2' | |
| ) | |
| lstm_out_2, state_h_2_out, state_c_2_out = lstm_2( | |
| mag_input_2, | |
| initial_state=[state_h_2, state_c_2] | |
| ) | |
| # Final mask | |
| mask_2 = Dense( | |
| self.freq_bins, | |
| activation='sigmoid', | |
| name='mask_2' | |
| )(lstm_out_2) | |
| # Apply mask | |
| enhanced_mag = Multiply()([mag_input_2, mask_2]) | |
| model_2 = Model( | |
| inputs=[mag_input_2, state_h_2, state_c_2], | |
| outputs=[enhanced_mag, state_h_2_out, state_c_2_out], | |
| name='DTLN_Stage2' | |
| ) | |
| return model_1, model_2 | |
| def get_model_summary(self): | |
| """Print model architecture and parameter count""" | |
| model = self.build_model() | |
| model.summary() | |
| total_params = model.count_params() | |
| print(f"\nTotal parameters: {total_params:,}") | |
| print(f"Estimated model size (FP32): {total_params * 4 / 1024:.2f} KB") | |
| print(f"Estimated model size (INT8): {total_params / 1024:.2f} KB") | |
| return model | |
| def create_lightweight_model(target_size_kb=100): | |
| """ | |
| Factory function to create a lightweight model that fits in target size | |
| Args: | |
| target_size_kb: Target model size in KB | |
| Returns: | |
| DTLN_Ethos_U55 instance configured for target size | |
| """ | |
| # Estimate LSTM units for target size | |
| # Rough estimate: each LSTM unit adds ~2KB for INT8 | |
| estimated_units = int((target_size_kb * 0.8) / 2) | |
| estimated_units = min(max(estimated_units, 64), 256) # Clamp to 64-256 | |
| print(f"Creating model with {estimated_units} LSTM units") | |
| print(f"Target size: {target_size_kb} KB") | |
| model = DTLN_Ethos_U55( | |
| frame_len=512, | |
| frame_shift=128, | |
| lstm_units=estimated_units, | |
| sampling_rate=16000 | |
| ) | |
| return model | |
| if __name__ == "__main__": | |
| # Example usage | |
| print("Creating DTLN model for Alif E7 Ethos-U55...") | |
| # Create model | |
| dtln = DTLN_Ethos_U55( | |
| frame_len=512, | |
| frame_shift=128, | |
| lstm_units=128 | |
| ) | |
| # Get model summary | |
| model = dtln.get_model_summary() | |
| # Build stateful models for inference | |
| print("\n" + "="*50) | |
| print("Building stateful models for real-time inference...") | |
| stage1, stage2 = dtln.build_stateful_model() | |
| print("\nStage 1:") | |
| stage1.summary() | |
| print("\nStage 2:") | |
| stage2.summary() | |
| print("\n✓ Model creation successful!") | |
| print("Next steps:") | |
| print("1. Train the model with quantization-aware training") | |
| print("2. Convert to TensorFlow Lite INT8 format") | |
| print("3. Use Vela compiler to optimize for Ethos-U55") | |