Spaces:
Paused
Paused
| import os | |
| import random | |
| from collections import OrderedDict | |
| from typing import Union, Literal, List, Optional | |
| import numpy as np | |
| from diffusers import T2IAdapter, AutoencoderTiny, ControlNetModel | |
| import torch.functional as F | |
| from safetensors.torch import load_file | |
| from torch.utils.data import DataLoader, ConcatDataset | |
| from toolkit import train_tools | |
| from toolkit.basic import value_map, adain, get_mean_std | |
| from toolkit.clip_vision_adapter import ClipVisionAdapter | |
| from toolkit.config_modules import GenerateImageConfig | |
| from toolkit.data_loader import get_dataloader_datasets | |
| from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO | |
| from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss, GuidanceType | |
| from toolkit.image_utils import show_tensors, show_latents | |
| from toolkit.ip_adapter import IPAdapter | |
| from toolkit.custom_adapter import CustomAdapter | |
| from toolkit.print import print_acc | |
| from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds | |
| from toolkit.reference_adapter import ReferenceAdapter | |
| from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork | |
| from toolkit.train_tools import get_torch_dtype, apply_snr_weight, add_all_snr_to_noise_scheduler, \ | |
| apply_learnable_snr_gos, LearnableSNRGamma | |
| import gc | |
| import torch | |
| from jobs.process import BaseSDTrainProcess | |
| from torchvision import transforms | |
| from diffusers import EMAModel | |
| import math | |
| from toolkit.train_tools import precondition_model_outputs_flow_match | |
| from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe | |
| from toolkit.util.wavelet_loss import wavelet_loss | |
| import torch.nn.functional as F | |
| from toolkit.unloader import unload_text_encoder | |
| from PIL import Image | |
| from torchvision.transforms import functional as TF | |
| def flush(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| adapter_transforms = transforms.Compose([ | |
| transforms.ToTensor(), | |
| ]) | |
| class SDTrainer(BaseSDTrainProcess): | |
| def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): | |
| super().__init__(process_id, job, config, **kwargs) | |
| self.assistant_adapter: Union['T2IAdapter', 'ControlNetModel', None] | |
| self.do_prior_prediction = False | |
| self.do_long_prompts = False | |
| self.do_guided_loss = False | |
| self.taesd: Optional[AutoencoderTiny] = None | |
| self._clip_image_embeds_unconditional: Union[List[str], None] = None | |
| self.negative_prompt_pool: Union[List[str], None] = None | |
| self.batch_negative_prompt: Union[List[str], None] = None | |
| self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16" | |
| self.do_grad_scale = True | |
| if self.is_fine_tuning and self.is_bfloat: | |
| self.do_grad_scale = False | |
| if self.adapter_config is not None: | |
| if self.adapter_config.train: | |
| self.do_grad_scale = False | |
| # if self.train_config.dtype in ["fp16", "float16"]: | |
| # # patch the scaler to allow fp16 training | |
| # org_unscale_grads = self.scaler._unscale_grads_ | |
| # def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): | |
| # return org_unscale_grads(optimizer, inv_scale, found_inf, True) | |
| # self.scaler._unscale_grads_ = _unscale_grads_replacer | |
| self.cached_blank_embeds: Optional[PromptEmbeds] = None | |
| self.cached_trigger_embeds: Optional[PromptEmbeds] = None | |
| self.diff_output_preservation_embeds: Optional[PromptEmbeds] = None | |
| self.dfe: Optional[DiffusionFeatureExtractor] = None | |
| self.unconditional_embeds = None | |
| if self.train_config.diff_output_preservation: | |
| if self.trigger_word is None: | |
| raise ValueError("diff_output_preservation requires a trigger_word to be set") | |
| if self.network_config is None: | |
| raise ValueError("diff_output_preservation requires a network to be set") | |
| if self.train_config.train_text_encoder: | |
| raise ValueError("diff_output_preservation is not supported with train_text_encoder") | |
| # always do a prior prediction when doing diff output preservation | |
| self.do_prior_prediction = True | |
| # store the loss target for a batch so we can use it in a loss | |
| self._guidance_loss_target_batch: float = 0.0 | |
| if isinstance(self.train_config.guidance_loss_target, (int, float)): | |
| self._guidance_loss_target_batch = float(self.train_config.guidance_loss_target) | |
| elif isinstance(self.train_config.guidance_loss_target, list): | |
| self._guidance_loss_target_batch = float(self.train_config.guidance_loss_target[0]) | |
| else: | |
| raise ValueError(f"Unknown guidance loss target type {type(self.train_config.guidance_loss_target)}") | |
| def before_model_load(self): | |
| pass | |
| def cache_sample_prompts(self): | |
| if self.train_config.disable_sampling: | |
| return | |
| if self.sample_config is not None and self.sample_config.samples is not None and len(self.sample_config.samples) > 0: | |
| # cache all the samples | |
| self.sd.sample_prompts_cache = [] | |
| sample_folder = os.path.join(self.save_root, 'samples') | |
| output_path = os.path.join(sample_folder, 'test.jpg') | |
| for i in range(len(self.sample_config.prompts)): | |
| sample_item = self.sample_config.samples[i] | |
| prompt = self.sample_config.prompts[i] | |
| # needed so we can autoparse the prompt to handle flags | |
| gen_img_config = GenerateImageConfig( | |
| prompt=prompt, # it will autoparse the prompt | |
| negative_prompt=sample_item.neg, | |
| output_path=output_path, | |
| ctrl_img=sample_item.ctrl_img | |
| ) | |
| # see if we need to encode the control images | |
| if self.sd.encode_control_in_text_embeddings and gen_img_config.ctrl_img is not None: | |
| ctrl_img = Image.open(gen_img_config.ctrl_img).convert("RGB") | |
| # convert to 0 to 1 tensor | |
| ctrl_img = ( | |
| TF.to_tensor(ctrl_img) | |
| .unsqueeze(0) | |
| .to(self.sd.device_torch, dtype=self.sd.torch_dtype) | |
| ) | |
| positive = self.sd.encode_prompt( | |
| gen_img_config.prompt, | |
| control_images=ctrl_img | |
| ).to('cpu') | |
| negative = self.sd.encode_prompt( | |
| gen_img_config.negative_prompt, | |
| control_images=ctrl_img | |
| ).to('cpu') | |
| else: | |
| positive = self.sd.encode_prompt(gen_img_config.prompt).to('cpu') | |
| negative = self.sd.encode_prompt(gen_img_config.negative_prompt).to('cpu') | |
| self.sd.sample_prompts_cache.append({ | |
| 'conditional': positive, | |
| 'unconditional': negative | |
| }) | |
| def before_dataset_load(self): | |
| self.assistant_adapter = None | |
| # get adapter assistant if one is set | |
| if self.train_config.adapter_assist_name_or_path is not None: | |
| adapter_path = self.train_config.adapter_assist_name_or_path | |
| if self.train_config.adapter_assist_type == "t2i": | |
| # dont name this adapter since we are not training it | |
| self.assistant_adapter = T2IAdapter.from_pretrained( | |
| adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype) | |
| ).to(self.device_torch) | |
| elif self.train_config.adapter_assist_type == "control_net": | |
| self.assistant_adapter = ControlNetModel.from_pretrained( | |
| adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype) | |
| ).to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)) | |
| else: | |
| raise ValueError(f"Unknown adapter assist type {self.train_config.adapter_assist_type}") | |
| self.assistant_adapter.eval() | |
| self.assistant_adapter.requires_grad_(False) | |
| flush() | |
| if self.train_config.train_turbo and self.train_config.show_turbo_outputs: | |
| if self.model_config.is_xl: | |
| self.taesd = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", | |
| torch_dtype=get_torch_dtype(self.train_config.dtype)) | |
| else: | |
| self.taesd = AutoencoderTiny.from_pretrained("madebyollin/taesd", | |
| torch_dtype=get_torch_dtype(self.train_config.dtype)) | |
| self.taesd.to(dtype=get_torch_dtype(self.train_config.dtype), device=self.device_torch) | |
| self.taesd.eval() | |
| self.taesd.requires_grad_(False) | |
| def hook_before_train_loop(self): | |
| super().hook_before_train_loop() | |
| if self.is_caching_text_embeddings: | |
| # make sure model is on cpu for this part so we don't oom. | |
| self.sd.unet.to('cpu') | |
| # cache unconditional embeds (blank prompt) | |
| with torch.no_grad(): | |
| kwargs = {} | |
| if self.sd.encode_control_in_text_embeddings: | |
| # just do a blank image for unconditionals | |
| control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype) | |
| kwargs['control_images'] = control_image | |
| self.unconditional_embeds = self.sd.encode_prompt( | |
| [self.train_config.unconditional_prompt], | |
| long_prompts=self.do_long_prompts, | |
| **kwargs | |
| ).to( | |
| self.device_torch, | |
| dtype=self.sd.torch_dtype | |
| ).detach() | |
| if self.train_config.do_prior_divergence: | |
| self.do_prior_prediction = True | |
| # move vae to device if we did not cache latents | |
| if not self.is_latents_cached: | |
| self.sd.vae.eval() | |
| self.sd.vae.to(self.device_torch) | |
| else: | |
| # offload it. Already cached | |
| self.sd.vae.to('cpu') | |
| flush() | |
| add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch) | |
| if self.adapter is not None: | |
| self.adapter.to(self.device_torch) | |
| # check if we have regs and using adapter and caching clip embeddings | |
| has_reg = self.datasets_reg is not None and len(self.datasets_reg) > 0 | |
| is_caching_clip_embeddings = self.datasets is not None and any([self.datasets[i].cache_clip_vision_to_disk for i in range(len(self.datasets))]) | |
| if has_reg and is_caching_clip_embeddings: | |
| # we need a list of unconditional clip image embeds from other datasets to handle regs | |
| unconditional_clip_image_embeds = [] | |
| datasets = get_dataloader_datasets(self.data_loader) | |
| for i in range(len(datasets)): | |
| unconditional_clip_image_embeds += datasets[i].clip_vision_unconditional_cache | |
| if len(unconditional_clip_image_embeds) == 0: | |
| raise ValueError("No unconditional clip image embeds found. This should not happen") | |
| self._clip_image_embeds_unconditional = unconditional_clip_image_embeds | |
| if self.train_config.negative_prompt is not None: | |
| if os.path.exists(self.train_config.negative_prompt): | |
| with open(self.train_config.negative_prompt, 'r') as f: | |
| self.negative_prompt_pool = f.readlines() | |
| # remove empty | |
| self.negative_prompt_pool = [x.strip() for x in self.negative_prompt_pool if x.strip() != ""] | |
| else: | |
| # single prompt | |
| self.negative_prompt_pool = [self.train_config.negative_prompt] | |
| # handle unload text encoder | |
| if self.train_config.unload_text_encoder or self.is_caching_text_embeddings: | |
| with torch.no_grad(): | |
| if self.train_config.train_text_encoder: | |
| raise ValueError("Cannot unload text encoder if training text encoder") | |
| # cache embeddings | |
| print_acc("\n***** UNLOADING TEXT ENCODER *****") | |
| if self.is_caching_text_embeddings: | |
| print_acc("Embeddings cached to disk. We dont need the text encoder anymore") | |
| else: | |
| print_acc("This will train only with a blank prompt or trigger word, if set") | |
| print_acc("If this is not what you want, remove the unload_text_encoder flag") | |
| print_acc("***********************************") | |
| print_acc("") | |
| self.sd.text_encoder_to(self.device_torch) | |
| encode_kwargs = {} | |
| if self.sd.encode_control_in_text_embeddings: | |
| # just do a blank image for unconditionals | |
| control_image = torch.zeros((1, 3, 224, 224), device=self.sd.device_torch, dtype=self.sd.torch_dtype) | |
| encode_kwargs['control_images'] = control_image | |
| self.cached_blank_embeds = self.sd.encode_prompt("", **encode_kwargs) | |
| if self.trigger_word is not None: | |
| self.cached_trigger_embeds = self.sd.encode_prompt(self.trigger_word, **encode_kwargs) | |
| if self.train_config.diff_output_preservation: | |
| self.diff_output_preservation_embeds = self.sd.encode_prompt(self.train_config.diff_output_preservation_class) | |
| self.cache_sample_prompts() | |
| # unload the text encoder | |
| if self.is_caching_text_embeddings: | |
| unload_text_encoder(self.sd) | |
| else: | |
| # todo once every model is tested to work, unload properly. Though, this will all be merged into one thing. | |
| # keep legacy usage for now. | |
| self.sd.text_encoder_to("cpu") | |
| flush() | |
| if self.train_config.diffusion_feature_extractor_path is not None: | |
| vae = self.sd.vae | |
| # if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer": | |
| # vae = self.sd.vae | |
| self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path, vae=vae) | |
| self.dfe.to(self.device_torch) | |
| if hasattr(self.dfe, 'vision_encoder') and self.train_config.gradient_checkpointing: | |
| # must be set to train for gradient checkpointing to work | |
| self.dfe.vision_encoder.train() | |
| self.dfe.vision_encoder.gradient_checkpointing = True | |
| else: | |
| self.dfe.eval() | |
| # enable gradient checkpointing on the vae | |
| if vae is not None and self.train_config.gradient_checkpointing: | |
| try: | |
| vae.enable_gradient_checkpointing() | |
| vae.train() | |
| except: | |
| pass | |
| def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch): | |
| # to process turbo learning, we make one big step from our current timestep to the end | |
| # we then denoise the prediction on that remaining step and target our loss to our target latents | |
| # this currently only works on euler_a (that I know of). Would work on others, but needs to be coded to do so. | |
| # needs to be done on each item in batch as they may all have different timesteps | |
| batch_size = pred.shape[0] | |
| pred_chunks = torch.chunk(pred, batch_size, dim=0) | |
| noisy_latents_chunks = torch.chunk(noisy_latents, batch_size, dim=0) | |
| timesteps_chunks = torch.chunk(timesteps, batch_size, dim=0) | |
| latent_chunks = torch.chunk(batch.latents, batch_size, dim=0) | |
| noise_chunks = torch.chunk(noise, batch_size, dim=0) | |
| with torch.no_grad(): | |
| # set the timesteps to 1000 so we can capture them to calculate the sigmas | |
| self.sd.noise_scheduler.set_timesteps( | |
| self.sd.noise_scheduler.config.num_train_timesteps, | |
| device=self.device_torch | |
| ) | |
| train_timesteps = self.sd.noise_scheduler.timesteps.clone().detach() | |
| train_sigmas = self.sd.noise_scheduler.sigmas.clone().detach() | |
| # set the scheduler to one timestep, we build the step and sigmas for each item in batch for the partial step | |
| self.sd.noise_scheduler.set_timesteps( | |
| 1, | |
| device=self.device_torch | |
| ) | |
| denoised_pred_chunks = [] | |
| target_pred_chunks = [] | |
| for i in range(batch_size): | |
| pred_item = pred_chunks[i] | |
| noisy_latents_item = noisy_latents_chunks[i] | |
| timesteps_item = timesteps_chunks[i] | |
| latents_item = latent_chunks[i] | |
| noise_item = noise_chunks[i] | |
| with torch.no_grad(): | |
| timestep_idx = [(train_timesteps == t).nonzero().item() for t in timesteps_item][0] | |
| single_step_timestep_schedule = [timesteps_item.squeeze().item()] | |
| # extract the sigma idx for our midpoint timestep | |
| sigmas = train_sigmas[timestep_idx:timestep_idx + 1].to(self.device_torch) | |
| end_sigma_idx = random.randint(timestep_idx, len(train_sigmas) - 1) | |
| end_sigma = train_sigmas[end_sigma_idx:end_sigma_idx + 1].to(self.device_torch) | |
| # add noise to our target | |
| # build the big sigma step. The to step will now be to 0 giving it a full remaining denoising half step | |
| # self.sd.noise_scheduler.sigmas = torch.cat([sigmas, torch.zeros_like(sigmas)]).detach() | |
| self.sd.noise_scheduler.sigmas = torch.cat([sigmas, end_sigma]).detach() | |
| # set our single timstep | |
| self.sd.noise_scheduler.timesteps = torch.from_numpy( | |
| np.array(single_step_timestep_schedule, dtype=np.float32) | |
| ).to(device=self.device_torch) | |
| # set the step index to None so it will be recalculated on first step | |
| self.sd.noise_scheduler._step_index = None | |
| denoised_latent = self.sd.noise_scheduler.step( | |
| pred_item, timesteps_item, noisy_latents_item.detach(), return_dict=False | |
| )[0] | |
| residual_noise = (noise_item * end_sigma.flatten()).detach().to(self.device_torch, dtype=get_torch_dtype( | |
| self.train_config.dtype)) | |
| # remove the residual noise from the denoised latents. Output should be a clean prediction (theoretically) | |
| denoised_latent = denoised_latent - residual_noise | |
| denoised_pred_chunks.append(denoised_latent) | |
| denoised_latents = torch.cat(denoised_pred_chunks, dim=0) | |
| # set the scheduler back to the original timesteps | |
| self.sd.noise_scheduler.set_timesteps( | |
| self.sd.noise_scheduler.config.num_train_timesteps, | |
| device=self.device_torch | |
| ) | |
| output = denoised_latents / self.sd.vae.config['scaling_factor'] | |
| output = self.sd.vae.decode(output).sample | |
| if self.train_config.show_turbo_outputs: | |
| # since we are completely denoising, we can show them here | |
| with torch.no_grad(): | |
| show_tensors(output) | |
| # we return our big partial step denoised latents as our pred and our untouched latents as our target. | |
| # you can do mse against the two here or run the denoised through the vae for pixel space loss against the | |
| # input tensor images. | |
| return output, batch.tensor.to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)) | |
| # you can expand these in a child class to make customization easier | |
| def calculate_loss( | |
| self, | |
| noise_pred: torch.Tensor, | |
| noise: torch.Tensor, | |
| noisy_latents: torch.Tensor, | |
| timesteps: torch.Tensor, | |
| batch: 'DataLoaderBatchDTO', | |
| mask_multiplier: Union[torch.Tensor, float] = 1.0, | |
| prior_pred: Union[torch.Tensor, None] = None, | |
| **kwargs | |
| ): | |
| loss_target = self.train_config.loss_target | |
| is_reg = any(batch.get_is_reg_list()) | |
| additional_loss = 0.0 | |
| prior_mask_multiplier = None | |
| target_mask_multiplier = None | |
| dtype = get_torch_dtype(self.train_config.dtype) | |
| has_mask = batch.mask_tensor is not None | |
| with torch.no_grad(): | |
| loss_multiplier = torch.tensor(batch.loss_multiplier_list).to(self.device_torch, dtype=torch.float32) | |
| if self.train_config.match_noise_norm: | |
| # match the norm of the noise | |
| noise_norm = torch.linalg.vector_norm(noise, ord=2, dim=(1, 2, 3), keepdim=True) | |
| noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True) | |
| noise_pred = noise_pred * (noise_norm / noise_pred_norm) | |
| if self.train_config.pred_scaler != 1.0: | |
| noise_pred = noise_pred * self.train_config.pred_scaler | |
| target = None | |
| if self.train_config.target_noise_multiplier != 1.0: | |
| noise = noise * self.train_config.target_noise_multiplier | |
| if self.train_config.correct_pred_norm or (self.train_config.inverted_mask_prior and prior_pred is not None and has_mask): | |
| if self.train_config.correct_pred_norm and not is_reg: | |
| with torch.no_grad(): | |
| # this only works if doing a prior pred | |
| if prior_pred is not None: | |
| prior_mean = prior_pred.mean([2,3], keepdim=True) | |
| prior_std = prior_pred.std([2,3], keepdim=True) | |
| noise_mean = noise_pred.mean([2,3], keepdim=True) | |
| noise_std = noise_pred.std([2,3], keepdim=True) | |
| mean_adjust = prior_mean - noise_mean | |
| std_adjust = prior_std - noise_std | |
| mean_adjust = mean_adjust * self.train_config.correct_pred_norm_multiplier | |
| std_adjust = std_adjust * self.train_config.correct_pred_norm_multiplier | |
| target_mean = noise_mean + mean_adjust | |
| target_std = noise_std + std_adjust | |
| eps = 1e-5 | |
| # match the noise to the prior | |
| noise = (noise - noise_mean) / (noise_std + eps) | |
| noise = noise * (target_std + eps) + target_mean | |
| noise = noise.detach() | |
| if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask: | |
| assert not self.train_config.train_turbo | |
| with torch.no_grad(): | |
| prior_mask = batch.mask_tensor.to(self.device_torch, dtype=dtype) | |
| # resize to size of noise_pred | |
| prior_mask = torch.nn.functional.interpolate(prior_mask, size=(noise_pred.shape[2], noise_pred.shape[3]), mode='bicubic') | |
| # stack first channel to match channels of noise_pred | |
| prior_mask = torch.cat([prior_mask[:1]] * noise_pred.shape[1], dim=1) | |
| prior_mask_multiplier = 1.0 - prior_mask | |
| # scale so it is a mean of 1 | |
| prior_mask_multiplier = prior_mask_multiplier / prior_mask_multiplier.mean() | |
| if self.sd.is_flow_matching: | |
| target = (noise - batch.latents).detach() | |
| else: | |
| target = noise | |
| elif prior_pred is not None and not self.train_config.do_prior_divergence: | |
| assert not self.train_config.train_turbo | |
| # matching adapter prediction | |
| target = prior_pred | |
| elif self.sd.prediction_type == 'v_prediction': | |
| # v-parameterization training | |
| target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps) | |
| elif hasattr(self.sd, 'get_loss_target'): | |
| target = self.sd.get_loss_target( | |
| noise=noise, | |
| batch=batch, | |
| timesteps=timesteps, | |
| ).detach() | |
| elif self.sd.is_flow_matching: | |
| # forward ODE | |
| target = (noise - batch.latents).detach() | |
| # reverse ODE | |
| # target = (batch.latents - noise).detach() | |
| else: | |
| target = noise | |
| if self.dfe is not None: | |
| if self.dfe.version == 1: | |
| model = self.sd | |
| if model is not None and hasattr(model, 'get_stepped_pred'): | |
| stepped_latents = model.get_stepped_pred(noise_pred, noise) | |
| else: | |
| # stepped_latents = noise - noise_pred | |
| # first we step the scheduler from current timestep to the very end for a full denoise | |
| bs = noise_pred.shape[0] | |
| noise_pred_chunks = torch.chunk(noise_pred, bs) | |
| timestep_chunks = torch.chunk(timesteps, bs) | |
| noisy_latent_chunks = torch.chunk(noisy_latents, bs) | |
| stepped_chunks = [] | |
| for idx in range(bs): | |
| model_output = noise_pred_chunks[idx] | |
| timestep = timestep_chunks[idx] | |
| self.sd.noise_scheduler._step_index = None | |
| self.sd.noise_scheduler._init_step_index(timestep) | |
| sample = noisy_latent_chunks[idx].to(torch.float32) | |
| sigma = self.sd.noise_scheduler.sigmas[self.sd.noise_scheduler.step_index] | |
| sigma_next = self.sd.noise_scheduler.sigmas[-1] # use last sigma for final step | |
| prev_sample = sample + (sigma_next - sigma) * model_output | |
| stepped_chunks.append(prev_sample) | |
| stepped_latents = torch.cat(stepped_chunks, dim=0) | |
| stepped_latents = stepped_latents.to(self.sd.vae.device, dtype=self.sd.vae.dtype) | |
| # resize to half the size of the latents | |
| stepped_latents_half = torch.nn.functional.interpolate( | |
| stepped_latents, | |
| size=(stepped_latents.shape[2] // 2, stepped_latents.shape[3] // 2), | |
| mode='bilinear', | |
| align_corners=False | |
| ) | |
| pred_features = self.dfe(stepped_latents.float()) | |
| pred_features_half = self.dfe(stepped_latents_half.float()) | |
| with torch.no_grad(): | |
| target_features = self.dfe(batch.latents.to(self.device_torch, dtype=torch.float32)) | |
| batch_latents_half = torch.nn.functional.interpolate( | |
| batch.latents.to(self.device_torch, dtype=torch.float32), | |
| size=(batch.latents.shape[2] // 2, batch.latents.shape[3] // 2), | |
| mode='bilinear', | |
| align_corners=False | |
| ) | |
| target_features_half = self.dfe(batch_latents_half) | |
| # scale dfe so it is weaker at higher noise levels | |
| dfe_scaler = 1 - (timesteps.float() / 1000.0).view(-1, 1, 1, 1).to(self.device_torch) | |
| dfe_loss = torch.nn.functional.mse_loss(pred_features, target_features, reduction="none") * \ | |
| self.train_config.diffusion_feature_extractor_weight * dfe_scaler | |
| dfe_loss_half = torch.nn.functional.mse_loss(pred_features_half, target_features_half, reduction="none") * \ | |
| self.train_config.diffusion_feature_extractor_weight * dfe_scaler | |
| additional_loss += dfe_loss.mean() + dfe_loss_half.mean() | |
| elif self.dfe.version == 2: | |
| # version 2 | |
| # do diffusion feature extraction on target | |
| with torch.no_grad(): | |
| rectified_flow_target = noise.float() - batch.latents.float() | |
| target_feature_list = self.dfe(torch.cat([rectified_flow_target, noise.float()], dim=1)) | |
| # do diffusion feature extraction on prediction | |
| pred_feature_list = self.dfe(torch.cat([noise_pred.float(), noise.float()], dim=1)) | |
| dfe_loss = 0.0 | |
| for i in range(len(target_feature_list)): | |
| dfe_loss += torch.nn.functional.mse_loss(pred_feature_list[i], target_feature_list[i], reduction="mean") | |
| additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0 | |
| elif self.dfe.version == 3 or self.dfe.version == 4: | |
| dfe_loss = self.dfe( | |
| noise=noise, | |
| noise_pred=noise_pred, | |
| noisy_latents=noisy_latents, | |
| timesteps=timesteps, | |
| batch=batch, | |
| scheduler=self.sd.noise_scheduler | |
| ) | |
| additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight | |
| else: | |
| raise ValueError(f"Unknown diffusion feature extractor version {self.dfe.version}") | |
| if self.train_config.do_guidance_loss: | |
| with torch.no_grad(): | |
| # we make cached blank prompt embeds that match the batch size | |
| unconditional_embeds = concat_prompt_embeds( | |
| [self.unconditional_embeds] * noisy_latents.shape[0], | |
| ) | |
| cfm_pred = self.predict_noise( | |
| noisy_latents=noisy_latents, | |
| timesteps=timesteps, | |
| conditional_embeds=unconditional_embeds, | |
| unconditional_embeds=None, | |
| batch=batch, | |
| ) | |
| # zero cfg | |
| # ref https://github.com/WeichenFan/CFG-Zero-star/blob/cdac25559e3f16cb95f0016c04c709ea1ab9452b/wan_pipeline.py#L557 | |
| batch_size = target.shape[0] | |
| positive_flat = target.view(batch_size, -1) | |
| negative_flat = cfm_pred.view(batch_size, -1) | |
| # Calculate dot production | |
| dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) | |
| # Squared norm of uncondition | |
| squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 | |
| # st_star = v_cond^T * v_uncond / ||v_uncond||^2 | |
| st_star = dot_product / squared_norm | |
| alpha = st_star | |
| is_video = len(target.shape) == 5 | |
| alpha = alpha.view(batch_size, 1, 1, 1) if not is_video else alpha.view(batch_size, 1, 1, 1, 1) | |
| guidance_scale = self._guidance_loss_target_batch | |
| if isinstance(guidance_scale, list): | |
| guidance_scale = torch.tensor(guidance_scale).to(target.device, dtype=target.dtype) | |
| guidance_scale = guidance_scale.view(-1, 1, 1, 1) if not is_video else guidance_scale.view(-1, 1, 1, 1, 1) | |
| unconditional_target = cfm_pred * alpha | |
| target = unconditional_target + guidance_scale * (target - unconditional_target) | |
| if target is None: | |
| target = noise | |
| pred = noise_pred | |
| if self.train_config.train_turbo: | |
| pred, target = self.process_output_for_turbo(pred, noisy_latents, timesteps, noise, batch) | |
| ignore_snr = False | |
| if loss_target == 'source' or loss_target == 'unaugmented': | |
| assert not self.train_config.train_turbo | |
| # ignore_snr = True | |
| if batch.sigmas is None: | |
| raise ValueError("Batch sigmas is None. This should not happen") | |
| # src https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1190 | |
| denoised_latents = noise_pred * (-batch.sigmas) + noisy_latents | |
| weighing = batch.sigmas ** -2.0 | |
| if loss_target == 'source': | |
| # denoise the latent and compare to the latent in the batch | |
| target = batch.latents | |
| elif loss_target == 'unaugmented': | |
| # we have to encode images into latents for now | |
| # we also denoise as the unaugmented tensor is not a noisy diffirental | |
| with torch.no_grad(): | |
| unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor).to(self.device_torch, dtype=dtype) | |
| unaugmented_latents = unaugmented_latents * self.train_config.latent_multiplier | |
| target = unaugmented_latents.detach() | |
| # Get the target for loss depending on the prediction type | |
| if self.sd.noise_scheduler.config.prediction_type == "epsilon": | |
| target = target # we are computing loss against denoise latents | |
| elif self.sd.noise_scheduler.config.prediction_type == "v_prediction": | |
| target = self.sd.noise_scheduler.get_velocity(target, noise, timesteps) | |
| else: | |
| raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}") | |
| # mse loss without reduction | |
| loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2) | |
| loss = loss_per_element | |
| else: | |
| if self.train_config.loss_type == "mae": | |
| loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none") | |
| elif self.train_config.loss_type == "wavelet": | |
| loss = wavelet_loss(pred, batch.latents, noise) | |
| else: | |
| loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none") | |
| do_weighted_timesteps = False | |
| if self.sd.is_flow_matching: | |
| if self.train_config.linear_timesteps or self.train_config.linear_timesteps2: | |
| do_weighted_timesteps = True | |
| if self.train_config.timestep_type == "weighted": | |
| # use the noise scheduler to get the weights for the timesteps | |
| do_weighted_timesteps = True | |
| # handle linear timesteps and only adjust the weight of the timesteps | |
| if do_weighted_timesteps: | |
| # calculate the weights for the timesteps | |
| timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps( | |
| timesteps, | |
| v2=self.train_config.linear_timesteps2, | |
| timestep_type=self.train_config.timestep_type | |
| ).to(loss.device, dtype=loss.dtype) | |
| if len(loss.shape) == 4: | |
| timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach() | |
| elif len(loss.shape) == 5: | |
| timestep_weight = timestep_weight.view(-1, 1, 1, 1, 1).detach() | |
| loss = loss * timestep_weight | |
| if self.train_config.do_prior_divergence and prior_pred is not None: | |
| loss = loss + (torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") * -1.0) | |
| if self.train_config.train_turbo: | |
| mask_multiplier = mask_multiplier[:, 3:, :, :] | |
| # resize to the size of the loss | |
| mask_multiplier = torch.nn.functional.interpolate(mask_multiplier, size=(pred.shape[2], pred.shape[3]), mode='nearest') | |
| # multiply by our mask | |
| try: | |
| loss = loss * mask_multiplier | |
| except: | |
| # todo handle mask with video models | |
| pass | |
| prior_loss = None | |
| if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None: | |
| assert not self.train_config.train_turbo | |
| if self.train_config.loss_type == "mae": | |
| prior_loss = torch.nn.functional.l1_loss(pred.float(), prior_pred.float(), reduction="none") | |
| else: | |
| prior_loss = torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") | |
| prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier | |
| if torch.isnan(prior_loss).any(): | |
| print_acc("Prior loss is nan") | |
| prior_loss = None | |
| else: | |
| prior_loss = prior_loss.mean([1, 2, 3]) | |
| # loss = loss + prior_loss | |
| # loss = loss + prior_loss | |
| # loss = loss + prior_loss | |
| loss = loss.mean([1, 2, 3]) | |
| # apply loss multiplier before prior loss | |
| # multiply by our mask | |
| try: | |
| loss = loss * loss_multiplier | |
| except: | |
| # todo handle mask with video models | |
| pass | |
| if prior_loss is not None: | |
| loss = loss + prior_loss | |
| if not self.train_config.train_turbo: | |
| if self.train_config.learnable_snr_gos: | |
| # add snr_gamma | |
| loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos) | |
| elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr: | |
| # add snr_gamma | |
| loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, | |
| fixed=True) | |
| elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr: | |
| # add min_snr_gamma | |
| loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) | |
| loss = loss.mean() | |
| # check for additional losses | |
| if self.adapter is not None and hasattr(self.adapter, "additional_loss") and self.adapter.additional_loss is not None: | |
| loss = loss + self.adapter.additional_loss.mean() | |
| self.adapter.additional_loss = None | |
| if self.train_config.target_norm_std: | |
| # seperate out the batch and channels | |
| pred_std = noise_pred.std([2, 3], keepdim=True) | |
| norm_std_loss = torch.abs(self.train_config.target_norm_std_value - pred_std).mean() | |
| loss = loss + norm_std_loss | |
| return loss + additional_loss | |
| def preprocess_batch(self, batch: 'DataLoaderBatchDTO'): | |
| return batch | |
| def get_guided_loss( | |
| self, | |
| noisy_latents: torch.Tensor, | |
| conditional_embeds: PromptEmbeds, | |
| match_adapter_assist: bool, | |
| network_weight_list: list, | |
| timesteps: torch.Tensor, | |
| pred_kwargs: dict, | |
| batch: 'DataLoaderBatchDTO', | |
| noise: torch.Tensor, | |
| unconditional_embeds: Optional[PromptEmbeds] = None, | |
| **kwargs | |
| ): | |
| loss = get_guidance_loss( | |
| noisy_latents=noisy_latents, | |
| conditional_embeds=conditional_embeds, | |
| match_adapter_assist=match_adapter_assist, | |
| network_weight_list=network_weight_list, | |
| timesteps=timesteps, | |
| pred_kwargs=pred_kwargs, | |
| batch=batch, | |
| noise=noise, | |
| sd=self.sd, | |
| unconditional_embeds=unconditional_embeds, | |
| train_config=self.train_config, | |
| **kwargs | |
| ) | |
| return loss | |
| # ------------------------------------------------------------------ | |
| # Mean-Flow loss (Geng et al., “Mean Flows for One-step Generative | |
| # Modelling”, 2025 – see Alg. 1 + Eq. (6) of the paper) | |
| # This version avoids jvp / double-back-prop issues with Flash-Attention | |
| # adapted from the work of lodestonerock | |
| # ------------------------------------------------------------------ | |
| def get_mean_flow_loss( | |
| self, | |
| noisy_latents: torch.Tensor, | |
| conditional_embeds: PromptEmbeds, | |
| match_adapter_assist: bool, | |
| network_weight_list: list, | |
| timesteps: torch.Tensor, | |
| pred_kwargs: dict, | |
| batch: 'DataLoaderBatchDTO', | |
| noise: torch.Tensor, | |
| unconditional_embeds: Optional[PromptEmbeds] = None, | |
| **kwargs | |
| ): | |
| dtype = get_torch_dtype(self.train_config.dtype) | |
| total_steps = float(self.sd.noise_scheduler.config.num_train_timesteps) # e.g. 1000 | |
| base_eps = 1e-3 | |
| min_time_gap = 1e-2 | |
| with torch.no_grad(): | |
| num_train_timesteps = self.sd.noise_scheduler.config.num_train_timesteps | |
| batch_size = batch.latents.shape[0] | |
| timestep_t_list = [] | |
| timestep_r_list = [] | |
| for i in range(batch_size): | |
| t1 = random.randint(0, num_train_timesteps - 1) | |
| t2 = random.randint(0, num_train_timesteps - 1) | |
| t_t = self.sd.noise_scheduler.timesteps[min(t1, t2)] | |
| t_r = self.sd.noise_scheduler.timesteps[max(t1, t2)] | |
| if (t_t - t_r).item() < min_time_gap * 1000: | |
| scaled_time_gap = min_time_gap * 1000 | |
| if t_t.item() + scaled_time_gap > 1000: | |
| t_r = t_r - scaled_time_gap | |
| else: | |
| t_t = t_t + scaled_time_gap | |
| timestep_t_list.append(t_t) | |
| timestep_r_list.append(t_r) | |
| timesteps_t = torch.stack(timestep_t_list, dim=0).float() | |
| timesteps_r = torch.stack(timestep_r_list, dim=0).float() | |
| t_frac = timesteps_t / total_steps # [0,1] | |
| r_frac = timesteps_r / total_steps # [0,1] | |
| latents_clean = batch.latents.to(dtype) | |
| noise_sample = noise.to(dtype) | |
| lerp_vector = latents_clean * (1.0 - t_frac[:, None, None, None]) + noise_sample * t_frac[:, None, None, None] | |
| eps = base_eps | |
| # concatenate timesteps as input for u(z, r, t) | |
| timesteps_cat = torch.cat([t_frac, r_frac], dim=0) * total_steps | |
| # model predicts u(z, r, t) | |
| u_pred = self.predict_noise( | |
| noisy_latents=lerp_vector.to(dtype), | |
| timesteps=timesteps_cat.to(dtype), | |
| conditional_embeds=conditional_embeds, | |
| unconditional_embeds=unconditional_embeds, | |
| batch=batch, | |
| **pred_kwargs | |
| ) | |
| with torch.no_grad(): | |
| t_frac_plus_eps = (t_frac + eps).clamp(0.0, 1.0) | |
| lerp_perturbed = latents_clean * (1.0 - t_frac_plus_eps[:, None, None, None]) + noise_sample * t_frac_plus_eps[:, None, None, None] | |
| timesteps_cat_perturbed = torch.cat([t_frac_plus_eps, r_frac], dim=0) * total_steps | |
| u_perturbed = self.predict_noise( | |
| noisy_latents=lerp_perturbed.to(dtype), | |
| timesteps=timesteps_cat_perturbed.to(dtype), | |
| conditional_embeds=conditional_embeds, | |
| unconditional_embeds=unconditional_embeds, | |
| batch=batch, | |
| **pred_kwargs | |
| ) | |
| # compute du/dt via finite difference (detached) | |
| du_dt = (u_perturbed - u_pred).detach() / eps | |
| # du_dt = (u_perturbed - u_pred).detach() | |
| du_dt = du_dt.to(dtype) | |
| time_gap = (t_frac - r_frac)[:, None, None, None].to(dtype) | |
| time_gap.clamp(min=1e-4) | |
| u_shifted = u_pred + time_gap * du_dt | |
| # u_shifted = u_pred + du_dt / time_gap | |
| # u_shifted = u_pred | |
| # a step is done like this: | |
| # stepped_latent = model_input + (timestep_next - timestep) * model_output | |
| # flow target velocity | |
| # v_target = (noise_sample - latents_clean) / time_gap | |
| # flux predicts opposite of velocity, so we need to invert it | |
| v_target = (latents_clean - noise_sample) / time_gap | |
| # compute loss | |
| loss = torch.nn.functional.mse_loss( | |
| u_shifted.float(), | |
| v_target.float(), | |
| reduction='none' | |
| ) | |
| with torch.no_grad(): | |
| pure_loss = loss.mean().detach() | |
| pure_loss.requires_grad_(True) | |
| loss = loss.mean() | |
| if loss.item() > 1e3: | |
| pass | |
| self.accelerator.backward(loss) | |
| return pure_loss | |
| def get_prior_prediction( | |
| self, | |
| noisy_latents: torch.Tensor, | |
| conditional_embeds: PromptEmbeds, | |
| match_adapter_assist: bool, | |
| network_weight_list: list, | |
| timesteps: torch.Tensor, | |
| pred_kwargs: dict, | |
| batch: 'DataLoaderBatchDTO', | |
| noise: torch.Tensor, | |
| unconditional_embeds: Optional[PromptEmbeds] = None, | |
| conditioned_prompts=None, | |
| **kwargs | |
| ): | |
| # todo for embeddings, we need to run without trigger words | |
| was_unet_training = self.sd.unet.training | |
| was_network_active = False | |
| if self.network is not None: | |
| was_network_active = self.network.is_active | |
| self.network.is_active = False | |
| can_disable_adapter = False | |
| was_adapter_active = False | |
| if self.adapter is not None and (isinstance(self.adapter, IPAdapter) or | |
| isinstance(self.adapter, ReferenceAdapter) or | |
| (isinstance(self.adapter, CustomAdapter)) | |
| ): | |
| can_disable_adapter = True | |
| was_adapter_active = self.adapter.is_active | |
| self.adapter.is_active = False | |
| if self.train_config.unload_text_encoder and self.adapter is not None and not isinstance(self.adapter, CustomAdapter): | |
| raise ValueError("Prior predictions currently do not support unloading text encoder with adapter") | |
| # do a prediction here so we can match its output with network multiplier set to 0.0 | |
| with torch.no_grad(): | |
| dtype = get_torch_dtype(self.train_config.dtype) | |
| embeds_to_use = conditional_embeds.clone().detach() | |
| # handle clip vision adapter by removing triggers from prompt and replacing with the class name | |
| if (self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter)) or self.embedding is not None: | |
| prompt_list = batch.get_caption_list() | |
| class_name = '' | |
| triggers = ['[trigger]', '[name]'] | |
| remove_tokens = [] | |
| if self.embed_config is not None: | |
| triggers.append(self.embed_config.trigger) | |
| for i in range(1, self.embed_config.tokens): | |
| remove_tokens.append(f"{self.embed_config.trigger}_{i}") | |
| if self.embed_config.trigger_class_name is not None: | |
| class_name = self.embed_config.trigger_class_name | |
| if self.adapter is not None: | |
| triggers.append(self.adapter_config.trigger) | |
| for i in range(1, self.adapter_config.num_tokens): | |
| remove_tokens.append(f"{self.adapter_config.trigger}_{i}") | |
| if self.adapter_config.trigger_class_name is not None: | |
| class_name = self.adapter_config.trigger_class_name | |
| for idx, prompt in enumerate(prompt_list): | |
| for remove_token in remove_tokens: | |
| prompt = prompt.replace(remove_token, '') | |
| for trigger in triggers: | |
| prompt = prompt.replace(trigger, class_name) | |
| prompt_list[idx] = prompt | |
| if batch.prompt_embeds is not None: | |
| embeds_to_use = batch.prompt_embeds.clone().to(self.device_torch, dtype=dtype) | |
| else: | |
| prompt_kwargs = {} | |
| if self.sd.encode_control_in_text_embeddings and batch.control_tensor is not None: | |
| prompt_kwargs['control_images'] = batch.control_tensor.to(self.sd.device_torch, dtype=self.sd.torch_dtype) | |
| embeds_to_use = self.sd.encode_prompt( | |
| prompt_list, | |
| long_prompts=self.do_long_prompts).to( | |
| self.device_torch, | |
| dtype=dtype, | |
| **prompt_kwargs | |
| ).detach() | |
| # dont use network on this | |
| # self.network.multiplier = 0.0 | |
| self.sd.unet.eval() | |
| if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not self.sd.is_flux and not self.sd.is_lumina2: | |
| # we need to remove the image embeds from the prompt except for flux | |
| embeds_to_use: PromptEmbeds = embeds_to_use.clone().detach() | |
| end_pos = embeds_to_use.text_embeds.shape[1] - self.adapter_config.num_tokens | |
| embeds_to_use.text_embeds = embeds_to_use.text_embeds[:, :end_pos, :] | |
| if unconditional_embeds is not None: | |
| unconditional_embeds = unconditional_embeds.clone().detach() | |
| unconditional_embeds.text_embeds = unconditional_embeds.text_embeds[:, :end_pos] | |
| if unconditional_embeds is not None: | |
| unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() | |
| guidance_embedding_scale = self.train_config.cfg_scale | |
| if self.train_config.do_guidance_loss: | |
| guidance_embedding_scale = self._guidance_loss_target_batch | |
| prior_pred = self.sd.predict_noise( | |
| latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(), | |
| conditional_embeddings=embeds_to_use.to(self.device_torch, dtype=dtype).detach(), | |
| unconditional_embeddings=unconditional_embeds, | |
| timestep=timesteps, | |
| guidance_scale=self.train_config.cfg_scale, | |
| guidance_embedding_scale=guidance_embedding_scale, | |
| rescale_cfg=self.train_config.cfg_rescale, | |
| batch=batch, | |
| **pred_kwargs # adapter residuals in here | |
| ) | |
| if was_unet_training: | |
| self.sd.unet.train() | |
| prior_pred = prior_pred.detach() | |
| # remove the residuals as we wont use them on prediction when matching control | |
| if match_adapter_assist and 'down_intrablock_additional_residuals' in pred_kwargs: | |
| del pred_kwargs['down_intrablock_additional_residuals'] | |
| if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs: | |
| del pred_kwargs['down_block_additional_residuals'] | |
| if match_adapter_assist and 'mid_block_additional_residual' in pred_kwargs: | |
| del pred_kwargs['mid_block_additional_residual'] | |
| if can_disable_adapter: | |
| self.adapter.is_active = was_adapter_active | |
| # restore network | |
| # self.network.multiplier = network_weight_list | |
| if self.network is not None: | |
| self.network.is_active = was_network_active | |
| return prior_pred | |
| def before_unet_predict(self): | |
| pass | |
| def after_unet_predict(self): | |
| pass | |
| def end_of_training_loop(self): | |
| pass | |
| def predict_noise( | |
| self, | |
| noisy_latents: torch.Tensor, | |
| timesteps: Union[int, torch.Tensor] = 1, | |
| conditional_embeds: Union[PromptEmbeds, None] = None, | |
| unconditional_embeds: Union[PromptEmbeds, None] = None, | |
| batch: Optional['DataLoaderBatchDTO'] = None, | |
| is_primary_pred: bool = False, | |
| **kwargs, | |
| ): | |
| dtype = get_torch_dtype(self.train_config.dtype) | |
| guidance_embedding_scale = self.train_config.cfg_scale | |
| if self.train_config.do_guidance_loss: | |
| guidance_embedding_scale = self._guidance_loss_target_batch | |
| return self.sd.predict_noise( | |
| latents=noisy_latents.to(self.device_torch, dtype=dtype), | |
| conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), | |
| unconditional_embeddings=unconditional_embeds, | |
| timestep=timesteps, | |
| guidance_scale=self.train_config.cfg_scale, | |
| guidance_embedding_scale=guidance_embedding_scale, | |
| detach_unconditional=False, | |
| rescale_cfg=self.train_config.cfg_rescale, | |
| bypass_guidance_embedding=self.train_config.bypass_guidance_embedding, | |
| batch=batch, | |
| **kwargs | |
| ) | |
| def train_single_accumulation(self, batch: DataLoaderBatchDTO): | |
| with torch.no_grad(): | |
| self.timer.start('preprocess_batch') | |
| if isinstance(self.adapter, CustomAdapter): | |
| batch = self.adapter.edit_batch_raw(batch) | |
| batch = self.preprocess_batch(batch) | |
| if isinstance(self.adapter, CustomAdapter): | |
| batch = self.adapter.edit_batch_processed(batch) | |
| dtype = get_torch_dtype(self.train_config.dtype) | |
| # sanity check | |
| if self.sd.vae.dtype != self.sd.vae_torch_dtype: | |
| self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype) | |
| if isinstance(self.sd.text_encoder, list): | |
| for encoder in self.sd.text_encoder: | |
| if encoder.dtype != self.sd.te_torch_dtype: | |
| encoder.to(self.sd.te_torch_dtype) | |
| else: | |
| if self.sd.text_encoder.dtype != self.sd.te_torch_dtype: | |
| self.sd.text_encoder.to(self.sd.te_torch_dtype) | |
| noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) | |
| if self.train_config.do_cfg or self.train_config.do_random_cfg: | |
| # pick random negative prompts | |
| if self.negative_prompt_pool is not None: | |
| negative_prompts = [] | |
| for i in range(noisy_latents.shape[0]): | |
| num_neg = random.randint(1, self.train_config.max_negative_prompts) | |
| this_neg_prompts = [random.choice(self.negative_prompt_pool) for _ in range(num_neg)] | |
| this_neg_prompt = ', '.join(this_neg_prompts) | |
| negative_prompts.append(this_neg_prompt) | |
| self.batch_negative_prompt = negative_prompts | |
| else: | |
| self.batch_negative_prompt = ['' for _ in range(batch.latents.shape[0])] | |
| if self.adapter and isinstance(self.adapter, CustomAdapter): | |
| # condition the prompt | |
| # todo handle more than one adapter image | |
| conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts) | |
| network_weight_list = batch.get_network_weight_list() | |
| if self.train_config.single_item_batching: | |
| network_weight_list = network_weight_list + network_weight_list | |
| has_adapter_img = batch.control_tensor is not None | |
| has_clip_image = batch.clip_image_tensor is not None | |
| has_clip_image_embeds = batch.clip_image_embeds is not None | |
| # force it to be true if doing regs as we handle those differently | |
| if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]): | |
| has_clip_image = True | |
| if self._clip_image_embeds_unconditional is not None: | |
| has_clip_image_embeds = True # we are caching embeds, handle that differently | |
| has_clip_image = False | |
| if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img: | |
| raise ValueError( | |
| "IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ") | |
| match_adapter_assist = False | |
| # check if we are matching the adapter assistant | |
| if self.assistant_adapter: | |
| if self.train_config.match_adapter_chance == 1.0: | |
| match_adapter_assist = True | |
| elif self.train_config.match_adapter_chance > 0.0: | |
| match_adapter_assist = torch.rand( | |
| (1,), device=self.device_torch, dtype=dtype | |
| ) < self.train_config.match_adapter_chance | |
| self.timer.stop('preprocess_batch') | |
| is_reg = False | |
| loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype) | |
| for idx, file_item in enumerate(batch.file_items): | |
| if file_item.is_reg: | |
| loss_multiplier[idx] = loss_multiplier[idx] * self.train_config.reg_weight | |
| is_reg = True | |
| adapter_images = None | |
| sigmas = None | |
| if has_adapter_img and (self.adapter or self.assistant_adapter): | |
| with self.timer('get_adapter_images'): | |
| # todo move this to data loader | |
| if batch.control_tensor is not None: | |
| adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach() | |
| # match in channels | |
| if self.assistant_adapter is not None: | |
| in_channels = self.assistant_adapter.config.in_channels | |
| if adapter_images.shape[1] != in_channels: | |
| # we need to match the channels | |
| adapter_images = adapter_images[:, :in_channels, :, :] | |
| else: | |
| raise NotImplementedError("Adapter images now must be loaded with dataloader") | |
| clip_images = None | |
| if has_clip_image: | |
| with self.timer('get_clip_images'): | |
| # todo move this to data loader | |
| if batch.clip_image_tensor is not None: | |
| clip_images = batch.clip_image_tensor.to(self.device_torch, dtype=dtype).detach() | |
| mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype) | |
| if batch.mask_tensor is not None: | |
| with self.timer('get_mask_multiplier'): | |
| # upsampling no supported for bfloat16 | |
| mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach() | |
| # scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height) | |
| mask_multiplier = torch.nn.functional.interpolate( | |
| mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3]) | |
| ) | |
| # expand to match latents | |
| mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1) | |
| mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach() | |
| # make avg 1.0 | |
| mask_multiplier = mask_multiplier / mask_multiplier.mean() | |
| def get_adapter_multiplier(): | |
| if self.adapter and isinstance(self.adapter, T2IAdapter): | |
| # training a t2i adapter, not using as assistant. | |
| return 1.0 | |
| elif match_adapter_assist: | |
| # training a texture. We want it high | |
| adapter_strength_min = 0.9 | |
| adapter_strength_max = 1.0 | |
| else: | |
| # training with assistance, we want it low | |
| # adapter_strength_min = 0.4 | |
| # adapter_strength_max = 0.7 | |
| adapter_strength_min = 0.5 | |
| adapter_strength_max = 1.1 | |
| adapter_conditioning_scale = torch.rand( | |
| (1,), device=self.device_torch, dtype=dtype | |
| ) | |
| adapter_conditioning_scale = value_map( | |
| adapter_conditioning_scale, | |
| 0.0, | |
| 1.0, | |
| adapter_strength_min, | |
| adapter_strength_max | |
| ) | |
| return adapter_conditioning_scale | |
| # flush() | |
| with self.timer('grad_setup'): | |
| # text encoding | |
| grad_on_text_encoder = False | |
| if self.train_config.train_text_encoder: | |
| grad_on_text_encoder = True | |
| if self.embedding is not None: | |
| grad_on_text_encoder = True | |
| if self.adapter and isinstance(self.adapter, ClipVisionAdapter): | |
| grad_on_text_encoder = True | |
| if self.adapter_config and self.adapter_config.type == 'te_augmenter': | |
| grad_on_text_encoder = True | |
| # have a blank network so we can wrap it in a context and set multipliers without checking every time | |
| if self.network is not None: | |
| network = self.network | |
| else: | |
| network = BlankNetwork() | |
| # set the weights | |
| network.multiplier = network_weight_list | |
| # activate network if it exits | |
| prompts_1 = conditioned_prompts | |
| prompts_2 = None | |
| if self.train_config.short_and_long_captions_encoder_split and self.sd.is_xl: | |
| prompts_1 = batch.get_caption_short_list() | |
| prompts_2 = conditioned_prompts | |
| # make the batch splits | |
| if self.train_config.single_item_batching: | |
| if self.model_config.refiner_name_or_path is not None: | |
| raise ValueError("Single item batching is not supported when training the refiner") | |
| batch_size = noisy_latents.shape[0] | |
| # chunk/split everything | |
| noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0) | |
| noise_list = torch.chunk(noise, batch_size, dim=0) | |
| timesteps_list = torch.chunk(timesteps, batch_size, dim=0) | |
| conditioned_prompts_list = [[prompt] for prompt in prompts_1] | |
| if imgs is not None: | |
| imgs_list = torch.chunk(imgs, batch_size, dim=0) | |
| else: | |
| imgs_list = [None for _ in range(batch_size)] | |
| if adapter_images is not None: | |
| adapter_images_list = torch.chunk(adapter_images, batch_size, dim=0) | |
| else: | |
| adapter_images_list = [None for _ in range(batch_size)] | |
| if clip_images is not None: | |
| clip_images_list = torch.chunk(clip_images, batch_size, dim=0) | |
| else: | |
| clip_images_list = [None for _ in range(batch_size)] | |
| mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0) | |
| if prompts_2 is None: | |
| prompt_2_list = [None for _ in range(batch_size)] | |
| else: | |
| prompt_2_list = [[prompt] for prompt in prompts_2] | |
| else: | |
| noisy_latents_list = [noisy_latents] | |
| noise_list = [noise] | |
| timesteps_list = [timesteps] | |
| conditioned_prompts_list = [prompts_1] | |
| imgs_list = [imgs] | |
| adapter_images_list = [adapter_images] | |
| clip_images_list = [clip_images] | |
| mask_multiplier_list = [mask_multiplier] | |
| if prompts_2 is None: | |
| prompt_2_list = [None] | |
| else: | |
| prompt_2_list = [prompts_2] | |
| for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, clip_images, mask_multiplier, prompt_2 in zip( | |
| noisy_latents_list, | |
| noise_list, | |
| timesteps_list, | |
| conditioned_prompts_list, | |
| imgs_list, | |
| adapter_images_list, | |
| clip_images_list, | |
| mask_multiplier_list, | |
| prompt_2_list | |
| ): | |
| # if self.train_config.negative_prompt is not None: | |
| # # add negative prompt | |
| # conditioned_prompts = conditioned_prompts + [self.train_config.negative_prompt for x in | |
| # range(len(conditioned_prompts))] | |
| # if prompt_2 is not None: | |
| # prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))] | |
| with (network): | |
| # encode clip adapter here so embeds are active for tokenizer | |
| if self.adapter and isinstance(self.adapter, ClipVisionAdapter): | |
| with self.timer('encode_clip_vision_embeds'): | |
| if has_clip_image: | |
| conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( | |
| clip_images.detach().to(self.device_torch, dtype=dtype), | |
| is_training=True, | |
| has_been_preprocessed=True | |
| ) | |
| else: | |
| # just do a blank one | |
| conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( | |
| torch.zeros( | |
| (noisy_latents.shape[0], 3, 512, 512), | |
| device=self.device_torch, dtype=dtype | |
| ), | |
| is_training=True, | |
| has_been_preprocessed=True, | |
| drop=True | |
| ) | |
| # it will be injected into the tokenizer when called | |
| self.adapter(conditional_clip_embeds) | |
| # do the custom adapter after the prior prediction | |
| if self.adapter and isinstance(self.adapter, CustomAdapter) and (has_clip_image or is_reg): | |
| quad_count = random.randint(1, 4) | |
| self.adapter.train() | |
| self.adapter.trigger_pre_te( | |
| tensors_preprocessed=clip_images if not is_reg else None, # on regs we send none to get random noise | |
| is_training=True, | |
| has_been_preprocessed=True, | |
| quad_count=quad_count, | |
| batch_tensor=batch.tensor if not is_reg else None, | |
| batch_size=noisy_latents.shape[0] | |
| ) | |
| with self.timer('encode_prompt'): | |
| unconditional_embeds = None | |
| prompt_kwargs = {} | |
| if self.sd.encode_control_in_text_embeddings and batch.control_tensor is not None: | |
| prompt_kwargs['control_images'] = batch.control_tensor.to(self.sd.device_torch, dtype=self.sd.torch_dtype) | |
| if self.train_config.unload_text_encoder or self.is_caching_text_embeddings: | |
| with torch.set_grad_enabled(False): | |
| if batch.prompt_embeds is not None: | |
| # use the cached embeds | |
| conditional_embeds = batch.prompt_embeds.clone().detach().to( | |
| self.device_torch, dtype=dtype | |
| ) | |
| else: | |
| embeds_to_use = self.cached_blank_embeds.clone().detach().to( | |
| self.device_torch, dtype=dtype | |
| ) | |
| if self.cached_trigger_embeds is not None and not is_reg: | |
| embeds_to_use = self.cached_trigger_embeds.clone().detach().to( | |
| self.device_torch, dtype=dtype | |
| ) | |
| conditional_embeds = concat_prompt_embeds( | |
| [embeds_to_use] * noisy_latents.shape[0] | |
| ) | |
| if self.train_config.do_cfg: | |
| unconditional_embeds = self.cached_blank_embeds.clone().detach().to( | |
| self.device_torch, dtype=dtype | |
| ) | |
| unconditional_embeds = concat_prompt_embeds( | |
| [unconditional_embeds] * noisy_latents.shape[0] | |
| ) | |
| if isinstance(self.adapter, CustomAdapter): | |
| self.adapter.is_unconditional_run = False | |
| elif grad_on_text_encoder: | |
| with torch.set_grad_enabled(True): | |
| if isinstance(self.adapter, CustomAdapter): | |
| self.adapter.is_unconditional_run = False | |
| conditional_embeds = self.sd.encode_prompt( | |
| conditioned_prompts, prompt_2, | |
| dropout_prob=self.train_config.prompt_dropout_prob, | |
| long_prompts=self.do_long_prompts, | |
| **prompt_kwargs | |
| ).to( | |
| self.device_torch, | |
| dtype=dtype) | |
| if self.train_config.do_cfg: | |
| if isinstance(self.adapter, CustomAdapter): | |
| self.adapter.is_unconditional_run = True | |
| # todo only do one and repeat it | |
| unconditional_embeds = self.sd.encode_prompt( | |
| self.batch_negative_prompt, | |
| self.batch_negative_prompt, | |
| dropout_prob=self.train_config.prompt_dropout_prob, | |
| long_prompts=self.do_long_prompts, | |
| **prompt_kwargs | |
| ).to( | |
| self.device_torch, | |
| dtype=dtype) | |
| if isinstance(self.adapter, CustomAdapter): | |
| self.adapter.is_unconditional_run = False | |
| else: | |
| with torch.set_grad_enabled(False): | |
| # make sure it is in eval mode | |
| if isinstance(self.sd.text_encoder, list): | |
| for te in self.sd.text_encoder: | |
| te.eval() | |
| else: | |
| self.sd.text_encoder.eval() | |
| if isinstance(self.adapter, CustomAdapter): | |
| self.adapter.is_unconditional_run = False | |
| conditional_embeds = self.sd.encode_prompt( | |
| conditioned_prompts, prompt_2, | |
| dropout_prob=self.train_config.prompt_dropout_prob, | |
| long_prompts=self.do_long_prompts, | |
| **prompt_kwargs | |
| ).to( | |
| self.device_torch, | |
| dtype=dtype) | |
| if self.train_config.do_cfg: | |
| if isinstance(self.adapter, CustomAdapter): | |
| self.adapter.is_unconditional_run = True | |
| unconditional_embeds = self.sd.encode_prompt( | |
| self.batch_negative_prompt, | |
| dropout_prob=self.train_config.prompt_dropout_prob, | |
| long_prompts=self.do_long_prompts, | |
| **prompt_kwargs | |
| ).to( | |
| self.device_torch, | |
| dtype=dtype) | |
| if isinstance(self.adapter, CustomAdapter): | |
| self.adapter.is_unconditional_run = False | |
| if self.train_config.diff_output_preservation: | |
| dop_prompts = [p.replace(self.trigger_word, self.train_config.diff_output_preservation_class) for p in conditioned_prompts] | |
| dop_prompts_2 = None | |
| if prompt_2 is not None: | |
| dop_prompts_2 = [p.replace(self.trigger_word, self.train_config.diff_output_preservation_class) for p in prompt_2] | |
| self.diff_output_preservation_embeds = self.sd.encode_prompt( | |
| dop_prompts, dop_prompts_2, | |
| dropout_prob=self.train_config.prompt_dropout_prob, | |
| long_prompts=self.do_long_prompts, | |
| **prompt_kwargs | |
| ).to( | |
| self.device_torch, | |
| dtype=dtype) | |
| # detach the embeddings | |
| conditional_embeds = conditional_embeds.detach() | |
| if self.train_config.do_cfg: | |
| unconditional_embeds = unconditional_embeds.detach() | |
| if self.decorator: | |
| conditional_embeds.text_embeds = self.decorator( | |
| conditional_embeds.text_embeds | |
| ) | |
| if self.train_config.do_cfg: | |
| unconditional_embeds.text_embeds = self.decorator( | |
| unconditional_embeds.text_embeds, | |
| is_unconditional=True | |
| ) | |
| # flush() | |
| pred_kwargs = {} | |
| if has_adapter_img: | |
| if (self.adapter and isinstance(self.adapter, T2IAdapter)) or ( | |
| self.assistant_adapter and isinstance(self.assistant_adapter, T2IAdapter)): | |
| with torch.set_grad_enabled(self.adapter is not None): | |
| adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter | |
| adapter_multiplier = get_adapter_multiplier() | |
| with self.timer('encode_adapter'): | |
| down_block_additional_residuals = adapter(adapter_images) | |
| if self.assistant_adapter: | |
| # not training. detach | |
| down_block_additional_residuals = [ | |
| sample.to(dtype=dtype).detach() * adapter_multiplier for sample in | |
| down_block_additional_residuals | |
| ] | |
| else: | |
| down_block_additional_residuals = [ | |
| sample.to(dtype=dtype) * adapter_multiplier for sample in | |
| down_block_additional_residuals | |
| ] | |
| pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals | |
| if self.adapter and isinstance(self.adapter, IPAdapter): | |
| with self.timer('encode_adapter_embeds'): | |
| # number of images to do if doing a quad image | |
| quad_count = random.randint(1, 4) | |
| image_size = self.adapter.input_size | |
| if has_clip_image_embeds: | |
| # todo handle reg images better than this | |
| if is_reg: | |
| # get unconditional image embeds from cache | |
| embeds = [ | |
| load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in | |
| range(noisy_latents.shape[0]) | |
| ] | |
| conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( | |
| embeds, | |
| quad_count=quad_count | |
| ) | |
| if self.train_config.do_cfg: | |
| embeds = [ | |
| load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in | |
| range(noisy_latents.shape[0]) | |
| ] | |
| unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( | |
| embeds, | |
| quad_count=quad_count | |
| ) | |
| else: | |
| conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( | |
| batch.clip_image_embeds, | |
| quad_count=quad_count | |
| ) | |
| if self.train_config.do_cfg: | |
| unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( | |
| batch.clip_image_embeds_unconditional, | |
| quad_count=quad_count | |
| ) | |
| elif is_reg: | |
| # we will zero it out in the img embedder | |
| clip_images = torch.zeros( | |
| (noisy_latents.shape[0], 3, image_size, image_size), | |
| device=self.device_torch, dtype=dtype | |
| ).detach() | |
| # drop will zero it out | |
| conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( | |
| clip_images, | |
| drop=True, | |
| is_training=True, | |
| has_been_preprocessed=False, | |
| quad_count=quad_count | |
| ) | |
| if self.train_config.do_cfg: | |
| unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( | |
| torch.zeros( | |
| (noisy_latents.shape[0], 3, image_size, image_size), | |
| device=self.device_torch, dtype=dtype | |
| ).detach(), | |
| is_training=True, | |
| drop=True, | |
| has_been_preprocessed=False, | |
| quad_count=quad_count | |
| ) | |
| elif has_clip_image: | |
| conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( | |
| clip_images.detach().to(self.device_torch, dtype=dtype), | |
| is_training=True, | |
| has_been_preprocessed=True, | |
| quad_count=quad_count, | |
| # do cfg on clip embeds to normalize the embeddings for when doing cfg | |
| # cfg_embed_strength=3.0 if not self.train_config.do_cfg else None | |
| # cfg_embed_strength=3.0 if not self.train_config.do_cfg else None | |
| ) | |
| if self.train_config.do_cfg: | |
| unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( | |
| clip_images.detach().to(self.device_torch, dtype=dtype), | |
| is_training=True, | |
| drop=True, | |
| has_been_preprocessed=True, | |
| quad_count=quad_count | |
| ) | |
| else: | |
| print_acc("No Clip Image") | |
| print_acc([file_item.path for file_item in batch.file_items]) | |
| raise ValueError("Could not find clip image") | |
| if not self.adapter_config.train_image_encoder: | |
| # we are not training the image encoder, so we need to detach the embeds | |
| conditional_clip_embeds = conditional_clip_embeds.detach() | |
| if self.train_config.do_cfg: | |
| unconditional_clip_embeds = unconditional_clip_embeds.detach() | |
| with self.timer('encode_adapter'): | |
| self.adapter.train() | |
| conditional_embeds = self.adapter( | |
| conditional_embeds.detach(), | |
| conditional_clip_embeds, | |
| is_unconditional=False | |
| ) | |
| if self.train_config.do_cfg: | |
| unconditional_embeds = self.adapter( | |
| unconditional_embeds.detach(), | |
| unconditional_clip_embeds, | |
| is_unconditional=True | |
| ) | |
| else: | |
| # wipe out unconsitional | |
| self.adapter.last_unconditional = None | |
| if self.adapter and isinstance(self.adapter, ReferenceAdapter): | |
| # pass in our scheduler | |
| self.adapter.noise_scheduler = self.lr_scheduler | |
| if has_clip_image or has_adapter_img: | |
| img_to_use = clip_images if has_clip_image else adapter_images | |
| # currently 0-1 needs to be -1 to 1 | |
| reference_images = ((img_to_use - 0.5) * 2).detach().to(self.device_torch, dtype=dtype) | |
| self.adapter.set_reference_images(reference_images) | |
| self.adapter.noise_scheduler = self.sd.noise_scheduler | |
| elif is_reg: | |
| self.adapter.set_blank_reference_images(noisy_latents.shape[0]) | |
| else: | |
| self.adapter.set_reference_images(None) | |
| prior_pred = None | |
| do_reg_prior = False | |
| # if is_reg and (self.network is not None or self.adapter is not None): | |
| # # we are doing a reg image and we have a network or adapter | |
| # do_reg_prior = True | |
| do_inverted_masked_prior = False | |
| if self.train_config.inverted_mask_prior and batch.mask_tensor is not None: | |
| do_inverted_masked_prior = True | |
| do_correct_pred_norm_prior = self.train_config.correct_pred_norm | |
| do_guidance_prior = False | |
| if batch.unconditional_latents is not None: | |
| # for this not that, we need a prior pred to normalize | |
| guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type | |
| if guidance_type == 'tnt': | |
| do_guidance_prior = True | |
| if (( | |
| has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_guidance_prior or do_reg_prior or do_inverted_masked_prior or self.train_config.correct_pred_norm): | |
| with self.timer('prior predict'): | |
| prior_embeds_to_use = conditional_embeds | |
| # use diff_output_preservation embeds if doing dfe | |
| if self.train_config.diff_output_preservation: | |
| prior_embeds_to_use = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0]) | |
| prior_pred = self.get_prior_prediction( | |
| noisy_latents=noisy_latents, | |
| conditional_embeds=prior_embeds_to_use, | |
| match_adapter_assist=match_adapter_assist, | |
| network_weight_list=network_weight_list, | |
| timesteps=timesteps, | |
| pred_kwargs=pred_kwargs, | |
| noise=noise, | |
| batch=batch, | |
| unconditional_embeds=unconditional_embeds, | |
| conditioned_prompts=conditioned_prompts | |
| ) | |
| if prior_pred is not None: | |
| prior_pred = prior_pred.detach() | |
| # do the custom adapter after the prior prediction | |
| if self.adapter and isinstance(self.adapter, CustomAdapter) and (has_clip_image or self.adapter_config.type in ['llm_adapter', 'text_encoder']): | |
| quad_count = random.randint(1, 4) | |
| self.adapter.train() | |
| conditional_embeds = self.adapter.condition_encoded_embeds( | |
| tensors_0_1=clip_images, | |
| prompt_embeds=conditional_embeds, | |
| is_training=True, | |
| has_been_preprocessed=True, | |
| quad_count=quad_count | |
| ) | |
| if self.train_config.do_cfg and unconditional_embeds is not None: | |
| unconditional_embeds = self.adapter.condition_encoded_embeds( | |
| tensors_0_1=clip_images, | |
| prompt_embeds=unconditional_embeds, | |
| is_training=True, | |
| has_been_preprocessed=True, | |
| is_unconditional=True, | |
| quad_count=quad_count | |
| ) | |
| if self.adapter and isinstance(self.adapter, CustomAdapter) and batch.extra_values is not None: | |
| self.adapter.add_extra_values(batch.extra_values.detach()) | |
| if self.train_config.do_cfg: | |
| self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()), | |
| is_unconditional=True) | |
| if has_adapter_img: | |
| if (self.adapter and isinstance(self.adapter, ControlNetModel)) or ( | |
| self.assistant_adapter and isinstance(self.assistant_adapter, ControlNetModel)): | |
| if self.train_config.do_cfg: | |
| raise ValueError("ControlNetModel is not supported with CFG") | |
| with torch.set_grad_enabled(self.adapter is not None): | |
| adapter: ControlNetModel = self.assistant_adapter if self.assistant_adapter is not None else self.adapter | |
| adapter_multiplier = get_adapter_multiplier() | |
| with self.timer('encode_adapter'): | |
| # add_text_embeds is pooled_prompt_embeds for sdxl | |
| added_cond_kwargs = {} | |
| if self.sd.is_xl: | |
| added_cond_kwargs["text_embeds"] = conditional_embeds.pooled_embeds | |
| added_cond_kwargs['time_ids'] = self.sd.get_time_ids_from_latents(noisy_latents) | |
| down_block_res_samples, mid_block_res_sample = adapter( | |
| noisy_latents, | |
| timesteps, | |
| encoder_hidden_states=conditional_embeds.text_embeds, | |
| controlnet_cond=adapter_images, | |
| conditioning_scale=1.0, | |
| guess_mode=False, | |
| added_cond_kwargs=added_cond_kwargs, | |
| return_dict=False, | |
| ) | |
| pred_kwargs['down_block_additional_residuals'] = down_block_res_samples | |
| pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample | |
| if self.train_config.do_guidance_loss and isinstance(self.train_config.guidance_loss_target, list): | |
| batch_size = noisy_latents.shape[0] | |
| # update the guidance value, random float between guidance_loss_target[0] and guidance_loss_target[1] | |
| self._guidance_loss_target_batch = [ | |
| random.uniform( | |
| self.train_config.guidance_loss_target[0], | |
| self.train_config.guidance_loss_target[1] | |
| ) for _ in range(batch_size) | |
| ] | |
| self.before_unet_predict() | |
| if unconditional_embeds is not None: | |
| unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() | |
| with self.timer('condition_noisy_latents'): | |
| # do it for the model | |
| noisy_latents = self.sd.condition_noisy_latents(noisy_latents, batch) | |
| if self.adapter and isinstance(self.adapter, CustomAdapter): | |
| noisy_latents = self.adapter.condition_noisy_latents(noisy_latents, batch) | |
| if self.train_config.timestep_type == 'next_sample': | |
| with self.timer('next_sample_step'): | |
| with torch.no_grad(): | |
| stepped_timestep_indicies = [self.sd.noise_scheduler.index_for_timestep(t) + 1 for t in timesteps] | |
| stepped_timesteps = [self.sd.noise_scheduler.timesteps[x] for x in stepped_timestep_indicies] | |
| stepped_timesteps = torch.stack(stepped_timesteps, dim=0) | |
| # do a sample at the current timestep and step it, then determine new noise | |
| next_sample_pred = self.predict_noise( | |
| noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), | |
| timesteps=timesteps, | |
| conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype), | |
| unconditional_embeds=unconditional_embeds, | |
| batch=batch, | |
| **pred_kwargs | |
| ) | |
| stepped_latents = self.sd.step_scheduler( | |
| next_sample_pred, | |
| noisy_latents, | |
| timesteps, | |
| self.sd.noise_scheduler | |
| ) | |
| # stepped latents is our new noisy latents. Now we need to determine noise in the current sample | |
| noisy_latents = stepped_latents | |
| original_samples = batch.latents.to(self.device_torch, dtype=dtype) | |
| # todo calc next timestep, for now this may work as it | |
| t_01 = (stepped_timesteps / 1000).to(original_samples.device) | |
| if len(stepped_latents.shape) == 4: | |
| t_01 = t_01.view(-1, 1, 1, 1) | |
| elif len(stepped_latents.shape) == 5: | |
| t_01 = t_01.view(-1, 1, 1, 1, 1) | |
| else: | |
| raise ValueError("Unknown stepped latents shape", stepped_latents.shape) | |
| next_sample_noise = (stepped_latents - (1.0 - t_01) * original_samples) / t_01 | |
| noise = next_sample_noise | |
| timesteps = stepped_timesteps | |
| # do a prior pred if we have an unconditional image, we will swap out the giadance later | |
| if batch.unconditional_latents is not None or self.do_guided_loss: | |
| # do guided loss | |
| loss = self.get_guided_loss( | |
| noisy_latents=noisy_latents, | |
| conditional_embeds=conditional_embeds, | |
| match_adapter_assist=match_adapter_assist, | |
| network_weight_list=network_weight_list, | |
| timesteps=timesteps, | |
| pred_kwargs=pred_kwargs, | |
| batch=batch, | |
| noise=noise, | |
| unconditional_embeds=unconditional_embeds, | |
| mask_multiplier=mask_multiplier, | |
| prior_pred=prior_pred, | |
| ) | |
| elif self.train_config.loss_type == 'mean_flow': | |
| loss = self.get_mean_flow_loss( | |
| noisy_latents=noisy_latents, | |
| conditional_embeds=conditional_embeds, | |
| match_adapter_assist=match_adapter_assist, | |
| network_weight_list=network_weight_list, | |
| timesteps=timesteps, | |
| pred_kwargs=pred_kwargs, | |
| batch=batch, | |
| noise=noise, | |
| unconditional_embeds=unconditional_embeds, | |
| prior_pred=prior_pred, | |
| ) | |
| else: | |
| with self.timer('predict_unet'): | |
| noise_pred = self.predict_noise( | |
| noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), | |
| timesteps=timesteps, | |
| conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype), | |
| unconditional_embeds=unconditional_embeds, | |
| batch=batch, | |
| is_primary_pred=True, | |
| **pred_kwargs | |
| ) | |
| self.after_unet_predict() | |
| with self.timer('calculate_loss'): | |
| noise = noise.to(self.device_torch, dtype=dtype).detach() | |
| prior_to_calculate_loss = prior_pred | |
| # if we are doing diff_output_preservation and not noing inverted masked prior | |
| # then we need to send none here so it will not target the prior | |
| if self.train_config.diff_output_preservation and not do_inverted_masked_prior: | |
| prior_to_calculate_loss = None | |
| loss = self.calculate_loss( | |
| noise_pred=noise_pred, | |
| noise=noise, | |
| noisy_latents=noisy_latents, | |
| timesteps=timesteps, | |
| batch=batch, | |
| mask_multiplier=mask_multiplier, | |
| prior_pred=prior_to_calculate_loss, | |
| ) | |
| if self.train_config.diff_output_preservation: | |
| # send the loss backwards otherwise checkpointing will fail | |
| self.accelerator.backward(loss) | |
| normal_loss = loss.detach() # dont send backward again | |
| dop_embeds = self.diff_output_preservation_embeds.expand_to_batch(noisy_latents.shape[0]) | |
| dop_pred = self.predict_noise( | |
| noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), | |
| timesteps=timesteps, | |
| conditional_embeds=dop_embeds.to(self.device_torch, dtype=dtype), | |
| unconditional_embeds=unconditional_embeds, | |
| batch=batch, | |
| **pred_kwargs | |
| ) | |
| dop_loss = torch.nn.functional.mse_loss(dop_pred, prior_pred) * self.train_config.diff_output_preservation_multiplier | |
| self.accelerator.backward(dop_loss) | |
| loss = normal_loss + dop_loss | |
| loss = loss.clone().detach() | |
| # require grad again so the backward wont fail | |
| loss.requires_grad_(True) | |
| # check if nan | |
| if torch.isnan(loss): | |
| print_acc("loss is nan") | |
| loss = torch.zeros_like(loss).requires_grad_(True) | |
| with self.timer('backward'): | |
| # todo we have multiplier seperated. works for now as res are not in same batch, but need to change | |
| loss = loss * loss_multiplier.mean() | |
| # IMPORTANT if gradient checkpointing do not leave with network when doing backward | |
| # it will destroy the gradients. This is because the network is a context manager | |
| # and will change the multipliers back to 0.0 when exiting. They will be | |
| # 0.0 for the backward pass and the gradients will be 0.0 | |
| # I spent weeks on fighting this. DON'T DO IT | |
| # with fsdp_overlap_step_with_backward(): | |
| # if self.is_bfloat: | |
| # loss.backward() | |
| # else: | |
| self.accelerator.backward(loss) | |
| return loss.detach() | |
| # flush() | |
| def hook_train_loop(self, batch: Union[DataLoaderBatchDTO, List[DataLoaderBatchDTO]]): | |
| if isinstance(batch, list): | |
| batch_list = batch | |
| else: | |
| batch_list = [batch] | |
| total_loss = None | |
| self.optimizer.zero_grad() | |
| for batch in batch_list: | |
| if self.sd.is_multistage: | |
| # handle multistage switching | |
| if self.steps_this_boundary >= self.train_config.switch_boundary_every or self.current_boundary_index not in self.sd.trainable_multistage_boundaries: | |
| # iterate to make sure we only train trainable_multistage_boundaries | |
| while True: | |
| self.steps_this_boundary = 0 | |
| self.current_boundary_index += 1 | |
| if self.current_boundary_index >= len(self.sd.multistage_boundaries): | |
| self.current_boundary_index = 0 | |
| if self.current_boundary_index in self.sd.trainable_multistage_boundaries: | |
| # if this boundary is trainable, we can stop looking | |
| break | |
| loss = self.train_single_accumulation(batch) | |
| self.steps_this_boundary += 1 | |
| if total_loss is None: | |
| total_loss = loss | |
| else: | |
| total_loss += loss | |
| if len(batch_list) > 1 and self.model_config.low_vram: | |
| torch.cuda.empty_cache() | |
| if not self.is_grad_accumulation_step: | |
| # fix this for multi params | |
| if self.train_config.optimizer != 'adafactor': | |
| if isinstance(self.params[0], dict): | |
| for i in range(len(self.params)): | |
| self.accelerator.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm) | |
| else: | |
| self.accelerator.clip_grad_norm_(self.params, self.train_config.max_grad_norm) | |
| # only step if we are not accumulating | |
| with self.timer('optimizer_step'): | |
| self.optimizer.step() | |
| self.optimizer.zero_grad(set_to_none=True) | |
| if self.adapter and isinstance(self.adapter, CustomAdapter): | |
| self.adapter.post_weight_update() | |
| if self.ema is not None: | |
| with self.timer('ema_update'): | |
| self.ema.update() | |
| else: | |
| # gradient accumulation. Just a place for breakpoint | |
| pass | |
| # TODO Should we only step scheduler on grad step? If so, need to recalculate last step | |
| with self.timer('scheduler_step'): | |
| self.lr_scheduler.step() | |
| if self.embedding is not None: | |
| with self.timer('restore_embeddings'): | |
| # Let's make sure we don't update any embedding weights besides the newly added token | |
| self.embedding.restore_embeddings() | |
| if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter): | |
| with self.timer('restore_adapter'): | |
| # Let's make sure we don't update any embedding weights besides the newly added token | |
| self.adapter.restore_embeddings() | |
| loss_dict = OrderedDict( | |
| {'loss': (total_loss / len(batch_list)).item()} | |
| ) | |
| self.end_of_training_loop() | |
| return loss_dict | |