File size: 7,029 Bytes
d64def4 0509b90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
import gradio as gr
from pathlib import Path
import soundfile as sf
import spaces
# forcing torch.load to CPU
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
_old_load = torch.load
def safe_torch_load(*args, **kwargs):
args = list(args)
if len(args) >= 2:
args[1] = device
else:
kwargs['map_location'] = device
return _old_load(*args, **kwargs)
torch.load = safe_torch_load
import torchaudio
import hydra
from omegaconf import OmegaConf
import diffusers.schedulers as noise_schedulers
from utils.config import register_omegaconf_resolvers
from models.common import LoadPretrainedBase
from huggingface_hub import hf_hub_download
import fairseq
register_omegaconf_resolvers()
config = OmegaConf.load("configs/infer.yaml")
ckpt_path = hf_hub_download(
repo_id="assasinatee/STAR",
filename="model.safetensors",
repo_type="model",
force_download=False
)
exp_config = OmegaConf.load("configs/config.yaml")
if "pretrained_ckpt" in exp_config["model"]:
exp_config["model"]["pretrained_ckpt"] = ckpt_path
model: LoadPretrainedBase = hydra.utils.instantiate(exp_config["model"])
model = model.to(device)
ckpt_path = hf_hub_download(
repo_id="assasinatee/STAR",
filename="hubert_large_ll60k.pt",
repo_type="model",
force_download=False
)
hubert_models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
hubert_model = hubert_models[0].eval().to(device)
scheduler = getattr(
noise_schedulers,
config["noise_scheduler"]["type"],
).from_pretrained(
config["noise_scheduler"]["name"],
subfolder="scheduler",
)
@torch.no_grad()
@spaces.GPU(duration=60)
def infer(audio_path: str) -> str:
waveform_tts, sample_rate = torchaudio.load(audio_path)
if sample_rate != 16000:
waveform_tts = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_tts)
if waveform_tts.shape[0] > 1:
waveform_tts = torch.mean(waveform_tts, dim=0, keepdim=True)
with torch.no_grad():
features, _ = hubert_model.extract_features(waveform_tts.to(device))
kwargs = OmegaConf.to_container(config["infer_args"].copy(), resolve=True)
kwargs['content'] = [features]
kwargs['condition'] = None
kwargs['task'] = ["speech_to_audio"]
model.eval()
waveform = model.inference(
scheduler=scheduler,
**kwargs,
)
output_file = "output_audio.wav"
sf.write(output_file, waveform.squeeze().cpu().numpy(), samplerate=exp_config["sample_rate"])
return output_file
with gr.Blocks(title="STAR Online Inference", theme=gr.themes.Soft()) as demo:
gr.Markdown("# STAR: Speech-to-Audio Generation via Representation Learning")
gr.Markdown("""
<div style="text-align: left; padding: 10px;">
## 📚️ Introduction
STAR is the first end-to-end speech-to-audio generation framework, designed to enhance efficiency and address error propagation inherent in cascaded systems.
Within this space, you have the opportunity to directly control our model through voice input, thereby generating the corresponding audio output.
## 🗣️ Input
A brief input speech utterance for the overall audio scene.
> Example:A cat meowing and young female speaking
### 🎙️ Input Speech Example
""")
speech = gr.Audio(value="wav/speech.wav", label="Input Speech Example", type="filepath")
gr.Markdown("""
<div style="text-align: left; padding: 10px;">
## 🎧️ Output
Capture both auditory events and scene cues and generate corresponding audio
### 🔊 Output Audio Example
""")
audio = gr.Audio(value="wav/audio.wav", label="Generated Audio Example", type="filepath")
gr.Markdown("""
<div style="text-align: left; padding: 10px;">
</div>
---
</div>
## 🛠️ Online Inference
You can upload your own samples, or try the quick examples provided below.
""")
with gr.Column():
input_audio = gr.Audio(label="🗣️ Speech Input", type="filepath")
btn = gr.Button("🎵Generate Audio!", variant="primary")
output_audio = gr.Audio(label="🎧️ Generated Audio", type="filepath")
btn.click(fn=infer, inputs=input_audio, outputs=output_audio)
gr.Markdown("""
<div style="text-align: left; padding: 10px;">
## 🎯 Quick Examples
""")
display_caption = gr.Textbox(label="📝 Caption" ,visible=False)
with gr.Tabs():
with gr.Tab("VITS Generated Speech"):
gr.Examples(
examples=[
["wav/vits/1.wav", "A cat meowing and young female speaking"],
["wav/vits/2.wav", "Sustained industrial engine noise"],
["wav/vits/3.wav", "A woman talks and a baby whispers"],
["wav/vits/4.wav", "A man speaks followed by a toilet flush"],
["wav/vits/5.wav", "It is raining and thundering, and then a man speaks"],
["wav/vits/6.wav", "A man speaking as birds are chirping"],
["wav/vits/7.wav", "A muffled man talking as a goat baas before and after two goats baaing in the distance while wind blows into a microphone"],
["wav/vits/8.wav", "Birds chirping and a horse neighing"],
["wav/vits/9.wav", "Several church bells ringing"],
["wav/vits/10.wav", "A telephone rings with bell sounds"]
],
inputs=[input_audio, display_caption],
label="Click examples below to try!",
cache_examples = False,
examples_per_page = 10,
)
with gr.Tab("Real Human Speech"):
gr.Examples(
examples=[
["wav/human/1.wav", "A cat meowing and young female speaking"],
["wav/human/2.wav", "Sustained industrial engine noise"],
["wav/human/3.wav", "A woman talks and a baby whispers"],
["wav/human/4.wav", "A man speaks followed by a toilet flush"],
["wav/human/5.wav", "It is raining and thundering, and then a man speaks"],
["wav/human/6.wav", "A man speaking as birds are chirping"],
["wav/human/7.wav", "A muffled man talking as a goat baas before and after two goats baaing in the distance while wind blows into a microphone"],
["wav/human/8.wav", "Birds chirping and a horse neighing"],
["wav/human/9.wav", "Several church bells ringing"],
["wav/human/10.wav", "A telephone rings with bell sounds"]
],
inputs=[input_audio, display_caption],
label="Click examples below to try!",
cache_examples = False,
examples_per_page = 10,
)
demo.launch() |