grgsaliba commited on
Commit
2a37d6d
·
verified ·
1 Parent(s): 0e7454e

Upload train_with_hf_datasets.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_with_hf_datasets.py +516 -0
train_with_hf_datasets.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train DTLN model using Hugging Face datasets
3
+ Uses real speech and noise datasets for production-quality training
4
+ """
5
+
6
+ import tensorflow as tf
7
+ import tensorflow_model_optimization as tfmot
8
+ import numpy as np
9
+ import soundfile as sf
10
+ import librosa
11
+ from pathlib import Path
12
+ import argparse
13
+ from dtln_ethos_u55 import DTLN_Ethos_U55
14
+ import os
15
+ from datasets import load_dataset, Audio
16
+ from tqdm import tqdm
17
+
18
+
19
+ class HuggingFaceAudioDataGenerator(tf.keras.utils.Sequence):
20
+ """
21
+ Data generator using Hugging Face datasets
22
+ Loads clean speech and noise from HF Hub
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ clean_dataset_name="librispeech_asr",
28
+ noise_dataset_name="dns-challenge/dns-challenge-4",
29
+ clean_split="train.clean.100",
30
+ noise_split="train",
31
+ batch_size=16,
32
+ samples_per_epoch=1000,
33
+ frame_len=512,
34
+ frame_shift=128,
35
+ sampling_rate=16000,
36
+ snr_range=(0, 20),
37
+ shuffle=True,
38
+ cache_dir=None
39
+ ):
40
+ """
41
+ Args:
42
+ clean_dataset_name: HF dataset for clean speech (default: LibriSpeech)
43
+ noise_dataset_name: HF dataset for noise (default: DNS Challenge)
44
+ clean_split: Split to use from clean dataset
45
+ noise_split: Split to use from noise dataset
46
+ batch_size: Batch size for training
47
+ samples_per_epoch: Number of samples per epoch
48
+ frame_len: Frame length in samples
49
+ frame_shift: Frame shift in samples
50
+ sampling_rate: Target sampling rate
51
+ snr_range: Range of SNR for mixing (min, max) in dB
52
+ shuffle: Whether to shuffle data each epoch
53
+ cache_dir: Directory to cache datasets
54
+ """
55
+ print(f"\n{'='*60}")
56
+ print("Initializing Hugging Face Dataset Generator")
57
+ print(f"{'='*60}")
58
+
59
+ self.batch_size = batch_size
60
+ self.samples_per_epoch = samples_per_epoch
61
+ self.frame_len = frame_len
62
+ self.frame_shift = frame_shift
63
+ self.sampling_rate = sampling_rate
64
+ self.snr_range = snr_range
65
+ self.shuffle = shuffle
66
+
67
+ # Segment length for training (1 second)
68
+ self.segment_len = sampling_rate
69
+
70
+ # Load datasets
71
+ print(f"\n1. Loading clean speech dataset: {clean_dataset_name}")
72
+ print(f" Split: {clean_split}")
73
+
74
+ try:
75
+ self.clean_dataset = load_dataset(
76
+ clean_dataset_name,
77
+ split=clean_split,
78
+ streaming=True, # Stream for large datasets
79
+ cache_dir=cache_dir
80
+ )
81
+ # Cast audio to correct sampling rate
82
+ self.clean_dataset = self.clean_dataset.cast_column(
83
+ "audio",
84
+ Audio(sampling_rate=sampling_rate)
85
+ )
86
+ print(f" ✓ Clean speech dataset loaded")
87
+ except Exception as e:
88
+ print(f" ⚠ Error loading clean dataset: {e}")
89
+ print(f" Using fallback: common_voice")
90
+ self.clean_dataset = load_dataset(
91
+ "mozilla-foundation/common_voice_11_0",
92
+ "en",
93
+ split="train",
94
+ streaming=True,
95
+ cache_dir=cache_dir
96
+ )
97
+ self.clean_dataset = self.clean_dataset.cast_column(
98
+ "audio",
99
+ Audio(sampling_rate=sampling_rate)
100
+ )
101
+
102
+ print(f"\n2. Loading noise dataset: {noise_dataset_name}")
103
+ print(f" Split: {noise_split}")
104
+
105
+ try:
106
+ self.noise_dataset = load_dataset(
107
+ noise_dataset_name,
108
+ split=noise_split,
109
+ streaming=True,
110
+ cache_dir=cache_dir
111
+ )
112
+ self.noise_dataset = self.noise_dataset.cast_column(
113
+ "audio",
114
+ Audio(sampling_rate=sampling_rate)
115
+ )
116
+ print(f" ✓ Noise dataset loaded")
117
+ except Exception as e:
118
+ print(f" ⚠ Error loading noise dataset: {e}")
119
+ print(f" Using synthetic noise instead")
120
+ self.noise_dataset = None
121
+
122
+ # Create iterators
123
+ self.clean_iter = iter(self.clean_dataset)
124
+ if self.noise_dataset:
125
+ self.noise_iter = iter(self.noise_dataset)
126
+
127
+ self.on_epoch_end()
128
+
129
+ print(f"\n{'='*60}")
130
+ print(f"Dataset Generator Ready")
131
+ print(f" Batch size: {batch_size}")
132
+ print(f" Samples per epoch: {samples_per_epoch}")
133
+ print(f" Batches per epoch: {len(self)}")
134
+ print(f"{'='*60}\n")
135
+
136
+ def __len__(self):
137
+ """Return number of batches per epoch"""
138
+ return self.samples_per_epoch // self.batch_size
139
+
140
+ def __getitem__(self, index):
141
+ """Generate one batch of data"""
142
+ batch_clean = []
143
+ batch_noisy = []
144
+
145
+ for _ in range(self.batch_size):
146
+ try:
147
+ # Get clean audio
148
+ clean_audio = self._load_next_clean_audio()
149
+
150
+ # Get noise
151
+ if self.noise_dataset:
152
+ noise_audio = self._load_next_noise_audio()
153
+ else:
154
+ # Generate synthetic noise
155
+ noise_audio = np.random.randn(self.segment_len).astype(np.float32) * 0.1
156
+
157
+ # Mix clean and noise at random SNR
158
+ noisy_audio = self._mix_audio(clean_audio, noise_audio)
159
+
160
+ batch_clean.append(clean_audio)
161
+ batch_noisy.append(noisy_audio)
162
+
163
+ except Exception as e:
164
+ # If error, use white noise as fallback
165
+ print(f" Warning: Error loading sample: {e}")
166
+ noise = np.random.randn(self.segment_len).astype(np.float32) * 0.01
167
+ batch_clean.append(noise)
168
+ batch_noisy.append(noise)
169
+
170
+ return np.array(batch_noisy), np.array(batch_clean)
171
+
172
+ def on_epoch_end(self):
173
+ """Reset iterators at epoch end"""
174
+ pass
175
+
176
+ def _load_next_clean_audio(self):
177
+ """Load next clean audio sample"""
178
+ try:
179
+ sample = next(self.clean_iter)
180
+ audio = sample['audio']['array']
181
+ except StopIteration:
182
+ # Restart iterator
183
+ self.clean_iter = iter(self.clean_dataset)
184
+ sample = next(self.clean_iter)
185
+ audio = sample['audio']['array']
186
+
187
+ return self._preprocess_audio(audio)
188
+
189
+ def _load_next_noise_audio(self):
190
+ """Load next noise sample"""
191
+ try:
192
+ sample = next(self.noise_iter)
193
+ if 'audio' in sample:
194
+ audio = sample['audio']['array']
195
+ elif 'noise' in sample:
196
+ audio = sample['noise']['array']
197
+ else:
198
+ # Fallback to white noise
199
+ audio = np.random.randn(self.segment_len).astype(np.float32) * 0.1
200
+ except StopIteration:
201
+ # Restart iterator
202
+ self.noise_iter = iter(self.noise_dataset)
203
+ sample = next(self.noise_iter)
204
+ audio = sample['audio']['array']
205
+
206
+ return self._preprocess_audio(audio)
207
+
208
+ def _preprocess_audio(self, audio):
209
+ """Preprocess audio to target length and format"""
210
+ # Convert to float32
211
+ audio = audio.astype(np.float32)
212
+
213
+ # Trim or pad to segment length
214
+ if len(audio) > self.segment_len:
215
+ start = np.random.randint(0, len(audio) - self.segment_len)
216
+ audio = audio[start:start + self.segment_len]
217
+ else:
218
+ audio = np.pad(audio, (0, self.segment_len - len(audio)))
219
+
220
+ # Normalize
221
+ max_val = np.max(np.abs(audio))
222
+ if max_val > 1e-8:
223
+ audio = audio / max_val
224
+
225
+ return audio
226
+
227
+ def _mix_audio(self, clean, noise):
228
+ """Mix clean audio with noise at random SNR"""
229
+ snr = np.random.uniform(*self.snr_range)
230
+
231
+ # Calculate noise power
232
+ clean_power = np.mean(clean ** 2)
233
+ noise_power = np.mean(noise ** 2)
234
+
235
+ # Calculate noise scaling factor
236
+ if noise_power > 1e-8:
237
+ snr_linear = 10 ** (snr / 10)
238
+ noise_scale = np.sqrt(clean_power / (snr_linear * noise_power))
239
+ else:
240
+ noise_scale = 0.1
241
+
242
+ # Mix
243
+ noisy = clean + noise_scale * noise
244
+
245
+ # Normalize to prevent clipping
246
+ max_val = np.max(np.abs(noisy))
247
+ if max_val > 1e-8:
248
+ noisy = noisy / max_val * 0.95
249
+
250
+ return noisy.astype(np.float32)
251
+
252
+
253
+ def create_loss_function():
254
+ """
255
+ Create custom loss function combining time and frequency domain losses
256
+ """
257
+ def combined_loss(y_true, y_pred):
258
+ # Time domain MSE
259
+ time_loss = tf.reduce_mean(tf.square(y_true - y_pred))
260
+
261
+ # Frequency domain loss (STFT-based)
262
+ stft_true = tf.signal.stft(
263
+ y_true,
264
+ frame_length=512,
265
+ frame_step=128
266
+ )
267
+ stft_pred = tf.signal.stft(
268
+ y_pred,
269
+ frame_length=512,
270
+ frame_step=128
271
+ )
272
+
273
+ mag_true = tf.abs(stft_true)
274
+ mag_pred = tf.abs(stft_pred)
275
+
276
+ freq_loss = tf.reduce_mean(tf.square(mag_true - mag_pred))
277
+
278
+ # Combined loss (weighted)
279
+ return 0.7 * time_loss + 0.3 * freq_loss
280
+
281
+ return combined_loss
282
+
283
+
284
+ def train_model(
285
+ clean_dataset="librispeech_asr",
286
+ noise_dataset="dns-challenge/dns-challenge-4",
287
+ clean_split="train.clean.100",
288
+ noise_split="train",
289
+ output_dir='./models',
290
+ epochs=50,
291
+ batch_size=16,
292
+ samples_per_epoch=1000,
293
+ lstm_units=128,
294
+ learning_rate=0.001,
295
+ use_qat=True,
296
+ cache_dir=None
297
+ ):
298
+ """
299
+ Main training function using HF datasets
300
+
301
+ Args:
302
+ clean_dataset: HF dataset name for clean speech
303
+ noise_dataset: HF dataset name for noise
304
+ clean_split: Split for clean dataset
305
+ noise_split: Split for noise dataset
306
+ output_dir: Directory to save models
307
+ epochs: Number of training epochs
308
+ batch_size: Training batch size
309
+ samples_per_epoch: Samples per epoch
310
+ lstm_units: Number of LSTM units
311
+ learning_rate: Learning rate for Adam optimizer
312
+ use_qat: Whether to use quantization-aware training
313
+ cache_dir: Directory to cache datasets
314
+ """
315
+ # Create output directory
316
+ os.makedirs(output_dir, exist_ok=True)
317
+
318
+ print("="*60)
319
+ print("Training DTLN with Hugging Face Datasets")
320
+ print("="*60)
321
+
322
+ # Create model
323
+ print("\n1. Building DTLN model...")
324
+ dtln = DTLN_Ethos_U55(
325
+ frame_len=512,
326
+ frame_shift=128,
327
+ lstm_units=lstm_units,
328
+ sampling_rate=16000
329
+ )
330
+
331
+ model = dtln.build_model()
332
+ print(" ✓ Model built")
333
+ print(f" Parameters: {model.count_params():,}")
334
+
335
+ # Apply QAT if requested
336
+ if use_qat:
337
+ print("\n2. Applying Quantization-Aware Training...")
338
+ quantize_model = tfmot.quantization.keras.quantize_model
339
+ model = quantize_model(model)
340
+ print(" ✓ QAT applied (INT8 optimized)")
341
+
342
+ # Compile model
343
+ print("\n3. Compiling model...")
344
+ model.compile(
345
+ optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
346
+ loss=create_loss_function(),
347
+ metrics=['mae']
348
+ )
349
+ print(" ✓ Model compiled")
350
+
351
+ # Create data generator
352
+ print("\n4. Creating Hugging Face data generator...")
353
+ train_generator = HuggingFaceAudioDataGenerator(
354
+ clean_dataset_name=clean_dataset,
355
+ noise_dataset_name=noise_dataset,
356
+ clean_split=clean_split,
357
+ noise_split=noise_split,
358
+ batch_size=batch_size,
359
+ samples_per_epoch=samples_per_epoch,
360
+ frame_len=512,
361
+ frame_shift=128,
362
+ sampling_rate=16000,
363
+ snr_range=(0, 20),
364
+ shuffle=True,
365
+ cache_dir=cache_dir
366
+ )
367
+
368
+ # Callbacks
369
+ callbacks = [
370
+ tf.keras.callbacks.ModelCheckpoint(
371
+ filepath=os.path.join(output_dir, 'best_model.h5'),
372
+ monitor='loss',
373
+ save_best_only=True,
374
+ verbose=1
375
+ ),
376
+ tf.keras.callbacks.ReduceLROnPlateau(
377
+ monitor='loss',
378
+ factor=0.5,
379
+ patience=5,
380
+ min_lr=1e-6,
381
+ verbose=1
382
+ ),
383
+ tf.keras.callbacks.EarlyStopping(
384
+ monitor='loss',
385
+ patience=10,
386
+ restore_best_weights=True,
387
+ verbose=1
388
+ ),
389
+ tf.keras.callbacks.TensorBoard(
390
+ log_dir=os.path.join(output_dir, 'logs'),
391
+ histogram_freq=1
392
+ )
393
+ ]
394
+
395
+ # Train
396
+ print("\n5. Starting training...")
397
+ print("="*60)
398
+ history = model.fit(
399
+ train_generator,
400
+ epochs=epochs,
401
+ callbacks=callbacks,
402
+ verbose=1
403
+ )
404
+
405
+ # Save final model
406
+ final_model_path = os.path.join(output_dir, 'dtln_final.h5')
407
+ model.save(final_model_path)
408
+
409
+ print(f"\n{'='*60}")
410
+ print("✓ Training Complete!")
411
+ print(f"{'='*60}")
412
+ print(f"Final loss: {history.history['loss'][-1]:.4f}")
413
+ print(f"Best loss: {min(history.history['loss']):.4f}")
414
+ print(f"Model saved to: {final_model_path}")
415
+ print(f"{'='*60}\n")
416
+
417
+ return model, history
418
+
419
+
420
+ if __name__ == "__main__":
421
+ parser = argparse.ArgumentParser(
422
+ description='Train DTLN model using Hugging Face datasets'
423
+ )
424
+
425
+ # Dataset arguments
426
+ parser.add_argument(
427
+ '--clean-dataset',
428
+ type=str,
429
+ default='librispeech_asr',
430
+ help='HF dataset for clean speech (default: librispeech_asr)'
431
+ )
432
+ parser.add_argument(
433
+ '--noise-dataset',
434
+ type=str,
435
+ default='dns-challenge/dns-challenge-4',
436
+ help='HF dataset for noise (default: dns-challenge)'
437
+ )
438
+ parser.add_argument(
439
+ '--clean-split',
440
+ type=str,
441
+ default='train.clean.100',
442
+ help='Split for clean dataset'
443
+ )
444
+ parser.add_argument(
445
+ '--noise-split',
446
+ type=str,
447
+ default='train',
448
+ help='Split for noise dataset'
449
+ )
450
+
451
+ # Training arguments
452
+ parser.add_argument(
453
+ '--output-dir',
454
+ type=str,
455
+ default='./models',
456
+ help='Output directory for models'
457
+ )
458
+ parser.add_argument(
459
+ '--epochs',
460
+ type=int,
461
+ default=50,
462
+ help='Number of training epochs'
463
+ )
464
+ parser.add_argument(
465
+ '--batch-size',
466
+ type=int,
467
+ default=16,
468
+ help='Training batch size'
469
+ )
470
+ parser.add_argument(
471
+ '--samples-per-epoch',
472
+ type=int,
473
+ default=1000,
474
+ help='Number of samples per epoch'
475
+ )
476
+ parser.add_argument(
477
+ '--lstm-units',
478
+ type=int,
479
+ default=128,
480
+ help='Number of LSTM units'
481
+ )
482
+ parser.add_argument(
483
+ '--learning-rate',
484
+ type=float,
485
+ default=0.001,
486
+ help='Learning rate'
487
+ )
488
+ parser.add_argument(
489
+ '--no-qat',
490
+ action='store_true',
491
+ help='Disable quantization-aware training'
492
+ )
493
+ parser.add_argument(
494
+ '--cache-dir',
495
+ type=str,
496
+ default=None,
497
+ help='Directory to cache HF datasets'
498
+ )
499
+
500
+ args = parser.parse_args()
501
+
502
+ # Train model
503
+ model, history = train_model(
504
+ clean_dataset=args.clean_dataset,
505
+ noise_dataset=args.noise_dataset,
506
+ clean_split=args.clean_split,
507
+ noise_split=args.noise_split,
508
+ output_dir=args.output_dir,
509
+ epochs=args.epochs,
510
+ batch_size=args.batch_size,
511
+ samples_per_epoch=args.samples_per_epoch,
512
+ lstm_units=args.lstm_units,
513
+ learning_rate=args.learning_rate,
514
+ use_qat=not args.no_qat,
515
+ cache_dir=args.cache_dir
516
+ )