neuralvfx commited on
Commit
87644bf
·
verified ·
1 Parent(s): 081b83c

Initial upload of LibreFlux ControlNet pipeline

Browse files
Files changed (4) hide show
  1. __init__.py +2 -1
  2. controlnet/net.py +227 -0
  3. pipeline.py +51 -52
  4. transformer/trans.py +510 -1
__init__.py CHANGED
@@ -2,4 +2,5 @@ from .pipeline import (
2
  LibreFluxControlNetPipeline,
3
  LibreFluxTransformer2DModel,
4
  LibreFluxControlNetModel,
5
- )
 
 
2
  LibreFluxControlNetPipeline,
3
  LibreFluxTransformer2DModel,
4
  LibreFluxControlNetModel,
5
+ )
6
+ from .transformer.tran import *
controlnet/net.py CHANGED
@@ -1,3 +1,4 @@
 
1
  # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -50,6 +51,210 @@ from diffusers.models.embeddings import apply_rotary_emb
50
 
51
 
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  class FluxFusedSDPAProcessor:
54
  """
55
  Fused QKV processor using PyTorch's scaled_dot_product_attention.
@@ -1070,6 +1275,27 @@ class LibreFluxTransformer2DModel(
1070
 
1071
  return Transformer2DModelOutput(sample=output)
1072
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1073
  ####################################
1074
  ##### CONTROL NET MODEL MERGE ######
1075
  ####################################
@@ -1505,3 +1731,4 @@ class LibreFluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
1505
  controlnet_block_samples=controlnet_block_samples,
1506
  controlnet_single_block_samples=controlnet_single_block_samples,
1507
  )
 
 
1
+
2
  # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
 
51
 
52
 
53
 
54
+ def fa3_sdpa(
55
+ q,
56
+ k,
57
+ v,
58
+ ):
59
+ # flash attention 3 sdpa drop-in replacement
60
+ q, k, v = [x.permute(0, 2, 1, 3) for x in [q, k, v]]
61
+ out = flash_attn_func(q, k, v)[0]
62
+ return out.permute(0, 2, 1, 3)
63
+
64
+
65
+ class FluxSingleAttnProcessor3_0:
66
+ r"""
67
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
68
+ """
69
+
70
+ def __init__(self):
71
+ if not hasattr(F, "scaled_dot_product_attention"):
72
+ raise ImportError(
73
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
74
+ )
75
+
76
+ def __call__(
77
+ self,
78
+ attn,
79
+ hidden_states: Tensor,
80
+ encoder_hidden_states: Tensor = None,
81
+ attention_mask: FloatTensor = None,
82
+ image_rotary_emb: Tensor = None,
83
+ ) -> Tensor:
84
+ input_ndim = hidden_states.ndim
85
+
86
+ if input_ndim == 4:
87
+ batch_size, channel, height, width = hidden_states.shape
88
+ hidden_states = hidden_states.view(
89
+ batch_size, channel, height * width
90
+ ).transpose(1, 2)
91
+
92
+ batch_size, _, _ = (
93
+ hidden_states.shape
94
+ if encoder_hidden_states is None
95
+ else encoder_hidden_states.shape
96
+ )
97
+
98
+ query = attn.to_q(hidden_states)
99
+ if encoder_hidden_states is None:
100
+ encoder_hidden_states = hidden_states
101
+
102
+ key = attn.to_k(encoder_hidden_states)
103
+ value = attn.to_v(encoder_hidden_states)
104
+
105
+ inner_dim = key.shape[-1]
106
+ head_dim = inner_dim // attn.heads
107
+
108
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
109
+
110
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
111
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
112
+
113
+ if attn.norm_q is not None:
114
+ query = attn.norm_q(query)
115
+ if attn.norm_k is not None:
116
+ key = attn.norm_k(key)
117
+
118
+ # Apply RoPE if needed
119
+ if image_rotary_emb is not None:
120
+ query = apply_rotary_emb(query, image_rotary_emb)
121
+ key = apply_rotary_emb(key, image_rotary_emb)
122
+
123
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
124
+ # TODO: add support for attn.scale when we move to Torch 2.1
125
+ # hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
126
+ hidden_states = fa3_sdpa(query, key, value)
127
+ hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)")
128
+
129
+ hidden_states = hidden_states.transpose(1, 2).reshape(
130
+ batch_size, -1, attn.heads * head_dim
131
+ )
132
+ hidden_states = hidden_states.to(query.dtype)
133
+
134
+ if input_ndim == 4:
135
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
136
+ batch_size, channel, height, width
137
+ )
138
+
139
+ return hidden_states
140
+
141
+
142
+ class FluxAttnProcessor3_0:
143
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
144
+
145
+ def __init__(self):
146
+ if not hasattr(F, "scaled_dot_product_attention"):
147
+ raise ImportError(
148
+ "FluxAttnProcessor3_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
149
+ )
150
+
151
+ def __call__(
152
+ self,
153
+ attn,
154
+ hidden_states: FloatTensor,
155
+ encoder_hidden_states: FloatTensor = None,
156
+ attention_mask: FloatTensor = None,
157
+ image_rotary_emb: Tensor = None,
158
+ ) -> FloatTensor:
159
+ input_ndim = hidden_states.ndim
160
+ if input_ndim == 4:
161
+ batch_size, channel, height, width = hidden_states.shape
162
+ hidden_states = hidden_states.view(
163
+ batch_size, channel, height * width
164
+ ).transpose(1, 2)
165
+ context_input_ndim = encoder_hidden_states.ndim
166
+ if context_input_ndim == 4:
167
+ batch_size, channel, height, width = encoder_hidden_states.shape
168
+ encoder_hidden_states = encoder_hidden_states.view(
169
+ batch_size, channel, height * width
170
+ ).transpose(1, 2)
171
+
172
+ batch_size = encoder_hidden_states.shape[0]
173
+
174
+ # `sample` projections.
175
+ query = attn.to_q(hidden_states)
176
+ key = attn.to_k(hidden_states)
177
+ value = attn.to_v(hidden_states)
178
+
179
+ inner_dim = key.shape[-1]
180
+ head_dim = inner_dim // attn.heads
181
+
182
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
183
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
184
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
185
+
186
+ if attn.norm_q is not None:
187
+ query = attn.norm_q(query)
188
+ if attn.norm_k is not None:
189
+ key = attn.norm_k(key)
190
+
191
+ # `context` projections.
192
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
193
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
194
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
195
+
196
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
197
+ batch_size, -1, attn.heads, head_dim
198
+ ).transpose(1, 2)
199
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
200
+ batch_size, -1, attn.heads, head_dim
201
+ ).transpose(1, 2)
202
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
203
+ batch_size, -1, attn.heads, head_dim
204
+ ).transpose(1, 2)
205
+
206
+ if attn.norm_added_q is not None:
207
+ encoder_hidden_states_query_proj = attn.norm_added_q(
208
+ encoder_hidden_states_query_proj
209
+ )
210
+ if attn.norm_added_k is not None:
211
+ encoder_hidden_states_key_proj = attn.norm_added_k(
212
+ encoder_hidden_states_key_proj
213
+ )
214
+
215
+ # attention
216
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
217
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
218
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
219
+
220
+ if image_rotary_emb is not None:
221
+
222
+ query = apply_rotary_emb(query, image_rotary_emb)
223
+ key = apply_rotary_emb(key, image_rotary_emb)
224
+
225
+ # hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
226
+ hidden_states = fa3_sdpa(query, key, value)
227
+ hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)")
228
+
229
+ hidden_states = hidden_states.transpose(1, 2).reshape(
230
+ batch_size, -1, attn.heads * head_dim
231
+ )
232
+ hidden_states = hidden_states.to(query.dtype)
233
+
234
+ encoder_hidden_states, hidden_states = (
235
+ hidden_states[:, : encoder_hidden_states.shape[1]],
236
+ hidden_states[:, encoder_hidden_states.shape[1] :],
237
+ )
238
+
239
+ # linear proj
240
+ hidden_states = attn.to_out[0](hidden_states)
241
+ # dropout
242
+ hidden_states = attn.to_out[1](hidden_states)
243
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
244
+
245
+ if input_ndim == 4:
246
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
247
+ batch_size, channel, height, width
248
+ )
249
+ if context_input_ndim == 4:
250
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(
251
+ batch_size, channel, height, width
252
+ )
253
+
254
+ return hidden_states, encoder_hidden_states
255
+
256
+
257
+
258
  class FluxFusedSDPAProcessor:
259
  """
260
  Fused QKV processor using PyTorch's scaled_dot_product_attention.
 
1275
 
1276
  return Transformer2DModelOutput(sample=output)
1277
 
1278
+ ###################################
1279
+ # END TRANS MERGE
1280
+ ####################################
1281
+
1282
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
1283
+ #
1284
+ # Licensed under the Apache License, Version 2.0 (the "License");
1285
+ # you may not use this file except in compliance with the License.
1286
+ # You may obtain a copy of the License at
1287
+ #
1288
+ # http://www.apache.org/licenses/LICENSE-2.0
1289
+ #
1290
+ # Unless required by applicable law or agreed to in writing, software
1291
+ # distributed under the License is distributed on an "AS IS" BASIS,
1292
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1293
+ # See the License for the specific language governing permissions and
1294
+ # limitations under the License.
1295
+ #
1296
+ # This was modied from the control net repo
1297
+
1298
+
1299
  ####################################
1300
  ##### CONTROL NET MODEL MERGE ######
1301
  ####################################
 
1731
  controlnet_block_samples=controlnet_block_samples,
1732
  controlnet_single_block_samples=controlnet_single_block_samples,
1733
  )
1734
+
pipeline.py CHANGED
@@ -749,62 +749,61 @@ class LibreFluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSi
749
  else:
750
  inner_module = self.controlnet
751
 
752
- if isinstance(inner_module, LibreFluxControlNetModel):
753
- control_image = self.prepare_image(
754
- image=control_image,
755
- width=width,
756
- height=height,
757
- batch_size=batch_size * num_images_per_prompt,
758
- num_images_per_prompt=num_images_per_prompt,
759
- device=device,
760
- dtype=dtype,
761
- )
762
 
763
- if control_image_undo_centering:
764
- if not self.image_processor.do_normalize:
765
- raise ValueError(
766
- "`control_image_undo_centering` only makes sense if `do_normalize==True` in the image processor"
767
- )
768
- control_image = control_image*0.5 + 0.5
769
-
770
- height, width = control_image.shape[-2:]
771
-
772
- #logger.warning(
773
- # f"pipeline_flux_controlnet, control_image: {control_image.min()} {control_image.max()}"
774
- #)
775
-
776
- # vae encode
777
- control_image = _maybe_to(control_image, device=self.vae.device)
778
- control_image = self.vae.encode(control_image).latent_dist.sample()
779
- control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
780
- control_image = _maybe_to(control_image, device=device)
781
- # pack
782
- height_control_image, width_control_image = control_image.shape[2:]
783
- control_image = self._pack_latents(
784
- control_image,
785
- batch_size * num_images_per_prompt,
786
- num_channels_latents,
787
- height_control_image,
788
- width_control_image,
789
- )
790
 
791
- # set control mode
792
- if control_mode is not None:
793
- control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
794
- control_mode = control_mode.reshape([-1, 1])
795
-
796
-
797
- # set control mode
798
- control_mode_ = []
799
- if isinstance(control_mode, list):
800
- for cmode in control_mode:
801
- if cmode is None:
802
- control_mode_.append(-1)
803
- else:
804
- control_mode_.append(cmode)
805
- control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
806
  control_mode = control_mode.reshape([-1, 1])
807
 
 
 
 
 
 
 
 
 
 
 
 
 
808
  # 4. Prepare latent variables
809
  num_channels_latents = self.transformer.config.in_channels // 4
810
  latents, latent_image_ids = self.prepare_latents(
 
749
  else:
750
  inner_module = self.controlnet
751
 
752
+ control_image = self.prepare_image(
753
+ image=control_image,
754
+ width=width,
755
+ height=height,
756
+ batch_size=batch_size * num_images_per_prompt,
757
+ num_images_per_prompt=num_images_per_prompt,
758
+ device=device,
759
+ dtype=dtype,
760
+ )
 
761
 
762
+ if control_image_undo_centering:
763
+ if not self.image_processor.do_normalize:
764
+ raise ValueError(
765
+ "`control_image_undo_centering` only makes sense if `do_normalize==True` in the image processor"
766
+ )
767
+ control_image = control_image*0.5 + 0.5
768
+
769
+ height, width = control_image.shape[-2:]
770
+
771
+ #logger.warning(
772
+ # f"pipeline_flux_controlnet, control_image: {control_image.min()} {control_image.max()}"
773
+ #)
774
+
775
+ # vae encode
776
+ control_image = _maybe_to(control_image, device=self.vae.device)
777
+ control_image = self.vae.encode(control_image).latent_dist.sample()
778
+ control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
779
+ control_image = _maybe_to(control_image, device=device)
780
+ # pack
781
+ height_control_image, width_control_image = control_image.shape[2:]
782
+ control_image = self._pack_latents(
783
+ control_image,
784
+ batch_size * num_images_per_prompt,
785
+ num_channels_latents,
786
+ height_control_image,
787
+ width_control_image,
788
+ )
789
 
790
+ # set control mode
791
+ if control_mode is not None:
792
+ control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
 
 
 
 
 
 
 
 
 
 
 
 
793
  control_mode = control_mode.reshape([-1, 1])
794
 
795
+
796
+ # set control mode
797
+ control_mode_ = []
798
+ if isinstance(control_mode, list):
799
+ for cmode in control_mode:
800
+ if cmode is None:
801
+ control_mode_.append(-1)
802
+ else:
803
+ control_mode_.append(cmode)
804
+ control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
805
+ control_mode = control_mode.reshape([-1, 1])
806
+
807
  # 4. Prepare latent variables
808
  num_channels_latents = self.transformer.config.in_channels // 4
809
  latents, latent_image_ids = self.prepare_latents(
transformer/trans.py CHANGED
@@ -1,3 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #################################
2
  ##### TRANSFORMER MERGE #########
3
  #################################
@@ -763,4 +1272,4 @@ class LibreFluxTransformer2DModel(
763
  if not return_dict:
764
  return (output,)
765
 
766
- return Transformer2DModelOutput(sample=output)
 
1
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This was modied from the control net repo
16
+
17
+
18
+ import inspect
19
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
22
+
23
+ import numpy as np
24
+ import torch
25
+ from transformers import (
26
+ CLIPTextModel,
27
+ CLIPTokenizer,
28
+ T5EncoderModel,
29
+ T5TokenizerFast,
30
+ )
31
+
32
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
33
+ from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin
34
+ from diffusers.models.autoencoders import AutoencoderKL
35
+ ### MERGEING THESE ###
36
+ # from src.models.transformer import FluxTransformer2DModel
37
+ # from src.models.controlnet_flux import FluxControlNetModel
38
+ #############
39
+
40
+ ##########################################
41
+ ########### ATTENTION MERGE ##############
42
+ ##########################################
43
+
44
+ import torch
45
+ from torch import Tensor, FloatTensor
46
+ from torch.nn import functional as F
47
+ from einops import rearrange
48
+ from diffusers.models.attention_processor import Attention
49
+ from diffusers.models.embeddings import apply_rotary_emb
50
+
51
+
52
+
53
+ def fa3_sdpa(
54
+ q,
55
+ k,
56
+ v,
57
+ ):
58
+ # flash attention 3 sdpa drop-in replacement
59
+ q, k, v = [x.permute(0, 2, 1, 3) for x in [q, k, v]]
60
+ out = flash_attn_func(q, k, v)[0]
61
+ return out.permute(0, 2, 1, 3)
62
+
63
+
64
+ class FluxSingleAttnProcessor3_0:
65
+ r"""
66
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
67
+ """
68
+
69
+ def __init__(self):
70
+ if not hasattr(F, "scaled_dot_product_attention"):
71
+ raise ImportError(
72
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
73
+ )
74
+
75
+ def __call__(
76
+ self,
77
+ attn,
78
+ hidden_states: Tensor,
79
+ encoder_hidden_states: Tensor = None,
80
+ attention_mask: FloatTensor = None,
81
+ image_rotary_emb: Tensor = None,
82
+ ) -> Tensor:
83
+ input_ndim = hidden_states.ndim
84
+
85
+ if input_ndim == 4:
86
+ batch_size, channel, height, width = hidden_states.shape
87
+ hidden_states = hidden_states.view(
88
+ batch_size, channel, height * width
89
+ ).transpose(1, 2)
90
+
91
+ batch_size, _, _ = (
92
+ hidden_states.shape
93
+ if encoder_hidden_states is None
94
+ else encoder_hidden_states.shape
95
+ )
96
+
97
+ query = attn.to_q(hidden_states)
98
+ if encoder_hidden_states is None:
99
+ encoder_hidden_states = hidden_states
100
+
101
+ key = attn.to_k(encoder_hidden_states)
102
+ value = attn.to_v(encoder_hidden_states)
103
+
104
+ inner_dim = key.shape[-1]
105
+ head_dim = inner_dim // attn.heads
106
+
107
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
108
+
109
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
110
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
111
+
112
+ if attn.norm_q is not None:
113
+ query = attn.norm_q(query)
114
+ if attn.norm_k is not None:
115
+ key = attn.norm_k(key)
116
+
117
+ # Apply RoPE if needed
118
+ if image_rotary_emb is not None:
119
+ query = apply_rotary_emb(query, image_rotary_emb)
120
+ key = apply_rotary_emb(key, image_rotary_emb)
121
+
122
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
123
+ # TODO: add support for attn.scale when we move to Torch 2.1
124
+ # hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
125
+ hidden_states = fa3_sdpa(query, key, value)
126
+ hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)")
127
+
128
+ hidden_states = hidden_states.transpose(1, 2).reshape(
129
+ batch_size, -1, attn.heads * head_dim
130
+ )
131
+ hidden_states = hidden_states.to(query.dtype)
132
+
133
+ if input_ndim == 4:
134
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
135
+ batch_size, channel, height, width
136
+ )
137
+
138
+ return hidden_states
139
+
140
+
141
+ class FluxAttnProcessor3_0:
142
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
143
+
144
+ def __init__(self):
145
+ if not hasattr(F, "scaled_dot_product_attention"):
146
+ raise ImportError(
147
+ "FluxAttnProcessor3_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
148
+ )
149
+
150
+ def __call__(
151
+ self,
152
+ attn,
153
+ hidden_states: FloatTensor,
154
+ encoder_hidden_states: FloatTensor = None,
155
+ attention_mask: FloatTensor = None,
156
+ image_rotary_emb: Tensor = None,
157
+ ) -> FloatTensor:
158
+ input_ndim = hidden_states.ndim
159
+ if input_ndim == 4:
160
+ batch_size, channel, height, width = hidden_states.shape
161
+ hidden_states = hidden_states.view(
162
+ batch_size, channel, height * width
163
+ ).transpose(1, 2)
164
+ context_input_ndim = encoder_hidden_states.ndim
165
+ if context_input_ndim == 4:
166
+ batch_size, channel, height, width = encoder_hidden_states.shape
167
+ encoder_hidden_states = encoder_hidden_states.view(
168
+ batch_size, channel, height * width
169
+ ).transpose(1, 2)
170
+
171
+ batch_size = encoder_hidden_states.shape[0]
172
+
173
+ # `sample` projections.
174
+ query = attn.to_q(hidden_states)
175
+ key = attn.to_k(hidden_states)
176
+ value = attn.to_v(hidden_states)
177
+
178
+ inner_dim = key.shape[-1]
179
+ head_dim = inner_dim // attn.heads
180
+
181
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
182
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
183
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
184
+
185
+ if attn.norm_q is not None:
186
+ query = attn.norm_q(query)
187
+ if attn.norm_k is not None:
188
+ key = attn.norm_k(key)
189
+
190
+ # `context` projections.
191
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
192
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
193
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
194
+
195
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
196
+ batch_size, -1, attn.heads, head_dim
197
+ ).transpose(1, 2)
198
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
199
+ batch_size, -1, attn.heads, head_dim
200
+ ).transpose(1, 2)
201
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
202
+ batch_size, -1, attn.heads, head_dim
203
+ ).transpose(1, 2)
204
+
205
+ if attn.norm_added_q is not None:
206
+ encoder_hidden_states_query_proj = attn.norm_added_q(
207
+ encoder_hidden_states_query_proj
208
+ )
209
+ if attn.norm_added_k is not None:
210
+ encoder_hidden_states_key_proj = attn.norm_added_k(
211
+ encoder_hidden_states_key_proj
212
+ )
213
+
214
+ # attention
215
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
216
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
217
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
218
+
219
+ if image_rotary_emb is not None:
220
+
221
+ query = apply_rotary_emb(query, image_rotary_emb)
222
+ key = apply_rotary_emb(key, image_rotary_emb)
223
+
224
+ # hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
225
+ hidden_states = fa3_sdpa(query, key, value)
226
+ hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)")
227
+
228
+ hidden_states = hidden_states.transpose(1, 2).reshape(
229
+ batch_size, -1, attn.heads * head_dim
230
+ )
231
+ hidden_states = hidden_states.to(query.dtype)
232
+
233
+ encoder_hidden_states, hidden_states = (
234
+ hidden_states[:, : encoder_hidden_states.shape[1]],
235
+ hidden_states[:, encoder_hidden_states.shape[1] :],
236
+ )
237
+
238
+ # linear proj
239
+ hidden_states = attn.to_out[0](hidden_states)
240
+ # dropout
241
+ hidden_states = attn.to_out[1](hidden_states)
242
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
243
+
244
+ if input_ndim == 4:
245
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
246
+ batch_size, channel, height, width
247
+ )
248
+ if context_input_ndim == 4:
249
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(
250
+ batch_size, channel, height, width
251
+ )
252
+
253
+ return hidden_states, encoder_hidden_states
254
+
255
+
256
+
257
+ class FluxFusedSDPAProcessor:
258
+ """
259
+ Fused QKV processor using PyTorch's scaled_dot_product_attention.
260
+ Uses fused projections but splits for attention computation.
261
+ """
262
+
263
+ def __init__(self):
264
+ if not hasattr(F, "scaled_dot_product_attention"):
265
+ raise ImportError(
266
+ "FluxFusedSDPAProcessor requires PyTorch 2.0+ for scaled_dot_product_attention"
267
+ )
268
+
269
+ def __call__(
270
+ self,
271
+ attn,
272
+ hidden_states: FloatTensor,
273
+ encoder_hidden_states: FloatTensor = None,
274
+ attention_mask: FloatTensor = None,
275
+ image_rotary_emb: Tensor = None,
276
+ ) -> FloatTensor:
277
+ input_ndim = hidden_states.ndim
278
+ if input_ndim == 4:
279
+ batch_size, channel, height, width = hidden_states.shape
280
+ hidden_states = hidden_states.view(
281
+ batch_size, channel, height * width
282
+ ).transpose(1, 2)
283
+
284
+ context_input_ndim = (
285
+ encoder_hidden_states.ndim if encoder_hidden_states is not None else None
286
+ )
287
+ if context_input_ndim == 4:
288
+ batch_size, channel, height, width = encoder_hidden_states.shape
289
+ encoder_hidden_states = encoder_hidden_states.view(
290
+ batch_size, channel, height * width
291
+ ).transpose(1, 2)
292
+
293
+ batch_size = (
294
+ encoder_hidden_states.shape[0]
295
+ if encoder_hidden_states is not None
296
+ else hidden_states.shape[0]
297
+ )
298
+
299
+ # Single attention case (no encoder states)
300
+ if encoder_hidden_states is None:
301
+ # Use fused QKV projection
302
+ qkv = attn.to_qkv(hidden_states) # (batch, seq_len, 3 * inner_dim)
303
+ inner_dim = qkv.shape[-1] // 3
304
+ head_dim = inner_dim // attn.heads
305
+ seq_len = hidden_states.shape[1]
306
+
307
+ # Split and reshape
308
+ qkv = qkv.view(batch_size, seq_len, 3, attn.heads, head_dim)
309
+ query, key, value = qkv.unbind(
310
+ dim=2
311
+ ) # Each is (batch, seq_len, heads, head_dim)
312
+
313
+ # Transpose to (batch, heads, seq_len, head_dim)
314
+ query = query.transpose(1, 2)
315
+ key = key.transpose(1, 2)
316
+ value = value.transpose(1, 2)
317
+
318
+ # Apply norms if needed
319
+ if attn.norm_q is not None:
320
+ query = attn.norm_q(query)
321
+ if attn.norm_k is not None:
322
+ key = attn.norm_k(key)
323
+
324
+ # Apply RoPE if needed
325
+ if image_rotary_emb is not None:
326
+ query = apply_rotary_emb(query, image_rotary_emb)
327
+ key = apply_rotary_emb(key, image_rotary_emb)
328
+
329
+ # SDPA
330
+ hidden_states = F.scaled_dot_product_attention(
331
+ query,
332
+ key,
333
+ value,
334
+ attn_mask=attention_mask,
335
+ dropout_p=0.0,
336
+ is_causal=False,
337
+ )
338
+
339
+ # Reshape back
340
+ hidden_states = hidden_states.transpose(1, 2).reshape(
341
+ batch_size, -1, attn.heads * head_dim
342
+ )
343
+ hidden_states = hidden_states.to(query.dtype)
344
+
345
+ if input_ndim == 4:
346
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
347
+ batch_size, channel, height, width
348
+ )
349
+
350
+ return hidden_states
351
+
352
+ # Joint attention case (with encoder states)
353
+ else:
354
+ # Process self-attention QKV
355
+ qkv = attn.to_qkv(hidden_states)
356
+ inner_dim = qkv.shape[-1] // 3
357
+ head_dim = inner_dim // attn.heads
358
+ seq_len = hidden_states.shape[1]
359
+
360
+ qkv = qkv.view(batch_size, seq_len, 3, attn.heads, head_dim)
361
+ query, key, value = qkv.unbind(dim=2)
362
+
363
+ # Transpose to (batch, heads, seq_len, head_dim)
364
+ query = query.transpose(1, 2)
365
+ key = key.transpose(1, 2)
366
+ value = value.transpose(1, 2)
367
+
368
+ # Apply norms if needed
369
+ if attn.norm_q is not None:
370
+ query = attn.norm_q(query)
371
+ if attn.norm_k is not None:
372
+ key = attn.norm_k(key)
373
+
374
+ # Process encoder QKV
375
+ encoder_seq_len = encoder_hidden_states.shape[1]
376
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
377
+ encoder_qkv = encoder_qkv.view(
378
+ batch_size, encoder_seq_len, 3, attn.heads, head_dim
379
+ )
380
+ encoder_query, encoder_key, encoder_value = encoder_qkv.unbind(dim=2)
381
+
382
+ # Transpose to (batch, heads, seq_len, head_dim)
383
+ encoder_query = encoder_query.transpose(1, 2)
384
+ encoder_key = encoder_key.transpose(1, 2)
385
+ encoder_value = encoder_value.transpose(1, 2)
386
+
387
+ # Apply encoder norms if needed
388
+ if attn.norm_added_q is not None:
389
+ encoder_query = attn.norm_added_q(encoder_query)
390
+ if attn.norm_added_k is not None:
391
+ encoder_key = attn.norm_added_k(encoder_key)
392
+
393
+ # Concatenate encoder and self-attention
394
+ query = torch.cat([encoder_query, query], dim=2)
395
+ key = torch.cat([encoder_key, key], dim=2)
396
+ value = torch.cat([encoder_value, value], dim=2)
397
+
398
+ # Apply RoPE if needed
399
+ if image_rotary_emb is not None:
400
+ query = apply_rotary_emb(query, image_rotary_emb)
401
+ key = apply_rotary_emb(key, image_rotary_emb)
402
+
403
+ # SDPA
404
+ hidden_states = F.scaled_dot_product_attention(
405
+ query,
406
+ key,
407
+ value,
408
+ attn_mask=attention_mask,
409
+ dropout_p=0.0,
410
+ is_causal=False,
411
+ )
412
+
413
+ # Reshape: (batch, heads, seq_len, head_dim) -> (batch, seq_len, heads * head_dim)
414
+ hidden_states = hidden_states.transpose(1, 2).reshape(
415
+ batch_size, -1, attn.heads * head_dim
416
+ )
417
+ hidden_states = hidden_states.to(query.dtype)
418
+
419
+ # Split encoder and self outputs
420
+ encoder_hidden_states = hidden_states[:, :encoder_seq_len]
421
+ hidden_states = hidden_states[:, encoder_seq_len:]
422
+
423
+ # Output projections
424
+ hidden_states = attn.to_out[0](hidden_states)
425
+ hidden_states = attn.to_out[1](hidden_states) # dropout
426
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
427
+
428
+ # Reshape if needed
429
+ if input_ndim == 4:
430
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
431
+ batch_size, channel, height, width
432
+ )
433
+ if context_input_ndim == 4:
434
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(
435
+ batch_size, channel, height, width
436
+ )
437
+
438
+ return hidden_states, encoder_hidden_states
439
+
440
+
441
+ class FluxSingleFusedSDPAProcessor:
442
+ """
443
+ Fused QKV processor for single attention (no encoder states).
444
+ Simpler version for self-attention only blocks.
445
+ """
446
+
447
+ def __init__(self):
448
+ if not hasattr(F, "scaled_dot_product_attention"):
449
+ raise ImportError(
450
+ "FluxSingleFusedSDPAProcessor requires PyTorch 2.0+ for scaled_dot_product_attention"
451
+ )
452
+
453
+ def __call__(
454
+ self,
455
+ attn,
456
+ hidden_states: Tensor,
457
+ encoder_hidden_states: Tensor = None,
458
+ attention_mask: FloatTensor = None,
459
+ image_rotary_emb: Tensor = None,
460
+ ) -> Tensor:
461
+ input_ndim = hidden_states.ndim
462
+ if input_ndim == 4:
463
+ batch_size, channel, height, width = hidden_states.shape
464
+ hidden_states = hidden_states.view(
465
+ batch_size, channel, height * width
466
+ ).transpose(1, 2)
467
+
468
+ batch_size, seq_len, _ = hidden_states.shape
469
+
470
+ # Use fused QKV projection
471
+ qkv = attn.to_qkv(hidden_states) # (batch, seq_len, 3 * inner_dim)
472
+ inner_dim = qkv.shape[-1] // 3
473
+ head_dim = inner_dim // attn.heads
474
+
475
+ # Split and reshape in one go
476
+ qkv = qkv.view(batch_size, seq_len, 3, attn.heads, head_dim)
477
+ qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, L, D) – still strided
478
+ query, key, value = [
479
+ t.contiguous() for t in qkv.unbind(0) # make each view dense
480
+ ]
481
+ # Now each is (batch, heads, seq_len, head_dim)
482
+
483
+ # Apply norms if needed
484
+ if attn.norm_q is not None:
485
+ query = attn.norm_q(query)
486
+ if attn.norm_k is not None:
487
+ key = attn.norm_k(key)
488
+
489
+ # Apply RoPE if needed
490
+ if image_rotary_emb is not None:
491
+ query = apply_rotary_emb(query, image_rotary_emb)
492
+ key = apply_rotary_emb(key, image_rotary_emb)
493
+
494
+ # SDPA
495
+ hidden_states = F.scaled_dot_product_attention(
496
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
497
+ )
498
+
499
+ # Reshape back
500
+ hidden_states = rearrange(hidden_states, "B H L D -> B L (H D)")
501
+ hidden_states = hidden_states.to(query.dtype)
502
+
503
+ if input_ndim == 4:
504
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
505
+ batch_size, channel, height, width
506
+ )
507
+
508
+ return hidden_states
509
+
510
  #################################
511
  ##### TRANSFORMER MERGE #########
512
  #################################
 
1272
  if not return_dict:
1273
  return (output,)
1274
 
1275
+ return Transformer2DModelOutput(sample=output)