Spaces:
Sleeping
Sleeping
Upload dtln_ethos_u55.py with huggingface_hub
Browse files- dtln_ethos_u55.py +318 -0
dtln_ethos_u55.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DTLN Model Optimized for Alif E7 Ethos-U55 NPU
|
| 3 |
+
Lightweight voice denoising with 8-bit quantization support
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import tensorflow as tf
|
| 7 |
+
from tensorflow import keras
|
| 8 |
+
from tensorflow.keras.layers import (
|
| 9 |
+
LSTM, Dense, Input, Multiply, Lambda, Concatenate
|
| 10 |
+
)
|
| 11 |
+
from tensorflow.keras.models import Model
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class DTLN_Ethos_U55:
|
| 16 |
+
"""
|
| 17 |
+
Dual-signal Transformation LSTM Network optimized for Ethos-U55
|
| 18 |
+
|
| 19 |
+
Key optimizations:
|
| 20 |
+
- Reduced parameters for <1MB model size
|
| 21 |
+
- 8-bit quantization aware architecture
|
| 22 |
+
- Stateful inference for real-time processing
|
| 23 |
+
- Memory-efficient for DTCM constraints
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
frame_len=512,
|
| 29 |
+
frame_shift=128,
|
| 30 |
+
lstm_units=128,
|
| 31 |
+
sampling_rate=16000,
|
| 32 |
+
use_stft=True
|
| 33 |
+
):
|
| 34 |
+
"""
|
| 35 |
+
Args:
|
| 36 |
+
frame_len: Length of audio frame (default 512 = 32ms @ 16kHz)
|
| 37 |
+
frame_shift: Frame hop size (default 128 = 8ms @ 16kHz)
|
| 38 |
+
lstm_units: Number of LSTM units (reduced to 128 for NPU)
|
| 39 |
+
sampling_rate: Audio sampling rate in Hz
|
| 40 |
+
use_stft: Use STFT domain processing
|
| 41 |
+
"""
|
| 42 |
+
self.frame_len = frame_len
|
| 43 |
+
self.frame_shift = frame_shift
|
| 44 |
+
self.lstm_units = lstm_units
|
| 45 |
+
self.sampling_rate = sampling_rate
|
| 46 |
+
self.use_stft = use_stft
|
| 47 |
+
|
| 48 |
+
# Frequency bins for STFT
|
| 49 |
+
self.freq_bins = frame_len // 2 + 1
|
| 50 |
+
|
| 51 |
+
def stft_layer(self, x):
|
| 52 |
+
"""Custom STFT layer using tf.signal"""
|
| 53 |
+
stft = tf.signal.stft(
|
| 54 |
+
x,
|
| 55 |
+
frame_length=self.frame_len,
|
| 56 |
+
frame_step=self.frame_shift,
|
| 57 |
+
fft_length=self.frame_len
|
| 58 |
+
)
|
| 59 |
+
mag = tf.abs(stft)
|
| 60 |
+
phase = tf.math.angle(stft)
|
| 61 |
+
return mag, phase
|
| 62 |
+
|
| 63 |
+
def istft_layer(self, mag, phase):
|
| 64 |
+
"""Custom inverse STFT layer"""
|
| 65 |
+
complex_spec = tf.cast(mag, tf.complex64) * tf.exp(
|
| 66 |
+
1j * tf.cast(phase, tf.complex64)
|
| 67 |
+
)
|
| 68 |
+
signal = tf.signal.inverse_stft(
|
| 69 |
+
complex_spec,
|
| 70 |
+
frame_length=self.frame_len,
|
| 71 |
+
frame_step=self.frame_shift,
|
| 72 |
+
fft_length=self.frame_len
|
| 73 |
+
)
|
| 74 |
+
return signal
|
| 75 |
+
|
| 76 |
+
def build_model(self, training=True):
|
| 77 |
+
"""
|
| 78 |
+
Build the full DTLN model for training
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
training: If True, builds training model. If False, builds inference model.
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
Keras Model
|
| 85 |
+
"""
|
| 86 |
+
# Input: raw waveform
|
| 87 |
+
input_audio = Input(shape=(None,), name='input_audio')
|
| 88 |
+
|
| 89 |
+
# Reshape for processing
|
| 90 |
+
audio_reshaped = Lambda(
|
| 91 |
+
lambda x: tf.expand_dims(x, -1)
|
| 92 |
+
)(input_audio)
|
| 93 |
+
|
| 94 |
+
# STFT transformation
|
| 95 |
+
mag, phase = Lambda(
|
| 96 |
+
lambda x: self.stft_layer(x),
|
| 97 |
+
name='stft'
|
| 98 |
+
)(input_audio)
|
| 99 |
+
|
| 100 |
+
# === First Processing Stage ===
|
| 101 |
+
# Process magnitude spectrum
|
| 102 |
+
lstm_1 = LSTM(
|
| 103 |
+
self.lstm_units,
|
| 104 |
+
return_sequences=True,
|
| 105 |
+
name='lstm_1'
|
| 106 |
+
)(mag)
|
| 107 |
+
|
| 108 |
+
# Estimate magnitude mask
|
| 109 |
+
mask_1 = Dense(
|
| 110 |
+
self.freq_bins,
|
| 111 |
+
activation='sigmoid',
|
| 112 |
+
name='mask_1'
|
| 113 |
+
)(lstm_1)
|
| 114 |
+
|
| 115 |
+
# Apply mask
|
| 116 |
+
enhanced_mag_1 = Multiply(name='apply_mask_1')([mag, mask_1])
|
| 117 |
+
|
| 118 |
+
# === Second Processing Stage ===
|
| 119 |
+
lstm_2 = LSTM(
|
| 120 |
+
self.lstm_units,
|
| 121 |
+
return_sequences=True,
|
| 122 |
+
name='lstm_2'
|
| 123 |
+
)(enhanced_mag_1)
|
| 124 |
+
|
| 125 |
+
# Second mask estimation
|
| 126 |
+
mask_2 = Dense(
|
| 127 |
+
self.freq_bins,
|
| 128 |
+
activation='sigmoid',
|
| 129 |
+
name='mask_2'
|
| 130 |
+
)(lstm_2)
|
| 131 |
+
|
| 132 |
+
# Apply second mask
|
| 133 |
+
enhanced_mag = Multiply(name='apply_mask_2')([enhanced_mag_1, mask_2])
|
| 134 |
+
|
| 135 |
+
# Inverse STFT
|
| 136 |
+
enhanced_audio = Lambda(
|
| 137 |
+
lambda x: self.istft_layer(x[0], x[1]),
|
| 138 |
+
name='istft'
|
| 139 |
+
)([enhanced_mag, phase])
|
| 140 |
+
|
| 141 |
+
# Build model
|
| 142 |
+
model = Model(
|
| 143 |
+
inputs=input_audio,
|
| 144 |
+
outputs=enhanced_audio,
|
| 145 |
+
name='DTLN_Ethos_U55'
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
return model
|
| 149 |
+
|
| 150 |
+
def build_stateful_model(self, batch_size=1):
|
| 151 |
+
"""
|
| 152 |
+
Build stateful model for frame-by-frame inference
|
| 153 |
+
This is more memory efficient for real-time processing
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
Two models (stage1, stage2) for sequential processing
|
| 157 |
+
"""
|
| 158 |
+
# === Stage 1 Model ===
|
| 159 |
+
# Inputs
|
| 160 |
+
mag_input = Input(
|
| 161 |
+
batch_shape=(batch_size, 1, self.freq_bins),
|
| 162 |
+
name='magnitude_input'
|
| 163 |
+
)
|
| 164 |
+
state_h_1 = Input(
|
| 165 |
+
batch_shape=(batch_size, self.lstm_units),
|
| 166 |
+
name='lstm_1_state_h'
|
| 167 |
+
)
|
| 168 |
+
state_c_1 = Input(
|
| 169 |
+
batch_shape=(batch_size, self.lstm_units),
|
| 170 |
+
name='lstm_1_state_c'
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# LSTM with state
|
| 174 |
+
lstm_1 = LSTM(
|
| 175 |
+
self.lstm_units,
|
| 176 |
+
return_sequences=True,
|
| 177 |
+
return_state=True,
|
| 178 |
+
stateful=False,
|
| 179 |
+
name='lstm_1'
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
lstm_out_1, state_h_1_out, state_c_1_out = lstm_1(
|
| 183 |
+
mag_input,
|
| 184 |
+
initial_state=[state_h_1, state_c_1]
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# Mask estimation
|
| 188 |
+
mask_1 = Dense(
|
| 189 |
+
self.freq_bins,
|
| 190 |
+
activation='sigmoid',
|
| 191 |
+
name='mask_1'
|
| 192 |
+
)(lstm_out_1)
|
| 193 |
+
|
| 194 |
+
# Apply mask
|
| 195 |
+
enhanced_mag_1 = Multiply()([mag_input, mask_1])
|
| 196 |
+
|
| 197 |
+
model_1 = Model(
|
| 198 |
+
inputs=[mag_input, state_h_1, state_c_1],
|
| 199 |
+
outputs=[enhanced_mag_1, state_h_1_out, state_c_1_out],
|
| 200 |
+
name='DTLN_Stage1'
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# === Stage 2 Model ===
|
| 204 |
+
mag_input_2 = Input(
|
| 205 |
+
batch_shape=(batch_size, 1, self.freq_bins),
|
| 206 |
+
name='enhanced_magnitude_input'
|
| 207 |
+
)
|
| 208 |
+
state_h_2 = Input(
|
| 209 |
+
batch_shape=(batch_size, self.lstm_units),
|
| 210 |
+
name='lstm_2_state_h'
|
| 211 |
+
)
|
| 212 |
+
state_c_2 = Input(
|
| 213 |
+
batch_shape=(batch_size, self.lstm_units),
|
| 214 |
+
name='lstm_2_state_c'
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# LSTM with state
|
| 218 |
+
lstm_2 = LSTM(
|
| 219 |
+
self.lstm_units,
|
| 220 |
+
return_sequences=True,
|
| 221 |
+
return_state=True,
|
| 222 |
+
stateful=False,
|
| 223 |
+
name='lstm_2'
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
lstm_out_2, state_h_2_out, state_c_2_out = lstm_2(
|
| 227 |
+
mag_input_2,
|
| 228 |
+
initial_state=[state_h_2, state_c_2]
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# Final mask
|
| 232 |
+
mask_2 = Dense(
|
| 233 |
+
self.freq_bins,
|
| 234 |
+
activation='sigmoid',
|
| 235 |
+
name='mask_2'
|
| 236 |
+
)(lstm_out_2)
|
| 237 |
+
|
| 238 |
+
# Apply mask
|
| 239 |
+
enhanced_mag = Multiply()([mag_input_2, mask_2])
|
| 240 |
+
|
| 241 |
+
model_2 = Model(
|
| 242 |
+
inputs=[mag_input_2, state_h_2, state_c_2],
|
| 243 |
+
outputs=[enhanced_mag, state_h_2_out, state_c_2_out],
|
| 244 |
+
name='DTLN_Stage2'
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
return model_1, model_2
|
| 248 |
+
|
| 249 |
+
def get_model_summary(self):
|
| 250 |
+
"""Print model architecture and parameter count"""
|
| 251 |
+
model = self.build_model()
|
| 252 |
+
model.summary()
|
| 253 |
+
|
| 254 |
+
total_params = model.count_params()
|
| 255 |
+
print(f"\nTotal parameters: {total_params:,}")
|
| 256 |
+
print(f"Estimated model size (FP32): {total_params * 4 / 1024:.2f} KB")
|
| 257 |
+
print(f"Estimated model size (INT8): {total_params / 1024:.2f} KB")
|
| 258 |
+
|
| 259 |
+
return model
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def create_lightweight_model(target_size_kb=100):
|
| 263 |
+
"""
|
| 264 |
+
Factory function to create a lightweight model that fits in target size
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
target_size_kb: Target model size in KB
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
DTLN_Ethos_U55 instance configured for target size
|
| 271 |
+
"""
|
| 272 |
+
# Estimate LSTM units for target size
|
| 273 |
+
# Rough estimate: each LSTM unit adds ~2KB for INT8
|
| 274 |
+
estimated_units = int((target_size_kb * 0.8) / 2)
|
| 275 |
+
estimated_units = min(max(estimated_units, 64), 256) # Clamp to 64-256
|
| 276 |
+
|
| 277 |
+
print(f"Creating model with {estimated_units} LSTM units")
|
| 278 |
+
print(f"Target size: {target_size_kb} KB")
|
| 279 |
+
|
| 280 |
+
model = DTLN_Ethos_U55(
|
| 281 |
+
frame_len=512,
|
| 282 |
+
frame_shift=128,
|
| 283 |
+
lstm_units=estimated_units,
|
| 284 |
+
sampling_rate=16000
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
return model
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
if __name__ == "__main__":
|
| 291 |
+
# Example usage
|
| 292 |
+
print("Creating DTLN model for Alif E7 Ethos-U55...")
|
| 293 |
+
|
| 294 |
+
# Create model
|
| 295 |
+
dtln = DTLN_Ethos_U55(
|
| 296 |
+
frame_len=512,
|
| 297 |
+
frame_shift=128,
|
| 298 |
+
lstm_units=128
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# Get model summary
|
| 302 |
+
model = dtln.get_model_summary()
|
| 303 |
+
|
| 304 |
+
# Build stateful models for inference
|
| 305 |
+
print("\n" + "="*50)
|
| 306 |
+
print("Building stateful models for real-time inference...")
|
| 307 |
+
stage1, stage2 = dtln.build_stateful_model()
|
| 308 |
+
|
| 309 |
+
print("\nStage 1:")
|
| 310 |
+
stage1.summary()
|
| 311 |
+
print("\nStage 2:")
|
| 312 |
+
stage2.summary()
|
| 313 |
+
|
| 314 |
+
print("\n✓ Model creation successful!")
|
| 315 |
+
print("Next steps:")
|
| 316 |
+
print("1. Train the model with quantization-aware training")
|
| 317 |
+
print("2. Convert to TensorFlow Lite INT8 format")
|
| 318 |
+
print("3. Use Vela compiler to optimize for Ethos-U55")
|