Commit
·
f0823f1
1
Parent(s):
5b910df
finetune assets loading fix
Browse files
app.py
CHANGED
|
@@ -134,11 +134,12 @@ _CENTROIDS: np.ndarray | None = None # shape (K, D) dtype float32
|
|
| 134 |
asset_manager = AssetManager()
|
| 135 |
model_selector = ModelSelector(CheckpointManager(), asset_manager)
|
| 136 |
|
| 137 |
-
|
| 138 |
-
#
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
|
|
|
| 142 |
|
| 143 |
def _any_jam_running() -> bool:
|
| 144 |
with jam_lock:
|
|
@@ -335,15 +336,20 @@ def get_mrt():
|
|
| 335 |
if _MRT is None:
|
| 336 |
with _MRT_LOCK:
|
| 337 |
if _MRT is None:
|
| 338 |
-
|
| 339 |
-
ckpt_dir = CheckpointManager.resolve_checkpoint_dir() # ← Updated call
|
| 340 |
_MRT = system.MagentaRT(
|
| 341 |
tag=os.getenv("MRT_SIZE", "large"),
|
| 342 |
guidance_weight=5.0,
|
| 343 |
device="gpu",
|
| 344 |
checkpoint_dir=ckpt_dir,
|
| 345 |
-
lazy=False
|
| 346 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
return _MRT
|
| 348 |
|
| 349 |
_WARMED = False
|
|
@@ -420,9 +426,18 @@ def _mrt_warmup():
|
|
| 420 |
# startup and model selection
|
| 421 |
# ----------------------------
|
| 422 |
|
| 423 |
-
# Kick it off in the background on server start
|
| 424 |
@app.on_event("startup")
|
| 425 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
if os.getenv("MRT_WARMUP", "1") != "0":
|
| 427 |
threading.Thread(target=_mrt_warmup, name="mrt-warmup", daemon=True).start()
|
| 428 |
|
|
@@ -556,6 +571,8 @@ def model_select(req: ModelSelect):
|
|
| 556 |
if "error" in validation_result:
|
| 557 |
raise HTTPException(status_code=400, detail=validation_result["error"])
|
| 558 |
return {"ok": False, **validation_result}
|
|
|
|
|
|
|
| 559 |
|
| 560 |
# Augment response surface
|
| 561 |
validation_result["active_jam"] = _any_jam_running()
|
|
@@ -563,6 +580,10 @@ def model_select(req: ModelSelect):
|
|
| 563 |
# Dry-run path
|
| 564 |
if req.dry_run:
|
| 565 |
return {"ok": True, "dry_run": True, **validation_result}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
|
| 567 |
# 2) Handle jam policy
|
| 568 |
if _any_jam_running():
|
|
|
|
| 134 |
asset_manager = AssetManager()
|
| 135 |
model_selector = ModelSelector(CheckpointManager(), asset_manager)
|
| 136 |
|
| 137 |
+
def _sync_assets_globals_from_manager():
|
| 138 |
+
# Keeps /model/config in sync with what the asset manager has
|
| 139 |
+
global _MEAN_EMBED, _CENTROIDS, _ASSETS_REPO_ID
|
| 140 |
+
_MEAN_EMBED = asset_manager.mean_embed
|
| 141 |
+
_CENTROIDS = asset_manager.centroids
|
| 142 |
+
_ASSETS_REPO_ID = asset_manager.assets_repo_id
|
| 143 |
|
| 144 |
def _any_jam_running() -> bool:
|
| 145 |
with jam_lock:
|
|
|
|
| 336 |
if _MRT is None:
|
| 337 |
with _MRT_LOCK:
|
| 338 |
if _MRT is None:
|
| 339 |
+
ckpt_dir = CheckpointManager.resolve_checkpoint_dir() # uses MRT_CKPT_REPO/STEP if present
|
|
|
|
| 340 |
_MRT = system.MagentaRT(
|
| 341 |
tag=os.getenv("MRT_SIZE", "large"),
|
| 342 |
guidance_weight=5.0,
|
| 343 |
device="gpu",
|
| 344 |
checkpoint_dir=ckpt_dir,
|
| 345 |
+
lazy=False
|
| 346 |
)
|
| 347 |
+
# If no assets loaded yet, and a repo is configured, load them now.
|
| 348 |
+
if asset_manager.mean_embed is None and asset_manager.centroids is None:
|
| 349 |
+
repo = os.getenv("MRT_ASSETS_REPO") or os.getenv("MRT_CKPT_REPO")
|
| 350 |
+
if repo:
|
| 351 |
+
asset_manager.load_finetune_assets_from_hf(repo, None)
|
| 352 |
+
_sync_assets_globals_from_manager()
|
| 353 |
return _MRT
|
| 354 |
|
| 355 |
_WARMED = False
|
|
|
|
| 426 |
# startup and model selection
|
| 427 |
# ----------------------------
|
| 428 |
|
|
|
|
| 429 |
@app.on_event("startup")
|
| 430 |
+
def _boot():
|
| 431 |
+
# 1) Load finetune assets up front (only if envs are present)
|
| 432 |
+
repo = os.getenv("MRT_ASSETS_REPO") or os.getenv("MRT_CKPT_REPO")
|
| 433 |
+
if repo:
|
| 434 |
+
ok, msg = asset_manager.load_finetune_assets_from_hf(repo, None)
|
| 435 |
+
_sync_assets_globals_from_manager() # keep /model/config in sync
|
| 436 |
+
logging.info("Startup asset load from %s: %s", repo, "ok" if ok else msg)
|
| 437 |
+
else:
|
| 438 |
+
logging.info("Startup asset load: no repo env set; skipping.")
|
| 439 |
+
|
| 440 |
+
# 2) Start warmup in the background (unchanged behavior)
|
| 441 |
if os.getenv("MRT_WARMUP", "1") != "0":
|
| 442 |
threading.Thread(target=_mrt_warmup, name="mrt-warmup", daemon=True).start()
|
| 443 |
|
|
|
|
| 571 |
if "error" in validation_result:
|
| 572 |
raise HTTPException(status_code=400, detail=validation_result["error"])
|
| 573 |
return {"ok": False, **validation_result}
|
| 574 |
+
|
| 575 |
+
|
| 576 |
|
| 577 |
# Augment response surface
|
| 578 |
validation_result["active_jam"] = _any_jam_running()
|
|
|
|
| 580 |
# Dry-run path
|
| 581 |
if req.dry_run:
|
| 582 |
return {"ok": True, "dry_run": True, **validation_result}
|
| 583 |
+
|
| 584 |
+
if req.ckpt_step == "none": # user asked for stock base
|
| 585 |
+
asset_manager.clear_assets() # implement .clear_assets() to set embeds/centroids to None
|
| 586 |
+
_sync_assets_globals_from_manager()
|
| 587 |
|
| 588 |
# 2) Handle jam policy
|
| 589 |
if _any_jam_running():
|