File size: 6,571 Bytes
0f01c60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
"""
Example: Quick test of DTLN model
This script demonstrates how to:
1. Create a model
2. Generate synthetic noisy audio
3. Process it through the model
"""

import numpy as np
import soundfile as sf
from dtln_ethos_u55 import DTLN_Ethos_U55, create_lightweight_model
import matplotlib.pyplot as plt


def generate_test_audio(duration=2.0, sample_rate=16000):
    """
    Generate synthetic test audio (speech + noise)
    
    Args:
        duration: Audio duration in seconds
        sample_rate: Sampling rate
    
    Returns:
        Tuple of (clean, noisy) audio
    """
    t = np.linspace(0, duration, int(duration * sample_rate))
    
    # Generate synthetic "speech" (mixture of frequencies)
    speech = (
        0.3 * np.sin(2 * np.pi * 200 * t) +
        0.2 * np.sin(2 * np.pi * 400 * t) +
        0.15 * np.sin(2 * np.pi * 600 * t)
    )
    
    # Add envelope to simulate speech patterns
    envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 3 * t)
    speech = speech * envelope
    
    # Generate noise
    noise = np.random.randn(len(t)) * 0.15
    
    # Mix speech and noise (SNR ~10dB)
    noisy = speech + noise
    
    # Normalize
    speech = speech / (np.max(np.abs(speech)) + 1e-8) * 0.9
    noisy = noisy / (np.max(np.abs(noisy)) + 1e-8) * 0.9
    
    return speech.astype(np.float32), noisy.astype(np.float32)


def plot_comparison(clean, noisy, enhanced, sample_rate=16000):
    """Plot waveforms and spectrograms for comparison"""
    
    fig, axes = plt.subplots(3, 2, figsize=(12, 10))
    
    # Time domain plots
    t = np.arange(len(clean)) / sample_rate
    
    axes[0, 0].plot(t, clean)
    axes[0, 0].set_title('Clean Speech (Waveform)')
    axes[0, 0].set_ylabel('Amplitude')
    axes[0, 0].grid(True)
    
    axes[1, 0].plot(t, noisy)
    axes[1, 0].set_title('Noisy Speech (Waveform)')
    axes[1, 0].set_ylabel('Amplitude')
    axes[1, 0].grid(True)
    
    axes[2, 0].plot(t, enhanced)
    axes[2, 0].set_title('Enhanced Speech (Waveform)')
    axes[2, 0].set_xlabel('Time (s)')
    axes[2, 0].set_ylabel('Amplitude')
    axes[2, 0].grid(True)
    
    # Frequency domain plots (spectrograms)
    from scipy import signal
    
    for idx, (audio, title) in enumerate([
        (clean, 'Clean Speech (Spectrogram)'),
        (noisy, 'Noisy Speech (Spectrogram)'),
        (enhanced, 'Enhanced Speech (Spectrogram)')
    ]):
        f, t_spec, Sxx = signal.spectrogram(audio, sample_rate, nperseg=512)
        axes[idx, 1].pcolormesh(
            t_spec, f, 10 * np.log10(Sxx + 1e-10),
            shading='gouraud',
            cmap='viridis'
        )
        axes[idx, 1].set_ylabel('Frequency (Hz)')
        axes[idx, 1].set_title(title)
        if idx == 2:
            axes[idx, 1].set_xlabel('Time (s)')
    
    plt.tight_layout()
    plt.savefig('/mnt/user-data/outputs/denoising_comparison.png', dpi=150)
    print("\nβœ“ Comparison plot saved to: denoising_comparison.png")
    plt.close()


def calculate_metrics(clean, enhanced):
    """Calculate quality metrics"""
    
    # Signal-to-Noise Ratio (SNR)
    noise = clean - enhanced
    signal_power = np.mean(clean ** 2)
    noise_power = np.mean(noise ** 2)
    snr = 10 * np.log10(signal_power / (noise_power + 1e-10))
    
    # Mean Squared Error
    mse = np.mean((clean - enhanced) ** 2)
    
    # Root Mean Squared Error
    rmse = np.sqrt(mse)
    
    return {
        'SNR (dB)': snr,
        'MSE': mse,
        'RMSE': rmse
    }


def main():
    """Main example function"""
    
    print("="*60)
    print("DTLN Model Example for Alif E7 Ethos-U55")
    print("="*60)
    
    # 1. Create model
    print("\n1. Creating DTLN model...")
    dtln = create_lightweight_model(target_size_kb=100)
    model = dtln.build_model()
    print("   βœ“ Model created")
    
    # 2. Generate test audio
    print("\n2. Generating test audio...")
    clean, noisy = generate_test_audio(duration=2.0)
    print(f"   βœ“ Generated {len(clean)/16000:.1f}s of audio")
    print(f"   βœ“ Clean audio range: [{np.min(clean):.3f}, {np.max(clean):.3f}]")
    print(f"   βœ“ Noisy audio range: [{np.min(noisy):.3f}, {np.max(noisy):.3f}]")
    
    # Save test audio
    sf.write('/mnt/user-data/outputs/test_clean.wav', clean, 16000)
    sf.write('/mnt/user-data/outputs/test_noisy.wav', noisy, 16000)
    print("   βœ“ Saved: test_clean.wav, test_noisy.wav")
    
    # 3. Process through model (random weights, not trained yet)
    print("\n3. Processing through model...")
    print("   ⚠ Note: Model has random weights (not trained)")
    
    # Expand dims for batch
    noisy_batch = np.expand_dims(noisy, 0)
    
    # Forward pass
    enhanced = model.predict(noisy_batch, verbose=0)
    enhanced = enhanced[0]  # Remove batch dimension
    
    print("   βœ“ Processing complete")
    print(f"   βœ“ Enhanced audio range: [{np.min(enhanced):.3f}, {np.max(enhanced):.3f}]")
    
    # Save enhanced audio
    sf.write('/mnt/user-data/outputs/test_enhanced.wav', enhanced, 16000)
    print("   βœ“ Saved: test_enhanced.wav")
    
    # 4. Calculate metrics
    print("\n4. Quality Metrics:")
    metrics = calculate_metrics(clean, enhanced)
    for metric_name, value in metrics.items():
        print(f"   {metric_name}: {value:.4f}")
    
    print("\n   ⚠ Note: These metrics are poor because model is untrained")
    print("   After training, expect SNR improvement of 10-15 dB")
    
    # 5. Plot comparison
    print("\n5. Creating visualization...")
    plot_comparison(clean, noisy, enhanced)
    
    # 6. Show model info
    print("\n6. Model Information:")
    print(f"   Parameters: {model.count_params():,}")
    print(f"   Layers: {len(model.layers)}")
    print(f"   Input shape: {model.input_shape}")
    print(f"   Output shape: {model.output_shape}")
    
    # 7. Build stateful models
    print("\n7. Building stateful models for real-time inference...")
    stage1, stage2 = dtln.build_stateful_model()
    print(f"   βœ“ Stage 1 parameters: {stage1.count_params():,}")
    print(f"   βœ“ Stage 2 parameters: {stage2.count_params():,}")
    
    print("\n" + "="*60)
    print("βœ“ Example complete!")
    print("\nGenerated files:")
    print("  - test_clean.wav")
    print("  - test_noisy.wav")
    print("  - test_enhanced.wav")
    print("  - denoising_comparison.png")
    print("\nNext steps:")
    print("  1. Train the model: python train_dtln.py --help")
    print("  2. Convert to TFLite: python convert_to_tflite.py --help")
    print("  3. Deploy to Alif E7")
    print("="*60)


if __name__ == "__main__":
    main()