Spaces:
Paused
Paused
| import os | |
| import json | |
| import asyncio | |
| import torch | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from huggingface_hub import login | |
| from snac import SNAC | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # — HF‑Token & Login — | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| if HF_TOKEN: | |
| login(HF_TOKEN) | |
| # — Gerät wählen — | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # — Modell‑Parameter — | |
| MODEL_NAME = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" | |
| START_MARKER = 128259 # <|startoftranscript|> | |
| RESTART_MARKER = 128257 # <|startoftranscript_again|> | |
| EOS_TOKEN = 128258 # <|endoftranscript|> | |
| AUDIO_TOKEN_OFFSET = 128266 # Offset zum Zurückrechnen | |
| BLOCK_TOKENS = 7 # SNAC erwartet 7 Audio‑Tokens pro Block | |
| CHUNK_TOKENS = 50 # Anzahl neuer Tokens pro Generate‑Runde | |
| # — FastAPI instanziieren — | |
| app = FastAPI() | |
| # — Damit GET / nicht 404 wirft — | |
| async def read_root(): | |
| return {"message": "Orpheus TTS Server ist live 🎙️"} | |
| # — Modelle bei Startup laden — | |
| async def load_models(): | |
| global tokenizer, model, snac | |
| # SNAC laden | |
| snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) | |
| # Tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| # TTS‑LM | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16 if device=="cuda" else None, | |
| low_cpu_mem_usage=True | |
| ) | |
| model.config.pad_token_id = EOS_TOKEN | |
| # — Eingabe aufbereiten — | |
| def prepare_inputs(text: str, voice: str): | |
| prompt = f"{voice}: {text}" | |
| input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
| start = torch.tensor([[START_MARKER]], device=device) | |
| end = torch.tensor([[128009, EOS_TOKEN]], device=device) | |
| ids = torch.cat([start, input_ids, end], dim=1) | |
| attn_mask = torch.ones_like(ids) | |
| return ids, attn_mask | |
| # — Aus 7 Audio‑Tokens ein PCM‑Block erzeugen — | |
| def decode_block(block: list[int]) -> bytes: | |
| l1, l2, l3 = [], [], [] | |
| b = block | |
| l1.append(b[0]) | |
| l2.append(b[1] - 4096) | |
| l3.append(b[2] - 2*4096) | |
| l3.append(b[3] - 3*4096) | |
| l2.append(b[4] - 4*4096) | |
| l3.append(b[5] - 5*4096) | |
| l3.append(b[6] - 6*4096) | |
| codes = [ | |
| torch.tensor(l1, device=device).unsqueeze(0), | |
| torch.tensor(l2, device=device).unsqueeze(0), | |
| torch.tensor(l3, device=device).unsqueeze(0), | |
| ] | |
| audio = snac.decode(codes).squeeze().cpu().numpy() | |
| pcm16 = (audio * 32767).astype("int16").tobytes() | |
| return pcm16 | |
| # — Generator: kleine Chunks token‑weise erzeugen und block‑weise dekodieren — | |
| async def generate_and_stream(ws: WebSocket, ids, attn_mask): | |
| buffer: list[int] = [] | |
| past_kvs = None | |
| while True: | |
| # wir rufen model.generate in Häppchen auf | |
| outputs = model.generate( | |
| input_ids = ids if past_kvs is None else None, | |
| attention_mask = attn_mask if past_kvs is None else None, | |
| past_key_values= past_kvs, | |
| use_cache = True, | |
| max_new_tokens = CHUNK_TOKENS, | |
| do_sample = True, | |
| temperature = 0.7, | |
| top_p = 0.95, | |
| repetition_penalty = 1.1, | |
| eos_token_id = EOS_TOKEN, | |
| pad_token_id = EOS_TOKEN, | |
| return_dict_in_generate = True, | |
| output_scores = False, | |
| ) | |
| # update past_kvs | |
| past_kvs = outputs.past_key_values | |
| # erhalte nur die gerade neu generierten Token | |
| seq = outputs.sequences[0] | |
| new_tokens = seq[-CHUNK_TOKENS:].tolist() if past_kvs is not None else seq[ids.shape[-1]:].tolist() | |
| for tok in new_tokens: | |
| # Neustart bei erneutem START‑Marker | |
| if tok == RESTART_MARKER: | |
| buffer = [] | |
| continue | |
| # Ende | |
| if tok == EOS_TOKEN: | |
| return | |
| # Audio‑Code berechnen | |
| buffer.append(tok - AUDIO_TOKEN_OFFSET) | |
| # sobald 7 Audio‑Tokens, dekodieren und streamen | |
| if len(buffer) >= BLOCK_TOKENS: | |
| block = buffer[:BLOCK_TOKENS] | |
| buffer = buffer[BLOCK_TOKENS:] | |
| pcm = decode_block(block) | |
| await ws.send_bytes(pcm) | |
| # — WebSocket‑Endpoint für TTS Streaming — | |
| async def tts_ws(ws: WebSocket): | |
| await ws.accept() | |
| try: | |
| data = await ws.receive_text() | |
| req = json.loads(data) | |
| text = req.get("text", "") | |
| voice = req.get("voice", "Jakob") | |
| ids, attn_mask = prepare_inputs(text, voice) | |
| await generate_and_stream(ws, ids, attn_mask) | |
| await ws.close() | |
| except WebSocketDisconnect: | |
| pass | |
| except Exception as e: | |
| print("Error in /ws/tts:", e) | |
| await ws.close(code=1011) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860) | |