from typing import Optional # Import HiggsAudio components from higgs_audio.serve.serve_engine import HiggsAudioServeEngine from higgs_audio.data_types import ChatMLSample, AudioContent, Message import base64 from functools import lru_cache from loguru import logger import os import json import uuid import time import numpy as np import re def process_text_output(text_output: str): # remove all the continuous <|AUDIO_OUT|> tokens with a single <|AUDIO_OUT|> text_output = re.sub(r"(<\|AUDIO_OUT\|>)+", r"<|AUDIO_OUT|>", text_output) return text_output def check_return_audio(audio_wv: np.ndarray): # check if the audio returned is all silent if np.all(audio_wv == 0): logger.warning("Audio is silent, returning None") def load_voice_presets(): """Load the voice presets from the voice_examples directory.""" try: with open( os.path.join(os.path.dirname(__file__), "examples", "audios", "config.json"), "r", ) as f: voice_dict = json.load(f) voice_presets = {k: v for k, v in voice_dict.items()} voice_presets["EMPTY"] = "No reference voice" logger.info(f"Loaded voice presets: {list(voice_presets.keys())}") return voice_presets except FileNotFoundError: logger.warning("Voice examples config file not found. Using empty voice presets.") return {"EMPTY": "No reference voice"} except Exception as e: logger.error(f"Error loading voice presets: {e}") return {"EMPTY": "No reference voice"} SAMPLE_RATE = 24000 DEFAULT_STOP_STRINGS = ["<|end_of_text|>", "<|eot_id|>"] VOICE_PRESETS = load_voice_presets() def initialize_engine(model_path, audio_tokenizer_path) -> bool: engine = HiggsAudioServeEngine( model_name_or_path=model_path, audio_tokenizer_name_or_path=audio_tokenizer_path, device="cuda", ) return engine def get_voice_preset(voice_preset): """Get the voice path and text for a given voice preset.""" preset_dir = os.path.join(os.path.dirname(__file__), "examples", "audios") voice_path = os.path.join(preset_dir, VOICE_PRESETS[voice_preset]["audio_file"]) if not os.path.exists(voice_path): logger.warning(f"Voice preset file not found: {voice_path}") return None, "Voice preset not found" text = VOICE_PRESETS[voice_preset]["transcript"] return voice_path, text def normalize_chinese_punctuation(text): """ Convert Chinese (full-width) punctuation marks to English (half-width) equivalents. """ # Mapping of Chinese punctuation to English punctuation chinese_to_english_punct = { ",": ", ", # comma "。": ".", # period ":": ":", # colon ";": ";", # semicolon "?": "?", # question mark "!": "!", # exclamation mark "(": "(", # left parenthesis ")": ")", # right parenthesis "【": "[", # left square bracket "】": "]", # right square bracket "《": "<", # left angle quote "》": ">", # right angle quote "“": '"', # left double quotation "”": '"', # right double quotation "‘": "'", # left single quotation "’": "'", # right single quotation "、": ",", # enumeration comma "—": "-", # em dash "…": "...", # ellipsis "·": ".", # middle dot "「": '"', # left corner bracket "」": '"', # right corner bracket "『": '"', # left double corner bracket "』": '"', # right double corner bracket } # Replace each Chinese punctuation with its English counterpart for zh_punct, en_punct in chinese_to_english_punct.items(): text = text.replace(zh_punct, en_punct) return text def normalize_text(transcript: str): transcript = normalize_chinese_punctuation(transcript) # Other normalizations (e.g., parentheses and other symbols. Will be improved in the future) transcript = transcript.replace("(", " ") transcript = transcript.replace(")", " ") transcript = transcript.replace("°F", " degrees Fahrenheit") transcript = transcript.replace("°C", " degrees Celsius") for tag, replacement in [ ("[laugh]", "[Laughter]"), ("[humming start]", "[Humming]"), ("[humming end]", "[Humming]"), ("[music start]", "[Music]"), ("[music end]", "[Music]"), ("[music]", "[Music]"), ("[sing start]", "[Singing]"), ("[sing end]", "[Singing]"), ("[applause]", "[Applause]"), ("[cheering]", "[Cheering]"), ("[cough]", "[Cough]"), ]: transcript = transcript.replace(tag, replacement) lines = transcript.split("\n") transcript = "\n".join([" ".join(line.split()) for line in lines if line.strip()]) transcript = transcript.strip() if not any([transcript.endswith(c) for c in [".", "!", "?", ",", ";", '"', "'", "", ""]]): transcript += "." return transcript @lru_cache(maxsize=20) def encode_audio_file(file_path): """Encode an audio file to base64.""" with open(file_path, "rb") as audio_file: return base64.b64encode(audio_file.read()).decode("utf-8") def prepare_chatml_sample( voice_preset: str, text: str, reference_audio: Optional[str] = None, reference_text: Optional[str] = None, system_prompt: str = "", ): """Prepare a ChatMLSample for the HiggsAudioServeEngine.""" messages = [] # Add system message if provided if len(system_prompt) > 0: messages.append(Message(role="system", content=system_prompt)) # Add reference audio if provided audio_base64 = None ref_text = "" if reference_audio: # Custom reference audio audio_base64 = encode_audio_file(reference_audio) ref_text = reference_text or "" elif voice_preset != "EMPTY": # Voice preset voice_path, ref_text = get_voice_preset(voice_preset) if voice_path is None: logger.warning(f"Voice preset {voice_preset} not found, skipping reference audio") else: audio_base64 = encode_audio_file(voice_path) # Only add reference audio if we have it if audio_base64 is not None: # Add user message with reference text messages.append(Message(role="user", content=ref_text)) # Add assistant message with audio content audio_content = AudioContent(raw_audio=audio_base64, audio_url="") messages.append(Message(role="assistant", content=[audio_content])) # Add the main user message text = normalize_text(text) messages.append(Message(role="user", content=text)) return ChatMLSample(messages=messages) def text_to_speech( engine, text, system_prompt="", voice_preset="EMPTY", reference_audio=None, reference_text=None, max_completion_tokens=1024, temperature=1.0, top_p=0.95, top_k=50, stop_strings=None, ras_win_len=7, ras_win_max_num_repeat=2, ): """ Convert text to speech using HiggsAudioServeEngine. Args: text: The text to convert to speech voice_preset: The voice preset to use (or "EMPTY" for no preset) reference_audio: Optional path to reference audio file reference_text: Optional transcript of the reference audio max_completion_tokens: Maximum number of tokens to generate temperature: Sampling temperature for generation top_p: Top-p sampling parameter top_k: Top-k sampling parameter system_prompt: System prompt to guide the model stop_strings: Dataframe containing stop strings ras_win_len: Window length for repetition avoidance sampling ras_win_max_num_repeat: Maximum number of repetitions allowed in the window Returns: Tuple of (generated_text, (sample_rate, audio_data)) where audio_data is int16 numpy array """ try: # Prepare ChatML sample chatml_sample = prepare_chatml_sample(voice_preset, text, reference_audio, reference_text, system_prompt) # Convert stop strings format if stop_strings is None: stop_list = DEFAULT_STOP_STRINGS else: stop_list = [s for s in stop_strings["stops"] if s.strip()] request_id = f"tts-playground-{str(uuid.uuid4())}" start_time = time.time() # Generate using the engine response = engine.generate( chat_ml_sample=chatml_sample, max_new_tokens=max_completion_tokens, temperature=temperature, top_k=top_k if top_k > 0 else None, top_p=top_p, stop_strings=stop_list, ras_win_len=ras_win_len if ras_win_len > 0 else None, ras_win_max_num_repeat=max(ras_win_len, ras_win_max_num_repeat), ) generation_time = time.time() - start_time # Process the response text_output = process_text_output(response.generated_text) if response.audio is not None: # Convert to int16 for Gradio audio_data = (response.audio * 32767).astype(np.int16) check_return_audio(audio_data) return text_output, (response.sampling_rate, audio_data) else: logger.warning("No audio generated") return text_output, None except Exception as e: error_msg = f"Error generating speech: {e}" logger.error(error_msg) return f"❌ {error_msg}", None