ZipVoice-DEMO / app.py
Luigi's picture
Clean Spaces deployment - Gradio interface only
ed290ee
raw
history blame
10.3 kB
#!/usr/bin/env python3
"""
ZipVoice Gradio Web Interface for HuggingFace Spaces
"""
import os
import tempfile
import gradio as gr
import torch
from pathlib import Path
# Import ZipVoice components
from zipvoice.models.zipvoice import ZipVoice
from zipvoice.models.zipvoice_distill import ZipVoiceDistill
from zipvoice.tokenizer.tokenizer import EmiliaTokenizer
from zipvoice.utils.checkpoint import load_checkpoint
from zipvoice.utils.feature import VocosFbank
from zipvoice.bin.infer_zipvoice import generate_sentence
from lhotse.utils import fix_random_seed
# Global variables for caching models
_models_cache = {}
_tokenizer_cache = None
_vocoder_cache = None
_feature_extractor_cache = None
def load_models_and_components(model_name: str):
"""Load and cache models, tokenizer, vocoder, and feature extractor."""
global _models_cache, _tokenizer_cache, _vocoder_cache, _feature_extractor_cache
# Set device (CPU for Spaces, but could be adapted for GPU)
device = torch.device("cpu")
if model_name not in _models_cache:
print(f"Loading {model_name} model...")
# Model directory mapping
model_dir_map = {
"zipvoice": "zipvoice",
"zipvoice_distill": "zipvoice_distill",
}
huggingface_repo = "k2-fsa/ZipVoice"
# Download model files from HuggingFace
from huggingface_hub import hf_hub_download
model_ckpt = hf_hub_download(
huggingface_repo, filename=f"{model_dir_map[model_name]}/model.pt"
)
model_config_path = hf_hub_download(
huggingface_repo, filename=f"{model_dir_map[model_name]}/model.json"
)
token_file = hf_hub_download(
huggingface_repo, filename=f"{model_dir_map[model_name]}/tokens.txt"
)
# Load tokenizer (cache it)
if _tokenizer_cache is None:
_tokenizer_cache = EmiliaTokenizer(token_file=token_file)
tokenizer = _tokenizer_cache
tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
# Load model configuration
import json
with open(model_config_path, "r") as f:
model_config = json.load(f)
# Create model
if model_name == "zipvoice":
model = ZipVoice(**model_config["model"], **tokenizer_config)
else:
model = ZipVoiceDistill(**model_config["model"], **tokenizer_config)
# Load model weights
load_checkpoint(filename=model_ckpt, model=model, strict=True)
model = model.to(device)
model.eval()
_models_cache[model_name] = model
# Load vocoder (cache it)
if _vocoder_cache is None:
from vocos import Vocos
_vocoder_cache = Vocos.from_pretrained("charactr/vocos-mel-24khz")
_vocoder_cache = _vocoder_cache.to(device)
_vocoder_cache.eval()
# Load feature extractor (cache it)
if _feature_extractor_cache is None:
_feature_extractor_cache = VocosFbank()
return (_models_cache[model_name], _tokenizer_cache,
_vocoder_cache, _feature_extractor_cache,
model_config["feature"]["sampling_rate"])
def synthesize_speech_gradio(
text: str,
prompt_audio_file,
prompt_text: str,
model_name: str,
speed: float
):
"""Synthesize speech using ZipVoice for Gradio interface."""
if not text.strip():
return None, "Error: Please enter text to synthesize."
if prompt_audio_file is None:
return None, "Error: Please upload a prompt audio file."
if not prompt_text.strip():
return None, "Error: Please enter the transcription of the prompt audio."
try:
# Set random seed for reproducibility
fix_random_seed(666)
# Load models and components
model, tokenizer, vocoder, feature_extractor, sampling_rate = load_models_and_components(model_name)
device = torch.device("cpu")
# Save uploaded audio to temporary file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
temp_audio_path = temp_audio.name
with open(temp_audio_path, "wb") as f:
f.write(prompt_audio_file)
# Create temporary output file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_output:
output_path = temp_output.name
print(f"Synthesizing: '{text}' using {model_name}")
print(f"Prompt: {prompt_text}")
print(f"Speed: {speed}")
# Generate speech
with torch.inference_mode():
metrics = generate_sentence(
save_path=output_path,
prompt_text=prompt_text,
prompt_wav=temp_audio_path,
text=text,
model=model,
vocoder=vocoder,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
device=device,
num_step=16 if model_name == "zipvoice" else 8,
guidance_scale=1.0 if model_name == "zipvoice" else 3.0,
speed=speed,
t_shift=0.5,
target_rms=0.1,
feat_scale=0.1,
sampling_rate=sampling_rate,
max_duration=100,
remove_long_sil=False,
)
# Read the generated audio file
with open(output_path, "rb") as f:
audio_data = f.read()
# Clean up temporary files
os.unlink(temp_audio_path)
os.unlink(output_path)
success_msg = f"Synthesis completed! Duration: {metrics['wav_seconds']:.2f}s, RTF: {metrics['rtf']:.2f}"
return audio_data, success_msg
except Exception as e:
error_msg = f"Error during synthesis: {str(e)}"
print(error_msg)
return None, error_msg
def create_gradio_interface():
"""Create the Gradio web interface."""
# Custom CSS for better styling
css = """
.gradio-container {
max-width: 1200px;
margin: auto;
}
.title {
text-align: center;
color: #2563eb;
font-size: 2.5em;
font-weight: bold;
margin-bottom: 1em;
}
.subtitle {
text-align: center;
color: #64748b;
font-size: 1.2em;
margin-bottom: 2em;
}
"""
with gr.Blocks(title="ZipVoice - Zero-Shot Text-to-Speech", css=css) as interface:
gr.HTML("""
<div class="title">🎵 ZipVoice</div>
<div class="subtitle">Fast and High-Quality Zero-Shot Text-to-Speech with Flow Matching</div>
""")
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(
label="Text to Synthesize",
placeholder="Enter the text you want to convert to speech...",
lines=3,
value="這是一則語音測試"
)
with gr.Row():
model_dropdown = gr.Dropdown(
choices=["zipvoice", "zipvoice_distill"],
value="zipvoice",
label="Model",
info="zipvoice_distill is faster but slightly less accurate"
)
speed_slider = gr.Slider(
minimum=0.5,
maximum=2.0,
value=1.0,
step=0.1,
label="Speed",
info="1.0 = normal speed, >1.0 = faster, <1.0 = slower"
)
prompt_audio = gr.File(
label="Prompt Audio",
file_types=["audio"],
type="binary",
info="Upload a short audio clip (1-3 seconds recommended) to mimic the voice style"
)
prompt_text = gr.Textbox(
label="Prompt Transcription",
placeholder="Enter the exact transcription of the prompt audio...",
lines=2,
info="This should match what is spoken in the audio file"
)
generate_btn = gr.Button(
"🎵 Generate Speech",
variant="primary",
size="lg"
)
with gr.Column(scale=1):
output_audio = gr.Audio(
label="Generated Speech",
type="filepath"
)
status_text = gr.Textbox(
label="Status",
interactive=False,
lines=3
)
gr.Examples(
examples=[
["Hello world! This is a test of ZipVoice.", None, "Hello world! This is a test.", "zipvoice", 1.0],
["今天天氣真好,我們去公園散步吧!", None, "今天天氣真好", "zipvoice", 1.0],
["The quick brown fox jumps over the lazy dog.", None, "The quick brown fox", "zipvoice_distill", 1.2],
],
inputs=[text_input, prompt_audio, prompt_text, model_dropdown, speed_slider],
label="Quick Examples"
)
# Event handling
generate_btn.click(
fn=synthesize_speech_gradio,
inputs=[text_input, prompt_audio, prompt_text, model_dropdown, speed_slider],
outputs=[output_audio, status_text]
)
# Footer
gr.HTML("""
<div style="text-align: center; margin-top: 2em; color: #64748b; font-size: 0.9em;">
<p>Powered by <a href="https://github.com/k2-fsa/ZipVoice" target="_blank">ZipVoice</a> |
Built with <a href="https://gradio.app" target="_blank">Gradio</a></p>
<p>Upload a short audio clip as prompt, and ZipVoice will synthesize speech in that voice style!</p>
</div>
""")
return interface
if __name__ == "__main__":
# Create and launch the interface
interface = create_gradio_interface()
interface.launch(
server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", 7860)),
show_error=True
)