Julian Bilcke
commited on
Commit
·
24afbfe
1
Parent(s):
d4b556f
fixes
Browse files- vms/services/importer.py +1 -2
- vms/tabs/train_tab.py +94 -58
vms/services/importer.py
CHANGED
|
@@ -10,8 +10,7 @@ from pytubefix import YouTube
|
|
| 10 |
import logging
|
| 11 |
|
| 12 |
from ..config import NORMALIZE_IMAGES_TO, TRAINING_VIDEOS_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, DEFAULT_PROMPT_PREFIX
|
| 13 |
-
from ..utils import normalize_image, is_image_file, is_video_file, add_prefix_to_caption
|
| 14 |
-
from ..webdataset import webdataset_handler
|
| 15 |
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
|
|
|
| 10 |
import logging
|
| 11 |
|
| 12 |
from ..config import NORMALIZE_IMAGES_TO, TRAINING_VIDEOS_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, DEFAULT_PROMPT_PREFIX
|
| 13 |
+
from ..utils import normalize_image, is_image_file, is_video_file, add_prefix_to_caption, webdataset_handler
|
|
|
|
| 14 |
|
| 15 |
logger = logging.getLogger(__name__)
|
| 16 |
|
vms/tabs/train_tab.py
CHANGED
|
@@ -4,12 +4,12 @@ Train tab for Video Model Studio UI
|
|
| 4 |
|
| 5 |
import gradio as gr
|
| 6 |
import logging
|
|
|
|
| 7 |
from typing import Dict, Any, List, Optional, Tuple
|
| 8 |
from pathlib import Path
|
| 9 |
|
| 10 |
from .base_tab import BaseTab
|
| 11 |
-
from ..config import TRAINING_PRESETS, OUTPUT_PATH, MODEL_TYPES, ASK_USER_TO_DUPLICATE_SPACE, SMALL_TRAINING_BUCKETS
|
| 12 |
-
from ..utils import TrainingLogParser
|
| 13 |
|
| 14 |
logger = logging.getLogger(__name__)
|
| 15 |
|
|
@@ -156,7 +156,7 @@ class TrainTab(BaseTab):
|
|
| 156 |
# Model type change event
|
| 157 |
def update_model_info(model, training_type):
|
| 158 |
params = self.get_default_params(MODEL_TYPES[model], TRAINING_TYPES[training_type])
|
| 159 |
-
info = self.get_model_info(
|
| 160 |
show_lora_params = training_type == list(TRAINING_TYPES.keys())[0] # Show if LoRA Finetune
|
| 161 |
|
| 162 |
return {
|
|
@@ -313,6 +313,21 @@ class TrainTab(BaseTab):
|
|
| 313 |
self.components["pause_resume_btn"]
|
| 314 |
]
|
| 315 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
|
| 317 |
def handle_training_start(self, preset, model_type, training_type, *args):
|
| 318 |
"""Handle training start with proper log parser reset and checkpoint detection"""
|
|
@@ -360,86 +375,103 @@ class TrainTab(BaseTab):
|
|
| 360 |
except Exception as e:
|
| 361 |
logger.exception("Error starting training")
|
| 362 |
return f"Error starting training: {str(e)}", f"Exception: {str(e)}\n\nCheck the logs for more details."
|
| 363 |
-
|
| 364 |
|
| 365 |
def get_model_info(self, model_type: str, training_type: str) -> str:
|
| 366 |
"""Get information about the selected model type and training method"""
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
if model_type == "hunyuan_video":
|
| 370 |
base_info = """### HunyuanVideo
|
| 371 |
- Required VRAM: ~48GB minimum
|
| 372 |
- Recommended batch size: 1-2
|
| 373 |
- Typical training time: 2-4 hours
|
| 374 |
- Default resolution: 49x512x768"""
|
| 375 |
|
| 376 |
-
if training_type == "
|
| 377 |
return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
|
| 378 |
else:
|
| 379 |
-
return base_info + "\n- Required VRAM: ~
|
| 380 |
|
| 381 |
-
elif model_type == "
|
| 382 |
-
base_info = """###
|
| 383 |
-
- Recommended batch size: 1-
|
| 384 |
- Typical training time: 1-3 hours
|
| 385 |
- Default resolution: 49x512x768"""
|
| 386 |
|
| 387 |
-
if training_type == "
|
| 388 |
-
return base_info + "\n- Required VRAM: ~
|
| 389 |
-
else:
|
| 390 |
-
return base_info + "\n- **Full finetune not supported in this UI**" + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
|
| 391 |
else:
|
| 392 |
return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB"
|
| 393 |
|
| 394 |
-
elif model_type == "
|
| 395 |
base_info = """### Wan-2.1-T2V
|
| 396 |
- Recommended batch size: 1-2
|
| 397 |
- Typical training time: 1-3 hours
|
| 398 |
- Default resolution: 49x512x768"""
|
| 399 |
|
| 400 |
-
if training_type == "
|
| 401 |
return base_info + "\n- Required VRAM: ~16GB minimum\n- Default LoRA rank: 32 (~120 MB)"
|
| 402 |
-
else:
|
| 403 |
-
return base_info + "\n- **Full finetune not supported in this UI**" + "\n- Default LoRA rank: 128 (~600 MB)"
|
| 404 |
else:
|
| 405 |
return base_info + "\n- **Full finetune not recommended due to VRAM requirements**"
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
- Recommended batch size: 1-4
|
| 410 |
-
- Typical training time: 1-3 hours
|
| 411 |
-
- Default resolution: 49x512x768"""
|
| 412 |
-
|
| 413 |
-
if training_type == "lora":
|
| 414 |
-
return base_
|
| 415 |
|
| 416 |
-
def get_default_params(self, model_type: str) -> Dict[str, Any]:
|
| 417 |
"""Get default training parameters for model type"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
if model_type == "hunyuan_video":
|
| 419 |
return {
|
| 420 |
"num_epochs": 70,
|
| 421 |
"batch_size": 1,
|
| 422 |
"learning_rate": 2e-5,
|
| 423 |
"save_iterations": 500,
|
| 424 |
-
"
|
| 425 |
-
"
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
}
|
| 431 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
return {
|
| 433 |
"num_epochs": 70,
|
| 434 |
"batch_size": 1,
|
| 435 |
"learning_rate": 3e-5,
|
| 436 |
"save_iterations": 500,
|
| 437 |
-
"
|
| 438 |
-
"
|
| 439 |
-
"caption_dropout_p": 0.05,
|
| 440 |
-
"gradient_accumulation_steps": 4,
|
| 441 |
-
"rank": 128,
|
| 442 |
-
"lora_alpha": 128
|
| 443 |
}
|
| 444 |
|
| 445 |
def update_training_params(self, preset_name: str) -> Tuple:
|
|
@@ -454,6 +486,12 @@ class TrainTab(BaseTab):
|
|
| 454 |
key for key, value in MODEL_TYPES.items()
|
| 455 |
if value == preset["model_type"]
|
| 456 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
|
| 458 |
# Get preset description for display
|
| 459 |
description = preset.get("description", "")
|
|
@@ -467,24 +505,29 @@ class TrainTab(BaseTab):
|
|
| 467 |
|
| 468 |
info_text = f"{description}{bucket_info}"
|
| 469 |
|
| 470 |
-
#
|
|
|
|
|
|
|
| 471 |
# Use preset defaults but preserve user-modified values if they exist
|
| 472 |
-
lora_rank_val = current_state.get("lora_rank") if current_state.get("lora_rank") != preset.get("lora_rank", "128") else preset
|
| 473 |
-
lora_alpha_val = current_state.get("lora_alpha") if current_state.get("lora_alpha") != preset.get("lora_alpha", "128") else preset
|
| 474 |
-
num_epochs_val = current_state.get("num_epochs") if current_state.get("num_epochs") != preset.get("num_epochs", 70) else preset
|
| 475 |
-
batch_size_val = current_state.get("batch_size") if current_state.get("batch_size") != preset.get("batch_size", 1) else preset
|
| 476 |
-
learning_rate_val = current_state.get("learning_rate") if current_state.get("learning_rate") != preset.get("learning_rate", 3e-5) else preset
|
| 477 |
-
save_iterations_val = current_state.get("save_iterations") if current_state.get("save_iterations") != preset.get("save_iterations", 500) else preset
|
| 478 |
|
|
|
|
| 479 |
return (
|
| 480 |
model_display_name,
|
|
|
|
| 481 |
lora_rank_val,
|
| 482 |
lora_alpha_val,
|
| 483 |
num_epochs_val,
|
| 484 |
batch_size_val,
|
| 485 |
learning_rate_val,
|
| 486 |
save_iterations_val,
|
| 487 |
-
info_text
|
|
|
|
| 488 |
)
|
| 489 |
|
| 490 |
def update_training_ui(self, training_state: Dict[str, Any]):
|
|
@@ -498,13 +541,6 @@ class TrainTab(BaseTab):
|
|
| 498 |
f"Status: {training_state['status']}",
|
| 499 |
f"Progress: {training_state['progress']}",
|
| 500 |
f"Step: {training_state['current_step']}/{training_state['total_steps']}",
|
| 501 |
-
|
| 502 |
-
# Epoch information
|
| 503 |
-
# there is an issue with how epoch is reported because we display:
|
| 504 |
-
# Progress: 96.9%, Step: 872/900, Epoch: 12/50
|
| 505 |
-
# we should probably just show the steps
|
| 506 |
-
#f"Epoch: {training_state['current_epoch']}/{training_state['total_epochs']}",
|
| 507 |
-
|
| 508 |
f"Time elapsed: {training_state['elapsed']}",
|
| 509 |
f"Estimated remaining: {training_state['remaining']}",
|
| 510 |
"",
|
|
|
|
| 4 |
|
| 5 |
import gradio as gr
|
| 6 |
import logging
|
| 7 |
+
import os
|
| 8 |
from typing import Dict, Any, List, Optional, Tuple
|
| 9 |
from pathlib import Path
|
| 10 |
|
| 11 |
from .base_tab import BaseTab
|
| 12 |
+
from ..config import TRAINING_PRESETS, OUTPUT_PATH, MODEL_TYPES, ASK_USER_TO_DUPLICATE_SPACE, SMALL_TRAINING_BUCKETS, TRAINING_TYPES
|
|
|
|
| 13 |
|
| 14 |
logger = logging.getLogger(__name__)
|
| 15 |
|
|
|
|
| 156 |
# Model type change event
|
| 157 |
def update_model_info(model, training_type):
|
| 158 |
params = self.get_default_params(MODEL_TYPES[model], TRAINING_TYPES[training_type])
|
| 159 |
+
info = self.get_model_info(model, training_type)
|
| 160 |
show_lora_params = training_type == list(TRAINING_TYPES.keys())[0] # Show if LoRA Finetune
|
| 161 |
|
| 162 |
return {
|
|
|
|
| 313 |
self.components["pause_resume_btn"]
|
| 314 |
]
|
| 315 |
)
|
| 316 |
+
|
| 317 |
+
# Add an event handler for delete_checkpoints_btn
|
| 318 |
+
self.components["delete_checkpoints_btn"].click(
|
| 319 |
+
fn=lambda: self.app.trainer.delete_all_checkpoints(),
|
| 320 |
+
outputs=[self.components["status_box"]]
|
| 321 |
+
).then(
|
| 322 |
+
fn=self.get_latest_status_message_logs_and_button_labels,
|
| 323 |
+
outputs=[
|
| 324 |
+
self.components["status_box"],
|
| 325 |
+
self.components["log_box"],
|
| 326 |
+
self.components["start_btn"],
|
| 327 |
+
self.components["stop_btn"],
|
| 328 |
+
self.components["delete_checkpoints_btn"]
|
| 329 |
+
]
|
| 330 |
+
)
|
| 331 |
|
| 332 |
def handle_training_start(self, preset, model_type, training_type, *args):
|
| 333 |
"""Handle training start with proper log parser reset and checkpoint detection"""
|
|
|
|
| 375 |
except Exception as e:
|
| 376 |
logger.exception("Error starting training")
|
| 377 |
return f"Error starting training: {str(e)}", f"Exception: {str(e)}\n\nCheck the logs for more details."
|
|
|
|
| 378 |
|
| 379 |
def get_model_info(self, model_type: str, training_type: str) -> str:
|
| 380 |
"""Get information about the selected model type and training method"""
|
| 381 |
+
if model_type == "HunyuanVideo (LoRA)":
|
|
|
|
|
|
|
| 382 |
base_info = """### HunyuanVideo
|
| 383 |
- Required VRAM: ~48GB minimum
|
| 384 |
- Recommended batch size: 1-2
|
| 385 |
- Typical training time: 2-4 hours
|
| 386 |
- Default resolution: 49x512x768"""
|
| 387 |
|
| 388 |
+
if training_type == "LoRA Finetune":
|
| 389 |
return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
|
| 390 |
else:
|
| 391 |
+
return base_info + "\n- Required VRAM: ~48GB minimum\n- **Full finetune not recommended due to VRAM requirements**"
|
| 392 |
|
| 393 |
+
elif model_type == "LTX-Video (LoRA)":
|
| 394 |
+
base_info = """### LTX-Video
|
| 395 |
+
- Recommended batch size: 1-4
|
| 396 |
- Typical training time: 1-3 hours
|
| 397 |
- Default resolution: 49x512x768"""
|
| 398 |
|
| 399 |
+
if training_type == "LoRA Finetune":
|
| 400 |
+
return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)"
|
|
|
|
|
|
|
| 401 |
else:
|
| 402 |
return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB"
|
| 403 |
|
| 404 |
+
elif model_type == "Wan-2.1-T2V (LoRA)":
|
| 405 |
base_info = """### Wan-2.1-T2V
|
| 406 |
- Recommended batch size: 1-2
|
| 407 |
- Typical training time: 1-3 hours
|
| 408 |
- Default resolution: 49x512x768"""
|
| 409 |
|
| 410 |
+
if training_type == "LoRA Finetune":
|
| 411 |
return base_info + "\n- Required VRAM: ~16GB minimum\n- Default LoRA rank: 32 (~120 MB)"
|
|
|
|
|
|
|
| 412 |
else:
|
| 413 |
return base_info + "\n- **Full finetune not recommended due to VRAM requirements**"
|
| 414 |
+
|
| 415 |
+
# Default fallback
|
| 416 |
+
return f"### {model_type}\nPlease check documentation for VRAM requirements and recommended settings."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
|
| 418 |
+
def get_default_params(self, model_type: str, training_type: str) -> Dict[str, Any]:
|
| 419 |
"""Get default training parameters for model type"""
|
| 420 |
+
# Find preset that matches model type and training type
|
| 421 |
+
matching_presets = [
|
| 422 |
+
preset for preset_name, preset in TRAINING_PRESETS.items()
|
| 423 |
+
if preset["model_type"] == model_type and preset["training_type"] == training_type
|
| 424 |
+
]
|
| 425 |
+
|
| 426 |
+
if matching_presets:
|
| 427 |
+
# Use the first matching preset
|
| 428 |
+
preset = matching_presets[0]
|
| 429 |
+
return {
|
| 430 |
+
"num_epochs": preset.get("num_epochs", 70),
|
| 431 |
+
"batch_size": preset.get("batch_size", 1),
|
| 432 |
+
"learning_rate": preset.get("learning_rate", 3e-5),
|
| 433 |
+
"save_iterations": preset.get("save_iterations", 500),
|
| 434 |
+
"lora_rank": preset.get("lora_rank", "128"),
|
| 435 |
+
"lora_alpha": preset.get("lora_alpha", "128")
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
# Default fallbacks
|
| 439 |
if model_type == "hunyuan_video":
|
| 440 |
return {
|
| 441 |
"num_epochs": 70,
|
| 442 |
"batch_size": 1,
|
| 443 |
"learning_rate": 2e-5,
|
| 444 |
"save_iterations": 500,
|
| 445 |
+
"lora_rank": "128",
|
| 446 |
+
"lora_alpha": "128"
|
| 447 |
+
}
|
| 448 |
+
elif model_type == "ltx_video":
|
| 449 |
+
return {
|
| 450 |
+
"num_epochs": 70,
|
| 451 |
+
"batch_size": 1,
|
| 452 |
+
"learning_rate": 3e-5,
|
| 453 |
+
"save_iterations": 500,
|
| 454 |
+
"lora_rank": "128",
|
| 455 |
+
"lora_alpha": "128"
|
| 456 |
}
|
| 457 |
+
elif model_type == "wan":
|
| 458 |
+
return {
|
| 459 |
+
"num_epochs": 70,
|
| 460 |
+
"batch_size": 1,
|
| 461 |
+
"learning_rate": 5e-5,
|
| 462 |
+
"save_iterations": 500,
|
| 463 |
+
"lora_rank": "32",
|
| 464 |
+
"lora_alpha": "32"
|
| 465 |
+
}
|
| 466 |
+
else:
|
| 467 |
+
# Generic defaults
|
| 468 |
return {
|
| 469 |
"num_epochs": 70,
|
| 470 |
"batch_size": 1,
|
| 471 |
"learning_rate": 3e-5,
|
| 472 |
"save_iterations": 500,
|
| 473 |
+
"lora_rank": "128",
|
| 474 |
+
"lora_alpha": "128"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
}
|
| 476 |
|
| 477 |
def update_training_params(self, preset_name: str) -> Tuple:
|
|
|
|
| 486 |
key for key, value in MODEL_TYPES.items()
|
| 487 |
if value == preset["model_type"]
|
| 488 |
)
|
| 489 |
+
|
| 490 |
+
# Find the display name that maps to our training type
|
| 491 |
+
training_display_name = next(
|
| 492 |
+
key for key, value in TRAINING_TYPES.items()
|
| 493 |
+
if value == preset["training_type"]
|
| 494 |
+
)
|
| 495 |
|
| 496 |
# Get preset description for display
|
| 497 |
description = preset.get("description", "")
|
|
|
|
| 505 |
|
| 506 |
info_text = f"{description}{bucket_info}"
|
| 507 |
|
| 508 |
+
# Check if LoRA params should be visible
|
| 509 |
+
show_lora_params = preset["training_type"] == "lora"
|
| 510 |
+
|
| 511 |
# Use preset defaults but preserve user-modified values if they exist
|
| 512 |
+
lora_rank_val = current_state.get("lora_rank") if current_state.get("lora_rank") != preset.get("lora_rank", "128") else preset.get("lora_rank", "128")
|
| 513 |
+
lora_alpha_val = current_state.get("lora_alpha") if current_state.get("lora_alpha") != preset.get("lora_alpha", "128") else preset.get("lora_alpha", "128")
|
| 514 |
+
num_epochs_val = current_state.get("num_epochs") if current_state.get("num_epochs") != preset.get("num_epochs", 70) else preset.get("num_epochs", 70)
|
| 515 |
+
batch_size_val = current_state.get("batch_size") if current_state.get("batch_size") != preset.get("batch_size", 1) else preset.get("batch_size", 1)
|
| 516 |
+
learning_rate_val = current_state.get("learning_rate") if current_state.get("learning_rate") != preset.get("learning_rate", 3e-5) else preset.get("learning_rate", 3e-5)
|
| 517 |
+
save_iterations_val = current_state.get("save_iterations") if current_state.get("save_iterations") != preset.get("save_iterations", 500) else preset.get("save_iterations", 500)
|
| 518 |
|
| 519 |
+
# Return values in the same order as the output components
|
| 520 |
return (
|
| 521 |
model_display_name,
|
| 522 |
+
training_display_name,
|
| 523 |
lora_rank_val,
|
| 524 |
lora_alpha_val,
|
| 525 |
num_epochs_val,
|
| 526 |
batch_size_val,
|
| 527 |
learning_rate_val,
|
| 528 |
save_iterations_val,
|
| 529 |
+
info_text,
|
| 530 |
+
gr.Row(visible=show_lora_params)
|
| 531 |
)
|
| 532 |
|
| 533 |
def update_training_ui(self, training_state: Dict[str, Any]):
|
|
|
|
| 541 |
f"Status: {training_state['status']}",
|
| 542 |
f"Progress: {training_state['progress']}",
|
| 543 |
f"Step: {training_state['current_step']}/{training_state['total_steps']}",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 544 |
f"Time elapsed: {training_state['elapsed']}",
|
| 545 |
f"Estimated remaining: {training_state['remaining']}",
|
| 546 |
"",
|