Spaces:
Sleeping
Sleeping
| """ | |
| Example: Quick test of DTLN model | |
| This script demonstrates how to: | |
| 1. Create a model | |
| 2. Generate synthetic noisy audio | |
| 3. Process it through the model | |
| """ | |
| import numpy as np | |
| import soundfile as sf | |
| from dtln_ethos_u55 import DTLN_Ethos_U55, create_lightweight_model | |
| import matplotlib.pyplot as plt | |
| def generate_test_audio(duration=2.0, sample_rate=16000): | |
| """ | |
| Generate synthetic test audio (speech + noise) | |
| Args: | |
| duration: Audio duration in seconds | |
| sample_rate: Sampling rate | |
| Returns: | |
| Tuple of (clean, noisy) audio | |
| """ | |
| t = np.linspace(0, duration, int(duration * sample_rate)) | |
| # Generate synthetic "speech" (mixture of frequencies) | |
| speech = ( | |
| 0.3 * np.sin(2 * np.pi * 200 * t) + | |
| 0.2 * np.sin(2 * np.pi * 400 * t) + | |
| 0.15 * np.sin(2 * np.pi * 600 * t) | |
| ) | |
| # Add envelope to simulate speech patterns | |
| envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 3 * t) | |
| speech = speech * envelope | |
| # Generate noise | |
| noise = np.random.randn(len(t)) * 0.15 | |
| # Mix speech and noise (SNR ~10dB) | |
| noisy = speech + noise | |
| # Normalize | |
| speech = speech / (np.max(np.abs(speech)) + 1e-8) * 0.9 | |
| noisy = noisy / (np.max(np.abs(noisy)) + 1e-8) * 0.9 | |
| return speech.astype(np.float32), noisy.astype(np.float32) | |
| def plot_comparison(clean, noisy, enhanced, sample_rate=16000): | |
| """Plot waveforms and spectrograms for comparison""" | |
| fig, axes = plt.subplots(3, 2, figsize=(12, 10)) | |
| # Time domain plots | |
| t = np.arange(len(clean)) / sample_rate | |
| axes[0, 0].plot(t, clean) | |
| axes[0, 0].set_title('Clean Speech (Waveform)') | |
| axes[0, 0].set_ylabel('Amplitude') | |
| axes[0, 0].grid(True) | |
| axes[1, 0].plot(t, noisy) | |
| axes[1, 0].set_title('Noisy Speech (Waveform)') | |
| axes[1, 0].set_ylabel('Amplitude') | |
| axes[1, 0].grid(True) | |
| axes[2, 0].plot(t, enhanced) | |
| axes[2, 0].set_title('Enhanced Speech (Waveform)') | |
| axes[2, 0].set_xlabel('Time (s)') | |
| axes[2, 0].set_ylabel('Amplitude') | |
| axes[2, 0].grid(True) | |
| # Frequency domain plots (spectrograms) | |
| from scipy import signal | |
| for idx, (audio, title) in enumerate([ | |
| (clean, 'Clean Speech (Spectrogram)'), | |
| (noisy, 'Noisy Speech (Spectrogram)'), | |
| (enhanced, 'Enhanced Speech (Spectrogram)') | |
| ]): | |
| f, t_spec, Sxx = signal.spectrogram(audio, sample_rate, nperseg=512) | |
| axes[idx, 1].pcolormesh( | |
| t_spec, f, 10 * np.log10(Sxx + 1e-10), | |
| shading='gouraud', | |
| cmap='viridis' | |
| ) | |
| axes[idx, 1].set_ylabel('Frequency (Hz)') | |
| axes[idx, 1].set_title(title) | |
| if idx == 2: | |
| axes[idx, 1].set_xlabel('Time (s)') | |
| plt.tight_layout() | |
| plt.savefig('/mnt/user-data/outputs/denoising_comparison.png', dpi=150) | |
| print("\nβ Comparison plot saved to: denoising_comparison.png") | |
| plt.close() | |
| def calculate_metrics(clean, enhanced): | |
| """Calculate quality metrics""" | |
| # Signal-to-Noise Ratio (SNR) | |
| noise = clean - enhanced | |
| signal_power = np.mean(clean ** 2) | |
| noise_power = np.mean(noise ** 2) | |
| snr = 10 * np.log10(signal_power / (noise_power + 1e-10)) | |
| # Mean Squared Error | |
| mse = np.mean((clean - enhanced) ** 2) | |
| # Root Mean Squared Error | |
| rmse = np.sqrt(mse) | |
| return { | |
| 'SNR (dB)': snr, | |
| 'MSE': mse, | |
| 'RMSE': rmse | |
| } | |
| def main(): | |
| """Main example function""" | |
| print("="*60) | |
| print("DTLN Model Example for Alif E7 Ethos-U55") | |
| print("="*60) | |
| # 1. Create model | |
| print("\n1. Creating DTLN model...") | |
| dtln = create_lightweight_model(target_size_kb=100) | |
| model = dtln.build_model() | |
| print(" β Model created") | |
| # 2. Generate test audio | |
| print("\n2. Generating test audio...") | |
| clean, noisy = generate_test_audio(duration=2.0) | |
| print(f" β Generated {len(clean)/16000:.1f}s of audio") | |
| print(f" β Clean audio range: [{np.min(clean):.3f}, {np.max(clean):.3f}]") | |
| print(f" β Noisy audio range: [{np.min(noisy):.3f}, {np.max(noisy):.3f}]") | |
| # Save test audio | |
| sf.write('/mnt/user-data/outputs/test_clean.wav', clean, 16000) | |
| sf.write('/mnt/user-data/outputs/test_noisy.wav', noisy, 16000) | |
| print(" β Saved: test_clean.wav, test_noisy.wav") | |
| # 3. Process through model (random weights, not trained yet) | |
| print("\n3. Processing through model...") | |
| print(" β Note: Model has random weights (not trained)") | |
| # Expand dims for batch | |
| noisy_batch = np.expand_dims(noisy, 0) | |
| # Forward pass | |
| enhanced = model.predict(noisy_batch, verbose=0) | |
| enhanced = enhanced[0] # Remove batch dimension | |
| print(" β Processing complete") | |
| print(f" β Enhanced audio range: [{np.min(enhanced):.3f}, {np.max(enhanced):.3f}]") | |
| # Save enhanced audio | |
| sf.write('/mnt/user-data/outputs/test_enhanced.wav', enhanced, 16000) | |
| print(" β Saved: test_enhanced.wav") | |
| # 4. Calculate metrics | |
| print("\n4. Quality Metrics:") | |
| metrics = calculate_metrics(clean, enhanced) | |
| for metric_name, value in metrics.items(): | |
| print(f" {metric_name}: {value:.4f}") | |
| print("\n β Note: These metrics are poor because model is untrained") | |
| print(" After training, expect SNR improvement of 10-15 dB") | |
| # 5. Plot comparison | |
| print("\n5. Creating visualization...") | |
| plot_comparison(clean, noisy, enhanced) | |
| # 6. Show model info | |
| print("\n6. Model Information:") | |
| print(f" Parameters: {model.count_params():,}") | |
| print(f" Layers: {len(model.layers)}") | |
| print(f" Input shape: {model.input_shape}") | |
| print(f" Output shape: {model.output_shape}") | |
| # 7. Build stateful models | |
| print("\n7. Building stateful models for real-time inference...") | |
| stage1, stage2 = dtln.build_stateful_model() | |
| print(f" β Stage 1 parameters: {stage1.count_params():,}") | |
| print(f" β Stage 2 parameters: {stage2.count_params():,}") | |
| print("\n" + "="*60) | |
| print("β Example complete!") | |
| print("\nGenerated files:") | |
| print(" - test_clean.wav") | |
| print(" - test_noisy.wav") | |
| print(" - test_enhanced.wav") | |
| print(" - denoising_comparison.png") | |
| print("\nNext steps:") | |
| print(" 1. Train the model: python train_dtln.py --help") | |
| print(" 2. Convert to TFLite: python convert_to_tflite.py --help") | |
| print(" 3. Deploy to Alif E7") | |
| print("="*60) | |
| if __name__ == "__main__": | |
| main() | |