Spaces:
Running
Running
Update inference.py
Browse files- inference.py +18 -5
inference.py
CHANGED
|
@@ -10,6 +10,10 @@ from SVCNN import SVCNN
|
|
| 10 |
from utils.tools import extract_voiced_area
|
| 11 |
from utils.extract_pitch import extract_pitch_ref as extract_pitch, coarse_f0
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
SPEAKER_INFORMATION_WEIGHTS = [
|
| 14 |
0, 0, 0, 0, 0, 0, # layer 0-5
|
| 15 |
1.0, 0, 0, 0,
|
|
@@ -51,12 +55,21 @@ def svc(model, src_wav_path, ref_wav_path, synth_set_path=None, f0_factor=0., sp
|
|
| 51 |
if synth_set_path:
|
| 52 |
synth_set = torch.load(synth_set_path).to(device)
|
| 53 |
else:
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
query_len = query_seq.shape[0]
|
| 61 |
if len(query_mask) > query_len:
|
| 62 |
query_mask = query_mask[:query_len]
|
|
|
|
| 10 |
from utils.tools import extract_voiced_area
|
| 11 |
from utils.extract_pitch import extract_pitch_ref as extract_pitch, coarse_f0
|
| 12 |
|
| 13 |
+
from Phoneme_Hallucinator_v2.utils.hparams import HParams
|
| 14 |
+
from Phoneme_Hallucinator_v2.models import get_model as get_hallucinator
|
| 15 |
+
from Phoneme_Hallucinator_v2.scripts.speech_expansion_ins import single_expand
|
| 16 |
+
|
| 17 |
SPEAKER_INFORMATION_WEIGHTS = [
|
| 18 |
0, 0, 0, 0, 0, 0, # layer 0-5
|
| 19 |
1.0, 0, 0, 0,
|
|
|
|
| 55 |
if synth_set_path:
|
| 56 |
synth_set = torch.load(synth_set_path).to(device)
|
| 57 |
else:
|
| 58 |
+
synth_set_path = f"matching_set/{ref_name}.pt"
|
| 59 |
+
synth_set = model.get_matching_set(ref_wav_path, out_path=synth_set_path).to(device)
|
| 60 |
+
|
| 61 |
+
if hallucinated_set_path is None:
|
| 62 |
+
params = HParams('Phoneme_Hallucinator_v2/exp/speech_XXL_cond/params.json')
|
| 63 |
+
Hallucinator = get_hallucinator(params)
|
| 64 |
+
Hallucinator.load()
|
| 65 |
+
hallucinated_set = single_expand(synth_set_path, Hallucinator, 15000)
|
| 66 |
+
else:
|
| 67 |
+
hallucinated_set = np.load(hallucinated_set_path)
|
| 68 |
|
| 69 |
+
hallucinated_set = torch.from_numpy(hallucinated_set).to(device)
|
| 70 |
+
|
| 71 |
+
synth_set = torch.cat([synth_set, hallucinated_set], dim=0)
|
| 72 |
+
|
| 73 |
query_len = query_seq.shape[0]
|
| 74 |
if len(query_mask) > query_len:
|
| 75 |
query_mask = query_mask[:query_len]
|