# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # This was modied from the control net repo import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel import numpy as np import torch from transformers import ( CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, ) from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin from diffusers.models.autoencoders import AutoencoderKL ### MERGEING THESE ### # from src.models.transformer import FluxTransformer2DModel # from src.models.controlnet_flux import FluxControlNetModel ############# ########################################## ########### ATTENTION MERGE ############## ########################################## import torch from torch import Tensor, FloatTensor from torch.nn import functional as F from einops import rearrange from diffusers.models.attention_processor import Attention from diffusers.models.embeddings import apply_rotary_emb #try: # from flash_attn_interface import flash_attn_func, flash_attn_qkvpacked_func #except: # pass """def fa3_sdpa( q, k, v, ): # flash attention 3 sdpa drop-in replacement q, k, v = [x.permute(0, 2, 1, 3) for x in [q, k, v]] out = flash_attn_func(q, k, v)[0] return out.permute(0, 2, 1, 3)""" """ class FluxSingleAttnProcessor3_0: r"" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). "" def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) def __call__( self, attn, hidden_states: Tensor, encoder_hidden_states: Tensor = None, attention_mask: FloatTensor = None, image_rotary_emb: Tensor = None, ) -> Tensor: input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view( batch_size, channel, height * width ).transpose(1, 2) batch_size, _, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # Apply RoPE if needed if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 # hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = fa3_sdpa(query, key, value) hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)") hidden_states = hidden_states.transpose(1, 2).reshape( batch_size, -1, attn.heads * head_dim ) hidden_states = hidden_states.to(query.dtype) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape( batch_size, channel, height, width ) return hidden_states class FluxAttnProcessor3_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "FluxAttnProcessor3_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) def __call__( self, attn, hidden_states: FloatTensor, encoder_hidden_states: FloatTensor = None, attention_mask: FloatTensor = None, image_rotary_emb: Tensor = None, ) -> FloatTensor: input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view( batch_size, channel, height * width ).transpose(1, 2) context_input_ndim = encoder_hidden_states.ndim if context_input_ndim == 4: batch_size, channel, height, width = encoder_hidden_states.shape encoder_hidden_states = encoder_hidden_states.view( batch_size, channel, height * width ).transpose(1, 2) batch_size = encoder_hidden_states.shape[0] # `sample` projections. query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # `context` projections. encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) if attn.norm_added_q is not None: encoder_hidden_states_query_proj = attn.norm_added_q( encoder_hidden_states_query_proj ) if attn.norm_added_k is not None: encoder_hidden_states_key_proj = attn.norm_added_k( encoder_hidden_states_key_proj ) # attention query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) # hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = fa3_sdpa(query, key, value) hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)") hidden_states = hidden_states.transpose(1, 2).reshape( batch_size, -1, attn.heads * head_dim ) hidden_states = hidden_states.to(query.dtype) encoder_hidden_states, hidden_states = ( hidden_states[:, : encoder_hidden_states.shape[1]], hidden_states[:, encoder_hidden_states.shape[1] :], ) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) encoder_hidden_states = attn.to_add_out(encoder_hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape( batch_size, channel, height, width ) if context_input_ndim == 4: encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape( batch_size, channel, height, width ) return hidden_states, encoder_hidden_states class FluxFusedFlashAttnProcessor3(object): """ True fused QKV Flash Attention 3 processor for Flux models. Keeps QKV tensors packed through the entire attention computation. """ def __init__(self): self.flash_attn_qkvpacked_func = None try: from flash_attn_interface import flash_attn_qkvpacked_func self.flash_attn_qkvpacked_func = flash_attn_qkvpacked_func except ImportError: raise ImportError( "FluxFusedFlashAttnProcessor3 requires flash-attn library. " "Please see this link for Hopper and Blackwell instructions: https://github.com/bghira/SimpleTuner/blob/main/INSTALL.md#nvidia-hopper--blackwell-follow-up-steps" ) def __call__( self, attn, hidden_states: FloatTensor, encoder_hidden_states: FloatTensor = None, attention_mask: FloatTensor = None, image_rotary_emb: Tensor = None, ) -> FloatTensor: input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view( batch_size, channel, height * width ).transpose(1, 2) context_input_ndim = ( encoder_hidden_states.ndim if encoder_hidden_states is not None else None ) if context_input_ndim == 4: batch_size, channel, height, width = encoder_hidden_states.shape encoder_hidden_states = encoder_hidden_states.view( batch_size, channel, height * width ).transpose(1, 2) batch_size = ( encoder_hidden_states.shape[0] if encoder_hidden_states is not None else hidden_states.shape[0] ) seq_len = hidden_states.shape[1] # Fused QKV projection qkv = attn.to_qkv(hidden_states) # (batch, seq_len, 3 * inner_dim) inner_dim = qkv.shape[-1] // 3 head_dim = inner_dim // attn.heads # Reshape to packed format: (batch, seq_len, 3, heads, head_dim) qkv = qkv.view(batch_size, seq_len, 3, attn.heads, head_dim) # Apply norms if needed (requires temporary unpacking) if attn.norm_q is not None or attn.norm_k is not None: q, k, v = qkv.unbind(dim=2) # Each is (batch, seq_len, heads, head_dim) q = q.transpose(1, 2) # (batch, heads, seq_len, head_dim) k = k.transpose(1, 2) v = v.transpose(1, 2) if attn.norm_q is not None: q = attn.norm_q(q) if attn.norm_k is not None: k = attn.norm_k(k) # Repack: back to (batch, seq_len, 3, heads, head_dim) qkv = torch.stack( [q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)], dim=2 ) # Handle encoder states if present if encoder_hidden_states is not None: encoder_seq_len = encoder_hidden_states.shape[1] # Fused encoder QKV encoder_qkv = attn.to_added_qkv(encoder_hidden_states) encoder_qkv = encoder_qkv.view( batch_size, encoder_seq_len, 3, attn.heads, head_dim ) # Apply norms if needed if attn.norm_added_q is not None or attn.norm_added_k is not None: enc_q, enc_k, enc_v = encoder_qkv.unbind(dim=2) enc_q = enc_q.transpose(1, 2) enc_k = enc_k.transpose(1, 2) enc_v = enc_v.transpose(1, 2) if attn.norm_added_q is not None: enc_q = attn.norm_added_q(enc_q) if attn.norm_added_k is not None: enc_k = attn.norm_added_k(enc_k) encoder_qkv = torch.stack( [ enc_q.transpose(1, 2), enc_k.transpose(1, 2), enc_v.transpose(1, 2), ], dim=2, ) # Concatenate along sequence dimension qkv = torch.cat( [encoder_qkv, qkv], dim=1 ) # (batch, encoder_seq + seq, 3, heads, head_dim) # Apply RoPE if needed if image_rotary_emb is not None: q, k, v = qkv.unbind(dim=2) # Each is (batch, seq_len, heads, head_dim) # Transpose to (batch, heads, seq_len, head_dim) for RoPE q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) # Apply RoPE to q and k q = apply_rotary_emb(q, image_rotary_emb) k = apply_rotary_emb(k, image_rotary_emb) # Transpose back and repack qkv = torch.stack( [q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)], dim=2 ) # Flash Attention 3 with packed QKV # Input shape: (batch, seq_len, 3, heads, head_dim) # Output shape: (batch, seq_len, heads, head_dim) hidden_states = self.flash_attn_qkvpacked_func( qkv, causal=False, # Don't pass num_heads_q for standard MHA ) # Reshape output: (batch, seq_len, heads, head_dim) -> (batch, seq_len, heads * head_dim) hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(qkv.dtype) # Split and process outputs if encoder_hidden_states is not None: encoder_seq_len = encoder_hidden_states.shape[1] encoder_hidden_states = hidden_states[:, :encoder_seq_len] hidden_states = hidden_states[:, encoder_seq_len:] # Output projections hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) # dropout encoder_hidden_states = attn.to_add_out(encoder_hidden_states) # Reshape if needed if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape( batch_size, channel, height, width ) if context_input_ndim == 4: encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape( batch_size, channel, height, width ) return hidden_states, encoder_hidden_states else: if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape( batch_size, channel, height, width ) return hidden_states """ class FluxFusedSDPAProcessor: """ Fused QKV processor using PyTorch's scaled_dot_product_attention. Uses fused projections but splits for attention computation. """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "FluxFusedSDPAProcessor requires PyTorch 2.0+ for scaled_dot_product_attention" ) def __call__( self, attn, hidden_states: FloatTensor, encoder_hidden_states: FloatTensor = None, attention_mask: FloatTensor = None, image_rotary_emb: Tensor = None, ) -> FloatTensor: input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view( batch_size, channel, height * width ).transpose(1, 2) context_input_ndim = ( encoder_hidden_states.ndim if encoder_hidden_states is not None else None ) if context_input_ndim == 4: batch_size, channel, height, width = encoder_hidden_states.shape encoder_hidden_states = encoder_hidden_states.view( batch_size, channel, height * width ).transpose(1, 2) batch_size = ( encoder_hidden_states.shape[0] if encoder_hidden_states is not None else hidden_states.shape[0] ) # Single attention case (no encoder states) if encoder_hidden_states is None: # Use fused QKV projection qkv = attn.to_qkv(hidden_states) # (batch, seq_len, 3 * inner_dim) inner_dim = qkv.shape[-1] // 3 head_dim = inner_dim // attn.heads seq_len = hidden_states.shape[1] # Split and reshape qkv = qkv.view(batch_size, seq_len, 3, attn.heads, head_dim) query, key, value = qkv.unbind( dim=2 ) # Each is (batch, seq_len, heads, head_dim) # Transpose to (batch, heads, seq_len, head_dim) query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) # Apply norms if needed if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # Apply RoPE if needed if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) # SDPA hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, ) # Reshape back hidden_states = hidden_states.transpose(1, 2).reshape( batch_size, -1, attn.heads * head_dim ) hidden_states = hidden_states.to(query.dtype) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape( batch_size, channel, height, width ) return hidden_states # Joint attention case (with encoder states) else: # Process self-attention QKV qkv = attn.to_qkv(hidden_states) inner_dim = qkv.shape[-1] // 3 head_dim = inner_dim // attn.heads seq_len = hidden_states.shape[1] qkv = qkv.view(batch_size, seq_len, 3, attn.heads, head_dim) query, key, value = qkv.unbind(dim=2) # Transpose to (batch, heads, seq_len, head_dim) query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) # Apply norms if needed if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # Process encoder QKV encoder_seq_len = encoder_hidden_states.shape[1] encoder_qkv = attn.to_added_qkv(encoder_hidden_states) encoder_qkv = encoder_qkv.view( batch_size, encoder_seq_len, 3, attn.heads, head_dim ) encoder_query, encoder_key, encoder_value = encoder_qkv.unbind(dim=2) # Transpose to (batch, heads, seq_len, head_dim) encoder_query = encoder_query.transpose(1, 2) encoder_key = encoder_key.transpose(1, 2) encoder_value = encoder_value.transpose(1, 2) # Apply encoder norms if needed if attn.norm_added_q is not None: encoder_query = attn.norm_added_q(encoder_query) if attn.norm_added_k is not None: encoder_key = attn.norm_added_k(encoder_key) # Concatenate encoder and self-attention query = torch.cat([encoder_query, query], dim=2) key = torch.cat([encoder_key, key], dim=2) value = torch.cat([encoder_value, value], dim=2) # Apply RoPE if needed if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) # SDPA hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, ) # Reshape: (batch, heads, seq_len, head_dim) -> (batch, seq_len, heads * head_dim) hidden_states = hidden_states.transpose(1, 2).reshape( batch_size, -1, attn.heads * head_dim ) hidden_states = hidden_states.to(query.dtype) # Split encoder and self outputs encoder_hidden_states = hidden_states[:, :encoder_seq_len] hidden_states = hidden_states[:, encoder_seq_len:] # Output projections hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) # dropout encoder_hidden_states = attn.to_add_out(encoder_hidden_states) # Reshape if needed if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape( batch_size, channel, height, width ) if context_input_ndim == 4: encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape( batch_size, channel, height, width ) return hidden_states, encoder_hidden_states class FluxSingleFusedSDPAProcessor: """ Fused QKV processor for single attention (no encoder states). Simpler version for self-attention only blocks. """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "FluxSingleFusedSDPAProcessor requires PyTorch 2.0+ for scaled_dot_product_attention" ) def __call__( self, attn, hidden_states: Tensor, encoder_hidden_states: Tensor = None, attention_mask: FloatTensor = None, image_rotary_emb: Tensor = None, ) -> Tensor: input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view( batch_size, channel, height * width ).transpose(1, 2) batch_size, seq_len, _ = hidden_states.shape # Use fused QKV projection qkv = attn.to_qkv(hidden_states) # (batch, seq_len, 3 * inner_dim) inner_dim = qkv.shape[-1] // 3 head_dim = inner_dim // attn.heads # Split and reshape in one go qkv = qkv.view(batch_size, seq_len, 3, attn.heads, head_dim) qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, L, D) – still strided query, key, value = [ t.contiguous() for t in qkv.unbind(0) # make each view dense ] # Now each is (batch, heads, seq_len, head_dim) # Apply norms if needed if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # Apply RoPE if needed if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) # SDPA hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) # Reshape back hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)") hidden_states = hidden_states.to(query.dtype) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape( batch_size, channel, height, width ) return hidden_states ################################# ##### TRANSFORMER MERGE ######### ################################# from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin from diffusers.models.attention import FeedForward from diffusers.models.attention_processor import ( Attention, AttentionProcessor, ) from diffusers.models.modeling_utils import ModelMixin from diffusers.models.normalization import ( AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, ) from diffusers.utils import ( USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers, ) from diffusers.utils.torch_utils import maybe_allow_in_graph from diffusers.models.embeddings import ( CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed, ) from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers import FluxTransformer2DModel as OriginalFluxTransformer2DModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name is_flash_attn_available = False """try: from flash_attn_interface import flash_attn_func is_flash_attn_available = True except: pass""" class FluxAttnProcessor2_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: batch_size, _, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) # `sample` projections. query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` if encoder_hidden_states is not None: # `context` projections. encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) if attn.norm_added_q is not None: encoder_hidden_states_query_proj = attn.norm_added_q( encoder_hidden_states_query_proj ) if attn.norm_added_k is not None: encoder_hidden_states_key_proj = attn.norm_added_k( encoder_hidden_states_key_proj ) # attention query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) if image_rotary_emb is not None: from diffusers.models.embeddings import apply_rotary_emb query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) if attention_mask is not None: #print ('Attention Used') attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) attention_mask = (attention_mask > 0).bool() # Edit 17 - match attn dtype to query d-type attention_mask = attention_mask.to( device=hidden_states.device, dtype=query.dtype ) hidden_states = F.scaled_dot_product_attention( query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask, ) hidden_states = hidden_states.transpose(1, 2).reshape( batch_size, -1, attn.heads * head_dim ) hidden_states = hidden_states.to(query.dtype) if encoder_hidden_states is not None: encoder_hidden_states, hidden_states = ( hidden_states[:, : encoder_hidden_states.shape[1]], hidden_states[:, encoder_hidden_states.shape[1] :], ) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states return hidden_states def expand_flux_attention_mask( hidden_states: torch.Tensor, attn_mask: torch.Tensor, ) -> torch.Tensor: """ Expand a mask so that the image is included. """ bsz = attn_mask.shape[0] assert bsz == hidden_states.shape[0] residual_seq_len = hidden_states.shape[1] mask_seq_len = attn_mask.shape[1] expanded_mask = torch.ones(bsz, residual_seq_len) expanded_mask[:, :mask_seq_len] = attn_mask return expanded_mask @maybe_allow_in_graph class FluxSingleTransformerBlock(nn.Module): r""" A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. Reference: https://arxiv.org/abs/2403.03206 Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the processing of `context` conditions. """ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): super().__init__() self.mlp_hidden_dim = int(dim * mlp_ratio) self.norm = AdaLayerNormZeroSingle(dim) self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) self.act_mlp = nn.GELU(approximate="tanh") self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) processor = FluxAttnProcessor2_0() self.attn = Attention( query_dim=dim, cross_attention_dim=None, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, bias=True, processor=processor, qk_norm="rms_norm", eps=1e-6, pre_only=True, ) def forward( self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor, image_rotary_emb=None, attention_mask: Optional[torch.Tensor] = None, ): residual = hidden_states norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) if attention_mask is not None: attention_mask = expand_flux_attention_mask( hidden_states, attention_mask, ) attn_output = self.attn( hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, attention_mask=attention_mask, ) hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) gate = gate.unsqueeze(1) hidden_states = gate * self.proj_out(hidden_states) hidden_states = residual + hidden_states if hidden_states.dtype == torch.float16: hidden_states = hidden_states.clip(-65504, 65504) return hidden_states @maybe_allow_in_graph class FluxTransformerBlock(nn.Module): r""" A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. Reference: https://arxiv.org/abs/2403.03206 Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the processing of `context` conditions. """ def __init__( self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6 ): super().__init__() self.norm1 = AdaLayerNormZero(dim) self.norm1_context = AdaLayerNormZero(dim) if hasattr(F, "scaled_dot_product_attention"): processor = FluxAttnProcessor2_0() else: raise ValueError( "The current PyTorch version does not support the `scaled_dot_product_attention` function." ) self.attn = Attention( query_dim=dim, cross_attention_dim=None, added_kv_proj_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, context_pre_only=False, bias=True, processor=processor, qk_norm=qk_norm, eps=eps, ) self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff_context = FeedForward( dim=dim, dim_out=dim, activation_fn="gelu-approximate" ) # let chunk size default to None self._chunk_size = None self._chunk_dim = 0 def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, image_rotary_emb=None, attention_mask: Optional[torch.Tensor] = None, ): norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( hidden_states, emb=temb ) norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = ( self.norm1_context(encoder_hidden_states, emb=temb) ) if attention_mask is not None: attention_mask = expand_flux_attention_mask( torch.cat([encoder_hidden_states, hidden_states], dim=1), attention_mask, ) # Attention. attention_outputs = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, attention_mask=attention_mask, ) if len(attention_outputs) == 2: attn_output, context_attn_output = attention_outputs elif len(attention_outputs) == 3: attn_output, context_attn_output, ip_attn_output = attention_outputs # Process attention outputs for the `hidden_states`. attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = hidden_states + attn_output norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = ( norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] ) ff_output = self.ff(norm_hidden_states) ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = hidden_states + ff_output if len(attention_outputs) == 3: hidden_states = hidden_states + ip_attn_output # Process attention outputs for the `encoder_hidden_states`. context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output encoder_hidden_states = encoder_hidden_states + context_attn_output norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) norm_encoder_hidden_states = ( norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] ) context_ff_output = self.ff_context(norm_encoder_hidden_states) encoder_hidden_states = ( encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output ) if encoder_hidden_states.dtype == torch.float16: encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) return encoder_hidden_states, hidden_states class LibreFluxTransformer2DModel( ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin ): """ The Transformer model introduced in Flux. Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ Parameters: patch_size (`int`): Patch size to turn the input data into small patches. in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, patch_size: int = 1, in_channels: int = 64, num_layers: int = 19, num_single_layers: int = 38, attention_head_dim: int = 128, num_attention_heads: int = 24, joint_attention_dim: int = 4096, pooled_projection_dim: int = 768, guidance_embeds: bool = False, axes_dims_rope: Tuple[int] = (16, 56, 56), ): super().__init__() self.out_channels = in_channels self.inner_dim = ( self.config.num_attention_heads * self.config.attention_head_dim ) self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) text_time_guidance_cls = ( CombinedTimestepGuidanceTextProjEmbeddings ### 3 input forward (timestep, guidance, pooled_projection) if guidance_embeds else CombinedTimestepTextProjEmbeddings #### 2 input forward (timestep, pooled_projection) ) self.time_text_embed = text_time_guidance_cls( embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim, ) self.context_embedder = nn.Linear( self.config.joint_attention_dim, self.inner_dim ) self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) self.transformer_blocks = nn.ModuleList( [ FluxTransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim, ) for i in range(self.config.num_layers) ] ) self.single_transformer_blocks = nn.ModuleList( [ FluxSingleTransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim, ) for i in range(self.config.num_single_layers) ] ) self.norm_out = AdaLayerNormContinuous( self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6 ) self.proj_out = nn.Linear( self.inner_dim, patch_size * patch_size * self.out_channels, bias=True ) self.gradient_checkpointing = False # added for users to disable checkpointing every nth step self.gradient_checkpointing_interval = None def set_gradient_checkpointing_interval(self, value: int): self.gradient_checkpointing_interval = value @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors( name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor], ): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor( self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] ): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None, pooled_projections: torch.Tensor = None, timestep: torch.LongTensor = None, img_ids: torch.Tensor = None, txt_ids: torch.Tensor = None, guidance: torch.Tensor = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_block_samples=None, controlnet_single_block_samples=None, return_dict: bool = True, attention_mask: Optional[torch.Tensor] = None, controlnet_blocks_repeat: bool = False, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: """ The [`FluxTransformer2DModel`] forward method. Args: hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input `hidden_states`. encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected from the embeddings of input conditions. timestep ( `torch.LongTensor`): Used to indicate denoising step. block_controlnet_hidden_states: (`list` of `torch.Tensor`): A list of tensors that if specified are added to the residuals of transformer blocks. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain tuple. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: if ( joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None ): logger.warning( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) hidden_states = self.x_embedder(hidden_states) timestep = timestep.to(hidden_states.dtype) * 1000 if guidance is not None: guidance = guidance.to(hidden_states.dtype) * 1000 else: guidance = None #print( self.time_text_embed) temb = ( self.time_text_embed(timestep,pooled_projections) # Edit 1 # Charlie NOT NEEDED - UNDONE if guidance is None else self.time_text_embed(timestep, guidance, pooled_projections) ) encoder_hidden_states = self.context_embedder(encoder_hidden_states) if txt_ids.ndim == 3: txt_ids = txt_ids[0] if img_ids.ndim == 3: img_ids = img_ids[0] ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) # IP adapter if ( joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs ): ip_adapter_image_embeds = joint_attention_kwargs.pop( "ip_adapter_image_embeds" ) ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) for index_block, block in enumerate(self.transformer_blocks): if ( self.training and self.gradient_checkpointing and ( self.gradient_checkpointing_interval is None or index_block % self.gradient_checkpointing_interval == 0 ) ): def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = ( {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} ) encoder_hidden_states, hidden_states = ( torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **ckpt_kwargs, ) ) else: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, attention_mask=attention_mask, ) # controlnet residual if controlnet_block_samples is not None: interval_control = len(self.transformer_blocks) / len( controlnet_block_samples ) interval_control = int(np.ceil(interval_control)) # For Xlabs ControlNet. if controlnet_blocks_repeat: hidden_states = ( hidden_states + controlnet_block_samples[ index_block % len(controlnet_block_samples) ] ) else: hidden_states = ( hidden_states + controlnet_block_samples[index_block // interval_control] ) # Flux places the text tokens in front of the image tokens in the # sequence. hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): if ( self.training and self.gradient_checkpointing or ( self.gradient_checkpointing_interval is not None and index_block % self.gradient_checkpointing_interval == 0 ) ): def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = ( {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, temb, image_rotary_emb, attention_mask, **ckpt_kwargs, ) else: hidden_states = block( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, attention_mask=attention_mask, ) # controlnet residual if controlnet_single_block_samples is not None: interval_control = len(self.single_transformer_blocks) / len( controlnet_single_block_samples ) interval_control = int(np.ceil(interval_control)) hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( hidden_states[:, encoder_hidden_states.shape[1] :, ...] + controlnet_single_block_samples[index_block // interval_control] ) hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) #################################### ##### CONTROL NET MODEL MERGE ###### #################################### from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import PeftAdapterMixin from diffusers.models.attention_processor import AttentionProcessor from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding, zero_module from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed from diffusers.models.modeling_outputs import Transformer2DModelOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class FluxControlNetOutput(BaseOutput): controlnet_block_samples: Tuple[torch.Tensor] controlnet_single_block_samples: Tuple[torch.Tensor] class LibreFluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): _supports_gradient_checkpointing = True @register_to_config def __init__( self, patch_size: int = 1, in_channels: int = 64, num_layers: int = 19, num_single_layers: int = 38, attention_head_dim: int = 128, num_attention_heads: int = 24, joint_attention_dim: int = 4096, pooled_projection_dim: int = 768, guidance_embeds: bool = False, axes_dims_rope: List[int] = [16, 56, 56], num_mode: int = None, conditioning_embedding_channels: int = None, ): super().__init__() self.out_channels = in_channels self.inner_dim = num_attention_heads * attention_head_dim self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) # edit 19 #text_time_guidance_cls = ( # CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings #) text_time_guidance_cls = CombinedTimestepGuidanceTextProjEmbeddings text_time_cls = CombinedTimestepTextProjEmbeddings self.time_text_embed = text_time_cls( embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim ) self.time_text_guidance_embed = text_time_guidance_cls( embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim ) self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim) self.transformer_blocks = nn.ModuleList( [ FluxTransformerBlock( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, ) for i in range(num_layers) ] ) self.single_transformer_blocks = nn.ModuleList( [ FluxSingleTransformerBlock( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, ) for i in range(num_single_layers) ] ) # controlnet_blocks self.controlnet_blocks = nn.ModuleList([]) for _ in range(len(self.transformer_blocks)): self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) self.controlnet_single_blocks = nn.ModuleList([]) for _ in range(len(self.single_transformer_blocks)): self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) self.union = num_mode is not None if self.union: self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim) if conditioning_embedding_channels is not None: self.input_hint_block = ControlNetConditioningEmbedding( conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16) ) self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim) else: self.input_hint_block = None self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim)) self.gradient_checkpointing = False @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self): r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value @classmethod def from_transformer( cls, transformer, num_layers: int = 4, num_single_layers: int = 10, attention_head_dim: int = 128, num_attention_heads: int = 24, load_weights_from_transformer=True, ): config = dict(transformer.config) config["num_layers"] = num_layers config["num_single_layers"] = num_single_layers config["attention_head_dim"] = attention_head_dim config["num_attention_heads"] = num_attention_heads controlnet = cls.from_config(config) if load_weights_from_transformer: controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict()) controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict()) controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict()) controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict()) controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False) controlnet.single_transformer_blocks.load_state_dict( transformer.single_transformer_blocks.state_dict(), strict=False ) controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder) return controlnet # Edit 13 Adding attention masking to forward def forward( self, hidden_states: torch.Tensor, controlnet_cond: torch.Tensor, controlnet_mode: torch.Tensor = None, conditioning_scale: float = 1.0, encoder_hidden_states: torch.Tensor = None, pooled_projections: torch.Tensor = None, timestep: torch.LongTensor = None, img_ids: torch.Tensor = None, txt_ids: torch.Tensor = None, guidance: torch.Tensor = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, attention_mask: Optional[torch.Tensor] = None, # <-- 1. ADD ARGUMENT HERE ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: """ The [`FluxTransformer2DModel`] forward method. Args: hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input `hidden_states`. controlnet_cond (`torch.Tensor`): The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. controlnet_mode (`torch.Tensor`): The mode tensor of shape `(batch_size, 1)`. conditioning_scale (`float`, defaults to `1.0`): The scale factor for ControlNet outputs. encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected from the embeddings of input conditions. timestep ( `torch.LongTensor`): Used to indicate denoising step. block_controlnet_hidden_states: (`list` of `torch.Tensor`): A list of tensors that if specified are added to the residuals of transformer blocks. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain tuple. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: logger.warning( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) hidden_states = self.x_embedder(hidden_states) if self.input_hint_block is not None: controlnet_cond = self.input_hint_block(controlnet_cond) batch_size, channels, height_pw, width_pw = controlnet_cond.shape height = height_pw // self.config.patch_size width = width_pw // self.config.patch_size controlnet_cond = controlnet_cond.reshape( batch_size, channels, height, self.config.patch_size, width, self.config.patch_size ) controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5) controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1) # add hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond) timestep = timestep.to(hidden_states.dtype) * 1000 if guidance is not None: guidance = guidance.to(hidden_states.dtype) * 1000 else: guidance = None #print ('Guidance:', guidance) temb = ( self.time_text_embed(timestep, pooled_projections) if guidance is None # edit 19 else self.time_text_guidance_embed(timestep, guidance, pooled_projections) ) encoder_hidden_states = self.context_embedder(encoder_hidden_states) if self.union: # union mode if controlnet_mode is None: raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union") # union mode emb controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode) encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1) txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0) if txt_ids.ndim == 3: logger.warning( "Passing `txt_ids` 3d torch.Tensor is deprecated." "Please remove the batch dimension and pass it as a 2d torch Tensor" ) txt_ids = txt_ids[0] if img_ids.ndim == 3: logger.warning( "Passing `img_ids` 3d torch.Tensor is deprecated." "Please remove the batch dimension and pass it as a 2d torch Tensor" ) img_ids = img_ids[0] ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) block_samples = () for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, # Edit 13 **ckpt_kwargs, ) else: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, attention_mask=attention_mask, # Edit 13 ) block_samples = block_samples + (hidden_states,) hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) single_block_samples = () for index_block, block in enumerate(self.single_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, temb, image_rotary_emb, attention_mask, # <-- 2. PASS MASK TO GRADIENT CHECKPOINTING # Edit 13 **ckpt_kwargs, ) else: hidden_states = block( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, attention_mask=attention_mask, # <-- 2. PASS MASK TO BLOCK Edit 13 ) single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) # controlnet block controlnet_block_samples = () for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks): block_sample = controlnet_block(block_sample) controlnet_block_samples = controlnet_block_samples + (block_sample,) controlnet_single_block_samples = () for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks): single_block_sample = controlnet_block(single_block_sample) controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,) # scaling controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples] controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples] controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples controlnet_single_block_samples = ( None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples ) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) if not return_dict: return (controlnet_block_samples, controlnet_single_block_samples) return FluxControlNetOutput( controlnet_block_samples=controlnet_block_samples, controlnet_single_block_samples=controlnet_single_block_samples, ) #################################### ##### ACTUAL PIPELINE STUFF ######## #################################### from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils import ( USE_PEFT_BACKEND, is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from diffusers.utils.torch_utils import randn_tensor from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False # TODO(Chris): why won't this emit messages at the INFO level??? logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers.utils import load_image >>> from diffusers import FluxControlNetPipeline >>> from diffusers import FluxControlNetModel >>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny" >>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16) >>> pipe = FluxControlNetPipeline.from_pretrained( ... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16 ... ) >>> pipe.to("cuda") >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg") >>> prompt = "A girl in city, 25 years old, cool, futuristic" >>> image = pipe( ... prompt, ... control_image=control_image, ... controlnet_conditioning_scale=0.6, ... num_inference_steps=28, ... guidance_scale=3.5, ... ).images[0] >>> image.save("flux.png") ``` """ def _maybe_to(x: torch.Tensor, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): if device is None and dtype is None: return x need_dev = device is not None and str(getattr(x, "device", None)) != str(device) need_dt = dtype is not None and getattr(x, "dtype", None) != dtype return x.to(device=device if need_dev else x.device, dtype=dtype if need_dt else x.dtype) if (need_dev or need_dt) else x # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift def calculate_shift( image_seq_len, base_seq_len: int = 256, max_seq_len: int = 4096, base_shift: float = 0.5, max_shift: float = 1.16, ): m = (max_shift - base_shift) / (max_seq_len - base_seq_len) b = base_shift - m * base_seq_len mu = image_seq_len * m + b return mu # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class LibreFluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): r""" The Flux pipeline for text-to-image generation. Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ Args: transformer ([`FluxTransformer2DModel`]): Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. text_encoder_2 ([`T5EncoderModel`]): [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_2 (`T5TokenizerFast`): Second Tokenizer of class [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). """ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" _optional_components = [] _callback_tensor_inputs = ["latents", "prompt_embeds"] def __init__( self, scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, text_encoder_2: T5EncoderModel, tokenizer_2: T5TokenizerFast, transformer: LibreFluxTransformer2DModel, controlnet: Union[ LibreFluxControlNetModel, List[LibreFluxControlNetModel], Tuple[LibreFluxControlNetModel], ], ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, transformer=transformer, scheduler=scheduler, controlnet=controlnet, ) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) self.default_sample_size = 64 def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, max_sequence_length: int = 512, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = self.tokenizer_2( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {max_sequence_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder_2(text_input_ids.to(self.text_encoder_2.device), output_hidden_states=False)[0] #prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] dtype = self.text_encoder_2.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # ADD THIS: Get the attention mask and repeat it for each image prompt_attention_mask = text_inputs.attention_mask.to(device=device, dtype=dtype) prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) # ADD THIS: Return the attention mask return prompt_embeds, prompt_attention_mask def _get_clip_prompt_embeds( self, prompt: Union[str, List[str]], num_images_per_prompt: int = 1, device: Optional[torch.device] = None, ): device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer_max_length, truncation=True, return_overflowing_tokens=False, return_length=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer_max_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder(text_input_ids.to(self.text_encoder.device), output_hidden_states=False) #prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds def encode_prompt( self, prompt: Union[str, List[str]], prompt_2: Union[str, List[str]], device: Optional[torch.device] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, max_sequence_length: int = 512, lora_scale: Optional[float] = None, ): device = device or self._execution_device if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): self._lora_scale = lora_scale if self.text_encoder is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 pooled_prompt_embeds = self._get_clip_prompt_embeds( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, ) # ADD THIS: Initialize mask and capture it from the T5 embedder prompt_attention_mask = None prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( prompt=prompt_2, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, device=device, ) if self.text_encoder is not None: if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: unscale_lora_layers(self.text_encoder_2, lora_scale) # FIX: Get batch_size and create text_ids with the correct shape batch_size = prompt_embeds.shape[0] dtype = self.transformer.dtype text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids, prompt_attention_mask def check_inputs( self, prompt, prompt_2, height, width, prompt_embeds=None, pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids # FIX: Correctly creates batched image IDs def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_ids = torch.zeros(height // 2, width // 2, 3) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] latent_image_ids = latent_image_ids.unsqueeze(0).repeat(batch_size, 1, 1, 1) latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape[1:] latent_image_ids = latent_image_ids.reshape( batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels ) return latent_image_ids.to(device=device, dtype=dtype) @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents def _pack_latents(latents, batch_size, num_channels_latents, height, width): latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.permute(0, 2, 4, 1, 3, 5) latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) return latents @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents def _unpack_latents(latents, height, width, vae_scale_factor): batch_size, num_patches, channels = latents.shape height = height // vae_scale_factor width = width // vae_scale_factor latents = latents.view(batch_size, height, width, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) return latents # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, ): height = 2 * (int(height) // self.vae_scale_factor) width = 2 * (int(width) // self.vae_scale_factor) shape = (batch_size, num_channels_latents, height, width) if latents is not None: latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) return latents.to(device=device, dtype=dtype), latent_image_ids if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) return latents, latent_image_ids # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image def prepare_image( self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance=False, guess_mode=False, ): if isinstance(image, torch.Tensor): pass else: image = self.image_processor.preprocess(image, height=height, width=width) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) if do_classifier_free_guidance and not guess_mode: image = torch.cat([image] * 2) return image @property def guidance_scale(self): return self._guidance_scale @property def joint_attention_kwargs(self): return self._joint_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 28, timesteps: List[int] = None, guidance_scale: float = 7.0, control_image: PipelineImageInput = None, control_mode: Optional[Union[int, List[int]]] = None, control_image_undo_centering: bool = False, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, negative_prompt: Optional[Union[str, List[str]]] = "", negative_prompt_2: Optional[Union[str, List[str]]] = "", negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is will be used instead height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, images must be passed as a list such that each element of the list can be correctly batched for input to a single ControlNet. controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set the corresponding scale as a list. control_mode (`int` or `List[int]`,, *optional*, defaults to None): The control mode when applying ControlNet-Union. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. Examples: Returns: [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, height, width, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) self._guidance_scale = guidance_scale self._joint_attention_kwargs = joint_attention_kwargs self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device dtype = self.transformer.dtype lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) # 💡 ADD THIS: Capture the attention_mask from encode_prompt ( prompt_embeds, pooled_prompt_embeds, text_ids, attention_mask, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) # ✨ FIX: Encode negative prompts for CFG do_classifier_free_guidance = guidance_scale > 1.0 if do_classifier_free_guidance: if negative_prompt_embeds is None or negative_pooled_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt_2 = negative_prompt_2 or negative_prompt (negative_prompt_embeds, negative_pooled_prompt_embeds, negative_text_ids, negative_attention_mask) = self.encode_prompt( prompt=negative_prompt, prompt_2=negative_prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) # 3. Prepare control image num_channels_latents = self.transformer.config.in_channels // 4 if type(self.controlnet) == FullyShardedDataParallel: inner_module = self.controlnet._fsdp_wrapped_module else: inner_module = self.controlnet if isinstance(inner_module, LibreFluxControlNetModel): control_image = self.prepare_image( image=control_image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=dtype, ) if control_image_undo_centering: if not self.image_processor.do_normalize: raise ValueError( "`control_image_undo_centering` only makes sense if `do_normalize==True` in the image processor" ) control_image = control_image*0.5 + 0.5 height, width = control_image.shape[-2:] #logger.warning( # f"pipeline_flux_controlnet, control_image: {control_image.min()} {control_image.max()}" #) # vae encode control_image = _maybe_to(control_image, device=self.vae.device) control_image = self.vae.encode(control_image).latent_dist.sample() control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor control_image = _maybe_to(control_image, device=device) # pack height_control_image, width_control_image = control_image.shape[2:] control_image = self._pack_latents( control_image, batch_size * num_images_per_prompt, num_channels_latents, height_control_image, width_control_image, ) # set control mode if control_mode is not None: control_mode = torch.tensor(control_mode).to(device, dtype=torch.long) control_mode = control_mode.reshape([-1, 1]) # set control mode control_mode_ = [] if isinstance(control_mode, list): for cmode in control_mode: if cmode is None: control_mode_.append(-1) else: control_mode_.append(cmode) control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long) control_mode = control_mode.reshape([-1, 1]) # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 latents, latent_image_ids = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, self.scheduler.config.base_image_seq_len, self.scheduler.config.max_image_seq_len, self.scheduler.config.base_shift, self.scheduler.config.max_shift, ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu, ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) # 6. Denoising loop target_device = self.transformer.device self.controlnet.to(target_device) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # FIX: BATCH INPUTS FOR CFG if do_classifier_free_guidance: latent_model_input = torch.cat([latents] * 2) current_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) current_pooled_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds]) current_attention_mask = torch.cat([negative_attention_mask, attention_mask]) current_text_ids = torch.cat([negative_text_ids, text_ids]) current_img_ids = torch.cat([latent_image_ids] * 2) current_control_image = torch.cat([control_image] * 2) if isinstance(control_image, torch.Tensor) else [torch.cat([c_img] * 2) for c_img in control_image] else: latent_model_input = latents current_prompt_embeds = prompt_embeds current_pooled_embeds = pooled_prompt_embeds current_attention_mask = attention_mask current_text_ids = text_ids current_img_ids = latent_image_ids current_control_image = control_image # FIX: Integrate with device handling target_device = self.transformer.device # Move all inputs to the target device latent_model_input = _maybe_to(latent_model_input, device=target_device) current_prompt_embeds = _maybe_to(current_prompt_embeds, device=target_device) current_pooled_embeds = _maybe_to(current_pooled_embeds, device=target_device) current_attention_mask = _maybe_to(current_attention_mask, device=target_device) current_text_ids = _maybe_to(current_text_ids, device=target_device) current_img_ids = _maybe_to(current_img_ids, device=target_device) if isinstance(current_control_image, torch.Tensor): current_control_image = _maybe_to(current_control_image, device=target_device) else: current_control_image = [ _maybe_to(c, device=target_device) for c in current_control_image ] control_mode = _maybe_to(control_mode, device=target_device) if control_mode is not None else None t_model = t.expand(latent_model_input.shape[0]).to(target_device) # Model calls controlnet_block_samples, controlnet_single_block_samples = self.controlnet( hidden_states=latent_model_input, controlnet_cond=current_control_image, controlnet_mode=control_mode, conditioning_scale=controlnet_conditioning_scale, timestep=(t_model / 1000), guidance=None, pooled_projections=current_pooled_embeds, encoder_hidden_states=current_prompt_embeds, attention_mask=current_attention_mask, txt_ids=current_text_ids, img_ids=current_img_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False ) controlnet_block_samples = [elem.to(dtype=latents.dtype, device=target_device) for elem in controlnet_block_samples] controlnet_single_block_samples = [elem.to(dtype=latents.dtype, device=target_device) for elem in controlnet_single_block_samples] noise_pred = self.transformer( hidden_states=latent_model_input, timestep=(t_model / 1000), guidance=None, pooled_projections=current_pooled_embeds, encoder_hidden_states=current_prompt_embeds, attention_mask=current_attention_mask, controlnet_block_samples=controlnet_block_samples, controlnet_single_block_samples=controlnet_single_block_samples, txt_ids=current_text_ids, img_ids=current_img_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False )[0] # FIX: Apply CFG formula if do_classifier_free_guidance: noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) ## Probably not needed #noise_pred = noise_pred.to(latents.device) latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() if output_type == "latent": image = latents else: latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor latents = _maybe_to(latents, device=self.vae.device) image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return FluxPipelineOutput(images=image)