Commit
·
87147f5
1
Parent(s):
56dfd15
warmup
Browse files
app.py
CHANGED
|
@@ -15,6 +15,8 @@ from utils import (
|
|
| 15 |
|
| 16 |
from jam_worker import JamWorker, JamParams, JamChunk
|
| 17 |
import uuid, threading
|
|
|
|
|
|
|
| 18 |
|
| 19 |
import gradio as gr
|
| 20 |
from typing import Optional
|
|
@@ -358,6 +360,82 @@ def get_mrt():
|
|
| 358 |
_MRT = system.MagentaRT(tag="large", guidance_weight=5.0, device="gpu", lazy=False)
|
| 359 |
return _MRT
|
| 360 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
@app.post("/generate")
|
| 362 |
def generate(
|
| 363 |
loop_audio: UploadFile = File(...),
|
|
|
|
| 15 |
|
| 16 |
from jam_worker import JamWorker, JamParams, JamChunk
|
| 17 |
import uuid, threading
|
| 18 |
+
import os
|
| 19 |
+
import logging
|
| 20 |
|
| 21 |
import gradio as gr
|
| 22 |
from typing import Optional
|
|
|
|
| 360 |
_MRT = system.MagentaRT(tag="large", guidance_weight=5.0, device="gpu", lazy=False)
|
| 361 |
return _MRT
|
| 362 |
|
| 363 |
+
_WARMED = False
|
| 364 |
+
_WARMUP_LOCK = threading.Lock()
|
| 365 |
+
|
| 366 |
+
def _mrt_warmup():
|
| 367 |
+
"""
|
| 368 |
+
Build a minimal, bar-aligned silent context and run one 2s generate_chunk
|
| 369 |
+
to trigger XLA JIT & autotune so first real request is fast.
|
| 370 |
+
"""
|
| 371 |
+
global _WARMED
|
| 372 |
+
with _WARMUP_LOCK:
|
| 373 |
+
if _WARMED:
|
| 374 |
+
return
|
| 375 |
+
try:
|
| 376 |
+
mrt = get_mrt()
|
| 377 |
+
|
| 378 |
+
# --- derive timing from model config ---
|
| 379 |
+
codec_fps = float(mrt.codec.frame_rate)
|
| 380 |
+
ctx_seconds = float(mrt.config.context_length_frames) / codec_fps
|
| 381 |
+
sr = int(mrt.sample_rate)
|
| 382 |
+
|
| 383 |
+
# We'll align to 120 BPM, 4/4, and generate one ~2s chunk
|
| 384 |
+
bpm = 120.0
|
| 385 |
+
beats_per_bar = 4
|
| 386 |
+
|
| 387 |
+
# --- build a silent, stereo context of ctx_seconds ---
|
| 388 |
+
import numpy as np, soundfile as sf
|
| 389 |
+
samples = int(max(1, round(ctx_seconds * sr)))
|
| 390 |
+
silent = np.zeros((samples, 2), dtype=np.float32)
|
| 391 |
+
|
| 392 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
| 393 |
+
sf.write(tmp.name, silent, sr, subtype="PCM_16")
|
| 394 |
+
tmp_path = tmp.name
|
| 395 |
+
|
| 396 |
+
try:
|
| 397 |
+
# Load as Waveform and take a tail of exactly ctx_seconds
|
| 398 |
+
loop = au.Waveform.from_file(tmp_path).resample(sr).as_stereo()
|
| 399 |
+
seconds_per_bar = beats_per_bar * (60.0 / bpm)
|
| 400 |
+
ctx_tail = take_bar_aligned_tail(loop, bpm, beats_per_bar, ctx_seconds)
|
| 401 |
+
|
| 402 |
+
# Tokens for context window
|
| 403 |
+
tokens_full = mrt.codec.encode(ctx_tail).astype(np.int32)
|
| 404 |
+
tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
|
| 405 |
+
context_tokens = make_bar_aligned_context(
|
| 406 |
+
tokens,
|
| 407 |
+
bpm=bpm,
|
| 408 |
+
fps=int(mrt.codec.frame_rate),
|
| 409 |
+
ctx_frames=mrt.config.context_length_frames,
|
| 410 |
+
beats_per_bar=beats_per_bar,
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
# Init state and a basic style vector (text token is fine)
|
| 414 |
+
state = mrt.init_state()
|
| 415 |
+
state.context_tokens = context_tokens
|
| 416 |
+
style_vec = mrt.embed_style("warmup")
|
| 417 |
+
|
| 418 |
+
# --- one throwaway chunk (~2s) ---
|
| 419 |
+
_wav, _state = mrt.generate_chunk(state=state, style=style_vec)
|
| 420 |
+
|
| 421 |
+
logging.info("MagentaRT warmup complete.")
|
| 422 |
+
finally:
|
| 423 |
+
try:
|
| 424 |
+
os.unlink(tmp_path)
|
| 425 |
+
except Exception:
|
| 426 |
+
pass
|
| 427 |
+
|
| 428 |
+
_WARMED = True
|
| 429 |
+
except Exception as e:
|
| 430 |
+
# Never crash on warmup errors; log and continue serving
|
| 431 |
+
logging.exception("MagentaRT warmup failed (continuing without warmup): %s", e)
|
| 432 |
+
|
| 433 |
+
# Kick it off in the background on server start
|
| 434 |
+
@app.on_event("startup")
|
| 435 |
+
def _kickoff_warmup():
|
| 436 |
+
if os.getenv("MRT_WARMUP", "1") != "0":
|
| 437 |
+
threading.Thread(target=_mrt_warmup, name="mrt-warmup", daemon=True).start()
|
| 438 |
+
|
| 439 |
@app.post("/generate")
|
| 440 |
def generate(
|
| 441 |
loop_audio: UploadFile = File(...),
|