Spaces:
Running
on
Zero
Running
on
Zero
Add hyperparameter overriding functionality in _prepare_script of app.py to allow safer and more flexible adjustments of learning rate, network dimension, seed, max training epochs, and save frequency. This enhances configurability for training scripts.
Browse files
app.py
CHANGED
|
@@ -345,6 +345,27 @@ def _prepare_script(
|
|
| 345 |
if override_seed is not None:
|
| 346 |
txt = re.sub(r"--seed\s+\d+", f"--seed {override_seed}", txt)
|
| 347 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
# Write to a temp file alongside this repo for easier inspection
|
| 349 |
run_dir = TRAINING_DIR / ".gradio_runs"
|
| 350 |
run_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 345 |
if override_seed is not None:
|
| 346 |
txt = re.sub(r"--seed\s+\d+", f"--seed {override_seed}", txt)
|
| 347 |
|
| 348 |
+
# Prefer overriding variable definitions at top of script (safer than CLI regex)
|
| 349 |
+
def _set_var(name: str, value: str) -> None:
|
| 350 |
+
nonlocal txt
|
| 351 |
+
pattern = rf"(?m)^\s*{name}\s*=.*$"
|
| 352 |
+
replacement = f'{name}="{value}"' if not str(value).isdigit() else f'{name}={value}'
|
| 353 |
+
if re.search(pattern, txt):
|
| 354 |
+
txt = re.sub(pattern, replacement, txt)
|
| 355 |
+
else:
|
| 356 |
+
txt = f"{replacement}\n" + txt
|
| 357 |
+
|
| 358 |
+
if override_learning_rate:
|
| 359 |
+
_set_var('LEARNING_RATE', override_learning_rate)
|
| 360 |
+
if override_network_dim is not None:
|
| 361 |
+
_set_var('NETWORK_DIM', str(override_network_dim))
|
| 362 |
+
if override_seed is not None:
|
| 363 |
+
_set_var('SEED', str(override_seed))
|
| 364 |
+
if override_max_epochs is not None and override_max_epochs > 0:
|
| 365 |
+
_set_var('MAX_TRAIN_EPOCHS', str(override_max_epochs))
|
| 366 |
+
if override_save_every is not None and override_save_every > 0:
|
| 367 |
+
_set_var('SAVE_EVERY_N_EPOCHS', str(override_save_every))
|
| 368 |
+
|
| 369 |
# Write to a temp file alongside this repo for easier inspection
|
| 370 |
run_dir = TRAINING_DIR / ".gradio_runs"
|
| 371 |
run_dir.mkdir(parents=True, exist_ok=True)
|