Spaces:
Paused
Paused
| # app.py ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| import os, json, asyncio, torch, logging | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from huggingface_hub import login | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor | |
| from transformers.generation.utils import Cache | |
| from snac import SNAC | |
| # ββ 0. Auth & Device ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| if HF_TOKEN: | |
| login(HF_TOKEN) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch.backends.cuda.enable_flash_sdp(False) # FlashβBug umgehen | |
| logging.getLogger("transformers.generation.utils").setLevel("ERROR") | |
| # ββ 1. Konstanten βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" | |
| CHUNK_TOKENS = 50 | |
| START_TOKEN = 128259 # <π > | |
| NEW_BLOCK_TOKEN = 128257 # πβStart | |
| EOS_TOKEN = 128258 # <eos> | |
| PROMPT_END = [128009, 128260] | |
| AUDIO_BASE = 128266 | |
| VALID_AUDIO_IDS = torch.arange(AUDIO_BASE, AUDIO_BASE + 4096) | |
| # ββ 2. LogitβMasker βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class AudioMask(LogitsProcessor): | |
| def __init__(self, allowed: torch.Tensor): | |
| super().__init__() | |
| self.allowed = allowed | |
| def __call__(self, input_ids, scores): | |
| mask = torch.full_like(scores, float("-inf")) | |
| mask[:, self.allowed] = 0 | |
| return scores + mask | |
| ALLOWED_IDS = torch.cat([ | |
| VALID_AUDIO_IDS, | |
| torch.tensor([START_TOKEN, NEW_BLOCK_TOKEN, EOS_TOKEN]) | |
| ]).to(device) | |
| MASKER = AudioMask(ALLOWED_IDS) | |
| # ββ 3. FastAPI GrundgerΓΌst ββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI() | |
| async def ping(): | |
| return {"message": "OrpheusβTTSΒ ready"} | |
| async def load_models(): | |
| global tok, model, snac | |
| tok = AutoTokenizer.from_pretrained(MODEL_REPO) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_REPO, | |
| low_cpu_mem_usage=True, | |
| device_map={"": 0} if device == "cuda" else None, | |
| torch_dtype=torch.bfloat16 if device == "cuda" else None, | |
| ) | |
| model.config.pad_token_id = model.config.eos_token_id | |
| snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) | |
| # ββ 4. Hilfsfunktionen ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_inputs(text: str, voice: str): | |
| prompt = f"{voice}: {text}" if voice and voice != "in_prompt" else text | |
| ids = tok(prompt, return_tensors="pt").input_ids.to(device) | |
| ids = torch.cat([ | |
| torch.tensor([[START_TOKEN]], device=device), | |
| ids, | |
| torch.tensor([PROMPT_END], device=device) | |
| ], 1) | |
| mask = torch.ones_like(ids) | |
| return ids, mask # shape (1,Β L) | |
| def decode_block(block7: list[int]) -> bytes: | |
| l1, l2, l3 = [], [], [] | |
| b = block7 | |
| l1.append(b[0]) | |
| l2.append(b[1] - 4096) | |
| l3 += [b[2]-8192, b[3]-12288] | |
| l2.append(b[4] - 16384) | |
| l3 += [b[5]-20480, b[6]-24576] | |
| codes = [torch.tensor(x, device=device).unsqueeze(0) for x in (l1, l2, l3)] | |
| audio = snac.decode(codes).squeeze().cpu().numpy() | |
| return (audio * 32767).astype("int16").tobytes() | |
| # ββ 5. WebSocketβEndpoint βββββββββββββββββββββββββββββββββββββββββββ | |
| async def tts(ws: WebSocket): | |
| await ws.accept() | |
| try: | |
| req = json.loads(await ws.receive_text()) | |
| ids, attn = build_inputs(req.get("text", ""), req.get("voice", "Jakob")) | |
| prompt_len = ids.size(1) | |
| past, buf = None, [] | |
| while True: | |
| out = model.generate( | |
| input_ids=ids if past is None else None, | |
| attention_mask=attn if past is None else None, | |
| past_key_values=past, | |
| max_new_tokens=CHUNK_TOKENS, | |
| logits_processor=[MASKER], | |
| do_sample=True, top_p=0.95, temperature=0.7, | |
| return_dict_in_generate=True, | |
| use_cache=True, | |
| return_legacy_cache=True, # β Warnung verschwindet | |
| ) | |
| past = out.past_key_values # unverΓ€ndert weiterreichen | |
| seq = out.sequences[0].tolist() | |
| new = seq[prompt_len:]; prompt_len = len(seq) | |
| if not new: # selten, aber mΓΆglich | |
| continue | |
| for t in new: | |
| if t == EOS_TOKEN: | |
| await ws.close() | |
| return | |
| if t == NEW_BLOCK_TOKEN: | |
| buf.clear(); continue | |
| if t < AUDIO_BASE: # sollte durch Maske nie passieren | |
| continue | |
| buf.append(t - AUDIO_BASE) | |
| if len(buf) == 7: | |
| await ws.send_bytes(decode_block(buf)) | |
| buf.clear() | |
| # Ab jetzt nur noch Cache β IDs & Mask nicht mehr nΓΆtig | |
| ids = attn = None | |
| except WebSocketDisconnect: | |
| pass | |
| except Exception as e: | |
| print("WSβError:", e) | |
| if ws.client_state.name == "CONNECTED": | |
| await ws.close(code=1011) | |
| # ββ 6. Lokaler Start ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| import uvicorn, sys | |
| port = int(sys.argv[1]) if len(sys.argv) > 1 else 7860 | |
| uvicorn.run("app:app", host="0.0.0.0", port=port) | |