Spaces:
Running
Running
| # LoRA network module | |
| # reference: | |
| # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py | |
| # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py | |
| # https://github.com/bmaltais/kohya_ss | |
| import hashlib | |
| import math | |
| import os | |
| from collections import defaultdict | |
| from io import BytesIO | |
| from typing import List, Optional, Type, Union | |
| import safetensors.torch | |
| import torch | |
| import torch.utils.checkpoint | |
| from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear | |
| from safetensors.torch import load_file | |
| from transformers import T5EncoderModel | |
| class LoRAModule(torch.nn.Module): | |
| """ | |
| replaces forward method of the original Linear, instead of replacing the original Linear module. | |
| """ | |
| def __init__( | |
| self, | |
| lora_name, | |
| org_module: torch.nn.Module, | |
| multiplier=1.0, | |
| lora_dim=4, | |
| alpha=1, | |
| dropout=None, | |
| rank_dropout=None, | |
| module_dropout=None, | |
| ): | |
| """if alpha == 0 or None, alpha is rank (no scaling).""" | |
| super().__init__() | |
| self.lora_name = lora_name | |
| if org_module.__class__.__name__ == "Conv2d": | |
| in_dim = org_module.in_channels | |
| out_dim = org_module.out_channels | |
| else: | |
| in_dim = org_module.in_features | |
| out_dim = org_module.out_features | |
| self.lora_dim = lora_dim | |
| if org_module.__class__.__name__ == "Conv2d": | |
| kernel_size = org_module.kernel_size | |
| stride = org_module.stride | |
| padding = org_module.padding | |
| self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) | |
| self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) | |
| else: | |
| self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) | |
| self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) | |
| if type(alpha) == torch.Tensor: | |
| alpha = alpha.detach().float().numpy() # without casting, bf16 causes error | |
| alpha = self.lora_dim if alpha is None or alpha == 0 else alpha | |
| self.scale = alpha / self.lora_dim | |
| self.register_buffer("alpha", torch.tensor(alpha)) | |
| # same as microsoft's | |
| torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) | |
| torch.nn.init.zeros_(self.lora_up.weight) | |
| self.multiplier = multiplier | |
| self.org_module = org_module # remove in applying | |
| self.dropout = dropout | |
| self.rank_dropout = rank_dropout | |
| self.module_dropout = module_dropout | |
| def apply_to(self): | |
| self.org_forward = self.org_module.forward | |
| self.org_module.forward = self.forward | |
| del self.org_module | |
| def forward(self, x, *args, **kwargs): | |
| weight_dtype = x.dtype | |
| org_forwarded = self.org_forward(x) | |
| # module dropout | |
| if self.module_dropout is not None and self.training: | |
| if torch.rand(1) < self.module_dropout: | |
| return org_forwarded | |
| lx = self.lora_down(x.to(self.lora_down.weight.dtype)) | |
| # normal dropout | |
| if self.dropout is not None and self.training: | |
| lx = torch.nn.functional.dropout(lx, p=self.dropout) | |
| # rank dropout | |
| if self.rank_dropout is not None and self.training: | |
| mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout | |
| if len(lx.size()) == 3: | |
| mask = mask.unsqueeze(1) # for Text Encoder | |
| elif len(lx.size()) == 4: | |
| mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d | |
| lx = lx * mask | |
| # scaling for rank dropout: treat as if the rank is changed | |
| scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability | |
| else: | |
| scale = self.scale | |
| lx = self.lora_up(lx) | |
| return org_forwarded.to(weight_dtype) + lx.to(weight_dtype) * self.multiplier * scale | |
| def addnet_hash_legacy(b): | |
| """Old model hash used by sd-webui-additional-networks for .safetensors format files""" | |
| m = hashlib.sha256() | |
| b.seek(0x100000) | |
| m.update(b.read(0x10000)) | |
| return m.hexdigest()[0:8] | |
| def addnet_hash_safetensors(b): | |
| """New model hash used by sd-webui-additional-networks for .safetensors format files""" | |
| hash_sha256 = hashlib.sha256() | |
| blksize = 1024 * 1024 | |
| b.seek(0) | |
| header = b.read(8) | |
| n = int.from_bytes(header, "little") | |
| offset = n + 8 | |
| b.seek(offset) | |
| for chunk in iter(lambda: b.read(blksize), b""): | |
| hash_sha256.update(chunk) | |
| return hash_sha256.hexdigest() | |
| def precalculate_safetensors_hashes(tensors, metadata): | |
| """Precalculate the model hashes needed by sd-webui-additional-networks to | |
| save time on indexing the model later.""" | |
| # Because writing user metadata to the file can change the result of | |
| # sd_models.model_hash(), only retain the training metadata for purposes of | |
| # calculating the hash, as they are meant to be immutable | |
| metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} | |
| bytes = safetensors.torch.save(tensors, metadata) | |
| b = BytesIO(bytes) | |
| model_hash = addnet_hash_safetensors(b) | |
| legacy_hash = addnet_hash_legacy(b) | |
| return model_hash, legacy_hash | |
| class LoRANetwork(torch.nn.Module): | |
| TRANSFORMER_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Transformer3DModel", "HunyuanTransformer3DModel", "EasyAnimateTransformer3DModel"] | |
| TEXT_ENCODER_TARGET_REPLACE_MODULE = ["T5LayerSelfAttention", "T5LayerFF", "BertEncoder"] | |
| LORA_PREFIX_TRANSFORMER = "lora_unet" | |
| LORA_PREFIX_TEXT_ENCODER = "lora_te" | |
| def __init__( | |
| self, | |
| text_encoder: Union[List[T5EncoderModel], T5EncoderModel], | |
| unet, | |
| multiplier: float = 1.0, | |
| lora_dim: int = 4, | |
| alpha: float = 1, | |
| dropout: Optional[float] = None, | |
| module_class: Type[object] = LoRAModule, | |
| add_lora_in_attn_temporal: bool = False, | |
| varbose: Optional[bool] = False, | |
| ) -> None: | |
| super().__init__() | |
| self.multiplier = multiplier | |
| self.lora_dim = lora_dim | |
| self.alpha = alpha | |
| self.dropout = dropout | |
| print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") | |
| print(f"neuron dropout: p={self.dropout}") | |
| # create module instances | |
| def create_modules( | |
| is_unet: bool, | |
| root_module: torch.nn.Module, | |
| target_replace_modules: List[torch.nn.Module], | |
| ) -> List[LoRAModule]: | |
| prefix = ( | |
| self.LORA_PREFIX_TRANSFORMER | |
| if is_unet | |
| else self.LORA_PREFIX_TEXT_ENCODER | |
| ) | |
| loras = [] | |
| skipped = [] | |
| for name, module in root_module.named_modules(): | |
| if module.__class__.__name__ in target_replace_modules: | |
| for child_name, child_module in module.named_modules(): | |
| is_linear = child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear" | |
| is_conv2d = child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv" | |
| is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) | |
| if not add_lora_in_attn_temporal: | |
| if "attn_temporal" in child_name: | |
| continue | |
| if is_linear or is_conv2d: | |
| lora_name = prefix + "." + name + "." + child_name | |
| lora_name = lora_name.replace(".", "_") | |
| dim = None | |
| alpha = None | |
| if is_linear or is_conv2d_1x1: | |
| dim = self.lora_dim | |
| alpha = self.alpha | |
| if dim is None or dim == 0: | |
| if is_linear or is_conv2d_1x1: | |
| skipped.append(lora_name) | |
| continue | |
| lora = module_class( | |
| lora_name, | |
| child_module, | |
| self.multiplier, | |
| dim, | |
| alpha, | |
| dropout=dropout, | |
| ) | |
| loras.append(lora) | |
| return loras, skipped | |
| text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] | |
| self.text_encoder_loras = [] | |
| skipped_te = [] | |
| for i, text_encoder in enumerate(text_encoders): | |
| if text_encoder is not None: | |
| text_encoder_loras, skipped = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) | |
| self.text_encoder_loras.extend(text_encoder_loras) | |
| skipped_te += skipped | |
| print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") | |
| self.unet_loras, skipped_un = create_modules(True, unet, LoRANetwork.TRANSFORMER_TARGET_REPLACE_MODULE) | |
| print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") | |
| # assertion | |
| names = set() | |
| for lora in self.text_encoder_loras + self.unet_loras: | |
| assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" | |
| names.add(lora.lora_name) | |
| def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): | |
| if apply_text_encoder: | |
| print("enable LoRA for text encoder") | |
| else: | |
| self.text_encoder_loras = [] | |
| if apply_unet: | |
| print("enable LoRA for U-Net") | |
| else: | |
| self.unet_loras = [] | |
| for lora in self.text_encoder_loras + self.unet_loras: | |
| lora.apply_to() | |
| self.add_module(lora.lora_name, lora) | |
| def set_multiplier(self, multiplier): | |
| self.multiplier = multiplier | |
| for lora in self.text_encoder_loras + self.unet_loras: | |
| lora.multiplier = self.multiplier | |
| def load_weights(self, file): | |
| if os.path.splitext(file)[1] == ".safetensors": | |
| from safetensors.torch import load_file | |
| weights_sd = load_file(file) | |
| else: | |
| weights_sd = torch.load(file, map_location="cpu") | |
| info = self.load_state_dict(weights_sd, False) | |
| return info | |
| def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): | |
| self.requires_grad_(True) | |
| all_params = [] | |
| def enumerate_params(loras): | |
| params = [] | |
| for lora in loras: | |
| params.extend(lora.parameters()) | |
| return params | |
| if self.text_encoder_loras: | |
| param_data = {"params": enumerate_params(self.text_encoder_loras)} | |
| if text_encoder_lr is not None: | |
| param_data["lr"] = text_encoder_lr | |
| all_params.append(param_data) | |
| if self.unet_loras: | |
| param_data = {"params": enumerate_params(self.unet_loras)} | |
| if unet_lr is not None: | |
| param_data["lr"] = unet_lr | |
| all_params.append(param_data) | |
| return all_params | |
| def enable_gradient_checkpointing(self): | |
| pass | |
| def get_trainable_params(self): | |
| return self.parameters() | |
| def save_weights(self, file, dtype, metadata): | |
| if metadata is not None and len(metadata) == 0: | |
| metadata = None | |
| state_dict = self.state_dict() | |
| if dtype is not None: | |
| for key in list(state_dict.keys()): | |
| v = state_dict[key] | |
| v = v.detach().clone().to("cpu").to(dtype) | |
| state_dict[key] = v | |
| if os.path.splitext(file)[1] == ".safetensors": | |
| from safetensors.torch import save_file | |
| # Precalculate model hashes to save time on indexing | |
| if metadata is None: | |
| metadata = {} | |
| model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata) | |
| metadata["sshs_model_hash"] = model_hash | |
| metadata["sshs_legacy_hash"] = legacy_hash | |
| save_file(state_dict, file, metadata) | |
| else: | |
| torch.save(state_dict, file) | |
| def create_network( | |
| multiplier: float, | |
| network_dim: Optional[int], | |
| network_alpha: Optional[float], | |
| text_encoder: Union[T5EncoderModel, List[T5EncoderModel]], | |
| transformer, | |
| neuron_dropout: Optional[float] = None, | |
| add_lora_in_attn_temporal: bool = False, | |
| **kwargs, | |
| ): | |
| if network_dim is None: | |
| network_dim = 4 # default | |
| if network_alpha is None: | |
| network_alpha = 1.0 | |
| network = LoRANetwork( | |
| text_encoder, | |
| transformer, | |
| multiplier=multiplier, | |
| lora_dim=network_dim, | |
| alpha=network_alpha, | |
| dropout=neuron_dropout, | |
| add_lora_in_attn_temporal=add_lora_in_attn_temporal, | |
| varbose=True, | |
| ) | |
| return network | |
| def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False): | |
| LORA_PREFIX_TRANSFORMER = "lora_unet" | |
| LORA_PREFIX_TEXT_ENCODER = "lora_te" | |
| if state_dict is None: | |
| state_dict = load_file(lora_path, device=device) | |
| else: | |
| state_dict = state_dict | |
| updates = defaultdict(dict) | |
| for key, value in state_dict.items(): | |
| layer, elem = key.split('.', 1) | |
| updates[layer][elem] = value | |
| for layer, elems in updates.items(): | |
| if "lora_te" in layer: | |
| if transformer_only: | |
| continue | |
| else: | |
| layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") | |
| curr_layer = pipeline.text_encoder | |
| else: | |
| layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_") | |
| curr_layer = pipeline.transformer | |
| try: | |
| curr_layer = curr_layer.__getattr__("_".join(layer_infos[1:])) | |
| except Exception: | |
| temp_name = layer_infos.pop(0) | |
| while len(layer_infos) > -1: | |
| try: | |
| curr_layer = curr_layer.__getattr__(temp_name) | |
| if len(layer_infos) > 0: | |
| temp_name = layer_infos.pop(0) | |
| elif len(layer_infos) == 0: | |
| break | |
| except Exception: | |
| if len(layer_infos) == 0: | |
| print('Error loading layer') | |
| if len(temp_name) > 0: | |
| temp_name += "_" + layer_infos.pop(0) | |
| else: | |
| temp_name = layer_infos.pop(0) | |
| origin_dtype = curr_layer.weight.data.dtype | |
| origin_device = curr_layer.weight.data.device | |
| curr_layer = curr_layer.to(device, dtype) | |
| weight_up = elems['lora_up.weight'].to(device, dtype) | |
| weight_down = elems['lora_down.weight'].to(device, dtype) | |
| if 'alpha' in elems.keys(): | |
| alpha = elems['alpha'].item() / weight_up.shape[1] | |
| else: | |
| alpha = 1.0 | |
| if len(weight_up.shape) == 4: | |
| curr_layer.weight.data += multiplier * alpha * torch.mm( | |
| weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2) | |
| ).unsqueeze(2).unsqueeze(3) | |
| else: | |
| curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down) | |
| curr_layer = curr_layer.to(origin_device, origin_dtype) | |
| return pipeline | |
| # TODO: Refactor with merge_lora. | |
| def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.float32): | |
| """Unmerge state_dict in LoRANetwork from the pipeline in diffusers.""" | |
| LORA_PREFIX_UNET = "lora_unet" | |
| LORA_PREFIX_TEXT_ENCODER = "lora_te" | |
| state_dict = load_file(lora_path, device=device) | |
| updates = defaultdict(dict) | |
| for key, value in state_dict.items(): | |
| layer, elem = key.split('.', 1) | |
| updates[layer][elem] = value | |
| for layer, elems in updates.items(): | |
| if "lora_te" in layer: | |
| layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") | |
| curr_layer = pipeline.text_encoder | |
| else: | |
| layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_") | |
| curr_layer = pipeline.transformer | |
| try: | |
| curr_layer = curr_layer.__getattr__("_".join(layer_infos[1:])) | |
| except Exception: | |
| temp_name = layer_infos.pop(0) | |
| while len(layer_infos) > -1: | |
| try: | |
| curr_layer = curr_layer.__getattr__(temp_name) | |
| if len(layer_infos) > 0: | |
| temp_name = layer_infos.pop(0) | |
| elif len(layer_infos) == 0: | |
| break | |
| except Exception: | |
| if len(layer_infos) == 0: | |
| print('Error loading layer') | |
| if len(temp_name) > 0: | |
| temp_name += "_" + layer_infos.pop(0) | |
| else: | |
| temp_name = layer_infos.pop(0) | |
| origin_dtype = curr_layer.weight.data.dtype | |
| origin_device = curr_layer.weight.data.device | |
| curr_layer = curr_layer.to(device, dtype) | |
| weight_up = elems['lora_up.weight'].to(device, dtype) | |
| weight_down = elems['lora_down.weight'].to(device, dtype) | |
| if 'alpha' in elems.keys(): | |
| alpha = elems['alpha'].item() / weight_up.shape[1] | |
| else: | |
| alpha = 1.0 | |
| if len(weight_up.shape) == 4: | |
| curr_layer.weight.data -= multiplier * alpha * torch.mm( | |
| weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2) | |
| ).unsqueeze(2).unsqueeze(3) | |
| else: | |
| curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down) | |
| curr_layer = curr_layer.to(origin_device, origin_dtype) | |
| return pipeline | |