kemuriririn commited on
Commit
f2186e3
·
1 Parent(s): f6fd8f5

(wip)replace space

Browse files
Files changed (1) hide show
  1. tts.py +7 -225
tts.py CHANGED
@@ -1,18 +1,5 @@
1
- # TODO: V2 of TTS Router
2
- # Currently just use current TTS router.
3
  import os
4
- import json
5
  from dotenv import load_dotenv
6
- import fal_client
7
- import requests
8
- import time
9
- import io
10
-
11
- from gradio_client import handle_file
12
- from pyht import Client as PyhtClient
13
- from pyht.client import TTSOptions
14
- import base64
15
- import tempfile
16
  import random
17
 
18
  load_dotenv()
@@ -25,54 +12,14 @@ def get_zerogpu_token():
25
 
26
 
27
  model_mapping = {
28
- # "eleven-multilingual-v2": {
29
- # "provider": "elevenlabs",
30
- # "model": "eleven_multilingual_v2",
31
- # },
32
- # "eleven-turbo-v2.5": {
33
- # "provider": "elevenlabs",
34
- # "model": "eleven_turbo_v2_5",
35
- # },
36
- # "eleven-flash-v2.5": {
37
- # "provider": "elevenlabs",
38
- # "model": "eleven_flash_v2_5",
39
- # },
40
  "spark-tts": {
41
  "provider": "spark",
42
  "model": "spark-tts",
43
  },
44
- # "playht-2.0": {
45
- # "provider": "playht",
46
- # "model": "PlayHT2.0",
47
- # },
48
- # "styletts2": {
49
- # "provider": "styletts",
50
- # "model": "styletts2",
51
- # },
52
  "cosyvoice-2.0": {
53
  "provider": "cosyvoice",
54
  "model": "cosyvoice_2_0",
55
  },
56
- # "papla-p1": {
57
- # "provider": "papla",
58
- # "model": "papla_p1",
59
- # },
60
- # "hume-octave": {
61
- # "provider": "hume",
62
- # "model": "octave",
63
- # },
64
- # "minimax-02-hd": {
65
- # "provider": "minimax",
66
- # "model": "speech-02-hd",
67
- # },
68
- # "minimax-02-turbo": {
69
- # "provider": "minimax",
70
- # "model": "speech-02-turbo",
71
- # },
72
- # "lanternfish-1": {
73
- # "provider": "lanternfish",
74
- # "model": "lanternfish-1",
75
- # },
76
  "index-tts": {
77
  "provider": "bilibili",
78
  "model": "index-tts",
@@ -95,110 +42,9 @@ headers = {
95
  data = {"text": "string", "provider": "string", "model": "string"}
96
 
97
 
98
- def predict_csm(script):
99
- result = fal_client.subscribe(
100
- "fal-ai/csm-1b",
101
- arguments={
102
- # "scene": [{
103
- # "text": "Hey how are you doing.",
104
- # "speaker_id": 0
105
- # }, {
106
- # "text": "Pretty good, pretty good.",
107
- # "speaker_id": 1
108
- # }, {
109
- # "text": "I'm great, so happy to be speaking to you.",
110
- # "speaker_id": 0
111
- # }]
112
- "scene": script
113
- },
114
- with_logs=True,
115
- )
116
- return requests.get(result["audio"]["url"]).content
117
-
118
-
119
- def predict_playdialog(script):
120
- # Initialize the PyHT client
121
- pyht_client = PyhtClient(
122
- user_id=os.getenv("PLAY_USERID"),
123
- api_key=os.getenv("PLAY_SECRETKEY"),
124
- )
125
-
126
- # Define the voices
127
- voice_1 = "s3://voice-cloning-zero-shot/baf1ef41-36b6-428c-9bdf-50ba54682bd8/original/manifest.json"
128
- voice_2 = "s3://voice-cloning-zero-shot/e040bd1b-f190-4bdb-83f0-75ef85b18f84/original/manifest.json"
129
-
130
- # Convert script format from CSM to PlayDialog format
131
- if isinstance(script, list):
132
- # Process script in CSM format (list of dictionaries)
133
- text = ""
134
- for turn in script:
135
- speaker_id = turn.get("speaker_id", 0)
136
- prefix = "Host 1:" if speaker_id == 0 else "Host 2:"
137
- text += f"{prefix} {turn['text']}\n"
138
- else:
139
- # If it's already a string, use as is
140
- text = script
141
-
142
- # Set up TTSOptions
143
- options = TTSOptions(
144
- voice=voice_1, voice_2=voice_2, turn_prefix="Host 1:", turn_prefix_2="Host 2:"
145
- )
146
-
147
- # Generate audio using PlayDialog
148
- audio_chunks = []
149
- for chunk in pyht_client.tts(text, options, voice_engine="PlayDialog"):
150
- audio_chunks.append(chunk)
151
-
152
- # Combine all chunks into a single audio file
153
- return b"".join(audio_chunks)
154
-
155
-
156
- def predict_dia(script):
157
- # Convert script to the required format for Dia
158
- if isinstance(script, list):
159
- # Convert from list of dictionaries to formatted string
160
- formatted_text = ""
161
- for turn in script:
162
- speaker_id = turn.get("speaker_id", 0)
163
- speaker_tag = "[S1]" if speaker_id == 0 else "[S2]"
164
- text = turn.get("text", "").strip().replace("[S1]", "").replace("[S2]", "")
165
- formatted_text += f"{speaker_tag} {text} "
166
- text = formatted_text.strip()
167
- else:
168
- # If it's already a string, use as is
169
- text = script
170
- print(text)
171
- # Make a POST request to initiate the dialogue generation
172
- headers = {
173
- # "Content-Type": "application/json",
174
- "Authorization": f"Bearer {get_zerogpu_token()}"
175
- }
176
-
177
- response = requests.post(
178
- "https://mrfakename-dia-1-6b.hf.space/gradio_api/call/generate_dialogue",
179
- headers=headers,
180
- json={"data": [text]},
181
- )
182
-
183
- # Extract the event ID from the response
184
- event_id = response.json()["event_id"]
185
-
186
- # Make a streaming request to get the generated dialogue
187
- stream_url = f"https://mrfakename-dia-1-6b.hf.space/gradio_api/call/generate_dialogue/{event_id}"
188
-
189
- # Use a streaming request to get the audio data
190
- with requests.get(stream_url, headers=headers, stream=True) as stream_response:
191
- # Process the streaming response
192
- for line in stream_response.iter_lines():
193
- if line:
194
- if line.startswith(b"data: ") and not line.startswith(b"data: null"):
195
- audio_data = line[6:]
196
- return requests.get(json.loads(audio_data)[0]["url"]).content
197
-
198
-
199
  def predict_index_tts(text, reference_audio_path=None):
200
  from gradio_client import Client, handle_file
201
- client = Client("IndexTeam/IndexTTS")
202
  if reference_audio_path:
203
  prompt = handle_file(reference_audio_path)
204
  else:
@@ -216,7 +62,7 @@ def predict_index_tts(text, reference_audio_path=None):
216
 
217
  def predict_spark_tts(text, reference_audio_path=None):
218
  from gradio_client import Client, handle_file
219
- client = Client("thunnai/SparkTTS")
220
  prompt_wav = None
221
  if reference_audio_path:
222
  prompt_wav = handle_file(reference_audio_path)
@@ -233,7 +79,7 @@ def predict_spark_tts(text, reference_audio_path=None):
233
 
234
  def predict_cosyvoice_tts(text, reference_audio_path=None):
235
  from gradio_client import Client, file, handle_file
236
- client = Client("https://iic-cosyvoice2-0-5b.ms.show/")
237
  if not reference_audio_path:
238
  raise ValueError("cosyvoice-2.0 需要 reference_audio_path")
239
  prompt_wav = handle_file(reference_audio_path)
@@ -246,7 +92,7 @@ def predict_cosyvoice_tts(text, reference_audio_path=None):
246
  prompt_text = recog_result if isinstance(recog_result, str) else str(recog_result)
247
  result = client.predict(
248
  tts_text=text,
249
- mode_checkbox_group="3s极速复刻",
250
  prompt_text=prompt_text,
251
  prompt_wav_upload=prompt_wav,
252
  prompt_wav_record=prompt_wav,
@@ -304,13 +150,7 @@ def predict_tts(text, model, reference_audio_path=None):
304
  global client
305
  print(f"Predicting TTS for {model}")
306
  # Exceptions: special models that shouldn't be passed to the router
307
- if model == "csm-1b":
308
- return predict_csm(text)
309
- elif model == "playdialog-1.0":
310
- return predict_playdialog(text)
311
- elif model == "dia-1.6b":
312
- return predict_dia(text)
313
- elif model == "index-tts":
314
  return predict_index_tts(text, reference_audio_path)
315
  elif model == "spark-tts":
316
  return predict_spark_tts(text, reference_audio_path)
@@ -321,66 +161,8 @@ def predict_tts(text, model, reference_audio_path=None):
321
  elif model == "gpt-sovits-v2":
322
  return predict_gpt_sovits_v2(text, reference_audio_path)
323
 
324
- if not model in model_mapping:
325
- raise ValueError(f"Model {model} not found")
326
-
327
- # 构建请求体
328
- payload = {
329
- "text": text,
330
- "provider": model_mapping[model]["provider"],
331
- "model": model_mapping[model]["model"],
332
- }
333
- # 仅支持音色克隆的模型传递参考音色
334
- supports_reference = model in [
335
- "styletts2", "eleven-multilingual-v2", "eleven-turbo-v2.5", "eleven-flash-v2.5"
336
- ]
337
- if reference_audio_path and supports_reference:
338
- with open(reference_audio_path, "rb") as f:
339
- audio_bytes = f.read()
340
- audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
341
- # 不同模型参考音色字段不同
342
- if model == "styletts2":
343
- payload["reference_speaker"] = audio_b64
344
- else: # elevenlabs 系列
345
- payload["reference_audio"] = audio_b64
346
-
347
- result = requests.post(
348
- url,
349
- headers=headers,
350
- data=json.dumps(payload),
351
- )
352
-
353
- response_json = result.json()
354
-
355
- audio_data = response_json["audio_data"] # base64 encoded audio data
356
- extension = response_json["extension"]
357
- # Decode the base64 audio data
358
- audio_bytes = base64.b64decode(audio_data)
359
-
360
- # Create a temporary file to store the audio data
361
- with tempfile.NamedTemporaryFile(delete=False, suffix=f".{extension}") as temp_file:
362
- temp_file.write(audio_bytes)
363
- temp_path = temp_file.name
364
-
365
- return temp_path
366
 
367
 
368
  if __name__ == "__main__":
369
- print(
370
- predict_dia(
371
- [
372
- {"text": "Hello, how are you?", "speaker_id": 0},
373
- {"text": "I'm great, thank you!", "speaker_id": 1},
374
- ]
375
- )
376
- )
377
- # print("Predicting PlayDialog")
378
- # print(
379
- # predict_playdialog(
380
- # [
381
- # {"text": "Hey how are you doing.", "speaker_id": 0},
382
- # {"text": "Pretty good, pretty good.", "speaker_id": 1},
383
- # {"text": "I'm great, so happy to be speaking to you.", "speaker_id": 0},
384
- # ]
385
- # )
386
- # )
 
 
 
1
  import os
 
2
  from dotenv import load_dotenv
 
 
 
 
 
 
 
 
 
 
3
  import random
4
 
5
  load_dotenv()
 
12
 
13
 
14
  model_mapping = {
 
 
 
 
 
 
 
 
 
 
 
 
15
  "spark-tts": {
16
  "provider": "spark",
17
  "model": "spark-tts",
18
  },
 
 
 
 
 
 
 
 
19
  "cosyvoice-2.0": {
20
  "provider": "cosyvoice",
21
  "model": "cosyvoice_2_0",
22
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  "index-tts": {
24
  "provider": "bilibili",
25
  "model": "index-tts",
 
42
  data = {"text": "string", "provider": "string", "model": "string"}
43
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def predict_index_tts(text, reference_audio_path=None):
46
  from gradio_client import Client, handle_file
47
+ client = Client("IndexTeam/IndexTTS",hf_token=os.getenv("HF_TOKEN"))
48
  if reference_audio_path:
49
  prompt = handle_file(reference_audio_path)
50
  else:
 
62
 
63
  def predict_spark_tts(text, reference_audio_path=None):
64
  from gradio_client import Client, handle_file
65
+ client = Client("kemuriririn/SparkTTS",hf_token=os.getenv("HF_TOKEN"))
66
  prompt_wav = None
67
  if reference_audio_path:
68
  prompt_wav = handle_file(reference_audio_path)
 
79
 
80
  def predict_cosyvoice_tts(text, reference_audio_path=None):
81
  from gradio_client import Client, file, handle_file
82
+ client = Client("kemuriririn/CosyVoice2-0.5B",hf_token=os.getenv("HF_TOKEN"))
83
  if not reference_audio_path:
84
  raise ValueError("cosyvoice-2.0 需要 reference_audio_path")
85
  prompt_wav = handle_file(reference_audio_path)
 
92
  prompt_text = recog_result if isinstance(recog_result, str) else str(recog_result)
93
  result = client.predict(
94
  tts_text=text,
95
+ mode_checkbox_group="3s Voice Clone",
96
  prompt_text=prompt_text,
97
  prompt_wav_upload=prompt_wav,
98
  prompt_wav_record=prompt_wav,
 
150
  global client
151
  print(f"Predicting TTS for {model}")
152
  # Exceptions: special models that shouldn't be passed to the router
153
+ if model == "index-tts":
 
 
 
 
 
 
154
  return predict_index_tts(text, reference_audio_path)
155
  elif model == "spark-tts":
156
  return predict_spark_tts(text, reference_audio_path)
 
161
  elif model == "gpt-sovits-v2":
162
  return predict_gpt_sovits_v2(text, reference_audio_path)
163
 
164
+ raise ValueError(f"Model {model} not found")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
 
167
  if __name__ == "__main__":
168
+ pass