Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import numpy as np | |
| import random | |
| import torch | |
| import spaces | |
| # ์ฌ์ฉํ ํ์ดํ๋ผ์ธ๋ค์ ๋ชจ๋ import ํฉ๋๋ค. | |
| from diffusers import DiffusionPipeline | |
| # 'pipelines' ํด๋๊ฐ ์๋ค๊ณ ๊ฐ์ ํฉ๋๋ค. | |
| from pipelines.pipeline_tag_stablediffusion import StableDiffusionTangentialDecomposedPipeline | |
| from pipelines.pipeline_tag_stablediffusion3 import StableDiffusion3TangentialDecomposedPipeline | |
| from pipelines.pipeline_tag_stablediffusionXL import StableDiffusionXLTangentialDecomposedPipeline | |
| # --- ์ค์ --- | |
| MODEL_MAP = { | |
| "SD 1.5": "runwayml/stable-diffusion-v1-5", | |
| "SD 2.1": "stabilityai/stable-diffusion-2-1", | |
| "SDXL": "stabilityai/stable-diffusion-xl-base-1.0", | |
| "SD 3": "stabilityai/stable-diffusion-3-medium-diffusers", | |
| } | |
| RESOLUTION_MAP = { "SD 1.5": 512, "SD 2.1": 768, "SDXL": 1024, "SD 3": 1024 } | |
| SEED_MAP = { "SD 1.5": 850728, "SD 2.1": 944905, "SDXL": 450040818, "SD 3": 282386105 } | |
| TAG_SCALE_MAP = { | |
| "SD 1.5": 1.15, # ๊ธฐ๋ณธ๊ฐ | |
| "SD 2.1": 1.15, # ๊ธฐ๋ณธ๊ฐ | |
| "SDXL": 1.20, | |
| "SD 3": 1.08 | |
| } | |
| PIPELINE_MAP = { | |
| "SD 1.5": StableDiffusionTangentialDecomposedPipeline, | |
| "SD 2.1": StableDiffusionTangentialDecomposedPipeline, | |
| "SDXL": StableDiffusionXLTangentialDecomposedPipeline, | |
| "SD 3": StableDiffusion3TangentialDecomposedPipeline, | |
| } | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| pipe = None | |
| current_model_id = None | |
| MAX_SEED = np.iinfo(np.int32).max | |
| # --- ํจ์ --- | |
| def load_pipeline(model_name, progress): | |
| global pipe, current_model_id | |
| model_id = MODEL_MAP[model_name] | |
| pipeline_class = PIPELINE_MAP[model_name] | |
| progress(0, desc=f"Loading model: {model_id} with {pipeline_class.__name__}...") | |
| if model_name == "SD 3": | |
| pipe = pipeline_class.from_pretrained(model_id, torch_dtype=torch_dtype, | |
| text_encoder_3=None, | |
| tokenizer_3=None,) | |
| else: | |
| pipe = pipeline_class.from_pretrained(model_id, torch_dtype=torch_dtype) | |
| pipe = pipe.to(device) | |
| current_model_id = model_id | |
| progress(1) | |
| def update_model_defaults(model_name): | |
| """๋ชจ๋ธ ์ ํ์ ๋ฐ๋ผ ํด์๋, ์๋, ๋๋ค ์๋ ์ฒดํฌ๋ฐ์ค, TAG Scale์ ์ ๋ฐ์ดํธํฉ๋๋ค.""" | |
| res = RESOLUTION_MAP[model_name] | |
| seed_val = SEED_MAP[model_name] | |
| tag_scale_val = TAG_SCALE_MAP[model_name] | |
| return ( | |
| gr.update(value=res), | |
| gr.update(value=res), | |
| gr.update(value=seed_val), | |
| gr.update(value=False), # 'Randomize seed' ์ฒดํฌ ํด์ | |
| gr.update(value=tag_scale_val), # TAG Scale ์ ๋ฐ์ดํธ | |
| ) | |
| def infer( | |
| model_name, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| guidance_scale, # ์ฌ์ฉ์ ์ง์ TAG Scale | |
| num_inference_steps, | |
| guidance_start_timestep, | |
| guidance_end_timestep, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| global pipe, current_model_id | |
| model_id = MODEL_MAP[model_name] | |
| if model_id != current_model_id: | |
| gr.Info(f"Changing model to {model_name}. Please wait...") | |
| load_pipeline(model_name, progress) | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| generator_custom = torch.Generator(device=device).manual_seed(int(seed)) | |
| generator_fixed = torch.Generator(device=device).manual_seed(int(seed)) | |
| unconditional_prompt = "" | |
| # ์ฒซ ๋ฒ์งธ ์ด๋ฏธ์ง (์ฌ์ฉ์ ์ง์ TAG Scale) ์์ฑ | |
| image_custom_scale = pipe( | |
| prompt=unconditional_prompt, guidance_scale=0., | |
| num_inference_steps=num_inference_steps, width=width, height=height, generator=generator_custom, | |
| sta_tpd=guidance_start_timestep, end_tpd=guidance_end_timestep, | |
| t_guidance_scale=guidance_scale | |
| ).images[0] | |
| fixed_tag_value = 1.0 | |
| image_fixed_scale = pipe( | |
| prompt=unconditional_prompt, guidance_scale=0., | |
| num_inference_steps=num_inference_steps, width=width, height=height, generator=generator_fixed, | |
| sta_tpd=guidance_start_timestep, end_tpd=guidance_end_timestep, | |
| t_guidance_scale=fixed_tag_value | |
| ).images[0] | |
| return [image_fixed_scale, image_custom_scale], seed | |
| # --- UI ๊ตฌ์ฑ (Gradio) --- | |
| css = """ | |
| #col-container { margin: 0 auto; max-width: 720px; } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown("# Tangential Amplifying Guidance Demo") | |
| model_selector = gr.Dropdown( | |
| label="Select Model", choices=list(MODEL_MAP.keys()), value="SDXL" | |
| ) | |
| with gr.Row(): | |
| prompt = gr.Text( | |
| label="Prompt (Disabled)", show_label=True, max_lines=1, | |
| placeholder="Unconditional generation mode. This input is ignored.", | |
| container=True, interactive=False, | |
| ) | |
| run_button = gr.Button("Run", scale=0, variant="primary") | |
| # --- 2. gr.ImageSlider ์ปดํฌ๋ํธ๋ก ๋ณ๊ฒฝ --- | |
| result_slider = gr.ImageSlider( | |
| label="Result Comparison (Fixed Scale vs. Your Scale)", | |
| show_label=True | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| seed = gr.Slider( | |
| label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=SEED_MAP["SDXL"] | |
| ) | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=False) | |
| with gr.Row(): | |
| width = gr.Slider( | |
| label="Width", minimum=256, maximum=1024, step=64, value=RESOLUTION_MAP["SDXL"] | |
| ) | |
| height = gr.Slider( | |
| label="Height", minimum=256, maximum=1024, step=64, value=RESOLUTION_MAP["SDXL"] | |
| ) | |
| with gr.Row(): | |
| guidance_scale = gr.Slider( | |
| label="TAG Scale", minimum=1.0, maximum=1.3, step=0.01, value=TAG_SCALE_MAP["SDXL"], | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Inference Steps", minimum=20, maximum=50, step=1, value=50 | |
| ) | |
| with gr.Row(): | |
| guidance_start_timestep = gr.Slider( | |
| label="Guidance Start Timestep", minimum=0, maximum=1000, step=1, value=999 | |
| ) | |
| guidance_end_timestep = gr.Slider( | |
| label="Guidance End Timestep", minimum=0, maximum=1000, step=1, value=0 | |
| ) | |
| # --- ์ด๋ฒคํธ ๋ฆฌ์ค๋ --- | |
| model_selector.change( | |
| fn=update_model_defaults, | |
| inputs=[model_selector], | |
| outputs=[width, height, seed, randomize_seed, guidance_scale], | |
| ) | |
| # --- 3. outputs๋ฅผ ImageSlider ์ปดํฌ๋ํธ๋ก ์ง์ --- | |
| run_button.click( | |
| fn=infer, | |
| inputs=[ | |
| model_selector, seed, randomize_seed, width, height, | |
| guidance_scale, num_inference_steps, | |
| guidance_start_timestep, guidance_end_timestep, | |
| ], | |
| outputs=[result_slider, seed], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) |