import inspect import weakref import torch from typing import TYPE_CHECKING, Tuple from toolkit.lora_special import LoRASpecialNetwork from diffusers import FluxTransformer2DModel from diffusers.models.embeddings import ( CombinedTimestepTextProjEmbeddings, CombinedTimestepGuidanceTextProjEmbeddings, ) from functools import partial if TYPE_CHECKING: from toolkit.stable_diffusion_model import StableDiffusion from toolkit.config_modules import AdapterConfig, TrainConfig, ModelConfig from toolkit.custom_adapter import CustomAdapter from extensions_built_in.diffusion_models.omnigen2.src.models.transformers import OmniGen2Transformer2DModel def mean_flow_time_text_embed_forward( self: CombinedTimestepTextProjEmbeddings, timestep, pooled_projection ): mean_flow_adapter: "MeanFlowAdapter" = self.mean_flow_adapter_ref() # make zero timestep ending if none is passed if mean_flow_adapter.is_active and timestep.shape[0] == pooled_projection.shape[0]: timestep = torch.cat( [timestep, torch.zeros_like(timestep)], dim=0 ) # timestep - 0 (final timestep) == same as start timestep timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder( timesteps_proj.to(dtype=pooled_projection.dtype) ) # (N, D) # mean flow stuff if mean_flow_adapter.is_active: # todo make sure that timesteps is batched correctly, I think diffusers expects non batched timesteps orig_dtype = timesteps_emb.dtype timesteps_emb = timesteps_emb.to(torch.float32) timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0) timesteps_emb = mean_flow_adapter.mean_flow_timestep_embedder( torch.cat([timesteps_emb_start, timesteps_emb_end], dim=-1) ) timesteps_emb = timesteps_emb.to(orig_dtype) pooled_projections = self.text_embedder(pooled_projection) conditioning = timesteps_emb + pooled_projections return conditioning def mean_flow_time_text_guidance_embed_forward( self: CombinedTimestepGuidanceTextProjEmbeddings, timestep, guidance, pooled_projection, ): mean_flow_adapter: "MeanFlowAdapter" = self.mean_flow_adapter_ref() # make zero timestep ending if none is passed if mean_flow_adapter.is_active and timestep.shape[0] == pooled_projection.shape[0]: timestep = torch.cat( [timestep, torch.ones_like(timestep)], dim=0 ) # timestep - 0 (final timestep) == same as start timestep timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder( timesteps_proj.to(dtype=pooled_projection.dtype) ) # (N, D) guidance_proj = self.time_proj(guidance) guidance_emb = self.guidance_embedder( guidance_proj.to(dtype=pooled_projection.dtype) ) # (N, D) # mean flow stuff if mean_flow_adapter.is_active: # todo make sure that timesteps is batched correctly, I think diffusers expects non batched timesteps orig_dtype = timesteps_emb.dtype timesteps_emb = timesteps_emb.to(torch.float32) timesteps_emb_start, timesteps_emb_end = timesteps_emb.chunk(2, dim=0) timesteps_emb = mean_flow_adapter.mean_flow_timestep_embedder( torch.cat([timesteps_emb_start, timesteps_emb_end], dim=-1) ) timesteps_emb = timesteps_emb.to(orig_dtype) time_guidance_emb = timesteps_emb + guidance_emb pooled_projections = self.text_embedder(pooled_projection) conditioning = time_guidance_emb + pooled_projections return conditioning def convert_flux_to_mean_flow( transformer: FluxTransformer2DModel, ): if isinstance(transformer.time_text_embed, CombinedTimestepTextProjEmbeddings): transformer.time_text_embed.forward = partial( mean_flow_time_text_embed_forward, transformer.time_text_embed ) elif isinstance( transformer.time_text_embed, CombinedTimestepGuidanceTextProjEmbeddings ): transformer.time_text_embed.forward = partial( mean_flow_time_text_guidance_embed_forward, transformer.time_text_embed ) else: raise ValueError( "Unsupported time_text_embed type: {}".format( type(transformer.time_text_embed) ) ) def mean_flow_omnigen2_time_text_embed_forward( self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype ) -> Tuple[torch.Tensor, torch.Tensor]: mean_flow_adapter: "MeanFlowAdapter" = self.mean_flow_adapter_ref() if mean_flow_adapter.is_active and timestep.shape[0] == text_hidden_states.shape[0]: timestep = torch.cat( [timestep, torch.ones_like(timestep)], dim=0 # omnigen does reverse timesteps ) timestep_proj = self.time_proj(timestep).to(dtype=dtype) time_embed = self.timestep_embedder(timestep_proj) # mean flow stuff if mean_flow_adapter.is_active: # todo make sure that timesteps is batched correctly, I think diffusers expects non batched timesteps orig_dtype = time_embed.dtype time_embed = time_embed.to(torch.float32) time_embed_start, time_embed_end = time_embed.chunk(2, dim=0) time_embed = mean_flow_adapter.mean_flow_timestep_embedder( torch.cat([time_embed_start, time_embed_end], dim=-1) ) time_embed = time_embed.to(orig_dtype) caption_embed = self.caption_embedder(text_hidden_states) return time_embed, caption_embed def convert_omnigen2_to_mean_flow( transformer: 'OmniGen2Transformer2DModel', ): transformer.time_caption_embed.forward = partial( mean_flow_omnigen2_time_text_embed_forward, transformer.time_caption_embed ) class MeanFlowAdapter(torch.nn.Module): def __init__( self, adapter: "CustomAdapter", sd: "StableDiffusion", config: "AdapterConfig", train_config: "TrainConfig", ): super().__init__() self.adapter_ref: weakref.ref = weakref.ref(adapter) self.sd_ref = weakref.ref(sd) self.model_config: ModelConfig = sd.model_config self.network_config = config.lora_config self.train_config = train_config self.device_torch = sd.device_torch self.lora = None if self.network_config is not None: network_kwargs = ( {} if self.network_config.network_kwargs is None else self.network_config.network_kwargs ) if hasattr(sd, "target_lora_modules"): network_kwargs["target_lin_modules"] = sd.target_lora_modules if "ignore_if_contains" not in network_kwargs: network_kwargs["ignore_if_contains"] = [] self.lora = LoRASpecialNetwork( text_encoder=sd.text_encoder, unet=sd.unet, lora_dim=self.network_config.linear, multiplier=1.0, alpha=self.network_config.linear_alpha, train_unet=self.train_config.train_unet, train_text_encoder=self.train_config.train_text_encoder, conv_lora_dim=self.network_config.conv, conv_alpha=self.network_config.conv_alpha, is_sdxl=self.model_config.is_xl or self.model_config.is_ssd, is_v2=self.model_config.is_v2, is_v3=self.model_config.is_v3, is_pixart=self.model_config.is_pixart, is_auraflow=self.model_config.is_auraflow, is_flux=self.model_config.is_flux, is_lumina2=self.model_config.is_lumina2, is_ssd=self.model_config.is_ssd, is_vega=self.model_config.is_vega, dropout=self.network_config.dropout, use_text_encoder_1=self.model_config.use_text_encoder_1, use_text_encoder_2=self.model_config.use_text_encoder_2, use_bias=False, is_lorm=False, network_config=self.network_config, network_type=self.network_config.type, transformer_only=self.network_config.transformer_only, is_transformer=sd.is_transformer, base_model=sd, **network_kwargs, ) self.lora.force_to(self.device_torch, dtype=torch.float32) self.lora._update_torch_multiplier() self.lora.apply_to( sd.text_encoder, sd.unet, self.train_config.train_text_encoder, self.train_config.train_unet, ) self.lora.can_merge_in = False self.lora.prepare_grad_etc(sd.text_encoder, sd.unet) if self.train_config.gradient_checkpointing: self.lora.enable_gradient_checkpointing() emb_dim = None if self.model_config.arch in ["flux", "flex2", "flex2"]: transformer: FluxTransformer2DModel = sd.unet emb_dim = ( transformer.config.num_attention_heads * transformer.config.attention_head_dim ) convert_flux_to_mean_flow(transformer) elif self.model_config.arch in ["omnigen2"]: transformer: 'OmniGen2Transformer2DModel' = sd.unet emb_dim = ( 1024 ) convert_omnigen2_to_mean_flow(transformer) else: raise ValueError(f"Unsupported architecture: {self.model_config.arch}") self.mean_flow_timestep_embedder = torch.nn.Linear( emb_dim * 2, emb_dim, ) # make the model function as before adding this adapter by initializing the weights with torch.no_grad(): self.mean_flow_timestep_embedder.weight.zero_() self.mean_flow_timestep_embedder.weight[:, :emb_dim] = torch.eye(emb_dim) self.mean_flow_timestep_embedder.bias.zero_() self.mean_flow_timestep_embedder.to(self.device_torch) # add our adapter as a weak ref if self.model_config.arch in ["flux", "flex2", "flex2"]: sd.unet.time_text_embed.mean_flow_adapter_ref = weakref.ref(self) elif self.model_config.arch in ["omnigen2"]: sd.unet.time_caption_embed.mean_flow_adapter_ref = weakref.ref(self) def get_params(self): if self.lora is not None: config = { "text_encoder_lr": self.train_config.lr, "unet_lr": self.train_config.lr, } sig = inspect.signature(self.lora.prepare_optimizer_params) if "default_lr" in sig.parameters: config["default_lr"] = self.train_config.lr if "learning_rate" in sig.parameters: config["learning_rate"] = self.train_config.lr params_net = self.lora.prepare_optimizer_params(**config) # we want only tensors here params = [] for p in params_net: if isinstance(p, dict): params += p["params"] elif isinstance(p, torch.Tensor): params.append(p) elif isinstance(p, list): params += p else: params = [] # make sure the embedder is float32 self.mean_flow_timestep_embedder.to(torch.float32) self.mean_flow_timestep_embedder.requires_grad = True self.mean_flow_timestep_embedder.train() params += list(self.mean_flow_timestep_embedder.parameters()) # we need to be able to yield from the list like yield from params return params def load_weights(self, state_dict, strict=True): lora_sd = {} mean_flow_embedder_sd = {} for key, value in state_dict.items(): if "mean_flow_timestep_embedder" in key: new_key = key.replace("transformer.mean_flow_timestep_embedder.", "") mean_flow_embedder_sd[new_key] = value else: lora_sd[key] = value # todo process state dict before loading for models that need it if self.lora is not None: self.lora.load_weights(lora_sd) self.mean_flow_timestep_embedder.load_state_dict( mean_flow_embedder_sd, strict=False ) def get_state_dict(self): if self.lora is not None: lora_sd = self.lora.get_state_dict(dtype=torch.float32) else: lora_sd = {} # todo make sure we match loras elseware. mean_flow_embedder_sd = self.mean_flow_timestep_embedder.state_dict() for key, value in mean_flow_embedder_sd.items(): lora_sd[f"transformer.mean_flow_timestep_embedder.{key}"] = value return lora_sd @property def is_active(self): return self.adapter_ref().is_active