|
|
import os |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Dict, Any, Optional, List, Tuple |
|
|
from pathlib import Path |
|
|
|
|
|
def parse_bool_env(env_value: Optional[str]) -> bool: |
|
|
"""Parse environment variable string to boolean |
|
|
|
|
|
Handles various true/false string representations: |
|
|
- True: "true", "True", "TRUE", "1", etc |
|
|
- False: "false", "False", "FALSE", "0", "", None |
|
|
""" |
|
|
if not env_value: |
|
|
return False |
|
|
return str(env_value).lower() in ('true', '1', 't', 'y', 'yes') |
|
|
|
|
|
HF_API_TOKEN = os.getenv("HF_API_TOKEN") |
|
|
ASK_USER_TO_DUPLICATE_SPACE = parse_bool_env(os.getenv("ASK_USER_TO_DUPLICATE_SPACE")) |
|
|
|
|
|
|
|
|
STORAGE_PATH = Path(os.environ.get('STORAGE_PATH', '.data')) |
|
|
|
|
|
|
|
|
VIDEOS_TO_SPLIT_PATH = STORAGE_PATH / "videos_to_split" |
|
|
STAGING_PATH = STORAGE_PATH / "staging" |
|
|
TRAINING_PATH = STORAGE_PATH / "training" |
|
|
TRAINING_VIDEOS_PATH = TRAINING_PATH / "videos" |
|
|
MODEL_PATH = STORAGE_PATH / "model" |
|
|
OUTPUT_PATH = STORAGE_PATH / "output" |
|
|
LOG_FILE_PATH = OUTPUT_PATH / "last_session.log" |
|
|
|
|
|
|
|
|
PRELOAD_CAPTIONING_MODEL = parse_bool_env(os.environ.get('PRELOAD_CAPTIONING_MODEL')) |
|
|
|
|
|
CAPTIONING_MODEL = "lmms-lab/LLaVA-Video-7B-Qwen2" |
|
|
|
|
|
DEFAULT_PROMPT_PREFIX = "In the style of TOK, " |
|
|
|
|
|
|
|
|
USE_MOCK_CAPTIONING_MODEL = parse_bool_env(os.environ.get('USE_MOCK_CAPTIONING_MODEL')) |
|
|
|
|
|
DEFAULT_CAPTIONING_BOT_INSTRUCTIONS = "Please write a full video description. Be synthetic, don't say things like ""this video features.."" etc. Instead, methodically list camera (close-up shot, medium-shot..), genre (music video, horror movie scene, video game footage, go pro footage, japanese anime, noir film, science-fiction, action movie, documentary..), characters (physical appearance, look, skin, facial features, haircut, clothing), scene (action, positions, movements), location (indoor, outdoor, place, building, country..), time and lighting (natural, golden hour, night time, LED lights, kelvin temperature etc), weather and climate (dusty, rainy, fog, haze, snowing..), era/settings." |
|
|
|
|
|
|
|
|
STORAGE_PATH.mkdir(parents=True, exist_ok=True) |
|
|
VIDEOS_TO_SPLIT_PATH.mkdir(parents=True, exist_ok=True) |
|
|
STAGING_PATH.mkdir(parents=True, exist_ok=True) |
|
|
TRAINING_PATH.mkdir(parents=True, exist_ok=True) |
|
|
TRAINING_VIDEOS_PATH.mkdir(parents=True, exist_ok=True) |
|
|
MODEL_PATH.mkdir(parents=True, exist_ok=True) |
|
|
OUTPUT_PATH.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
NORMALIZE_IMAGES_TO = os.environ.get('NORMALIZE_IMAGES_TO', 'png').lower() |
|
|
if NORMALIZE_IMAGES_TO not in ['png', 'jpg']: |
|
|
raise ValueError("NORMALIZE_IMAGES_TO must be either 'png' or 'jpg'") |
|
|
JPEG_QUALITY = int(os.environ.get('JPEG_QUALITY', '97')) |
|
|
|
|
|
MODEL_TYPES = { |
|
|
"HunyuanVideo (LoRA)": "hunyuan_video", |
|
|
"LTX-Video (LoRA)": "ltx_video" |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MEDIUM_19_9_RATIO_WIDTH = 768 |
|
|
MEDIUM_19_9_RATIO_HEIGHT = 512 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
NB_FRAMES_1 = 1 |
|
|
NB_FRAMES_9 = 8 + 1 |
|
|
NB_FRAMES_17 = 8 * 2 + 1 |
|
|
NB_FRAMES_32 = 8 * 4 + 1 |
|
|
NB_FRAMES_48 = 8 * 6 + 1 |
|
|
NB_FRAMES_64 = 8 * 8 + 1 |
|
|
NB_FRAMES_80 = 8 * 10 + 1 |
|
|
NB_FRAMES_96 = 8 * 12 + 1 |
|
|
NB_FRAMES_112 = 8 * 14 + 1 |
|
|
NB_FRAMES_128 = 8 * 16 + 1 |
|
|
NB_FRAMES_144 = 8 * 18 + 1 |
|
|
NB_FRAMES_160 = 8 * 20 + 1 |
|
|
NB_FRAMES_176 = 8 * 22 + 1 |
|
|
NB_FRAMES_192 = 8 * 24 + 1 |
|
|
NB_FRAMES_224 = 8 * 28 + 1 |
|
|
NB_FRAMES_256 = 8 * 32 + 1 |
|
|
|
|
|
|
|
|
|
|
|
SMALL_TRAINING_BUCKETS = [ |
|
|
(NB_FRAMES_1, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_9, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_17, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_32, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_48, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_64, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_80, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_96, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_112, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_128, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_144, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_160, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_176, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_192, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_224, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_256, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
] |
|
|
|
|
|
MEDIUM_19_9_RATIO_WIDTH = 928 |
|
|
MEDIUM_19_9_RATIO_HEIGHT = 512 |
|
|
|
|
|
MEDIUM_19_9_RATIO_BUCKETS = [ |
|
|
(NB_FRAMES_1, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_9, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_17, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_32, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_48, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_64, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_80, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_96, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_112, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_128, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_144, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_160, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_176, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_192, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_224, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
(NB_FRAMES_256, MEDIUM_19_9_RATIO_HEIGHT, MEDIUM_19_9_RATIO_WIDTH), |
|
|
] |
|
|
|
|
|
TRAINING_PRESETS = { |
|
|
"HunyuanVideo (normal)": { |
|
|
"model_type": "hunyuan_video", |
|
|
"lora_rank": "128", |
|
|
"lora_alpha": "128", |
|
|
"num_epochs": 70, |
|
|
"batch_size": 1, |
|
|
"learning_rate": 2e-5, |
|
|
"save_iterations": 500, |
|
|
"training_buckets": SMALL_TRAINING_BUCKETS, |
|
|
}, |
|
|
"LTX-Video (normal)": { |
|
|
"model_type": "ltx_video", |
|
|
"lora_rank": "128", |
|
|
"lora_alpha": "128", |
|
|
"num_epochs": 70, |
|
|
"batch_size": 1, |
|
|
"learning_rate": 3e-5, |
|
|
"save_iterations": 500, |
|
|
"training_buckets": SMALL_TRAINING_BUCKETS, |
|
|
}, |
|
|
"LTX-Video (16:9, HQ)": { |
|
|
"model_type": "ltx_video", |
|
|
"lora_rank": "256", |
|
|
"lora_alpha": "128", |
|
|
"num_epochs": 50, |
|
|
"batch_size": 1, |
|
|
"learning_rate": 3e-5, |
|
|
"save_iterations": 200, |
|
|
"training_buckets": MEDIUM_19_9_RATIO_BUCKETS, |
|
|
} |
|
|
} |
|
|
|
|
|
@dataclass |
|
|
class TrainingConfig: |
|
|
"""Configuration class for finetrainers training""" |
|
|
|
|
|
|
|
|
model_name: str |
|
|
pretrained_model_name_or_path: str |
|
|
data_root: str |
|
|
output_dir: str |
|
|
|
|
|
|
|
|
revision: Optional[str] = None |
|
|
variant: Optional[str] = None |
|
|
cache_dir: Optional[str] = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video_column: str = "videos.txt" |
|
|
caption_column: str = "prompts.txt" |
|
|
|
|
|
id_token: Optional[str] = None |
|
|
video_resolution_buckets: List[Tuple[int, int, int]] = field(default_factory=lambda: SMALL_TRAINING_BUCKETS) |
|
|
video_reshape_mode: str = "center" |
|
|
caption_dropout_p: float = 0.05 |
|
|
caption_dropout_technique: str = "empty" |
|
|
precompute_conditions: bool = False |
|
|
|
|
|
|
|
|
flow_resolution_shifting: bool = False |
|
|
flow_weighting_scheme: str = "none" |
|
|
flow_logit_mean: float = 0.0 |
|
|
flow_logit_std: float = 1.0 |
|
|
flow_mode_scale: float = 1.29 |
|
|
|
|
|
|
|
|
training_type: str = "lora" |
|
|
seed: int = 42 |
|
|
mixed_precision: str = "bf16" |
|
|
batch_size: int = 1 |
|
|
train_epochs: int = 70 |
|
|
lora_rank: int = 128 |
|
|
lora_alpha: int = 128 |
|
|
target_modules: List[str] = field(default_factory=lambda: ["to_q", "to_k", "to_v", "to_out.0"]) |
|
|
gradient_accumulation_steps: int = 1 |
|
|
gradient_checkpointing: bool = True |
|
|
checkpointing_steps: int = 500 |
|
|
checkpointing_limit: Optional[int] = 2 |
|
|
resume_from_checkpoint: Optional[str] = None |
|
|
enable_slicing: bool = True |
|
|
enable_tiling: bool = True |
|
|
|
|
|
|
|
|
optimizer: str = "adamw" |
|
|
lr: float = 3e-5 |
|
|
scale_lr: bool = False |
|
|
lr_scheduler: str = "constant_with_warmup" |
|
|
lr_warmup_steps: int = 100 |
|
|
lr_num_cycles: int = 1 |
|
|
lr_power: float = 1.0 |
|
|
beta1: float = 0.9 |
|
|
beta2: float = 0.95 |
|
|
weight_decay: float = 1e-4 |
|
|
epsilon: float = 1e-8 |
|
|
max_grad_norm: float = 1.0 |
|
|
|
|
|
|
|
|
tracker_name: str = "finetrainers" |
|
|
report_to: str = "wandb" |
|
|
nccl_timeout: int = 1800 |
|
|
|
|
|
@classmethod |
|
|
def hunyuan_video_lora(cls, data_path: str, output_path: str, buckets=None) -> 'TrainingConfig': |
|
|
"""Configuration for Hunyuan video-to-video LoRA training""" |
|
|
return cls( |
|
|
model_name="hunyuan_video", |
|
|
pretrained_model_name_or_path="hunyuanvideo-community/HunyuanVideo", |
|
|
data_root=data_path, |
|
|
output_dir=output_path, |
|
|
batch_size=1, |
|
|
train_epochs=70, |
|
|
lr=2e-5, |
|
|
gradient_checkpointing=True, |
|
|
id_token="afkx", |
|
|
gradient_accumulation_steps=1, |
|
|
lora_rank=128, |
|
|
lora_alpha=128, |
|
|
video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS, |
|
|
caption_dropout_p=0.05, |
|
|
flow_weighting_scheme="none" |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def ltx_video_lora(cls, data_path: str, output_path: str, buckets=None) -> 'TrainingConfig': |
|
|
"""Configuration for LTX-Video LoRA training""" |
|
|
return cls( |
|
|
model_name="ltx_video", |
|
|
pretrained_model_name_or_path="Lightricks/LTX-Video", |
|
|
data_root=data_path, |
|
|
output_dir=output_path, |
|
|
batch_size=1, |
|
|
train_epochs=40, |
|
|
lr=3e-5, |
|
|
gradient_checkpointing=True, |
|
|
id_token="BW_STYLE", |
|
|
gradient_accumulation_steps=4, |
|
|
lora_rank=128, |
|
|
lora_alpha=128, |
|
|
video_resolution_buckets=buckets or SMALL_TRAINING_BUCKETS, |
|
|
caption_dropout_p=0.05, |
|
|
flow_weighting_scheme="logit_normal" |
|
|
) |
|
|
|
|
|
def to_args_list(self) -> List[str]: |
|
|
"""Convert config to command line arguments list""" |
|
|
args = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
args.extend(["--model_name", self.model_name]) |
|
|
|
|
|
args.extend(["--pretrained_model_name_or_path", self.pretrained_model_name_or_path]) |
|
|
if self.revision: |
|
|
args.extend(["--revision", self.revision]) |
|
|
if self.variant: |
|
|
args.extend(["--variant", self.variant]) |
|
|
if self.cache_dir: |
|
|
args.extend(["--cache_dir", self.cache_dir]) |
|
|
|
|
|
|
|
|
args.extend(["--data_root", self.data_root]) |
|
|
args.extend(["--video_column", self.video_column]) |
|
|
args.extend(["--caption_column", self.caption_column]) |
|
|
if self.id_token: |
|
|
args.extend(["--id_token", self.id_token]) |
|
|
|
|
|
|
|
|
if self.video_resolution_buckets: |
|
|
bucket_strs = [f"{f}x{h}x{w}" for f, h, w in self.video_resolution_buckets] |
|
|
args.extend(["--video_resolution_buckets"] + bucket_strs) |
|
|
|
|
|
if self.video_reshape_mode: |
|
|
args.extend(["--video_reshape_mode", self.video_reshape_mode]) |
|
|
|
|
|
args.extend(["--caption_dropout_p", str(self.caption_dropout_p)]) |
|
|
args.extend(["--caption_dropout_technique", self.caption_dropout_technique]) |
|
|
if self.precompute_conditions: |
|
|
args.append("--precompute_conditions") |
|
|
|
|
|
|
|
|
if self.flow_resolution_shifting: |
|
|
args.append("--flow_resolution_shifting") |
|
|
args.extend(["--flow_weighting_scheme", self.flow_weighting_scheme]) |
|
|
args.extend(["--flow_logit_mean", str(self.flow_logit_mean)]) |
|
|
args.extend(["--flow_logit_std", str(self.flow_logit_std)]) |
|
|
args.extend(["--flow_mode_scale", str(self.flow_mode_scale)]) |
|
|
|
|
|
|
|
|
args.extend(["--training_type", self.training_type]) |
|
|
args.extend(["--seed", str(self.seed)]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
args.extend(["--batch_size", str(self.batch_size)]) |
|
|
args.extend(["--train_epochs", str(self.train_epochs)]) |
|
|
args.extend(["--rank", str(self.lora_rank)]) |
|
|
args.extend(["--lora_alpha", str(self.lora_alpha)]) |
|
|
args.extend(["--target_modules"] + self.target_modules) |
|
|
args.extend(["--gradient_accumulation_steps", str(self.gradient_accumulation_steps)]) |
|
|
if self.gradient_checkpointing: |
|
|
args.append("--gradient_checkpointing") |
|
|
args.extend(["--checkpointing_steps", str(self.checkpointing_steps)]) |
|
|
if self.checkpointing_limit: |
|
|
args.extend(["--checkpointing_limit", str(self.checkpointing_limit)]) |
|
|
if self.resume_from_checkpoint: |
|
|
args.extend(["--resume_from_checkpoint", self.resume_from_checkpoint]) |
|
|
if self.enable_slicing: |
|
|
args.append("--enable_slicing") |
|
|
if self.enable_tiling: |
|
|
args.append("--enable_tiling") |
|
|
|
|
|
|
|
|
args.extend(["--optimizer", self.optimizer]) |
|
|
args.extend(["--lr", str(self.lr)]) |
|
|
if self.scale_lr: |
|
|
args.append("--scale_lr") |
|
|
args.extend(["--lr_scheduler", self.lr_scheduler]) |
|
|
args.extend(["--lr_warmup_steps", str(self.lr_warmup_steps)]) |
|
|
args.extend(["--lr_num_cycles", str(self.lr_num_cycles)]) |
|
|
args.extend(["--lr_power", str(self.lr_power)]) |
|
|
args.extend(["--beta1", str(self.beta1)]) |
|
|
args.extend(["--beta2", str(self.beta2)]) |
|
|
args.extend(["--weight_decay", str(self.weight_decay)]) |
|
|
args.extend(["--epsilon", str(self.epsilon)]) |
|
|
args.extend(["--max_grad_norm", str(self.max_grad_norm)]) |
|
|
|
|
|
|
|
|
args.extend(["--tracker_name", self.tracker_name]) |
|
|
args.extend(["--output_dir", self.output_dir]) |
|
|
args.extend(["--report_to", self.report_to]) |
|
|
args.extend(["--nccl_timeout", str(self.nccl_timeout)]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
args.append("--remove_common_llm_caption_prefixes") |
|
|
|
|
|
return args |