Spaces:
Paused
Paused
| import os | |
| import json | |
| import asyncio | |
| import torch | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from huggingface_hub import login | |
| from snac import SNAC | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # — HF‑Token & Login — | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| if HF_TOKEN: | |
| login(HF_TOKEN) | |
| # — Device wählen — | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # — FastAPI instanziieren — | |
| app = FastAPI() | |
| # — Hello‑Route, damit kein 404 bei GET / mehr kommt — | |
| async def read_root(): | |
| return {"message": "Hello, world!"} | |
| # — Modelle bei Startup laden — | |
| async def load_models(): | |
| global tokenizer, model, snac | |
| # SNAC laden | |
| snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) | |
| # TTS‑Modell laden | |
| model_name = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map={"": 0} if device == "cuda" else None, | |
| torch_dtype=torch.bfloat16 if device == "cuda" else None, | |
| low_cpu_mem_usage=True | |
| ) | |
| # Pad‑ID auf EOS einstellen | |
| model.config.pad_token_id = model.config.eos_token_id | |
| # — Hilfsfunktionen — | |
| def prepare_inputs(text: str, voice: str): | |
| prompt = f"{voice}: {text}" | |
| input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
| # Start‑/End‑Marker | |
| start = torch.tensor([[128259]], dtype=torch.int64, device=device) | |
| end = torch.tensor([[128009, 128260]], dtype=torch.int64, device=device) | |
| ids = torch.cat([start, input_ids, end], dim=1) | |
| mask = torch.ones_like(ids) | |
| return ids, mask | |
| def decode_block(block_tokens: list[int]): | |
| # aus 7 Tokens einen SNAC‑Decode‑Block bauen | |
| layer1, layer2, layer3 = [], [], [] | |
| b = block_tokens | |
| layer1.append(b[0]) | |
| layer2.append(b[1] - 4096) | |
| layer3.append(b[2] - 2*4096) | |
| layer3.append(b[3] - 3*4096) | |
| layer2.append(b[4] - 4*4096) | |
| layer3.append(b[5] - 5*4096) | |
| layer3.append(b[6] - 6*4096) | |
| codes = [ | |
| torch.tensor(layer1, device=device).unsqueeze(0), | |
| torch.tensor(layer2, device=device).unsqueeze(0), | |
| torch.tensor(layer3, device=device).unsqueeze(0), | |
| ] | |
| # ergibt FloatTensor shape (1, N), @24 kHz | |
| audio = snac.decode(codes).squeeze().cpu().numpy() | |
| # in PCM16 umwandeln | |
| return (audio * 32767).astype("int16").tobytes() | |
| # — WebSocket Endpoint für TTS Streaming — | |
| async def tts_ws(ws: WebSocket): | |
| await ws.accept() | |
| try: | |
| # erst die Anfrage als JSON empfangen | |
| msg = await ws.receive_text() | |
| req = json.loads(msg) | |
| text = req.get("text", "") | |
| voice = req.get("voice", "Jakob") | |
| # Inputs bauen | |
| input_ids, attention_mask = prepare_inputs(text, voice) | |
| past_kvs = None | |
| collected = [] | |
| # Token‑für‑Token loop | |
| while True: | |
| out = model( | |
| input_ids=input_ids if past_kvs is None else None, | |
| attention_mask=attention_mask if past_kvs is None else None, | |
| past_key_values=past_kvs, | |
| use_cache=True, | |
| ) | |
| logits = out.logits[:, -1, :] | |
| past_kvs = out.past_key_values | |
| # Sampling | |
| probs = torch.softmax(logits, dim=-1) | |
| nxt = torch.multinomial(probs, num_samples=1).item() | |
| # Ende, wenn EOS | |
| if nxt == model.config.eos_token_id: | |
| break | |
| # Reset bei neuem Start‑Marker | |
| if nxt == 128257: | |
| collected = [] | |
| continue | |
| # Audio‑Code offsetten und sammeln | |
| collected.append(nxt - 128266) | |
| # sobald 7 Stück, direkt dekodieren und senden | |
| if len(collected) == 7: | |
| pcm = decode_block(collected) | |
| collected = [] | |
| await ws.send_bytes(pcm) | |
| # nach Ende sauber schließen | |
| await ws.close() | |
| except WebSocketDisconnect: | |
| # Client hat disconnectet | |
| pass | |
| except Exception as e: | |
| # bei Fehlern 1011 senden | |
| print("Error in /ws/tts:", e) | |
| await ws.close(code=1011) | |