Spaces:
Running
Running
| import matplotlib | |
| import numpy | |
| import soundfile as sf | |
| from matplotlib import pyplot as plt | |
| from matplotlib import cm | |
| matplotlib.use("tkAgg") | |
| from sklearn.manifold import TSNE | |
| from sklearn.decomposition import PCA | |
| from tqdm import tqdm | |
| from Preprocessing.ProsodicConditionExtractor import ProsodicConditionExtractor | |
| class Visualizer: | |
| def __init__(self, sr=48000, device="cpu"): | |
| """ | |
| Args: | |
| sr: The sampling rate of the audios you want to visualize. | |
| """ | |
| self.tsne = TSNE(n_jobs=-1) | |
| self.pca = PCA(n_components=2) | |
| self.pros_cond_ext = ProsodicConditionExtractor(sr=sr, device=device) | |
| self.sr = sr | |
| def visualize_speaker_embeddings(self, label_to_filepaths, title_of_plot, save_file_path=None, include_pca=True, legend=True): | |
| label_list = list() | |
| embedding_list = list() | |
| for label in tqdm(label_to_filepaths): | |
| for filepath in tqdm(label_to_filepaths[label]): | |
| wave, sr = sf.read(filepath) | |
| if len(wave) / sr < 1: | |
| continue | |
| if self.sr != sr: | |
| print("One of the Audios you included doesn't match the sampling rate of this visualizer object, " | |
| "creating a new condition extractor. Results will be correct, but if there are too many cases " | |
| "of changing samplingrate, this will run very slowly.") | |
| self.pros_cond_ext = ProsodicConditionExtractor(sr=sr) | |
| self.sr = sr | |
| embedding_list.append(self.pros_cond_ext.extract_condition_from_reference_wave(wave).squeeze().numpy()) | |
| label_list.append(label) | |
| embeddings_as_array = numpy.array(embedding_list) | |
| dimensionality_reduced_embeddings_tsne = self.tsne.fit_transform(embeddings_as_array) | |
| self._plot_embeddings(projected_data=dimensionality_reduced_embeddings_tsne, | |
| labels=label_list, | |
| title=title_of_plot + " t-SNE" if include_pca else title_of_plot, | |
| save_file_path=save_file_path, | |
| legend=legend) | |
| if include_pca: | |
| dimensionality_reduced_embeddings_pca = self.pca.fit_transform(embeddings_as_array) | |
| self._plot_embeddings(projected_data=dimensionality_reduced_embeddings_pca, | |
| labels=label_list, | |
| title=title_of_plot + " PCA", | |
| save_file_path=save_file_path, | |
| legend=legend) | |
| def _plot_embeddings(self, projected_data, labels, title, save_file_path, legend): | |
| colors = cm.gist_rainbow(numpy.linspace(0, 1, len(set(labels)))) | |
| label_to_color = dict() | |
| for index, label in enumerate(list(set(labels))): | |
| label_to_color[label] = colors[index] | |
| labels_to_points_x = dict() | |
| labels_to_points_y = dict() | |
| for label in labels: | |
| labels_to_points_x[label] = list() | |
| labels_to_points_y[label] = list() | |
| for index, label in enumerate(labels): | |
| labels_to_points_x[label].append(projected_data[index][0]) | |
| labels_to_points_y[label].append(projected_data[index][1]) | |
| fig, ax = plt.subplots() | |
| for label in set(labels): | |
| x = numpy.array(labels_to_points_x[label]) | |
| y = numpy.array(labels_to_points_y[label]) | |
| ax.scatter(x=x, | |
| y=y, | |
| c=label_to_color[label], | |
| label=label, | |
| alpha=0.9) | |
| if legend: | |
| ax.legend() | |
| fig.tight_layout() | |
| ax.axis('off') | |
| fig.subplots_adjust(top=0.9, bottom=0.0, right=1.0, left=0.0) | |
| ax.set_title(title) | |
| if save_file_path is not None: | |
| plt.savefig(save_file_path) | |
| else: | |
| plt.show() | |
| plt.close() | |