Commit
·
5af7cde
1
Parent(s):
c1e4dcd
fixing bar-aligned context inside /generate route just like we did for jam_worker
Browse files
utils.py
CHANGED
|
@@ -111,44 +111,40 @@ def apply_micro_fades(wav: au.Waveform, ms: int = 5) -> None:
|
|
| 111 |
# ---------- Token context helpers ----------
|
| 112 |
def make_bar_aligned_context(tokens, bpm, fps=25.0, ctx_frames=250, beats_per_bar=4):
|
| 113 |
"""
|
| 114 |
-
Return a ctx_frames-long slice of `tokens` whose **end** lands on
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
tokens: np.ndarray of shape (T, D) or (T,) where T = codec frames
|
| 118 |
-
bpm: float
|
| 119 |
-
fps: float (codec frames per second; keep this as float)
|
| 120 |
-
ctx_frames: int (length of context window in codec frames)
|
| 121 |
-
beats_per_bar: int
|
| 122 |
"""
|
| 123 |
-
|
| 124 |
|
| 125 |
if tokens is None:
|
| 126 |
raise ValueError("tokens is None")
|
| 127 |
tokens = np.asarray(tokens)
|
| 128 |
if tokens.ndim == 1:
|
| 129 |
-
tokens = tokens[:, None]
|
| 130 |
|
| 131 |
T = tokens.shape[0]
|
| 132 |
if T == 0:
|
| 133 |
return tokens
|
| 134 |
|
| 135 |
fps = float(fps)
|
| 136 |
-
frames_per_bar_f = (beats_per_bar * 60.0 / float(bpm)) * fps # float frames per bar
|
| 137 |
|
| 138 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
reps = int(np.ceil((ctx_frames + T) / float(T))) + 1
|
| 140 |
tiled = np.tile(tokens, (reps, 1))
|
| 141 |
total = tiled.shape[0]
|
| 142 |
|
| 143 |
-
# How many whole bars fit?
|
| 144 |
-
k_bars =
|
| 145 |
if k_bars <= 0:
|
| 146 |
-
|
| 147 |
-
window = tiled[-ctx_frames:]
|
| 148 |
-
return window
|
| 149 |
|
| 150 |
-
# Snap END
|
| 151 |
-
end_idx = int(
|
| 152 |
end_idx = min(max(end_idx, ctx_frames), total)
|
| 153 |
start_idx = end_idx - ctx_frames
|
| 154 |
if start_idx < 0:
|
|
@@ -157,7 +153,7 @@ def make_bar_aligned_context(tokens, bpm, fps=25.0, ctx_frames=250, beats_per_ba
|
|
| 157 |
|
| 158 |
window = tiled[start_idx:end_idx]
|
| 159 |
|
| 160 |
-
# Guard
|
| 161 |
if window.shape[0] < ctx_frames:
|
| 162 |
pad = np.tile(tokens, (int(np.ceil((ctx_frames - window.shape[0]) / T)), 1))
|
| 163 |
window = np.vstack([window, pad])[:ctx_frames]
|
|
@@ -167,6 +163,7 @@ def make_bar_aligned_context(tokens, bpm, fps=25.0, ctx_frames=250, beats_per_ba
|
|
| 167 |
return window
|
| 168 |
|
| 169 |
|
|
|
|
| 170 |
def take_bar_aligned_tail(
|
| 171 |
wav: au.Waveform,
|
| 172 |
bpm: float,
|
|
|
|
| 111 |
# ---------- Token context helpers ----------
|
| 112 |
def make_bar_aligned_context(tokens, bpm, fps=25.0, ctx_frames=250, beats_per_bar=4):
|
| 113 |
"""
|
| 114 |
+
Return a ctx_frames-long slice of `tokens` whose **end** lands on an integer
|
| 115 |
+
bar boundary in codec-frame space (model runs at `fps`, typically 25).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
"""
|
|
|
|
| 117 |
|
| 118 |
if tokens is None:
|
| 119 |
raise ValueError("tokens is None")
|
| 120 |
tokens = np.asarray(tokens)
|
| 121 |
if tokens.ndim == 1:
|
| 122 |
+
tokens = tokens[:, None]
|
| 123 |
|
| 124 |
T = tokens.shape[0]
|
| 125 |
if T == 0:
|
| 126 |
return tokens
|
| 127 |
|
| 128 |
fps = float(fps)
|
|
|
|
| 129 |
|
| 130 |
+
# float frames per bar (e.g., ~65.934 at 91 BPM for 4/4 @ 25fps)
|
| 131 |
+
frames_per_bar_f = (beats_per_bar * 60.0 / float(bpm)) * fps
|
| 132 |
+
|
| 133 |
+
# >>> KEY FIX: quantize bar length to an integer number of codec frames
|
| 134 |
+
frames_per_bar_i = max(1, int(round(frames_per_bar_f)))
|
| 135 |
+
|
| 136 |
+
# Tile so we can always snap the *end* to a bar boundary and still have ctx_frames
|
| 137 |
reps = int(np.ceil((ctx_frames + T) / float(T))) + 1
|
| 138 |
tiled = np.tile(tokens, (reps, 1))
|
| 139 |
total = tiled.shape[0]
|
| 140 |
|
| 141 |
+
# How many whole integer bars fit in the tiled sequence?
|
| 142 |
+
k_bars = total // frames_per_bar_i
|
| 143 |
if k_bars <= 0:
|
| 144 |
+
return tiled[-ctx_frames:]
|
|
|
|
|
|
|
| 145 |
|
| 146 |
+
# Snap END to an exact integer multiple of frames_per_bar_i
|
| 147 |
+
end_idx = int(k_bars * frames_per_bar_i)
|
| 148 |
end_idx = min(max(end_idx, ctx_frames), total)
|
| 149 |
start_idx = end_idx - ctx_frames
|
| 150 |
if start_idx < 0:
|
|
|
|
| 153 |
|
| 154 |
window = tiled[start_idx:end_idx]
|
| 155 |
|
| 156 |
+
# Guard off-by-one
|
| 157 |
if window.shape[0] < ctx_frames:
|
| 158 |
pad = np.tile(tokens, (int(np.ceil((ctx_frames - window.shape[0]) / T)), 1))
|
| 159 |
window = np.vstack([window, pad])[:ctx_frames]
|
|
|
|
| 163 |
return window
|
| 164 |
|
| 165 |
|
| 166 |
+
|
| 167 |
def take_bar_aligned_tail(
|
| 168 |
wav: au.Waveform,
|
| 169 |
bpm: float,
|