Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -99,10 +99,8 @@ async def tts(ws: WebSocket):
|
|
| 99 |
try:
|
| 100 |
req = json.loads(await ws.receive_text())
|
| 101 |
ids, attn = build_inputs(req.get("text", ""), req.get("voice", "Jakob"))
|
| 102 |
-
prompt_len = ids.size(1)
|
| 103 |
-
|
| 104 |
-
past = None
|
| 105 |
-
buf = []
|
| 106 |
|
| 107 |
while True:
|
| 108 |
gen = model.generate(
|
|
@@ -111,44 +109,39 @@ async def tts(ws: WebSocket):
|
|
| 111 |
past_key_values=past,
|
| 112 |
max_new_tokens=CHUNK_TOKENS,
|
| 113 |
logits_processor=[MASKER],
|
| 114 |
-
do_sample=True,
|
| 115 |
return_dict_in_generate=True,
|
| 116 |
-
use_cache=True,
|
| 117 |
-
return_legacy_cache=True, # wichtig <4.49
|
| 118 |
)
|
| 119 |
|
| 120 |
-
# Cache fΓΌr den nΓ€chsten Loop
|
| 121 |
past = gen.past_key_values if not isinstance(gen.past_key_values, Cache) else gen.past_key_values.to_legacy()
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
prompt_len = len(seq) # nΓ€chstes Delta
|
| 126 |
-
|
| 127 |
-
if not new_tok: # (selten) nichts erzeugt β weiter
|
| 128 |
-
continue
|
| 129 |
|
| 130 |
for t in new_tok:
|
| 131 |
if t == EOS_TOKEN:
|
| 132 |
-
|
|
|
|
|
|
|
| 133 |
if t == NEW_BLOCK_TOKEN:
|
| 134 |
-
buf.clear()
|
| 135 |
-
continue
|
| 136 |
buf.append(t - AUDIO_BASE)
|
| 137 |
if len(buf) == 7:
|
| 138 |
await ws.send_bytes(decode_block(buf))
|
| 139 |
buf.clear()
|
| 140 |
|
| 141 |
-
ids =
|
| 142 |
|
| 143 |
-
except
|
| 144 |
-
pass
|
| 145 |
except Exception as e:
|
| 146 |
print("WSβError:", e)
|
| 147 |
if ws.client_state.name == "CONNECTED":
|
| 148 |
-
await ws.close(code=1011)
|
| 149 |
-
finally:
|
| 150 |
-
if ws.client_state.name == "CONNECTED":
|
| 151 |
-
await ws.close()
|
| 152 |
|
| 153 |
# ββ 6. Local run ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 154 |
if __name__ == "__main__":
|
|
|
|
| 99 |
try:
|
| 100 |
req = json.loads(await ws.receive_text())
|
| 101 |
ids, attn = build_inputs(req.get("text", ""), req.get("voice", "Jakob"))
|
| 102 |
+
prompt_len = ids.size(1)
|
| 103 |
+
past, buf = None, []
|
|
|
|
|
|
|
| 104 |
|
| 105 |
while True:
|
| 106 |
gen = model.generate(
|
|
|
|
| 109 |
past_key_values=past,
|
| 110 |
max_new_tokens=CHUNK_TOKENS,
|
| 111 |
logits_processor=[MASKER],
|
| 112 |
+
do_sample=True, temperature=0.7, top_p=0.95,
|
| 113 |
return_dict_in_generate=True,
|
| 114 |
+
use_cache=True, return_legacy_cache=True,
|
|
|
|
| 115 |
)
|
| 116 |
|
|
|
|
| 117 |
past = gen.past_key_values if not isinstance(gen.past_key_values, Cache) else gen.past_key_values.to_legacy()
|
| 118 |
+
seq = gen.sequences[0].tolist()
|
| 119 |
+
new_tok = seq[prompt_len:]
|
| 120 |
+
prompt_len = len(seq)
|
| 121 |
|
| 122 |
+
if not new_tok:
|
| 123 |
+
continue # selten, aber mΓΆglich
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
for t in new_tok:
|
| 126 |
if t == EOS_TOKEN:
|
| 127 |
+
# ein einziges CloseβFrame genΓΌgt
|
| 128 |
+
await ws.close() # <ββ einziges explizites close
|
| 129 |
+
return
|
| 130 |
if t == NEW_BLOCK_TOKEN:
|
| 131 |
+
buf.clear(); continue
|
|
|
|
| 132 |
buf.append(t - AUDIO_BASE)
|
| 133 |
if len(buf) == 7:
|
| 134 |
await ws.send_bytes(decode_block(buf))
|
| 135 |
buf.clear()
|
| 136 |
|
| 137 |
+
ids = attn = None # nur noch Cache
|
| 138 |
|
| 139 |
+
except WebSocketDisconnect:
|
| 140 |
+
pass # Client ging von selbst
|
| 141 |
except Exception as e:
|
| 142 |
print("WSβError:", e)
|
| 143 |
if ws.client_state.name == "CONNECTED":
|
| 144 |
+
await ws.close(code=1011) # Fehler melden
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
# ββ 6. Local run ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 147 |
if __name__ == "__main__":
|