Julian Bilcke
commited on
Commit
·
892fa67
1
Parent(s):
d78dede
working on fixes for session recovery
Browse files- vms/services/trainer.py +179 -93
- vms/ui/video_trainer_ui.py +56 -11
vms/services/trainer.py
CHANGED
|
@@ -361,8 +361,14 @@ class TrainingService:
|
|
| 361 |
if model_type not in MODEL_TYPES.values():
|
| 362 |
raise ValueError(f"Invalid model_type: {model_type}. Must be one of {list(MODEL_TYPES.values())}")
|
| 363 |
|
| 364 |
-
|
| 365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
|
| 367 |
try:
|
| 368 |
# Get absolute paths
|
|
@@ -395,7 +401,7 @@ class TrainingService:
|
|
| 395 |
return error_msg, "No training data available"
|
| 396 |
|
| 397 |
|
| 398 |
-
|
| 399 |
preset = TRAINING_PRESETS[preset_name]
|
| 400 |
training_buckets = preset["training_buckets"]
|
| 401 |
|
|
@@ -524,13 +530,12 @@ class TrainingService:
|
|
| 524 |
return success_msg, self.get_logs()
|
| 525 |
|
| 526 |
except Exception as e:
|
| 527 |
-
error_msg = f"Error starting training: {str(e)}"
|
| 528 |
self.append_log(error_msg)
|
| 529 |
logger.exception("Training startup failed")
|
| 530 |
-
traceback.print_exc()
|
| 531 |
-
return "Error starting training", error_msg
|
| 532 |
-
|
| 533 |
-
|
| 534 |
def stop_training(self) -> Tuple[str, str]:
|
| 535 |
"""Stop training process"""
|
| 536 |
if not self.pid_file.exists():
|
|
@@ -631,123 +636,204 @@ class TrainingService:
|
|
| 631 |
status = self.get_status()
|
| 632 |
ui_updates = {}
|
| 633 |
|
| 634 |
-
#
|
| 635 |
-
|
| 636 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 637 |
|
| 638 |
# Get the latest checkpoint
|
| 639 |
last_session = self.load_session()
|
|
|
|
| 640 |
if not last_session:
|
| 641 |
-
logger.warning("No session data found for recovery")
|
| 642 |
-
#
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 653 |
logger.warning("No checkpoints found for recovery")
|
| 654 |
# Set buttons for no active training
|
| 655 |
ui_updates = {
|
| 656 |
-
"start_btn": {"interactive": True, "variant": "primary"},
|
| 657 |
-
"stop_btn": {"interactive": False, "variant": "secondary"},
|
| 658 |
-
"pause_resume_btn": {"interactive": False, "variant": "secondary"}
|
| 659 |
}
|
| 660 |
return {"status": "error", "message": "No checkpoints found", "ui_updates": ui_updates}
|
| 661 |
-
|
| 662 |
-
latest_checkpoint = max(checkpoints, key=os.path.getmtime)
|
| 663 |
-
checkpoint_step = int(latest_checkpoint.name.split("-")[1])
|
| 664 |
-
|
| 665 |
-
logger.info(f"Found checkpoint at step {checkpoint_step}, attempting to resume")
|
| 666 |
|
| 667 |
# Extract parameters from the saved session (not current UI state)
|
| 668 |
# This ensures we use the original training parameters
|
| 669 |
params = last_session.get('params', {})
|
| 670 |
-
initial_ui_state = last_session.get('initial_ui_state', {})
|
| 671 |
|
| 672 |
# Add UI updates to restore the training parameters in the UI
|
| 673 |
# This shows the user what values are being used for the resumed training
|
| 674 |
ui_updates.update({
|
| 675 |
-
"model_type":
|
| 676 |
-
"lora_rank":
|
| 677 |
-
"lora_alpha":
|
| 678 |
-
"num_epochs":
|
| 679 |
-
"batch_size":
|
| 680 |
-
"learning_rate":
|
| 681 |
-
"save_iterations":
|
| 682 |
-
"training_preset":
|
| 683 |
})
|
| 684 |
|
| 685 |
-
#
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 728 |
ui_updates.update({
|
| 729 |
-
"start_btn": {"interactive": True, "variant": "primary"},
|
| 730 |
-
"stop_btn": {"interactive": False, "variant": "secondary"},
|
| 731 |
-
"pause_resume_btn": {"interactive": False, "variant": "secondary"}
|
| 732 |
})
|
| 733 |
-
return {"status": "
|
|
|
|
| 734 |
elif self.is_training_running():
|
| 735 |
# Process is still running, set buttons accordingly
|
| 736 |
ui_updates = {
|
| 737 |
-
"start_btn": {"interactive": False, "variant": "secondary"},
|
| 738 |
-
"stop_btn": {"interactive": True, "variant": "
|
| 739 |
-
"pause_resume_btn": {"interactive":
|
| 740 |
}
|
| 741 |
return {"status": "running", "message": "Training process is running", "ui_updates": ui_updates}
|
| 742 |
else:
|
| 743 |
# No training process, set buttons to default state
|
|
|
|
| 744 |
ui_updates = {
|
| 745 |
-
"start_btn": {"interactive": True, "variant": "primary"},
|
| 746 |
-
"stop_btn": {"interactive": False, "variant": "secondary"},
|
| 747 |
-
"pause_resume_btn": {"interactive": False, "variant": "secondary"}
|
| 748 |
}
|
| 749 |
return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 750 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 751 |
def clear_training_data(self) -> str:
|
| 752 |
"""Clear all training data"""
|
| 753 |
if self.is_training_running():
|
|
|
|
| 361 |
if model_type not in MODEL_TYPES.values():
|
| 362 |
raise ValueError(f"Invalid model_type: {model_type}. Must be one of {list(MODEL_TYPES.values())}")
|
| 363 |
|
| 364 |
+
# Check if we're resuming or starting new
|
| 365 |
+
is_resuming = resume_from_checkpoint is not None
|
| 366 |
+
log_prefix = "Resuming" if is_resuming else "Initializing"
|
| 367 |
+
logger.info(f"{log_prefix} training with model_type={model_type}")
|
| 368 |
+
self.append_log(f"{log_prefix} training with model_type={model_type}")
|
| 369 |
+
|
| 370 |
+
if is_resuming:
|
| 371 |
+
self.append_log(f"Resuming from checkpoint: {resume_from_checkpoint}")
|
| 372 |
|
| 373 |
try:
|
| 374 |
# Get absolute paths
|
|
|
|
| 401 |
return error_msg, "No training data available"
|
| 402 |
|
| 403 |
|
| 404 |
+
# Get preset configuration
|
| 405 |
preset = TRAINING_PRESETS[preset_name]
|
| 406 |
training_buckets = preset["training_buckets"]
|
| 407 |
|
|
|
|
| 530 |
return success_msg, self.get_logs()
|
| 531 |
|
| 532 |
except Exception as e:
|
| 533 |
+
error_msg = f"Error {'resuming' if is_resuming else 'starting'} training: {str(e)}"
|
| 534 |
self.append_log(error_msg)
|
| 535 |
logger.exception("Training startup failed")
|
| 536 |
+
traceback.print_exc()
|
| 537 |
+
return f"Error {'resuming' if is_resuming else 'starting'} training", error_msg
|
| 538 |
+
|
|
|
|
| 539 |
def stop_training(self) -> Tuple[str, str]:
|
| 540 |
"""Stop training process"""
|
| 541 |
if not self.pid_file.exists():
|
|
|
|
| 636 |
status = self.get_status()
|
| 637 |
ui_updates = {}
|
| 638 |
|
| 639 |
+
# Check for any checkpoints, even if status doesn't indicate training
|
| 640 |
+
checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
|
| 641 |
+
has_checkpoints = len(checkpoints) > 0
|
| 642 |
+
|
| 643 |
+
# If status indicates training but process isn't running, or if we have checkpoints
|
| 644 |
+
# and no active training process, try to recover
|
| 645 |
+
if (status.get('status') in ['training', 'paused'] and not self.is_training_running()) or \
|
| 646 |
+
(has_checkpoints and not self.is_training_running()):
|
| 647 |
+
|
| 648 |
+
logger.info("Detected interrupted training session or existing checkpoints, attempting to recover...")
|
| 649 |
|
| 650 |
# Get the latest checkpoint
|
| 651 |
last_session = self.load_session()
|
| 652 |
+
|
| 653 |
if not last_session:
|
| 654 |
+
logger.warning("No session data found for recovery, but will check for checkpoints")
|
| 655 |
+
# Try to create a default session based on UI state if we have checkpoints
|
| 656 |
+
if has_checkpoints:
|
| 657 |
+
ui_state = self.load_ui_state()
|
| 658 |
+
# Create a default session using UI state values
|
| 659 |
+
last_session = {
|
| 660 |
+
"params": {
|
| 661 |
+
"model_type": MODEL_TYPES.get(ui_state.get("model_type", list(MODEL_TYPES.keys())[0])),
|
| 662 |
+
"lora_rank": ui_state.get("lora_rank", "128"),
|
| 663 |
+
"lora_alpha": ui_state.get("lora_alpha", "128"),
|
| 664 |
+
"num_epochs": ui_state.get("num_epochs", 70),
|
| 665 |
+
"batch_size": ui_state.get("batch_size", 1),
|
| 666 |
+
"learning_rate": ui_state.get("learning_rate", 3e-5),
|
| 667 |
+
"save_iterations": ui_state.get("save_iterations", 500),
|
| 668 |
+
"preset_name": ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
|
| 669 |
+
"repo_id": "" # Default empty repo ID
|
| 670 |
+
}
|
| 671 |
+
}
|
| 672 |
+
logger.info("Created default session from UI state for recovery")
|
| 673 |
+
else:
|
| 674 |
+
# Set buttons for no active training
|
| 675 |
+
ui_updates = {
|
| 676 |
+
"start_btn": {"interactive": True, "variant": "primary", "value": "Start Training"},
|
| 677 |
+
"stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
|
| 678 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
|
| 679 |
+
}
|
| 680 |
+
return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
|
| 681 |
+
|
| 682 |
+
# Find the latest checkpoint if we have checkpoints
|
| 683 |
+
latest_checkpoint = None
|
| 684 |
+
checkpoint_step = 0
|
| 685 |
+
|
| 686 |
+
if has_checkpoints:
|
| 687 |
+
latest_checkpoint = max(checkpoints, key=os.path.getmtime)
|
| 688 |
+
checkpoint_step = int(latest_checkpoint.name.split("-")[1])
|
| 689 |
+
logger.info(f"Found checkpoint at step {checkpoint_step}")
|
| 690 |
+
else:
|
| 691 |
logger.warning("No checkpoints found for recovery")
|
| 692 |
# Set buttons for no active training
|
| 693 |
ui_updates = {
|
| 694 |
+
"start_btn": {"interactive": True, "variant": "primary", "value": "Start Training"},
|
| 695 |
+
"stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
|
| 696 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
|
| 697 |
}
|
| 698 |
return {"status": "error", "message": "No checkpoints found", "ui_updates": ui_updates}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 699 |
|
| 700 |
# Extract parameters from the saved session (not current UI state)
|
| 701 |
# This ensures we use the original training parameters
|
| 702 |
params = last_session.get('params', {})
|
|
|
|
| 703 |
|
| 704 |
# Add UI updates to restore the training parameters in the UI
|
| 705 |
# This shows the user what values are being used for the resumed training
|
| 706 |
ui_updates.update({
|
| 707 |
+
"model_type": params.get('model_type', list(MODEL_TYPES.keys())[0]),
|
| 708 |
+
"lora_rank": params.get('lora_rank', "128"),
|
| 709 |
+
"lora_alpha": params.get('lora_alpha', "128"),
|
| 710 |
+
"num_epochs": params.get('num_epochs', 70),
|
| 711 |
+
"batch_size": params.get('batch_size', 1),
|
| 712 |
+
"learning_rate": params.get('learning_rate', 3e-5),
|
| 713 |
+
"save_iterations": params.get('save_iterations', 500),
|
| 714 |
+
"training_preset": params.get('preset_name', list(TRAINING_PRESETS.keys())[0])
|
| 715 |
})
|
| 716 |
|
| 717 |
+
# Check if we should auto-recover (immediate restart)
|
| 718 |
+
auto_recover = True # Always auto-recover on startup
|
| 719 |
+
|
| 720 |
+
if auto_recover:
|
| 721 |
+
# Attempt to resume training using the ORIGINAL parameters
|
| 722 |
+
try:
|
| 723 |
+
# Extract required parameters from the session
|
| 724 |
+
model_type = params.get('model_type')
|
| 725 |
+
lora_rank = params.get('lora_rank')
|
| 726 |
+
lora_alpha = params.get('lora_alpha')
|
| 727 |
+
num_epochs = params.get('num_epochs')
|
| 728 |
+
batch_size = params.get('batch_size')
|
| 729 |
+
learning_rate = params.get('learning_rate')
|
| 730 |
+
save_iterations = params.get('save_iterations')
|
| 731 |
+
repo_id = params.get('repo_id', '')
|
| 732 |
+
preset_name = params.get('preset_name', list(TRAINING_PRESETS.keys())[0])
|
| 733 |
+
|
| 734 |
+
# Log the recovery attempt
|
| 735 |
+
self.append_log(f"Auto-recovering training from checkpoint {checkpoint_step}")
|
| 736 |
+
gr.Info(f"Automatically resuming training from checkpoint {checkpoint_step}")
|
| 737 |
+
|
| 738 |
+
# Attempt to resume training
|
| 739 |
+
result = self.start_training(
|
| 740 |
+
model_type=model_type,
|
| 741 |
+
lora_rank=lora_rank,
|
| 742 |
+
lora_alpha=lora_alpha,
|
| 743 |
+
num_epochs=num_epochs,
|
| 744 |
+
batch_size=batch_size,
|
| 745 |
+
learning_rate=learning_rate,
|
| 746 |
+
save_iterations=save_iterations,
|
| 747 |
+
repo_id=repo_id,
|
| 748 |
+
preset_name=preset_name,
|
| 749 |
+
resume_from_checkpoint=str(latest_checkpoint)
|
| 750 |
+
)
|
| 751 |
+
|
| 752 |
+
# Set buttons for active training
|
| 753 |
+
ui_updates.update({
|
| 754 |
+
"start_btn": {"interactive": False, "variant": "secondary", "value": "Continue Training"},
|
| 755 |
+
"stop_btn": {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"},
|
| 756 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
|
| 757 |
+
})
|
| 758 |
+
|
| 759 |
+
return {
|
| 760 |
+
"status": "recovered",
|
| 761 |
+
"message": f"Training resumed from checkpoint {checkpoint_step}",
|
| 762 |
+
"result": result,
|
| 763 |
+
"ui_updates": ui_updates
|
| 764 |
+
}
|
| 765 |
+
except Exception as e:
|
| 766 |
+
logger.error(f"Failed to auto-resume training: {str(e)}")
|
| 767 |
+
# Set buttons for manual recovery
|
| 768 |
+
ui_updates.update({
|
| 769 |
+
"start_btn": {"interactive": True, "variant": "primary", "value": "Continue Training"},
|
| 770 |
+
"stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
|
| 771 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
|
| 772 |
+
})
|
| 773 |
+
return {"status": "error", "message": f"Failed to auto-resume: {str(e)}", "ui_updates": ui_updates}
|
| 774 |
+
else:
|
| 775 |
+
# Set up UI for manual recovery
|
| 776 |
ui_updates.update({
|
| 777 |
+
"start_btn": {"interactive": True, "variant": "primary", "value": "Continue Training"},
|
| 778 |
+
"stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
|
| 779 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
|
| 780 |
})
|
| 781 |
+
return {"status": "ready_to_recover", "message": f"Ready to resume from checkpoint {checkpoint_step}", "ui_updates": ui_updates}
|
| 782 |
+
|
| 783 |
elif self.is_training_running():
|
| 784 |
# Process is still running, set buttons accordingly
|
| 785 |
ui_updates = {
|
| 786 |
+
"start_btn": {"interactive": False, "variant": "secondary", "value": "Continue Training" if has_checkpoints else "Start Training"},
|
| 787 |
+
"stop_btn": {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"},
|
| 788 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
|
| 789 |
}
|
| 790 |
return {"status": "running", "message": "Training process is running", "ui_updates": ui_updates}
|
| 791 |
else:
|
| 792 |
# No training process, set buttons to default state
|
| 793 |
+
button_text = "Continue Training" if has_checkpoints else "Start Training"
|
| 794 |
ui_updates = {
|
| 795 |
+
"start_btn": {"interactive": True, "variant": "primary", "value": button_text},
|
| 796 |
+
"stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
|
| 797 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
|
| 798 |
}
|
| 799 |
return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
|
| 800 |
+
|
| 801 |
+
def delete_all_checkpoints(self) -> str:
|
| 802 |
+
"""Delete all checkpoints in the output directory.
|
| 803 |
+
|
| 804 |
+
Returns:
|
| 805 |
+
Status message
|
| 806 |
+
"""
|
| 807 |
+
if self.is_training_running():
|
| 808 |
+
return "Cannot delete checkpoints while training is running. Stop training first."
|
| 809 |
|
| 810 |
+
try:
|
| 811 |
+
# Find all checkpoint directories
|
| 812 |
+
checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
|
| 813 |
+
|
| 814 |
+
if not checkpoints:
|
| 815 |
+
return "No checkpoints found to delete."
|
| 816 |
+
|
| 817 |
+
# Delete each checkpoint directory
|
| 818 |
+
for checkpoint in checkpoints:
|
| 819 |
+
if checkpoint.is_dir():
|
| 820 |
+
shutil.rmtree(checkpoint)
|
| 821 |
+
|
| 822 |
+
# Also delete session.json which contains previous training info
|
| 823 |
+
if self.session_file.exists():
|
| 824 |
+
self.session_file.unlink()
|
| 825 |
+
|
| 826 |
+
# Reset status file to idle
|
| 827 |
+
self.save_status(state='idle', message='No training in progress')
|
| 828 |
+
|
| 829 |
+
self.append_log(f"Deleted {len(checkpoints)} checkpoint(s)")
|
| 830 |
+
return f"Successfully deleted {len(checkpoints)} checkpoint(s)"
|
| 831 |
+
|
| 832 |
+
except Exception as e:
|
| 833 |
+
error_msg = f"Error deleting checkpoints: {str(e)}"
|
| 834 |
+
self.append_log(error_msg)
|
| 835 |
+
return error_msg
|
| 836 |
+
|
| 837 |
def clear_training_data(self) -> str:
|
| 838 |
"""Clear all training data"""
|
| 839 |
if self.is_training_running():
|
vms/ui/video_trainer_ui.py
CHANGED
|
@@ -36,7 +36,7 @@ class VideoTrainerUI:
|
|
| 36 |
|
| 37 |
# Initialize log parser
|
| 38 |
self.log_parser = TrainingLogParser()
|
| 39 |
-
|
| 40 |
# Shared state for tabs
|
| 41 |
self.state = {
|
| 42 |
"recovery_result": recovery_result
|
|
@@ -45,6 +45,9 @@ class VideoTrainerUI:
|
|
| 45 |
# Initialize tabs dictionary (will be populated in create_ui)
|
| 46 |
self.tabs = {}
|
| 47 |
self.tabs_component = None
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
def create_ui(self):
|
| 50 |
"""Create the main Gradio UI"""
|
|
@@ -104,7 +107,7 @@ class VideoTrainerUI:
|
|
| 104 |
self.tabs["train_tab"].components["log_box"],
|
| 105 |
self.tabs["train_tab"].components["start_btn"],
|
| 106 |
self.tabs["train_tab"].components["stop_btn"],
|
| 107 |
-
self.tabs["train_tab"].components["
|
| 108 |
]
|
| 109 |
)
|
| 110 |
|
|
@@ -135,14 +138,33 @@ class VideoTrainerUI:
|
|
| 135 |
video_list = self.tabs["split_tab"].list_unprocessed_videos()
|
| 136 |
training_dataset = self.tabs["caption_tab"].list_training_files_to_caption()
|
| 137 |
|
| 138 |
-
# Get button states
|
| 139 |
button_states = self.get_initial_button_states()
|
| 140 |
start_btn = button_states[0]
|
| 141 |
stop_btn = button_states[1]
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
-
#
|
| 145 |
ui_state = self.load_ui_values()
|
|
|
|
| 146 |
training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
|
| 147 |
model_type_val = ui_state.get("model_type", list(MODEL_TYPES.keys())[0])
|
| 148 |
lora_rank_val = ui_state.get("lora_rank", "128")
|
|
@@ -158,7 +180,7 @@ class VideoTrainerUI:
|
|
| 158 |
training_dataset,
|
| 159 |
start_btn,
|
| 160 |
stop_btn,
|
| 161 |
-
pause_resume_btn
|
| 162 |
training_preset,
|
| 163 |
model_type_val,
|
| 164 |
lora_rank_val,
|
|
@@ -210,16 +232,39 @@ class VideoTrainerUI:
|
|
| 210 |
# Add this new method to get initial button states:
|
| 211 |
def get_initial_button_states(self):
|
| 212 |
"""Get the initial states for training buttons based on recovery status"""
|
| 213 |
-
recovery_result = self.trainer.recover_interrupted_training()
|
| 214 |
ui_updates = recovery_result.get("ui_updates", {})
|
| 215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
# Return button states in the correct order
|
| 217 |
return (
|
| 218 |
-
gr.Button(**
|
| 219 |
-
gr.Button(**
|
| 220 |
-
gr.Button(**
|
| 221 |
)
|
| 222 |
-
|
| 223 |
def update_titles(self) -> Tuple[Any]:
|
| 224 |
"""Update all dynamic titles with current counts
|
| 225 |
|
|
|
|
| 36 |
|
| 37 |
# Initialize log parser
|
| 38 |
self.log_parser = TrainingLogParser()
|
| 39 |
+
|
| 40 |
# Shared state for tabs
|
| 41 |
self.state = {
|
| 42 |
"recovery_result": recovery_result
|
|
|
|
| 45 |
# Initialize tabs dictionary (will be populated in create_ui)
|
| 46 |
self.tabs = {}
|
| 47 |
self.tabs_component = None
|
| 48 |
+
|
| 49 |
+
# Log recovery status
|
| 50 |
+
logger.info(f"Initialization complete. Recovery status: {self.recovery_status}")
|
| 51 |
|
| 52 |
def create_ui(self):
|
| 53 |
"""Create the main Gradio UI"""
|
|
|
|
| 107 |
self.tabs["train_tab"].components["log_box"],
|
| 108 |
self.tabs["train_tab"].components["start_btn"],
|
| 109 |
self.tabs["train_tab"].components["stop_btn"],
|
| 110 |
+
self.tabs["train_tab"].components["delete_checkpoints_btn"] # Replace pause_resume_btn
|
| 111 |
]
|
| 112 |
)
|
| 113 |
|
|
|
|
| 138 |
video_list = self.tabs["split_tab"].list_unprocessed_videos()
|
| 139 |
training_dataset = self.tabs["caption_tab"].list_training_files_to_caption()
|
| 140 |
|
| 141 |
+
# Get button states based on recovery status
|
| 142 |
button_states = self.get_initial_button_states()
|
| 143 |
start_btn = button_states[0]
|
| 144 |
stop_btn = button_states[1]
|
| 145 |
+
delete_checkpoints_btn = button_states[2] # This replaces pause_resume_btn in the response tuple
|
| 146 |
+
|
| 147 |
+
# Get UI form values - possibly from the recovery
|
| 148 |
+
if self.recovery_status in ["recovered", "ready_to_recover", "running"] and "ui_updates" in self.state["recovery_result"]:
|
| 149 |
+
recovery_ui = self.state["recovery_result"]["ui_updates"]
|
| 150 |
+
|
| 151 |
+
# If we recovered training parameters from the original session
|
| 152 |
+
ui_state = {}
|
| 153 |
+
for param in ["model_type", "lora_rank", "lora_alpha", "num_epochs",
|
| 154 |
+
"batch_size", "learning_rate", "save_iterations", "training_preset"]:
|
| 155 |
+
if param in recovery_ui:
|
| 156 |
+
ui_state[param] = recovery_ui[param]
|
| 157 |
+
|
| 158 |
+
# Merge with existing UI state if needed
|
| 159 |
+
if ui_state:
|
| 160 |
+
current_state = self.load_ui_values()
|
| 161 |
+
current_state.update(ui_state)
|
| 162 |
+
self.trainer.save_ui_state(current_state)
|
| 163 |
+
logger.info(f"Updated UI state from recovery: {ui_state}")
|
| 164 |
|
| 165 |
+
# Load values (potentially with recovery updates applied)
|
| 166 |
ui_state = self.load_ui_values()
|
| 167 |
+
|
| 168 |
training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
|
| 169 |
model_type_val = ui_state.get("model_type", list(MODEL_TYPES.keys())[0])
|
| 170 |
lora_rank_val = ui_state.get("lora_rank", "128")
|
|
|
|
| 180 |
training_dataset,
|
| 181 |
start_btn,
|
| 182 |
stop_btn,
|
| 183 |
+
delete_checkpoints_btn, # Replaces pause_resume_btn
|
| 184 |
training_preset,
|
| 185 |
model_type_val,
|
| 186 |
lora_rank_val,
|
|
|
|
| 232 |
# Add this new method to get initial button states:
|
| 233 |
def get_initial_button_states(self):
|
| 234 |
"""Get the initial states for training buttons based on recovery status"""
|
| 235 |
+
recovery_result = self.state.get("recovery_result") or self.trainer.recover_interrupted_training()
|
| 236 |
ui_updates = recovery_result.get("ui_updates", {})
|
| 237 |
|
| 238 |
+
# Check for checkpoints to determine start button text
|
| 239 |
+
has_checkpoints = len(list(OUTPUT_PATH.glob("checkpoint-*"))) > 0
|
| 240 |
+
|
| 241 |
+
# Default button states if recovery didn't provide any
|
| 242 |
+
if not ui_updates or not ui_updates.get("start_btn"):
|
| 243 |
+
is_training = self.trainer.is_training_running()
|
| 244 |
+
|
| 245 |
+
if is_training:
|
| 246 |
+
# Active training detected
|
| 247 |
+
start_btn_props = {"interactive": False, "variant": "secondary", "value": "Continue Training" if has_checkpoints else "Start Training"}
|
| 248 |
+
stop_btn_props = {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"}
|
| 249 |
+
delete_btn_props = {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"}
|
| 250 |
+
else:
|
| 251 |
+
# No active training
|
| 252 |
+
start_btn_props = {"interactive": True, "variant": "primary", "value": "Continue Training" if has_checkpoints else "Start Training"}
|
| 253 |
+
stop_btn_props = {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"}
|
| 254 |
+
delete_btn_props = {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"}
|
| 255 |
+
else:
|
| 256 |
+
# Use button states from recovery
|
| 257 |
+
start_btn_props = ui_updates.get("start_btn", {"interactive": True, "variant": "primary", "value": "Start Training"})
|
| 258 |
+
stop_btn_props = ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"})
|
| 259 |
+
delete_btn_props = ui_updates.get("delete_checkpoints_btn", {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"})
|
| 260 |
+
|
| 261 |
# Return button states in the correct order
|
| 262 |
return (
|
| 263 |
+
gr.Button(**start_btn_props),
|
| 264 |
+
gr.Button(**stop_btn_props),
|
| 265 |
+
gr.Button(**delete_btn_props)
|
| 266 |
)
|
| 267 |
+
|
| 268 |
def update_titles(self) -> Tuple[Any]:
|
| 269 |
"""Update all dynamic titles with current counts
|
| 270 |
|