Spaces:
Running
Running
| import os | |
| from pathlib import Path | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchaudio | |
| import torchaudio.transforms as T | |
| from models import GeneratorNSF as Generator, PitchEncoder as PitchEmbedding | |
| from torch import Tensor | |
| from torchaudio.sox_effects import apply_effects_tensor | |
| from wavlm.WavLM import WavLM, WavLMConfig | |
| import json | |
| from utils.tools import AttrDict, load_wav | |
| SPEAKER_INFORMATION_WEIGHTS = [ | |
| 0, 0, 0, 0, 0, 0, # layer 0-5 | |
| 1.0, 0, 0, 0, | |
| 0, 0, 0, 0, 0, 0, # layer 15 | |
| 0, 0, 0, 0, 0, 0, # layer 16-21 | |
| 0, # layer 22 | |
| 0, 0 # layer 23-24 | |
| ] | |
| SPEAKER_INFORMATION_LAYER = 6 | |
| def fast_cosine_dist(source_feats: Tensor, matching_pool: Tensor, device: str = 'cpu') -> Tensor: | |
| """ Like torch.cdist, but fixed dim=-1 and for cosine distance.""" | |
| source_norms = torch.norm(source_feats, p=2, dim=-1).to(device) | |
| matching_norms = torch.norm(matching_pool, p=2, dim=-1) | |
| dotprod = -torch.cdist(source_feats[None].to(device), matching_pool[None], p=2)[0]**2 + source_norms[:, None]**2 + matching_norms[None]**2 | |
| dotprod /= 2 | |
| dists = 1 - ( dotprod / (source_norms[:, None] * matching_norms[None]) ) | |
| return dists | |
| class SVCNN(nn.Module): | |
| def __init__(self, | |
| model_ckpt_path, | |
| model_cfg_path='config.json', | |
| wavlm_ckpt_path='ckpt/WavLM-Large.pt', | |
| device='cpu' | |
| ) -> None: | |
| super().__init__() | |
| # set which features to extract from wavlm | |
| self.weighting = torch.tensor(SPEAKER_INFORMATION_WEIGHTS, device=device)[:, None] | |
| # load model | |
| with open(model_cfg_path) as f: | |
| data = f.read() | |
| json_config = json.loads(data) | |
| model_cfg = AttrDict(json_config) | |
| pitch_emb = PitchEmbedding(model_cfg).to(device) | |
| model = Generator(model_cfg).to(device) | |
| state_dict_g = torch.load(model_ckpt_path, map_location='cpu') | |
| pitch_emb.load_state_dict(state_dict_g['pitch_encoder']) | |
| model.load_state_dict(state_dict_g['generator']) | |
| model.remove_weight_norm() | |
| self.pitch_emb = pitch_emb.to(device).eval() | |
| self.model = model.to(device).eval() | |
| print(f"Generator loaded with {sum([p.numel() for p in model.parameters() if p.requires_grad]) + sum([p.numel() for p in pitch_emb.parameters() if p.requires_grad]):,d} parameters.") | |
| self.h = model_cfg | |
| # load wavlm | |
| wavlm_ckpt = torch.load(wavlm_ckpt_path, map_location='cpu') | |
| cfg = WavLMConfig(wavlm_ckpt['cfg']) | |
| wavlm = WavLM(cfg) | |
| wavlm.load_state_dict(wavlm_ckpt['model']) | |
| wavlm.to(device) | |
| self.wavlm = wavlm.eval() | |
| print('wavlm loaded') | |
| self.device = torch.device(device) | |
| self.sr = 16000 | |
| self.hop_length = 320 | |
| def get_matching_set(self, p: Path|str, weights=None, vad_trigger_level=0, out_path=None) -> Tensor: | |
| """ Get concatenated wavlm features for the matching set using all waveforms in `wavs`, | |
| specified as either a list of paths or list of loaded waveform tensors of | |
| shape (channels, T), assumed to be of 16kHz sample rate. | |
| Optionally specify custom WavLM feature weighting with `weights`. | |
| """ | |
| feats = [] | |
| # 只取有声段作为matching_set | |
| # x, sr = torchaudio.load(p, normalize=True) | |
| x, sr = load_wav(p, self.sr) | |
| # x, _, __ = trim_long_silences(x, sr) | |
| audio_length = len(x) | |
| slice_length = 60*self.sr | |
| for start_pos in range(0, audio_length, slice_length): | |
| end_pos = start_pos + slice_length | |
| slice_x = x[start_pos:end_pos] | |
| slice_x = torch.from_numpy(slice_x).float() | |
| feats.append(self.get_features(slice_x, weights=self.weighting if weights is None else weights, vad_trigger_level=vad_trigger_level)) | |
| feats = torch.concat(feats, dim=0).cpu() | |
| if out_path: | |
| os.makedirs(os.path.dirname(out_path), exist_ok=True) | |
| torch.save(feats, out_path) | |
| return feats | |
| def vocode(self, c: Tensor, pitch:Tensor) -> Tensor: | |
| y_g_hat = self.model(c, pitch) | |
| y_g_hat = y_g_hat.squeeze(1) | |
| return y_g_hat | |
| def get_features(self, path, weights=None, vad_trigger_level=0): | |
| """Returns features of `path` waveform as a tensor of shape (seq_len, dim), optionally perform VAD trimming | |
| on start/end with `vad_trigger_level`. | |
| """ | |
| # load audio | |
| if weights == None: weights = self.weighting | |
| if type(path) in [str, Path]: | |
| # x, sr = torchaudio.load(path, normalize=True) | |
| x, sr = load_wav(path, self.sr) | |
| x = torch.from_numpy(x).float() | |
| else: | |
| x: Tensor = path | |
| sr = self.sr | |
| if x.dim() == 1: x = x[None] | |
| assert sr == self.sr, f"input audio sample rate must be 16kHz. Got {sr}" | |
| # trim silence from front and back | |
| if vad_trigger_level > 1e-3: | |
| transform = T.Vad(sample_rate=sr, trigger_level=vad_trigger_level) | |
| x_front_trim = transform(x) | |
| waveform_reversed, sr = apply_effects_tensor(x_front_trim, sr, [["reverse"]]) | |
| waveform_reversed_front_trim = transform(waveform_reversed) | |
| waveform_end_trim, sr = apply_effects_tensor( | |
| waveform_reversed_front_trim, sr, [["reverse"]] | |
| ) | |
| x = waveform_end_trim | |
| # extract the representation of each layer | |
| wavs_split = torch.tensor_split(x, (x.shape[1]-1)//(sr*30)+1, dim=1) | |
| feature_list = [] | |
| for wav_chunk in wavs_split: | |
| wav_input_16khz = wav_chunk.to(self.device) | |
| if torch.allclose(weights, self.weighting): | |
| # use fastpath | |
| features = self.wavlm.extract_features(wav_input_16khz, output_layer=SPEAKER_INFORMATION_LAYER, ret_layer_results=False)[0] | |
| features = features.squeeze(0) | |
| feature_list.append(features) | |
| else: | |
| # use slower weighted | |
| rep, layer_results = self.wavlm.extract_features(wav_input_16khz, output_layer=self.wavlm.cfg.encoder_layers, ret_layer_results=True)[0] | |
| features = torch.cat([x.transpose(0, 1) for x, _ in layer_results], dim=0) # (n_layers, seq_len, dim) | |
| # save full sequence | |
| features = (features*weights[:, None] ).sum(dim=0) # (seq_len, dim) | |
| feature_list.append(features) | |
| return torch.cat(feature_list) | |
| def match(self, query_seq: Tensor, pitch:Tensor, pitch_bins:Tensor, synth_set: Tensor, | |
| topk: int = 4, query_mask: Tensor = None, alpha = 0, tgt_loudness_db: float | None = -16, | |
| target_duration: float | None = None, device: str | None = None) -> Tensor: | |
| """ Given `query_seq`, `matching_set`, and `synth_set` tensors of shape (N, dim), perform kNN regression matching | |
| with k=`topk`. Inputs: | |
| - `query_seq`: Tensor (N1, dim) of the input/source query features. | |
| - `matching_set`: Tensor (N2, dim) of the matching set used as the 'training set' for the kNN algorithm. | |
| - `synth_set`: optional Tensor (N2, dim) corresponding to the matching set. We use the matching set to assign each query | |
| vector to a vector in the matching set, and then use the corresponding vector from the synth set during decoder synthesis. | |
| By default, and for best performance, this should be identical to the matching set. | |
| - `topk`: k in the kNN -- the number of nearest neighbors to average over. | |
| - `tgt_loudness_db`: float db used to normalize the output volume. Set to None to disable. | |
| - `target_duration`: if set to a float, interpolate resulting waveform duration to be equal to this value in seconds. | |
| - `device`: if None, uses default device at initialization. Otherwise uses specified device | |
| Returns: | |
| - converted waveform of shape (T,) | |
| """ | |
| device = torch.device(device) if device is not None else self.device | |
| synth_set = synth_set.to(device) | |
| query_seq = query_seq.to(device) | |
| pitch = pitch.to(device) | |
| pitch_bins = pitch_bins.to(device) | |
| if target_duration is not None: | |
| target_samples = int(target_duration*self.sr) | |
| scale_factor = (target_samples/self.hop_length) / query_seq.shape[0] # n_targ_feats / n_input_feats | |
| query_seq = F.interpolate(query_seq.T[None], scale_factor=scale_factor, mode='linear')[0].T | |
| dists = fast_cosine_dist(query_seq, synth_set, device=device) | |
| best = dists.topk(k=topk, largest=False, dim=-1) | |
| out_feats = (1-alpha) * synth_set[best.indices].mean(dim=1) + alpha * query_seq | |
| if query_mask is not None: | |
| query_mask = query_mask[..., None].repeat([1, out_feats.shape[-1]]) | |
| out_feats = out_feats * query_mask + query_seq * (query_mask == False) | |
| out_feats = torch.repeat_interleave(out_feats, 2, 0) | |
| out_feats = self.pitch_emb(out_feats, pitch_bins) | |
| prediction = self.vocode(out_feats[None].to(device), pitch.unsqueeze(0)).cpu().squeeze() | |
| # normalization | |
| if tgt_loudness_db is not None: | |
| src_loudness = torchaudio.functional.loudness(prediction[None], self.h.sampling_rate) | |
| tgt_loudness = tgt_loudness_db | |
| pred_wav = torchaudio.functional.gain(prediction, tgt_loudness - src_loudness) | |
| else: pred_wav = prediction | |
| return pred_wav |