Julian Bilcke
commited on
Commit
·
743eda6
1
Parent(s):
446e79f
improve UI persistence
Browse files- app.py +21 -8
- vms/training_service.py +14 -3
app.py
CHANGED
|
@@ -144,7 +144,9 @@ class VideoTrainerUI:
|
|
| 144 |
"""Load UI state values for initializing form fields"""
|
| 145 |
ui_state = self.trainer.load_ui_state()
|
| 146 |
|
| 147 |
-
#
|
|
|
|
|
|
|
| 148 |
ui_state["num_epochs"] = int(ui_state.get("num_epochs", 70))
|
| 149 |
ui_state["batch_size"] = int(ui_state.get("batch_size", 1))
|
| 150 |
ui_state["learning_rate"] = float(ui_state.get("learning_rate", 3e-5))
|
|
@@ -866,9 +868,12 @@ class VideoTrainerUI:
|
|
| 866 |
)
|
| 867 |
|
| 868 |
def update_training_params(self, preset_name: str) -> Tuple:
|
| 869 |
-
"""Update UI components based on selected preset"""
|
| 870 |
preset = TRAINING_PRESETS[preset_name]
|
| 871 |
|
|
|
|
|
|
|
|
|
|
| 872 |
# Find the display name that maps to our model type
|
| 873 |
model_display_name = next(
|
| 874 |
key for key, value in MODEL_TYPES.items()
|
|
@@ -888,14 +893,22 @@ class VideoTrainerUI:
|
|
| 888 |
info_text = f"{description}{bucket_info}"
|
| 889 |
|
| 890 |
# Return values in the same order as the output components
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 891 |
return (
|
| 892 |
model_display_name,
|
| 893 |
-
|
| 894 |
-
|
| 895 |
-
|
| 896 |
-
|
| 897 |
-
|
| 898 |
-
|
| 899 |
info_text
|
| 900 |
)
|
| 901 |
|
|
|
|
| 144 |
"""Load UI state values for initializing form fields"""
|
| 145 |
ui_state = self.trainer.load_ui_state()
|
| 146 |
|
| 147 |
+
# Ensure proper type conversion for numeric values
|
| 148 |
+
ui_state["lora_rank"] = ui_state.get("lora_rank", "128")
|
| 149 |
+
ui_state["lora_alpha"] = ui_state.get("lora_alpha", "128")
|
| 150 |
ui_state["num_epochs"] = int(ui_state.get("num_epochs", 70))
|
| 151 |
ui_state["batch_size"] = int(ui_state.get("batch_size", 1))
|
| 152 |
ui_state["learning_rate"] = float(ui_state.get("learning_rate", 3e-5))
|
|
|
|
| 868 |
)
|
| 869 |
|
| 870 |
def update_training_params(self, preset_name: str) -> Tuple:
|
| 871 |
+
"""Update UI components based on selected preset while preserving custom settings"""
|
| 872 |
preset = TRAINING_PRESETS[preset_name]
|
| 873 |
|
| 874 |
+
# Load current UI state to check if user has customized values
|
| 875 |
+
current_state = self.load_ui_values()
|
| 876 |
+
|
| 877 |
# Find the display name that maps to our model type
|
| 878 |
model_display_name = next(
|
| 879 |
key for key, value in MODEL_TYPES.items()
|
|
|
|
| 893 |
info_text = f"{description}{bucket_info}"
|
| 894 |
|
| 895 |
# Return values in the same order as the output components
|
| 896 |
+
# Use preset defaults but preserve user-modified values if they exist
|
| 897 |
+
lora_rank_val = current_state.get("lora_rank") if current_state.get("lora_rank") != preset.get("lora_rank", "128") else preset["lora_rank"]
|
| 898 |
+
lora_alpha_val = current_state.get("lora_alpha") if current_state.get("lora_alpha") != preset.get("lora_alpha", "128") else preset["lora_alpha"]
|
| 899 |
+
num_epochs_val = current_state.get("num_epochs") if current_state.get("num_epochs") != preset.get("num_epochs", 70) else preset["num_epochs"]
|
| 900 |
+
batch_size_val = current_state.get("batch_size") if current_state.get("batch_size") != preset.get("batch_size", 1) else preset["batch_size"]
|
| 901 |
+
learning_rate_val = current_state.get("learning_rate") if current_state.get("learning_rate") != preset.get("learning_rate", 3e-5) else preset["learning_rate"]
|
| 902 |
+
save_iterations_val = current_state.get("save_iterations") if current_state.get("save_iterations") != preset.get("save_iterations", 500) else preset["save_iterations"]
|
| 903 |
+
|
| 904 |
return (
|
| 905 |
model_display_name,
|
| 906 |
+
lora_rank_val,
|
| 907 |
+
lora_alpha_val,
|
| 908 |
+
num_epochs_val,
|
| 909 |
+
batch_size_val,
|
| 910 |
+
learning_rate_val,
|
| 911 |
+
save_iterations_val,
|
| 912 |
info_text
|
| 913 |
)
|
| 914 |
|
vms/training_service.py
CHANGED
|
@@ -114,19 +114,30 @@ class TrainingService:
|
|
| 114 |
"model_type": list(MODEL_TYPES.keys())[0],
|
| 115 |
"lora_rank": "128",
|
| 116 |
"lora_alpha": "128",
|
| 117 |
-
"num_epochs":
|
| 118 |
"batch_size": 1,
|
| 119 |
"learning_rate": 3e-5,
|
| 120 |
-
"save_iterations":
|
| 121 |
"training_preset": list(TRAINING_PRESETS.keys())[0]
|
| 122 |
}
|
| 123 |
|
| 124 |
if not ui_state_file.exists():
|
| 125 |
return default_state
|
| 126 |
-
|
| 127 |
try:
|
| 128 |
with open(ui_state_file, 'r') as f:
|
| 129 |
saved_state = json.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
# Make sure we have all keys (in case structure changed)
|
| 131 |
merged_state = default_state.copy()
|
| 132 |
merged_state.update(saved_state)
|
|
|
|
| 114 |
"model_type": list(MODEL_TYPES.keys())[0],
|
| 115 |
"lora_rank": "128",
|
| 116 |
"lora_alpha": "128",
|
| 117 |
+
"num_epochs": 50,
|
| 118 |
"batch_size": 1,
|
| 119 |
"learning_rate": 3e-5,
|
| 120 |
+
"save_iterations": 200,
|
| 121 |
"training_preset": list(TRAINING_PRESETS.keys())[0]
|
| 122 |
}
|
| 123 |
|
| 124 |
if not ui_state_file.exists():
|
| 125 |
return default_state
|
| 126 |
+
|
| 127 |
try:
|
| 128 |
with open(ui_state_file, 'r') as f:
|
| 129 |
saved_state = json.load(f)
|
| 130 |
+
|
| 131 |
+
# Convert numeric values to appropriate types
|
| 132 |
+
if "num_epochs" in saved_state:
|
| 133 |
+
saved_state["num_epochs"] = int(saved_state["num_epochs"])
|
| 134 |
+
if "batch_size" in saved_state:
|
| 135 |
+
saved_state["batch_size"] = int(saved_state["batch_size"])
|
| 136 |
+
if "learning_rate" in saved_state:
|
| 137 |
+
saved_state["learning_rate"] = float(saved_state["learning_rate"])
|
| 138 |
+
if "save_iterations" in saved_state:
|
| 139 |
+
saved_state["save_iterations"] = int(saved_state["save_iterations"])
|
| 140 |
+
|
| 141 |
# Make sure we have all keys (in case structure changed)
|
| 142 |
merged_state = default_state.copy()
|
| 143 |
merged_state.update(saved_state)
|