grgsaliba commited on
Commit
0f01c60
Β·
verified Β·
1 Parent(s): 2d3c3ce

Upload example_usage.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. example_usage.py +207 -0
example_usage.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example: Quick test of DTLN model
3
+ This script demonstrates how to:
4
+ 1. Create a model
5
+ 2. Generate synthetic noisy audio
6
+ 3. Process it through the model
7
+ """
8
+
9
+ import numpy as np
10
+ import soundfile as sf
11
+ from dtln_ethos_u55 import DTLN_Ethos_U55, create_lightweight_model
12
+ import matplotlib.pyplot as plt
13
+
14
+
15
+ def generate_test_audio(duration=2.0, sample_rate=16000):
16
+ """
17
+ Generate synthetic test audio (speech + noise)
18
+
19
+ Args:
20
+ duration: Audio duration in seconds
21
+ sample_rate: Sampling rate
22
+
23
+ Returns:
24
+ Tuple of (clean, noisy) audio
25
+ """
26
+ t = np.linspace(0, duration, int(duration * sample_rate))
27
+
28
+ # Generate synthetic "speech" (mixture of frequencies)
29
+ speech = (
30
+ 0.3 * np.sin(2 * np.pi * 200 * t) +
31
+ 0.2 * np.sin(2 * np.pi * 400 * t) +
32
+ 0.15 * np.sin(2 * np.pi * 600 * t)
33
+ )
34
+
35
+ # Add envelope to simulate speech patterns
36
+ envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 3 * t)
37
+ speech = speech * envelope
38
+
39
+ # Generate noise
40
+ noise = np.random.randn(len(t)) * 0.15
41
+
42
+ # Mix speech and noise (SNR ~10dB)
43
+ noisy = speech + noise
44
+
45
+ # Normalize
46
+ speech = speech / (np.max(np.abs(speech)) + 1e-8) * 0.9
47
+ noisy = noisy / (np.max(np.abs(noisy)) + 1e-8) * 0.9
48
+
49
+ return speech.astype(np.float32), noisy.astype(np.float32)
50
+
51
+
52
+ def plot_comparison(clean, noisy, enhanced, sample_rate=16000):
53
+ """Plot waveforms and spectrograms for comparison"""
54
+
55
+ fig, axes = plt.subplots(3, 2, figsize=(12, 10))
56
+
57
+ # Time domain plots
58
+ t = np.arange(len(clean)) / sample_rate
59
+
60
+ axes[0, 0].plot(t, clean)
61
+ axes[0, 0].set_title('Clean Speech (Waveform)')
62
+ axes[0, 0].set_ylabel('Amplitude')
63
+ axes[0, 0].grid(True)
64
+
65
+ axes[1, 0].plot(t, noisy)
66
+ axes[1, 0].set_title('Noisy Speech (Waveform)')
67
+ axes[1, 0].set_ylabel('Amplitude')
68
+ axes[1, 0].grid(True)
69
+
70
+ axes[2, 0].plot(t, enhanced)
71
+ axes[2, 0].set_title('Enhanced Speech (Waveform)')
72
+ axes[2, 0].set_xlabel('Time (s)')
73
+ axes[2, 0].set_ylabel('Amplitude')
74
+ axes[2, 0].grid(True)
75
+
76
+ # Frequency domain plots (spectrograms)
77
+ from scipy import signal
78
+
79
+ for idx, (audio, title) in enumerate([
80
+ (clean, 'Clean Speech (Spectrogram)'),
81
+ (noisy, 'Noisy Speech (Spectrogram)'),
82
+ (enhanced, 'Enhanced Speech (Spectrogram)')
83
+ ]):
84
+ f, t_spec, Sxx = signal.spectrogram(audio, sample_rate, nperseg=512)
85
+ axes[idx, 1].pcolormesh(
86
+ t_spec, f, 10 * np.log10(Sxx + 1e-10),
87
+ shading='gouraud',
88
+ cmap='viridis'
89
+ )
90
+ axes[idx, 1].set_ylabel('Frequency (Hz)')
91
+ axes[idx, 1].set_title(title)
92
+ if idx == 2:
93
+ axes[idx, 1].set_xlabel('Time (s)')
94
+
95
+ plt.tight_layout()
96
+ plt.savefig('/mnt/user-data/outputs/denoising_comparison.png', dpi=150)
97
+ print("\nβœ“ Comparison plot saved to: denoising_comparison.png")
98
+ plt.close()
99
+
100
+
101
+ def calculate_metrics(clean, enhanced):
102
+ """Calculate quality metrics"""
103
+
104
+ # Signal-to-Noise Ratio (SNR)
105
+ noise = clean - enhanced
106
+ signal_power = np.mean(clean ** 2)
107
+ noise_power = np.mean(noise ** 2)
108
+ snr = 10 * np.log10(signal_power / (noise_power + 1e-10))
109
+
110
+ # Mean Squared Error
111
+ mse = np.mean((clean - enhanced) ** 2)
112
+
113
+ # Root Mean Squared Error
114
+ rmse = np.sqrt(mse)
115
+
116
+ return {
117
+ 'SNR (dB)': snr,
118
+ 'MSE': mse,
119
+ 'RMSE': rmse
120
+ }
121
+
122
+
123
+ def main():
124
+ """Main example function"""
125
+
126
+ print("="*60)
127
+ print("DTLN Model Example for Alif E7 Ethos-U55")
128
+ print("="*60)
129
+
130
+ # 1. Create model
131
+ print("\n1. Creating DTLN model...")
132
+ dtln = create_lightweight_model(target_size_kb=100)
133
+ model = dtln.build_model()
134
+ print(" βœ“ Model created")
135
+
136
+ # 2. Generate test audio
137
+ print("\n2. Generating test audio...")
138
+ clean, noisy = generate_test_audio(duration=2.0)
139
+ print(f" βœ“ Generated {len(clean)/16000:.1f}s of audio")
140
+ print(f" βœ“ Clean audio range: [{np.min(clean):.3f}, {np.max(clean):.3f}]")
141
+ print(f" βœ“ Noisy audio range: [{np.min(noisy):.3f}, {np.max(noisy):.3f}]")
142
+
143
+ # Save test audio
144
+ sf.write('/mnt/user-data/outputs/test_clean.wav', clean, 16000)
145
+ sf.write('/mnt/user-data/outputs/test_noisy.wav', noisy, 16000)
146
+ print(" βœ“ Saved: test_clean.wav, test_noisy.wav")
147
+
148
+ # 3. Process through model (random weights, not trained yet)
149
+ print("\n3. Processing through model...")
150
+ print(" ⚠ Note: Model has random weights (not trained)")
151
+
152
+ # Expand dims for batch
153
+ noisy_batch = np.expand_dims(noisy, 0)
154
+
155
+ # Forward pass
156
+ enhanced = model.predict(noisy_batch, verbose=0)
157
+ enhanced = enhanced[0] # Remove batch dimension
158
+
159
+ print(" βœ“ Processing complete")
160
+ print(f" βœ“ Enhanced audio range: [{np.min(enhanced):.3f}, {np.max(enhanced):.3f}]")
161
+
162
+ # Save enhanced audio
163
+ sf.write('/mnt/user-data/outputs/test_enhanced.wav', enhanced, 16000)
164
+ print(" βœ“ Saved: test_enhanced.wav")
165
+
166
+ # 4. Calculate metrics
167
+ print("\n4. Quality Metrics:")
168
+ metrics = calculate_metrics(clean, enhanced)
169
+ for metric_name, value in metrics.items():
170
+ print(f" {metric_name}: {value:.4f}")
171
+
172
+ print("\n ⚠ Note: These metrics are poor because model is untrained")
173
+ print(" After training, expect SNR improvement of 10-15 dB")
174
+
175
+ # 5. Plot comparison
176
+ print("\n5. Creating visualization...")
177
+ plot_comparison(clean, noisy, enhanced)
178
+
179
+ # 6. Show model info
180
+ print("\n6. Model Information:")
181
+ print(f" Parameters: {model.count_params():,}")
182
+ print(f" Layers: {len(model.layers)}")
183
+ print(f" Input shape: {model.input_shape}")
184
+ print(f" Output shape: {model.output_shape}")
185
+
186
+ # 7. Build stateful models
187
+ print("\n7. Building stateful models for real-time inference...")
188
+ stage1, stage2 = dtln.build_stateful_model()
189
+ print(f" βœ“ Stage 1 parameters: {stage1.count_params():,}")
190
+ print(f" βœ“ Stage 2 parameters: {stage2.count_params():,}")
191
+
192
+ print("\n" + "="*60)
193
+ print("βœ“ Example complete!")
194
+ print("\nGenerated files:")
195
+ print(" - test_clean.wav")
196
+ print(" - test_noisy.wav")
197
+ print(" - test_enhanced.wav")
198
+ print(" - denoising_comparison.png")
199
+ print("\nNext steps:")
200
+ print(" 1. Train the model: python train_dtln.py --help")
201
+ print(" 2. Convert to TFLite: python convert_to_tflite.py --help")
202
+ print(" 3. Deploy to Alif E7")
203
+ print("="*60)
204
+
205
+
206
+ if __name__ == "__main__":
207
+ main()