Spaces:
Running
on
T4
Running
on
T4
overwrite some pitch values at the start and end to make it sound more lively
Browse files
Modules/ToucanTTS/InferenceToucanTTS.py
CHANGED
|
@@ -219,32 +219,42 @@ class ToucanTTS(torch.nn.Module):
|
|
| 219 |
encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids)
|
| 220 |
|
| 221 |
# predicting pitch, energy and durations
|
| 222 |
-
reduced_pitch_space =
|
| 223 |
pitch_predictions = self.pitch_predictor(mu=reduced_pitch_space,
|
| 224 |
mask=text_masks.float(),
|
| 225 |
-
n_timesteps=
|
| 226 |
temperature=prosody_creativity,
|
| 227 |
c=utterance_embedding) if gold_pitch is None else gold_pitch
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
pitch_predictions = _scale_variance(pitch_predictions, pitch_variance_scale)
|
| 229 |
embedded_pitch_curve = self.pitch_embed(pitch_predictions).transpose(1, 2)
|
| 230 |
|
| 231 |
-
reduced_energy_space =
|
| 232 |
energy_predictions = self.energy_predictor(mu=reduced_energy_space,
|
| 233 |
mask=text_masks.float(),
|
| 234 |
-
n_timesteps=
|
| 235 |
temperature=prosody_creativity,
|
| 236 |
c=utterance_embedding) if gold_energy is None else gold_energy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
energy_predictions = _scale_variance(energy_predictions, energy_variance_scale)
|
| 238 |
embedded_energy_curve = self.energy_embed(energy_predictions).transpose(1, 2)
|
| 239 |
|
| 240 |
-
reduced_duration_space =
|
| 241 |
predicted_durations = torch.clamp(torch.ceil(self.duration_predictor(mu=reduced_duration_space,
|
| 242 |
mask=text_masks.float(),
|
| 243 |
-
n_timesteps=
|
| 244 |
temperature=prosody_creativity,
|
| 245 |
-
c=utterance_embedding)), min=0.0).long().squeeze(1) if gold_durations is None else gold_durations
|
| 246 |
|
| 247 |
# modifying the predictions with control parameters
|
|
|
|
| 248 |
for phoneme_index, phoneme_vector in enumerate(text_tensors.squeeze(0)):
|
| 249 |
if phoneme_vector[get_feature_to_index_lookup()["word-boundary"]] == 1:
|
| 250 |
predicted_durations[0][phoneme_index] = 0
|
|
@@ -267,8 +277,8 @@ class ToucanTTS(torch.nn.Module):
|
|
| 267 |
|
| 268 |
refined_codec_frames = self.flow_matching_decoder(mu=preliminary_spectrogram.transpose(1, 2),
|
| 269 |
mask=make_non_pad_mask([len(decoded_speech[0])], device=decoded_speech.device).unsqueeze(-2),
|
| 270 |
-
n_timesteps=
|
| 271 |
-
temperature=0.
|
| 272 |
c=None).transpose(1, 2)
|
| 273 |
|
| 274 |
return refined_codec_frames, predicted_durations.squeeze(), pitch_predictions.squeeze(), energy_predictions.squeeze()
|
|
@@ -326,19 +336,19 @@ class ToucanTTS(torch.nn.Module):
|
|
| 326 |
lang_id = lang_id.to(text.device)
|
| 327 |
|
| 328 |
outs, \
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
|
| 343 |
if return_duration_pitch_energy:
|
| 344 |
return outs.squeeze().transpose(0, 1), predicted_durations, pitch_predictions, energy_predictions
|
|
|
|
| 219 |
encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids)
|
| 220 |
|
| 221 |
# predicting pitch, energy and durations
|
| 222 |
+
reduced_pitch_space = self.pitch_latent_reduction(encoded_texts).transpose(1, 2)
|
| 223 |
pitch_predictions = self.pitch_predictor(mu=reduced_pitch_space,
|
| 224 |
mask=text_masks.float(),
|
| 225 |
+
n_timesteps=20,
|
| 226 |
temperature=prosody_creativity,
|
| 227 |
c=utterance_embedding) if gold_pitch is None else gold_pitch
|
| 228 |
+
# because of the way we are processing the data, the last few elements of a sequence will always receive an unnaturally low pitch value. To fix this, we just overwrite them here.
|
| 229 |
+
pitch_predictions[0][0][0] = pitch_predictions[0][0][1]
|
| 230 |
+
pitch_predictions[0][0][-1] = pitch_predictions[0][0][-3]
|
| 231 |
+
pitch_predictions[0][0][-2] = pitch_predictions[0][0][-3]
|
| 232 |
pitch_predictions = _scale_variance(pitch_predictions, pitch_variance_scale)
|
| 233 |
embedded_pitch_curve = self.pitch_embed(pitch_predictions).transpose(1, 2)
|
| 234 |
|
| 235 |
+
reduced_energy_space = self.energy_latent_reduction(encoded_texts + embedded_pitch_curve).transpose(1, 2)
|
| 236 |
energy_predictions = self.energy_predictor(mu=reduced_energy_space,
|
| 237 |
mask=text_masks.float(),
|
| 238 |
+
n_timesteps=20,
|
| 239 |
temperature=prosody_creativity,
|
| 240 |
c=utterance_embedding) if gold_energy is None else gold_energy
|
| 241 |
+
|
| 242 |
+
# because of the way we are processing the data, the last few elements of a sequence will always receive an unnaturally low energy value. To fix this, we just overwrite them here.
|
| 243 |
+
energy_predictions[0][0][0] = energy_predictions[0][0][1]
|
| 244 |
+
energy_predictions[0][0][-1] = energy_predictions[0][0][-3]
|
| 245 |
+
energy_predictions[0][0][-2] = energy_predictions[0][0][-3]
|
| 246 |
energy_predictions = _scale_variance(energy_predictions, energy_variance_scale)
|
| 247 |
embedded_energy_curve = self.energy_embed(energy_predictions).transpose(1, 2)
|
| 248 |
|
| 249 |
+
reduced_duration_space = self.duration_latent_reduction(encoded_texts + embedded_pitch_curve + embedded_energy_curve).transpose(1, 2)
|
| 250 |
predicted_durations = torch.clamp(torch.ceil(self.duration_predictor(mu=reduced_duration_space,
|
| 251 |
mask=text_masks.float(),
|
| 252 |
+
n_timesteps=20,
|
| 253 |
temperature=prosody_creativity,
|
| 254 |
+
c=utterance_embedding)), min=0.0).long().squeeze(1) if gold_durations is None else gold_durations.squeeze(1)
|
| 255 |
|
| 256 |
# modifying the predictions with control parameters
|
| 257 |
+
predicted_durations[0][0] = 1 # if the initial pause is too long, we get artifacts. This is once more a dirty hack.
|
| 258 |
for phoneme_index, phoneme_vector in enumerate(text_tensors.squeeze(0)):
|
| 259 |
if phoneme_vector[get_feature_to_index_lookup()["word-boundary"]] == 1:
|
| 260 |
predicted_durations[0][phoneme_index] = 0
|
|
|
|
| 277 |
|
| 278 |
refined_codec_frames = self.flow_matching_decoder(mu=preliminary_spectrogram.transpose(1, 2),
|
| 279 |
mask=make_non_pad_mask([len(decoded_speech[0])], device=decoded_speech.device).unsqueeze(-2),
|
| 280 |
+
n_timesteps=30,
|
| 281 |
+
temperature=0.2, # low temperature, so the model follows the specified prosody curves better.
|
| 282 |
c=None).transpose(1, 2)
|
| 283 |
|
| 284 |
return refined_codec_frames, predicted_durations.squeeze(), pitch_predictions.squeeze(), energy_predictions.squeeze()
|
|
|
|
| 336 |
lang_id = lang_id.to(text.device)
|
| 337 |
|
| 338 |
outs, \
|
| 339 |
+
predicted_durations, \
|
| 340 |
+
pitch_predictions, \
|
| 341 |
+
energy_predictions = self._forward(text.unsqueeze(0),
|
| 342 |
+
text_length,
|
| 343 |
+
gold_durations=durations,
|
| 344 |
+
gold_pitch=pitch,
|
| 345 |
+
gold_energy=energy,
|
| 346 |
+
utterance_embedding=utterance_embedding.unsqueeze(0) if utterance_embedding is not None else None, lang_ids=lang_id,
|
| 347 |
+
duration_scaling_factor=duration_scaling_factor,
|
| 348 |
+
pitch_variance_scale=pitch_variance_scale,
|
| 349 |
+
energy_variance_scale=energy_variance_scale,
|
| 350 |
+
pause_duration_scaling_factor=pause_duration_scaling_factor,
|
| 351 |
+
prosody_creativity=prosody_creativity)
|
| 352 |
|
| 353 |
if return_duration_pitch_energy:
|
| 354 |
return outs.squeeze().transpose(0, 1), predicted_durations, pitch_predictions, energy_predictions
|