Julian Bilcke
commited on
Commit
·
2264c6e
1
Parent(s):
7bce2a2
tentative fix
Browse files- vms/ui/app_ui.py +2 -3
- vms/ui/project/services/training.py +3 -6
- vms/ui/project/tabs/manage_tab.py +3 -47
- vms/ui/project/tabs/train_tab.py +13 -49
vms/ui/app_ui.py
CHANGED
|
@@ -399,12 +399,11 @@ class AppUI:
|
|
| 399 |
outputs=[
|
| 400 |
self.project_tabs["train_tab"].components["status_box"],
|
| 401 |
self.project_tabs["train_tab"].components["log_box"],
|
| 402 |
-
self.project_tabs["train_tab"].components["current_task_box"] if "current_task_box" in self.project_tabs["train_tab"].components else None
|
| 403 |
-
self.project_tabs["manage_tab"].components["download_model_btn"],
|
| 404 |
-
self.project_tabs["manage_tab"].components["download_checkpoint_btn"]
|
| 405 |
]
|
| 406 |
)
|
| 407 |
|
|
|
|
| 408 |
# Button update timer for button components (every 1 second)
|
| 409 |
button_timer = gr.Timer(value=1)
|
| 410 |
button_outputs = [
|
|
|
|
| 399 |
outputs=[
|
| 400 |
self.project_tabs["train_tab"].components["status_box"],
|
| 401 |
self.project_tabs["train_tab"].components["log_box"],
|
| 402 |
+
self.project_tabs["train_tab"].components["current_task_box"] if "current_task_box" in self.project_tabs["train_tab"].components else None
|
|
|
|
|
|
|
| 403 |
]
|
| 404 |
)
|
| 405 |
|
| 406 |
+
|
| 407 |
# Button update timer for button components (every 1 second)
|
| 408 |
button_timer = gr.Timer(value=1)
|
| 409 |
button_outputs = [
|
vms/ui/project/services/training.py
CHANGED
|
@@ -1823,12 +1823,9 @@ class TrainingService:
|
|
| 1823 |
try:
|
| 1824 |
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
| 1825 |
if not checkpoints:
|
| 1826 |
-
return "
|
| 1827 |
|
| 1828 |
-
|
| 1829 |
-
latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
|
| 1830 |
-
step_num = int(latest_checkpoint.name.split("_")[-1])
|
| 1831 |
-
return f"📥 Download checkpoints (step {step_num})"
|
| 1832 |
except Exception as e:
|
| 1833 |
logger.warning(f"Error getting checkpoint info for button text: {e}")
|
| 1834 |
-
return "
|
|
|
|
| 1823 |
try:
|
| 1824 |
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
| 1825 |
if not checkpoints:
|
| 1826 |
+
return "No checkpoints available"
|
| 1827 |
|
| 1828 |
+
return f"💽 Download checkpoints"
|
|
|
|
|
|
|
|
|
|
| 1829 |
except Exception as e:
|
| 1830 |
logger.warning(f"Error getting checkpoint info for button text: {e}")
|
| 1831 |
+
return "No checkpoints available"
|
vms/ui/project/tabs/manage_tab.py
CHANGED
|
@@ -25,50 +25,6 @@ class ManageTab(BaseTab):
|
|
| 25 |
self.id = "manage_tab"
|
| 26 |
self.title = "5️⃣ Storage"
|
| 27 |
|
| 28 |
-
def get_download_button_text(self) -> str:
|
| 29 |
-
"""Get the dynamic text for the download button based on current model state"""
|
| 30 |
-
try:
|
| 31 |
-
model_info = self.app.training.get_model_output_info()
|
| 32 |
-
if model_info["path"] and model_info["steps"]:
|
| 33 |
-
return f"🧠 Download weights ({model_info['steps']} steps)"
|
| 34 |
-
elif model_info["path"]:
|
| 35 |
-
return "🧠 Download weights (.safetensors)"
|
| 36 |
-
else:
|
| 37 |
-
return "🧠 Download weights (not available)"
|
| 38 |
-
except Exception as e:
|
| 39 |
-
logger.warning(f"Error getting model info for button text: {e}")
|
| 40 |
-
return "🧠 Download weights (.safetensors)"
|
| 41 |
-
|
| 42 |
-
def get_checkpoint_button_text(self) -> str:
|
| 43 |
-
"""Get the dynamic text for the download checkpoint button"""
|
| 44 |
-
try:
|
| 45 |
-
return self.app.training.get_checkpoint_button_text()
|
| 46 |
-
except Exception as e:
|
| 47 |
-
logger.warning(f"Error getting checkpoint button text: {e}")
|
| 48 |
-
return "📥 Download checkpoints (not available)"
|
| 49 |
-
|
| 50 |
-
def update_download_button_text(self) -> gr.update:
|
| 51 |
-
"""Update the download button text"""
|
| 52 |
-
return gr.update(value=self.get_download_button_text())
|
| 53 |
-
|
| 54 |
-
def update_checkpoint_button_text(self) -> gr.update:
|
| 55 |
-
"""Update the checkpoint button text"""
|
| 56 |
-
return gr.update(value=self.get_checkpoint_button_text())
|
| 57 |
-
|
| 58 |
-
def update_both_download_buttons(self) -> Tuple[gr.update, gr.update]:
|
| 59 |
-
"""Update both download button texts"""
|
| 60 |
-
return (
|
| 61 |
-
gr.update(value=self.get_download_button_text()),
|
| 62 |
-
gr.update(value=self.get_checkpoint_button_text())
|
| 63 |
-
)
|
| 64 |
-
|
| 65 |
-
def download_and_update_button(self):
|
| 66 |
-
"""Handle download and return updated button with current text"""
|
| 67 |
-
# Get the safetensors path for download
|
| 68 |
-
path = self.app.training.get_model_output_safetensors()
|
| 69 |
-
# For DownloadButton, we need to return the file path directly for download
|
| 70 |
-
# The button text will be updated on next render
|
| 71 |
-
return path
|
| 72 |
|
| 73 |
def create(self, parent=None) -> gr.TabItem:
|
| 74 |
"""Create the Manage tab UI components"""
|
|
@@ -90,19 +46,19 @@ class ManageTab(BaseTab):
|
|
| 90 |
gr.Markdown("📦 Training dataset download disabled for large datasets")
|
| 91 |
|
| 92 |
self.components["download_model_btn"] = gr.DownloadButton(
|
| 93 |
-
|
| 94 |
variant="secondary",
|
| 95 |
size="lg"
|
| 96 |
)
|
| 97 |
|
| 98 |
self.components["download_checkpoint_btn"] = gr.DownloadButton(
|
| 99 |
-
|
| 100 |
variant="secondary",
|
| 101 |
size="lg"
|
| 102 |
)
|
| 103 |
|
| 104 |
self.components["download_output_btn"] = gr.DownloadButton(
|
| 105 |
-
"📁 Download output
|
| 106 |
variant="secondary",
|
| 107 |
size="lg",
|
| 108 |
visible=False
|
|
|
|
| 25 |
self.id = "manage_tab"
|
| 26 |
self.title = "5️⃣ Storage"
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
def create(self, parent=None) -> gr.TabItem:
|
| 30 |
"""Create the Manage tab UI components"""
|
|
|
|
| 46 |
gr.Markdown("📦 Training dataset download disabled for large datasets")
|
| 47 |
|
| 48 |
self.components["download_model_btn"] = gr.DownloadButton(
|
| 49 |
+
"🧠 Download LoRA weights",
|
| 50 |
variant="secondary",
|
| 51 |
size="lg"
|
| 52 |
)
|
| 53 |
|
| 54 |
self.components["download_checkpoint_btn"] = gr.DownloadButton(
|
| 55 |
+
"💽 Download Checkpoints",
|
| 56 |
variant="secondary",
|
| 57 |
size="lg"
|
| 58 |
)
|
| 59 |
|
| 60 |
self.components["download_output_btn"] = gr.DownloadButton(
|
| 61 |
+
"📁 Download output/ (.zip)",
|
| 62 |
variant="secondary",
|
| 63 |
size="lg",
|
| 64 |
visible=False
|
vms/ui/project/tabs/train_tab.py
CHANGED
|
@@ -494,12 +494,7 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 494 |
save_iterations, repo_id, progress
|
| 495 |
)
|
| 496 |
|
| 497 |
-
|
| 498 |
-
manage_tab = self.app.tabs["manage_tab"]
|
| 499 |
-
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
| 500 |
-
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
| 501 |
-
|
| 502 |
-
return status, logs, download_btn_text, checkpoint_btn_text
|
| 503 |
|
| 504 |
def handle_resume_training(
|
| 505 |
self, model_type, model_version, training_type,
|
|
@@ -511,10 +506,7 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 511 |
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
| 512 |
|
| 513 |
if not checkpoints:
|
| 514 |
-
|
| 515 |
-
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
| 516 |
-
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
| 517 |
-
return "No checkpoints found to resume from", "Please start a new training session instead", download_btn_text, checkpoint_btn_text
|
| 518 |
|
| 519 |
self.app.training.append_log(f"Resuming training from latest checkpoint")
|
| 520 |
|
|
@@ -526,12 +518,7 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 526 |
resume_from_checkpoint="latest"
|
| 527 |
)
|
| 528 |
|
| 529 |
-
|
| 530 |
-
manage_tab = self.app.tabs["manage_tab"]
|
| 531 |
-
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
| 532 |
-
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
| 533 |
-
|
| 534 |
-
return status, logs, download_btn_text, checkpoint_btn_text
|
| 535 |
|
| 536 |
def handle_start_from_lora_training(
|
| 537 |
self, model_type, model_version, training_type,
|
|
@@ -542,26 +529,22 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 542 |
# Find the latest LoRA weights
|
| 543 |
lora_weights_path = self.app.output_path / "lora_weights"
|
| 544 |
|
| 545 |
-
manage_tab = self.app.tabs["manage_tab"]
|
| 546 |
-
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
| 547 |
-
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
| 548 |
-
|
| 549 |
if not lora_weights_path.exists():
|
| 550 |
-
return "No LoRA weights found", "Please train a model first or start a new training session"
|
| 551 |
|
| 552 |
# Find the latest LoRA checkpoint directory
|
| 553 |
lora_dirs = sorted([d for d in lora_weights_path.iterdir() if d.is_dir()],
|
| 554 |
key=lambda x: int(x.name), reverse=True)
|
| 555 |
|
| 556 |
if not lora_dirs:
|
| 557 |
-
return "No LoRA weight directories found", "Please train a model first or start a new training session"
|
| 558 |
|
| 559 |
latest_lora_dir = lora_dirs[0]
|
| 560 |
|
| 561 |
# Verify the LoRA weights file exists
|
| 562 |
lora_weights_file = latest_lora_dir / "pytorch_lora_weights.safetensors"
|
| 563 |
if not lora_weights_file.exists():
|
| 564 |
-
return f"LoRA weights file not found in {latest_lora_dir}", "Please check your LoRA weights directory"
|
| 565 |
|
| 566 |
# Clear checkpoints to start fresh (but keep LoRA weights)
|
| 567 |
for checkpoint in self.app.output_path.glob("finetrainers_step_*"):
|
|
@@ -582,11 +565,7 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 582 |
save_iterations, repo_id, progress,
|
| 583 |
)
|
| 584 |
|
| 585 |
-
|
| 586 |
-
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
| 587 |
-
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
| 588 |
-
|
| 589 |
-
return status, logs, download_btn_text, checkpoint_btn_text
|
| 590 |
|
| 591 |
def connect_events(self) -> None:
|
| 592 |
"""Connect event handlers to UI components"""
|
|
@@ -769,9 +748,7 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 769 |
],
|
| 770 |
outputs=[
|
| 771 |
self.components["status_box"],
|
| 772 |
-
self.components["log_box"]
|
| 773 |
-
self.app.tabs["manage_tab"].components["download_model_btn"],
|
| 774 |
-
self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
|
| 775 |
]
|
| 776 |
)
|
| 777 |
|
|
@@ -791,9 +768,7 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 791 |
],
|
| 792 |
outputs=[
|
| 793 |
self.components["status_box"],
|
| 794 |
-
self.components["log_box"]
|
| 795 |
-
self.app.tabs["manage_tab"].components["download_model_btn"],
|
| 796 |
-
self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
|
| 797 |
]
|
| 798 |
)
|
| 799 |
|
|
@@ -813,9 +788,7 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 813 |
],
|
| 814 |
outputs=[
|
| 815 |
self.components["status_box"],
|
| 816 |
-
self.components["log_box"]
|
| 817 |
-
self.app.tabs["manage_tab"].components["download_model_btn"],
|
| 818 |
-
self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
|
| 819 |
]
|
| 820 |
)
|
| 821 |
|
|
@@ -831,9 +804,7 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 831 |
self.components["current_task_box"],
|
| 832 |
self.components["start_btn"],
|
| 833 |
self.components["stop_btn"],
|
| 834 |
-
third_btn
|
| 835 |
-
self.app.tabs["manage_tab"].components["download_model_btn"],
|
| 836 |
-
self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
|
| 837 |
]
|
| 838 |
)
|
| 839 |
|
|
@@ -845,9 +816,7 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 845 |
self.components["current_task_box"],
|
| 846 |
self.components["start_btn"],
|
| 847 |
self.components["stop_btn"],
|
| 848 |
-
third_btn
|
| 849 |
-
self.app.tabs["manage_tab"].components["download_model_btn"],
|
| 850 |
-
self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
|
| 851 |
]
|
| 852 |
)
|
| 853 |
|
|
@@ -1201,12 +1170,7 @@ Full finetune mode trains all parameters of the model, requiring more VRAM but p
|
|
| 1201 |
if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
|
| 1202 |
current_task = self.app.log_parser.get_current_task_display()
|
| 1203 |
|
| 1204 |
-
|
| 1205 |
-
manage_tab = self.app.tabs["manage_tab"]
|
| 1206 |
-
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
| 1207 |
-
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
| 1208 |
-
|
| 1209 |
-
return message, logs, current_task, download_btn_text, checkpoint_btn_text
|
| 1210 |
|
| 1211 |
def get_button_updates(self):
|
| 1212 |
"""Get button updates (with variant property)"""
|
|
|
|
| 494 |
save_iterations, repo_id, progress
|
| 495 |
)
|
| 496 |
|
| 497 |
+
return status, logs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
|
| 499 |
def handle_resume_training(
|
| 500 |
self, model_type, model_version, training_type,
|
|
|
|
| 506 |
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
| 507 |
|
| 508 |
if not checkpoints:
|
| 509 |
+
return "No checkpoints found to resume from", "Please start a new training session instead"
|
|
|
|
|
|
|
|
|
|
| 510 |
|
| 511 |
self.app.training.append_log(f"Resuming training from latest checkpoint")
|
| 512 |
|
|
|
|
| 518 |
resume_from_checkpoint="latest"
|
| 519 |
)
|
| 520 |
|
| 521 |
+
return status, logs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
|
| 523 |
def handle_start_from_lora_training(
|
| 524 |
self, model_type, model_version, training_type,
|
|
|
|
| 529 |
# Find the latest LoRA weights
|
| 530 |
lora_weights_path = self.app.output_path / "lora_weights"
|
| 531 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
if not lora_weights_path.exists():
|
| 533 |
+
return "No LoRA weights found", "Please train a model first or start a new training session"
|
| 534 |
|
| 535 |
# Find the latest LoRA checkpoint directory
|
| 536 |
lora_dirs = sorted([d for d in lora_weights_path.iterdir() if d.is_dir()],
|
| 537 |
key=lambda x: int(x.name), reverse=True)
|
| 538 |
|
| 539 |
if not lora_dirs:
|
| 540 |
+
return "No LoRA weight directories found", "Please train a model first or start a new training session"
|
| 541 |
|
| 542 |
latest_lora_dir = lora_dirs[0]
|
| 543 |
|
| 544 |
# Verify the LoRA weights file exists
|
| 545 |
lora_weights_file = latest_lora_dir / "pytorch_lora_weights.safetensors"
|
| 546 |
if not lora_weights_file.exists():
|
| 547 |
+
return f"LoRA weights file not found in {latest_lora_dir}", "Please check your LoRA weights directory"
|
| 548 |
|
| 549 |
# Clear checkpoints to start fresh (but keep LoRA weights)
|
| 550 |
for checkpoint in self.app.output_path.glob("finetrainers_step_*"):
|
|
|
|
| 565 |
save_iterations, repo_id, progress,
|
| 566 |
)
|
| 567 |
|
| 568 |
+
return status, logs
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
|
| 570 |
def connect_events(self) -> None:
|
| 571 |
"""Connect event handlers to UI components"""
|
|
|
|
| 748 |
],
|
| 749 |
outputs=[
|
| 750 |
self.components["status_box"],
|
| 751 |
+
self.components["log_box"]
|
|
|
|
|
|
|
| 752 |
]
|
| 753 |
)
|
| 754 |
|
|
|
|
| 768 |
],
|
| 769 |
outputs=[
|
| 770 |
self.components["status_box"],
|
| 771 |
+
self.components["log_box"]
|
|
|
|
|
|
|
| 772 |
]
|
| 773 |
)
|
| 774 |
|
|
|
|
| 788 |
],
|
| 789 |
outputs=[
|
| 790 |
self.components["status_box"],
|
| 791 |
+
self.components["log_box"]
|
|
|
|
|
|
|
| 792 |
]
|
| 793 |
)
|
| 794 |
|
|
|
|
| 804 |
self.components["current_task_box"],
|
| 805 |
self.components["start_btn"],
|
| 806 |
self.components["stop_btn"],
|
| 807 |
+
third_btn
|
|
|
|
|
|
|
| 808 |
]
|
| 809 |
)
|
| 810 |
|
|
|
|
| 816 |
self.components["current_task_box"],
|
| 817 |
self.components["start_btn"],
|
| 818 |
self.components["stop_btn"],
|
| 819 |
+
third_btn
|
|
|
|
|
|
|
| 820 |
]
|
| 821 |
)
|
| 822 |
|
|
|
|
| 1170 |
if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
|
| 1171 |
current_task = self.app.log_parser.get_current_task_display()
|
| 1172 |
|
| 1173 |
+
return message, logs, current_task
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1174 |
|
| 1175 |
def get_button_updates(self):
|
| 1176 |
"""Get button updates (with variant property)"""
|