Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
60f6b78
1
Parent(s):
90eb953
TAGv0.5
Browse files- README.md +1 -1
- app.py +129 -98
- pipelines/__init__.py +0 -0
- pipelines/__pycache__/__init__.cpython-310.pyc +0 -0
- pipelines/__pycache__/__init__.cpython-311.pyc +0 -0
- pipelines/__pycache__/pipeline_tag_stablediffusion.cpython-310.pyc +0 -0
- pipelines/__pycache__/pipeline_tag_stablediffusion.cpython-311.pyc +0 -0
- pipelines/__pycache__/pipeline_tag_stablediffusion3.cpython-310.pyc +0 -0
- pipelines/__pycache__/pipeline_tag_stablediffusionXL.cpython-310.pyc +0 -0
- pipelines/pipeline_tag_stablediffusion.py +390 -0
- pipelines/pipeline_tag_stablediffusion3.py +509 -0
- pipelines/pipeline_tag_stablediffusionXL.py +585 -0
README.md
CHANGED
|
@@ -8,7 +8,7 @@ sdk_version: 5.44.0
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
| 11 |
-
short_description:
|
| 12 |
---
|
| 13 |
|
| 14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
| 11 |
+
short_description: TAG Image Generation Demo on Unconditional Generation
|
| 12 |
---
|
| 13 |
|
| 14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
|
@@ -1,154 +1,185 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import numpy as np
|
| 3 |
import random
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
#
|
| 6 |
from diffusers import DiffusionPipeline
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
if torch.cuda.is_available():
|
| 13 |
-
torch_dtype = torch.float16
|
| 14 |
-
else:
|
| 15 |
-
torch_dtype = torch.float32
|
| 16 |
-
|
| 17 |
-
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
|
| 18 |
-
pipe = pipe.to(device)
|
| 19 |
|
|
|
|
|
|
|
| 20 |
MAX_SEED = np.iinfo(np.int32).max
|
| 21 |
-
MAX_IMAGE_SIZE = 1024
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
|
| 25 |
def infer(
|
| 26 |
-
|
| 27 |
-
negative_prompt,
|
| 28 |
seed,
|
| 29 |
randomize_seed,
|
| 30 |
width,
|
| 31 |
height,
|
| 32 |
-
guidance_scale,
|
| 33 |
num_inference_steps,
|
|
|
|
|
|
|
| 34 |
progress=gr.Progress(track_tqdm=True),
|
| 35 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
if randomize_seed:
|
| 37 |
seed = random.randint(0, MAX_SEED)
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
height=height,
|
| 48 |
-
|
|
|
|
| 49 |
).images[0]
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
|
| 55 |
-
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
| 56 |
-
"An astronaut riding a green horse",
|
| 57 |
-
"A delicious ceviche cheesecake slice",
|
| 58 |
-
]
|
| 59 |
|
|
|
|
| 60 |
css = """
|
| 61 |
-
#col-container {
|
| 62 |
-
margin: 0 auto;
|
| 63 |
-
max-width: 640px;
|
| 64 |
-
}
|
| 65 |
"""
|
| 66 |
|
| 67 |
with gr.Blocks(css=css) as demo:
|
| 68 |
with gr.Column(elem_id="col-container"):
|
| 69 |
-
gr.Markdown("
|
| 70 |
-
|
|
|
|
|
|
|
| 71 |
with gr.Row():
|
| 72 |
prompt = gr.Text(
|
| 73 |
-
label="Prompt",
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
placeholder="Enter your prompt",
|
| 77 |
-
container=False,
|
| 78 |
)
|
| 79 |
-
|
| 80 |
run_button = gr.Button("Run", scale=0, variant="primary")
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
with gr.Accordion("Advanced Settings", open=False):
|
| 85 |
-
negative_prompt = gr.Text(
|
| 86 |
-
label="Negative prompt",
|
| 87 |
-
max_lines=1,
|
| 88 |
-
placeholder="Enter a negative prompt",
|
| 89 |
-
visible=False,
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
seed = gr.Slider(
|
| 93 |
-
label="Seed",
|
| 94 |
-
minimum=0,
|
| 95 |
-
maximum=MAX_SEED,
|
| 96 |
-
step=1,
|
| 97 |
-
value=0,
|
| 98 |
)
|
| 99 |
-
|
| 100 |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
| 101 |
-
|
| 102 |
with gr.Row():
|
| 103 |
width = gr.Slider(
|
| 104 |
-
label="Width",
|
| 105 |
-
minimum=256,
|
| 106 |
-
maximum=MAX_IMAGE_SIZE,
|
| 107 |
-
step=32,
|
| 108 |
-
value=1024, # Replace with defaults that work for your model
|
| 109 |
)
|
| 110 |
-
|
| 111 |
height = gr.Slider(
|
| 112 |
-
label="Height",
|
| 113 |
-
minimum=256,
|
| 114 |
-
maximum=MAX_IMAGE_SIZE,
|
| 115 |
-
step=32,
|
| 116 |
-
value=1024, # Replace with defaults that work for your model
|
| 117 |
)
|
| 118 |
-
|
| 119 |
with gr.Row():
|
| 120 |
guidance_scale = gr.Slider(
|
| 121 |
-
label="
|
| 122 |
-
minimum=0.0,
|
| 123 |
-
maximum=10.0,
|
| 124 |
-
step=0.1,
|
| 125 |
-
value=0.0, # Replace with defaults that work for your model
|
| 126 |
)
|
| 127 |
-
|
| 128 |
num_inference_steps = gr.Slider(
|
| 129 |
-
label="
|
| 130 |
-
minimum=1,
|
| 131 |
-
maximum=50,
|
| 132 |
-
step=1,
|
| 133 |
-
value=2, # Replace with defaults that work for your model
|
| 134 |
)
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
fn=infer,
|
| 140 |
inputs=[
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
randomize_seed,
|
| 145 |
-
width,
|
| 146 |
-
height,
|
| 147 |
-
guidance_scale,
|
| 148 |
-
num_inference_steps,
|
| 149 |
],
|
| 150 |
-
outputs=[
|
| 151 |
)
|
| 152 |
|
| 153 |
if __name__ == "__main__":
|
| 154 |
-
demo.launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import numpy as np
|
| 3 |
import random
|
| 4 |
+
import torch
|
| 5 |
+
import spaces
|
| 6 |
|
| 7 |
+
# 사용할 파이프라인들을 모두 import 합니다.
|
| 8 |
from diffusers import DiffusionPipeline
|
| 9 |
+
# 'pipelines' 폴더가 있다고 가정합니다.
|
| 10 |
+
from pipelines.pipeline_tag_stablediffusion import StableDiffusionTangentialDecomposedPipeline
|
| 11 |
+
from pipelines.pipeline_tag_stablediffusion3 import StableDiffusion3TangentialDecomposedPipeline
|
| 12 |
+
from pipelines.pipeline_tag_stablediffusionXL import StableDiffusionXLTangentialDecomposedPipeline
|
| 13 |
+
|
| 14 |
+
# --- 설정 ---
|
| 15 |
+
MODEL_MAP = {
|
| 16 |
+
"SD 1.5": "runwayml/stable-diffusion-v1-5",
|
| 17 |
+
"SD 2.1": "stabilityai/stable-diffusion-2-1",
|
| 18 |
+
"SDXL": "stabilityai/stable-diffusion-xl-base-1.0",
|
| 19 |
+
"SD 3": "stabilityai/stable-diffusion-3-medium-diffusers",
|
| 20 |
+
}
|
| 21 |
+
RESOLUTION_MAP = { "SD 1.5": 512, "SD 2.1": 768, "SDXL": 1024, "SD 3": 1024 }
|
| 22 |
+
SEED_MAP = { "SD 1.5": 850728, "SD 2.1": 944905, "SDXL": 178914170, "SD 3": 282386105 }
|
| 23 |
+
TAG_SCALE_MAP = {
|
| 24 |
+
"SD 1.5": 1.15, # 기본값
|
| 25 |
+
"SD 2.1": 1.15, # 기본값
|
| 26 |
+
"SDXL": 1.20,
|
| 27 |
+
"SD 3": 1.05
|
| 28 |
+
}
|
| 29 |
+
PIPELINE_MAP = {
|
| 30 |
+
"SD 1.5": StableDiffusionTangentialDecomposedPipeline,
|
| 31 |
+
"SD 2.1": StableDiffusionTangentialDecomposedPipeline,
|
| 32 |
+
"SDXL": StableDiffusionXLTangentialDecomposedPipeline,
|
| 33 |
+
"SD 3": StableDiffusion3TangentialDecomposedPipeline,
|
| 34 |
+
}
|
| 35 |
|
| 36 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 37 |
+
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
+
pipe = None
|
| 40 |
+
current_model_id = None
|
| 41 |
MAX_SEED = np.iinfo(np.int32).max
|
|
|
|
| 42 |
|
| 43 |
+
# --- 함수 ---
|
| 44 |
+
|
| 45 |
+
def load_pipeline(model_name, progress):
|
| 46 |
+
global pipe, current_model_id
|
| 47 |
+
model_id = MODEL_MAP[model_name]
|
| 48 |
+
pipeline_class = PIPELINE_MAP[model_name]
|
| 49 |
+
progress(0, desc=f"Loading model: {model_id} with {pipeline_class.__name__}...")
|
| 50 |
+
pipe = pipeline_class.from_pretrained(model_id, torch_dtype=torch_dtype)
|
| 51 |
+
pipe = pipe.to(device)
|
| 52 |
+
current_model_id = model_id
|
| 53 |
+
progress(1)
|
| 54 |
+
|
| 55 |
+
def update_model_defaults(model_name):
|
| 56 |
+
"""모델 선택에 따라 해상도, 시드, 랜덤 시드 체크박스, TAG Scale을 업데이트합니다."""
|
| 57 |
+
res = RESOLUTION_MAP[model_name]
|
| 58 |
+
seed_val = SEED_MAP[model_name]
|
| 59 |
+
tag_scale_val = TAG_SCALE_MAP[model_name]
|
| 60 |
+
return (
|
| 61 |
+
gr.update(value=res),
|
| 62 |
+
gr.update(value=res),
|
| 63 |
+
gr.update(value=seed_val),
|
| 64 |
+
gr.update(value=False), # 'Randomize seed' 체크 해제
|
| 65 |
+
gr.update(value=tag_scale_val), # TAG Scale 업데이트
|
| 66 |
+
)
|
| 67 |
|
| 68 |
+
@spaces.GPU
|
| 69 |
def infer(
|
| 70 |
+
model_name,
|
|
|
|
| 71 |
seed,
|
| 72 |
randomize_seed,
|
| 73 |
width,
|
| 74 |
height,
|
| 75 |
+
guidance_scale, # 사용자 지정 TAG Scale
|
| 76 |
num_inference_steps,
|
| 77 |
+
guidance_start_timestep,
|
| 78 |
+
guidance_end_timestep,
|
| 79 |
progress=gr.Progress(track_tqdm=True),
|
| 80 |
):
|
| 81 |
+
global pipe, current_model_id
|
| 82 |
+
|
| 83 |
+
model_id = MODEL_MAP[model_name]
|
| 84 |
+
if model_id != current_model_id:
|
| 85 |
+
gr.Info(f"Changing model to {model_name}. Please wait...")
|
| 86 |
+
load_pipeline(model_name, progress)
|
| 87 |
+
|
| 88 |
if randomize_seed:
|
| 89 |
seed = random.randint(0, MAX_SEED)
|
| 90 |
|
| 91 |
+
generator_custom = torch.Generator(device=device).manual_seed(int(seed))
|
| 92 |
+
generator_fixed = torch.Generator(device=device).manual_seed(int(seed))
|
| 93 |
+
|
| 94 |
+
unconditional_prompt = ""
|
| 95 |
+
|
| 96 |
+
# 첫 번째 이미지 (사용자 지정 TAG Scale) 생성
|
| 97 |
+
image_custom_scale = pipe(
|
| 98 |
+
prompt=unconditional_prompt, guidance_scale=0.,
|
| 99 |
+
num_inference_steps=num_inference_steps, width=width, height=height, generator=generator_custom,
|
| 100 |
+
sta_tpd=guidance_start_timestep, end_tpd=guidance_end_timestep,
|
| 101 |
+
t_guidance_scale=guidance_scale
|
| 102 |
).images[0]
|
| 103 |
|
| 104 |
+
fixed_tag_value = 1.0
|
| 105 |
+
image_fixed_scale = pipe(
|
| 106 |
+
prompt=unconditional_prompt, guidance_scale=0.,
|
| 107 |
+
num_inference_steps=num_inference_steps, width=width, height=height, generator=generator_fixed,
|
| 108 |
+
sta_tpd=guidance_start_timestep, end_tpd=guidance_end_timestep,
|
| 109 |
+
t_guidance_scale=fixed_tag_value
|
| 110 |
+
).images[0]
|
| 111 |
|
| 112 |
+
return [image_fixed_scale, image_custom_scale], seed
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
+
# --- UI 구성 (Gradio) ---
|
| 115 |
css = """
|
| 116 |
+
#col-container { margin: 0 auto; max-width: 720px; }
|
|
|
|
|
|
|
|
|
|
| 117 |
"""
|
| 118 |
|
| 119 |
with gr.Blocks(css=css) as demo:
|
| 120 |
with gr.Column(elem_id="col-container"):
|
| 121 |
+
gr.Markdown("# Tangential Amplifying Guidance Demo")
|
| 122 |
+
model_selector = gr.Dropdown(
|
| 123 |
+
label="Select Model", choices=list(MODEL_MAP.keys()), value="SDXL"
|
| 124 |
+
)
|
| 125 |
with gr.Row():
|
| 126 |
prompt = gr.Text(
|
| 127 |
+
label="Prompt (Disabled)", show_label=True, max_lines=1,
|
| 128 |
+
placeholder="Unconditional generation mode. This input is ignored.",
|
| 129 |
+
container=True, interactive=False,
|
|
|
|
|
|
|
| 130 |
)
|
|
|
|
| 131 |
run_button = gr.Button("Run", scale=0, variant="primary")
|
| 132 |
+
|
| 133 |
+
# --- 2. gr.ImageSlider 컴포넌트로 변경 ---
|
| 134 |
+
result_slider = gr.ImageSlider(
|
| 135 |
+
label="Result Comparison (Fixed Scale vs. Your Scale)",
|
| 136 |
+
show_label=True
|
| 137 |
+
)
|
| 138 |
|
| 139 |
with gr.Accordion("Advanced Settings", open=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
seed = gr.Slider(
|
| 141 |
+
label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=SEED_MAP["SDXL"]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
)
|
|
|
|
| 143 |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
|
|
|
| 144 |
with gr.Row():
|
| 145 |
width = gr.Slider(
|
| 146 |
+
label="Width", minimum=256, maximum=1024, step=64, value=RESOLUTION_MAP["SDXL"]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
)
|
|
|
|
| 148 |
height = gr.Slider(
|
| 149 |
+
label="Height", minimum=256, maximum=1024, step=64, value=RESOLUTION_MAP["SDXL"]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
)
|
|
|
|
| 151 |
with gr.Row():
|
| 152 |
guidance_scale = gr.Slider(
|
| 153 |
+
label="TAG Scale", minimum=1.0, maximum=1.3, step=0.05, value=TAG_SCALE_MAP["SDXL"],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
)
|
|
|
|
| 155 |
num_inference_steps = gr.Slider(
|
| 156 |
+
label="Inference Steps", minimum=20, maximum=50, step=1, value=50
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
)
|
| 158 |
+
with gr.Row():
|
| 159 |
+
guidance_start_timestep = gr.Slider(
|
| 160 |
+
label="Guidance Start Timestep", minimum=0, maximum=1000, step=1, value=999
|
| 161 |
+
)
|
| 162 |
+
guidance_end_timestep = gr.Slider(
|
| 163 |
+
label="Guidance End Timestep", minimum=0, maximum=1000, step=1, value=0
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# --- 이벤트 리스너 ---
|
| 167 |
+
model_selector.change(
|
| 168 |
+
fn=update_model_defaults,
|
| 169 |
+
inputs=[model_selector],
|
| 170 |
+
outputs=[width, height, seed, randomize_seed, guidance_scale],
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# --- 3. outputs를 ImageSlider 컴포넌트로 지정 ---
|
| 174 |
+
run_button.click(
|
| 175 |
fn=infer,
|
| 176 |
inputs=[
|
| 177 |
+
model_selector, seed, randomize_seed, width, height,
|
| 178 |
+
guidance_scale, num_inference_steps,
|
| 179 |
+
guidance_start_timestep, guidance_end_timestep,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
],
|
| 181 |
+
outputs=[result_slider, seed],
|
| 182 |
)
|
| 183 |
|
| 184 |
if __name__ == "__main__":
|
| 185 |
+
demo.launch(debug=True)
|
pipelines/__init__.py
ADDED
|
File without changes
|
pipelines/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (129 Bytes). View file
|
|
|
pipelines/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (145 Bytes). View file
|
|
|
pipelines/__pycache__/pipeline_tag_stablediffusion.cpython-310.pyc
ADDED
|
Binary file (13.5 kB). View file
|
|
|
pipelines/__pycache__/pipeline_tag_stablediffusion.cpython-311.pyc
ADDED
|
Binary file (18.6 kB). View file
|
|
|
pipelines/__pycache__/pipeline_tag_stablediffusion3.cpython-310.pyc
ADDED
|
Binary file (18.7 kB). View file
|
|
|
pipelines/__pycache__/pipeline_tag_stablediffusionXL.cpython-310.pyc
ADDED
|
Binary file (22.1 kB). View file
|
|
|
pipelines/pipeline_tag_stablediffusion.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from diffusers import StableDiffusionPipeline
|
| 2 |
+
import torch
|
| 3 |
+
import inspect
|
| 4 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from packaging import version
|
| 8 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
| 9 |
+
|
| 10 |
+
from diffusers.utils import (
|
| 11 |
+
USE_PEFT_BACKEND,
|
| 12 |
+
deprecate,
|
| 13 |
+
logging,
|
| 14 |
+
replace_example_docstring,
|
| 15 |
+
scale_lora_layers,
|
| 16 |
+
unscale_lora_layers,
|
| 17 |
+
)
|
| 18 |
+
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
| 19 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 20 |
+
from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks
|
| 21 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
| 22 |
+
retrieve_timesteps,
|
| 23 |
+
rescale_noise_cfg,
|
| 24 |
+
EXAMPLE_DOC_STRING
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# 1. StableDiffusionPipeline을 상속받는 새로운 클래스 정의
|
| 28 |
+
class StableDiffusionTangentialDecomposedPipeline(StableDiffusionPipeline):
|
| 29 |
+
|
| 30 |
+
@torch.no_grad()
|
| 31 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 32 |
+
def __call__(
|
| 33 |
+
self,
|
| 34 |
+
prompt: Union[str, List[str]] = None,
|
| 35 |
+
height: Optional[int] = None,
|
| 36 |
+
width: Optional[int] = None,
|
| 37 |
+
num_inference_steps: int = 50,
|
| 38 |
+
timesteps: List[int] = None,
|
| 39 |
+
sigmas: List[float] = None,
|
| 40 |
+
guidance_scale: float = 7.5,
|
| 41 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 42 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 43 |
+
eta: float = 0.0,
|
| 44 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 45 |
+
latents: Optional[torch.Tensor] = None,
|
| 46 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 47 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 48 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 49 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 50 |
+
output_type: Optional[str] = "pil",
|
| 51 |
+
return_dict: bool = True,
|
| 52 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 53 |
+
guidance_rescale: float = 0.0,
|
| 54 |
+
clip_skip: Optional[int] = None,
|
| 55 |
+
callback_on_step_end: Optional[
|
| 56 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 57 |
+
] = None,
|
| 58 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 59 |
+
|
| 60 |
+
## Tangential Scailing Guidance specific parameters
|
| 61 |
+
t_guidance_scale: float = 1.0, # Scale for TGS
|
| 62 |
+
r_guidance_scale: float = 1.0, # Scale for radial guidance
|
| 63 |
+
|
| 64 |
+
## Apply range for each scaling
|
| 65 |
+
sta_tpd: int = 1000, # Start step for tangential scaling
|
| 66 |
+
end_tpd: int = 0, # End step for tangential scaling
|
| 67 |
+
|
| 68 |
+
**kwargs: Any, # Additional arguments for future compatibility
|
| 69 |
+
):
|
| 70 |
+
r"""
|
| 71 |
+
The call function to the pipeline for generation.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 75 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
| 76 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 77 |
+
The height in pixels of the generated image.
|
| 78 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 79 |
+
The width in pixels of the generated image.
|
| 80 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 81 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 82 |
+
expense of slower inference.
|
| 83 |
+
timesteps (`List[int]`, *optional*):
|
| 84 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 85 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 86 |
+
passed will be used. Must be in descending order.
|
| 87 |
+
sigmas (`List[float]`, *optional*):
|
| 88 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 89 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 90 |
+
will be used.
|
| 91 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 92 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
| 93 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 94 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 95 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
| 96 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
| 97 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 98 |
+
The number of images to generate per prompt.
|
| 99 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 100 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
| 101 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
| 102 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 103 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 104 |
+
generation deterministic.
|
| 105 |
+
latents (`torch.Tensor`, *optional*):
|
| 106 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 107 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 108 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 109 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 110 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 111 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 112 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 113 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
| 114 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
| 115 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| 116 |
+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| 117 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
| 118 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
|
| 119 |
+
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
|
| 120 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 121 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 122 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 123 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 124 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 125 |
+
plain tuple.
|
| 126 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 127 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
| 128 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 129 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
| 130 |
+
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
|
| 131 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
|
| 132 |
+
using zero terminal SNR.
|
| 133 |
+
clip_skip (`int`, *optional*):
|
| 134 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 135 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 136 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
| 137 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
| 138 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
| 139 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
| 140 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
| 141 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 142 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 143 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 144 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 145 |
+
|
| 146 |
+
Examples:
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 150 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
| 151 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
| 152 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
| 153 |
+
"not-safe-for-work" (nsfw) content.
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
callback = kwargs.pop("callback", None)
|
| 157 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
| 158 |
+
|
| 159 |
+
if callback is not None:
|
| 160 |
+
deprecate(
|
| 161 |
+
"callback",
|
| 162 |
+
"1.0.0",
|
| 163 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
| 164 |
+
)
|
| 165 |
+
if callback_steps is not None:
|
| 166 |
+
deprecate(
|
| 167 |
+
"callback_steps",
|
| 168 |
+
"1.0.0",
|
| 169 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 173 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 174 |
+
|
| 175 |
+
# 0. Default height and width to unet
|
| 176 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 177 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 178 |
+
# to deal with lora scaling and other possible forward hooks
|
| 179 |
+
|
| 180 |
+
# 1. Check inputs. Raise error if not correct
|
| 181 |
+
self.check_inputs(
|
| 182 |
+
prompt,
|
| 183 |
+
height,
|
| 184 |
+
width,
|
| 185 |
+
callback_steps,
|
| 186 |
+
negative_prompt,
|
| 187 |
+
prompt_embeds,
|
| 188 |
+
negative_prompt_embeds,
|
| 189 |
+
ip_adapter_image,
|
| 190 |
+
ip_adapter_image_embeds,
|
| 191 |
+
callback_on_step_end_tensor_inputs,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
self._guidance_scale = guidance_scale
|
| 195 |
+
self._guidance_rescale = guidance_rescale
|
| 196 |
+
self._clip_skip = clip_skip
|
| 197 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
| 198 |
+
self._interrupt = False
|
| 199 |
+
|
| 200 |
+
# 2. Define call parameters
|
| 201 |
+
if prompt is not None and isinstance(prompt, str):
|
| 202 |
+
batch_size = 1
|
| 203 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 204 |
+
batch_size = len(prompt)
|
| 205 |
+
else:
|
| 206 |
+
batch_size = prompt_embeds.shape[0]
|
| 207 |
+
|
| 208 |
+
device = self._execution_device
|
| 209 |
+
|
| 210 |
+
# 3. Encode input prompt
|
| 211 |
+
lora_scale = (
|
| 212 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 216 |
+
prompt,
|
| 217 |
+
device,
|
| 218 |
+
num_images_per_prompt,
|
| 219 |
+
self.do_classifier_free_guidance,
|
| 220 |
+
negative_prompt,
|
| 221 |
+
prompt_embeds=prompt_embeds,
|
| 222 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 223 |
+
lora_scale=lora_scale,
|
| 224 |
+
clip_skip=self.clip_skip,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 228 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 229 |
+
# to avoid doing two forward passes
|
| 230 |
+
if self.do_classifier_free_guidance:
|
| 231 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 232 |
+
|
| 233 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| 234 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 235 |
+
ip_adapter_image,
|
| 236 |
+
ip_adapter_image_embeds,
|
| 237 |
+
device,
|
| 238 |
+
batch_size * num_images_per_prompt,
|
| 239 |
+
self.do_classifier_free_guidance,
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# 4. Prepare timesteps
|
| 243 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 244 |
+
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
# 5. Prepare latent variables
|
| 248 |
+
num_channels_latents = self.unet.config.in_channels
|
| 249 |
+
latents = self.prepare_latents(
|
| 250 |
+
batch_size * num_images_per_prompt,
|
| 251 |
+
num_channels_latents,
|
| 252 |
+
height,
|
| 253 |
+
width,
|
| 254 |
+
prompt_embeds.dtype,
|
| 255 |
+
device,
|
| 256 |
+
generator,
|
| 257 |
+
latents,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 261 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 262 |
+
|
| 263 |
+
# 6.1 Add image embeds for IP-Adapter
|
| 264 |
+
added_cond_kwargs = (
|
| 265 |
+
{"image_embeds": image_embeds}
|
| 266 |
+
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
|
| 267 |
+
else None
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
# 6.2 Optionally get Guidance Scale Embedding
|
| 271 |
+
timestep_cond = None
|
| 272 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
| 273 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
| 274 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
| 275 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
| 276 |
+
).to(device=device, dtype=latents.dtype)
|
| 277 |
+
|
| 278 |
+
# 7. Denoising loop
|
| 279 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 280 |
+
self._num_timesteps = len(timesteps)
|
| 281 |
+
|
| 282 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 283 |
+
for i, t in enumerate(timesteps):
|
| 284 |
+
if self.interrupt:
|
| 285 |
+
continue
|
| 286 |
+
|
| 287 |
+
# expand the latents if we are doing classifier free guidance
|
| 288 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 289 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 290 |
+
|
| 291 |
+
# predict the noise residual
|
| 292 |
+
noise_pred = self.unet(
|
| 293 |
+
latent_model_input,
|
| 294 |
+
t,
|
| 295 |
+
encoder_hidden_states=prompt_embeds,
|
| 296 |
+
timestep_cond=timestep_cond,
|
| 297 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
| 298 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 299 |
+
return_dict=False,
|
| 300 |
+
)[0]
|
| 301 |
+
|
| 302 |
+
if self.do_classifier_free_guidance:
|
| 303 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 304 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 305 |
+
|
| 306 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
| 307 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
| 308 |
+
|
| 309 |
+
_output = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True)
|
| 310 |
+
|
| 311 |
+
# [NOTE] 7.1.1 Get the current step index
|
| 312 |
+
pred_latent_sample = None
|
| 313 |
+
# [NOTE] 7.2 Get the unit vector of previous latents
|
| 314 |
+
post_latents = latents
|
| 315 |
+
v_t_2d = post_latents / (post_latents.norm(p=2,dim=(1,2,3), keepdim=True) + 1e-8) # Normalize v_t_2d
|
| 316 |
+
|
| 317 |
+
# [NOTE] 7.3 Get the latents and predicted latent sample
|
| 318 |
+
latents = _output.prev_sample
|
| 319 |
+
pred_latent_sample = _output.pred_original_sample if hasattr(_output, 'pred_original_sample') else None
|
| 320 |
+
del _output
|
| 321 |
+
|
| 322 |
+
# [NOTE] 7.4 Retrieve the difference of latents
|
| 323 |
+
delta_latents = latents - post_latents
|
| 324 |
+
delta_unit = (delta_latents * v_t_2d).sum(dim=(1,2,3), keepdim=True) # [B, 1, 1, 1]
|
| 325 |
+
|
| 326 |
+
# [NOTE] 7.5 Calculate the normal and tangential update vectors
|
| 327 |
+
normal_update_vector = delta_unit * v_t_2d
|
| 328 |
+
tangential_update_vector = delta_latents - normal_update_vector
|
| 329 |
+
|
| 330 |
+
eta_v = t_guidance_scale
|
| 331 |
+
eta_r = r_guidance_scale
|
| 332 |
+
|
| 333 |
+
# [NOTE] 7.6 Apply the tangential and normal updates to the latents
|
| 334 |
+
if t <= sta_tpd and t >= end_tpd and t_guidance_scale != 1.:
|
| 335 |
+
pass
|
| 336 |
+
else:
|
| 337 |
+
eta_v = 1.0
|
| 338 |
+
eta_r = 1.0
|
| 339 |
+
|
| 340 |
+
latents = post_latents + \
|
| 341 |
+
eta_r * normal_update_vector + \
|
| 342 |
+
eta_v * tangential_update_vector
|
| 343 |
+
|
| 344 |
+
if callback_on_step_end is not None:
|
| 345 |
+
callback_kwargs = {}
|
| 346 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 347 |
+
callback_kwargs[k] = locals()[k]
|
| 348 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 349 |
+
|
| 350 |
+
latents = callback_outputs.pop("latents", latents)
|
| 351 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 352 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 353 |
+
|
| 354 |
+
# call the callback, if provided
|
| 355 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 356 |
+
progress_bar.update()
|
| 357 |
+
if callback is not None and i % callback_steps == 0:
|
| 358 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 359 |
+
callback(step_idx, t, latents)
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
if not output_type == "latent":
|
| 369 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
| 370 |
+
0
|
| 371 |
+
]
|
| 372 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
| 373 |
+
else:
|
| 374 |
+
image = latents
|
| 375 |
+
has_nsfw_concept = None
|
| 376 |
+
|
| 377 |
+
if has_nsfw_concept is None:
|
| 378 |
+
do_denormalize = [True] * image.shape[0]
|
| 379 |
+
else:
|
| 380 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
| 381 |
+
|
| 382 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
| 383 |
+
|
| 384 |
+
# Offload all models
|
| 385 |
+
self.maybe_free_model_hooks()
|
| 386 |
+
|
| 387 |
+
if not return_dict:
|
| 388 |
+
return (image, has_nsfw_concept)
|
| 389 |
+
|
| 390 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
pipelines/pipeline_tag_stablediffusion3.py
ADDED
|
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from diffusers import StableDiffusion3Pipeline
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import inspect
|
| 5 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from transformers import (
|
| 9 |
+
CLIPTextModelWithProjection,
|
| 10 |
+
CLIPTokenizer,
|
| 11 |
+
SiglipImageProcessor,
|
| 12 |
+
SiglipVisionModel,
|
| 13 |
+
T5EncoderModel,
|
| 14 |
+
T5TokenizerFast,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 18 |
+
from diffusers.loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
|
| 19 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
| 20 |
+
from diffusers.models.transformers import SD3Transformer2DModel
|
| 21 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 22 |
+
from diffusers.utils import (
|
| 23 |
+
USE_PEFT_BACKEND,
|
| 24 |
+
is_torch_xla_available,
|
| 25 |
+
logging,
|
| 26 |
+
replace_example_docstring,
|
| 27 |
+
scale_lora_layers,
|
| 28 |
+
unscale_lora_layers,
|
| 29 |
+
)
|
| 30 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 31 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 32 |
+
from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
|
| 33 |
+
|
| 34 |
+
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import (
|
| 35 |
+
logger,
|
| 36 |
+
EXAMPLE_DOC_STRING,
|
| 37 |
+
calculate_shift,
|
| 38 |
+
retrieve_timesteps,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
class HalCalculatorTensor:
|
| 42 |
+
"""
|
| 43 |
+
Incrementally calculates the average of squared deviations from the mean
|
| 44 |
+
for a tensor of any shape. Each new tensor must have the same shape.
|
| 45 |
+
|
| 46 |
+
Hal(x) is computed elementwise:
|
| 47 |
+
Hal(x) = (1 / n) * sum((x_i - mean)^2)
|
| 48 |
+
which is effectively the population variance at each element.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(self, shape, timestep_range: int = 200):
|
| 52 |
+
"""
|
| 53 |
+
Initialize with a specific tensor shape.
|
| 54 |
+
|
| 55 |
+
:param shape: tuple describing the shape of each input tensor.
|
| 56 |
+
"""
|
| 57 |
+
self.n = 0
|
| 58 |
+
self.mean = torch.zeros(shape) # Running mean (same shape)
|
| 59 |
+
self.m2 = torch.zeros(shape) # Sum of squared deviations (same shape)
|
| 60 |
+
self.timestep_range = timestep_range
|
| 61 |
+
|
| 62 |
+
# 새로운 입력 간 차이(trajectory)를 저장할 리스트 (CPU에 저장)
|
| 63 |
+
self.trajectory = []
|
| 64 |
+
# 이전 입력 tensor를 저장하기 위한 변수
|
| 65 |
+
self.prev_x = None
|
| 66 |
+
|
| 67 |
+
def add(self, x: torch.Tensor) -> torch.Tensor:
|
| 68 |
+
"""
|
| 69 |
+
Add a new tensor and update the running statistics.
|
| 70 |
+
|
| 71 |
+
:param x: A new tensor of the same shape as specified in __init__.
|
| 72 |
+
:return: The current Hal(x) (elementwise average of squared deviations).
|
| 73 |
+
"""
|
| 74 |
+
# 이전 입력이 존재하면, 차이를 계산하여 trajectory에 저장 (항상 CPU에)
|
| 75 |
+
if self.prev_x is not None:
|
| 76 |
+
diff = (x - self.prev_x).detach().cpu()
|
| 77 |
+
self.trajectory.append(diff)
|
| 78 |
+
# 현재 입력을 prev_x로 업데이트 (복사본 생성)
|
| 79 |
+
self.prev_x = x.clone()
|
| 80 |
+
|
| 81 |
+
if self.n == 0:
|
| 82 |
+
# First tensor: mean = x, m2 stays zero
|
| 83 |
+
self.mean = x.clone()
|
| 84 |
+
self.n = 1
|
| 85 |
+
return self.item() # This will be all zeros if n=1
|
| 86 |
+
|
| 87 |
+
self.n += 1
|
| 88 |
+
# Welford's online update
|
| 89 |
+
delta = x - self.mean
|
| 90 |
+
self.mean += delta / self.n
|
| 91 |
+
delta2 = x - self.mean
|
| 92 |
+
self.m2 += delta * delta2
|
| 93 |
+
|
| 94 |
+
return self.item()
|
| 95 |
+
|
| 96 |
+
def item(self) -> torch.Tensor:
|
| 97 |
+
"""
|
| 98 |
+
Return the current elementwise average of squared deviations
|
| 99 |
+
(which is effectively the population variance for each element).
|
| 100 |
+
|
| 101 |
+
:return: A tensor of the same shape, containing Hal(x) values.
|
| 102 |
+
"""
|
| 103 |
+
if self.n < 2:
|
| 104 |
+
# With fewer than 2 values, variance is zero
|
| 105 |
+
return torch.zeros_like(self.mean)
|
| 106 |
+
return self.m2 / self.timestep_range
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# 1. StableDiffusionPipeline을 상속받는 새로운 클래스 정의
|
| 110 |
+
class StableDiffusion3TangentialDecomposedPipeline(StableDiffusion3Pipeline):
|
| 111 |
+
|
| 112 |
+
@torch.no_grad()
|
| 113 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 114 |
+
def __call__(
|
| 115 |
+
self,
|
| 116 |
+
prompt: Union[str, List[str]] = None,
|
| 117 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 118 |
+
prompt_3: Optional[Union[str, List[str]]] = None,
|
| 119 |
+
height: Optional[int] = None,
|
| 120 |
+
width: Optional[int] = None,
|
| 121 |
+
num_inference_steps: int = 28,
|
| 122 |
+
sigmas: Optional[List[float]] = None,
|
| 123 |
+
guidance_scale: float = 7.0,
|
| 124 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 125 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 126 |
+
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
| 127 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 128 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 129 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 130 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 131 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 132 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 133 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 134 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 135 |
+
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
|
| 136 |
+
output_type: Optional[str] = "pil",
|
| 137 |
+
return_dict: bool = True,
|
| 138 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 139 |
+
clip_skip: Optional[int] = None,
|
| 140 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 141 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 142 |
+
max_sequence_length: int = 256,
|
| 143 |
+
skip_guidance_layers: List[int] = None,
|
| 144 |
+
skip_layer_guidance_scale: float = 2.8,
|
| 145 |
+
skip_layer_guidance_stop: float = 0.2,
|
| 146 |
+
skip_layer_guidance_start: float = 0.01,
|
| 147 |
+
mu: Optional[float] = None,
|
| 148 |
+
|
| 149 |
+
## Tangential Scailing Guidance specific parameters
|
| 150 |
+
t_guidance_scale: float = 1.0, # Scale for TGS
|
| 151 |
+
r_guidance_scale: float = 1.0, # Scale for radial guidance
|
| 152 |
+
|
| 153 |
+
## Apply range for each scaling
|
| 154 |
+
sta_tpd: int = 1000, # Start step for tangential scaling
|
| 155 |
+
end_tpd: int = 0, # End step for tangential scaling
|
| 156 |
+
):
|
| 157 |
+
r"""
|
| 158 |
+
Function invoked when calling the pipeline for generation.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 162 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 163 |
+
instead.
|
| 164 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 165 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 166 |
+
will be used instead
|
| 167 |
+
prompt_3 (`str` or `List[str]`, *optional*):
|
| 168 |
+
The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
|
| 169 |
+
will be used instead
|
| 170 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 171 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 172 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 173 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 174 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 175 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 176 |
+
expense of slower inference.
|
| 177 |
+
sigmas (`List[float]`, *optional*):
|
| 178 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 179 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 180 |
+
will be used.
|
| 181 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
| 182 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 183 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 184 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 185 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 186 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 187 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 188 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 189 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 190 |
+
less than `1`).
|
| 191 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
| 192 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
| 193 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used instead
|
| 194 |
+
negative_prompt_3 (`str` or `List[str]`, *optional*):
|
| 195 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
|
| 196 |
+
`text_encoder_3`. If not defined, `negative_prompt` is used instead
|
| 197 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 198 |
+
The number of images to generate per prompt.
|
| 199 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 200 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 201 |
+
to make generation deterministic.
|
| 202 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 203 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 204 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 205 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 206 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 207 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 208 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 209 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 210 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 211 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 212 |
+
argument.
|
| 213 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 214 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 215 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 216 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 217 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 218 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
| 219 |
+
input argument.
|
| 220 |
+
ip_adapter_image (`PipelineImageInput`, *optional*):
|
| 221 |
+
Optional image input to work with IP Adapters.
|
| 222 |
+
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
|
| 223 |
+
Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
|
| 224 |
+
emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
|
| 225 |
+
`True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 226 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 227 |
+
The output format of the generate image. Choose between
|
| 228 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 229 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 230 |
+
Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
|
| 231 |
+
a plain tuple.
|
| 232 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 233 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 234 |
+
`self.processor` in
|
| 235 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 236 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 237 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 238 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 239 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 240 |
+
`callback_on_step_end_tensor_inputs`.
|
| 241 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 242 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 243 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 244 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 245 |
+
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
|
| 246 |
+
skip_guidance_layers (`List[int]`, *optional*):
|
| 247 |
+
A list of integers that specify layers to skip during guidance. If not provided, all layers will be
|
| 248 |
+
used for guidance. If provided, the guidance will only be applied to the layers specified in the list.
|
| 249 |
+
Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is [7, 8, 9].
|
| 250 |
+
skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in
|
| 251 |
+
`skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers`
|
| 252 |
+
with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers
|
| 253 |
+
with a scale of `1`.
|
| 254 |
+
skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in
|
| 255 |
+
`skip_guidance_layers` will stop. The guidance will be applied to the layers specified in
|
| 256 |
+
`skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by
|
| 257 |
+
StabiltyAI for Stable Diffusion 3.5 Medium is 0.2.
|
| 258 |
+
skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in
|
| 259 |
+
`skip_guidance_layers` will start. The guidance will be applied to the layers specified in
|
| 260 |
+
`skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by
|
| 261 |
+
StabiltyAI for Stable Diffusion 3.5 Medium is 0.01.
|
| 262 |
+
mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
|
| 263 |
+
|
| 264 |
+
Examples:
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
[`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
|
| 268 |
+
[`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
|
| 269 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
| 270 |
+
"""
|
| 271 |
+
|
| 272 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 273 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 274 |
+
|
| 275 |
+
# 1. Check inputs. Raise error if not correct
|
| 276 |
+
self.check_inputs(
|
| 277 |
+
prompt,
|
| 278 |
+
prompt_2,
|
| 279 |
+
prompt_3,
|
| 280 |
+
height,
|
| 281 |
+
width,
|
| 282 |
+
negative_prompt=negative_prompt,
|
| 283 |
+
negative_prompt_2=negative_prompt_2,
|
| 284 |
+
negative_prompt_3=negative_prompt_3,
|
| 285 |
+
prompt_embeds=prompt_embeds,
|
| 286 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 287 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 288 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 289 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 290 |
+
max_sequence_length=max_sequence_length,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
self._guidance_scale = guidance_scale
|
| 294 |
+
self._skip_layer_guidance_scale = skip_layer_guidance_scale
|
| 295 |
+
self._clip_skip = clip_skip
|
| 296 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 297 |
+
self._interrupt = False
|
| 298 |
+
|
| 299 |
+
# 2. Define call parameters
|
| 300 |
+
if prompt is not None and isinstance(prompt, str):
|
| 301 |
+
batch_size = 1
|
| 302 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 303 |
+
batch_size = len(prompt)
|
| 304 |
+
else:
|
| 305 |
+
batch_size = prompt_embeds.shape[0]
|
| 306 |
+
|
| 307 |
+
device = self._execution_device
|
| 308 |
+
|
| 309 |
+
lora_scale = (
|
| 310 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 311 |
+
)
|
| 312 |
+
(
|
| 313 |
+
prompt_embeds,
|
| 314 |
+
negative_prompt_embeds,
|
| 315 |
+
pooled_prompt_embeds,
|
| 316 |
+
negative_pooled_prompt_embeds,
|
| 317 |
+
) = self.encode_prompt(
|
| 318 |
+
prompt=prompt,
|
| 319 |
+
prompt_2=prompt_2,
|
| 320 |
+
prompt_3=prompt_3,
|
| 321 |
+
negative_prompt=negative_prompt,
|
| 322 |
+
negative_prompt_2=negative_prompt_2,
|
| 323 |
+
negative_prompt_3=negative_prompt_3,
|
| 324 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 325 |
+
prompt_embeds=prompt_embeds,
|
| 326 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 327 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 328 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 329 |
+
device=device,
|
| 330 |
+
clip_skip=self.clip_skip,
|
| 331 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 332 |
+
max_sequence_length=max_sequence_length,
|
| 333 |
+
lora_scale=lora_scale,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
if self.do_classifier_free_guidance:
|
| 337 |
+
if skip_guidance_layers is not None:
|
| 338 |
+
original_prompt_embeds = prompt_embeds
|
| 339 |
+
original_pooled_prompt_embeds = pooled_prompt_embeds
|
| 340 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 341 |
+
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
| 342 |
+
|
| 343 |
+
# 4. Prepare latent variables
|
| 344 |
+
num_channels_latents = self.transformer.config.in_channels
|
| 345 |
+
latents = self.prepare_latents(
|
| 346 |
+
batch_size * num_images_per_prompt,
|
| 347 |
+
num_channels_latents,
|
| 348 |
+
height,
|
| 349 |
+
width,
|
| 350 |
+
prompt_embeds.dtype,
|
| 351 |
+
device,
|
| 352 |
+
generator,
|
| 353 |
+
latents,
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
# 5. Prepare timesteps
|
| 357 |
+
scheduler_kwargs = {}
|
| 358 |
+
if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
|
| 359 |
+
_, _, height, width = latents.shape
|
| 360 |
+
image_seq_len = (height // self.transformer.config.patch_size) * (
|
| 361 |
+
width // self.transformer.config.patch_size
|
| 362 |
+
)
|
| 363 |
+
mu = calculate_shift(
|
| 364 |
+
image_seq_len,
|
| 365 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 366 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 367 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 368 |
+
self.scheduler.config.get("max_shift", 1.16),
|
| 369 |
+
)
|
| 370 |
+
scheduler_kwargs["mu"] = mu
|
| 371 |
+
elif mu is not None:
|
| 372 |
+
scheduler_kwargs["mu"] = mu
|
| 373 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 374 |
+
self.scheduler,
|
| 375 |
+
num_inference_steps,
|
| 376 |
+
device,
|
| 377 |
+
sigmas=sigmas,
|
| 378 |
+
**scheduler_kwargs,
|
| 379 |
+
)
|
| 380 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 381 |
+
self._num_timesteps = len(timesteps)
|
| 382 |
+
|
| 383 |
+
# 6. Prepare image embeddings
|
| 384 |
+
if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
|
| 385 |
+
ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 386 |
+
ip_adapter_image,
|
| 387 |
+
ip_adapter_image_embeds,
|
| 388 |
+
device,
|
| 389 |
+
batch_size * num_images_per_prompt,
|
| 390 |
+
self.do_classifier_free_guidance,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
if self.joint_attention_kwargs is None:
|
| 394 |
+
self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
|
| 395 |
+
else:
|
| 396 |
+
self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
|
| 397 |
+
|
| 398 |
+
# 7. Denoising loop
|
| 399 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 400 |
+
for i, t in enumerate(timesteps):
|
| 401 |
+
if self.interrupt:
|
| 402 |
+
continue
|
| 403 |
+
|
| 404 |
+
# expand the latents if we are doing classifier free guidance
|
| 405 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 406 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 407 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 408 |
+
|
| 409 |
+
noise_pred = self.transformer(
|
| 410 |
+
hidden_states=latent_model_input,
|
| 411 |
+
timestep=timestep,
|
| 412 |
+
encoder_hidden_states=prompt_embeds,
|
| 413 |
+
pooled_projections=pooled_prompt_embeds,
|
| 414 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 415 |
+
return_dict=False,
|
| 416 |
+
)[0]
|
| 417 |
+
|
| 418 |
+
# perform guidance
|
| 419 |
+
if self.do_classifier_free_guidance:
|
| 420 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 421 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 422 |
+
should_skip_layers = (
|
| 423 |
+
True
|
| 424 |
+
if i > num_inference_steps * skip_layer_guidance_start
|
| 425 |
+
and i < num_inference_steps * skip_layer_guidance_stop
|
| 426 |
+
else False
|
| 427 |
+
)
|
| 428 |
+
if skip_guidance_layers is not None and should_skip_layers:
|
| 429 |
+
timestep = t.expand(latents.shape[0])
|
| 430 |
+
latent_model_input = latents
|
| 431 |
+
noise_pred_skip_layers = self.transformer(
|
| 432 |
+
hidden_states=latent_model_input,
|
| 433 |
+
timestep=timestep,
|
| 434 |
+
encoder_hidden_states=original_prompt_embeds,
|
| 435 |
+
pooled_projections=original_pooled_prompt_embeds,
|
| 436 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 437 |
+
return_dict=False,
|
| 438 |
+
skip_layers=skip_guidance_layers,
|
| 439 |
+
)[0]
|
| 440 |
+
noise_pred = (
|
| 441 |
+
noise_pred + (noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 445 |
+
latents_dtype = latents.dtype
|
| 446 |
+
output = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 447 |
+
if t <= sta_tpd and t >= end_tpd:
|
| 448 |
+
post_latents = latents
|
| 449 |
+
v_t_2d = post_latents / (post_latents.norm(p=2, dim=(1,2,3), keepdim=True) + 1e-8)
|
| 450 |
+
|
| 451 |
+
latents = output
|
| 452 |
+
|
| 453 |
+
delta_latents = latents - post_latents
|
| 454 |
+
delta_unit = (delta_latents * v_t_2d).sum(dim=(1,2,3), keepdim=True)
|
| 455 |
+
|
| 456 |
+
normal_update_vector = delta_unit * v_t_2d
|
| 457 |
+
tangential_update_vector = delta_latents - normal_update_vector
|
| 458 |
+
|
| 459 |
+
eta_v = t_guidance_scale
|
| 460 |
+
eta_n = r_guidance_scale
|
| 461 |
+
|
| 462 |
+
latents = post_latents + \
|
| 463 |
+
eta_v * tangential_update_vector + \
|
| 464 |
+
eta_n * normal_update_vector
|
| 465 |
+
else: # [NOTE] Simple Path (equal to original)
|
| 466 |
+
latents = output
|
| 467 |
+
|
| 468 |
+
# [NOTE] Apple MPS Bug -- Don't need to care
|
| 469 |
+
if latents.dtype != latents_dtype:
|
| 470 |
+
if torch.backends.mps.is_available():
|
| 471 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 472 |
+
latents = latents.to(latents_dtype)
|
| 473 |
+
|
| 474 |
+
if callback_on_step_end is not None:
|
| 475 |
+
callback_kwargs = {}
|
| 476 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 477 |
+
callback_kwargs[k] = locals()[k]
|
| 478 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 479 |
+
|
| 480 |
+
latents = callback_outputs.pop("latents", latents)
|
| 481 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 482 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 483 |
+
negative_pooled_prompt_embeds = callback_outputs.pop(
|
| 484 |
+
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
# call the callback, if provided
|
| 488 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 489 |
+
progress_bar.update()
|
| 490 |
+
|
| 491 |
+
# if XLA_AVAILABLE:
|
| 492 |
+
# xm.mark_step()
|
| 493 |
+
|
| 494 |
+
if output_type == "latent":
|
| 495 |
+
image = latents
|
| 496 |
+
|
| 497 |
+
else:
|
| 498 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 499 |
+
|
| 500 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 501 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 502 |
+
|
| 503 |
+
# Offload all models
|
| 504 |
+
self.maybe_free_model_hooks()
|
| 505 |
+
|
| 506 |
+
if not return_dict:
|
| 507 |
+
return (image,)
|
| 508 |
+
|
| 509 |
+
return StableDiffusion3PipelineOutput(images=image)
|
pipelines/pipeline_tag_stablediffusionXL.py
ADDED
|
@@ -0,0 +1,585 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from diffusers import StableDiffusionXLPipeline
|
| 2 |
+
|
| 3 |
+
import inspect
|
| 4 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from transformers import (
|
| 8 |
+
CLIPImageProcessor,
|
| 9 |
+
CLIPTextModel,
|
| 10 |
+
CLIPTextModelWithProjection,
|
| 11 |
+
CLIPTokenizer,
|
| 12 |
+
CLIPVisionModelWithProjection,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 16 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 17 |
+
from diffusers.loaders import (
|
| 18 |
+
FromSingleFileMixin,
|
| 19 |
+
IPAdapterMixin,
|
| 20 |
+
StableDiffusionXLLoraLoaderMixin,
|
| 21 |
+
TextualInversionLoaderMixin,
|
| 22 |
+
)
|
| 23 |
+
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
| 24 |
+
from diffusers.models.attention_processor import (
|
| 25 |
+
AttnProcessor2_0,
|
| 26 |
+
FusedAttnProcessor2_0,
|
| 27 |
+
LoRAAttnProcessor2_0,
|
| 28 |
+
LoRAXFormersAttnProcessor,
|
| 29 |
+
XFormersAttnProcessor,
|
| 30 |
+
)
|
| 31 |
+
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
| 32 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
| 33 |
+
from diffusers.utils import (
|
| 34 |
+
USE_PEFT_BACKEND,
|
| 35 |
+
deprecate,
|
| 36 |
+
is_invisible_watermark_available,
|
| 37 |
+
is_torch_xla_available,
|
| 38 |
+
logging,
|
| 39 |
+
replace_example_docstring,
|
| 40 |
+
scale_lora_layers,
|
| 41 |
+
unscale_lora_layers,
|
| 42 |
+
)
|
| 43 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 44 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
| 45 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
| 46 |
+
|
| 47 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import (
|
| 48 |
+
is_invisible_watermark_available,
|
| 49 |
+
is_torch_xla_available,
|
| 50 |
+
logger,
|
| 51 |
+
EXAMPLE_DOC_STRING,
|
| 52 |
+
rescale_noise_cfg,
|
| 53 |
+
retrieve_timesteps
|
| 54 |
+
)
|
| 55 |
+
# from custom_hal import HalCalculatorTensor
|
| 56 |
+
|
| 57 |
+
# 1. StableDiffusionPipeline을 상속받는 새로운 클래스 정의
|
| 58 |
+
class StableDiffusionXLTangentialDecomposedPipeline(StableDiffusionXLPipeline):
|
| 59 |
+
|
| 60 |
+
@torch.no_grad()
|
| 61 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 62 |
+
def __call__(
|
| 63 |
+
self,
|
| 64 |
+
prompt: Union[str, List[str]] = None,
|
| 65 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 66 |
+
height: Optional[int] = None,
|
| 67 |
+
width: Optional[int] = None,
|
| 68 |
+
num_inference_steps: int = 50,
|
| 69 |
+
timesteps: List[int] = None,
|
| 70 |
+
sigmas: List[float] = None,
|
| 71 |
+
denoising_end: Optional[float] = None,
|
| 72 |
+
guidance_scale: float = 5.0,
|
| 73 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 74 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 75 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 76 |
+
eta: float = 0.0,
|
| 77 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 78 |
+
latents: Optional[torch.Tensor] = None,
|
| 79 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 80 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 81 |
+
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
| 82 |
+
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
| 83 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 84 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
| 85 |
+
output_type: Optional[str] = "pil",
|
| 86 |
+
return_dict: bool = True,
|
| 87 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 88 |
+
guidance_rescale: float = 0.0,
|
| 89 |
+
original_size: Optional[Tuple[int, int]] = None,
|
| 90 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 91 |
+
target_size: Optional[Tuple[int, int]] = None,
|
| 92 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
| 93 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 94 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
| 95 |
+
clip_skip: Optional[int] = None,
|
| 96 |
+
callback_on_step_end: Optional[
|
| 97 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 98 |
+
] = None,
|
| 99 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 100 |
+
|
| 101 |
+
## Repregentation Guidance specific parameters
|
| 102 |
+
f_phi: Callable[[torch.Tensor], torch.Tensor] = None, # Feature extractor
|
| 103 |
+
V: torch.Tensor = None, # Representative vectors
|
| 104 |
+
c: int = 0, # Class label for which to apply RepG
|
| 105 |
+
rep_guidance_scale: float = 1.0, # Scale for RepG
|
| 106 |
+
distance_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
| 107 |
+
################################################
|
| 108 |
+
|
| 109 |
+
## Tangential Scailing Guidance specific parameters
|
| 110 |
+
t_guidance_scale: float = 1.0, # Scale for TGS
|
| 111 |
+
r_guidance_scale: float = 1.0, # Scale for radial guidance
|
| 112 |
+
|
| 113 |
+
## Apply range for each scaling
|
| 114 |
+
sta_tpd: int = 1000, # Start step for tangential scaling
|
| 115 |
+
end_tpd: int = 0, # End step for tangential scaling
|
| 116 |
+
# r_start_step: int = 999, # Start step for radial scaling
|
| 117 |
+
# r_end_step: int = 0, # End step for radial scaling
|
| 118 |
+
|
| 119 |
+
**kwargs: Any, # Additional arguments for future compatibility
|
| 120 |
+
):
|
| 121 |
+
r"""
|
| 122 |
+
Function invoked when calling the pipeline for generation.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 126 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 127 |
+
instead.
|
| 128 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 129 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 130 |
+
used in both text-encoders
|
| 131 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 132 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 133 |
+
Anything below 512 pixels won't work well for
|
| 134 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
| 135 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
| 136 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 137 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 138 |
+
Anything below 512 pixels won't work well for
|
| 139 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
| 140 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
| 141 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 142 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 143 |
+
expense of slower inference.
|
| 144 |
+
timesteps (`List[int]`, *optional*):
|
| 145 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 146 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 147 |
+
passed will be used. Must be in descending order.
|
| 148 |
+
sigmas (`List[float]`, *optional*):
|
| 149 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 150 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 151 |
+
will be used.
|
| 152 |
+
denoising_end (`float`, *optional*):
|
| 153 |
+
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
|
| 154 |
+
completed before it is intentionally prematurely terminated. As a result, the returned sample will
|
| 155 |
+
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
|
| 156 |
+
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
|
| 157 |
+
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
|
| 158 |
+
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
|
| 159 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
| 160 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 161 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 162 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 163 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 164 |
+
usually at the expense of lower image quality.
|
| 165 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 166 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 167 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 168 |
+
less than `1`).
|
| 169 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
| 170 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
| 171 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
| 172 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 173 |
+
The number of images to generate per prompt.
|
| 174 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 175 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 176 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 177 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 178 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 179 |
+
to make generation deterministic.
|
| 180 |
+
latents (`torch.Tensor`, *optional*):
|
| 181 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 182 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 183 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 184 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 185 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 186 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 187 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 188 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 189 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 190 |
+
argument.
|
| 191 |
+
pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
| 192 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 193 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 194 |
+
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
| 195 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 196 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
| 197 |
+
input argument.
|
| 198 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
| 199 |
+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
| 200 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
| 201 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
|
| 202 |
+
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
|
| 203 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
| 204 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 205 |
+
The output format of the generate image. Choose between
|
| 206 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 207 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 208 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
| 209 |
+
of a plain tuple.
|
| 210 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 211 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 212 |
+
`self.processor` in
|
| 213 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 214 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
| 215 |
+
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
| 216 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
| 217 |
+
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
| 218 |
+
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
| 219 |
+
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
| 220 |
+
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
| 221 |
+
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
| 222 |
+
explained in section 2.2 of
|
| 223 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
| 224 |
+
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
| 225 |
+
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
| 226 |
+
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
| 227 |
+
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
| 228 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
| 229 |
+
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
| 230 |
+
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
| 231 |
+
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
|
| 232 |
+
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
| 233 |
+
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
| 234 |
+
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
| 235 |
+
micro-conditioning as explained in section 2.2 of
|
| 236 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
| 237 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
| 238 |
+
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
| 239 |
+
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
| 240 |
+
micro-conditioning as explained in section 2.2 of
|
| 241 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
| 242 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
| 243 |
+
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
| 244 |
+
To negatively condition the generation process based on a target image resolution. It should be as same
|
| 245 |
+
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
| 246 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
| 247 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
| 248 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
| 249 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
| 250 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
| 251 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
| 252 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
| 253 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 254 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 255 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 256 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 257 |
+
|
| 258 |
+
Examples:
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
|
| 262 |
+
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
| 263 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
| 264 |
+
"""
|
| 265 |
+
|
| 266 |
+
callback = kwargs.pop("callback", None)
|
| 267 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
| 268 |
+
|
| 269 |
+
if callback is not None:
|
| 270 |
+
deprecate(
|
| 271 |
+
"callback",
|
| 272 |
+
"1.0.0",
|
| 273 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
| 274 |
+
)
|
| 275 |
+
if callback_steps is not None:
|
| 276 |
+
deprecate(
|
| 277 |
+
"callback_steps",
|
| 278 |
+
"1.0.0",
|
| 279 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 283 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 284 |
+
|
| 285 |
+
# 0. Default height and width to unet
|
| 286 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 287 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 288 |
+
|
| 289 |
+
original_size = original_size or (height, width)
|
| 290 |
+
target_size = target_size or (height, width)
|
| 291 |
+
|
| 292 |
+
# 1. Check inputs. Raise error if not correct
|
| 293 |
+
self.check_inputs(
|
| 294 |
+
prompt,
|
| 295 |
+
prompt_2,
|
| 296 |
+
height,
|
| 297 |
+
width,
|
| 298 |
+
callback_steps,
|
| 299 |
+
negative_prompt,
|
| 300 |
+
negative_prompt_2,
|
| 301 |
+
prompt_embeds,
|
| 302 |
+
negative_prompt_embeds,
|
| 303 |
+
pooled_prompt_embeds,
|
| 304 |
+
negative_pooled_prompt_embeds,
|
| 305 |
+
ip_adapter_image,
|
| 306 |
+
ip_adapter_image_embeds,
|
| 307 |
+
callback_on_step_end_tensor_inputs,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
self._guidance_scale = guidance_scale
|
| 311 |
+
self._guidance_rescale = guidance_rescale
|
| 312 |
+
self._clip_skip = clip_skip
|
| 313 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
| 314 |
+
self._denoising_end = denoising_end
|
| 315 |
+
self._interrupt = False
|
| 316 |
+
|
| 317 |
+
# 2. Define call parameters
|
| 318 |
+
if prompt is not None and isinstance(prompt, str):
|
| 319 |
+
batch_size = 1
|
| 320 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 321 |
+
batch_size = len(prompt)
|
| 322 |
+
else:
|
| 323 |
+
batch_size = prompt_embeds.shape[0]
|
| 324 |
+
|
| 325 |
+
device = self._execution_device
|
| 326 |
+
|
| 327 |
+
# 3. Encode input prompt
|
| 328 |
+
lora_scale = (
|
| 329 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
(
|
| 333 |
+
prompt_embeds,
|
| 334 |
+
negative_prompt_embeds,
|
| 335 |
+
pooled_prompt_embeds,
|
| 336 |
+
negative_pooled_prompt_embeds,
|
| 337 |
+
) = self.encode_prompt(
|
| 338 |
+
prompt=prompt,
|
| 339 |
+
prompt_2=prompt_2,
|
| 340 |
+
device=device,
|
| 341 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 342 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 343 |
+
negative_prompt=negative_prompt,
|
| 344 |
+
negative_prompt_2=negative_prompt_2,
|
| 345 |
+
prompt_embeds=prompt_embeds,
|
| 346 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 347 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 348 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 349 |
+
lora_scale=lora_scale,
|
| 350 |
+
clip_skip=self.clip_skip,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
# 4. Prepare timesteps
|
| 354 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 355 |
+
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
# 5. Prepare latent variables
|
| 359 |
+
num_channels_latents = self.unet.config.in_channels
|
| 360 |
+
latents = self.prepare_latents(
|
| 361 |
+
batch_size * num_images_per_prompt,
|
| 362 |
+
num_channels_latents,
|
| 363 |
+
height,
|
| 364 |
+
width,
|
| 365 |
+
prompt_embeds.dtype,
|
| 366 |
+
device,
|
| 367 |
+
generator,
|
| 368 |
+
latents,
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 372 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 373 |
+
|
| 374 |
+
# 7. Prepare added time ids & embeddings
|
| 375 |
+
add_text_embeds = pooled_prompt_embeds
|
| 376 |
+
if self.text_encoder_2 is None:
|
| 377 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
| 378 |
+
else:
|
| 379 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
| 380 |
+
|
| 381 |
+
add_time_ids = self._get_add_time_ids(
|
| 382 |
+
original_size,
|
| 383 |
+
crops_coords_top_left,
|
| 384 |
+
target_size,
|
| 385 |
+
dtype=prompt_embeds.dtype,
|
| 386 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
| 387 |
+
)
|
| 388 |
+
if negative_original_size is not None and negative_target_size is not None:
|
| 389 |
+
negative_add_time_ids = self._get_add_time_ids(
|
| 390 |
+
negative_original_size,
|
| 391 |
+
negative_crops_coords_top_left,
|
| 392 |
+
negative_target_size,
|
| 393 |
+
dtype=prompt_embeds.dtype,
|
| 394 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
| 395 |
+
)
|
| 396 |
+
else:
|
| 397 |
+
negative_add_time_ids = add_time_ids
|
| 398 |
+
|
| 399 |
+
if self.do_classifier_free_guidance:
|
| 400 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 401 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
| 402 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
| 403 |
+
|
| 404 |
+
prompt_embeds = prompt_embeds.to(device)
|
| 405 |
+
add_text_embeds = add_text_embeds.to(device)
|
| 406 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
| 407 |
+
|
| 408 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| 409 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
| 410 |
+
ip_adapter_image,
|
| 411 |
+
ip_adapter_image_embeds,
|
| 412 |
+
device,
|
| 413 |
+
batch_size * num_images_per_prompt,
|
| 414 |
+
self.do_classifier_free_guidance,
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
# 8. Denoising loop
|
| 418 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 419 |
+
|
| 420 |
+
# 8.1 Apply denoising_end
|
| 421 |
+
if (
|
| 422 |
+
self.denoising_end is not None
|
| 423 |
+
and isinstance(self.denoising_end, float)
|
| 424 |
+
and self.denoising_end > 0
|
| 425 |
+
and self.denoising_end < 1
|
| 426 |
+
):
|
| 427 |
+
discrete_timestep_cutoff = int(
|
| 428 |
+
round(
|
| 429 |
+
self.scheduler.config.num_train_timesteps
|
| 430 |
+
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
|
| 431 |
+
)
|
| 432 |
+
)
|
| 433 |
+
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
| 434 |
+
timesteps = timesteps[:num_inference_steps]
|
| 435 |
+
|
| 436 |
+
# 9. Optionally get Guidance Scale Embedding
|
| 437 |
+
timestep_cond = None
|
| 438 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
| 439 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
| 440 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
| 441 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
| 442 |
+
).to(device=device, dtype=latents.dtype)
|
| 443 |
+
|
| 444 |
+
self._num_timesteps = len(timesteps)
|
| 445 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 446 |
+
for i, t in enumerate(timesteps):
|
| 447 |
+
if self.interrupt:
|
| 448 |
+
continue
|
| 449 |
+
|
| 450 |
+
# expand the latents if we are doing classifier free guidance
|
| 451 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 452 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 453 |
+
|
| 454 |
+
# predict the noise residual
|
| 455 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
| 456 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
| 457 |
+
added_cond_kwargs["image_embeds"] = image_embeds
|
| 458 |
+
noise_pred = self.unet(
|
| 459 |
+
latent_model_input,
|
| 460 |
+
t,
|
| 461 |
+
encoder_hidden_states=prompt_embeds,
|
| 462 |
+
timestep_cond=timestep_cond,
|
| 463 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
| 464 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 465 |
+
return_dict=False,
|
| 466 |
+
)[0]
|
| 467 |
+
|
| 468 |
+
# perform guidance
|
| 469 |
+
if self.do_classifier_free_guidance:
|
| 470 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 471 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 472 |
+
|
| 473 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
| 474 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
| 475 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
| 476 |
+
|
| 477 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 478 |
+
latents_dtype = latents.dtype
|
| 479 |
+
|
| 480 |
+
# [NOTE] Apply the Tangential Scaling Guidance
|
| 481 |
+
output = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 482 |
+
if t <= sta_tpd and t >= end_tpd:
|
| 483 |
+
# Apply Tangential Scaling Guidance
|
| 484 |
+
post_latents = latents
|
| 485 |
+
v_t_2d = post_latents / (post_latents.norm(p=2, dim=(1,2,3), keepdim=True) + 1e-8)
|
| 486 |
+
|
| 487 |
+
latents = output
|
| 488 |
+
|
| 489 |
+
delta_latents = latents - post_latents
|
| 490 |
+
delta_unit = (delta_latents * v_t_2d).sum(dim=(1,2,3), keepdim=True)
|
| 491 |
+
|
| 492 |
+
normal_update_vector = delta_unit * v_t_2d
|
| 493 |
+
tangential_update_vector = delta_latents - normal_update_vector
|
| 494 |
+
|
| 495 |
+
eta_v = t_guidance_scale
|
| 496 |
+
eta_n = r_guidance_scale
|
| 497 |
+
|
| 498 |
+
latents = post_latents + \
|
| 499 |
+
eta_v * tangential_update_vector + \
|
| 500 |
+
eta_n * normal_update_vector
|
| 501 |
+
else: # [NOTE] Simple Path (equal to original)
|
| 502 |
+
latents = output
|
| 503 |
+
|
| 504 |
+
# [NOTE] Apple MPS Bug -- Don't need to care
|
| 505 |
+
if latents.dtype != latents_dtype:
|
| 506 |
+
if torch.backends.mps.is_available():
|
| 507 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 508 |
+
latents = latents.to(latents_dtype)
|
| 509 |
+
|
| 510 |
+
if callback_on_step_end is not None:
|
| 511 |
+
callback_kwargs = {}
|
| 512 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 513 |
+
callback_kwargs[k] = locals()[k]
|
| 514 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 515 |
+
|
| 516 |
+
latents = callback_outputs.pop("latents", latents)
|
| 517 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 518 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 519 |
+
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
|
| 520 |
+
negative_pooled_prompt_embeds = callback_outputs.pop(
|
| 521 |
+
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
| 522 |
+
)
|
| 523 |
+
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
| 524 |
+
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
|
| 525 |
+
|
| 526 |
+
# call the callback, if provided
|
| 527 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 528 |
+
progress_bar.update()
|
| 529 |
+
if callback is not None and i % callback_steps == 0:
|
| 530 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 531 |
+
callback(step_idx, t, latents)
|
| 532 |
+
|
| 533 |
+
# [NOTE] Disabled XLA (TPU) support for now
|
| 534 |
+
# if XLA_AVAILABLE:
|
| 535 |
+
# xm.mark_step()
|
| 536 |
+
|
| 537 |
+
if not output_type == "latent":
|
| 538 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
| 539 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
| 540 |
+
|
| 541 |
+
if needs_upcasting:
|
| 542 |
+
self.upcast_vae()
|
| 543 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
| 544 |
+
elif latents.dtype != self.vae.dtype:
|
| 545 |
+
if torch.backends.mps.is_available():
|
| 546 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 547 |
+
self.vae = self.vae.to(latents.dtype)
|
| 548 |
+
|
| 549 |
+
# unscale/denormalize the latents
|
| 550 |
+
# denormalize with the mean and std if available and not None
|
| 551 |
+
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
| 552 |
+
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
| 553 |
+
if has_latents_mean and has_latents_std:
|
| 554 |
+
latents_mean = (
|
| 555 |
+
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
| 556 |
+
)
|
| 557 |
+
latents_std = (
|
| 558 |
+
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
| 559 |
+
)
|
| 560 |
+
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
| 561 |
+
else:
|
| 562 |
+
latents = latents / self.vae.config.scaling_factor
|
| 563 |
+
|
| 564 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 565 |
+
|
| 566 |
+
# cast back to fp16 if needed
|
| 567 |
+
if needs_upcasting:
|
| 568 |
+
self.vae.to(dtype=torch.float16)
|
| 569 |
+
else:
|
| 570 |
+
image = latents
|
| 571 |
+
|
| 572 |
+
if not output_type == "latent":
|
| 573 |
+
# apply watermark if available
|
| 574 |
+
if self.watermark is not None:
|
| 575 |
+
image = self.watermark.apply_watermark(image)
|
| 576 |
+
|
| 577 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 578 |
+
|
| 579 |
+
# Offload all models
|
| 580 |
+
self.maybe_free_model_hooks()
|
| 581 |
+
|
| 582 |
+
if not return_dict:
|
| 583 |
+
return (image,)
|
| 584 |
+
|
| 585 |
+
return StableDiffusionXLPipelineOutput(images=image)
|