Julian Bilcke
commited on
Commit
·
5a793ee
1
Parent(s):
aeb51a1
small fix (not tested yet)
Browse files
vms/ui/project/tabs/manage_tab.py
CHANGED
|
@@ -6,7 +6,7 @@ import gradio as gr
|
|
| 6 |
import logging
|
| 7 |
import shutil
|
| 8 |
from pathlib import Path
|
| 9 |
-
from typing import Dict, Any, List, Optional
|
| 10 |
from gradio_modal import Modal
|
| 11 |
|
| 12 |
from vms.utils import BaseTab, validate_model_repo
|
|
@@ -51,6 +51,17 @@ class ManageTab(BaseTab):
|
|
| 51 |
"""Update the download button text"""
|
| 52 |
return gr.update(value=self.get_download_button_text())
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
def download_and_update_button(self):
|
| 55 |
"""Handle download and return updated button with current text"""
|
| 56 |
# Get the safetensors path for download
|
|
|
|
| 6 |
import logging
|
| 7 |
import shutil
|
| 8 |
from pathlib import Path
|
| 9 |
+
from typing import Dict, Any, List, Optional, Tuple
|
| 10 |
from gradio_modal import Modal
|
| 11 |
|
| 12 |
from vms.utils import BaseTab, validate_model_repo
|
|
|
|
| 51 |
"""Update the download button text"""
|
| 52 |
return gr.update(value=self.get_download_button_text())
|
| 53 |
|
| 54 |
+
def update_checkpoint_button_text(self) -> gr.update:
|
| 55 |
+
"""Update the checkpoint button text"""
|
| 56 |
+
return gr.update(value=self.get_checkpoint_button_text())
|
| 57 |
+
|
| 58 |
+
def update_both_download_buttons(self) -> Tuple[gr.update, gr.update]:
|
| 59 |
+
"""Update both download button texts"""
|
| 60 |
+
return (
|
| 61 |
+
gr.update(value=self.get_download_button_text()),
|
| 62 |
+
gr.update(value=self.get_checkpoint_button_text())
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
def download_and_update_button(self):
|
| 66 |
"""Handle download and return updated button with current text"""
|
| 67 |
# Get the safetensors path for download
|
vms/ui/project/tabs/train_tab.py
CHANGED
|
@@ -341,9 +341,12 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 341 |
## ⚗️ Train your model on your dataset
|
| 342 |
- **🚀 Start new training**: Begins training from scratch (clears previous checkpoints)
|
| 343 |
- **🛸 Start from latest checkpoint**: Continues training from the most recent checkpoint
|
| 344 |
-
- **🔄 Start over using latest LoRA weights**: Start fresh training but use existing LoRA weights as initialization
|
| 345 |
""")
|
| 346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
with gr.Row():
|
| 348 |
# Check for existing checkpoints to determine button text
|
| 349 |
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
|
@@ -485,11 +488,18 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 485 |
self.app.training.append_log("Cleared previous checkpoints for new training session")
|
| 486 |
|
| 487 |
# Start training normally
|
| 488 |
-
|
| 489 |
model_type, model_version, training_type,
|
| 490 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
| 491 |
save_iterations, repo_id, progress
|
| 492 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
|
| 494 |
def handle_resume_training(
|
| 495 |
self, model_type, model_version, training_type,
|
|
@@ -501,17 +511,27 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 501 |
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
| 502 |
|
| 503 |
if not checkpoints:
|
| 504 |
-
|
|
|
|
|
|
|
|
|
|
| 505 |
|
| 506 |
self.app.training.append_log(f"Resuming training from latest checkpoint")
|
| 507 |
|
| 508 |
# Start training with the checkpoint
|
| 509 |
-
|
| 510 |
model_type, model_version, training_type,
|
| 511 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
| 512 |
save_iterations, repo_id, progress,
|
| 513 |
resume_from_checkpoint="latest"
|
| 514 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
|
| 516 |
def handle_start_from_lora_training(
|
| 517 |
self, model_type, model_version, training_type,
|
|
@@ -522,22 +542,26 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 522 |
# Find the latest LoRA weights
|
| 523 |
lora_weights_path = self.app.output_path / "lora_weights"
|
| 524 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 525 |
if not lora_weights_path.exists():
|
| 526 |
-
return "No LoRA weights found", "Please train a model first or start a new training session"
|
| 527 |
|
| 528 |
# Find the latest LoRA checkpoint directory
|
| 529 |
lora_dirs = sorted([d for d in lora_weights_path.iterdir() if d.is_dir()],
|
| 530 |
key=lambda x: int(x.name), reverse=True)
|
| 531 |
|
| 532 |
if not lora_dirs:
|
| 533 |
-
return "No LoRA weight directories found", "Please train a model first or start a new training session"
|
| 534 |
|
| 535 |
latest_lora_dir = lora_dirs[0]
|
| 536 |
|
| 537 |
# Verify the LoRA weights file exists
|
| 538 |
lora_weights_file = latest_lora_dir / "pytorch_lora_weights.safetensors"
|
| 539 |
if not lora_weights_file.exists():
|
| 540 |
-
return f"LoRA weights file not found in {latest_lora_dir}", "Please check your LoRA weights directory"
|
| 541 |
|
| 542 |
# Clear checkpoints to start fresh (but keep LoRA weights)
|
| 543 |
for checkpoint in self.app.output_path.glob("finetrainers_step_*"):
|
|
@@ -552,11 +576,17 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 552 |
self.app.training.append_log(f"Starting training from LoRA weights: {latest_lora_dir}")
|
| 553 |
|
| 554 |
# Start training with the LoRA weights
|
| 555 |
-
|
| 556 |
model_type, model_version, training_type,
|
| 557 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
| 558 |
save_iterations, repo_id, progress,
|
| 559 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 560 |
|
| 561 |
def connect_events(self) -> None:
|
| 562 |
"""Connect event handlers to UI components"""
|
|
@@ -739,7 +769,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 739 |
],
|
| 740 |
outputs=[
|
| 741 |
self.components["status_box"],
|
| 742 |
-
self.components["log_box"]
|
|
|
|
|
|
|
| 743 |
]
|
| 744 |
)
|
| 745 |
|
|
@@ -759,7 +791,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 759 |
],
|
| 760 |
outputs=[
|
| 761 |
self.components["status_box"],
|
| 762 |
-
self.components["log_box"]
|
|
|
|
|
|
|
| 763 |
]
|
| 764 |
)
|
| 765 |
|
|
@@ -779,7 +813,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 779 |
],
|
| 780 |
outputs=[
|
| 781 |
self.components["status_box"],
|
| 782 |
-
self.components["log_box"]
|
|
|
|
|
|
|
| 783 |
]
|
| 784 |
)
|
| 785 |
|
|
@@ -795,7 +831,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 795 |
self.components["current_task_box"],
|
| 796 |
self.components["start_btn"],
|
| 797 |
self.components["stop_btn"],
|
| 798 |
-
third_btn
|
|
|
|
|
|
|
| 799 |
]
|
| 800 |
)
|
| 801 |
|
|
@@ -807,7 +845,9 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 807 |
self.components["current_task_box"],
|
| 808 |
self.components["start_btn"],
|
| 809 |
self.components["stop_btn"],
|
| 810 |
-
third_btn
|
|
|
|
|
|
|
| 811 |
]
|
| 812 |
)
|
| 813 |
|
|
@@ -1200,7 +1240,12 @@ Full finetune mode trains all parameters of the model, requiring more VRAM but p
|
|
| 1200 |
variant="stop"
|
| 1201 |
)
|
| 1202 |
|
| 1203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1204 |
|
| 1205 |
def update_training_ui(self, training_state: Dict[str, Any]):
|
| 1206 |
"""Update UI components based on training state"""
|
|
|
|
| 341 |
## ⚗️ Train your model on your dataset
|
| 342 |
- **🚀 Start new training**: Begins training from scratch (clears previous checkpoints)
|
| 343 |
- **🛸 Start from latest checkpoint**: Continues training from the most recent checkpoint
|
|
|
|
| 344 |
""")
|
| 345 |
+
|
| 346 |
+
#Finetrainers doesn't support recovery of a training session using a LoRA,
|
| 347 |
+
#so this feature doesn't work, I've disabled the line/documentation:
|
| 348 |
+
#- **🔄 Start over using latest LoRA weights**: Start fresh training but use existing LoRA weights as initialization
|
| 349 |
+
|
| 350 |
with gr.Row():
|
| 351 |
# Check for existing checkpoints to determine button text
|
| 352 |
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
|
|
|
| 488 |
self.app.training.append_log("Cleared previous checkpoints for new training session")
|
| 489 |
|
| 490 |
# Start training normally
|
| 491 |
+
status, logs = self.handle_training_start(
|
| 492 |
model_type, model_version, training_type,
|
| 493 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
| 494 |
save_iterations, repo_id, progress
|
| 495 |
)
|
| 496 |
+
|
| 497 |
+
# Update download button texts
|
| 498 |
+
manage_tab = self.app.tabs["manage_tab"]
|
| 499 |
+
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
| 500 |
+
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
| 501 |
+
|
| 502 |
+
return status, logs, download_btn_text, checkpoint_btn_text
|
| 503 |
|
| 504 |
def handle_resume_training(
|
| 505 |
self, model_type, model_version, training_type,
|
|
|
|
| 511 |
checkpoints = list(self.app.output_path.glob("finetrainers_step_*"))
|
| 512 |
|
| 513 |
if not checkpoints:
|
| 514 |
+
manage_tab = self.app.tabs["manage_tab"]
|
| 515 |
+
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
| 516 |
+
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
| 517 |
+
return "No checkpoints found to resume from", "Please start a new training session instead", download_btn_text, checkpoint_btn_text
|
| 518 |
|
| 519 |
self.app.training.append_log(f"Resuming training from latest checkpoint")
|
| 520 |
|
| 521 |
# Start training with the checkpoint
|
| 522 |
+
status, logs = self.handle_training_start(
|
| 523 |
model_type, model_version, training_type,
|
| 524 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
| 525 |
save_iterations, repo_id, progress,
|
| 526 |
resume_from_checkpoint="latest"
|
| 527 |
)
|
| 528 |
+
|
| 529 |
+
# Update download button texts
|
| 530 |
+
manage_tab = self.app.tabs["manage_tab"]
|
| 531 |
+
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
| 532 |
+
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
| 533 |
+
|
| 534 |
+
return status, logs, download_btn_text, checkpoint_btn_text
|
| 535 |
|
| 536 |
def handle_start_from_lora_training(
|
| 537 |
self, model_type, model_version, training_type,
|
|
|
|
| 542 |
# Find the latest LoRA weights
|
| 543 |
lora_weights_path = self.app.output_path / "lora_weights"
|
| 544 |
|
| 545 |
+
manage_tab = self.app.tabs["manage_tab"]
|
| 546 |
+
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
| 547 |
+
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
| 548 |
+
|
| 549 |
if not lora_weights_path.exists():
|
| 550 |
+
return "No LoRA weights found", "Please train a model first or start a new training session", download_btn_text, checkpoint_btn_text
|
| 551 |
|
| 552 |
# Find the latest LoRA checkpoint directory
|
| 553 |
lora_dirs = sorted([d for d in lora_weights_path.iterdir() if d.is_dir()],
|
| 554 |
key=lambda x: int(x.name), reverse=True)
|
| 555 |
|
| 556 |
if not lora_dirs:
|
| 557 |
+
return "No LoRA weight directories found", "Please train a model first or start a new training session", download_btn_text, checkpoint_btn_text
|
| 558 |
|
| 559 |
latest_lora_dir = lora_dirs[0]
|
| 560 |
|
| 561 |
# Verify the LoRA weights file exists
|
| 562 |
lora_weights_file = latest_lora_dir / "pytorch_lora_weights.safetensors"
|
| 563 |
if not lora_weights_file.exists():
|
| 564 |
+
return f"LoRA weights file not found in {latest_lora_dir}", "Please check your LoRA weights directory", download_btn_text, checkpoint_btn_text
|
| 565 |
|
| 566 |
# Clear checkpoints to start fresh (but keep LoRA weights)
|
| 567 |
for checkpoint in self.app.output_path.glob("finetrainers_step_*"):
|
|
|
|
| 576 |
self.app.training.append_log(f"Starting training from LoRA weights: {latest_lora_dir}")
|
| 577 |
|
| 578 |
# Start training with the LoRA weights
|
| 579 |
+
status, logs = self.handle_training_start(
|
| 580 |
model_type, model_version, training_type,
|
| 581 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
| 582 |
save_iterations, repo_id, progress,
|
| 583 |
)
|
| 584 |
+
|
| 585 |
+
# Update download button texts
|
| 586 |
+
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
| 587 |
+
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
| 588 |
+
|
| 589 |
+
return status, logs, download_btn_text, checkpoint_btn_text
|
| 590 |
|
| 591 |
def connect_events(self) -> None:
|
| 592 |
"""Connect event handlers to UI components"""
|
|
|
|
| 769 |
],
|
| 770 |
outputs=[
|
| 771 |
self.components["status_box"],
|
| 772 |
+
self.components["log_box"],
|
| 773 |
+
self.app.tabs["manage_tab"].components["download_model_btn"],
|
| 774 |
+
self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
|
| 775 |
]
|
| 776 |
)
|
| 777 |
|
|
|
|
| 791 |
],
|
| 792 |
outputs=[
|
| 793 |
self.components["status_box"],
|
| 794 |
+
self.components["log_box"],
|
| 795 |
+
self.app.tabs["manage_tab"].components["download_model_btn"],
|
| 796 |
+
self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
|
| 797 |
]
|
| 798 |
)
|
| 799 |
|
|
|
|
| 813 |
],
|
| 814 |
outputs=[
|
| 815 |
self.components["status_box"],
|
| 816 |
+
self.components["log_box"],
|
| 817 |
+
self.app.tabs["manage_tab"].components["download_model_btn"],
|
| 818 |
+
self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
|
| 819 |
]
|
| 820 |
)
|
| 821 |
|
|
|
|
| 831 |
self.components["current_task_box"],
|
| 832 |
self.components["start_btn"],
|
| 833 |
self.components["stop_btn"],
|
| 834 |
+
third_btn,
|
| 835 |
+
self.app.tabs["manage_tab"].components["download_model_btn"],
|
| 836 |
+
self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
|
| 837 |
]
|
| 838 |
)
|
| 839 |
|
|
|
|
| 845 |
self.components["current_task_box"],
|
| 846 |
self.components["start_btn"],
|
| 847 |
self.components["stop_btn"],
|
| 848 |
+
third_btn,
|
| 849 |
+
self.app.tabs["manage_tab"].components["download_model_btn"],
|
| 850 |
+
self.app.tabs["manage_tab"].components["download_checkpoint_btn"]
|
| 851 |
]
|
| 852 |
)
|
| 853 |
|
|
|
|
| 1240 |
variant="stop"
|
| 1241 |
)
|
| 1242 |
|
| 1243 |
+
# Update download button texts
|
| 1244 |
+
manage_tab = self.app.tabs["manage_tab"]
|
| 1245 |
+
download_btn_text = gr.update(value=manage_tab.get_download_button_text())
|
| 1246 |
+
checkpoint_btn_text = gr.update(value=manage_tab.get_checkpoint_button_text())
|
| 1247 |
+
|
| 1248 |
+
return start_btn, resume_btn, stop_btn, delete_checkpoints_btn, download_btn_text, checkpoint_btn_text
|
| 1249 |
|
| 1250 |
def update_training_ui(self, training_state: Dict[str, Any]):
|
| 1251 |
"""Update UI components based on training state"""
|