Commit
·
49b5fee
1
Parent(s):
d373851
full model switching logic
Browse files
app.py
CHANGED
|
@@ -51,7 +51,7 @@ import uuid, threading
|
|
| 51 |
import logging
|
| 52 |
|
| 53 |
import gradio as gr
|
| 54 |
-
from typing import Optional
|
| 55 |
|
| 56 |
|
| 57 |
import json, asyncio, base64
|
|
@@ -68,7 +68,9 @@ except Exception:
|
|
| 68 |
|
| 69 |
import re, tarfile
|
| 70 |
from pathlib import Path
|
| 71 |
-
from huggingface_hub import snapshot_download
|
|
|
|
|
|
|
| 72 |
|
| 73 |
# ---- Finetune assets (mean & centroids) --------------------------------------
|
| 74 |
_FINETUNE_REPO_DEFAULT = os.getenv("MRT_ASSETS_REPO", "thepatch/magenta-ft")
|
|
@@ -76,6 +78,43 @@ _ASSETS_REPO_ID: str | None = None
|
|
| 76 |
_MEAN_EMBED: np.ndarray | None = None # shape (D,) dtype float32
|
| 77 |
_CENTROIDS: np.ndarray | None = None # shape (K, D) dtype float32
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
def _load_finetune_assets_from_hf(repo_id: str | None) -> tuple[bool, str]:
|
| 80 |
"""
|
| 81 |
Download & load mean_style_embed.npy and cluster_centroids.npy from a HF model repo.
|
|
@@ -927,6 +966,151 @@ def model_assets_status():
|
|
| 927 |
"embedding_dim": d,
|
| 928 |
}
|
| 929 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 930 |
@app.post("/generate")
|
| 931 |
def generate(
|
| 932 |
loop_audio: UploadFile = File(...),
|
|
|
|
| 51 |
import logging
|
| 52 |
|
| 53 |
import gradio as gr
|
| 54 |
+
from typing import Optional, Union, Literal
|
| 55 |
|
| 56 |
|
| 57 |
import json, asyncio, base64
|
|
|
|
| 68 |
|
| 69 |
import re, tarfile
|
| 70 |
from pathlib import Path
|
| 71 |
+
from huggingface_hub import snapshot_download, HfApi
|
| 72 |
+
|
| 73 |
+
from pydantic import BaseModel
|
| 74 |
|
| 75 |
# ---- Finetune assets (mean & centroids) --------------------------------------
|
| 76 |
_FINETUNE_REPO_DEFAULT = os.getenv("MRT_ASSETS_REPO", "thepatch/magenta-ft")
|
|
|
|
| 78 |
_MEAN_EMBED: np.ndarray | None = None # shape (D,) dtype float32
|
| 79 |
_CENTROIDS: np.ndarray | None = None # shape (K, D) dtype float32
|
| 80 |
|
| 81 |
+
_STEP_RE = re.compile(r"(?:^|/)checkpoint_(\d+)(?:/|\.tar\.gz|\.tgz)?$")
|
| 82 |
+
|
| 83 |
+
def _list_ckpt_steps(repo_id: str, revision: str = "main") -> list[int]:
|
| 84 |
+
"""
|
| 85 |
+
List available checkpoint steps in a HF model repo without downloading all weights.
|
| 86 |
+
Looks for:
|
| 87 |
+
checkpoint_<step>/
|
| 88 |
+
checkpoint_<step>.tgz | .tar.gz
|
| 89 |
+
archives/checkpoint_<step>.tgz | .tar.gz
|
| 90 |
+
"""
|
| 91 |
+
api = HfApi()
|
| 92 |
+
files = api.list_repo_files(repo_id=repo_id, repo_type="model", revision=revision)
|
| 93 |
+
steps = set()
|
| 94 |
+
for f in files:
|
| 95 |
+
m = _STEP_RE.search(f)
|
| 96 |
+
if m:
|
| 97 |
+
try:
|
| 98 |
+
steps.add(int(m.group(1)))
|
| 99 |
+
except:
|
| 100 |
+
pass
|
| 101 |
+
return sorted(steps)
|
| 102 |
+
|
| 103 |
+
def _step_exists(repo_id: str, revision: str, step: int) -> bool:
|
| 104 |
+
return step in _list_ckpt_steps(repo_id, revision)
|
| 105 |
+
|
| 106 |
+
def _any_jam_running() -> bool:
|
| 107 |
+
with jam_lock:
|
| 108 |
+
return any(w.is_alive() for w in jam_registry.values())
|
| 109 |
+
|
| 110 |
+
def _stop_all_jams(timeout: float = 5.0):
|
| 111 |
+
with jam_lock:
|
| 112 |
+
for sid, w in list(jam_registry.items()):
|
| 113 |
+
if w.is_alive():
|
| 114 |
+
w.stop()
|
| 115 |
+
w.join(timeout=timeout)
|
| 116 |
+
jam_registry.pop(sid, None)
|
| 117 |
+
|
| 118 |
def _load_finetune_assets_from_hf(repo_id: str | None) -> tuple[bool, str]:
|
| 119 |
"""
|
| 120 |
Download & load mean_style_embed.npy and cluster_centroids.npy from a HF model repo.
|
|
|
|
| 966 |
"embedding_dim": d,
|
| 967 |
}
|
| 968 |
|
| 969 |
+
@app.get("/model/config")
|
| 970 |
+
def model_config():
|
| 971 |
+
mrt = None
|
| 972 |
+
try:
|
| 973 |
+
mrt = get_mrt()
|
| 974 |
+
except Exception:
|
| 975 |
+
pass
|
| 976 |
+
return {
|
| 977 |
+
"size": os.getenv("MRT_SIZE", "large"),
|
| 978 |
+
"repo": os.getenv("MRT_CKPT_REPO"),
|
| 979 |
+
"revision": os.getenv("MRT_CKPT_REV", "main"),
|
| 980 |
+
"selected_step": os.getenv("MRT_CKPT_STEP"),
|
| 981 |
+
"resolved_ckpt_dir": _resolve_checkpoint_dir(), # may be None if not yet downloaded
|
| 982 |
+
"loaded": bool(mrt),
|
| 983 |
+
}
|
| 984 |
+
|
| 985 |
+
@app.get("/model/checkpoints")
|
| 986 |
+
def model_checkpoints(repo_id: str, revision: str = "main"):
|
| 987 |
+
steps = _list_ckpt_steps(repo_id, revision)
|
| 988 |
+
return {"repo": repo_id, "revision": revision, "steps": steps, "latest": (steps[-1] if steps else None)}
|
| 989 |
+
|
| 990 |
+
class ModelSelect(BaseModel):
|
| 991 |
+
size: Optional[Literal["base","large"]] = None
|
| 992 |
+
repo_id: Optional[str] = None
|
| 993 |
+
revision: Optional[str] = "main"
|
| 994 |
+
step: Optional[Union[int, str]] = None # allow "latest"
|
| 995 |
+
assets_repo_id: Optional[str] = None # default: follow repo_id
|
| 996 |
+
sync_assets: bool = True # load mean/centroids from repo
|
| 997 |
+
prewarm: bool = False # call get_mrt() to build right away
|
| 998 |
+
stop_active: bool = True # auto-stop jams; else 409
|
| 999 |
+
dry_run: bool = False # validate only, don't swap
|
| 1000 |
+
|
| 1001 |
+
@app.post("/model/select")
|
| 1002 |
+
def model_select(req: ModelSelect):
|
| 1003 |
+
# Resolve desired target config (fall back to current env)
|
| 1004 |
+
cur = {
|
| 1005 |
+
"size": os.getenv("MRT_SIZE", "large"),
|
| 1006 |
+
"repo": os.getenv("MRT_CKPT_REPO"),
|
| 1007 |
+
"rev": os.getenv("MRT_CKPT_REV", "main"),
|
| 1008 |
+
"step": os.getenv("MRT_CKPT_STEP"),
|
| 1009 |
+
"assets": os.getenv("MRT_ASSETS_REPO", _FINETUNE_REPO_DEFAULT),
|
| 1010 |
+
}
|
| 1011 |
+
tgt = {
|
| 1012 |
+
"size": req.size or cur["size"],
|
| 1013 |
+
"repo": req.repo_id or cur["repo"],
|
| 1014 |
+
"rev": (req.revision if req.revision is not None else cur["rev"]),
|
| 1015 |
+
"step": (None if (isinstance(req.step, str) and req.step.lower()=="latest") else (str(req.step) if req.step is not None else cur["step"])),
|
| 1016 |
+
"assets": (req.assets_repo_id or req.repo_id or cur["assets"]),
|
| 1017 |
+
}
|
| 1018 |
+
|
| 1019 |
+
if not tgt["repo"]:
|
| 1020 |
+
raise HTTPException(status_code=400, detail="repo_id is required at least once before selecting 'latest'.")
|
| 1021 |
+
|
| 1022 |
+
# ---- Dry-run validation (no env changes) ----
|
| 1023 |
+
# 1) enumerate steps
|
| 1024 |
+
steps = _list_ckpt_steps(tgt["repo"], tgt["rev"])
|
| 1025 |
+
if not steps:
|
| 1026 |
+
return {"ok": False, "error": f"No checkpoint files found in {tgt['repo']}@{tgt['rev']}", "discovered_steps": steps}
|
| 1027 |
+
|
| 1028 |
+
# 2) choose step
|
| 1029 |
+
chosen_step = int(tgt["step"]) if tgt["step"] is not None else steps[-1]
|
| 1030 |
+
if chosen_step not in steps:
|
| 1031 |
+
return {"ok": False, "error": f"checkpoint_{chosen_step} not present in {tgt['repo']}@{tgt['rev']}", "discovered_steps": steps}
|
| 1032 |
+
|
| 1033 |
+
# 3) optional: quick asset sanity (only list, don’t download)
|
| 1034 |
+
assets_ok = True
|
| 1035 |
+
assets_msg = "skipped"
|
| 1036 |
+
if req.sync_assets:
|
| 1037 |
+
try:
|
| 1038 |
+
# a quick probe: ensure either file exists; if not, allow anyway (assets are optional)
|
| 1039 |
+
api = HfApi()
|
| 1040 |
+
files = set(api.list_repo_files(repo_id=tgt["assets"], repo_type="model"))
|
| 1041 |
+
if ("mean_style_embed.npy" not in files) and ("cluster_centroids.npy" not in files):
|
| 1042 |
+
assets_ok, assets_msg = False, f"No finetune asset files in {tgt['assets']}"
|
| 1043 |
+
else:
|
| 1044 |
+
assets_msg = "found"
|
| 1045 |
+
except Exception as e:
|
| 1046 |
+
assets_ok, assets_msg = False, f"probe failed: {e}"
|
| 1047 |
+
|
| 1048 |
+
preview = {
|
| 1049 |
+
"target_size": tgt["size"],
|
| 1050 |
+
"target_repo": tgt["repo"],
|
| 1051 |
+
"target_revision": tgt["rev"],
|
| 1052 |
+
"target_step": chosen_step,
|
| 1053 |
+
"assets_repo": tgt["assets"] if req.sync_assets else None,
|
| 1054 |
+
"assets_probe": {"ok": assets_ok, "message": assets_msg},
|
| 1055 |
+
"active_jam": _any_jam_running(),
|
| 1056 |
+
}
|
| 1057 |
+
if req.dry_run:
|
| 1058 |
+
return {"ok": True, "dry_run": True, **preview}
|
| 1059 |
+
|
| 1060 |
+
# ---- Enforce jam policy ----
|
| 1061 |
+
if _any_jam_running():
|
| 1062 |
+
if req.stop_active:
|
| 1063 |
+
_stop_all_jams()
|
| 1064 |
+
else:
|
| 1065 |
+
raise HTTPException(status_code=409, detail="A jam is running; retry with stop_active=true")
|
| 1066 |
+
|
| 1067 |
+
# ---- Atomic swap with rollback ----
|
| 1068 |
+
old_env = {
|
| 1069 |
+
"MRT_SIZE": os.getenv("MRT_SIZE"),
|
| 1070 |
+
"MRT_CKPT_REPO": os.getenv("MRT_CKPT_REPO"),
|
| 1071 |
+
"MRT_CKPT_REV": os.getenv("MRT_CKPT_REV"),
|
| 1072 |
+
"MRT_CKPT_STEP": os.getenv("MRT_CKPT_STEP"),
|
| 1073 |
+
"MRT_ASSETS_REPO": os.getenv("MRT_ASSETS_REPO"),
|
| 1074 |
+
}
|
| 1075 |
+
try:
|
| 1076 |
+
os.environ["MRT_SIZE"] = str(tgt["size"])
|
| 1077 |
+
os.environ["MRT_CKPT_REPO"] = str(tgt["repo"])
|
| 1078 |
+
os.environ["MRT_CKPT_REV"] = str(tgt["rev"])
|
| 1079 |
+
os.environ["MRT_CKPT_STEP"] = str(chosen_step)
|
| 1080 |
+
|
| 1081 |
+
if req.sync_assets:
|
| 1082 |
+
os.environ["MRT_ASSETS_REPO"] = str(tgt["assets"])
|
| 1083 |
+
|
| 1084 |
+
# force rebuild
|
| 1085 |
+
global _MRT
|
| 1086 |
+
with _MRT_LOCK:
|
| 1087 |
+
_MRT = None
|
| 1088 |
+
|
| 1089 |
+
# optionally sync+load finetune assets
|
| 1090 |
+
if req.sync_assets:
|
| 1091 |
+
_load_finetune_assets_from_hf(os.getenv("MRT_ASSETS_REPO"))
|
| 1092 |
+
|
| 1093 |
+
# optional pre-warm to amortize JIT
|
| 1094 |
+
if req.prewarm:
|
| 1095 |
+
get_mrt() # triggers snapshot_download/resolve + init
|
| 1096 |
+
|
| 1097 |
+
return {"ok": True, **preview}
|
| 1098 |
+
except Exception as e:
|
| 1099 |
+
# rollback on error
|
| 1100 |
+
for k, v in old_env.items():
|
| 1101 |
+
if v is None:
|
| 1102 |
+
os.environ.pop(k, None)
|
| 1103 |
+
else:
|
| 1104 |
+
os.environ[k] = v
|
| 1105 |
+
with _MRT_LOCK:
|
| 1106 |
+
_MRT = None
|
| 1107 |
+
try:
|
| 1108 |
+
get_mrt()
|
| 1109 |
+
except Exception:
|
| 1110 |
+
pass
|
| 1111 |
+
raise HTTPException(status_code=500, detail=f"Swap failed: {e}")
|
| 1112 |
+
|
| 1113 |
+
|
| 1114 |
@app.post("/generate")
|
| 1115 |
def generate(
|
| 1116 |
loop_audio: UploadFile = File(...),
|