thecollabagepatch commited on
Commit
c985b41
·
1 Parent(s): dd42331

manual reset for /generate. something's still accumulating...

Browse files
Files changed (1) hide show
  1. app.py +96 -0
app.py CHANGED
@@ -372,6 +372,67 @@ _MRT_LOCK = threading.Lock()
372
  _PROGRESS = {}
373
  _PROGRESS_LOCK = threading.Lock()
374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  def _progress_update(req_id: str, n: int, total: int, stage: str = "generating"):
376
  if not req_id:
377
  return
@@ -741,6 +802,26 @@ def model_select(req: ModelSelect):
741
  # one-shot generation
742
  # ----------------------------
743
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
744
  @app.get("/progress")
745
  def progress(request_id: str):
746
  return _progress_get(request_id)
@@ -762,11 +843,26 @@ def generate(
762
  target_sample_rate: int | None = Form(None),
763
  intro_bars_to_drop: int = Form(0),
764
  request_id: str = Form(None),
 
765
  ):
 
 
766
  req_id = request_id or str(uuid.uuid4())
767
  tmp_path = None
768
 
769
  try:
 
 
 
 
 
 
 
 
 
 
 
 
770
  # 0) Read file -> tmp wav
771
  data = loop_audio.file.read()
772
  if not data:
 
372
  _PROGRESS = {}
373
  _PROGRESS_LOCK = threading.Lock()
374
 
375
+ _GENERATE_COUNTER = 0
376
+ _GENERATE_COUNTER_LOCK = threading.Lock()
377
+
378
+ # In app.py, near the top with other globals
379
+ _GENERATE_COUNTER = 0
380
+ _GENERATE_COUNTER_LOCK = threading.Lock()
381
+
382
+ def _light_reset_mrt(mrt):
383
+ """
384
+ Lightweight reset that clears accumulated state without full recompilation.
385
+ Should take <1 second instead of 30 seconds.
386
+ """
387
+ import logging
388
+ logging.info("Performing light MRT reset after prolonged use...")
389
+
390
+ try:
391
+ # 1. Clear JAX device arrays (but not compiled functions)
392
+ import jax
393
+ for device in jax.devices():
394
+ # Force garbage collection on device
395
+ try:
396
+ device.clear_memory() # If available in your JAX version
397
+ except AttributeError:
398
+ pass
399
+
400
+ # 2. Clear any MRT-level caches
401
+ attrs_to_clear = ['_last_state', '_generation_cache', '_style_cache']
402
+ for attr in attrs_to_clear:
403
+ if hasattr(mrt, attr):
404
+ setattr(mrt, attr, None)
405
+
406
+ # 3. Clear codec internal state
407
+ codec_attrs = [
408
+ '_encode_state', '_decode_state',
409
+ '_encoder_cache', '_decoder_cache',
410
+ '_buffer', '_frame_buffer'
411
+ ]
412
+ for attr in codec_attrs:
413
+ if hasattr(mrt.codec, attr):
414
+ setattr(mrt.codec, attr, None)
415
+
416
+ # 4. Force Python garbage collection
417
+ import gc
418
+ gc.collect()
419
+
420
+ # 5. If style model has cache, clear it
421
+ if hasattr(mrt, 'style_model'):
422
+ if hasattr(mrt.style_model, 'clear_cache'):
423
+ mrt.style_model.clear_cache()
424
+ # Clear any embedding caches
425
+ for attr in ['_embed_cache', '_text_cache']:
426
+ if hasattr(mrt.style_model, attr):
427
+ setattr(mrt.style_model, attr, None)
428
+
429
+ logging.info("Light reset complete")
430
+ return True
431
+
432
+ except Exception as e:
433
+ logging.warning(f"Light reset partially failed (non-fatal): {e}")
434
+ return False
435
+
436
  def _progress_update(req_id: str, n: int, total: int, stage: str = "generating"):
437
  if not req_id:
438
  return
 
802
  # one-shot generation
803
  # ----------------------------
804
 
805
+ @app.post("/generate/reset")
806
+ def generate_reset():
807
+ """
808
+ Manually trigger a light reset of the generation system.
809
+ Useful if user notices quality degradation without full restart.
810
+ """
811
+ global _GENERATE_COUNTER
812
+
813
+ with _GENERATE_COUNTER_LOCK:
814
+ _GENERATE_COUNTER = 0 # Reset counter
815
+
816
+ mrt = get_mrt()
817
+ success = _light_reset_mrt(mrt)
818
+
819
+ return {
820
+ "reset": success,
821
+ "message": "Light reset complete" if success else "Reset partially completed",
822
+ "counter_reset": True
823
+ }
824
+
825
  @app.get("/progress")
826
  def progress(request_id: str):
827
  return _progress_get(request_id)
 
843
  target_sample_rate: int | None = Form(None),
844
  intro_bars_to_drop: int = Form(0),
845
  request_id: str = Form(None),
846
+ force_reset: bool = Form(False), # NEW: Manual reset trigger
847
  ):
848
+ global _GENERATE_COUNTER
849
+
850
  req_id = request_id or str(uuid.uuid4())
851
  tmp_path = None
852
 
853
  try:
854
+ # Check if we need a periodic reset
855
+ with _GENERATE_COUNTER_LOCK:
856
+ _GENERATE_COUNTER += 1
857
+ gen_count = _GENERATE_COUNTER
858
+
859
+ # Every 5 generations, do a light reset
860
+ # (Or if user explicitly requests it)
861
+ if gen_count % 5 == 0 or force_reset:
862
+ logging.info(f"[Generate {req_id}] Triggering light reset (generation #{gen_count})")
863
+ mrt = get_mrt()
864
+ _light_reset_mrt(mrt)
865
+
866
  # 0) Read file -> tmp wav
867
  data = loop_audio.file.read()
868
  if not data: