yeq6x commited on
Commit
0a49e69
·
1 Parent(s): 6fabdaf

Add hyperparameter overriding functionality in _prepare_script of app.py to allow safer and more flexible adjustments of learning rate, network dimension, seed, max training epochs, and save frequency. This enhances configurability for training scripts.

Browse files
Files changed (1) hide show
  1. app.py +21 -0
app.py CHANGED
@@ -345,6 +345,27 @@ def _prepare_script(
345
  if override_seed is not None:
346
  txt = re.sub(r"--seed\s+\d+", f"--seed {override_seed}", txt)
347
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  # Write to a temp file alongside this repo for easier inspection
349
  run_dir = TRAINING_DIR / ".gradio_runs"
350
  run_dir.mkdir(parents=True, exist_ok=True)
 
345
  if override_seed is not None:
346
  txt = re.sub(r"--seed\s+\d+", f"--seed {override_seed}", txt)
347
 
348
+ # Prefer overriding variable definitions at top of script (safer than CLI regex)
349
+ def _set_var(name: str, value: str) -> None:
350
+ nonlocal txt
351
+ pattern = rf"(?m)^\s*{name}\s*=.*$"
352
+ replacement = f'{name}="{value}"' if not str(value).isdigit() else f'{name}={value}'
353
+ if re.search(pattern, txt):
354
+ txt = re.sub(pattern, replacement, txt)
355
+ else:
356
+ txt = f"{replacement}\n" + txt
357
+
358
+ if override_learning_rate:
359
+ _set_var('LEARNING_RATE', override_learning_rate)
360
+ if override_network_dim is not None:
361
+ _set_var('NETWORK_DIM', str(override_network_dim))
362
+ if override_seed is not None:
363
+ _set_var('SEED', str(override_seed))
364
+ if override_max_epochs is not None and override_max_epochs > 0:
365
+ _set_var('MAX_TRAIN_EPOCHS', str(override_max_epochs))
366
+ if override_save_every is not None and override_save_every > 0:
367
+ _set_var('SAVE_EVERY_N_EPOCHS', str(override_save_every))
368
+
369
  # Write to a temp file alongside this repo for easier inspection
370
  run_dir = TRAINING_DIR / ".gradio_runs"
371
  run_dir.mkdir(parents=True, exist_ok=True)