Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import spaces
|
| 2 |
from snac import SNAC
|
| 3 |
import torch
|
|
|
|
| 4 |
import gradio as gr
|
| 5 |
import os
|
| 6 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
@@ -235,7 +236,7 @@ with gr.Blocks(title="Orpheus Text-to-Speech") as demo:
|
|
| 235 |
inputs=[text_input, voice, temperature, top_p, repetition_penalty, max_new_tokens],
|
| 236 |
outputs=audio_output,
|
| 237 |
fn=generate_speech,
|
| 238 |
-
cache_examples=
|
| 239 |
)
|
| 240 |
|
| 241 |
# Set up event handlers
|
|
@@ -251,6 +252,34 @@ with gr.Blocks(title="Orpheus Text-to-Speech") as demo:
|
|
| 251 |
outputs=[text_input, audio_output]
|
| 252 |
)
|
| 253 |
|
| 254 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
if __name__ == "__main__":
|
| 256 |
-
|
|
|
|
| 1 |
import spaces
|
| 2 |
from snac import SNAC
|
| 3 |
import torch
|
| 4 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
| 5 |
import gradio as gr
|
| 6 |
import os
|
| 7 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
| 236 |
inputs=[text_input, voice, temperature, top_p, repetition_penalty, max_new_tokens],
|
| 237 |
outputs=audio_output,
|
| 238 |
fn=generate_speech,
|
| 239 |
+
cache_examples=False,
|
| 240 |
)
|
| 241 |
|
| 242 |
# Set up event handlers
|
|
|
|
| 252 |
outputs=[text_input, audio_output]
|
| 253 |
)
|
| 254 |
|
| 255 |
+
# Create FastAPI app and mount Gradio
|
| 256 |
+
app = FastAPI()
|
| 257 |
+
app.mount("/", demo)
|
| 258 |
+
|
| 259 |
+
# WebSocket TTS endpoint\@app.websocket("/ws/tts")
|
| 260 |
+
async def websocket_tts(websocket: WebSocket):
|
| 261 |
+
await websocket.accept()
|
| 262 |
+
try:
|
| 263 |
+
while True:
|
| 264 |
+
msg = await websocket.receive_text()
|
| 265 |
+
data = json.loads(msg)
|
| 266 |
+
text = data.get("text", "")
|
| 267 |
+
voice = data.get("voice", VOICES[0])
|
| 268 |
+
# Generate audio for the chunk
|
| 269 |
+
_, audio = generate_speech(text, voice, 0.7, 0.95, 1.1, 1200)
|
| 270 |
+
# Stream audio in 0.1s chunks
|
| 271 |
+
chunk_size = 2400 # 24000 Hz -> 2400 samples = 0.1s
|
| 272 |
+
for i in range(0, len(audio), chunk_size):
|
| 273 |
+
chunk = audio[i:i+chunk_size]
|
| 274 |
+
await websocket.send_bytes(chunk.tobytes())
|
| 275 |
+
await websocket.send_text("__END__")
|
| 276 |
+
except WebSocketDisconnect:
|
| 277 |
+
print("Client disconnected from /ws/tts")
|
| 278 |
+
|
| 279 |
+
# Launch if run directly
|
| 280 |
+
def main():
|
| 281 |
+
import uvicorn
|
| 282 |
+
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|
| 283 |
+
|
| 284 |
if __name__ == "__main__":
|
| 285 |
+
main()
|