hyeoncho01's picture
rdn_box_false
3d1192b
raw
history blame
7.04 kB
import gradio as gr
import numpy as np
import random
import torch
import spaces
# ์‚ฌ์šฉํ•  ํŒŒ์ดํ”„๋ผ์ธ๋“ค์„ ๋ชจ๋‘ import ํ•ฉ๋‹ˆ๋‹ค.
from diffusers import DiffusionPipeline
# 'pipelines' ํด๋”๊ฐ€ ์žˆ๋‹ค๊ณ  ๊ฐ€์ •ํ•ฉ๋‹ˆ๋‹ค.
from pipelines.pipeline_tag_stablediffusion import StableDiffusionTangentialDecomposedPipeline
from pipelines.pipeline_tag_stablediffusion3 import StableDiffusion3TangentialDecomposedPipeline
from pipelines.pipeline_tag_stablediffusionXL import StableDiffusionXLTangentialDecomposedPipeline
# --- ์„ค์ • ---
MODEL_MAP = {
"SD 1.5": "runwayml/stable-diffusion-v1-5",
"SD 2.1": "stabilityai/stable-diffusion-2-1",
"SDXL": "stabilityai/stable-diffusion-xl-base-1.0",
"SD 3": "stabilityai/stable-diffusion-3-medium-diffusers",
}
RESOLUTION_MAP = { "SD 1.5": 512, "SD 2.1": 768, "SDXL": 1024, "SD 3": 1024 }
SEED_MAP = { "SD 1.5": 850728, "SD 2.1": 944905, "SDXL": 450040818, "SD 3": 282386105 }
TAG_SCALE_MAP = {
"SD 1.5": 1.15, # ๊ธฐ๋ณธ๊ฐ’
"SD 2.1": 1.15, # ๊ธฐ๋ณธ๊ฐ’
"SDXL": 1.20,
"SD 3": 1.08
}
PIPELINE_MAP = {
"SD 1.5": StableDiffusionTangentialDecomposedPipeline,
"SD 2.1": StableDiffusionTangentialDecomposedPipeline,
"SDXL": StableDiffusionXLTangentialDecomposedPipeline,
"SD 3": StableDiffusion3TangentialDecomposedPipeline,
}
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
pipe = None
current_model_id = None
MAX_SEED = np.iinfo(np.int32).max
# --- ํ•จ์ˆ˜ ---
def load_pipeline(model_name, progress):
global pipe, current_model_id
model_id = MODEL_MAP[model_name]
pipeline_class = PIPELINE_MAP[model_name]
progress(0, desc=f"Loading model: {model_id} with {pipeline_class.__name__}...")
if model_name == "SD 3":
pipe = pipeline_class.from_pretrained(model_id, torch_dtype=torch_dtype,
text_encoder_3=None,
tokenizer_3=None,)
else:
pipe = pipeline_class.from_pretrained(model_id, torch_dtype=torch_dtype)
pipe = pipe.to(device)
current_model_id = model_id
progress(1)
def update_model_defaults(model_name):
"""๋ชจ๋ธ ์„ ํƒ์— ๋”ฐ๋ผ ํ•ด์ƒ๋„, ์‹œ๋“œ, ๋žœ๋ค ์‹œ๋“œ ์ฒดํฌ๋ฐ•์Šค, TAG Scale์„ ์—…๋ฐ์ดํŠธํ•ฉ๋‹ˆ๋‹ค."""
res = RESOLUTION_MAP[model_name]
seed_val = SEED_MAP[model_name]
tag_scale_val = TAG_SCALE_MAP[model_name]
return (
gr.update(value=res),
gr.update(value=res),
gr.update(value=seed_val),
gr.update(value=False), # 'Randomize seed' ์ฒดํฌ ํ•ด์ œ
gr.update(value=tag_scale_val), # TAG Scale ์—…๋ฐ์ดํŠธ
)
@spaces.GPU
def infer(
model_name,
seed,
randomize_seed,
width,
height,
guidance_scale, # ์‚ฌ์šฉ์ž ์ง€์ • TAG Scale
num_inference_steps,
guidance_start_timestep,
guidance_end_timestep,
progress=gr.Progress(track_tqdm=True),
):
global pipe, current_model_id
model_id = MODEL_MAP[model_name]
if model_id != current_model_id:
gr.Info(f"Changing model to {model_name}. Please wait...")
load_pipeline(model_name, progress)
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator_custom = torch.Generator(device=device).manual_seed(int(seed))
generator_fixed = torch.Generator(device=device).manual_seed(int(seed))
unconditional_prompt = ""
# ์ฒซ ๋ฒˆ์งธ ์ด๋ฏธ์ง€ (์‚ฌ์šฉ์ž ์ง€์ • TAG Scale) ์ƒ์„ฑ
image_custom_scale = pipe(
prompt=unconditional_prompt, guidance_scale=0.,
num_inference_steps=num_inference_steps, width=width, height=height, generator=generator_custom,
sta_tpd=guidance_start_timestep, end_tpd=guidance_end_timestep,
t_guidance_scale=guidance_scale
).images[0]
fixed_tag_value = 1.0
image_fixed_scale = pipe(
prompt=unconditional_prompt, guidance_scale=0.,
num_inference_steps=num_inference_steps, width=width, height=height, generator=generator_fixed,
sta_tpd=guidance_start_timestep, end_tpd=guidance_end_timestep,
t_guidance_scale=fixed_tag_value
).images[0]
return [image_fixed_scale, image_custom_scale], seed
# --- UI ๊ตฌ์„ฑ (Gradio) ---
css = """
#col-container { margin: 0 auto; max-width: 720px; }
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# Tangential Amplifying Guidance Demo")
model_selector = gr.Dropdown(
label="Select Model", choices=list(MODEL_MAP.keys()), value="SDXL"
)
with gr.Row():
prompt = gr.Text(
label="Prompt (Disabled)", show_label=True, max_lines=1,
placeholder="Unconditional generation mode. This input is ignored.",
container=True, interactive=False,
)
run_button = gr.Button("Run", scale=0, variant="primary")
# --- 2. gr.ImageSlider ์ปดํฌ๋„ŒํŠธ๋กœ ๋ณ€๊ฒฝ ---
result_slider = gr.ImageSlider(
label="Result Comparison (Fixed Scale vs. Your Scale)",
show_label=True
)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=SEED_MAP["SDXL"]
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
with gr.Row():
width = gr.Slider(
label="Width", minimum=256, maximum=1024, step=64, value=RESOLUTION_MAP["SDXL"]
)
height = gr.Slider(
label="Height", minimum=256, maximum=1024, step=64, value=RESOLUTION_MAP["SDXL"]
)
with gr.Row():
guidance_scale = gr.Slider(
label="TAG Scale", minimum=1.0, maximum=1.3, step=0.01, value=TAG_SCALE_MAP["SDXL"],
)
num_inference_steps = gr.Slider(
label="Inference Steps", minimum=20, maximum=50, step=1, value=50
)
with gr.Row():
guidance_start_timestep = gr.Slider(
label="Guidance Start Timestep", minimum=0, maximum=1000, step=1, value=999
)
guidance_end_timestep = gr.Slider(
label="Guidance End Timestep", minimum=0, maximum=1000, step=1, value=0
)
# --- ์ด๋ฒคํŠธ ๋ฆฌ์Šค๋„ˆ ---
model_selector.change(
fn=update_model_defaults,
inputs=[model_selector],
outputs=[width, height, seed, randomize_seed, guidance_scale],
)
# --- 3. outputs๋ฅผ ImageSlider ์ปดํฌ๋„ŒํŠธ๋กœ ์ง€์ • ---
run_button.click(
fn=infer,
inputs=[
model_selector, seed, randomize_seed, width, height,
guidance_scale, num_inference_steps,
guidance_start_timestep, guidance_end_timestep,
],
outputs=[result_slider, seed],
)
if __name__ == "__main__":
demo.launch(debug=True)