Spaces:
Runtime error
Runtime error
| import inspect | |
| import weakref | |
| import torch | |
| from typing import TYPE_CHECKING, Tuple | |
| from toolkit.lora_special import LoRASpecialNetwork | |
| from diffusers import FluxTransformer2DModel | |
| from diffusers.models.embeddings import ( | |
| CombinedTimestepTextProjEmbeddings, | |
| CombinedTimestepGuidanceTextProjEmbeddings, | |
| ) | |
| from functools import partial | |
| if TYPE_CHECKING: | |
| from toolkit.stable_diffusion_model import StableDiffusion | |
| from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig | |
| from toolkit.custom_adapter import CustomAdapter | |
| from extensions_built_in.diffusion_models.omnigen2.src.models.transformers import OmniGen2Transformer2DModel | |
| def mean_flow_time_text_embed_forward( | |
| self: CombinedTimestepTextProjEmbeddings, timestep, pooled_projection | |
| ): | |
| mean_flow_adapter: "MeanFlowAdapter" = self.mean_flow_adapter_ref() | |
| # make zero timestep ending if none is passed | |
| if mean_flow_adapter.is_active and timestep.shape[0] == pooled_projection.shape[0]: | |
| timestep = torch.cat( | |
| [timestep, torch.zeros_like(timestep)], dim=0 | |
| ) # timestep - 0 (final timestep) == same as start timestep | |
| timesteps_proj = self.time_proj(timestep) | |
| timesteps_emb = self.timestep_embedder( | |
| timesteps_proj.to(dtype=pooled_projection.dtype) | |
| ) # (N, D) | |
| # mean flow stuff | |
| if mean_flow_adapter.is_active: | |
| # todo make sure that timesteps is batched correctly, I think diffusers expects non batched timesteps | |
| orig_dtype = timesteps_emb.dtype | |
| timesteps_emb = timesteps_emb.to(torch.float32) | |
| timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0) | |
| timesteps_emb = mean_flow_adapter.mean_flow_timestep_embedder( | |
| torch.cat([timesteps_emb_start, timesteps_emb_end], dim=-1) | |
| ) | |
| timesteps_emb = timesteps_emb.to(orig_dtype) | |
| pooled_projections = self.text_embedder(pooled_projection) | |
| conditioning = timesteps_emb + pooled_projections | |
| return conditioning | |
| def mean_flow_time_text_guidance_embed_forward( | |
| self: CombinedTimestepGuidanceTextProjEmbeddings, | |
| timestep, | |
| guidance, | |
| pooled_projection, | |
| ): | |
| mean_flow_adapter: "MeanFlowAdapter" = self.mean_flow_adapter_ref() | |
| # make zero timestep ending if none is passed | |
| if mean_flow_adapter.is_active and timestep.shape[0] == pooled_projection.shape[0]: | |
| timestep = torch.cat( | |
| [timestep, torch.ones_like(timestep)], dim=0 | |
| ) # timestep - 0 (final timestep) == same as start timestep | |
| timesteps_proj = self.time_proj(timestep) | |
| timesteps_emb = self.timestep_embedder( | |
| timesteps_proj.to(dtype=pooled_projection.dtype) | |
| ) # (N, D) | |
| guidance_proj = self.time_proj(guidance) | |
| guidance_emb = self.guidance_embedder( | |
| guidance_proj.to(dtype=pooled_projection.dtype) | |
| ) # (N, D) | |
| # mean flow stuff | |
| if mean_flow_adapter.is_active: | |
| # todo make sure that timesteps is batched correctly, I think diffusers expects non batched timesteps | |
| orig_dtype = timesteps_emb.dtype | |
| timesteps_emb = timesteps_emb.to(torch.float32) | |
| timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0) | |
| timesteps_emb = mean_flow_adapter.mean_flow_timestep_embedder( | |
| torch.cat([timesteps_emb_start, timesteps_emb_end], dim=-1) | |
| ) | |
| timesteps_emb = timesteps_emb.to(orig_dtype) | |
| time_guidance_emb = timesteps_emb + guidance_emb | |
| pooled_projections = self.text_embedder(pooled_projection) | |
| conditioning = time_guidance_emb + pooled_projections | |
| return conditioning | |
| def convert_flux_to_mean_flow( | |
| transformer: FluxTransformer2DModel, | |
| ): | |
| if isinstance(transformer.time_text_embed, CombinedTimestepTextProjEmbeddings): | |
| transformer.time_text_embed.forward = partial( | |
| mean_flow_time_text_embed_forward, transformer.time_text_embed | |
| ) | |
| elif isinstance( | |
| transformer.time_text_embed, CombinedTimestepGuidanceTextProjEmbeddings | |
| ): | |
| transformer.time_text_embed.forward = partial( | |
| mean_flow_time_text_guidance_embed_forward, transformer.time_text_embed | |
| ) | |
| else: | |
| raise ValueError( | |
| "Unsupported time_text_embed type: {}".format( | |
| type(transformer.time_text_embed) | |
| ) | |
| ) | |
| def mean_flow_omnigen2_time_text_embed_forward( | |
| self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| mean_flow_adapter: "MeanFlowAdapter" = self.mean_flow_adapter_ref() | |
| if mean_flow_adapter.is_active and timestep.shape[0] == text_hidden_states.shape[0]: | |
| timestep = torch.cat( | |
| [timestep, torch.ones_like(timestep)], dim=0 # omnigen does reverse timesteps | |
| ) | |
| timestep_proj = self.time_proj(timestep).to(dtype=dtype) | |
| time_embed = self.timestep_embedder(timestep_proj) | |
| # mean flow stuff | |
| if mean_flow_adapter.is_active: | |
| # todo make sure that timesteps is batched correctly, I think diffusers expects non batched timesteps | |
| orig_dtype = time_embed.dtype | |
| time_embed = time_embed.to(torch.float32) | |
| time_embed_start, time_embed_end = time_embed.chunk(2, dim=0) | |
| time_embed = mean_flow_adapter.mean_flow_timestep_embedder( | |
| torch.cat([time_embed_start, time_embed_end], dim=-1) | |
| ) | |
| time_embed = time_embed.to(orig_dtype) | |
| caption_embed = self.caption_embedder(text_hidden_states) | |
| return time_embed, caption_embed | |
| def convert_omnigen2_to_mean_flow( | |
| transformer: 'OmniGen2Transformer2DModel', | |
| ): | |
| transformer.time_caption_embed.forward = partial( | |
| mean_flow_omnigen2_time_text_embed_forward, transformer.time_caption_embed | |
| ) | |
| class MeanFlowAdapter(torch.nn.Module): | |
| def __init__( | |
| self, | |
| adapter: "CustomAdapter", | |
| sd: "StableDiffusion", | |
| config: "AdapterConfig", | |
| train_config: "TrainConfig", | |
| ): | |
| super().__init__() | |
| self.adapter_ref: weakref.ref = weakref.ref(adapter) | |
| self.sd_ref = weakref.ref(sd) | |
| self.model_config: ModelConfig = sd.model_config | |
| self.network_config = config.lora_config | |
| self.train_config = train_config | |
| self.device_torch = sd.device_torch | |
| self.lora = None | |
| if self.network_config is not None: | |
| network_kwargs = ( | |
| {} | |
| if self.network_config.network_kwargs is None | |
| else self.network_config.network_kwargs | |
| ) | |
| if hasattr(sd, "target_lora_modules"): | |
| network_kwargs["target_lin_modules"] = sd.target_lora_modules | |
| if "ignore_if_contains" not in network_kwargs: | |
| network_kwargs["ignore_if_contains"] = [] | |
| self.lora = LoRASpecialNetwork( | |
| text_encoder=sd.text_encoder, | |
| unet=sd.unet, | |
| lora_dim=self.network_config.linear, | |
| multiplier=1.0, | |
| alpha=self.network_config.linear_alpha, | |
| train_unet=self.train_config.train_unet, | |
| train_text_encoder=self.train_config.train_text_encoder, | |
| conv_lora_dim=self.network_config.conv, | |
| conv_alpha=self.network_config.conv_alpha, | |
| is_sdxl=self.model_config.is_xl or self.model_config.is_ssd, | |
| is_v2=self.model_config.is_v2, | |
| is_v3=self.model_config.is_v3, | |
| is_pixart=self.model_config.is_pixart, | |
| is_auraflow=self.model_config.is_auraflow, | |
| is_flux=self.model_config.is_flux, | |
| is_lumina2=self.model_config.is_lumina2, | |
| is_ssd=self.model_config.is_ssd, | |
| is_vega=self.model_config.is_vega, | |
| dropout=self.network_config.dropout, | |
| use_text_encoder_1=self.model_config.use_text_encoder_1, | |
| use_text_encoder_2=self.model_config.use_text_encoder_2, | |
| use_bias=False, | |
| is_lorm=False, | |
| network_config=self.network_config, | |
| network_type=self.network_config.type, | |
| transformer_only=self.network_config.transformer_only, | |
| is_transformer=sd.is_transformer, | |
| base_model=sd, | |
| **network_kwargs, | |
| ) | |
| self.lora.force_to(self.device_torch, dtype=torch.float32) | |
| self.lora._update_torch_multiplier() | |
| self.lora.apply_to( | |
| sd.text_encoder, | |
| sd.unet, | |
| self.train_config.train_text_encoder, | |
| self.train_config.train_unet, | |
| ) | |
| self.lora.can_merge_in = False | |
| self.lora.prepare_grad_etc(sd.text_encoder, sd.unet) | |
| if self.train_config.gradient_checkpointing: | |
| self.lora.enable_gradient_checkpointing() | |
| emb_dim = None | |
| if self.model_config.arch in ["flux", "flex2", "flex2"]: | |
| transformer: FluxTransformer2DModel = sd.unet | |
| emb_dim = ( | |
| transformer.config.num_attention_heads | |
| * transformer.config.attention_head_dim | |
| ) | |
| convert_flux_to_mean_flow(transformer) | |
| elif self.model_config.arch in ["omnigen2"]: | |
| transformer: 'OmniGen2Transformer2DModel' = sd.unet | |
| emb_dim = ( | |
| 1024 | |
| ) | |
| convert_omnigen2_to_mean_flow(transformer) | |
| else: | |
| raise ValueError(f"Unsupported architecture: {self.model_config.arch}") | |
| self.mean_flow_timestep_embedder = torch.nn.Linear( | |
| emb_dim * 2, | |
| emb_dim, | |
| ) | |
| # make the model function as before adding this adapter by initializing the weights | |
| with torch.no_grad(): | |
| self.mean_flow_timestep_embedder.weight.zero_() | |
| self.mean_flow_timestep_embedder.weight[:, :emb_dim] = torch.eye(emb_dim) | |
| self.mean_flow_timestep_embedder.bias.zero_() | |
| self.mean_flow_timestep_embedder.to(self.device_torch) | |
| # add our adapter as a weak ref | |
| if self.model_config.arch in ["flux", "flex2", "flex2"]: | |
| sd.unet.time_text_embed.mean_flow_adapter_ref = weakref.ref(self) | |
| elif self.model_config.arch in ["omnigen2"]: | |
| sd.unet.time_caption_embed.mean_flow_adapter_ref = weakref.ref(self) | |
| def get_params(self): | |
| if self.lora is not None: | |
| config = { | |
| "text_encoder_lr": self.train_config.lr, | |
| "unet_lr": self.train_config.lr, | |
| } | |
| sig = inspect.signature(self.lora.prepare_optimizer_params) | |
| if "default_lr" in sig.parameters: | |
| config["default_lr"] = self.train_config.lr | |
| if "learning_rate" in sig.parameters: | |
| config["learning_rate"] = self.train_config.lr | |
| params_net = self.lora.prepare_optimizer_params(**config) | |
| # we want only tensors here | |
| params = [] | |
| for p in params_net: | |
| if isinstance(p, dict): | |
| params += p["params"] | |
| elif isinstance(p, torch.Tensor): | |
| params.append(p) | |
| elif isinstance(p, list): | |
| params += p | |
| else: | |
| params = [] | |
| # make sure the embedder is float32 | |
| self.mean_flow_timestep_embedder.to(torch.float32) | |
| self.mean_flow_timestep_embedder.requires_grad = True | |
| self.mean_flow_timestep_embedder.train() | |
| params += list(self.mean_flow_timestep_embedder.parameters()) | |
| # we need to be able to yield from the list like yield from params | |
| return params | |
| def load_weights(self, state_dict, strict=True): | |
| lora_sd = {} | |
| mean_flow_embedder_sd = {} | |
| for key, value in state_dict.items(): | |
| if "mean_flow_timestep_embedder" in key: | |
| new_key = key.replace("transformer.mean_flow_timestep_embedder.", "") | |
| mean_flow_embedder_sd[new_key] = value | |
| else: | |
| lora_sd[key] = value | |
| # todo process state dict before loading for models that need it | |
| if self.lora is not None: | |
| self.lora.load_weights(lora_sd) | |
| self.mean_flow_timestep_embedder.load_state_dict( | |
| mean_flow_embedder_sd, strict=False | |
| ) | |
| def get_state_dict(self): | |
| if self.lora is not None: | |
| lora_sd = self.lora.get_state_dict(dtype=torch.float32) | |
| else: | |
| lora_sd = {} | |
| # todo make sure we match loras elseware. | |
| mean_flow_embedder_sd = self.mean_flow_timestep_embedder.state_dict() | |
| for key, value in mean_flow_embedder_sd.items(): | |
| lora_sd[f"transformer.mean_flow_timestep_embedder.{key}"] = value | |
| return lora_sd | |
| def is_active(self): | |
| return self.adapter_ref().is_active | |