Spaces:
Paused
Paused
| # app.py ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| import os | |
| import json | |
| import torch | |
| import asyncio | |
| import traceback # Import traceback for better error logging | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from huggingface_hub import login | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, StoppingCriteria, StoppingCriteriaList | |
| # Import BaseStreamer for the interface | |
| from transformers.generation.streamers import BaseStreamer | |
| from snac import SNAC # Ensure you have 'pip install snac' | |
| # --- Globals (populated in load_models) --- | |
| tok = None | |
| model = None | |
| snac = None | |
| masker = None | |
| stopping_criteria = None | |
| actual_eos_token_id = None # Will be determined during startup | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # 0) Login + Device --------------------------------------------------- | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| if HF_TOKEN: | |
| print("π Logging in to Hugging Face Hub...") | |
| login(HF_TOKEN) | |
| # torch.backends.cuda.enable_flash_sdp(False) # Uncomment if needed for PyTorchβ2.2βBug | |
| # 1) Konstanten ------------------------------------------------------- | |
| REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1" | |
| # CHUNK_TOKENS = 50 # Not directly used by us with the streamer approach | |
| START_TOKEN = 128259 | |
| NEW_BLOCK = 128257 | |
| # EOS_TOKEN = 128258 # REMOVED - Will be determined from model/tokenizer config | |
| AUDIO_BASE = 128266 | |
| AUDIO_SPAN = 4096 * 7 # 28672 Codes | |
| CODEBOOK_SIZE = 4096 # Explicitly define the codebook size | |
| # Create AUDIO_IDS on the correct device later in load_models | |
| AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN) | |
| # 2) LogitβMask ------------------------------------------------------- | |
| # Uses the dynamically determined EOS token ID | |
| class AudioMask(LogitsProcessor): | |
| def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int): | |
| super().__init__() | |
| # Ensure input tensors are Long type for concatenation if needed, although indices are usually int | |
| new_block_tensor = torch.tensor([new_block_token_id], device=audio_ids.device, dtype=torch.long) | |
| eos_tensor = torch.tensor([eos_token_id], device=audio_ids.device, dtype=torch.long) | |
| # Allow NEW_BLOCK and all valid audio tokens initially | |
| self.allow = torch.cat([new_block_tensor, audio_ids], dim=0) | |
| self.eos = eos_tensor # Store EOS token ID as tensor | |
| self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0) # Precompute combined tensor | |
| self.sent_blocks = 0 # State: Number of audio blocks sent | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
| # Determine which tokens are allowed based on whether blocks have been sent | |
| current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow | |
| # Create a mask initialized to negative infinity | |
| mask = torch.full_like(scores, float("-inf")) | |
| # Set allowed token scores to 0 (effectively allowing them) | |
| mask[:, current_allow] = 0 | |
| # Apply the mask to the scores | |
| return scores + mask | |
| def reset(self): | |
| """Resets the state for a new generation request.""" | |
| self.sent_blocks = 0 | |
| # 3) StoppingCriteria fΓΌr EOS --------------------------------------- | |
| # Uses the dynamically determined EOS token ID | |
| class EosStoppingCriteria(StoppingCriteria): | |
| def __init__(self, eos_token_id: int): | |
| self.eos_token_id = eos_token_id | |
| if self.eos_token_id is None: | |
| print("β οΈ EosStoppingCriteria initialized with eos_token_id=None!") | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
| if self.eos_token_id is None: | |
| return False # Cannot stop if EOS ID is unknown | |
| # Check if the *last* generated token is the EOS token | |
| if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id: | |
| # print("StoppingCriteria: EOS detected.") | |
| return True | |
| return False | |
| # 4) Benutzerdefinierter AudioStreamer ------------------------------- | |
| class AudioStreamer(BaseStreamer): | |
| def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str, eos_token_id: int): | |
| self.ws = ws | |
| self.snac = snac_decoder | |
| self.masker = audio_mask | |
| self.loop = loop | |
| self.device = target_device | |
| self.eos_token_id = eos_token_id # Store EOS ID for potential use in put (optional) | |
| self.buf: list[int] = [] | |
| self.tasks = set() | |
| def _decode_block(self, block7: list[int]) -> bytes: | |
| """ | |
| Decodes a block of 7 audio token values (AUDIO_BASE subtracted) into audio bytes. | |
| NOTE: Extracts base code value (0-4095) using modulo, assuming | |
| input values represent (slot_offset + code_value). | |
| Maps extracted values using the structure potentially correct for Kartoffel_Orpheus. | |
| """ | |
| if len(block7) != 7: | |
| print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.") | |
| return b"" | |
| try: | |
| # --- Extract base code value (0 to CODEBOOK_SIZE-1) for each slot using modulo --- | |
| code_val_0 = block7[0] % CODEBOOK_SIZE | |
| code_val_1 = block7[1] % CODEBOOK_SIZE | |
| code_val_2 = block7[2] % CODEBOOK_SIZE | |
| code_val_3 = block7[3] % CODEBOOK_SIZE | |
| code_val_4 = block7[4] % CODEBOOK_SIZE | |
| code_val_5 = block7[5] % CODEBOOK_SIZE | |
| code_val_6 = block7[6] % CODEBOOK_SIZE | |
| # --- Map the extracted code values to the SNAC codebooks (l1, l2, l3) --- | |
| l1 = [code_val_0] | |
| l2 = [code_val_1, code_val_4] | |
| l3 = [code_val_2, code_val_3, code_val_5, code_val_6] | |
| except IndexError: | |
| print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}") | |
| return b"" | |
| except Exception as e_map: # Catch potential issues with modulo/mapping | |
| print(f"Streamer Error: Exception during code value extraction/mapping: {e_map}. Block: {block7}") | |
| return b"" | |
| # --- Convert lists to tensors on the correct device --- | |
| try: | |
| codes_l1 = torch.tensor(l1, dtype=torch.long, device=self.device).unsqueeze(0) | |
| codes_l2 = torch.tensor(l2, dtype=torch.long, device=self.device).unsqueeze(0) | |
| codes_l3 = torch.tensor(l3, dtype=torch.long, device=self.device).unsqueeze(0) | |
| codes = [codes_l1, codes_l2, codes_l3] | |
| except Exception as e_tensor: | |
| print(f"Streamer Error: Exception during tensor conversion: {e_tensor}. l1={l1}, l2={l2}, l3={l3}") | |
| return b"" | |
| # --- Decode using SNAC --- | |
| try: | |
| with torch.no_grad(): | |
| audio = self.snac.decode(codes)[0] | |
| except Exception as e_decode: | |
| print(f"Streamer Error: Exception during snac.decode: {e_decode}") | |
| print(f"Input codes shapes: {[c.shape for c in codes]}") | |
| print(f"Input codes dtypes: {[c.dtype for c in codes]}") | |
| print(f"Input codes devices: {[c.device for c in codes]}") | |
| print(f"Input code values (min/max): L1({min(l1)}/{max(l1)}) L2({min(l2)}/{max(l2)}) L3({min(l3)}/{max(l3)})") | |
| return b"" | |
| # --- Post-processing --- | |
| try: | |
| audio_np = audio.squeeze().detach().cpu().numpy() | |
| audio_bytes = (audio_np * 32767).astype("int16").tobytes() | |
| return audio_bytes | |
| except Exception as e_post: | |
| print(f"Streamer Error: Exception during post-processing: {e_post}. Audio tensor shape: {audio.shape}") | |
| return b"" | |
| async def _send_audio_bytes(self, data: bytes): | |
| """Coroutine to send bytes over WebSocket.""" | |
| if not data: | |
| return | |
| try: | |
| await self.ws.send_bytes(data) | |
| except WebSocketDisconnect: | |
| print("Streamer: WebSocket disconnected during send.") | |
| except Exception as e: | |
| # Handle cases where sending fails after connection closed | |
| if "Cannot call \"send\" once a close message has been sent" in str(e): | |
| # This is expected if client disconnects during generation, suppress repetitive logs | |
| pass | |
| else: | |
| print(f"Streamer: Error sending bytes: {e}") | |
| def put(self, value: torch.LongTensor): | |
| """ | |
| Receives new token IDs (Tensor) from generate(). | |
| Processes tokens, decodes full blocks, and schedules sending. | |
| """ | |
| if value.numel() == 0: | |
| return | |
| # Ensure value is on CPU and flatten to a list of ints | |
| new_token_ids = value.squeeze().cpu().tolist() | |
| if isinstance(new_token_ids, int): | |
| new_token_ids = [new_token_ids] | |
| for t in new_token_ids: | |
| # No need to check for EOS here, StoppingCriteria handles it | |
| if t == NEW_BLOCK: | |
| self.buf.clear() | |
| continue | |
| if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN: | |
| self.buf.append(t - AUDIO_BASE) # Store value relative to base | |
| # else: # Optionally log ignored tokens outside audio range | |
| # if t != self.eos_token_id: # Don't warn about the EOS token itself | |
| # print(f"Streamer Warning: Ignoring unexpected token {t}") | |
| if len(self.buf) == 7: | |
| audio_bytes = self._decode_block(self.buf) | |
| self.buf.clear() | |
| if audio_bytes: | |
| # Schedule the async send function to run on the main event loop | |
| future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop) | |
| self.tasks.add(future) | |
| future.add_done_callback(self.tasks.discard) | |
| # Allow EOS only after the first full block has been processed | |
| if self.masker.sent_blocks == 0: | |
| self.masker.sent_blocks = 1 | |
| def end(self): | |
| """Called by generate() when generation finishes.""" | |
| if len(self.buf) > 0: | |
| print(f"Streamer: End of generation with incomplete block ({len(self.buf)} tokens). Discarding.") | |
| self.buf.clear() | |
| pass | |
| # 5) FastAPI App ------------------------------------------------------ | |
| app = FastAPI() | |
| async def load_models_startup(): | |
| global tok, model, snac, masker, stopping_criteria, device, AUDIO_IDS_CPU, actual_eos_token_id | |
| print(f"π Starting up on device: {device}") | |
| print("β³ Lade Modelle β¦", flush=True) | |
| tok = AutoTokenizer.from_pretrained(REPO) | |
| print("Tokenizer loaded.") | |
| snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device) | |
| print(f"SNAC loaded to {device}.") | |
| model_dtype = torch.float32 | |
| if device == "cuda": | |
| if torch.cuda.is_bf16_supported(): | |
| model_dtype = torch.bfloat16 | |
| print("Using bfloat16 for model.") | |
| else: | |
| model_dtype = torch.float16 | |
| print("Using float16 for model.") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| REPO, | |
| device_map={"": 0} if device == "cuda" else None, | |
| torch_dtype=model_dtype, | |
| low_cpu_mem_usage=True, | |
| ) | |
| print(f"Model loaded to {model.device} with dtype {model.dtype}.") | |
| model.eval() | |
| # --- Determine and set the correct EOS token ID --- | |
| conf_eos = model.config.eos_token_id | |
| tok_eos = tok.eos_token_id | |
| print(f"Model Config EOS ID: {conf_eos}") | |
| print(f"Tokenizer EOS ID: {tok_eos}") | |
| if conf_eos is not None: | |
| actual_eos_token_id = conf_eos | |
| elif tok_eos is not None: | |
| actual_eos_token_id = tok_eos | |
| print(f"β οΈ Model config EOS ID is None, using Tokenizer EOS ID: {actual_eos_token_id}") | |
| else: | |
| raise ValueError("Could not determine EOS token ID from model config or tokenizer.") | |
| print(f"Using EOS Token ID: {actual_eos_token_id}") | |
| # Set pad_token_id to eos_token_id if not already set (common practice for generation) | |
| if model.config.pad_token_id is None: | |
| print(f"Setting model.config.pad_token_id to EOS token ID ({actual_eos_token_id})") | |
| model.config.pad_token_id = actual_eos_token_id | |
| # --- End EOS Token ID determination --- | |
| audio_ids_device = AUDIO_IDS_CPU.to(device) | |
| # Pass the determined EOS ID to the mask | |
| masker = AudioMask(audio_ids_device, NEW_BLOCK, actual_eos_token_id) | |
| print("AudioMask initialized.") | |
| # Pass the determined EOS ID to the stopping criteria | |
| stopping_criteria = StoppingCriteriaList([EosStoppingCriteria(actual_eos_token_id)]) | |
| print("StoppingCriteria initialized.") | |
| print("β Modelle geladen und bereit!", flush=True) | |
| def hello(): | |
| return {"status": "ok", "message": "TTS Service is running"} | |
| # 6) Helper zum Prompt Bauen ------------------------------------------- | |
| def build_prompt(text: str, voice: str) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Builds the input_ids and attention_mask for the model.""" | |
| prompt_text = f"{voice}: {text}" | |
| prompt_ids = tok(prompt_text, return_tensors="pt").input_ids.to(device) | |
| input_ids = torch.cat([ | |
| torch.tensor([[START_TOKEN]], device=device, dtype=torch.long), | |
| prompt_ids, | |
| torch.tensor([[NEW_BLOCK]], device=device, dtype=torch.long) | |
| ], dim=1) | |
| attention_mask = torch.ones_like(input_ids) | |
| return input_ids, attention_mask | |
| # 7) WebSocketβEndpoint (vereinfacht mit Streamer) --------------------- | |
| async def tts(ws: WebSocket): | |
| global actual_eos_token_id # Ensure we can access the determined EOS ID | |
| await ws.accept() | |
| print("π Client connected") | |
| streamer = None | |
| main_loop = asyncio.get_running_loop() | |
| try: | |
| req_text = await ws.receive_text() | |
| print(f"Received request: {req_text}") | |
| req = json.loads(req_text) | |
| text = req.get("text", "Hallo Welt, wie geht es dir heute?") | |
| voice = req.get("voice", "Jakob") | |
| if not text: | |
| print("β οΈ Request text is empty.") | |
| await ws.close(code=1003, reason="Text cannot be empty") | |
| return | |
| print(f"Generating audio for: '{text}' with voice '{voice}'") | |
| ids, attn = build_prompt(text, voice) | |
| masker.reset() | |
| # Pass the determined EOS ID to the streamer as well (optional, for logging/checks) | |
| streamer = AudioStreamer(ws, snac, masker, main_loop, device, actual_eos_token_id) | |
| print("Starting generation in background thread...") | |
| # Use sampling parameters to avoid repetition | |
| await asyncio.to_thread( | |
| model.generate, | |
| input_ids=ids, | |
| attention_mask=attn, | |
| max_new_tokens=2500, # Increased slightly, adjust as needed | |
| logits_processor=[masker], | |
| stopping_criteria=stopping_criteria, | |
| # --- Sampling Parameters --- | |
| do_sample=True, | |
| temperature=0.6, | |
| top_p=0.9, | |
| repetition_penalty=1.15, | |
| # --- End Sampling Parameters --- | |
| use_cache=True, | |
| streamer=streamer, | |
| eos_token_id=actual_eos_token_id # Explicitly pass correct EOS ID here too | |
| ) | |
| print("Generation thread finished.") | |
| except WebSocketDisconnect: | |
| print("π Client disconnected.") | |
| except json.JSONDecodeError: | |
| print("β Invalid JSON received.") | |
| if ws.client_state.name == "CONNECTED": | |
| await ws.close(code=1003, reason="Invalid JSON format") | |
| except Exception as e: | |
| error_details = traceback.format_exc() | |
| print(f"β WSβError: {e}\n{error_details}", flush=True) | |
| error_payload = json.dumps({"error": str(e)}) | |
| try: | |
| if ws.client_state.name == "CONNECTED": | |
| await ws.send_text(error_payload) | |
| except Exception: | |
| pass | |
| if ws.client_state.name == "CONNECTED": | |
| await ws.close(code=1011) | |
| finally: | |
| if streamer: | |
| try: | |
| streamer.end() | |
| except Exception as e_end: | |
| print(f"Error during streamer.end(): {e_end}") | |
| print("Closing connection.") | |
| if ws.client_state.name == "CONNECTED": | |
| try: | |
| await ws.close(code=1000) | |
| except RuntimeError as e_close: | |
| # Suppress "Cannot call 'send'..." error during final close if already disconnected | |
| if "Cannot call \"send\"" not in str(e_close): | |
| print(f"Runtime error closing websocket: {e_close}") | |
| except Exception as e_close_final: | |
| print(f"Error closing websocket: {e_close_final}") | |
| elif ws.client_state.name != "DISCONNECTED": | |
| print(f"WebSocket final state: {ws.client_state.name}") | |
| print("Connection closed.") | |
| # 8) DevβStart -------------------------------------------------------- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| print("Starting Uvicorn server...") | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info") |