neuralvfx commited on
Commit
de32761
·
verified ·
1 Parent(s): da98671

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +973 -971
pipeline.py CHANGED
@@ -1,972 +1,974 @@
1
- # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- #
15
- # This was modied from the control net repo
16
-
17
-
18
- import inspect
19
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20
-
21
- from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
22
-
23
- import numpy as np
24
- import torch
25
- from transformers import (
26
- CLIPTextModel,
27
- CLIPTokenizer,
28
- T5EncoderModel,
29
- T5TokenizerFast,
30
- )
31
-
32
- from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
33
- from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin
34
- from diffusers.models.autoencoders import AutoencoderKL
35
-
36
- from .controlnet.net import LibreFluxControlNetModel
37
- from .transformer.trans import LibreFluxTransformer2DModel
38
-
39
- ####################################
40
- ##### ACTUAL PIPELINE STUFF ########
41
- ####################################
42
-
43
-
44
- from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
45
- from diffusers.utils import (
46
- USE_PEFT_BACKEND,
47
- is_torch_xla_available,
48
- logging,
49
- replace_example_docstring,
50
- scale_lora_layers,
51
- unscale_lora_layers,
52
- )
53
- from diffusers.utils.torch_utils import randn_tensor
54
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
55
- from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
56
-
57
-
58
- if is_torch_xla_available():
59
- import torch_xla.core.xla_model as xm
60
-
61
- XLA_AVAILABLE = True
62
- else:
63
- XLA_AVAILABLE = False
64
-
65
- # TODO(Chris): why won't this emit messages at the INFO level???
66
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
67
-
68
- EXAMPLE_DOC_STRING = """
69
- Examples:
70
- ```py
71
- >>> import torch
72
- >>> from diffusers.utils import load_image
73
- >>> from diffusers import FluxControlNetPipeline
74
- >>> from diffusers import FluxControlNetModel
75
-
76
- >>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny"
77
- >>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
78
- >>> pipe = FluxControlNetPipeline.from_pretrained(
79
- ... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16
80
- ... )
81
- >>> pipe.to("cuda")
82
- >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
83
- >>> prompt = "A girl in city, 25 years old, cool, futuristic"
84
- >>> image = pipe(
85
- ... prompt,
86
- ... control_image=control_image,
87
- ... controlnet_conditioning_scale=0.6,
88
- ... num_inference_steps=28,
89
- ... guidance_scale=3.5,
90
- ... ).images[0]
91
- >>> image.save("flux.png")
92
- ```
93
- """
94
-
95
- def _maybe_to(x: torch.Tensor, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
96
- if device is None and dtype is None:
97
- return x
98
- need_dev = device is not None and str(getattr(x, "device", None)) != str(device)
99
- need_dt = dtype is not None and getattr(x, "dtype", None) != dtype
100
- return x.to(device=device if need_dev else x.device, dtype=dtype if need_dt else x.dtype) if (need_dev or need_dt) else x
101
-
102
-
103
- # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
104
- def calculate_shift(
105
- image_seq_len,
106
- base_seq_len: int = 256,
107
- max_seq_len: int = 4096,
108
- base_shift: float = 0.5,
109
- max_shift: float = 1.16,
110
- ):
111
- m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
112
- b = base_shift - m * base_seq_len
113
- mu = image_seq_len * m + b
114
- return mu
115
-
116
-
117
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
118
- def retrieve_timesteps(
119
- scheduler,
120
- num_inference_steps: Optional[int] = None,
121
- device: Optional[Union[str, torch.device]] = None,
122
- timesteps: Optional[List[int]] = None,
123
- sigmas: Optional[List[float]] = None,
124
- **kwargs,
125
- ):
126
- """
127
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
128
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
129
-
130
- Args:
131
- scheduler (`SchedulerMixin`):
132
- The scheduler to get timesteps from.
133
- num_inference_steps (`int`):
134
- The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
135
- must be `None`.
136
- device (`str` or `torch.device`, *optional*):
137
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
138
- timesteps (`List[int]`, *optional*):
139
- Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
140
- `num_inference_steps` and `sigmas` must be `None`.
141
- sigmas (`List[float]`, *optional*):
142
- Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
143
- `num_inference_steps` and `timesteps` must be `None`.
144
-
145
- Returns:
146
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
147
- second element is the number of inference steps.
148
- """
149
- if timesteps is not None and sigmas is not None:
150
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
151
- if timesteps is not None:
152
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
153
- if not accepts_timesteps:
154
- raise ValueError(
155
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
156
- f" timestep schedules. Please check whether you are using the correct scheduler."
157
- )
158
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
159
- timesteps = scheduler.timesteps
160
- num_inference_steps = len(timesteps)
161
- elif sigmas is not None:
162
- accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
163
- if not accept_sigmas:
164
- raise ValueError(
165
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
166
- f" sigmas schedules. Please check whether you are using the correct scheduler."
167
- )
168
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
169
- timesteps = scheduler.timesteps
170
- num_inference_steps = len(timesteps)
171
- else:
172
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
173
- timesteps = scheduler.timesteps
174
- return timesteps, num_inference_steps
175
-
176
-
177
- class LibreFluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
178
- r"""
179
- The Flux pipeline for text-to-image generation.
180
-
181
- Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
182
-
183
- Args:
184
- transformer ([`FluxTransformer2DModel`]):
185
- Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
186
- scheduler ([`FlowMatchEulerDiscreteScheduler`]):
187
- A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
188
- vae ([`AutoencoderKL`]):
189
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
190
- text_encoder ([`CLIPTextModel`]):
191
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
192
- the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
193
- text_encoder_2 ([`T5EncoderModel`]):
194
- [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
195
- the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
196
- tokenizer (`CLIPTokenizer`):
197
- Tokenizer of class
198
- [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
199
- tokenizer_2 (`T5TokenizerFast`):
200
- Second Tokenizer of class
201
- [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
202
- """
203
-
204
- model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
205
- _optional_components = []
206
- _callback_tensor_inputs = ["latents", "prompt_embeds"]
207
-
208
- def __init__(
209
- self,
210
- scheduler: FlowMatchEulerDiscreteScheduler,
211
- vae: AutoencoderKL,
212
- text_encoder: CLIPTextModel,
213
- tokenizer: CLIPTokenizer,
214
- text_encoder_2: T5EncoderModel,
215
- tokenizer_2: T5TokenizerFast,
216
- transformer: LibreFluxTransformer2DModel,
217
- controlnet: Union[
218
- LibreFluxControlNetModel, List[LibreFluxControlNetModel], Tuple[LibreFluxControlNetModel],
219
- ],
220
- ):
221
- super().__init__()
222
-
223
- self.register_modules(
224
- vae=vae,
225
- text_encoder=text_encoder,
226
- text_encoder_2=text_encoder_2,
227
- tokenizer=tokenizer,
228
- tokenizer_2=tokenizer_2,
229
- transformer=transformer,
230
- scheduler=scheduler,
231
- controlnet=controlnet,
232
- )
233
- self.vae_scale_factor = (
234
- 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
235
- )
236
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
237
- self.tokenizer_max_length = (
238
- self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
239
- )
240
- self.default_sample_size = 64
241
-
242
- def _get_t5_prompt_embeds(
243
- self,
244
- prompt: Union[str, List[str]] = None,
245
- num_images_per_prompt: int = 1,
246
- max_sequence_length: int = 512,
247
- device: Optional[torch.device] = None,
248
- dtype: Optional[torch.dtype] = None,
249
- ):
250
- device = device or self._execution_device
251
- dtype = dtype or self.text_encoder.dtype
252
-
253
- prompt = [prompt] if isinstance(prompt, str) else prompt
254
- batch_size = len(prompt)
255
-
256
- text_inputs = self.tokenizer_2(
257
- prompt,
258
- padding="max_length",
259
- max_length=max_sequence_length,
260
- truncation=True,
261
- return_length=False,
262
- return_overflowing_tokens=False,
263
- return_tensors="pt",
264
- )
265
- text_input_ids = text_inputs.input_ids
266
- untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
267
-
268
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
269
- removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
270
- logger.warning(
271
- "The following part of your input was truncated because `max_sequence_length` is set to "
272
- f" {max_sequence_length} tokens: {removed_text}"
273
- )
274
-
275
- prompt_embeds = self.text_encoder_2(text_input_ids.to(self.text_encoder_2.device), output_hidden_states=False)[0]
276
- #prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
277
-
278
- dtype = self.text_encoder_2.dtype
279
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
280
-
281
- _, seq_len, _ = prompt_embeds.shape
282
-
283
- # duplicate text embeddings for each generation per prompt
284
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
285
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
286
-
287
- # ADD THIS: Get the attention mask and repeat it for each image
288
- prompt_attention_mask = text_inputs.attention_mask.to(device=device, dtype=dtype)
289
- prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
290
-
291
- # ADD THIS: Return the attention mask
292
- return prompt_embeds, prompt_attention_mask
293
-
294
- def _get_clip_prompt_embeds(
295
- self,
296
- prompt: Union[str, List[str]],
297
- num_images_per_prompt: int = 1,
298
- device: Optional[torch.device] = None,
299
- ):
300
- device = device or self._execution_device
301
-
302
- prompt = [prompt] if isinstance(prompt, str) else prompt
303
- batch_size = len(prompt)
304
-
305
- text_inputs = self.tokenizer(
306
- prompt,
307
- padding="max_length",
308
- max_length=self.tokenizer_max_length,
309
- truncation=True,
310
- return_overflowing_tokens=False,
311
- return_length=False,
312
- return_tensors="pt",
313
- )
314
-
315
- text_input_ids = text_inputs.input_ids
316
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
317
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
318
- removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
319
- logger.warning(
320
- "The following part of your input was truncated because CLIP can only handle sequences up to"
321
- f" {self.tokenizer_max_length} tokens: {removed_text}"
322
- )
323
- prompt_embeds = self.text_encoder(text_input_ids.to(self.text_encoder.device), output_hidden_states=False)
324
- #prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
325
-
326
- # Use pooled output of CLIPTextModel
327
- prompt_embeds = prompt_embeds.pooler_output
328
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
329
-
330
- # duplicate text embeddings for each generation per prompt, using mps friendly method
331
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
332
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
333
-
334
- return prompt_embeds
335
-
336
- def encode_prompt(
337
- self,
338
- prompt: Union[str, List[str]],
339
- prompt_2: Union[str, List[str]],
340
- device: Optional[torch.device] = None,
341
- num_images_per_prompt: int = 1,
342
- prompt_embeds: Optional[torch.FloatTensor] = None,
343
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
344
- max_sequence_length: int = 512,
345
- lora_scale: Optional[float] = None,
346
- ):
347
- device = device or self._execution_device
348
-
349
- if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
350
- self._lora_scale = lora_scale
351
- if self.text_encoder is not None and USE_PEFT_BACKEND:
352
- scale_lora_layers(self.text_encoder, lora_scale)
353
- if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
354
- scale_lora_layers(self.text_encoder_2, lora_scale)
355
-
356
- prompt = [prompt] if isinstance(prompt, str) else prompt
357
-
358
- if prompt_embeds is None:
359
- prompt_2 = prompt_2 or prompt
360
- prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
361
-
362
- pooled_prompt_embeds = self._get_clip_prompt_embeds(
363
- prompt=prompt,
364
- device=device,
365
- num_images_per_prompt=num_images_per_prompt,
366
- )
367
-
368
- # ADD THIS: Initialize mask and capture it from the T5 embedder
369
- prompt_attention_mask = None
370
- prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
371
- prompt=prompt_2,
372
- num_images_per_prompt=num_images_per_prompt,
373
- max_sequence_length=max_sequence_length,
374
- device=device,
375
- )
376
-
377
- if self.text_encoder is not None:
378
- if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
379
- unscale_lora_layers(self.text_encoder, lora_scale)
380
- if self.text_encoder_2 is not None:
381
- if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
382
- unscale_lora_layers(self.text_encoder_2, lora_scale)
383
-
384
- # FIX: Get batch_size and create text_ids with the correct shape
385
- batch_size = prompt_embeds.shape[0]
386
- dtype = self.transformer.dtype
387
- text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
388
-
389
- return prompt_embeds, pooled_prompt_embeds, text_ids, prompt_attention_mask
390
-
391
- def check_inputs(
392
- self,
393
- prompt,
394
- prompt_2,
395
- height,
396
- width,
397
- prompt_embeds=None,
398
- pooled_prompt_embeds=None,
399
- callback_on_step_end_tensor_inputs=None,
400
- max_sequence_length=None,
401
- ):
402
- if height % 8 != 0 or width % 8 != 0:
403
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
404
-
405
- if callback_on_step_end_tensor_inputs is not None and not all(
406
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
407
- ):
408
- raise ValueError(
409
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
410
- )
411
-
412
- if prompt is not None and prompt_embeds is not None:
413
- raise ValueError(
414
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
415
- " only forward one of the two."
416
- )
417
- elif prompt_2 is not None and prompt_embeds is not None:
418
- raise ValueError(
419
- f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
420
- " only forward one of the two."
421
- )
422
- elif prompt is None and prompt_embeds is None:
423
- raise ValueError(
424
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
425
- )
426
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
427
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
428
- elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
429
- raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
430
-
431
- if prompt_embeds is not None and pooled_prompt_embeds is None:
432
- raise ValueError(
433
- "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
434
- )
435
-
436
- if max_sequence_length is not None and max_sequence_length > 512:
437
- raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
438
-
439
- @staticmethod
440
- # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
441
- # FIX: Correctly creates batched image IDs
442
- def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
443
- latent_image_ids = torch.zeros(height // 2, width // 2, 3)
444
- latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
445
- latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
446
-
447
- latent_image_ids = latent_image_ids.unsqueeze(0).repeat(batch_size, 1, 1, 1)
448
-
449
- latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape[1:]
450
-
451
- latent_image_ids = latent_image_ids.reshape(
452
- batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
453
- )
454
-
455
- return latent_image_ids.to(device=device, dtype=dtype)
456
-
457
- @staticmethod
458
- # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
459
- def _pack_latents(latents, batch_size, num_channels_latents, height, width):
460
- latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
461
- latents = latents.permute(0, 2, 4, 1, 3, 5)
462
- latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
463
-
464
- return latents
465
-
466
- @staticmethod
467
- # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
468
- def _unpack_latents(latents, height, width, vae_scale_factor):
469
- batch_size, num_patches, channels = latents.shape
470
-
471
- height = height // vae_scale_factor
472
- width = width // vae_scale_factor
473
-
474
- latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
475
- latents = latents.permute(0, 3, 1, 4, 2, 5)
476
-
477
- latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
478
-
479
- return latents
480
-
481
- # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
482
- def prepare_latents(
483
- self,
484
- batch_size,
485
- num_channels_latents,
486
- height,
487
- width,
488
- dtype,
489
- device,
490
- generator,
491
- latents=None,
492
- ):
493
- height = 2 * (int(height) // self.vae_scale_factor)
494
- width = 2 * (int(width) // self.vae_scale_factor)
495
-
496
- shape = (batch_size, num_channels_latents, height, width)
497
-
498
- if latents is not None:
499
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
500
- return latents.to(device=device, dtype=dtype), latent_image_ids
501
-
502
- if isinstance(generator, list) and len(generator) != batch_size:
503
- raise ValueError(
504
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
505
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
506
- )
507
-
508
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
509
- latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
510
-
511
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
512
-
513
- return latents, latent_image_ids
514
-
515
- # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
516
- def prepare_image(
517
- self,
518
- image,
519
- width,
520
- height,
521
- batch_size,
522
- num_images_per_prompt,
523
- device,
524
- dtype,
525
- do_classifier_free_guidance=False,
526
- guess_mode=False,
527
- ):
528
- if isinstance(image, torch.Tensor):
529
- pass
530
- else:
531
- image = self.image_processor.preprocess(image, height=height, width=width)
532
-
533
- image_batch_size = image.shape[0]
534
-
535
- if image_batch_size == 1:
536
- repeat_by = batch_size
537
- else:
538
- # image batch size is the same as prompt batch size
539
- repeat_by = num_images_per_prompt
540
-
541
- image = image.repeat_interleave(repeat_by, dim=0)
542
-
543
- image = image.to(device=device, dtype=dtype)
544
-
545
- if do_classifier_free_guidance and not guess_mode:
546
- image = torch.cat([image] * 2)
547
-
548
- return image
549
-
550
- @property
551
- def guidance_scale(self):
552
- return self._guidance_scale
553
-
554
- @property
555
- def joint_attention_kwargs(self):
556
- return self._joint_attention_kwargs
557
-
558
- @property
559
- def num_timesteps(self):
560
- return self._num_timesteps
561
-
562
- @property
563
- def interrupt(self):
564
- return self._interrupt
565
-
566
- @torch.no_grad()
567
- @replace_example_docstring(EXAMPLE_DOC_STRING)
568
- def __call__(
569
- self,
570
- prompt: Union[str, List[str]] = None,
571
- prompt_2: Optional[Union[str, List[str]]] = None,
572
- height: Optional[int] = None,
573
- width: Optional[int] = None,
574
- num_inference_steps: int = 28,
575
- timesteps: List[int] = None,
576
- guidance_scale: float = 7.0,
577
- control_image: PipelineImageInput = None,
578
- control_mode: Optional[Union[int, List[int]]] = None,
579
- control_image_undo_centering: bool = False,
580
- controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
581
- num_images_per_prompt: Optional[int] = 1,
582
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
583
- latents: Optional[torch.FloatTensor] = None,
584
- prompt_embeds: Optional[torch.FloatTensor] = None,
585
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
586
- output_type: Optional[str] = "pil",
587
- return_dict: bool = True,
588
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
589
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
590
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
591
- max_sequence_length: int = 512,
592
- negative_prompt: Optional[Union[str, List[str]]] = "",
593
- negative_prompt_2: Optional[Union[str, List[str]]] = "",
594
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
595
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
596
- ):
597
- r"""
598
- Function invoked when calling the pipeline for generation.
599
-
600
- Args:
601
- prompt (`str` or `List[str]`, *optional*):
602
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
603
- instead.
604
- prompt_2 (`str` or `List[str]`, *optional*):
605
- The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
606
- will be used instead
607
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
608
- The height in pixels of the generated image. This is set to 1024 by default for the best results.
609
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
610
- The width in pixels of the generated image. This is set to 1024 by default for the best results.
611
- num_inference_steps (`int`, *optional*, defaults to 50):
612
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
613
- expense of slower inference.
614
- timesteps (`List[int]`, *optional*):
615
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
616
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
617
- passed will be used. Must be in descending order.
618
- guidance_scale (`float`, *optional*, defaults to 7.0):
619
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
620
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
621
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
622
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
623
- usually at the expense of lower image quality.
624
- control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
625
- `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
626
- The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
627
- specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
628
- as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
629
- width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
630
- images must be passed as a list such that each element of the list can be correctly batched for input
631
- to a single ControlNet.
632
- controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
633
- The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
634
- to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
635
- the corresponding scale as a list.
636
- control_mode (`int` or `List[int]`,, *optional*, defaults to None):
637
- The control mode when applying ControlNet-Union.
638
- num_images_per_prompt (`int`, *optional*, defaults to 1):
639
- The number of images to generate per prompt.
640
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
641
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
642
- to make generation deterministic.
643
- latents (`torch.FloatTensor`, *optional*):
644
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
645
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
646
- tensor will ge generated by sampling using the supplied random `generator`.
647
- prompt_embeds (`torch.FloatTensor`, *optional*):
648
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
649
- provided, text embeddings will be generated from `prompt` input argument.
650
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
651
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
652
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
653
- output_type (`str`, *optional*, defaults to `"pil"`):
654
- The output format of the generate image. Choose between
655
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
656
- return_dict (`bool`, *optional*, defaults to `True`):
657
- Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
658
- joint_attention_kwargs (`dict`, *optional*):
659
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
660
- `self.processor` in
661
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
662
- callback_on_step_end (`Callable`, *optional*):
663
- A function that calls at the end of each denoising steps during the inference. The function is called
664
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
665
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
666
- `callback_on_step_end_tensor_inputs`.
667
- callback_on_step_end_tensor_inputs (`List`, *optional*):
668
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
669
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
670
- `._callback_tensor_inputs` attribute of your pipeline class.
671
- max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
672
-
673
- Examples:
674
-
675
- Returns:
676
- [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
677
- is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
678
- images.
679
- """
680
-
681
- height = height or self.default_sample_size * self.vae_scale_factor
682
- width = width or self.default_sample_size * self.vae_scale_factor
683
-
684
- # 1. Check inputs. Raise error if not correct
685
- self.check_inputs(
686
- prompt,
687
- prompt_2,
688
- height,
689
- width,
690
- prompt_embeds=prompt_embeds,
691
- pooled_prompt_embeds=pooled_prompt_embeds,
692
- callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
693
- max_sequence_length=max_sequence_length,
694
- )
695
-
696
- self._guidance_scale = guidance_scale
697
- self._joint_attention_kwargs = joint_attention_kwargs
698
- self._interrupt = False
699
-
700
- # 2. Define call parameters
701
- if prompt is not None and isinstance(prompt, str):
702
- batch_size = 1
703
- elif prompt is not None and isinstance(prompt, list):
704
- batch_size = len(prompt)
705
- else:
706
- batch_size = prompt_embeds.shape[0]
707
-
708
- device = self._execution_device
709
- dtype = self.transformer.dtype
710
-
711
- lora_scale = (
712
- self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
713
- )
714
- # 💡 ADD THIS: Capture the attention_mask from encode_prompt
715
- (
716
- prompt_embeds,
717
- pooled_prompt_embeds,
718
- text_ids,
719
- attention_mask,
720
- ) = self.encode_prompt(
721
- prompt=prompt,
722
- prompt_2=prompt_2,
723
- prompt_embeds=prompt_embeds,
724
- pooled_prompt_embeds=pooled_prompt_embeds,
725
- device=device,
726
- num_images_per_prompt=num_images_per_prompt,
727
- max_sequence_length=max_sequence_length,
728
- lora_scale=lora_scale,
729
- )
730
-
731
- # ✨ FIX: Encode negative prompts for CFG
732
- do_classifier_free_guidance = guidance_scale > 1.0
733
- if do_classifier_free_guidance:
734
- if negative_prompt_embeds is None or negative_pooled_prompt_embeds is None:
735
- negative_prompt = negative_prompt or ""
736
- negative_prompt_2 = negative_prompt_2 or negative_prompt
737
- (negative_prompt_embeds, negative_pooled_prompt_embeds, negative_text_ids, negative_attention_mask) = self.encode_prompt(
738
- prompt=negative_prompt, prompt_2=negative_prompt_2, device=device,
739
- num_images_per_prompt=num_images_per_prompt,
740
- max_sequence_length=max_sequence_length, lora_scale=lora_scale,
741
- )
742
-
743
-
744
- # 3. Prepare control image
745
- num_channels_latents = self.transformer.config.in_channels // 4
746
-
747
- if type(self.controlnet) == FullyShardedDataParallel:
748
- inner_module = self.controlnet._fsdp_wrapped_module
749
- else:
750
- inner_module = self.controlnet
751
-
752
- control_image = self.prepare_image(
753
- image=control_image,
754
- width=width,
755
- height=height,
756
- batch_size=batch_size * num_images_per_prompt,
757
- num_images_per_prompt=num_images_per_prompt,
758
- device=device,
759
- dtype=dtype,
760
- )
761
-
762
- if control_image_undo_centering:
763
- if not self.image_processor.do_normalize:
764
- raise ValueError(
765
- "`control_image_undo_centering` only makes sense if `do_normalize==True` in the image processor"
766
- )
767
- control_image = control_image*0.5 + 0.5
768
-
769
- height, width = control_image.shape[-2:]
770
-
771
- #logger.warning(
772
- # f"pipeline_flux_controlnet, control_image: {control_image.min()} {control_image.max()}"
773
- #)
774
-
775
- # vae encode
776
- control_image = _maybe_to(control_image, device=self.vae.device)
777
- control_image = self.vae.encode(control_image).latent_dist.sample()
778
- control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
779
- control_image = _maybe_to(control_image, device=device)
780
- # pack
781
- height_control_image, width_control_image = control_image.shape[2:]
782
- control_image = self._pack_latents(
783
- control_image,
784
- batch_size * num_images_per_prompt,
785
- num_channels_latents,
786
- height_control_image,
787
- width_control_image,
788
- )
789
-
790
- # set control mode
791
- if control_mode is not None:
792
- control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
793
- control_mode = control_mode.reshape([-1, 1])
794
-
795
-
796
- # set control mode
797
- control_mode_ = []
798
- if isinstance(control_mode, list):
799
- for cmode in control_mode:
800
- if cmode is None:
801
- control_mode_.append(-1)
802
- else:
803
- control_mode_.append(cmode)
804
- control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
805
- control_mode = control_mode.reshape([-1, 1])
806
-
807
- # 4. Prepare latent variables
808
- num_channels_latents = self.transformer.config.in_channels // 4
809
- latents, latent_image_ids = self.prepare_latents(
810
- batch_size * num_images_per_prompt,
811
- num_channels_latents,
812
- height,
813
- width,
814
- prompt_embeds.dtype,
815
- device,
816
- generator,
817
- latents,
818
- )
819
-
820
- # 5. Prepare timesteps
821
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
822
- image_seq_len = latents.shape[1]
823
- mu = calculate_shift(
824
- image_seq_len,
825
- self.scheduler.config.base_image_seq_len,
826
- self.scheduler.config.max_image_seq_len,
827
- self.scheduler.config.base_shift,
828
- self.scheduler.config.max_shift,
829
- )
830
- timesteps, num_inference_steps = retrieve_timesteps(
831
- self.scheduler,
832
- num_inference_steps,
833
- device,
834
- timesteps,
835
- sigmas,
836
- mu=mu,
837
- )
838
-
839
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
840
- self._num_timesteps = len(timesteps)
841
-
842
- # 6. Denoising loop
843
- target_device = self.transformer.device
844
- self.controlnet.to(target_device)
845
- with self.progress_bar(total=num_inference_steps) as progress_bar:
846
- for i, t in enumerate(timesteps):
847
- if self.interrupt:
848
- continue
849
-
850
-
851
- # FIX: BATCH INPUTS FOR CFG
852
- if do_classifier_free_guidance:
853
- latent_model_input = torch.cat([latents] * 2)
854
- current_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
855
- current_pooled_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
856
- current_attention_mask = torch.cat([negative_attention_mask, attention_mask])
857
- current_text_ids = torch.cat([negative_text_ids, text_ids])
858
- current_img_ids = torch.cat([latent_image_ids] * 2)
859
- current_control_image = torch.cat([control_image] * 2) if isinstance(control_image, torch.Tensor) else [torch.cat([c_img] * 2) for c_img in control_image]
860
- else:
861
- latent_model_input = latents
862
- current_prompt_embeds = prompt_embeds
863
- current_pooled_embeds = pooled_prompt_embeds
864
- current_attention_mask = attention_mask
865
- current_text_ids = text_ids
866
- current_img_ids = latent_image_ids
867
- current_control_image = control_image
868
-
869
- # FIX: Integrate with device handling
870
- target_device = self.transformer.device
871
-
872
- # Move all inputs to the target device
873
- latent_model_input = _maybe_to(latent_model_input, device=target_device)
874
- current_prompt_embeds = _maybe_to(current_prompt_embeds, device=target_device)
875
- current_pooled_embeds = _maybe_to(current_pooled_embeds, device=target_device)
876
- current_attention_mask = _maybe_to(current_attention_mask, device=target_device)
877
- current_text_ids = _maybe_to(current_text_ids, device=target_device)
878
- current_img_ids = _maybe_to(current_img_ids, device=target_device)
879
- if isinstance(current_control_image, torch.Tensor):
880
- current_control_image = _maybe_to(current_control_image, device=target_device)
881
- else:
882
- current_control_image = [ _maybe_to(c, device=target_device) for c in current_control_image ]
883
- control_mode = _maybe_to(control_mode, device=target_device) if control_mode is not None else None
884
-
885
- t_model = t.expand(latent_model_input.shape[0]).to(target_device)
886
-
887
-
888
- # Model calls
889
- controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
890
- hidden_states=latent_model_input,
891
- controlnet_cond=current_control_image,
892
- controlnet_mode=control_mode,
893
- conditioning_scale=controlnet_conditioning_scale,
894
- timestep=(t_model / 1000),
895
- guidance=None,
896
- pooled_projections=current_pooled_embeds,
897
- encoder_hidden_states=current_prompt_embeds,
898
- attention_mask=current_attention_mask,
899
- txt_ids=current_text_ids,
900
- img_ids=current_img_ids,
901
- joint_attention_kwargs=self.joint_attention_kwargs,
902
- return_dict=False
903
- )
904
-
905
- controlnet_block_samples = [elem.to(dtype=latents.dtype, device=target_device) for elem in controlnet_block_samples]
906
- controlnet_single_block_samples = [elem.to(dtype=latents.dtype, device=target_device) for elem in controlnet_single_block_samples]
907
-
908
- noise_pred = self.transformer(
909
- hidden_states=latent_model_input,
910
- timestep=(t_model / 1000),
911
- guidance=None,
912
- pooled_projections=current_pooled_embeds,
913
- encoder_hidden_states=current_prompt_embeds,
914
- attention_mask=current_attention_mask,
915
- controlnet_block_samples=controlnet_block_samples,
916
- controlnet_single_block_samples=controlnet_single_block_samples,
917
- txt_ids=current_text_ids,
918
- img_ids=current_img_ids,
919
- joint_attention_kwargs=self.joint_attention_kwargs,
920
- return_dict=False
921
- )[0]
922
-
923
- # FIX: Apply CFG formula
924
- if do_classifier_free_guidance:
925
- noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
926
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
927
-
928
- ## Probably not needed
929
- #noise_pred = noise_pred.to(latents.device)
930
-
931
- latents_dtype = latents.dtype
932
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
933
-
934
- if latents.dtype != latents_dtype:
935
- if torch.backends.mps.is_available():
936
- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
937
- latents = latents.to(latents_dtype)
938
-
939
- if callback_on_step_end is not None:
940
- callback_kwargs = {}
941
- for k in callback_on_step_end_tensor_inputs:
942
- callback_kwargs[k] = locals()[k]
943
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
944
-
945
- latents = callback_outputs.pop("latents", latents)
946
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
947
-
948
- # call the callback, if provided
949
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
950
- progress_bar.update()
951
-
952
- if XLA_AVAILABLE:
953
- xm.mark_step()
954
-
955
- if output_type == "latent":
956
- image = latents
957
-
958
- else:
959
- latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
960
- latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
961
-
962
- latents = _maybe_to(latents, device=self.vae.device)
963
- image = self.vae.decode(latents, return_dict=False)[0]
964
- image = self.image_processor.postprocess(image, output_type=output_type)
965
-
966
- # Offload all models
967
- self.maybe_free_model_hooks()
968
-
969
- if not return_dict:
970
- return (image,)
971
-
 
 
972
  return FluxPipelineOutput(images=image)
 
1
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # This was modied from the control net repo
16
+
17
+
18
+ import inspect
19
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
22
+
23
+ import numpy as np
24
+ import torch
25
+ from transformers import (
26
+ CLIPTextModel,
27
+ CLIPTokenizer,
28
+ T5EncoderModel,
29
+ T5TokenizerFast,
30
+ )
31
+
32
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
33
+ from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin
34
+ from diffusers.models.autoencoders import AutoencoderKL
35
+
36
+ from .controlnet.net import LibreFluxControlNetModel
37
+ from .transformer.trans import LibreFluxTransformer2DModel
38
+
39
+ ####################################
40
+ ##### ACTUAL PIPELINE STUFF ########
41
+ ####################################
42
+
43
+
44
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
45
+ from diffusers.utils import (
46
+ USE_PEFT_BACKEND,
47
+ is_torch_xla_available,
48
+ logging,
49
+ replace_example_docstring,
50
+ scale_lora_layers,
51
+ unscale_lora_layers,
52
+ )
53
+ from diffusers.utils.torch_utils import randn_tensor
54
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
55
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
56
+
57
+
58
+ if is_torch_xla_available():
59
+ import torch_xla.core.xla_model as xm
60
+
61
+ XLA_AVAILABLE = True
62
+ else:
63
+ XLA_AVAILABLE = False
64
+
65
+ # TODO(Chris): why won't this emit messages at the INFO level???
66
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
67
+
68
+ EXAMPLE_DOC_STRING = """
69
+ Examples:
70
+ ```py
71
+ >>> import torch
72
+ >>> from diffusers.utils import load_image
73
+ >>> from diffusers import FluxControlNetPipeline
74
+ >>> from diffusers import FluxControlNetModel
75
+
76
+ >>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny"
77
+ >>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
78
+ >>> pipe = FluxControlNetPipeline.from_pretrained(
79
+ ... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16
80
+ ... )
81
+ >>> pipe.to("cuda")
82
+ >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
83
+ >>> prompt = "A girl in city, 25 years old, cool, futuristic"
84
+ >>> image = pipe(
85
+ ... prompt,
86
+ ... control_image=control_image,
87
+ ... controlnet_conditioning_scale=0.6,
88
+ ... num_inference_steps=28,
89
+ ... guidance_scale=3.5,
90
+ ... ).images[0]
91
+ >>> image.save("flux.png")
92
+ ```
93
+ """
94
+
95
+ def _maybe_to(x: torch.Tensor, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
96
+ if device is None and dtype is None:
97
+ return x
98
+ need_dev = device is not None and str(getattr(x, "device", None)) != str(device)
99
+ need_dt = dtype is not None and getattr(x, "dtype", None) != dtype
100
+ return x.to(device=device if need_dev else x.device, dtype=dtype if need_dt else x.dtype) if (need_dev or need_dt) else x
101
+
102
+
103
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
104
+ def calculate_shift(
105
+ image_seq_len,
106
+ base_seq_len: int = 256,
107
+ max_seq_len: int = 4096,
108
+ base_shift: float = 0.5,
109
+ max_shift: float = 1.16,
110
+ ):
111
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
112
+ b = base_shift - m * base_seq_len
113
+ mu = image_seq_len * m + b
114
+ return mu
115
+
116
+
117
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
118
+ def retrieve_timesteps(
119
+ scheduler,
120
+ num_inference_steps: Optional[int] = None,
121
+ device: Optional[Union[str, torch.device]] = None,
122
+ timesteps: Optional[List[int]] = None,
123
+ sigmas: Optional[List[float]] = None,
124
+ **kwargs,
125
+ ):
126
+ """
127
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
128
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
129
+
130
+ Args:
131
+ scheduler (`SchedulerMixin`):
132
+ The scheduler to get timesteps from.
133
+ num_inference_steps (`int`):
134
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
135
+ must be `None`.
136
+ device (`str` or `torch.device`, *optional*):
137
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
138
+ timesteps (`List[int]`, *optional*):
139
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
140
+ `num_inference_steps` and `sigmas` must be `None`.
141
+ sigmas (`List[float]`, *optional*):
142
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
143
+ `num_inference_steps` and `timesteps` must be `None`.
144
+
145
+ Returns:
146
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
147
+ second element is the number of inference steps.
148
+ """
149
+ if timesteps is not None and sigmas is not None:
150
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
151
+ if timesteps is not None:
152
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
153
+ if not accepts_timesteps:
154
+ raise ValueError(
155
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
156
+ f" timestep schedules. Please check whether you are using the correct scheduler."
157
+ )
158
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
159
+ timesteps = scheduler.timesteps
160
+ num_inference_steps = len(timesteps)
161
+ elif sigmas is not None:
162
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
163
+ if not accept_sigmas:
164
+ raise ValueError(
165
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
166
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
167
+ )
168
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
169
+ timesteps = scheduler.timesteps
170
+ num_inference_steps = len(timesteps)
171
+ else:
172
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
173
+ timesteps = scheduler.timesteps
174
+ return timesteps, num_inference_steps
175
+
176
+
177
+ class LibreFluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
178
+ r"""
179
+ The Flux pipeline for text-to-image generation.
180
+
181
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
182
+
183
+ Args:
184
+ transformer ([`FluxTransformer2DModel`]):
185
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
186
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
187
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
188
+ vae ([`AutoencoderKL`]):
189
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
190
+ text_encoder ([`CLIPTextModel`]):
191
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
192
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
193
+ text_encoder_2 ([`T5EncoderModel`]):
194
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
195
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
196
+ tokenizer (`CLIPTokenizer`):
197
+ Tokenizer of class
198
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
199
+ tokenizer_2 (`T5TokenizerFast`):
200
+ Second Tokenizer of class
201
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
202
+ """
203
+
204
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
205
+ _optional_components = []
206
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
207
+
208
+ def __init__(
209
+ self,
210
+ scheduler: FlowMatchEulerDiscreteScheduler,
211
+ vae: AutoencoderKL,
212
+ text_encoder: CLIPTextModel,
213
+ tokenizer: CLIPTokenizer,
214
+ text_encoder_2: T5EncoderModel,
215
+ tokenizer_2: T5TokenizerFast,
216
+ transformer: LibreFluxTransformer2DModel,
217
+ controlnet: Union[
218
+ LibreFluxControlNetModel, List[LibreFluxControlNetModel], Tuple[LibreFluxControlNetModel],
219
+ ],
220
+ ):
221
+ super().__init__()
222
+
223
+ self.register_modules(
224
+ vae=vae,
225
+ text_encoder=text_encoder,
226
+ text_encoder_2=text_encoder_2,
227
+ tokenizer=tokenizer,
228
+ tokenizer_2=tokenizer_2,
229
+ transformer=transformer,
230
+ scheduler=scheduler,
231
+ controlnet=controlnet,
232
+ )
233
+ self.vae_scale_factor = (
234
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
235
+ )
236
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
237
+ self.tokenizer_max_length = (
238
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
239
+ )
240
+ self.default_sample_size = 64
241
+
242
+ def _get_t5_prompt_embeds(
243
+ self,
244
+ prompt: Union[str, List[str]] = None,
245
+ num_images_per_prompt: int = 1,
246
+ max_sequence_length: int = 512,
247
+ device: Optional[torch.device] = None,
248
+ dtype: Optional[torch.dtype] = None,
249
+ ):
250
+ device = device or self._execution_device
251
+ dtype = dtype or self.text_encoder.dtype
252
+
253
+ prompt = [prompt] if isinstance(prompt, str) else prompt
254
+ batch_size = len(prompt)
255
+
256
+ text_inputs = self.tokenizer_2(
257
+ prompt,
258
+ padding="max_length",
259
+ max_length=max_sequence_length,
260
+ truncation=True,
261
+ return_length=False,
262
+ return_overflowing_tokens=False,
263
+ return_tensors="pt",
264
+ )
265
+ text_input_ids = text_inputs.input_ids
266
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
267
+
268
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
269
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
270
+ logger.warning(
271
+ "The following part of your input was truncated because `max_sequence_length` is set to "
272
+ f" {max_sequence_length} tokens: {removed_text}"
273
+ )
274
+
275
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(self.text_encoder_2.device), output_hidden_states=False)[0]
276
+ #prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
277
+
278
+ dtype = self.text_encoder_2.dtype
279
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
280
+
281
+ _, seq_len, _ = prompt_embeds.shape
282
+
283
+ # duplicate text embeddings for each generation per prompt
284
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
285
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
286
+
287
+ # ADD THIS: Get the attention mask and repeat it for each image
288
+ prompt_attention_mask = text_inputs.attention_mask.to(device=device, dtype=dtype)
289
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
290
+
291
+ # ADD THIS: Return the attention mask
292
+ return prompt_embeds, prompt_attention_mask
293
+
294
+ def _get_clip_prompt_embeds(
295
+ self,
296
+ prompt: Union[str, List[str]],
297
+ num_images_per_prompt: int = 1,
298
+ device: Optional[torch.device] = None,
299
+ ):
300
+ device = device or self._execution_device
301
+
302
+ prompt = [prompt] if isinstance(prompt, str) else prompt
303
+ batch_size = len(prompt)
304
+
305
+ text_inputs = self.tokenizer(
306
+ prompt,
307
+ padding="max_length",
308
+ max_length=self.tokenizer_max_length,
309
+ truncation=True,
310
+ return_overflowing_tokens=False,
311
+ return_length=False,
312
+ return_tensors="pt",
313
+ )
314
+
315
+ text_input_ids = text_inputs.input_ids
316
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
317
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
318
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
319
+ logger.warning(
320
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
321
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
322
+ )
323
+ prompt_embeds = self.text_encoder(text_input_ids.to(self.text_encoder.device), output_hidden_states=False)
324
+ #prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
325
+
326
+ # Use pooled output of CLIPTextModel
327
+ prompt_embeds = prompt_embeds.pooler_output
328
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
329
+
330
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
331
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
332
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
333
+
334
+ return prompt_embeds
335
+
336
+ def encode_prompt(
337
+ self,
338
+ prompt: Union[str, List[str]],
339
+ prompt_2: Union[str, List[str]],
340
+ device: Optional[torch.device] = None,
341
+ num_images_per_prompt: int = 1,
342
+ prompt_embeds: Optional[torch.FloatTensor] = None,
343
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
344
+ max_sequence_length: int = 512,
345
+ lora_scale: Optional[float] = None,
346
+ ):
347
+ device = device or self._execution_device
348
+
349
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
350
+ self._lora_scale = lora_scale
351
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
352
+ scale_lora_layers(self.text_encoder, lora_scale)
353
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
354
+ scale_lora_layers(self.text_encoder_2, lora_scale)
355
+
356
+ prompt = [prompt] if isinstance(prompt, str) else prompt
357
+
358
+ if prompt_embeds is None:
359
+ prompt_2 = prompt_2 or prompt
360
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
361
+
362
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
363
+ prompt=prompt,
364
+ device=device,
365
+ num_images_per_prompt=num_images_per_prompt,
366
+ )
367
+
368
+ # ADD THIS: Initialize mask and capture it from the T5 embedder
369
+ prompt_attention_mask = None
370
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
371
+ prompt=prompt_2,
372
+ num_images_per_prompt=num_images_per_prompt,
373
+ max_sequence_length=max_sequence_length,
374
+ device=device,
375
+ )
376
+
377
+ if self.text_encoder is not None:
378
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
379
+ unscale_lora_layers(self.text_encoder, lora_scale)
380
+ if self.text_encoder_2 is not None:
381
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
382
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
383
+
384
+ # FIX: Get batch_size and create text_ids with the correct shape
385
+ batch_size = prompt_embeds.shape[0]
386
+ dtype = self.transformer.dtype
387
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
388
+
389
+ return prompt_embeds, pooled_prompt_embeds, text_ids, prompt_attention_mask
390
+
391
+ def check_inputs(
392
+ self,
393
+ prompt,
394
+ prompt_2,
395
+ height,
396
+ width,
397
+ prompt_embeds=None,
398
+ pooled_prompt_embeds=None,
399
+ callback_on_step_end_tensor_inputs=None,
400
+ max_sequence_length=None,
401
+ ):
402
+ if height % 8 != 0 or width % 8 != 0:
403
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
404
+
405
+ if callback_on_step_end_tensor_inputs is not None and not all(
406
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
407
+ ):
408
+ raise ValueError(
409
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
410
+ )
411
+
412
+ if prompt is not None and prompt_embeds is not None:
413
+ raise ValueError(
414
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
415
+ " only forward one of the two."
416
+ )
417
+ elif prompt_2 is not None and prompt_embeds is not None:
418
+ raise ValueError(
419
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
420
+ " only forward one of the two."
421
+ )
422
+ elif prompt is None and prompt_embeds is None:
423
+ raise ValueError(
424
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
425
+ )
426
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
427
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
428
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
429
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
430
+
431
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
432
+ raise ValueError(
433
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
434
+ )
435
+
436
+ if max_sequence_length is not None and max_sequence_length > 512:
437
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
438
+
439
+ @staticmethod
440
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
441
+ # FIX: Correctly creates batched image IDs
442
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
443
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
444
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
445
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
446
+
447
+ latent_image_ids = latent_image_ids.unsqueeze(0).repeat(batch_size, 1, 1, 1)
448
+
449
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape[1:]
450
+
451
+ latent_image_ids = latent_image_ids.reshape(
452
+ batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
453
+ )
454
+
455
+ return latent_image_ids.to(device=device, dtype=dtype)
456
+
457
+ @staticmethod
458
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
459
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
460
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
461
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
462
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
463
+
464
+ return latents
465
+
466
+ @staticmethod
467
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
468
+ def _unpack_latents(latents, height, width, vae_scale_factor):
469
+ batch_size, num_patches, channels = latents.shape
470
+
471
+ height = height // vae_scale_factor
472
+ width = width // vae_scale_factor
473
+
474
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
475
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
476
+
477
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
478
+
479
+ return latents
480
+
481
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
482
+ def prepare_latents(
483
+ self,
484
+ batch_size,
485
+ num_channels_latents,
486
+ height,
487
+ width,
488
+ dtype,
489
+ device,
490
+ generator,
491
+ latents=None,
492
+ ):
493
+ height = 2 * (int(height) // self.vae_scale_factor)
494
+ width = 2 * (int(width) // self.vae_scale_factor)
495
+
496
+ shape = (batch_size, num_channels_latents, height, width)
497
+
498
+ if latents is not None:
499
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
500
+ return latents.to(device=device, dtype=dtype), latent_image_ids
501
+
502
+ if isinstance(generator, list) and len(generator) != batch_size:
503
+ raise ValueError(
504
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
505
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
506
+ )
507
+
508
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
509
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
510
+
511
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
512
+
513
+ return latents, latent_image_ids
514
+
515
+ # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
516
+ def prepare_image(
517
+ self,
518
+ image,
519
+ width,
520
+ height,
521
+ batch_size,
522
+ num_images_per_prompt,
523
+ device,
524
+ dtype,
525
+ do_classifier_free_guidance=False,
526
+ guess_mode=False,
527
+ ):
528
+ if isinstance(image, torch.Tensor):
529
+ pass
530
+ else:
531
+ image = self.image_processor.preprocess(image, height=height, width=width)
532
+
533
+ image_batch_size = image.shape[0]
534
+
535
+ if image_batch_size == 1:
536
+ repeat_by = batch_size
537
+ else:
538
+ # image batch size is the same as prompt batch size
539
+ repeat_by = num_images_per_prompt
540
+
541
+ image = image.repeat_interleave(repeat_by, dim=0)
542
+
543
+ image = image.to(device=device, dtype=dtype)
544
+
545
+ if do_classifier_free_guidance and not guess_mode:
546
+ image = torch.cat([image] * 2)
547
+
548
+ return image
549
+
550
+ @property
551
+ def guidance_scale(self):
552
+ return self._guidance_scale
553
+
554
+ @property
555
+ def joint_attention_kwargs(self):
556
+ return self._joint_attention_kwargs
557
+
558
+ @property
559
+ def num_timesteps(self):
560
+ return self._num_timesteps
561
+
562
+ @property
563
+ def interrupt(self):
564
+ return self._interrupt
565
+
566
+ @torch.no_grad()
567
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
568
+ def __call__(
569
+ self,
570
+ prompt: Union[str, List[str]] = None,
571
+ prompt_2: Optional[Union[str, List[str]]] = None,
572
+ height: Optional[int] = None,
573
+ width: Optional[int] = None,
574
+ num_inference_steps: int = 28,
575
+ timesteps: List[int] = None,
576
+ guidance_scale: float = 7.0,
577
+ control_image: PipelineImageInput = None,
578
+ control_mode: Optional[Union[int, List[int]]] = None,
579
+ control_image_undo_centering: bool = False,
580
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
581
+ num_images_per_prompt: Optional[int] = 1,
582
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
583
+ latents: Optional[torch.FloatTensor] = None,
584
+ prompt_embeds: Optional[torch.FloatTensor] = None,
585
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
586
+ output_type: Optional[str] = "pil",
587
+ return_dict: bool = True,
588
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
589
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
590
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
591
+ max_sequence_length: int = 512,
592
+ negative_prompt: Optional[Union[str, List[str]]] = "",
593
+ negative_prompt_2: Optional[Union[str, List[str]]] = "",
594
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
595
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
596
+ ):
597
+ r"""
598
+ Function invoked when calling the pipeline for generation.
599
+
600
+ Args:
601
+ prompt (`str` or `List[str]`, *optional*):
602
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
603
+ instead.
604
+ prompt_2 (`str` or `List[str]`, *optional*):
605
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
606
+ will be used instead
607
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
608
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
609
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
610
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
611
+ num_inference_steps (`int`, *optional*, defaults to 50):
612
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
613
+ expense of slower inference.
614
+ timesteps (`List[int]`, *optional*):
615
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
616
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
617
+ passed will be used. Must be in descending order.
618
+ guidance_scale (`float`, *optional*, defaults to 7.0):
619
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
620
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
621
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
622
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
623
+ usually at the expense of lower image quality.
624
+ control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
625
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
626
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
627
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
628
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
629
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
630
+ images must be passed as a list such that each element of the list can be correctly batched for input
631
+ to a single ControlNet.
632
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
633
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
634
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
635
+ the corresponding scale as a list.
636
+ control_mode (`int` or `List[int]`,, *optional*, defaults to None):
637
+ The control mode when applying ControlNet-Union.
638
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
639
+ The number of images to generate per prompt.
640
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
641
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
642
+ to make generation deterministic.
643
+ latents (`torch.FloatTensor`, *optional*):
644
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
645
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
646
+ tensor will ge generated by sampling using the supplied random `generator`.
647
+ prompt_embeds (`torch.FloatTensor`, *optional*):
648
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
649
+ provided, text embeddings will be generated from `prompt` input argument.
650
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
651
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
652
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
653
+ output_type (`str`, *optional*, defaults to `"pil"`):
654
+ The output format of the generate image. Choose between
655
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
656
+ return_dict (`bool`, *optional*, defaults to `True`):
657
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
658
+ joint_attention_kwargs (`dict`, *optional*):
659
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
660
+ `self.processor` in
661
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
662
+ callback_on_step_end (`Callable`, *optional*):
663
+ A function that calls at the end of each denoising steps during the inference. The function is called
664
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
665
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
666
+ `callback_on_step_end_tensor_inputs`.
667
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
668
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
669
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
670
+ `._callback_tensor_inputs` attribute of your pipeline class.
671
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
672
+
673
+ Examples:
674
+
675
+ Returns:
676
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
677
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
678
+ images.
679
+ """
680
+
681
+ height = height or self.default_sample_size * self.vae_scale_factor
682
+ width = width or self.default_sample_size * self.vae_scale_factor
683
+
684
+ # 1. Check inputs. Raise error if not correct
685
+ self.check_inputs(
686
+ prompt,
687
+ prompt_2,
688
+ height,
689
+ width,
690
+ prompt_embeds=prompt_embeds,
691
+ pooled_prompt_embeds=pooled_prompt_embeds,
692
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
693
+ max_sequence_length=max_sequence_length,
694
+ )
695
+
696
+ self._guidance_scale = guidance_scale
697
+ self._joint_attention_kwargs = joint_attention_kwargs
698
+ self._interrupt = False
699
+
700
+ # 2. Define call parameters
701
+ if prompt is not None and isinstance(prompt, str):
702
+ batch_size = 1
703
+ elif prompt is not None and isinstance(prompt, list):
704
+ batch_size = len(prompt)
705
+ else:
706
+ batch_size = prompt_embeds.shape[0]
707
+
708
+ device = self._execution_device
709
+ dtype = self.transformer.dtype
710
+
711
+ lora_scale = (
712
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
713
+ )
714
+ # 💡 ADD THIS: Capture the attention_mask from encode_prompt
715
+ (
716
+ prompt_embeds,
717
+ pooled_prompt_embeds,
718
+ text_ids,
719
+ attention_mask,
720
+ ) = self.encode_prompt(
721
+ prompt=prompt,
722
+ prompt_2=prompt_2,
723
+ prompt_embeds=prompt_embeds,
724
+ pooled_prompt_embeds=pooled_prompt_embeds,
725
+ device=device,
726
+ num_images_per_prompt=num_images_per_prompt,
727
+ max_sequence_length=max_sequence_length,
728
+ lora_scale=lora_scale,
729
+ )
730
+
731
+ # ✨ FIX: Encode negative prompts for CFG
732
+ do_classifier_free_guidance = guidance_scale > 1.0
733
+ if do_classifier_free_guidance:
734
+ if negative_prompt_embeds is None or negative_pooled_prompt_embeds is None:
735
+ negative_prompt = negative_prompt or ""
736
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
737
+ (negative_prompt_embeds, negative_pooled_prompt_embeds, negative_text_ids, negative_attention_mask) = self.encode_prompt(
738
+ prompt=negative_prompt, prompt_2=negative_prompt_2, device=device,
739
+ num_images_per_prompt=num_images_per_prompt,
740
+ max_sequence_length=max_sequence_length, lora_scale=lora_scale,
741
+ )
742
+
743
+
744
+ # 3. Prepare control image
745
+ num_channels_latents = self.transformer.config.in_channels // 4
746
+
747
+ if type(self.controlnet) == FullyShardedDataParallel:
748
+ inner_module = self.controlnet._fsdp_wrapped_module
749
+ else:
750
+ inner_module = self.controlnet
751
+
752
+ control_image = self.prepare_image(
753
+ image=control_image,
754
+ width=width,
755
+ height=height,
756
+ batch_size=batch_size * num_images_per_prompt,
757
+ num_images_per_prompt=num_images_per_prompt,
758
+ device=device,
759
+ dtype=dtype,
760
+ )
761
+
762
+ if control_image_undo_centering:
763
+ if not self.image_processor.do_normalize:
764
+ raise ValueError(
765
+ "`control_image_undo_centering` only makes sense if `do_normalize==True` in the image processor"
766
+ )
767
+ control_image = control_image*0.5 + 0.5
768
+
769
+ height, width = control_image.shape[-2:]
770
+
771
+ #logger.warning(
772
+ # f"pipeline_flux_controlnet, control_image: {control_image.min()} {control_image.max()}"
773
+ #)
774
+
775
+ # vae encode
776
+ control_image = _maybe_to(control_image, device=self.vae.device)
777
+ control_image = self.vae.encode(control_image).latent_dist.sample()
778
+ control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
779
+ control_image = _maybe_to(control_image, device=device)
780
+ # pack
781
+ height_control_image, width_control_image = control_image.shape[2:]
782
+ control_image = self._pack_latents(
783
+ control_image,
784
+ batch_size * num_images_per_prompt,
785
+ num_channels_latents,
786
+ height_control_image,
787
+ width_control_image,
788
+ )
789
+
790
+ # set control mode
791
+ if control_mode is not None:
792
+ control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
793
+ control_mode = control_mode.reshape([-1, 1])
794
+
795
+
796
+ # set control mode
797
+ control_mode_ = []
798
+ if isinstance(control_mode, list):
799
+ for cmode in control_mode:
800
+ if cmode is None:
801
+ control_mode_.append(-1)
802
+ else:
803
+ control_mode_.append(cmode)
804
+ control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
805
+ control_mode = control_mode.reshape([-1, 1])
806
+
807
+ # 4. Prepare latent variables
808
+ num_channels_latents = self.transformer.config.in_channels // 4
809
+ latents, latent_image_ids = self.prepare_latents(
810
+ batch_size * num_images_per_prompt,
811
+ num_channels_latents,
812
+ height,
813
+ width,
814
+ prompt_embeds.dtype,
815
+ device,
816
+ generator,
817
+ latents,
818
+ )
819
+
820
+ # 5. Prepare timesteps
821
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
822
+ image_seq_len = latents.shape[1]
823
+ mu = calculate_shift(
824
+ image_seq_len,
825
+ self.scheduler.config.base_image_seq_len,
826
+ self.scheduler.config.max_image_seq_len,
827
+ self.scheduler.config.base_shift,
828
+ self.scheduler.config.max_shift,
829
+ )
830
+ timesteps, num_inference_steps = retrieve_timesteps(
831
+ self.scheduler,
832
+ num_inference_steps,
833
+ device,
834
+ timesteps,
835
+ sigmas,
836
+ mu=mu,
837
+ )
838
+
839
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
840
+ self._num_timesteps = len(timesteps)
841
+
842
+ # 6. Denoising loop
843
+ target_device = self.transformer.device
844
+ self.controlnet.to(target_device)
845
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
846
+ for i, t in enumerate(timesteps):
847
+ if self.interrupt:
848
+ continue
849
+
850
+
851
+ # FIX: BATCH INPUTS FOR CFG
852
+ if do_classifier_free_guidance:
853
+ latent_model_input = torch.cat([latents] * 2)
854
+ current_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
855
+ current_pooled_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
856
+ current_attention_mask = torch.cat([negative_attention_mask, attention_mask])
857
+ print('Neg Text Ids Shape:',negative_text_ids.shape,'Text Ids Shape:', text_ids.shape, 'Latent Image Ids Shape:',latent_image_ids.shape)
858
+
859
+ current_text_ids = torch.cat([negative_text_ids, text_ids])
860
+ current_img_ids = torch.cat([latent_image_ids] * 2)
861
+ current_control_image = torch.cat([control_image] * 2) if isinstance(control_image, torch.Tensor) else [torch.cat([c_img] * 2) for c_img in control_image]
862
+ else:
863
+ latent_model_input = latents
864
+ current_prompt_embeds = prompt_embeds
865
+ current_pooled_embeds = pooled_prompt_embeds
866
+ current_attention_mask = attention_mask
867
+ current_text_ids = text_ids
868
+ current_img_ids = latent_image_ids
869
+ current_control_image = control_image
870
+
871
+ # FIX: Integrate with device handling
872
+ target_device = self.transformer.device
873
+
874
+ # Move all inputs to the target device
875
+ latent_model_input = _maybe_to(latent_model_input, device=target_device)
876
+ current_prompt_embeds = _maybe_to(current_prompt_embeds, device=target_device)
877
+ current_pooled_embeds = _maybe_to(current_pooled_embeds, device=target_device)
878
+ current_attention_mask = _maybe_to(current_attention_mask, device=target_device)
879
+ current_text_ids = _maybe_to(current_text_ids, device=target_device)
880
+ current_img_ids = _maybe_to(current_img_ids, device=target_device)
881
+ if isinstance(current_control_image, torch.Tensor):
882
+ current_control_image = _maybe_to(current_control_image, device=target_device)
883
+ else:
884
+ current_control_image = [ _maybe_to(c, device=target_device) for c in current_control_image ]
885
+ control_mode = _maybe_to(control_mode, device=target_device) if control_mode is not None else None
886
+
887
+ t_model = t.expand(latent_model_input.shape[0]).to(target_device)
888
+
889
+
890
+ # Model calls
891
+ controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
892
+ hidden_states=latent_model_input,
893
+ controlnet_cond=current_control_image,
894
+ controlnet_mode=control_mode,
895
+ conditioning_scale=controlnet_conditioning_scale,
896
+ timestep=(t_model / 1000),
897
+ guidance=None,
898
+ pooled_projections=current_pooled_embeds,
899
+ encoder_hidden_states=current_prompt_embeds,
900
+ attention_mask=current_attention_mask,
901
+ txt_ids=current_text_ids,
902
+ img_ids=current_img_ids,
903
+ joint_attention_kwargs=self.joint_attention_kwargs,
904
+ return_dict=False
905
+ )
906
+
907
+ controlnet_block_samples = [elem.to(dtype=latents.dtype, device=target_device) for elem in controlnet_block_samples]
908
+ controlnet_single_block_samples = [elem.to(dtype=latents.dtype, device=target_device) for elem in controlnet_single_block_samples]
909
+
910
+ noise_pred = self.transformer(
911
+ hidden_states=latent_model_input,
912
+ timestep=(t_model / 1000),
913
+ guidance=None,
914
+ pooled_projections=current_pooled_embeds,
915
+ encoder_hidden_states=current_prompt_embeds,
916
+ attention_mask=current_attention_mask,
917
+ controlnet_block_samples=controlnet_block_samples,
918
+ controlnet_single_block_samples=controlnet_single_block_samples,
919
+ txt_ids=current_text_ids,
920
+ img_ids=current_img_ids,
921
+ joint_attention_kwargs=self.joint_attention_kwargs,
922
+ return_dict=False
923
+ )[0]
924
+
925
+ # FIX: Apply CFG formula
926
+ if do_classifier_free_guidance:
927
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
928
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
929
+
930
+ ## Probably not needed
931
+ #noise_pred = noise_pred.to(latents.device)
932
+
933
+ latents_dtype = latents.dtype
934
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
935
+
936
+ if latents.dtype != latents_dtype:
937
+ if torch.backends.mps.is_available():
938
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
939
+ latents = latents.to(latents_dtype)
940
+
941
+ if callback_on_step_end is not None:
942
+ callback_kwargs = {}
943
+ for k in callback_on_step_end_tensor_inputs:
944
+ callback_kwargs[k] = locals()[k]
945
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
946
+
947
+ latents = callback_outputs.pop("latents", latents)
948
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
949
+
950
+ # call the callback, if provided
951
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
952
+ progress_bar.update()
953
+
954
+ if XLA_AVAILABLE:
955
+ xm.mark_step()
956
+
957
+ if output_type == "latent":
958
+ image = latents
959
+
960
+ else:
961
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
962
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
963
+
964
+ latents = _maybe_to(latents, device=self.vae.device)
965
+ image = self.vae.decode(latents, return_dict=False)[0]
966
+ image = self.image_processor.postprocess(image, output_type=output_type)
967
+
968
+ # Offload all models
969
+ self.maybe_free_model_hooks()
970
+
971
+ if not return_dict:
972
+ return (image,)
973
+
974
  return FluxPipelineOutput(images=image)