Lakonik commited on
Commit
2d8f0dd
·
1 Parent(s): 71877e6

Add gradio app

Browse files
.gitignore ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /.idea/
2
+ /work_dirs*
3
+ .vscode/
4
+ /tmp
5
+ /data
6
+ /checkpoints
7
+ *.so
8
+ *.patch
9
+ __pycache__/
10
+ *.egg-info/
11
+ /viz*
12
+ /submit*
13
+ build/
14
+ *.pyd
15
+ /cache*
16
+ *.stl
17
+ *.pth
18
+ /venv/
19
+ .nk8s
20
+ *.mp4
21
+ .vs
22
+ /exp/
23
+ /dev/
24
+ *.pyi
25
+ !/data/imagenet/imagenet1000_clsidx_to_labels.txt
README.md CHANGED
@@ -1,13 +1,25 @@
1
  ---
2
- title: Pi Qwen
3
- emoji: 🏃
4
- colorFrom: red
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Pi-Qwen Demo
3
+ emoji: 🚀
4
+ colorFrom: yellow
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 4.18.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
12
 
13
+ Official demo of the paper:
14
+
15
+ **pi-Flow: Policy-Based Few-Step Generation via Imitation Distillation**
16
+ <br>
17
+ [Hansheng Chen](https://lakonik.github.io/)<sup>1</sup>,
18
+ [Kai Zhang](https://kai-46.github.io/website/)<sup>2</sup>,
19
+ [Hao Tan](https://research.adobe.com/person/hao-tan/)<sup>2</sup>,
20
+ [Leonidas Guibas](https://geometry.stanford.edu/?member=guibas)<sup>1</sup>,
21
+ [Gordon Wetzstein](http://web.stanford.edu/~gordonwz/)<sup>1</sup>,
22
+ [Sai Bi](https://sai-bi.github.io/)<sup>2</sup><br>
23
+ <sup>1</sup>Stanford University, <sup>2</sup>Adobe Research
24
+ <br>
25
+ [[arXiv]()] [[pi-Qwen Demo🤗](https://huggingface.co/spaces/Lakonik/pi-Qwen)] [[pi-FLUX Demo🤗](https://huggingface.co/spaces/Lakonik/pi-FLUX.1)]
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ torch.backends.cuda.matmul.allow_tf32 = True
4
+ torch.backends.cudnn.allow_tf32 = True
5
+
6
+ import os
7
+ import gradio as gr
8
+ import spaces
9
+ from diffusers import FlowMatchEulerDiscreteScheduler
10
+ from lakonlab.ui.gradio.create_text_to_img import create_interface_text_to_img
11
+ from lakonlab.pipelines.piqwen_pipeline import PiQwenImagePipeline
12
+
13
+ from huggingface_hub import login
14
+ login(token=os.getenv('HF_TOKEN'))
15
+
16
+
17
+ DEFAULT_PROMPT = ('Photo of a coffee shop entrance featuring a chalkboard sign reading "π-Qwen Coffee 😊 $2 per cup," '
18
+ 'with a neon light beside it displaying "π-通义千问". Next to it hangs a poster showing a beautiful '
19
+ 'Chinese woman, and beneath the poster is written "e≈2.71828-18284-59045-23536-02874-71352".')
20
+
21
+
22
+ pipe = PiQwenImagePipeline.from_pretrained(
23
+ 'Qwen/Qwen-Image',
24
+ torch_dtype=torch.bfloat16)
25
+ pipe.load_piflow_adapter(
26
+ 'Lakonik/pi-Qwen-Image',
27
+ subfolder='gmqwen_k8_piid_4step',
28
+ target_module_name='transformer')
29
+ pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config( # use fixed shift=3.2
30
+ pipe.scheduler.config, shift=3.2, shift_terminal=None, use_dynamic_shifting=False)
31
+ pipe = pipe.to('cuda')
32
+
33
+
34
+ @spaces.GPU
35
+ def generate(seed, prompt, width, height, steps):
36
+ return pipe(
37
+ prompt=prompt,
38
+ width=width,
39
+ height=height,
40
+ num_inference_steps=steps,
41
+ generator=torch.Generator().manual_seed(seed),
42
+ ).images[0]
43
+
44
+
45
+ with gr.Blocks(analytics_enabled=False,
46
+ title='pi-Qwen Demo',
47
+ css='lakonlab/ui/gradio/style.css'
48
+ ) as demo:
49
+
50
+ md_txt = '# pi-Qwen Demo\n\n' \
51
+ 'Official demo of the paper [pi-Flow: Policy-Based Few-Step Generation via Imitation Distillation](). ' \
52
+ '**Base model:** [Qwen-Image](https://huggingface.co/Qwen/Qwen-Image). **Fast policy:** GMFlow. **Code:** [https://github.com/Lakonik/piFlow](https://github.com/Lakonik/piFlow).'
53
+ gr.Markdown(md_txt)
54
+
55
+ create_interface_text_to_img(
56
+ generate,
57
+ prompt=DEFAULT_PROMPT,
58
+ steps=4, guidance_scale=None,
59
+ args=['last_seed', 'prompt', 'width', 'height', 'steps'])
60
+ demo.queue().launch()
lakonlab/__init__.py ADDED
File without changes
lakonlab/models/__init__.py ADDED
File without changes
lakonlab/models/architecture/__init__.py ADDED
File without changes
lakonlab/models/architecture/gmflow/__init__.py ADDED
File without changes
lakonlab/models/architecture/gmflow/gm_output.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from dataclasses import dataclass
3
+ from diffusers.utils import BaseOutput
4
+
5
+
6
+ @dataclass
7
+ class GMFlowModelOutput(BaseOutput):
8
+ """
9
+ The output of GMFlow models.
10
+
11
+ Args:
12
+ means (`torch.Tensor` of shape `(batch_size, num_gaussians, num_channels, height, width)` or
13
+ `(batch_size, num_gaussians, num_channels, frame, height, width)`):
14
+ Gaussian mixture means.
15
+ logweights (`torch.Tensor` of shape `(batch_size, num_gaussians, 1, height, width)` or
16
+ `(batch_size, num_gaussians, 1, frame, height, width)`):
17
+ Gaussian mixture log-weights (logits).
18
+ logstds (`torch.Tensor` of shape `(batch_size, 1, 1, 1, 1)` or `(batch_size, 1, 1, 1, 1, 1)`):
19
+ Gaussian mixture log-standard-deviations (logstds are shared across all Gaussians and channels).
20
+ """
21
+
22
+ means: torch.Tensor
23
+ logweights: torch.Tensor
24
+ logstds: torch.Tensor
lakonlab/models/architecture/gmflow/gmflux.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from typing import Any, Dict, Optional, Tuple
6
+ from diffusers.models import ModelMixin
7
+ from diffusers.models.transformers.transformer_flux import (
8
+ FluxTransformer2DModel, FluxPosEmbed, FluxTransformerBlock, FluxSingleTransformerBlock)
9
+ from diffusers.models.embeddings import (
10
+ CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings)
11
+ from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
12
+ from diffusers.configuration_utils import register_to_config
13
+ from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers
14
+ from .gm_output import GMFlowModelOutput
15
+
16
+
17
+ class _GMFluxTransformer2DModel(FluxTransformer2DModel):
18
+
19
+ @register_to_config
20
+ def __init__(
21
+ self,
22
+ num_gaussians=16,
23
+ constant_logstd=None,
24
+ logstd_inner_dim=1024,
25
+ gm_num_logstd_layers=2,
26
+ logweights_channels=1,
27
+ in_channels: int = 64,
28
+ out_channels: Optional[int] = None,
29
+ num_layers: int = 19,
30
+ num_single_layers: int = 38,
31
+ attention_head_dim: int = 128,
32
+ num_attention_heads: int = 24,
33
+ joint_attention_dim: int = 4096,
34
+ pooled_projection_dim: int = 768,
35
+ guidance_embeds: bool = False,
36
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)):
37
+ super(FluxTransformer2DModel, self).__init__()
38
+
39
+ self.num_gaussians = num_gaussians
40
+ self.logweights_channels = logweights_channels
41
+
42
+ self.out_channels = out_channels or in_channels
43
+ self.inner_dim = num_attention_heads * attention_head_dim
44
+
45
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
46
+
47
+ text_time_guidance_cls = (
48
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
49
+ )
50
+ self.time_text_embed = text_time_guidance_cls(
51
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
52
+ )
53
+
54
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
55
+ self.x_embedder = nn.Linear(in_channels, self.inner_dim)
56
+
57
+ self.transformer_blocks = nn.ModuleList(
58
+ [
59
+ FluxTransformerBlock(
60
+ dim=self.inner_dim,
61
+ num_attention_heads=num_attention_heads,
62
+ attention_head_dim=attention_head_dim,
63
+ )
64
+ for _ in range(num_layers)
65
+ ]
66
+ )
67
+
68
+ self.single_transformer_blocks = nn.ModuleList(
69
+ [
70
+ FluxSingleTransformerBlock(
71
+ dim=self.inner_dim,
72
+ num_attention_heads=num_attention_heads,
73
+ attention_head_dim=attention_head_dim,
74
+ )
75
+ for _ in range(num_single_layers)
76
+ ]
77
+ )
78
+
79
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
80
+ self.proj_out_means = nn.Linear(self.inner_dim, self.num_gaussians * self.out_channels)
81
+ self.proj_out_logweights = nn.Linear(self.inner_dim, self.num_gaussians * self.logweights_channels)
82
+ self.constant_logstd = constant_logstd
83
+
84
+ if self.constant_logstd is None:
85
+ assert gm_num_logstd_layers >= 1
86
+ in_dim = self.inner_dim
87
+ logstd_layers = []
88
+ for _ in range(gm_num_logstd_layers - 1):
89
+ logstd_layers.extend([
90
+ nn.SiLU(),
91
+ nn.Linear(in_dim, logstd_inner_dim)])
92
+ in_dim = logstd_inner_dim
93
+ self.proj_out_logstds = nn.Sequential(
94
+ *logstd_layers,
95
+ nn.SiLU(),
96
+ nn.Linear(in_dim, 1))
97
+
98
+ self.gradient_checkpointing = False
99
+
100
+ def forward(
101
+ self,
102
+ hidden_states: torch.Tensor,
103
+ encoder_hidden_states: torch.Tensor = None,
104
+ pooled_projections: torch.Tensor = None,
105
+ timestep: torch.Tensor = None,
106
+ img_ids: torch.Tensor = None,
107
+ txt_ids: torch.Tensor = None,
108
+ guidance: torch.Tensor = None,
109
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
110
+ controlnet_block_samples=None,
111
+ controlnet_single_block_samples=None,
112
+ controlnet_blocks_repeat: bool = False):
113
+ if joint_attention_kwargs is not None:
114
+ joint_attention_kwargs = joint_attention_kwargs.copy()
115
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
116
+ else:
117
+ lora_scale = 1.0
118
+
119
+ if USE_PEFT_BACKEND:
120
+ scale_lora_layers(self, lora_scale)
121
+ else:
122
+ assert joint_attention_kwargs is None or joint_attention_kwargs.get('scale', None) is None
123
+
124
+ hidden_states = self.x_embedder(hidden_states)
125
+
126
+ timestep = timestep.to(hidden_states.dtype) * 1000
127
+ if guidance is not None:
128
+ guidance = guidance.to(hidden_states.dtype) * 1000
129
+
130
+ temb = (
131
+ self.time_text_embed(timestep, pooled_projections)
132
+ if guidance is None
133
+ else self.time_text_embed(timestep, guidance, pooled_projections)
134
+ )
135
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
136
+
137
+ ids = torch.cat((txt_ids, img_ids), dim=0)
138
+ image_rotary_emb = self.pos_embed(ids)
139
+ image_rotary_emb = tuple([x.to(hidden_states.dtype) for x in image_rotary_emb])
140
+
141
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
142
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
143
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
144
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
145
+
146
+ for index_block, block in enumerate(self.transformer_blocks):
147
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
148
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
149
+ block,
150
+ hidden_states,
151
+ encoder_hidden_states,
152
+ temb,
153
+ image_rotary_emb,
154
+ joint_attention_kwargs,
155
+ )
156
+
157
+ else:
158
+ encoder_hidden_states, hidden_states = block(
159
+ hidden_states=hidden_states,
160
+ encoder_hidden_states=encoder_hidden_states,
161
+ temb=temb,
162
+ image_rotary_emb=image_rotary_emb,
163
+ joint_attention_kwargs=joint_attention_kwargs,
164
+ )
165
+
166
+ # controlnet residual
167
+ if controlnet_block_samples is not None:
168
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
169
+ interval_control = int(np.ceil(interval_control))
170
+ # For Xlabs ControlNet.
171
+ if controlnet_blocks_repeat:
172
+ hidden_states = (
173
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
174
+ )
175
+ else:
176
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
177
+
178
+ for index_block, block in enumerate(self.single_transformer_blocks):
179
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
180
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
181
+ block,
182
+ hidden_states,
183
+ encoder_hidden_states,
184
+ temb,
185
+ image_rotary_emb,
186
+ joint_attention_kwargs,
187
+ )
188
+
189
+ else:
190
+ encoder_hidden_states, hidden_states = block(
191
+ hidden_states=hidden_states,
192
+ encoder_hidden_states=encoder_hidden_states,
193
+ temb=temb,
194
+ image_rotary_emb=image_rotary_emb,
195
+ joint_attention_kwargs=joint_attention_kwargs,
196
+ )
197
+
198
+ # controlnet residual
199
+ if controlnet_single_block_samples is not None:
200
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
201
+ interval_control = int(np.ceil(interval_control))
202
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
203
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
204
+ + controlnet_single_block_samples[index_block // interval_control]
205
+ )
206
+
207
+ hidden_states = self.norm_out(hidden_states, temb)
208
+
209
+ bs, seq_len, _ = hidden_states.size()
210
+ out_means = self.proj_out_means(hidden_states).reshape(
211
+ bs, seq_len, self.num_gaussians, self.out_channels)
212
+ out_logweights = self.proj_out_logweights(hidden_states).reshape(
213
+ bs, seq_len, self.num_gaussians, self.logweights_channels).log_softmax(dim=-2)
214
+ if self.constant_logstd is None:
215
+ out_logstds = self.proj_out_logstds(temb.detach()).reshape(bs, 1, 1, 1)
216
+ else:
217
+ out_logstds = hidden_states.new_full((bs, 1, 1, 1), float(self.constant_logstd))
218
+
219
+ if USE_PEFT_BACKEND:
220
+ unscale_lora_layers(self, lora_scale)
221
+
222
+ return GMFlowModelOutput(
223
+ means=out_means,
224
+ logweights=out_logweights,
225
+ logstds=out_logstds)
lakonlab/models/architecture/gmflow/gmqwen.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from typing import Any, Dict, Optional, Tuple, List
5
+ from diffusers.models import ModelMixin
6
+ from diffusers.models.transformers.transformer_qwenimage import (
7
+ QwenImageTransformer2DModel, QwenEmbedRope, QwenImageTransformerBlock, QwenTimestepProjEmbeddings)
8
+ from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm
9
+ from diffusers.configuration_utils import register_to_config
10
+ from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers
11
+ from .gm_output import GMFlowModelOutput
12
+
13
+
14
+ class _GMQwenImageTransformer2DModel(QwenImageTransformer2DModel):
15
+
16
+ @register_to_config
17
+ def __init__(
18
+ self,
19
+ num_gaussians=16,
20
+ constant_logstd=None,
21
+ logstd_inner_dim=1024,
22
+ gm_num_logstd_layers=2,
23
+ logweights_channels=1,
24
+ in_channels: int = 64,
25
+ out_channels: Optional[int] = None,
26
+ num_layers: int = 60,
27
+ attention_head_dim: int = 128,
28
+ num_attention_heads: int = 24,
29
+ joint_attention_dim: int = 3584,
30
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)):
31
+ super(QwenImageTransformer2DModel, self).__init__()
32
+
33
+ self.num_gaussians = num_gaussians
34
+ self.logweights_channels = logweights_channels
35
+
36
+ self.out_channels = out_channels or in_channels
37
+ self.inner_dim = num_attention_heads * attention_head_dim
38
+
39
+ self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
40
+
41
+ self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
42
+
43
+ self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
44
+
45
+ self.img_in = nn.Linear(in_channels, self.inner_dim)
46
+ self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim)
47
+
48
+ self.transformer_blocks = nn.ModuleList(
49
+ [
50
+ QwenImageTransformerBlock(
51
+ dim=self.inner_dim,
52
+ num_attention_heads=num_attention_heads,
53
+ attention_head_dim=attention_head_dim,
54
+ )
55
+ for _ in range(num_layers)
56
+ ]
57
+ )
58
+
59
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
60
+ self.proj_out_means = nn.Linear(self.inner_dim, self.num_gaussians * self.out_channels)
61
+ self.proj_out_logweights = nn.Linear(self.inner_dim, self.num_gaussians * self.logweights_channels)
62
+ self.constant_logstd = constant_logstd
63
+
64
+ if self.constant_logstd is None:
65
+ assert gm_num_logstd_layers >= 1
66
+ in_dim = self.inner_dim
67
+ logstd_layers = []
68
+ for _ in range(gm_num_logstd_layers - 1):
69
+ logstd_layers.extend([
70
+ nn.SiLU(),
71
+ nn.Linear(in_dim, logstd_inner_dim)])
72
+ in_dim = logstd_inner_dim
73
+ self.proj_out_logstds = nn.Sequential(
74
+ *logstd_layers,
75
+ nn.SiLU(),
76
+ nn.Linear(in_dim, 1))
77
+
78
+ self.gradient_checkpointing = False
79
+
80
+ def forward(
81
+ self,
82
+ hidden_states: torch.Tensor,
83
+ encoder_hidden_states: torch.Tensor = None,
84
+ encoder_hidden_states_mask: torch.Tensor = None,
85
+ timestep: torch.LongTensor = None,
86
+ img_shapes: Optional[List[Tuple[int, int, int]]] = None,
87
+ txt_seq_lens: Optional[List[int]] = None,
88
+ attention_kwargs: Optional[Dict[str, Any]] = None):
89
+ if attention_kwargs is not None:
90
+ attention_kwargs = attention_kwargs.copy()
91
+ lora_scale = attention_kwargs.pop("scale", 1.0)
92
+ else:
93
+ lora_scale = 1.0
94
+
95
+ if USE_PEFT_BACKEND:
96
+ scale_lora_layers(self, lora_scale)
97
+ else:
98
+ assert attention_kwargs is None or attention_kwargs.get('scale', None) is None
99
+
100
+ hidden_states = self.img_in(hidden_states)
101
+
102
+ timestep = timestep.to(hidden_states.dtype)
103
+ encoder_hidden_states = self.txt_norm(encoder_hidden_states)
104
+ encoder_hidden_states = self.txt_in(encoder_hidden_states)
105
+
106
+ temb = self.time_text_embed(timestep, hidden_states)
107
+
108
+ image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
109
+
110
+ for index_block, block in enumerate(self.transformer_blocks):
111
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
112
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
113
+ block,
114
+ hidden_states,
115
+ encoder_hidden_states,
116
+ encoder_hidden_states_mask,
117
+ temb,
118
+ image_rotary_emb,
119
+ )
120
+
121
+ else:
122
+ encoder_hidden_states, hidden_states = block(
123
+ hidden_states=hidden_states,
124
+ encoder_hidden_states=encoder_hidden_states,
125
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
126
+ temb=temb,
127
+ image_rotary_emb=image_rotary_emb,
128
+ joint_attention_kwargs=attention_kwargs,
129
+ )
130
+
131
+ hidden_states = self.norm_out(hidden_states, temb)
132
+
133
+ bs, seq_len, _ = hidden_states.size()
134
+ out_means = self.proj_out_means(hidden_states).reshape(
135
+ bs, seq_len, self.num_gaussians, self.out_channels)
136
+ out_logweights = self.proj_out_logweights(hidden_states).reshape(
137
+ bs, seq_len, self.num_gaussians, self.logweights_channels).log_softmax(dim=-2)
138
+ if self.constant_logstd is None:
139
+ out_logstds = self.proj_out_logstds(temb.detach()).reshape(bs, 1, 1, 1)
140
+ else:
141
+ out_logstds = hidden_states.new_full((bs, 1, 1, 1), float(self.constant_logstd))
142
+
143
+ if USE_PEFT_BACKEND:
144
+ unscale_lora_layers(self, lora_scale)
145
+
146
+ return GMFlowModelOutput(
147
+ means=out_means,
148
+ logweights=out_logweights,
149
+ logstds=out_logstds)
lakonlab/models/diffusions/__init__.py ADDED
File without changes
lakonlab/models/diffusions/piflow_policies/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .dx import DXPolicy
2
+ from .gmflow import GMFlowPolicy
3
+
4
+
5
+ POLICY_CLASSES = dict(
6
+ DX=DXPolicy,
7
+ GMFlow=GMFlowPolicy
8
+ )
lakonlab/models/diffusions/piflow_policies/base.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABCMeta, abstractmethod
2
+
3
+
4
+ class BasePolicy(metaclass=ABCMeta):
5
+
6
+ @abstractmethod
7
+ def u(self, x_t, sigma_t):
8
+ """Compute the flow velocity at (x_t, t).
9
+
10
+ Args:
11
+ x_t (torch.Tensor): Noisy input at time t.
12
+ sigma_t (torch.Tensor): Noise level at time t.
13
+
14
+ Returns:
15
+ torch.Tensor: The computed flow velocity u_t.
16
+ """
17
+ pass
18
+
19
+ @abstractmethod
20
+ def detach(self):
21
+ pass
lakonlab/models/diffusions/piflow_policies/dx.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Hansheng Chen
2
+
3
+ import torch
4
+ from .base import BasePolicy
5
+
6
+
7
+ class DXPolicy(BasePolicy):
8
+ """DX policy. The number of grid points N is inferred from the denoising output.
9
+
10
+ Note: segment_size and shift are intrinsic parameters of the DX policy. For elastic inference (i.e., changing
11
+ the number of function evaluations or noise schedule at test time), these parameters should be kept unchanged.
12
+
13
+ Args:
14
+ denoising_output (torch.Tensor): The output of the denoising model. Shape (B, N, C, H, W) or (B, N, C, T, H, W).
15
+ x_t_src (torch.Tensor): The initial noisy sample. Shape (B, C, H, W) or (B, C, T, H, W).
16
+ sigma_t_src (torch.Tensor): The initial noise level. Shape (B,).
17
+ segment_size (float): The size of each DX policy time segment. Defaults to 1.0.
18
+ shift (float): The shift parameter for the DX policy noise schedule. Defaults to 1.0.
19
+ eps (float): A small value to avoid numerical issues. Defaults to 1e-4.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ denoising_output: torch.Tensor,
25
+ x_t_src: torch.Tensor,
26
+ sigma_t_src: torch.Tensor,
27
+ segment_size: float = 1.0,
28
+ shift: float = 1.0,
29
+ eps: float = 1e-4):
30
+ self.x_t_src = x_t_src
31
+ self.ndim = x_t_src.dim()
32
+ self.shift = shift
33
+ self.eps = eps
34
+
35
+ self.sigma_t_src = sigma_t_src.reshape(*sigma_t_src.size(), *((self.ndim - sigma_t_src.dim()) * [1]))
36
+ self.raw_t_src = self._unwarp_t(self.sigma_t_src)
37
+ self.raw_t_dst = (self.raw_t_src - segment_size).clamp(min=0)
38
+ self.segment_size = (self.raw_t_src - self.raw_t_dst).clamp(min=eps)
39
+
40
+ self.denoising_output_x_0 = self._u_to_x_0(
41
+ denoising_output, self.x_t_src, self.sigma_t_src)
42
+
43
+ def _unwarp_t(self, sigma_t):
44
+ return sigma_t / (self.shift + (1 - self.shift) * sigma_t)
45
+
46
+ @staticmethod
47
+ def _u_to_x_0(denoising_output, x_t, sigma_t):
48
+ x_0 = x_t.unsqueeze(1) - sigma_t.unsqueeze(1) * denoising_output
49
+ return x_0
50
+
51
+ @staticmethod
52
+ def _interpolate(x, t):
53
+ """
54
+ Args:
55
+ x (torch.Tensor): (B, N, *)
56
+ t (torch.Tensor): (B, *) in [0, 1]
57
+
58
+ Returns:
59
+ torch.Tensor: (B, *)
60
+ """
61
+ n = x.size(1)
62
+ if n < 2:
63
+ return x.squeeze(1)
64
+ t = t.clamp(min=0, max=1) * (n - 1)
65
+ t0 = t.floor().to(torch.long).clamp(min=0, max=n - 2)
66
+ t1 = t0 + 1
67
+ t0t1 = torch.stack([t0, t1], dim=1) # (B, 2, *)
68
+ x0x1 = torch.gather(x, dim=1, index=t0t1.expand(-1, -1, *x.shape[2:]))
69
+ x_interp = (t1 - t) * x0x1[:, 0] + (t - t0) * x0x1[:, 1]
70
+ return x_interp
71
+
72
+ def u(self, x_t, sigma_t):
73
+ """Compute the flow velocity at (x_t, t).
74
+
75
+ Args:
76
+ x_t (torch.Tensor): Noisy input at time t.
77
+ sigma_t (torch.Tensor): Noise level at time t.
78
+
79
+ Returns:
80
+ torch.Tensor: The computed flow velocity u_t.
81
+ """
82
+ sigma_t = sigma_t.reshape(*sigma_t.size(), *((self.ndim - sigma_t.dim()) * [1]))
83
+ raw_t = self._unwarp_t(sigma_t)
84
+ x_0 = self._interpolate(
85
+ self.denoising_output_x_0, (raw_t - self.raw_t_dst) / self.segment_size)
86
+ u = (x_t - x_0) / sigma_t.clamp(min=self.eps)
87
+ return u
88
+
89
+ def copy(self):
90
+ new_policy = DXPolicy.__new__(DXPolicy)
91
+ new_policy.x_t_src = self.x_t_src
92
+ new_policy.ndim = self.ndim
93
+ new_policy.shift = self.shift
94
+ new_policy.eps = self.eps
95
+ new_policy.sigma_t_src = self.sigma_t_src
96
+ new_policy.raw_t_src = self.raw_t_src
97
+ new_policy.raw_t_dst = self.raw_t_dst
98
+ new_policy.segment_size = self.segment_size
99
+ new_policy.denoising_output_x_0 = self.denoising_output_x_0
100
+ return new_policy
101
+
102
+ def detach_(self):
103
+ self.denoising_output_x_0 = self.denoising_output_x_0.detach()
104
+ return self
105
+
106
+ def detach(self):
107
+ new_policy = self.copy()
108
+ return new_policy.detach_()
lakonlab/models/diffusions/piflow_policies/gmflow.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Hansheng Chen
2
+
3
+ import torch
4
+
5
+ from typing import Dict
6
+ from .base import BasePolicy
7
+
8
+
9
+ @torch.jit.script
10
+ def gmflow_posterior_mean_jit(
11
+ sigma_t_src, sigma_t, x_t_src, x_t,
12
+ gm_means, gm_vars, gm_logweights,
13
+ eps: float, gm_dim: int = -4, channel_dim: int = -3):
14
+ alpha_t_src = 1 - sigma_t_src
15
+ alpha_t = 1 - sigma_t
16
+
17
+ sigma_t_src_sq = sigma_t_src.square()
18
+ sigma_t_sq = sigma_t.square()
19
+
20
+ # compute gaussian params
21
+ denom = (alpha_t.square() * sigma_t_src_sq - alpha_t_src.square() * sigma_t_sq).clamp(min=eps) # ζ
22
+ g_mean = (alpha_t * sigma_t_src_sq * x_t - alpha_t_src * sigma_t_sq * x_t_src) / denom # ν / ζ
23
+ g_var = sigma_t_sq * sigma_t_src_sq / denom
24
+
25
+ # gm_mul_iso_gaussian
26
+ g_mean = g_mean.unsqueeze(gm_dim) # (bs, *, 1, out_channels, h, w)
27
+ g_var = g_var.unsqueeze(gm_dim) # (bs, *, 1, 1, 1, 1)
28
+
29
+ gm_diffs = gm_means - g_mean # (bs, *, num_gaussians, out_channels, h, w)
30
+ norm_factor = (g_var + gm_vars).clamp(min=eps)
31
+
32
+ out_means = (g_var * gm_means + gm_vars * g_mean) / norm_factor
33
+ # (bs, *, num_gaussians, 1, h, w)
34
+ logweights_delta = gm_diffs.square().sum(dim=channel_dim, keepdim=True) * (-0.5 / norm_factor)
35
+ out_weights = (gm_logweights + logweights_delta).softmax(dim=gm_dim)
36
+
37
+ out_mean = (out_means * out_weights).sum(dim=gm_dim)
38
+
39
+ return out_mean
40
+
41
+
42
+ def gm_temperature(gm, temperature, gm_dim=-4, eps=1e-6):
43
+ gm = gm.copy()
44
+ temperature = max(temperature, eps)
45
+ gm['logweights'] = (gm['logweights'] / temperature).log_softmax(dim=gm_dim)
46
+ if 'logstds' in gm:
47
+ gm['logstds'] = gm['logstds'] + (0.5 * math.log(temperature))
48
+ if 'gm_vars' in gm:
49
+ gm['gm_vars'] = gm['gm_vars'] * temperature
50
+ return gm
51
+
52
+
53
+ class GMFlowPolicy(BasePolicy):
54
+ """GMFlow policy. The number of components K is inferred from the denoising output.
55
+
56
+ Args:
57
+ denoising_output (dict): The output of the denoising model, containing:
58
+ means (torch.Tensor): The means of the Gaussian components. Shape (B, K, C, H, W) or (B, K, C, T, H, W).
59
+ logstds (torch.Tensor): The log standard deviations of the Gaussian components. Shape (B, K, 1, 1, 1)
60
+ or (B, K, 1, 1, 1, 1).
61
+ logweights (torch.Tensor): The log weights of the Gaussian components. Shape (B, K, 1, H, W) or
62
+ (B, K, 1, T, H, W).
63
+ x_t_src (torch.Tensor): The initial noisy sample. Shape (B, C, H, W) or (B, C, T, H, W).
64
+ sigma_t_src (torch.Tensor): The initial noise level. Shape (B,).
65
+ checkpointing (bool): Whether to use gradient checkpointing to save memory. Defaults to True.
66
+ eps (float): A small value to avoid numerical issues. Defaults to 1e-4.
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ denoising_output: Dict[str, torch.Tensor],
72
+ x_t_src: torch.Tensor,
73
+ sigma_t_src: torch.Tensor,
74
+ checkpointing: bool = True,
75
+ eps: float = 1e-4):
76
+ self.x_t_src = x_t_src
77
+ self.ndim = x_t_src.dim()
78
+ self.checkpointing = checkpointing
79
+ self.eps = eps
80
+
81
+ self.sigma_t_src = sigma_t_src.reshape(*sigma_t_src.size(), *((self.ndim - sigma_t_src.dim()) * [1]))
82
+ self.denoising_output_x_0 = self._u_to_x_0(
83
+ denoising_output, self.x_t_src, self.sigma_t_src)
84
+
85
+ @staticmethod
86
+ def _u_to_x_0(denoising_output, x_t, sigma_t):
87
+ x_t = x_t.unsqueeze(1)
88
+ sigma_t = sigma_t.unsqueeze(1)
89
+ means_x_0 = x_t - sigma_t * denoising_output['means']
90
+ gm_vars = (denoising_output['logstds'] * 2).exp() * sigma_t.square()
91
+ return dict(
92
+ means=means_x_0,
93
+ gm_vars=gm_vars,
94
+ logweights=denoising_output['logweights'])
95
+
96
+ def u(self, x_t, sigma_t):
97
+ """Compute the flow velocity at (x_t, t).
98
+
99
+ Args:
100
+ x_t (torch.Tensor): Noisy input at time t.
101
+ sigma_t (torch.Tensor): Noise level at time t.
102
+
103
+ Returns:
104
+ torch.Tensor: The computed flow velocity u_t.
105
+ """
106
+ sigma_t = sigma_t.reshape(*sigma_t.size(), *((self.ndim - sigma_t.dim()) * [1]))
107
+ means = self.denoising_output_x_0['means']
108
+ gm_vars = self.denoising_output_x_0['gm_vars']
109
+ logweights = self.denoising_output_x_0['logweights']
110
+ if (sigma_t == self.sigma_t_src).all() and (x_t == self.x_t_src).all():
111
+ x_0 = (logweights.softmax(dim=1) * means).sum(dim=1)
112
+ else:
113
+ if self.checkpointing and torch.is_grad_enabled():
114
+ x_0 = torch.utils.checkpoint.checkpoint(
115
+ gmflow_posterior_mean_jit,
116
+ self.sigma_t_src, sigma_t, self.x_t_src, x_t,
117
+ means,
118
+ gm_vars,
119
+ logweights,
120
+ self.eps, 1, 2,
121
+ use_reentrant=True) # use_reentrant=False does not work with jit
122
+ else:
123
+ x_0 = gmflow_posterior_mean_jit(
124
+ self.sigma_t_src, sigma_t, self.x_t_src, x_t,
125
+ means,
126
+ gm_vars,
127
+ logweights,
128
+ self.eps, 1, 2)
129
+ u = (x_t - x_0) / sigma_t.clamp(min=self.eps)
130
+ return u
131
+
132
+ def copy(self):
133
+ new_policy = GMFlowPolicy.__new__(GMFlowPolicy)
134
+ new_policy.x_t_src = self.x_t_src
135
+ new_policy.ndim = self.ndim
136
+ new_policy.checkpointing = self.checkpointing
137
+ new_policy.eps = self.eps
138
+ new_policy.sigma_t_src = self.sigma_t_src
139
+ new_policy.denoising_output_x_0 = self.denoising_output_x_0.copy()
140
+ return new_policy
141
+
142
+ def detach_(self):
143
+ self.denoising_output_x_0 = {k: v.detach() for k, v in self.denoising_output_x_0.items()}
144
+ return self
145
+
146
+ def detach(self):
147
+ new_policy = self.copy()
148
+ return new_policy.detach_()
149
+
150
+ def dropout_(self, p):
151
+ if p <= 0 or p >= 1:
152
+ return self
153
+ logweights = self.denoising_output_x_0['logweights']
154
+ dropout_mask = torch.rand(
155
+ (*logweights.shape[:2], *((self.ndim - 1) * [1])), device=logweights.device) < p
156
+ is_all_dropout = dropout_mask.all(dim=1, keepdim=True)
157
+ dropout_mask &= ~is_all_dropout
158
+ self.denoising_output_x_0['logweights'] = logweights.masked_fill(
159
+ dropout_mask, float('-inf'))
160
+ return self
161
+
162
+ def dropout(self, p):
163
+ new_policy = self.copy()
164
+ return new_policy.dropout_(p)
165
+
166
+ def temperature_(self, temp):
167
+ if temp >= 1.0:
168
+ return self
169
+ self.denoising_output_x_0 = gm_temperature(
170
+ self.denoising_output_x_0, temp, gm_dim=1, eps=self.eps)
171
+ return self
172
+
173
+ def temperature(self, temp):
174
+ new_policy = self.copy()
175
+ return new_policy.temperature_(temp)
lakonlab/pipelines/__init__.py ADDED
File without changes
lakonlab/pipelines/piflow_loader.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Hansheng Chen
2
+
3
+ import os
4
+ from typing import Union, Optional
5
+
6
+ import torch
7
+ import accelerate
8
+ import diffusers
9
+ from diffusers.models import AutoModel
10
+ from diffusers.models.modeling_utils import (
11
+ load_state_dict,
12
+ _LOW_CPU_MEM_USAGE_DEFAULT,
13
+ no_init_weights,
14
+ ContextManagers
15
+ )
16
+ from diffusers.utils import (
17
+ SAFETENSORS_WEIGHTS_NAME,
18
+ WEIGHTS_NAME,
19
+ _add_variant,
20
+ _get_model_file,
21
+ is_accelerate_available,
22
+ is_torch_version,
23
+ logging,
24
+ )
25
+ from diffusers.loaders.peft import _SET_ADAPTER_SCALE_FN_MAPPING
26
+ from lakonlab.models.architecture.gmflow.gmflux import _GMFluxTransformer2DModel
27
+ from lakonlab.models.architecture.gmflow.gmqwen import _GMQwenImageTransformer2DModel
28
+
29
+
30
+ LOCAL_CLASS_MAPPING = {
31
+ "GMFluxTransformer2DModel": _GMFluxTransformer2DModel,
32
+ "GMQwenImageTransformer2DModel": _GMQwenImageTransformer2DModel,
33
+ }
34
+
35
+ _SET_ADAPTER_SCALE_FN_MAPPING.update(
36
+ _GMFluxTransformer2DModel=lambda model_cls, weights: weights,
37
+ _GMQwenImageTransformer2DModel=lambda model_cls, weights: weights,
38
+ )
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+
43
+ class PiFlowLoaderMixin:
44
+
45
+ def load_piflow_adapter(
46
+ self,
47
+ pretrained_model_name_or_path: Union[str, os.PathLike],
48
+ target_module_name: str = "transformer",
49
+ adapter_name: Optional[str] = None,
50
+ **kwargs
51
+ ):
52
+ r"""
53
+ Load a PiFlow adapter from a pretrained model repository into the target module.
54
+
55
+ Args:
56
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
57
+ Can be either:
58
+
59
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
60
+ the Hub.
61
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
62
+ with [`~ModelMixin.save_pretrained`].
63
+
64
+ target_module_name (`str`, *optional*, defaults to `"transformer"`):
65
+ The module name in the model to load the PiFlow adapter into.
66
+ adapter_name (`str`, *optional*):
67
+ The name to assign to the loaded adapter. If not provided, it defaults to
68
+ `"{target_module_name}_piflow"`.
69
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
70
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
71
+ is not used.
72
+ force_download (`bool`, *optional*, defaults to `False`):
73
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
74
+ cached versions if they exist.
75
+ proxies (`Dict[str, str]`, *optional*):
76
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
77
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
78
+ local_files_only(`bool`, *optional*, defaults to `False`):
79
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
80
+ won't be downloaded from the Hub.
81
+ token (`str` or *bool*, *optional*):
82
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
83
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
84
+ revision (`str`, *optional*, defaults to `"main"`):
85
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
86
+ allowed by Git.
87
+ subfolder (`str`, *optional*, defaults to `""`):
88
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
89
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
90
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
91
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
92
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
93
+ argument to `True` will raise an error.
94
+ variant (`str`, *optional*):
95
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
96
+ loading `from_flax`.
97
+ use_safetensors (`bool`, *optional*, defaults to `None`):
98
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
99
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
100
+ weights. If set to `False`, `safetensors` weights are not loaded.
101
+ disable_mmap ('bool', *optional*, defaults to 'False'):
102
+ Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
103
+ is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
104
+
105
+ Returns:
106
+ `str` or `None`: The name assigned to the loaded adapter, or `None` if no LoRA weights were found.
107
+ """
108
+ cache_dir = kwargs.pop("cache_dir", None)
109
+ force_download = kwargs.pop("force_download", False)
110
+ proxies = kwargs.pop("proxies", None)
111
+ token = kwargs.pop("token", None)
112
+ local_files_only = kwargs.pop("local_files_only", False)
113
+ revision = kwargs.pop("revision", None)
114
+ subfolder = kwargs.pop("subfolder", None)
115
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
116
+ variant = kwargs.pop("variant", None)
117
+ use_safetensors = kwargs.pop("use_safetensors", None)
118
+ disable_mmap = kwargs.pop("disable_mmap", False)
119
+
120
+ allow_pickle = False
121
+ if use_safetensors is None:
122
+ use_safetensors = True
123
+ allow_pickle = True
124
+
125
+ if low_cpu_mem_usage and not is_accelerate_available():
126
+ low_cpu_mem_usage = False
127
+ logger.warning(
128
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
129
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
130
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
131
+ " install accelerate\n```\n."
132
+ )
133
+
134
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
135
+ raise NotImplementedError(
136
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
137
+ " `low_cpu_mem_usage=False`."
138
+ )
139
+
140
+ user_agent = {
141
+ "diffusers": diffusers.__version__,
142
+ "file_type": "model",
143
+ "framework": "pytorch",
144
+ }
145
+
146
+ # 1. Determine model class from config
147
+
148
+ load_config_kwargs = {
149
+ "cache_dir": cache_dir,
150
+ "force_download": force_download,
151
+ "proxies": proxies,
152
+ "token": token,
153
+ "local_files_only": local_files_only,
154
+ "revision": revision,
155
+ }
156
+
157
+ config = AutoModel.load_config(pretrained_model_name_or_path, subfolder=subfolder, **load_config_kwargs)
158
+
159
+ orig_class_name = config["_class_name"]
160
+
161
+ if orig_class_name in LOCAL_CLASS_MAPPING:
162
+ model_cls = LOCAL_CLASS_MAPPING[orig_class_name]
163
+
164
+ else:
165
+ load_config_kwargs.update({"subfolder": subfolder})
166
+
167
+ from diffusers.pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
168
+
169
+ model_cls, _ = get_class_obj_and_candidates(
170
+ library_name="diffusers",
171
+ class_name=orig_class_name,
172
+ importable_classes=ALL_IMPORTABLE_CLASSES,
173
+ pipelines=None,
174
+ is_pipeline_module=False,
175
+ )
176
+
177
+ if model_cls is None:
178
+ raise ValueError(f"Can't find a model linked to {orig_class_name}.")
179
+
180
+ # 2. Get model file
181
+
182
+ model_file = None
183
+
184
+ if use_safetensors:
185
+ try:
186
+ model_file = _get_model_file(
187
+ pretrained_model_name_or_path,
188
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
189
+ cache_dir=cache_dir,
190
+ force_download=force_download,
191
+ proxies=proxies,
192
+ local_files_only=local_files_only,
193
+ token=token,
194
+ revision=revision,
195
+ subfolder=subfolder,
196
+ user_agent=user_agent,
197
+ )
198
+
199
+ except IOError as e:
200
+ logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
201
+ if not allow_pickle:
202
+ raise
203
+ logger.warning(
204
+ "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
205
+ )
206
+
207
+ if model_file is None:
208
+ model_file = _get_model_file(
209
+ pretrained_model_name_or_path,
210
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
211
+ cache_dir=cache_dir,
212
+ force_download=force_download,
213
+ proxies=proxies,
214
+ local_files_only=local_files_only,
215
+ token=token,
216
+ revision=revision,
217
+ subfolder=subfolder,
218
+ user_agent=user_agent,
219
+ )
220
+
221
+ # 3. Initialize model
222
+
223
+ base_module = getattr(self, target_module_name)
224
+
225
+ torch_dtype = base_module.dtype
226
+ device = base_module.device
227
+ dtype_orig = model_cls._set_default_torch_dtype(torch_dtype)
228
+
229
+ init_contexts = [no_init_weights()]
230
+
231
+ if low_cpu_mem_usage:
232
+ init_contexts.append(accelerate.init_empty_weights())
233
+
234
+ with ContextManagers(init_contexts):
235
+ piflow_module = model_cls.from_config(config).eval()
236
+
237
+ torch.set_default_dtype(dtype_orig)
238
+
239
+ # 4. Load model weights
240
+
241
+ if model_file is not None:
242
+ base_state_dict = base_module.state_dict()
243
+ lora_state_dict = dict()
244
+
245
+ adapter_state_dict = load_state_dict(model_file, disable_mmap=disable_mmap)
246
+ for k in adapter_state_dict.keys():
247
+ adapter_state_dict[k] = adapter_state_dict[k].to(dtype=torch_dtype, device=device)
248
+ if "lora" in k:
249
+ lora_state_dict[k.removeprefix(f"{target_module_name}.")] = adapter_state_dict[k]
250
+ else:
251
+ base_state_dict[k.removeprefix(f"{target_module_name}.")] = adapter_state_dict[k]
252
+
253
+ if len(lora_state_dict) == 0:
254
+ adapter_name = None
255
+
256
+ else:
257
+ if adapter_name is None:
258
+ adapter_name = f"{target_module_name}_piflow"
259
+
260
+ piflow_module.load_state_dict(
261
+ base_state_dict, strict=False, assign=True)
262
+ piflow_module.load_lora_adapter(
263
+ lora_state_dict, prefix=None, adapter_name=adapter_name)
264
+
265
+ setattr(self, target_module_name, piflow_module)
266
+
267
+ else:
268
+ adapter_name = None
269
+
270
+ if adapter_name is None:
271
+ logger.warning(
272
+ f"No LoRA weights were found in {pretrained_model_name_or_path}."
273
+ )
274
+
275
+ return adapter_name
lakonlab/pipelines/piflux_pipeline.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Hansheng Chen
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from typing import Dict, List, Optional, Union, Any, Callable
7
+ from functools import partial
8
+ from transformers import (
9
+ CLIPImageProcessor,
10
+ CLIPTextModel,
11
+ CLIPTokenizer,
12
+ CLIPVisionModelWithProjection,
13
+ T5EncoderModel,
14
+ T5TokenizerFast,
15
+ )
16
+ from diffusers.utils import is_torch_xla_available
17
+ from diffusers.image_processor import PipelineImageInput
18
+ from diffusers.models import AutoencoderKL, FluxTransformer2DModel
19
+ from diffusers.pipelines.flux.pipeline_flux import (
20
+ FluxPipeline, calculate_shift, FluxPipelineOutput, retrieve_timesteps)
21
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
22
+ from lakonlab.models.diffusions.piflow_policies import POLICY_CLASSES
23
+ from .piflow_loader import PiFlowLoaderMixin
24
+
25
+
26
+ if is_torch_xla_available():
27
+ import torch_xla.core.xla_model as xm
28
+
29
+ XLA_AVAILABLE = True
30
+ else:
31
+ XLA_AVAILABLE = False
32
+
33
+
34
+ def retrieve_raw_timesteps(
35
+ num_inference_steps: int,
36
+ total_substeps: int,
37
+ final_step_size_scale: float
38
+ ):
39
+ r"""
40
+ Retrieve the raw times and the number of substeps for each inference step.
41
+
42
+ Args:
43
+ num_inference_steps (`int`):
44
+ Number of inference steps.
45
+ total_substeps (`int`):
46
+ Total number of substeps (e.g., 128).
47
+ final_step_size_scale (`float`):
48
+ Scale for the final step size (e.g., 0.5).
49
+
50
+ Returns:
51
+ `Tuple[List[float], List[int], int]`: A tuple where the first element is the raw timestep schedule, the second
52
+ element is the number of substeps for each inference step, and the third element is the rounded total number of
53
+ substeps.
54
+ """
55
+ base_segment_size = 1 / (num_inference_steps - 1 + final_step_size_scale)
56
+ raw_timesteps = []
57
+ num_inference_substeps = []
58
+ _raw_t = 1.0
59
+ for i in range(num_inference_steps):
60
+ if i < num_inference_steps - 1:
61
+ segment_size = base_segment_size
62
+ else:
63
+ segment_size = base_segment_size * final_step_size_scale
64
+ _num_inference_substeps = max(round(segment_size * total_substeps), 1)
65
+ num_inference_substeps.append(_num_inference_substeps)
66
+ raw_timesteps.extend(np.linspace(
67
+ _raw_t, _raw_t - segment_size, _num_inference_substeps, endpoint=False).clip(min=0.0).tolist())
68
+ _raw_t = _raw_t - segment_size
69
+ total_substeps = sum(num_inference_substeps)
70
+ return raw_timesteps, num_inference_substeps, total_substeps
71
+
72
+
73
+ class PiFluxPipeline(FluxPipeline, PiFlowLoaderMixin):
74
+ r"""
75
+ The policy-based Flux pipeline for text-to-image generation.
76
+
77
+ Reference: Todo: add paper link
78
+
79
+ Args:
80
+ transformer ([`FluxTransformer2DModel`]):
81
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
82
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
83
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
84
+ vae ([`AutoencoderKL`]):
85
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
86
+ text_encoder ([`CLIPTextModel`]):
87
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
88
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
89
+ text_encoder_2 ([`T5EncoderModel`]):
90
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
91
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
92
+ tokenizer (`CLIPTokenizer`):
93
+ Tokenizer of class
94
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
95
+ tokenizer_2 (`T5TokenizerFast`):
96
+ Second Tokenizer of class
97
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
98
+ policy_type (`str`, *optional*, defaults to `"GMFlow"`):
99
+ The type of flow policy to use. Currently supports `"GMFlow"` and `"DX"`.
100
+ policy_kwargs (`Dict`, *optional*):
101
+ Additional keyword arguments to pass to the policy class.
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ scheduler: FlowMatchEulerDiscreteScheduler,
107
+ vae: AutoencoderKL,
108
+ text_encoder: CLIPTextModel,
109
+ tokenizer: CLIPTokenizer,
110
+ text_encoder_2: T5EncoderModel,
111
+ tokenizer_2: T5TokenizerFast,
112
+ transformer: FluxTransformer2DModel,
113
+ image_encoder: CLIPVisionModelWithProjection = None,
114
+ feature_extractor: CLIPImageProcessor = None,
115
+ policy_type: str = 'GMFlow',
116
+ policy_kwargs: Optional[Dict[str, Any]] = None,
117
+ ):
118
+ super().__init__(
119
+ scheduler,
120
+ vae,
121
+ text_encoder,
122
+ tokenizer,
123
+ text_encoder_2,
124
+ tokenizer_2,
125
+ transformer,
126
+ image_encoder,
127
+ feature_extractor
128
+ )
129
+ assert policy_type in POLICY_CLASSES, f'Invalid policy: {policy_type}. Supported policies are {list(POLICY_CLASSES.keys())}.'
130
+ self.policy_type = policy_type
131
+ self.policy_class = partial(
132
+ POLICY_CLASSES[policy_type], **policy_kwargs
133
+ ) if policy_kwargs else POLICY_CLASSES[policy_type]
134
+
135
+ def _unpack_gm(self, gm, height, width, num_channels_latents, patch_size=2, gm_patch_size=1):
136
+ c = num_channels_latents * patch_size * patch_size
137
+ h = (int(height) // (self.vae_scale_factor * patch_size))
138
+ w = (int(width) // (self.vae_scale_factor * patch_size))
139
+ bs = gm['means'].size(0)
140
+ k = self.transformer.num_gaussians
141
+ scale = patch_size // gm_patch_size
142
+ gm['means'] = gm['means'].reshape(
143
+ bs, h, w, k, c // (scale * scale), scale, scale
144
+ ).permute(
145
+ 0, 3, 4, 1, 5, 2, 6
146
+ ).reshape(
147
+ bs, k, c // (scale * scale), h * scale, w * scale)
148
+ gm['logweights'] = gm['logweights'].reshape(
149
+ bs, h, w, k, 1, scale, scale
150
+ ).permute(
151
+ 0, 3, 4, 1, 5, 2, 6
152
+ ).reshape(
153
+ bs, k, 1, h * scale, w * scale)
154
+ gm['logstds'] = gm['logstds'].reshape(bs, 1, 1, 1, 1)
155
+ return gm
156
+
157
+ @staticmethod
158
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width, patch_size=1, target_patch_size=2):
159
+ scale = target_patch_size // patch_size
160
+ latents = latents.view(
161
+ batch_size,
162
+ num_channels_latents * patch_size * patch_size,
163
+ height // target_patch_size, scale, width // target_patch_size, scale)
164
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
165
+ latents = latents.reshape(
166
+ batch_size,
167
+ (height // target_patch_size) * (width // target_patch_size),
168
+ num_channels_latents * target_patch_size * target_patch_size)
169
+
170
+ return latents
171
+
172
+ @staticmethod
173
+ def _unpack_latents(latents, height, width, vae_scale_factor, patch_size=2, target_patch_size=1):
174
+ batch_size, num_patches, channels = latents.shape
175
+ scale = patch_size // target_patch_size
176
+
177
+ # VAE applies 8x compression on images but we must also account for packing which requires
178
+ # latent height and width to be divisible by 2.
179
+ height = (int(height) // (vae_scale_factor * patch_size))
180
+ width = (int(width) // (vae_scale_factor * patch_size))
181
+
182
+ latents = latents.view(
183
+ batch_size, height, width, channels // (scale * scale), scale, scale)
184
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
185
+
186
+ latents = latents.reshape(batch_size, channels // (scale * scale), height * scale, width * scale)
187
+
188
+ return latents
189
+
190
+ @torch.inference_mode()
191
+ def __call__(
192
+ self,
193
+ prompt: Union[str, List[str]] = None,
194
+ prompt_2: Optional[Union[str, List[str]]] = None,
195
+ height: Optional[int] = None,
196
+ width: Optional[int] = None,
197
+ num_inference_steps: int = 4,
198
+ total_substeps: int = 128,
199
+ final_step_size_scale: float = 0.5,
200
+ temperature: Union[float, str] = 'auto',
201
+ guidance_scale: float = 3.5,
202
+ num_images_per_prompt: Optional[int] = 1,
203
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
204
+ latents: Optional[torch.FloatTensor] = None,
205
+ prompt_embeds: Optional[torch.FloatTensor] = None,
206
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
207
+ ip_adapter_image: Optional[PipelineImageInput] = None,
208
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
209
+ output_type: Optional[str] = "pil",
210
+ return_dict: bool = True,
211
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
212
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
213
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
214
+ max_sequence_length: int = 512,
215
+ ):
216
+ r"""
217
+ Function invoked when calling the pipeline for generation.
218
+
219
+ Args:
220
+ prompt (`str` or `List[str]`, *optional*):
221
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
222
+ instead.
223
+ prompt_2 (`str` or `List[str]`, *optional*):
224
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
225
+ will be used instead.
226
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
227
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
228
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
229
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
230
+ num_inference_steps (`int`, *optional*, defaults to 50):
231
+ The number of denoising steps.
232
+ total_substeps (`int`, *optional*, defaults to 128):
233
+ The total number of substeps for policy-based flow integration.
234
+ final_step_size_scale (`float`, *optional*, defaults to 0.5):
235
+ The scale for the final step size.
236
+ temperature (`float` or `"auto"`, *optional*, defaults to `"auto"`):
237
+ The tmperature parameter for the flow policy.
238
+ guidance_scale (`float`, *optional*, defaults to 3.5):
239
+ Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
240
+ a model to generate images more aligned with `prompt` at the expense of lower image quality.
241
+
242
+ Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
243
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
244
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
245
+ The number of images to generate per prompt.
246
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
247
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
248
+ to make generation deterministic.
249
+ latents (`torch.FloatTensor`, *optional*):
250
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
251
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
252
+ tensor will be generated by sampling using the supplied random `generator`.
253
+ prompt_embeds (`torch.FloatTensor`, *optional*):
254
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
255
+ provided, text embeddings will be generated from `prompt` input argument.
256
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
257
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
258
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
259
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
260
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
261
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
262
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
263
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
264
+ output_type (`str`, *optional*, defaults to `"pil"`):
265
+ The output format of the generate image. Choose between
266
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
267
+ return_dict (`bool`, *optional*, defaults to `True`):
268
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
269
+ joint_attention_kwargs (`dict`, *optional*):
270
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
271
+ `self.processor` in
272
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
273
+ callback_on_step_end (`Callable`, *optional*):
274
+ A function that calls at the end of each denoising steps during the inference. The function is called
275
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
276
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
277
+ `callback_on_step_end_tensor_inputs`.
278
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
279
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
280
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
281
+ `._callback_tensor_inputs` attribute of your pipeline class.
282
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
283
+
284
+ Returns:
285
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
286
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
287
+ images.
288
+ """
289
+
290
+ height = height or self.default_sample_size * self.vae_scale_factor
291
+ width = width or self.default_sample_size * self.vae_scale_factor
292
+
293
+ # 1. Check inputs. Raise error if not correct
294
+ self.check_inputs(
295
+ prompt,
296
+ prompt_2,
297
+ height,
298
+ width,
299
+ prompt_embeds=prompt_embeds,
300
+ pooled_prompt_embeds=pooled_prompt_embeds,
301
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
302
+ max_sequence_length=max_sequence_length,
303
+ )
304
+
305
+ self._guidance_scale = guidance_scale
306
+ self._joint_attention_kwargs = joint_attention_kwargs
307
+ self._current_timestep = None
308
+ self._interrupt = False
309
+
310
+ # 2. Define call parameters
311
+ if prompt is not None and isinstance(prompt, str):
312
+ batch_size = 1
313
+ elif prompt is not None and isinstance(prompt, list):
314
+ batch_size = len(prompt)
315
+ else:
316
+ batch_size = prompt_embeds.shape[0]
317
+
318
+ device = self._execution_device
319
+
320
+ # 3. Prepare prompt embeddings
321
+ lora_scale = (
322
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
323
+ )
324
+ (
325
+ prompt_embeds,
326
+ pooled_prompt_embeds,
327
+ text_ids,
328
+ ) = self.encode_prompt(
329
+ prompt=prompt,
330
+ prompt_2=prompt_2,
331
+ prompt_embeds=prompt_embeds,
332
+ pooled_prompt_embeds=pooled_prompt_embeds,
333
+ device=device,
334
+ num_images_per_prompt=num_images_per_prompt,
335
+ max_sequence_length=max_sequence_length,
336
+ lora_scale=lora_scale,
337
+ )
338
+
339
+ # 4. Prepare latent variables
340
+ num_channels_latents = self.transformer.config.in_channels // 4
341
+ latents, latent_image_ids = self.prepare_latents(
342
+ batch_size * num_images_per_prompt,
343
+ num_channels_latents,
344
+ height,
345
+ width,
346
+ torch.float32,
347
+ device,
348
+ generator,
349
+ latents,
350
+ )
351
+
352
+ # 5. Prepare timesteps
353
+ raw_timesteps, num_inference_substeps, total_substeps = retrieve_raw_timesteps(
354
+ num_inference_steps, total_substeps, final_step_size_scale)
355
+ image_seq_len = latents.shape[1]
356
+ mu = calculate_shift(
357
+ image_seq_len,
358
+ self.scheduler.config.get("base_image_seq_len", 256),
359
+ self.scheduler.config.get("max_image_seq_len", 4096),
360
+ self.scheduler.config.get("base_shift", 0.5),
361
+ self.scheduler.config.get("max_shift", 1.15),
362
+ )
363
+ timesteps, _ = retrieve_timesteps(
364
+ self.scheduler,
365
+ num_inference_steps,
366
+ device,
367
+ sigmas=raw_timesteps,
368
+ mu=mu,
369
+ )
370
+ assert len(timesteps) == total_substeps
371
+ self._num_timesteps = total_substeps
372
+
373
+ # handle guidance
374
+ if self.transformer.config.guidance_embeds:
375
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
376
+ guidance = guidance.expand(latents.shape[0])
377
+ else:
378
+ guidance = None
379
+
380
+ if self.joint_attention_kwargs is None:
381
+ self._joint_attention_kwargs = {}
382
+
383
+ image_embeds = None
384
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
385
+ image_embeds = self.prepare_ip_adapter_image_embeds(
386
+ ip_adapter_image,
387
+ ip_adapter_image_embeds,
388
+ device,
389
+ batch_size * num_images_per_prompt,
390
+ )
391
+
392
+ # 6. Denoising loop
393
+ self.scheduler.set_begin_index(0)
394
+ timestep_id = 0
395
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
396
+ for i in range(num_inference_steps):
397
+ if self.interrupt:
398
+ continue
399
+
400
+ t_src = timesteps[timestep_id]
401
+ sigma_t_src = t_src / self.scheduler.config.num_train_timesteps
402
+ is_final_step = i == (num_inference_steps - 1)
403
+
404
+ self._current_timestep = t_src
405
+ if image_embeds is not None:
406
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
407
+
408
+ with self.transformer.cache_context("cond"):
409
+ denoising_output = self.transformer(
410
+ hidden_states=latents.to(dtype=self.transformer.dtype),
411
+ timestep=t_src.expand(latents.shape[0]) / 1000,
412
+ guidance=guidance,
413
+ pooled_projections=pooled_prompt_embeds,
414
+ encoder_hidden_states=prompt_embeds,
415
+ txt_ids=text_ids,
416
+ img_ids=latent_image_ids,
417
+ joint_attention_kwargs=self.joint_attention_kwargs,
418
+ )
419
+
420
+ # unpack and create policy
421
+ latents = self._unpack_latents(
422
+ latents, height, width, self.vae_scale_factor, target_patch_size=1)
423
+ if self.policy_type == 'GMFlow':
424
+ denoising_output = self._unpack_gm(
425
+ denoising_output, height, width, num_channels_latents, gm_patch_size=1)
426
+ denoising_output = {k: v.to(torch.float32) for k, v in denoising_output.items()}
427
+ policy = self.policy_class(
428
+ denoising_output, latents, sigma_t_src)
429
+ if not is_final_step:
430
+ if temperature == 'auto':
431
+ temperature = min(max(0.1 * (num_inference_steps - 1), 0), 1)
432
+ else:
433
+ assert isinstance(temperature, (float, int))
434
+ policy.temperature_(temperature)
435
+ elif self.policy_type == 'DX':
436
+ denoising_output = denoising_output[0]
437
+ denoising_output = self._unpack_latents(
438
+ denoising_output, height, width, self.vae_scale_factor, target_patch_size=1)
439
+ denoising_output = denoising_output.reshape(latents.size(0), -1, *latents.shape[1:])
440
+ denoising_output = denoising_output.to(torch.float32)
441
+ policy = self.policy_class(
442
+ denoising_output, latents, sigma_t_src)
443
+ else:
444
+ raise ValueError(f'Unknown policy type: {self.policy_type}.')
445
+
446
+ # compute the previous noisy sample x_t -> x_t-1
447
+ for _ in range(num_inference_substeps[i]):
448
+ t = timesteps[timestep_id]
449
+ sigma_t = t / self.scheduler.config.num_train_timesteps
450
+ u = policy.u(latents, sigma_t)
451
+ latents = self.scheduler.step(u, t, latents, return_dict=False)[0]
452
+ timestep_id += 1
453
+
454
+ # repack
455
+ latents = self._pack_latents(
456
+ latents, latents.size(0), num_channels_latents,
457
+ 2 * (int(height) // (self.vae_scale_factor * 2)),
458
+ 2 * (int(width) // (self.vae_scale_factor * 2)),
459
+ patch_size=1)
460
+
461
+ if callback_on_step_end is not None:
462
+ callback_kwargs = {}
463
+ for k in callback_on_step_end_tensor_inputs:
464
+ callback_kwargs[k] = locals()[k]
465
+ callback_outputs = callback_on_step_end(self, i, t_src, callback_kwargs)
466
+
467
+ latents = callback_outputs.pop("latents", latents)
468
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
469
+
470
+ progress_bar.update()
471
+
472
+ if XLA_AVAILABLE:
473
+ xm.mark_step()
474
+
475
+ self._current_timestep = None
476
+
477
+ if output_type == "latent":
478
+ image = latents
479
+ else:
480
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
481
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
482
+ image = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
483
+ image = self.image_processor.postprocess(image, output_type=output_type)
484
+
485
+ # Offload all models
486
+ self.maybe_free_model_hooks()
487
+
488
+ if not return_dict:
489
+ return (image,)
490
+
491
+ return FluxPipelineOutput(images=image)
lakonlab/pipelines/piqwen_pipeline.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Hansheng Chen
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from typing import Dict, List, Optional, Union, Any, Callable
7
+ from functools import partial
8
+ from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
9
+ from diffusers.utils import is_torch_xla_available
10
+ from diffusers.models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
11
+ from diffusers.pipelines.qwenimage.pipeline_qwenimage import (
12
+ QwenImagePipeline, calculate_shift, retrieve_timesteps, QwenImagePipelineOutput)
13
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
14
+ from lakonlab.models.diffusions.piflow_policies import POLICY_CLASSES
15
+ from .piflow_loader import PiFlowLoaderMixin
16
+
17
+
18
+ if is_torch_xla_available():
19
+ import torch_xla.core.xla_model as xm
20
+
21
+ XLA_AVAILABLE = True
22
+ else:
23
+ XLA_AVAILABLE = False
24
+
25
+
26
+ def retrieve_raw_timesteps(
27
+ num_inference_steps: int,
28
+ total_substeps: int,
29
+ final_step_size_scale: float
30
+ ):
31
+ r"""
32
+ Retrieve the raw times and the number of substeps for each inference step.
33
+
34
+ Args:
35
+ num_inference_steps (`int`):
36
+ Number of inference steps.
37
+ total_substeps (`int`):
38
+ Total number of substeps (e.g., 128).
39
+ final_step_size_scale (`float`):
40
+ Scale for the final step size (e.g., 0.5).
41
+
42
+ Returns:
43
+ `Tuple[List[float], List[int], int]`: A tuple where the first element is the raw timestep schedule, the second
44
+ element is the number of substeps for each inference step, and the third element is the rounded total number of
45
+ substeps.
46
+ """
47
+ base_segment_size = 1 / (num_inference_steps - 1 + final_step_size_scale)
48
+ raw_timesteps = []
49
+ num_inference_substeps = []
50
+ _raw_t = 1.0
51
+ for i in range(num_inference_steps):
52
+ if i < num_inference_steps - 1:
53
+ segment_size = base_segment_size
54
+ else:
55
+ segment_size = base_segment_size * final_step_size_scale
56
+ _num_inference_substeps = max(round(segment_size * total_substeps), 1)
57
+ num_inference_substeps.append(_num_inference_substeps)
58
+ raw_timesteps.extend(np.linspace(
59
+ _raw_t, _raw_t - segment_size, _num_inference_substeps, endpoint=False).clip(min=0.0).tolist())
60
+ _raw_t = _raw_t - segment_size
61
+ total_substeps = sum(num_inference_substeps)
62
+ return raw_timesteps, num_inference_substeps, total_substeps
63
+
64
+
65
+ class PiQwenImagePipeline(QwenImagePipeline, PiFlowLoaderMixin):
66
+ r"""
67
+ The policy-based QwenImage pipeline for text-to-image generation.
68
+
69
+ Reference: Todo: add paper link
70
+
71
+ Args:
72
+ transformer ([`QwenImageTransformer2DModel`]):
73
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
74
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
75
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
76
+ vae ([`AutoencoderKL`]):
77
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
78
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
79
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
80
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
81
+ tokenizer (`QwenTokenizer`):
82
+ Tokenizer of class
83
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
84
+ policy_type (`str`, *optional*, defaults to `"GMFlow"`):
85
+ The type of flow policy to use. Currently supports `"GMFlow"` and `"DX"`.
86
+ policy_kwargs (`Dict`, *optional*):
87
+ Additional keyword arguments to pass to the policy class.
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ scheduler: FlowMatchEulerDiscreteScheduler,
93
+ vae: AutoencoderKLQwenImage,
94
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
95
+ tokenizer: Qwen2Tokenizer,
96
+ transformer: QwenImageTransformer2DModel,
97
+ policy_type: str = 'GMFlow',
98
+ policy_kwargs: Optional[Dict[str, Any]] = None,
99
+ ):
100
+ super().__init__(
101
+ scheduler,
102
+ vae,
103
+ text_encoder,
104
+ tokenizer,
105
+ transformer,
106
+ )
107
+ assert policy_type in POLICY_CLASSES, f'Invalid policy: {policy_type}. Supported policies are {list(POLICY_CLASSES.keys())}.'
108
+ self.policy_type = policy_type
109
+ self.policy_class = partial(
110
+ POLICY_CLASSES[policy_type], **policy_kwargs
111
+ ) if policy_kwargs else POLICY_CLASSES[policy_type]
112
+
113
+ def _unpack_gm(self, gm, height, width, num_channels_latents, patch_size=2, gm_patch_size=1):
114
+ c = num_channels_latents * patch_size * patch_size
115
+ h = (int(height) // (self.vae_scale_factor * patch_size))
116
+ w = (int(width) // (self.vae_scale_factor * patch_size))
117
+ bs = gm['means'].size(0)
118
+ k = self.transformer.num_gaussians
119
+ scale = patch_size // gm_patch_size
120
+ gm['means'] = gm['means'].reshape(
121
+ bs, h, w, k, c // (scale * scale), scale, scale
122
+ ).permute(
123
+ 0, 3, 4, 1, 5, 2, 6
124
+ ).reshape(
125
+ bs, k, c // (scale * scale), h * scale, w * scale)
126
+ gm['logweights'] = gm['logweights'].reshape(
127
+ bs, h, w, k, 1, scale, scale
128
+ ).permute(
129
+ 0, 3, 4, 1, 5, 2, 6
130
+ ).reshape(
131
+ bs, k, 1, h * scale, w * scale)
132
+ gm['logstds'] = gm['logstds'].reshape(bs, 1, 1, 1, 1)
133
+ return gm
134
+
135
+ @staticmethod
136
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width, patch_size=1, target_patch_size=2):
137
+ scale = target_patch_size // patch_size
138
+ latents = latents.view(
139
+ batch_size,
140
+ num_channels_latents * patch_size * patch_size,
141
+ height // target_patch_size, scale, width // target_patch_size, scale)
142
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
143
+ latents = latents.reshape(
144
+ batch_size,
145
+ (height // target_patch_size) * (width // target_patch_size),
146
+ num_channels_latents * target_patch_size * target_patch_size)
147
+
148
+ return latents
149
+
150
+ @staticmethod
151
+ def _unpack_latents(latents, height, width, vae_scale_factor, patch_size=2, target_patch_size=1):
152
+ batch_size, num_patches, channels = latents.shape
153
+ scale = patch_size // target_patch_size
154
+
155
+ # VAE applies 8x compression on images but we must also account for packing which requires
156
+ # latent height and width to be divisible by 2.
157
+ height = (int(height) // (vae_scale_factor * patch_size))
158
+ width = (int(width) // (vae_scale_factor * patch_size))
159
+
160
+ latents = latents.view(
161
+ batch_size, height, width, channels // (scale * scale), scale, scale)
162
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
163
+
164
+ latents = latents.reshape(batch_size, channels // (scale * scale), height * scale, width * scale)
165
+
166
+ return latents
167
+
168
+ @torch.inference_mode()
169
+ def __call__(
170
+ self,
171
+ prompt: Union[str, List[str]] = None,
172
+ height: Optional[int] = None,
173
+ width: Optional[int] = None,
174
+ num_inference_steps: int = 4,
175
+ total_substeps: int = 128,
176
+ final_step_size_scale: float = 0.5,
177
+ temperature: Union[float, str] = 'auto',
178
+ num_images_per_prompt: Optional[int] = 1,
179
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
180
+ latents: Optional[torch.FloatTensor] = None,
181
+ prompt_embeds: Optional[torch.FloatTensor] = None,
182
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
183
+ output_type: Optional[str] = "pil",
184
+ return_dict: bool = True,
185
+ attention_kwargs: Optional[Dict[str, Any]] = None,
186
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
187
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
188
+ max_sequence_length: int = 512,
189
+ ):
190
+ r"""
191
+ Function invoked when calling the pipeline for generation.
192
+
193
+ Args:
194
+ prompt (`str` or `List[str]`, *optional*):
195
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
196
+ instead.
197
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
198
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
199
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
200
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
201
+ num_inference_steps (`int`, *optional*, defaults to 50):
202
+ The number of denoising steps.
203
+ total_substeps (`int`, *optional*, defaults to 128):
204
+ The total number of substeps for policy-based flow integration.
205
+ final_step_size_scale (`float`, *optional*, defaults to 0.5):
206
+ The scale for the final step size.
207
+ temperature (`float` or `"auto"`, *optional*, defaults to `"auto"`):
208
+ The tmperature parameter for the flow policy.
209
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
210
+ The number of images to generate per prompt.
211
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
212
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
213
+ to make generation deterministic.
214
+ latents (`torch.FloatTensor`, *optional*):
215
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
216
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
217
+ tensor will be generated by sampling using the supplied random `generator`.
218
+ prompt_embeds (`torch.FloatTensor`, *optional*):
219
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
220
+ provided, text embeddings will be generated from `prompt` input argument.
221
+ output_type (`str`, *optional*, defaults to `"pil"`):
222
+ The output format of the generate image. Choose between
223
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
224
+ return_dict (`bool`, *optional*, defaults to `True`):
225
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
226
+ attention_kwargs (`dict`, *optional*):
227
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
228
+ `self.processor` in
229
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
230
+ callback_on_step_end (`Callable`, *optional*):
231
+ A function that calls at the end of each denoising steps during the inference. The function is called
232
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
233
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
234
+ `callback_on_step_end_tensor_inputs`.
235
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
236
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
237
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
238
+ `._callback_tensor_inputs` attribute of your pipeline class.
239
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
240
+
241
+ Returns:
242
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
243
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
244
+ images.
245
+ """
246
+
247
+ height = height or self.default_sample_size * self.vae_scale_factor
248
+ width = width or self.default_sample_size * self.vae_scale_factor
249
+
250
+ # 1. Check inputs. Raise error if not correct
251
+ self.check_inputs(
252
+ prompt,
253
+ height,
254
+ width,
255
+ prompt_embeds=prompt_embeds,
256
+ prompt_embeds_mask=prompt_embeds_mask,
257
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
258
+ max_sequence_length=max_sequence_length,
259
+ )
260
+
261
+ self._attention_kwargs = attention_kwargs
262
+ self._current_timestep = None
263
+ self._interrupt = False
264
+
265
+ # 2. Define call parameters
266
+ if prompt is not None and isinstance(prompt, str):
267
+ batch_size = 1
268
+ elif prompt is not None and isinstance(prompt, list):
269
+ batch_size = len(prompt)
270
+ else:
271
+ batch_size = prompt_embeds.shape[0]
272
+
273
+ device = self._execution_device
274
+
275
+ # 3. Prepare prompt embeddings
276
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
277
+ prompt=prompt,
278
+ prompt_embeds=prompt_embeds,
279
+ prompt_embeds_mask=prompt_embeds_mask,
280
+ device=device,
281
+ num_images_per_prompt=num_images_per_prompt,
282
+ max_sequence_length=max_sequence_length,
283
+ )
284
+
285
+ # 4. Prepare latent variables
286
+ num_channels_latents = self.transformer.config.in_channels // 4
287
+ latents = self.prepare_latents(
288
+ batch_size * num_images_per_prompt,
289
+ num_channels_latents,
290
+ height,
291
+ width,
292
+ torch.float32,
293
+ device,
294
+ generator,
295
+ latents,
296
+ )
297
+ img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size
298
+
299
+ # 5. Prepare timesteps
300
+ raw_timesteps, num_inference_substeps, total_substeps = retrieve_raw_timesteps(
301
+ num_inference_steps, total_substeps, final_step_size_scale)
302
+ image_seq_len = latents.shape[1]
303
+ mu = calculate_shift(
304
+ image_seq_len,
305
+ self.scheduler.config.get("base_image_seq_len", 256),
306
+ self.scheduler.config.get("max_image_seq_len", 4096),
307
+ self.scheduler.config.get("base_shift", 0.5),
308
+ self.scheduler.config.get("max_shift", 1.15),
309
+ )
310
+ timesteps, _ = retrieve_timesteps(
311
+ self.scheduler,
312
+ num_inference_steps,
313
+ device,
314
+ sigmas=raw_timesteps,
315
+ mu=mu,
316
+ )
317
+ assert len(timesteps) == total_substeps
318
+ self._num_timesteps = total_substeps
319
+
320
+ if self.attention_kwargs is None:
321
+ self._attention_kwargs = {}
322
+
323
+ txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
324
+
325
+ # 6. Denoising loop
326
+ self.scheduler.set_begin_index(0)
327
+ timestep_id = 0
328
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
329
+ for i in range(num_inference_steps):
330
+ if self.interrupt:
331
+ continue
332
+
333
+ t_src = timesteps[timestep_id]
334
+ sigma_t_src = t_src / self.scheduler.config.num_train_timesteps
335
+ is_final_step = i == (num_inference_steps - 1)
336
+
337
+ self._current_timestep = t_src
338
+
339
+ with self.transformer.cache_context("cond"):
340
+ denoising_output = self.transformer(
341
+ hidden_states=latents.to(dtype=self.transformer.dtype),
342
+ timestep=t_src.expand(latents.shape[0]) / 1000,
343
+ encoder_hidden_states_mask=prompt_embeds_mask,
344
+ encoder_hidden_states=prompt_embeds,
345
+ img_shapes=img_shapes,
346
+ txt_seq_lens=txt_seq_lens,
347
+ attention_kwargs=self.attention_kwargs,
348
+ )
349
+
350
+ # unpack and create policy
351
+ latents = self._unpack_latents(
352
+ latents, height, width, self.vae_scale_factor, target_patch_size=1)
353
+ if self.policy_type == 'GMFlow':
354
+ denoising_output = self._unpack_gm(
355
+ denoising_output, height, width, num_channels_latents, gm_patch_size=1)
356
+ denoising_output = {k: v.to(torch.float32) for k, v in denoising_output.items()}
357
+ policy = self.policy_class(
358
+ denoising_output, latents, sigma_t_src)
359
+ if not is_final_step:
360
+ if temperature == 'auto':
361
+ temperature = min(max(0.1 * (num_inference_steps - 1), 0), 1)
362
+ else:
363
+ assert isinstance(temperature, (float, int))
364
+ policy.temperature_(temperature)
365
+ elif self.policy_type == 'DX':
366
+ denoising_output = denoising_output[0]
367
+ denoising_output = self._unpack_latents(
368
+ denoising_output, height, width, self.vae_scale_factor, target_patch_size=1)
369
+ denoising_output = denoising_output.reshape(latents.size(0), -1, *latents.shape[1:])
370
+ denoising_output = denoising_output.to(torch.float32)
371
+ policy = self.policy_class(
372
+ denoising_output, latents, sigma_t_src)
373
+ else:
374
+ raise ValueError(f'Unknown policy type: {self.policy_type}.')
375
+
376
+ # compute the previous noisy sample x_t -> x_t-1
377
+ for _ in range(num_inference_substeps[i]):
378
+ t = timesteps[timestep_id]
379
+ sigma_t = t / self.scheduler.config.num_train_timesteps
380
+ u = policy.u(latents, sigma_t)
381
+ latents = self.scheduler.step(u, t, latents, return_dict=False)[0]
382
+ timestep_id += 1
383
+
384
+ # repack
385
+ latents = self._pack_latents(
386
+ latents, latents.size(0), num_channels_latents,
387
+ 2 * (int(height) // (self.vae_scale_factor * 2)),
388
+ 2 * (int(width) // (self.vae_scale_factor * 2)),
389
+ patch_size=1)
390
+
391
+ if callback_on_step_end is not None:
392
+ callback_kwargs = {}
393
+ for k in callback_on_step_end_tensor_inputs:
394
+ callback_kwargs[k] = locals()[k]
395
+ callback_outputs = callback_on_step_end(self, i, t_src, callback_kwargs)
396
+
397
+ latents = callback_outputs.pop("latents", latents)
398
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
399
+
400
+ progress_bar.update()
401
+
402
+ if XLA_AVAILABLE:
403
+ xm.mark_step()
404
+
405
+ self._current_timestep = None
406
+
407
+ if output_type == "latent":
408
+ image = latents
409
+ else:
410
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)[:, :, None]
411
+ latents_mean = (
412
+ torch.tensor(self.vae.config.latents_mean)
413
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
414
+ .to(latents.device, latents.dtype)
415
+ )
416
+ latents_std = torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
417
+ latents.device, latents.dtype
418
+ )
419
+ latents = latents * latents_std + latents_mean
420
+ image = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0][:, :, 0]
421
+ image = self.image_processor.postprocess(image, output_type=output_type)
422
+
423
+ # Offload all models
424
+ self.maybe_free_model_hooks()
425
+
426
+ if not return_dict:
427
+ return (image,)
428
+
429
+ return QwenImagePipelineOutput(images=image)
lakonlab/ui/__init__.py ADDED
File without changes
lakonlab/ui/gradio/__init__.py ADDED
File without changes
lakonlab/ui/gradio/create_text_to_img.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from .shared_opts import create_base_opts, create_generate_bar, set_seed, create_prompt_opts
3
+
4
+
5
+ def create_interface_text_to_img(
6
+ api, prompt='', seed=42, steps=32, min_steps=4, max_steps=50, steps_slider_step=1,
7
+ height=768, width=1360, hw_slider_step=16,
8
+ guidance_scale=None, temperature=None, api_name='text_to_img',
9
+ create_negative_prompt=False, args=['last_seed', 'prompt', 'width', 'height', 'steps', 'guidance_scale']):
10
+ var_dict = dict()
11
+ with gr.Blocks(analytics_enabled=False) as interface:
12
+ var_dict['output_image'] = gr.Image(
13
+ type='pil', image_mode='RGB', label='Output image', interactive=False, elem_classes=['vh-img'])
14
+ create_prompt_opts(var_dict, create_negative_prompt=create_negative_prompt, prompt=prompt)
15
+ with gr.Column(variant='compact', elem_classes=['custom-spacing']):
16
+ with gr.Row(variant='compact', elem_classes=['force-hide-container']):
17
+ var_dict['width'] = gr.Slider(
18
+ label='Width', minimum=64, maximum=2048, step=hw_slider_step, value=width,
19
+ elem_classes=['force-hide-container'])
20
+ var_dict['switch_hw'] = gr.Button('\U000021C6', elem_classes=['tool'])
21
+ var_dict['height'] = gr.Slider(
22
+ label='Height', minimum=64, maximum=2048, step=hw_slider_step, value=height,
23
+ elem_classes=['force-hide-container'])
24
+ var_dict['switch_hw'].click(
25
+ fn=lambda w, h: (h, w),
26
+ inputs=[var_dict['width'], var_dict['height']],
27
+ outputs=[var_dict['width'], var_dict['height']],
28
+ show_progress=False,
29
+ api_name=False)
30
+ create_generate_bar(var_dict, text='Generate', seed=seed)
31
+ create_base_opts(
32
+ var_dict,
33
+ steps=steps,
34
+ min_steps=min_steps,
35
+ max_steps=max_steps,
36
+ steps_slider_step=steps_slider_step,
37
+ guidance_scale=guidance_scale,
38
+ temperature=temperature)
39
+
40
+ var_dict['run_btn'].click(
41
+ fn=set_seed,
42
+ inputs=var_dict['seed'],
43
+ outputs=var_dict['last_seed'],
44
+ show_progress=False,
45
+ api_name=False
46
+ ).success(
47
+ fn=api,
48
+ inputs=[var_dict[arg] for arg in args],
49
+ outputs=var_dict['output_image'],
50
+ concurrency_id='default_group', api_name=api_name
51
+ )
52
+
53
+ return interface, var_dict
lakonlab/ui/gradio/shared_opts.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import gradio as gr
3
+
4
+
5
+ def create_prompt_opts(var_dict, create_negative_prompt=True, prompt='', negatove_prompt=''):
6
+ var_dict['prompt'] = gr.Textbox(
7
+ prompt, label='Prompt', show_label=False, lines=2, placeholder='Prompt', container=False, interactive=True)
8
+ if create_negative_prompt:
9
+ var_dict['negative_prompt'] = gr.Textbox(
10
+ negatove_prompt, label='Negative prompt', show_label=False, lines=2,
11
+ placeholder='Negative prompt', container=False, interactive=True)
12
+
13
+
14
+ def create_generate_bar(var_dict, text='Generate', variant='primary', seed=-1):
15
+ with gr.Row(equal_height=False):
16
+ var_dict['run_btn'] = gr.Button(text, variant=variant, scale=2)
17
+ var_dict['seed'] = gr.Number(
18
+ label='Seed', value=seed, min_width=100, precision=0, minimum=-1, maximum=2 ** 31,
19
+ elem_classes=['force-hide-container'])
20
+ var_dict['random_seed'] = gr.Button('\U0001f3b2\ufe0f', elem_classes=['tool'])
21
+ var_dict['reuse_seed'] = gr.Button('\u267b\ufe0f', elem_classes=['tool'])
22
+ with gr.Column(visible=False):
23
+ var_dict['last_seed'] = gr.Number(value=seed, label='Last seed')
24
+ var_dict['reuse_seed'].click(
25
+ fn=lambda x: x,
26
+ inputs=var_dict['last_seed'],
27
+ outputs=var_dict['seed'],
28
+ show_progress=False,
29
+ api_name=False)
30
+ var_dict['random_seed'].click(
31
+ fn=lambda: -1,
32
+ outputs=var_dict['seed'],
33
+ show_progress=False,
34
+ api_name=False)
35
+
36
+
37
+ def create_base_opts(var_dict,
38
+ steps=24,
39
+ min_steps=4,
40
+ max_steps=50,
41
+ steps_slider_step=1,
42
+ guidance_scale=None,
43
+ temperature=None,
44
+ render=True):
45
+ with gr.Column(variant='compact', elem_classes=['custom-spacing'], render=render) as base_opts:
46
+ with gr.Row(variant='compact', elem_classes=['force-hide-container']):
47
+ var_dict['steps'] = gr.Slider(
48
+ min_steps, max_steps, value=steps, step=steps_slider_step, label='Sampling steps',
49
+ elem_classes=['force-hide-container'])
50
+ with gr.Row(variant='compact', elem_classes=['force-hide-container']):
51
+ if guidance_scale is not None:
52
+ var_dict['guidance_scale'] = gr.Slider(
53
+ 0.0, 30.0, value=guidance_scale, step=0.5, label='Guidance scale',
54
+ elem_classes=['force-hide-container'])
55
+ if temperature is not None:
56
+ var_dict['temperature'] = gr.Slider(
57
+ 0.0, 1.0, value=temperature, step=0.01, label='Temperature',
58
+ elem_classes=['force-hide-container'])
59
+ return base_opts
60
+
61
+
62
+ def set_seed(seed):
63
+ seed = random.randint(0, 2**31) if seed == -1 else seed
64
+ return seed
lakonlab/ui/gradio/style.css ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .force-hide-container {
2
+ margin: 0;
3
+ box-shadow: none;
4
+ --block-border-width: 0;
5
+ background: transparent;
6
+ padding: 0;
7
+ overflow: visible;
8
+ }
9
+
10
+ .svelte-sfqy0y {
11
+ display: flex;
12
+ flex-direction: inherit;
13
+ flex-wrap: wrap;
14
+ gap: 0;
15
+ box-shadow: none;
16
+ border: 0;
17
+ border-radius: 0;
18
+ background: transparent;
19
+ overflow-y: hidden;
20
+ }
21
+
22
+ .custom-spacing {
23
+ padding: 10px;
24
+ gap: 20px;
25
+ flex-grow: 0 !important;
26
+ }
27
+
28
+ .unequal-height {
29
+ align-items: flex-end;
30
+ }
31
+
32
+ .tool{
33
+ max-width: 40px;
34
+ min-width: 40px !important;
35
+ }
36
+
37
+ /* Center the component and allow it to use the full row width */
38
+ .vh-img {
39
+ display: grid;
40
+ justify-items: center;
41
+ }
42
+
43
+ /* Container should size to the image, but never exceed the row width */
44
+ .vh-img .image-container {
45
+ inline-size: fit-content !important; /* prefers image’s natural width */
46
+ max-inline-size: 100% !important; /* ...but clamps to available width */
47
+ margin-inline: auto;
48
+ overflow: hidden; /* avoid odd overflow on iOS */
49
+ }
50
+
51
+ /* Image scales by BOTH constraints: height cap and row width */
52
+ .vh-img .image-container img {
53
+ max-block-size: 700px !important; /* fixed max height cap */
54
+ max-inline-size: 100%; /* never wider than container */
55
+ inline-size: auto; /* keep aspect ratio */
56
+ block-size: auto;
57
+ object-fit: contain;
58
+ display: block;
59
+ }
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ numpy==1.26.4
2
+ torch==2.6.0
3
+ diffusers==0.35.1
4
+ peft==0.17.0
5
+ sentencepiece
6
+ accelerate
7
+ transformers==4.54.1
8
+ gradio==4.18.0