Spaces:
Running
on
Zero
Running
on
Zero
| """ Test the speed of the augmentation """ | |
| import torch | |
| import torchaudio | |
| # Device | |
| device = torch.device("cuda") | |
| # device = torch.device("cpu") | |
| # Music | |
| # x, _ = torchaudio.load("music.wav") | |
| # slice_length = 32767 | |
| # n_slices = 80 | |
| # slices = [x[0, i * slice_length:(i + 1) * slice_length] for i in range(n_slices)] | |
| # x = torch.stack(slices) # (80, 32767) | |
| # Sine wave | |
| t = torch.arange(0, 2.0479, 1 / 16000) # 2.05 seconds at 16kHz | |
| x = torch.sin(2 * torch.pi * 440 * t) * 0.5 | |
| x = x.reshape(1, 1, 32767).tile(80, 1, 1) | |
| x = x.to(device) | |
| ############################################################################################ | |
| # torch-audiomentation: https://github.com/asteroid-team/torch-audiomentation | |
| # | |
| # process time <CPU>: 1.18 s ± 5.35 ms | |
| # process time <GPU>: 58 ms | |
| # GPU memory usage: 3.8 GB per 1 semitone | |
| ############################################################################################ | |
| import torch | |
| from torch_audiomentations import Compose, PitchShift, Gain, PolarityInversion | |
| apply_augmentation = Compose(transforms=[ | |
| # Gain( | |
| # min_gain_in_db=-15.0, | |
| # max_gain_in_db=5.0, | |
| # p=0.5, | |
| # ), | |
| # PolarityInversion(p=0.5) | |
| PitchShift( | |
| min_transpose_semitones=0, | |
| max_transpose_semitones=2.2, | |
| mode="per_batch", #"per_example", | |
| p=1.0, | |
| p_mode="per_batch", | |
| sample_rate=16000, | |
| target_rate=16000) | |
| ]) | |
| x_am = apply_augmentation(x, sample_rate=16000) | |
| ############################################################################################ | |
| # torchaudio: | |
| # | |
| # process time <CPU>: 4.01 s ± 19.6 ms per loop | |
| # process time <GPU>: 25.1 ms ± 161 µs per loop | |
| # memory usage <GPU>: 1.2 (growth to 5.49) GB per 1 semitone | |
| ############################################################################################ | |
| from torchaudio import transforms | |
| ta_transform = transforms.PitchShift(16000, n_steps=2).to(device) | |
| x_ta = ta_transform(x) | |
| ############################################################################################ | |
| # YourMT3 pitch_shift_layer: | |
| # | |
| # process time <CPU>: 389ms ± 22ms, (stretch=143 ms, resampler=245 ms) | |
| # process time <GPU>: 7.18 ms ± 17.3 µs (stretch=6.47 ms, resampler=0.71 ms) | |
| # memory usage: 16 MB per 1 semitone (average) | |
| ############################################################################################ | |
| from model.pitchshift_layer import PitchShiftLayer | |
| ps_ymt3 = PitchShiftLayer(pshift_range=[2, 2], fs=16000, min_gcd=16, n_fft=2048).to(device) | |
| x_ymt3 = ps_ymt3(x, 2) | |
| ############################################################################################ | |
| # Plot 1: Comparison of Process Time and GPU Memory Usage for 3 Pitch Shifting Methods | |
| ############################################################################################ | |
| import matplotlib.pyplot as plt | |
| # Model names | |
| models = ['torch-audiomentation', 'torchaudio', 'YourMT3:PitchShiftLayer'] | |
| # Process time (CPU) in seconds | |
| cpu_time = [1.18, 4.01, 0.389] | |
| # Process time (GPU) in milliseconds | |
| gpu_time = [58, 25.1, 7.18] | |
| # GPU memory usage in GB | |
| gpu_memory = [3.8, 5.49, 0.016] | |
| # Creating subplots | |
| fig, axs = plt.subplots(1, 3, figsize=(15, 5)) | |
| # Creating bar charts | |
| bar1 = axs[0].bar(models, cpu_time, color=['#FFB6C1', '#ADD8E6', '#98FB98']) | |
| bar2 = axs[1].bar(models, gpu_time, color=['#FFB6C1', '#ADD8E6', '#98FB98']) | |
| bar3 = axs[2].bar(models, gpu_memory, color=['#FFB6C1', '#ADD8E6', '#98FB98']) | |
| # Adding labels and titles | |
| axs[0].set_ylabel('Time (s)') | |
| axs[0].set_title('Process Time (CPU) bsz=80') | |
| axs[1].set_ylabel('Time (ms)') | |
| axs[1].set_title('Process Time (GPU) bsz=80') | |
| axs[2].set_ylabel('Memory (GB)') | |
| axs[2].set_title('GPU Memory Usage per semitone') | |
| # Adding grid for better readability of the plots | |
| for ax in axs: | |
| ax.grid(axis='y') | |
| ax.set_yscale('log') | |
| ax.set_xticklabels(models, rotation=45, ha="right") | |
| # Adding text labels above the bars | |
| for i, rect in enumerate(bar1): | |
| axs[0].text( | |
| rect.get_x() + rect.get_width() / 2, | |
| rect.get_height(), | |
| f'{cpu_time[i]:.2f} s', | |
| ha='center', | |
| va='bottom') | |
| for i, rect in enumerate(bar2): | |
| axs[1].text( | |
| rect.get_x() + rect.get_width() / 2, | |
| rect.get_height(), | |
| f'{gpu_time[i]:.2f} ms', | |
| ha='center', | |
| va='bottom') | |
| for i, rect in enumerate(bar3): | |
| axs[2].text( | |
| rect.get_x() + rect.get_width() / 2, | |
| rect.get_height(), | |
| f'{gpu_memory[i]:.3f} GB', | |
| ha='center', | |
| va='bottom') | |
| plt.tight_layout() | |
| plt.show() | |
| ############################################################################################ | |
| # Plot 2: Stretch and Resampler Processing Time Contribution | |
| ############################################################################################ | |
| # Data | |
| processing_type = ['Stretch (Phase Vocoder)', 'Resampler (Conv1D)'] | |
| cpu_times = [143, 245] # [Stretch, Resampler] times for CPU in milliseconds | |
| gpu_times = [6.47, 0.71] # [Stretch, Resampler] times for GPU in milliseconds | |
| # Creating subplots | |
| fig, axs = plt.subplots(1, 2, figsize=(12, 6)) | |
| # Plotting bar charts | |
| axs[0].bar(processing_type, cpu_times, color=['#ADD8E6', '#98FB98']) | |
| axs[1].bar(processing_type, gpu_times, color=['#ADD8E6', '#98FB98']) | |
| # Adding labels and titles | |
| axs[0].set_ylabel('Time (ms)') | |
| axs[0].set_title('Contribution of CPU Processing Time: YMT3-PS (BSZ=80)') | |
| axs[1].set_title('Contribution of GPU Processing Time: YMT3-PS (BSZ=80)') | |
| # Adding grid for better readability of the plots | |
| for ax in axs: | |
| ax.grid(axis='y') | |
| ax.set_yscale('log') # Log scale to better visualize the smaller values | |
| # Adding values on top of the bars | |
| for ax, times in zip(axs, [cpu_times, gpu_times]): | |
| for idx, time in enumerate(times): | |
| ax.text(idx, time, f"{time:.2f} ms", ha='center', va='bottom', fontsize=8) | |
| plt.tight_layout() | |
| plt.show() | |