Spaces:
Running
on
Zero
Running
on
Zero
Add gradio app
Browse files- .gitignore +25 -0
- README.md +18 -6
- app.py +60 -0
- lakonlab/__init__.py +0 -0
- lakonlab/models/__init__.py +0 -0
- lakonlab/models/architecture/__init__.py +0 -0
- lakonlab/models/architecture/gmflow/__init__.py +0 -0
- lakonlab/models/architecture/gmflow/gm_output.py +24 -0
- lakonlab/models/architecture/gmflow/gmflux.py +225 -0
- lakonlab/models/architecture/gmflow/gmqwen.py +149 -0
- lakonlab/models/diffusions/__init__.py +0 -0
- lakonlab/models/diffusions/piflow_policies/__init__.py +8 -0
- lakonlab/models/diffusions/piflow_policies/base.py +21 -0
- lakonlab/models/diffusions/piflow_policies/dx.py +108 -0
- lakonlab/models/diffusions/piflow_policies/gmflow.py +175 -0
- lakonlab/pipelines/__init__.py +0 -0
- lakonlab/pipelines/piflow_loader.py +275 -0
- lakonlab/pipelines/piflux_pipeline.py +491 -0
- lakonlab/pipelines/piqwen_pipeline.py +429 -0
- lakonlab/ui/__init__.py +0 -0
- lakonlab/ui/gradio/__init__.py +0 -0
- lakonlab/ui/gradio/create_text_to_img.py +53 -0
- lakonlab/ui/gradio/shared_opts.py +64 -0
- lakonlab/ui/gradio/style.css +59 -0
- requirements.txt +8 -0
.gitignore
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/.idea/
|
| 2 |
+
/work_dirs*
|
| 3 |
+
.vscode/
|
| 4 |
+
/tmp
|
| 5 |
+
/data
|
| 6 |
+
/checkpoints
|
| 7 |
+
*.so
|
| 8 |
+
*.patch
|
| 9 |
+
__pycache__/
|
| 10 |
+
*.egg-info/
|
| 11 |
+
/viz*
|
| 12 |
+
/submit*
|
| 13 |
+
build/
|
| 14 |
+
*.pyd
|
| 15 |
+
/cache*
|
| 16 |
+
*.stl
|
| 17 |
+
*.pth
|
| 18 |
+
/venv/
|
| 19 |
+
.nk8s
|
| 20 |
+
*.mp4
|
| 21 |
+
.vs
|
| 22 |
+
/exp/
|
| 23 |
+
/dev/
|
| 24 |
+
*.pyi
|
| 25 |
+
!/data/imagenet/imagenet1000_clsidx_to_labels.txt
|
README.md
CHANGED
|
@@ -1,13 +1,25 @@
|
|
| 1 |
---
|
| 2 |
-
title: Pi
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Pi-Qwen Demo
|
| 3 |
+
emoji: 🚀
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: pink
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.18.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
| 11 |
---
|
| 12 |
|
| 13 |
+
Official demo of the paper:
|
| 14 |
+
|
| 15 |
+
**pi-Flow: Policy-Based Few-Step Generation via Imitation Distillation**
|
| 16 |
+
<br>
|
| 17 |
+
[Hansheng Chen](https://lakonik.github.io/)<sup>1</sup>,
|
| 18 |
+
[Kai Zhang](https://kai-46.github.io/website/)<sup>2</sup>,
|
| 19 |
+
[Hao Tan](https://research.adobe.com/person/hao-tan/)<sup>2</sup>,
|
| 20 |
+
[Leonidas Guibas](https://geometry.stanford.edu/?member=guibas)<sup>1</sup>,
|
| 21 |
+
[Gordon Wetzstein](http://web.stanford.edu/~gordonwz/)<sup>1</sup>,
|
| 22 |
+
[Sai Bi](https://sai-bi.github.io/)<sup>2</sup><br>
|
| 23 |
+
<sup>1</sup>Stanford University, <sup>2</sup>Adobe Research
|
| 24 |
+
<br>
|
| 25 |
+
[[arXiv]()] [[pi-Qwen Demo🤗](https://huggingface.co/spaces/Lakonik/pi-Qwen)] [[pi-FLUX Demo🤗](https://huggingface.co/spaces/Lakonik/pi-FLUX.1)]
|
app.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 4 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import spaces
|
| 9 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 10 |
+
from lakonlab.ui.gradio.create_text_to_img import create_interface_text_to_img
|
| 11 |
+
from lakonlab.pipelines.piqwen_pipeline import PiQwenImagePipeline
|
| 12 |
+
|
| 13 |
+
from huggingface_hub import login
|
| 14 |
+
login(token=os.getenv('HF_TOKEN'))
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
DEFAULT_PROMPT = ('Photo of a coffee shop entrance featuring a chalkboard sign reading "π-Qwen Coffee 😊 $2 per cup," '
|
| 18 |
+
'with a neon light beside it displaying "π-通义千问". Next to it hangs a poster showing a beautiful '
|
| 19 |
+
'Chinese woman, and beneath the poster is written "e≈2.71828-18284-59045-23536-02874-71352".')
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
pipe = PiQwenImagePipeline.from_pretrained(
|
| 23 |
+
'Qwen/Qwen-Image',
|
| 24 |
+
torch_dtype=torch.bfloat16)
|
| 25 |
+
pipe.load_piflow_adapter(
|
| 26 |
+
'Lakonik/pi-Qwen-Image',
|
| 27 |
+
subfolder='gmqwen_k8_piid_4step',
|
| 28 |
+
target_module_name='transformer')
|
| 29 |
+
pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config( # use fixed shift=3.2
|
| 30 |
+
pipe.scheduler.config, shift=3.2, shift_terminal=None, use_dynamic_shifting=False)
|
| 31 |
+
pipe = pipe.to('cuda')
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@spaces.GPU
|
| 35 |
+
def generate(seed, prompt, width, height, steps):
|
| 36 |
+
return pipe(
|
| 37 |
+
prompt=prompt,
|
| 38 |
+
width=width,
|
| 39 |
+
height=height,
|
| 40 |
+
num_inference_steps=steps,
|
| 41 |
+
generator=torch.Generator().manual_seed(seed),
|
| 42 |
+
).images[0]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
with gr.Blocks(analytics_enabled=False,
|
| 46 |
+
title='pi-Qwen Demo',
|
| 47 |
+
css='lakonlab/ui/gradio/style.css'
|
| 48 |
+
) as demo:
|
| 49 |
+
|
| 50 |
+
md_txt = '# pi-Qwen Demo\n\n' \
|
| 51 |
+
'Official demo of the paper [pi-Flow: Policy-Based Few-Step Generation via Imitation Distillation](). ' \
|
| 52 |
+
'**Base model:** [Qwen-Image](https://huggingface.co/Qwen/Qwen-Image). **Fast policy:** GMFlow. **Code:** [https://github.com/Lakonik/piFlow](https://github.com/Lakonik/piFlow).'
|
| 53 |
+
gr.Markdown(md_txt)
|
| 54 |
+
|
| 55 |
+
create_interface_text_to_img(
|
| 56 |
+
generate,
|
| 57 |
+
prompt=DEFAULT_PROMPT,
|
| 58 |
+
steps=4, guidance_scale=None,
|
| 59 |
+
args=['last_seed', 'prompt', 'width', 'height', 'steps'])
|
| 60 |
+
demo.queue().launch()
|
lakonlab/__init__.py
ADDED
|
File without changes
|
lakonlab/models/__init__.py
ADDED
|
File without changes
|
lakonlab/models/architecture/__init__.py
ADDED
|
File without changes
|
lakonlab/models/architecture/gmflow/__init__.py
ADDED
|
File without changes
|
lakonlab/models/architecture/gmflow/gm_output.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from diffusers.utils import BaseOutput
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class GMFlowModelOutput(BaseOutput):
|
| 8 |
+
"""
|
| 9 |
+
The output of GMFlow models.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
means (`torch.Tensor` of shape `(batch_size, num_gaussians, num_channels, height, width)` or
|
| 13 |
+
`(batch_size, num_gaussians, num_channels, frame, height, width)`):
|
| 14 |
+
Gaussian mixture means.
|
| 15 |
+
logweights (`torch.Tensor` of shape `(batch_size, num_gaussians, 1, height, width)` or
|
| 16 |
+
`(batch_size, num_gaussians, 1, frame, height, width)`):
|
| 17 |
+
Gaussian mixture log-weights (logits).
|
| 18 |
+
logstds (`torch.Tensor` of shape `(batch_size, 1, 1, 1, 1)` or `(batch_size, 1, 1, 1, 1, 1)`):
|
| 19 |
+
Gaussian mixture log-standard-deviations (logstds are shared across all Gaussians and channels).
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
means: torch.Tensor
|
| 23 |
+
logweights: torch.Tensor
|
| 24 |
+
logstds: torch.Tensor
|
lakonlab/models/architecture/gmflow/gmflux.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, Optional, Tuple
|
| 6 |
+
from diffusers.models import ModelMixin
|
| 7 |
+
from diffusers.models.transformers.transformer_flux import (
|
| 8 |
+
FluxTransformer2DModel, FluxPosEmbed, FluxTransformerBlock, FluxSingleTransformerBlock)
|
| 9 |
+
from diffusers.models.embeddings import (
|
| 10 |
+
CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings)
|
| 11 |
+
from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
| 12 |
+
from diffusers.configuration_utils import register_to_config
|
| 13 |
+
from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers
|
| 14 |
+
from .gm_output import GMFlowModelOutput
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class _GMFluxTransformer2DModel(FluxTransformer2DModel):
|
| 18 |
+
|
| 19 |
+
@register_to_config
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
num_gaussians=16,
|
| 23 |
+
constant_logstd=None,
|
| 24 |
+
logstd_inner_dim=1024,
|
| 25 |
+
gm_num_logstd_layers=2,
|
| 26 |
+
logweights_channels=1,
|
| 27 |
+
in_channels: int = 64,
|
| 28 |
+
out_channels: Optional[int] = None,
|
| 29 |
+
num_layers: int = 19,
|
| 30 |
+
num_single_layers: int = 38,
|
| 31 |
+
attention_head_dim: int = 128,
|
| 32 |
+
num_attention_heads: int = 24,
|
| 33 |
+
joint_attention_dim: int = 4096,
|
| 34 |
+
pooled_projection_dim: int = 768,
|
| 35 |
+
guidance_embeds: bool = False,
|
| 36 |
+
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)):
|
| 37 |
+
super(FluxTransformer2DModel, self).__init__()
|
| 38 |
+
|
| 39 |
+
self.num_gaussians = num_gaussians
|
| 40 |
+
self.logweights_channels = logweights_channels
|
| 41 |
+
|
| 42 |
+
self.out_channels = out_channels or in_channels
|
| 43 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
| 44 |
+
|
| 45 |
+
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
| 46 |
+
|
| 47 |
+
text_time_guidance_cls = (
|
| 48 |
+
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
|
| 49 |
+
)
|
| 50 |
+
self.time_text_embed = text_time_guidance_cls(
|
| 51 |
+
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
|
| 55 |
+
self.x_embedder = nn.Linear(in_channels, self.inner_dim)
|
| 56 |
+
|
| 57 |
+
self.transformer_blocks = nn.ModuleList(
|
| 58 |
+
[
|
| 59 |
+
FluxTransformerBlock(
|
| 60 |
+
dim=self.inner_dim,
|
| 61 |
+
num_attention_heads=num_attention_heads,
|
| 62 |
+
attention_head_dim=attention_head_dim,
|
| 63 |
+
)
|
| 64 |
+
for _ in range(num_layers)
|
| 65 |
+
]
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
self.single_transformer_blocks = nn.ModuleList(
|
| 69 |
+
[
|
| 70 |
+
FluxSingleTransformerBlock(
|
| 71 |
+
dim=self.inner_dim,
|
| 72 |
+
num_attention_heads=num_attention_heads,
|
| 73 |
+
attention_head_dim=attention_head_dim,
|
| 74 |
+
)
|
| 75 |
+
for _ in range(num_single_layers)
|
| 76 |
+
]
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
| 80 |
+
self.proj_out_means = nn.Linear(self.inner_dim, self.num_gaussians * self.out_channels)
|
| 81 |
+
self.proj_out_logweights = nn.Linear(self.inner_dim, self.num_gaussians * self.logweights_channels)
|
| 82 |
+
self.constant_logstd = constant_logstd
|
| 83 |
+
|
| 84 |
+
if self.constant_logstd is None:
|
| 85 |
+
assert gm_num_logstd_layers >= 1
|
| 86 |
+
in_dim = self.inner_dim
|
| 87 |
+
logstd_layers = []
|
| 88 |
+
for _ in range(gm_num_logstd_layers - 1):
|
| 89 |
+
logstd_layers.extend([
|
| 90 |
+
nn.SiLU(),
|
| 91 |
+
nn.Linear(in_dim, logstd_inner_dim)])
|
| 92 |
+
in_dim = logstd_inner_dim
|
| 93 |
+
self.proj_out_logstds = nn.Sequential(
|
| 94 |
+
*logstd_layers,
|
| 95 |
+
nn.SiLU(),
|
| 96 |
+
nn.Linear(in_dim, 1))
|
| 97 |
+
|
| 98 |
+
self.gradient_checkpointing = False
|
| 99 |
+
|
| 100 |
+
def forward(
|
| 101 |
+
self,
|
| 102 |
+
hidden_states: torch.Tensor,
|
| 103 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 104 |
+
pooled_projections: torch.Tensor = None,
|
| 105 |
+
timestep: torch.Tensor = None,
|
| 106 |
+
img_ids: torch.Tensor = None,
|
| 107 |
+
txt_ids: torch.Tensor = None,
|
| 108 |
+
guidance: torch.Tensor = None,
|
| 109 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 110 |
+
controlnet_block_samples=None,
|
| 111 |
+
controlnet_single_block_samples=None,
|
| 112 |
+
controlnet_blocks_repeat: bool = False):
|
| 113 |
+
if joint_attention_kwargs is not None:
|
| 114 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
| 115 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
| 116 |
+
else:
|
| 117 |
+
lora_scale = 1.0
|
| 118 |
+
|
| 119 |
+
if USE_PEFT_BACKEND:
|
| 120 |
+
scale_lora_layers(self, lora_scale)
|
| 121 |
+
else:
|
| 122 |
+
assert joint_attention_kwargs is None or joint_attention_kwargs.get('scale', None) is None
|
| 123 |
+
|
| 124 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 125 |
+
|
| 126 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
| 127 |
+
if guidance is not None:
|
| 128 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
| 129 |
+
|
| 130 |
+
temb = (
|
| 131 |
+
self.time_text_embed(timestep, pooled_projections)
|
| 132 |
+
if guidance is None
|
| 133 |
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
| 134 |
+
)
|
| 135 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
| 136 |
+
|
| 137 |
+
ids = torch.cat((txt_ids, img_ids), dim=0)
|
| 138 |
+
image_rotary_emb = self.pos_embed(ids)
|
| 139 |
+
image_rotary_emb = tuple([x.to(hidden_states.dtype) for x in image_rotary_emb])
|
| 140 |
+
|
| 141 |
+
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
| 142 |
+
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
| 143 |
+
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
|
| 144 |
+
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
|
| 145 |
+
|
| 146 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
| 147 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 148 |
+
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
| 149 |
+
block,
|
| 150 |
+
hidden_states,
|
| 151 |
+
encoder_hidden_states,
|
| 152 |
+
temb,
|
| 153 |
+
image_rotary_emb,
|
| 154 |
+
joint_attention_kwargs,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
else:
|
| 158 |
+
encoder_hidden_states, hidden_states = block(
|
| 159 |
+
hidden_states=hidden_states,
|
| 160 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 161 |
+
temb=temb,
|
| 162 |
+
image_rotary_emb=image_rotary_emb,
|
| 163 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# controlnet residual
|
| 167 |
+
if controlnet_block_samples is not None:
|
| 168 |
+
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
| 169 |
+
interval_control = int(np.ceil(interval_control))
|
| 170 |
+
# For Xlabs ControlNet.
|
| 171 |
+
if controlnet_blocks_repeat:
|
| 172 |
+
hidden_states = (
|
| 173 |
+
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
|
| 174 |
+
)
|
| 175 |
+
else:
|
| 176 |
+
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
| 177 |
+
|
| 178 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
| 179 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 180 |
+
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
| 181 |
+
block,
|
| 182 |
+
hidden_states,
|
| 183 |
+
encoder_hidden_states,
|
| 184 |
+
temb,
|
| 185 |
+
image_rotary_emb,
|
| 186 |
+
joint_attention_kwargs,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
else:
|
| 190 |
+
encoder_hidden_states, hidden_states = block(
|
| 191 |
+
hidden_states=hidden_states,
|
| 192 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 193 |
+
temb=temb,
|
| 194 |
+
image_rotary_emb=image_rotary_emb,
|
| 195 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# controlnet residual
|
| 199 |
+
if controlnet_single_block_samples is not None:
|
| 200 |
+
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
| 201 |
+
interval_control = int(np.ceil(interval_control))
|
| 202 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
| 203 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
| 204 |
+
+ controlnet_single_block_samples[index_block // interval_control]
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 208 |
+
|
| 209 |
+
bs, seq_len, _ = hidden_states.size()
|
| 210 |
+
out_means = self.proj_out_means(hidden_states).reshape(
|
| 211 |
+
bs, seq_len, self.num_gaussians, self.out_channels)
|
| 212 |
+
out_logweights = self.proj_out_logweights(hidden_states).reshape(
|
| 213 |
+
bs, seq_len, self.num_gaussians, self.logweights_channels).log_softmax(dim=-2)
|
| 214 |
+
if self.constant_logstd is None:
|
| 215 |
+
out_logstds = self.proj_out_logstds(temb.detach()).reshape(bs, 1, 1, 1)
|
| 216 |
+
else:
|
| 217 |
+
out_logstds = hidden_states.new_full((bs, 1, 1, 1), float(self.constant_logstd))
|
| 218 |
+
|
| 219 |
+
if USE_PEFT_BACKEND:
|
| 220 |
+
unscale_lora_layers(self, lora_scale)
|
| 221 |
+
|
| 222 |
+
return GMFlowModelOutput(
|
| 223 |
+
means=out_means,
|
| 224 |
+
logweights=out_logweights,
|
| 225 |
+
logstds=out_logstds)
|
lakonlab/models/architecture/gmflow/gmqwen.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from typing import Any, Dict, Optional, Tuple, List
|
| 5 |
+
from diffusers.models import ModelMixin
|
| 6 |
+
from diffusers.models.transformers.transformer_qwenimage import (
|
| 7 |
+
QwenImageTransformer2DModel, QwenEmbedRope, QwenImageTransformerBlock, QwenTimestepProjEmbeddings)
|
| 8 |
+
from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm
|
| 9 |
+
from diffusers.configuration_utils import register_to_config
|
| 10 |
+
from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers
|
| 11 |
+
from .gm_output import GMFlowModelOutput
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class _GMQwenImageTransformer2DModel(QwenImageTransformer2DModel):
|
| 15 |
+
|
| 16 |
+
@register_to_config
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
num_gaussians=16,
|
| 20 |
+
constant_logstd=None,
|
| 21 |
+
logstd_inner_dim=1024,
|
| 22 |
+
gm_num_logstd_layers=2,
|
| 23 |
+
logweights_channels=1,
|
| 24 |
+
in_channels: int = 64,
|
| 25 |
+
out_channels: Optional[int] = None,
|
| 26 |
+
num_layers: int = 60,
|
| 27 |
+
attention_head_dim: int = 128,
|
| 28 |
+
num_attention_heads: int = 24,
|
| 29 |
+
joint_attention_dim: int = 3584,
|
| 30 |
+
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56)):
|
| 31 |
+
super(QwenImageTransformer2DModel, self).__init__()
|
| 32 |
+
|
| 33 |
+
self.num_gaussians = num_gaussians
|
| 34 |
+
self.logweights_channels = logweights_channels
|
| 35 |
+
|
| 36 |
+
self.out_channels = out_channels or in_channels
|
| 37 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
| 38 |
+
|
| 39 |
+
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
|
| 40 |
+
|
| 41 |
+
self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
|
| 42 |
+
|
| 43 |
+
self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
|
| 44 |
+
|
| 45 |
+
self.img_in = nn.Linear(in_channels, self.inner_dim)
|
| 46 |
+
self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim)
|
| 47 |
+
|
| 48 |
+
self.transformer_blocks = nn.ModuleList(
|
| 49 |
+
[
|
| 50 |
+
QwenImageTransformerBlock(
|
| 51 |
+
dim=self.inner_dim,
|
| 52 |
+
num_attention_heads=num_attention_heads,
|
| 53 |
+
attention_head_dim=attention_head_dim,
|
| 54 |
+
)
|
| 55 |
+
for _ in range(num_layers)
|
| 56 |
+
]
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
| 60 |
+
self.proj_out_means = nn.Linear(self.inner_dim, self.num_gaussians * self.out_channels)
|
| 61 |
+
self.proj_out_logweights = nn.Linear(self.inner_dim, self.num_gaussians * self.logweights_channels)
|
| 62 |
+
self.constant_logstd = constant_logstd
|
| 63 |
+
|
| 64 |
+
if self.constant_logstd is None:
|
| 65 |
+
assert gm_num_logstd_layers >= 1
|
| 66 |
+
in_dim = self.inner_dim
|
| 67 |
+
logstd_layers = []
|
| 68 |
+
for _ in range(gm_num_logstd_layers - 1):
|
| 69 |
+
logstd_layers.extend([
|
| 70 |
+
nn.SiLU(),
|
| 71 |
+
nn.Linear(in_dim, logstd_inner_dim)])
|
| 72 |
+
in_dim = logstd_inner_dim
|
| 73 |
+
self.proj_out_logstds = nn.Sequential(
|
| 74 |
+
*logstd_layers,
|
| 75 |
+
nn.SiLU(),
|
| 76 |
+
nn.Linear(in_dim, 1))
|
| 77 |
+
|
| 78 |
+
self.gradient_checkpointing = False
|
| 79 |
+
|
| 80 |
+
def forward(
|
| 81 |
+
self,
|
| 82 |
+
hidden_states: torch.Tensor,
|
| 83 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 84 |
+
encoder_hidden_states_mask: torch.Tensor = None,
|
| 85 |
+
timestep: torch.LongTensor = None,
|
| 86 |
+
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
|
| 87 |
+
txt_seq_lens: Optional[List[int]] = None,
|
| 88 |
+
attention_kwargs: Optional[Dict[str, Any]] = None):
|
| 89 |
+
if attention_kwargs is not None:
|
| 90 |
+
attention_kwargs = attention_kwargs.copy()
|
| 91 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
| 92 |
+
else:
|
| 93 |
+
lora_scale = 1.0
|
| 94 |
+
|
| 95 |
+
if USE_PEFT_BACKEND:
|
| 96 |
+
scale_lora_layers(self, lora_scale)
|
| 97 |
+
else:
|
| 98 |
+
assert attention_kwargs is None or attention_kwargs.get('scale', None) is None
|
| 99 |
+
|
| 100 |
+
hidden_states = self.img_in(hidden_states)
|
| 101 |
+
|
| 102 |
+
timestep = timestep.to(hidden_states.dtype)
|
| 103 |
+
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
| 104 |
+
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
| 105 |
+
|
| 106 |
+
temb = self.time_text_embed(timestep, hidden_states)
|
| 107 |
+
|
| 108 |
+
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
|
| 109 |
+
|
| 110 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
| 111 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 112 |
+
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
| 113 |
+
block,
|
| 114 |
+
hidden_states,
|
| 115 |
+
encoder_hidden_states,
|
| 116 |
+
encoder_hidden_states_mask,
|
| 117 |
+
temb,
|
| 118 |
+
image_rotary_emb,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
else:
|
| 122 |
+
encoder_hidden_states, hidden_states = block(
|
| 123 |
+
hidden_states=hidden_states,
|
| 124 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 125 |
+
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
| 126 |
+
temb=temb,
|
| 127 |
+
image_rotary_emb=image_rotary_emb,
|
| 128 |
+
joint_attention_kwargs=attention_kwargs,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 132 |
+
|
| 133 |
+
bs, seq_len, _ = hidden_states.size()
|
| 134 |
+
out_means = self.proj_out_means(hidden_states).reshape(
|
| 135 |
+
bs, seq_len, self.num_gaussians, self.out_channels)
|
| 136 |
+
out_logweights = self.proj_out_logweights(hidden_states).reshape(
|
| 137 |
+
bs, seq_len, self.num_gaussians, self.logweights_channels).log_softmax(dim=-2)
|
| 138 |
+
if self.constant_logstd is None:
|
| 139 |
+
out_logstds = self.proj_out_logstds(temb.detach()).reshape(bs, 1, 1, 1)
|
| 140 |
+
else:
|
| 141 |
+
out_logstds = hidden_states.new_full((bs, 1, 1, 1), float(self.constant_logstd))
|
| 142 |
+
|
| 143 |
+
if USE_PEFT_BACKEND:
|
| 144 |
+
unscale_lora_layers(self, lora_scale)
|
| 145 |
+
|
| 146 |
+
return GMFlowModelOutput(
|
| 147 |
+
means=out_means,
|
| 148 |
+
logweights=out_logweights,
|
| 149 |
+
logstds=out_logstds)
|
lakonlab/models/diffusions/__init__.py
ADDED
|
File without changes
|
lakonlab/models/diffusions/piflow_policies/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .dx import DXPolicy
|
| 2 |
+
from .gmflow import GMFlowPolicy
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
POLICY_CLASSES = dict(
|
| 6 |
+
DX=DXPolicy,
|
| 7 |
+
GMFlow=GMFlowPolicy
|
| 8 |
+
)
|
lakonlab/models/diffusions/piflow_policies/base.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABCMeta, abstractmethod
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class BasePolicy(metaclass=ABCMeta):
|
| 5 |
+
|
| 6 |
+
@abstractmethod
|
| 7 |
+
def u(self, x_t, sigma_t):
|
| 8 |
+
"""Compute the flow velocity at (x_t, t).
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
x_t (torch.Tensor): Noisy input at time t.
|
| 12 |
+
sigma_t (torch.Tensor): Noise level at time t.
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
torch.Tensor: The computed flow velocity u_t.
|
| 16 |
+
"""
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
@abstractmethod
|
| 20 |
+
def detach(self):
|
| 21 |
+
pass
|
lakonlab/models/diffusions/piflow_policies/dx.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Hansheng Chen
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from .base import BasePolicy
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class DXPolicy(BasePolicy):
|
| 8 |
+
"""DX policy. The number of grid points N is inferred from the denoising output.
|
| 9 |
+
|
| 10 |
+
Note: segment_size and shift are intrinsic parameters of the DX policy. For elastic inference (i.e., changing
|
| 11 |
+
the number of function evaluations or noise schedule at test time), these parameters should be kept unchanged.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
denoising_output (torch.Tensor): The output of the denoising model. Shape (B, N, C, H, W) or (B, N, C, T, H, W).
|
| 15 |
+
x_t_src (torch.Tensor): The initial noisy sample. Shape (B, C, H, W) or (B, C, T, H, W).
|
| 16 |
+
sigma_t_src (torch.Tensor): The initial noise level. Shape (B,).
|
| 17 |
+
segment_size (float): The size of each DX policy time segment. Defaults to 1.0.
|
| 18 |
+
shift (float): The shift parameter for the DX policy noise schedule. Defaults to 1.0.
|
| 19 |
+
eps (float): A small value to avoid numerical issues. Defaults to 1e-4.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
denoising_output: torch.Tensor,
|
| 25 |
+
x_t_src: torch.Tensor,
|
| 26 |
+
sigma_t_src: torch.Tensor,
|
| 27 |
+
segment_size: float = 1.0,
|
| 28 |
+
shift: float = 1.0,
|
| 29 |
+
eps: float = 1e-4):
|
| 30 |
+
self.x_t_src = x_t_src
|
| 31 |
+
self.ndim = x_t_src.dim()
|
| 32 |
+
self.shift = shift
|
| 33 |
+
self.eps = eps
|
| 34 |
+
|
| 35 |
+
self.sigma_t_src = sigma_t_src.reshape(*sigma_t_src.size(), *((self.ndim - sigma_t_src.dim()) * [1]))
|
| 36 |
+
self.raw_t_src = self._unwarp_t(self.sigma_t_src)
|
| 37 |
+
self.raw_t_dst = (self.raw_t_src - segment_size).clamp(min=0)
|
| 38 |
+
self.segment_size = (self.raw_t_src - self.raw_t_dst).clamp(min=eps)
|
| 39 |
+
|
| 40 |
+
self.denoising_output_x_0 = self._u_to_x_0(
|
| 41 |
+
denoising_output, self.x_t_src, self.sigma_t_src)
|
| 42 |
+
|
| 43 |
+
def _unwarp_t(self, sigma_t):
|
| 44 |
+
return sigma_t / (self.shift + (1 - self.shift) * sigma_t)
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def _u_to_x_0(denoising_output, x_t, sigma_t):
|
| 48 |
+
x_0 = x_t.unsqueeze(1) - sigma_t.unsqueeze(1) * denoising_output
|
| 49 |
+
return x_0
|
| 50 |
+
|
| 51 |
+
@staticmethod
|
| 52 |
+
def _interpolate(x, t):
|
| 53 |
+
"""
|
| 54 |
+
Args:
|
| 55 |
+
x (torch.Tensor): (B, N, *)
|
| 56 |
+
t (torch.Tensor): (B, *) in [0, 1]
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
torch.Tensor: (B, *)
|
| 60 |
+
"""
|
| 61 |
+
n = x.size(1)
|
| 62 |
+
if n < 2:
|
| 63 |
+
return x.squeeze(1)
|
| 64 |
+
t = t.clamp(min=0, max=1) * (n - 1)
|
| 65 |
+
t0 = t.floor().to(torch.long).clamp(min=0, max=n - 2)
|
| 66 |
+
t1 = t0 + 1
|
| 67 |
+
t0t1 = torch.stack([t0, t1], dim=1) # (B, 2, *)
|
| 68 |
+
x0x1 = torch.gather(x, dim=1, index=t0t1.expand(-1, -1, *x.shape[2:]))
|
| 69 |
+
x_interp = (t1 - t) * x0x1[:, 0] + (t - t0) * x0x1[:, 1]
|
| 70 |
+
return x_interp
|
| 71 |
+
|
| 72 |
+
def u(self, x_t, sigma_t):
|
| 73 |
+
"""Compute the flow velocity at (x_t, t).
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
x_t (torch.Tensor): Noisy input at time t.
|
| 77 |
+
sigma_t (torch.Tensor): Noise level at time t.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
torch.Tensor: The computed flow velocity u_t.
|
| 81 |
+
"""
|
| 82 |
+
sigma_t = sigma_t.reshape(*sigma_t.size(), *((self.ndim - sigma_t.dim()) * [1]))
|
| 83 |
+
raw_t = self._unwarp_t(sigma_t)
|
| 84 |
+
x_0 = self._interpolate(
|
| 85 |
+
self.denoising_output_x_0, (raw_t - self.raw_t_dst) / self.segment_size)
|
| 86 |
+
u = (x_t - x_0) / sigma_t.clamp(min=self.eps)
|
| 87 |
+
return u
|
| 88 |
+
|
| 89 |
+
def copy(self):
|
| 90 |
+
new_policy = DXPolicy.__new__(DXPolicy)
|
| 91 |
+
new_policy.x_t_src = self.x_t_src
|
| 92 |
+
new_policy.ndim = self.ndim
|
| 93 |
+
new_policy.shift = self.shift
|
| 94 |
+
new_policy.eps = self.eps
|
| 95 |
+
new_policy.sigma_t_src = self.sigma_t_src
|
| 96 |
+
new_policy.raw_t_src = self.raw_t_src
|
| 97 |
+
new_policy.raw_t_dst = self.raw_t_dst
|
| 98 |
+
new_policy.segment_size = self.segment_size
|
| 99 |
+
new_policy.denoising_output_x_0 = self.denoising_output_x_0
|
| 100 |
+
return new_policy
|
| 101 |
+
|
| 102 |
+
def detach_(self):
|
| 103 |
+
self.denoising_output_x_0 = self.denoising_output_x_0.detach()
|
| 104 |
+
return self
|
| 105 |
+
|
| 106 |
+
def detach(self):
|
| 107 |
+
new_policy = self.copy()
|
| 108 |
+
return new_policy.detach_()
|
lakonlab/models/diffusions/piflow_policies/gmflow.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Hansheng Chen
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from typing import Dict
|
| 6 |
+
from .base import BasePolicy
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@torch.jit.script
|
| 10 |
+
def gmflow_posterior_mean_jit(
|
| 11 |
+
sigma_t_src, sigma_t, x_t_src, x_t,
|
| 12 |
+
gm_means, gm_vars, gm_logweights,
|
| 13 |
+
eps: float, gm_dim: int = -4, channel_dim: int = -3):
|
| 14 |
+
alpha_t_src = 1 - sigma_t_src
|
| 15 |
+
alpha_t = 1 - sigma_t
|
| 16 |
+
|
| 17 |
+
sigma_t_src_sq = sigma_t_src.square()
|
| 18 |
+
sigma_t_sq = sigma_t.square()
|
| 19 |
+
|
| 20 |
+
# compute gaussian params
|
| 21 |
+
denom = (alpha_t.square() * sigma_t_src_sq - alpha_t_src.square() * sigma_t_sq).clamp(min=eps) # ζ
|
| 22 |
+
g_mean = (alpha_t * sigma_t_src_sq * x_t - alpha_t_src * sigma_t_sq * x_t_src) / denom # ν / ζ
|
| 23 |
+
g_var = sigma_t_sq * sigma_t_src_sq / denom
|
| 24 |
+
|
| 25 |
+
# gm_mul_iso_gaussian
|
| 26 |
+
g_mean = g_mean.unsqueeze(gm_dim) # (bs, *, 1, out_channels, h, w)
|
| 27 |
+
g_var = g_var.unsqueeze(gm_dim) # (bs, *, 1, 1, 1, 1)
|
| 28 |
+
|
| 29 |
+
gm_diffs = gm_means - g_mean # (bs, *, num_gaussians, out_channels, h, w)
|
| 30 |
+
norm_factor = (g_var + gm_vars).clamp(min=eps)
|
| 31 |
+
|
| 32 |
+
out_means = (g_var * gm_means + gm_vars * g_mean) / norm_factor
|
| 33 |
+
# (bs, *, num_gaussians, 1, h, w)
|
| 34 |
+
logweights_delta = gm_diffs.square().sum(dim=channel_dim, keepdim=True) * (-0.5 / norm_factor)
|
| 35 |
+
out_weights = (gm_logweights + logweights_delta).softmax(dim=gm_dim)
|
| 36 |
+
|
| 37 |
+
out_mean = (out_means * out_weights).sum(dim=gm_dim)
|
| 38 |
+
|
| 39 |
+
return out_mean
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def gm_temperature(gm, temperature, gm_dim=-4, eps=1e-6):
|
| 43 |
+
gm = gm.copy()
|
| 44 |
+
temperature = max(temperature, eps)
|
| 45 |
+
gm['logweights'] = (gm['logweights'] / temperature).log_softmax(dim=gm_dim)
|
| 46 |
+
if 'logstds' in gm:
|
| 47 |
+
gm['logstds'] = gm['logstds'] + (0.5 * math.log(temperature))
|
| 48 |
+
if 'gm_vars' in gm:
|
| 49 |
+
gm['gm_vars'] = gm['gm_vars'] * temperature
|
| 50 |
+
return gm
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class GMFlowPolicy(BasePolicy):
|
| 54 |
+
"""GMFlow policy. The number of components K is inferred from the denoising output.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
denoising_output (dict): The output of the denoising model, containing:
|
| 58 |
+
means (torch.Tensor): The means of the Gaussian components. Shape (B, K, C, H, W) or (B, K, C, T, H, W).
|
| 59 |
+
logstds (torch.Tensor): The log standard deviations of the Gaussian components. Shape (B, K, 1, 1, 1)
|
| 60 |
+
or (B, K, 1, 1, 1, 1).
|
| 61 |
+
logweights (torch.Tensor): The log weights of the Gaussian components. Shape (B, K, 1, H, W) or
|
| 62 |
+
(B, K, 1, T, H, W).
|
| 63 |
+
x_t_src (torch.Tensor): The initial noisy sample. Shape (B, C, H, W) or (B, C, T, H, W).
|
| 64 |
+
sigma_t_src (torch.Tensor): The initial noise level. Shape (B,).
|
| 65 |
+
checkpointing (bool): Whether to use gradient checkpointing to save memory. Defaults to True.
|
| 66 |
+
eps (float): A small value to avoid numerical issues. Defaults to 1e-4.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
denoising_output: Dict[str, torch.Tensor],
|
| 72 |
+
x_t_src: torch.Tensor,
|
| 73 |
+
sigma_t_src: torch.Tensor,
|
| 74 |
+
checkpointing: bool = True,
|
| 75 |
+
eps: float = 1e-4):
|
| 76 |
+
self.x_t_src = x_t_src
|
| 77 |
+
self.ndim = x_t_src.dim()
|
| 78 |
+
self.checkpointing = checkpointing
|
| 79 |
+
self.eps = eps
|
| 80 |
+
|
| 81 |
+
self.sigma_t_src = sigma_t_src.reshape(*sigma_t_src.size(), *((self.ndim - sigma_t_src.dim()) * [1]))
|
| 82 |
+
self.denoising_output_x_0 = self._u_to_x_0(
|
| 83 |
+
denoising_output, self.x_t_src, self.sigma_t_src)
|
| 84 |
+
|
| 85 |
+
@staticmethod
|
| 86 |
+
def _u_to_x_0(denoising_output, x_t, sigma_t):
|
| 87 |
+
x_t = x_t.unsqueeze(1)
|
| 88 |
+
sigma_t = sigma_t.unsqueeze(1)
|
| 89 |
+
means_x_0 = x_t - sigma_t * denoising_output['means']
|
| 90 |
+
gm_vars = (denoising_output['logstds'] * 2).exp() * sigma_t.square()
|
| 91 |
+
return dict(
|
| 92 |
+
means=means_x_0,
|
| 93 |
+
gm_vars=gm_vars,
|
| 94 |
+
logweights=denoising_output['logweights'])
|
| 95 |
+
|
| 96 |
+
def u(self, x_t, sigma_t):
|
| 97 |
+
"""Compute the flow velocity at (x_t, t).
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
x_t (torch.Tensor): Noisy input at time t.
|
| 101 |
+
sigma_t (torch.Tensor): Noise level at time t.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
torch.Tensor: The computed flow velocity u_t.
|
| 105 |
+
"""
|
| 106 |
+
sigma_t = sigma_t.reshape(*sigma_t.size(), *((self.ndim - sigma_t.dim()) * [1]))
|
| 107 |
+
means = self.denoising_output_x_0['means']
|
| 108 |
+
gm_vars = self.denoising_output_x_0['gm_vars']
|
| 109 |
+
logweights = self.denoising_output_x_0['logweights']
|
| 110 |
+
if (sigma_t == self.sigma_t_src).all() and (x_t == self.x_t_src).all():
|
| 111 |
+
x_0 = (logweights.softmax(dim=1) * means).sum(dim=1)
|
| 112 |
+
else:
|
| 113 |
+
if self.checkpointing and torch.is_grad_enabled():
|
| 114 |
+
x_0 = torch.utils.checkpoint.checkpoint(
|
| 115 |
+
gmflow_posterior_mean_jit,
|
| 116 |
+
self.sigma_t_src, sigma_t, self.x_t_src, x_t,
|
| 117 |
+
means,
|
| 118 |
+
gm_vars,
|
| 119 |
+
logweights,
|
| 120 |
+
self.eps, 1, 2,
|
| 121 |
+
use_reentrant=True) # use_reentrant=False does not work with jit
|
| 122 |
+
else:
|
| 123 |
+
x_0 = gmflow_posterior_mean_jit(
|
| 124 |
+
self.sigma_t_src, sigma_t, self.x_t_src, x_t,
|
| 125 |
+
means,
|
| 126 |
+
gm_vars,
|
| 127 |
+
logweights,
|
| 128 |
+
self.eps, 1, 2)
|
| 129 |
+
u = (x_t - x_0) / sigma_t.clamp(min=self.eps)
|
| 130 |
+
return u
|
| 131 |
+
|
| 132 |
+
def copy(self):
|
| 133 |
+
new_policy = GMFlowPolicy.__new__(GMFlowPolicy)
|
| 134 |
+
new_policy.x_t_src = self.x_t_src
|
| 135 |
+
new_policy.ndim = self.ndim
|
| 136 |
+
new_policy.checkpointing = self.checkpointing
|
| 137 |
+
new_policy.eps = self.eps
|
| 138 |
+
new_policy.sigma_t_src = self.sigma_t_src
|
| 139 |
+
new_policy.denoising_output_x_0 = self.denoising_output_x_0.copy()
|
| 140 |
+
return new_policy
|
| 141 |
+
|
| 142 |
+
def detach_(self):
|
| 143 |
+
self.denoising_output_x_0 = {k: v.detach() for k, v in self.denoising_output_x_0.items()}
|
| 144 |
+
return self
|
| 145 |
+
|
| 146 |
+
def detach(self):
|
| 147 |
+
new_policy = self.copy()
|
| 148 |
+
return new_policy.detach_()
|
| 149 |
+
|
| 150 |
+
def dropout_(self, p):
|
| 151 |
+
if p <= 0 or p >= 1:
|
| 152 |
+
return self
|
| 153 |
+
logweights = self.denoising_output_x_0['logweights']
|
| 154 |
+
dropout_mask = torch.rand(
|
| 155 |
+
(*logweights.shape[:2], *((self.ndim - 1) * [1])), device=logweights.device) < p
|
| 156 |
+
is_all_dropout = dropout_mask.all(dim=1, keepdim=True)
|
| 157 |
+
dropout_mask &= ~is_all_dropout
|
| 158 |
+
self.denoising_output_x_0['logweights'] = logweights.masked_fill(
|
| 159 |
+
dropout_mask, float('-inf'))
|
| 160 |
+
return self
|
| 161 |
+
|
| 162 |
+
def dropout(self, p):
|
| 163 |
+
new_policy = self.copy()
|
| 164 |
+
return new_policy.dropout_(p)
|
| 165 |
+
|
| 166 |
+
def temperature_(self, temp):
|
| 167 |
+
if temp >= 1.0:
|
| 168 |
+
return self
|
| 169 |
+
self.denoising_output_x_0 = gm_temperature(
|
| 170 |
+
self.denoising_output_x_0, temp, gm_dim=1, eps=self.eps)
|
| 171 |
+
return self
|
| 172 |
+
|
| 173 |
+
def temperature(self, temp):
|
| 174 |
+
new_policy = self.copy()
|
| 175 |
+
return new_policy.temperature_(temp)
|
lakonlab/pipelines/__init__.py
ADDED
|
File without changes
|
lakonlab/pipelines/piflow_loader.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Hansheng Chen
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from typing import Union, Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import accelerate
|
| 8 |
+
import diffusers
|
| 9 |
+
from diffusers.models import AutoModel
|
| 10 |
+
from diffusers.models.modeling_utils import (
|
| 11 |
+
load_state_dict,
|
| 12 |
+
_LOW_CPU_MEM_USAGE_DEFAULT,
|
| 13 |
+
no_init_weights,
|
| 14 |
+
ContextManagers
|
| 15 |
+
)
|
| 16 |
+
from diffusers.utils import (
|
| 17 |
+
SAFETENSORS_WEIGHTS_NAME,
|
| 18 |
+
WEIGHTS_NAME,
|
| 19 |
+
_add_variant,
|
| 20 |
+
_get_model_file,
|
| 21 |
+
is_accelerate_available,
|
| 22 |
+
is_torch_version,
|
| 23 |
+
logging,
|
| 24 |
+
)
|
| 25 |
+
from diffusers.loaders.peft import _SET_ADAPTER_SCALE_FN_MAPPING
|
| 26 |
+
from lakonlab.models.architecture.gmflow.gmflux import _GMFluxTransformer2DModel
|
| 27 |
+
from lakonlab.models.architecture.gmflow.gmqwen import _GMQwenImageTransformer2DModel
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
LOCAL_CLASS_MAPPING = {
|
| 31 |
+
"GMFluxTransformer2DModel": _GMFluxTransformer2DModel,
|
| 32 |
+
"GMQwenImageTransformer2DModel": _GMQwenImageTransformer2DModel,
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
_SET_ADAPTER_SCALE_FN_MAPPING.update(
|
| 36 |
+
_GMFluxTransformer2DModel=lambda model_cls, weights: weights,
|
| 37 |
+
_GMQwenImageTransformer2DModel=lambda model_cls, weights: weights,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
logger = logging.get_logger(__name__)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class PiFlowLoaderMixin:
|
| 44 |
+
|
| 45 |
+
def load_piflow_adapter(
|
| 46 |
+
self,
|
| 47 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
| 48 |
+
target_module_name: str = "transformer",
|
| 49 |
+
adapter_name: Optional[str] = None,
|
| 50 |
+
**kwargs
|
| 51 |
+
):
|
| 52 |
+
r"""
|
| 53 |
+
Load a PiFlow adapter from a pretrained model repository into the target module.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
| 57 |
+
Can be either:
|
| 58 |
+
|
| 59 |
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
| 60 |
+
the Hub.
|
| 61 |
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
| 62 |
+
with [`~ModelMixin.save_pretrained`].
|
| 63 |
+
|
| 64 |
+
target_module_name (`str`, *optional*, defaults to `"transformer"`):
|
| 65 |
+
The module name in the model to load the PiFlow adapter into.
|
| 66 |
+
adapter_name (`str`, *optional*):
|
| 67 |
+
The name to assign to the loaded adapter. If not provided, it defaults to
|
| 68 |
+
`"{target_module_name}_piflow"`.
|
| 69 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
| 70 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
| 71 |
+
is not used.
|
| 72 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 73 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
| 74 |
+
cached versions if they exist.
|
| 75 |
+
proxies (`Dict[str, str]`, *optional*):
|
| 76 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
| 77 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
| 78 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
| 79 |
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
| 80 |
+
won't be downloaded from the Hub.
|
| 81 |
+
token (`str` or *bool*, *optional*):
|
| 82 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
| 83 |
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
| 84 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 85 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
| 86 |
+
allowed by Git.
|
| 87 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
| 88 |
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
| 89 |
+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
| 90 |
+
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
| 91 |
+
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
| 92 |
+
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
| 93 |
+
argument to `True` will raise an error.
|
| 94 |
+
variant (`str`, *optional*):
|
| 95 |
+
Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
|
| 96 |
+
loading `from_flax`.
|
| 97 |
+
use_safetensors (`bool`, *optional*, defaults to `None`):
|
| 98 |
+
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
|
| 99 |
+
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
|
| 100 |
+
weights. If set to `False`, `safetensors` weights are not loaded.
|
| 101 |
+
disable_mmap ('bool', *optional*, defaults to 'False'):
|
| 102 |
+
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
| 103 |
+
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
`str` or `None`: The name assigned to the loaded adapter, or `None` if no LoRA weights were found.
|
| 107 |
+
"""
|
| 108 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
| 109 |
+
force_download = kwargs.pop("force_download", False)
|
| 110 |
+
proxies = kwargs.pop("proxies", None)
|
| 111 |
+
token = kwargs.pop("token", None)
|
| 112 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
| 113 |
+
revision = kwargs.pop("revision", None)
|
| 114 |
+
subfolder = kwargs.pop("subfolder", None)
|
| 115 |
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
| 116 |
+
variant = kwargs.pop("variant", None)
|
| 117 |
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
| 118 |
+
disable_mmap = kwargs.pop("disable_mmap", False)
|
| 119 |
+
|
| 120 |
+
allow_pickle = False
|
| 121 |
+
if use_safetensors is None:
|
| 122 |
+
use_safetensors = True
|
| 123 |
+
allow_pickle = True
|
| 124 |
+
|
| 125 |
+
if low_cpu_mem_usage and not is_accelerate_available():
|
| 126 |
+
low_cpu_mem_usage = False
|
| 127 |
+
logger.warning(
|
| 128 |
+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
| 129 |
+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
| 130 |
+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
| 131 |
+
" install accelerate\n```\n."
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
| 135 |
+
raise NotImplementedError(
|
| 136 |
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
| 137 |
+
" `low_cpu_mem_usage=False`."
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
user_agent = {
|
| 141 |
+
"diffusers": diffusers.__version__,
|
| 142 |
+
"file_type": "model",
|
| 143 |
+
"framework": "pytorch",
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
# 1. Determine model class from config
|
| 147 |
+
|
| 148 |
+
load_config_kwargs = {
|
| 149 |
+
"cache_dir": cache_dir,
|
| 150 |
+
"force_download": force_download,
|
| 151 |
+
"proxies": proxies,
|
| 152 |
+
"token": token,
|
| 153 |
+
"local_files_only": local_files_only,
|
| 154 |
+
"revision": revision,
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
config = AutoModel.load_config(pretrained_model_name_or_path, subfolder=subfolder, **load_config_kwargs)
|
| 158 |
+
|
| 159 |
+
orig_class_name = config["_class_name"]
|
| 160 |
+
|
| 161 |
+
if orig_class_name in LOCAL_CLASS_MAPPING:
|
| 162 |
+
model_cls = LOCAL_CLASS_MAPPING[orig_class_name]
|
| 163 |
+
|
| 164 |
+
else:
|
| 165 |
+
load_config_kwargs.update({"subfolder": subfolder})
|
| 166 |
+
|
| 167 |
+
from diffusers.pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
|
| 168 |
+
|
| 169 |
+
model_cls, _ = get_class_obj_and_candidates(
|
| 170 |
+
library_name="diffusers",
|
| 171 |
+
class_name=orig_class_name,
|
| 172 |
+
importable_classes=ALL_IMPORTABLE_CLASSES,
|
| 173 |
+
pipelines=None,
|
| 174 |
+
is_pipeline_module=False,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
if model_cls is None:
|
| 178 |
+
raise ValueError(f"Can't find a model linked to {orig_class_name}.")
|
| 179 |
+
|
| 180 |
+
# 2. Get model file
|
| 181 |
+
|
| 182 |
+
model_file = None
|
| 183 |
+
|
| 184 |
+
if use_safetensors:
|
| 185 |
+
try:
|
| 186 |
+
model_file = _get_model_file(
|
| 187 |
+
pretrained_model_name_or_path,
|
| 188 |
+
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
| 189 |
+
cache_dir=cache_dir,
|
| 190 |
+
force_download=force_download,
|
| 191 |
+
proxies=proxies,
|
| 192 |
+
local_files_only=local_files_only,
|
| 193 |
+
token=token,
|
| 194 |
+
revision=revision,
|
| 195 |
+
subfolder=subfolder,
|
| 196 |
+
user_agent=user_agent,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
except IOError as e:
|
| 200 |
+
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
|
| 201 |
+
if not allow_pickle:
|
| 202 |
+
raise
|
| 203 |
+
logger.warning(
|
| 204 |
+
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
if model_file is None:
|
| 208 |
+
model_file = _get_model_file(
|
| 209 |
+
pretrained_model_name_or_path,
|
| 210 |
+
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
| 211 |
+
cache_dir=cache_dir,
|
| 212 |
+
force_download=force_download,
|
| 213 |
+
proxies=proxies,
|
| 214 |
+
local_files_only=local_files_only,
|
| 215 |
+
token=token,
|
| 216 |
+
revision=revision,
|
| 217 |
+
subfolder=subfolder,
|
| 218 |
+
user_agent=user_agent,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# 3. Initialize model
|
| 222 |
+
|
| 223 |
+
base_module = getattr(self, target_module_name)
|
| 224 |
+
|
| 225 |
+
torch_dtype = base_module.dtype
|
| 226 |
+
device = base_module.device
|
| 227 |
+
dtype_orig = model_cls._set_default_torch_dtype(torch_dtype)
|
| 228 |
+
|
| 229 |
+
init_contexts = [no_init_weights()]
|
| 230 |
+
|
| 231 |
+
if low_cpu_mem_usage:
|
| 232 |
+
init_contexts.append(accelerate.init_empty_weights())
|
| 233 |
+
|
| 234 |
+
with ContextManagers(init_contexts):
|
| 235 |
+
piflow_module = model_cls.from_config(config).eval()
|
| 236 |
+
|
| 237 |
+
torch.set_default_dtype(dtype_orig)
|
| 238 |
+
|
| 239 |
+
# 4. Load model weights
|
| 240 |
+
|
| 241 |
+
if model_file is not None:
|
| 242 |
+
base_state_dict = base_module.state_dict()
|
| 243 |
+
lora_state_dict = dict()
|
| 244 |
+
|
| 245 |
+
adapter_state_dict = load_state_dict(model_file, disable_mmap=disable_mmap)
|
| 246 |
+
for k in adapter_state_dict.keys():
|
| 247 |
+
adapter_state_dict[k] = adapter_state_dict[k].to(dtype=torch_dtype, device=device)
|
| 248 |
+
if "lora" in k:
|
| 249 |
+
lora_state_dict[k.removeprefix(f"{target_module_name}.")] = adapter_state_dict[k]
|
| 250 |
+
else:
|
| 251 |
+
base_state_dict[k.removeprefix(f"{target_module_name}.")] = adapter_state_dict[k]
|
| 252 |
+
|
| 253 |
+
if len(lora_state_dict) == 0:
|
| 254 |
+
adapter_name = None
|
| 255 |
+
|
| 256 |
+
else:
|
| 257 |
+
if adapter_name is None:
|
| 258 |
+
adapter_name = f"{target_module_name}_piflow"
|
| 259 |
+
|
| 260 |
+
piflow_module.load_state_dict(
|
| 261 |
+
base_state_dict, strict=False, assign=True)
|
| 262 |
+
piflow_module.load_lora_adapter(
|
| 263 |
+
lora_state_dict, prefix=None, adapter_name=adapter_name)
|
| 264 |
+
|
| 265 |
+
setattr(self, target_module_name, piflow_module)
|
| 266 |
+
|
| 267 |
+
else:
|
| 268 |
+
adapter_name = None
|
| 269 |
+
|
| 270 |
+
if adapter_name is None:
|
| 271 |
+
logger.warning(
|
| 272 |
+
f"No LoRA weights were found in {pretrained_model_name_or_path}."
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
return adapter_name
|
lakonlab/pipelines/piflux_pipeline.py
ADDED
|
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Hansheng Chen
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from typing import Dict, List, Optional, Union, Any, Callable
|
| 7 |
+
from functools import partial
|
| 8 |
+
from transformers import (
|
| 9 |
+
CLIPImageProcessor,
|
| 10 |
+
CLIPTextModel,
|
| 11 |
+
CLIPTokenizer,
|
| 12 |
+
CLIPVisionModelWithProjection,
|
| 13 |
+
T5EncoderModel,
|
| 14 |
+
T5TokenizerFast,
|
| 15 |
+
)
|
| 16 |
+
from diffusers.utils import is_torch_xla_available
|
| 17 |
+
from diffusers.image_processor import PipelineImageInput
|
| 18 |
+
from diffusers.models import AutoencoderKL, FluxTransformer2DModel
|
| 19 |
+
from diffusers.pipelines.flux.pipeline_flux import (
|
| 20 |
+
FluxPipeline, calculate_shift, FluxPipelineOutput, retrieve_timesteps)
|
| 21 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 22 |
+
from lakonlab.models.diffusions.piflow_policies import POLICY_CLASSES
|
| 23 |
+
from .piflow_loader import PiFlowLoaderMixin
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if is_torch_xla_available():
|
| 27 |
+
import torch_xla.core.xla_model as xm
|
| 28 |
+
|
| 29 |
+
XLA_AVAILABLE = True
|
| 30 |
+
else:
|
| 31 |
+
XLA_AVAILABLE = False
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def retrieve_raw_timesteps(
|
| 35 |
+
num_inference_steps: int,
|
| 36 |
+
total_substeps: int,
|
| 37 |
+
final_step_size_scale: float
|
| 38 |
+
):
|
| 39 |
+
r"""
|
| 40 |
+
Retrieve the raw times and the number of substeps for each inference step.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
num_inference_steps (`int`):
|
| 44 |
+
Number of inference steps.
|
| 45 |
+
total_substeps (`int`):
|
| 46 |
+
Total number of substeps (e.g., 128).
|
| 47 |
+
final_step_size_scale (`float`):
|
| 48 |
+
Scale for the final step size (e.g., 0.5).
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
`Tuple[List[float], List[int], int]`: A tuple where the first element is the raw timestep schedule, the second
|
| 52 |
+
element is the number of substeps for each inference step, and the third element is the rounded total number of
|
| 53 |
+
substeps.
|
| 54 |
+
"""
|
| 55 |
+
base_segment_size = 1 / (num_inference_steps - 1 + final_step_size_scale)
|
| 56 |
+
raw_timesteps = []
|
| 57 |
+
num_inference_substeps = []
|
| 58 |
+
_raw_t = 1.0
|
| 59 |
+
for i in range(num_inference_steps):
|
| 60 |
+
if i < num_inference_steps - 1:
|
| 61 |
+
segment_size = base_segment_size
|
| 62 |
+
else:
|
| 63 |
+
segment_size = base_segment_size * final_step_size_scale
|
| 64 |
+
_num_inference_substeps = max(round(segment_size * total_substeps), 1)
|
| 65 |
+
num_inference_substeps.append(_num_inference_substeps)
|
| 66 |
+
raw_timesteps.extend(np.linspace(
|
| 67 |
+
_raw_t, _raw_t - segment_size, _num_inference_substeps, endpoint=False).clip(min=0.0).tolist())
|
| 68 |
+
_raw_t = _raw_t - segment_size
|
| 69 |
+
total_substeps = sum(num_inference_substeps)
|
| 70 |
+
return raw_timesteps, num_inference_substeps, total_substeps
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class PiFluxPipeline(FluxPipeline, PiFlowLoaderMixin):
|
| 74 |
+
r"""
|
| 75 |
+
The policy-based Flux pipeline for text-to-image generation.
|
| 76 |
+
|
| 77 |
+
Reference: Todo: add paper link
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
transformer ([`FluxTransformer2DModel`]):
|
| 81 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
| 82 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
| 83 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 84 |
+
vae ([`AutoencoderKL`]):
|
| 85 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 86 |
+
text_encoder ([`CLIPTextModel`]):
|
| 87 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
| 88 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
| 89 |
+
text_encoder_2 ([`T5EncoderModel`]):
|
| 90 |
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
| 91 |
+
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
| 92 |
+
tokenizer (`CLIPTokenizer`):
|
| 93 |
+
Tokenizer of class
|
| 94 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 95 |
+
tokenizer_2 (`T5TokenizerFast`):
|
| 96 |
+
Second Tokenizer of class
|
| 97 |
+
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
| 98 |
+
policy_type (`str`, *optional*, defaults to `"GMFlow"`):
|
| 99 |
+
The type of flow policy to use. Currently supports `"GMFlow"` and `"DX"`.
|
| 100 |
+
policy_kwargs (`Dict`, *optional*):
|
| 101 |
+
Additional keyword arguments to pass to the policy class.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 107 |
+
vae: AutoencoderKL,
|
| 108 |
+
text_encoder: CLIPTextModel,
|
| 109 |
+
tokenizer: CLIPTokenizer,
|
| 110 |
+
text_encoder_2: T5EncoderModel,
|
| 111 |
+
tokenizer_2: T5TokenizerFast,
|
| 112 |
+
transformer: FluxTransformer2DModel,
|
| 113 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
| 114 |
+
feature_extractor: CLIPImageProcessor = None,
|
| 115 |
+
policy_type: str = 'GMFlow',
|
| 116 |
+
policy_kwargs: Optional[Dict[str, Any]] = None,
|
| 117 |
+
):
|
| 118 |
+
super().__init__(
|
| 119 |
+
scheduler,
|
| 120 |
+
vae,
|
| 121 |
+
text_encoder,
|
| 122 |
+
tokenizer,
|
| 123 |
+
text_encoder_2,
|
| 124 |
+
tokenizer_2,
|
| 125 |
+
transformer,
|
| 126 |
+
image_encoder,
|
| 127 |
+
feature_extractor
|
| 128 |
+
)
|
| 129 |
+
assert policy_type in POLICY_CLASSES, f'Invalid policy: {policy_type}. Supported policies are {list(POLICY_CLASSES.keys())}.'
|
| 130 |
+
self.policy_type = policy_type
|
| 131 |
+
self.policy_class = partial(
|
| 132 |
+
POLICY_CLASSES[policy_type], **policy_kwargs
|
| 133 |
+
) if policy_kwargs else POLICY_CLASSES[policy_type]
|
| 134 |
+
|
| 135 |
+
def _unpack_gm(self, gm, height, width, num_channels_latents, patch_size=2, gm_patch_size=1):
|
| 136 |
+
c = num_channels_latents * patch_size * patch_size
|
| 137 |
+
h = (int(height) // (self.vae_scale_factor * patch_size))
|
| 138 |
+
w = (int(width) // (self.vae_scale_factor * patch_size))
|
| 139 |
+
bs = gm['means'].size(0)
|
| 140 |
+
k = self.transformer.num_gaussians
|
| 141 |
+
scale = patch_size // gm_patch_size
|
| 142 |
+
gm['means'] = gm['means'].reshape(
|
| 143 |
+
bs, h, w, k, c // (scale * scale), scale, scale
|
| 144 |
+
).permute(
|
| 145 |
+
0, 3, 4, 1, 5, 2, 6
|
| 146 |
+
).reshape(
|
| 147 |
+
bs, k, c // (scale * scale), h * scale, w * scale)
|
| 148 |
+
gm['logweights'] = gm['logweights'].reshape(
|
| 149 |
+
bs, h, w, k, 1, scale, scale
|
| 150 |
+
).permute(
|
| 151 |
+
0, 3, 4, 1, 5, 2, 6
|
| 152 |
+
).reshape(
|
| 153 |
+
bs, k, 1, h * scale, w * scale)
|
| 154 |
+
gm['logstds'] = gm['logstds'].reshape(bs, 1, 1, 1, 1)
|
| 155 |
+
return gm
|
| 156 |
+
|
| 157 |
+
@staticmethod
|
| 158 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width, patch_size=1, target_patch_size=2):
|
| 159 |
+
scale = target_patch_size // patch_size
|
| 160 |
+
latents = latents.view(
|
| 161 |
+
batch_size,
|
| 162 |
+
num_channels_latents * patch_size * patch_size,
|
| 163 |
+
height // target_patch_size, scale, width // target_patch_size, scale)
|
| 164 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
| 165 |
+
latents = latents.reshape(
|
| 166 |
+
batch_size,
|
| 167 |
+
(height // target_patch_size) * (width // target_patch_size),
|
| 168 |
+
num_channels_latents * target_patch_size * target_patch_size)
|
| 169 |
+
|
| 170 |
+
return latents
|
| 171 |
+
|
| 172 |
+
@staticmethod
|
| 173 |
+
def _unpack_latents(latents, height, width, vae_scale_factor, patch_size=2, target_patch_size=1):
|
| 174 |
+
batch_size, num_patches, channels = latents.shape
|
| 175 |
+
scale = patch_size // target_patch_size
|
| 176 |
+
|
| 177 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 178 |
+
# latent height and width to be divisible by 2.
|
| 179 |
+
height = (int(height) // (vae_scale_factor * patch_size))
|
| 180 |
+
width = (int(width) // (vae_scale_factor * patch_size))
|
| 181 |
+
|
| 182 |
+
latents = latents.view(
|
| 183 |
+
batch_size, height, width, channels // (scale * scale), scale, scale)
|
| 184 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
| 185 |
+
|
| 186 |
+
latents = latents.reshape(batch_size, channels // (scale * scale), height * scale, width * scale)
|
| 187 |
+
|
| 188 |
+
return latents
|
| 189 |
+
|
| 190 |
+
@torch.inference_mode()
|
| 191 |
+
def __call__(
|
| 192 |
+
self,
|
| 193 |
+
prompt: Union[str, List[str]] = None,
|
| 194 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 195 |
+
height: Optional[int] = None,
|
| 196 |
+
width: Optional[int] = None,
|
| 197 |
+
num_inference_steps: int = 4,
|
| 198 |
+
total_substeps: int = 128,
|
| 199 |
+
final_step_size_scale: float = 0.5,
|
| 200 |
+
temperature: Union[float, str] = 'auto',
|
| 201 |
+
guidance_scale: float = 3.5,
|
| 202 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 203 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 204 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 205 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 206 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 207 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 208 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 209 |
+
output_type: Optional[str] = "pil",
|
| 210 |
+
return_dict: bool = True,
|
| 211 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 212 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 213 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 214 |
+
max_sequence_length: int = 512,
|
| 215 |
+
):
|
| 216 |
+
r"""
|
| 217 |
+
Function invoked when calling the pipeline for generation.
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 221 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 222 |
+
instead.
|
| 223 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 224 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 225 |
+
will be used instead.
|
| 226 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 227 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 228 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 229 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 230 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 231 |
+
The number of denoising steps.
|
| 232 |
+
total_substeps (`int`, *optional*, defaults to 128):
|
| 233 |
+
The total number of substeps for policy-based flow integration.
|
| 234 |
+
final_step_size_scale (`float`, *optional*, defaults to 0.5):
|
| 235 |
+
The scale for the final step size.
|
| 236 |
+
temperature (`float` or `"auto"`, *optional*, defaults to `"auto"`):
|
| 237 |
+
The tmperature parameter for the flow policy.
|
| 238 |
+
guidance_scale (`float`, *optional*, defaults to 3.5):
|
| 239 |
+
Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
|
| 240 |
+
a model to generate images more aligned with `prompt` at the expense of lower image quality.
|
| 241 |
+
|
| 242 |
+
Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
|
| 243 |
+
the [paper](https://huggingface.co/papers/2210.03142) to learn more.
|
| 244 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 245 |
+
The number of images to generate per prompt.
|
| 246 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 247 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 248 |
+
to make generation deterministic.
|
| 249 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 250 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 251 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 252 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 253 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 254 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 255 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 256 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 257 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 258 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 259 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| 260 |
+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| 261 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
| 262 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
| 263 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 264 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 265 |
+
The output format of the generate image. Choose between
|
| 266 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 267 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 268 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
| 269 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 270 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 271 |
+
`self.processor` in
|
| 272 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 273 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 274 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 275 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 276 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 277 |
+
`callback_on_step_end_tensor_inputs`.
|
| 278 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 279 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 280 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 281 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 282 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
| 283 |
+
|
| 284 |
+
Returns:
|
| 285 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
| 286 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
| 287 |
+
images.
|
| 288 |
+
"""
|
| 289 |
+
|
| 290 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 291 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 292 |
+
|
| 293 |
+
# 1. Check inputs. Raise error if not correct
|
| 294 |
+
self.check_inputs(
|
| 295 |
+
prompt,
|
| 296 |
+
prompt_2,
|
| 297 |
+
height,
|
| 298 |
+
width,
|
| 299 |
+
prompt_embeds=prompt_embeds,
|
| 300 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 301 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 302 |
+
max_sequence_length=max_sequence_length,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
self._guidance_scale = guidance_scale
|
| 306 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 307 |
+
self._current_timestep = None
|
| 308 |
+
self._interrupt = False
|
| 309 |
+
|
| 310 |
+
# 2. Define call parameters
|
| 311 |
+
if prompt is not None and isinstance(prompt, str):
|
| 312 |
+
batch_size = 1
|
| 313 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 314 |
+
batch_size = len(prompt)
|
| 315 |
+
else:
|
| 316 |
+
batch_size = prompt_embeds.shape[0]
|
| 317 |
+
|
| 318 |
+
device = self._execution_device
|
| 319 |
+
|
| 320 |
+
# 3. Prepare prompt embeddings
|
| 321 |
+
lora_scale = (
|
| 322 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 323 |
+
)
|
| 324 |
+
(
|
| 325 |
+
prompt_embeds,
|
| 326 |
+
pooled_prompt_embeds,
|
| 327 |
+
text_ids,
|
| 328 |
+
) = self.encode_prompt(
|
| 329 |
+
prompt=prompt,
|
| 330 |
+
prompt_2=prompt_2,
|
| 331 |
+
prompt_embeds=prompt_embeds,
|
| 332 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 333 |
+
device=device,
|
| 334 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 335 |
+
max_sequence_length=max_sequence_length,
|
| 336 |
+
lora_scale=lora_scale,
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
# 4. Prepare latent variables
|
| 340 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
| 341 |
+
latents, latent_image_ids = self.prepare_latents(
|
| 342 |
+
batch_size * num_images_per_prompt,
|
| 343 |
+
num_channels_latents,
|
| 344 |
+
height,
|
| 345 |
+
width,
|
| 346 |
+
torch.float32,
|
| 347 |
+
device,
|
| 348 |
+
generator,
|
| 349 |
+
latents,
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
# 5. Prepare timesteps
|
| 353 |
+
raw_timesteps, num_inference_substeps, total_substeps = retrieve_raw_timesteps(
|
| 354 |
+
num_inference_steps, total_substeps, final_step_size_scale)
|
| 355 |
+
image_seq_len = latents.shape[1]
|
| 356 |
+
mu = calculate_shift(
|
| 357 |
+
image_seq_len,
|
| 358 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 359 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 360 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 361 |
+
self.scheduler.config.get("max_shift", 1.15),
|
| 362 |
+
)
|
| 363 |
+
timesteps, _ = retrieve_timesteps(
|
| 364 |
+
self.scheduler,
|
| 365 |
+
num_inference_steps,
|
| 366 |
+
device,
|
| 367 |
+
sigmas=raw_timesteps,
|
| 368 |
+
mu=mu,
|
| 369 |
+
)
|
| 370 |
+
assert len(timesteps) == total_substeps
|
| 371 |
+
self._num_timesteps = total_substeps
|
| 372 |
+
|
| 373 |
+
# handle guidance
|
| 374 |
+
if self.transformer.config.guidance_embeds:
|
| 375 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 376 |
+
guidance = guidance.expand(latents.shape[0])
|
| 377 |
+
else:
|
| 378 |
+
guidance = None
|
| 379 |
+
|
| 380 |
+
if self.joint_attention_kwargs is None:
|
| 381 |
+
self._joint_attention_kwargs = {}
|
| 382 |
+
|
| 383 |
+
image_embeds = None
|
| 384 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| 385 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 386 |
+
ip_adapter_image,
|
| 387 |
+
ip_adapter_image_embeds,
|
| 388 |
+
device,
|
| 389 |
+
batch_size * num_images_per_prompt,
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
# 6. Denoising loop
|
| 393 |
+
self.scheduler.set_begin_index(0)
|
| 394 |
+
timestep_id = 0
|
| 395 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 396 |
+
for i in range(num_inference_steps):
|
| 397 |
+
if self.interrupt:
|
| 398 |
+
continue
|
| 399 |
+
|
| 400 |
+
t_src = timesteps[timestep_id]
|
| 401 |
+
sigma_t_src = t_src / self.scheduler.config.num_train_timesteps
|
| 402 |
+
is_final_step = i == (num_inference_steps - 1)
|
| 403 |
+
|
| 404 |
+
self._current_timestep = t_src
|
| 405 |
+
if image_embeds is not None:
|
| 406 |
+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
|
| 407 |
+
|
| 408 |
+
with self.transformer.cache_context("cond"):
|
| 409 |
+
denoising_output = self.transformer(
|
| 410 |
+
hidden_states=latents.to(dtype=self.transformer.dtype),
|
| 411 |
+
timestep=t_src.expand(latents.shape[0]) / 1000,
|
| 412 |
+
guidance=guidance,
|
| 413 |
+
pooled_projections=pooled_prompt_embeds,
|
| 414 |
+
encoder_hidden_states=prompt_embeds,
|
| 415 |
+
txt_ids=text_ids,
|
| 416 |
+
img_ids=latent_image_ids,
|
| 417 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
# unpack and create policy
|
| 421 |
+
latents = self._unpack_latents(
|
| 422 |
+
latents, height, width, self.vae_scale_factor, target_patch_size=1)
|
| 423 |
+
if self.policy_type == 'GMFlow':
|
| 424 |
+
denoising_output = self._unpack_gm(
|
| 425 |
+
denoising_output, height, width, num_channels_latents, gm_patch_size=1)
|
| 426 |
+
denoising_output = {k: v.to(torch.float32) for k, v in denoising_output.items()}
|
| 427 |
+
policy = self.policy_class(
|
| 428 |
+
denoising_output, latents, sigma_t_src)
|
| 429 |
+
if not is_final_step:
|
| 430 |
+
if temperature == 'auto':
|
| 431 |
+
temperature = min(max(0.1 * (num_inference_steps - 1), 0), 1)
|
| 432 |
+
else:
|
| 433 |
+
assert isinstance(temperature, (float, int))
|
| 434 |
+
policy.temperature_(temperature)
|
| 435 |
+
elif self.policy_type == 'DX':
|
| 436 |
+
denoising_output = denoising_output[0]
|
| 437 |
+
denoising_output = self._unpack_latents(
|
| 438 |
+
denoising_output, height, width, self.vae_scale_factor, target_patch_size=1)
|
| 439 |
+
denoising_output = denoising_output.reshape(latents.size(0), -1, *latents.shape[1:])
|
| 440 |
+
denoising_output = denoising_output.to(torch.float32)
|
| 441 |
+
policy = self.policy_class(
|
| 442 |
+
denoising_output, latents, sigma_t_src)
|
| 443 |
+
else:
|
| 444 |
+
raise ValueError(f'Unknown policy type: {self.policy_type}.')
|
| 445 |
+
|
| 446 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 447 |
+
for _ in range(num_inference_substeps[i]):
|
| 448 |
+
t = timesteps[timestep_id]
|
| 449 |
+
sigma_t = t / self.scheduler.config.num_train_timesteps
|
| 450 |
+
u = policy.u(latents, sigma_t)
|
| 451 |
+
latents = self.scheduler.step(u, t, latents, return_dict=False)[0]
|
| 452 |
+
timestep_id += 1
|
| 453 |
+
|
| 454 |
+
# repack
|
| 455 |
+
latents = self._pack_latents(
|
| 456 |
+
latents, latents.size(0), num_channels_latents,
|
| 457 |
+
2 * (int(height) // (self.vae_scale_factor * 2)),
|
| 458 |
+
2 * (int(width) // (self.vae_scale_factor * 2)),
|
| 459 |
+
patch_size=1)
|
| 460 |
+
|
| 461 |
+
if callback_on_step_end is not None:
|
| 462 |
+
callback_kwargs = {}
|
| 463 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 464 |
+
callback_kwargs[k] = locals()[k]
|
| 465 |
+
callback_outputs = callback_on_step_end(self, i, t_src, callback_kwargs)
|
| 466 |
+
|
| 467 |
+
latents = callback_outputs.pop("latents", latents)
|
| 468 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 469 |
+
|
| 470 |
+
progress_bar.update()
|
| 471 |
+
|
| 472 |
+
if XLA_AVAILABLE:
|
| 473 |
+
xm.mark_step()
|
| 474 |
+
|
| 475 |
+
self._current_timestep = None
|
| 476 |
+
|
| 477 |
+
if output_type == "latent":
|
| 478 |
+
image = latents
|
| 479 |
+
else:
|
| 480 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 481 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 482 |
+
image = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
|
| 483 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 484 |
+
|
| 485 |
+
# Offload all models
|
| 486 |
+
self.maybe_free_model_hooks()
|
| 487 |
+
|
| 488 |
+
if not return_dict:
|
| 489 |
+
return (image,)
|
| 490 |
+
|
| 491 |
+
return FluxPipelineOutput(images=image)
|
lakonlab/pipelines/piqwen_pipeline.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Hansheng Chen
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from typing import Dict, List, Optional, Union, Any, Callable
|
| 7 |
+
from functools import partial
|
| 8 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
|
| 9 |
+
from diffusers.utils import is_torch_xla_available
|
| 10 |
+
from diffusers.models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
|
| 11 |
+
from diffusers.pipelines.qwenimage.pipeline_qwenimage import (
|
| 12 |
+
QwenImagePipeline, calculate_shift, retrieve_timesteps, QwenImagePipelineOutput)
|
| 13 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 14 |
+
from lakonlab.models.diffusions.piflow_policies import POLICY_CLASSES
|
| 15 |
+
from .piflow_loader import PiFlowLoaderMixin
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
if is_torch_xla_available():
|
| 19 |
+
import torch_xla.core.xla_model as xm
|
| 20 |
+
|
| 21 |
+
XLA_AVAILABLE = True
|
| 22 |
+
else:
|
| 23 |
+
XLA_AVAILABLE = False
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def retrieve_raw_timesteps(
|
| 27 |
+
num_inference_steps: int,
|
| 28 |
+
total_substeps: int,
|
| 29 |
+
final_step_size_scale: float
|
| 30 |
+
):
|
| 31 |
+
r"""
|
| 32 |
+
Retrieve the raw times and the number of substeps for each inference step.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
num_inference_steps (`int`):
|
| 36 |
+
Number of inference steps.
|
| 37 |
+
total_substeps (`int`):
|
| 38 |
+
Total number of substeps (e.g., 128).
|
| 39 |
+
final_step_size_scale (`float`):
|
| 40 |
+
Scale for the final step size (e.g., 0.5).
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
`Tuple[List[float], List[int], int]`: A tuple where the first element is the raw timestep schedule, the second
|
| 44 |
+
element is the number of substeps for each inference step, and the third element is the rounded total number of
|
| 45 |
+
substeps.
|
| 46 |
+
"""
|
| 47 |
+
base_segment_size = 1 / (num_inference_steps - 1 + final_step_size_scale)
|
| 48 |
+
raw_timesteps = []
|
| 49 |
+
num_inference_substeps = []
|
| 50 |
+
_raw_t = 1.0
|
| 51 |
+
for i in range(num_inference_steps):
|
| 52 |
+
if i < num_inference_steps - 1:
|
| 53 |
+
segment_size = base_segment_size
|
| 54 |
+
else:
|
| 55 |
+
segment_size = base_segment_size * final_step_size_scale
|
| 56 |
+
_num_inference_substeps = max(round(segment_size * total_substeps), 1)
|
| 57 |
+
num_inference_substeps.append(_num_inference_substeps)
|
| 58 |
+
raw_timesteps.extend(np.linspace(
|
| 59 |
+
_raw_t, _raw_t - segment_size, _num_inference_substeps, endpoint=False).clip(min=0.0).tolist())
|
| 60 |
+
_raw_t = _raw_t - segment_size
|
| 61 |
+
total_substeps = sum(num_inference_substeps)
|
| 62 |
+
return raw_timesteps, num_inference_substeps, total_substeps
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class PiQwenImagePipeline(QwenImagePipeline, PiFlowLoaderMixin):
|
| 66 |
+
r"""
|
| 67 |
+
The policy-based QwenImage pipeline for text-to-image generation.
|
| 68 |
+
|
| 69 |
+
Reference: Todo: add paper link
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
transformer ([`QwenImageTransformer2DModel`]):
|
| 73 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
| 74 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
| 75 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 76 |
+
vae ([`AutoencoderKL`]):
|
| 77 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 78 |
+
text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
|
| 79 |
+
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
|
| 80 |
+
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
|
| 81 |
+
tokenizer (`QwenTokenizer`):
|
| 82 |
+
Tokenizer of class
|
| 83 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 84 |
+
policy_type (`str`, *optional*, defaults to `"GMFlow"`):
|
| 85 |
+
The type of flow policy to use. Currently supports `"GMFlow"` and `"DX"`.
|
| 86 |
+
policy_kwargs (`Dict`, *optional*):
|
| 87 |
+
Additional keyword arguments to pass to the policy class.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(
|
| 91 |
+
self,
|
| 92 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 93 |
+
vae: AutoencoderKLQwenImage,
|
| 94 |
+
text_encoder: Qwen2_5_VLForConditionalGeneration,
|
| 95 |
+
tokenizer: Qwen2Tokenizer,
|
| 96 |
+
transformer: QwenImageTransformer2DModel,
|
| 97 |
+
policy_type: str = 'GMFlow',
|
| 98 |
+
policy_kwargs: Optional[Dict[str, Any]] = None,
|
| 99 |
+
):
|
| 100 |
+
super().__init__(
|
| 101 |
+
scheduler,
|
| 102 |
+
vae,
|
| 103 |
+
text_encoder,
|
| 104 |
+
tokenizer,
|
| 105 |
+
transformer,
|
| 106 |
+
)
|
| 107 |
+
assert policy_type in POLICY_CLASSES, f'Invalid policy: {policy_type}. Supported policies are {list(POLICY_CLASSES.keys())}.'
|
| 108 |
+
self.policy_type = policy_type
|
| 109 |
+
self.policy_class = partial(
|
| 110 |
+
POLICY_CLASSES[policy_type], **policy_kwargs
|
| 111 |
+
) if policy_kwargs else POLICY_CLASSES[policy_type]
|
| 112 |
+
|
| 113 |
+
def _unpack_gm(self, gm, height, width, num_channels_latents, patch_size=2, gm_patch_size=1):
|
| 114 |
+
c = num_channels_latents * patch_size * patch_size
|
| 115 |
+
h = (int(height) // (self.vae_scale_factor * patch_size))
|
| 116 |
+
w = (int(width) // (self.vae_scale_factor * patch_size))
|
| 117 |
+
bs = gm['means'].size(0)
|
| 118 |
+
k = self.transformer.num_gaussians
|
| 119 |
+
scale = patch_size // gm_patch_size
|
| 120 |
+
gm['means'] = gm['means'].reshape(
|
| 121 |
+
bs, h, w, k, c // (scale * scale), scale, scale
|
| 122 |
+
).permute(
|
| 123 |
+
0, 3, 4, 1, 5, 2, 6
|
| 124 |
+
).reshape(
|
| 125 |
+
bs, k, c // (scale * scale), h * scale, w * scale)
|
| 126 |
+
gm['logweights'] = gm['logweights'].reshape(
|
| 127 |
+
bs, h, w, k, 1, scale, scale
|
| 128 |
+
).permute(
|
| 129 |
+
0, 3, 4, 1, 5, 2, 6
|
| 130 |
+
).reshape(
|
| 131 |
+
bs, k, 1, h * scale, w * scale)
|
| 132 |
+
gm['logstds'] = gm['logstds'].reshape(bs, 1, 1, 1, 1)
|
| 133 |
+
return gm
|
| 134 |
+
|
| 135 |
+
@staticmethod
|
| 136 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width, patch_size=1, target_patch_size=2):
|
| 137 |
+
scale = target_patch_size // patch_size
|
| 138 |
+
latents = latents.view(
|
| 139 |
+
batch_size,
|
| 140 |
+
num_channels_latents * patch_size * patch_size,
|
| 141 |
+
height // target_patch_size, scale, width // target_patch_size, scale)
|
| 142 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
| 143 |
+
latents = latents.reshape(
|
| 144 |
+
batch_size,
|
| 145 |
+
(height // target_patch_size) * (width // target_patch_size),
|
| 146 |
+
num_channels_latents * target_patch_size * target_patch_size)
|
| 147 |
+
|
| 148 |
+
return latents
|
| 149 |
+
|
| 150 |
+
@staticmethod
|
| 151 |
+
def _unpack_latents(latents, height, width, vae_scale_factor, patch_size=2, target_patch_size=1):
|
| 152 |
+
batch_size, num_patches, channels = latents.shape
|
| 153 |
+
scale = patch_size // target_patch_size
|
| 154 |
+
|
| 155 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 156 |
+
# latent height and width to be divisible by 2.
|
| 157 |
+
height = (int(height) // (vae_scale_factor * patch_size))
|
| 158 |
+
width = (int(width) // (vae_scale_factor * patch_size))
|
| 159 |
+
|
| 160 |
+
latents = latents.view(
|
| 161 |
+
batch_size, height, width, channels // (scale * scale), scale, scale)
|
| 162 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
| 163 |
+
|
| 164 |
+
latents = latents.reshape(batch_size, channels // (scale * scale), height * scale, width * scale)
|
| 165 |
+
|
| 166 |
+
return latents
|
| 167 |
+
|
| 168 |
+
@torch.inference_mode()
|
| 169 |
+
def __call__(
|
| 170 |
+
self,
|
| 171 |
+
prompt: Union[str, List[str]] = None,
|
| 172 |
+
height: Optional[int] = None,
|
| 173 |
+
width: Optional[int] = None,
|
| 174 |
+
num_inference_steps: int = 4,
|
| 175 |
+
total_substeps: int = 128,
|
| 176 |
+
final_step_size_scale: float = 0.5,
|
| 177 |
+
temperature: Union[float, str] = 'auto',
|
| 178 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 179 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 180 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 181 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 182 |
+
prompt_embeds_mask: Optional[torch.Tensor] = None,
|
| 183 |
+
output_type: Optional[str] = "pil",
|
| 184 |
+
return_dict: bool = True,
|
| 185 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 186 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 187 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 188 |
+
max_sequence_length: int = 512,
|
| 189 |
+
):
|
| 190 |
+
r"""
|
| 191 |
+
Function invoked when calling the pipeline for generation.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 195 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 196 |
+
instead.
|
| 197 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 198 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 199 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 200 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 201 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 202 |
+
The number of denoising steps.
|
| 203 |
+
total_substeps (`int`, *optional*, defaults to 128):
|
| 204 |
+
The total number of substeps for policy-based flow integration.
|
| 205 |
+
final_step_size_scale (`float`, *optional*, defaults to 0.5):
|
| 206 |
+
The scale for the final step size.
|
| 207 |
+
temperature (`float` or `"auto"`, *optional*, defaults to `"auto"`):
|
| 208 |
+
The tmperature parameter for the flow policy.
|
| 209 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 210 |
+
The number of images to generate per prompt.
|
| 211 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 212 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 213 |
+
to make generation deterministic.
|
| 214 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 215 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 216 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 217 |
+
tensor will be generated by sampling using the supplied random `generator`.
|
| 218 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 219 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 220 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 221 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 222 |
+
The output format of the generate image. Choose between
|
| 223 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 224 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 225 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
| 226 |
+
attention_kwargs (`dict`, *optional*):
|
| 227 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 228 |
+
`self.processor` in
|
| 229 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 230 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 231 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 232 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 233 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 234 |
+
`callback_on_step_end_tensor_inputs`.
|
| 235 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 236 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 237 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 238 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 239 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
| 243 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
| 244 |
+
images.
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 248 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 249 |
+
|
| 250 |
+
# 1. Check inputs. Raise error if not correct
|
| 251 |
+
self.check_inputs(
|
| 252 |
+
prompt,
|
| 253 |
+
height,
|
| 254 |
+
width,
|
| 255 |
+
prompt_embeds=prompt_embeds,
|
| 256 |
+
prompt_embeds_mask=prompt_embeds_mask,
|
| 257 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 258 |
+
max_sequence_length=max_sequence_length,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
self._attention_kwargs = attention_kwargs
|
| 262 |
+
self._current_timestep = None
|
| 263 |
+
self._interrupt = False
|
| 264 |
+
|
| 265 |
+
# 2. Define call parameters
|
| 266 |
+
if prompt is not None and isinstance(prompt, str):
|
| 267 |
+
batch_size = 1
|
| 268 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 269 |
+
batch_size = len(prompt)
|
| 270 |
+
else:
|
| 271 |
+
batch_size = prompt_embeds.shape[0]
|
| 272 |
+
|
| 273 |
+
device = self._execution_device
|
| 274 |
+
|
| 275 |
+
# 3. Prepare prompt embeddings
|
| 276 |
+
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
|
| 277 |
+
prompt=prompt,
|
| 278 |
+
prompt_embeds=prompt_embeds,
|
| 279 |
+
prompt_embeds_mask=prompt_embeds_mask,
|
| 280 |
+
device=device,
|
| 281 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 282 |
+
max_sequence_length=max_sequence_length,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# 4. Prepare latent variables
|
| 286 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
| 287 |
+
latents = self.prepare_latents(
|
| 288 |
+
batch_size * num_images_per_prompt,
|
| 289 |
+
num_channels_latents,
|
| 290 |
+
height,
|
| 291 |
+
width,
|
| 292 |
+
torch.float32,
|
| 293 |
+
device,
|
| 294 |
+
generator,
|
| 295 |
+
latents,
|
| 296 |
+
)
|
| 297 |
+
img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size
|
| 298 |
+
|
| 299 |
+
# 5. Prepare timesteps
|
| 300 |
+
raw_timesteps, num_inference_substeps, total_substeps = retrieve_raw_timesteps(
|
| 301 |
+
num_inference_steps, total_substeps, final_step_size_scale)
|
| 302 |
+
image_seq_len = latents.shape[1]
|
| 303 |
+
mu = calculate_shift(
|
| 304 |
+
image_seq_len,
|
| 305 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 306 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 307 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 308 |
+
self.scheduler.config.get("max_shift", 1.15),
|
| 309 |
+
)
|
| 310 |
+
timesteps, _ = retrieve_timesteps(
|
| 311 |
+
self.scheduler,
|
| 312 |
+
num_inference_steps,
|
| 313 |
+
device,
|
| 314 |
+
sigmas=raw_timesteps,
|
| 315 |
+
mu=mu,
|
| 316 |
+
)
|
| 317 |
+
assert len(timesteps) == total_substeps
|
| 318 |
+
self._num_timesteps = total_substeps
|
| 319 |
+
|
| 320 |
+
if self.attention_kwargs is None:
|
| 321 |
+
self._attention_kwargs = {}
|
| 322 |
+
|
| 323 |
+
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
|
| 324 |
+
|
| 325 |
+
# 6. Denoising loop
|
| 326 |
+
self.scheduler.set_begin_index(0)
|
| 327 |
+
timestep_id = 0
|
| 328 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 329 |
+
for i in range(num_inference_steps):
|
| 330 |
+
if self.interrupt:
|
| 331 |
+
continue
|
| 332 |
+
|
| 333 |
+
t_src = timesteps[timestep_id]
|
| 334 |
+
sigma_t_src = t_src / self.scheduler.config.num_train_timesteps
|
| 335 |
+
is_final_step = i == (num_inference_steps - 1)
|
| 336 |
+
|
| 337 |
+
self._current_timestep = t_src
|
| 338 |
+
|
| 339 |
+
with self.transformer.cache_context("cond"):
|
| 340 |
+
denoising_output = self.transformer(
|
| 341 |
+
hidden_states=latents.to(dtype=self.transformer.dtype),
|
| 342 |
+
timestep=t_src.expand(latents.shape[0]) / 1000,
|
| 343 |
+
encoder_hidden_states_mask=prompt_embeds_mask,
|
| 344 |
+
encoder_hidden_states=prompt_embeds,
|
| 345 |
+
img_shapes=img_shapes,
|
| 346 |
+
txt_seq_lens=txt_seq_lens,
|
| 347 |
+
attention_kwargs=self.attention_kwargs,
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
# unpack and create policy
|
| 351 |
+
latents = self._unpack_latents(
|
| 352 |
+
latents, height, width, self.vae_scale_factor, target_patch_size=1)
|
| 353 |
+
if self.policy_type == 'GMFlow':
|
| 354 |
+
denoising_output = self._unpack_gm(
|
| 355 |
+
denoising_output, height, width, num_channels_latents, gm_patch_size=1)
|
| 356 |
+
denoising_output = {k: v.to(torch.float32) for k, v in denoising_output.items()}
|
| 357 |
+
policy = self.policy_class(
|
| 358 |
+
denoising_output, latents, sigma_t_src)
|
| 359 |
+
if not is_final_step:
|
| 360 |
+
if temperature == 'auto':
|
| 361 |
+
temperature = min(max(0.1 * (num_inference_steps - 1), 0), 1)
|
| 362 |
+
else:
|
| 363 |
+
assert isinstance(temperature, (float, int))
|
| 364 |
+
policy.temperature_(temperature)
|
| 365 |
+
elif self.policy_type == 'DX':
|
| 366 |
+
denoising_output = denoising_output[0]
|
| 367 |
+
denoising_output = self._unpack_latents(
|
| 368 |
+
denoising_output, height, width, self.vae_scale_factor, target_patch_size=1)
|
| 369 |
+
denoising_output = denoising_output.reshape(latents.size(0), -1, *latents.shape[1:])
|
| 370 |
+
denoising_output = denoising_output.to(torch.float32)
|
| 371 |
+
policy = self.policy_class(
|
| 372 |
+
denoising_output, latents, sigma_t_src)
|
| 373 |
+
else:
|
| 374 |
+
raise ValueError(f'Unknown policy type: {self.policy_type}.')
|
| 375 |
+
|
| 376 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 377 |
+
for _ in range(num_inference_substeps[i]):
|
| 378 |
+
t = timesteps[timestep_id]
|
| 379 |
+
sigma_t = t / self.scheduler.config.num_train_timesteps
|
| 380 |
+
u = policy.u(latents, sigma_t)
|
| 381 |
+
latents = self.scheduler.step(u, t, latents, return_dict=False)[0]
|
| 382 |
+
timestep_id += 1
|
| 383 |
+
|
| 384 |
+
# repack
|
| 385 |
+
latents = self._pack_latents(
|
| 386 |
+
latents, latents.size(0), num_channels_latents,
|
| 387 |
+
2 * (int(height) // (self.vae_scale_factor * 2)),
|
| 388 |
+
2 * (int(width) // (self.vae_scale_factor * 2)),
|
| 389 |
+
patch_size=1)
|
| 390 |
+
|
| 391 |
+
if callback_on_step_end is not None:
|
| 392 |
+
callback_kwargs = {}
|
| 393 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 394 |
+
callback_kwargs[k] = locals()[k]
|
| 395 |
+
callback_outputs = callback_on_step_end(self, i, t_src, callback_kwargs)
|
| 396 |
+
|
| 397 |
+
latents = callback_outputs.pop("latents", latents)
|
| 398 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 399 |
+
|
| 400 |
+
progress_bar.update()
|
| 401 |
+
|
| 402 |
+
if XLA_AVAILABLE:
|
| 403 |
+
xm.mark_step()
|
| 404 |
+
|
| 405 |
+
self._current_timestep = None
|
| 406 |
+
|
| 407 |
+
if output_type == "latent":
|
| 408 |
+
image = latents
|
| 409 |
+
else:
|
| 410 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)[:, :, None]
|
| 411 |
+
latents_mean = (
|
| 412 |
+
torch.tensor(self.vae.config.latents_mean)
|
| 413 |
+
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
| 414 |
+
.to(latents.device, latents.dtype)
|
| 415 |
+
)
|
| 416 |
+
latents_std = torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
| 417 |
+
latents.device, latents.dtype
|
| 418 |
+
)
|
| 419 |
+
latents = latents * latents_std + latents_mean
|
| 420 |
+
image = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0][:, :, 0]
|
| 421 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 422 |
+
|
| 423 |
+
# Offload all models
|
| 424 |
+
self.maybe_free_model_hooks()
|
| 425 |
+
|
| 426 |
+
if not return_dict:
|
| 427 |
+
return (image,)
|
| 428 |
+
|
| 429 |
+
return QwenImagePipelineOutput(images=image)
|
lakonlab/ui/__init__.py
ADDED
|
File without changes
|
lakonlab/ui/gradio/__init__.py
ADDED
|
File without changes
|
lakonlab/ui/gradio/create_text_to_img.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from .shared_opts import create_base_opts, create_generate_bar, set_seed, create_prompt_opts
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def create_interface_text_to_img(
|
| 6 |
+
api, prompt='', seed=42, steps=32, min_steps=4, max_steps=50, steps_slider_step=1,
|
| 7 |
+
height=768, width=1360, hw_slider_step=16,
|
| 8 |
+
guidance_scale=None, temperature=None, api_name='text_to_img',
|
| 9 |
+
create_negative_prompt=False, args=['last_seed', 'prompt', 'width', 'height', 'steps', 'guidance_scale']):
|
| 10 |
+
var_dict = dict()
|
| 11 |
+
with gr.Blocks(analytics_enabled=False) as interface:
|
| 12 |
+
var_dict['output_image'] = gr.Image(
|
| 13 |
+
type='pil', image_mode='RGB', label='Output image', interactive=False, elem_classes=['vh-img'])
|
| 14 |
+
create_prompt_opts(var_dict, create_negative_prompt=create_negative_prompt, prompt=prompt)
|
| 15 |
+
with gr.Column(variant='compact', elem_classes=['custom-spacing']):
|
| 16 |
+
with gr.Row(variant='compact', elem_classes=['force-hide-container']):
|
| 17 |
+
var_dict['width'] = gr.Slider(
|
| 18 |
+
label='Width', minimum=64, maximum=2048, step=hw_slider_step, value=width,
|
| 19 |
+
elem_classes=['force-hide-container'])
|
| 20 |
+
var_dict['switch_hw'] = gr.Button('\U000021C6', elem_classes=['tool'])
|
| 21 |
+
var_dict['height'] = gr.Slider(
|
| 22 |
+
label='Height', minimum=64, maximum=2048, step=hw_slider_step, value=height,
|
| 23 |
+
elem_classes=['force-hide-container'])
|
| 24 |
+
var_dict['switch_hw'].click(
|
| 25 |
+
fn=lambda w, h: (h, w),
|
| 26 |
+
inputs=[var_dict['width'], var_dict['height']],
|
| 27 |
+
outputs=[var_dict['width'], var_dict['height']],
|
| 28 |
+
show_progress=False,
|
| 29 |
+
api_name=False)
|
| 30 |
+
create_generate_bar(var_dict, text='Generate', seed=seed)
|
| 31 |
+
create_base_opts(
|
| 32 |
+
var_dict,
|
| 33 |
+
steps=steps,
|
| 34 |
+
min_steps=min_steps,
|
| 35 |
+
max_steps=max_steps,
|
| 36 |
+
steps_slider_step=steps_slider_step,
|
| 37 |
+
guidance_scale=guidance_scale,
|
| 38 |
+
temperature=temperature)
|
| 39 |
+
|
| 40 |
+
var_dict['run_btn'].click(
|
| 41 |
+
fn=set_seed,
|
| 42 |
+
inputs=var_dict['seed'],
|
| 43 |
+
outputs=var_dict['last_seed'],
|
| 44 |
+
show_progress=False,
|
| 45 |
+
api_name=False
|
| 46 |
+
).success(
|
| 47 |
+
fn=api,
|
| 48 |
+
inputs=[var_dict[arg] for arg in args],
|
| 49 |
+
outputs=var_dict['output_image'],
|
| 50 |
+
concurrency_id='default_group', api_name=api_name
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
return interface, var_dict
|
lakonlab/ui/gradio/shared_opts.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import gradio as gr
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def create_prompt_opts(var_dict, create_negative_prompt=True, prompt='', negatove_prompt=''):
|
| 6 |
+
var_dict['prompt'] = gr.Textbox(
|
| 7 |
+
prompt, label='Prompt', show_label=False, lines=2, placeholder='Prompt', container=False, interactive=True)
|
| 8 |
+
if create_negative_prompt:
|
| 9 |
+
var_dict['negative_prompt'] = gr.Textbox(
|
| 10 |
+
negatove_prompt, label='Negative prompt', show_label=False, lines=2,
|
| 11 |
+
placeholder='Negative prompt', container=False, interactive=True)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def create_generate_bar(var_dict, text='Generate', variant='primary', seed=-1):
|
| 15 |
+
with gr.Row(equal_height=False):
|
| 16 |
+
var_dict['run_btn'] = gr.Button(text, variant=variant, scale=2)
|
| 17 |
+
var_dict['seed'] = gr.Number(
|
| 18 |
+
label='Seed', value=seed, min_width=100, precision=0, minimum=-1, maximum=2 ** 31,
|
| 19 |
+
elem_classes=['force-hide-container'])
|
| 20 |
+
var_dict['random_seed'] = gr.Button('\U0001f3b2\ufe0f', elem_classes=['tool'])
|
| 21 |
+
var_dict['reuse_seed'] = gr.Button('\u267b\ufe0f', elem_classes=['tool'])
|
| 22 |
+
with gr.Column(visible=False):
|
| 23 |
+
var_dict['last_seed'] = gr.Number(value=seed, label='Last seed')
|
| 24 |
+
var_dict['reuse_seed'].click(
|
| 25 |
+
fn=lambda x: x,
|
| 26 |
+
inputs=var_dict['last_seed'],
|
| 27 |
+
outputs=var_dict['seed'],
|
| 28 |
+
show_progress=False,
|
| 29 |
+
api_name=False)
|
| 30 |
+
var_dict['random_seed'].click(
|
| 31 |
+
fn=lambda: -1,
|
| 32 |
+
outputs=var_dict['seed'],
|
| 33 |
+
show_progress=False,
|
| 34 |
+
api_name=False)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def create_base_opts(var_dict,
|
| 38 |
+
steps=24,
|
| 39 |
+
min_steps=4,
|
| 40 |
+
max_steps=50,
|
| 41 |
+
steps_slider_step=1,
|
| 42 |
+
guidance_scale=None,
|
| 43 |
+
temperature=None,
|
| 44 |
+
render=True):
|
| 45 |
+
with gr.Column(variant='compact', elem_classes=['custom-spacing'], render=render) as base_opts:
|
| 46 |
+
with gr.Row(variant='compact', elem_classes=['force-hide-container']):
|
| 47 |
+
var_dict['steps'] = gr.Slider(
|
| 48 |
+
min_steps, max_steps, value=steps, step=steps_slider_step, label='Sampling steps',
|
| 49 |
+
elem_classes=['force-hide-container'])
|
| 50 |
+
with gr.Row(variant='compact', elem_classes=['force-hide-container']):
|
| 51 |
+
if guidance_scale is not None:
|
| 52 |
+
var_dict['guidance_scale'] = gr.Slider(
|
| 53 |
+
0.0, 30.0, value=guidance_scale, step=0.5, label='Guidance scale',
|
| 54 |
+
elem_classes=['force-hide-container'])
|
| 55 |
+
if temperature is not None:
|
| 56 |
+
var_dict['temperature'] = gr.Slider(
|
| 57 |
+
0.0, 1.0, value=temperature, step=0.01, label='Temperature',
|
| 58 |
+
elem_classes=['force-hide-container'])
|
| 59 |
+
return base_opts
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def set_seed(seed):
|
| 63 |
+
seed = random.randint(0, 2**31) if seed == -1 else seed
|
| 64 |
+
return seed
|
lakonlab/ui/gradio/style.css
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.force-hide-container {
|
| 2 |
+
margin: 0;
|
| 3 |
+
box-shadow: none;
|
| 4 |
+
--block-border-width: 0;
|
| 5 |
+
background: transparent;
|
| 6 |
+
padding: 0;
|
| 7 |
+
overflow: visible;
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
.svelte-sfqy0y {
|
| 11 |
+
display: flex;
|
| 12 |
+
flex-direction: inherit;
|
| 13 |
+
flex-wrap: wrap;
|
| 14 |
+
gap: 0;
|
| 15 |
+
box-shadow: none;
|
| 16 |
+
border: 0;
|
| 17 |
+
border-radius: 0;
|
| 18 |
+
background: transparent;
|
| 19 |
+
overflow-y: hidden;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
.custom-spacing {
|
| 23 |
+
padding: 10px;
|
| 24 |
+
gap: 20px;
|
| 25 |
+
flex-grow: 0 !important;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
.unequal-height {
|
| 29 |
+
align-items: flex-end;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
.tool{
|
| 33 |
+
max-width: 40px;
|
| 34 |
+
min-width: 40px !important;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
/* Center the component and allow it to use the full row width */
|
| 38 |
+
.vh-img {
|
| 39 |
+
display: grid;
|
| 40 |
+
justify-items: center;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
/* Container should size to the image, but never exceed the row width */
|
| 44 |
+
.vh-img .image-container {
|
| 45 |
+
inline-size: fit-content !important; /* prefers image’s natural width */
|
| 46 |
+
max-inline-size: 100% !important; /* ...but clamps to available width */
|
| 47 |
+
margin-inline: auto;
|
| 48 |
+
overflow: hidden; /* avoid odd overflow on iOS */
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
/* Image scales by BOTH constraints: height cap and row width */
|
| 52 |
+
.vh-img .image-container img {
|
| 53 |
+
max-block-size: 700px !important; /* fixed max height cap */
|
| 54 |
+
max-inline-size: 100%; /* never wider than container */
|
| 55 |
+
inline-size: auto; /* keep aspect ratio */
|
| 56 |
+
block-size: auto;
|
| 57 |
+
object-fit: contain;
|
| 58 |
+
display: block;
|
| 59 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy==1.26.4
|
| 2 |
+
torch==2.6.0
|
| 3 |
+
diffusers==0.35.1
|
| 4 |
+
peft==0.17.0
|
| 5 |
+
sentencepiece
|
| 6 |
+
accelerate
|
| 7 |
+
transformers==4.54.1
|
| 8 |
+
gradio==4.18.0
|