Commit
·
85e8363
1
Parent(s):
0577e3b
new SPACE_MODE env variable so the template can stay up
Browse files
app.py
CHANGED
|
@@ -1,13 +1,21 @@
|
|
| 1 |
import os
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
# Optional: persist JAX compile cache across restarts (reduces warmup time)
|
| 13 |
os.environ.setdefault("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
|
|
@@ -32,7 +40,7 @@ except Exception:
|
|
| 32 |
|
| 33 |
from magenta_rt import system, audio as au
|
| 34 |
import numpy as np
|
| 35 |
-
from fastapi import FastAPI, UploadFile, File, Form, Body, HTTPException, Response, Request, WebSocket, WebSocketDisconnect, Query
|
| 36 |
import tempfile, io, base64, math, threading
|
| 37 |
from fastapi.middleware.cors import CORSMiddleware
|
| 38 |
from contextlib import contextmanager
|
|
@@ -76,6 +84,35 @@ from pydantic import BaseModel
|
|
| 76 |
|
| 77 |
from model_management import CheckpointManager, AssetManager, ModelSelector, ModelSelect
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
# ---- Finetune assets (mean & centroids) --------------------------------------
|
| 80 |
# _FINETUNE_REPO_DEFAULT = os.getenv("MRT_ASSETS_REPO", "thepatch/magenta-ft")
|
| 81 |
_ASSETS_REPO_ID: str | None = None
|
|
@@ -1108,7 +1145,44 @@ def jam_status(session_id: str):
|
|
| 1108 |
|
| 1109 |
@app.get("/health")
|
| 1110 |
def health():
|
| 1111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1112 |
|
| 1113 |
@app.middleware("http")
|
| 1114 |
async def log_requests(request: Request, call_next):
|
|
|
|
| 1 |
import os
|
| 2 |
+
|
| 3 |
+
# ---- Space mode gating (place above any JAX import!) ----
|
| 4 |
+
SPACE_MODE = os.getenv("SPACE_MODE", "serve") # "serve" | "template"
|
| 5 |
+
|
| 6 |
+
if SPACE_MODE != "serve":
|
| 7 |
+
# In template mode, force JAX to CPU so it won't try to load CUDA plugins
|
| 8 |
+
os.environ.setdefault("JAX_PLATFORMS", "cpu")
|
| 9 |
+
else:
|
| 10 |
+
# Only set GPU-friendly XLA flags when we actually intend to serve on GPU
|
| 11 |
+
os.environ.setdefault(
|
| 12 |
+
"XLA_FLAGS",
|
| 13 |
+
" ".join([
|
| 14 |
+
"--xla_gpu_enable_triton_gemm=true",
|
| 15 |
+
"--xla_gpu_enable_latency_hiding_scheduler=true",
|
| 16 |
+
"--xla_gpu_autotune_level=2",
|
| 17 |
+
])
|
| 18 |
+
)
|
| 19 |
|
| 20 |
# Optional: persist JAX compile cache across restarts (reduces warmup time)
|
| 21 |
os.environ.setdefault("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
|
|
|
|
| 40 |
|
| 41 |
from magenta_rt import system, audio as au
|
| 42 |
import numpy as np
|
| 43 |
+
from fastapi import FastAPI, UploadFile, File, Form, Body, HTTPException, Response, Request, WebSocket, WebSocketDisconnect, Query, JSONResponse
|
| 44 |
import tempfile, io, base64, math, threading
|
| 45 |
from fastapi.middleware.cors import CORSMiddleware
|
| 46 |
from contextlib import contextmanager
|
|
|
|
| 84 |
|
| 85 |
from model_management import CheckpointManager, AssetManager, ModelSelector, ModelSelect
|
| 86 |
|
| 87 |
+
def _gpu_probe() -> dict:
|
| 88 |
+
"""
|
| 89 |
+
Returns:
|
| 90 |
+
{
|
| 91 |
+
"ok": bool,
|
| 92 |
+
"backend": str | None, # "gpu" | "cpu" | "tpu" | None
|
| 93 |
+
"has_gpu": bool,
|
| 94 |
+
"devices": list[str], # e.g. ["gpu:0", "gpu:1"]
|
| 95 |
+
"error": str | None,
|
| 96 |
+
}
|
| 97 |
+
"""
|
| 98 |
+
try:
|
| 99 |
+
import jax
|
| 100 |
+
try:
|
| 101 |
+
backend = jax.default_backend() # "gpu", "cpu", "tpu"
|
| 102 |
+
except Exception:
|
| 103 |
+
from jax.lib import xla_bridge
|
| 104 |
+
backend = getattr(xla_bridge.get_backend(), "platform", None)
|
| 105 |
+
|
| 106 |
+
try:
|
| 107 |
+
devices = jax.devices()
|
| 108 |
+
has_gpu = any(getattr(d, "platform", "") in ("gpu", "cuda", "rocm") for d in devices)
|
| 109 |
+
dev_list = [f"{getattr(d, 'platform', '?')}:{getattr(d, 'id', '?')}" for d in devices]
|
| 110 |
+
return {"ok": True, "backend": backend, "has_gpu": has_gpu, "devices": dev_list, "error": None}
|
| 111 |
+
except Exception as e:
|
| 112 |
+
return {"ok": False, "backend": backend, "has_gpu": False, "devices": [], "error": f"jax.devices failed: {e}"}
|
| 113 |
+
except Exception as e:
|
| 114 |
+
return {"ok": False, "backend": None, "has_gpu": False, "devices": [], "error": f"jax import failed: {e}"}
|
| 115 |
+
|
| 116 |
# ---- Finetune assets (mean & centroids) --------------------------------------
|
| 117 |
# _FINETUNE_REPO_DEFAULT = os.getenv("MRT_ASSETS_REPO", "thepatch/magenta-ft")
|
| 118 |
_ASSETS_REPO_ID: str | None = None
|
|
|
|
| 1145 |
|
| 1146 |
@app.get("/health")
|
| 1147 |
def health():
|
| 1148 |
+
# 1) Template mode → not ready (encourage duplication on GPU)
|
| 1149 |
+
if SPACE_MODE != "serve":
|
| 1150 |
+
return JSONResponse(
|
| 1151 |
+
status_code=503,
|
| 1152 |
+
content={
|
| 1153 |
+
"ok": False,
|
| 1154 |
+
"status": "template_mode",
|
| 1155 |
+
"message": "This Space is a GPU template. Duplicate it and select an L40s/A100-class runtime to use the API.",
|
| 1156 |
+
"mode": SPACE_MODE,
|
| 1157 |
+
},
|
| 1158 |
+
)
|
| 1159 |
+
|
| 1160 |
+
# 2) Runtime hardware probe
|
| 1161 |
+
probe = _gpu_probe()
|
| 1162 |
+
if not probe["ok"] or not probe["has_gpu"] or probe.get("backend") != "gpu":
|
| 1163 |
+
return JSONResponse(
|
| 1164 |
+
status_code=503,
|
| 1165 |
+
content={
|
| 1166 |
+
"ok": False,
|
| 1167 |
+
"status": "gpu_unavailable",
|
| 1168 |
+
"message": "GPU is not visible to JAX. Select a GPU runtime (e.g., L40s) to serve.",
|
| 1169 |
+
"probe": probe,
|
| 1170 |
+
"mode": SPACE_MODE,
|
| 1171 |
+
},
|
| 1172 |
+
)
|
| 1173 |
+
|
| 1174 |
+
# 3) Ready; include operational hints
|
| 1175 |
+
warmed = bool(_WARMED)
|
| 1176 |
+
with jam_lock:
|
| 1177 |
+
active_jams = sum(1 for w in jam_registry.values() if w.is_alive())
|
| 1178 |
+
return {
|
| 1179 |
+
"ok": True,
|
| 1180 |
+
"status": "ready" if warmed else "initializing",
|
| 1181 |
+
"mode": SPACE_MODE,
|
| 1182 |
+
"warmed": warmed,
|
| 1183 |
+
"active_jams": active_jams,
|
| 1184 |
+
"probe": probe,
|
| 1185 |
+
}
|
| 1186 |
|
| 1187 |
@app.middleware("http")
|
| 1188 |
async def log_requests(request: Request, call_next):
|