Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import time | |
| # p = os.path.split(os.path.dirname(os.path.abspath(__file__)))[0] | |
| # sys.path.append(p) | |
| import argparse | |
| import numpy as np | |
| import tensorflow as tf | |
| tf.compat.v1.disable_v2_behavior() | |
| from ..utils.hparams import HParams | |
| from ..models import get_model | |
| import torch | |
| set_size = 200 | |
| threshold = 100 | |
| def fast_cosine_dist(source_feats, matching_pool): | |
| source_norms = torch.norm(source_feats, p=2, dim=-1) | |
| matching_norms = torch.norm(matching_pool, p=2, dim=-1) | |
| dotprod = -torch.cdist(source_feats[None], matching_pool[None], p=2)[0]**2 + source_norms[:, None]**2 + matching_norms[None]**2 | |
| dotprod /= 2 | |
| dists = 1 - ( dotprod / (source_norms[:, None] * matching_norms[None]) ) | |
| return dists | |
| def evaluate(batch, model): | |
| sample = model.execute(model.sample, batch) | |
| return sample | |
| def prematch(path, expanded): | |
| uttrs_from_same_spk = sorted(list(path.parent.rglob('**/*.pt'))) | |
| uttrs_from_same_spk.remove(path) | |
| candidates = [] | |
| for each in uttrs_from_same_spk: | |
| candidates.append(torch.load(each)) | |
| candidates = torch.cat(candidates,0) | |
| candidates = torch.cat([candidates, torch.tensor(expanded)], 0) | |
| source_feats = torch.load(path) | |
| source_feats=source_feats.to(torch.float32) | |
| dists = fast_cosine_dist(source_feats.cpu(), candidates.cpu()).cpu() | |
| best = dists.topk(k=args.topk, dim=-1, largest=False) # (src_len, 4) | |
| out_feats = candidates[best.indices].mean(dim=1) # (N, dim) | |
| return out_feats | |
| def single_expand(path, model, num_samples, seed=1234, out_path=None): | |
| np.random.seed(seed) | |
| tf.compat.v1.set_random_seed(seed) | |
| # test | |
| matching_set = torch.load(path, map_location=torch.device('cpu')).numpy() | |
| matching_set = matching_set / 10 | |
| matching_size = matching_set.shape[0] | |
| new_samples = [] | |
| cur_num_samples = 0 | |
| while cur_num_samples < num_samples: | |
| batch = dict() | |
| if matching_size < threshold: | |
| num_new_samples = set_size - matching_size | |
| padded_data = np.zeros((num_new_samples, matching_set.shape[1])) | |
| batch['b'] = np.concatenate([np.ones_like(matching_set), np.zeros_like(padded_data)], 0)[None, ...] | |
| batch['x'] = np.concatenate([matching_set, padded_data], axis=0)[None, ...] | |
| batch['m'] = np.ones_like(batch['b']) | |
| sample = evaluate(batch, model) | |
| new_sample = sample[0,matching_size:] * 10 | |
| cur_num_samples += num_new_samples | |
| else: | |
| num_new_samples = set_size - threshold | |
| ind = np.random.choice(matching_size, threshold, replace=False) | |
| padded_data = np.zeros((num_new_samples, matching_set.shape[1])) | |
| obs_data = matching_set[ind] | |
| batch['x'] = np.concatenate([obs_data, padded_data], 0)[None, ...] | |
| batch['b'] = np.concatenate([np.ones_like(obs_data), np.zeros_like(padded_data)], 0)[None, ...] | |
| batch['m'] = np.ones_like(batch['b']) | |
| sample = evaluate(batch, model) | |
| new_sample = sample[0,num_new_samples:,:] * 10 | |
| cur_num_samples += num_new_samples | |
| new_samples.append(new_sample) | |
| new_samples = np.concatenate(new_samples, 0) | |
| new_samples = new_samples[:num_samples] | |
| if out_path: | |
| os.makedirs(os.path.dirname(out_path), exist_ok=True) | |
| np.save(out_path, new_samples) | |
| return new_samples | |
| def single_expand_fast(path): | |
| # test | |
| matching_set = torch.load(path).cpu().numpy() | |
| matching_set = matching_set / 10 | |
| matching_size = matching_set.shape[0] | |
| batch = dict() | |
| if matching_size < threshold: | |
| num_new_samples = set_size - matching_size | |
| else: | |
| num_new_samples = set_size - threshold | |
| batch_size = int(np.ceil(args.num_samples // num_new_samples)) | |
| if matching_size < threshold: | |
| padded_data = np.zeros((num_new_samples, matching_set.shape[1])) | |
| batch['b'] = np.concatenate([np.ones_like(matching_set), np.zeros_like(padded_data)], 0)[None, ...] | |
| batch['x'] = np.concatenate([matching_set, padded_data], axis=0)[None, ...] | |
| batch['b'] = np.tile(batch['b'], (batch_size, 1, 1)) | |
| batch['x'] = np.tile(batch['b'], (batch_size, 1, 1)) | |
| batch['m'] = np.ones_like(batch['b']) | |
| sample = evaluate(batch, model) | |
| new_samples = sample[:,matching_size:, :] * 10 | |
| new_samples = new_samples.reshape((-1, new_samples.shape[-1])) | |
| else: | |
| padded_data = np.zeros((num_new_samples, matching_set.shape[1])) | |
| batch['x'] = [] | |
| for i in range(batch_size): | |
| ind = np.random.choice(matching_size, threshold, replace=False) | |
| obs_data = matching_set[ind] | |
| batch['x'].append(np.concatenate([obs_data, padded_data], 0)[None, ...]) | |
| batch['x'] = np.concatenate(batch['x'], 0) | |
| batch['b'] = np.concatenate([np.ones_like(obs_data), np.zeros_like(padded_data)], 0)[None, ...] | |
| batch['b'] = np.tile(batch['b'], (batch_size, 1, 1)) | |
| batch['m'] = np.ones_like(batch['b']) | |
| sample = evaluate(batch, model) | |
| new_samples = sample[:,matching_size:, :] * 10 | |
| new_samples = new_samples.reshape((-1, new_samples.shape[-1])) | |
| new_samples = new_samples[:args.num_samples] | |
| return new_samples | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--cfg_file', type=str) | |
| parser.add_argument('--seed', type=int, default=1234) | |
| parser.add_argument('--gpu', type=str, default='0') | |
| parser.add_argument('--num_samples', type=int, default=100) | |
| parser.add_argument('--path', type=str, default="matching_set.pt") | |
| parser.add_argument('--out_path', type=str, default="expanded_set.pt") | |
| parser.add_argument('--topk', type=int, default=4) | |
| args = parser.parse_args() | |
| params = HParams(args.cfg_file) | |
| # modify config | |
| t0 = time.time() | |
| # model | |
| model = get_model(params) | |
| model.load() | |
| t1 = time.time() | |
| print(f"{t1-t0:.2f}s to load the model") | |
| os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu | |
| path = args.path | |
| if path.endswith(".pt"): | |
| t0 = time.time() | |
| expanded = single_expand(path, model, args.num_samples, args.seed, args.out_path) | |
| t1 = time.time() | |
| print(f"{t1-t0:.2f}s to expand the set") | |