Commit
·
d41a575
1
Parent(s):
9a1b4dc
oh boy bar boundaries is hard
Browse files- jam_worker.py +3 -3
- utils.py +90 -21
jam_worker.py
CHANGED
|
@@ -80,7 +80,7 @@ class JamWorker(threading.Thread):
|
|
| 80 |
context_tokens = make_bar_aligned_context(
|
| 81 |
tokens,
|
| 82 |
bpm=self.params.bpm,
|
| 83 |
-
fps=
|
| 84 |
ctx_frames=self.mrt.config.context_length_frames,
|
| 85 |
beats_per_bar=self.params.beats_per_bar
|
| 86 |
)
|
|
@@ -213,7 +213,7 @@ class JamWorker(threading.Thread):
|
|
| 213 |
tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
|
| 214 |
tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
|
| 215 |
context_tokens = make_bar_aligned_context(tokens,
|
| 216 |
-
bpm=self.params.bpm, fps=
|
| 217 |
ctx_frames=self.mrt.config.context_length_frames,
|
| 218 |
beats_per_bar=self.params.beats_per_bar
|
| 219 |
)
|
|
@@ -242,7 +242,7 @@ class JamWorker(threading.Thread):
|
|
| 242 |
ctx = make_bar_aligned_context(
|
| 243 |
tokens,
|
| 244 |
bpm=self.params.bpm,
|
| 245 |
-
fps=
|
| 246 |
ctx_frames=self.mrt.config.context_length_frames,
|
| 247 |
beats_per_bar=self.params.beats_per_bar
|
| 248 |
)
|
|
|
|
| 80 |
context_tokens = make_bar_aligned_context(
|
| 81 |
tokens,
|
| 82 |
bpm=self.params.bpm,
|
| 83 |
+
fps=float(self.mrt.codec.frame_rate), # keep fractional fps
|
| 84 |
ctx_frames=self.mrt.config.context_length_frames,
|
| 85 |
beats_per_bar=self.params.beats_per_bar
|
| 86 |
)
|
|
|
|
| 213 |
tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
|
| 214 |
tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
|
| 215 |
context_tokens = make_bar_aligned_context(tokens,
|
| 216 |
+
bpm=self.params.bpm, fps=float(self.mrt.codec.frame_rate),
|
| 217 |
ctx_frames=self.mrt.config.context_length_frames,
|
| 218 |
beats_per_bar=self.params.beats_per_bar
|
| 219 |
)
|
|
|
|
| 242 |
ctx = make_bar_aligned_context(
|
| 243 |
tokens,
|
| 244 |
bpm=self.params.bpm,
|
| 245 |
+
fps=float(self.mrt.codec.frame_rate), # keep fractional fps
|
| 246 |
ctx_frames=self.mrt.config.context_length_frames,
|
| 247 |
beats_per_bar=self.params.beats_per_bar
|
| 248 |
)
|
utils.py
CHANGED
|
@@ -109,30 +109,99 @@ def apply_micro_fades(wav: au.Waveform, ms: int = 5) -> None:
|
|
| 109 |
|
| 110 |
|
| 111 |
# ---------- Token context helpers ----------
|
| 112 |
-
def make_bar_aligned_context(tokens, bpm, fps=25, ctx_frames=250, beats_per_bar=4):
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
tiled = np.tile(tokens, (reps, 1))
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
if max_bars is not None:
|
| 130 |
-
bars_needed = min(bars_needed, max_bars)
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
return wav
|
| 135 |
-
|
|
|
|
|
|
|
| 136 |
|
| 137 |
|
| 138 |
# ---------- SR normalize + snap ----------
|
|
|
|
| 109 |
|
| 110 |
|
| 111 |
# ---------- Token context helpers ----------
|
| 112 |
+
def make_bar_aligned_context(tokens, bpm, fps=25.0, ctx_frames=250, beats_per_bar=4):
|
| 113 |
+
"""
|
| 114 |
+
Return a ctx_frames-long slice of `tokens` whose **end** lands on the nearest
|
| 115 |
+
whole-bar boundary in codec-frame space, even when frames_per_bar is fractional.
|
| 116 |
+
|
| 117 |
+
tokens: np.ndarray of shape (T, D) or (T,) where T = codec frames
|
| 118 |
+
bpm: float
|
| 119 |
+
fps: float (codec frames per second; keep this as float)
|
| 120 |
+
ctx_frames: int (length of context window in codec frames)
|
| 121 |
+
beats_per_bar: int
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
if tokens is None:
|
| 126 |
+
raise ValueError("tokens is None")
|
| 127 |
+
tokens = np.asarray(tokens)
|
| 128 |
+
if tokens.ndim == 1:
|
| 129 |
+
tokens = tokens[:, None] # promote to (T, 1) for uniform tiling
|
| 130 |
+
|
| 131 |
+
T = tokens.shape[0]
|
| 132 |
+
if T == 0:
|
| 133 |
+
return tokens
|
| 134 |
+
|
| 135 |
+
fps = float(fps)
|
| 136 |
+
frames_per_bar_f = (beats_per_bar * 60.0 / float(bpm)) * fps # float frames per bar
|
| 137 |
+
|
| 138 |
+
# Tile a little more than we need so we can always snap the END to a bar boundary
|
| 139 |
+
reps = int(np.ceil((ctx_frames + T) / float(T))) + 1
|
| 140 |
tiled = np.tile(tokens, (reps, 1))
|
| 141 |
+
total = tiled.shape[0]
|
| 142 |
+
|
| 143 |
+
# How many whole bars fit?
|
| 144 |
+
k_bars = int(np.floor(total / frames_per_bar_f))
|
| 145 |
+
if k_bars <= 0:
|
| 146 |
+
# Fallback: just take the last ctx_frames
|
| 147 |
+
window = tiled[-ctx_frames:]
|
| 148 |
+
return window
|
| 149 |
+
|
| 150 |
+
# Snap END index to the nearest integer frame at a whole-bar boundary
|
| 151 |
+
end_idx = int(round(k_bars * frames_per_bar_f))
|
| 152 |
+
end_idx = min(max(end_idx, ctx_frames), total)
|
| 153 |
+
start_idx = end_idx - ctx_frames
|
| 154 |
+
if start_idx < 0:
|
| 155 |
+
start_idx = 0
|
| 156 |
+
end_idx = ctx_frames
|
| 157 |
+
|
| 158 |
+
window = tiled[start_idx:end_idx]
|
| 159 |
+
|
| 160 |
+
# Guard against rare off-by-one due to rounding
|
| 161 |
+
if window.shape[0] < ctx_frames:
|
| 162 |
+
pad = np.tile(tokens, (int(np.ceil((ctx_frames - window.shape[0]) / T)), 1))
|
| 163 |
+
window = np.vstack([window, pad])[:ctx_frames]
|
| 164 |
+
elif window.shape[0] > ctx_frames:
|
| 165 |
+
window = window[-ctx_frames:]
|
| 166 |
+
|
| 167 |
+
return window
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def take_bar_aligned_tail(
|
| 171 |
+
wav: au.Waveform,
|
| 172 |
+
bpm: float,
|
| 173 |
+
beats_per_bar: int,
|
| 174 |
+
ctx_seconds: float,
|
| 175 |
+
max_bars=None
|
| 176 |
+
) -> au.Waveform:
|
| 177 |
+
"""
|
| 178 |
+
Take a tail whose length is an integer number of bars, with the END aligned
|
| 179 |
+
to a bar boundary. Uses ceil for bars_needed so we never under-fill the context.
|
| 180 |
+
"""
|
| 181 |
+
import math
|
| 182 |
+
|
| 183 |
+
# seconds per bar
|
| 184 |
+
spb = (60.0 / float(bpm)) * float(beats_per_bar)
|
| 185 |
+
|
| 186 |
+
# Pick enough whole bars to cover ctx_seconds (avoid underfilling on round-down).
|
| 187 |
+
# The small epsilon avoids an extra bar due to FP jitter when ctx_seconds ~= k * spb.
|
| 188 |
+
eps = 1e-9
|
| 189 |
+
bars_needed = max(1, int(math.ceil((float(ctx_seconds) - eps) / spb)))
|
| 190 |
+
|
| 191 |
if max_bars is not None:
|
| 192 |
+
bars_needed = min(bars_needed, int(max_bars))
|
| 193 |
+
|
| 194 |
+
# Convert bars -> samples (do rounding once at the end for stability)
|
| 195 |
+
samples_per_bar_f = spb * float(wav.sample_rate)
|
| 196 |
+
n = int(round(bars_needed * samples_per_bar_f))
|
| 197 |
+
|
| 198 |
+
total = int(wav.samples.shape[0])
|
| 199 |
+
if n >= total:
|
| 200 |
+
# Not enough audio to take that many bars—return as-is (current behavior).
|
| 201 |
return wav
|
| 202 |
+
|
| 203 |
+
start = total - n
|
| 204 |
+
return au.Waveform(wav.samples[start:], wav.sample_rate)
|
| 205 |
|
| 206 |
|
| 207 |
# ---------- SR normalize + snap ----------
|