voice-denoising / dtln_ethos_u55.py
grgsaliba's picture
Upload dtln_ethos_u55.py with huggingface_hub
3a19dc4 verified
raw
history blame
9.14 kB
"""
DTLN Model Optimized for Alif E7 Ethos-U55 NPU
Lightweight voice denoising with 8-bit quantization support
"""
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import (
LSTM, Dense, Input, Multiply, Lambda, Concatenate
)
from tensorflow.keras.models import Model
import numpy as np
class DTLN_Ethos_U55:
"""
Dual-signal Transformation LSTM Network optimized for Ethos-U55
Key optimizations:
- Reduced parameters for <1MB model size
- 8-bit quantization aware architecture
- Stateful inference for real-time processing
- Memory-efficient for DTCM constraints
"""
def __init__(
self,
frame_len=512,
frame_shift=128,
lstm_units=128,
sampling_rate=16000,
use_stft=True
):
"""
Args:
frame_len: Length of audio frame (default 512 = 32ms @ 16kHz)
frame_shift: Frame hop size (default 128 = 8ms @ 16kHz)
lstm_units: Number of LSTM units (reduced to 128 for NPU)
sampling_rate: Audio sampling rate in Hz
use_stft: Use STFT domain processing
"""
self.frame_len = frame_len
self.frame_shift = frame_shift
self.lstm_units = lstm_units
self.sampling_rate = sampling_rate
self.use_stft = use_stft
# Frequency bins for STFT
self.freq_bins = frame_len // 2 + 1
def stft_layer(self, x):
"""Custom STFT layer using tf.signal"""
stft = tf.signal.stft(
x,
frame_length=self.frame_len,
frame_step=self.frame_shift,
fft_length=self.frame_len
)
mag = tf.abs(stft)
phase = tf.math.angle(stft)
return mag, phase
def istft_layer(self, mag, phase):
"""Custom inverse STFT layer"""
complex_spec = tf.cast(mag, tf.complex64) * tf.exp(
1j * tf.cast(phase, tf.complex64)
)
signal = tf.signal.inverse_stft(
complex_spec,
frame_length=self.frame_len,
frame_step=self.frame_shift,
fft_length=self.frame_len
)
return signal
def build_model(self, training=True):
"""
Build the full DTLN model for training
Args:
training: If True, builds training model. If False, builds inference model.
Returns:
Keras Model
"""
# Input: raw waveform
input_audio = Input(shape=(None,), name='input_audio')
# Reshape for processing
audio_reshaped = Lambda(
lambda x: tf.expand_dims(x, -1)
)(input_audio)
# STFT transformation
mag, phase = Lambda(
lambda x: self.stft_layer(x),
name='stft'
)(input_audio)
# === First Processing Stage ===
# Process magnitude spectrum
lstm_1 = LSTM(
self.lstm_units,
return_sequences=True,
name='lstm_1'
)(mag)
# Estimate magnitude mask
mask_1 = Dense(
self.freq_bins,
activation='sigmoid',
name='mask_1'
)(lstm_1)
# Apply mask
enhanced_mag_1 = Multiply(name='apply_mask_1')([mag, mask_1])
# === Second Processing Stage ===
lstm_2 = LSTM(
self.lstm_units,
return_sequences=True,
name='lstm_2'
)(enhanced_mag_1)
# Second mask estimation
mask_2 = Dense(
self.freq_bins,
activation='sigmoid',
name='mask_2'
)(lstm_2)
# Apply second mask
enhanced_mag = Multiply(name='apply_mask_2')([enhanced_mag_1, mask_2])
# Inverse STFT
enhanced_audio = Lambda(
lambda x: self.istft_layer(x[0], x[1]),
name='istft'
)([enhanced_mag, phase])
# Build model
model = Model(
inputs=input_audio,
outputs=enhanced_audio,
name='DTLN_Ethos_U55'
)
return model
def build_stateful_model(self, batch_size=1):
"""
Build stateful model for frame-by-frame inference
This is more memory efficient for real-time processing
Returns:
Two models (stage1, stage2) for sequential processing
"""
# === Stage 1 Model ===
# Inputs
mag_input = Input(
batch_shape=(batch_size, 1, self.freq_bins),
name='magnitude_input'
)
state_h_1 = Input(
batch_shape=(batch_size, self.lstm_units),
name='lstm_1_state_h'
)
state_c_1 = Input(
batch_shape=(batch_size, self.lstm_units),
name='lstm_1_state_c'
)
# LSTM with state
lstm_1 = LSTM(
self.lstm_units,
return_sequences=True,
return_state=True,
stateful=False,
name='lstm_1'
)
lstm_out_1, state_h_1_out, state_c_1_out = lstm_1(
mag_input,
initial_state=[state_h_1, state_c_1]
)
# Mask estimation
mask_1 = Dense(
self.freq_bins,
activation='sigmoid',
name='mask_1'
)(lstm_out_1)
# Apply mask
enhanced_mag_1 = Multiply()([mag_input, mask_1])
model_1 = Model(
inputs=[mag_input, state_h_1, state_c_1],
outputs=[enhanced_mag_1, state_h_1_out, state_c_1_out],
name='DTLN_Stage1'
)
# === Stage 2 Model ===
mag_input_2 = Input(
batch_shape=(batch_size, 1, self.freq_bins),
name='enhanced_magnitude_input'
)
state_h_2 = Input(
batch_shape=(batch_size, self.lstm_units),
name='lstm_2_state_h'
)
state_c_2 = Input(
batch_shape=(batch_size, self.lstm_units),
name='lstm_2_state_c'
)
# LSTM with state
lstm_2 = LSTM(
self.lstm_units,
return_sequences=True,
return_state=True,
stateful=False,
name='lstm_2'
)
lstm_out_2, state_h_2_out, state_c_2_out = lstm_2(
mag_input_2,
initial_state=[state_h_2, state_c_2]
)
# Final mask
mask_2 = Dense(
self.freq_bins,
activation='sigmoid',
name='mask_2'
)(lstm_out_2)
# Apply mask
enhanced_mag = Multiply()([mag_input_2, mask_2])
model_2 = Model(
inputs=[mag_input_2, state_h_2, state_c_2],
outputs=[enhanced_mag, state_h_2_out, state_c_2_out],
name='DTLN_Stage2'
)
return model_1, model_2
def get_model_summary(self):
"""Print model architecture and parameter count"""
model = self.build_model()
model.summary()
total_params = model.count_params()
print(f"\nTotal parameters: {total_params:,}")
print(f"Estimated model size (FP32): {total_params * 4 / 1024:.2f} KB")
print(f"Estimated model size (INT8): {total_params / 1024:.2f} KB")
return model
def create_lightweight_model(target_size_kb=100):
"""
Factory function to create a lightweight model that fits in target size
Args:
target_size_kb: Target model size in KB
Returns:
DTLN_Ethos_U55 instance configured for target size
"""
# Estimate LSTM units for target size
# Rough estimate: each LSTM unit adds ~2KB for INT8
estimated_units = int((target_size_kb * 0.8) / 2)
estimated_units = min(max(estimated_units, 64), 256) # Clamp to 64-256
print(f"Creating model with {estimated_units} LSTM units")
print(f"Target size: {target_size_kb} KB")
model = DTLN_Ethos_U55(
frame_len=512,
frame_shift=128,
lstm_units=estimated_units,
sampling_rate=16000
)
return model
if __name__ == "__main__":
# Example usage
print("Creating DTLN model for Alif E7 Ethos-U55...")
# Create model
dtln = DTLN_Ethos_U55(
frame_len=512,
frame_shift=128,
lstm_units=128
)
# Get model summary
model = dtln.get_model_summary()
# Build stateful models for inference
print("\n" + "="*50)
print("Building stateful models for real-time inference...")
stage1, stage2 = dtln.build_stateful_model()
print("\nStage 1:")
stage1.summary()
print("\nStage 2:")
stage2.summary()
print("\n✓ Model creation successful!")
print("Next steps:")
print("1. Train the model with quantization-aware training")
print("2. Convert to TensorFlow Lite INT8 format")
print("3. Use Vela compiler to optimize for Ethos-U55")