Commit
·
c1e9a88
1
Parent(s):
5139a47
let's see if base model runs RT on L4
Browse files
app.py
CHANGED
|
@@ -15,7 +15,10 @@ os.environ.setdefault("JAX_CACHE_DIR", "/home/appuser/.cache/jax")
|
|
| 15 |
import jax
|
| 16 |
# ✅ Valid choices include: "default", "high", "highest", "tensorfloat32", "float32", etc.
|
| 17 |
# TF32 is the sweet spot on Ampere/Ada GPUs for ~1.1–1.3× matmul speedups.
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
# Initialize the on-disk compilation cache (best-effort)
|
| 21 |
try:
|
|
@@ -447,7 +450,7 @@ def get_mrt():
|
|
| 447 |
if _MRT is None:
|
| 448 |
with _MRT_LOCK:
|
| 449 |
if _MRT is None:
|
| 450 |
-
_MRT = system.MagentaRT(tag="
|
| 451 |
return _MRT
|
| 452 |
|
| 453 |
_WARMED = False
|
|
|
|
| 15 |
import jax
|
| 16 |
# ✅ Valid choices include: "default", "high", "highest", "tensorfloat32", "float32", etc.
|
| 17 |
# TF32 is the sweet spot on Ampere/Ada GPUs for ~1.1–1.3× matmul speedups.
|
| 18 |
+
try:
|
| 19 |
+
jax.config.update("jax_default_matmul_precision", "tensorfloat32")
|
| 20 |
+
except Exception:
|
| 21 |
+
jax.config.update("jax_default_matmul_precision", "high") # older alias
|
| 22 |
|
| 23 |
# Initialize the on-disk compilation cache (best-effort)
|
| 24 |
try:
|
|
|
|
| 450 |
if _MRT is None:
|
| 451 |
with _MRT_LOCK:
|
| 452 |
if _MRT is None:
|
| 453 |
+
_MRT = system.MagentaRT(tag="base", guidance_weight=5.0, device="gpu", lazy=False)
|
| 454 |
return _MRT
|
| 455 |
|
| 456 |
_WARMED = False
|