Commit
·
169ed8c
1
Parent(s):
e0bae41
reverted
Browse files- jam_worker.py +227 -181
jam_worker.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
# jam_worker.py -
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
import os
|
|
@@ -20,6 +20,7 @@ from utils import (
|
|
| 20 |
)
|
| 21 |
|
| 22 |
def _dbg_rms_dbfs(x: np.ndarray) -> float:
|
|
|
|
| 23 |
if x.ndim == 2:
|
| 24 |
x = x.mean(axis=1)
|
| 25 |
r = float(np.sqrt(np.mean(x * x) + 1e-12))
|
|
@@ -27,6 +28,7 @@ def _dbg_rms_dbfs(x: np.ndarray) -> float:
|
|
| 27 |
|
| 28 |
def _dbg_rms_dbfs_model(x: np.ndarray) -> float:
|
| 29 |
# x is model-rate, shape [S,C] or [S]
|
|
|
|
| 30 |
if x.ndim == 2:
|
| 31 |
x = x.mean(axis=1)
|
| 32 |
r = float(np.sqrt(np.mean(x * x) + 1e-12))
|
|
@@ -35,19 +37,6 @@ def _dbg_rms_dbfs_model(x: np.ndarray) -> float:
|
|
| 35 |
def _dbg_shape(x):
|
| 36 |
return tuple(x.shape) if hasattr(x, "shape") else ("-",)
|
| 37 |
|
| 38 |
-
def _is_silent(audio: np.ndarray, threshold_db: float = -60.0) -> bool:
|
| 39 |
-
"""Check if audio is effectively silent."""
|
| 40 |
-
if audio.size == 0:
|
| 41 |
-
return True
|
| 42 |
-
if audio.ndim == 2:
|
| 43 |
-
audio = audio.mean(axis=1)
|
| 44 |
-
rms = float(np.sqrt(np.mean(audio**2)))
|
| 45 |
-
return 20.0 * np.log10(max(rms, 1e-12)) < threshold_db
|
| 46 |
-
|
| 47 |
-
def _has_energy(audio: np.ndarray, threshold_db: float = -40.0) -> bool:
|
| 48 |
-
"""Check if audio has significant energy (stricter than just non-silent)."""
|
| 49 |
-
return not _is_silent(audio, threshold_db)
|
| 50 |
-
|
| 51 |
# -----------------------------
|
| 52 |
# Data classes
|
| 53 |
# -----------------------------
|
|
@@ -66,7 +55,7 @@ class JamParams:
|
|
| 66 |
guidance_weight: float = 1.1
|
| 67 |
temperature: float = 1.1
|
| 68 |
topk: int = 40
|
| 69 |
-
style_ramp_seconds: float = 8.0
|
| 70 |
|
| 71 |
|
| 72 |
@dataclass
|
|
@@ -121,6 +110,8 @@ class JamWorker(threading.Thread):
|
|
| 121 |
self.mrt.temperature = float(self.params.temperature)
|
| 122 |
self.mrt.topk = int(self.params.topk)
|
| 123 |
|
|
|
|
|
|
|
| 124 |
# codec/setup
|
| 125 |
self._codec_fps = float(self.mrt.codec.frame_rate)
|
| 126 |
JamWorker.FRAMES_PER_SECOND = self._codec_fps
|
|
@@ -146,9 +137,8 @@ class JamWorker(threading.Thread):
|
|
| 146 |
self._spool = np.zeros((0, 2), dtype=np.float32) # (S,2) target SR
|
| 147 |
self._spool_written = 0 # absolute frames written into spool
|
| 148 |
|
| 149 |
-
#
|
| 150 |
-
self.
|
| 151 |
-
self._last_good_context_tokens = None # backup of last known good context
|
| 152 |
|
| 153 |
# bar clock: start with offset 0; if you have a downbeat estimator, set base later
|
| 154 |
self._bar_clock = BarClock(self.params.target_sr, self.params.bpm, self.params.beats_per_bar, base_offset_samples=0)
|
|
@@ -173,47 +163,6 @@ class JamWorker(threading.Thread):
|
|
| 173 |
# Prepare initial context from combined loop (best musical alignment)
|
| 174 |
if self.params.combined_loop is not None:
|
| 175 |
self._install_context_from_loop(self.params.combined_loop)
|
| 176 |
-
# Save this as our "good" context backup
|
| 177 |
-
if hasattr(self.state, 'context_tokens') and self.state.context_tokens is not None:
|
| 178 |
-
self._last_good_context_tokens = np.copy(self.state.context_tokens)
|
| 179 |
-
|
| 180 |
-
# ---------- NEW: Health monitoring methods ----------
|
| 181 |
-
|
| 182 |
-
def _check_model_health(self, new_chunk: np.ndarray) -> bool:
|
| 183 |
-
"""Check if the model output looks healthy."""
|
| 184 |
-
if _is_silent(new_chunk, threshold_db=-80.0):
|
| 185 |
-
self._silence_streak += 1
|
| 186 |
-
print(f"⚠️ Silent chunk detected (streak: {self._silence_streak})")
|
| 187 |
-
return False
|
| 188 |
-
else:
|
| 189 |
-
if self._silence_streak > 0:
|
| 190 |
-
print(f"✅ Audio resumed after {self._silence_streak} silent chunks")
|
| 191 |
-
self._silence_streak = 0
|
| 192 |
-
return True
|
| 193 |
-
|
| 194 |
-
def _recover_from_silence(self):
|
| 195 |
-
"""Attempt to recover from silence by restoring last good context."""
|
| 196 |
-
print("🔧 Attempting recovery from silence...")
|
| 197 |
-
|
| 198 |
-
if self._last_good_context_tokens is not None:
|
| 199 |
-
# Restore last known good context
|
| 200 |
-
try:
|
| 201 |
-
new_state = self.mrt.init_state()
|
| 202 |
-
new_state.context_tokens = np.copy(self._last_good_context_tokens)
|
| 203 |
-
self.state = new_state
|
| 204 |
-
self._model_stream = None # Reset stream to start fresh
|
| 205 |
-
print(" Restored last good context")
|
| 206 |
-
except Exception as e:
|
| 207 |
-
print(f" Context restoration failed: {e}")
|
| 208 |
-
|
| 209 |
-
# If we have the original loop, rebuild context from it
|
| 210 |
-
elif self.params.combined_loop is not None:
|
| 211 |
-
try:
|
| 212 |
-
self._install_context_from_loop(self.params.combined_loop)
|
| 213 |
-
self._model_stream = None
|
| 214 |
-
print(" Rebuilt context from original loop")
|
| 215 |
-
except Exception as e:
|
| 216 |
-
print(f" Context rebuild failed: {e}")
|
| 217 |
|
| 218 |
# ---------- lifecycle ----------
|
| 219 |
|
|
@@ -299,7 +248,13 @@ class JamWorker(threading.Thread):
|
|
| 299 |
return toks
|
| 300 |
|
| 301 |
def _encode_exact_context_tokens(self, loop: au.Waveform) -> np.ndarray:
|
| 302 |
-
"""Build *exactly* context_length_frames worth of tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
wav = loop.as_stereo().resample(self._model_sr)
|
| 304 |
data = wav.samples.astype(np.float32, copy=False)
|
| 305 |
if data.ndim == 1:
|
|
@@ -334,14 +289,8 @@ class JamWorker(threading.Thread):
|
|
| 334 |
|
| 335 |
# final snap to *exact* ctx samples
|
| 336 |
if ctx.shape[0] < ctx_samps:
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
if ctx.shape[0] > 0:
|
| 340 |
-
fill = np.tile(ctx, (int(np.ceil(shortfall / ctx.shape[0])) + 1, 1))[:shortfall]
|
| 341 |
-
ctx = np.concatenate([fill, ctx], axis=0)
|
| 342 |
-
else:
|
| 343 |
-
print("⚠️ Zero-length context, using fallback")
|
| 344 |
-
ctx = np.zeros((ctx_samps, 2), dtype=np.float32)
|
| 345 |
elif ctx.shape[0] > ctx_samps:
|
| 346 |
ctx = ctx[-ctx_samps:]
|
| 347 |
|
|
@@ -352,20 +301,79 @@ class JamWorker(threading.Thread):
|
|
| 352 |
|
| 353 |
# Force expected (F,D) at *return time*
|
| 354 |
tokens = self._coerce_tokens(tokens)
|
| 355 |
-
|
| 356 |
-
# Validate that we don't have a silent context
|
| 357 |
-
if _is_silent(ctx, threshold_db=-80.0):
|
| 358 |
-
print("⚠️ Generated silent context - this may cause issues")
|
| 359 |
-
|
| 360 |
return tokens
|
| 361 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
def _install_context_from_loop(self, loop: au.Waveform):
|
| 363 |
# Build exact-length, bar-locked context tokens
|
| 364 |
context_tokens = self._encode_exact_context_tokens(loop)
|
| 365 |
s = self.mrt.init_state()
|
| 366 |
s.context_tokens = context_tokens
|
| 367 |
self.state = s
|
| 368 |
-
self.
|
| 369 |
|
| 370 |
def reseed_from_waveform(self, wav: au.Waveform):
|
| 371 |
"""Immediate reseed: replace context from provided wave (bar-locked, exact length)."""
|
|
@@ -375,11 +383,14 @@ class JamWorker(threading.Thread):
|
|
| 375 |
s.context_tokens = context_tokens
|
| 376 |
self.state = s
|
| 377 |
self._model_stream = None # drop model-domain continuity so next chunk starts cleanly
|
| 378 |
-
self.
|
| 379 |
-
self._silence_streak = 0 # Reset health monitoring
|
| 380 |
|
| 381 |
def reseed_splice(self, recent_wav: au.Waveform, anchor_bars: float):
|
| 382 |
-
"""Queue a *seamless* reseed by token splicing instead of full restart.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
new_ctx = self._encode_exact_context_tokens(recent_wav) # coerce to (F,D)
|
| 384 |
F, D = self._expected_token_shape()
|
| 385 |
|
|
@@ -408,20 +419,44 @@ class JamWorker(threading.Thread):
|
|
| 408 |
"tokens": spliced,
|
| 409 |
"debug": {"F": F, "D": D, "splice_frames": splice_frames, "frames_per_bar": frames_per_bar}
|
| 410 |
}
|
|
|
|
| 411 |
|
| 412 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
|
| 414 |
def _append_model_chunk_and_spool(self, wav: au.Waveform) -> None:
|
| 415 |
"""
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
|
|
|
|
|
|
| 423 |
"""
|
| 424 |
-
|
|
|
|
| 425 |
s = wav.samples.astype(np.float32, copy=False)
|
| 426 |
if s.ndim == 1:
|
| 427 |
s = s[:, None]
|
|
@@ -429,103 +464,119 @@ class JamWorker(threading.Thread):
|
|
| 429 |
if n_samps == 0:
|
| 430 |
return
|
| 431 |
|
| 432 |
-
#
|
| 433 |
-
is_healthy = self._check_model_health(s)
|
| 434 |
-
is_very_quiet = _is_silent(s, threshold_db=-50.0) # stricter than default -60
|
| 435 |
-
|
| 436 |
-
# Get crossfade params
|
| 437 |
try:
|
| 438 |
xfade_s = float(self.mrt.config.crossfade_length)
|
| 439 |
except Exception:
|
| 440 |
xfade_s = 0.0
|
| 441 |
xfade_n = int(round(max(0.0, xfade_s) * float(self._model_sr)))
|
| 442 |
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
# --- REJECT PROBLEMATIC CHUNKS ---
|
| 446 |
-
if not is_healthy or is_very_quiet:
|
| 447 |
-
print(f"[REJECT] Discarding unhealthy/quiet chunk - not adding to spool or model stream")
|
| 448 |
-
|
| 449 |
-
# Trigger recovery immediately on first bad chunk
|
| 450 |
-
if self._silence_streak >= 1:
|
| 451 |
-
self._recover_from_silence()
|
| 452 |
-
|
| 453 |
-
# Don't process this chunk at all - return early
|
| 454 |
-
return
|
| 455 |
-
|
| 456 |
-
# Reset silence streak on good chunk
|
| 457 |
-
if self._silence_streak > 0:
|
| 458 |
-
print(f"✅ Audio resumed after {self._silence_streak} rejected chunks")
|
| 459 |
-
self._silence_streak = 0
|
| 460 |
-
|
| 461 |
-
# Helper: resample to target SR
|
| 462 |
def to_target(y: np.ndarray) -> np.ndarray:
|
| 463 |
return y if self._rs is None else self._rs.process(y, final=False)
|
| 464 |
|
| 465 |
-
#
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
else:
|
| 482 |
-
#
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 506 |
if xfade_n > 0 and n_samps >= xfade_n:
|
| 507 |
-
|
| 508 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 509 |
else:
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
target_audio = to_target(new_audio)
|
| 515 |
-
if target_audio.shape[0] > 0:
|
| 516 |
-
print(f"[append] body len={target_audio.shape[0]} rms={_dbg_rms_dbfs(target_audio):+.1f} dBFS")
|
| 517 |
-
self._spool = np.concatenate([self._spool, target_audio], axis=0) if self._spool.size else target_audio
|
| 518 |
-
self._spool_written += target_audio.shape[0]
|
| 519 |
-
|
| 520 |
-
# --- SAVE GOOD CONTEXT ---
|
| 521 |
-
# Only save context from healthy chunks
|
| 522 |
-
if hasattr(self.state, 'context_tokens') and self.state.context_tokens is not None:
|
| 523 |
-
self._last_good_context_tokens = np.copy(self.state.context_tokens)
|
| 524 |
-
|
| 525 |
-
# Trim model stream to reasonable length (keep ~30 seconds)
|
| 526 |
-
max_model_samples = int(30.0 * self._model_sr)
|
| 527 |
-
if self._model_stream.shape[0] > max_model_samples:
|
| 528 |
-
self._model_stream = self._model_stream[-max_model_samples:]
|
| 529 |
|
| 530 |
def _should_generate_next_chunk(self) -> bool:
|
| 531 |
# Allow running ahead relative to whichever is larger: last *consumed*
|
|
@@ -562,7 +613,6 @@ class JamWorker(threading.Thread):
|
|
| 562 |
"guidance_weight": float(self.params.guidance_weight),
|
| 563 |
"temperature": float(self.params.temperature),
|
| 564 |
"topk": int(self.params.topk),
|
| 565 |
-
"silence_streak": self._silence_streak, # Add health info
|
| 566 |
}
|
| 567 |
chunk = JamChunk(index=self.idx, audio_base64=audio_b64, metadata=meta)
|
| 568 |
|
|
@@ -587,7 +637,6 @@ class JamWorker(threading.Thread):
|
|
| 587 |
# inplace update (no reset)
|
| 588 |
self.state.context_tokens = spliced
|
| 589 |
self._pending_token_splice = None
|
| 590 |
-
print("[reseed] Token splice applied")
|
| 591 |
except Exception:
|
| 592 |
# fallback: full reseed using spliced tokens
|
| 593 |
new_state = self.mrt.init_state()
|
|
@@ -595,7 +644,6 @@ class JamWorker(threading.Thread):
|
|
| 595 |
self.state = new_state
|
| 596 |
self._model_stream = None
|
| 597 |
self._pending_token_splice = None
|
| 598 |
-
print("[reseed] Token splice fallback to full reset")
|
| 599 |
elif self._pending_reseed is not None:
|
| 600 |
ctx = self._coerce_tokens(self._pending_reseed["ctx"])
|
| 601 |
new_state = self.mrt.init_state()
|
|
@@ -603,7 +651,6 @@ class JamWorker(threading.Thread):
|
|
| 603 |
self.state = new_state
|
| 604 |
self._model_stream = None
|
| 605 |
self._pending_reseed = None
|
| 606 |
-
print("[reseed] Full reseed applied")
|
| 607 |
|
| 608 |
# ---------- main loop ----------
|
| 609 |
|
|
@@ -640,10 +687,9 @@ class JamWorker(threading.Thread):
|
|
| 640 |
self._emit_ready()
|
| 641 |
|
| 642 |
# finalize resampler (flush) — not strictly necessary here
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
self._spool_written += tail.shape[0]
|
| 648 |
# one last emit attempt
|
| 649 |
-
self._emit_ready()
|
|
|
|
| 1 |
+
# jam_worker.py - Bar-locked spool rewrite
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
import os
|
|
|
|
| 20 |
)
|
| 21 |
|
| 22 |
def _dbg_rms_dbfs(x: np.ndarray) -> float:
|
| 23 |
+
|
| 24 |
if x.ndim == 2:
|
| 25 |
x = x.mean(axis=1)
|
| 26 |
r = float(np.sqrt(np.mean(x * x) + 1e-12))
|
|
|
|
| 28 |
|
| 29 |
def _dbg_rms_dbfs_model(x: np.ndarray) -> float:
|
| 30 |
# x is model-rate, shape [S,C] or [S]
|
| 31 |
+
|
| 32 |
if x.ndim == 2:
|
| 33 |
x = x.mean(axis=1)
|
| 34 |
r = float(np.sqrt(np.mean(x * x) + 1e-12))
|
|
|
|
| 37 |
def _dbg_shape(x):
|
| 38 |
return tuple(x.shape) if hasattr(x, "shape") else ("-",)
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
# -----------------------------
|
| 41 |
# Data classes
|
| 42 |
# -----------------------------
|
|
|
|
| 55 |
guidance_weight: float = 1.1
|
| 56 |
temperature: float = 1.1
|
| 57 |
topk: int = 40
|
| 58 |
+
style_ramp_seconds: float = 8.0 # 0 => instant (current behavior), try 6.0–10.0 for gentle glides
|
| 59 |
|
| 60 |
|
| 61 |
@dataclass
|
|
|
|
| 110 |
self.mrt.temperature = float(self.params.temperature)
|
| 111 |
self.mrt.topk = int(self.params.topk)
|
| 112 |
|
| 113 |
+
|
| 114 |
+
|
| 115 |
# codec/setup
|
| 116 |
self._codec_fps = float(self.mrt.codec.frame_rate)
|
| 117 |
JamWorker.FRAMES_PER_SECOND = self._codec_fps
|
|
|
|
| 137 |
self._spool = np.zeros((0, 2), dtype=np.float32) # (S,2) target SR
|
| 138 |
self._spool_written = 0 # absolute frames written into spool
|
| 139 |
|
| 140 |
+
self._pending_tail_model = None # type: Optional[np.ndarray] # last tail at model SR
|
| 141 |
+
self._pending_tail_target_len = 0 # number of target-SR samples last tail contributed
|
|
|
|
| 142 |
|
| 143 |
# bar clock: start with offset 0; if you have a downbeat estimator, set base later
|
| 144 |
self._bar_clock = BarClock(self.params.target_sr, self.params.bpm, self.params.beats_per_bar, base_offset_samples=0)
|
|
|
|
| 163 |
# Prepare initial context from combined loop (best musical alignment)
|
| 164 |
if self.params.combined_loop is not None:
|
| 165 |
self._install_context_from_loop(self.params.combined_loop)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
# ---------- lifecycle ----------
|
| 168 |
|
|
|
|
| 248 |
return toks
|
| 249 |
|
| 250 |
def _encode_exact_context_tokens(self, loop: au.Waveform) -> np.ndarray:
|
| 251 |
+
"""Build *exactly* context_length_frames worth of tokens (e.g., 250 @ 25fps),
|
| 252 |
+
while ensuring the *end* of the audio lands on a bar boundary.
|
| 253 |
+
Strategy: take the largest integer number of bars <= ctx_seconds as the tail,
|
| 254 |
+
then left-fill from just before that tail (wrapping if needed) to reach exactly
|
| 255 |
+
ctx_seconds; finally, pad/trim to exact samples and, as a last resort, pad/trim
|
| 256 |
+
tokens to the expected frame count.
|
| 257 |
+
"""
|
| 258 |
wav = loop.as_stereo().resample(self._model_sr)
|
| 259 |
data = wav.samples.astype(np.float32, copy=False)
|
| 260 |
if data.ndim == 1:
|
|
|
|
| 289 |
|
| 290 |
# final snap to *exact* ctx samples
|
| 291 |
if ctx.shape[0] < ctx_samps:
|
| 292 |
+
pad = np.zeros((ctx_samps - ctx.shape[0], ctx.shape[1]), dtype=np.float32)
|
| 293 |
+
ctx = np.concatenate([pad, ctx], axis=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
elif ctx.shape[0] > ctx_samps:
|
| 295 |
ctx = ctx[-ctx_samps:]
|
| 296 |
|
|
|
|
| 301 |
|
| 302 |
# Force expected (F,D) at *return time*
|
| 303 |
tokens = self._coerce_tokens(tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
return tokens
|
| 305 |
|
| 306 |
+
def _encode_exact_context_tokens(self, loop: au.Waveform) -> np.ndarray:
|
| 307 |
+
"""Build *exactly* context_length_frames worth of tokens (e.g., 250 @ 25fps),
|
| 308 |
+
while ensuring the *end* of the audio lands on a bar boundary.
|
| 309 |
+
Strategy: take the largest integer number of bars <= ctx_seconds as the tail,
|
| 310 |
+
then left-fill from just before that tail (wrapping if needed) to reach exactly
|
| 311 |
+
ctx_seconds; finally, pad/trim to exact samples and, as a last resort, pad/trim
|
| 312 |
+
tokens to the expected frame count.
|
| 313 |
+
"""
|
| 314 |
+
wav = loop.as_stereo().resample(self._model_sr)
|
| 315 |
+
data = wav.samples.astype(np.float32, copy=False)
|
| 316 |
+
if data.ndim == 1:
|
| 317 |
+
data = data[:, None]
|
| 318 |
+
|
| 319 |
+
spb = self._bar_clock.seconds_per_bar()
|
| 320 |
+
ctx_sec = float(self._ctx_seconds)
|
| 321 |
+
sr = int(self._model_sr)
|
| 322 |
+
|
| 323 |
+
# bars that fit fully inside ctx_sec (at least 1)
|
| 324 |
+
bars_fit = max(1, int(ctx_sec // spb))
|
| 325 |
+
tail_len_samps = int(round(bars_fit * spb * sr))
|
| 326 |
+
|
| 327 |
+
# ensure we have enough source by tiling
|
| 328 |
+
need = int(round(ctx_sec * sr)) + tail_len_samps
|
| 329 |
+
if data.shape[0] == 0:
|
| 330 |
+
data = np.zeros((1, 2), dtype=np.float32)
|
| 331 |
+
reps = int(np.ceil(need / float(data.shape[0])))
|
| 332 |
+
tiled = np.tile(data, (reps, 1))
|
| 333 |
+
|
| 334 |
+
end = tiled.shape[0]
|
| 335 |
+
tail = tiled[end - tail_len_samps:end]
|
| 336 |
+
|
| 337 |
+
# left-fill to reach exact ctx samples (keeps end-of-bar alignment)
|
| 338 |
+
ctx_samps = int(round(ctx_sec * sr))
|
| 339 |
+
pad_len = ctx_samps - tail.shape[0]
|
| 340 |
+
if pad_len > 0:
|
| 341 |
+
pre = tiled[end - tail_len_samps - pad_len:end - tail_len_samps]
|
| 342 |
+
ctx = np.concatenate([pre, tail], axis=0)
|
| 343 |
+
else:
|
| 344 |
+
ctx = tail[-ctx_samps:]
|
| 345 |
+
|
| 346 |
+
# final snap to *exact* ctx samples
|
| 347 |
+
if ctx.shape[0] < ctx_samps:
|
| 348 |
+
pad = np.zeros((ctx_samps - ctx.shape[0], ctx.shape[1]), dtype=np.float32)
|
| 349 |
+
ctx = np.concatenate([pad, ctx], axis=0)
|
| 350 |
+
elif ctx.shape[0] > ctx_samps:
|
| 351 |
+
ctx = ctx[-ctx_samps:]
|
| 352 |
+
|
| 353 |
+
exact = au.Waveform(ctx, sr)
|
| 354 |
+
tokens_full = self.mrt.codec.encode(exact).astype(np.int32)
|
| 355 |
+
depth = int(self.mrt.config.decoder_codec_rvq_depth)
|
| 356 |
+
tokens = tokens_full[:, :depth]
|
| 357 |
+
|
| 358 |
+
# Last defense: force expected frame count
|
| 359 |
+
frames = tokens.shape[0]
|
| 360 |
+
exp = int(self._ctx_frames)
|
| 361 |
+
if frames < exp:
|
| 362 |
+
# repeat last frame
|
| 363 |
+
pad = np.repeat(tokens[-1:, :], exp - frames, axis=0)
|
| 364 |
+
tokens = np.concatenate([pad, tokens], axis=0)
|
| 365 |
+
elif frames > exp:
|
| 366 |
+
tokens = tokens[-exp:, :]
|
| 367 |
+
return tokens
|
| 368 |
+
|
| 369 |
+
|
| 370 |
def _install_context_from_loop(self, loop: au.Waveform):
|
| 371 |
# Build exact-length, bar-locked context tokens
|
| 372 |
context_tokens = self._encode_exact_context_tokens(loop)
|
| 373 |
s = self.mrt.init_state()
|
| 374 |
s.context_tokens = context_tokens
|
| 375 |
self.state = s
|
| 376 |
+
self._original_context_tokens = np.copy(context_tokens)
|
| 377 |
|
| 378 |
def reseed_from_waveform(self, wav: au.Waveform):
|
| 379 |
"""Immediate reseed: replace context from provided wave (bar-locked, exact length)."""
|
|
|
|
| 383 |
s.context_tokens = context_tokens
|
| 384 |
self.state = s
|
| 385 |
self._model_stream = None # drop model-domain continuity so next chunk starts cleanly
|
| 386 |
+
self._original_context_tokens = np.copy(context_tokens)
|
|
|
|
| 387 |
|
| 388 |
def reseed_splice(self, recent_wav: au.Waveform, anchor_bars: float):
|
| 389 |
+
"""Queue a *seamless* reseed by token splicing instead of full restart.
|
| 390 |
+
We compute a fresh, bar-locked context token tensor of exact length
|
| 391 |
+
(e.g., 250 frames), then splice only the *tail* corresponding to
|
| 392 |
+
`anchor_bars` so generation continues smoothly without resetting state.
|
| 393 |
+
"""
|
| 394 |
new_ctx = self._encode_exact_context_tokens(recent_wav) # coerce to (F,D)
|
| 395 |
F, D = self._expected_token_shape()
|
| 396 |
|
|
|
|
| 419 |
"tokens": spliced,
|
| 420 |
"debug": {"F": F, "D": D, "splice_frames": splice_frames, "frames_per_bar": frames_per_bar}
|
| 421 |
}
|
| 422 |
+
|
| 423 |
|
| 424 |
+
|
| 425 |
+
def reseed_from_waveform(self, wav: au.Waveform):
|
| 426 |
+
"""Immediate reseed: replace context from provided wave (bar-aligned tail)."""
|
| 427 |
+
wav = wav.as_stereo().resample(self._model_sr)
|
| 428 |
+
tail = take_bar_aligned_tail(wav, self.params.bpm, self.params.beats_per_bar, self._ctx_seconds)
|
| 429 |
+
tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
|
| 430 |
+
depth = int(self.mrt.config.decoder_codec_rvq_depth)
|
| 431 |
+
context_tokens = tokens_full[:, :depth]
|
| 432 |
+
|
| 433 |
+
s = self.mrt.init_state()
|
| 434 |
+
s.context_tokens = context_tokens
|
| 435 |
+
self.state = s
|
| 436 |
+
# reset model stream so next generate starts cleanly
|
| 437 |
+
self._model_stream = None
|
| 438 |
+
|
| 439 |
+
# optional loudness match will be applied per-chunk on emission
|
| 440 |
+
|
| 441 |
+
# also remember this as new "original"
|
| 442 |
+
self._original_context_tokens = np.copy(context_tokens)
|
| 443 |
+
|
| 444 |
+
# ---------- core streaming helpers ----------
|
| 445 |
|
| 446 |
def _append_model_chunk_and_spool(self, wav: au.Waveform) -> None:
|
| 447 |
"""
|
| 448 |
+
Conservative boundary fix:
|
| 449 |
+
- Emit body+tail immediately (target SR), unchanged from your original behavior.
|
| 450 |
+
- On *next* call, compute the mixed overlap (prev tail ⨉ cos + new head ⨉ sin),
|
| 451 |
+
resample it, and overwrite the last `_pending_tail_target_len` samples in the
|
| 452 |
+
target-SR spool with that mixed overlap. Then emit THIS chunk's body+tail and
|
| 453 |
+
remember THIS chunk's tail length at target SR for the next correction.
|
| 454 |
+
|
| 455 |
+
This keeps external timing and bar alignment identical, but removes the audible
|
| 456 |
+
fade-to-zero at chunk ends.
|
| 457 |
"""
|
| 458 |
+
|
| 459 |
+
# ---- unpack model-rate samples ----
|
| 460 |
s = wav.samples.astype(np.float32, copy=False)
|
| 461 |
if s.ndim == 1:
|
| 462 |
s = s[:, None]
|
|
|
|
| 464 |
if n_samps == 0:
|
| 465 |
return
|
| 466 |
|
| 467 |
+
# crossfade length in model samples
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
try:
|
| 469 |
xfade_s = float(self.mrt.config.crossfade_length)
|
| 470 |
except Exception:
|
| 471 |
xfade_s = 0.0
|
| 472 |
xfade_n = int(round(max(0.0, xfade_s) * float(self._model_sr)))
|
| 473 |
|
| 474 |
+
# helper: resample to target SR via your streaming resampler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
def to_target(y: np.ndarray) -> np.ndarray:
|
| 476 |
return y if self._rs is None else self._rs.process(y, final=False)
|
| 477 |
|
| 478 |
+
# ------------------------------------------
|
| 479 |
+
# (A) If we have a pending model tail, fix the last emitted tail at target SR
|
| 480 |
+
# ------------------------------------------
|
| 481 |
+
if self._pending_tail_model is not None and self._pending_tail_model.shape[0] == xfade_n and xfade_n > 0 and n_samps >= xfade_n:
|
| 482 |
+
head = s[:xfade_n, :]
|
| 483 |
+
|
| 484 |
+
print(f"[model] head len={head.shape[0]} rms={_dbg_rms_dbfs_model(head):+.1f} dBFS")
|
| 485 |
+
|
| 486 |
+
t = np.linspace(0.0, np.pi/2.0, xfade_n, endpoint=False, dtype=np.float32)[:, None]
|
| 487 |
+
cosw = np.cos(t, dtype=np.float32)
|
| 488 |
+
sinw = np.sin(t, dtype=np.float32)
|
| 489 |
+
mixed_model = (self._pending_tail_model * cosw) + (head * sinw) # [xfade_n, C] at model SR
|
| 490 |
+
|
| 491 |
+
y_mixed = to_target(mixed_model.astype(np.float32))
|
| 492 |
+
Lcorr = int(y_mixed.shape[0]) # exact target-SR samples to write
|
| 493 |
+
|
| 494 |
+
# DEBUG: corrected overlap RMS (what we intend to hear at the boundary)
|
| 495 |
+
if y_mixed.size:
|
| 496 |
+
print(f"[append] mixedOverlap len={y_mixed.shape[0]} rms={_dbg_rms_dbfs(y_mixed):+.1f} dBFS")
|
| 497 |
+
|
| 498 |
+
# Overwrite the last `_pending_tail_target_len` samples of the spool with `y_mixed`.
|
| 499 |
+
# Use the *smaller* of the two lengths to be safe.
|
| 500 |
+
Lpop = min(self._pending_tail_target_len, self._spool.shape[0], Lcorr)
|
| 501 |
+
if Lpop > 0 and self._spool.size:
|
| 502 |
+
# Trim last Lpop samples
|
| 503 |
+
self._spool = self._spool[:-Lpop, :]
|
| 504 |
+
self._spool_written -= Lpop
|
| 505 |
+
# Append corrected overlap (trim/pad to Lpop to avoid drift)
|
| 506 |
+
if Lcorr != Lpop:
|
| 507 |
+
if Lcorr > Lpop:
|
| 508 |
+
y_m = y_mixed[-Lpop:, :]
|
| 509 |
+
else:
|
| 510 |
+
pad = np.zeros((Lpop - Lcorr, y_mixed.shape[1]), dtype=np.float32)
|
| 511 |
+
y_m = np.concatenate([y_mixed, pad], axis=0)
|
| 512 |
+
else:
|
| 513 |
+
y_m = y_mixed
|
| 514 |
+
self._spool = np.concatenate([self._spool, y_m], axis=0) if self._spool.size else y_m
|
| 515 |
+
self._spool_written += y_m.shape[0]
|
| 516 |
+
|
| 517 |
+
# For internal continuity, update _model_stream like before
|
| 518 |
+
if self._model_stream is None or self._model_stream.shape[0] < xfade_n:
|
| 519 |
+
self._model_stream = s[xfade_n:].copy()
|
| 520 |
+
else:
|
| 521 |
+
self._model_stream = np.concatenate([self._model_stream[:-xfade_n], mixed_model, s[xfade_n:]], axis=0)
|
| 522 |
else:
|
| 523 |
+
# First-ever call or too-short to mix: maintain _model_stream minimally
|
| 524 |
+
if xfade_n > 0 and n_samps > xfade_n:
|
| 525 |
+
self._model_stream = s[xfade_n:].copy() if self._model_stream is None else np.concatenate([self._model_stream, s[xfade_n:]], axis=0)
|
| 526 |
+
else:
|
| 527 |
+
self._model_stream = s.copy() if self._model_stream is None else np.concatenate([self._model_stream, s], axis=0)
|
| 528 |
+
|
| 529 |
+
# ------------------------------------------
|
| 530 |
+
# (B) Emit THIS chunk's body and tail (same external behavior)
|
| 531 |
+
# ------------------------------------------
|
| 532 |
+
if xfade_n > 0 and n_samps >= (2 * xfade_n):
|
| 533 |
+
body = s[xfade_n:-xfade_n, :]
|
| 534 |
+
print(f"[model] body len={body.shape[0]} rms={_dbg_rms_dbfs_model(body):+.1f} dBFS")
|
| 535 |
+
if body.size:
|
| 536 |
+
y_body = to_target(body.astype(np.float32))
|
| 537 |
+
if y_body.size:
|
| 538 |
+
# DEBUG: body RMS we are actually appending
|
| 539 |
+
print(f"[append] body len={y_body.shape[0]} rms={_dbg_rms_dbfs(y_body):+.1f} dBFS")
|
| 540 |
+
self._spool = np.concatenate([self._spool, y_body], axis=0) if self._spool.size else y_body
|
| 541 |
+
self._spool_written += y_body.shape[0]
|
| 542 |
+
else:
|
| 543 |
+
# If chunk too short for head+tail split, treat all (minus preroll) as body
|
| 544 |
+
if xfade_n > 0 and n_samps > xfade_n:
|
| 545 |
+
body = s[xfade_n:, :]
|
| 546 |
+
print(f"[model] body(S) len={body.shape[0]} rms={_dbg_rms_dbfs_model(body):+.1f} dBFS")
|
| 547 |
+
y_body = to_target(body.astype(np.float32))
|
| 548 |
+
if y_body.size:
|
| 549 |
+
# DEBUG: body RMS in short-chunk path
|
| 550 |
+
print(f"[append] body(len=short) len={y_body.shape[0]} rms={_dbg_rms_dbfs(y_body):+.1f} dBFS")
|
| 551 |
+
self._spool = np.concatenate([self._spool, y_body], axis=0) if self._spool.size else y_body
|
| 552 |
+
self._spool_written += y_body.shape[0]
|
| 553 |
+
# No tail to remember this round
|
| 554 |
+
self._pending_tail_model = None
|
| 555 |
+
self._pending_tail_target_len = 0
|
| 556 |
+
return
|
| 557 |
+
|
| 558 |
+
# Tail (always remember how many TARGET samples we append)
|
| 559 |
if xfade_n > 0 and n_samps >= xfade_n:
|
| 560 |
+
tail = s[-xfade_n:, :]
|
| 561 |
+
print(f"[model] tail len={tail.shape[0]} rms={_dbg_rms_dbfs_model(tail):+.1f} dBFS")
|
| 562 |
+
y_tail = to_target(tail.astype(np.float32))
|
| 563 |
+
Ltail = int(y_tail.shape[0])
|
| 564 |
+
if Ltail:
|
| 565 |
+
# DEBUG: tail RMS we are appending now (to be corrected next call)
|
| 566 |
+
print(f"[append] tail len={y_tail.shape[0]} rms={_dbg_rms_dbfs(y_tail):+.1f} dBFS")
|
| 567 |
+
self._spool = np.concatenate([self._spool, y_tail], axis=0) if self._spool.size else y_tail
|
| 568 |
+
self._spool_written += Ltail
|
| 569 |
+
self._pending_tail_model = tail.copy()
|
| 570 |
+
self._pending_tail_target_len = Ltail
|
| 571 |
+
else:
|
| 572 |
+
# Nothing appended (resampler returned nothing yet) — keep model tail but mark zero target len
|
| 573 |
+
self._pending_tail_model = tail.copy()
|
| 574 |
+
self._pending_tail_target_len = 0
|
| 575 |
else:
|
| 576 |
+
self._pending_tail_model = None
|
| 577 |
+
self._pending_tail_target_len = 0
|
| 578 |
+
|
| 579 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 580 |
|
| 581 |
def _should_generate_next_chunk(self) -> bool:
|
| 582 |
# Allow running ahead relative to whichever is larger: last *consumed*
|
|
|
|
| 613 |
"guidance_weight": float(self.params.guidance_weight),
|
| 614 |
"temperature": float(self.params.temperature),
|
| 615 |
"topk": int(self.params.topk),
|
|
|
|
| 616 |
}
|
| 617 |
chunk = JamChunk(index=self.idx, audio_base64=audio_b64, metadata=meta)
|
| 618 |
|
|
|
|
| 637 |
# inplace update (no reset)
|
| 638 |
self.state.context_tokens = spliced
|
| 639 |
self._pending_token_splice = None
|
|
|
|
| 640 |
except Exception:
|
| 641 |
# fallback: full reseed using spliced tokens
|
| 642 |
new_state = self.mrt.init_state()
|
|
|
|
| 644 |
self.state = new_state
|
| 645 |
self._model_stream = None
|
| 646 |
self._pending_token_splice = None
|
|
|
|
| 647 |
elif self._pending_reseed is not None:
|
| 648 |
ctx = self._coerce_tokens(self._pending_reseed["ctx"])
|
| 649 |
new_state = self.mrt.init_state()
|
|
|
|
| 651 |
self.state = new_state
|
| 652 |
self._model_stream = None
|
| 653 |
self._pending_reseed = None
|
|
|
|
| 654 |
|
| 655 |
# ---------- main loop ----------
|
| 656 |
|
|
|
|
| 687 |
self._emit_ready()
|
| 688 |
|
| 689 |
# finalize resampler (flush) — not strictly necessary here
|
| 690 |
+
tail = self._rs.process(np.zeros((0,2), np.float32), final=True)
|
| 691 |
+
if tail.size:
|
| 692 |
+
self._spool = np.concatenate([self._spool, tail], axis=0)
|
| 693 |
+
self._spool_written += tail.shape[0]
|
|
|
|
| 694 |
# one last emit attempt
|
| 695 |
+
self._emit_ready()
|