grgsaliba commited on
Commit
e5d5706
·
verified ·
1 Parent(s): 3a19dc4

Upload train_dtln.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_dtln.py +445 -0
train_dtln.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training script for DTLN model with Quantization-Aware Training (QAT)
3
+ Optimized for deployment on Alif E7 Ethos-U55 NPU
4
+ """
5
+
6
+ import tensorflow as tf
7
+ import tensorflow_model_optimization as tfmot
8
+ import numpy as np
9
+ import soundfile as sf
10
+ import librosa
11
+ from pathlib import Path
12
+ import argparse
13
+ from dtln_ethos_u55 import DTLN_Ethos_U55
14
+ import os
15
+
16
+
17
+ class AudioDataGenerator(tf.keras.utils.Sequence):
18
+ """
19
+ Data generator for training audio denoising models
20
+ Loads clean and noisy audio pairs
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ clean_audio_dir,
26
+ noise_audio_dir,
27
+ batch_size=16,
28
+ frame_len=512,
29
+ frame_shift=128,
30
+ sampling_rate=16000,
31
+ snr_range=(0, 20),
32
+ shuffle=True
33
+ ):
34
+ """
35
+ Args:
36
+ clean_audio_dir: Directory containing clean speech files
37
+ noise_audio_dir: Directory containing noise files
38
+ batch_size: Batch size for training
39
+ frame_len: Frame length in samples
40
+ frame_shift: Frame shift in samples
41
+ sampling_rate: Target sampling rate
42
+ snr_range: Range of SNR for mixing (min, max) in dB
43
+ shuffle: Whether to shuffle data each epoch
44
+ """
45
+ self.clean_files = list(Path(clean_audio_dir).glob('**/*.wav'))
46
+ self.noise_files = list(Path(noise_audio_dir).glob('**/*.wav'))
47
+
48
+ self.batch_size = batch_size
49
+ self.frame_len = frame_len
50
+ self.frame_shift = frame_shift
51
+ self.sampling_rate = sampling_rate
52
+ self.snr_range = snr_range
53
+ self.shuffle = shuffle
54
+
55
+ # Segment length for training (1 second)
56
+ self.segment_len = sampling_rate
57
+
58
+ self.on_epoch_end()
59
+
60
+ def __len__(self):
61
+ """Return number of batches per epoch"""
62
+ return len(self.clean_files) // self.batch_size
63
+
64
+ def __getitem__(self, index):
65
+ """Generate one batch of data"""
66
+ # Select files for this batch
67
+ batch_indices = self.indices[
68
+ index * self.batch_size:(index + 1) * self.batch_size
69
+ ]
70
+
71
+ batch_clean = []
72
+ batch_noisy = []
73
+
74
+ for idx in batch_indices:
75
+ clean_audio = self._load_audio(self.clean_files[idx])
76
+ noise_audio = self._load_random_noise()
77
+
78
+ # Mix clean and noise at random SNR
79
+ noisy_audio = self._mix_audio(clean_audio, noise_audio)
80
+
81
+ batch_clean.append(clean_audio)
82
+ batch_noisy.append(noisy_audio)
83
+
84
+ return np.array(batch_noisy), np.array(batch_clean)
85
+
86
+ def on_epoch_end(self):
87
+ """Update indices after each epoch"""
88
+ self.indices = np.arange(len(self.clean_files))
89
+ if self.shuffle:
90
+ np.random.shuffle(self.indices)
91
+
92
+ def _load_audio(self, file_path):
93
+ """Load and preprocess audio file"""
94
+ audio, sr = sf.read(file_path)
95
+
96
+ # Resample if needed
97
+ if sr != self.sampling_rate:
98
+ audio = librosa.resample(
99
+ audio,
100
+ orig_sr=sr,
101
+ target_sr=self.sampling_rate
102
+ )
103
+
104
+ # Convert to mono if stereo
105
+ if len(audio.shape) > 1:
106
+ audio = np.mean(audio, axis=1)
107
+
108
+ # Trim or pad to segment length
109
+ if len(audio) > self.segment_len:
110
+ start = np.random.randint(0, len(audio) - self.segment_len)
111
+ audio = audio[start:start + self.segment_len]
112
+ else:
113
+ audio = np.pad(audio, (0, self.segment_len - len(audio)))
114
+
115
+ # Normalize
116
+ audio = audio / (np.max(np.abs(audio)) + 1e-8)
117
+
118
+ return audio.astype(np.float32)
119
+
120
+ def _load_random_noise(self):
121
+ """Load random noise file"""
122
+ noise_file = np.random.choice(self.noise_files)
123
+ return self._load_audio(noise_file)
124
+
125
+ def _mix_audio(self, clean, noise):
126
+ """Mix clean audio with noise at random SNR"""
127
+ snr = np.random.uniform(*self.snr_range)
128
+
129
+ # Calculate noise power
130
+ clean_power = np.mean(clean ** 2)
131
+ noise_power = np.mean(noise ** 2)
132
+
133
+ # Calculate noise scaling factor
134
+ snr_linear = 10 ** (snr / 10)
135
+ noise_scale = np.sqrt(clean_power / (snr_linear * noise_power + 1e-8))
136
+
137
+ # Mix
138
+ noisy = clean + noise_scale * noise
139
+
140
+ # Normalize to prevent clipping
141
+ noisy = noisy / (np.max(np.abs(noisy)) + 1e-8) * 0.95
142
+
143
+ return noisy.astype(np.float32)
144
+
145
+
146
+ def apply_quantization_aware_training(model):
147
+ """
148
+ Apply quantization-aware training for 8-bit deployment
149
+
150
+ Args:
151
+ model: Keras model to quantize
152
+
153
+ Returns:
154
+ Quantization-aware model
155
+ """
156
+ # Quantize the entire model
157
+ quantize_model = tfmot.quantization.keras.quantize_model
158
+
159
+ # Use default quantization config
160
+ q_aware_model = quantize_model(model)
161
+
162
+ return q_aware_model
163
+
164
+
165
+ def create_loss_function():
166
+ """
167
+ Create custom loss function combining time and frequency domain losses
168
+ """
169
+ def combined_loss(y_true, y_pred):
170
+ # Time domain MSE
171
+ time_loss = tf.reduce_mean(tf.square(y_true - y_pred))
172
+
173
+ # Frequency domain loss (STFT-based)
174
+ stft_true = tf.signal.stft(
175
+ y_true,
176
+ frame_length=512,
177
+ frame_step=128
178
+ )
179
+ stft_pred = tf.signal.stft(
180
+ y_pred,
181
+ frame_length=512,
182
+ frame_step=128
183
+ )
184
+
185
+ mag_true = tf.abs(stft_true)
186
+ mag_pred = tf.abs(stft_pred)
187
+
188
+ freq_loss = tf.reduce_mean(tf.square(mag_true - mag_pred))
189
+
190
+ # Combined loss (weighted)
191
+ return 0.7 * time_loss + 0.3 * freq_loss
192
+
193
+ return combined_loss
194
+
195
+
196
+ def train_model(
197
+ clean_dir,
198
+ noise_dir,
199
+ output_dir='./models',
200
+ epochs=50,
201
+ batch_size=16,
202
+ lstm_units=128,
203
+ learning_rate=0.001,
204
+ use_qat=True
205
+ ):
206
+ """
207
+ Main training function
208
+
209
+ Args:
210
+ clean_dir: Directory with clean speech
211
+ noise_dir: Directory with noise files
212
+ output_dir: Directory to save models
213
+ epochs: Number of training epochs
214
+ batch_size: Training batch size
215
+ lstm_units: Number of LSTM units
216
+ learning_rate: Learning rate for Adam optimizer
217
+ use_qat: Whether to use quantization-aware training
218
+ """
219
+ # Create output directory
220
+ os.makedirs(output_dir, exist_ok=True)
221
+
222
+ print("="*60)
223
+ print("Training DTLN for Alif E7 Ethos-U55")
224
+ print("="*60)
225
+
226
+ # Create model
227
+ print("\n1. Building model...")
228
+ dtln = DTLN_Ethos_U55(
229
+ frame_len=512,
230
+ frame_shift=128,
231
+ lstm_units=lstm_units,
232
+ sampling_rate=16000
233
+ )
234
+
235
+ model = dtln.build_model()
236
+ model.summary()
237
+
238
+ # Apply QAT if requested
239
+ if use_qat:
240
+ print("\n2. Applying Quantization-Aware Training...")
241
+ model = apply_quantization_aware_training(model)
242
+ print(" ✓ QAT applied")
243
+
244
+ # Compile model
245
+ print("\n3. Compiling model...")
246
+ model.compile(
247
+ optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
248
+ loss=create_loss_function(),
249
+ metrics=['mae']
250
+ )
251
+ print(" ✓ Model compiled")
252
+
253
+ # Create data generators
254
+ print("\n4. Creating data generators...")
255
+ train_generator = AudioDataGenerator(
256
+ clean_audio_dir=clean_dir,
257
+ noise_audio_dir=noise_dir,
258
+ batch_size=batch_size,
259
+ frame_len=512,
260
+ frame_shift=128,
261
+ sampling_rate=16000,
262
+ snr_range=(0, 20),
263
+ shuffle=True
264
+ )
265
+ print(f" ✓ Training samples: {len(train_generator) * batch_size}")
266
+
267
+ # Callbacks
268
+ callbacks = [
269
+ tf.keras.callbacks.ModelCheckpoint(
270
+ filepath=os.path.join(output_dir, 'best_model.h5'),
271
+ monitor='loss',
272
+ save_best_only=True,
273
+ verbose=1
274
+ ),
275
+ tf.keras.callbacks.ReduceLROnPlateau(
276
+ monitor='loss',
277
+ factor=0.5,
278
+ patience=5,
279
+ min_lr=1e-6,
280
+ verbose=1
281
+ ),
282
+ tf.keras.callbacks.EarlyStopping(
283
+ monitor='loss',
284
+ patience=10,
285
+ restore_best_weights=True,
286
+ verbose=1
287
+ ),
288
+ tf.keras.callbacks.TensorBoard(
289
+ log_dir=os.path.join(output_dir, 'logs'),
290
+ histogram_freq=1
291
+ )
292
+ ]
293
+
294
+ # Train
295
+ print("\n5. Starting training...")
296
+ print("="*60)
297
+ history = model.fit(
298
+ train_generator,
299
+ epochs=epochs,
300
+ callbacks=callbacks,
301
+ verbose=1
302
+ )
303
+
304
+ # Save final model
305
+ final_model_path = os.path.join(
306
+ output_dir,
307
+ 'dtln_ethos_u55_final.h5'
308
+ )
309
+ model.save(final_model_path)
310
+ print(f"\n✓ Training complete! Model saved to {final_model_path}")
311
+
312
+ return model, history
313
+
314
+
315
+ def train_with_pretrained_dtln(
316
+ pretrained_weights_path,
317
+ clean_dir,
318
+ noise_dir,
319
+ output_dir='./models',
320
+ epochs=20,
321
+ batch_size=16
322
+ ):
323
+ """
324
+ Fine-tune from pre-trained DTLN weights
325
+
326
+ Args:
327
+ pretrained_weights_path: Path to pretrained DTLN weights
328
+ clean_dir: Directory with clean speech
329
+ noise_dir: Directory with noise files
330
+ output_dir: Output directory
331
+ epochs: Number of fine-tuning epochs
332
+ batch_size: Training batch size
333
+ """
334
+ print("Fine-tuning from pretrained DTLN weights...")
335
+
336
+ # Build model
337
+ dtln = DTLN_Ethos_U55(lstm_units=128)
338
+ model = dtln.build_model()
339
+
340
+ # Load pretrained weights (if architecture matches)
341
+ try:
342
+ model.load_weights(pretrained_weights_path, by_name=True)
343
+ print("✓ Pretrained weights loaded")
344
+ except:
345
+ print("⚠ Could not load pretrained weights, training from scratch")
346
+
347
+ # Continue training
348
+ return train_model(
349
+ clean_dir=clean_dir,
350
+ noise_dir=noise_dir,
351
+ output_dir=output_dir,
352
+ epochs=epochs,
353
+ batch_size=batch_size,
354
+ use_qat=True
355
+ )
356
+
357
+
358
+ if __name__ == "__main__":
359
+ parser = argparse.ArgumentParser(
360
+ description='Train DTLN model for Alif E7 Ethos-U55'
361
+ )
362
+ parser.add_argument(
363
+ '--clean-dir',
364
+ type=str,
365
+ required=True,
366
+ help='Directory containing clean speech files'
367
+ )
368
+ parser.add_argument(
369
+ '--noise-dir',
370
+ type=str,
371
+ required=True,
372
+ help='Directory containing noise files'
373
+ )
374
+ parser.add_argument(
375
+ '--output-dir',
376
+ type=str,
377
+ default='./models',
378
+ help='Output directory for models'
379
+ )
380
+ parser.add_argument(
381
+ '--epochs',
382
+ type=int,
383
+ default=50,
384
+ help='Number of training epochs'
385
+ )
386
+ parser.add_argument(
387
+ '--batch-size',
388
+ type=int,
389
+ default=16,
390
+ help='Training batch size'
391
+ )
392
+ parser.add_argument(
393
+ '--lstm-units',
394
+ type=int,
395
+ default=128,
396
+ help='Number of LSTM units'
397
+ )
398
+ parser.add_argument(
399
+ '--learning-rate',
400
+ type=float,
401
+ default=0.001,
402
+ help='Learning rate'
403
+ )
404
+ parser.add_argument(
405
+ '--no-qat',
406
+ action='store_true',
407
+ help='Disable quantization-aware training'
408
+ )
409
+ parser.add_argument(
410
+ '--pretrained',
411
+ type=str,
412
+ default=None,
413
+ help='Path to pretrained weights for fine-tuning'
414
+ )
415
+
416
+ args = parser.parse_args()
417
+
418
+ # Train model
419
+ if args.pretrained:
420
+ model, history = train_with_pretrained_dtln(
421
+ pretrained_weights_path=args.pretrained,
422
+ clean_dir=args.clean_dir,
423
+ noise_dir=args.noise_dir,
424
+ output_dir=args.output_dir,
425
+ epochs=args.epochs,
426
+ batch_size=args.batch_size
427
+ )
428
+ else:
429
+ model, history = train_model(
430
+ clean_dir=args.clean_dir,
431
+ noise_dir=args.noise_dir,
432
+ output_dir=args.output_dir,
433
+ epochs=args.epochs,
434
+ batch_size=args.batch_size,
435
+ lstm_units=args.lstm_units,
436
+ learning_rate=args.learning_rate,
437
+ use_qat=not args.no_qat
438
+ )
439
+
440
+ print("\n" + "="*60)
441
+ print("Training Summary:")
442
+ print(f" Final loss: {history.history['loss'][-1]:.4f}")
443
+ print(f" Best loss: {min(history.history['loss']):.4f}")
444
+ print(f" Model saved to: {args.output_dir}")
445
+ print("="*60)