Spaces:
Sleeping
Sleeping
File size: 6,571 Bytes
0f01c60 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
"""
Example: Quick test of DTLN model
This script demonstrates how to:
1. Create a model
2. Generate synthetic noisy audio
3. Process it through the model
"""
import numpy as np
import soundfile as sf
from dtln_ethos_u55 import DTLN_Ethos_U55, create_lightweight_model
import matplotlib.pyplot as plt
def generate_test_audio(duration=2.0, sample_rate=16000):
"""
Generate synthetic test audio (speech + noise)
Args:
duration: Audio duration in seconds
sample_rate: Sampling rate
Returns:
Tuple of (clean, noisy) audio
"""
t = np.linspace(0, duration, int(duration * sample_rate))
# Generate synthetic "speech" (mixture of frequencies)
speech = (
0.3 * np.sin(2 * np.pi * 200 * t) +
0.2 * np.sin(2 * np.pi * 400 * t) +
0.15 * np.sin(2 * np.pi * 600 * t)
)
# Add envelope to simulate speech patterns
envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 3 * t)
speech = speech * envelope
# Generate noise
noise = np.random.randn(len(t)) * 0.15
# Mix speech and noise (SNR ~10dB)
noisy = speech + noise
# Normalize
speech = speech / (np.max(np.abs(speech)) + 1e-8) * 0.9
noisy = noisy / (np.max(np.abs(noisy)) + 1e-8) * 0.9
return speech.astype(np.float32), noisy.astype(np.float32)
def plot_comparison(clean, noisy, enhanced, sample_rate=16000):
"""Plot waveforms and spectrograms for comparison"""
fig, axes = plt.subplots(3, 2, figsize=(12, 10))
# Time domain plots
t = np.arange(len(clean)) / sample_rate
axes[0, 0].plot(t, clean)
axes[0, 0].set_title('Clean Speech (Waveform)')
axes[0, 0].set_ylabel('Amplitude')
axes[0, 0].grid(True)
axes[1, 0].plot(t, noisy)
axes[1, 0].set_title('Noisy Speech (Waveform)')
axes[1, 0].set_ylabel('Amplitude')
axes[1, 0].grid(True)
axes[2, 0].plot(t, enhanced)
axes[2, 0].set_title('Enhanced Speech (Waveform)')
axes[2, 0].set_xlabel('Time (s)')
axes[2, 0].set_ylabel('Amplitude')
axes[2, 0].grid(True)
# Frequency domain plots (spectrograms)
from scipy import signal
for idx, (audio, title) in enumerate([
(clean, 'Clean Speech (Spectrogram)'),
(noisy, 'Noisy Speech (Spectrogram)'),
(enhanced, 'Enhanced Speech (Spectrogram)')
]):
f, t_spec, Sxx = signal.spectrogram(audio, sample_rate, nperseg=512)
axes[idx, 1].pcolormesh(
t_spec, f, 10 * np.log10(Sxx + 1e-10),
shading='gouraud',
cmap='viridis'
)
axes[idx, 1].set_ylabel('Frequency (Hz)')
axes[idx, 1].set_title(title)
if idx == 2:
axes[idx, 1].set_xlabel('Time (s)')
plt.tight_layout()
plt.savefig('/mnt/user-data/outputs/denoising_comparison.png', dpi=150)
print("\nβ Comparison plot saved to: denoising_comparison.png")
plt.close()
def calculate_metrics(clean, enhanced):
"""Calculate quality metrics"""
# Signal-to-Noise Ratio (SNR)
noise = clean - enhanced
signal_power = np.mean(clean ** 2)
noise_power = np.mean(noise ** 2)
snr = 10 * np.log10(signal_power / (noise_power + 1e-10))
# Mean Squared Error
mse = np.mean((clean - enhanced) ** 2)
# Root Mean Squared Error
rmse = np.sqrt(mse)
return {
'SNR (dB)': snr,
'MSE': mse,
'RMSE': rmse
}
def main():
"""Main example function"""
print("="*60)
print("DTLN Model Example for Alif E7 Ethos-U55")
print("="*60)
# 1. Create model
print("\n1. Creating DTLN model...")
dtln = create_lightweight_model(target_size_kb=100)
model = dtln.build_model()
print(" β Model created")
# 2. Generate test audio
print("\n2. Generating test audio...")
clean, noisy = generate_test_audio(duration=2.0)
print(f" β Generated {len(clean)/16000:.1f}s of audio")
print(f" β Clean audio range: [{np.min(clean):.3f}, {np.max(clean):.3f}]")
print(f" β Noisy audio range: [{np.min(noisy):.3f}, {np.max(noisy):.3f}]")
# Save test audio
sf.write('/mnt/user-data/outputs/test_clean.wav', clean, 16000)
sf.write('/mnt/user-data/outputs/test_noisy.wav', noisy, 16000)
print(" β Saved: test_clean.wav, test_noisy.wav")
# 3. Process through model (random weights, not trained yet)
print("\n3. Processing through model...")
print(" β Note: Model has random weights (not trained)")
# Expand dims for batch
noisy_batch = np.expand_dims(noisy, 0)
# Forward pass
enhanced = model.predict(noisy_batch, verbose=0)
enhanced = enhanced[0] # Remove batch dimension
print(" β Processing complete")
print(f" β Enhanced audio range: [{np.min(enhanced):.3f}, {np.max(enhanced):.3f}]")
# Save enhanced audio
sf.write('/mnt/user-data/outputs/test_enhanced.wav', enhanced, 16000)
print(" β Saved: test_enhanced.wav")
# 4. Calculate metrics
print("\n4. Quality Metrics:")
metrics = calculate_metrics(clean, enhanced)
for metric_name, value in metrics.items():
print(f" {metric_name}: {value:.4f}")
print("\n β Note: These metrics are poor because model is untrained")
print(" After training, expect SNR improvement of 10-15 dB")
# 5. Plot comparison
print("\n5. Creating visualization...")
plot_comparison(clean, noisy, enhanced)
# 6. Show model info
print("\n6. Model Information:")
print(f" Parameters: {model.count_params():,}")
print(f" Layers: {len(model.layers)}")
print(f" Input shape: {model.input_shape}")
print(f" Output shape: {model.output_shape}")
# 7. Build stateful models
print("\n7. Building stateful models for real-time inference...")
stage1, stage2 = dtln.build_stateful_model()
print(f" β Stage 1 parameters: {stage1.count_params():,}")
print(f" β Stage 2 parameters: {stage2.count_params():,}")
print("\n" + "="*60)
print("β Example complete!")
print("\nGenerated files:")
print(" - test_clean.wav")
print(" - test_noisy.wav")
print(" - test_enhanced.wav")
print(" - denoising_comparison.png")
print("\nNext steps:")
print(" 1. Train the model: python train_dtln.py --help")
print(" 2. Convert to TFLite: python convert_to_tflite.py --help")
print(" 3. Deploy to Alif E7")
print("="*60)
if __name__ == "__main__":
main()
|