Commit
·
147fd8b
1
Parent(s):
406bd0f
attempting to use finetunes
Browse files
app.py
CHANGED
|
@@ -66,6 +66,51 @@ except Exception:
|
|
| 66 |
class ClientDisconnected(Exception): # fallback
|
| 67 |
pass
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
async def send_json_safe(ws: WebSocket, obj) -> bool:
|
| 70 |
"""Try to send. Returns False if the socket is (or becomes) closed."""
|
| 71 |
if ws.client_state == WebSocketState.DISCONNECTED or ws.application_state == WebSocketState.DISCONNECTED:
|
|
@@ -569,7 +614,14 @@ def get_mrt():
|
|
| 569 |
if _MRT is None:
|
| 570 |
with _MRT_LOCK:
|
| 571 |
if _MRT is None:
|
| 572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
return _MRT
|
| 574 |
|
| 575 |
_WARMED = False
|
|
@@ -648,6 +700,31 @@ def _kickoff_warmup():
|
|
| 648 |
if os.getenv("MRT_WARMUP", "1") != "0":
|
| 649 |
threading.Thread(target=_mrt_warmup, name="mrt-warmup", daemon=True).start()
|
| 650 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 651 |
@app.post("/generate")
|
| 652 |
def generate(
|
| 653 |
loop_audio: UploadFile = File(...),
|
|
|
|
| 66 |
class ClientDisconnected(Exception): # fallback
|
| 67 |
pass
|
| 68 |
|
| 69 |
+
import re
|
| 70 |
+
from pathlib import Path
|
| 71 |
+
|
| 72 |
+
def _resolve_checkpoint_dir() -> str | None:
|
| 73 |
+
"""
|
| 74 |
+
Returns a local directory path for MagentaRT(checkpoint_dir=...),
|
| 75 |
+
using a Hugging Face model repo that contains subfolders like:
|
| 76 |
+
checkpoint_1861001/, checkpoint_1862001/, ...
|
| 77 |
+
"""
|
| 78 |
+
repo_id = os.getenv("MRT_CKPT_REPO")
|
| 79 |
+
if not repo_id:
|
| 80 |
+
return None # fall back to builtin 'base'/'large' assets
|
| 81 |
+
|
| 82 |
+
step = os.getenv("MRT_CKPT_STEP") # e.g., "1863001"
|
| 83 |
+
allow = None
|
| 84 |
+
if step:
|
| 85 |
+
# only pull that step + optional centroid files
|
| 86 |
+
allow = [f"checkpoint_{step}/**", "cluster_centroids.npy", "mean_style_embed.npy"]
|
| 87 |
+
|
| 88 |
+
from huggingface_hub import snapshot_download
|
| 89 |
+
local = snapshot_download(
|
| 90 |
+
repo_id=repo_id,
|
| 91 |
+
repo_type="model",
|
| 92 |
+
local_dir="/home/appuser/.cache/mrt_ckpt/repo",
|
| 93 |
+
local_dir_use_symlinks=False,
|
| 94 |
+
allow_patterns=allow or ["*"], # whole repo if no step provided
|
| 95 |
+
)
|
| 96 |
+
root = Path(local)
|
| 97 |
+
|
| 98 |
+
# If a step is specified, return that subfolder
|
| 99 |
+
if step:
|
| 100 |
+
cand = root / f"checkpoint_{step}"
|
| 101 |
+
if cand.is_dir():
|
| 102 |
+
return str(cand)
|
| 103 |
+
|
| 104 |
+
# Otherwise pick the numerically latest checkpoint_* folder
|
| 105 |
+
step_dirs = [d for d in root.iterdir() if d.is_dir() and re.match(r"checkpoint_\\d+$", d.name)]
|
| 106 |
+
if step_dirs:
|
| 107 |
+
pick = max(step_dirs, key=lambda d: int(d.name.split("_")[-1]))
|
| 108 |
+
return str(pick)
|
| 109 |
+
|
| 110 |
+
# Fallback: repo itself might already be a single checkpoint directory
|
| 111 |
+
return str(root)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
async def send_json_safe(ws: WebSocket, obj) -> bool:
|
| 115 |
"""Try to send. Returns False if the socket is (or becomes) closed."""
|
| 116 |
if ws.client_state == WebSocketState.DISCONNECTED or ws.application_state == WebSocketState.DISCONNECTED:
|
|
|
|
| 614 |
if _MRT is None:
|
| 615 |
with _MRT_LOCK:
|
| 616 |
if _MRT is None:
|
| 617 |
+
ckpt_dir = _resolve_checkpoint_dir() # ← points to checkpoint_1863001
|
| 618 |
+
_MRT = system.MagentaRT(
|
| 619 |
+
tag=os.getenv("MRT_SIZE", "large"), # keep 'large' if finetuned from large
|
| 620 |
+
guidance_weight=5.0,
|
| 621 |
+
device="gpu",
|
| 622 |
+
checkpoint_dir=ckpt_dir, # ← uses your finetune
|
| 623 |
+
lazy=False,
|
| 624 |
+
)
|
| 625 |
return _MRT
|
| 626 |
|
| 627 |
_WARMED = False
|
|
|
|
| 700 |
if os.getenv("MRT_WARMUP", "1") != "0":
|
| 701 |
threading.Thread(target=_mrt_warmup, name="mrt-warmup", daemon=True).start()
|
| 702 |
|
| 703 |
+
@app.get("/model/status")
|
| 704 |
+
def model_status():
|
| 705 |
+
mrt = get_mrt()
|
| 706 |
+
return {
|
| 707 |
+
"tag": getattr(mrt, "_tag", "unknown"),
|
| 708 |
+
"using_checkpoint_dir": True,
|
| 709 |
+
"codec_frame_rate": float(mrt.codec.frame_rate),
|
| 710 |
+
"decoder_rvq_depth": int(mrt.config.decoder_codec_rvq_depth),
|
| 711 |
+
"context_seconds": float(mrt.config.context_length),
|
| 712 |
+
"chunk_seconds": float(mrt.config.chunk_length),
|
| 713 |
+
"crossfade_seconds": float(mrt.config.crossfade_length),
|
| 714 |
+
"selected_step": os.getenv("MRT_CKPT_STEP"),
|
| 715 |
+
"repo": os.getenv("MRT_CKPT_REPO"),
|
| 716 |
+
}
|
| 717 |
+
|
| 718 |
+
@app.post("/model/swap")
|
| 719 |
+
def model_swap(step: int = Form(...)):
|
| 720 |
+
# stop any active jam if you want to be strict (not shown)
|
| 721 |
+
os.environ["MRT_CKPT_STEP"] = str(step)
|
| 722 |
+
global _MRT
|
| 723 |
+
with _MRT_LOCK:
|
| 724 |
+
_MRT = None # force re-create on next get_mrt()
|
| 725 |
+
# optionally pre-warm here by calling get_mrt()
|
| 726 |
+
return {"reloaded": True, "step": step}
|
| 727 |
+
|
| 728 |
@app.post("/generate")
|
| 729 |
def generate(
|
| 730 |
loop_audio: UploadFile = File(...),
|