Commit
·
c985b41
1
Parent(s):
dd42331
manual reset for /generate. something's still accumulating...
Browse files
app.py
CHANGED
|
@@ -372,6 +372,67 @@ _MRT_LOCK = threading.Lock()
|
|
| 372 |
_PROGRESS = {}
|
| 373 |
_PROGRESS_LOCK = threading.Lock()
|
| 374 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
def _progress_update(req_id: str, n: int, total: int, stage: str = "generating"):
|
| 376 |
if not req_id:
|
| 377 |
return
|
|
@@ -741,6 +802,26 @@ def model_select(req: ModelSelect):
|
|
| 741 |
# one-shot generation
|
| 742 |
# ----------------------------
|
| 743 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 744 |
@app.get("/progress")
|
| 745 |
def progress(request_id: str):
|
| 746 |
return _progress_get(request_id)
|
|
@@ -762,11 +843,26 @@ def generate(
|
|
| 762 |
target_sample_rate: int | None = Form(None),
|
| 763 |
intro_bars_to_drop: int = Form(0),
|
| 764 |
request_id: str = Form(None),
|
|
|
|
| 765 |
):
|
|
|
|
|
|
|
| 766 |
req_id = request_id or str(uuid.uuid4())
|
| 767 |
tmp_path = None
|
| 768 |
|
| 769 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 770 |
# 0) Read file -> tmp wav
|
| 771 |
data = loop_audio.file.read()
|
| 772 |
if not data:
|
|
|
|
| 372 |
_PROGRESS = {}
|
| 373 |
_PROGRESS_LOCK = threading.Lock()
|
| 374 |
|
| 375 |
+
_GENERATE_COUNTER = 0
|
| 376 |
+
_GENERATE_COUNTER_LOCK = threading.Lock()
|
| 377 |
+
|
| 378 |
+
# In app.py, near the top with other globals
|
| 379 |
+
_GENERATE_COUNTER = 0
|
| 380 |
+
_GENERATE_COUNTER_LOCK = threading.Lock()
|
| 381 |
+
|
| 382 |
+
def _light_reset_mrt(mrt):
|
| 383 |
+
"""
|
| 384 |
+
Lightweight reset that clears accumulated state without full recompilation.
|
| 385 |
+
Should take <1 second instead of 30 seconds.
|
| 386 |
+
"""
|
| 387 |
+
import logging
|
| 388 |
+
logging.info("Performing light MRT reset after prolonged use...")
|
| 389 |
+
|
| 390 |
+
try:
|
| 391 |
+
# 1. Clear JAX device arrays (but not compiled functions)
|
| 392 |
+
import jax
|
| 393 |
+
for device in jax.devices():
|
| 394 |
+
# Force garbage collection on device
|
| 395 |
+
try:
|
| 396 |
+
device.clear_memory() # If available in your JAX version
|
| 397 |
+
except AttributeError:
|
| 398 |
+
pass
|
| 399 |
+
|
| 400 |
+
# 2. Clear any MRT-level caches
|
| 401 |
+
attrs_to_clear = ['_last_state', '_generation_cache', '_style_cache']
|
| 402 |
+
for attr in attrs_to_clear:
|
| 403 |
+
if hasattr(mrt, attr):
|
| 404 |
+
setattr(mrt, attr, None)
|
| 405 |
+
|
| 406 |
+
# 3. Clear codec internal state
|
| 407 |
+
codec_attrs = [
|
| 408 |
+
'_encode_state', '_decode_state',
|
| 409 |
+
'_encoder_cache', '_decoder_cache',
|
| 410 |
+
'_buffer', '_frame_buffer'
|
| 411 |
+
]
|
| 412 |
+
for attr in codec_attrs:
|
| 413 |
+
if hasattr(mrt.codec, attr):
|
| 414 |
+
setattr(mrt.codec, attr, None)
|
| 415 |
+
|
| 416 |
+
# 4. Force Python garbage collection
|
| 417 |
+
import gc
|
| 418 |
+
gc.collect()
|
| 419 |
+
|
| 420 |
+
# 5. If style model has cache, clear it
|
| 421 |
+
if hasattr(mrt, 'style_model'):
|
| 422 |
+
if hasattr(mrt.style_model, 'clear_cache'):
|
| 423 |
+
mrt.style_model.clear_cache()
|
| 424 |
+
# Clear any embedding caches
|
| 425 |
+
for attr in ['_embed_cache', '_text_cache']:
|
| 426 |
+
if hasattr(mrt.style_model, attr):
|
| 427 |
+
setattr(mrt.style_model, attr, None)
|
| 428 |
+
|
| 429 |
+
logging.info("Light reset complete")
|
| 430 |
+
return True
|
| 431 |
+
|
| 432 |
+
except Exception as e:
|
| 433 |
+
logging.warning(f"Light reset partially failed (non-fatal): {e}")
|
| 434 |
+
return False
|
| 435 |
+
|
| 436 |
def _progress_update(req_id: str, n: int, total: int, stage: str = "generating"):
|
| 437 |
if not req_id:
|
| 438 |
return
|
|
|
|
| 802 |
# one-shot generation
|
| 803 |
# ----------------------------
|
| 804 |
|
| 805 |
+
@app.post("/generate/reset")
|
| 806 |
+
def generate_reset():
|
| 807 |
+
"""
|
| 808 |
+
Manually trigger a light reset of the generation system.
|
| 809 |
+
Useful if user notices quality degradation without full restart.
|
| 810 |
+
"""
|
| 811 |
+
global _GENERATE_COUNTER
|
| 812 |
+
|
| 813 |
+
with _GENERATE_COUNTER_LOCK:
|
| 814 |
+
_GENERATE_COUNTER = 0 # Reset counter
|
| 815 |
+
|
| 816 |
+
mrt = get_mrt()
|
| 817 |
+
success = _light_reset_mrt(mrt)
|
| 818 |
+
|
| 819 |
+
return {
|
| 820 |
+
"reset": success,
|
| 821 |
+
"message": "Light reset complete" if success else "Reset partially completed",
|
| 822 |
+
"counter_reset": True
|
| 823 |
+
}
|
| 824 |
+
|
| 825 |
@app.get("/progress")
|
| 826 |
def progress(request_id: str):
|
| 827 |
return _progress_get(request_id)
|
|
|
|
| 843 |
target_sample_rate: int | None = Form(None),
|
| 844 |
intro_bars_to_drop: int = Form(0),
|
| 845 |
request_id: str = Form(None),
|
| 846 |
+
force_reset: bool = Form(False), # NEW: Manual reset trigger
|
| 847 |
):
|
| 848 |
+
global _GENERATE_COUNTER
|
| 849 |
+
|
| 850 |
req_id = request_id or str(uuid.uuid4())
|
| 851 |
tmp_path = None
|
| 852 |
|
| 853 |
try:
|
| 854 |
+
# Check if we need a periodic reset
|
| 855 |
+
with _GENERATE_COUNTER_LOCK:
|
| 856 |
+
_GENERATE_COUNTER += 1
|
| 857 |
+
gen_count = _GENERATE_COUNTER
|
| 858 |
+
|
| 859 |
+
# Every 5 generations, do a light reset
|
| 860 |
+
# (Or if user explicitly requests it)
|
| 861 |
+
if gen_count % 5 == 0 or force_reset:
|
| 862 |
+
logging.info(f"[Generate {req_id}] Triggering light reset (generation #{gen_count})")
|
| 863 |
+
mrt = get_mrt()
|
| 864 |
+
_light_reset_mrt(mrt)
|
| 865 |
+
|
| 866 |
# 0) Read file -> tmp wav
|
| 867 |
data = loop_audio.file.read()
|
| 868 |
if not data:
|