Commit
·
946680d
1
Parent(s):
2d4dad3
reseed splice fix
Browse files- jam_worker.py +63 -6
jam_worker.py
CHANGED
|
@@ -71,6 +71,7 @@ class BarClock:
|
|
| 71 |
# -----------------------------
|
| 72 |
|
| 73 |
class JamWorker(threading.Thread):
|
|
|
|
| 74 |
"""Generates continuous audio with MagentaRT, spools it at target SR,
|
| 75 |
and emits *sample-accurate*, bar-aligned chunks (no FPS drift)."""
|
| 76 |
|
|
@@ -93,6 +94,7 @@ class JamWorker(threading.Thread):
|
|
| 93 |
|
| 94 |
# codec/setup
|
| 95 |
self._codec_fps = float(self.mrt.codec.frame_rate)
|
|
|
|
| 96 |
self._ctx_frames = int(self.mrt.config.context_length_frames)
|
| 97 |
self._ctx_seconds = self._ctx_frames / self._codec_fps
|
| 98 |
|
|
@@ -121,8 +123,9 @@ class JamWorker(threading.Thread):
|
|
| 121 |
self._stop_event = threading.Event()
|
| 122 |
self._max_buffer_ahead = 5
|
| 123 |
|
| 124 |
-
# reseed
|
| 125 |
-
self._pending_reseed: Optional[dict] = None
|
|
|
|
| 126 |
|
| 127 |
# Prepare initial context from combined loop (best musical alignment)
|
| 128 |
if self.params.combined_loop is not None:
|
|
@@ -254,10 +257,49 @@ class JamWorker(threading.Thread):
|
|
| 254 |
self._original_context_tokens = np.copy(context_tokens)
|
| 255 |
|
| 256 |
def reseed_splice(self, recent_wav: au.Waveform, anchor_bars: float):
|
| 257 |
-
"""Queue a
|
| 258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
with self._lock:
|
| 260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
|
| 262 |
|
| 263 |
def reseed_from_waveform(self, wav: au.Waveform):
|
|
@@ -376,7 +418,22 @@ class JamWorker(threading.Thread):
|
|
| 376 |
|
| 377 |
# If a reseed is queued, install it *right after* we finish a chunk
|
| 378 |
with self._lock:
|
| 379 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
new_state = self.mrt.init_state()
|
| 381 |
new_state.context_tokens = self._pending_reseed["ctx"]
|
| 382 |
self.state = new_state
|
|
|
|
| 71 |
# -----------------------------
|
| 72 |
|
| 73 |
class JamWorker(threading.Thread):
|
| 74 |
+
FRAMES_PER_SECOND: float | None = None # filled in __init__ once codec is available
|
| 75 |
"""Generates continuous audio with MagentaRT, spools it at target SR,
|
| 76 |
and emits *sample-accurate*, bar-aligned chunks (no FPS drift)."""
|
| 77 |
|
|
|
|
| 94 |
|
| 95 |
# codec/setup
|
| 96 |
self._codec_fps = float(self.mrt.codec.frame_rate)
|
| 97 |
+
JamWorker.FRAMES_PER_SECOND = self._codec_fps
|
| 98 |
self._ctx_frames = int(self.mrt.config.context_length_frames)
|
| 99 |
self._ctx_seconds = self._ctx_frames / self._codec_fps
|
| 100 |
|
|
|
|
| 123 |
self._stop_event = threading.Event()
|
| 124 |
self._max_buffer_ahead = 5
|
| 125 |
|
| 126 |
+
# reseed queues (install at next bar boundary after emission)
|
| 127 |
+
self._pending_reseed: Optional[dict] = None # legacy full reset path (kept for fallback)
|
| 128 |
+
self._pending_token_splice: Optional[dict] = None # seamless token splice
|
| 129 |
|
| 130 |
# Prepare initial context from combined loop (best musical alignment)
|
| 131 |
if self.params.combined_loop is not None:
|
|
|
|
| 257 |
self._original_context_tokens = np.copy(context_tokens)
|
| 258 |
|
| 259 |
def reseed_splice(self, recent_wav: au.Waveform, anchor_bars: float):
|
| 260 |
+
"""Queue a *seamless* reseed by token splicing instead of full restart.
|
| 261 |
+
We compute a fresh, bar-locked context token tensor of exact length
|
| 262 |
+
(e.g., 250 frames), then splice only the *tail* corresponding to
|
| 263 |
+
`anchor_bars` so generation continues smoothly without resetting state.
|
| 264 |
+
"""
|
| 265 |
+
new_ctx = self._encode_exact_context_tokens(recent_wav) # (F,D)
|
| 266 |
+
F = int(self._ctx_frames)
|
| 267 |
+
D = int(self.mrt.config.decoder_codec_rvq_depth)
|
| 268 |
+
assert new_ctx.shape == (F, D), f"expected {(F, D)}, got {new_ctx.shape}"
|
| 269 |
+
|
| 270 |
+
# how many frames correspond to the requested anchor bars
|
| 271 |
+
spb = self._bar_clock.seconds_per_bar()
|
| 272 |
+
frames_per_bar = int(round(self._codec_fps * spb))
|
| 273 |
+
splice_frames = int(round(max(1, anchor_bars) * frames_per_bar))
|
| 274 |
+
splice_frames = max(1, min(splice_frames, F))
|
| 275 |
+
|
| 276 |
with self._lock:
|
| 277 |
+
# snapshot current context
|
| 278 |
+
cur = getattr(self.state, "context_tokens", None)
|
| 279 |
+
if cur is None:
|
| 280 |
+
# if state has no context yet, fall back to full reseed
|
| 281 |
+
self._pending_reseed = {"ctx": new_ctx}
|
| 282 |
+
return
|
| 283 |
+
if cur.shape != (F, D):
|
| 284 |
+
# safety: coerce by trim/pad
|
| 285 |
+
if cur.shape[0] > F:
|
| 286 |
+
cur = cur[-F:, :]
|
| 287 |
+
elif cur.shape[0] < F:
|
| 288 |
+
pad = np.repeat(cur[0:1, :], F - cur.shape[0], axis=0)
|
| 289 |
+
cur = np.concatenate([pad, cur], axis=0)
|
| 290 |
+
if cur.shape[1] != D:
|
| 291 |
+
cur = cur[:, :D]
|
| 292 |
+
|
| 293 |
+
# build the spliced tensor: keep left (F - splice) from cur, take right (splice) from new
|
| 294 |
+
left = cur[:F - splice_frames, :]
|
| 295 |
+
right = new_ctx[F - splice_frames:, :]
|
| 296 |
+
spliced = np.concatenate([left, right], axis=0)
|
| 297 |
+
|
| 298 |
+
# queue for install at the *next bar boundary* right after emission
|
| 299 |
+
self._pending_token_splice = {
|
| 300 |
+
"tokens": spliced,
|
| 301 |
+
"debug": {"F": F, "D": D, "splice_frames": splice_frames, "frames_per_bar": frames_per_bar}
|
| 302 |
+
}
|
| 303 |
|
| 304 |
|
| 305 |
def reseed_from_waveform(self, wav: au.Waveform):
|
|
|
|
| 418 |
|
| 419 |
# If a reseed is queued, install it *right after* we finish a chunk
|
| 420 |
with self._lock:
|
| 421 |
+
# Prefer seamless token splice when available
|
| 422 |
+
if self._pending_token_splice is not None:
|
| 423 |
+
try:
|
| 424 |
+
spliced = self._pending_token_splice["tokens"]
|
| 425 |
+
self.state.context_tokens = spliced # in-place, no reset
|
| 426 |
+
self._pending_token_splice = None
|
| 427 |
+
# do NOT reset self._model_stream — keep continuity
|
| 428 |
+
# leave params/style as-is
|
| 429 |
+
except Exception as e:
|
| 430 |
+
# fallback: full reseed if setter rejects
|
| 431 |
+
new_state = self.mrt.init_state()
|
| 432 |
+
new_state.context_tokens = spliced
|
| 433 |
+
self.state = new_state
|
| 434 |
+
self._model_stream = None
|
| 435 |
+
self._pending_token_splice = None
|
| 436 |
+
elif self._pending_reseed is not None:
|
| 437 |
new_state = self.mrt.init_state()
|
| 438 |
new_state.context_tokens = self._pending_reseed["ctx"]
|
| 439 |
self.state = new_state
|