Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -18,7 +18,7 @@ model = None
|
|
| 18 |
snac = None
|
| 19 |
masker = None
|
| 20 |
stopping_criteria = None
|
| 21 |
-
actual_eos_token_id = None #
|
| 22 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 23 |
|
| 24 |
# 0) Login + Device ---------------------------------------------------
|
|
@@ -31,10 +31,11 @@ if HF_TOKEN:
|
|
| 31 |
|
| 32 |
# 1) Konstanten -------------------------------------------------------
|
| 33 |
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
|
| 34 |
-
# CHUNK_TOKENS = 50 # Not directly used by us with the streamer approach
|
| 35 |
START_TOKEN = 128259
|
| 36 |
NEW_BLOCK = 128257
|
| 37 |
-
#
|
|
|
|
|
|
|
| 38 |
AUDIO_BASE = 128266
|
| 39 |
AUDIO_SPAN = 4096 * 7 # 28672 Codes
|
| 40 |
CODEBOOK_SIZE = 4096 # Explicitly define the codebook size
|
|
@@ -42,61 +43,51 @@ CODEBOOK_SIZE = 4096 # Explicitly define the codebook size
|
|
| 42 |
AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)
|
| 43 |
|
| 44 |
# 2) Logit‑Mask -------------------------------------------------------
|
| 45 |
-
# Uses the
|
| 46 |
class AudioMask(LogitsProcessor):
|
| 47 |
def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int):
|
| 48 |
super().__init__()
|
| 49 |
-
# Ensure input tensors are Long type for concatenation if needed, although indices are usually int
|
| 50 |
new_block_tensor = torch.tensor([new_block_token_id], device=audio_ids.device, dtype=torch.long)
|
| 51 |
eos_tensor = torch.tensor([eos_token_id], device=audio_ids.device, dtype=torch.long)
|
| 52 |
-
|
| 53 |
-
# Allow NEW_BLOCK and all valid audio tokens initially
|
| 54 |
self.allow = torch.cat([new_block_tensor, audio_ids], dim=0)
|
| 55 |
-
self.eos = eos_tensor
|
| 56 |
-
self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0)
|
| 57 |
-
self.sent_blocks = 0
|
| 58 |
|
| 59 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
| 60 |
-
# Determine which tokens are allowed based on whether blocks have been sent
|
| 61 |
current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow
|
| 62 |
-
|
| 63 |
-
# Create a mask initialized to negative infinity
|
| 64 |
mask = torch.full_like(scores, float("-inf"))
|
| 65 |
-
# Set allowed token scores to 0 (effectively allowing them)
|
| 66 |
mask[:, current_allow] = 0
|
| 67 |
-
# Apply the mask to the scores
|
| 68 |
return scores + mask
|
| 69 |
|
| 70 |
def reset(self):
|
| 71 |
-
"""Resets the state for a new generation request."""
|
| 72 |
self.sent_blocks = 0
|
| 73 |
|
| 74 |
# 3) StoppingCriteria für EOS ---------------------------------------
|
| 75 |
-
# Uses the
|
| 76 |
class EosStoppingCriteria(StoppingCriteria):
|
| 77 |
def __init__(self, eos_token_id: int):
|
| 78 |
self.eos_token_id = eos_token_id
|
| 79 |
-
|
| 80 |
-
print("⚠️ EosStoppingCriteria initialized with eos_token_id=None!")
|
| 81 |
|
| 82 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
| 83 |
if self.eos_token_id is None:
|
| 84 |
-
return False
|
| 85 |
-
# Check if the *last* generated token is the EOS token
|
| 86 |
if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
|
| 87 |
-
|
| 88 |
return True
|
| 89 |
return False
|
| 90 |
|
| 91 |
# 4) Benutzerdefinierter AudioStreamer -------------------------------
|
| 92 |
class AudioStreamer(BaseStreamer):
|
|
|
|
| 93 |
def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str, eos_token_id: int):
|
| 94 |
self.ws = ws
|
| 95 |
self.snac = snac_decoder
|
| 96 |
self.masker = audio_mask
|
| 97 |
self.loop = loop
|
| 98 |
self.device = target_device
|
| 99 |
-
self.eos_token_id = eos_token_id # Store EOS ID
|
| 100 |
self.buf: list[int] = []
|
| 101 |
self.tasks = set()
|
| 102 |
|
|
@@ -108,8 +99,8 @@ class AudioStreamer(BaseStreamer):
|
|
| 108 |
Maps extracted values using the structure potentially correct for Kartoffel_Orpheus.
|
| 109 |
"""
|
| 110 |
if len(block7) != 7:
|
| 111 |
-
print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.")
|
| 112 |
-
return b""
|
| 113 |
|
| 114 |
try:
|
| 115 |
# --- Extract base code value (0 to CODEBOOK_SIZE-1) for each slot using modulo ---
|
|
@@ -129,7 +120,7 @@ class AudioStreamer(BaseStreamer):
|
|
| 129 |
except IndexError:
|
| 130 |
print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}")
|
| 131 |
return b""
|
| 132 |
-
except Exception as e_map:
|
| 133 |
print(f"Streamer Error: Exception during code value extraction/mapping: {e_map}. Block: {block7}")
|
| 134 |
return b""
|
| 135 |
|
|
@@ -149,10 +140,7 @@ class AudioStreamer(BaseStreamer):
|
|
| 149 |
audio = self.snac.decode(codes)[0]
|
| 150 |
except Exception as e_decode:
|
| 151 |
print(f"Streamer Error: Exception during snac.decode: {e_decode}")
|
| 152 |
-
|
| 153 |
-
print(f"Input codes dtypes: {[c.dtype for c in codes]}")
|
| 154 |
-
print(f"Input codes devices: {[c.device for c in codes]}")
|
| 155 |
-
print(f"Input code values (min/max): L1({min(l1)}/{max(l1)}) L2({min(l2)}/{max(l2)}) L3({min(l3)}/{max(l3)})")
|
| 156 |
return b""
|
| 157 |
|
| 158 |
# --- Post-processing ---
|
|
@@ -171,10 +159,12 @@ class AudioStreamer(BaseStreamer):
|
|
| 171 |
try:
|
| 172 |
await self.ws.send_bytes(data)
|
| 173 |
except WebSocketDisconnect:
|
| 174 |
-
|
|
|
|
|
|
|
| 175 |
except Exception as e:
|
| 176 |
-
|
| 177 |
-
|
| 178 |
# This is expected if client disconnects during generation, suppress repetitive logs
|
| 179 |
pass
|
| 180 |
else:
|
|
@@ -187,7 +177,6 @@ class AudioStreamer(BaseStreamer):
|
|
| 187 |
"""
|
| 188 |
if value.numel() == 0:
|
| 189 |
return
|
| 190 |
-
# Ensure value is on CPU and flatten to a list of ints
|
| 191 |
new_token_ids = value.squeeze().cpu().tolist()
|
| 192 |
if isinstance(new_token_ids, int):
|
| 193 |
new_token_ids = [new_token_ids]
|
|
@@ -198,23 +187,22 @@ class AudioStreamer(BaseStreamer):
|
|
| 198 |
self.buf.clear()
|
| 199 |
continue
|
| 200 |
|
|
|
|
| 201 |
if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN:
|
| 202 |
self.buf.append(t - AUDIO_BASE) # Store value relative to base
|
| 203 |
-
# else: # Optionally log ignored tokens
|
| 204 |
-
|
| 205 |
-
|
| 206 |
|
| 207 |
if len(self.buf) == 7:
|
| 208 |
audio_bytes = self._decode_block(self.buf)
|
| 209 |
self.buf.clear()
|
| 210 |
|
| 211 |
if audio_bytes:
|
| 212 |
-
# Schedule the async send function to run on the main event loop
|
| 213 |
future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop)
|
| 214 |
self.tasks.add(future)
|
| 215 |
future.add_done_callback(self.tasks.discard)
|
| 216 |
|
| 217 |
-
# Allow EOS only after the first full block has been processed
|
| 218 |
if self.masker.sent_blocks == 0:
|
| 219 |
self.masker.sent_blocks = 1
|
| 220 |
|
|
@@ -230,7 +218,8 @@ app = FastAPI()
|
|
| 230 |
|
| 231 |
@app.on_event("startup")
|
| 232 |
async def load_models_startup():
|
| 233 |
-
global
|
|
|
|
| 234 |
|
| 235 |
print(f"🚀 Starting up on device: {device}")
|
| 236 |
print("⏳ Lade Modelle …", flush=True)
|
|
@@ -259,34 +248,28 @@ async def load_models_startup():
|
|
| 259 |
print(f"Model loaded to {model.device} with dtype {model.dtype}.")
|
| 260 |
model.eval()
|
| 261 |
|
| 262 |
-
# ---
|
| 263 |
conf_eos = model.config.eos_token_id
|
| 264 |
tok_eos = tok.eos_token_id
|
| 265 |
print(f"Model Config EOS ID: {conf_eos}")
|
| 266 |
print(f"Tokenizer EOS ID: {tok_eos}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
-
if
|
| 269 |
-
actual_eos_token_id = conf_eos
|
| 270 |
-
elif tok_eos is not None:
|
| 271 |
-
actual_eos_token_id = tok_eos
|
| 272 |
-
print(f"⚠️ Model config EOS ID is None, using Tokenizer EOS ID: {actual_eos_token_id}")
|
| 273 |
-
else:
|
| 274 |
-
raise ValueError("Could not determine EOS token ID from model config or tokenizer.")
|
| 275 |
-
|
| 276 |
-
print(f"Using EOS Token ID: {actual_eos_token_id}")
|
| 277 |
-
# Set pad_token_id to eos_token_id if not already set (common practice for generation)
|
| 278 |
if model.config.pad_token_id is None:
|
| 279 |
-
print(f"Setting model.config.pad_token_id to EOS token ID ({
|
| 280 |
-
model.config.pad_token_id =
|
| 281 |
-
# --- End EOS Token ID determination ---
|
| 282 |
|
| 283 |
audio_ids_device = AUDIO_IDS_CPU.to(device)
|
| 284 |
-
# Pass the
|
| 285 |
-
masker = AudioMask(audio_ids_device, NEW_BLOCK,
|
| 286 |
print("AudioMask initialized.")
|
| 287 |
|
| 288 |
-
# Pass the
|
| 289 |
-
stopping_criteria = StoppingCriteriaList([EosStoppingCriteria(
|
| 290 |
print("StoppingCriteria initialized.")
|
| 291 |
|
| 292 |
print("✅ Modelle geladen und bereit!", flush=True)
|
|
@@ -313,7 +296,7 @@ def build_prompt(text: str, voice: str) -> tuple[torch.Tensor, torch.Tensor]:
|
|
| 313 |
# 7) WebSocket‑Endpoint (vereinfacht mit Streamer) ---------------------
|
| 314 |
@app.websocket("/ws/tts")
|
| 315 |
async def tts(ws: WebSocket):
|
| 316 |
-
|
| 317 |
await ws.accept()
|
| 318 |
print("🔌 Client connected")
|
| 319 |
streamer = None
|
|
@@ -334,27 +317,28 @@ async def tts(ws: WebSocket):
|
|
| 334 |
print(f"Generating audio for: '{text}' with voice '{voice}'")
|
| 335 |
ids, attn = build_prompt(text, voice)
|
| 336 |
masker.reset()
|
| 337 |
-
# Pass the
|
| 338 |
-
streamer = AudioStreamer(ws, snac, masker, main_loop, device,
|
| 339 |
|
| 340 |
print("Starting generation in background thread...")
|
| 341 |
-
# Use sampling parameters
|
| 342 |
await asyncio.to_thread(
|
| 343 |
model.generate,
|
| 344 |
input_ids=ids,
|
| 345 |
attention_mask=attn,
|
| 346 |
-
max_new_tokens=2500, #
|
| 347 |
logits_processor=[masker],
|
| 348 |
stopping_criteria=stopping_criteria,
|
| 349 |
-
# --- Sampling Parameters ---
|
| 350 |
do_sample=True,
|
| 351 |
-
temperature=0.6,
|
| 352 |
-
top_p=0.9,
|
| 353 |
-
repetition_penalty=1.
|
|
|
|
| 354 |
# --- End Sampling Parameters ---
|
| 355 |
use_cache=True,
|
| 356 |
streamer=streamer,
|
| 357 |
-
eos_token_id=
|
| 358 |
)
|
| 359 |
print("Generation thread finished.")
|
| 360 |
|
|
@@ -387,8 +371,7 @@ async def tts(ws: WebSocket):
|
|
| 387 |
try:
|
| 388 |
await ws.close(code=1000)
|
| 389 |
except RuntimeError as e_close:
|
| 390 |
-
|
| 391 |
-
if "Cannot call \"send\"" not in str(e_close):
|
| 392 |
print(f"Runtime error closing websocket: {e_close}")
|
| 393 |
except Exception as e_close_final:
|
| 394 |
print(f"Error closing websocket: {e_close_final}")
|
|
|
|
| 18 |
snac = None
|
| 19 |
masker = None
|
| 20 |
stopping_criteria = None
|
| 21 |
+
# actual_eos_token_id = None # Reverted to constant below
|
| 22 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 23 |
|
| 24 |
# 0) Login + Device ---------------------------------------------------
|
|
|
|
| 31 |
|
| 32 |
# 1) Konstanten -------------------------------------------------------
|
| 33 |
REPO = "SebastianBodza/Kartoffel_Orpheus-3B_german_natural-v0.1"
|
|
|
|
| 34 |
START_TOKEN = 128259
|
| 35 |
NEW_BLOCK = 128257
|
| 36 |
+
# --- Reverted to using the hardcoded EOS token based on user belief ---
|
| 37 |
+
EOS_TOKEN = 128258
|
| 38 |
+
# --- End Reverted EOS Token ---
|
| 39 |
AUDIO_BASE = 128266
|
| 40 |
AUDIO_SPAN = 4096 * 7 # 28672 Codes
|
| 41 |
CODEBOOK_SIZE = 4096 # Explicitly define the codebook size
|
|
|
|
| 43 |
AUDIO_IDS_CPU = torch.arange(AUDIO_BASE, AUDIO_BASE + AUDIO_SPAN)
|
| 44 |
|
| 45 |
# 2) Logit‑Mask -------------------------------------------------------
|
| 46 |
+
# Uses the constant EOS_TOKEN
|
| 47 |
class AudioMask(LogitsProcessor):
|
| 48 |
def __init__(self, audio_ids: torch.Tensor, new_block_token_id: int, eos_token_id: int):
|
| 49 |
super().__init__()
|
|
|
|
| 50 |
new_block_tensor = torch.tensor([new_block_token_id], device=audio_ids.device, dtype=torch.long)
|
| 51 |
eos_tensor = torch.tensor([eos_token_id], device=audio_ids.device, dtype=torch.long)
|
|
|
|
|
|
|
| 52 |
self.allow = torch.cat([new_block_tensor, audio_ids], dim=0)
|
| 53 |
+
self.eos = eos_tensor
|
| 54 |
+
self.allow_with_eos = torch.cat([self.allow, self.eos], dim=0)
|
| 55 |
+
self.sent_blocks = 0
|
| 56 |
|
| 57 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
|
|
| 58 |
current_allow = self.allow_with_eos if self.sent_blocks > 0 else self.allow
|
|
|
|
|
|
|
| 59 |
mask = torch.full_like(scores, float("-inf"))
|
|
|
|
| 60 |
mask[:, current_allow] = 0
|
|
|
|
| 61 |
return scores + mask
|
| 62 |
|
| 63 |
def reset(self):
|
|
|
|
| 64 |
self.sent_blocks = 0
|
| 65 |
|
| 66 |
# 3) StoppingCriteria für EOS ---------------------------------------
|
| 67 |
+
# Uses the constant EOS_TOKEN
|
| 68 |
class EosStoppingCriteria(StoppingCriteria):
|
| 69 |
def __init__(self, eos_token_id: int):
|
| 70 |
self.eos_token_id = eos_token_id
|
| 71 |
+
# No warning needed here as we are intentionally using the constant
|
|
|
|
| 72 |
|
| 73 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
| 74 |
if self.eos_token_id is None:
|
| 75 |
+
return False
|
|
|
|
| 76 |
if input_ids.shape[1] > 0 and input_ids[:, -1] == self.eos_token_id:
|
| 77 |
+
print(f"StoppingCriteria: EOS detected (ID: {self.eos_token_id}).") # Add log
|
| 78 |
return True
|
| 79 |
return False
|
| 80 |
|
| 81 |
# 4) Benutzerdefinierter AudioStreamer -------------------------------
|
| 82 |
class AudioStreamer(BaseStreamer):
|
| 83 |
+
# Pass the constant EOS_TOKEN here too
|
| 84 |
def __init__(self, ws: WebSocket, snac_decoder: SNAC, audio_mask: AudioMask, loop: asyncio.AbstractEventLoop, target_device: str, eos_token_id: int):
|
| 85 |
self.ws = ws
|
| 86 |
self.snac = snac_decoder
|
| 87 |
self.masker = audio_mask
|
| 88 |
self.loop = loop
|
| 89 |
self.device = target_device
|
| 90 |
+
self.eos_token_id = eos_token_id # Store constant EOS ID
|
| 91 |
self.buf: list[int] = []
|
| 92 |
self.tasks = set()
|
| 93 |
|
|
|
|
| 99 |
Maps extracted values using the structure potentially correct for Kartoffel_Orpheus.
|
| 100 |
"""
|
| 101 |
if len(block7) != 7:
|
| 102 |
+
# print(f"Streamer Warning: _decode_block received {len(block7)} tokens, expected 7. Skipping.")
|
| 103 |
+
return b"" # Less verbose logging
|
| 104 |
|
| 105 |
try:
|
| 106 |
# --- Extract base code value (0 to CODEBOOK_SIZE-1) for each slot using modulo ---
|
|
|
|
| 120 |
except IndexError:
|
| 121 |
print(f"Streamer Error: Index out of bounds during token mapping. Block: {block7}")
|
| 122 |
return b""
|
| 123 |
+
except Exception as e_map:
|
| 124 |
print(f"Streamer Error: Exception during code value extraction/mapping: {e_map}. Block: {block7}")
|
| 125 |
return b""
|
| 126 |
|
|
|
|
| 140 |
audio = self.snac.decode(codes)[0]
|
| 141 |
except Exception as e_decode:
|
| 142 |
print(f"Streamer Error: Exception during snac.decode: {e_decode}")
|
| 143 |
+
# Add more details if needed, e.g., shapes: {[c.shape for c in codes]}
|
|
|
|
|
|
|
|
|
|
| 144 |
return b""
|
| 145 |
|
| 146 |
# --- Post-processing ---
|
|
|
|
| 159 |
try:
|
| 160 |
await self.ws.send_bytes(data)
|
| 161 |
except WebSocketDisconnect:
|
| 162 |
+
# This is expected if client disconnects first, don't log error
|
| 163 |
+
# print("Streamer: WebSocket disconnected during send.")
|
| 164 |
+
pass
|
| 165 |
except Exception as e:
|
| 166 |
+
if "Cannot call \"send\" once a close message has been sent" in str(e) or \
|
| 167 |
+
"Connection is closed" in str(e):
|
| 168 |
# This is expected if client disconnects during generation, suppress repetitive logs
|
| 169 |
pass
|
| 170 |
else:
|
|
|
|
| 177 |
"""
|
| 178 |
if value.numel() == 0:
|
| 179 |
return
|
|
|
|
| 180 |
new_token_ids = value.squeeze().cpu().tolist()
|
| 181 |
if isinstance(new_token_ids, int):
|
| 182 |
new_token_ids = [new_token_ids]
|
|
|
|
| 187 |
self.buf.clear()
|
| 188 |
continue
|
| 189 |
|
| 190 |
+
# Use the constant EOS_TOKEN for comparison if needed (e.g. for logging)
|
| 191 |
if AUDIO_BASE <= t < AUDIO_BASE + AUDIO_SPAN:
|
| 192 |
self.buf.append(t - AUDIO_BASE) # Store value relative to base
|
| 193 |
+
# else: # Optionally log ignored tokens
|
| 194 |
+
# if t != self.eos_token_id: # Don't warn about the EOS token itself
|
| 195 |
+
# print(f"Streamer Warning: Ignoring unexpected token {t}")
|
| 196 |
|
| 197 |
if len(self.buf) == 7:
|
| 198 |
audio_bytes = self._decode_block(self.buf)
|
| 199 |
self.buf.clear()
|
| 200 |
|
| 201 |
if audio_bytes:
|
|
|
|
| 202 |
future = asyncio.run_coroutine_threadsafe(self._send_audio_bytes(audio_bytes), self.loop)
|
| 203 |
self.tasks.add(future)
|
| 204 |
future.add_done_callback(self.tasks.discard)
|
| 205 |
|
|
|
|
| 206 |
if self.masker.sent_blocks == 0:
|
| 207 |
self.masker.sent_blocks = 1
|
| 208 |
|
|
|
|
| 218 |
|
| 219 |
@app.on_event("startup")
|
| 220 |
async def load_models_startup():
|
| 221 |
+
# Keep global references, but EOS_TOKEN is now a constant again
|
| 222 |
+
global tok, model, snac, masker, stopping_criteria, device, AUDIO_IDS_CPU
|
| 223 |
|
| 224 |
print(f"🚀 Starting up on device: {device}")
|
| 225 |
print("⏳ Lade Modelle …", flush=True)
|
|
|
|
| 248 |
print(f"Model loaded to {model.device} with dtype {model.dtype}.")
|
| 249 |
model.eval()
|
| 250 |
|
| 251 |
+
# --- Print comparison for EOS token IDs but use the constant ---
|
| 252 |
conf_eos = model.config.eos_token_id
|
| 253 |
tok_eos = tok.eos_token_id
|
| 254 |
print(f"Model Config EOS ID: {conf_eos}")
|
| 255 |
print(f"Tokenizer EOS ID: {tok_eos}")
|
| 256 |
+
print(f"Using Constant EOS_TOKEN: {EOS_TOKEN}") # State the used constant
|
| 257 |
+
if conf_eos != EOS_TOKEN or tok_eos != EOS_TOKEN:
|
| 258 |
+
print(f"⚠️ WARNING: Constant EOS_TOKEN {EOS_TOKEN} differs from model/tokenizer IDs ({conf_eos}/{tok_eos}).")
|
| 259 |
+
# --- End EOS comparison ---
|
| 260 |
|
| 261 |
+
# Set pad_token_id if None (use the constant EOS)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
if model.config.pad_token_id is None:
|
| 263 |
+
print(f"Setting model.config.pad_token_id to Constant EOS token ID ({EOS_TOKEN})")
|
| 264 |
+
model.config.pad_token_id = EOS_TOKEN
|
|
|
|
| 265 |
|
| 266 |
audio_ids_device = AUDIO_IDS_CPU.to(device)
|
| 267 |
+
# Pass the constant EOS_TOKEN to the mask
|
| 268 |
+
masker = AudioMask(audio_ids_device, NEW_BLOCK, EOS_TOKEN)
|
| 269 |
print("AudioMask initialized.")
|
| 270 |
|
| 271 |
+
# Pass the constant EOS_TOKEN to the stopping criteria
|
| 272 |
+
stopping_criteria = StoppingCriteriaList([EosStoppingCriteria(EOS_TOKEN)])
|
| 273 |
print("StoppingCriteria initialized.")
|
| 274 |
|
| 275 |
print("✅ Modelle geladen und bereit!", flush=True)
|
|
|
|
| 296 |
# 7) WebSocket‑Endpoint (vereinfacht mit Streamer) ---------------------
|
| 297 |
@app.websocket("/ws/tts")
|
| 298 |
async def tts(ws: WebSocket):
|
| 299 |
+
# No need for global actual_eos_token_id
|
| 300 |
await ws.accept()
|
| 301 |
print("🔌 Client connected")
|
| 302 |
streamer = None
|
|
|
|
| 317 |
print(f"Generating audio for: '{text}' with voice '{voice}'")
|
| 318 |
ids, attn = build_prompt(text, voice)
|
| 319 |
masker.reset()
|
| 320 |
+
# Pass the constant EOS_TOKEN to streamer
|
| 321 |
+
streamer = AudioStreamer(ws, snac, masker, main_loop, device, EOS_TOKEN)
|
| 322 |
|
| 323 |
print("Starting generation in background thread...")
|
| 324 |
+
# Use sampling parameters with anti-repetition measures
|
| 325 |
await asyncio.to_thread(
|
| 326 |
model.generate,
|
| 327 |
input_ids=ids,
|
| 328 |
attention_mask=attn,
|
| 329 |
+
max_new_tokens=2500, # Or adjust as needed
|
| 330 |
logits_processor=[masker],
|
| 331 |
stopping_criteria=stopping_criteria,
|
| 332 |
+
# --- Sampling Parameters with Anti-Repetition ---
|
| 333 |
do_sample=True,
|
| 334 |
+
temperature=0.6, # Adjust if needed
|
| 335 |
+
top_p=0.9, # Adjust if needed
|
| 336 |
+
repetition_penalty=1.2, # Increased (experiment!)
|
| 337 |
+
no_repeat_ngram_size=4, # Added (experiment!)
|
| 338 |
# --- End Sampling Parameters ---
|
| 339 |
use_cache=True,
|
| 340 |
streamer=streamer,
|
| 341 |
+
eos_token_id=EOS_TOKEN # Explicitly pass constant EOS ID
|
| 342 |
)
|
| 343 |
print("Generation thread finished.")
|
| 344 |
|
|
|
|
| 371 |
try:
|
| 372 |
await ws.close(code=1000)
|
| 373 |
except RuntimeError as e_close:
|
| 374 |
+
if "Cannot call \"send\"" not in str(e_close) and "Connection is closed" not in str(e_close):
|
|
|
|
| 375 |
print(f"Runtime error closing websocket: {e_close}")
|
| 376 |
except Exception as e_close_final:
|
| 377 |
print(f"Error closing websocket: {e_close_final}")
|