|
|
import gradio as gr |
|
|
from pathlib import Path |
|
|
|
|
|
import soundfile as sf |
|
|
|
|
|
import spaces |
|
|
|
|
|
import torch |
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
_old_load = torch.load |
|
|
|
|
|
def safe_torch_load(*args, **kwargs): |
|
|
args = list(args) |
|
|
if len(args) >= 2: |
|
|
args[1] = device |
|
|
else: |
|
|
kwargs['map_location'] = device |
|
|
return _old_load(*args, **kwargs) |
|
|
|
|
|
torch.load = safe_torch_load |
|
|
|
|
|
import torchaudio |
|
|
import hydra |
|
|
from omegaconf import OmegaConf |
|
|
import diffusers.schedulers as noise_schedulers |
|
|
|
|
|
from utils.config import register_omegaconf_resolvers |
|
|
from models.common import LoadPretrainedBase |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
import fairseq |
|
|
|
|
|
register_omegaconf_resolvers() |
|
|
config = OmegaConf.load("configs/infer.yaml") |
|
|
|
|
|
ckpt_path = hf_hub_download( |
|
|
repo_id="assasinatee/STAR", |
|
|
filename="model.safetensors", |
|
|
repo_type="model", |
|
|
force_download=False |
|
|
) |
|
|
|
|
|
exp_config = OmegaConf.load("configs/config.yaml") |
|
|
if "pretrained_ckpt" in exp_config["model"]: |
|
|
exp_config["model"]["pretrained_ckpt"] = ckpt_path |
|
|
model: LoadPretrainedBase = hydra.utils.instantiate(exp_config["model"]) |
|
|
|
|
|
model = model.to(device) |
|
|
|
|
|
ckpt_path = hf_hub_download( |
|
|
repo_id="assasinatee/STAR", |
|
|
filename="hubert_large_ll60k.pt", |
|
|
repo_type="model", |
|
|
force_download=False |
|
|
) |
|
|
hubert_models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path]) |
|
|
hubert_model = hubert_models[0].eval().to(device) |
|
|
|
|
|
scheduler = getattr( |
|
|
noise_schedulers, |
|
|
config["noise_scheduler"]["type"], |
|
|
).from_pretrained( |
|
|
config["noise_scheduler"]["name"], |
|
|
subfolder="scheduler", |
|
|
) |
|
|
|
|
|
@torch.no_grad() |
|
|
@spaces.GPU(duration=60) |
|
|
def infer(audio_path: str) -> str: |
|
|
waveform_tts, sample_rate = torchaudio.load(audio_path) |
|
|
if sample_rate != 16000: |
|
|
waveform_tts = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_tts) |
|
|
if waveform_tts.shape[0] > 1: |
|
|
waveform_tts = torch.mean(waveform_tts, dim=0, keepdim=True) |
|
|
with torch.no_grad(): |
|
|
features, _ = hubert_model.extract_features(waveform_tts.to(device)) |
|
|
|
|
|
kwargs = OmegaConf.to_container(config["infer_args"].copy(), resolve=True) |
|
|
kwargs['content'] = [features] |
|
|
kwargs['condition'] = None |
|
|
kwargs['task'] = ["speech_to_audio"] |
|
|
|
|
|
model.eval() |
|
|
waveform = model.inference( |
|
|
scheduler=scheduler, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
output_file = "output_audio.wav" |
|
|
sf.write(output_file, waveform.squeeze().cpu().numpy(), samplerate=exp_config["sample_rate"]) |
|
|
|
|
|
return output_file |
|
|
|
|
|
with gr.Blocks(title="STAR Online Inference", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# STAR: Speech-to-Audio Generation via Representation Learning") |
|
|
|
|
|
gr.Markdown(""" |
|
|
<div style="text-align: left; padding: 10px;"> |
|
|
|
|
|
## 📚️ Introduction |
|
|
|
|
|
STAR is the first end-to-end speech-to-audio generation framework, designed to enhance efficiency and address error propagation inherent in cascaded systems. |
|
|
|
|
|
Within this space, you have the opportunity to directly control our model through voice input, thereby generating the corresponding audio output. |
|
|
|
|
|
## 🗣️ Input |
|
|
|
|
|
A brief input speech utterance for the overall audio scene. |
|
|
|
|
|
> Example:A cat meowing and young female speaking |
|
|
|
|
|
### 🎙️ Input Speech Example |
|
|
""") |
|
|
|
|
|
speech = gr.Audio(value="wav/speech.wav", label="Input Speech Example", type="filepath") |
|
|
|
|
|
gr.Markdown(""" |
|
|
<div style="text-align: left; padding: 10px;"> |
|
|
|
|
|
## 🎧️ Output |
|
|
|
|
|
Capture both auditory events and scene cues and generate corresponding audio |
|
|
|
|
|
### 🔊 Output Audio Example |
|
|
""") |
|
|
|
|
|
audio = gr.Audio(value="wav/audio.wav", label="Generated Audio Example", type="filepath") |
|
|
|
|
|
gr.Markdown(""" |
|
|
<div style="text-align: left; padding: 10px;"> |
|
|
|
|
|
</div> |
|
|
|
|
|
--- |
|
|
|
|
|
</div> |
|
|
|
|
|
## 🛠️ Online Inference |
|
|
|
|
|
You can upload your own samples, or try the quick examples provided below. |
|
|
""") |
|
|
|
|
|
with gr.Column(): |
|
|
input_audio = gr.Audio(label="🗣️ Speech Input", type="filepath") |
|
|
btn = gr.Button("🎵Generate Audio!", variant="primary") |
|
|
output_audio = gr.Audio(label="🎧️ Generated Audio", type="filepath") |
|
|
btn.click(fn=infer, inputs=input_audio, outputs=output_audio) |
|
|
|
|
|
gr.Markdown(""" |
|
|
<div style="text-align: left; padding: 10px;"> |
|
|
|
|
|
## 🎯 Quick Examples |
|
|
""") |
|
|
|
|
|
display_caption = gr.Textbox(label="📝 Caption" ,visible=False) |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.Tab("VITS Generated Speech"): |
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["wav/vits/1.wav", "A cat meowing and young female speaking"], |
|
|
["wav/vits/2.wav", "Sustained industrial engine noise"], |
|
|
["wav/vits/3.wav", "A woman talks and a baby whispers"], |
|
|
["wav/vits/4.wav", "A man speaks followed by a toilet flush"], |
|
|
["wav/vits/5.wav", "It is raining and thundering, and then a man speaks"], |
|
|
["wav/vits/6.wav", "A man speaking as birds are chirping"], |
|
|
["wav/vits/7.wav", "A muffled man talking as a goat baas before and after two goats baaing in the distance while wind blows into a microphone"], |
|
|
["wav/vits/8.wav", "Birds chirping and a horse neighing"], |
|
|
["wav/vits/9.wav", "Several church bells ringing"], |
|
|
["wav/vits/10.wav", "A telephone rings with bell sounds"] |
|
|
], |
|
|
inputs=[input_audio, display_caption], |
|
|
label="Click examples below to try!", |
|
|
cache_examples = False, |
|
|
examples_per_page = 10, |
|
|
) |
|
|
|
|
|
with gr.Tab("Real Human Speech"): |
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["wav/human/1.wav", "A cat meowing and young female speaking"], |
|
|
["wav/human/2.wav", "Sustained industrial engine noise"], |
|
|
["wav/human/3.wav", "A woman talks and a baby whispers"], |
|
|
["wav/human/4.wav", "A man speaks followed by a toilet flush"], |
|
|
["wav/human/5.wav", "It is raining and thundering, and then a man speaks"], |
|
|
["wav/human/6.wav", "A man speaking as birds are chirping"], |
|
|
["wav/human/7.wav", "A muffled man talking as a goat baas before and after two goats baaing in the distance while wind blows into a microphone"], |
|
|
["wav/human/8.wav", "Birds chirping and a horse neighing"], |
|
|
["wav/human/9.wav", "Several church bells ringing"], |
|
|
["wav/human/10.wav", "A telephone rings with bell sounds"] |
|
|
], |
|
|
inputs=[input_audio, display_caption], |
|
|
label="Click examples below to try!", |
|
|
cache_examples = False, |
|
|
examples_per_page = 10, |
|
|
) |
|
|
|
|
|
|
|
|
demo.launch() |