grgsaliba commited on
Commit
3a19dc4
·
verified ·
1 Parent(s): 5f473ab

Upload dtln_ethos_u55.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dtln_ethos_u55.py +318 -0
dtln_ethos_u55.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DTLN Model Optimized for Alif E7 Ethos-U55 NPU
3
+ Lightweight voice denoising with 8-bit quantization support
4
+ """
5
+
6
+ import tensorflow as tf
7
+ from tensorflow import keras
8
+ from tensorflow.keras.layers import (
9
+ LSTM, Dense, Input, Multiply, Lambda, Concatenate
10
+ )
11
+ from tensorflow.keras.models import Model
12
+ import numpy as np
13
+
14
+
15
+ class DTLN_Ethos_U55:
16
+ """
17
+ Dual-signal Transformation LSTM Network optimized for Ethos-U55
18
+
19
+ Key optimizations:
20
+ - Reduced parameters for <1MB model size
21
+ - 8-bit quantization aware architecture
22
+ - Stateful inference for real-time processing
23
+ - Memory-efficient for DTCM constraints
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ frame_len=512,
29
+ frame_shift=128,
30
+ lstm_units=128,
31
+ sampling_rate=16000,
32
+ use_stft=True
33
+ ):
34
+ """
35
+ Args:
36
+ frame_len: Length of audio frame (default 512 = 32ms @ 16kHz)
37
+ frame_shift: Frame hop size (default 128 = 8ms @ 16kHz)
38
+ lstm_units: Number of LSTM units (reduced to 128 for NPU)
39
+ sampling_rate: Audio sampling rate in Hz
40
+ use_stft: Use STFT domain processing
41
+ """
42
+ self.frame_len = frame_len
43
+ self.frame_shift = frame_shift
44
+ self.lstm_units = lstm_units
45
+ self.sampling_rate = sampling_rate
46
+ self.use_stft = use_stft
47
+
48
+ # Frequency bins for STFT
49
+ self.freq_bins = frame_len // 2 + 1
50
+
51
+ def stft_layer(self, x):
52
+ """Custom STFT layer using tf.signal"""
53
+ stft = tf.signal.stft(
54
+ x,
55
+ frame_length=self.frame_len,
56
+ frame_step=self.frame_shift,
57
+ fft_length=self.frame_len
58
+ )
59
+ mag = tf.abs(stft)
60
+ phase = tf.math.angle(stft)
61
+ return mag, phase
62
+
63
+ def istft_layer(self, mag, phase):
64
+ """Custom inverse STFT layer"""
65
+ complex_spec = tf.cast(mag, tf.complex64) * tf.exp(
66
+ 1j * tf.cast(phase, tf.complex64)
67
+ )
68
+ signal = tf.signal.inverse_stft(
69
+ complex_spec,
70
+ frame_length=self.frame_len,
71
+ frame_step=self.frame_shift,
72
+ fft_length=self.frame_len
73
+ )
74
+ return signal
75
+
76
+ def build_model(self, training=True):
77
+ """
78
+ Build the full DTLN model for training
79
+
80
+ Args:
81
+ training: If True, builds training model. If False, builds inference model.
82
+
83
+ Returns:
84
+ Keras Model
85
+ """
86
+ # Input: raw waveform
87
+ input_audio = Input(shape=(None,), name='input_audio')
88
+
89
+ # Reshape for processing
90
+ audio_reshaped = Lambda(
91
+ lambda x: tf.expand_dims(x, -1)
92
+ )(input_audio)
93
+
94
+ # STFT transformation
95
+ mag, phase = Lambda(
96
+ lambda x: self.stft_layer(x),
97
+ name='stft'
98
+ )(input_audio)
99
+
100
+ # === First Processing Stage ===
101
+ # Process magnitude spectrum
102
+ lstm_1 = LSTM(
103
+ self.lstm_units,
104
+ return_sequences=True,
105
+ name='lstm_1'
106
+ )(mag)
107
+
108
+ # Estimate magnitude mask
109
+ mask_1 = Dense(
110
+ self.freq_bins,
111
+ activation='sigmoid',
112
+ name='mask_1'
113
+ )(lstm_1)
114
+
115
+ # Apply mask
116
+ enhanced_mag_1 = Multiply(name='apply_mask_1')([mag, mask_1])
117
+
118
+ # === Second Processing Stage ===
119
+ lstm_2 = LSTM(
120
+ self.lstm_units,
121
+ return_sequences=True,
122
+ name='lstm_2'
123
+ )(enhanced_mag_1)
124
+
125
+ # Second mask estimation
126
+ mask_2 = Dense(
127
+ self.freq_bins,
128
+ activation='sigmoid',
129
+ name='mask_2'
130
+ )(lstm_2)
131
+
132
+ # Apply second mask
133
+ enhanced_mag = Multiply(name='apply_mask_2')([enhanced_mag_1, mask_2])
134
+
135
+ # Inverse STFT
136
+ enhanced_audio = Lambda(
137
+ lambda x: self.istft_layer(x[0], x[1]),
138
+ name='istft'
139
+ )([enhanced_mag, phase])
140
+
141
+ # Build model
142
+ model = Model(
143
+ inputs=input_audio,
144
+ outputs=enhanced_audio,
145
+ name='DTLN_Ethos_U55'
146
+ )
147
+
148
+ return model
149
+
150
+ def build_stateful_model(self, batch_size=1):
151
+ """
152
+ Build stateful model for frame-by-frame inference
153
+ This is more memory efficient for real-time processing
154
+
155
+ Returns:
156
+ Two models (stage1, stage2) for sequential processing
157
+ """
158
+ # === Stage 1 Model ===
159
+ # Inputs
160
+ mag_input = Input(
161
+ batch_shape=(batch_size, 1, self.freq_bins),
162
+ name='magnitude_input'
163
+ )
164
+ state_h_1 = Input(
165
+ batch_shape=(batch_size, self.lstm_units),
166
+ name='lstm_1_state_h'
167
+ )
168
+ state_c_1 = Input(
169
+ batch_shape=(batch_size, self.lstm_units),
170
+ name='lstm_1_state_c'
171
+ )
172
+
173
+ # LSTM with state
174
+ lstm_1 = LSTM(
175
+ self.lstm_units,
176
+ return_sequences=True,
177
+ return_state=True,
178
+ stateful=False,
179
+ name='lstm_1'
180
+ )
181
+
182
+ lstm_out_1, state_h_1_out, state_c_1_out = lstm_1(
183
+ mag_input,
184
+ initial_state=[state_h_1, state_c_1]
185
+ )
186
+
187
+ # Mask estimation
188
+ mask_1 = Dense(
189
+ self.freq_bins,
190
+ activation='sigmoid',
191
+ name='mask_1'
192
+ )(lstm_out_1)
193
+
194
+ # Apply mask
195
+ enhanced_mag_1 = Multiply()([mag_input, mask_1])
196
+
197
+ model_1 = Model(
198
+ inputs=[mag_input, state_h_1, state_c_1],
199
+ outputs=[enhanced_mag_1, state_h_1_out, state_c_1_out],
200
+ name='DTLN_Stage1'
201
+ )
202
+
203
+ # === Stage 2 Model ===
204
+ mag_input_2 = Input(
205
+ batch_shape=(batch_size, 1, self.freq_bins),
206
+ name='enhanced_magnitude_input'
207
+ )
208
+ state_h_2 = Input(
209
+ batch_shape=(batch_size, self.lstm_units),
210
+ name='lstm_2_state_h'
211
+ )
212
+ state_c_2 = Input(
213
+ batch_shape=(batch_size, self.lstm_units),
214
+ name='lstm_2_state_c'
215
+ )
216
+
217
+ # LSTM with state
218
+ lstm_2 = LSTM(
219
+ self.lstm_units,
220
+ return_sequences=True,
221
+ return_state=True,
222
+ stateful=False,
223
+ name='lstm_2'
224
+ )
225
+
226
+ lstm_out_2, state_h_2_out, state_c_2_out = lstm_2(
227
+ mag_input_2,
228
+ initial_state=[state_h_2, state_c_2]
229
+ )
230
+
231
+ # Final mask
232
+ mask_2 = Dense(
233
+ self.freq_bins,
234
+ activation='sigmoid',
235
+ name='mask_2'
236
+ )(lstm_out_2)
237
+
238
+ # Apply mask
239
+ enhanced_mag = Multiply()([mag_input_2, mask_2])
240
+
241
+ model_2 = Model(
242
+ inputs=[mag_input_2, state_h_2, state_c_2],
243
+ outputs=[enhanced_mag, state_h_2_out, state_c_2_out],
244
+ name='DTLN_Stage2'
245
+ )
246
+
247
+ return model_1, model_2
248
+
249
+ def get_model_summary(self):
250
+ """Print model architecture and parameter count"""
251
+ model = self.build_model()
252
+ model.summary()
253
+
254
+ total_params = model.count_params()
255
+ print(f"\nTotal parameters: {total_params:,}")
256
+ print(f"Estimated model size (FP32): {total_params * 4 / 1024:.2f} KB")
257
+ print(f"Estimated model size (INT8): {total_params / 1024:.2f} KB")
258
+
259
+ return model
260
+
261
+
262
+ def create_lightweight_model(target_size_kb=100):
263
+ """
264
+ Factory function to create a lightweight model that fits in target size
265
+
266
+ Args:
267
+ target_size_kb: Target model size in KB
268
+
269
+ Returns:
270
+ DTLN_Ethos_U55 instance configured for target size
271
+ """
272
+ # Estimate LSTM units for target size
273
+ # Rough estimate: each LSTM unit adds ~2KB for INT8
274
+ estimated_units = int((target_size_kb * 0.8) / 2)
275
+ estimated_units = min(max(estimated_units, 64), 256) # Clamp to 64-256
276
+
277
+ print(f"Creating model with {estimated_units} LSTM units")
278
+ print(f"Target size: {target_size_kb} KB")
279
+
280
+ model = DTLN_Ethos_U55(
281
+ frame_len=512,
282
+ frame_shift=128,
283
+ lstm_units=estimated_units,
284
+ sampling_rate=16000
285
+ )
286
+
287
+ return model
288
+
289
+
290
+ if __name__ == "__main__":
291
+ # Example usage
292
+ print("Creating DTLN model for Alif E7 Ethos-U55...")
293
+
294
+ # Create model
295
+ dtln = DTLN_Ethos_U55(
296
+ frame_len=512,
297
+ frame_shift=128,
298
+ lstm_units=128
299
+ )
300
+
301
+ # Get model summary
302
+ model = dtln.get_model_summary()
303
+
304
+ # Build stateful models for inference
305
+ print("\n" + "="*50)
306
+ print("Building stateful models for real-time inference...")
307
+ stage1, stage2 = dtln.build_stateful_model()
308
+
309
+ print("\nStage 1:")
310
+ stage1.summary()
311
+ print("\nStage 2:")
312
+ stage2.summary()
313
+
314
+ print("\n✓ Model creation successful!")
315
+ print("Next steps:")
316
+ print("1. Train the model with quantization-aware training")
317
+ print("2. Convert to TensorFlow Lite INT8 format")
318
+ print("3. Use Vela compiler to optimize for Ethos-U55")