Spaces:
Runtime error
Runtime error
| import argparse | |
| import datetime | |
| import logging | |
| import inspect | |
| import math | |
| import os | |
| import random | |
| import gc | |
| import copy | |
| from typing import Dict, Optional, Tuple | |
| from omegaconf import OmegaConf | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint | |
| import diffusers | |
| import transformers | |
| from torchvision import transforms | |
| from tqdm.auto import tqdm | |
| from accelerate import Accelerator | |
| from accelerate.logging import get_logger | |
| from accelerate.utils import set_seed | |
| from models.unet_3d_condition import UNet3DConditionModel | |
| from diffusers.models import AutoencoderKL | |
| from diffusers import DDIMScheduler, TextToVideoSDPipeline | |
| from diffusers.optimization import get_scheduler | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from diffusers.models.attention_processor import AttnProcessor2_0, Attention | |
| from diffusers.models.attention import BasicTransformerBlock | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from transformers.models.clip.modeling_clip import CLIPEncoder | |
| from utils.dataset import VideoJsonDataset, SingleVideoDataset, \ | |
| ImageDataset, VideoFolderDataset, CachedDataset | |
| from einops import rearrange, repeat | |
| from utils.lora_handler import LoraHandler | |
| from utils.lora import extract_lora_child_module | |
| from utils.ddim_utils import ddim_inversion | |
| import imageio | |
| import numpy as np | |
| already_printed_trainables = False | |
| logger = get_logger(__name__, log_level="INFO") | |
| def create_logging(logging, logger, accelerator): | |
| logging.basicConfig( | |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
| datefmt="%m/%d/%Y %H:%M:%S", | |
| level=logging.INFO, | |
| ) | |
| logger.info(accelerator.state, main_process_only=False) | |
| def accelerate_set_verbose(accelerator): | |
| if accelerator.is_local_main_process: | |
| transformers.utils.logging.set_verbosity_warning() | |
| diffusers.utils.logging.set_verbosity_info() | |
| else: | |
| transformers.utils.logging.set_verbosity_error() | |
| diffusers.utils.logging.set_verbosity_error() | |
| def get_train_dataset(dataset_types, train_data, tokenizer): | |
| train_datasets = [] | |
| # Loop through all available datasets, get the name, then add to list of data to process. | |
| for DataSet in [VideoJsonDataset, SingleVideoDataset, ImageDataset, VideoFolderDataset]: | |
| for dataset in dataset_types: | |
| if dataset == DataSet.__getname__(): | |
| train_datasets.append(DataSet(**train_data, tokenizer=tokenizer)) | |
| if len(train_datasets) > 0: | |
| return train_datasets | |
| else: | |
| raise ValueError("Dataset type not found: 'json', 'single_video', 'folder', 'image'") | |
| def extend_datasets(datasets, dataset_items, extend=False): | |
| biggest_data_len = max(x.__len__() for x in datasets) | |
| extended = [] | |
| for dataset in datasets: | |
| if dataset.__len__() == 0: | |
| del dataset | |
| continue | |
| if dataset.__len__() < biggest_data_len: | |
| for item in dataset_items: | |
| if extend and item not in extended and hasattr(dataset, item): | |
| print(f"Extending {item}") | |
| value = getattr(dataset, item) | |
| value *= biggest_data_len | |
| value = value[:biggest_data_len] | |
| setattr(dataset, item, value) | |
| print(f"New {item} dataset length: {dataset.__len__()}") | |
| extended.append(item) | |
| def export_to_video(video_frames, output_video_path, fps): | |
| video_writer = imageio.get_writer(output_video_path, fps=fps) | |
| for img in video_frames: | |
| video_writer.append_data(np.array(img)) | |
| video_writer.close() | |
| def create_output_folders(output_dir, config): | |
| now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | |
| out_dir = os.path.join(output_dir, f"train_{now}") | |
| os.makedirs(out_dir, exist_ok=True) | |
| os.makedirs(f"{out_dir}/samples", exist_ok=True) | |
| # OmegaConf.save(config, os.path.join(out_dir, 'config.yaml')) | |
| return out_dir | |
| def load_primary_models(pretrained_model_path): | |
| noise_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") | |
| tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") | |
| text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") | |
| vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") | |
| unet = UNet3DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet") | |
| return noise_scheduler, tokenizer, text_encoder, vae, unet | |
| def unet_and_text_g_c(unet, text_encoder, unet_enable, text_enable): | |
| unet._set_gradient_checkpointing(value=unet_enable) | |
| text_encoder._set_gradient_checkpointing(CLIPEncoder, value=text_enable) | |
| def freeze_models(models_to_freeze): | |
| for model in models_to_freeze: | |
| if model is not None: model.requires_grad_(False) | |
| def is_attn(name): | |
| return ('attn1' or 'attn2' == name.split('.')[-1]) | |
| def set_processors(attentions): | |
| for attn in attentions: attn.set_processor(AttnProcessor2_0()) | |
| def set_torch_2_attn(unet): | |
| optim_count = 0 | |
| for name, module in unet.named_modules(): | |
| if is_attn(name): | |
| if isinstance(module, torch.nn.ModuleList): | |
| for m in module: | |
| if isinstance(m, BasicTransformerBlock): | |
| set_processors([m.attn1, m.attn2]) | |
| optim_count += 1 | |
| if optim_count > 0: | |
| print(f"{optim_count} Attention layers using Scaled Dot Product Attention.") | |
| def handle_memory_attention(enable_xformers_memory_efficient_attention, enable_torch_2_attn, unet): | |
| try: | |
| is_torch_2 = hasattr(F, 'scaled_dot_product_attention') | |
| enable_torch_2 = is_torch_2 and enable_torch_2_attn | |
| if enable_xformers_memory_efficient_attention and not enable_torch_2: | |
| if is_xformers_available(): | |
| from xformers.ops import MemoryEfficientAttentionFlashAttentionOp | |
| unet.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) | |
| else: | |
| raise ValueError("xformers is not available. Make sure it is installed correctly") | |
| if enable_torch_2: | |
| set_torch_2_attn(unet) | |
| except: | |
| print("Could not enable memory efficient attention for xformers or Torch 2.0.") | |
| def param_optim(model, condition, extra_params=None, is_lora=False, negation=None): | |
| extra_params = extra_params if len(extra_params.keys()) > 0 else None | |
| return { | |
| "model": model, | |
| "condition": condition, | |
| 'extra_params': extra_params, | |
| 'is_lora': is_lora, | |
| "negation": negation | |
| } | |
| def create_optim_params(name='param', params=None, lr=5e-6, extra_params=None): | |
| params = { | |
| "name": name, | |
| "params": params, | |
| "lr": lr | |
| } | |
| if extra_params is not None: | |
| for k, v in extra_params.items(): | |
| params[k] = v | |
| return params | |
| def negate_params(name, negation): | |
| # We have to do this if we are co-training with LoRA. | |
| # This ensures that parameter groups aren't duplicated. | |
| if negation is None: return False | |
| for n in negation: | |
| if n in name and 'temp' not in name: | |
| return True | |
| return False | |
| def create_optimizer_params(model_list, lr): | |
| import itertools | |
| optimizer_params = [] | |
| for optim in model_list: | |
| model, condition, extra_params, is_lora, negation = optim.values() | |
| # Check if we are doing LoRA training. | |
| if is_lora and condition and isinstance(model, list): | |
| params = create_optim_params( | |
| params=itertools.chain(*model), | |
| extra_params=extra_params | |
| ) | |
| optimizer_params.append(params) | |
| continue | |
| if is_lora and condition and not isinstance(model, list): | |
| for n, p in model.named_parameters(): | |
| if 'lora' in n: | |
| params = create_optim_params(n, p, lr, extra_params) | |
| optimizer_params.append(params) | |
| continue | |
| # If this is true, we can train it. | |
| if condition: | |
| for n, p in model.named_parameters(): | |
| should_negate = 'lora' in n and not is_lora | |
| if should_negate: continue | |
| params = create_optim_params(n, p, lr, extra_params) | |
| optimizer_params.append(params) | |
| return optimizer_params | |
| def get_optimizer(use_8bit_adam): | |
| if use_8bit_adam: | |
| try: | |
| import bitsandbytes as bnb | |
| except ImportError: | |
| raise ImportError( | |
| "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" | |
| ) | |
| return bnb.optim.AdamW8bit | |
| else: | |
| return torch.optim.AdamW | |
| def is_mixed_precision(accelerator): | |
| weight_dtype = torch.float32 | |
| if accelerator.mixed_precision == "fp16": | |
| weight_dtype = torch.float16 | |
| elif accelerator.mixed_precision == "bf16": | |
| weight_dtype = torch.bfloat16 | |
| return weight_dtype | |
| def cast_to_gpu_and_type(model_list, accelerator, weight_dtype): | |
| for model in model_list: | |
| if model is not None: model.to(accelerator.device, dtype=weight_dtype) | |
| def inverse_video(pipe, latents, num_steps): | |
| ddim_inv_scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | |
| ddim_inv_scheduler.set_timesteps(num_steps) | |
| ddim_inv_latent = ddim_inversion( | |
| pipe, ddim_inv_scheduler, video_latent=latents.to(pipe.device), | |
| num_inv_steps=num_steps, prompt="")[-1] | |
| return ddim_inv_latent | |
| def handle_cache_latents( | |
| should_cache, | |
| output_dir, | |
| train_dataloader, | |
| train_batch_size, | |
| vae, | |
| unet, | |
| pretrained_model_path, | |
| noise_prior, | |
| cached_latent_dir=None, | |
| ): | |
| # Cache latents by storing them in VRAM. | |
| # Speeds up training and saves memory by not encoding during the train loop. | |
| if not should_cache: return None | |
| vae.to('cuda', dtype=torch.float16) | |
| vae.enable_slicing() | |
| pipe = TextToVideoSDPipeline.from_pretrained( | |
| pretrained_model_path, | |
| vae=vae, | |
| unet=copy.deepcopy(unet).to('cuda', dtype=torch.float16) | |
| ) | |
| pipe.text_encoder.to('cuda', dtype=torch.float16) | |
| cached_latent_dir = ( | |
| os.path.abspath(cached_latent_dir) if cached_latent_dir is not None else None | |
| ) | |
| if cached_latent_dir is None: | |
| cache_save_dir = f"{output_dir}/cached_latents" | |
| os.makedirs(cache_save_dir, exist_ok=True) | |
| for i, batch in enumerate(tqdm(train_dataloader, desc="Caching Latents.")): | |
| save_name = f"cached_{i}" | |
| full_out_path = f"{cache_save_dir}/{save_name}.pt" | |
| pixel_values = batch['pixel_values'].to('cuda', dtype=torch.float16) | |
| batch['latents'] = tensor_to_vae_latent(pixel_values, vae) | |
| if noise_prior > 0.: | |
| batch['inversion_noise'] = inverse_video(pipe, batch['latents'], 50) | |
| for k, v in batch.items(): batch[k] = v[0] | |
| torch.save(batch, full_out_path) | |
| del pixel_values | |
| del batch | |
| # We do this to avoid fragmentation from casting latents between devices. | |
| torch.cuda.empty_cache() | |
| else: | |
| cache_save_dir = cached_latent_dir | |
| return torch.utils.data.DataLoader( | |
| CachedDataset(cache_dir=cache_save_dir), | |
| batch_size=train_batch_size, | |
| shuffle=True, | |
| num_workers=0 | |
| ) | |
| def handle_trainable_modules(model, trainable_modules=None, is_enabled=True, negation=None): | |
| global already_printed_trainables | |
| # This can most definitely be refactored :-) | |
| unfrozen_params = 0 | |
| if trainable_modules is not None: | |
| for name, module in model.named_modules(): | |
| for tm in tuple(trainable_modules): | |
| if tm == 'all': | |
| model.requires_grad_(is_enabled) | |
| unfrozen_params = len(list(model.parameters())) | |
| break | |
| if tm in name and 'lora' not in name: | |
| for m in module.parameters(): | |
| m.requires_grad_(is_enabled) | |
| if is_enabled: unfrozen_params += 1 | |
| if unfrozen_params > 0 and not already_printed_trainables: | |
| already_printed_trainables = True | |
| print(f"{unfrozen_params} params have been unfrozen for training.") | |
| def tensor_to_vae_latent(t, vae): | |
| video_length = t.shape[1] | |
| t = rearrange(t, "b f c h w -> (b f) c h w") | |
| latents = vae.encode(t).latent_dist.sample() | |
| latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length) | |
| latents = latents * 0.18215 | |
| return latents | |
| def sample_noise(latents, noise_strength, use_offset_noise=False): | |
| b, c, f, *_ = latents.shape | |
| noise_latents = torch.randn_like(latents, device=latents.device) | |
| if use_offset_noise: | |
| offset_noise = torch.randn(b, c, f, 1, 1, device=latents.device) | |
| noise_latents = noise_latents + noise_strength * offset_noise | |
| return noise_latents | |
| def enforce_zero_terminal_snr(betas): | |
| """ | |
| Corrects noise in diffusion schedulers. | |
| From: Common Diffusion Noise Schedules and Sample Steps are Flawed | |
| https://arxiv.org/pdf/2305.08891.pdf | |
| """ | |
| # Convert betas to alphas_bar_sqrt | |
| alphas = 1 - betas | |
| alphas_bar = alphas.cumprod(0) | |
| alphas_bar_sqrt = alphas_bar.sqrt() | |
| # Store old values. | |
| alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() | |
| alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() | |
| # Shift so the last timestep is zero. | |
| alphas_bar_sqrt -= alphas_bar_sqrt_T | |
| # Scale so the first timestep is back to the old value. | |
| alphas_bar_sqrt *= alphas_bar_sqrt_0 / ( | |
| alphas_bar_sqrt_0 - alphas_bar_sqrt_T | |
| ) | |
| # Convert alphas_bar_sqrt to betas | |
| alphas_bar = alphas_bar_sqrt ** 2 | |
| alphas = alphas_bar[1:] / alphas_bar[:-1] | |
| alphas = torch.cat([alphas_bar[0:1], alphas]) | |
| betas = 1 - alphas | |
| return betas | |
| def should_sample(global_step, validation_steps, validation_data): | |
| return global_step % validation_steps == 0 and validation_data.sample_preview | |
| def save_pipe( | |
| path, | |
| global_step, | |
| accelerator, | |
| unet, | |
| text_encoder, | |
| vae, | |
| output_dir, | |
| lora_manager_spatial: LoraHandler, | |
| lora_manager_temporal: LoraHandler, | |
| unet_target_replace_module=None, | |
| text_target_replace_module=None, | |
| is_checkpoint=False, | |
| save_pretrained_model=True | |
| ): | |
| if is_checkpoint: | |
| save_path = os.path.join(output_dir, f"checkpoint-{global_step}") | |
| os.makedirs(save_path, exist_ok=True) | |
| else: | |
| save_path = output_dir | |
| # Save the dtypes so we can continue training at the same precision. | |
| u_dtype, t_dtype, v_dtype = unet.dtype, text_encoder.dtype, vae.dtype | |
| # Copy the model without creating a reference to it. This allows keeping the state of our lora training if enabled. | |
| unet_out = copy.deepcopy(accelerator.unwrap_model(unet.cpu(), keep_fp32_wrapper=False)) | |
| text_encoder_out = copy.deepcopy(accelerator.unwrap_model(text_encoder.cpu(), keep_fp32_wrapper=False)) | |
| pipeline = TextToVideoSDPipeline.from_pretrained( | |
| path, | |
| unet=unet_out, | |
| text_encoder=text_encoder_out, | |
| vae=vae, | |
| ).to(torch_dtype=torch.float32) | |
| lora_manager_spatial.save_lora_weights(model=copy.deepcopy(pipeline), save_path=save_path+'/spatial', step=global_step) | |
| lora_manager_temporal.save_lora_weights(model=copy.deepcopy(pipeline), save_path=save_path+'/temporal', step=global_step) | |
| if save_pretrained_model: | |
| pipeline.save_pretrained(save_path) | |
| if is_checkpoint: | |
| unet, text_encoder = accelerator.prepare(unet, text_encoder) | |
| models_to_cast_back = [(unet, u_dtype), (text_encoder, t_dtype), (vae, v_dtype)] | |
| [x[0].to(accelerator.device, dtype=x[1]) for x in models_to_cast_back] | |
| logger.info(f"Saved model at {save_path} on step {global_step}") | |
| del pipeline | |
| del unet_out | |
| del text_encoder_out | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| def main( | |
| pretrained_model_path: str, | |
| output_dir: str, | |
| train_data: Dict, | |
| validation_data: Dict, | |
| extra_train_data: list = [], | |
| dataset_types: Tuple[str] = ('json'), | |
| validation_steps: int = 100, | |
| trainable_modules: Tuple[str] = None, # Eg: ("attn1", "attn2") | |
| extra_unet_params=None, | |
| train_batch_size: int = 1, | |
| max_train_steps: int = 500, | |
| learning_rate: float = 5e-5, | |
| lr_scheduler: str = "constant", | |
| lr_warmup_steps: int = 0, | |
| adam_beta1: float = 0.9, | |
| adam_beta2: float = 0.999, | |
| adam_weight_decay: float = 1e-2, | |
| adam_epsilon: float = 1e-08, | |
| gradient_accumulation_steps: int = 1, | |
| gradient_checkpointing: bool = False, | |
| text_encoder_gradient_checkpointing: bool = False, | |
| checkpointing_steps: int = 500, | |
| resume_from_checkpoint: Optional[str] = None, | |
| resume_step: Optional[int] = None, | |
| mixed_precision: Optional[str] = "fp16", | |
| use_8bit_adam: bool = False, | |
| enable_xformers_memory_efficient_attention: bool = True, | |
| enable_torch_2_attn: bool = False, | |
| seed: Optional[int] = None, | |
| use_offset_noise: bool = False, | |
| rescale_schedule: bool = False, | |
| offset_noise_strength: float = 0.1, | |
| extend_dataset: bool = False, | |
| cache_latents: bool = False, | |
| cached_latent_dir=None, | |
| use_unet_lora: bool = False, | |
| unet_lora_modules: Tuple[str] = [], | |
| text_encoder_lora_modules: Tuple[str] = [], | |
| save_pretrained_model: bool = True, | |
| lora_rank: int = 16, | |
| lora_path: str = '', | |
| lora_unet_dropout: float = 0.1, | |
| logger_type: str = 'tensorboard', | |
| **kwargs | |
| ): | |
| *_, config = inspect.getargvalues(inspect.currentframe()) | |
| accelerator = Accelerator( | |
| gradient_accumulation_steps=gradient_accumulation_steps, | |
| mixed_precision=mixed_precision, | |
| log_with=logger_type, | |
| project_dir=output_dir | |
| ) | |
| # Make one log on every process with the configuration for debugging. | |
| create_logging(logging, logger, accelerator) | |
| # Initialize accelerate, transformers, and diffusers warnings | |
| accelerate_set_verbose(accelerator) | |
| # Handle the output folder creation | |
| if accelerator.is_main_process: | |
| output_dir = create_output_folders(output_dir, config) | |
| # Load scheduler, tokenizer and models. | |
| noise_scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(pretrained_model_path) | |
| # Freeze any necessary models | |
| freeze_models([vae, text_encoder, unet]) | |
| # Enable xformers if available | |
| handle_memory_attention(enable_xformers_memory_efficient_attention, enable_torch_2_attn, unet) | |
| # Initialize the optimizer | |
| optimizer_cls = get_optimizer(use_8bit_adam) | |
| # Get the training dataset based on types (json, single_video, image) | |
| train_datasets = get_train_dataset(dataset_types, train_data, tokenizer) | |
| # If you have extra train data, you can add a list of however many you would like. | |
| # Eg: extra_train_data: [{: {dataset_types, train_data: {etc...}}}] | |
| try: | |
| if extra_train_data is not None and len(extra_train_data) > 0: | |
| for dataset in extra_train_data: | |
| d_t, t_d = dataset['dataset_types'], dataset['train_data'] | |
| train_datasets += get_train_dataset(d_t, t_d, tokenizer) | |
| except Exception as e: | |
| print(f"Could not process extra train datasets due to an error : {e}") | |
| # Extend datasets that are less than the greatest one. This allows for more balanced training. | |
| attrs = ['train_data', 'frames', 'image_dir', 'video_files'] | |
| extend_datasets(train_datasets, attrs, extend=extend_dataset) | |
| # Process one dataset | |
| if len(train_datasets) == 1: | |
| train_dataset = train_datasets[0] | |
| # Process many datasets | |
| else: | |
| train_dataset = torch.utils.data.ConcatDataset(train_datasets) | |
| # Create parameters to optimize over with a condition (if "condition" is true, optimize it) | |
| extra_unet_params = extra_unet_params if extra_unet_params is not None else {} | |
| extra_text_encoder_params = extra_unet_params if extra_unet_params is not None else {} | |
| # Use LoRA if enabled. | |
| # one temporal lora | |
| lora_manager_temporal = LoraHandler(use_unet_lora=use_unet_lora, unet_replace_modules=["TransformerTemporalModel"]) | |
| unet_lora_params_temporal, unet_negation_temporal = lora_manager_temporal.add_lora_to_model( | |
| use_unet_lora, unet, lora_manager_temporal.unet_replace_modules, lora_unet_dropout, | |
| lora_path + '/temporal/lora/', r=lora_rank) | |
| optimizer_temporal = optimizer_cls( | |
| create_optimizer_params([param_optim(unet_lora_params_temporal, use_unet_lora, is_lora=True, | |
| extra_params={**{"lr": learning_rate}, **extra_text_encoder_params} | |
| )], learning_rate), | |
| lr=learning_rate, | |
| betas=(adam_beta1, adam_beta2), | |
| weight_decay=adam_weight_decay, | |
| eps=adam_epsilon, | |
| ) | |
| lr_scheduler_temporal = get_scheduler( | |
| lr_scheduler, | |
| optimizer=optimizer_temporal, | |
| num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps, | |
| num_training_steps=max_train_steps * gradient_accumulation_steps, | |
| ) | |
| # one spatial lora for each video | |
| if 'folder' in dataset_types: | |
| spatial_lora_num = train_dataset.__len__() | |
| else: | |
| spatial_lora_num = 1 | |
| lora_manager_spatials = [] | |
| unet_lora_params_spatial_list = [] | |
| optimizer_spatial_list = [] | |
| lr_scheduler_spatial_list = [] | |
| for i in range(spatial_lora_num): | |
| lora_manager_spatial = LoraHandler(use_unet_lora=use_unet_lora, unet_replace_modules=["Transformer2DModel"]) | |
| lora_manager_spatials.append(lora_manager_spatial) | |
| unet_lora_params_spatial, unet_negation_spatial = lora_manager_spatial.add_lora_to_model( | |
| use_unet_lora, unet, lora_manager_spatial.unet_replace_modules, lora_unet_dropout, | |
| lora_path + '/spatial/lora/', r=lora_rank) | |
| unet_lora_params_spatial_list.append(unet_lora_params_spatial) | |
| optimizer_spatial = optimizer_cls( | |
| create_optimizer_params([param_optim(unet_lora_params_spatial, use_unet_lora, is_lora=True, | |
| extra_params={**{"lr": learning_rate}, **extra_text_encoder_params} | |
| )], learning_rate), | |
| lr=learning_rate, | |
| betas=(adam_beta1, adam_beta2), | |
| weight_decay=adam_weight_decay, | |
| eps=adam_epsilon, | |
| ) | |
| optimizer_spatial_list.append(optimizer_spatial) | |
| # Scheduler | |
| lr_scheduler_spatial = get_scheduler( | |
| lr_scheduler, | |
| optimizer=optimizer_spatial, | |
| num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps, | |
| num_training_steps=max_train_steps * gradient_accumulation_steps, | |
| ) | |
| lr_scheduler_spatial_list.append(lr_scheduler_spatial) | |
| unet_negation_all = unet_negation_spatial + unet_negation_temporal | |
| # DataLoaders creation: | |
| train_dataloader = torch.utils.data.DataLoader( | |
| train_dataset, | |
| batch_size=train_batch_size, | |
| shuffle=True | |
| ) | |
| # Latents caching | |
| cached_data_loader = handle_cache_latents( | |
| cache_latents, | |
| output_dir, | |
| train_dataloader, | |
| train_batch_size, | |
| vae, | |
| unet, | |
| pretrained_model_path, | |
| validation_data.noise_prior, | |
| cached_latent_dir, | |
| ) | |
| if cached_data_loader is not None: | |
| train_dataloader = cached_data_loader | |
| # Prepare everything with our `accelerator`. | |
| unet, optimizer_spatial_list, optimizer_temporal, train_dataloader, lr_scheduler_spatial_list, lr_scheduler_temporal, text_encoder = accelerator.prepare( | |
| unet, | |
| optimizer_spatial_list, optimizer_temporal, | |
| train_dataloader, | |
| lr_scheduler_spatial_list, lr_scheduler_temporal, | |
| text_encoder | |
| ) | |
| # Use Gradient Checkpointing if enabled. | |
| unet_and_text_g_c( | |
| unet, | |
| text_encoder, | |
| gradient_checkpointing, | |
| text_encoder_gradient_checkpointing | |
| ) | |
| # Enable VAE slicing to save memory. | |
| vae.enable_slicing() | |
| # For mixed precision training we cast the text_encoder and vae weights to half-precision | |
| # as these models are only used for inference, keeping weights in full precision is not required. | |
| weight_dtype = is_mixed_precision(accelerator) | |
| # Move text encoders, and VAE to GPU | |
| models_to_cast = [text_encoder, vae] | |
| cast_to_gpu_and_type(models_to_cast, accelerator, weight_dtype) | |
| # Fix noise schedules to predcit light and dark areas if available. | |
| if not use_offset_noise and rescale_schedule: | |
| noise_scheduler.betas = enforce_zero_terminal_snr(noise_scheduler.betas) | |
| # We need to recalculate our total training steps as the size of the training dataloader may have changed. | |
| num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | |
| # Afterwards we recalculate our number of training epochs | |
| num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) | |
| # We need to initialize the trackers we use, and also store our configuration. | |
| # The trackers initializes automatically on the main process. | |
| if accelerator.is_main_process: | |
| accelerator.init_trackers("text2video-fine-tune") | |
| # Train! | |
| total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps | |
| logger.info("***** Running training *****") | |
| logger.info(f" Num examples = {len(train_dataset)}") | |
| logger.info(f" Num Epochs = {num_train_epochs}") | |
| logger.info(f" Instantaneous batch size per device = {train_batch_size}") | |
| logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") | |
| logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") | |
| logger.info(f" Total optimization steps = {max_train_steps}") | |
| global_step = 0 | |
| first_epoch = 0 | |
| # Only show the progress bar once on each machine. | |
| progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process) | |
| progress_bar.set_description("Steps") | |
| def finetune_unet(batch, step, mask_spatial_lora=False, mask_temporal_lora=False): | |
| nonlocal use_offset_noise | |
| nonlocal rescale_schedule | |
| # Unfreeze UNET Layers | |
| if global_step == 0: | |
| already_printed_trainables = False | |
| unet.train() | |
| handle_trainable_modules( | |
| unet, | |
| trainable_modules, | |
| is_enabled=True, | |
| negation=unet_negation_all | |
| ) | |
| # Convert videos to latent space | |
| if not cache_latents: | |
| latents = tensor_to_vae_latent(batch["pixel_values"], vae) | |
| else: | |
| latents = batch["latents"] | |
| # Sample noise that we'll add to the latents | |
| use_offset_noise = use_offset_noise and not rescale_schedule | |
| noise = sample_noise(latents, offset_noise_strength, use_offset_noise) | |
| bsz = latents.shape[0] | |
| # Sample a random timestep for each video | |
| timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) | |
| timesteps = timesteps.long() | |
| # Add noise to the latents according to the noise magnitude at each timestep | |
| # (this is the forward diffusion process) | |
| noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | |
| # *Potentially* Fixes gradient checkpointing training. | |
| # See: https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb | |
| if kwargs.get('eval_train', False): | |
| unet.eval() | |
| text_encoder.eval() | |
| # Encode text embeddings | |
| token_ids = batch['prompt_ids'] | |
| encoder_hidden_states = text_encoder(token_ids)[0] | |
| detached_encoder_state = encoder_hidden_states.clone().detach() | |
| # Get the target for loss depending on the prediction type | |
| if noise_scheduler.config.prediction_type == "epsilon": | |
| target = noise | |
| elif noise_scheduler.config.prediction_type == "v_prediction": | |
| target = noise_scheduler.get_velocity(latents, noise, timesteps) | |
| else: | |
| raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | |
| encoder_hidden_states = detached_encoder_state | |
| if mask_spatial_lora: | |
| loras = extract_lora_child_module(unet, target_replace_module=["Transformer2DModel"]) | |
| for lora_i in loras: | |
| lora_i.scale = 0. | |
| loss_spatial = None | |
| else: | |
| loras = extract_lora_child_module(unet, target_replace_module=["Transformer2DModel"]) | |
| for lora_i in loras: | |
| lora_i.scale = 1. | |
| for lora_idx in range(0, len(loras), spatial_lora_num): | |
| loras[lora_idx + step].scale = 1. | |
| loras = extract_lora_child_module(unet, target_replace_module=["TransformerTemporalModel"]) | |
| for lora_i in loras: | |
| lora_i.scale = 0. | |
| ran_idx = torch.randint(0, noisy_latents.shape[2], (1,)).item() | |
| if random.uniform(0, 1) < -0.5: | |
| pixel_values_spatial = transforms.functional.hflip( | |
| batch["pixel_values"][:, ran_idx, :, :, :]).unsqueeze(1) | |
| latents_spatial = tensor_to_vae_latent(pixel_values_spatial, vae) | |
| noise_spatial = sample_noise(latents_spatial, offset_noise_strength, use_offset_noise) | |
| noisy_latents_input = noise_scheduler.add_noise(latents_spatial, noise_spatial, timesteps) | |
| target_spatial = noise_spatial | |
| model_pred_spatial = unet(noisy_latents_input, timesteps, | |
| encoder_hidden_states=encoder_hidden_states).sample | |
| loss_spatial = F.mse_loss(model_pred_spatial[:, :, 0, :, :].float(), | |
| target_spatial[:, :, 0, :, :].float(), reduction="mean") | |
| else: | |
| noisy_latents_input = noisy_latents[:, :, ran_idx, :, :] | |
| target_spatial = target[:, :, ran_idx, :, :] | |
| model_pred_spatial = unet(noisy_latents_input.unsqueeze(2), timesteps, | |
| encoder_hidden_states=encoder_hidden_states).sample | |
| loss_spatial = F.mse_loss(model_pred_spatial[:, :, 0, :, :].float(), | |
| target_spatial.float(), reduction="mean") | |
| if mask_temporal_lora: | |
| loras = extract_lora_child_module(unet, target_replace_module=["TransformerTemporalModel"]) | |
| for lora_i in loras: | |
| lora_i.scale = 0. | |
| loss_temporal = None | |
| else: | |
| loras = extract_lora_child_module(unet, target_replace_module=["TransformerTemporalModel"]) | |
| for lora_i in loras: | |
| lora_i.scale = 1. | |
| model_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample | |
| loss_temporal = F.mse_loss(model_pred.float(), target.float(), reduction="mean") | |
| beta = 1 | |
| alpha = (beta ** 2 + 1) ** 0.5 | |
| ran_idx = torch.randint(0, model_pred.shape[2], (1,)).item() | |
| model_pred_decent = alpha * model_pred - beta * model_pred[:, :, ran_idx, :, :].unsqueeze(2) | |
| target_decent = alpha * target - beta * target[:, :, ran_idx, :, :].unsqueeze(2) | |
| loss_ad_temporal = F.mse_loss(model_pred_decent.float(), target_decent.float(), reduction="mean") | |
| loss_temporal = loss_temporal + loss_ad_temporal | |
| return loss_spatial, loss_temporal, latents, noise | |
| for epoch in range(first_epoch, num_train_epochs): | |
| train_loss_spatial = 0.0 | |
| train_loss_temporal = 0.0 | |
| for step, batch in enumerate(train_dataloader): | |
| # Skip steps until we reach the resumed step | |
| if resume_from_checkpoint and epoch == first_epoch and step < resume_step: | |
| if step % gradient_accumulation_steps == 0: | |
| progress_bar.update(1) | |
| continue | |
| with accelerator.accumulate(unet), accelerator.accumulate(text_encoder): | |
| text_prompt = batch['text_prompt'][0] | |
| for optimizer_spatial in optimizer_spatial_list: | |
| optimizer_spatial.zero_grad(set_to_none=True) | |
| optimizer_temporal.zero_grad(set_to_none=True) | |
| mask_temporal_lora = False | |
| # mask_spatial_lora = False | |
| mask_spatial_lora = random.uniform(0, 1) < 0.1 and not mask_temporal_lora | |
| with accelerator.autocast(): | |
| loss_spatial, loss_temporal, latents, init_noise = finetune_unet(batch, step, mask_spatial_lora=mask_spatial_lora, mask_temporal_lora=mask_temporal_lora) | |
| # Gather the losses across all processes for logging (if we use distributed training). | |
| if not mask_spatial_lora: | |
| avg_loss_spatial = accelerator.gather(loss_spatial.repeat(train_batch_size)).mean() | |
| train_loss_spatial += avg_loss_spatial.item() / gradient_accumulation_steps | |
| if not mask_temporal_lora: | |
| avg_loss_temporal = accelerator.gather(loss_temporal.repeat(train_batch_size)).mean() | |
| train_loss_temporal += avg_loss_temporal.item() / gradient_accumulation_steps | |
| # Backpropagate | |
| if not mask_spatial_lora: | |
| accelerator.backward(loss_spatial, retain_graph = True) | |
| optimizer_spatial_list[step].step() | |
| if not mask_temporal_lora: | |
| accelerator.backward(loss_temporal) | |
| optimizer_temporal.step() | |
| lr_scheduler_spatial_list[step].step() | |
| lr_scheduler_temporal.step() | |
| # Checks if the accelerator has performed an optimization step behind the scenes | |
| if accelerator.sync_gradients: | |
| progress_bar.update(1) | |
| global_step += 1 | |
| accelerator.log({"train_loss": train_loss_temporal}, step=global_step) | |
| train_loss_temporal = 0.0 | |
| if global_step % checkpointing_steps == 0 and global_step > 0: | |
| save_pipe( | |
| pretrained_model_path, | |
| global_step, | |
| accelerator, | |
| unet, | |
| text_encoder, | |
| vae, | |
| output_dir, | |
| lora_manager_spatial, | |
| lora_manager_temporal, | |
| unet_lora_modules, | |
| text_encoder_lora_modules, | |
| is_checkpoint=True, | |
| save_pretrained_model=save_pretrained_model | |
| ) | |
| if should_sample(global_step, validation_steps, validation_data): | |
| if accelerator.is_main_process: | |
| with accelerator.autocast(): | |
| unet.eval() | |
| text_encoder.eval() | |
| unet_and_text_g_c(unet, text_encoder, False, False) | |
| loras = extract_lora_child_module(unet, target_replace_module=["Transformer2DModel"]) | |
| for lora_i in loras: | |
| lora_i.scale = validation_data.spatial_scale | |
| if validation_data.noise_prior > 0: | |
| preset_noise = (validation_data.noise_prior) ** 0.5 * batch['inversion_noise'] + ( | |
| 1-validation_data.noise_prior) ** 0.5 * torch.randn_like(batch['inversion_noise']) | |
| else: | |
| preset_noise = None | |
| pipeline = TextToVideoSDPipeline.from_pretrained( | |
| pretrained_model_path, | |
| text_encoder=text_encoder, | |
| vae=vae, | |
| unet=unet | |
| ) | |
| diffusion_scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) | |
| pipeline.scheduler = diffusion_scheduler | |
| prompt_list = text_prompt if len(validation_data.prompt) <= 0 else validation_data.prompt | |
| for prompt in prompt_list: | |
| save_filename = f"{global_step}_{prompt.replace('.', '')}" | |
| out_file = f"{output_dir}/samples/{save_filename}.mp4" | |
| with torch.no_grad(): | |
| video_frames = pipeline( | |
| prompt, | |
| width=validation_data.width, | |
| height=validation_data.height, | |
| num_frames=validation_data.num_frames, | |
| num_inference_steps=validation_data.num_inference_steps, | |
| guidance_scale=validation_data.guidance_scale, | |
| latents=preset_noise | |
| ).frames | |
| export_to_video(video_frames, out_file, train_data.get('fps', 8)) | |
| logger.info(f"Saved a new sample to {out_file}") | |
| del pipeline | |
| torch.cuda.empty_cache() | |
| unet_and_text_g_c( | |
| unet, | |
| text_encoder, | |
| gradient_checkpointing, | |
| text_encoder_gradient_checkpointing | |
| ) | |
| accelerator.log({"loss_temporal": loss_temporal.detach().item()}, step=step) | |
| if global_step >= max_train_steps: | |
| break | |
| # Create the pipeline using the trained modules and save it. | |
| accelerator.wait_for_everyone() | |
| if accelerator.is_main_process: | |
| save_pipe( | |
| pretrained_model_path, | |
| global_step, | |
| accelerator, | |
| unet, | |
| text_encoder, | |
| vae, | |
| output_dir, | |
| lora_manager_spatial, | |
| lora_manager_temporal, | |
| unet_lora_modules, | |
| text_encoder_lora_modules, | |
| is_checkpoint=False, | |
| save_pretrained_model=save_pretrained_model | |
| ) | |
| accelerator.end_training() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", type=str, default='./configs/config_multi_videos.yaml') | |
| args = parser.parse_args() | |
| main(**OmegaConf.load(args.config)) | |