Spaces:
Running
Running
| import argparse | |
| import gc | |
| import os | |
| import sys | |
| import time | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchaudio | |
| from fastprogress.fastprogress import master_bar, progress_bar | |
| from torch import Tensor | |
| from hubconf import wavlm_large | |
| DOWNSAMPLE_FACTOR = 320 | |
| global feature_cache | |
| feature_cache = {} | |
| global synthesis_cache | |
| synthesis_cache = {} | |
| def make_librispeech_df(root_path: Path) -> pd.DataFrame: | |
| all_files = [] | |
| folders = ['train-clean-100', 'dev-clean'] | |
| print(f"[LIBRISPEECH] Computing folders {folders}") | |
| for f in folders: | |
| all_files.extend(list((root_path/f).rglob('**/*.flac'))) | |
| speakers = ['ls-' + f.stem.split('-')[0] for f in all_files] | |
| df = pd.DataFrame({'path': all_files, 'speaker': speakers}) | |
| return df | |
| def main(args): | |
| device = torch.device(args.device) | |
| SYNTH_WEIGHTINGS = F.one_hot(torch.tensor(args.synthesis_layer), num_classes=25).float().to(device)[:, None] | |
| MATCH_WEIGHTINGS = F.one_hot(torch.tensor(args.matching_layer), num_classes=25).float().to(device)[:, None] | |
| print(f"Matching weightings: {MATCH_WEIGHTINGS.squeeze()}\nSynthesis weightings: {SYNTH_WEIGHTINGS.squeeze()}") | |
| ls_df = make_librispeech_df(Path(args.librispeech_path)) | |
| print(f"Loading wavlm.") | |
| wavlm = wavlm_large(pretrained=True, progress=True, device=args.device) | |
| np.random.seed(args.seed) | |
| torch.manual_seed(args.seed) | |
| extract(ls_df, wavlm, args.device, Path(args.librispeech_path), Path(args.out_path), SYNTH_WEIGHTINGS, MATCH_WEIGHTINGS) | |
| print("All done!", flush=True) | |
| def path2pools(path: Path, wavlm: nn.Module(), match_weights: Tensor, synth_weights: Tensor, device): | |
| """Given a waveform `path`, compute the matching pool""" | |
| uttrs_from_same_spk = sorted(list(path.parent.rglob('**/*.flac'))) | |
| uttrs_from_same_spk.remove(path) | |
| matching_pool = [] | |
| synth_pool = [] | |
| for pth in uttrs_from_same_spk: | |
| if pth in feature_cache and pth in synthesis_cache: | |
| matching_feats = feature_cache[pth].float() # (seq_len, dim) | |
| synth_feats = synthesis_cache[pth].float() # (seq_len, dim) | |
| else: | |
| feats = get_full_features(pth, wavlm, device) | |
| matching_feats = ( feats*match_weights[:, None] ).sum(dim=0) # (seq_len, dim) | |
| synth_feats = ( feats*synth_weights[:, None] ).sum(dim=0) # (seq_len, dim) | |
| feature_cache[pth] = matching_feats.half().cpu() | |
| synthesis_cache[pth] = synth_feats.half().cpu() | |
| matching_pool.append(matching_feats.cpu()) | |
| synth_pool.append(synth_feats.cpu()) | |
| matching_pool = torch.concat(matching_pool, dim=0) | |
| synth_pool = torch.concat(synth_pool, dim=0) | |
| return matching_pool, synth_pool # (N, dim) | |
| def get_full_features(path, wavlm, device): | |
| x, sr = torchaudio.load(path) | |
| assert sr == 16000 | |
| # This does not work i.t.o the hifigan training. | |
| # x = F.pad(x, (DOWNSAMPLE_FACTOR//2, DOWNSAMPLE_FACTOR - DOWNSAMPLE_FACTOR//2), value=0) | |
| # This does. | |
| n_pad = DOWNSAMPLE_FACTOR - (x.shape[-1] % DOWNSAMPLE_FACTOR) | |
| x = F.pad(x, (0, n_pad), value=0) | |
| # extract the representation of each layer | |
| wav_input_16khz = x.to(device) | |
| rep, layer_results = wavlm.extract_features(wav_input_16khz, output_layer=wavlm.cfg.encoder_layers, ret_layer_results=True)[0] | |
| features = torch.cat([x.transpose(0, 1) for x, _ in layer_results], dim=0) # (n_layers, seq_len, dim) | |
| return features | |
| def fast_cosine_dist(source_feats, matching_pool): | |
| source_norms = torch.norm(source_feats, p=2, dim=-1) | |
| matching_norms = torch.norm(matching_pool, p=2, dim=-1) | |
| dotprod = -torch.cdist(source_feats[None], matching_pool[None], p=2)[0]**2 + source_norms[:, None]**2 + matching_norms[None]**2 | |
| dotprod /= 2 | |
| dists = 1 - ( dotprod / (source_norms[:, None] * matching_norms[None]) ) | |
| return dists | |
| def extract(df: pd.DataFrame, wavlm: nn.Module, device, ls_path: Path, out_path: Path, synth_weights: Tensor, match_weights: Tensor): | |
| pb = progress_bar(df.iterrows(), total=len(df)) | |
| for i, row in pb: | |
| rel_path = Path(row.path).relative_to(ls_path) | |
| targ_path = (out_path/rel_path).with_suffix('.pt') | |
| if args.resume: | |
| if targ_path.is_file(): continue | |
| # if targ_path.is_file(): continue | |
| os.makedirs(targ_path.parent, exist_ok=True) | |
| if Path(row.path) in feature_cache: | |
| source_feats = feature_cache[Path(row.path)].float() | |
| else: | |
| source_feats = get_full_features(row.path, wavlm, device) | |
| source_feats = ( source_feats*match_weights[:, None] ).sum(dim=0) # (seq_len, dim) | |
| matching_pool, synth_pool = path2pools(row.path, wavlm, match_weights, synth_weights, device) | |
| if not args.prematch: | |
| out_feats = source_feats.cpu() | |
| else: | |
| dists = fast_cosine_dist(source_feats.cpu(), matching_pool.cpu()).cpu() | |
| best = dists.topk(k=args.topk, dim=-1, largest=False) # (src_len, 4) | |
| out_feats = synth_pool[best.indices].mean(dim=1) # (N, dim) | |
| # save matched sequence | |
| if i < 3: print("Feature has shape: ", out_feats.shape, flush=True) | |
| # 3. save | |
| torch.save(out_feats.cpu().half(), str(targ_path)) | |
| if hasattr(pb, 'child'): | |
| pb.child.comment = str(rel_path) | |
| pb.child.wait_for = min(pb.child.wait_for, 10) | |
| pb.main_bar.comment = str(rel_path) | |
| else: | |
| pb.wait_for = min(pb.wait_for, 10) | |
| pb.comment = str(rel_path) | |
| if i % 1000 == 0: | |
| print(f"Done {i:,d}/{len(df):,d}", flush=True) | |
| feature_cache.clear() | |
| synthesis_cache.clear() | |
| gc.collect() | |
| time.sleep(4) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description="Compute matched wavlm features for a librispeech dataset") | |
| parser.add_argument('--librispeech_path', required=True, type=str) | |
| parser.add_argument('--seed', default=123, type=int) | |
| parser.add_argument('--out_path', required=True, type=str) | |
| parser.add_argument('--device', default='cuda', type=str) | |
| parser.add_argument('--topk', type=int, default=4) | |
| parser.add_argument('--matching_layer', type=int, default=6) | |
| parser.add_argument('--synthesis_layer', type=int, default=6) | |
| parser.add_argument('--prematch', action='store_true', help='prematch') | |
| parser.add_argument('--resume', action='store_true') | |
| args = parser.parse_args() | |
| main(args) | |