Commit
Β·
184daaa
1
Parent(s):
a36f465
fixing jam seamlessness
Browse files- jam_worker.py +95 -59
jam_worker.py
CHANGED
|
@@ -3,6 +3,7 @@ import threading, time, base64, io, uuid
|
|
| 3 |
from dataclasses import dataclass, field
|
| 4 |
import numpy as np
|
| 5 |
import soundfile as sf
|
|
|
|
| 6 |
|
| 7 |
from utils import (
|
| 8 |
match_loudness_to_reference, stitch_generated, hard_trim_seconds,
|
|
@@ -155,77 +156,112 @@ class JamWorker(threading.Thread):
|
|
| 155 |
}
|
| 156 |
return b64, meta
|
| 157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
def run(self):
|
| 159 |
-
"""
|
|
|
|
| 160 |
spb = self._seconds_per_bar()
|
| 161 |
-
chunk_secs = self.params.bars_per_chunk * spb
|
|
|
|
| 162 |
xfade = self.mrt.config.crossfade_length
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
while not self._stop_event.is_set():
|
| 167 |
-
#
|
| 168 |
-
if not self._should_generate_next_chunk():
|
| 169 |
-
# We're ahead enough, wait a bit for frontend to catch up
|
| 170 |
-
print(f"βΈοΈ Buffer full, waiting for consumption...")
|
| 171 |
-
time.sleep(0.5)
|
| 172 |
-
continue
|
| 173 |
-
|
| 174 |
-
# Generate the next chunk
|
| 175 |
with self._lock:
|
|
|
|
|
|
|
|
|
|
| 176 |
style_vec = self.params.style_vec
|
| 177 |
self.mrt.guidance_weight = self.params.guidance_weight
|
| 178 |
-
self.mrt.temperature
|
| 179 |
-
self.mrt.topk
|
| 180 |
-
next_idx = self.idx + 1
|
| 181 |
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
# Generate enough model chunks to cover chunk_secs
|
| 185 |
-
need = chunk_secs
|
| 186 |
-
chunks = []
|
| 187 |
self.last_chunk_started_at = time.time()
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
)
|
| 208 |
-
else:
|
| 209 |
-
apply_micro_fades(y, 3)
|
| 210 |
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
self.idx = next_idx
|
| 221 |
-
self.outbox.append(JamChunk(index=next_idx, audio_base64=b64, metadata=meta))
|
| 222 |
-
|
| 223 |
-
# Keep outbox bounded (remove old chunks)
|
| 224 |
-
if len(self.outbox) > 10:
|
| 225 |
-
# Remove chunks that are way behind the delivery point
|
| 226 |
-
self.outbox = [ch for ch in self.outbox if ch.index > self._last_delivered_index - 5]
|
| 227 |
|
| 228 |
-
|
| 229 |
-
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
-
print("π JamWorker stopped")
|
|
|
|
| 3 |
from dataclasses import dataclass, field
|
| 4 |
import numpy as np
|
| 5 |
import soundfile as sf
|
| 6 |
+
from magenta_rt import audio as au
|
| 7 |
|
| 8 |
from utils import (
|
| 9 |
match_loudness_to_reference, stitch_generated, hard_trim_seconds,
|
|
|
|
| 156 |
}
|
| 157 |
return b64, meta
|
| 158 |
|
| 159 |
+
def _append_model_chunk_to_stream(self, wav):
|
| 160 |
+
"""Incrementally append a model chunk with equal-power crossfade."""
|
| 161 |
+
xfade_s = float(self.mrt.config.crossfade_length)
|
| 162 |
+
sr = int(self.mrt.sample_rate)
|
| 163 |
+
xfade_n = int(round(xfade_s * sr))
|
| 164 |
+
|
| 165 |
+
s = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None]
|
| 166 |
+
|
| 167 |
+
if getattr(self, "_stream", None) is None:
|
| 168 |
+
# First chunk: drop model pre-roll (xfade head)
|
| 169 |
+
if s.shape[0] > xfade_n:
|
| 170 |
+
self._stream = s[xfade_n:].astype(np.float32, copy=True)
|
| 171 |
+
else:
|
| 172 |
+
self._stream = np.zeros((0, s.shape[1]), dtype=np.float32)
|
| 173 |
+
self._next_emit_start = 0 # pointer into _stream (model SR samples)
|
| 174 |
+
return
|
| 175 |
+
|
| 176 |
+
# Crossfade last xfade_n samples of _stream with head of new s
|
| 177 |
+
if s.shape[0] <= xfade_n or self._stream.shape[0] < xfade_n:
|
| 178 |
+
# Degenerate safeguard
|
| 179 |
+
self._stream = np.concatenate([self._stream, s], axis=0)
|
| 180 |
+
return
|
| 181 |
+
|
| 182 |
+
tail = self._stream[-xfade_n:]
|
| 183 |
+
head = s[:xfade_n]
|
| 184 |
+
|
| 185 |
+
# Equal-power envelopes
|
| 186 |
+
t = np.linspace(0, np.pi/2, xfade_n, endpoint=False, dtype=np.float32)[:, None]
|
| 187 |
+
eq_in, eq_out = np.sin(t), np.cos(t)
|
| 188 |
+
mixed = tail * eq_out + head * eq_in
|
| 189 |
+
|
| 190 |
+
self._stream = np.concatenate([self._stream[:-xfade_n], mixed, s[xfade_n:]], axis=0)
|
| 191 |
+
|
| 192 |
def run(self):
|
| 193 |
+
"""Continuous stream + sliding 8-bar window emitter."""
|
| 194 |
+
sr_model = int(self.mrt.sample_rate)
|
| 195 |
spb = self._seconds_per_bar()
|
| 196 |
+
chunk_secs = float(self.params.bars_per_chunk) * spb
|
| 197 |
+
chunk_n_model = int(round(chunk_secs * sr_model))
|
| 198 |
xfade = self.mrt.config.crossfade_length
|
| 199 |
|
| 200 |
+
# Streaming state
|
| 201 |
+
self._stream = None # np.ndarray [S, C] at model SR
|
| 202 |
+
self._next_emit_start = 0 # sample pointer for next 8-bar cut
|
| 203 |
+
|
| 204 |
+
print("π JamWorker (streaming) started...")
|
| 205 |
+
|
| 206 |
while not self._stop_event.is_set():
|
| 207 |
+
# Flow control: don't get too far ahead of the consumer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
with self._lock:
|
| 209 |
+
if self.idx > self._last_delivered_index + self._max_buffer_ahead:
|
| 210 |
+
time.sleep(0.25)
|
| 211 |
+
continue
|
| 212 |
style_vec = self.params.style_vec
|
| 213 |
self.mrt.guidance_weight = self.params.guidance_weight
|
| 214 |
+
self.mrt.temperature = self.params.temperature
|
| 215 |
+
self.mrt.topk = self.params.topk
|
|
|
|
| 216 |
|
| 217 |
+
# Generate ONE model chunk and append to the continuous stream
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
self.last_chunk_started_at = time.time()
|
| 219 |
+
wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
|
| 220 |
+
self._append_model_chunk_to_stream(wav)
|
| 221 |
+
self.last_chunk_completed_at = time.time()
|
| 222 |
+
|
| 223 |
+
# While we have at least one full 8-bar window available, emit it
|
| 224 |
+
while (getattr(self, "_stream", None) is not None and
|
| 225 |
+
self._stream.shape[0] - self._next_emit_start >= chunk_n_model and
|
| 226 |
+
not self._stop_event.is_set()):
|
| 227 |
+
|
| 228 |
+
seg = self._stream[self._next_emit_start:self._next_emit_start + chunk_n_model]
|
| 229 |
+
|
| 230 |
+
# Wrap as Waveform at model SR
|
| 231 |
+
y = au.Waveform(seg.astype(np.float32, copy=False), sr_model).as_stereo()
|
| 232 |
+
|
| 233 |
+
# Post-processing:
|
| 234 |
+
# - First emitted chunk: loudness-match to ref_loop
|
| 235 |
+
# - No micro-fades on mid-stream windows (they cause dips)
|
| 236 |
+
next_idx = self.idx + 1
|
| 237 |
+
if next_idx == 1 and self.params.ref_loop is not None:
|
| 238 |
+
y, _ = match_loudness_to_reference(
|
| 239 |
+
self.params.ref_loop, y,
|
| 240 |
+
method=self.params.loudness_mode,
|
| 241 |
+
headroom_db=self.params.headroom_db
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# Resample + snap + encode exactly chunk_secs long
|
| 245 |
+
b64, meta = self._snap_and_encode(
|
| 246 |
+
y, seconds=chunk_secs,
|
| 247 |
+
target_sr=self.params.target_sr,
|
| 248 |
+
bars=self.params.bars_per_chunk
|
| 249 |
)
|
|
|
|
|
|
|
| 250 |
|
| 251 |
+
with self._lock:
|
| 252 |
+
self.idx = next_idx
|
| 253 |
+
self.outbox.append(JamChunk(index=next_idx, audio_base64=b64, metadata=meta))
|
| 254 |
+
# Bound the outbox
|
| 255 |
+
if len(self.outbox) > 10:
|
| 256 |
+
self.outbox = [ch for ch in self.outbox if ch.index > self._last_delivered_index - 5]
|
| 257 |
|
| 258 |
+
# Advance window pointer to the next 8-bar slot
|
| 259 |
+
self._next_emit_start += chunk_n_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
+
# Trim old samples to keep memory bounded (keep a little guard)
|
| 262 |
+
keep_from = max(0, self._next_emit_start - chunk_n_model) # keep 1 extra window
|
| 263 |
+
if keep_from > 0:
|
| 264 |
+
self._stream = self._stream[keep_from:]
|
| 265 |
+
self._next_emit_start -= keep_from
|
| 266 |
|
| 267 |
+
print("π JamWorker (streaming) stopped")
|