Commit
Β·
1b98b73
1
Parent(s):
783cbeb
fixing continuity
Browse files- jam_worker.py +106 -72
- utils.py +4 -2
jam_worker.py
CHANGED
|
@@ -350,88 +350,122 @@ class JamWorker(threading.Thread):
|
|
| 350 |
self._pending_drop_intro_bars = getattr(self, "_pending_drop_intro_bars", 0) + 1
|
| 351 |
|
| 352 |
def run(self):
|
| 353 |
-
"""
|
| 354 |
-
sr_model = int(self.mrt.sample_rate)
|
| 355 |
spb = self._seconds_per_bar()
|
| 356 |
-
chunk_secs =
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
#
|
| 361 |
-
|
| 362 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
|
| 364 |
-
print("π JamWorker
|
| 365 |
|
| 366 |
while not self._stop_event.is_set():
|
| 367 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
with self._lock:
|
| 369 |
-
if self.idx > self._last_delivered_index + self._max_buffer_ahead:
|
| 370 |
-
time.sleep(0.25)
|
| 371 |
-
continue
|
| 372 |
style_vec = self.params.style_vec
|
| 373 |
-
self.mrt.guidance_weight = self.params.guidance_weight
|
| 374 |
-
self.mrt.temperature = self.params.temperature
|
| 375 |
-
self.mrt.topk = self.params.topk
|
|
|
|
| 376 |
|
| 377 |
-
|
| 378 |
self.last_chunk_started_at = time.time()
|
| 379 |
-
wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
|
| 380 |
-
self._append_model_chunk_to_stream(wav)
|
| 381 |
-
if getattr(self, "_needs_bar_realign", False):
|
| 382 |
-
self._realign_emit_pointer_to_bar(sr_model)
|
| 383 |
-
self._needs_bar_realign = False
|
| 384 |
-
# DEBUG
|
| 385 |
-
bar_samps = int(round(self._seconds_per_bar() * sr_model))
|
| 386 |
-
if bar_samps > 0 and (self._next_emit_start % bar_samps) != 0:
|
| 387 |
-
print(f"β οΈ emit pointer not aligned: phase={self._next_emit_start % bar_samps}")
|
| 388 |
-
else:
|
| 389 |
-
print("β
emit pointer aligned to bar")
|
| 390 |
-
|
| 391 |
-
self.last_chunk_completed_at = time.time()
|
| 392 |
-
|
| 393 |
-
# While we have at least one full 8-bar window available, emit it
|
| 394 |
-
while (getattr(self, "_stream", None) is not None and
|
| 395 |
-
self._stream.shape[0] - self._next_emit_start >= chunk_n_model and
|
| 396 |
-
not self._stop_event.is_set()):
|
| 397 |
-
|
| 398 |
-
seg = self._stream[self._next_emit_start:self._next_emit_start + chunk_n_model]
|
| 399 |
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
self._next_emit_start += chunk_n_model
|
| 430 |
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
if keep_from > 0:
|
| 434 |
-
self._stream = self._stream[keep_from:]
|
| 435 |
-
self._next_emit_start -= keep_from
|
| 436 |
|
| 437 |
-
print("π JamWorker
|
|
|
|
| 350 |
self._pending_drop_intro_bars = getattr(self, "_pending_drop_intro_bars", 0) + 1
|
| 351 |
|
| 352 |
def run(self):
|
| 353 |
+
"""Main worker loop - generate chunks continuously but don't get too far ahead"""
|
|
|
|
| 354 |
spb = self._seconds_per_bar()
|
| 355 |
+
chunk_secs = self.params.bars_per_chunk * spb
|
| 356 |
+
xfade = float(self.mrt.config.crossfade_length) # seconds
|
| 357 |
+
|
| 358 |
+
# local fallback stitcher that *keeps* the first head if utils.stitch_generated
|
| 359 |
+
# doesn't yet support drop_first_pre_roll
|
| 360 |
+
def _stitch_keep_head(chunks, sr: int, xfade_s: float):
|
| 361 |
+
from magenta_rt import audio as au
|
| 362 |
+
import numpy as _np
|
| 363 |
+
if not chunks:
|
| 364 |
+
raise ValueError("no chunks to stitch")
|
| 365 |
+
xfade_n = int(round(max(0.0, xfade_s) * sr))
|
| 366 |
+
# Fast-path: no crossfade
|
| 367 |
+
if xfade_n <= 0:
|
| 368 |
+
out = _np.concatenate([c.samples for c in chunks], axis=0)
|
| 369 |
+
return au.Waveform(out, sr)
|
| 370 |
+
# build equal-power curves
|
| 371 |
+
t = _np.linspace(0, _np.pi / 2, xfade_n, endpoint=False, dtype=_np.float32)
|
| 372 |
+
eq_in, eq_out = _np.sin(t)[:, None], _np.cos(t)[:, None]
|
| 373 |
+
|
| 374 |
+
first = chunks[0].samples
|
| 375 |
+
if first.shape[0] < xfade_n:
|
| 376 |
+
raise ValueError("chunk shorter than crossfade prefix")
|
| 377 |
+
out = first.copy() # π keep the head for live seam
|
| 378 |
+
|
| 379 |
+
for i in range(1, len(chunks)):
|
| 380 |
+
cur = chunks[i].samples
|
| 381 |
+
if cur.shape[0] < xfade_n:
|
| 382 |
+
# too short to crossfade; just butt-join
|
| 383 |
+
out = _np.concatenate([out, cur], axis=0)
|
| 384 |
+
continue
|
| 385 |
+
head, tail = cur[:xfade_n], cur[xfade_n:]
|
| 386 |
+
mixed = out[-xfade_n:] * eq_out + head * eq_in
|
| 387 |
+
out = _np.concatenate([out[:-xfade_n], mixed, tail], axis=0)
|
| 388 |
+
return au.Waveform(out, sr)
|
| 389 |
|
| 390 |
+
print("π JamWorker started with flow control...")
|
| 391 |
|
| 392 |
while not self._stop_event.is_set():
|
| 393 |
+
# Donβt get too far ahead of the consumer
|
| 394 |
+
if not self._should_generate_next_chunk():
|
| 395 |
+
# We're ahead enough, wait a bit for frontend to catch up
|
| 396 |
+
# (kept short so stop() stays responsive)
|
| 397 |
+
time.sleep(0.5)
|
| 398 |
+
continue
|
| 399 |
+
|
| 400 |
+
# Snapshot knobs + compute index atomically
|
| 401 |
with self._lock:
|
|
|
|
|
|
|
|
|
|
| 402 |
style_vec = self.params.style_vec
|
| 403 |
+
self.mrt.guidance_weight = float(self.params.guidance_weight)
|
| 404 |
+
self.mrt.temperature = float(self.params.temperature)
|
| 405 |
+
self.mrt.topk = int(self.params.topk)
|
| 406 |
+
next_idx = self.idx + 1
|
| 407 |
|
| 408 |
+
print(f"πΉ Generating chunk {next_idx}...")
|
| 409 |
self.last_chunk_started_at = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
|
| 411 |
+
# ---- Generate enough model sub-chunks to yield *audible* chunk_secs ----
|
| 412 |
+
# Count the first chunk at full length L, and each subsequent at (L - xfade)
|
| 413 |
+
assembled = 0.0
|
| 414 |
+
chunks = []
|
| 415 |
+
|
| 416 |
+
while assembled < chunk_secs and not self._stop_event.is_set():
|
| 417 |
+
# generate_chunk returns (au.Waveform, new_state)
|
| 418 |
+
wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
|
| 419 |
+
chunks.append(wav)
|
| 420 |
+
L = wav.samples.shape[0] / float(self.mrt.sample_rate)
|
| 421 |
+
assembled += L if len(chunks) == 1 else max(0.0, L - xfade)
|
| 422 |
+
|
| 423 |
+
if self._stop_event.is_set():
|
| 424 |
+
break
|
| 425 |
+
|
| 426 |
+
# ---- Stitch and trim at model SR (keep first head for seamless handoff) ----
|
| 427 |
+
try:
|
| 428 |
+
# Preferred path if you've added the new param in utils.stitch_generated
|
| 429 |
+
y = stitch_generated(chunks, self.mrt.sample_rate, xfade, drop_first_pre_roll=False).as_stereo()
|
| 430 |
+
except TypeError:
|
| 431 |
+
# Backward-compatible: local stitcher that keeps the head
|
| 432 |
+
y = _stitch_keep_head(chunks, int(self.mrt.sample_rate), xfade).as_stereo()
|
| 433 |
+
|
| 434 |
+
# Hard trim to the exact musical duration (still at model SR)
|
| 435 |
+
y = hard_trim_seconds(y, chunk_secs)
|
| 436 |
+
|
| 437 |
+
# ---- Post-processing ----
|
| 438 |
+
if next_idx == 1 and self.params.ref_loop is not None:
|
| 439 |
+
# match loudness to the provided reference on the very first audible chunk
|
| 440 |
+
y, _ = match_loudness_to_reference(
|
| 441 |
+
self.params.ref_loop, y,
|
| 442 |
+
method=self.params.loudness_mode,
|
| 443 |
+
headroom_db=self.params.headroom_db
|
| 444 |
)
|
| 445 |
+
else:
|
| 446 |
+
# light micro-fades to guard against clicks
|
| 447 |
+
apply_micro_fades(y, 3)
|
| 448 |
+
|
| 449 |
+
# ---- Resample + bar-snap + encode ----
|
| 450 |
+
b64, meta = self._snap_and_encode(
|
| 451 |
+
y,
|
| 452 |
+
seconds=chunk_secs,
|
| 453 |
+
target_sr=self.params.target_sr,
|
| 454 |
+
bars=self.params.bars_per_chunk
|
| 455 |
+
)
|
| 456 |
+
# small hint for the client if you want UI butter between chunks
|
| 457 |
+
meta["xfade_seconds"] = xfade
|
| 458 |
|
| 459 |
+
# ---- Publish the completed chunk ----
|
| 460 |
+
with self._lock:
|
| 461 |
+
self.idx = next_idx
|
| 462 |
+
self.outbox.append(JamChunk(index=next_idx, audio_base64=b64, metadata=meta))
|
| 463 |
+
# Keep outbox bounded (trim far-behind entries)
|
| 464 |
+
if len(self.outbox) > 10:
|
| 465 |
+
cutoff = self._last_delivered_index - 5
|
| 466 |
+
self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
|
|
|
|
| 467 |
|
| 468 |
+
self.last_chunk_completed_at = time.time()
|
| 469 |
+
print(f"β
Completed chunk {next_idx}")
|
|
|
|
|
|
|
|
|
|
| 470 |
|
| 471 |
+
print("π JamWorker stopped")
|
utils.py
CHANGED
|
@@ -69,7 +69,7 @@ def match_loudness_to_reference(
|
|
| 69 |
|
| 70 |
|
| 71 |
# ---------- Stitch / fades / trims ----------
|
| 72 |
-
def stitch_generated(chunks, sr: int, xfade_s: float
|
| 73 |
if not chunks:
|
| 74 |
raise ValueError("no chunks")
|
| 75 |
xfade_n = int(round(xfade_s * sr))
|
|
@@ -82,7 +82,9 @@ def stitch_generated(chunks, sr: int, xfade_s: float) -> au.Waveform:
|
|
| 82 |
first = chunks[0].samples
|
| 83 |
if first.shape[0] < xfade_n:
|
| 84 |
raise ValueError("chunk shorter than crossfade prefix")
|
| 85 |
-
|
|
|
|
|
|
|
| 86 |
|
| 87 |
for i in range(1, len(chunks)):
|
| 88 |
cur = chunks[i].samples
|
|
|
|
| 69 |
|
| 70 |
|
| 71 |
# ---------- Stitch / fades / trims ----------
|
| 72 |
+
def stitch_generated(chunks, sr: int, xfade_s: float, drop_first_pre_roll: bool = True):
|
| 73 |
if not chunks:
|
| 74 |
raise ValueError("no chunks")
|
| 75 |
xfade_n = int(round(xfade_s * sr))
|
|
|
|
| 82 |
first = chunks[0].samples
|
| 83 |
if first.shape[0] < xfade_n:
|
| 84 |
raise ValueError("chunk shorter than crossfade prefix")
|
| 85 |
+
|
| 86 |
+
# π§ key change:
|
| 87 |
+
out = first[xfade_n:].copy() if drop_first_pre_roll else first.copy()
|
| 88 |
|
| 89 |
for i in range(1, len(chunks)):
|
| 90 |
cur = chunks[i].samples
|