voice-denoising / example_usage.py
grgsaliba's picture
Upload example_usage.py with huggingface_hub
0f01c60 verified
"""
Example: Quick test of DTLN model
This script demonstrates how to:
1. Create a model
2. Generate synthetic noisy audio
3. Process it through the model
"""
import numpy as np
import soundfile as sf
from dtln_ethos_u55 import DTLN_Ethos_U55, create_lightweight_model
import matplotlib.pyplot as plt
def generate_test_audio(duration=2.0, sample_rate=16000):
"""
Generate synthetic test audio (speech + noise)
Args:
duration: Audio duration in seconds
sample_rate: Sampling rate
Returns:
Tuple of (clean, noisy) audio
"""
t = np.linspace(0, duration, int(duration * sample_rate))
# Generate synthetic "speech" (mixture of frequencies)
speech = (
0.3 * np.sin(2 * np.pi * 200 * t) +
0.2 * np.sin(2 * np.pi * 400 * t) +
0.15 * np.sin(2 * np.pi * 600 * t)
)
# Add envelope to simulate speech patterns
envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 3 * t)
speech = speech * envelope
# Generate noise
noise = np.random.randn(len(t)) * 0.15
# Mix speech and noise (SNR ~10dB)
noisy = speech + noise
# Normalize
speech = speech / (np.max(np.abs(speech)) + 1e-8) * 0.9
noisy = noisy / (np.max(np.abs(noisy)) + 1e-8) * 0.9
return speech.astype(np.float32), noisy.astype(np.float32)
def plot_comparison(clean, noisy, enhanced, sample_rate=16000):
"""Plot waveforms and spectrograms for comparison"""
fig, axes = plt.subplots(3, 2, figsize=(12, 10))
# Time domain plots
t = np.arange(len(clean)) / sample_rate
axes[0, 0].plot(t, clean)
axes[0, 0].set_title('Clean Speech (Waveform)')
axes[0, 0].set_ylabel('Amplitude')
axes[0, 0].grid(True)
axes[1, 0].plot(t, noisy)
axes[1, 0].set_title('Noisy Speech (Waveform)')
axes[1, 0].set_ylabel('Amplitude')
axes[1, 0].grid(True)
axes[2, 0].plot(t, enhanced)
axes[2, 0].set_title('Enhanced Speech (Waveform)')
axes[2, 0].set_xlabel('Time (s)')
axes[2, 0].set_ylabel('Amplitude')
axes[2, 0].grid(True)
# Frequency domain plots (spectrograms)
from scipy import signal
for idx, (audio, title) in enumerate([
(clean, 'Clean Speech (Spectrogram)'),
(noisy, 'Noisy Speech (Spectrogram)'),
(enhanced, 'Enhanced Speech (Spectrogram)')
]):
f, t_spec, Sxx = signal.spectrogram(audio, sample_rate, nperseg=512)
axes[idx, 1].pcolormesh(
t_spec, f, 10 * np.log10(Sxx + 1e-10),
shading='gouraud',
cmap='viridis'
)
axes[idx, 1].set_ylabel('Frequency (Hz)')
axes[idx, 1].set_title(title)
if idx == 2:
axes[idx, 1].set_xlabel('Time (s)')
plt.tight_layout()
plt.savefig('/mnt/user-data/outputs/denoising_comparison.png', dpi=150)
print("\nβœ“ Comparison plot saved to: denoising_comparison.png")
plt.close()
def calculate_metrics(clean, enhanced):
"""Calculate quality metrics"""
# Signal-to-Noise Ratio (SNR)
noise = clean - enhanced
signal_power = np.mean(clean ** 2)
noise_power = np.mean(noise ** 2)
snr = 10 * np.log10(signal_power / (noise_power + 1e-10))
# Mean Squared Error
mse = np.mean((clean - enhanced) ** 2)
# Root Mean Squared Error
rmse = np.sqrt(mse)
return {
'SNR (dB)': snr,
'MSE': mse,
'RMSE': rmse
}
def main():
"""Main example function"""
print("="*60)
print("DTLN Model Example for Alif E7 Ethos-U55")
print("="*60)
# 1. Create model
print("\n1. Creating DTLN model...")
dtln = create_lightweight_model(target_size_kb=100)
model = dtln.build_model()
print(" βœ“ Model created")
# 2. Generate test audio
print("\n2. Generating test audio...")
clean, noisy = generate_test_audio(duration=2.0)
print(f" βœ“ Generated {len(clean)/16000:.1f}s of audio")
print(f" βœ“ Clean audio range: [{np.min(clean):.3f}, {np.max(clean):.3f}]")
print(f" βœ“ Noisy audio range: [{np.min(noisy):.3f}, {np.max(noisy):.3f}]")
# Save test audio
sf.write('/mnt/user-data/outputs/test_clean.wav', clean, 16000)
sf.write('/mnt/user-data/outputs/test_noisy.wav', noisy, 16000)
print(" βœ“ Saved: test_clean.wav, test_noisy.wav")
# 3. Process through model (random weights, not trained yet)
print("\n3. Processing through model...")
print(" ⚠ Note: Model has random weights (not trained)")
# Expand dims for batch
noisy_batch = np.expand_dims(noisy, 0)
# Forward pass
enhanced = model.predict(noisy_batch, verbose=0)
enhanced = enhanced[0] # Remove batch dimension
print(" βœ“ Processing complete")
print(f" βœ“ Enhanced audio range: [{np.min(enhanced):.3f}, {np.max(enhanced):.3f}]")
# Save enhanced audio
sf.write('/mnt/user-data/outputs/test_enhanced.wav', enhanced, 16000)
print(" βœ“ Saved: test_enhanced.wav")
# 4. Calculate metrics
print("\n4. Quality Metrics:")
metrics = calculate_metrics(clean, enhanced)
for metric_name, value in metrics.items():
print(f" {metric_name}: {value:.4f}")
print("\n ⚠ Note: These metrics are poor because model is untrained")
print(" After training, expect SNR improvement of 10-15 dB")
# 5. Plot comparison
print("\n5. Creating visualization...")
plot_comparison(clean, noisy, enhanced)
# 6. Show model info
print("\n6. Model Information:")
print(f" Parameters: {model.count_params():,}")
print(f" Layers: {len(model.layers)}")
print(f" Input shape: {model.input_shape}")
print(f" Output shape: {model.output_shape}")
# 7. Build stateful models
print("\n7. Building stateful models for real-time inference...")
stage1, stage2 = dtln.build_stateful_model()
print(f" βœ“ Stage 1 parameters: {stage1.count_params():,}")
print(f" βœ“ Stage 2 parameters: {stage2.count_params():,}")
print("\n" + "="*60)
print("βœ“ Example complete!")
print("\nGenerated files:")
print(" - test_clean.wav")
print(" - test_noisy.wav")
print(" - test_enhanced.wav")
print(" - denoising_comparison.png")
print("\nNext steps:")
print(" 1. Train the model: python train_dtln.py --help")
print(" 2. Convert to TFLite: python convert_to_tflite.py --help")
print(" 3. Deploy to Alif E7")
print("="*60)
if __name__ == "__main__":
main()