Commit
·
5cf7b18
1
Parent(s):
6eb9ea3
[Experimental] Gruut support
Browse files- app.py +4 -3
- gruut_phonemize.py +10 -0
- requirements.txt +2 -1
- styletts2importable.py +72 -58
app.py
CHANGED
|
@@ -16,13 +16,13 @@ voices = {}
|
|
| 16 |
# else:
|
| 17 |
for v in voicelist:
|
| 18 |
voices[v] = styletts2importable.compute_style(f'voices/{v}.wav')
|
| 19 |
-
def synthesize(text, voice):
|
| 20 |
if text.strip() == "":
|
| 21 |
raise gr.Error("You must enter some text")
|
| 22 |
if len(text) > 300:
|
| 23 |
raise gr.Error("Text must be under 300 characters")
|
| 24 |
v = voice.lower()
|
| 25 |
-
return (24000, styletts2importable.inference(text, voices[v], alpha=0.3, beta=0.7, diffusion_steps=7, embedding_scale=1))
|
| 26 |
def clsynthesize(text, voice):
|
| 27 |
if text.strip() == "":
|
| 28 |
raise gr.Error("You must enter some text")
|
|
@@ -43,10 +43,11 @@ with gr.Blocks() as vctk:
|
|
| 43 |
with gr.Column(scale=1):
|
| 44 |
inp = gr.Textbox(label="Text", info="What would you like StyleTTS 2 to read? It works better on full sentences.", interactive=True)
|
| 45 |
voice = gr.Dropdown(voicelist, label="Voice", info="Select a default voice.", value='m-us-1', interactive=True)
|
|
|
|
| 46 |
with gr.Column(scale=1):
|
| 47 |
btn = gr.Button("Synthesize", variant="primary")
|
| 48 |
audio = gr.Audio(interactive=False, label="Synthesized Audio")
|
| 49 |
-
btn.click(synthesize, inputs=[inp, voice], outputs=[audio], concurrency_limit=4)
|
| 50 |
with gr.Blocks() as clone:
|
| 51 |
with gr.Row():
|
| 52 |
with gr.Column(scale=1):
|
|
|
|
| 16 |
# else:
|
| 17 |
for v in voicelist:
|
| 18 |
voices[v] = styletts2importable.compute_style(f'voices/{v}.wav')
|
| 19 |
+
def synthesize(text, voice, use_gruut):
|
| 20 |
if text.strip() == "":
|
| 21 |
raise gr.Error("You must enter some text")
|
| 22 |
if len(text) > 300:
|
| 23 |
raise gr.Error("Text must be under 300 characters")
|
| 24 |
v = voice.lower()
|
| 25 |
+
return (24000, styletts2importable.inference(text, voices[v], alpha=0.3, beta=0.7, diffusion_steps=7, embedding_scale=1, use_gruut=use_gruut))
|
| 26 |
def clsynthesize(text, voice):
|
| 27 |
if text.strip() == "":
|
| 28 |
raise gr.Error("You must enter some text")
|
|
|
|
| 43 |
with gr.Column(scale=1):
|
| 44 |
inp = gr.Textbox(label="Text", info="What would you like StyleTTS 2 to read? It works better on full sentences.", interactive=True)
|
| 45 |
voice = gr.Dropdown(voicelist, label="Voice", info="Select a default voice.", value='m-us-1', interactive=True)
|
| 46 |
+
use_gruut = gr.Checkbox(label="Use alternate phonemizer (Gruut) - Experimental")
|
| 47 |
with gr.Column(scale=1):
|
| 48 |
btn = gr.Button("Synthesize", variant="primary")
|
| 49 |
audio = gr.Audio(interactive=False, label="Synthesized Audio")
|
| 50 |
+
btn.click(synthesize, inputs=[inp, voice, use_gruut], outputs=[audio], concurrency_limit=4)
|
| 51 |
with gr.Blocks() as clone:
|
| 52 |
with gr.Row():
|
| 53 |
with gr.Column(scale=1):
|
gruut_phonemize.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from gruut import sentences
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def gphonemize(text):
|
| 5 |
+
phonemes = ''
|
| 6 |
+
for sent in sentences(text, lang="en-us"):
|
| 7 |
+
for word in sent:
|
| 8 |
+
if word.phonemes:
|
| 9 |
+
phonemes += ''.join(word.phonemes)
|
| 10 |
+
return phonemes
|
requirements.txt
CHANGED
|
@@ -18,4 +18,5 @@ git+https://github.com/resemble-ai/monotonic_align.git
|
|
| 18 |
scipy
|
| 19 |
phonemizer
|
| 20 |
cached-path
|
| 21 |
-
gradio
|
|
|
|
|
|
| 18 |
scipy
|
| 19 |
phonemizer
|
| 20 |
cached-path
|
| 21 |
+
gradio
|
| 22 |
+
gruut
|
styletts2importable.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
from cached_path import cached_path
|
|
|
|
|
|
|
| 2 |
|
| 3 |
# from dp.phonemizer import Phonemizer
|
| 4 |
print("NLTK")
|
|
@@ -131,9 +133,12 @@ sampler = DiffusionSampler(
|
|
| 131 |
clamp=False
|
| 132 |
)
|
| 133 |
|
| 134 |
-
def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):
|
| 135 |
text = text.strip()
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
| 137 |
ps = word_tokenize(ps[0])
|
| 138 |
ps = ' '.join(ps)
|
| 139 |
tokens = textclenaer(ps)
|
|
@@ -200,86 +205,92 @@ def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding
|
|
| 200 |
|
| 201 |
return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later
|
| 202 |
|
| 203 |
-
def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1):
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
|
| 223 |
-
|
| 224 |
embedding=bert_dur,
|
| 225 |
embedding_scale=embedding_scale,
|
| 226 |
-
|
| 227 |
num_steps=diffusion_steps).squeeze(1)
|
| 228 |
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
|
| 233 |
-
|
| 234 |
-
|
| 235 |
|
| 236 |
-
|
| 237 |
-
|
| 238 |
|
| 239 |
-
|
| 240 |
|
| 241 |
-
|
| 242 |
s, input_lengths, text_mask)
|
| 243 |
|
| 244 |
-
|
| 245 |
-
|
| 246 |
|
| 247 |
-
|
| 248 |
-
|
| 249 |
|
| 250 |
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
|
| 265 |
-
|
| 266 |
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
|
| 274 |
-
|
| 275 |
-
|
| 276 |
|
| 277 |
|
| 278 |
-
|
| 279 |
|
| 280 |
-
def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):
|
| 281 |
text = text.strip()
|
| 282 |
-
|
|
|
|
|
|
|
|
|
|
| 283 |
ps = word_tokenize(ps[0])
|
| 284 |
ps = ' '.join(ps)
|
| 285 |
|
|
@@ -288,7 +299,10 @@ def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=
|
|
| 288 |
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
|
| 289 |
|
| 290 |
ref_text = ref_text.strip()
|
| 291 |
-
|
|
|
|
|
|
|
|
|
|
| 292 |
ps = word_tokenize(ps[0])
|
| 293 |
ps = ' '.join(ps)
|
| 294 |
|
|
|
|
| 1 |
from cached_path import cached_path
|
| 2 |
+
print("GRUUT")
|
| 3 |
+
from gruut_phonemize import gphonemize
|
| 4 |
|
| 5 |
# from dp.phonemizer import Phonemizer
|
| 6 |
print("NLTK")
|
|
|
|
| 133 |
clamp=False
|
| 134 |
)
|
| 135 |
|
| 136 |
+
def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False):
|
| 137 |
text = text.strip()
|
| 138 |
+
if use_gruut:
|
| 139 |
+
ps = gphonemize(text)
|
| 140 |
+
else:
|
| 141 |
+
ps = global_phonemizer.phonemize([text])
|
| 142 |
ps = word_tokenize(ps[0])
|
| 143 |
ps = ' '.join(ps)
|
| 144 |
tokens = textclenaer(ps)
|
|
|
|
| 205 |
|
| 206 |
return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later
|
| 207 |
|
| 208 |
+
def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False):
|
| 209 |
+
text = text.strip()
|
| 210 |
+
if use_gruut:
|
| 211 |
+
ps = gphonemize(text)
|
| 212 |
+
else:
|
| 213 |
+
ps = global_phonemizer.phonemize([text])
|
| 214 |
+
ps = word_tokenize(ps[0])
|
| 215 |
+
ps = ' '.join(ps)
|
| 216 |
+
ps = ps.replace('``', '"')
|
| 217 |
+
ps = ps.replace("''", '"')
|
| 218 |
|
| 219 |
+
tokens = textclenaer(ps)
|
| 220 |
+
tokens.insert(0, 0)
|
| 221 |
+
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
|
| 222 |
|
| 223 |
+
with torch.no_grad():
|
| 224 |
+
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
| 225 |
+
text_mask = length_to_mask(input_lengths).to(device)
|
| 226 |
|
| 227 |
+
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
| 228 |
+
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
| 229 |
+
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
| 230 |
|
| 231 |
+
s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
|
| 232 |
embedding=bert_dur,
|
| 233 |
embedding_scale=embedding_scale,
|
| 234 |
+
features=ref_s, # reference from the same speaker as the embedding
|
| 235 |
num_steps=diffusion_steps).squeeze(1)
|
| 236 |
|
| 237 |
+
if s_prev is not None:
|
| 238 |
+
# convex combination of previous and current style
|
| 239 |
+
s_pred = t * s_prev + (1 - t) * s_pred
|
| 240 |
|
| 241 |
+
s = s_pred[:, 128:]
|
| 242 |
+
ref = s_pred[:, :128]
|
| 243 |
|
| 244 |
+
ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
|
| 245 |
+
s = beta * s + (1 - beta) * ref_s[:, 128:]
|
| 246 |
|
| 247 |
+
s_pred = torch.cat([ref, s], dim=-1)
|
| 248 |
|
| 249 |
+
d = model.predictor.text_encoder(d_en,
|
| 250 |
s, input_lengths, text_mask)
|
| 251 |
|
| 252 |
+
x, _ = model.predictor.lstm(d)
|
| 253 |
+
duration = model.predictor.duration_proj(x)
|
| 254 |
|
| 255 |
+
duration = torch.sigmoid(duration).sum(axis=-1)
|
| 256 |
+
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
|
| 257 |
|
| 258 |
|
| 259 |
+
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
|
| 260 |
+
c_frame = 0
|
| 261 |
+
for i in range(pred_aln_trg.size(0)):
|
| 262 |
+
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
|
| 263 |
+
c_frame += int(pred_dur[i].data)
|
| 264 |
|
| 265 |
+
# encode prosody
|
| 266 |
+
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
|
| 267 |
+
if model_params.decoder.type == "hifigan":
|
| 268 |
+
asr_new = torch.zeros_like(en)
|
| 269 |
+
asr_new[:, :, 0] = en[:, :, 0]
|
| 270 |
+
asr_new[:, :, 1:] = en[:, :, 0:-1]
|
| 271 |
+
en = asr_new
|
| 272 |
|
| 273 |
+
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
| 274 |
|
| 275 |
+
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
|
| 276 |
+
if model_params.decoder.type == "hifigan":
|
| 277 |
+
asr_new = torch.zeros_like(asr)
|
| 278 |
+
asr_new[:, :, 0] = asr[:, :, 0]
|
| 279 |
+
asr_new[:, :, 1:] = asr[:, :, 0:-1]
|
| 280 |
+
asr = asr_new
|
| 281 |
|
| 282 |
+
out = model.decoder(asr,
|
| 283 |
+
F0_pred, N_pred, ref.squeeze().unsqueeze(0))
|
| 284 |
|
| 285 |
|
| 286 |
+
return out.squeeze().cpu().numpy()[..., :-100], s_pred # weird pulse at the end of the model, need to be fixed later
|
| 287 |
|
| 288 |
+
def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False):
|
| 289 |
text = text.strip()
|
| 290 |
+
if use_gruut:
|
| 291 |
+
ps = gphonemize(text)
|
| 292 |
+
else:
|
| 293 |
+
ps = global_phonemizer.phonemize([text])
|
| 294 |
ps = word_tokenize(ps[0])
|
| 295 |
ps = ' '.join(ps)
|
| 296 |
|
|
|
|
| 299 |
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
|
| 300 |
|
| 301 |
ref_text = ref_text.strip()
|
| 302 |
+
if use_gruut:
|
| 303 |
+
ps = gphonemize(text)
|
| 304 |
+
else:
|
| 305 |
+
ps = global_phonemizer.phonemize([ref_text])
|
| 306 |
ps = word_tokenize(ps[0])
|
| 307 |
ps = ' '.join(ps)
|
| 308 |
|