thecollabagepatch commited on
Commit
2446a8b
Β·
1 Parent(s): 7ae8a62

one shot generations start failing after a few successful ones...

Browse files
Files changed (2) hide show
  1. app.py +36 -0
  2. one_shot_generation.py +60 -45
app.py CHANGED
@@ -199,8 +199,44 @@ def _patch_t5x_for_gpu_coords():
199
  except Exception as e:
200
  import logging; logging.exception("t5x GPU-coords patch failed: %s", e)
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  # Call the patch immediately at import time (before MagentaRT init)
203
  _patch_t5x_for_gpu_coords()
 
204
 
205
  jam_registry: dict[str, JamWorker] = {}
206
  jam_lock = threading.Lock()
 
199
  except Exception as e:
200
  import logging; logging.exception("t5x GPU-coords patch failed: %s", e)
201
 
202
+ def _patch_magenta_rt_asset_fetch():
203
+ """
204
+ Patch magenta_rt.asset._fetch_single_hf to handle None response gracefully.
205
+ Prevents AttributeError when network timeouts occur.
206
+ """
207
+ try:
208
+ from magenta_rt import asset
209
+ import logging
210
+
211
+ # Save original function
212
+ _original_fetch = asset._fetch_single_hf
213
+
214
+ def _fetch_single_hf_safe(*args, **kwargs):
215
+ """Wrapper that fixes the None response.status_code bug"""
216
+ try:
217
+ return _original_fetch(*args, **kwargs)
218
+ except Exception as e:
219
+ # This is the bug fix: check if response exists before accessing it
220
+ response = getattr(e, 'response', None)
221
+ if response is not None and hasattr(response, 'status_code'):
222
+ if response.status_code == 429:
223
+ # Original code's rate-limit handling would go here
224
+ logging.warning("Rate limited by HuggingFace Hub")
225
+ raise
226
+ # For all other cases (including timeout with no response), re-raise
227
+ logging.error(f"HuggingFace download failed: {e}")
228
+ raise
229
+
230
+ # Apply the patch
231
+ asset._fetch_single_hf = _fetch_single_hf_safe
232
+ logging.info("Patched magenta_rt.asset._fetch_single_hf for safer error handling.")
233
+
234
+ except Exception as e:
235
+ logging.exception("magenta_rt asset fetch patch failed: %s", e)
236
+
237
  # Call the patch immediately at import time (before MagentaRT init)
238
  _patch_t5x_for_gpu_coords()
239
+ _patch_magenta_rt_asset_fetch()
240
 
241
  jam_registry: dict[str, JamWorker] = {}
242
  jam_lock = threading.Lock()
one_shot_generation.py CHANGED
@@ -33,97 +33,112 @@ def generate_loop_continuation_with_mrt(
33
  ):
34
  """
35
  Generate a continuation of an input loop using MagentaRT.
 
36
 
37
- Args:
38
- mrt: MagentaRT instance
39
- input_wav_path: Path to input audio file
40
- bpm: Beats per minute
41
- extra_styles: List of additional text style prompts (optional)
42
- style_weights: List of weights for style prompts (optional)
43
- bars: Number of bars to generate
44
- beats_per_bar: Beats per bar (typically 4)
45
- loop_weight: Weight for the input loop's style embedding
46
- loudness_mode: Loudness matching method ("auto", "lufs", "rms", "none")
47
- loudness_headroom_db: Headroom in dB for peak limiting
48
- intro_bars_to_drop: Number of intro bars to generate then drop
49
- progress_cb: Braindead progress updates for JUCE
50
 
51
- Returns:
52
- Tuple of (au.Waveform output, dict loudness_stats)
53
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  # Load & prep (unchanged)
55
  loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo()
56
 
57
- # Use tail for context (your recent change)
58
  codec_fps = float(mrt.codec.frame_rate)
59
  ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
60
  loop_for_context = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds)
61
 
62
- tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32)
63
- tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
 
 
64
 
65
- # Bar-aligned token window (unchanged)
66
  context_tokens = make_bar_aligned_context(
67
  tokens, bpm=bpm, fps=float(mrt.codec.frame_rate),
68
  ctx_frames=mrt.config.context_length_frames, beats_per_bar=beats_per_bar
69
  )
 
 
70
  state = mrt.init_state()
71
- state.context_tokens = context_tokens
 
 
 
 
 
 
 
72
 
73
- # STYLE embed (optional: switch to loop_for_context if you want stronger "recent" bias)
74
  loop_embed = mrt.embed_style(loop_for_context)
75
- embeds, weights = [loop_embed], [float(loop_weight)]
76
  if extra_styles:
77
  for i, s in enumerate(extra_styles):
78
  if s.strip():
79
- embeds.append(mrt.embed_style(s.strip()))
80
  w = style_weights[i] if (style_weights and i < len(style_weights)) else 1.0
81
  weights.append(float(w))
82
  wsum = float(sum(weights)) or 1.0
83
  weights = [w / wsum for w in weights]
84
- combined_style = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(loop_embed.dtype)
85
 
86
- # --- Length math ---
87
  seconds_per_bar = beats_per_bar * (60.0 / bpm)
88
  total_secs = bars * seconds_per_bar
89
  drop_bars = max(0, int(intro_bars_to_drop))
90
- drop_secs = min(drop_bars, bars) * seconds_per_bar # clamp to <= bars
91
- gen_total_secs = total_secs + drop_secs # generate extra
92
 
93
- # Chunk scheduling to cover gen_total_secs
94
- chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate # ~2.0
95
  steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1
96
 
97
  if progress_cb:
98
- progress_cb(0, steps) # announce total before first chunk
99
 
100
- # Generate
101
  chunks = []
102
  for i in range(steps):
103
- wav, state = mrt.generate_chunk(state=state, style=combined_style)
 
104
  chunks.append(wav)
 
 
 
 
 
105
  if progress_cb:
106
- progress_cb(i + 1, steps) # <-- report chunk progress
 
107
 
108
- # Stitch continuous audio
109
  stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
110
-
111
- # Trim to generated length (bars + dropped bars)
112
  stitched = hard_trim_seconds(stitched, gen_total_secs)
113
 
114
- # πŸ‘‰ Drop the intro bars
115
  if drop_secs > 0:
116
  n_drop = int(round(drop_secs * stitched.sample_rate))
117
  stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate)
118
 
119
- # Final exact-length trim to requested bars
120
  out = hard_trim_seconds(stitched, total_secs)
121
 
122
- # Final polish AFTER drop
123
- # out = out.peak_normalize(0.95)
124
-
125
-
126
- # Loudness match to input (after drop) so bar 1 sits right
127
  out, loud_stats = apply_barwise_loudness_match(
128
  out=out,
129
  ref_loop=loop,
@@ -131,7 +146,7 @@ def generate_loop_continuation_with_mrt(
131
  beats_per_bar=beats_per_bar,
132
  method=loudness_mode,
133
  headroom_db=loudness_headroom_db,
134
- smooth_ms=50, # 50ms crossfade between bars
135
  )
136
 
137
  apply_micro_fades(out, 5)
 
33
  ):
34
  """
35
  Generate a continuation of an input loop using MagentaRT.
36
+ """
37
 
38
+ # ===== NEW: Force codec/model reset before generation =====
39
+ # Clear any accumulated state in the codec that might cause silence issues
40
+ try:
41
+ # Option 1: If codec has explicit reset
42
+ if hasattr(mrt.codec, 'reset') and callable(mrt.codec.reset):
43
+ mrt.codec.reset()
 
 
 
 
 
 
 
44
 
45
+ # Option 2: Force clear any cached codec state
46
+ if hasattr(mrt.codec, '_encode_cache'):
47
+ mrt.codec._encode_cache = None
48
+ if hasattr(mrt.codec, '_decode_cache'):
49
+ mrt.codec._decode_cache = None
50
+
51
+ # Option 3: Clear JAX compilation caches (nuclear but effective)
52
+ # Uncomment if issues persist:
53
+ # import jax
54
+ # jax.clear_caches()
55
+
56
+ except Exception as e:
57
+ import logging
58
+ logging.warning(f"Codec reset attempt failed (non-fatal): {e}")
59
+ # ============================================================
60
+
61
  # Load & prep (unchanged)
62
  loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo()
63
 
64
+ # Use tail for context
65
  codec_fps = float(mrt.codec.frame_rate)
66
  ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
67
  loop_for_context = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds)
68
 
69
+ # ===== NEW: Force fresh token copies =====
70
+ tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32, copy=True) # ← Added copy=True
71
+ tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth].copy() # ← Added .copy()
72
+ # ==========================================
73
 
74
+ # Bar-aligned token window
75
  context_tokens = make_bar_aligned_context(
76
  tokens, bpm=bpm, fps=float(mrt.codec.frame_rate),
77
  ctx_frames=mrt.config.context_length_frames, beats_per_bar=beats_per_bar
78
  )
79
+
80
+ # ===== NEW: More aggressive state initialization =====
81
  state = mrt.init_state()
82
+
83
+ # Ensure context_tokens is a fresh array, not a view
84
+ state.context_tokens = np.array(context_tokens, dtype=np.int32, copy=True)
85
+
86
+ # If there's any internal model state cache, clear it
87
+ if hasattr(state, '_cache'):
88
+ state._cache = None
89
+ # =====================================================
90
 
91
+ # STYLE embed (unchanged but ensure fresh embedding)
92
  loop_embed = mrt.embed_style(loop_for_context)
93
+ embeds, weights = [loop_embed.copy()], [float(loop_weight)] # ← Added .copy()
94
  if extra_styles:
95
  for i, s in enumerate(extra_styles):
96
  if s.strip():
97
+ embeds.append(mrt.embed_style(s.strip()).copy()) # ← Added .copy()
98
  w = style_weights[i] if (style_weights and i < len(style_weights)) else 1.0
99
  weights.append(float(w))
100
  wsum = float(sum(weights)) or 1.0
101
  weights = [w / wsum for w in weights]
102
+ combined_style = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(loop_embed.dtype, copy=True) # ← Added copy=True
103
 
104
+ # --- Length math (unchanged) ---
105
  seconds_per_bar = beats_per_bar * (60.0 / bpm)
106
  total_secs = bars * seconds_per_bar
107
  drop_bars = max(0, int(intro_bars_to_drop))
108
+ drop_secs = min(drop_bars, bars) * seconds_per_bar
109
+ gen_total_secs = total_secs + drop_secs
110
 
111
+ chunk_secs = mrt.config.chunk_length_frames * mrt.config.frame_length_samples / mrt.sample_rate
 
112
  steps = int(math.ceil(gen_total_secs / chunk_secs)) + 1
113
 
114
  if progress_cb:
115
+ progress_cb(0, steps)
116
 
117
+ # ===== NEW: Generation loop with explicit state refresh =====
118
  chunks = []
119
  for i in range(steps):
120
+ # Generate chunk with current state
121
+ wav, new_state = mrt.generate_chunk(state=state, style=combined_style)
122
  chunks.append(wav)
123
+
124
+ # CRITICAL: Replace state, don't mutate it
125
+ # This ensures we're not accumulating corrupted state
126
+ state = new_state
127
+
128
  if progress_cb:
129
+ progress_cb(i + 1, steps)
130
+ # ============================================================
131
 
132
+ # Rest of the function unchanged...
133
  stitched = stitch_generated(chunks, mrt.sample_rate, mrt.config.crossfade_length).as_stereo()
 
 
134
  stitched = hard_trim_seconds(stitched, gen_total_secs)
135
 
 
136
  if drop_secs > 0:
137
  n_drop = int(round(drop_secs * stitched.sample_rate))
138
  stitched = au.Waveform(stitched.samples[n_drop:], stitched.sample_rate)
139
 
 
140
  out = hard_trim_seconds(stitched, total_secs)
141
 
 
 
 
 
 
142
  out, loud_stats = apply_barwise_loudness_match(
143
  out=out,
144
  ref_loop=loop,
 
146
  beats_per_bar=beats_per_bar,
147
  method=loudness_mode,
148
  headroom_db=loudness_headroom_db,
149
+ smooth_ms=50,
150
  )
151
 
152
  apply_micro_fades(out, 5)