Julian Bilcke
commited on
Commit
·
c8589f9
1
Parent(s):
adc5756
various fixes regarding session recovery
Browse files- finetrainers/dataset.py +1 -1
- vms/services/trainer.py +134 -132
- vms/tabs/train_tab.py +55 -9
- vms/ui/video_trainer_ui.py +30 -5
finetrainers/dataset.py
CHANGED
|
@@ -32,6 +32,7 @@ from .constants import ( # noqa
|
|
| 32 |
PRECOMPUTED_LATENTS_DIR_NAME,
|
| 33 |
)
|
| 34 |
|
|
|
|
| 35 |
|
| 36 |
# Decord is causing us some issues!
|
| 37 |
# Let's try to increase file descriptor limits to avoid this error:
|
|
@@ -49,7 +50,6 @@ try:
|
|
| 49 |
except Exception as e:
|
| 50 |
logger.warning(f"Could not check or update file descriptor limits: {e}")
|
| 51 |
|
| 52 |
-
logger = get_logger(__name__)
|
| 53 |
|
| 54 |
|
| 55 |
# TODO(aryan): This needs a refactor with separation of concerns.
|
|
|
|
| 32 |
PRECOMPUTED_LATENTS_DIR_NAME,
|
| 33 |
)
|
| 34 |
|
| 35 |
+
logger = get_logger(__name__)
|
| 36 |
|
| 37 |
# Decord is causing us some issues!
|
| 38 |
# Let's try to increase file descriptor limits to avoid this error:
|
|
|
|
| 50 |
except Exception as e:
|
| 51 |
logger.warning(f"Could not check or update file descriptor limits: {e}")
|
| 52 |
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
# TODO(aryan): This needs a refactor with separation of concerns.
|
vms/services/trainer.py
CHANGED
|
@@ -637,149 +637,151 @@ class TrainingService:
|
|
| 637 |
return False
|
| 638 |
|
| 639 |
def recover_interrupted_training(self) -> Dict[str, Any]:
|
| 640 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 641 |
|
| 642 |
-
|
| 643 |
-
Dict with recovery status and UI updates
|
| 644 |
-
"""
|
| 645 |
-
status = self.get_status()
|
| 646 |
-
ui_updates = {}
|
| 647 |
|
| 648 |
-
#
|
| 649 |
-
|
| 650 |
-
has_checkpoints = len(checkpoints) > 0
|
| 651 |
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
(has_checkpoints and not self.is_training_running()):
|
| 656 |
-
|
| 657 |
-
logger.info("Detected interrupted training session or existing checkpoints, attempting to recover...")
|
| 658 |
-
|
| 659 |
-
# Get the latest checkpoint
|
| 660 |
-
last_session = self.load_session()
|
| 661 |
-
|
| 662 |
-
if not last_session:
|
| 663 |
-
logger.warning("No session data found for recovery, but will check for checkpoints")
|
| 664 |
-
# Try to create a default session based on UI state if we have checkpoints
|
| 665 |
-
if has_checkpoints:
|
| 666 |
-
ui_state = self.load_ui_state()
|
| 667 |
-
# Create a default session using UI state values
|
| 668 |
-
last_session = {
|
| 669 |
-
"params": {
|
| 670 |
-
"model_type": MODEL_TYPES.get(ui_state.get("model_type", list(MODEL_TYPES.keys())[0])),
|
| 671 |
-
"lora_rank": ui_state.get("lora_rank", "128"),
|
| 672 |
-
"lora_alpha": ui_state.get("lora_alpha", "128"),
|
| 673 |
-
"num_epochs": ui_state.get("num_epochs", 70),
|
| 674 |
-
"batch_size": ui_state.get("batch_size", 1),
|
| 675 |
-
"learning_rate": ui_state.get("learning_rate", 3e-5),
|
| 676 |
-
"save_iterations": ui_state.get("save_iterations", 500),
|
| 677 |
-
"preset_name": ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
|
| 678 |
-
"repo_id": "" # Default empty repo ID
|
| 679 |
-
}
|
| 680 |
-
}
|
| 681 |
-
logger.info("Created default session from UI state for recovery")
|
| 682 |
-
else:
|
| 683 |
-
# Set buttons for no active training
|
| 684 |
-
ui_updates = {
|
| 685 |
-
"start_btn": {"interactive": True, "variant": "primary", "value": "Start Training"},
|
| 686 |
-
"stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
|
| 687 |
-
"pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
|
| 688 |
-
}
|
| 689 |
-
return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
|
| 690 |
-
|
| 691 |
-
# Find the latest checkpoint if we have checkpoints
|
| 692 |
-
latest_checkpoint = None
|
| 693 |
-
checkpoint_step = 0
|
| 694 |
-
|
| 695 |
if has_checkpoints:
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 699 |
else:
|
| 700 |
-
logger.warning("No checkpoints found for recovery")
|
| 701 |
# Set buttons for no active training
|
| 702 |
ui_updates = {
|
| 703 |
"start_btn": {"interactive": True, "variant": "primary", "value": "Start Training"},
|
| 704 |
"stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
|
|
|
|
| 705 |
"pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
|
| 706 |
}
|
| 707 |
-
return {"status": "
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
"
|
| 722 |
-
"
|
| 723 |
-
"
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
}
|
| 782 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 783 |
else:
|
| 784 |
# Set up UI for manual recovery
|
| 785 |
ui_updates.update({
|
|
|
|
| 637 |
return False
|
| 638 |
|
| 639 |
def recover_interrupted_training(self) -> Dict[str, Any]:
|
| 640 |
+
"""Attempt to recover interrupted training
|
| 641 |
+
|
| 642 |
+
Returns:
|
| 643 |
+
Dict with recovery status and UI updates
|
| 644 |
+
"""
|
| 645 |
+
status = self.get_status()
|
| 646 |
+
ui_updates = {}
|
| 647 |
+
|
| 648 |
+
# Check for any checkpoints, even if status doesn't indicate training
|
| 649 |
+
checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
|
| 650 |
+
has_checkpoints = len(checkpoints) > 0
|
| 651 |
+
|
| 652 |
+
# If status indicates training but process isn't running, or if we have checkpoints
|
| 653 |
+
# and no active training process, try to recover
|
| 654 |
+
if (status.get('status') in ['training', 'paused'] and not self.is_training_running()) or \
|
| 655 |
+
(has_checkpoints and not self.is_training_running()):
|
| 656 |
|
| 657 |
+
logger.info("Detected interrupted training session or existing checkpoints, attempting to recover...")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 658 |
|
| 659 |
+
# Get the latest checkpoint
|
| 660 |
+
last_session = self.load_session()
|
|
|
|
| 661 |
|
| 662 |
+
if not last_session:
|
| 663 |
+
logger.warning("No session data found for recovery, but will check for checkpoints")
|
| 664 |
+
# Try to create a default session based on UI state if we have checkpoints
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 665 |
if has_checkpoints:
|
| 666 |
+
ui_state = self.load_ui_state()
|
| 667 |
+
# Create a default session using UI state values
|
| 668 |
+
last_session = {
|
| 669 |
+
"params": {
|
| 670 |
+
"model_type": MODEL_TYPES.get(ui_state.get("model_type", list(MODEL_TYPES.keys())[0])),
|
| 671 |
+
"lora_rank": ui_state.get("lora_rank", "128"),
|
| 672 |
+
"lora_alpha": ui_state.get("lora_alpha", "128"),
|
| 673 |
+
"num_epochs": ui_state.get("num_epochs", 70),
|
| 674 |
+
"batch_size": ui_state.get("batch_size", 1),
|
| 675 |
+
"learning_rate": ui_state.get("learning_rate", 3e-5),
|
| 676 |
+
"save_iterations": ui_state.get("save_iterations", 500),
|
| 677 |
+
"preset_name": ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
|
| 678 |
+
"repo_id": "" # Default empty repo ID
|
| 679 |
+
}
|
| 680 |
+
}
|
| 681 |
+
logger.info("Created default session from UI state for recovery")
|
| 682 |
else:
|
|
|
|
| 683 |
# Set buttons for no active training
|
| 684 |
ui_updates = {
|
| 685 |
"start_btn": {"interactive": True, "variant": "primary", "value": "Start Training"},
|
| 686 |
"stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
|
| 687 |
+
"delete_checkpoints_btn": {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"},
|
| 688 |
"pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
|
| 689 |
}
|
| 690 |
+
return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
|
| 691 |
+
|
| 692 |
+
# Find the latest checkpoint if we have checkpoints
|
| 693 |
+
latest_checkpoint = None
|
| 694 |
+
checkpoint_step = 0
|
| 695 |
+
|
| 696 |
+
if has_checkpoints:
|
| 697 |
+
latest_checkpoint = max(checkpoints, key=os.path.getmtime)
|
| 698 |
+
checkpoint_step = int(latest_checkpoint.name.split("-")[1])
|
| 699 |
+
logger.info(f"Found checkpoint at step {checkpoint_step}")
|
| 700 |
+
else:
|
| 701 |
+
logger.warning("No checkpoints found for recovery")
|
| 702 |
+
# Set buttons for no active training
|
| 703 |
+
ui_updates = {
|
| 704 |
+
"start_btn": {"interactive": True, "variant": "primary", "value": "Start Training"},
|
| 705 |
+
"stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
|
| 706 |
+
"delete_checkpoints_btn": {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"},
|
| 707 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
|
| 708 |
+
}
|
| 709 |
+
return {"status": "error", "message": "No checkpoints found", "ui_updates": ui_updates}
|
| 710 |
+
|
| 711 |
+
# Extract parameters from the saved session (not current UI state)
|
| 712 |
+
# This ensures we use the original training parameters
|
| 713 |
+
params = last_session.get('params', {})
|
| 714 |
+
|
| 715 |
+
# Map internal model type back to display name for UI
|
| 716 |
+
# This is the key fix for the "ltx_video" vs "LTX-Video (LoRA)" mismatch
|
| 717 |
+
model_type_internal = params.get('model_type')
|
| 718 |
+
model_type_display = model_type_internal
|
| 719 |
+
|
| 720 |
+
# Find the display name that maps to our internal model type
|
| 721 |
+
for display_name, internal_name in MODEL_TYPES.items():
|
| 722 |
+
if internal_name == model_type_internal:
|
| 723 |
+
model_type_display = display_name
|
| 724 |
+
logger.info(f"Mapped internal model type '{model_type_internal}' to display name '{model_type_display}'")
|
| 725 |
+
break
|
| 726 |
+
|
| 727 |
+
# Add UI updates to restore the training parameters in the UI
|
| 728 |
+
# This shows the user what values are being used for the resumed training
|
| 729 |
+
ui_updates.update({
|
| 730 |
+
"model_type": model_type_display, # Use the display name for the UI dropdown
|
| 731 |
+
"lora_rank": params.get('lora_rank', "128"),
|
| 732 |
+
"lora_alpha": params.get('lora_alpha', "128"),
|
| 733 |
+
"num_epochs": params.get('num_epochs', 70),
|
| 734 |
+
"batch_size": params.get('batch_size', 1),
|
| 735 |
+
"learning_rate": params.get('learning_rate', 3e-5),
|
| 736 |
+
"save_iterations": params.get('save_iterations', 500),
|
| 737 |
+
"training_preset": params.get('preset_name', list(TRAINING_PRESETS.keys())[0])
|
| 738 |
+
})
|
| 739 |
+
|
| 740 |
+
# Check if we should auto-recover (immediate restart)
|
| 741 |
+
auto_recover = True # Always auto-recover on startup
|
| 742 |
+
|
| 743 |
+
if auto_recover:
|
| 744 |
+
# Rest of the auto-recovery code remains unchanged
|
| 745 |
+
try:
|
| 746 |
+
# Use the internal model_type for the actual training
|
| 747 |
+
# But keep model_type_display for the UI
|
| 748 |
+
result = self.start_training(
|
| 749 |
+
model_type=model_type_internal,
|
| 750 |
+
lora_rank=params.get('lora_rank', "128"),
|
| 751 |
+
lora_alpha=params.get('lora_alpha', "128"),
|
| 752 |
+
num_epochs=params.get('num_epochs', 70),
|
| 753 |
+
batch_size=params.get('batch_size', 1),
|
| 754 |
+
learning_rate=params.get('learning_rate', 3e-5),
|
| 755 |
+
save_iterations=params.get('save_iterations', 500),
|
| 756 |
+
repo_id=params.get('repo_id', ''),
|
| 757 |
+
preset_name=params.get('preset_name', list(TRAINING_PRESETS.keys())[0]),
|
| 758 |
+
resume_from_checkpoint=str(latest_checkpoint)
|
| 759 |
+
)
|
| 760 |
+
|
| 761 |
+
# Set buttons for active training
|
| 762 |
+
ui_updates.update({
|
| 763 |
+
"start_btn": {"interactive": False, "variant": "secondary", "value": "Continue Training"},
|
| 764 |
+
"stop_btn": {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"},
|
| 765 |
+
"delete_checkpoints_btn": {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"},
|
| 766 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
|
| 767 |
+
})
|
| 768 |
+
|
| 769 |
+
return {
|
| 770 |
+
"status": "recovered",
|
| 771 |
+
"message": f"Training resumed from checkpoint {checkpoint_step}",
|
| 772 |
+
"result": result,
|
| 773 |
+
"ui_updates": ui_updates
|
| 774 |
+
}
|
| 775 |
+
except Exception as e:
|
| 776 |
+
logger.error(f"Failed to auto-resume training: {str(e)}")
|
| 777 |
+
# Set buttons for manual recovery
|
| 778 |
+
ui_updates.update({
|
| 779 |
+
"start_btn": {"interactive": True, "variant": "primary", "value": "Continue Training"},
|
| 780 |
+
"stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
|
| 781 |
+
"delete_checkpoints_btn": {"interactive": True, "variant": "stop", "value": "Delete All Checkpoints"},
|
| 782 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
|
| 783 |
+
})
|
| 784 |
+
return {"status": "error", "message": f"Failed to auto-resume: {str(e)}", "ui_updates": ui_updates}
|
| 785 |
else:
|
| 786 |
# Set up UI for manual recovery
|
| 787 |
ui_updates.update({
|
vms/tabs/train_tab.py
CHANGED
|
@@ -8,7 +8,7 @@ from typing import Dict, Any, List, Optional, Tuple
|
|
| 8 |
from pathlib import Path
|
| 9 |
|
| 10 |
from .base_tab import BaseTab
|
| 11 |
-
from ..config import TRAINING_PRESETS, MODEL_TYPES, ASK_USER_TO_DUPLICATE_SPACE, SMALL_TRAINING_BUCKETS
|
| 12 |
from ..utils import TrainingLogParser
|
| 13 |
|
| 14 |
logger = logging.getLogger(__name__)
|
|
@@ -279,7 +279,7 @@ class TrainTab(BaseTab):
|
|
| 279 |
)
|
| 280 |
|
| 281 |
def handle_training_start(self, preset, model_type, *args):
|
| 282 |
-
"""Handle training start with proper log parser reset"""
|
| 283 |
# Safely reset log parser if it exists
|
| 284 |
if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
|
| 285 |
self.app.log_parser.reset()
|
|
@@ -288,12 +288,35 @@ class TrainTab(BaseTab):
|
|
| 288 |
from ..utils import TrainingLogParser
|
| 289 |
self.app.log_parser = TrainingLogParser()
|
| 290 |
|
| 291 |
-
#
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
|
| 298 |
def get_model_info(self, model_type: str) -> str:
|
| 299 |
"""Get information about the selected model type"""
|
|
@@ -455,6 +478,23 @@ class TrainTab(BaseTab):
|
|
| 455 |
state = self.app.trainer.get_status()
|
| 456 |
logs = self.app.trainer.get_logs()
|
| 457 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
# Ensure log parser is initialized
|
| 459 |
if not hasattr(self.app, 'log_parser') or self.app.log_parser is None:
|
| 460 |
from ..utils import TrainingLogParser
|
|
@@ -462,7 +502,7 @@ class TrainTab(BaseTab):
|
|
| 462 |
logger.info("Initialized missing log parser")
|
| 463 |
|
| 464 |
# Parse new log lines
|
| 465 |
-
if logs:
|
| 466 |
last_state = None
|
| 467 |
for line in logs.splitlines():
|
| 468 |
try:
|
|
@@ -480,6 +520,12 @@ class TrainTab(BaseTab):
|
|
| 480 |
# Parse status for training state
|
| 481 |
if "completed" in state["message"].lower():
|
| 482 |
state["status"] = "completed"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 483 |
|
| 484 |
return (state["status"], state["message"], logs)
|
| 485 |
|
|
|
|
| 8 |
from pathlib import Path
|
| 9 |
|
| 10 |
from .base_tab import BaseTab
|
| 11 |
+
from ..config import TRAINING_PRESETS, OUTPUT_PATH, MODEL_TYPES, ASK_USER_TO_DUPLICATE_SPACE, SMALL_TRAINING_BUCKETS
|
| 12 |
from ..utils import TrainingLogParser
|
| 13 |
|
| 14 |
logger = logging.getLogger(__name__)
|
|
|
|
| 279 |
)
|
| 280 |
|
| 281 |
def handle_training_start(self, preset, model_type, *args):
|
| 282 |
+
"""Handle training start with proper log parser reset and checkpoint detection"""
|
| 283 |
# Safely reset log parser if it exists
|
| 284 |
if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
|
| 285 |
self.app.log_parser.reset()
|
|
|
|
| 288 |
from ..utils import TrainingLogParser
|
| 289 |
self.app.log_parser = TrainingLogParser()
|
| 290 |
|
| 291 |
+
# Check for latest checkpoint
|
| 292 |
+
checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
|
| 293 |
+
resume_from = None
|
| 294 |
+
|
| 295 |
+
if checkpoints:
|
| 296 |
+
# Find the latest checkpoint
|
| 297 |
+
latest_checkpoint = max(checkpoints, key=os.path.getmtime)
|
| 298 |
+
resume_from = str(latest_checkpoint)
|
| 299 |
+
logger.info(f"Found checkpoint at {resume_from}, will resume training")
|
| 300 |
+
|
| 301 |
+
# Convert model_type display name to internal name
|
| 302 |
+
model_internal_type = MODEL_TYPES.get(model_type)
|
| 303 |
+
|
| 304 |
+
if not model_internal_type:
|
| 305 |
+
logger.error(f"Invalid model type: {model_type}")
|
| 306 |
+
return f"Error: Invalid model type '{model_type}'", "Model type not recognized"
|
| 307 |
+
|
| 308 |
+
# Start training (it will automatically use the checkpoint if provided)
|
| 309 |
+
try:
|
| 310 |
+
return self.app.trainer.start_training(
|
| 311 |
+
model_internal_type, # Use internal model type
|
| 312 |
+
*args,
|
| 313 |
+
preset_name=preset,
|
| 314 |
+
resume_from_checkpoint=resume_from
|
| 315 |
+
)
|
| 316 |
+
except Exception as e:
|
| 317 |
+
logger.exception("Error starting training")
|
| 318 |
+
return f"Error starting training: {str(e)}", f"Exception: {str(e)}\n\nCheck the logs for more details."
|
| 319 |
+
|
| 320 |
|
| 321 |
def get_model_info(self, model_type: str) -> str:
|
| 322 |
"""Get information about the selected model type"""
|
|
|
|
| 478 |
state = self.app.trainer.get_status()
|
| 479 |
logs = self.app.trainer.get_logs()
|
| 480 |
|
| 481 |
+
# Check if training process died unexpectedly
|
| 482 |
+
training_died = False
|
| 483 |
+
|
| 484 |
+
if state["status"] == "training" and not self.app.trainer.is_training_running():
|
| 485 |
+
state["status"] = "error"
|
| 486 |
+
state["message"] = "Training process terminated unexpectedly."
|
| 487 |
+
training_died = True
|
| 488 |
+
|
| 489 |
+
# Look for error in logs
|
| 490 |
+
error_lines = []
|
| 491 |
+
for line in logs.splitlines():
|
| 492 |
+
if "Error:" in line or "Exception:" in line or "Traceback" in line:
|
| 493 |
+
error_lines.append(line)
|
| 494 |
+
|
| 495 |
+
if error_lines:
|
| 496 |
+
state["message"] += f"\n\nPossible error: {error_lines[-1]}"
|
| 497 |
+
|
| 498 |
# Ensure log parser is initialized
|
| 499 |
if not hasattr(self.app, 'log_parser') or self.app.log_parser is None:
|
| 500 |
from ..utils import TrainingLogParser
|
|
|
|
| 502 |
logger.info("Initialized missing log parser")
|
| 503 |
|
| 504 |
# Parse new log lines
|
| 505 |
+
if logs and not training_died:
|
| 506 |
last_state = None
|
| 507 |
for line in logs.splitlines():
|
| 508 |
try:
|
|
|
|
| 520 |
# Parse status for training state
|
| 521 |
if "completed" in state["message"].lower():
|
| 522 |
state["status"] = "completed"
|
| 523 |
+
elif "error" in state["message"].lower():
|
| 524 |
+
state["status"] = "error"
|
| 525 |
+
elif "failed" in state["message"].lower():
|
| 526 |
+
state["status"] = "error"
|
| 527 |
+
elif "stopped" in state["message"].lower():
|
| 528 |
+
state["status"] = "stopped"
|
| 529 |
|
| 530 |
return (state["status"], state["message"], logs)
|
| 531 |
|
vms/ui/video_trainer_ui.py
CHANGED
|
@@ -7,7 +7,7 @@ from typing import Any, Optional, Dict, List, Union, Tuple
|
|
| 7 |
|
| 8 |
from ..services import TrainingService, CaptioningService, SplittingService, ImportService
|
| 9 |
from ..config import (
|
| 10 |
-
STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH,
|
| 11 |
TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH,
|
| 12 |
MODEL_TYPES, SMALL_TRAINING_BUCKETS
|
| 13 |
)
|
|
@@ -160,7 +160,24 @@ class VideoTrainerUI:
|
|
| 160 |
|
| 161 |
# If we recovered training parameters from the original session
|
| 162 |
ui_state = {}
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
"batch_size", "learning_rate", "save_iterations", "training_preset"]:
|
| 165 |
if param in recovery_ui:
|
| 166 |
ui_state[param] = recovery_ui[param]
|
|
@@ -175,8 +192,16 @@ class VideoTrainerUI:
|
|
| 175 |
# Load values (potentially with recovery updates applied)
|
| 176 |
ui_state = self.load_ui_values()
|
| 177 |
|
| 178 |
-
|
| 179 |
model_type_val = ui_state.get("model_type", list(MODEL_TYPES.keys())[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
lora_rank_val = ui_state.get("lora_rank", "128")
|
| 181 |
lora_alpha_val = ui_state.get("lora_alpha", "128")
|
| 182 |
num_epochs_val = int(ui_state.get("num_epochs", 70))
|
|
@@ -190,9 +215,9 @@ class VideoTrainerUI:
|
|
| 190 |
training_dataset,
|
| 191 |
start_btn,
|
| 192 |
stop_btn,
|
| 193 |
-
delete_checkpoints_btn,
|
| 194 |
training_preset,
|
| 195 |
-
model_type_val,
|
| 196 |
lora_rank_val,
|
| 197 |
lora_alpha_val,
|
| 198 |
num_epochs_val,
|
|
|
|
| 7 |
|
| 8 |
from ..services import TrainingService, CaptioningService, SplittingService, ImportService
|
| 9 |
from ..config import (
|
| 10 |
+
STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, OUTPUT_PATH,
|
| 11 |
TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH,
|
| 12 |
MODEL_TYPES, SMALL_TRAINING_BUCKETS
|
| 13 |
)
|
|
|
|
| 160 |
|
| 161 |
# If we recovered training parameters from the original session
|
| 162 |
ui_state = {}
|
| 163 |
+
|
| 164 |
+
# Handle model_type specifically - could be internal or display name
|
| 165 |
+
if "model_type" in recovery_ui:
|
| 166 |
+
model_type_value = recovery_ui["model_type"]
|
| 167 |
+
|
| 168 |
+
# If it's an internal name, convert to display name
|
| 169 |
+
if model_type_value not in MODEL_TYPES:
|
| 170 |
+
# Find the display name for this internal model type
|
| 171 |
+
for display_name, internal_name in MODEL_TYPES.items():
|
| 172 |
+
if internal_name == model_type_value:
|
| 173 |
+
model_type_value = display_name
|
| 174 |
+
logger.info(f"Converted internal model type '{recovery_ui['model_type']}' to display name '{model_type_value}'")
|
| 175 |
+
break
|
| 176 |
+
|
| 177 |
+
ui_state["model_type"] = model_type_value
|
| 178 |
+
|
| 179 |
+
# Copy other parameters
|
| 180 |
+
for param in ["lora_rank", "lora_alpha", "num_epochs",
|
| 181 |
"batch_size", "learning_rate", "save_iterations", "training_preset"]:
|
| 182 |
if param in recovery_ui:
|
| 183 |
ui_state[param] = recovery_ui[param]
|
|
|
|
| 192 |
# Load values (potentially with recovery updates applied)
|
| 193 |
ui_state = self.load_ui_values()
|
| 194 |
|
| 195 |
+
# Ensure model_type is a display name, not internal name
|
| 196 |
model_type_val = ui_state.get("model_type", list(MODEL_TYPES.keys())[0])
|
| 197 |
+
if model_type_val not in MODEL_TYPES:
|
| 198 |
+
# Convert from internal to display name
|
| 199 |
+
for display_name, internal_name in MODEL_TYPES.items():
|
| 200 |
+
if internal_name == model_type_val:
|
| 201 |
+
model_type_val = display_name
|
| 202 |
+
break
|
| 203 |
+
|
| 204 |
+
training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
|
| 205 |
lora_rank_val = ui_state.get("lora_rank", "128")
|
| 206 |
lora_alpha_val = ui_state.get("lora_alpha", "128")
|
| 207 |
num_epochs_val = int(ui_state.get("num_epochs", 70))
|
|
|
|
| 215 |
training_dataset,
|
| 216 |
start_btn,
|
| 217 |
stop_btn,
|
| 218 |
+
delete_checkpoints_btn,
|
| 219 |
training_preset,
|
| 220 |
+
model_type_val,
|
| 221 |
lora_rank_val,
|
| 222 |
lora_alpha_val,
|
| 223 |
num_epochs_val,
|