Spaces:
Running
on
T4
Running
on
T4
| import os | |
| import pickle | |
| import matplotlib.pyplot as plt | |
| import networkx as nx | |
| import torch | |
| from tqdm import tqdm | |
| from Modules.ToucanTTS.InferenceToucanTTS import ToucanTTS | |
| from Utility.utils import load_json_from_path | |
| distance_types = ["tree", "asp", "map", "learned", "l1"] | |
| modes = ["plot_all", "plot_neighbors"] | |
| neighbor = "Latin" | |
| num_neighbors = 12 | |
| distance_type = distance_types[0] # switch here | |
| mode = modes[1] | |
| edge_threshold = 0.01 | |
| # TODO histograms to figure out a good threshold | |
| cache_root = "." | |
| supervised_iso_codes = load_json_from_path(os.path.join(cache_root, "supervised_languages.json")) | |
| if distance_type == "l1": | |
| iso_codes_to_ids = load_json_from_path(os.path.join(cache_root, "iso_lookup.json"))[-1] | |
| model_path = "../../Models/ToucanTTS_Meta/best.pt" | |
| checkpoint = torch.load(model_path, map_location='cpu') | |
| embedding_provider = ToucanTTS(weights=checkpoint["model"], config=checkpoint["config"]).encoder.language_embedding | |
| embedding_provider.requires_grad_(False) | |
| l1_dist = dict() | |
| seen_langs = set() | |
| for lang_1 in supervised_iso_codes: | |
| if lang_1 not in seen_langs: | |
| seen_langs.add(lang_1) | |
| l1_dist[lang_1] = dict() | |
| for lang_2 in supervised_iso_codes: | |
| if lang_2 not in seen_langs: # it's symmetric | |
| l1_dist[lang_1][lang_2] = torch.nn.functional.mse_loss(embedding_provider(torch.LongTensor([iso_codes_to_ids[lang_1]])).squeeze(), embedding_provider(torch.LongTensor([iso_codes_to_ids[lang_2]])).squeeze()) | |
| largest_value_l1_dist = 0.0 | |
| for _, values in l1_dist.items(): | |
| for _, value in values.items(): | |
| largest_value_l1_dist = max(largest_value_l1_dist, value) | |
| for key1 in l1_dist: | |
| for key2 in l1_dist[key1]: | |
| l1_dist[key1][key2] = l1_dist[key1][key2] / largest_value_l1_dist | |
| distance_measure = l1_dist | |
| if distance_type == "tree": | |
| tree_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_tree_dist.json") | |
| tree_dist = load_json_from_path(tree_lookup_path) | |
| distance_measure = tree_dist | |
| if distance_type == "map": | |
| map_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_map_dist.json") | |
| map_dist = load_json_from_path(map_lookup_path) | |
| largest_value_map_dist = 0.0 | |
| for _, values in map_dist.items(): | |
| for _, value in values.items(): | |
| largest_value_map_dist = max(largest_value_map_dist, value) | |
| for key1 in map_dist: | |
| for key2 in map_dist[key1]: | |
| map_dist[key1][key2] = map_dist[key1][key2] / largest_value_map_dist | |
| distance_measure = map_dist | |
| if distance_type == "learned": | |
| learned_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_map_dist.json") | |
| learned_dist = load_json_from_path(learned_lookup_path) | |
| largest_value_learned_dist = 0.0 | |
| for _, values in learned_dist.items(): | |
| for _, value in values.items(): | |
| largest_value_learned_dist = max(largest_value_learned_dist, value) | |
| for key1 in learned_dist: | |
| for key2 in learned_dist[key1]: | |
| learned_dist[key1][key2] = learned_dist[key1][key2] / largest_value_learned_dist | |
| distance_measure = learned_dist | |
| if distance_type == "asp": | |
| asp_dict_path = os.path.join(cache_root, "asp_dict.pkl") | |
| with open(asp_dict_path, 'rb') as dictfile: | |
| asp_sim = pickle.load(dictfile) | |
| lang_list = list(asp_sim.keys()) | |
| asp_dist = dict() | |
| seen_langs = set() | |
| for lang_1 in lang_list: | |
| if lang_1 not in seen_langs: | |
| seen_langs.add(lang_1) | |
| asp_dist[lang_1] = dict() | |
| for index, lang_2 in enumerate(lang_list): | |
| if lang_2 not in seen_langs: # it's symmetric | |
| asp_dist[lang_1][lang_2] = 1 - asp_sim[lang_1][index] | |
| distance_measure = asp_dist | |
| iso_codes_to_names = load_json_from_path(os.path.join(cache_root, "iso_to_fullname.json")) | |
| distances = list() | |
| for lang_1 in distance_measure: | |
| if lang_1 not in iso_codes_to_names: | |
| continue | |
| if lang_1 not in supervised_iso_codes and iso_codes_to_names[lang_1] != neighbor: | |
| continue | |
| for lang_2 in distance_measure[lang_1]: | |
| try: | |
| if lang_2 not in supervised_iso_codes and iso_codes_to_names[lang_2] != neighbor: | |
| continue | |
| except KeyError: | |
| continue | |
| distances.append((iso_codes_to_names[lang_1], iso_codes_to_names[lang_2], distance_measure[lang_1][lang_2])) | |
| # Create a graph | |
| G = nx.Graph() | |
| # Add edges along with distances as weights | |
| min_dist = min(d for _, _, d in distances) | |
| max_dist = max(d for _, _, d in distances) | |
| normalized_distances = [(entity1, entity2, (d - min_dist) / (max_dist - min_dist)) for entity1, entity2, d in distances] | |
| if mode == "plot_neighbors": | |
| fullnames = list() | |
| fullnames.append(neighbor) | |
| for code in supervised_iso_codes: | |
| fullnames.append(iso_codes_to_names[code]) | |
| supervised_iso_codes = fullnames | |
| d_dist = list() | |
| for entity1, entity2, d in tqdm(normalized_distances): | |
| if (neighbor == entity2 or neighbor == entity1) and (entity1 in supervised_iso_codes and entity2 in supervised_iso_codes): | |
| if entity1 != entity2: | |
| d_dist.append(d) | |
| thresh = sorted(d_dist)[num_neighbors] | |
| # distance_scores = sorted(d_dist)[:num_neighbors] | |
| neighbors = list() | |
| for entity1, entity2, d in tqdm(normalized_distances): | |
| if (d < thresh and (neighbor == entity2 or neighbor == entity1)) and (entity1 in supervised_iso_codes and entity2 in supervised_iso_codes): | |
| neighbors.append(entity1) | |
| neighbors.append(entity2) | |
| unique_neighbors = list(set(neighbors)) | |
| unique_neighbors.remove(neighbor) | |
| for entity1, entity2, d in tqdm(normalized_distances): | |
| if (neighbor == entity2 or neighbor == entity1) and (entity1 in supervised_iso_codes and entity2 in supervised_iso_codes): | |
| if entity1 != entity2 and d < thresh: | |
| spring_tension = ((thresh - d) ** 2) * 20000 # for vis purposes | |
| print(f"{d}-->{spring_tension}") | |
| G.add_edge(entity1, entity2, weight=spring_tension) | |
| for entity1, entity2, d in tqdm(normalized_distances): | |
| if (entity2 in unique_neighbors and entity1 in unique_neighbors) and (entity1 in supervised_iso_codes and entity2 in supervised_iso_codes): | |
| if entity1 != entity2: | |
| spring_tension = 1 - d | |
| G.add_edge(entity1, entity2, weight=spring_tension) | |
| # Draw the graph | |
| pos = nx.spring_layout(G, weight="weight") # Positions for all nodes | |
| edges = G.edges(data=True) | |
| # Draw nodes | |
| nx.draw_networkx_nodes(G, pos, node_size=1, alpha=0.01) | |
| # Draw edges with labels | |
| edges_connected_to_specific_node = [(u, v) for u, v in G.edges() if u == neighbor or v == neighbor] | |
| # nx.draw_networkx_edges(G, pos, alpha=0.1) | |
| nx.draw_networkx_edges(G, pos, edgelist=edges_connected_to_specific_node, edge_color='red', alpha=0.3, width=3) | |
| for u, v, d in edges: | |
| if u == neighbor or v == neighbor: | |
| nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): round((thresh - (d['weight'] / 20000) ** (1 / 2)) * 10, 2)}, font_color="red", alpha=0.3) # reverse modifications | |
| else: | |
| pass | |
| # nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): d['weight']}) | |
| # Draw node labels | |
| nx.draw_networkx_labels(G, pos, font_size=14, font_family='sans-serif', font_color='green') | |
| nx.draw_networkx_labels(G, pos, labels={neighbor: neighbor}, font_size=14, font_family='sans-serif', font_color='red') | |
| plt.title(f'Graph of {distance_type} Distances') | |
| plt.subplots_adjust(left=0, right=1, top=1, bottom=0) | |
| plt.tight_layout(pad=0) | |
| plt.savefig("avg.png", dpi=300) | |
| plt.show() | |
| elif mode == "plot_all": | |
| for entity1, entity2, d in tqdm(normalized_distances): | |
| if d < edge_threshold and entity1 != entity2: | |
| spring_tension = edge_threshold - d | |
| G.add_edge(entity1, entity2, weight=spring_tension) | |
| # Draw the graph | |
| pos = nx.spring_layout(G, weight="weight") # Positions for all nodes | |
| edges = G.edges(data=True) | |
| # Draw nodes | |
| nx.draw_networkx_nodes(G, pos, node_size=1, alpha=0.01) | |
| # Draw edges with labels | |
| nx.draw_networkx_edges(G, pos, alpha=0.1, edge_color="blue") | |
| # nx.draw_networkx_edge_labels(G, pos, edge_labels={(u, v): d['weight'] for u, v, d in edges}) | |
| # Draw node labels | |
| nx.draw_networkx_labels(G, pos, font_size=10, font_family='sans-serif') | |
| plt.title(f'Graph of {distance_type} Distances') | |
| plt.subplots_adjust(left=0, right=1, top=1, bottom=0) | |
| plt.tight_layout(pad=0) | |
| plt.show() | |