|
|
import os |
|
|
import uuid |
|
|
import json |
|
|
import shutil |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Dict, Any, Optional, List, Tuple |
|
|
from pathlib import Path |
|
|
import torch |
|
|
import math |
|
|
|
|
|
def parse_bool_env(env_value: Optional[str]) -> bool: |
|
|
"""Parse environment variable string to boolean |
|
|
|
|
|
Handles various true/false string representations: |
|
|
- True: "true", "True", "TRUE", "1", etc |
|
|
- False: "false", "False", "FALSE", "0", "", None |
|
|
""" |
|
|
if not env_value: |
|
|
return False |
|
|
return str(env_value).lower() in ('true', '1', 't', 'y', 'yes') |
|
|
|
|
|
HF_API_TOKEN = os.getenv("HF_API_TOKEN") |
|
|
ASK_USER_TO_DUPLICATE_SPACE = parse_bool_env(os.getenv("ASK_USER_TO_DUPLICATE_SPACE")) |
|
|
|
|
|
|
|
|
USE_LARGE_DATASET = parse_bool_env(os.getenv("USE_LARGE_DATASET")) |
|
|
|
|
|
|
|
|
STORAGE_PATH = Path(os.environ.get('STORAGE_PATH', '.data')) |
|
|
|
|
|
|
|
|
|
|
|
VIDEOS_TO_SPLIT_PATH = STORAGE_PATH / "videos_to_split" |
|
|
STAGING_PATH = STORAGE_PATH / "staging" |
|
|
|
|
|
|
|
|
|
|
|
PRELOAD_CAPTIONING_MODEL = parse_bool_env(os.environ.get('PRELOAD_CAPTIONING_MODEL')) |
|
|
|
|
|
CAPTIONING_MODEL = "lmms-lab/LLaVA-Video-7B-Qwen2" |
|
|
|
|
|
DEFAULT_PROMPT_PREFIX = "In the style of TOK, " |
|
|
|
|
|
|
|
|
USE_MOCK_CAPTIONING_MODEL = parse_bool_env(os.environ.get('USE_MOCK_CAPTIONING_MODEL')) |
|
|
|
|
|
DEFAULT_CAPTIONING_BOT_INSTRUCTIONS = "Please write a full video description. Be synthetic and methodically list camera (close-up shot, medium-shot..), genre (music video, horror movie scene, video game footage, go pro footage, japanese anime, noir film, science-fiction, action movie, documentary..), characters (physical appearance, look, skin, facial features, haircut, clothing), scene (action, positions, movements), location (indoor, outdoor, place, building, country..), time and lighting (natural, golden hour, night time, LED lights, kelvin temperature etc), weather and climate (dusty, rainy, fog, haze, snowing..), era/settings." |
|
|
|
|
|
def generate_model_project_id() -> str: |
|
|
"""Generate a new UUID for a model project""" |
|
|
return str(uuid.uuid4()) |
|
|
|
|
|
def get_global_config_path() -> Path: |
|
|
"""Get the path to the global config file""" |
|
|
return STORAGE_PATH / "config.json" |
|
|
|
|
|
def load_global_config() -> dict: |
|
|
"""Load the global configuration file |
|
|
|
|
|
Returns: |
|
|
Dict containing global configuration |
|
|
""" |
|
|
config_path = get_global_config_path() |
|
|
if not config_path.exists(): |
|
|
|
|
|
default_config = { |
|
|
"latest_model_project_id": None |
|
|
} |
|
|
save_global_config(default_config) |
|
|
return default_config |
|
|
|
|
|
try: |
|
|
with open(config_path, 'r') as f: |
|
|
return json.load(f) |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading global config: {e}") |
|
|
return {"latest_model_project_id": None} |
|
|
|
|
|
def save_global_config(config: dict) -> bool: |
|
|
"""Save the global configuration file |
|
|
|
|
|
Args: |
|
|
config: Dictionary containing configuration to save |
|
|
|
|
|
Returns: |
|
|
True if successful, False otherwise |
|
|
""" |
|
|
config_path = get_global_config_path() |
|
|
try: |
|
|
with open(config_path, 'w') as f: |
|
|
json.dump(config, f, indent=2) |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"Error saving global config: {e}") |
|
|
return False |
|
|
|
|
|
def update_latest_project_id(project_id: str) -> bool: |
|
|
"""Update the latest project ID in global config |
|
|
|
|
|
Args: |
|
|
project_id: The project ID to save |
|
|
|
|
|
Returns: |
|
|
True if successful, False otherwise |
|
|
""" |
|
|
config = load_global_config() |
|
|
config["latest_model_project_id"] = project_id |
|
|
return save_global_config(config) |
|
|
|
|
|
def get_project_paths(project_id: str) -> Tuple[Path, Path, Path, Path]: |
|
|
"""Get paths for a specific project |
|
|
|
|
|
Args: |
|
|
project_id: The model project UUID |
|
|
|
|
|
Returns: |
|
|
Tuple of (training_path, training_videos_path, output_path, log_file_path) |
|
|
""" |
|
|
project_base = STORAGE_PATH / "models" / project_id |
|
|
training_path = project_base / "training" |
|
|
training_videos_path = training_path / "videos" |
|
|
output_path = project_base / "output" |
|
|
log_file_path = output_path / "last_session.log" |
|
|
|
|
|
|
|
|
training_path.mkdir(parents=True, exist_ok=True) |
|
|
training_videos_path.mkdir(parents=True, exist_ok=True) |
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
return training_path, training_videos_path, output_path, log_file_path |
|
|
|
|
|
def migrate_legacy_project() -> Optional[str]: |
|
|
"""Migrate legacy project structure to new UUID-based structure |
|
|
|
|
|
Returns: |
|
|
New project UUID if migration was performed, None otherwise |
|
|
""" |
|
|
legacy_training = STORAGE_PATH / "training" |
|
|
legacy_output = STORAGE_PATH / "output" |
|
|
|
|
|
|
|
|
has_training_data = legacy_training.exists() and any(legacy_training.iterdir()) |
|
|
has_output_data = legacy_output.exists() and any(legacy_output.iterdir()) |
|
|
|
|
|
if not (has_training_data or has_output_data): |
|
|
return None |
|
|
|
|
|
|
|
|
project_id = generate_model_project_id() |
|
|
training_path, training_videos_path, output_path, log_file_path = get_project_paths(project_id) |
|
|
|
|
|
|
|
|
if has_training_data: |
|
|
|
|
|
for file in legacy_training.glob("*"): |
|
|
if file.is_file(): |
|
|
shutil.copy2(file, training_path) |
|
|
|
|
|
|
|
|
legacy_videos = legacy_training / "videos" |
|
|
if legacy_videos.exists(): |
|
|
for file in legacy_videos.glob("*"): |
|
|
if file.is_file(): |
|
|
shutil.copy2(file, training_videos_path) |
|
|
|
|
|
if has_output_data: |
|
|
for file in legacy_output.glob("*"): |
|
|
if file.is_file(): |
|
|
shutil.copy2(file, output_path) |
|
|
elif file.is_dir(): |
|
|
|
|
|
target_dir = output_path / file.name |
|
|
target_dir.mkdir(exist_ok=True) |
|
|
for subfile in file.glob("*"): |
|
|
if subfile.is_file(): |
|
|
shutil.copy2(subfile, target_dir) |
|
|
|
|
|
return project_id |
|
|
|
|
|
|
|
|
STORAGE_PATH.mkdir(parents=True, exist_ok=True) |
|
|
VIDEOS_TO_SPLIT_PATH.mkdir(parents=True, exist_ok=True) |
|
|
STAGING_PATH.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
MODELS_PATH = STORAGE_PATH / "models" |
|
|
MODELS_PATH.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
VMS_ADMIN_PASSWORD = os.environ.get('VMS_ADMIN_PASSWORD', '') |
|
|
|
|
|
|
|
|
NORMALIZE_IMAGES_TO = os.environ.get('NORMALIZE_IMAGES_TO', 'png').lower() |
|
|
if NORMALIZE_IMAGES_TO not in ['png', 'jpg']: |
|
|
raise ValueError("NORMALIZE_IMAGES_TO must be either 'png' or 'jpg'") |
|
|
JPEG_QUALITY = int(os.environ.get('JPEG_QUALITY', '97')) |
|
|
|
|
|
MODEL_TYPES = { |
|
|
"HunyuanVideo": "hunyuan_video", |
|
|
"LTX-Video": "ltx_video", |
|
|
"Wan": "wan" |
|
|
} |
|
|
|
|
|
|
|
|
TRAINING_TYPES = { |
|
|
"LoRA Finetune": "lora", |
|
|
"Full Finetune": "full-finetune", |
|
|
"Control LoRA": "control-lora", |
|
|
"Control Full Finetune": "control-full-finetune" |
|
|
} |
|
|
|
|
|
|
|
|
MODEL_VERSIONS = { |
|
|
"wan": { |
|
|
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers": { |
|
|
"name": "Wan 2.1 T2V 1.3B (text-only, smaller)", |
|
|
"type": "text-to-video", |
|
|
"description": "Faster, smaller model (1.3B parameters)" |
|
|
}, |
|
|
"Wan-AI/Wan2.1-T2V-14B-Diffusers": { |
|
|
"name": "Wan 2.1 T2V 14B (text-only, larger)", |
|
|
"type": "text-to-video", |
|
|
"description": "Higher quality but slower (14B parameters)" |
|
|
}, |
|
|
"Wan-AI/Wan2.1-I2V-14B-480P-Diffusers": { |
|
|
"name": "Wan 2.1 I2V 480p (image+text)", |
|
|
"type": "image-to-video", |
|
|
"description": "Image conditioning at 480p resolution" |
|
|
}, |
|
|
"Wan-AI/Wan2.1-I2V-14B-720P-Diffusers": { |
|
|
"name": "Wan 2.1 I2V 720p (image+text)", |
|
|
"type": "image-to-video", |
|
|
"description": "Image conditioning at 720p resolution" |
|
|
}, |
|
|
"Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers": { |
|
|
"name": "Wan 2.1 FLF2V 720p (frame conditioning)", |
|
|
"type": "frame-to-video", |
|
|
"description": "Frame conditioning (first/last frame to video) at 720p resolution" |
|
|
} |
|
|
}, |
|
|
"ltx_video": { |
|
|
"Lightricks/LTX-Video": { |
|
|
"name": "LTX Video (official)", |
|
|
"type": "text-to-video", |
|
|
"description": "Official LTX Video model" |
|
|
} |
|
|
}, |
|
|
"hunyuan_video": { |
|
|
"hunyuanvideo-community/HunyuanVideo": { |
|
|
"name": "Hunyuan Video (official)", |
|
|
"type": "text-to-video", |
|
|
"description": "Official Hunyuan Video model" |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
DEFAULT_SEED = 42 |
|
|
|
|
|
DEFAULT_REMOVE_COMMON_LLM_CAPTION_PREFIXES = True |
|
|
|
|
|
DEFAULT_DATASET_TYPE = "video" |
|
|
DEFAULT_TRAINING_TYPE = "lora" |
|
|
|
|
|
DEFAULT_RESHAPE_MODE = "bicubic" |
|
|
|
|
|
DEFAULT_MIXED_PRECISION = "bf16" |
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS = 200 |
|
|
|
|
|
DEFAULT_LORA_RANK = 128 |
|
|
DEFAULT_LORA_RANK_STR = str(DEFAULT_LORA_RANK) |
|
|
|
|
|
DEFAULT_LORA_ALPHA = 128 |
|
|
DEFAULT_LORA_ALPHA_STR = str(DEFAULT_LORA_ALPHA) |
|
|
|
|
|
DEFAULT_CAPTION_DROPOUT_P = 0.05 |
|
|
|
|
|
DEFAULT_BATCH_SIZE = 1 |
|
|
|
|
|
DEFAULT_LEARNING_RATE = 3e-5 |
|
|
|
|
|
|
|
|
DEFAULT_NUM_GPUS = 1 |
|
|
DEFAULT_MAX_GPUS = min(8, torch.cuda.device_count() if torch.cuda.is_available() else 1) |
|
|
DEFAULT_PRECOMPUTATION_ITEMS = 512 |
|
|
|
|
|
DEFAULT_NB_TRAINING_STEPS = 1000 |
|
|
|
|
|
|
|
|
DEFAULT_NB_LR_WARMUP_STEPS = math.ceil(0.20 * DEFAULT_NB_TRAINING_STEPS) |
|
|
|
|
|
|
|
|
DEFAULT_AUTO_RESUME = False |
|
|
|
|
|
|
|
|
DEFAULT_CONTROL_TYPE = "canny" |
|
|
DEFAULT_TRAIN_QK_NORM = False |
|
|
DEFAULT_FRAME_CONDITIONING_TYPE = "full" |
|
|
DEFAULT_FRAME_CONDITIONING_INDEX = 0 |
|
|
DEFAULT_FRAME_CONDITIONING_CONCATENATE_MASK = False |
|
|
|
|
|
|
|
|
DEFAULT_VALIDATION_NB_STEPS = 50 |
|
|
DEFAULT_VALIDATION_HEIGHT = 512 |
|
|
DEFAULT_VALIDATION_WIDTH = 768 |
|
|
DEFAULT_VALIDATION_NB_FRAMES = 49 |
|
|
DEFAULT_VALIDATION_FRAMERATE = 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SD_16_9_W = 1024 |
|
|
SD_16_9_H = 576 |
|
|
SD_9_16_W = 576 |
|
|
SD_9_16_H = 1024 |
|
|
|
|
|
|
|
|
HD_16_9_W = 1280 |
|
|
HD_16_9_H = 720 |
|
|
HD_9_16_W = 720 |
|
|
HD_9_16_H = 1280 |
|
|
|
|
|
|
|
|
FHD_16_9_W = 1920 |
|
|
FHD_16_9_H = 1080 |
|
|
FHD_9_16_W = 1080 |
|
|
FHD_9_16_H = 1920 |
|
|
|
|
|
|
|
|
QHD_16_9_W = 2160 |
|
|
QHD_16_9_H = 1440 |
|
|
QHD_9_16_W = 1440 |
|
|
QHD_9_16_H = 2160 |
|
|
|
|
|
|
|
|
UHD_16_9_W = 3840 |
|
|
UHD_16_9_H = 2160 |
|
|
UHD_9_16_W = 2160 |
|
|
UHD_9_16_H = 3840 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
NB_FRAMES_1 = 1 |
|
|
NB_FRAMES_9 = 8 + 1 |
|
|
NB_FRAMES_17 = 8 * 2 + 1 |
|
|
NB_FRAMES_33 = 8 * 4 + 1 |
|
|
NB_FRAMES_49 = 8 * 6 + 1 |
|
|
NB_FRAMES_65 = 8 * 8 + 1 |
|
|
NB_FRAMES_73 = 8 * 9 + 1 |
|
|
NB_FRAMES_81 = 8 * 10 + 1 |
|
|
NB_FRAMES_89 = 8 * 11 + 1 |
|
|
NB_FRAMES_97 = 8 * 12 + 1 |
|
|
NB_FRAMES_105 = 8 * 13 + 1 |
|
|
NB_FRAMES_113 = 8 * 14 + 1 |
|
|
NB_FRAMES_121 = 8 * 14 + 1 |
|
|
NB_FRAMES_129 = 8 * 16 + 1 |
|
|
NB_FRAMES_137 = 8 * 16 + 1 |
|
|
NB_FRAMES_145 = 8 * 18 + 1 |
|
|
NB_FRAMES_161 = 8 * 20 + 1 |
|
|
NB_FRAMES_177 = 8 * 22 + 1 |
|
|
NB_FRAMES_193 = 8 * 24 + 1 |
|
|
NB_FRAMES_201 = 8 * 25 + 1 |
|
|
NB_FRAMES_209 = 8 * 26 + 1 |
|
|
NB_FRAMES_217 = 8 * 27 + 1 |
|
|
NB_FRAMES_225 = 8 * 28 + 1 |
|
|
NB_FRAMES_233 = 8 * 29 + 1 |
|
|
NB_FRAMES_241 = 8 * 30 + 1 |
|
|
NB_FRAMES_249 = 8 * 31 + 1 |
|
|
NB_FRAMES_257 = 8 * 32 + 1 |
|
|
NB_FRAMES_265 = 8 * 33 + 1 |
|
|
NB_FRAMES_273 = 8 * 34 + 1 |
|
|
NB_FRAMES_289 = 8 * 36 + 1 |
|
|
NB_FRAMES_305 = 8 * 38 + 1 |
|
|
NB_FRAMES_321 = 8 * 40 + 1 |
|
|
NB_FRAMES_337 = 8 * 42 + 1 |
|
|
NB_FRAMES_353 = 8 * 44 + 1 |
|
|
NB_FRAMES_369 = 8 * 46 + 1 |
|
|
NB_FRAMES_385 = 8 * 48 + 1 |
|
|
NB_FRAMES_401 = 8 * 50 + 1 |
|
|
NB_FRAMES_417 = 8 * 52 + 1 |
|
|
NB_FRAMES_433 = 8 * 54 + 1 |
|
|
NB_FRAMES_449 = 8 * 56 + 1 |
|
|
NB_FRAMES_465 = 8 * 58 + 1 |
|
|
NB_FRAMES_481 = 8 * 60 + 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SD_TRAINING_BUCKETS = [ |
|
|
(NB_FRAMES_1, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_9, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_17, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_33, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_49, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_65, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_73, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_81, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_89, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_97, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_105, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_113, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_121, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_129, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_137, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_145, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_161, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_177, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_193, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_201, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_209, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_217, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_225, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_233, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_241, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_249, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_257, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_265, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_273, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_289, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_305, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_321, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_337, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_353, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_369, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_385, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_401, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_417, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_433, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_449, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_465, SD_16_9_H, SD_16_9_W), |
|
|
(NB_FRAMES_481, SD_16_9_H, SD_16_9_W), |
|
|
] |
|
|
|
|
|
|
|
|
HD_TRAINING_BUCKETS = [ |
|
|
(NB_FRAMES_1, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_9, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_17, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_33, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_49, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_65, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_73, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_81, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_89, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_97, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_105, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_113, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_121, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_129, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_137, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_145, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_161, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_177, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_193, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_201, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_209, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_217, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_225, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_233, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_241, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_249, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_257, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_265, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_273, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_289, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_305, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_321, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_337, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_353, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_369, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_385, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_401, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_417, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_433, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_449, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_465, HD_16_9_H, HD_16_9_W), |
|
|
(NB_FRAMES_481, HD_16_9_H, HD_16_9_W), |
|
|
] |
|
|
|
|
|
|
|
|
FHD_TRAINING_BUCKETS = [ |
|
|
(NB_FRAMES_1, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_9, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_17, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_33, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_49, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_65, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_73, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_81, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_89, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_97, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_105, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_113, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_121, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_129, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_137, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_145, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_161, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_177, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_193, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_201, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_209, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_217, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_225, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_233, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_241, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_249, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_257, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_265, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_273, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_289, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_305, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_321, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_337, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_353, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_369, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_385, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_401, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_417, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_433, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_449, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_465, FHD_16_9_H, FHD_16_9_W), |
|
|
(NB_FRAMES_481, FHD_16_9_H, FHD_16_9_W), |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
RESOLUTION_OPTIONS = { |
|
|
"SD (1024x576)": "SD_TRAINING_BUCKETS", |
|
|
"HD (1280x720)": "HD_TRAINING_BUCKETS", |
|
|
"FHD (1920x1080)": "FHD_TRAINING_BUCKETS" |
|
|
} |
|
|
|
|
|
|
|
|
HUNYUAN_VIDEO_DEFAULTS = { |
|
|
"lora": { |
|
|
"learning_rate": 2e-5, |
|
|
"flow_weighting_scheme": "none", |
|
|
"lora_rank": DEFAULT_LORA_RANK_STR, |
|
|
"lora_alpha": DEFAULT_LORA_ALPHA_STR |
|
|
}, |
|
|
"control-lora": { |
|
|
"learning_rate": 2e-5, |
|
|
"flow_weighting_scheme": "none", |
|
|
"lora_rank": "128", |
|
|
"lora_alpha": "128", |
|
|
"control_type": "custom", |
|
|
"train_qk_norm": True, |
|
|
"frame_conditioning_type": "index", |
|
|
"frame_conditioning_index": 0, |
|
|
"frame_conditioning_concatenate_mask": True |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
LTX_VIDEO_DEFAULTS = { |
|
|
"lora": { |
|
|
"learning_rate": DEFAULT_LEARNING_RATE, |
|
|
"flow_weighting_scheme": "none", |
|
|
"lora_rank": DEFAULT_LORA_RANK_STR, |
|
|
"lora_alpha": DEFAULT_LORA_ALPHA_STR |
|
|
}, |
|
|
"full-finetune": { |
|
|
"learning_rate": DEFAULT_LEARNING_RATE, |
|
|
"flow_weighting_scheme": "logit_normal" |
|
|
}, |
|
|
"control-lora": { |
|
|
"learning_rate": DEFAULT_LEARNING_RATE, |
|
|
"flow_weighting_scheme": "logit_normal", |
|
|
"lora_rank": "128", |
|
|
"lora_alpha": "128", |
|
|
"control_type": "custom", |
|
|
"train_qk_norm": True, |
|
|
"frame_conditioning_type": "index", |
|
|
"frame_conditioning_index": 0, |
|
|
"frame_conditioning_concatenate_mask": True |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
WAN_DEFAULTS = { |
|
|
"lora": { |
|
|
"learning_rate": 5e-5, |
|
|
"flow_weighting_scheme": "logit_normal", |
|
|
"lora_rank": "32", |
|
|
"lora_alpha": "32" |
|
|
}, |
|
|
"control-lora": { |
|
|
"learning_rate": 5e-5, |
|
|
"flow_weighting_scheme": "logit_normal", |
|
|
"lora_rank": "32", |
|
|
"lora_alpha": "32", |
|
|
"control_type": "custom", |
|
|
"train_qk_norm": True, |
|
|
"frame_conditioning_type": "index", |
|
|
"frame_conditioning_index": 0, |
|
|
"frame_conditioning_concatenate_mask": True |
|
|
} |
|
|
} |
|
|
|
|
|
@dataclass |
|
|
class TrainingConfig: |
|
|
"""Configuration class for finetrainers training""" |
|
|
|
|
|
|
|
|
model_name: str |
|
|
pretrained_model_name_or_path: str |
|
|
data_root: str |
|
|
output_dir: str |
|
|
|
|
|
|
|
|
revision: Optional[str] = None |
|
|
version: Optional[str] = None |
|
|
cache_dir: Optional[str] = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video_column: str = "videos.txt" |
|
|
caption_column: str = "prompts.txt" |
|
|
|
|
|
id_token: Optional[str] = None |
|
|
video_resolution_buckets: List[Tuple[int, int, int]] = field(default_factory=lambda: SD_TRAINING_BUCKETS) |
|
|
video_reshape_mode: str = "center" |
|
|
caption_dropout_p: float = DEFAULT_CAPTION_DROPOUT_P |
|
|
caption_dropout_technique: str = "empty" |
|
|
precompute_conditions: bool = False |
|
|
|
|
|
|
|
|
flow_resolution_shifting: bool = False |
|
|
flow_weighting_scheme: str = "none" |
|
|
flow_logit_mean: float = 0.0 |
|
|
flow_logit_std: float = 1.0 |
|
|
flow_mode_scale: float = 1.29 |
|
|
|
|
|
|
|
|
training_type: str = "lora" |
|
|
seed: int = DEFAULT_SEED |
|
|
mixed_precision: str = "bf16" |
|
|
batch_size: int = 1 |
|
|
train_steps: int = DEFAULT_NB_TRAINING_STEPS |
|
|
lora_rank: int = DEFAULT_LORA_RANK |
|
|
lora_alpha: int = DEFAULT_LORA_ALPHA |
|
|
target_modules: List[str] = field(default_factory=lambda: ["to_q", "to_k", "to_v", "to_out.0"]) |
|
|
gradient_accumulation_steps: int = 1 |
|
|
gradient_checkpointing: bool = True |
|
|
checkpointing_steps: int = DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS |
|
|
checkpointing_limit: Optional[int] = 2 |
|
|
resume_from_checkpoint: Optional[str] = None |
|
|
enable_slicing: bool = True |
|
|
enable_tiling: bool = True |
|
|
|
|
|
|
|
|
optimizer: str = "adamw" |
|
|
lr: float = DEFAULT_LEARNING_RATE |
|
|
scale_lr: bool = False |
|
|
lr_scheduler: str = "constant_with_warmup" |
|
|
lr_warmup_steps: int = DEFAULT_NB_LR_WARMUP_STEPS |
|
|
lr_num_cycles: int = 1 |
|
|
lr_power: float = 1.0 |
|
|
beta1: float = 0.9 |
|
|
beta2: float = 0.95 |
|
|
weight_decay: float = 1e-4 |
|
|
epsilon: float = 1e-8 |
|
|
max_grad_norm: float = 1.0 |
|
|
|
|
|
|
|
|
tracker_name: str = "finetrainers" |
|
|
report_to: str = "wandb" |
|
|
nccl_timeout: int = 1800 |
|
|
|
|
|
@classmethod |
|
|
def hunyuan_video_lora(cls, data_path: str, output_path: str, buckets=None) -> 'TrainingConfig': |
|
|
"""Configuration for Hunyuan video-to-video LoRA training""" |
|
|
return cls( |
|
|
model_name="hunyuan_video", |
|
|
pretrained_model_name_or_path="hunyuanvideo-community/HunyuanVideo", |
|
|
data_root=data_path, |
|
|
output_dir=output_path, |
|
|
batch_size=1, |
|
|
train_steps=DEFAULT_NB_TRAINING_STEPS, |
|
|
lr=2e-5, |
|
|
gradient_checkpointing=True, |
|
|
id_token=None, |
|
|
gradient_accumulation_steps=1, |
|
|
lora_rank=DEFAULT_LORA_RANK, |
|
|
lora_alpha=DEFAULT_LORA_ALPHA, |
|
|
video_resolution_buckets=buckets or SD_TRAINING_BUCKETS, |
|
|
caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P, |
|
|
flow_weighting_scheme="none", |
|
|
training_type="lora" |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def ltx_video_lora(cls, data_path: str, output_path: str, buckets=None) -> 'TrainingConfig': |
|
|
"""Configuration for LTX-Video LoRA training""" |
|
|
return cls( |
|
|
model_name="ltx_video", |
|
|
pretrained_model_name_or_path="Lightricks/LTX-Video", |
|
|
data_root=data_path, |
|
|
output_dir=output_path, |
|
|
batch_size=1, |
|
|
train_steps=DEFAULT_NB_TRAINING_STEPS, |
|
|
lr=DEFAULT_LEARNING_RATE, |
|
|
gradient_checkpointing=True, |
|
|
id_token=None, |
|
|
gradient_accumulation_steps=4, |
|
|
lora_rank=DEFAULT_LORA_RANK, |
|
|
lora_alpha=DEFAULT_LORA_ALPHA, |
|
|
video_resolution_buckets=buckets or SD_TRAINING_BUCKETS, |
|
|
caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P, |
|
|
flow_weighting_scheme="logit_normal", |
|
|
training_type="lora" |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def ltx_video_full_finetune(cls, data_path: str, output_path: str, buckets=None) -> 'TrainingConfig': |
|
|
"""Configuration for LTX-Video full finetune training""" |
|
|
return cls( |
|
|
model_name="ltx_video", |
|
|
pretrained_model_name_or_path="Lightricks/LTX-Video", |
|
|
data_root=data_path, |
|
|
output_dir=output_path, |
|
|
batch_size=1, |
|
|
train_steps=DEFAULT_NB_TRAINING_STEPS, |
|
|
lr=1e-5, |
|
|
gradient_checkpointing=True, |
|
|
id_token=None, |
|
|
gradient_accumulation_steps=1, |
|
|
video_resolution_buckets=buckets or SD_TRAINING_BUCKETS, |
|
|
caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P, |
|
|
flow_weighting_scheme="logit_normal", |
|
|
training_type="full-finetune" |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def wan_lora(cls, data_path: str, output_path: str, buckets=None) -> 'TrainingConfig': |
|
|
"""Configuration for Wan T2V LoRA training""" |
|
|
return cls( |
|
|
model_name="wan", |
|
|
pretrained_model_name_or_path="Wan-AI/Wan2.1-T2V-1.3B-Diffusers", |
|
|
data_root=data_path, |
|
|
output_dir=output_path, |
|
|
batch_size=1, |
|
|
train_steps=DEFAULT_NB_TRAINING_STEPS, |
|
|
lr=5e-5, |
|
|
gradient_checkpointing=True, |
|
|
id_token=None, |
|
|
gradient_accumulation_steps=1, |
|
|
lora_rank=32, |
|
|
lora_alpha=32, |
|
|
target_modules=["blocks.*(to_q|to_k|to_v|to_out.0)"], |
|
|
video_resolution_buckets=buckets or SD_TRAINING_BUCKETS, |
|
|
caption_dropout_p=DEFAULT_CAPTION_DROPOUT_P, |
|
|
flow_weighting_scheme="logit_normal", |
|
|
training_type="lora" |
|
|
) |
|
|
|
|
|
def to_args_list(self) -> List[str]: |
|
|
"""Convert config to command line arguments list""" |
|
|
args = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
args.extend(["--model_name", self.model_name]) |
|
|
|
|
|
args.extend(["--pretrained_model_name_or_path", self.pretrained_model_name_or_path]) |
|
|
if self.revision: |
|
|
args.extend(["--revision", self.revision]) |
|
|
if self.version: |
|
|
args.extend(["--variant", self.version]) |
|
|
if self.cache_dir: |
|
|
args.extend(["--cache_dir", self.cache_dir]) |
|
|
|
|
|
|
|
|
args.extend(["--dataset_config", self.data_root]) |
|
|
|
|
|
|
|
|
if self.id_token: |
|
|
args.extend(["--id_token", self.id_token]) |
|
|
|
|
|
|
|
|
if self.video_resolution_buckets: |
|
|
bucket_strs = [f"{f}x{h}x{w}" for f, h, w in self.video_resolution_buckets] |
|
|
args.extend(["--video_resolution_buckets"] + bucket_strs) |
|
|
|
|
|
args.extend(["--caption_dropout_p", str(self.caption_dropout_p)]) |
|
|
args.extend(["--caption_dropout_technique", self.caption_dropout_technique]) |
|
|
if self.precompute_conditions: |
|
|
args.append("--precompute_conditions") |
|
|
|
|
|
if hasattr(self, 'precomputation_items') and self.precomputation_items: |
|
|
args.extend(["--precomputation_items", str(self.precomputation_items)]) |
|
|
|
|
|
|
|
|
if self.flow_resolution_shifting: |
|
|
args.append("--flow_resolution_shifting") |
|
|
args.extend(["--flow_weighting_scheme", self.flow_weighting_scheme]) |
|
|
args.extend(["--flow_logit_mean", str(self.flow_logit_mean)]) |
|
|
args.extend(["--flow_logit_std", str(self.flow_logit_std)]) |
|
|
args.extend(["--flow_mode_scale", str(self.flow_mode_scale)]) |
|
|
|
|
|
|
|
|
args.extend(["--training_type",self.training_type]) |
|
|
args.extend(["--seed", str(self.seed)]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
args.extend(["--batch_size", str(self.batch_size)]) |
|
|
args.extend(["--train_steps", str(self.train_steps)]) |
|
|
|
|
|
|
|
|
if self.training_type == "lora": |
|
|
args.extend(["--rank", str(self.lora_rank)]) |
|
|
args.extend(["--lora_alpha", str(self.lora_alpha)]) |
|
|
args.extend(["--target_modules"] + self.target_modules) |
|
|
|
|
|
args.extend(["--gradient_accumulation_steps", str(self.gradient_accumulation_steps)]) |
|
|
if self.gradient_checkpointing: |
|
|
args.append("--gradient_checkpointing") |
|
|
args.extend(["--checkpointing_steps", str(self.checkpointing_steps)]) |
|
|
if self.checkpointing_limit: |
|
|
args.extend(["--checkpointing_limit", str(self.checkpointing_limit)]) |
|
|
if self.resume_from_checkpoint: |
|
|
args.extend(["--resume_from_checkpoint", self.resume_from_checkpoint]) |
|
|
if self.enable_slicing: |
|
|
args.append("--enable_slicing") |
|
|
if self.enable_tiling: |
|
|
args.append("--enable_tiling") |
|
|
|
|
|
|
|
|
args.extend(["--optimizer", self.optimizer]) |
|
|
args.extend(["--lr", str(self.lr)]) |
|
|
if self.scale_lr: |
|
|
args.append("--scale_lr") |
|
|
args.extend(["--lr_scheduler", self.lr_scheduler]) |
|
|
args.extend(["--lr_warmup_steps", str(self.lr_warmup_steps)]) |
|
|
args.extend(["--lr_num_cycles", str(self.lr_num_cycles)]) |
|
|
args.extend(["--lr_power", str(self.lr_power)]) |
|
|
args.extend(["--beta1", str(self.beta1)]) |
|
|
args.extend(["--beta2", str(self.beta2)]) |
|
|
args.extend(["--weight_decay", str(self.weight_decay)]) |
|
|
args.extend(["--epsilon", str(self.epsilon)]) |
|
|
args.extend(["--max_grad_norm", str(self.max_grad_norm)]) |
|
|
|
|
|
|
|
|
args.extend(["--tracker_name", self.tracker_name]) |
|
|
args.extend(["--output_dir", self.output_dir]) |
|
|
args.extend(["--report_to", self.report_to]) |
|
|
args.extend(["--nccl_timeout", str(self.nccl_timeout)]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
args.append("--remove_common_llm_caption_prefixes") |
|
|
|
|
|
return args |
|
|
|