Spaces:
Paused
Paused
| # app.py ────────────────────────────────────────────────────────────── | |
| import os, json, torch, asyncio | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from huggingface_hub import login | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor | |
| from transformers.generation.utils import Cache | |
| from snac import SNAC | |
| # ── 0 · Login & Device ─────────────────────────────────────────────── | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| if HF_TOKEN: | |
| login(HF_TOKEN) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch.backends.cuda.enable_flash_sdp(False) # CUDA‑Assert‑Fix | |
| # ── 1 · Konstanten ─────────────────────────────────────────────────── | |
| REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" | |
| CHUNK_TOKENS = 50 | |
| START_TOKEN = 128259 | |
| NEW_BLOCK = 128257 | |
| EOS_TOKEN = 128258 | |
| AUDIO_BASE = 128266 | |
| VALID_AUDIO = torch.arange(AUDIO_BASE, AUDIO_BASE+4096) | |
| # ── 2 · Logit‑Masker ───────────────────────────────────────────────── | |
| class DynamicAudioMask(LogitsProcessor): | |
| def __init__(self, audio_ids: torch.Tensor, min_blocks:int=1): | |
| super().__init__() | |
| self.audio_ids = audio_ids | |
| self.ctrl_ids = torch.tensor([NEW_BLOCK], device=audio_ids.device) | |
| self.min_blocks = min_blocks | |
| self.blocks = 0 | |
| def __call__(self, inp, scores): | |
| allow = torch.cat([self.audio_ids, self.ctrl_ids]) | |
| if self.blocks >= self.min_blocks: | |
| allow = torch.cat([allow, | |
| torch.tensor([EOS_TOKEN], device=scores.device)]) | |
| mask = torch.full_like(scores, float("-inf")) | |
| mask[:, allow] = 0 | |
| return scores + mask | |
| # ── 3 · FastAPI‑App ────────────────────────────────────────────────── | |
| app = FastAPI() | |
| async def root(): | |
| return {"msg": "Orpheus‑TTS alive"} | |
| async def load(): | |
| global tok, model, snac, masker | |
| print("⏳ Lade Modelle …") | |
| tok = AutoTokenizer.from_pretrained(REPO) | |
| snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| REPO, | |
| low_cpu_mem_usage=True, | |
| device_map={"":0} if device=="cuda" else None, | |
| torch_dtype=torch.bfloat16 if device=="cuda" else None) | |
| model.config.pad_token_id = model.config.eos_token_id | |
| model.config.use_cache = True | |
| masker = DynamicAudioMask(VALID_AUDIO.to(device)) | |
| print("✅ Modelle geladen") | |
| # ── 4 · Hilfsfunktionen ────────────────────────────────────────────── | |
| def build_inputs(text:str, voice:str): | |
| prompt = f"{voice}: {text}" | |
| ids = tok(prompt, return_tensors="pt").input_ids.to(device) | |
| ids = torch.cat([torch.tensor([[START_TOKEN]], device=device), | |
| ids, | |
| torch.tensor([[128009,128260]], device=device)],1) | |
| return ids, torch.ones_like(ids) | |
| def decode_block(block): | |
| l1,l2,l3=[],[],[] | |
| l1.append(block[0]) | |
| l2.append(block[1]-4096) | |
| l3.extend([block[2]-8192, block[3]-12288]) | |
| l2.append(block[4]-16384) | |
| l3.extend([block[5]-20480, block[6]-24576]) | |
| codes=[torch.tensor(x,device=device).unsqueeze(0) for x in (l1,l2,l3)] | |
| audio=snac.decode(codes).squeeze().cpu().numpy() | |
| return (audio*32767).astype("int16").tobytes() | |
| # ── 5 · WebSocket‑TTS ──────────────────────────────────────────────── | |
| async def tts(ws:WebSocket): | |
| await ws.accept() | |
| try: | |
| req = json.loads(await ws.receive_text()) | |
| text = req.get("text","") | |
| voice = req.get("voice","Jakob") | |
| ids, attn = build_inputs(text, voice) | |
| total_len = ids.shape[1] # Länge des Prompts | |
| past = None | |
| last_tok = None | |
| buf = [] | |
| while True: | |
| out = model.generate( | |
| input_ids = ids if past is None else torch.tensor([[last_tok]], device=device), | |
| attention_mask = attn if past is None else None, | |
| past_key_values = past, | |
| max_new_tokens = CHUNK_TOKENS, | |
| logits_processor= [masker], | |
| do_sample=True, temperature=0.7, top_p=0.95, | |
| use_cache=True, return_dict_in_generate=True, | |
| return_legacy_cache=True) | |
| pkv = out.past_key_values | |
| if isinstance(pkv, Cache): pkv = pkv.to_legacy() | |
| past = pkv | |
| seq = out.sequences[0].tolist() | |
| new = seq[total_len:] # alles *nach* Prompt | |
| total_len = len(seq) # fürs nächste Mal | |
| print("new tokens:", new[:32]) | |
| if not new: # nichts generiert | |
| raise StopIteration | |
| for t in new: | |
| last_tok = t | |
| if t == EOS_TOKEN: raise StopIteration | |
| if t == NEW_BLOCK: | |
| buf.clear(); continue | |
| buf.append(t-AUDIO_BASE) | |
| if len(buf)==7: | |
| await ws.send_bytes(decode_block(buf)) | |
| buf.clear() | |
| masker.blocks += 1 | |
| ids, attn = None, None # ab jetzt 1‑Token‑Step | |
| except (StopIteration, WebSocketDisconnect): | |
| pass | |
| except Exception as e: | |
| print("❌ WS‑Error:", e) | |
| if ws.client_state.name != "DISCONNECTED": | |
| await ws.close(code=1011) | |
| finally: | |
| if ws.client_state.name != "DISCONNECTED": | |
| try: await ws.close() | |
| except RuntimeError: pass | |
| # ── 6 · local run ──────────────────────────────────────────────────── | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860) | |