LibreFlux-ControlNet / backup_pipeline.py
neuralvfx's picture
Initial upload of LibreFlux ControlNet pipeline
addfdf3 verified
raw
history blame
121 kB
# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This was modied from the control net repo
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
import numpy as np
import torch
from transformers import (
CLIPTextModel,
CLIPTokenizer,
T5EncoderModel,
T5TokenizerFast,
)
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin
from diffusers.models.autoencoders import AutoencoderKL
### MERGEING THESE ###
# from src.models.transformer import FluxTransformer2DModel
# from src.models.controlnet_flux import FluxControlNetModel
#############
##########################################
########### ATTENTION MERGE ##############
##########################################
import torch
from torch import Tensor, FloatTensor
from torch.nn import functional as F
from einops import rearrange
from diffusers.models.attention_processor import Attention
from diffusers.models.embeddings import apply_rotary_emb
#try:
# from flash_attn_interface import flash_attn_func, flash_attn_qkvpacked_func
#except:
# pass
"""def fa3_sdpa(
q,
k,
v,
):
# flash attention 3 sdpa drop-in replacement
q, k, v = [x.permute(0, 2, 1, 3) for x in [q, k, v]]
out = flash_attn_func(q, k, v)[0]
return out.permute(0, 2, 1, 3)"""
"""
class FluxSingleAttnProcessor3_0:
r""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn,
hidden_states: Tensor,
encoder_hidden_states: Tensor = None,
attention_mask: FloatTensor = None,
image_rotary_emb: Tensor = None,
) -> Tensor:
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
batch_size, _, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
# hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = fa3_sdpa(query, key, value)
hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)")
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
return hidden_states
class FluxAttnProcessor3_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FluxAttnProcessor3_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn,
hidden_states: FloatTensor,
encoder_hidden_states: FloatTensor = None,
attention_mask: FloatTensor = None,
image_rotary_emb: Tensor = None,
) -> FloatTensor:
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
batch_size = encoder_hidden_states.shape[0]
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(
encoder_hidden_states_query_proj
)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(
encoder_hidden_states_key_proj
)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
# hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = fa3_sdpa(query, key, value)
hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)")
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
if context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
return hidden_states, encoder_hidden_states
class FluxFusedFlashAttnProcessor3(object):
"""
True fused QKV Flash Attention 3 processor for Flux models.
Keeps QKV tensors packed through the entire attention computation.
"""
def __init__(self):
self.flash_attn_qkvpacked_func = None
try:
from flash_attn_interface import flash_attn_qkvpacked_func
self.flash_attn_qkvpacked_func = flash_attn_qkvpacked_func
except ImportError:
raise ImportError(
"FluxFusedFlashAttnProcessor3 requires flash-attn library. "
"Please see this link for Hopper and Blackwell instructions: https://github.com/bghira/SimpleTuner/blob/main/INSTALL.md#nvidia-hopper--blackwell-follow-up-steps"
)
def __call__(
self,
attn,
hidden_states: FloatTensor,
encoder_hidden_states: FloatTensor = None,
attention_mask: FloatTensor = None,
image_rotary_emb: Tensor = None,
) -> FloatTensor:
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
context_input_ndim = (
encoder_hidden_states.ndim if encoder_hidden_states is not None else None
)
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
batch_size = (
encoder_hidden_states.shape[0]
if encoder_hidden_states is not None
else hidden_states.shape[0]
)
seq_len = hidden_states.shape[1]
# Fused QKV projection
qkv = attn.to_qkv(hidden_states) # (batch, seq_len, 3 * inner_dim)
inner_dim = qkv.shape[-1] // 3
head_dim = inner_dim // attn.heads
# Reshape to packed format: (batch, seq_len, 3, heads, head_dim)
qkv = qkv.view(batch_size, seq_len, 3, attn.heads, head_dim)
# Apply norms if needed (requires temporary unpacking)
if attn.norm_q is not None or attn.norm_k is not None:
q, k, v = qkv.unbind(dim=2) # Each is (batch, seq_len, heads, head_dim)
q = q.transpose(1, 2) # (batch, heads, seq_len, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if attn.norm_q is not None:
q = attn.norm_q(q)
if attn.norm_k is not None:
k = attn.norm_k(k)
# Repack: back to (batch, seq_len, 3, heads, head_dim)
qkv = torch.stack(
[q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)], dim=2
)
# Handle encoder states if present
if encoder_hidden_states is not None:
encoder_seq_len = encoder_hidden_states.shape[1]
# Fused encoder QKV
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
encoder_qkv = encoder_qkv.view(
batch_size, encoder_seq_len, 3, attn.heads, head_dim
)
# Apply norms if needed
if attn.norm_added_q is not None or attn.norm_added_k is not None:
enc_q, enc_k, enc_v = encoder_qkv.unbind(dim=2)
enc_q = enc_q.transpose(1, 2)
enc_k = enc_k.transpose(1, 2)
enc_v = enc_v.transpose(1, 2)
if attn.norm_added_q is not None:
enc_q = attn.norm_added_q(enc_q)
if attn.norm_added_k is not None:
enc_k = attn.norm_added_k(enc_k)
encoder_qkv = torch.stack(
[
enc_q.transpose(1, 2),
enc_k.transpose(1, 2),
enc_v.transpose(1, 2),
],
dim=2,
)
# Concatenate along sequence dimension
qkv = torch.cat(
[encoder_qkv, qkv], dim=1
) # (batch, encoder_seq + seq, 3, heads, head_dim)
# Apply RoPE if needed
if image_rotary_emb is not None:
q, k, v = qkv.unbind(dim=2) # Each is (batch, seq_len, heads, head_dim)
# Transpose to (batch, heads, seq_len, head_dim) for RoPE
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# Apply RoPE to q and k
q = apply_rotary_emb(q, image_rotary_emb)
k = apply_rotary_emb(k, image_rotary_emb)
# Transpose back and repack
qkv = torch.stack(
[q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)], dim=2
)
# Flash Attention 3 with packed QKV
# Input shape: (batch, seq_len, 3, heads, head_dim)
# Output shape: (batch, seq_len, heads, head_dim)
hidden_states = self.flash_attn_qkvpacked_func(
qkv,
causal=False,
# Don't pass num_heads_q for standard MHA
)
# Reshape output: (batch, seq_len, heads, head_dim) -> (batch, seq_len, heads * head_dim)
hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(qkv.dtype)
# Split and process outputs
if encoder_hidden_states is not None:
encoder_seq_len = encoder_hidden_states.shape[1]
encoder_hidden_states = hidden_states[:, :encoder_seq_len]
hidden_states = hidden_states[:, encoder_seq_len:]
# Output projections
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states) # dropout
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
# Reshape if needed
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
if context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
return hidden_states, encoder_hidden_states
else:
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
return hidden_states
"""
class FluxFusedSDPAProcessor:
"""
Fused QKV processor using PyTorch's scaled_dot_product_attention.
Uses fused projections but splits for attention computation.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FluxFusedSDPAProcessor requires PyTorch 2.0+ for scaled_dot_product_attention"
)
def __call__(
self,
attn,
hidden_states: FloatTensor,
encoder_hidden_states: FloatTensor = None,
attention_mask: FloatTensor = None,
image_rotary_emb: Tensor = None,
) -> FloatTensor:
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
context_input_ndim = (
encoder_hidden_states.ndim if encoder_hidden_states is not None else None
)
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
batch_size = (
encoder_hidden_states.shape[0]
if encoder_hidden_states is not None
else hidden_states.shape[0]
)
# Single attention case (no encoder states)
if encoder_hidden_states is None:
# Use fused QKV projection
qkv = attn.to_qkv(hidden_states) # (batch, seq_len, 3 * inner_dim)
inner_dim = qkv.shape[-1] // 3
head_dim = inner_dim // attn.heads
seq_len = hidden_states.shape[1]
# Split and reshape
qkv = qkv.view(batch_size, seq_len, 3, attn.heads, head_dim)
query, key, value = qkv.unbind(
dim=2
) # Each is (batch, seq_len, heads, head_dim)
# Transpose to (batch, heads, seq_len, head_dim)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# Apply norms if needed
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
# SDPA
hidden_states = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
)
# Reshape back
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
return hidden_states
# Joint attention case (with encoder states)
else:
# Process self-attention QKV
qkv = attn.to_qkv(hidden_states)
inner_dim = qkv.shape[-1] // 3
head_dim = inner_dim // attn.heads
seq_len = hidden_states.shape[1]
qkv = qkv.view(batch_size, seq_len, 3, attn.heads, head_dim)
query, key, value = qkv.unbind(dim=2)
# Transpose to (batch, heads, seq_len, head_dim)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# Apply norms if needed
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Process encoder QKV
encoder_seq_len = encoder_hidden_states.shape[1]
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
encoder_qkv = encoder_qkv.view(
batch_size, encoder_seq_len, 3, attn.heads, head_dim
)
encoder_query, encoder_key, encoder_value = encoder_qkv.unbind(dim=2)
# Transpose to (batch, heads, seq_len, head_dim)
encoder_query = encoder_query.transpose(1, 2)
encoder_key = encoder_key.transpose(1, 2)
encoder_value = encoder_value.transpose(1, 2)
# Apply encoder norms if needed
if attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query)
if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)
# Concatenate encoder and self-attention
query = torch.cat([encoder_query, query], dim=2)
key = torch.cat([encoder_key, key], dim=2)
value = torch.cat([encoder_value, value], dim=2)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
# SDPA
hidden_states = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
)
# Reshape: (batch, heads, seq_len, head_dim) -> (batch, seq_len, heads * head_dim)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)
# Split encoder and self outputs
encoder_hidden_states = hidden_states[:, :encoder_seq_len]
hidden_states = hidden_states[:, encoder_seq_len:]
# Output projections
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states) # dropout
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
# Reshape if needed
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
if context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
return hidden_states, encoder_hidden_states
class FluxSingleFusedSDPAProcessor:
"""
Fused QKV processor for single attention (no encoder states).
Simpler version for self-attention only blocks.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FluxSingleFusedSDPAProcessor requires PyTorch 2.0+ for scaled_dot_product_attention"
)
def __call__(
self,
attn,
hidden_states: Tensor,
encoder_hidden_states: Tensor = None,
attention_mask: FloatTensor = None,
image_rotary_emb: Tensor = None,
) -> Tensor:
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
batch_size, seq_len, _ = hidden_states.shape
# Use fused QKV projection
qkv = attn.to_qkv(hidden_states) # (batch, seq_len, 3 * inner_dim)
inner_dim = qkv.shape[-1] // 3
head_dim = inner_dim // attn.heads
# Split and reshape in one go
qkv = qkv.view(batch_size, seq_len, 3, attn.heads, head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, L, D) – still strided
query, key, value = [
t.contiguous() for t in qkv.unbind(0) # make each view dense
]
# Now each is (batch, heads, seq_len, head_dim)
# Apply norms if needed
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
# SDPA
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
# Reshape back
hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)")
hidden_states = hidden_states.to(query.dtype)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
return hidden_states
#################################
##### TRANSFORMER MERGE #########
#################################
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
from diffusers.models.attention import FeedForward
from diffusers.models.attention_processor import (
Attention,
AttentionProcessor,
)
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import (
AdaLayerNormContinuous,
AdaLayerNormZero,
AdaLayerNormZeroSingle,
)
from diffusers.utils import (
USE_PEFT_BACKEND,
is_torch_version,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import maybe_allow_in_graph
from diffusers.models.embeddings import (
CombinedTimestepGuidanceTextProjEmbeddings,
CombinedTimestepTextProjEmbeddings,
FluxPosEmbed,
)
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers import FluxTransformer2DModel as OriginalFluxTransformer2DModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
is_flash_attn_available = False
"""try:
from flash_attn_interface import flash_attn_func
is_flash_attn_available = True
except:
pass"""
class FluxAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(
encoder_hidden_states_query_proj
)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(
encoder_hidden_states_key_proj
)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
if attention_mask is not None:
#print ('Attention Used')
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = (attention_mask > 0).bool()
# Edit 17 - match attn dtype to query d-type
attention_mask = attention_mask.to(
device=hidden_states.device, dtype=query.dtype
)
hidden_states = F.scaled_dot_product_attention(
query,
key,
value,
dropout_p=0.0,
is_causal=False,
attn_mask=attention_mask,
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
return hidden_states
def expand_flux_attention_mask(
hidden_states: torch.Tensor,
attn_mask: torch.Tensor,
) -> torch.Tensor:
"""
Expand a mask so that the image is included.
"""
bsz = attn_mask.shape[0]
assert bsz == hidden_states.shape[0]
residual_seq_len = hidden_states.shape[1]
mask_seq_len = attn_mask.shape[1]
expanded_mask = torch.ones(bsz, residual_seq_len)
expanded_mask[:, :mask_seq_len] = attn_mask
return expanded_mask
@maybe_allow_in_graph
class FluxSingleTransformerBlock(nn.Module):
r"""
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
Reference: https://arxiv.org/abs/2403.03206
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
processing of `context` conditions.
"""
def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)
self.norm = AdaLayerNormZeroSingle(dim)
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
processor = FluxAttnProcessor2_0()
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
qk_norm="rms_norm",
eps=1e-6,
pre_only=True,
)
def forward(
self,
hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
image_rotary_emb=None,
attention_mask: Optional[torch.Tensor] = None,
):
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
if attention_mask is not None:
attention_mask = expand_flux_attention_mask(
hidden_states,
attention_mask,
)
attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
)
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
gate = gate.unsqueeze(1)
hidden_states = gate * self.proj_out(hidden_states)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return hidden_states
@maybe_allow_in_graph
class FluxTransformerBlock(nn.Module):
r"""
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
Reference: https://arxiv.org/abs/2403.03206
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
processing of `context` conditions.
"""
def __init__(
self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6
):
super().__init__()
self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
if hasattr(F, "scaled_dot_product_attention"):
processor = FluxAttnProcessor2_0()
else:
raise ValueError(
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
)
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
processor=processor,
qk_norm=qk_norm,
eps=eps,
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(
dim=dim, dim_out=dim, activation_fn="gelu-approximate"
)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
image_rotary_emb=None,
attention_mask: Optional[torch.Tensor] = None,
):
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, emb=temb
)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
self.norm1_context(encoder_hidden_states, emb=temb)
)
if attention_mask is not None:
attention_mask = expand_flux_attention_mask(
torch.cat([encoder_hidden_states, hidden_states], dim=1),
attention_mask,
)
# Attention.
attention_outputs = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
)
if len(attention_outputs) == 2:
attn_output, context_attn_output = attention_outputs
elif len(attention_outputs) == 3:
attn_output, context_attn_output, ip_attn_output = attention_outputs
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = (
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
)
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output
if len(attention_outputs) == 3:
hidden_states = hidden_states + ip_attn_output
# Process attention outputs for the `encoder_hidden_states`.
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = (
norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])
+ c_shift_mlp[:, None]
)
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = (
encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
)
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return encoder_hidden_states, hidden_states
class LibreFluxTransformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin
):
"""
The Transformer model introduced in Flux.
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
Parameters:
patch_size (`int`): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int] = (16, 56, 56),
):
super().__init__()
self.out_channels = in_channels
self.inner_dim = (
self.config.num_attention_heads * self.config.attention_head_dim
)
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
text_time_guidance_cls = (
CombinedTimestepGuidanceTextProjEmbeddings ### 3 input forward (timestep, guidance, pooled_projection)
if guidance_embeds
else CombinedTimestepTextProjEmbeddings #### 2 input forward (timestep, pooled_projection)
)
self.time_text_embed = text_time_guidance_cls(
embedding_dim=self.inner_dim,
pooled_projection_dim=self.config.pooled_projection_dim,
)
self.context_embedder = nn.Linear(
self.config.joint_attention_dim, self.inner_dim
)
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_single_layers)
]
)
self.norm_out = AdaLayerNormContinuous(
self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
)
self.proj_out = nn.Linear(
self.inner_dim, patch_size * patch_size * self.out_channels, bias=True
)
self.gradient_checkpointing = False
# added for users to disable checkpointing every nth step
self.gradient_checkpointing_interval = None
def set_gradient_checkpointing_interval(self, value: int):
self.gradient_checkpointing_interval = value
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(
name: str,
module: torch.nn.Module,
processors: Dict[str, AttentionProcessor],
):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
return_dict: bool = True,
attention_mask: Optional[torch.Tensor] = None,
controlnet_blocks_repeat: bool = False,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if (
joint_attention_kwargs is not None
and joint_attention_kwargs.get("scale", None) is not None
):
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None
#print( self.time_text_embed)
temb = (
self.time_text_embed(timestep,pooled_projections)
# Edit 1 # Charlie NOT NEEDED - UNDONE
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if txt_ids.ndim == 3:
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
# IP adapter
if (
joint_attention_kwargs is not None
and "ip_adapter_image_embeds" in joint_attention_kwargs
):
ip_adapter_image_embeds = joint_attention_kwargs.pop(
"ip_adapter_image_embeds"
)
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
for index_block, block in enumerate(self.transformer_blocks):
if (
self.training
and self.gradient_checkpointing
and (
self.gradient_checkpointing_interval is None
or index_block % self.gradient_checkpointing_interval == 0
)
):
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
)
encoder_hidden_states, hidden_states = (
torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
attention_mask,
**ckpt_kwargs,
)
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
)
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(
controlnet_block_samples
)
interval_control = int(np.ceil(interval_control))
# For Xlabs ControlNet.
if controlnet_blocks_repeat:
hidden_states = (
hidden_states
+ controlnet_block_samples[
index_block % len(controlnet_block_samples)
]
)
else:
hidden_states = (
hidden_states
+ controlnet_block_samples[index_block // interval_control]
)
# Flux places the text tokens in front of the image tokens in the
# sequence.
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if (
self.training
and self.gradient_checkpointing
or (
self.gradient_checkpointing_interval is not None
and index_block % self.gradient_checkpointing_interval == 0
)
):
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
attention_mask,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
)
# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(
controlnet_single_block_samples
)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
####################################
##### CONTROL NET MODEL MERGE ######
####################################
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import PeftAdapterMixin
from diffusers.models.attention_processor import AttentionProcessor
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from diffusers.models.modeling_outputs import Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class FluxControlNetOutput(BaseOutput):
controlnet_block_samples: Tuple[torch.Tensor]
controlnet_single_block_samples: Tuple[torch.Tensor]
class LibreFluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: List[int] = [16, 56, 56],
num_mode: int = None,
conditioning_embedding_channels: int = None,
):
super().__init__()
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
# edit 19
#text_time_guidance_cls = (
# CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
#)
text_time_guidance_cls = CombinedTimestepGuidanceTextProjEmbeddings
text_time_cls = CombinedTimestepTextProjEmbeddings
self.time_text_embed = text_time_cls(
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)
self.time_text_guidance_embed = text_time_guidance_cls(
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for i in range(num_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for i in range(num_single_layers)
]
)
# controlnet_blocks
self.controlnet_blocks = nn.ModuleList([])
for _ in range(len(self.transformer_blocks)):
self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
self.controlnet_single_blocks = nn.ModuleList([])
for _ in range(len(self.single_transformer_blocks)):
self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
self.union = num_mode is not None
if self.union:
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
if conditioning_embedding_channels is not None:
self.input_hint_block = ControlNetConditioningEmbedding(
conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16)
)
self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
else:
self.input_hint_block = None
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self):
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
@classmethod
def from_transformer(
cls,
transformer,
num_layers: int = 4,
num_single_layers: int = 10,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
load_weights_from_transformer=True,
):
config = dict(transformer.config)
config["num_layers"] = num_layers
config["num_single_layers"] = num_single_layers
config["attention_head_dim"] = attention_head_dim
config["num_attention_heads"] = num_attention_heads
controlnet = cls.from_config(config)
if load_weights_from_transformer:
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
controlnet.single_transformer_blocks.load_state_dict(
transformer.single_transformer_blocks.state_dict(), strict=False
)
controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
return controlnet
# Edit 13 Adding attention masking to forward
def forward(
self,
hidden_states: torch.Tensor,
controlnet_cond: torch.Tensor,
controlnet_mode: torch.Tensor = None,
conditioning_scale: float = 1.0,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
attention_mask: Optional[torch.Tensor] = None, # <-- 1. ADD ARGUMENT HERE
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
controlnet_cond (`torch.Tensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
controlnet_mode (`torch.Tensor`):
The mode tensor of shape `(batch_size, 1)`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
if self.input_hint_block is not None:
controlnet_cond = self.input_hint_block(controlnet_cond)
batch_size, channels, height_pw, width_pw = controlnet_cond.shape
height = height_pw // self.config.patch_size
width = width_pw // self.config.patch_size
controlnet_cond = controlnet_cond.reshape(
batch_size, channels, height, self.config.patch_size, width, self.config.patch_size
)
controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5)
controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1)
# add
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None
#print ('Guidance:', guidance)
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
# edit 19
else self.time_text_guidance_embed(timestep, guidance, pooled_projections)
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if self.union:
# union mode
if controlnet_mode is None:
raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
# union mode emb
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
block_samples = ()
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
attention_mask, # Edit 13
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask, # Edit 13
)
block_samples = block_samples + (hidden_states,)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
single_block_samples = ()
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
attention_mask, # <-- 2. PASS MASK TO GRADIENT CHECKPOINTING # Edit 13
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask, # <-- 2. PASS MASK TO BLOCK Edit 13
)
single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
# controlnet block
controlnet_block_samples = ()
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
block_sample = controlnet_block(block_sample)
controlnet_block_samples = controlnet_block_samples + (block_sample,)
controlnet_single_block_samples = ()
for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks):
single_block_sample = controlnet_block(single_block_sample)
controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
# scaling
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
controlnet_single_block_samples = (
None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (controlnet_block_samples, controlnet_single_block_samples)
return FluxControlNetOutput(
controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples,
)
####################################
##### ACTUAL PIPELINE STUFF ########
####################################
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
# TODO(Chris): why won't this emit messages at the INFO level???
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers.utils import load_image
>>> from diffusers import FluxControlNetPipeline
>>> from diffusers import FluxControlNetModel
>>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny"
>>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
>>> pipe = FluxControlNetPipeline.from_pretrained(
... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16
... )
>>> pipe.to("cuda")
>>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
>>> prompt = "A girl in city, 25 years old, cool, futuristic"
>>> image = pipe(
... prompt,
... control_image=control_image,
... controlnet_conditioning_scale=0.6,
... num_inference_steps=28,
... guidance_scale=3.5,
... ).images[0]
>>> image.save("flux.png")
```
"""
def _maybe_to(x: torch.Tensor, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
if device is None and dtype is None:
return x
need_dev = device is not None and str(getattr(x, "device", None)) != str(device)
need_dt = dtype is not None and getattr(x, "dtype", None) != dtype
return x.to(device=device if need_dev else x.device, dtype=dtype if need_dt else x.dtype) if (need_dev or need_dt) else x
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.16,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class LibreFluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
r"""
The Flux pipeline for text-to-image generation.
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
Args:
transformer ([`FluxTransformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
text_encoder_2 ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
tokenizer_2 (`T5TokenizerFast`):
Second Tokenizer of class
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
text_encoder_2: T5EncoderModel,
tokenizer_2: T5TokenizerFast,
transformer: LibreFluxTransformer2DModel,
controlnet: Union[
LibreFluxControlNetModel, List[LibreFluxControlNetModel], Tuple[LibreFluxControlNetModel],
],
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
controlnet=controlnet,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
self.default_sample_size = 64
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer_2(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder_2(text_input_ids.to(self.text_encoder_2.device), output_hidden_states=False)[0]
#prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
dtype = self.text_encoder_2.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# ADD THIS: Get the attention mask and repeat it for each image
prompt_attention_mask = text_inputs.attention_mask.to(device=device, dtype=dtype)
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
# ADD THIS: Return the attention mask
return prompt_embeds, prompt_attention_mask
def _get_clip_prompt_embeds(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
):
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer_max_length,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(self.text_encoder.device), output_hidden_states=False)
#prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
):
device = device or self._execution_device
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
pooled_prompt_embeds = self._get_clip_prompt_embeds(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
)
# ADD THIS: Initialize mask and capture it from the T5 embedder
prompt_attention_mask = None
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
prompt=prompt_2,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
if self.text_encoder is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
unscale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
unscale_lora_layers(self.text_encoder_2, lora_scale)
# FIX: Get batch_size and create text_ids with the correct shape
batch_size = prompt_embeds.shape[0]
dtype = self.transformer.dtype
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds, text_ids, prompt_attention_mask
def check_inputs(
self,
prompt,
prompt_2,
height,
width,
prompt_embeds=None,
pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt_2 is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
# FIX: Correctly creates batched image IDs
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_ids = latent_image_ids.unsqueeze(0).repeat(batch_size, 1, 1, 1)
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape[1:]
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor
width = width // vae_scale_factor
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
return latents
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
height = 2 * (int(height) // self.vae_scale_factor)
width = 2 * (int(width) // self.vae_scale_factor)
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
return latents, latent_image_ids
# Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
def prepare_image(
self,
image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance=False,
guess_mode=False,
):
if isinstance(image, torch.Tensor):
pass
else:
image = self.image_processor.preprocess(image, height=height, width=width)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance and not guess_mode:
image = torch.cat([image] * 2)
return image
@property
def guidance_scale(self):
return self._guidance_scale
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 7.0,
control_image: PipelineImageInput = None,
control_mode: Optional[Union[int, List[int]]] = None,
control_image_undo_centering: bool = False,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
negative_prompt: Optional[Union[str, List[str]]] = "",
negative_prompt_2: Optional[Union[str, List[str]]] = "",
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
will be used instead
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
images must be passed as a list such that each element of the list can be correctly batched for input
to a single ControlNet.
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
the corresponding scale as a list.
control_mode (`int` or `List[int]`,, *optional*, defaults to None):
The control mode when applying ControlNet-Union.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
height,
width,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
dtype = self.transformer.dtype
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
# 💡 ADD THIS: Capture the attention_mask from encode_prompt
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
attention_mask,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# ✨ FIX: Encode negative prompts for CFG
do_classifier_free_guidance = guidance_scale > 1.0
if do_classifier_free_guidance:
if negative_prompt_embeds is None or negative_pooled_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
(negative_prompt_embeds, negative_pooled_prompt_embeds, negative_text_ids, negative_attention_mask) = self.encode_prompt(
prompt=negative_prompt, prompt_2=negative_prompt_2, device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length, lora_scale=lora_scale,
)
# 3. Prepare control image
num_channels_latents = self.transformer.config.in_channels // 4
if type(self.controlnet) == FullyShardedDataParallel:
inner_module = self.controlnet._fsdp_wrapped_module
else:
inner_module = self.controlnet
if isinstance(inner_module, LibreFluxControlNetModel):
control_image = self.prepare_image(
image=control_image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=dtype,
)
if control_image_undo_centering:
if not self.image_processor.do_normalize:
raise ValueError(
"`control_image_undo_centering` only makes sense if `do_normalize==True` in the image processor"
)
control_image = control_image*0.5 + 0.5
height, width = control_image.shape[-2:]
#logger.warning(
# f"pipeline_flux_controlnet, control_image: {control_image.min()} {control_image.max()}"
#)
# vae encode
control_image = _maybe_to(control_image, device=self.vae.device)
control_image = self.vae.encode(control_image).latent_dist.sample()
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
control_image = _maybe_to(control_image, device=device)
# pack
height_control_image, width_control_image = control_image.shape[2:]
control_image = self._pack_latents(
control_image,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
# set control mode
if control_mode is not None:
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
control_mode = control_mode.reshape([-1, 1])
# set control mode
control_mode_ = []
if isinstance(control_mode, list):
for cmode in control_mode:
if cmode is None:
control_mode_.append(-1)
else:
control_mode_.append(cmode)
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
control_mode = control_mode.reshape([-1, 1])
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 6. Denoising loop
target_device = self.transformer.device
self.controlnet.to(target_device)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# FIX: BATCH INPUTS FOR CFG
if do_classifier_free_guidance:
latent_model_input = torch.cat([latents] * 2)
current_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
current_pooled_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
current_attention_mask = torch.cat([negative_attention_mask, attention_mask])
current_text_ids = torch.cat([negative_text_ids, text_ids])
current_img_ids = torch.cat([latent_image_ids] * 2)
current_control_image = torch.cat([control_image] * 2) if isinstance(control_image, torch.Tensor) else [torch.cat([c_img] * 2) for c_img in control_image]
else:
latent_model_input = latents
current_prompt_embeds = prompt_embeds
current_pooled_embeds = pooled_prompt_embeds
current_attention_mask = attention_mask
current_text_ids = text_ids
current_img_ids = latent_image_ids
current_control_image = control_image
# FIX: Integrate with device handling
target_device = self.transformer.device
# Move all inputs to the target device
latent_model_input = _maybe_to(latent_model_input, device=target_device)
current_prompt_embeds = _maybe_to(current_prompt_embeds, device=target_device)
current_pooled_embeds = _maybe_to(current_pooled_embeds, device=target_device)
current_attention_mask = _maybe_to(current_attention_mask, device=target_device)
current_text_ids = _maybe_to(current_text_ids, device=target_device)
current_img_ids = _maybe_to(current_img_ids, device=target_device)
if isinstance(current_control_image, torch.Tensor):
current_control_image = _maybe_to(current_control_image, device=target_device)
else:
current_control_image = [ _maybe_to(c, device=target_device) for c in current_control_image ]
control_mode = _maybe_to(control_mode, device=target_device) if control_mode is not None else None
t_model = t.expand(latent_model_input.shape[0]).to(target_device)
# Model calls
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
hidden_states=latent_model_input,
controlnet_cond=current_control_image,
controlnet_mode=control_mode,
conditioning_scale=controlnet_conditioning_scale,
timestep=(t_model / 1000),
guidance=None,
pooled_projections=current_pooled_embeds,
encoder_hidden_states=current_prompt_embeds,
attention_mask=current_attention_mask,
txt_ids=current_text_ids,
img_ids=current_img_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False
)
controlnet_block_samples = [elem.to(dtype=latents.dtype, device=target_device) for elem in controlnet_block_samples]
controlnet_single_block_samples = [elem.to(dtype=latents.dtype, device=target_device) for elem in controlnet_single_block_samples]
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=(t_model / 1000),
guidance=None,
pooled_projections=current_pooled_embeds,
encoder_hidden_states=current_prompt_embeds,
attention_mask=current_attention_mask,
controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples,
txt_ids=current_text_ids,
img_ids=current_img_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False
)[0]
# FIX: Apply CFG formula
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
## Probably not needed
#noise_pred = noise_pred.to(latents.device)
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
latents = _maybe_to(latents, device=self.vae.device)
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image)