Spaces:
Running
Running
| import os | |
| from dotenv import load_dotenv | |
| import random | |
| from gradio_client import Client, handle_file,file | |
| from huggingface_hub.constants import HF_TOKEN_PATH | |
| from pydub import AudioSegment | |
| import os.path | |
| load_dotenv() | |
| ZEROGPU_TOKENS = os.getenv("ZEROGPU_TOKENS", "").split(",") | |
| def get_zerogpu_token(): | |
| if not ZEROGPU_TOKENS or ZEROGPU_TOKENS == [""]: | |
| return os.getenv("HF_TOKEN") | |
| return random.choice(ZEROGPU_TOKENS) | |
| model_mapping = { | |
| "spark-tts": { | |
| "provider": "spark", | |
| "model": "spark-tts", | |
| }, | |
| "cosyvoice-2.0": { | |
| "provider": "cosyvoice", | |
| "model": "cosyvoice_2_0", | |
| }, | |
| "index-tts": { | |
| "provider": "bilibili", | |
| "model": "index-tts", | |
| }, | |
| "maskgct": { | |
| "provider": "amphion", | |
| "model": "maskgct", | |
| }, | |
| "gpt-sovits-v2-pro-plus": { | |
| "provider": "gpt-sovits", | |
| "model": "gpt-sovits-v2-pro-plus", | |
| }, | |
| } | |
| url = "https://tts-agi-tts-router-v2.hf.space/tts" | |
| headers = { | |
| "accept": "application/json", | |
| "Content-Type": "application/json", | |
| "Authorization": f'Bearer {os.getenv("HF_TOKEN")}', | |
| } | |
| data = {"text": "string", "provider": "string", "model": "string"} | |
| def set_client_for_session(space:str, user_token=None): | |
| if user_token is None: | |
| return Client(space, hf_token=get_zerogpu_token()) | |
| else: | |
| x_ip_token = user_token | |
| return Client(space, headers={"X-IP-Token": x_ip_token}) | |
| def predict_index_tts(text, user_token=None, reference_audio_path=None): | |
| client = set_client_for_session("kemuriririn/IndexTTS",user_token=user_token) | |
| if reference_audio_path: | |
| prompt = handle_file(reference_audio_path) | |
| else: | |
| raise ValueError("index-tts ιθ¦ reference_audio_path") | |
| result = client.predict( | |
| prompt=prompt, | |
| text=text, | |
| api_name="/gen_single" | |
| ) | |
| if type(result) != str: | |
| result = result.get("value") | |
| print("index-tts result:", result) | |
| return result | |
| def predict_spark_tts(text, user_token=None,reference_audio_path=None): | |
| client = set_client_for_session("thunnai/SparkTTS",user_token=user_token) | |
| prompt_wav = None | |
| if reference_audio_path: | |
| prompt_wav = handle_file(reference_audio_path) | |
| result = client.predict( | |
| text=text, | |
| prompt_text=text, | |
| prompt_wav_upload=prompt_wav, | |
| prompt_wav_record=prompt_wav, | |
| api_name="/voice_clone" | |
| ) | |
| print("spark-tts result:", result) | |
| return result | |
| def predict_cosyvoice_tts(text, user_token=None, reference_audio_path=None): | |
| client = set_client_for_session("kemuriririn/CosyVoice2-0.5B",user_token=user_token) | |
| if not reference_audio_path: | |
| raise ValueError("cosyvoice-2.0 ιθ¦ reference_audio_path") | |
| prompt_wav = handle_file(reference_audio_path) | |
| # ε θ―ε«εθι³ι’ζζ¬ | |
| recog_result = client.predict( | |
| prompt_wav=file(reference_audio_path), | |
| api_name="/prompt_wav_recognition" | |
| ) | |
| print("cosyvoice-2.0 prompt_wav_recognition result:", recog_result) | |
| prompt_text = recog_result if isinstance(recog_result, str) else str(recog_result) | |
| result = client.predict( | |
| tts_text=text, | |
| prompt_text=prompt_text, | |
| prompt_wav_upload=prompt_wav, | |
| prompt_wav_record=prompt_wav, | |
| seed=0, | |
| stream=False, | |
| api_name="/generate_audio" | |
| ) | |
| print("cosyvoice-2.0 result:", result) | |
| return result | |
| def predict_maskgct(text, user_token=None, reference_audio_path=None): | |
| client = set_client_for_session("amphion/maskgct",user_token=user_token) | |
| if not reference_audio_path: | |
| raise ValueError("maskgct ιθ¦ reference_audio_path") | |
| prompt_wav = handle_file(reference_audio_path) | |
| result = client.predict( | |
| prompt_wav=prompt_wav, | |
| target_text=text, | |
| target_len=-1, | |
| n_timesteps=25, | |
| api_name="/predict" | |
| ) | |
| print("maskgct result:", result) | |
| return result | |
| def predict_gpt_sovits_v2(text, user_token=None,reference_audio_path=None): | |
| client = set_client_for_session("kemuriririn/GPT-SoVITS-v2",user_token=user_token) | |
| if not reference_audio_path: | |
| raise ValueError("GPT-SoVITS-v2 ιθ¦ reference_audio_path") | |
| result = client.predict( | |
| ref_wav_path=file(reference_audio_path), | |
| prompt_text="", | |
| prompt_language="English", | |
| text=text, | |
| text_language="English", | |
| how_to_cut="Slice once every 4 sentences", | |
| top_k=15, | |
| top_p=1, | |
| temperature=1, | |
| ref_free=False, | |
| speed=1, | |
| if_freeze=False, | |
| inp_refs=[], | |
| api_name="/get_tts_wav" | |
| ) | |
| print("gpt-sovits-v2 result:", result) | |
| return result | |
| def normalize_audio_volume(audio_path): | |
| """ζε€§ει³ι’ι³ι""" | |
| # θ·εζδ»Άζ©ε±ε | |
| file_name, ext = os.path.splitext(audio_path) | |
| normalized_path = f"{file_name}_normalized{ext}" | |
| # θ―»ει³ι’ζδ»Ά | |
| sound = AudioSegment.from_file(audio_path) | |
| # ζε€§ει³ι (ζ εε) | |
| normalized_sound = sound.normalize() | |
| # δΏεε€ηεηι³ι’ | |
| normalized_sound.export(normalized_path, format=ext.replace('.', '')) | |
| return normalized_path | |
| def predict_tts(text, model, user_token=None, reference_audio_path=None): | |
| print(f"Predicting TTS for {model}, user_token: {user_token}, reference_audio_path: {reference_audio_path}") | |
| # Exceptions: special models that shouldn't be passed to the router | |
| if model == "index-tts": | |
| result = predict_index_tts(text, user_token,reference_audio_path) | |
| elif model == "spark-tts": | |
| result = predict_spark_tts(text, user_token,reference_audio_path) | |
| elif model == "cosyvoice-2.0": | |
| result = predict_cosyvoice_tts(text, user_token,reference_audio_path) | |
| elif model == "maskgct": | |
| result = predict_maskgct(text, user_token,reference_audio_path) | |
| elif model == "gpt-sovits-v2-pro-plus": | |
| result = predict_gpt_sovits_v2(text, user_token, reference_audio_path) | |
| else: | |
| raise ValueError(f"Model {model} not found") | |
| # ε―Ήηζηι³ι’θΏθ‘ι³ιζε€§εε€η | |
| normalized_result = normalize_audio_volume(result) | |
| return normalized_result | |
| if __name__ == "__main__": | |
| pass |