Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	removed src folder
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- src/__init__.py +0 -0
 - src/app.py +0 -554
 - src/app_settings.py +0 -124
 - src/backend/__init__.py +0 -0
 - src/backend/annotators/canny_control.py +0 -15
 - src/backend/annotators/control_interface.py +0 -12
 - src/backend/annotators/depth_control.py +0 -15
 - src/backend/annotators/image_control_factory.py +0 -31
 - src/backend/annotators/lineart_control.py +0 -11
 - src/backend/annotators/mlsd_control.py +0 -10
 - src/backend/annotators/normal_control.py +0 -10
 - src/backend/annotators/pose_control.py +0 -10
 - src/backend/annotators/shuffle_control.py +0 -10
 - src/backend/annotators/softedge_control.py +0 -10
 - src/backend/api/mcp_server.py +0 -97
 - src/backend/api/models/response.py +0 -16
 - src/backend/api/web.py +0 -112
 - src/backend/base64_image.py +0 -21
 - src/backend/controlnet.py +0 -90
 - src/backend/device.py +0 -23
 - src/backend/gguf/gguf_diffusion.py +0 -319
 - src/backend/gguf/sdcpp_types.py +0 -104
 - src/backend/image_saver.py +0 -75
 - src/backend/lcm_text_to_image.py +0 -577
 - src/backend/lora.py +0 -136
 - src/backend/models/device.py +0 -9
 - src/backend/models/gen_images.py +0 -17
 - src/backend/models/lcmdiffusion_setting.py +0 -76
 - src/backend/models/upscale.py +0 -9
 - src/backend/openvino/custom_ov_model_vae_decoder.py +0 -21
 - src/backend/openvino/flux_pipeline.py +0 -36
 - src/backend/openvino/ov_hc_stablediffusion_pipeline.py +0 -93
 - src/backend/openvino/ovflux.py +0 -675
 - src/backend/openvino/pipelines.py +0 -75
 - src/backend/openvino/stable_diffusion_engine.py +0 -1817
 - src/backend/pipelines/lcm.py +0 -122
 - src/backend/pipelines/lcm_lora.py +0 -81
 - src/backend/tiny_decoder.py +0 -32
 - src/backend/upscale/aura_sr.py +0 -1004
 - src/backend/upscale/aura_sr_upscale.py +0 -9
 - src/backend/upscale/edsr_upscale_onnx.py +0 -37
 - src/backend/upscale/tiled_upscale.py +0 -237
 - src/backend/upscale/upscaler.py +0 -52
 - src/constants.py +0 -25
 - src/context.py +0 -85
 - src/frontend/cli_interactive.py +0 -661
 - src/frontend/gui/app_window.py +0 -595
 - src/frontend/gui/base_widget.py +0 -199
 - src/frontend/gui/image_generator_worker.py +0 -37
 - src/frontend/gui/image_variations_widget.py +0 -35
 
    	
        src/__init__.py
    DELETED
    
    | 
         
            File without changes
         
     | 
    	
        src/app.py
    DELETED
    
    | 
         @@ -1,554 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            import json
         
     | 
| 2 | 
         
            -
            from argparse import ArgumentParser
         
     | 
| 3 | 
         
            -
             
     | 
| 4 | 
         
            -
            from PIL import Image
         
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
            import constants
         
     | 
| 7 | 
         
            -
            from backend.controlnet import controlnet_settings_from_dict
         
     | 
| 8 | 
         
            -
            from backend.device import get_device_name
         
     | 
| 9 | 
         
            -
            from backend.models.gen_images import ImageFormat
         
     | 
| 10 | 
         
            -
            from backend.models.lcmdiffusion_setting import DiffusionTask
         
     | 
| 11 | 
         
            -
            from backend.upscale.tiled_upscale import generate_upscaled_image
         
     | 
| 12 | 
         
            -
            from constants import APP_VERSION, DEVICE
         
     | 
| 13 | 
         
            -
            from frontend.webui.image_variations_ui import generate_image_variations
         
     | 
| 14 | 
         
            -
            from models.interface_types import InterfaceType
         
     | 
| 15 | 
         
            -
            from paths import FastStableDiffusionPaths, ensure_path
         
     | 
| 16 | 
         
            -
            from state import get_context, get_settings
         
     | 
| 17 | 
         
            -
            from utils import show_system_info
         
     | 
| 18 | 
         
            -
             
     | 
| 19 | 
         
            -
            parser = ArgumentParser(description=f"FAST SD CPU {constants.APP_VERSION}")
         
     | 
| 20 | 
         
            -
            parser.add_argument(
         
     | 
| 21 | 
         
            -
                "-s",
         
     | 
| 22 | 
         
            -
                "--share",
         
     | 
| 23 | 
         
            -
                action="store_true",
         
     | 
| 24 | 
         
            -
                help="Create sharable link(Web UI)",
         
     | 
| 25 | 
         
            -
                required=False,
         
     | 
| 26 | 
         
            -
            )
         
     | 
| 27 | 
         
            -
            group = parser.add_mutually_exclusive_group(required=False)
         
     | 
| 28 | 
         
            -
            group.add_argument(
         
     | 
| 29 | 
         
            -
                "-g",
         
     | 
| 30 | 
         
            -
                "--gui",
         
     | 
| 31 | 
         
            -
                action="store_true",
         
     | 
| 32 | 
         
            -
                help="Start desktop GUI",
         
     | 
| 33 | 
         
            -
            )
         
     | 
| 34 | 
         
            -
            group.add_argument(
         
     | 
| 35 | 
         
            -
                "-w",
         
     | 
| 36 | 
         
            -
                "--webui",
         
     | 
| 37 | 
         
            -
                action="store_true",
         
     | 
| 38 | 
         
            -
                help="Start Web UI",
         
     | 
| 39 | 
         
            -
            )
         
     | 
| 40 | 
         
            -
            group.add_argument(
         
     | 
| 41 | 
         
            -
                "-a",
         
     | 
| 42 | 
         
            -
                "--api",
         
     | 
| 43 | 
         
            -
                action="store_true",
         
     | 
| 44 | 
         
            -
                help="Start Web API server",
         
     | 
| 45 | 
         
            -
            )
         
     | 
| 46 | 
         
            -
            group.add_argument(
         
     | 
| 47 | 
         
            -
                "-m",
         
     | 
| 48 | 
         
            -
                "--mcp",
         
     | 
| 49 | 
         
            -
                action="store_true",
         
     | 
| 50 | 
         
            -
                help="Start MCP(Model Context Protocol) server",
         
     | 
| 51 | 
         
            -
            )
         
     | 
| 52 | 
         
            -
            group.add_argument(
         
     | 
| 53 | 
         
            -
                "-r",
         
     | 
| 54 | 
         
            -
                "--realtime",
         
     | 
| 55 | 
         
            -
                action="store_true",
         
     | 
| 56 | 
         
            -
                help="Start realtime inference UI(experimental)",
         
     | 
| 57 | 
         
            -
            )
         
     | 
| 58 | 
         
            -
            group.add_argument(
         
     | 
| 59 | 
         
            -
                "-v",
         
     | 
| 60 | 
         
            -
                "--version",
         
     | 
| 61 | 
         
            -
                action="store_true",
         
     | 
| 62 | 
         
            -
                help="Version",
         
     | 
| 63 | 
         
            -
            )
         
     | 
| 64 | 
         
            -
             
     | 
| 65 | 
         
            -
            parser.add_argument(
         
     | 
| 66 | 
         
            -
                "-b",
         
     | 
| 67 | 
         
            -
                "--benchmark",
         
     | 
| 68 | 
         
            -
                action="store_true",
         
     | 
| 69 | 
         
            -
                help="Run inference benchmark on the selected device",
         
     | 
| 70 | 
         
            -
            )
         
     | 
| 71 | 
         
            -
            parser.add_argument(
         
     | 
| 72 | 
         
            -
                "--lcm_model_id",
         
     | 
| 73 | 
         
            -
                type=str,
         
     | 
| 74 | 
         
            -
                help="Model ID or path,Default stabilityai/sd-turbo",
         
     | 
| 75 | 
         
            -
                default="stabilityai/sd-turbo",
         
     | 
| 76 | 
         
            -
            )
         
     | 
| 77 | 
         
            -
            parser.add_argument(
         
     | 
| 78 | 
         
            -
                "--openvino_lcm_model_id",
         
     | 
| 79 | 
         
            -
                type=str,
         
     | 
| 80 | 
         
            -
                help="OpenVINO Model ID or path,Default rupeshs/sd-turbo-openvino",
         
     | 
| 81 | 
         
            -
                default="rupeshs/sd-turbo-openvino",
         
     | 
| 82 | 
         
            -
            )
         
     | 
| 83 | 
         
            -
            parser.add_argument(
         
     | 
| 84 | 
         
            -
                "--prompt",
         
     | 
| 85 | 
         
            -
                type=str,
         
     | 
| 86 | 
         
            -
                help="Describe the image you want to generate",
         
     | 
| 87 | 
         
            -
                default="",
         
     | 
| 88 | 
         
            -
            )
         
     | 
| 89 | 
         
            -
            parser.add_argument(
         
     | 
| 90 | 
         
            -
                "--negative_prompt",
         
     | 
| 91 | 
         
            -
                type=str,
         
     | 
| 92 | 
         
            -
                help="Describe what you want to exclude from the generation",
         
     | 
| 93 | 
         
            -
                default="",
         
     | 
| 94 | 
         
            -
            )
         
     | 
| 95 | 
         
            -
            parser.add_argument(
         
     | 
| 96 | 
         
            -
                "--image_height",
         
     | 
| 97 | 
         
            -
                type=int,
         
     | 
| 98 | 
         
            -
                help="Height of the image",
         
     | 
| 99 | 
         
            -
                default=512,
         
     | 
| 100 | 
         
            -
            )
         
     | 
| 101 | 
         
            -
            parser.add_argument(
         
     | 
| 102 | 
         
            -
                "--image_width",
         
     | 
| 103 | 
         
            -
                type=int,
         
     | 
| 104 | 
         
            -
                help="Width of the image",
         
     | 
| 105 | 
         
            -
                default=512,
         
     | 
| 106 | 
         
            -
            )
         
     | 
| 107 | 
         
            -
            parser.add_argument(
         
     | 
| 108 | 
         
            -
                "--inference_steps",
         
     | 
| 109 | 
         
            -
                type=int,
         
     | 
| 110 | 
         
            -
                help="Number of steps,default : 1",
         
     | 
| 111 | 
         
            -
                default=1,
         
     | 
| 112 | 
         
            -
            )
         
     | 
| 113 | 
         
            -
            parser.add_argument(
         
     | 
| 114 | 
         
            -
                "--guidance_scale",
         
     | 
| 115 | 
         
            -
                type=float,
         
     | 
| 116 | 
         
            -
                help="Guidance scale,default : 1.0",
         
     | 
| 117 | 
         
            -
                default=1.0,
         
     | 
| 118 | 
         
            -
            )
         
     | 
| 119 | 
         
            -
             
     | 
| 120 | 
         
            -
            parser.add_argument(
         
     | 
| 121 | 
         
            -
                "--number_of_images",
         
     | 
| 122 | 
         
            -
                type=int,
         
     | 
| 123 | 
         
            -
                help="Number of images to generate ,default : 1",
         
     | 
| 124 | 
         
            -
                default=1,
         
     | 
| 125 | 
         
            -
            )
         
     | 
| 126 | 
         
            -
            parser.add_argument(
         
     | 
| 127 | 
         
            -
                "--seed",
         
     | 
| 128 | 
         
            -
                type=int,
         
     | 
| 129 | 
         
            -
                help="Seed,default : -1 (disabled) ",
         
     | 
| 130 | 
         
            -
                default=-1,
         
     | 
| 131 | 
         
            -
            )
         
     | 
| 132 | 
         
            -
            parser.add_argument(
         
     | 
| 133 | 
         
            -
                "--use_openvino",
         
     | 
| 134 | 
         
            -
                action="store_true",
         
     | 
| 135 | 
         
            -
                help="Use OpenVINO model",
         
     | 
| 136 | 
         
            -
            )
         
     | 
| 137 | 
         
            -
             
     | 
| 138 | 
         
            -
            parser.add_argument(
         
     | 
| 139 | 
         
            -
                "--use_offline_model",
         
     | 
| 140 | 
         
            -
                action="store_true",
         
     | 
| 141 | 
         
            -
                help="Use offline model",
         
     | 
| 142 | 
         
            -
            )
         
     | 
| 143 | 
         
            -
            parser.add_argument(
         
     | 
| 144 | 
         
            -
                "--clip_skip",
         
     | 
| 145 | 
         
            -
                type=int,
         
     | 
| 146 | 
         
            -
                help="CLIP Skip (1-12), default : 1 (disabled) ",
         
     | 
| 147 | 
         
            -
                default=1,
         
     | 
| 148 | 
         
            -
            )
         
     | 
| 149 | 
         
            -
            parser.add_argument(
         
     | 
| 150 | 
         
            -
                "--token_merging",
         
     | 
| 151 | 
         
            -
                type=float,
         
     | 
| 152 | 
         
            -
                help="Token merging scale, 0.0 - 1.0, default : 0.0",
         
     | 
| 153 | 
         
            -
                default=0.0,
         
     | 
| 154 | 
         
            -
            )
         
     | 
| 155 | 
         
            -
             
     | 
| 156 | 
         
            -
            parser.add_argument(
         
     | 
| 157 | 
         
            -
                "--use_safety_checker",
         
     | 
| 158 | 
         
            -
                action="store_true",
         
     | 
| 159 | 
         
            -
                help="Use safety checker",
         
     | 
| 160 | 
         
            -
            )
         
     | 
| 161 | 
         
            -
            parser.add_argument(
         
     | 
| 162 | 
         
            -
                "--use_lcm_lora",
         
     | 
| 163 | 
         
            -
                action="store_true",
         
     | 
| 164 | 
         
            -
                help="Use LCM-LoRA",
         
     | 
| 165 | 
         
            -
            )
         
     | 
| 166 | 
         
            -
            parser.add_argument(
         
     | 
| 167 | 
         
            -
                "--base_model_id",
         
     | 
| 168 | 
         
            -
                type=str,
         
     | 
| 169 | 
         
            -
                help="LCM LoRA base model ID,Default Lykon/dreamshaper-8",
         
     | 
| 170 | 
         
            -
                default="Lykon/dreamshaper-8",
         
     | 
| 171 | 
         
            -
            )
         
     | 
| 172 | 
         
            -
            parser.add_argument(
         
     | 
| 173 | 
         
            -
                "--lcm_lora_id",
         
     | 
| 174 | 
         
            -
                type=str,
         
     | 
| 175 | 
         
            -
                help="LCM LoRA model ID,Default latent-consistency/lcm-lora-sdv1-5",
         
     | 
| 176 | 
         
            -
                default="latent-consistency/lcm-lora-sdv1-5",
         
     | 
| 177 | 
         
            -
            )
         
     | 
| 178 | 
         
            -
            parser.add_argument(
         
     | 
| 179 | 
         
            -
                "-i",
         
     | 
| 180 | 
         
            -
                "--interactive",
         
     | 
| 181 | 
         
            -
                action="store_true",
         
     | 
| 182 | 
         
            -
                help="Interactive CLI mode",
         
     | 
| 183 | 
         
            -
            )
         
     | 
| 184 | 
         
            -
            parser.add_argument(
         
     | 
| 185 | 
         
            -
                "-t",
         
     | 
| 186 | 
         
            -
                "--use_tiny_auto_encoder",
         
     | 
| 187 | 
         
            -
                action="store_true",
         
     | 
| 188 | 
         
            -
                help="Use tiny auto encoder for SD (TAESD)",
         
     | 
| 189 | 
         
            -
            )
         
     | 
| 190 | 
         
            -
            parser.add_argument(
         
     | 
| 191 | 
         
            -
                "-f",
         
     | 
| 192 | 
         
            -
                "--file",
         
     | 
| 193 | 
         
            -
                type=str,
         
     | 
| 194 | 
         
            -
                help="Input image for img2img mode",
         
     | 
| 195 | 
         
            -
                default="",
         
     | 
| 196 | 
         
            -
            )
         
     | 
| 197 | 
         
            -
            parser.add_argument(
         
     | 
| 198 | 
         
            -
                "--img2img",
         
     | 
| 199 | 
         
            -
                action="store_true",
         
     | 
| 200 | 
         
            -
                help="img2img mode; requires input file via -f argument",
         
     | 
| 201 | 
         
            -
            )
         
     | 
| 202 | 
         
            -
            parser.add_argument(
         
     | 
| 203 | 
         
            -
                "--batch_count",
         
     | 
| 204 | 
         
            -
                type=int,
         
     | 
| 205 | 
         
            -
                help="Number of sequential generations",
         
     | 
| 206 | 
         
            -
                default=1,
         
     | 
| 207 | 
         
            -
            )
         
     | 
| 208 | 
         
            -
            parser.add_argument(
         
     | 
| 209 | 
         
            -
                "--strength",
         
     | 
| 210 | 
         
            -
                type=float,
         
     | 
| 211 | 
         
            -
                help="Denoising strength for img2img and Image variations",
         
     | 
| 212 | 
         
            -
                default=0.3,
         
     | 
| 213 | 
         
            -
            )
         
     | 
| 214 | 
         
            -
            parser.add_argument(
         
     | 
| 215 | 
         
            -
                "--sdupscale",
         
     | 
| 216 | 
         
            -
                action="store_true",
         
     | 
| 217 | 
         
            -
                help="Tiled SD upscale,works only for the resolution 512x512,(2x upscale)",
         
     | 
| 218 | 
         
            -
            )
         
     | 
| 219 | 
         
            -
            parser.add_argument(
         
     | 
| 220 | 
         
            -
                "--upscale",
         
     | 
| 221 | 
         
            -
                action="store_true",
         
     | 
| 222 | 
         
            -
                help="EDSR SD upscale ",
         
     | 
| 223 | 
         
            -
            )
         
     | 
| 224 | 
         
            -
            parser.add_argument(
         
     | 
| 225 | 
         
            -
                "--custom_settings",
         
     | 
| 226 | 
         
            -
                type=str,
         
     | 
| 227 | 
         
            -
                help="JSON file containing custom generation settings",
         
     | 
| 228 | 
         
            -
                default=None,
         
     | 
| 229 | 
         
            -
            )
         
     | 
| 230 | 
         
            -
            parser.add_argument(
         
     | 
| 231 | 
         
            -
                "--usejpeg",
         
     | 
| 232 | 
         
            -
                action="store_true",
         
     | 
| 233 | 
         
            -
                help="Images will be saved as JPEG format",
         
     | 
| 234 | 
         
            -
            )
         
     | 
| 235 | 
         
            -
            parser.add_argument(
         
     | 
| 236 | 
         
            -
                "--noimagesave",
         
     | 
| 237 | 
         
            -
                action="store_true",
         
     | 
| 238 | 
         
            -
                help="Disable image saving",
         
     | 
| 239 | 
         
            -
            )
         
     | 
| 240 | 
         
            -
            parser.add_argument(
         
     | 
| 241 | 
         
            -
                "--imagequality", type=int, help="Output image quality [0 to 100]", default=90
         
     | 
| 242 | 
         
            -
            )
         
     | 
| 243 | 
         
            -
            parser.add_argument(
         
     | 
| 244 | 
         
            -
                "--lora",
         
     | 
| 245 | 
         
            -
                type=str,
         
     | 
| 246 | 
         
            -
                help="LoRA model full path e.g D:\lora_models\CuteCartoon15V-LiberteRedmodModel-Cartoon-CuteCartoonAF.safetensors",
         
     | 
| 247 | 
         
            -
                default=None,
         
     | 
| 248 | 
         
            -
            )
         
     | 
| 249 | 
         
            -
            parser.add_argument(
         
     | 
| 250 | 
         
            -
                "--lora_weight",
         
     | 
| 251 | 
         
            -
                type=float,
         
     | 
| 252 | 
         
            -
                help="LoRA adapter weight [0 to 1.0]",
         
     | 
| 253 | 
         
            -
                default=0.5,
         
     | 
| 254 | 
         
            -
            )
         
     | 
| 255 | 
         
            -
            parser.add_argument(
         
     | 
| 256 | 
         
            -
                "--port",
         
     | 
| 257 | 
         
            -
                type=int,
         
     | 
| 258 | 
         
            -
                help="Web server port",
         
     | 
| 259 | 
         
            -
                default=8000,
         
     | 
| 260 | 
         
            -
            )
         
     | 
| 261 | 
         
            -
             
     | 
| 262 | 
         
            -
            args = parser.parse_args()
         
     | 
| 263 | 
         
            -
             
     | 
| 264 | 
         
            -
            if args.version:
         
     | 
| 265 | 
         
            -
                print(APP_VERSION)
         
     | 
| 266 | 
         
            -
                exit()
         
     | 
| 267 | 
         
            -
             
     | 
| 268 | 
         
            -
            # parser.print_help()
         
     | 
| 269 | 
         
            -
            print("FastSD CPU - ", APP_VERSION)
         
     | 
| 270 | 
         
            -
            show_system_info()
         
     | 
| 271 | 
         
            -
            print(f"Using device : {constants.DEVICE}")
         
     | 
| 272 | 
         
            -
             
     | 
| 273 | 
         
            -
             
     | 
| 274 | 
         
            -
            if args.webui:
         
     | 
| 275 | 
         
            -
                app_settings = get_settings()
         
     | 
| 276 | 
         
            -
            else:
         
     | 
| 277 | 
         
            -
                app_settings = get_settings()
         
     | 
| 278 | 
         
            -
             
     | 
| 279 | 
         
            -
            print(f"Output path : {app_settings.settings.generated_images.path}")
         
     | 
| 280 | 
         
            -
            ensure_path(app_settings.settings.generated_images.path)
         
     | 
| 281 | 
         
            -
             
     | 
| 282 | 
         
            -
            print(f"Found {len(app_settings.lcm_models)} LCM models in config/lcm-models.txt")
         
     | 
| 283 | 
         
            -
            print(
         
     | 
| 284 | 
         
            -
                f"Found {len(app_settings.stable_diffsuion_models)} stable diffusion models in config/stable-diffusion-models.txt"
         
     | 
| 285 | 
         
            -
            )
         
     | 
| 286 | 
         
            -
            print(
         
     | 
| 287 | 
         
            -
                f"Found {len(app_settings.lcm_lora_models)} LCM-LoRA models in config/lcm-lora-models.txt"
         
     | 
| 288 | 
         
            -
            )
         
     | 
| 289 | 
         
            -
            print(
         
     | 
| 290 | 
         
            -
                f"Found {len(app_settings.openvino_lcm_models)} OpenVINO LCM models in config/openvino-lcm-models.txt"
         
     | 
| 291 | 
         
            -
            )
         
     | 
| 292 | 
         
            -
             
     | 
| 293 | 
         
            -
            if args.noimagesave:
         
     | 
| 294 | 
         
            -
                app_settings.settings.generated_images.save_image = False
         
     | 
| 295 | 
         
            -
            else:
         
     | 
| 296 | 
         
            -
                app_settings.settings.generated_images.save_image = True
         
     | 
| 297 | 
         
            -
             
     | 
| 298 | 
         
            -
            app_settings.settings.generated_images.save_image_quality = args.imagequality
         
     | 
| 299 | 
         
            -
             
     | 
| 300 | 
         
            -
            if not args.realtime:
         
     | 
| 301 | 
         
            -
                # To minimize realtime mode dependencies
         
     | 
| 302 | 
         
            -
                from backend.upscale.upscaler import upscale_image
         
     | 
| 303 | 
         
            -
                from frontend.cli_interactive import interactive_mode
         
     | 
| 304 | 
         
            -
             
     | 
| 305 | 
         
            -
            if args.gui:
         
     | 
| 306 | 
         
            -
                from frontend.gui.ui import start_gui
         
     | 
| 307 | 
         
            -
             
     | 
| 308 | 
         
            -
                print("Starting desktop GUI mode(Qt)")
         
     | 
| 309 | 
         
            -
                start_gui(
         
     | 
| 310 | 
         
            -
                    [],
         
     | 
| 311 | 
         
            -
                    app_settings,
         
     | 
| 312 | 
         
            -
                )
         
     | 
| 313 | 
         
            -
            elif args.webui:
         
     | 
| 314 | 
         
            -
                from frontend.webui.ui import start_webui
         
     | 
| 315 | 
         
            -
             
     | 
| 316 | 
         
            -
                print("Starting web UI mode")
         
     | 
| 317 | 
         
            -
                start_webui(
         
     | 
| 318 | 
         
            -
                    args.share,
         
     | 
| 319 | 
         
            -
                )
         
     | 
| 320 | 
         
            -
            elif args.realtime:
         
     | 
| 321 | 
         
            -
                from frontend.webui.realtime_ui import start_realtime_text_to_image
         
     | 
| 322 | 
         
            -
             
     | 
| 323 | 
         
            -
                print("Starting realtime text to image(EXPERIMENTAL)")
         
     | 
| 324 | 
         
            -
                start_realtime_text_to_image(args.share)
         
     | 
| 325 | 
         
            -
            elif args.api:
         
     | 
| 326 | 
         
            -
                from backend.api.web import start_web_server
         
     | 
| 327 | 
         
            -
             
     | 
| 328 | 
         
            -
                start_web_server(args.port)
         
     | 
| 329 | 
         
            -
            elif args.mcp:
         
     | 
| 330 | 
         
            -
                from backend.api.mcp_server import start_mcp_server
         
     | 
| 331 | 
         
            -
             
     | 
| 332 | 
         
            -
                start_mcp_server(args.port)
         
     | 
| 333 | 
         
            -
            else:
         
     | 
| 334 | 
         
            -
                context = get_context(InterfaceType.CLI)
         
     | 
| 335 | 
         
            -
                config = app_settings.settings
         
     | 
| 336 | 
         
            -
             
     | 
| 337 | 
         
            -
                if args.use_openvino:
         
     | 
| 338 | 
         
            -
                    config.lcm_diffusion_setting.openvino_lcm_model_id = args.openvino_lcm_model_id
         
     | 
| 339 | 
         
            -
                else:
         
     | 
| 340 | 
         
            -
                    config.lcm_diffusion_setting.lcm_model_id = args.lcm_model_id
         
     | 
| 341 | 
         
            -
             
     | 
| 342 | 
         
            -
                config.lcm_diffusion_setting.prompt = args.prompt
         
     | 
| 343 | 
         
            -
                config.lcm_diffusion_setting.negative_prompt = args.negative_prompt
         
     | 
| 344 | 
         
            -
                config.lcm_diffusion_setting.image_height = args.image_height
         
     | 
| 345 | 
         
            -
                config.lcm_diffusion_setting.image_width = args.image_width
         
     | 
| 346 | 
         
            -
                config.lcm_diffusion_setting.guidance_scale = args.guidance_scale
         
     | 
| 347 | 
         
            -
                config.lcm_diffusion_setting.number_of_images = args.number_of_images
         
     | 
| 348 | 
         
            -
                config.lcm_diffusion_setting.inference_steps = args.inference_steps
         
     | 
| 349 | 
         
            -
                config.lcm_diffusion_setting.strength = args.strength
         
     | 
| 350 | 
         
            -
                config.lcm_diffusion_setting.seed = args.seed
         
     | 
| 351 | 
         
            -
                config.lcm_diffusion_setting.use_openvino = args.use_openvino
         
     | 
| 352 | 
         
            -
                config.lcm_diffusion_setting.use_tiny_auto_encoder = args.use_tiny_auto_encoder
         
     | 
| 353 | 
         
            -
                config.lcm_diffusion_setting.use_lcm_lora = args.use_lcm_lora
         
     | 
| 354 | 
         
            -
                config.lcm_diffusion_setting.lcm_lora.base_model_id = args.base_model_id
         
     | 
| 355 | 
         
            -
                config.lcm_diffusion_setting.lcm_lora.lcm_lora_id = args.lcm_lora_id
         
     | 
| 356 | 
         
            -
                config.lcm_diffusion_setting.diffusion_task = DiffusionTask.text_to_image.value
         
     | 
| 357 | 
         
            -
                config.lcm_diffusion_setting.lora.enabled = False
         
     | 
| 358 | 
         
            -
                config.lcm_diffusion_setting.lora.path = args.lora
         
     | 
| 359 | 
         
            -
                config.lcm_diffusion_setting.lora.weight = args.lora_weight
         
     | 
| 360 | 
         
            -
                config.lcm_diffusion_setting.lora.fuse = True
         
     | 
| 361 | 
         
            -
                if config.lcm_diffusion_setting.lora.path:
         
     | 
| 362 | 
         
            -
                    config.lcm_diffusion_setting.lora.enabled = True
         
     | 
| 363 | 
         
            -
                if args.usejpeg:
         
     | 
| 364 | 
         
            -
                    config.generated_images.format = ImageFormat.JPEG.value.upper()
         
     | 
| 365 | 
         
            -
                if args.seed > -1:
         
     | 
| 366 | 
         
            -
                    config.lcm_diffusion_setting.use_seed = True
         
     | 
| 367 | 
         
            -
                else:
         
     | 
| 368 | 
         
            -
                    config.lcm_diffusion_setting.use_seed = False
         
     | 
| 369 | 
         
            -
                config.lcm_diffusion_setting.use_offline_model = args.use_offline_model
         
     | 
| 370 | 
         
            -
                config.lcm_diffusion_setting.clip_skip = args.clip_skip
         
     | 
| 371 | 
         
            -
                config.lcm_diffusion_setting.token_merging = args.token_merging
         
     | 
| 372 | 
         
            -
                config.lcm_diffusion_setting.use_safety_checker = args.use_safety_checker
         
     | 
| 373 | 
         
            -
             
     | 
| 374 | 
         
            -
                # Read custom settings from JSON file
         
     | 
| 375 | 
         
            -
                custom_settings = {}
         
     | 
| 376 | 
         
            -
                if args.custom_settings:
         
     | 
| 377 | 
         
            -
                    with open(args.custom_settings) as f:
         
     | 
| 378 | 
         
            -
                        custom_settings = json.load(f)
         
     | 
| 379 | 
         
            -
             
     | 
| 380 | 
         
            -
                # Basic ControlNet settings; if ControlNet is enabled, an image is
         
     | 
| 381 | 
         
            -
                # required even in txt2img mode
         
     | 
| 382 | 
         
            -
                config.lcm_diffusion_setting.controlnet = None
         
     | 
| 383 | 
         
            -
                controlnet_settings_from_dict(
         
     | 
| 384 | 
         
            -
                    config.lcm_diffusion_setting,
         
     | 
| 385 | 
         
            -
                    custom_settings,
         
     | 
| 386 | 
         
            -
                )
         
     | 
| 387 | 
         
            -
             
     | 
| 388 | 
         
            -
                # Interactive mode
         
     | 
| 389 | 
         
            -
                if args.interactive:
         
     | 
| 390 | 
         
            -
                    # wrapper(interactive_mode, config, context)
         
     | 
| 391 | 
         
            -
                    config.lcm_diffusion_setting.lora.fuse = False
         
     | 
| 392 | 
         
            -
                    interactive_mode(config, context)
         
     | 
| 393 | 
         
            -
             
     | 
| 394 | 
         
            -
                # Start of non-interactive CLI image generation
         
     | 
| 395 | 
         
            -
                if args.img2img and args.file != "":
         
     | 
| 396 | 
         
            -
                    config.lcm_diffusion_setting.init_image = Image.open(args.file)
         
     | 
| 397 | 
         
            -
                    config.lcm_diffusion_setting.diffusion_task = DiffusionTask.image_to_image.value
         
     | 
| 398 | 
         
            -
                elif args.img2img and args.file == "":
         
     | 
| 399 | 
         
            -
                    print("Error : You need to specify a file in img2img mode")
         
     | 
| 400 | 
         
            -
                    exit()
         
     | 
| 401 | 
         
            -
                elif args.upscale and args.file == "" and args.custom_settings == None:
         
     | 
| 402 | 
         
            -
                    print("Error : You need to specify a file in SD upscale mode")
         
     | 
| 403 | 
         
            -
                    exit()
         
     | 
| 404 | 
         
            -
                elif (
         
     | 
| 405 | 
         
            -
                    args.prompt == ""
         
     | 
| 406 | 
         
            -
                    and args.file == ""
         
     | 
| 407 | 
         
            -
                    and args.custom_settings == None
         
     | 
| 408 | 
         
            -
                    and not args.benchmark
         
     | 
| 409 | 
         
            -
                ):
         
     | 
| 410 | 
         
            -
                    print("Error : You need to provide a prompt")
         
     | 
| 411 | 
         
            -
                    exit()
         
     | 
| 412 | 
         
            -
             
     | 
| 413 | 
         
            -
                if args.upscale:
         
     | 
| 414 | 
         
            -
                    # image = Image.open(args.file)
         
     | 
| 415 | 
         
            -
                    output_path = FastStableDiffusionPaths.get_upscale_filepath(
         
     | 
| 416 | 
         
            -
                        args.file,
         
     | 
| 417 | 
         
            -
                        2,
         
     | 
| 418 | 
         
            -
                        config.generated_images.format,
         
     | 
| 419 | 
         
            -
                    )
         
     | 
| 420 | 
         
            -
                    result = upscale_image(
         
     | 
| 421 | 
         
            -
                        context,
         
     | 
| 422 | 
         
            -
                        args.file,
         
     | 
| 423 | 
         
            -
                        output_path,
         
     | 
| 424 | 
         
            -
                        2,
         
     | 
| 425 | 
         
            -
                    )
         
     | 
| 426 | 
         
            -
                # Perform Tiled SD upscale (EXPERIMENTAL)
         
     | 
| 427 | 
         
            -
                elif args.sdupscale:
         
     | 
| 428 | 
         
            -
                    if args.use_openvino:
         
     | 
| 429 | 
         
            -
                        config.lcm_diffusion_setting.strength = 0.3
         
     | 
| 430 | 
         
            -
                    upscale_settings = None
         
     | 
| 431 | 
         
            -
                    if custom_settings != {}:
         
     | 
| 432 | 
         
            -
                        upscale_settings = custom_settings
         
     | 
| 433 | 
         
            -
                    filepath = args.file
         
     | 
| 434 | 
         
            -
                    output_format = config.generated_images.format
         
     | 
| 435 | 
         
            -
                    if upscale_settings:
         
     | 
| 436 | 
         
            -
                        filepath = upscale_settings["source_file"]
         
     | 
| 437 | 
         
            -
                        output_format = upscale_settings["output_format"].upper()
         
     | 
| 438 | 
         
            -
                    output_path = FastStableDiffusionPaths.get_upscale_filepath(
         
     | 
| 439 | 
         
            -
                        filepath,
         
     | 
| 440 | 
         
            -
                        2,
         
     | 
| 441 | 
         
            -
                        output_format,
         
     | 
| 442 | 
         
            -
                    )
         
     | 
| 443 | 
         
            -
             
     | 
| 444 | 
         
            -
                    generate_upscaled_image(
         
     | 
| 445 | 
         
            -
                        config,
         
     | 
| 446 | 
         
            -
                        filepath,
         
     | 
| 447 | 
         
            -
                        config.lcm_diffusion_setting.strength,
         
     | 
| 448 | 
         
            -
                        upscale_settings=upscale_settings,
         
     | 
| 449 | 
         
            -
                        context=context,
         
     | 
| 450 | 
         
            -
                        tile_overlap=32 if config.lcm_diffusion_setting.use_openvino else 16,
         
     | 
| 451 | 
         
            -
                        output_path=output_path,
         
     | 
| 452 | 
         
            -
                        image_format=output_format,
         
     | 
| 453 | 
         
            -
                    )
         
     | 
| 454 | 
         
            -
                    exit()
         
     | 
| 455 | 
         
            -
                # If img2img argument is set and prompt is empty, use image variations mode
         
     | 
| 456 | 
         
            -
                elif args.img2img and args.prompt == "":
         
     | 
| 457 | 
         
            -
                    for i in range(0, args.batch_count):
         
     | 
| 458 | 
         
            -
                        generate_image_variations(
         
     | 
| 459 | 
         
            -
                            config.lcm_diffusion_setting.init_image, args.strength
         
     | 
| 460 | 
         
            -
                        )
         
     | 
| 461 | 
         
            -
                else:
         
     | 
| 462 | 
         
            -
                    if args.benchmark:
         
     | 
| 463 | 
         
            -
                        print("Initializing benchmark...")
         
     | 
| 464 | 
         
            -
                        bench_lcm_setting = config.lcm_diffusion_setting
         
     | 
| 465 | 
         
            -
                        bench_lcm_setting.prompt = "a cat"
         
     | 
| 466 | 
         
            -
                        bench_lcm_setting.use_tiny_auto_encoder = False
         
     | 
| 467 | 
         
            -
                        context.generate_text_to_image(
         
     | 
| 468 | 
         
            -
                            settings=config,
         
     | 
| 469 | 
         
            -
                            device=DEVICE,
         
     | 
| 470 | 
         
            -
                        )
         
     | 
| 471 | 
         
            -
             
     | 
| 472 | 
         
            -
                        latencies = []
         
     | 
| 473 | 
         
            -
             
     | 
| 474 | 
         
            -
                        print("Starting benchmark please wait...")
         
     | 
| 475 | 
         
            -
                        for _ in range(3):
         
     | 
| 476 | 
         
            -
                            context.generate_text_to_image(
         
     | 
| 477 | 
         
            -
                                settings=config,
         
     | 
| 478 | 
         
            -
                                device=DEVICE,
         
     | 
| 479 | 
         
            -
                            )
         
     | 
| 480 | 
         
            -
                            latencies.append(context.latency)
         
     | 
| 481 | 
         
            -
             
     | 
| 482 | 
         
            -
                        avg_latency = sum(latencies) / 3
         
     | 
| 483 | 
         
            -
             
     | 
| 484 | 
         
            -
                        bench_lcm_setting.use_tiny_auto_encoder = True
         
     | 
| 485 | 
         
            -
             
     | 
| 486 | 
         
            -
                        context.generate_text_to_image(
         
     | 
| 487 | 
         
            -
                            settings=config,
         
     | 
| 488 | 
         
            -
                            device=DEVICE,
         
     | 
| 489 | 
         
            -
                        )
         
     | 
| 490 | 
         
            -
                        latencies = []
         
     | 
| 491 | 
         
            -
                        for _ in range(3):
         
     | 
| 492 | 
         
            -
                            context.generate_text_to_image(
         
     | 
| 493 | 
         
            -
                                settings=config,
         
     | 
| 494 | 
         
            -
                                device=DEVICE,
         
     | 
| 495 | 
         
            -
                            )
         
     | 
| 496 | 
         
            -
                            latencies.append(context.latency)
         
     | 
| 497 | 
         
            -
             
     | 
| 498 | 
         
            -
                        avg_latency_taesd = sum(latencies) / 3
         
     | 
| 499 | 
         
            -
             
     | 
| 500 | 
         
            -
                        benchmark_name = ""
         
     | 
| 501 | 
         
            -
             
     | 
| 502 | 
         
            -
                        if config.lcm_diffusion_setting.use_openvino:
         
     | 
| 503 | 
         
            -
                            benchmark_name = "OpenVINO"
         
     | 
| 504 | 
         
            -
                        else:
         
     | 
| 505 | 
         
            -
                            benchmark_name = "PyTorch"
         
     | 
| 506 | 
         
            -
             
     | 
| 507 | 
         
            -
                        bench_model_id = ""
         
     | 
| 508 | 
         
            -
                        if bench_lcm_setting.use_openvino:
         
     | 
| 509 | 
         
            -
                            bench_model_id = bench_lcm_setting.openvino_lcm_model_id
         
     | 
| 510 | 
         
            -
                        elif bench_lcm_setting.use_lcm_lora:
         
     | 
| 511 | 
         
            -
                            bench_model_id = bench_lcm_setting.lcm_lora.base_model_id
         
     | 
| 512 | 
         
            -
                        else:
         
     | 
| 513 | 
         
            -
                            bench_model_id = bench_lcm_setting.lcm_model_id
         
     | 
| 514 | 
         
            -
             
     | 
| 515 | 
         
            -
                        benchmark_result = [
         
     | 
| 516 | 
         
            -
                            ["Device", f"{DEVICE.upper()},{get_device_name()}"],
         
     | 
| 517 | 
         
            -
                            ["Stable Diffusion Model", bench_model_id],
         
     | 
| 518 | 
         
            -
                            [
         
     | 
| 519 | 
         
            -
                                "Image Size ",
         
     | 
| 520 | 
         
            -
                                f"{bench_lcm_setting.image_width}x{bench_lcm_setting.image_height}",
         
     | 
| 521 | 
         
            -
                            ],
         
     | 
| 522 | 
         
            -
                            [
         
     | 
| 523 | 
         
            -
                                "Inference Steps",
         
     | 
| 524 | 
         
            -
                                f"{bench_lcm_setting.inference_steps}",
         
     | 
| 525 | 
         
            -
                            ],
         
     | 
| 526 | 
         
            -
                            [
         
     | 
| 527 | 
         
            -
                                "Benchmark Passes",
         
     | 
| 528 | 
         
            -
                                3,
         
     | 
| 529 | 
         
            -
                            ],
         
     | 
| 530 | 
         
            -
                            [
         
     | 
| 531 | 
         
            -
                                "Average Latency",
         
     | 
| 532 | 
         
            -
                                f"{round(avg_latency, 3)} sec",
         
     | 
| 533 | 
         
            -
                            ],
         
     | 
| 534 | 
         
            -
                            [
         
     | 
| 535 | 
         
            -
                                "Average Latency(TAESD* enabled)",
         
     | 
| 536 | 
         
            -
                                f"{round(avg_latency_taesd, 3)} sec",
         
     | 
| 537 | 
         
            -
                            ],
         
     | 
| 538 | 
         
            -
                        ]
         
     | 
| 539 | 
         
            -
                        print()
         
     | 
| 540 | 
         
            -
                        print(
         
     | 
| 541 | 
         
            -
                            f"                          FastSD Benchmark - {benchmark_name:8}                         "
         
     | 
| 542 | 
         
            -
                        )
         
     | 
| 543 | 
         
            -
                        print(f"-" * 80)
         
     | 
| 544 | 
         
            -
                        for benchmark in benchmark_result:
         
     | 
| 545 | 
         
            -
                            print(f"{benchmark[0]:35} - {benchmark[1]}")
         
     | 
| 546 | 
         
            -
                        print(f"-" * 80)
         
     | 
| 547 | 
         
            -
                        print("*TAESD - Tiny AutoEncoder for Stable Diffusion")
         
     | 
| 548 | 
         
            -
             
     | 
| 549 | 
         
            -
                    else:
         
     | 
| 550 | 
         
            -
                        for i in range(0, args.batch_count):
         
     | 
| 551 | 
         
            -
                            context.generate_text_to_image(
         
     | 
| 552 | 
         
            -
                                settings=config,
         
     | 
| 553 | 
         
            -
                                device=DEVICE,
         
     | 
| 554 | 
         
            -
                            )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/app_settings.py
    DELETED
    
    | 
         @@ -1,124 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from copy import deepcopy
         
     | 
| 2 | 
         
            -
            from os import makedirs, path
         
     | 
| 3 | 
         
            -
             
     | 
| 4 | 
         
            -
            import yaml
         
     | 
| 5 | 
         
            -
            from constants import (
         
     | 
| 6 | 
         
            -
                LCM_LORA_MODELS_FILE,
         
     | 
| 7 | 
         
            -
                LCM_MODELS_FILE,
         
     | 
| 8 | 
         
            -
                OPENVINO_LCM_MODELS_FILE,
         
     | 
| 9 | 
         
            -
                SD_MODELS_FILE,
         
     | 
| 10 | 
         
            -
            )
         
     | 
| 11 | 
         
            -
            from paths import FastStableDiffusionPaths, join_paths
         
     | 
| 12 | 
         
            -
            from utils import get_files_in_dir, get_models_from_text_file
         
     | 
| 13 | 
         
            -
             
     | 
| 14 | 
         
            -
            from models.settings import Settings
         
     | 
| 15 | 
         
            -
             
     | 
| 16 | 
         
            -
             
     | 
| 17 | 
         
            -
            class AppSettings:
         
     | 
| 18 | 
         
            -
                def __init__(self):
         
     | 
| 19 | 
         
            -
                    self.config_path = FastStableDiffusionPaths().get_app_settings_path()
         
     | 
| 20 | 
         
            -
                    self._stable_diffsuion_models = get_models_from_text_file(
         
     | 
| 21 | 
         
            -
                        FastStableDiffusionPaths().get_models_config_path(SD_MODELS_FILE)
         
     | 
| 22 | 
         
            -
                    )
         
     | 
| 23 | 
         
            -
                    self._lcm_lora_models = get_models_from_text_file(
         
     | 
| 24 | 
         
            -
                        FastStableDiffusionPaths().get_models_config_path(LCM_LORA_MODELS_FILE)
         
     | 
| 25 | 
         
            -
                    )
         
     | 
| 26 | 
         
            -
                    self._openvino_lcm_models = get_models_from_text_file(
         
     | 
| 27 | 
         
            -
                        FastStableDiffusionPaths().get_models_config_path(OPENVINO_LCM_MODELS_FILE)
         
     | 
| 28 | 
         
            -
                    )
         
     | 
| 29 | 
         
            -
                    self._lcm_models = get_models_from_text_file(
         
     | 
| 30 | 
         
            -
                        FastStableDiffusionPaths().get_models_config_path(LCM_MODELS_FILE)
         
     | 
| 31 | 
         
            -
                    )
         
     | 
| 32 | 
         
            -
                    self._gguf_diffusion_models = get_files_in_dir(
         
     | 
| 33 | 
         
            -
                        join_paths(FastStableDiffusionPaths().get_gguf_models_path(), "diffusion")
         
     | 
| 34 | 
         
            -
                    )
         
     | 
| 35 | 
         
            -
                    self._gguf_clip_models = get_files_in_dir(
         
     | 
| 36 | 
         
            -
                        join_paths(FastStableDiffusionPaths().get_gguf_models_path(), "clip")
         
     | 
| 37 | 
         
            -
                    )
         
     | 
| 38 | 
         
            -
                    self._gguf_vae_models = get_files_in_dir(
         
     | 
| 39 | 
         
            -
                        join_paths(FastStableDiffusionPaths().get_gguf_models_path(), "vae")
         
     | 
| 40 | 
         
            -
                    )
         
     | 
| 41 | 
         
            -
                    self._gguf_t5xxl_models = get_files_in_dir(
         
     | 
| 42 | 
         
            -
                        join_paths(FastStableDiffusionPaths().get_gguf_models_path(), "t5xxl")
         
     | 
| 43 | 
         
            -
                    )
         
     | 
| 44 | 
         
            -
                    self._config = None
         
     | 
| 45 | 
         
            -
             
     | 
| 46 | 
         
            -
                @property
         
     | 
| 47 | 
         
            -
                def settings(self):
         
     | 
| 48 | 
         
            -
                    return self._config
         
     | 
| 49 | 
         
            -
             
     | 
| 50 | 
         
            -
                @property
         
     | 
| 51 | 
         
            -
                def stable_diffsuion_models(self):
         
     | 
| 52 | 
         
            -
                    return self._stable_diffsuion_models
         
     | 
| 53 | 
         
            -
             
     | 
| 54 | 
         
            -
                @property
         
     | 
| 55 | 
         
            -
                def openvino_lcm_models(self):
         
     | 
| 56 | 
         
            -
                    return self._openvino_lcm_models
         
     | 
| 57 | 
         
            -
             
     | 
| 58 | 
         
            -
                @property
         
     | 
| 59 | 
         
            -
                def lcm_models(self):
         
     | 
| 60 | 
         
            -
                    return self._lcm_models
         
     | 
| 61 | 
         
            -
             
     | 
| 62 | 
         
            -
                @property
         
     | 
| 63 | 
         
            -
                def lcm_lora_models(self):
         
     | 
| 64 | 
         
            -
                    return self._lcm_lora_models
         
     | 
| 65 | 
         
            -
             
     | 
| 66 | 
         
            -
                @property
         
     | 
| 67 | 
         
            -
                def gguf_diffusion_models(self):
         
     | 
| 68 | 
         
            -
                    return self._gguf_diffusion_models
         
     | 
| 69 | 
         
            -
             
     | 
| 70 | 
         
            -
                @property
         
     | 
| 71 | 
         
            -
                def gguf_clip_models(self):
         
     | 
| 72 | 
         
            -
                    return self._gguf_clip_models
         
     | 
| 73 | 
         
            -
             
     | 
| 74 | 
         
            -
                @property
         
     | 
| 75 | 
         
            -
                def gguf_vae_models(self):
         
     | 
| 76 | 
         
            -
                    return self._gguf_vae_models
         
     | 
| 77 | 
         
            -
             
     | 
| 78 | 
         
            -
                @property
         
     | 
| 79 | 
         
            -
                def gguf_t5xxl_models(self):
         
     | 
| 80 | 
         
            -
                    return self._gguf_t5xxl_models
         
     | 
| 81 | 
         
            -
             
     | 
| 82 | 
         
            -
                def load(self, skip_file=False):
         
     | 
| 83 | 
         
            -
                    if skip_file:
         
     | 
| 84 | 
         
            -
                        print("Skipping config file")
         
     | 
| 85 | 
         
            -
                        settings_dict = self._load_default()
         
     | 
| 86 | 
         
            -
                        self._config = Settings.model_validate(settings_dict)
         
     | 
| 87 | 
         
            -
                    else:
         
     | 
| 88 | 
         
            -
                        if not path.exists(self.config_path):
         
     | 
| 89 | 
         
            -
                            base_dir = path.dirname(self.config_path)
         
     | 
| 90 | 
         
            -
                            if not path.exists(base_dir):
         
     | 
| 91 | 
         
            -
                                makedirs(base_dir)
         
     | 
| 92 | 
         
            -
                            try:
         
     | 
| 93 | 
         
            -
                                print("Settings not found creating default settings")
         
     | 
| 94 | 
         
            -
                                with open(self.config_path, "w") as file:
         
     | 
| 95 | 
         
            -
                                    yaml.dump(
         
     | 
| 96 | 
         
            -
                                        self._load_default(),
         
     | 
| 97 | 
         
            -
                                        file,
         
     | 
| 98 | 
         
            -
                                    )
         
     | 
| 99 | 
         
            -
                            except Exception as ex:
         
     | 
| 100 | 
         
            -
                                print(f"Error in creating settings : {ex}")
         
     | 
| 101 | 
         
            -
                                exit()
         
     | 
| 102 | 
         
            -
                        try:
         
     | 
| 103 | 
         
            -
                            with open(self.config_path) as file:
         
     | 
| 104 | 
         
            -
                                settings_dict = yaml.safe_load(file)
         
     | 
| 105 | 
         
            -
                                self._config = Settings.model_validate(settings_dict)
         
     | 
| 106 | 
         
            -
                        except Exception as ex:
         
     | 
| 107 | 
         
            -
                            print(f"Error in loading settings : {ex}")
         
     | 
| 108 | 
         
            -
             
     | 
| 109 | 
         
            -
                def save(self):
         
     | 
| 110 | 
         
            -
                    try:
         
     | 
| 111 | 
         
            -
                        with open(self.config_path, "w") as file:
         
     | 
| 112 | 
         
            -
                            tmp_cfg = deepcopy(self._config)
         
     | 
| 113 | 
         
            -
                            tmp_cfg.lcm_diffusion_setting.init_image = None
         
     | 
| 114 | 
         
            -
                            configurations = tmp_cfg.model_dump(
         
     | 
| 115 | 
         
            -
                                exclude=["init_image"],
         
     | 
| 116 | 
         
            -
                            )
         
     | 
| 117 | 
         
            -
                            if configurations:
         
     | 
| 118 | 
         
            -
                                yaml.dump(configurations, file)
         
     | 
| 119 | 
         
            -
                    except Exception as ex:
         
     | 
| 120 | 
         
            -
                        print(f"Error in saving settings : {ex}")
         
     | 
| 121 | 
         
            -
             
     | 
| 122 | 
         
            -
                def _load_default(self) -> dict:
         
     | 
| 123 | 
         
            -
                    default_config = Settings()
         
     | 
| 124 | 
         
            -
                    return default_config.model_dump()
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/__init__.py
    DELETED
    
    | 
         
            File without changes
         
     | 
    	
        src/backend/annotators/canny_control.py
    DELETED
    
    | 
         @@ -1,15 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            import numpy as np
         
     | 
| 2 | 
         
            -
            from backend.annotators.control_interface import ControlInterface
         
     | 
| 3 | 
         
            -
            from cv2 import Canny
         
     | 
| 4 | 
         
            -
            from PIL import Image
         
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
             
     | 
| 7 | 
         
            -
            class CannyControl(ControlInterface):
         
     | 
| 8 | 
         
            -
                def get_control_image(self, image: Image) -> Image:
         
     | 
| 9 | 
         
            -
                    low_threshold = 100
         
     | 
| 10 | 
         
            -
                    high_threshold = 200
         
     | 
| 11 | 
         
            -
                    image = np.array(image)
         
     | 
| 12 | 
         
            -
                    image = Canny(image, low_threshold, high_threshold)
         
     | 
| 13 | 
         
            -
                    image = image[:, :, None]
         
     | 
| 14 | 
         
            -
                    image = np.concatenate([image, image, image], axis=2)
         
     | 
| 15 | 
         
            -
                    return Image.fromarray(image)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/annotators/control_interface.py
    DELETED
    
    | 
         @@ -1,12 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from abc import ABC, abstractmethod
         
     | 
| 2 | 
         
            -
             
     | 
| 3 | 
         
            -
            from PIL import Image
         
     | 
| 4 | 
         
            -
             
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
            class ControlInterface(ABC):
         
     | 
| 7 | 
         
            -
                @abstractmethod
         
     | 
| 8 | 
         
            -
                def get_control_image(
         
     | 
| 9 | 
         
            -
                    self,
         
     | 
| 10 | 
         
            -
                    image: Image,
         
     | 
| 11 | 
         
            -
                ) -> Image:
         
     | 
| 12 | 
         
            -
                    pass
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/annotators/depth_control.py
    DELETED
    
    | 
         @@ -1,15 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            import numpy as np
         
     | 
| 2 | 
         
            -
            from backend.annotators.control_interface import ControlInterface
         
     | 
| 3 | 
         
            -
            from PIL import Image
         
     | 
| 4 | 
         
            -
            from transformers import pipeline
         
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
             
     | 
| 7 | 
         
            -
            class DepthControl(ControlInterface):
         
     | 
| 8 | 
         
            -
                def get_control_image(self, image: Image) -> Image:
         
     | 
| 9 | 
         
            -
                    depth_estimator = pipeline("depth-estimation")
         
     | 
| 10 | 
         
            -
                    image = depth_estimator(image)["depth"]
         
     | 
| 11 | 
         
            -
                    image = np.array(image)
         
     | 
| 12 | 
         
            -
                    image = image[:, :, None]
         
     | 
| 13 | 
         
            -
                    image = np.concatenate([image, image, image], axis=2)
         
     | 
| 14 | 
         
            -
                    image = Image.fromarray(image)
         
     | 
| 15 | 
         
            -
                    return image
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/annotators/image_control_factory.py
    DELETED
    
    | 
         @@ -1,31 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from backend.annotators.canny_control import CannyControl
         
     | 
| 2 | 
         
            -
            from backend.annotators.depth_control import DepthControl
         
     | 
| 3 | 
         
            -
            from backend.annotators.lineart_control import LineArtControl
         
     | 
| 4 | 
         
            -
            from backend.annotators.mlsd_control import MlsdControl
         
     | 
| 5 | 
         
            -
            from backend.annotators.normal_control import NormalControl
         
     | 
| 6 | 
         
            -
            from backend.annotators.pose_control import PoseControl
         
     | 
| 7 | 
         
            -
            from backend.annotators.shuffle_control import ShuffleControl
         
     | 
| 8 | 
         
            -
            from backend.annotators.softedge_control import SoftEdgeControl
         
     | 
| 9 | 
         
            -
             
     | 
| 10 | 
         
            -
             
     | 
| 11 | 
         
            -
            class ImageControlFactory:
         
     | 
| 12 | 
         
            -
                def create_control(self, controlnet_type: str):
         
     | 
| 13 | 
         
            -
                    if controlnet_type == "Canny":
         
     | 
| 14 | 
         
            -
                        return CannyControl()
         
     | 
| 15 | 
         
            -
                    elif controlnet_type == "Pose":
         
     | 
| 16 | 
         
            -
                        return PoseControl()
         
     | 
| 17 | 
         
            -
                    elif controlnet_type == "MLSD":
         
     | 
| 18 | 
         
            -
                        return MlsdControl()
         
     | 
| 19 | 
         
            -
                    elif controlnet_type == "Depth":
         
     | 
| 20 | 
         
            -
                        return DepthControl()
         
     | 
| 21 | 
         
            -
                    elif controlnet_type == "LineArt":
         
     | 
| 22 | 
         
            -
                        return LineArtControl()
         
     | 
| 23 | 
         
            -
                    elif controlnet_type == "Shuffle":
         
     | 
| 24 | 
         
            -
                        return ShuffleControl()
         
     | 
| 25 | 
         
            -
                    elif controlnet_type == "NormalBAE":
         
     | 
| 26 | 
         
            -
                        return NormalControl()
         
     | 
| 27 | 
         
            -
                    elif controlnet_type == "SoftEdge":
         
     | 
| 28 | 
         
            -
                        return SoftEdgeControl()
         
     | 
| 29 | 
         
            -
                    else:
         
     | 
| 30 | 
         
            -
                        print("Error: Control type not implemented!")
         
     | 
| 31 | 
         
            -
                        raise Exception("Error: Control type not implemented!")
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/annotators/lineart_control.py
    DELETED
    
    | 
         @@ -1,11 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            import numpy as np
         
     | 
| 2 | 
         
            -
            from backend.annotators.control_interface import ControlInterface
         
     | 
| 3 | 
         
            -
            from controlnet_aux import LineartDetector
         
     | 
| 4 | 
         
            -
            from PIL import Image
         
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
             
     | 
| 7 | 
         
            -
            class LineArtControl(ControlInterface):
         
     | 
| 8 | 
         
            -
                def get_control_image(self, image: Image) -> Image:
         
     | 
| 9 | 
         
            -
                    processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
         
     | 
| 10 | 
         
            -
                    control_image = processor(image)
         
     | 
| 11 | 
         
            -
                    return control_image
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/annotators/mlsd_control.py
    DELETED
    
    | 
         @@ -1,10 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from backend.annotators.control_interface import ControlInterface
         
     | 
| 2 | 
         
            -
            from controlnet_aux import MLSDdetector
         
     | 
| 3 | 
         
            -
            from PIL import Image
         
     | 
| 4 | 
         
            -
             
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
            class MlsdControl(ControlInterface):
         
     | 
| 7 | 
         
            -
                def get_control_image(self, image: Image) -> Image:
         
     | 
| 8 | 
         
            -
                    mlsd = MLSDdetector.from_pretrained("lllyasviel/ControlNet")
         
     | 
| 9 | 
         
            -
                    image = mlsd(image)
         
     | 
| 10 | 
         
            -
                    return image
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/annotators/normal_control.py
    DELETED
    
    | 
         @@ -1,10 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from backend.annotators.control_interface import ControlInterface
         
     | 
| 2 | 
         
            -
            from controlnet_aux import NormalBaeDetector
         
     | 
| 3 | 
         
            -
            from PIL import Image
         
     | 
| 4 | 
         
            -
             
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
            class NormalControl(ControlInterface):
         
     | 
| 7 | 
         
            -
                def get_control_image(self, image: Image) -> Image:
         
     | 
| 8 | 
         
            -
                    processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
         
     | 
| 9 | 
         
            -
                    control_image = processor(image)
         
     | 
| 10 | 
         
            -
                    return control_image
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/annotators/pose_control.py
    DELETED
    
    | 
         @@ -1,10 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from backend.annotators.control_interface import ControlInterface
         
     | 
| 2 | 
         
            -
            from controlnet_aux import OpenposeDetector
         
     | 
| 3 | 
         
            -
            from PIL import Image
         
     | 
| 4 | 
         
            -
             
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
            class PoseControl(ControlInterface):
         
     | 
| 7 | 
         
            -
                def get_control_image(self, image: Image) -> Image:
         
     | 
| 8 | 
         
            -
                    openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
         
     | 
| 9 | 
         
            -
                    image = openpose(image)
         
     | 
| 10 | 
         
            -
                    return image
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/annotators/shuffle_control.py
    DELETED
    
    | 
         @@ -1,10 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from backend.annotators.control_interface import ControlInterface
         
     | 
| 2 | 
         
            -
            from controlnet_aux import ContentShuffleDetector
         
     | 
| 3 | 
         
            -
            from PIL import Image
         
     | 
| 4 | 
         
            -
             
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
            class ShuffleControl(ControlInterface):
         
     | 
| 7 | 
         
            -
                def get_control_image(self, image: Image) -> Image:
         
     | 
| 8 | 
         
            -
                    shuffle_processor = ContentShuffleDetector()
         
     | 
| 9 | 
         
            -
                    image = shuffle_processor(image)
         
     | 
| 10 | 
         
            -
                    return image
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/annotators/softedge_control.py
    DELETED
    
    | 
         @@ -1,10 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from backend.annotators.control_interface import ControlInterface
         
     | 
| 2 | 
         
            -
            from controlnet_aux import PidiNetDetector
         
     | 
| 3 | 
         
            -
            from PIL import Image
         
     | 
| 4 | 
         
            -
             
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
            class SoftEdgeControl(ControlInterface):
         
     | 
| 7 | 
         
            -
                def get_control_image(self, image: Image) -> Image:
         
     | 
| 8 | 
         
            -
                    processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
         
     | 
| 9 | 
         
            -
                    control_image = processor(image)
         
     | 
| 10 | 
         
            -
                    return control_image
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/api/mcp_server.py
    DELETED
    
    | 
         @@ -1,97 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            import platform
         
     | 
| 2 | 
         
            -
             
     | 
| 3 | 
         
            -
            import uvicorn
         
     | 
| 4 | 
         
            -
            from backend.device import get_device_name
         
     | 
| 5 | 
         
            -
            from backend.models.device import DeviceInfo
         
     | 
| 6 | 
         
            -
            from constants import APP_VERSION, DEVICE
         
     | 
| 7 | 
         
            -
            from context import Context
         
     | 
| 8 | 
         
            -
            from fastapi import FastAPI, Request
         
     | 
| 9 | 
         
            -
            from fastapi_mcp import FastApiMCP
         
     | 
| 10 | 
         
            -
            from state import get_settings
         
     | 
| 11 | 
         
            -
            from fastapi.middleware.cors import CORSMiddleware
         
     | 
| 12 | 
         
            -
            from models.interface_types import InterfaceType
         
     | 
| 13 | 
         
            -
            from fastapi.staticfiles import StaticFiles
         
     | 
| 14 | 
         
            -
             
     | 
| 15 | 
         
            -
            app_settings = get_settings()
         
     | 
| 16 | 
         
            -
            app = FastAPI(
         
     | 
| 17 | 
         
            -
                title="FastSD CPU",
         
     | 
| 18 | 
         
            -
                description="Fast stable diffusion on CPU",
         
     | 
| 19 | 
         
            -
                version=APP_VERSION,
         
     | 
| 20 | 
         
            -
                license_info={
         
     | 
| 21 | 
         
            -
                    "name": "MIT",
         
     | 
| 22 | 
         
            -
                    "identifier": "MIT",
         
     | 
| 23 | 
         
            -
                },
         
     | 
| 24 | 
         
            -
                describe_all_responses=True,
         
     | 
| 25 | 
         
            -
                describe_full_response_schema=True,
         
     | 
| 26 | 
         
            -
            )
         
     | 
| 27 | 
         
            -
            origins = ["*"]
         
     | 
| 28 | 
         
            -
             
     | 
| 29 | 
         
            -
            app.add_middleware(
         
     | 
| 30 | 
         
            -
                CORSMiddleware,
         
     | 
| 31 | 
         
            -
                allow_origins=origins,
         
     | 
| 32 | 
         
            -
                allow_credentials=True,
         
     | 
| 33 | 
         
            -
                allow_methods=["*"],
         
     | 
| 34 | 
         
            -
                allow_headers=["*"],
         
     | 
| 35 | 
         
            -
            )
         
     | 
| 36 | 
         
            -
            print(app_settings.settings.lcm_diffusion_setting)
         
     | 
| 37 | 
         
            -
             
     | 
| 38 | 
         
            -
            context = Context(InterfaceType.API_SERVER)
         
     | 
| 39 | 
         
            -
            app.mount("/results", StaticFiles(directory="results"), name="results")
         
     | 
| 40 | 
         
            -
             
     | 
| 41 | 
         
            -
             
     | 
| 42 | 
         
            -
            @app.get(
         
     | 
| 43 | 
         
            -
                "/info",
         
     | 
| 44 | 
         
            -
                description="Get system information",
         
     | 
| 45 | 
         
            -
                summary="Get system information",
         
     | 
| 46 | 
         
            -
                operation_id="get_system_info",
         
     | 
| 47 | 
         
            -
            )
         
     | 
| 48 | 
         
            -
            async def info() -> dict:
         
     | 
| 49 | 
         
            -
                device_info = DeviceInfo(
         
     | 
| 50 | 
         
            -
                    device_type=DEVICE,
         
     | 
| 51 | 
         
            -
                    device_name=get_device_name(),
         
     | 
| 52 | 
         
            -
                    os=platform.system(),
         
     | 
| 53 | 
         
            -
                    platform=platform.platform(),
         
     | 
| 54 | 
         
            -
                    processor=platform.processor(),
         
     | 
| 55 | 
         
            -
                )
         
     | 
| 56 | 
         
            -
                return device_info.model_dump()
         
     | 
| 57 | 
         
            -
             
     | 
| 58 | 
         
            -
             
     | 
| 59 | 
         
            -
            @app.post(
         
     | 
| 60 | 
         
            -
                "/generate",
         
     | 
| 61 | 
         
            -
                description="Generate image from text prompt",
         
     | 
| 62 | 
         
            -
                summary="Text to image generation",
         
     | 
| 63 | 
         
            -
                operation_id="generate",
         
     | 
| 64 | 
         
            -
            )
         
     | 
| 65 | 
         
            -
            async def generate(
         
     | 
| 66 | 
         
            -
                prompt: str,
         
     | 
| 67 | 
         
            -
                request: Request,
         
     | 
| 68 | 
         
            -
            ) -> str:
         
     | 
| 69 | 
         
            -
                """
         
     | 
| 70 | 
         
            -
                Returns URL of the generated image for text prompt
         
     | 
| 71 | 
         
            -
                """
         
     | 
| 72 | 
         
            -
             
     | 
| 73 | 
         
            -
                app_settings.settings.lcm_diffusion_setting.prompt = prompt
         
     | 
| 74 | 
         
            -
                images = context.generate_text_to_image(app_settings.settings)
         
     | 
| 75 | 
         
            -
                image_names = context.save_images(
         
     | 
| 76 | 
         
            -
                    images,
         
     | 
| 77 | 
         
            -
                    app_settings.settings,
         
     | 
| 78 | 
         
            -
                )
         
     | 
| 79 | 
         
            -
                url = request.url_for("results", path=image_names[0])
         
     | 
| 80 | 
         
            -
                image_url = f"The generated image available at the URL {url}"
         
     | 
| 81 | 
         
            -
                return image_url
         
     | 
| 82 | 
         
            -
             
     | 
| 83 | 
         
            -
             
     | 
| 84 | 
         
            -
            def start_mcp_server(port: int = 8000):
         
     | 
| 85 | 
         
            -
                mcp = FastApiMCP(
         
     | 
| 86 | 
         
            -
                    app,
         
     | 
| 87 | 
         
            -
                    name="FastSDCPU MCP",
         
     | 
| 88 | 
         
            -
                    description="MCP server for FastSD CPU API",
         
     | 
| 89 | 
         
            -
                    base_url=f"http://localhost:{port}",
         
     | 
| 90 | 
         
            -
                )
         
     | 
| 91 | 
         
            -
             
     | 
| 92 | 
         
            -
                mcp.mount()
         
     | 
| 93 | 
         
            -
                uvicorn.run(
         
     | 
| 94 | 
         
            -
                    app,
         
     | 
| 95 | 
         
            -
                    host="0.0.0.0",
         
     | 
| 96 | 
         
            -
                    port=port,
         
     | 
| 97 | 
         
            -
                )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/api/models/response.py
    DELETED
    
    | 
         @@ -1,16 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from typing import List
         
     | 
| 2 | 
         
            -
             
     | 
| 3 | 
         
            -
            from pydantic import BaseModel
         
     | 
| 4 | 
         
            -
             
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
            class StableDiffusionResponse(BaseModel):
         
     | 
| 7 | 
         
            -
                """
         
     | 
| 8 | 
         
            -
                Stable diffusion response model
         
     | 
| 9 | 
         
            -
             
     | 
| 10 | 
         
            -
                Attributes:
         
     | 
| 11 | 
         
            -
                    images (List[str]): List of JPEG image as base64 encoded
         
     | 
| 12 | 
         
            -
                    latency (float): Latency in seconds
         
     | 
| 13 | 
         
            -
                """
         
     | 
| 14 | 
         
            -
             
     | 
| 15 | 
         
            -
                images: List[str]
         
     | 
| 16 | 
         
            -
                latency: float
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/api/web.py
    DELETED
    
    | 
         @@ -1,112 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            import platform
         
     | 
| 2 | 
         
            -
             
     | 
| 3 | 
         
            -
            import uvicorn
         
     | 
| 4 | 
         
            -
            from fastapi import FastAPI
         
     | 
| 5 | 
         
            -
            from fastapi.middleware.cors import CORSMiddleware
         
     | 
| 6 | 
         
            -
             
     | 
| 7 | 
         
            -
            from backend.api.models.response import StableDiffusionResponse
         
     | 
| 8 | 
         
            -
            from backend.base64_image import base64_image_to_pil, pil_image_to_base64_str
         
     | 
| 9 | 
         
            -
            from backend.device import get_device_name
         
     | 
| 10 | 
         
            -
            from backend.models.device import DeviceInfo
         
     | 
| 11 | 
         
            -
            from backend.models.lcmdiffusion_setting import DiffusionTask, LCMDiffusionSetting
         
     | 
| 12 | 
         
            -
            from constants import APP_VERSION, DEVICE
         
     | 
| 13 | 
         
            -
            from context import Context
         
     | 
| 14 | 
         
            -
            from models.interface_types import InterfaceType
         
     | 
| 15 | 
         
            -
            from state import get_settings
         
     | 
| 16 | 
         
            -
             
     | 
| 17 | 
         
            -
            app_settings = get_settings()
         
     | 
| 18 | 
         
            -
            app = FastAPI(
         
     | 
| 19 | 
         
            -
                title="FastSD CPU",
         
     | 
| 20 | 
         
            -
                description="Fast stable diffusion on CPU",
         
     | 
| 21 | 
         
            -
                version=APP_VERSION,
         
     | 
| 22 | 
         
            -
                license_info={
         
     | 
| 23 | 
         
            -
                    "name": "MIT",
         
     | 
| 24 | 
         
            -
                    "identifier": "MIT",
         
     | 
| 25 | 
         
            -
                },
         
     | 
| 26 | 
         
            -
                docs_url="/api/docs",
         
     | 
| 27 | 
         
            -
                redoc_url="/api/redoc",
         
     | 
| 28 | 
         
            -
                openapi_url="/api/openapi.json",
         
     | 
| 29 | 
         
            -
            )
         
     | 
| 30 | 
         
            -
            print(app_settings.settings.lcm_diffusion_setting)
         
     | 
| 31 | 
         
            -
            origins = ["*"]
         
     | 
| 32 | 
         
            -
            app.add_middleware(
         
     | 
| 33 | 
         
            -
                CORSMiddleware,
         
     | 
| 34 | 
         
            -
                allow_origins=origins,
         
     | 
| 35 | 
         
            -
                allow_credentials=True,
         
     | 
| 36 | 
         
            -
                allow_methods=["*"],
         
     | 
| 37 | 
         
            -
                allow_headers=["*"],
         
     | 
| 38 | 
         
            -
            )
         
     | 
| 39 | 
         
            -
            context = Context(InterfaceType.API_SERVER)
         
     | 
| 40 | 
         
            -
             
     | 
| 41 | 
         
            -
             
     | 
| 42 | 
         
            -
            @app.get("/api/")
         
     | 
| 43 | 
         
            -
            async def root():
         
     | 
| 44 | 
         
            -
                return {"message": "Welcome to FastSD CPU API"}
         
     | 
| 45 | 
         
            -
             
     | 
| 46 | 
         
            -
             
     | 
| 47 | 
         
            -
            @app.get(
         
     | 
| 48 | 
         
            -
                "/api/info",
         
     | 
| 49 | 
         
            -
                description="Get system information",
         
     | 
| 50 | 
         
            -
                summary="Get system information",
         
     | 
| 51 | 
         
            -
            )
         
     | 
| 52 | 
         
            -
            async def info():
         
     | 
| 53 | 
         
            -
                device_info = DeviceInfo(
         
     | 
| 54 | 
         
            -
                    device_type=DEVICE,
         
     | 
| 55 | 
         
            -
                    device_name=get_device_name(),
         
     | 
| 56 | 
         
            -
                    os=platform.system(),
         
     | 
| 57 | 
         
            -
                    platform=platform.platform(),
         
     | 
| 58 | 
         
            -
                    processor=platform.processor(),
         
     | 
| 59 | 
         
            -
                )
         
     | 
| 60 | 
         
            -
                return device_info.model_dump()
         
     | 
| 61 | 
         
            -
             
     | 
| 62 | 
         
            -
             
     | 
| 63 | 
         
            -
            @app.get(
         
     | 
| 64 | 
         
            -
                "/api/config",
         
     | 
| 65 | 
         
            -
                description="Get current configuration",
         
     | 
| 66 | 
         
            -
                summary="Get configurations",
         
     | 
| 67 | 
         
            -
            )
         
     | 
| 68 | 
         
            -
            async def config():
         
     | 
| 69 | 
         
            -
                return app_settings.settings
         
     | 
| 70 | 
         
            -
             
     | 
| 71 | 
         
            -
             
     | 
| 72 | 
         
            -
            @app.get(
         
     | 
| 73 | 
         
            -
                "/api/models",
         
     | 
| 74 | 
         
            -
                description="Get available models",
         
     | 
| 75 | 
         
            -
                summary="Get available models",
         
     | 
| 76 | 
         
            -
            )
         
     | 
| 77 | 
         
            -
            async def models():
         
     | 
| 78 | 
         
            -
                return {
         
     | 
| 79 | 
         
            -
                    "lcm_lora_models": app_settings.lcm_lora_models,
         
     | 
| 80 | 
         
            -
                    "stable_diffusion": app_settings.stable_diffsuion_models,
         
     | 
| 81 | 
         
            -
                    "openvino_models": app_settings.openvino_lcm_models,
         
     | 
| 82 | 
         
            -
                    "lcm_models": app_settings.lcm_models,
         
     | 
| 83 | 
         
            -
                }
         
     | 
| 84 | 
         
            -
             
     | 
| 85 | 
         
            -
             
     | 
| 86 | 
         
            -
            @app.post(
         
     | 
| 87 | 
         
            -
                "/api/generate",
         
     | 
| 88 | 
         
            -
                description="Generate image(Text to image,Image to Image)",
         
     | 
| 89 | 
         
            -
                summary="Generate image(Text to image,Image to Image)",
         
     | 
| 90 | 
         
            -
            )
         
     | 
| 91 | 
         
            -
            async def generate(diffusion_config: LCMDiffusionSetting) -> StableDiffusionResponse:
         
     | 
| 92 | 
         
            -
                app_settings.settings.lcm_diffusion_setting = diffusion_config
         
     | 
| 93 | 
         
            -
                if diffusion_config.diffusion_task == DiffusionTask.image_to_image:
         
     | 
| 94 | 
         
            -
                    app_settings.settings.lcm_diffusion_setting.init_image = base64_image_to_pil(
         
     | 
| 95 | 
         
            -
                        diffusion_config.init_image
         
     | 
| 96 | 
         
            -
                    )
         
     | 
| 97 | 
         
            -
             
     | 
| 98 | 
         
            -
                images = context.generate_text_to_image(app_settings.settings)
         
     | 
| 99 | 
         
            -
             
     | 
| 100 | 
         
            -
                images_base64 = [pil_image_to_base64_str(img) for img in images]
         
     | 
| 101 | 
         
            -
                return StableDiffusionResponse(
         
     | 
| 102 | 
         
            -
                    latency=round(context.latency, 2),
         
     | 
| 103 | 
         
            -
                    images=images_base64,
         
     | 
| 104 | 
         
            -
                )
         
     | 
| 105 | 
         
            -
             
     | 
| 106 | 
         
            -
             
     | 
| 107 | 
         
            -
            def start_web_server(port: int = 8000):
         
     | 
| 108 | 
         
            -
                uvicorn.run(
         
     | 
| 109 | 
         
            -
                    app,
         
     | 
| 110 | 
         
            -
                    host="0.0.0.0",
         
     | 
| 111 | 
         
            -
                    port=port,
         
     | 
| 112 | 
         
            -
                )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/base64_image.py
    DELETED
    
    | 
         @@ -1,21 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from io import BytesIO
         
     | 
| 2 | 
         
            -
            from base64 import b64encode, b64decode
         
     | 
| 3 | 
         
            -
            from PIL import Image
         
     | 
| 4 | 
         
            -
             
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
            def pil_image_to_base64_str(
         
     | 
| 7 | 
         
            -
                image: Image,
         
     | 
| 8 | 
         
            -
                format: str = "JPEG",
         
     | 
| 9 | 
         
            -
            ) -> str:
         
     | 
| 10 | 
         
            -
                buffer = BytesIO()
         
     | 
| 11 | 
         
            -
                image.save(buffer, format=format)
         
     | 
| 12 | 
         
            -
                buffer.seek(0)
         
     | 
| 13 | 
         
            -
                img_base64 = b64encode(buffer.getvalue()).decode("utf-8")
         
     | 
| 14 | 
         
            -
                return img_base64
         
     | 
| 15 | 
         
            -
             
     | 
| 16 | 
         
            -
             
     | 
| 17 | 
         
            -
            def base64_image_to_pil(base64_str) -> Image:
         
     | 
| 18 | 
         
            -
                image_data = b64decode(base64_str)
         
     | 
| 19 | 
         
            -
                image_buffer = BytesIO(image_data)
         
     | 
| 20 | 
         
            -
                image = Image.open(image_buffer)
         
     | 
| 21 | 
         
            -
                return image
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/controlnet.py
    DELETED
    
    | 
         @@ -1,90 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            import logging
         
     | 
| 2 | 
         
            -
            from PIL import Image
         
     | 
| 3 | 
         
            -
            from diffusers import ControlNetModel
         
     | 
| 4 | 
         
            -
            from backend.models.lcmdiffusion_setting import (
         
     | 
| 5 | 
         
            -
                DiffusionTask,
         
     | 
| 6 | 
         
            -
                ControlNetSetting,
         
     | 
| 7 | 
         
            -
            )
         
     | 
| 8 | 
         
            -
             
     | 
| 9 | 
         
            -
             
     | 
| 10 | 
         
            -
            # Prepares ControlNet adapters for use with FastSD CPU
         
     | 
| 11 | 
         
            -
            #
         
     | 
| 12 | 
         
            -
            # This function loads the ControlNet adapters defined by the
         
     | 
| 13 | 
         
            -
            # _lcm_diffusion_setting.controlnet_ object and returns a dictionary
         
     | 
| 14 | 
         
            -
            # with the pipeline arguments required to use the loaded adapters
         
     | 
| 15 | 
         
            -
            def load_controlnet_adapters(lcm_diffusion_setting) -> dict:
         
     | 
| 16 | 
         
            -
                controlnet_args = {}
         
     | 
| 17 | 
         
            -
                if (
         
     | 
| 18 | 
         
            -
                    lcm_diffusion_setting.controlnet is None
         
     | 
| 19 | 
         
            -
                    or not lcm_diffusion_setting.controlnet.enabled
         
     | 
| 20 | 
         
            -
                ):
         
     | 
| 21 | 
         
            -
                    return controlnet_args
         
     | 
| 22 | 
         
            -
             
     | 
| 23 | 
         
            -
                logging.info("Loading ControlNet adapter")
         
     | 
| 24 | 
         
            -
                controlnet_adapter = ControlNetModel.from_single_file(
         
     | 
| 25 | 
         
            -
                    lcm_diffusion_setting.controlnet.adapter_path,
         
     | 
| 26 | 
         
            -
                    # local_files_only=True,
         
     | 
| 27 | 
         
            -
                    use_safetensors=True,
         
     | 
| 28 | 
         
            -
                )
         
     | 
| 29 | 
         
            -
                controlnet_args["controlnet"] = controlnet_adapter
         
     | 
| 30 | 
         
            -
                return controlnet_args
         
     | 
| 31 | 
         
            -
             
     | 
| 32 | 
         
            -
             
     | 
| 33 | 
         
            -
            # Updates the ControlNet pipeline arguments to use for image generation
         
     | 
| 34 | 
         
            -
            #
         
     | 
| 35 | 
         
            -
            # This function uses the contents of the _lcm_diffusion_setting.controlnet_
         
     | 
| 36 | 
         
            -
            # object to generate a dictionary with the corresponding pipeline arguments
         
     | 
| 37 | 
         
            -
            # to be used for image generation; in particular, it sets the ControlNet control
         
     | 
| 38 | 
         
            -
            # image and conditioning scale
         
     | 
| 39 | 
         
            -
            def update_controlnet_arguments(lcm_diffusion_setting) -> dict:
         
     | 
| 40 | 
         
            -
                controlnet_args = {}
         
     | 
| 41 | 
         
            -
                if (
         
     | 
| 42 | 
         
            -
                    lcm_diffusion_setting.controlnet is None
         
     | 
| 43 | 
         
            -
                    or not lcm_diffusion_setting.controlnet.enabled
         
     | 
| 44 | 
         
            -
                ):
         
     | 
| 45 | 
         
            -
                    return controlnet_args
         
     | 
| 46 | 
         
            -
             
     | 
| 47 | 
         
            -
                controlnet_args["controlnet_conditioning_scale"] = (
         
     | 
| 48 | 
         
            -
                    lcm_diffusion_setting.controlnet.conditioning_scale
         
     | 
| 49 | 
         
            -
                )
         
     | 
| 50 | 
         
            -
                if lcm_diffusion_setting.diffusion_task == DiffusionTask.text_to_image.value:
         
     | 
| 51 | 
         
            -
                    controlnet_args["image"] = lcm_diffusion_setting.controlnet._control_image
         
     | 
| 52 | 
         
            -
                elif lcm_diffusion_setting.diffusion_task == DiffusionTask.image_to_image.value:
         
     | 
| 53 | 
         
            -
                    controlnet_args["control_image"] = (
         
     | 
| 54 | 
         
            -
                        lcm_diffusion_setting.controlnet._control_image
         
     | 
| 55 | 
         
            -
                    )
         
     | 
| 56 | 
         
            -
                return controlnet_args
         
     | 
| 57 | 
         
            -
             
     | 
| 58 | 
         
            -
             
     | 
| 59 | 
         
            -
            # Helper function to adjust ControlNet settings from a dictionary
         
     | 
| 60 | 
         
            -
            def controlnet_settings_from_dict(
         
     | 
| 61 | 
         
            -
                lcm_diffusion_setting,
         
     | 
| 62 | 
         
            -
                dictionary,
         
     | 
| 63 | 
         
            -
            ) -> None:
         
     | 
| 64 | 
         
            -
                if lcm_diffusion_setting is None or dictionary is None:
         
     | 
| 65 | 
         
            -
                    logging.error("Invalid arguments!")
         
     | 
| 66 | 
         
            -
                    return
         
     | 
| 67 | 
         
            -
                if (
         
     | 
| 68 | 
         
            -
                    "controlnet" not in dictionary
         
     | 
| 69 | 
         
            -
                    or dictionary["controlnet"] is None
         
     | 
| 70 | 
         
            -
                    or len(dictionary["controlnet"]) == 0
         
     | 
| 71 | 
         
            -
                ):
         
     | 
| 72 | 
         
            -
                    logging.warning("ControlNet settings not found, ControlNet will be disabled")
         
     | 
| 73 | 
         
            -
                    lcm_diffusion_setting.controlnet = None
         
     | 
| 74 | 
         
            -
                    return
         
     | 
| 75 | 
         
            -
             
     | 
| 76 | 
         
            -
                controlnet = ControlNetSetting()
         
     | 
| 77 | 
         
            -
                controlnet.enabled = dictionary["controlnet"][0]["enabled"]
         
     | 
| 78 | 
         
            -
                controlnet.conditioning_scale = dictionary["controlnet"][0]["conditioning_scale"]
         
     | 
| 79 | 
         
            -
                controlnet.adapter_path = dictionary["controlnet"][0]["adapter_path"]
         
     | 
| 80 | 
         
            -
                controlnet._control_image = None
         
     | 
| 81 | 
         
            -
                image_path = dictionary["controlnet"][0]["control_image"]
         
     | 
| 82 | 
         
            -
                if controlnet.enabled:
         
     | 
| 83 | 
         
            -
                    try:
         
     | 
| 84 | 
         
            -
                        controlnet._control_image = Image.open(image_path)
         
     | 
| 85 | 
         
            -
                    except (AttributeError, FileNotFoundError) as err:
         
     | 
| 86 | 
         
            -
                        print(err)
         
     | 
| 87 | 
         
            -
                    if controlnet._control_image is None:
         
     | 
| 88 | 
         
            -
                        logging.error("Wrong ControlNet control image! Disabling ControlNet")
         
     | 
| 89 | 
         
            -
                        controlnet.enabled = False
         
     | 
| 90 | 
         
            -
                lcm_diffusion_setting.controlnet = controlnet
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/device.py
    DELETED
    
    | 
         @@ -1,23 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            import platform
         
     | 
| 2 | 
         
            -
            from constants import DEVICE
         
     | 
| 3 | 
         
            -
            import torch
         
     | 
| 4 | 
         
            -
            import openvino as ov
         
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
            core = ov.Core()
         
     | 
| 7 | 
         
            -
             
     | 
| 8 | 
         
            -
             
     | 
| 9 | 
         
            -
            def is_openvino_device() -> bool:
         
     | 
| 10 | 
         
            -
                if DEVICE.lower() == "cpu" or DEVICE.lower()[0] == "g" or DEVICE.lower()[0] == "n":
         
     | 
| 11 | 
         
            -
                    return True
         
     | 
| 12 | 
         
            -
                else:
         
     | 
| 13 | 
         
            -
                    return False
         
     | 
| 14 | 
         
            -
             
     | 
| 15 | 
         
            -
             
     | 
| 16 | 
         
            -
            def get_device_name() -> str:
         
     | 
| 17 | 
         
            -
                if DEVICE == "cuda" or DEVICE == "mps":
         
     | 
| 18 | 
         
            -
                    default_gpu_index = torch.cuda.current_device()
         
     | 
| 19 | 
         
            -
                    return torch.cuda.get_device_name(default_gpu_index)
         
     | 
| 20 | 
         
            -
                elif platform.system().lower() == "darwin":
         
     | 
| 21 | 
         
            -
                    return platform.processor()
         
     | 
| 22 | 
         
            -
                elif is_openvino_device():
         
     | 
| 23 | 
         
            -
                    return core.get_property(DEVICE.upper(), "FULL_DEVICE_NAME")
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/gguf/gguf_diffusion.py
    DELETED
    
    | 
         @@ -1,319 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            """
         
     | 
| 2 | 
         
            -
            Wrapper class to call the stablediffusion.cpp shared library for GGUF support
         
     | 
| 3 | 
         
            -
            """
         
     | 
| 4 | 
         
            -
             
     | 
| 5 | 
         
            -
            import ctypes
         
     | 
| 6 | 
         
            -
            import platform
         
     | 
| 7 | 
         
            -
            from ctypes import (
         
     | 
| 8 | 
         
            -
                POINTER,
         
     | 
| 9 | 
         
            -
                c_bool,
         
     | 
| 10 | 
         
            -
                c_char_p,
         
     | 
| 11 | 
         
            -
                c_float,
         
     | 
| 12 | 
         
            -
                c_int,
         
     | 
| 13 | 
         
            -
                c_int64,
         
     | 
| 14 | 
         
            -
                c_void_p,
         
     | 
| 15 | 
         
            -
            )
         
     | 
| 16 | 
         
            -
            from dataclasses import dataclass
         
     | 
| 17 | 
         
            -
            from os import path
         
     | 
| 18 | 
         
            -
            from typing import List, Any
         
     | 
| 19 | 
         
            -
             
     | 
| 20 | 
         
            -
            import numpy as np
         
     | 
| 21 | 
         
            -
            from PIL import Image
         
     | 
| 22 | 
         
            -
             
     | 
| 23 | 
         
            -
            from backend.gguf.sdcpp_types import (
         
     | 
| 24 | 
         
            -
                RngType,
         
     | 
| 25 | 
         
            -
                SampleMethod,
         
     | 
| 26 | 
         
            -
                Schedule,
         
     | 
| 27 | 
         
            -
                SDCPPLogLevel,
         
     | 
| 28 | 
         
            -
                SDImage,
         
     | 
| 29 | 
         
            -
                SdType,
         
     | 
| 30 | 
         
            -
            )
         
     | 
| 31 | 
         
            -
             
     | 
| 32 | 
         
            -
             
     | 
| 33 | 
         
            -
            @dataclass
         
     | 
| 34 | 
         
            -
            class ModelConfig:
         
     | 
| 35 | 
         
            -
                model_path: str = ""
         
     | 
| 36 | 
         
            -
                clip_l_path: str = ""
         
     | 
| 37 | 
         
            -
                t5xxl_path: str = ""
         
     | 
| 38 | 
         
            -
                diffusion_model_path: str = ""
         
     | 
| 39 | 
         
            -
                vae_path: str = ""
         
     | 
| 40 | 
         
            -
                taesd_path: str = ""
         
     | 
| 41 | 
         
            -
                control_net_path: str = ""
         
     | 
| 42 | 
         
            -
                lora_model_dir: str = ""
         
     | 
| 43 | 
         
            -
                embed_dir: str = ""
         
     | 
| 44 | 
         
            -
                stacked_id_embed_dir: str = ""
         
     | 
| 45 | 
         
            -
                vae_decode_only: bool = True
         
     | 
| 46 | 
         
            -
                vae_tiling: bool = False
         
     | 
| 47 | 
         
            -
                free_params_immediately: bool = False
         
     | 
| 48 | 
         
            -
                n_threads: int = 4
         
     | 
| 49 | 
         
            -
                wtype: SdType = SdType.SD_TYPE_Q4_0
         
     | 
| 50 | 
         
            -
                rng_type: RngType = RngType.CUDA_RNG
         
     | 
| 51 | 
         
            -
                schedule: Schedule = Schedule.DEFAULT
         
     | 
| 52 | 
         
            -
                keep_clip_on_cpu: bool = False
         
     | 
| 53 | 
         
            -
                keep_control_net_cpu: bool = False
         
     | 
| 54 | 
         
            -
                keep_vae_on_cpu: bool = False
         
     | 
| 55 | 
         
            -
             
     | 
| 56 | 
         
            -
             
     | 
| 57 | 
         
            -
            @dataclass
         
     | 
| 58 | 
         
            -
            class Txt2ImgConfig:
         
     | 
| 59 | 
         
            -
                prompt: str = "a man wearing sun glasses, highly detailed"
         
     | 
| 60 | 
         
            -
                negative_prompt: str = ""
         
     | 
| 61 | 
         
            -
                clip_skip: int = -1
         
     | 
| 62 | 
         
            -
                cfg_scale: float = 2.0
         
     | 
| 63 | 
         
            -
                guidance: float = 3.5
         
     | 
| 64 | 
         
            -
                width: int = 512
         
     | 
| 65 | 
         
            -
                height: int = 512
         
     | 
| 66 | 
         
            -
                sample_method: SampleMethod = SampleMethod.EULER_A
         
     | 
| 67 | 
         
            -
                sample_steps: int = 1
         
     | 
| 68 | 
         
            -
                seed: int = -1
         
     | 
| 69 | 
         
            -
                batch_count: int = 2
         
     | 
| 70 | 
         
            -
                control_cond: Image = None
         
     | 
| 71 | 
         
            -
                control_strength: float = 0.90
         
     | 
| 72 | 
         
            -
                style_strength: float = 0.5
         
     | 
| 73 | 
         
            -
                normalize_input: bool = False
         
     | 
| 74 | 
         
            -
                input_id_images_path: bytes = b""
         
     | 
| 75 | 
         
            -
             
     | 
| 76 | 
         
            -
             
     | 
| 77 | 
         
            -
            class GGUFDiffusion:
         
     | 
| 78 | 
         
            -
                """GGUF Diffusion
         
     | 
| 79 | 
         
            -
                To support GGUF diffusion model based on stablediffusion.cpp
         
     | 
| 80 | 
         
            -
                https://github.com/ggerganov/ggml/blob/master/docs/gguf.md
         
     | 
| 81 | 
         
            -
                Implmented based on stablediffusion.h
         
     | 
| 82 | 
         
            -
                """
         
     | 
| 83 | 
         
            -
             
     | 
| 84 | 
         
            -
                def __init__(
         
     | 
| 85 | 
         
            -
                    self,
         
     | 
| 86 | 
         
            -
                    libpath: str,
         
     | 
| 87 | 
         
            -
                    config: ModelConfig,
         
     | 
| 88 | 
         
            -
                    logging_enabled: bool = False,
         
     | 
| 89 | 
         
            -
                ):
         
     | 
| 90 | 
         
            -
                    sdcpp_shared_lib_path = self._get_sdcpp_shared_lib_path(libpath)
         
     | 
| 91 | 
         
            -
                    try:
         
     | 
| 92 | 
         
            -
                        self.libsdcpp = ctypes.CDLL(sdcpp_shared_lib_path)
         
     | 
| 93 | 
         
            -
                    except OSError as e:
         
     | 
| 94 | 
         
            -
                        print(f"Failed to load library {sdcpp_shared_lib_path}")
         
     | 
| 95 | 
         
            -
                        raise ValueError(f"Error: {e}")
         
     | 
| 96 | 
         
            -
             
     | 
| 97 | 
         
            -
                    if not config.clip_l_path or not path.exists(config.clip_l_path):
         
     | 
| 98 | 
         
            -
                        raise ValueError(
         
     | 
| 99 | 
         
            -
                            "CLIP model file not found,please check readme.md for GGUF model usage"
         
     | 
| 100 | 
         
            -
                        )
         
     | 
| 101 | 
         
            -
             
     | 
| 102 | 
         
            -
                    if not config.t5xxl_path or not path.exists(config.t5xxl_path):
         
     | 
| 103 | 
         
            -
                        raise ValueError(
         
     | 
| 104 | 
         
            -
                            "T5XXL model file not found,please check readme.md for GGUF model usage"
         
     | 
| 105 | 
         
            -
                        )
         
     | 
| 106 | 
         
            -
             
     | 
| 107 | 
         
            -
                    if not config.diffusion_model_path or not path.exists(
         
     | 
| 108 | 
         
            -
                        config.diffusion_model_path
         
     | 
| 109 | 
         
            -
                    ):
         
     | 
| 110 | 
         
            -
                        raise ValueError(
         
     | 
| 111 | 
         
            -
                            "Diffusion model file not found,please check readme.md for GGUF model usage"
         
     | 
| 112 | 
         
            -
                        )
         
     | 
| 113 | 
         
            -
             
     | 
| 114 | 
         
            -
                    if not config.vae_path or not path.exists(config.vae_path):
         
     | 
| 115 | 
         
            -
                        raise ValueError(
         
     | 
| 116 | 
         
            -
                            "VAE model file not found,please check readme.md for GGUF model usage"
         
     | 
| 117 | 
         
            -
                        )
         
     | 
| 118 | 
         
            -
             
     | 
| 119 | 
         
            -
                    self.model_config = config
         
     | 
| 120 | 
         
            -
             
     | 
| 121 | 
         
            -
                    self.libsdcpp.new_sd_ctx.argtypes = [
         
     | 
| 122 | 
         
            -
                        c_char_p,  # const char* model_path
         
     | 
| 123 | 
         
            -
                        c_char_p,  # const char* clip_l_path
         
     | 
| 124 | 
         
            -
                        c_char_p,  # const char* t5xxl_path
         
     | 
| 125 | 
         
            -
                        c_char_p,  # const char* diffusion_model_path
         
     | 
| 126 | 
         
            -
                        c_char_p,  # const char* vae_path
         
     | 
| 127 | 
         
            -
                        c_char_p,  # const char* taesd_path
         
     | 
| 128 | 
         
            -
                        c_char_p,  # const char* control_net_path_c_str
         
     | 
| 129 | 
         
            -
                        c_char_p,  # const char* lora_model_dir
         
     | 
| 130 | 
         
            -
                        c_char_p,  # const char* embed_dir_c_str
         
     | 
| 131 | 
         
            -
                        c_char_p,  # const char* stacked_id_embed_dir_c_str
         
     | 
| 132 | 
         
            -
                        c_bool,  # bool vae_decode_only
         
     | 
| 133 | 
         
            -
                        c_bool,  # bool vae_tiling
         
     | 
| 134 | 
         
            -
                        c_bool,  # bool free_params_immediately
         
     | 
| 135 | 
         
            -
                        c_int,  # int n_threads
         
     | 
| 136 | 
         
            -
                        SdType,  # enum sd_type_t wtype
         
     | 
| 137 | 
         
            -
                        RngType,  # enum rng_type_t rng_type
         
     | 
| 138 | 
         
            -
                        Schedule,  # enum schedule_t s
         
     | 
| 139 | 
         
            -
                        c_bool,  # bool keep_clip_on_cpu
         
     | 
| 140 | 
         
            -
                        c_bool,  # bool keep_control_net_cpu
         
     | 
| 141 | 
         
            -
                        c_bool,  # bool keep_vae_on_cpu
         
     | 
| 142 | 
         
            -
                    ]
         
     | 
| 143 | 
         
            -
             
     | 
| 144 | 
         
            -
                    self.libsdcpp.new_sd_ctx.restype = POINTER(c_void_p)
         
     | 
| 145 | 
         
            -
             
     | 
| 146 | 
         
            -
                    self.sd_ctx = self.libsdcpp.new_sd_ctx(
         
     | 
| 147 | 
         
            -
                        self._str_to_bytes(self.model_config.model_path),
         
     | 
| 148 | 
         
            -
                        self._str_to_bytes(self.model_config.clip_l_path),
         
     | 
| 149 | 
         
            -
                        self._str_to_bytes(self.model_config.t5xxl_path),
         
     | 
| 150 | 
         
            -
                        self._str_to_bytes(self.model_config.diffusion_model_path),
         
     | 
| 151 | 
         
            -
                        self._str_to_bytes(self.model_config.vae_path),
         
     | 
| 152 | 
         
            -
                        self._str_to_bytes(self.model_config.taesd_path),
         
     | 
| 153 | 
         
            -
                        self._str_to_bytes(self.model_config.control_net_path),
         
     | 
| 154 | 
         
            -
                        self._str_to_bytes(self.model_config.lora_model_dir),
         
     | 
| 155 | 
         
            -
                        self._str_to_bytes(self.model_config.embed_dir),
         
     | 
| 156 | 
         
            -
                        self._str_to_bytes(self.model_config.stacked_id_embed_dir),
         
     | 
| 157 | 
         
            -
                        self.model_config.vae_decode_only,
         
     | 
| 158 | 
         
            -
                        self.model_config.vae_tiling,
         
     | 
| 159 | 
         
            -
                        self.model_config.free_params_immediately,
         
     | 
| 160 | 
         
            -
                        self.model_config.n_threads,
         
     | 
| 161 | 
         
            -
                        self.model_config.wtype,
         
     | 
| 162 | 
         
            -
                        self.model_config.rng_type,
         
     | 
| 163 | 
         
            -
                        self.model_config.schedule,
         
     | 
| 164 | 
         
            -
                        self.model_config.keep_clip_on_cpu,
         
     | 
| 165 | 
         
            -
                        self.model_config.keep_control_net_cpu,
         
     | 
| 166 | 
         
            -
                        self.model_config.keep_vae_on_cpu,
         
     | 
| 167 | 
         
            -
                    )
         
     | 
| 168 | 
         
            -
             
     | 
| 169 | 
         
            -
                    if logging_enabled:
         
     | 
| 170 | 
         
            -
                        self._set_logcallback()
         
     | 
| 171 | 
         
            -
             
     | 
| 172 | 
         
            -
                def _set_logcallback(self):
         
     | 
| 173 | 
         
            -
                    print("Setting logging callback")
         
     | 
| 174 | 
         
            -
                    # Define function callback
         
     | 
| 175 | 
         
            -
                    SdLogCallbackType = ctypes.CFUNCTYPE(
         
     | 
| 176 | 
         
            -
                        None,
         
     | 
| 177 | 
         
            -
                        SDCPPLogLevel,
         
     | 
| 178 | 
         
            -
                        ctypes.c_char_p,
         
     | 
| 179 | 
         
            -
                        ctypes.c_void_p,
         
     | 
| 180 | 
         
            -
                    )
         
     | 
| 181 | 
         
            -
             
     | 
| 182 | 
         
            -
                    self.libsdcpp.sd_set_log_callback.argtypes = [
         
     | 
| 183 | 
         
            -
                        SdLogCallbackType,
         
     | 
| 184 | 
         
            -
                        ctypes.c_void_p,
         
     | 
| 185 | 
         
            -
                    ]
         
     | 
| 186 | 
         
            -
                    self.libsdcpp.sd_set_log_callback.restype = None
         
     | 
| 187 | 
         
            -
                    # Convert the Python callback to a C func pointer
         
     | 
| 188 | 
         
            -
                    self.c_log_callback = SdLogCallbackType(
         
     | 
| 189 | 
         
            -
                        self.log_callback
         
     | 
| 190 | 
         
            -
                    )  # prevent GC,keep callback as member variable
         
     | 
| 191 | 
         
            -
                    self.libsdcpp.sd_set_log_callback(self.c_log_callback, None)
         
     | 
| 192 | 
         
            -
             
     | 
| 193 | 
         
            -
                def _get_sdcpp_shared_lib_path(
         
     | 
| 194 | 
         
            -
                    self,
         
     | 
| 195 | 
         
            -
                    root_path: str,
         
     | 
| 196 | 
         
            -
                ) -> str:
         
     | 
| 197 | 
         
            -
                    system_name = platform.system()
         
     | 
| 198 | 
         
            -
                    print(f"GGUF Diffusion on {system_name}")
         
     | 
| 199 | 
         
            -
                    lib_name = "stable-diffusion.dll"
         
     | 
| 200 | 
         
            -
                    sdcpp_lib_path = ""
         
     | 
| 201 | 
         
            -
             
     | 
| 202 | 
         
            -
                    if system_name == "Windows":
         
     | 
| 203 | 
         
            -
                        sdcpp_lib_path = path.join(root_path, lib_name)
         
     | 
| 204 | 
         
            -
                    elif system_name == "Linux":
         
     | 
| 205 | 
         
            -
                        lib_name = "libstable-diffusion.so"
         
     | 
| 206 | 
         
            -
                        sdcpp_lib_path = path.join(root_path, lib_name)
         
     | 
| 207 | 
         
            -
                    elif system_name == "Darwin":
         
     | 
| 208 | 
         
            -
                        lib_name = "libstable-diffusion.dylib"
         
     | 
| 209 | 
         
            -
                        sdcpp_lib_path = path.join(root_path, lib_name)
         
     | 
| 210 | 
         
            -
                    else:
         
     | 
| 211 | 
         
            -
                        print("Unknown platform.")
         
     | 
| 212 | 
         
            -
             
     | 
| 213 | 
         
            -
                    return sdcpp_lib_path
         
     | 
| 214 | 
         
            -
             
     | 
| 215 | 
         
            -
                @staticmethod
         
     | 
| 216 | 
         
            -
                def log_callback(
         
     | 
| 217 | 
         
            -
                    level,
         
     | 
| 218 | 
         
            -
                    text,
         
     | 
| 219 | 
         
            -
                    data,
         
     | 
| 220 | 
         
            -
                ):
         
     | 
| 221 | 
         
            -
                    print(f"{text.decode('utf-8')}", end="")
         
     | 
| 222 | 
         
            -
             
     | 
| 223 | 
         
            -
                def _str_to_bytes(self, in_str: str, encoding: str = "utf-8") -> bytes:
         
     | 
| 224 | 
         
            -
                    if in_str:
         
     | 
| 225 | 
         
            -
                        return in_str.encode(encoding)
         
     | 
| 226 | 
         
            -
                    else:
         
     | 
| 227 | 
         
            -
                        return b""
         
     | 
| 228 | 
         
            -
             
     | 
| 229 | 
         
            -
                def generate_text2mg(self, txt2img_cfg: Txt2ImgConfig) -> List[Any]:
         
     | 
| 230 | 
         
            -
                    self.libsdcpp.txt2img.restype = POINTER(SDImage)
         
     | 
| 231 | 
         
            -
                    self.libsdcpp.txt2img.argtypes = [
         
     | 
| 232 | 
         
            -
                        c_void_p,  # sd_ctx_t* sd_ctx (pointer to context object)
         
     | 
| 233 | 
         
            -
                        c_char_p,  # const char* prompt
         
     | 
| 234 | 
         
            -
                        c_char_p,  # const char* negative_prompt
         
     | 
| 235 | 
         
            -
                        c_int,  # int clip_skip
         
     | 
| 236 | 
         
            -
                        c_float,  # float cfg_scale
         
     | 
| 237 | 
         
            -
                        c_float,  # float guidance
         
     | 
| 238 | 
         
            -
                        c_int,  # int width
         
     | 
| 239 | 
         
            -
                        c_int,  # int height
         
     | 
| 240 | 
         
            -
                        SampleMethod,  # enum sample_method_t sample_method
         
     | 
| 241 | 
         
            -
                        c_int,  # int sample_steps
         
     | 
| 242 | 
         
            -
                        c_int64,  # int64_t seed
         
     | 
| 243 | 
         
            -
                        c_int,  # int batch_count
         
     | 
| 244 | 
         
            -
                        POINTER(SDImage),  # const sd_image_t* control_cond (pointer to SDImage)
         
     | 
| 245 | 
         
            -
                        c_float,  # float control_strength
         
     | 
| 246 | 
         
            -
                        c_float,  # float style_strength
         
     | 
| 247 | 
         
            -
                        c_bool,  # bool normalize_input
         
     | 
| 248 | 
         
            -
                        c_char_p,  # const char* input_id_images_path
         
     | 
| 249 | 
         
            -
                    ]
         
     | 
| 250 | 
         
            -
             
     | 
| 251 | 
         
            -
                    image_buffer = self.libsdcpp.txt2img(
         
     | 
| 252 | 
         
            -
                        self.sd_ctx,
         
     | 
| 253 | 
         
            -
                        self._str_to_bytes(txt2img_cfg.prompt),
         
     | 
| 254 | 
         
            -
                        self._str_to_bytes(txt2img_cfg.negative_prompt),
         
     | 
| 255 | 
         
            -
                        txt2img_cfg.clip_skip,
         
     | 
| 256 | 
         
            -
                        txt2img_cfg.cfg_scale,
         
     | 
| 257 | 
         
            -
                        txt2img_cfg.guidance,
         
     | 
| 258 | 
         
            -
                        txt2img_cfg.width,
         
     | 
| 259 | 
         
            -
                        txt2img_cfg.height,
         
     | 
| 260 | 
         
            -
                        txt2img_cfg.sample_method,
         
     | 
| 261 | 
         
            -
                        txt2img_cfg.sample_steps,
         
     | 
| 262 | 
         
            -
                        txt2img_cfg.seed,
         
     | 
| 263 | 
         
            -
                        txt2img_cfg.batch_count,
         
     | 
| 264 | 
         
            -
                        txt2img_cfg.control_cond,
         
     | 
| 265 | 
         
            -
                        txt2img_cfg.control_strength,
         
     | 
| 266 | 
         
            -
                        txt2img_cfg.style_strength,
         
     | 
| 267 | 
         
            -
                        txt2img_cfg.normalize_input,
         
     | 
| 268 | 
         
            -
                        txt2img_cfg.input_id_images_path,
         
     | 
| 269 | 
         
            -
                    )
         
     | 
| 270 | 
         
            -
             
     | 
| 271 | 
         
            -
                    images = self._get_sd_images_from_buffer(
         
     | 
| 272 | 
         
            -
                        image_buffer,
         
     | 
| 273 | 
         
            -
                        txt2img_cfg.batch_count,
         
     | 
| 274 | 
         
            -
                    )
         
     | 
| 275 | 
         
            -
             
     | 
| 276 | 
         
            -
                    return images
         
     | 
| 277 | 
         
            -
             
     | 
| 278 | 
         
            -
                def _get_sd_images_from_buffer(
         
     | 
| 279 | 
         
            -
                    self,
         
     | 
| 280 | 
         
            -
                    image_buffer: Any,
         
     | 
| 281 | 
         
            -
                    batch_count: int,
         
     | 
| 282 | 
         
            -
                ) -> List[Any]:
         
     | 
| 283 | 
         
            -
                    images = []
         
     | 
| 284 | 
         
            -
                    if image_buffer:
         
     | 
| 285 | 
         
            -
                        for i in range(batch_count):
         
     | 
| 286 | 
         
            -
                            image = image_buffer[i]
         
     | 
| 287 | 
         
            -
                            print(
         
     | 
| 288 | 
         
            -
                                f"Generated image: {image.width}x{image.height} with {image.channel} channels"
         
     | 
| 289 | 
         
            -
                            )
         
     | 
| 290 | 
         
            -
             
     | 
| 291 | 
         
            -
                            width = image.width
         
     | 
| 292 | 
         
            -
                            height = image.height
         
     | 
| 293 | 
         
            -
                            channels = image.channel
         
     | 
| 294 | 
         
            -
                            pixel_data = np.ctypeslib.as_array(
         
     | 
| 295 | 
         
            -
                                image.data, shape=(height, width, channels)
         
     | 
| 296 | 
         
            -
                            )
         
     | 
| 297 | 
         
            -
             
     | 
| 298 | 
         
            -
                            if channels == 1:
         
     | 
| 299 | 
         
            -
                                pil_image = Image.fromarray(pixel_data.squeeze(), mode="L")
         
     | 
| 300 | 
         
            -
                            elif channels == 3:
         
     | 
| 301 | 
         
            -
                                pil_image = Image.fromarray(pixel_data, mode="RGB")
         
     | 
| 302 | 
         
            -
                            elif channels == 4:
         
     | 
| 303 | 
         
            -
                                pil_image = Image.fromarray(pixel_data, mode="RGBA")
         
     | 
| 304 | 
         
            -
                            else:
         
     | 
| 305 | 
         
            -
                                raise ValueError(f"Unsupported number of channels: {channels}")
         
     | 
| 306 | 
         
            -
             
     | 
| 307 | 
         
            -
                            images.append(pil_image)
         
     | 
| 308 | 
         
            -
                    return images
         
     | 
| 309 | 
         
            -
             
     | 
| 310 | 
         
            -
                def terminate(self):
         
     | 
| 311 | 
         
            -
                    if self.libsdcpp:
         
     | 
| 312 | 
         
            -
                        if self.sd_ctx:
         
     | 
| 313 | 
         
            -
                            self.libsdcpp.free_sd_ctx.argtypes = [c_void_p]
         
     | 
| 314 | 
         
            -
                            self.libsdcpp.free_sd_ctx.restype = None
         
     | 
| 315 | 
         
            -
                            self.libsdcpp.free_sd_ctx(self.sd_ctx)
         
     | 
| 316 | 
         
            -
                            del self.sd_ctx
         
     | 
| 317 | 
         
            -
                            self.sd_ctx = None
         
     | 
| 318 | 
         
            -
                            del self.libsdcpp
         
     | 
| 319 | 
         
            -
                            self.libsdcpp = None
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/gguf/sdcpp_types.py
    DELETED
    
    | 
         @@ -1,104 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            """
         
     | 
| 2 | 
         
            -
            Ctypes for stablediffusion.cpp shared library
         
     | 
| 3 | 
         
            -
            This is as per the stablediffusion.h  file
         
     | 
| 4 | 
         
            -
            """
         
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
            from enum import IntEnum
         
     | 
| 7 | 
         
            -
            from ctypes import (
         
     | 
| 8 | 
         
            -
                c_int,
         
     | 
| 9 | 
         
            -
                c_uint32,
         
     | 
| 10 | 
         
            -
                c_uint8,
         
     | 
| 11 | 
         
            -
                POINTER,
         
     | 
| 12 | 
         
            -
                Structure,
         
     | 
| 13 | 
         
            -
            )
         
     | 
| 14 | 
         
            -
             
     | 
| 15 | 
         
            -
             
     | 
| 16 | 
         
            -
            class CtypesEnum(IntEnum):
         
     | 
| 17 | 
         
            -
                """A ctypes-compatible IntEnum superclass."""
         
     | 
| 18 | 
         
            -
             
     | 
| 19 | 
         
            -
                @classmethod
         
     | 
| 20 | 
         
            -
                def from_param(cls, obj):
         
     | 
| 21 | 
         
            -
                    return int(obj)
         
     | 
| 22 | 
         
            -
             
     | 
| 23 | 
         
            -
             
     | 
| 24 | 
         
            -
            class RngType(CtypesEnum):
         
     | 
| 25 | 
         
            -
                STD_DEFAULT_RNG = 0
         
     | 
| 26 | 
         
            -
                CUDA_RNG = 1
         
     | 
| 27 | 
         
            -
             
     | 
| 28 | 
         
            -
             
     | 
| 29 | 
         
            -
            class SampleMethod(CtypesEnum):
         
     | 
| 30 | 
         
            -
                EULER_A = 0
         
     | 
| 31 | 
         
            -
                EULER = 1
         
     | 
| 32 | 
         
            -
                HEUN = 2
         
     | 
| 33 | 
         
            -
                DPM2 = 3
         
     | 
| 34 | 
         
            -
                DPMPP2S_A = 4
         
     | 
| 35 | 
         
            -
                DPMPP2M = 5
         
     | 
| 36 | 
         
            -
                DPMPP2Mv2 = 6
         
     | 
| 37 | 
         
            -
                IPNDM = 7
         
     | 
| 38 | 
         
            -
                IPNDM_V = 7
         
     | 
| 39 | 
         
            -
                LCM = 8
         
     | 
| 40 | 
         
            -
                N_SAMPLE_METHODS = 9
         
     | 
| 41 | 
         
            -
             
     | 
| 42 | 
         
            -
             
     | 
| 43 | 
         
            -
            class Schedule(CtypesEnum):
         
     | 
| 44 | 
         
            -
                DEFAULT = 0
         
     | 
| 45 | 
         
            -
                DISCRETE = 1
         
     | 
| 46 | 
         
            -
                KARRAS = 2
         
     | 
| 47 | 
         
            -
                EXPONENTIAL = 3
         
     | 
| 48 | 
         
            -
                AYS = 4
         
     | 
| 49 | 
         
            -
                GITS = 5
         
     | 
| 50 | 
         
            -
                N_SCHEDULES = 5
         
     | 
| 51 | 
         
            -
             
     | 
| 52 | 
         
            -
             
     | 
| 53 | 
         
            -
            class SdType(CtypesEnum):
         
     | 
| 54 | 
         
            -
                SD_TYPE_F32 = 0
         
     | 
| 55 | 
         
            -
                SD_TYPE_F16 = 1
         
     | 
| 56 | 
         
            -
                SD_TYPE_Q4_0 = 2
         
     | 
| 57 | 
         
            -
                SD_TYPE_Q4_1 = 3
         
     | 
| 58 | 
         
            -
                # SD_TYPE_Q4_2 = 4, support has been removed
         
     | 
| 59 | 
         
            -
                # SD_TYPE_Q4_3 = 5, support has been removed
         
     | 
| 60 | 
         
            -
                SD_TYPE_Q5_0 = 6
         
     | 
| 61 | 
         
            -
                SD_TYPE_Q5_1 = 7
         
     | 
| 62 | 
         
            -
                SD_TYPE_Q8_0 = 8
         
     | 
| 63 | 
         
            -
                SD_TYPE_Q8_1 = 9
         
     | 
| 64 | 
         
            -
                SD_TYPE_Q2_K = 10
         
     | 
| 65 | 
         
            -
                SD_TYPE_Q3_K = 11
         
     | 
| 66 | 
         
            -
                SD_TYPE_Q4_K = 12
         
     | 
| 67 | 
         
            -
                SD_TYPE_Q5_K = 13
         
     | 
| 68 | 
         
            -
                SD_TYPE_Q6_K = 14
         
     | 
| 69 | 
         
            -
                SD_TYPE_Q8_K = 15
         
     | 
| 70 | 
         
            -
                SD_TYPE_IQ2_XXS = 16
         
     | 
| 71 | 
         
            -
                SD_TYPE_IQ2_XS = 17
         
     | 
| 72 | 
         
            -
                SD_TYPE_IQ3_XXS = 18
         
     | 
| 73 | 
         
            -
                SD_TYPE_IQ1_S = 19
         
     | 
| 74 | 
         
            -
                SD_TYPE_IQ4_NL = 20
         
     | 
| 75 | 
         
            -
                SD_TYPE_IQ3_S = 21
         
     | 
| 76 | 
         
            -
                SD_TYPE_IQ2_S = 22
         
     | 
| 77 | 
         
            -
                SD_TYPE_IQ4_XS = 23
         
     | 
| 78 | 
         
            -
                SD_TYPE_I8 = 24
         
     | 
| 79 | 
         
            -
                SD_TYPE_I16 = 25
         
     | 
| 80 | 
         
            -
                SD_TYPE_I32 = 26
         
     | 
| 81 | 
         
            -
                SD_TYPE_I64 = 27
         
     | 
| 82 | 
         
            -
                SD_TYPE_F64 = 28
         
     | 
| 83 | 
         
            -
                SD_TYPE_IQ1_M = 29
         
     | 
| 84 | 
         
            -
                SD_TYPE_BF16 = 30
         
     | 
| 85 | 
         
            -
                SD_TYPE_Q4_0_4_4 = 31
         
     | 
| 86 | 
         
            -
                SD_TYPE_Q4_0_4_8 = 32
         
     | 
| 87 | 
         
            -
                SD_TYPE_Q4_0_8_8 = 33
         
     | 
| 88 | 
         
            -
                SD_TYPE_COUNT = 34
         
     | 
| 89 | 
         
            -
             
     | 
| 90 | 
         
            -
             
     | 
| 91 | 
         
            -
            class SDImage(Structure):
         
     | 
| 92 | 
         
            -
                _fields_ = [
         
     | 
| 93 | 
         
            -
                    ("width", c_uint32),
         
     | 
| 94 | 
         
            -
                    ("height", c_uint32),
         
     | 
| 95 | 
         
            -
                    ("channel", c_uint32),
         
     | 
| 96 | 
         
            -
                    ("data", POINTER(c_uint8)),
         
     | 
| 97 | 
         
            -
                ]
         
     | 
| 98 | 
         
            -
             
     | 
| 99 | 
         
            -
             
     | 
| 100 | 
         
            -
            class SDCPPLogLevel(c_int):
         
     | 
| 101 | 
         
            -
                SD_LOG_LEVEL_DEBUG = 0
         
     | 
| 102 | 
         
            -
                SD_LOG_LEVEL_INFO = 1
         
     | 
| 103 | 
         
            -
                SD_LOG_LEVEL_WARNING = 2
         
     | 
| 104 | 
         
            -
                SD_LOG_LEVEL_ERROR = 3
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/image_saver.py
    DELETED
    
    | 
         @@ -1,75 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            import json
         
     | 
| 2 | 
         
            -
            from os import path, mkdir
         
     | 
| 3 | 
         
            -
            from typing import Any
         
     | 
| 4 | 
         
            -
            from uuid import uuid4
         
     | 
| 5 | 
         
            -
            from backend.models.lcmdiffusion_setting import LCMDiffusionSetting
         
     | 
| 6 | 
         
            -
            from utils import get_image_file_extension
         
     | 
| 7 | 
         
            -
             
     | 
| 8 | 
         
            -
             
     | 
| 9 | 
         
            -
            def get_exclude_keys():
         
     | 
| 10 | 
         
            -
                exclude_keys = {
         
     | 
| 11 | 
         
            -
                    "init_image": True,
         
     | 
| 12 | 
         
            -
                    "generated_images": True,
         
     | 
| 13 | 
         
            -
                    "lora": {
         
     | 
| 14 | 
         
            -
                        "models_dir": True,
         
     | 
| 15 | 
         
            -
                        "path": True,
         
     | 
| 16 | 
         
            -
                    },
         
     | 
| 17 | 
         
            -
                    "dirs": True,
         
     | 
| 18 | 
         
            -
                    "controlnet": {
         
     | 
| 19 | 
         
            -
                        "adapter_path": True,
         
     | 
| 20 | 
         
            -
                    },
         
     | 
| 21 | 
         
            -
                }
         
     | 
| 22 | 
         
            -
                return exclude_keys
         
     | 
| 23 | 
         
            -
             
     | 
| 24 | 
         
            -
             
     | 
| 25 | 
         
            -
            class ImageSaver:
         
     | 
| 26 | 
         
            -
                @staticmethod
         
     | 
| 27 | 
         
            -
                def save_images(
         
     | 
| 28 | 
         
            -
                    output_path: str,
         
     | 
| 29 | 
         
            -
                    images: Any,
         
     | 
| 30 | 
         
            -
                    folder_name: str = "",
         
     | 
| 31 | 
         
            -
                    format: str = "PNG",
         
     | 
| 32 | 
         
            -
                    jpeg_quality: int = 90,
         
     | 
| 33 | 
         
            -
                    lcm_diffusion_setting: LCMDiffusionSetting = None,
         
     | 
| 34 | 
         
            -
                ) -> list[str]:
         
     | 
| 35 | 
         
            -
                    gen_id = uuid4()
         
     | 
| 36 | 
         
            -
                    image_ids = []
         
     | 
| 37 | 
         
            -
             
     | 
| 38 | 
         
            -
                    if images:
         
     | 
| 39 | 
         
            -
                        image_seeds = []
         
     | 
| 40 | 
         
            -
             
     | 
| 41 | 
         
            -
                        for index, image in enumerate(images):
         
     | 
| 42 | 
         
            -
             
     | 
| 43 | 
         
            -
                            image_seed = image.info.get('image_seed')
         
     | 
| 44 | 
         
            -
                            if image_seed is not None:
         
     | 
| 45 | 
         
            -
                                image_seeds.append(image_seed)
         
     | 
| 46 | 
         
            -
             
     | 
| 47 | 
         
            -
                            if not path.exists(output_path):
         
     | 
| 48 | 
         
            -
                                mkdir(output_path)
         
     | 
| 49 | 
         
            -
             
     | 
| 50 | 
         
            -
                            if folder_name:
         
     | 
| 51 | 
         
            -
                                out_path = path.join(
         
     | 
| 52 | 
         
            -
                                    output_path,
         
     | 
| 53 | 
         
            -
                                    folder_name,
         
     | 
| 54 | 
         
            -
                                )
         
     | 
| 55 | 
         
            -
                            else:
         
     | 
| 56 | 
         
            -
                                out_path = output_path
         
     | 
| 57 | 
         
            -
             
     | 
| 58 | 
         
            -
                            if not path.exists(out_path):
         
     | 
| 59 | 
         
            -
                                mkdir(out_path)
         
     | 
| 60 | 
         
            -
                            image_extension = get_image_file_extension(format)
         
     | 
| 61 | 
         
            -
                            image_file_name = f"{gen_id}-{index+1}{image_extension}"
         
     | 
| 62 | 
         
            -
                            image_ids.append(image_file_name)
         
     | 
| 63 | 
         
            -
                            image.save(path.join(out_path, image_file_name), quality = jpeg_quality)
         
     | 
| 64 | 
         
            -
                        if lcm_diffusion_setting:
         
     | 
| 65 | 
         
            -
                            data = lcm_diffusion_setting.model_dump(exclude=get_exclude_keys())
         
     | 
| 66 | 
         
            -
                            if image_seeds:
         
     | 
| 67 | 
         
            -
                                data['image_seeds'] = image_seeds
         
     | 
| 68 | 
         
            -
                            with open(path.join(out_path, f"{gen_id}.json"), "w") as json_file:
         
     | 
| 69 | 
         
            -
                                json.dump(
         
     | 
| 70 | 
         
            -
                                    data,
         
     | 
| 71 | 
         
            -
                                    json_file,
         
     | 
| 72 | 
         
            -
                                    indent=4,
         
     | 
| 73 | 
         
            -
                                )
         
     | 
| 74 | 
         
            -
                    return image_ids
         
     | 
| 75 | 
         
            -
                        
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/lcm_text_to_image.py
    DELETED
    
    | 
         @@ -1,577 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            import gc
         
     | 
| 2 | 
         
            -
            from math import ceil
         
     | 
| 3 | 
         
            -
            from typing import Any, List
         
     | 
| 4 | 
         
            -
            import random
         
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
            import numpy as np
         
     | 
| 7 | 
         
            -
            import torch
         
     | 
| 8 | 
         
            -
            from backend.device import is_openvino_device
         
     | 
| 9 | 
         
            -
            from backend.controlnet import (
         
     | 
| 10 | 
         
            -
                load_controlnet_adapters,
         
     | 
| 11 | 
         
            -
                update_controlnet_arguments,
         
     | 
| 12 | 
         
            -
            )
         
     | 
| 13 | 
         
            -
            from backend.models.lcmdiffusion_setting import (
         
     | 
| 14 | 
         
            -
                DiffusionTask,
         
     | 
| 15 | 
         
            -
                LCMDiffusionSetting,
         
     | 
| 16 | 
         
            -
                LCMLora,
         
     | 
| 17 | 
         
            -
            )
         
     | 
| 18 | 
         
            -
            from backend.openvino.pipelines import (
         
     | 
| 19 | 
         
            -
                get_ov_image_to_image_pipeline,
         
     | 
| 20 | 
         
            -
                get_ov_text_to_image_pipeline,
         
     | 
| 21 | 
         
            -
                ov_load_taesd,
         
     | 
| 22 | 
         
            -
            )
         
     | 
| 23 | 
         
            -
            from backend.pipelines.lcm import (
         
     | 
| 24 | 
         
            -
                get_image_to_image_pipeline,
         
     | 
| 25 | 
         
            -
                get_lcm_model_pipeline,
         
     | 
| 26 | 
         
            -
                load_taesd,
         
     | 
| 27 | 
         
            -
            )
         
     | 
| 28 | 
         
            -
            from backend.pipelines.lcm_lora import get_lcm_lora_pipeline
         
     | 
| 29 | 
         
            -
            from constants import DEVICE, GGUF_THREADS
         
     | 
| 30 | 
         
            -
            from diffusers import LCMScheduler
         
     | 
| 31 | 
         
            -
            from image_ops import resize_pil_image
         
     | 
| 32 | 
         
            -
            from backend.openvino.flux_pipeline import get_flux_pipeline
         
     | 
| 33 | 
         
            -
            from backend.openvino.ov_hc_stablediffusion_pipeline import OvHcLatentConsistency
         
     | 
| 34 | 
         
            -
            from backend.gguf.gguf_diffusion import (
         
     | 
| 35 | 
         
            -
                GGUFDiffusion,
         
     | 
| 36 | 
         
            -
                ModelConfig,
         
     | 
| 37 | 
         
            -
                Txt2ImgConfig,
         
     | 
| 38 | 
         
            -
                SampleMethod,
         
     | 
| 39 | 
         
            -
            )
         
     | 
| 40 | 
         
            -
            from paths import get_app_path
         
     | 
| 41 | 
         
            -
            from pprint import pprint
         
     | 
| 42 | 
         
            -
             
     | 
| 43 | 
         
            -
            try:
         
     | 
| 44 | 
         
            -
                # support for token merging; keeping it optional for now
         
     | 
| 45 | 
         
            -
                import tomesd
         
     | 
| 46 | 
         
            -
            except ImportError:
         
     | 
| 47 | 
         
            -
                print("tomesd library unavailable; disabling token merging support")
         
     | 
| 48 | 
         
            -
                tomesd = None
         
     | 
| 49 | 
         
            -
             
     | 
| 50 | 
         
            -
             
     | 
| 51 | 
         
            -
            class LCMTextToImage:
         
     | 
| 52 | 
         
            -
                def __init__(
         
     | 
| 53 | 
         
            -
                    self,
         
     | 
| 54 | 
         
            -
                    device: str = "cpu",
         
     | 
| 55 | 
         
            -
                ) -> None:
         
     | 
| 56 | 
         
            -
                    self.pipeline = None
         
     | 
| 57 | 
         
            -
                    self.use_openvino = False
         
     | 
| 58 | 
         
            -
                    self.device = ""
         
     | 
| 59 | 
         
            -
                    self.previous_model_id = None
         
     | 
| 60 | 
         
            -
                    self.previous_use_tae_sd = False
         
     | 
| 61 | 
         
            -
                    self.previous_use_lcm_lora = False
         
     | 
| 62 | 
         
            -
                    self.previous_ov_model_id = ""
         
     | 
| 63 | 
         
            -
                    self.previous_token_merging = 0.0
         
     | 
| 64 | 
         
            -
                    self.previous_safety_checker = False
         
     | 
| 65 | 
         
            -
                    self.previous_use_openvino = False
         
     | 
| 66 | 
         
            -
                    self.img_to_img_pipeline = None
         
     | 
| 67 | 
         
            -
                    self.is_openvino_init = False
         
     | 
| 68 | 
         
            -
                    self.previous_lora = None
         
     | 
| 69 | 
         
            -
                    self.task_type = DiffusionTask.text_to_image
         
     | 
| 70 | 
         
            -
                    self.previous_use_gguf_model = False
         
     | 
| 71 | 
         
            -
                    self.previous_gguf_model = None
         
     | 
| 72 | 
         
            -
                    self.torch_data_type = (
         
     | 
| 73 | 
         
            -
                        torch.float32 if is_openvino_device() or DEVICE == "mps" else torch.float16
         
     | 
| 74 | 
         
            -
                    )
         
     | 
| 75 | 
         
            -
                    self.ov_model_id = None
         
     | 
| 76 | 
         
            -
                    print(f"Torch datatype : {self.torch_data_type}")
         
     | 
| 77 | 
         
            -
             
     | 
| 78 | 
         
            -
                def _pipeline_to_device(self):
         
     | 
| 79 | 
         
            -
                    print(f"Pipeline device : {DEVICE}")
         
     | 
| 80 | 
         
            -
                    print(f"Pipeline dtype : {self.torch_data_type}")
         
     | 
| 81 | 
         
            -
                    self.pipeline.to(
         
     | 
| 82 | 
         
            -
                        torch_device=DEVICE,
         
     | 
| 83 | 
         
            -
                        torch_dtype=self.torch_data_type,
         
     | 
| 84 | 
         
            -
                    )
         
     | 
| 85 | 
         
            -
             
     | 
| 86 | 
         
            -
                def _add_freeu(self):
         
     | 
| 87 | 
         
            -
                    pipeline_class = self.pipeline.__class__.__name__
         
     | 
| 88 | 
         
            -
                    if isinstance(self.pipeline.scheduler, LCMScheduler):
         
     | 
| 89 | 
         
            -
                        if pipeline_class == "StableDiffusionPipeline":
         
     | 
| 90 | 
         
            -
                            print("Add FreeU - SD")
         
     | 
| 91 | 
         
            -
                            self.pipeline.enable_freeu(
         
     | 
| 92 | 
         
            -
                                s1=0.9,
         
     | 
| 93 | 
         
            -
                                s2=0.2,
         
     | 
| 94 | 
         
            -
                                b1=1.2,
         
     | 
| 95 | 
         
            -
                                b2=1.4,
         
     | 
| 96 | 
         
            -
                            )
         
     | 
| 97 | 
         
            -
                        elif pipeline_class == "StableDiffusionXLPipeline":
         
     | 
| 98 | 
         
            -
                            print("Add FreeU - SDXL")
         
     | 
| 99 | 
         
            -
                            self.pipeline.enable_freeu(
         
     | 
| 100 | 
         
            -
                                s1=0.6,
         
     | 
| 101 | 
         
            -
                                s2=0.4,
         
     | 
| 102 | 
         
            -
                                b1=1.1,
         
     | 
| 103 | 
         
            -
                                b2=1.2,
         
     | 
| 104 | 
         
            -
                            )
         
     | 
| 105 | 
         
            -
             
     | 
| 106 | 
         
            -
                def _enable_vae_tiling(self):
         
     | 
| 107 | 
         
            -
                    self.pipeline.vae.enable_tiling()
         
     | 
| 108 | 
         
            -
             
     | 
| 109 | 
         
            -
                def _update_lcm_scheduler_params(self):
         
     | 
| 110 | 
         
            -
                    if isinstance(self.pipeline.scheduler, LCMScheduler):
         
     | 
| 111 | 
         
            -
                        self.pipeline.scheduler = LCMScheduler.from_config(
         
     | 
| 112 | 
         
            -
                            self.pipeline.scheduler.config,
         
     | 
| 113 | 
         
            -
                            beta_start=0.001,
         
     | 
| 114 | 
         
            -
                            beta_end=0.01,
         
     | 
| 115 | 
         
            -
                        )
         
     | 
| 116 | 
         
            -
             
     | 
| 117 | 
         
            -
                def _is_hetero_pipeline(self) -> bool:
         
     | 
| 118 | 
         
            -
                    return "square" in self.ov_model_id.lower()
         
     | 
| 119 | 
         
            -
             
     | 
| 120 | 
         
            -
                def _load_ov_hetero_pipeline(self):
         
     | 
| 121 | 
         
            -
                    print("Loading Heterogeneous Compute pipeline")
         
     | 
| 122 | 
         
            -
                    if DEVICE.upper() == "NPU":
         
     | 
| 123 | 
         
            -
                        device = ["NPU", "NPU", "NPU"]
         
     | 
| 124 | 
         
            -
                        self.pipeline = OvHcLatentConsistency(self.ov_model_id, device)
         
     | 
| 125 | 
         
            -
                    else:
         
     | 
| 126 | 
         
            -
                        self.pipeline = OvHcLatentConsistency(self.ov_model_id)
         
     | 
| 127 | 
         
            -
             
     | 
| 128 | 
         
            -
                def _generate_images_hetero_compute(
         
     | 
| 129 | 
         
            -
                    self,
         
     | 
| 130 | 
         
            -
                    lcm_diffusion_setting: LCMDiffusionSetting,
         
     | 
| 131 | 
         
            -
                ):
         
     | 
| 132 | 
         
            -
                    print("Using OpenVINO ")
         
     | 
| 133 | 
         
            -
                    if lcm_diffusion_setting.diffusion_task == DiffusionTask.text_to_image.value:
         
     | 
| 134 | 
         
            -
                        return [
         
     | 
| 135 | 
         
            -
                            self.pipeline.generate(
         
     | 
| 136 | 
         
            -
                                prompt=lcm_diffusion_setting.prompt,
         
     | 
| 137 | 
         
            -
                                neg_prompt=lcm_diffusion_setting.negative_prompt,
         
     | 
| 138 | 
         
            -
                                init_image=None,
         
     | 
| 139 | 
         
            -
                                strength=1.0,
         
     | 
| 140 | 
         
            -
                                num_inference_steps=lcm_diffusion_setting.inference_steps,
         
     | 
| 141 | 
         
            -
                            )
         
     | 
| 142 | 
         
            -
                        ]
         
     | 
| 143 | 
         
            -
                    else:
         
     | 
| 144 | 
         
            -
                        return [
         
     | 
| 145 | 
         
            -
                            self.pipeline.generate(
         
     | 
| 146 | 
         
            -
                                prompt=lcm_diffusion_setting.prompt,
         
     | 
| 147 | 
         
            -
                                neg_prompt=lcm_diffusion_setting.negative_prompt,
         
     | 
| 148 | 
         
            -
                                init_image=lcm_diffusion_setting.init_image,
         
     | 
| 149 | 
         
            -
                                strength=lcm_diffusion_setting.strength,
         
     | 
| 150 | 
         
            -
                                num_inference_steps=lcm_diffusion_setting.inference_steps,
         
     | 
| 151 | 
         
            -
                            )
         
     | 
| 152 | 
         
            -
                        ]
         
     | 
| 153 | 
         
            -
             
     | 
| 154 | 
         
            -
                def _is_valid_mode(
         
     | 
| 155 | 
         
            -
                    self,
         
     | 
| 156 | 
         
            -
                    modes: List,
         
     | 
| 157 | 
         
            -
                ) -> bool:
         
     | 
| 158 | 
         
            -
                    return modes.count(True) == 1 or modes.count(False) == 3
         
     | 
| 159 | 
         
            -
             
     | 
| 160 | 
         
            -
                def _validate_mode(
         
     | 
| 161 | 
         
            -
                    self,
         
     | 
| 162 | 
         
            -
                    modes: List,
         
     | 
| 163 | 
         
            -
                ) -> None:
         
     | 
| 164 | 
         
            -
                    if not self._is_valid_mode(modes):
         
     | 
| 165 | 
         
            -
                        raise ValueError("Invalid mode,delete configs/settings.yaml and retry!")
         
     | 
| 166 | 
         
            -
             
     | 
| 167 | 
         
            -
                def init(
         
     | 
| 168 | 
         
            -
                    self,
         
     | 
| 169 | 
         
            -
                    device: str = "cpu",
         
     | 
| 170 | 
         
            -
                    lcm_diffusion_setting: LCMDiffusionSetting = LCMDiffusionSetting(),
         
     | 
| 171 | 
         
            -
                ) -> None:
         
     | 
| 172 | 
         
            -
                    # Mode validation either LCM LoRA or OpenVINO or GGUF
         
     | 
| 173 | 
         
            -
             
     | 
| 174 | 
         
            -
                    modes = [
         
     | 
| 175 | 
         
            -
                        lcm_diffusion_setting.use_gguf_model,
         
     | 
| 176 | 
         
            -
                        lcm_diffusion_setting.use_openvino,
         
     | 
| 177 | 
         
            -
                        lcm_diffusion_setting.use_lcm_lora,
         
     | 
| 178 | 
         
            -
                    ]
         
     | 
| 179 | 
         
            -
                    self._validate_mode(modes)
         
     | 
| 180 | 
         
            -
                    self.device = device
         
     | 
| 181 | 
         
            -
                    self.use_openvino = lcm_diffusion_setting.use_openvino
         
     | 
| 182 | 
         
            -
                    model_id = lcm_diffusion_setting.lcm_model_id
         
     | 
| 183 | 
         
            -
                    use_local_model = lcm_diffusion_setting.use_offline_model
         
     | 
| 184 | 
         
            -
                    use_tiny_auto_encoder = lcm_diffusion_setting.use_tiny_auto_encoder
         
     | 
| 185 | 
         
            -
                    use_lora = lcm_diffusion_setting.use_lcm_lora
         
     | 
| 186 | 
         
            -
                    lcm_lora: LCMLora = lcm_diffusion_setting.lcm_lora
         
     | 
| 187 | 
         
            -
                    token_merging = lcm_diffusion_setting.token_merging
         
     | 
| 188 | 
         
            -
                    self.ov_model_id = lcm_diffusion_setting.openvino_lcm_model_id
         
     | 
| 189 | 
         
            -
             
     | 
| 190 | 
         
            -
                    if lcm_diffusion_setting.diffusion_task == DiffusionTask.image_to_image.value:
         
     | 
| 191 | 
         
            -
                        lcm_diffusion_setting.init_image = resize_pil_image(
         
     | 
| 192 | 
         
            -
                            lcm_diffusion_setting.init_image,
         
     | 
| 193 | 
         
            -
                            lcm_diffusion_setting.image_width,
         
     | 
| 194 | 
         
            -
                            lcm_diffusion_setting.image_height,
         
     | 
| 195 | 
         
            -
                        )
         
     | 
| 196 | 
         
            -
             
     | 
| 197 | 
         
            -
                    if (
         
     | 
| 198 | 
         
            -
                        self.pipeline is None
         
     | 
| 199 | 
         
            -
                        or self.previous_model_id != model_id
         
     | 
| 200 | 
         
            -
                        or self.previous_use_tae_sd != use_tiny_auto_encoder
         
     | 
| 201 | 
         
            -
                        or self.previous_lcm_lora_base_id != lcm_lora.base_model_id
         
     | 
| 202 | 
         
            -
                        or self.previous_lcm_lora_id != lcm_lora.lcm_lora_id
         
     | 
| 203 | 
         
            -
                        or self.previous_use_lcm_lora != use_lora
         
     | 
| 204 | 
         
            -
                        or self.previous_ov_model_id != self.ov_model_id
         
     | 
| 205 | 
         
            -
                        or self.previous_token_merging != token_merging
         
     | 
| 206 | 
         
            -
                        or self.previous_safety_checker != lcm_diffusion_setting.use_safety_checker
         
     | 
| 207 | 
         
            -
                        or self.previous_use_openvino != lcm_diffusion_setting.use_openvino
         
     | 
| 208 | 
         
            -
                        or self.previous_use_gguf_model != lcm_diffusion_setting.use_gguf_model
         
     | 
| 209 | 
         
            -
                        or self.previous_gguf_model != lcm_diffusion_setting.gguf_model
         
     | 
| 210 | 
         
            -
                        or (
         
     | 
| 211 | 
         
            -
                            self.use_openvino
         
     | 
| 212 | 
         
            -
                            and (
         
     | 
| 213 | 
         
            -
                                self.previous_task_type != lcm_diffusion_setting.diffusion_task
         
     | 
| 214 | 
         
            -
                                or self.previous_lora != lcm_diffusion_setting.lora
         
     | 
| 215 | 
         
            -
                            )
         
     | 
| 216 | 
         
            -
                        )
         
     | 
| 217 | 
         
            -
                        or lcm_diffusion_setting.rebuild_pipeline
         
     | 
| 218 | 
         
            -
                    ):
         
     | 
| 219 | 
         
            -
                        if self.use_openvino and is_openvino_device():
         
     | 
| 220 | 
         
            -
                            if self.pipeline:
         
     | 
| 221 | 
         
            -
                                del self.pipeline
         
     | 
| 222 | 
         
            -
                                self.pipeline = None
         
     | 
| 223 | 
         
            -
                                gc.collect()
         
     | 
| 224 | 
         
            -
                            self.is_openvino_init = True
         
     | 
| 225 | 
         
            -
                            if (
         
     | 
| 226 | 
         
            -
                                lcm_diffusion_setting.diffusion_task
         
     | 
| 227 | 
         
            -
                                == DiffusionTask.text_to_image.value
         
     | 
| 228 | 
         
            -
                            ):
         
     | 
| 229 | 
         
            -
                                print(
         
     | 
| 230 | 
         
            -
                                    f"***** Init Text to image (OpenVINO) - {self.ov_model_id} *****"
         
     | 
| 231 | 
         
            -
                                )
         
     | 
| 232 | 
         
            -
                                if "flux" in self.ov_model_id.lower():
         
     | 
| 233 | 
         
            -
                                    print("Loading OpenVINO Flux pipeline")
         
     | 
| 234 | 
         
            -
                                    self.pipeline = get_flux_pipeline(
         
     | 
| 235 | 
         
            -
                                        self.ov_model_id,
         
     | 
| 236 | 
         
            -
                                        lcm_diffusion_setting.use_tiny_auto_encoder,
         
     | 
| 237 | 
         
            -
                                    )
         
     | 
| 238 | 
         
            -
                                elif self._is_hetero_pipeline():
         
     | 
| 239 | 
         
            -
                                    self._load_ov_hetero_pipeline()
         
     | 
| 240 | 
         
            -
                                else:
         
     | 
| 241 | 
         
            -
                                    self.pipeline = get_ov_text_to_image_pipeline(
         
     | 
| 242 | 
         
            -
                                        self.ov_model_id,
         
     | 
| 243 | 
         
            -
                                        use_local_model,
         
     | 
| 244 | 
         
            -
                                    )
         
     | 
| 245 | 
         
            -
                            elif (
         
     | 
| 246 | 
         
            -
                                lcm_diffusion_setting.diffusion_task
         
     | 
| 247 | 
         
            -
                                == DiffusionTask.image_to_image.value
         
     | 
| 248 | 
         
            -
                            ):
         
     | 
| 249 | 
         
            -
                                if not self.pipeline and self._is_hetero_pipeline():
         
     | 
| 250 | 
         
            -
                                    self._load_ov_hetero_pipeline()
         
     | 
| 251 | 
         
            -
                                else:
         
     | 
| 252 | 
         
            -
                                    print(
         
     | 
| 253 | 
         
            -
                                        f"***** Image to image (OpenVINO) - {self.ov_model_id} *****"
         
     | 
| 254 | 
         
            -
                                    )
         
     | 
| 255 | 
         
            -
                                    self.pipeline = get_ov_image_to_image_pipeline(
         
     | 
| 256 | 
         
            -
                                        self.ov_model_id,
         
     | 
| 257 | 
         
            -
                                        use_local_model,
         
     | 
| 258 | 
         
            -
                                    )
         
     | 
| 259 | 
         
            -
                        elif lcm_diffusion_setting.use_gguf_model:
         
     | 
| 260 | 
         
            -
                            model = lcm_diffusion_setting.gguf_model.diffusion_path
         
     | 
| 261 | 
         
            -
                            print(f"***** Init Text to image (GGUF) - {model} *****")
         
     | 
| 262 | 
         
            -
                            # if self.pipeline:
         
     | 
| 263 | 
         
            -
                            #     self.pipeline.terminate()
         
     | 
| 264 | 
         
            -
                            #     del self.pipeline
         
     | 
| 265 | 
         
            -
                            #     self.pipeline = None
         
     | 
| 266 | 
         
            -
                            self._init_gguf_diffusion(lcm_diffusion_setting)
         
     | 
| 267 | 
         
            -
                        else:
         
     | 
| 268 | 
         
            -
                            if self.pipeline or self.img_to_img_pipeline:
         
     | 
| 269 | 
         
            -
                                self.pipeline = None
         
     | 
| 270 | 
         
            -
                                self.img_to_img_pipeline = None
         
     | 
| 271 | 
         
            -
                                gc.collect()
         
     | 
| 272 | 
         
            -
             
     | 
| 273 | 
         
            -
                            controlnet_args = load_controlnet_adapters(lcm_diffusion_setting)
         
     | 
| 274 | 
         
            -
                            if use_lora:
         
     | 
| 275 | 
         
            -
                                print(
         
     | 
| 276 | 
         
            -
                                    f"***** Init LCM-LoRA pipeline - {lcm_lora.base_model_id} *****"
         
     | 
| 277 | 
         
            -
                                )
         
     | 
| 278 | 
         
            -
                                self.pipeline = get_lcm_lora_pipeline(
         
     | 
| 279 | 
         
            -
                                    lcm_lora.base_model_id,
         
     | 
| 280 | 
         
            -
                                    lcm_lora.lcm_lora_id,
         
     | 
| 281 | 
         
            -
                                    use_local_model,
         
     | 
| 282 | 
         
            -
                                    torch_data_type=self.torch_data_type,
         
     | 
| 283 | 
         
            -
                                    pipeline_args=controlnet_args,
         
     | 
| 284 | 
         
            -
                                )
         
     | 
| 285 | 
         
            -
             
     | 
| 286 | 
         
            -
                            else:
         
     | 
| 287 | 
         
            -
                                print(f"***** Init LCM Model pipeline - {model_id} *****")
         
     | 
| 288 | 
         
            -
                                self.pipeline = get_lcm_model_pipeline(
         
     | 
| 289 | 
         
            -
                                    model_id,
         
     | 
| 290 | 
         
            -
                                    use_local_model,
         
     | 
| 291 | 
         
            -
                                    controlnet_args,
         
     | 
| 292 | 
         
            -
                                )
         
     | 
| 293 | 
         
            -
             
     | 
| 294 | 
         
            -
                            self.img_to_img_pipeline = get_image_to_image_pipeline(self.pipeline)
         
     | 
| 295 | 
         
            -
             
     | 
| 296 | 
         
            -
                            if tomesd and token_merging > 0.001:
         
     | 
| 297 | 
         
            -
                                print(f"***** Token Merging: {token_merging} *****")
         
     | 
| 298 | 
         
            -
                                tomesd.apply_patch(self.pipeline, ratio=token_merging)
         
     | 
| 299 | 
         
            -
                                tomesd.apply_patch(self.img_to_img_pipeline, ratio=token_merging)
         
     | 
| 300 | 
         
            -
             
     | 
| 301 | 
         
            -
                        if use_tiny_auto_encoder:
         
     | 
| 302 | 
         
            -
                            if self.use_openvino and is_openvino_device():
         
     | 
| 303 | 
         
            -
                                if self.pipeline.__class__.__name__ != "OVFluxPipeline":
         
     | 
| 304 | 
         
            -
                                    print("Using Tiny Auto Encoder (OpenVINO)")
         
     | 
| 305 | 
         
            -
                                    ov_load_taesd(
         
     | 
| 306 | 
         
            -
                                        self.pipeline,
         
     | 
| 307 | 
         
            -
                                        use_local_model,
         
     | 
| 308 | 
         
            -
                                    )
         
     | 
| 309 | 
         
            -
                            else:
         
     | 
| 310 | 
         
            -
                                print("Using Tiny Auto Encoder")
         
     | 
| 311 | 
         
            -
                                load_taesd(
         
     | 
| 312 | 
         
            -
                                    self.pipeline,
         
     | 
| 313 | 
         
            -
                                    use_local_model,
         
     | 
| 314 | 
         
            -
                                    self.torch_data_type,
         
     | 
| 315 | 
         
            -
                                )
         
     | 
| 316 | 
         
            -
                                load_taesd(
         
     | 
| 317 | 
         
            -
                                    self.img_to_img_pipeline,
         
     | 
| 318 | 
         
            -
                                    use_local_model,
         
     | 
| 319 | 
         
            -
                                    self.torch_data_type,
         
     | 
| 320 | 
         
            -
                                )
         
     | 
| 321 | 
         
            -
             
     | 
| 322 | 
         
            -
                        if not self.use_openvino and not is_openvino_device():
         
     | 
| 323 | 
         
            -
                            self._pipeline_to_device()
         
     | 
| 324 | 
         
            -
             
     | 
| 325 | 
         
            -
                        if not self._is_hetero_pipeline():
         
     | 
| 326 | 
         
            -
                            if (
         
     | 
| 327 | 
         
            -
                                lcm_diffusion_setting.diffusion_task
         
     | 
| 328 | 
         
            -
                                == DiffusionTask.image_to_image.value
         
     | 
| 329 | 
         
            -
                                and lcm_diffusion_setting.use_openvino
         
     | 
| 330 | 
         
            -
                            ):
         
     | 
| 331 | 
         
            -
                                self.pipeline.scheduler = LCMScheduler.from_config(
         
     | 
| 332 | 
         
            -
                                    self.pipeline.scheduler.config,
         
     | 
| 333 | 
         
            -
                                )
         
     | 
| 334 | 
         
            -
                            else:
         
     | 
| 335 | 
         
            -
                                if not lcm_diffusion_setting.use_gguf_model:
         
     | 
| 336 | 
         
            -
                                    self._update_lcm_scheduler_params()
         
     | 
| 337 | 
         
            -
             
     | 
| 338 | 
         
            -
                        if use_lora:
         
     | 
| 339 | 
         
            -
                            self._add_freeu()
         
     | 
| 340 | 
         
            -
             
     | 
| 341 | 
         
            -
                        self.previous_model_id = model_id
         
     | 
| 342 | 
         
            -
                        self.previous_ov_model_id = self.ov_model_id
         
     | 
| 343 | 
         
            -
                        self.previous_use_tae_sd = use_tiny_auto_encoder
         
     | 
| 344 | 
         
            -
                        self.previous_lcm_lora_base_id = lcm_lora.base_model_id
         
     | 
| 345 | 
         
            -
                        self.previous_lcm_lora_id = lcm_lora.lcm_lora_id
         
     | 
| 346 | 
         
            -
                        self.previous_use_lcm_lora = use_lora
         
     | 
| 347 | 
         
            -
                        self.previous_token_merging = lcm_diffusion_setting.token_merging
         
     | 
| 348 | 
         
            -
                        self.previous_safety_checker = lcm_diffusion_setting.use_safety_checker
         
     | 
| 349 | 
         
            -
                        self.previous_use_openvino = lcm_diffusion_setting.use_openvino
         
     | 
| 350 | 
         
            -
                        self.previous_task_type = lcm_diffusion_setting.diffusion_task
         
     | 
| 351 | 
         
            -
                        self.previous_lora = lcm_diffusion_setting.lora.model_copy(deep=True)
         
     | 
| 352 | 
         
            -
                        self.previous_use_gguf_model = lcm_diffusion_setting.use_gguf_model
         
     | 
| 353 | 
         
            -
                        self.previous_gguf_model = lcm_diffusion_setting.gguf_model.model_copy(
         
     | 
| 354 | 
         
            -
                            deep=True
         
     | 
| 355 | 
         
            -
                        )
         
     | 
| 356 | 
         
            -
                        lcm_diffusion_setting.rebuild_pipeline = False
         
     | 
| 357 | 
         
            -
                        if (
         
     | 
| 358 | 
         
            -
                            lcm_diffusion_setting.diffusion_task
         
     | 
| 359 | 
         
            -
                            == DiffusionTask.text_to_image.value
         
     | 
| 360 | 
         
            -
                        ):
         
     | 
| 361 | 
         
            -
                            print(f"Pipeline : {self.pipeline}")
         
     | 
| 362 | 
         
            -
                        elif (
         
     | 
| 363 | 
         
            -
                            lcm_diffusion_setting.diffusion_task
         
     | 
| 364 | 
         
            -
                            == DiffusionTask.image_to_image.value
         
     | 
| 365 | 
         
            -
                        ):
         
     | 
| 366 | 
         
            -
                            if self.use_openvino and is_openvino_device():
         
     | 
| 367 | 
         
            -
                                print(f"Pipeline : {self.pipeline}")
         
     | 
| 368 | 
         
            -
                            else:
         
     | 
| 369 | 
         
            -
                                print(f"Pipeline : {self.img_to_img_pipeline}")
         
     | 
| 370 | 
         
            -
                        if self.use_openvino:
         
     | 
| 371 | 
         
            -
                            if lcm_diffusion_setting.lora.enabled:
         
     | 
| 372 | 
         
            -
                                print("Warning: Lora models not supported on OpenVINO mode")
         
     | 
| 373 | 
         
            -
                        elif not lcm_diffusion_setting.use_gguf_model:
         
     | 
| 374 | 
         
            -
                            adapters = self.pipeline.get_active_adapters()
         
     | 
| 375 | 
         
            -
                            print(f"Active adapters : {adapters}")
         
     | 
| 376 | 
         
            -
             
     | 
| 377 | 
         
            -
                def _get_timesteps(self):
         
     | 
| 378 | 
         
            -
                    time_steps = self.pipeline.scheduler.config.get("timesteps")
         
     | 
| 379 | 
         
            -
                    time_steps_value = [int(time_steps)] if time_steps else None
         
     | 
| 380 | 
         
            -
                    return time_steps_value
         
     | 
| 381 | 
         
            -
             
     | 
| 382 | 
         
            -
                def generate(
         
     | 
| 383 | 
         
            -
                    self,
         
     | 
| 384 | 
         
            -
                    lcm_diffusion_setting: LCMDiffusionSetting,
         
     | 
| 385 | 
         
            -
                    reshape: bool = False,
         
     | 
| 386 | 
         
            -
                ) -> Any:
         
     | 
| 387 | 
         
            -
                    guidance_scale = lcm_diffusion_setting.guidance_scale
         
     | 
| 388 | 
         
            -
                    img_to_img_inference_steps = lcm_diffusion_setting.inference_steps
         
     | 
| 389 | 
         
            -
                    check_step_value = int(
         
     | 
| 390 | 
         
            -
                        lcm_diffusion_setting.inference_steps * lcm_diffusion_setting.strength
         
     | 
| 391 | 
         
            -
                    )
         
     | 
| 392 | 
         
            -
                    if (
         
     | 
| 393 | 
         
            -
                        lcm_diffusion_setting.diffusion_task == DiffusionTask.image_to_image.value
         
     | 
| 394 | 
         
            -
                        and check_step_value < 1
         
     | 
| 395 | 
         
            -
                    ):
         
     | 
| 396 | 
         
            -
                        img_to_img_inference_steps = ceil(1 / lcm_diffusion_setting.strength)
         
     | 
| 397 | 
         
            -
                        print(
         
     | 
| 398 | 
         
            -
                            f"Strength: {lcm_diffusion_setting.strength},{img_to_img_inference_steps}"
         
     | 
| 399 | 
         
            -
                        )
         
     | 
| 400 | 
         
            -
             
     | 
| 401 | 
         
            -
                    pipeline_extra_args = {}
         
     | 
| 402 | 
         
            -
             
     | 
| 403 | 
         
            -
                    if lcm_diffusion_setting.use_seed:
         
     | 
| 404 | 
         
            -
                        cur_seed = lcm_diffusion_setting.seed
         
     | 
| 405 | 
         
            -
                        # for multiple images with a fixed seed, use sequential seeds
         
     | 
| 406 | 
         
            -
                        seeds = [
         
     | 
| 407 | 
         
            -
                            (cur_seed + i) for i in range(lcm_diffusion_setting.number_of_images)
         
     | 
| 408 | 
         
            -
                        ]
         
     | 
| 409 | 
         
            -
                    else:
         
     | 
| 410 | 
         
            -
                        seeds = [
         
     | 
| 411 | 
         
            -
                            random.randint(0, 999999999)
         
     | 
| 412 | 
         
            -
                            for i in range(lcm_diffusion_setting.number_of_images)
         
     | 
| 413 | 
         
            -
                        ]
         
     | 
| 414 | 
         
            -
             
     | 
| 415 | 
         
            -
                    if self.use_openvino:
         
     | 
| 416 | 
         
            -
                        # no support for generators; try at least to ensure reproducible results for single images
         
     | 
| 417 | 
         
            -
                        np.random.seed(seeds[0])
         
     | 
| 418 | 
         
            -
                        if self._is_hetero_pipeline():
         
     | 
| 419 | 
         
            -
                            torch.manual_seed(seeds[0])
         
     | 
| 420 | 
         
            -
                            lcm_diffusion_setting.seed = seeds[0]
         
     | 
| 421 | 
         
            -
                    else:
         
     | 
| 422 | 
         
            -
                        pipeline_extra_args["generator"] = [
         
     | 
| 423 | 
         
            -
                            torch.Generator(device=self.device).manual_seed(s) for s in seeds
         
     | 
| 424 | 
         
            -
                        ]
         
     | 
| 425 | 
         
            -
             
     | 
| 426 | 
         
            -
                    is_openvino_pipe = lcm_diffusion_setting.use_openvino and is_openvino_device()
         
     | 
| 427 | 
         
            -
                    if is_openvino_pipe and not self._is_hetero_pipeline():
         
     | 
| 428 | 
         
            -
                        print("Using OpenVINO")
         
     | 
| 429 | 
         
            -
                        if reshape and not self.is_openvino_init:
         
     | 
| 430 | 
         
            -
                            print("Reshape and compile")
         
     | 
| 431 | 
         
            -
                            self.pipeline.reshape(
         
     | 
| 432 | 
         
            -
                                batch_size=-1,
         
     | 
| 433 | 
         
            -
                                height=lcm_diffusion_setting.image_height,
         
     | 
| 434 | 
         
            -
                                width=lcm_diffusion_setting.image_width,
         
     | 
| 435 | 
         
            -
                                num_images_per_prompt=lcm_diffusion_setting.number_of_images,
         
     | 
| 436 | 
         
            -
                            )
         
     | 
| 437 | 
         
            -
                            self.pipeline.compile()
         
     | 
| 438 | 
         
            -
             
     | 
| 439 | 
         
            -
                        if self.is_openvino_init:
         
     | 
| 440 | 
         
            -
                            self.is_openvino_init = False
         
     | 
| 441 | 
         
            -
             
     | 
| 442 | 
         
            -
                    if is_openvino_pipe and self._is_hetero_pipeline():
         
     | 
| 443 | 
         
            -
                        return self._generate_images_hetero_compute(lcm_diffusion_setting)
         
     | 
| 444 | 
         
            -
                    elif lcm_diffusion_setting.use_gguf_model:
         
     | 
| 445 | 
         
            -
                        return self._generate_images_gguf(lcm_diffusion_setting)
         
     | 
| 446 | 
         
            -
             
     | 
| 447 | 
         
            -
                    if lcm_diffusion_setting.clip_skip > 1:
         
     | 
| 448 | 
         
            -
                        # We follow the convention that "CLIP Skip == 2" means "skip
         
     | 
| 449 | 
         
            -
                        # the last layer", so "CLIP Skip == 1" means "no skipping"
         
     | 
| 450 | 
         
            -
                        pipeline_extra_args["clip_skip"] = lcm_diffusion_setting.clip_skip - 1
         
     | 
| 451 | 
         
            -
             
     | 
| 452 | 
         
            -
                    if not lcm_diffusion_setting.use_safety_checker:
         
     | 
| 453 | 
         
            -
                        self.pipeline.safety_checker = None
         
     | 
| 454 | 
         
            -
                        if (
         
     | 
| 455 | 
         
            -
                            lcm_diffusion_setting.diffusion_task
         
     | 
| 456 | 
         
            -
                            == DiffusionTask.image_to_image.value
         
     | 
| 457 | 
         
            -
                            and not is_openvino_pipe
         
     | 
| 458 | 
         
            -
                        ):
         
     | 
| 459 | 
         
            -
                            self.img_to_img_pipeline.safety_checker = None
         
     | 
| 460 | 
         
            -
             
     | 
| 461 | 
         
            -
                    if (
         
     | 
| 462 | 
         
            -
                        not lcm_diffusion_setting.use_lcm_lora
         
     | 
| 463 | 
         
            -
                        and not lcm_diffusion_setting.use_openvino
         
     | 
| 464 | 
         
            -
                        and lcm_diffusion_setting.guidance_scale != 1.0
         
     | 
| 465 | 
         
            -
                    ):
         
     | 
| 466 | 
         
            -
                        print("Not using LCM-LoRA so setting guidance_scale 1.0")
         
     | 
| 467 | 
         
            -
                        guidance_scale = 1.0
         
     | 
| 468 | 
         
            -
             
     | 
| 469 | 
         
            -
                    controlnet_args = update_controlnet_arguments(lcm_diffusion_setting)
         
     | 
| 470 | 
         
            -
                    if lcm_diffusion_setting.use_openvino:
         
     | 
| 471 | 
         
            -
                        if (
         
     | 
| 472 | 
         
            -
                            lcm_diffusion_setting.diffusion_task
         
     | 
| 473 | 
         
            -
                            == DiffusionTask.text_to_image.value
         
     | 
| 474 | 
         
            -
                        ):
         
     | 
| 475 | 
         
            -
                            result_images = self.pipeline(
         
     | 
| 476 | 
         
            -
                                prompt=lcm_diffusion_setting.prompt,
         
     | 
| 477 | 
         
            -
                                negative_prompt=lcm_diffusion_setting.negative_prompt,
         
     | 
| 478 | 
         
            -
                                num_inference_steps=lcm_diffusion_setting.inference_steps,
         
     | 
| 479 | 
         
            -
                                guidance_scale=guidance_scale,
         
     | 
| 480 | 
         
            -
                                width=lcm_diffusion_setting.image_width,
         
     | 
| 481 | 
         
            -
                                height=lcm_diffusion_setting.image_height,
         
     | 
| 482 | 
         
            -
                                num_images_per_prompt=lcm_diffusion_setting.number_of_images,
         
     | 
| 483 | 
         
            -
                            ).images
         
     | 
| 484 | 
         
            -
                        elif (
         
     | 
| 485 | 
         
            -
                            lcm_diffusion_setting.diffusion_task
         
     | 
| 486 | 
         
            -
                            == DiffusionTask.image_to_image.value
         
     | 
| 487 | 
         
            -
                        ):
         
     | 
| 488 | 
         
            -
                            result_images = self.pipeline(
         
     | 
| 489 | 
         
            -
                                image=lcm_diffusion_setting.init_image,
         
     | 
| 490 | 
         
            -
                                strength=lcm_diffusion_setting.strength,
         
     | 
| 491 | 
         
            -
                                prompt=lcm_diffusion_setting.prompt,
         
     | 
| 492 | 
         
            -
                                negative_prompt=lcm_diffusion_setting.negative_prompt,
         
     | 
| 493 | 
         
            -
                                num_inference_steps=img_to_img_inference_steps * 3,
         
     | 
| 494 | 
         
            -
                                guidance_scale=guidance_scale,
         
     | 
| 495 | 
         
            -
                                num_images_per_prompt=lcm_diffusion_setting.number_of_images,
         
     | 
| 496 | 
         
            -
                            ).images
         
     | 
| 497 | 
         
            -
             
     | 
| 498 | 
         
            -
                    else:
         
     | 
| 499 | 
         
            -
                        if (
         
     | 
| 500 | 
         
            -
                            lcm_diffusion_setting.diffusion_task
         
     | 
| 501 | 
         
            -
                            == DiffusionTask.text_to_image.value
         
     | 
| 502 | 
         
            -
                        ):
         
     | 
| 503 | 
         
            -
                            result_images = self.pipeline(
         
     | 
| 504 | 
         
            -
                                prompt=lcm_diffusion_setting.prompt,
         
     | 
| 505 | 
         
            -
                                negative_prompt=lcm_diffusion_setting.negative_prompt,
         
     | 
| 506 | 
         
            -
                                num_inference_steps=lcm_diffusion_setting.inference_steps,
         
     | 
| 507 | 
         
            -
                                guidance_scale=guidance_scale,
         
     | 
| 508 | 
         
            -
                                width=lcm_diffusion_setting.image_width,
         
     | 
| 509 | 
         
            -
                                height=lcm_diffusion_setting.image_height,
         
     | 
| 510 | 
         
            -
                                num_images_per_prompt=lcm_diffusion_setting.number_of_images,
         
     | 
| 511 | 
         
            -
                                timesteps=self._get_timesteps(),
         
     | 
| 512 | 
         
            -
                                **pipeline_extra_args,
         
     | 
| 513 | 
         
            -
                                **controlnet_args,
         
     | 
| 514 | 
         
            -
                            ).images
         
     | 
| 515 | 
         
            -
             
     | 
| 516 | 
         
            -
                        elif (
         
     | 
| 517 | 
         
            -
                            lcm_diffusion_setting.diffusion_task
         
     | 
| 518 | 
         
            -
                            == DiffusionTask.image_to_image.value
         
     | 
| 519 | 
         
            -
                        ):
         
     | 
| 520 | 
         
            -
                            result_images = self.img_to_img_pipeline(
         
     | 
| 521 | 
         
            -
                                image=lcm_diffusion_setting.init_image,
         
     | 
| 522 | 
         
            -
                                strength=lcm_diffusion_setting.strength,
         
     | 
| 523 | 
         
            -
                                prompt=lcm_diffusion_setting.prompt,
         
     | 
| 524 | 
         
            -
                                negative_prompt=lcm_diffusion_setting.negative_prompt,
         
     | 
| 525 | 
         
            -
                                num_inference_steps=img_to_img_inference_steps,
         
     | 
| 526 | 
         
            -
                                guidance_scale=guidance_scale,
         
     | 
| 527 | 
         
            -
                                width=lcm_diffusion_setting.image_width,
         
     | 
| 528 | 
         
            -
                                height=lcm_diffusion_setting.image_height,
         
     | 
| 529 | 
         
            -
                                num_images_per_prompt=lcm_diffusion_setting.number_of_images,
         
     | 
| 530 | 
         
            -
                                **pipeline_extra_args,
         
     | 
| 531 | 
         
            -
                                **controlnet_args,
         
     | 
| 532 | 
         
            -
                            ).images
         
     | 
| 533 | 
         
            -
             
     | 
| 534 | 
         
            -
                    for i, seed in enumerate(seeds):
         
     | 
| 535 | 
         
            -
                        result_images[i].info["image_seed"] = seed
         
     | 
| 536 | 
         
            -
             
     | 
| 537 | 
         
            -
                    return result_images
         
     | 
| 538 | 
         
            -
             
     | 
| 539 | 
         
            -
                def _init_gguf_diffusion(
         
     | 
| 540 | 
         
            -
                    self,
         
     | 
| 541 | 
         
            -
                    lcm_diffusion_setting: LCMDiffusionSetting,
         
     | 
| 542 | 
         
            -
                ):
         
     | 
| 543 | 
         
            -
                    config = ModelConfig()
         
     | 
| 544 | 
         
            -
                    config.model_path = lcm_diffusion_setting.gguf_model.diffusion_path
         
     | 
| 545 | 
         
            -
                    config.diffusion_model_path = lcm_diffusion_setting.gguf_model.diffusion_path
         
     | 
| 546 | 
         
            -
                    config.clip_l_path = lcm_diffusion_setting.gguf_model.clip_path
         
     | 
| 547 | 
         
            -
                    config.t5xxl_path = lcm_diffusion_setting.gguf_model.t5xxl_path
         
     | 
| 548 | 
         
            -
                    config.vae_path = lcm_diffusion_setting.gguf_model.vae_path
         
     | 
| 549 | 
         
            -
                    config.n_threads = GGUF_THREADS
         
     | 
| 550 | 
         
            -
                    print(f"GGUF Threads : {GGUF_THREADS} ")
         
     | 
| 551 | 
         
            -
                    print("GGUF - Model config")
         
     | 
| 552 | 
         
            -
                    pprint(lcm_diffusion_setting.gguf_model.model_dump())
         
     | 
| 553 | 
         
            -
                    self.pipeline = GGUFDiffusion(
         
     | 
| 554 | 
         
            -
                        get_app_path(),  # Place DLL in fastsdcpu folder
         
     | 
| 555 | 
         
            -
                        config,
         
     | 
| 556 | 
         
            -
                        True,
         
     | 
| 557 | 
         
            -
                    )
         
     | 
| 558 | 
         
            -
             
     | 
| 559 | 
         
            -
                def _generate_images_gguf(
         
     | 
| 560 | 
         
            -
                    self,
         
     | 
| 561 | 
         
            -
                    lcm_diffusion_setting: LCMDiffusionSetting,
         
     | 
| 562 | 
         
            -
                ):
         
     | 
| 563 | 
         
            -
                    if lcm_diffusion_setting.diffusion_task == DiffusionTask.text_to_image.value:
         
     | 
| 564 | 
         
            -
                        t2iconfig = Txt2ImgConfig()
         
     | 
| 565 | 
         
            -
                        t2iconfig.prompt = lcm_diffusion_setting.prompt
         
     | 
| 566 | 
         
            -
                        t2iconfig.batch_count = lcm_diffusion_setting.number_of_images
         
     | 
| 567 | 
         
            -
                        t2iconfig.cfg_scale = lcm_diffusion_setting.guidance_scale
         
     | 
| 568 | 
         
            -
                        t2iconfig.height = lcm_diffusion_setting.image_height
         
     | 
| 569 | 
         
            -
                        t2iconfig.width = lcm_diffusion_setting.image_width
         
     | 
| 570 | 
         
            -
                        t2iconfig.sample_steps = lcm_diffusion_setting.inference_steps
         
     | 
| 571 | 
         
            -
                        t2iconfig.sample_method = SampleMethod.EULER
         
     | 
| 572 | 
         
            -
                        if lcm_diffusion_setting.use_seed:
         
     | 
| 573 | 
         
            -
                            t2iconfig.seed = lcm_diffusion_setting.seed
         
     | 
| 574 | 
         
            -
                        else:
         
     | 
| 575 | 
         
            -
                            t2iconfig.seed = -1
         
     | 
| 576 | 
         
            -
             
     | 
| 577 | 
         
            -
                        return self.pipeline.generate_text2mg(t2iconfig)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/lora.py
    DELETED
    
    | 
         @@ -1,136 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            import glob
         
     | 
| 2 | 
         
            -
            from os import path
         
     | 
| 3 | 
         
            -
            from paths import get_file_name, FastStableDiffusionPaths
         
     | 
| 4 | 
         
            -
            from pathlib import Path
         
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
             
     | 
| 7 | 
         
            -
            # A basic class to keep track of the currently loaded LoRAs and
         
     | 
| 8 | 
         
            -
            # their weights; the diffusers function \c get_active_adapters()
         
     | 
| 9 | 
         
            -
            # returns a list of adapter names but not their weights so we need
         
     | 
| 10 | 
         
            -
            # a way to keep track of the current LoRA weights to set whenever
         
     | 
| 11 | 
         
            -
            # a new LoRA is loaded
         
     | 
| 12 | 
         
            -
            class _lora_info:
         
     | 
| 13 | 
         
            -
                def __init__(
         
     | 
| 14 | 
         
            -
                    self,
         
     | 
| 15 | 
         
            -
                    path: str,
         
     | 
| 16 | 
         
            -
                    weight: float,
         
     | 
| 17 | 
         
            -
                ):
         
     | 
| 18 | 
         
            -
                    self.path = path
         
     | 
| 19 | 
         
            -
                    self.adapter_name = get_file_name(path)
         
     | 
| 20 | 
         
            -
                    self.weight = weight
         
     | 
| 21 | 
         
            -
             
     | 
| 22 | 
         
            -
                def __del__(self):
         
     | 
| 23 | 
         
            -
                    self.path = None
         
     | 
| 24 | 
         
            -
                    self.adapter_name = None
         
     | 
| 25 | 
         
            -
             
     | 
| 26 | 
         
            -
             
     | 
| 27 | 
         
            -
            _loaded_loras = []
         
     | 
| 28 | 
         
            -
            _current_pipeline = None
         
     | 
| 29 | 
         
            -
             
     | 
| 30 | 
         
            -
             
     | 
| 31 | 
         
            -
            # This function loads a LoRA from the LoRA path setting, so it's
         
     | 
| 32 | 
         
            -
            # possible to load multiple LoRAs by calling this function more than
         
     | 
| 33 | 
         
            -
            # once with a different LoRA path setting; note that if you plan to
         
     | 
| 34 | 
         
            -
            # load multiple LoRAs and dynamically change their weights, you
         
     | 
| 35 | 
         
            -
            # might want to set the LoRA fuse option to False
         
     | 
| 36 | 
         
            -
            def load_lora_weight(
         
     | 
| 37 | 
         
            -
                pipeline,
         
     | 
| 38 | 
         
            -
                lcm_diffusion_setting,
         
     | 
| 39 | 
         
            -
            ):
         
     | 
| 40 | 
         
            -
                if not lcm_diffusion_setting.lora.path:
         
     | 
| 41 | 
         
            -
                    raise Exception("Empty lora model path")
         
     | 
| 42 | 
         
            -
             
     | 
| 43 | 
         
            -
                if not path.exists(lcm_diffusion_setting.lora.path):
         
     | 
| 44 | 
         
            -
                    raise Exception("Lora model path is invalid")
         
     | 
| 45 | 
         
            -
             
     | 
| 46 | 
         
            -
                # If the pipeline has been rebuilt since the last call, remove all
         
     | 
| 47 | 
         
            -
                # references to previously loaded LoRAs and store the new pipeline
         
     | 
| 48 | 
         
            -
                global _loaded_loras
         
     | 
| 49 | 
         
            -
                global _current_pipeline
         
     | 
| 50 | 
         
            -
                if pipeline != _current_pipeline:
         
     | 
| 51 | 
         
            -
                    for lora in _loaded_loras:
         
     | 
| 52 | 
         
            -
                        del lora
         
     | 
| 53 | 
         
            -
                    del _loaded_loras
         
     | 
| 54 | 
         
            -
                    _loaded_loras = []
         
     | 
| 55 | 
         
            -
                    _current_pipeline = pipeline
         
     | 
| 56 | 
         
            -
             
     | 
| 57 | 
         
            -
                current_lora = _lora_info(
         
     | 
| 58 | 
         
            -
                    lcm_diffusion_setting.lora.path,
         
     | 
| 59 | 
         
            -
                    lcm_diffusion_setting.lora.weight,
         
     | 
| 60 | 
         
            -
                )
         
     | 
| 61 | 
         
            -
                _loaded_loras.append(current_lora)
         
     | 
| 62 | 
         
            -
             
     | 
| 63 | 
         
            -
                if lcm_diffusion_setting.lora.enabled:
         
     | 
| 64 | 
         
            -
                    print(f"LoRA adapter name : {current_lora.adapter_name}")
         
     | 
| 65 | 
         
            -
                    pipeline.load_lora_weights(
         
     | 
| 66 | 
         
            -
                        FastStableDiffusionPaths.get_lora_models_path(),
         
     | 
| 67 | 
         
            -
                        weight_name=Path(lcm_diffusion_setting.lora.path).name,
         
     | 
| 68 | 
         
            -
                        local_files_only=True,
         
     | 
| 69 | 
         
            -
                        adapter_name=current_lora.adapter_name,
         
     | 
| 70 | 
         
            -
                    )
         
     | 
| 71 | 
         
            -
                    update_lora_weights(
         
     | 
| 72 | 
         
            -
                        pipeline,
         
     | 
| 73 | 
         
            -
                        lcm_diffusion_setting,
         
     | 
| 74 | 
         
            -
                    )
         
     | 
| 75 | 
         
            -
             
     | 
| 76 | 
         
            -
                    if lcm_diffusion_setting.lora.fuse:
         
     | 
| 77 | 
         
            -
                        pipeline.fuse_lora()
         
     | 
| 78 | 
         
            -
             
     | 
| 79 | 
         
            -
             
     | 
| 80 | 
         
            -
            def get_lora_models(root_dir: str):
         
     | 
| 81 | 
         
            -
                lora_models = glob.glob(f"{root_dir}/**/*.safetensors", recursive=True)
         
     | 
| 82 | 
         
            -
                lora_models_map = {}
         
     | 
| 83 | 
         
            -
                for file_path in lora_models:
         
     | 
| 84 | 
         
            -
                    lora_name = get_file_name(file_path)
         
     | 
| 85 | 
         
            -
                    if lora_name is not None:
         
     | 
| 86 | 
         
            -
                        lora_models_map[lora_name] = file_path
         
     | 
| 87 | 
         
            -
                return lora_models_map
         
     | 
| 88 | 
         
            -
             
     | 
| 89 | 
         
            -
             
     | 
| 90 | 
         
            -
            # This function returns a list of (adapter_name, weight) tuples for the
         
     | 
| 91 | 
         
            -
            # currently loaded LoRAs
         
     | 
| 92 | 
         
            -
            def get_active_lora_weights():
         
     | 
| 93 | 
         
            -
                active_loras = []
         
     | 
| 94 | 
         
            -
                for lora_info in _loaded_loras:
         
     | 
| 95 | 
         
            -
                    active_loras.append(
         
     | 
| 96 | 
         
            -
                        (
         
     | 
| 97 | 
         
            -
                            lora_info.adapter_name,
         
     | 
| 98 | 
         
            -
                            lora_info.weight,
         
     | 
| 99 | 
         
            -
                        )
         
     | 
| 100 | 
         
            -
                    )
         
     | 
| 101 | 
         
            -
                return active_loras
         
     | 
| 102 | 
         
            -
             
     | 
| 103 | 
         
            -
             
     | 
| 104 | 
         
            -
            # This function receives a pipeline, an lcm_diffusion_setting object and
         
     | 
| 105 | 
         
            -
            # an optional list of updated (adapter_name, weight) tuples
         
     | 
| 106 | 
         
            -
            def update_lora_weights(
         
     | 
| 107 | 
         
            -
                pipeline,
         
     | 
| 108 | 
         
            -
                lcm_diffusion_setting,
         
     | 
| 109 | 
         
            -
                lora_weights=None,
         
     | 
| 110 | 
         
            -
            ):
         
     | 
| 111 | 
         
            -
                global _loaded_loras
         
     | 
| 112 | 
         
            -
                global _current_pipeline
         
     | 
| 113 | 
         
            -
                if pipeline != _current_pipeline:
         
     | 
| 114 | 
         
            -
                    print("Wrong pipeline when trying to update LoRA weights")
         
     | 
| 115 | 
         
            -
                    return
         
     | 
| 116 | 
         
            -
                if lora_weights:
         
     | 
| 117 | 
         
            -
                    for idx, lora in enumerate(lora_weights):
         
     | 
| 118 | 
         
            -
                        if _loaded_loras[idx].adapter_name != lora[0]:
         
     | 
| 119 | 
         
            -
                            print("Wrong adapter name in LoRA enumeration!")
         
     | 
| 120 | 
         
            -
                            continue
         
     | 
| 121 | 
         
            -
                        _loaded_loras[idx].weight = lora[1]
         
     | 
| 122 | 
         
            -
             
     | 
| 123 | 
         
            -
                adapter_names = []
         
     | 
| 124 | 
         
            -
                adapter_weights = []
         
     | 
| 125 | 
         
            -
                if lcm_diffusion_setting.use_lcm_lora:
         
     | 
| 126 | 
         
            -
                    adapter_names.append("lcm")
         
     | 
| 127 | 
         
            -
                    adapter_weights.append(1.0)
         
     | 
| 128 | 
         
            -
                for lora in _loaded_loras:
         
     | 
| 129 | 
         
            -
                    adapter_names.append(lora.adapter_name)
         
     | 
| 130 | 
         
            -
                    adapter_weights.append(lora.weight)
         
     | 
| 131 | 
         
            -
                pipeline.set_adapters(
         
     | 
| 132 | 
         
            -
                    adapter_names,
         
     | 
| 133 | 
         
            -
                    adapter_weights=adapter_weights,
         
     | 
| 134 | 
         
            -
                )
         
     | 
| 135 | 
         
            -
                adapter_weights = zip(adapter_names, adapter_weights)
         
     | 
| 136 | 
         
            -
                print(f"Adapters: {list(adapter_weights)}")
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/models/device.py
    DELETED
    
    | 
         @@ -1,9 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from pydantic import BaseModel
         
     | 
| 2 | 
         
            -
             
     | 
| 3 | 
         
            -
             
     | 
| 4 | 
         
            -
            class DeviceInfo(BaseModel):
         
     | 
| 5 | 
         
            -
                device_type: str
         
     | 
| 6 | 
         
            -
                device_name: str
         
     | 
| 7 | 
         
            -
                os: str
         
     | 
| 8 | 
         
            -
                platform: str
         
     | 
| 9 | 
         
            -
                processor: str
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/models/gen_images.py
    DELETED
    
    | 
         @@ -1,17 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from pydantic import BaseModel
         
     | 
| 2 | 
         
            -
            from enum import Enum
         
     | 
| 3 | 
         
            -
            from paths import FastStableDiffusionPaths
         
     | 
| 4 | 
         
            -
             
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
            class ImageFormat(str, Enum):
         
     | 
| 7 | 
         
            -
                """Image format"""
         
     | 
| 8 | 
         
            -
             
     | 
| 9 | 
         
            -
                JPEG = "jpeg"
         
     | 
| 10 | 
         
            -
                PNG = "png"
         
     | 
| 11 | 
         
            -
             
     | 
| 12 | 
         
            -
             
     | 
| 13 | 
         
            -
            class GeneratedImages(BaseModel):
         
     | 
| 14 | 
         
            -
                path: str = FastStableDiffusionPaths.get_results_path()
         
     | 
| 15 | 
         
            -
                format: str = ImageFormat.PNG.value.upper()
         
     | 
| 16 | 
         
            -
                save_image: bool = True
         
     | 
| 17 | 
         
            -
                save_image_quality: int = 90
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/models/lcmdiffusion_setting.py
    DELETED
    
    | 
         @@ -1,76 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from enum import Enum
         
     | 
| 2 | 
         
            -
            from PIL import Image
         
     | 
| 3 | 
         
            -
            from typing import Any, Optional, Union
         
     | 
| 4 | 
         
            -
             
     | 
| 5 | 
         
            -
            from constants import LCM_DEFAULT_MODEL, LCM_DEFAULT_MODEL_OPENVINO
         
     | 
| 6 | 
         
            -
            from paths import FastStableDiffusionPaths
         
     | 
| 7 | 
         
            -
            from pydantic import BaseModel
         
     | 
| 8 | 
         
            -
             
     | 
| 9 | 
         
            -
             
     | 
| 10 | 
         
            -
            class LCMLora(BaseModel):
         
     | 
| 11 | 
         
            -
                base_model_id: str = "Lykon/dreamshaper-8"
         
     | 
| 12 | 
         
            -
                lcm_lora_id: str = "latent-consistency/lcm-lora-sdv1-5"
         
     | 
| 13 | 
         
            -
             
     | 
| 14 | 
         
            -
             
     | 
| 15 | 
         
            -
            class DiffusionTask(str, Enum):
         
     | 
| 16 | 
         
            -
                """Diffusion task types"""
         
     | 
| 17 | 
         
            -
             
     | 
| 18 | 
         
            -
                text_to_image = "text_to_image"
         
     | 
| 19 | 
         
            -
                image_to_image = "image_to_image"
         
     | 
| 20 | 
         
            -
             
     | 
| 21 | 
         
            -
             
     | 
| 22 | 
         
            -
            class Lora(BaseModel):
         
     | 
| 23 | 
         
            -
                models_dir: str = FastStableDiffusionPaths.get_lora_models_path()
         
     | 
| 24 | 
         
            -
                path: Optional[Any] = None
         
     | 
| 25 | 
         
            -
                weight: Optional[float] = 0.5
         
     | 
| 26 | 
         
            -
                fuse: bool = True
         
     | 
| 27 | 
         
            -
                enabled: bool = False
         
     | 
| 28 | 
         
            -
             
     | 
| 29 | 
         
            -
             
     | 
| 30 | 
         
            -
            class ControlNetSetting(BaseModel):
         
     | 
| 31 | 
         
            -
                adapter_path: Optional[str] = None  # ControlNet adapter path
         
     | 
| 32 | 
         
            -
                conditioning_scale: float = 0.5
         
     | 
| 33 | 
         
            -
                enabled: bool = False
         
     | 
| 34 | 
         
            -
                _control_image: Image = None  # Control image, PIL image
         
     | 
| 35 | 
         
            -
             
     | 
| 36 | 
         
            -
             
     | 
| 37 | 
         
            -
            class GGUFModel(BaseModel):
         
     | 
| 38 | 
         
            -
                gguf_models: str = FastStableDiffusionPaths.get_gguf_models_path()
         
     | 
| 39 | 
         
            -
                diffusion_path: Optional[str] = None
         
     | 
| 40 | 
         
            -
                clip_path: Optional[str] = None
         
     | 
| 41 | 
         
            -
                t5xxl_path: Optional[str] = None
         
     | 
| 42 | 
         
            -
                vae_path: Optional[str] = None
         
     | 
| 43 | 
         
            -
             
     | 
| 44 | 
         
            -
             
     | 
| 45 | 
         
            -
            class LCMDiffusionSetting(BaseModel):
         
     | 
| 46 | 
         
            -
                lcm_model_id: str = LCM_DEFAULT_MODEL
         
     | 
| 47 | 
         
            -
                openvino_lcm_model_id: str = LCM_DEFAULT_MODEL_OPENVINO
         
     | 
| 48 | 
         
            -
                use_offline_model: bool = False
         
     | 
| 49 | 
         
            -
                use_lcm_lora: bool = False
         
     | 
| 50 | 
         
            -
                lcm_lora: Optional[LCMLora] = LCMLora()
         
     | 
| 51 | 
         
            -
                use_tiny_auto_encoder: bool = False
         
     | 
| 52 | 
         
            -
                use_openvino: bool = False
         
     | 
| 53 | 
         
            -
                prompt: str = ""
         
     | 
| 54 | 
         
            -
                negative_prompt: str = ""
         
     | 
| 55 | 
         
            -
                init_image: Any = None
         
     | 
| 56 | 
         
            -
                strength: Optional[float] = 0.6
         
     | 
| 57 | 
         
            -
                image_height: Optional[int] = 512
         
     | 
| 58 | 
         
            -
                image_width: Optional[int] = 512
         
     | 
| 59 | 
         
            -
                inference_steps: Optional[int] = 1
         
     | 
| 60 | 
         
            -
                guidance_scale: Optional[float] = 1
         
     | 
| 61 | 
         
            -
                clip_skip: Optional[int] = 1
         
     | 
| 62 | 
         
            -
                token_merging: Optional[float] = 0
         
     | 
| 63 | 
         
            -
                number_of_images: Optional[int] = 1
         
     | 
| 64 | 
         
            -
                seed: Optional[int] = 123123
         
     | 
| 65 | 
         
            -
                use_seed: bool = False
         
     | 
| 66 | 
         
            -
                use_safety_checker: bool = False
         
     | 
| 67 | 
         
            -
                diffusion_task: str = DiffusionTask.text_to_image.value
         
     | 
| 68 | 
         
            -
                lora: Optional[Lora] = Lora()
         
     | 
| 69 | 
         
            -
                controlnet: Optional[Union[ControlNetSetting, list[ControlNetSetting]]] = None
         
     | 
| 70 | 
         
            -
                dirs: dict = {
         
     | 
| 71 | 
         
            -
                    "controlnet": FastStableDiffusionPaths.get_controlnet_models_path(),
         
     | 
| 72 | 
         
            -
                    "lora": FastStableDiffusionPaths.get_lora_models_path(),
         
     | 
| 73 | 
         
            -
                }
         
     | 
| 74 | 
         
            -
                rebuild_pipeline: bool = False
         
     | 
| 75 | 
         
            -
                use_gguf_model: bool = False
         
     | 
| 76 | 
         
            -
                gguf_model: Optional[GGUFModel] = GGUFModel()
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/models/upscale.py
    DELETED
    
    | 
         @@ -1,9 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from enum import Enum
         
     | 
| 2 | 
         
            -
             
     | 
| 3 | 
         
            -
             
     | 
| 4 | 
         
            -
            class UpscaleMode(str, Enum):
         
     | 
| 5 | 
         
            -
                """Diffusion task types"""
         
     | 
| 6 | 
         
            -
             
     | 
| 7 | 
         
            -
                normal = "normal"
         
     | 
| 8 | 
         
            -
                sd_upscale = "sd_upscale"
         
     | 
| 9 | 
         
            -
                aura_sr = "aura_sr"
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/openvino/custom_ov_model_vae_decoder.py
    DELETED
    
    | 
         @@ -1,21 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from backend.device import is_openvino_device
         
     | 
| 2 | 
         
            -
             
     | 
| 3 | 
         
            -
            if is_openvino_device():
         
     | 
| 4 | 
         
            -
                from optimum.intel.openvino.modeling_diffusion import OVModelVaeDecoder
         
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
             
     | 
| 7 | 
         
            -
            class CustomOVModelVaeDecoder(OVModelVaeDecoder):
         
     | 
| 8 | 
         
            -
                def __init__(
         
     | 
| 9 | 
         
            -
                    self,
         
     | 
| 10 | 
         
            -
                    model,
         
     | 
| 11 | 
         
            -
                    parent_model,
         
     | 
| 12 | 
         
            -
                    ov_config=None,
         
     | 
| 13 | 
         
            -
                    model_dir=None,
         
     | 
| 14 | 
         
            -
                ):
         
     | 
| 15 | 
         
            -
                    super(OVModelVaeDecoder, self).__init__(
         
     | 
| 16 | 
         
            -
                        model,
         
     | 
| 17 | 
         
            -
                        parent_model,
         
     | 
| 18 | 
         
            -
                        ov_config,
         
     | 
| 19 | 
         
            -
                        "vae_decoder",
         
     | 
| 20 | 
         
            -
                        model_dir,
         
     | 
| 21 | 
         
            -
                    )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/openvino/flux_pipeline.py
    DELETED
    
    | 
         @@ -1,36 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from pathlib import Path
         
     | 
| 2 | 
         
            -
             
     | 
| 3 | 
         
            -
            from constants import DEVICE, LCM_DEFAULT_MODEL_OPENVINO, TAEF1_MODEL_OPENVINO
         
     | 
| 4 | 
         
            -
            from huggingface_hub import snapshot_download
         
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
            from backend.openvino.ovflux import (
         
     | 
| 7 | 
         
            -
                TEXT_ENCODER_2_PATH,
         
     | 
| 8 | 
         
            -
                TEXT_ENCODER_PATH,
         
     | 
| 9 | 
         
            -
                TRANSFORMER_PATH,
         
     | 
| 10 | 
         
            -
                VAE_DECODER_PATH,
         
     | 
| 11 | 
         
            -
                init_pipeline,
         
     | 
| 12 | 
         
            -
            )
         
     | 
| 13 | 
         
            -
             
     | 
| 14 | 
         
            -
             
     | 
| 15 | 
         
            -
            def get_flux_pipeline(
         
     | 
| 16 | 
         
            -
                model_id: str = LCM_DEFAULT_MODEL_OPENVINO,
         
     | 
| 17 | 
         
            -
                use_taef1: bool = False,
         
     | 
| 18 | 
         
            -
                taef1_path: str = TAEF1_MODEL_OPENVINO,
         
     | 
| 19 | 
         
            -
            ):
         
     | 
| 20 | 
         
            -
                model_dir = Path(snapshot_download(model_id))
         
     | 
| 21 | 
         
            -
                vae_dir = Path(snapshot_download(taef1_path)) if use_taef1 else model_dir
         
     | 
| 22 | 
         
            -
             
     | 
| 23 | 
         
            -
                model_dict = {
         
     | 
| 24 | 
         
            -
                    "transformer": model_dir / TRANSFORMER_PATH,
         
     | 
| 25 | 
         
            -
                    "text_encoder": model_dir / TEXT_ENCODER_PATH,
         
     | 
| 26 | 
         
            -
                    "text_encoder_2": model_dir / TEXT_ENCODER_2_PATH,
         
     | 
| 27 | 
         
            -
                    "vae": vae_dir / VAE_DECODER_PATH,
         
     | 
| 28 | 
         
            -
                }
         
     | 
| 29 | 
         
            -
                ov_pipe = init_pipeline(
         
     | 
| 30 | 
         
            -
                    model_dir,
         
     | 
| 31 | 
         
            -
                    model_dict,
         
     | 
| 32 | 
         
            -
                    device=DEVICE.upper(),
         
     | 
| 33 | 
         
            -
                    use_taef1=use_taef1,
         
     | 
| 34 | 
         
            -
                )
         
     | 
| 35 | 
         
            -
             
     | 
| 36 | 
         
            -
                return ov_pipe
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/openvino/ov_hc_stablediffusion_pipeline.py
    DELETED
    
    | 
         @@ -1,93 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            """This is an experimental pipeline used to test AI PC NPU and GPU"""
         
     | 
| 2 | 
         
            -
             
     | 
| 3 | 
         
            -
            from pathlib import Path
         
     | 
| 4 | 
         
            -
             
     | 
| 5 | 
         
            -
            from diffusers import EulerDiscreteScheduler,LCMScheduler
         
     | 
| 6 | 
         
            -
            from huggingface_hub import snapshot_download
         
     | 
| 7 | 
         
            -
            from PIL import Image
         
     | 
| 8 | 
         
            -
            from backend.openvino.stable_diffusion_engine import (
         
     | 
| 9 | 
         
            -
                StableDiffusionEngineAdvanced,
         
     | 
| 10 | 
         
            -
                LatentConsistencyEngineAdvanced
         
     | 
| 11 | 
         
            -
            )
         
     | 
| 12 | 
         
            -
             
     | 
| 13 | 
         
            -
             
     | 
| 14 | 
         
            -
            class OvHcStableDiffusion:
         
     | 
| 15 | 
         
            -
                "OpenVINO Heterogeneous compute Stablediffusion"
         
     | 
| 16 | 
         
            -
             
     | 
| 17 | 
         
            -
                def __init__(
         
     | 
| 18 | 
         
            -
                    self,
         
     | 
| 19 | 
         
            -
                    model_path,
         
     | 
| 20 | 
         
            -
                    device: list = ["GPU", "NPU", "GPU", "GPU"],
         
     | 
| 21 | 
         
            -
                ):
         
     | 
| 22 | 
         
            -
                    model_dir = Path(snapshot_download(model_path))
         
     | 
| 23 | 
         
            -
                    self.scheduler = EulerDiscreteScheduler(
         
     | 
| 24 | 
         
            -
                        beta_start=0.00085,
         
     | 
| 25 | 
         
            -
                        beta_end=0.012,
         
     | 
| 26 | 
         
            -
                        beta_schedule="scaled_linear",
         
     | 
| 27 | 
         
            -
                    )
         
     | 
| 28 | 
         
            -
                    self.ov_sd_pipleline = StableDiffusionEngineAdvanced(
         
     | 
| 29 | 
         
            -
                        model=model_dir,
         
     | 
| 30 | 
         
            -
                        device=device,
         
     | 
| 31 | 
         
            -
                    )
         
     | 
| 32 | 
         
            -
             
     | 
| 33 | 
         
            -
                def generate(
         
     | 
| 34 | 
         
            -
                    self,
         
     | 
| 35 | 
         
            -
                    prompt: str,
         
     | 
| 36 | 
         
            -
                    neg_prompt: str,
         
     | 
| 37 | 
         
            -
                    init_image: Image = None,
         
     | 
| 38 | 
         
            -
                    strength: float = 1.0,
         
     | 
| 39 | 
         
            -
                ):
         
     | 
| 40 | 
         
            -
                    image = self.ov_sd_pipleline(
         
     | 
| 41 | 
         
            -
                        prompt=prompt,
         
     | 
| 42 | 
         
            -
                        negative_prompt=neg_prompt,
         
     | 
| 43 | 
         
            -
                        init_image=init_image,
         
     | 
| 44 | 
         
            -
                        strength=strength,
         
     | 
| 45 | 
         
            -
                        num_inference_steps=25,
         
     | 
| 46 | 
         
            -
                        scheduler=self.scheduler,
         
     | 
| 47 | 
         
            -
                    )
         
     | 
| 48 | 
         
            -
                    image_rgb = image[..., ::-1]
         
     | 
| 49 | 
         
            -
                    return Image.fromarray(image_rgb)
         
     | 
| 50 | 
         
            -
             
     | 
| 51 | 
         
            -
             
     | 
| 52 | 
         
            -
            class OvHcLatentConsistency:
         
     | 
| 53 | 
         
            -
                """
         
     | 
| 54 | 
         
            -
                OpenVINO Heterogeneous compute Latent consistency models
         
     | 
| 55 | 
         
            -
                For the current Intel Cor Ultra, the Text Encoder and Unet can run on NPU
         
     | 
| 56 | 
         
            -
                Supports following  - Text to image , Image to image and image variations
         
     | 
| 57 | 
         
            -
                """
         
     | 
| 58 | 
         
            -
             
     | 
| 59 | 
         
            -
                def __init__(
         
     | 
| 60 | 
         
            -
                    self,
         
     | 
| 61 | 
         
            -
                    model_path,
         
     | 
| 62 | 
         
            -
                    device: list = ["NPU", "NPU", "GPU"],
         
     | 
| 63 | 
         
            -
                ):
         
     | 
| 64 | 
         
            -
                    
         
     | 
| 65 | 
         
            -
                    model_dir = Path(snapshot_download(model_path))
         
     | 
| 66 | 
         
            -
                  
         
     | 
| 67 | 
         
            -
                    self.scheduler = LCMScheduler(
         
     | 
| 68 | 
         
            -
                            beta_start=0.001,
         
     | 
| 69 | 
         
            -
                            beta_end=0.01,
         
     | 
| 70 | 
         
            -
                        )
         
     | 
| 71 | 
         
            -
                    self.ov_sd_pipleline = LatentConsistencyEngineAdvanced(
         
     | 
| 72 | 
         
            -
                        model=model_dir,
         
     | 
| 73 | 
         
            -
                        device=device,
         
     | 
| 74 | 
         
            -
                    )
         
     | 
| 75 | 
         
            -
             
     | 
| 76 | 
         
            -
                def generate(
         
     | 
| 77 | 
         
            -
                    self,
         
     | 
| 78 | 
         
            -
                    prompt: str,
         
     | 
| 79 | 
         
            -
                    neg_prompt: str,
         
     | 
| 80 | 
         
            -
                    init_image: Image = None,
         
     | 
| 81 | 
         
            -
                     num_inference_steps=4,
         
     | 
| 82 | 
         
            -
                    strength: float = 0.5,
         
     | 
| 83 | 
         
            -
                ):
         
     | 
| 84 | 
         
            -
                    image = self.ov_sd_pipleline(
         
     | 
| 85 | 
         
            -
                        prompt=prompt,
         
     | 
| 86 | 
         
            -
                        init_image = init_image,
         
     | 
| 87 | 
         
            -
                        strength = strength,
         
     | 
| 88 | 
         
            -
                        num_inference_steps=num_inference_steps,
         
     | 
| 89 | 
         
            -
                        scheduler=self.scheduler,
         
     | 
| 90 | 
         
            -
                        seed=None,
         
     | 
| 91 | 
         
            -
                    )
         
     | 
| 92 | 
         
            -
                    
         
     | 
| 93 | 
         
            -
                    return image
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/openvino/ovflux.py
    DELETED
    
    | 
         @@ -1,675 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            """Based on  https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/notebooks/flux.1-image-generation/flux_helper.py"""
         
     | 
| 2 | 
         
            -
             
     | 
| 3 | 
         
            -
            import inspect
         
     | 
| 4 | 
         
            -
            import json
         
     | 
| 5 | 
         
            -
            from pathlib import Path
         
     | 
| 6 | 
         
            -
            from typing import Any, Dict, List, Optional, Union
         
     | 
| 7 | 
         
            -
             
     | 
| 8 | 
         
            -
            import numpy as np
         
     | 
| 9 | 
         
            -
            import openvino as ov
         
     | 
| 10 | 
         
            -
            import torch
         
     | 
| 11 | 
         
            -
            from diffusers.image_processor import VaeImageProcessor
         
     | 
| 12 | 
         
            -
            from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
         
     | 
| 13 | 
         
            -
            from diffusers.pipelines.pipeline_utils import DiffusionPipeline
         
     | 
| 14 | 
         
            -
            from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
         
     | 
| 15 | 
         
            -
            from diffusers.utils.torch_utils import randn_tensor
         
     | 
| 16 | 
         
            -
            from transformers import AutoTokenizer
         
     | 
| 17 | 
         
            -
             
     | 
| 18 | 
         
            -
            TRANSFORMER_PATH = Path("transformer/transformer.xml")
         
     | 
| 19 | 
         
            -
            VAE_DECODER_PATH = Path("vae/vae_decoder.xml")
         
     | 
| 20 | 
         
            -
            TEXT_ENCODER_PATH = Path("text_encoder/text_encoder.xml")
         
     | 
| 21 | 
         
            -
            TEXT_ENCODER_2_PATH = Path("text_encoder_2/text_encoder_2.xml")
         
     | 
| 22 | 
         
            -
             
     | 
| 23 | 
         
            -
             
     | 
| 24 | 
         
            -
            def cleanup_torchscript_cache():
         
     | 
| 25 | 
         
            -
                """
         
     | 
| 26 | 
         
            -
                Helper for removing cached model representation
         
     | 
| 27 | 
         
            -
                """
         
     | 
| 28 | 
         
            -
                torch._C._jit_clear_class_registry()
         
     | 
| 29 | 
         
            -
                torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
         
     | 
| 30 | 
         
            -
                torch.jit._state._clear_class_state()
         
     | 
| 31 | 
         
            -
             
     | 
| 32 | 
         
            -
             
     | 
| 33 | 
         
            -
            def _prepare_latent_image_ids(
         
     | 
| 34 | 
         
            -
                batch_size, height, width, device=torch.device("cpu"), dtype=torch.float32
         
     | 
| 35 | 
         
            -
            ):
         
     | 
| 36 | 
         
            -
                latent_image_ids = torch.zeros(height // 2, width // 2, 3)
         
     | 
| 37 | 
         
            -
                latent_image_ids[..., 1] = (
         
     | 
| 38 | 
         
            -
                    latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
         
     | 
| 39 | 
         
            -
                )
         
     | 
| 40 | 
         
            -
                latent_image_ids[..., 2] = (
         
     | 
| 41 | 
         
            -
                    latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
         
     | 
| 42 | 
         
            -
                )
         
     | 
| 43 | 
         
            -
             
     | 
| 44 | 
         
            -
                latent_image_id_height, latent_image_id_width, latent_image_id_channels = (
         
     | 
| 45 | 
         
            -
                    latent_image_ids.shape
         
     | 
| 46 | 
         
            -
                )
         
     | 
| 47 | 
         
            -
             
     | 
| 48 | 
         
            -
                latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
         
     | 
| 49 | 
         
            -
                latent_image_ids = latent_image_ids.reshape(
         
     | 
| 50 | 
         
            -
                    batch_size,
         
     | 
| 51 | 
         
            -
                    latent_image_id_height * latent_image_id_width,
         
     | 
| 52 | 
         
            -
                    latent_image_id_channels,
         
     | 
| 53 | 
         
            -
                )
         
     | 
| 54 | 
         
            -
             
     | 
| 55 | 
         
            -
                return latent_image_ids.to(device=device, dtype=dtype)
         
     | 
| 56 | 
         
            -
             
     | 
| 57 | 
         
            -
             
     | 
| 58 | 
         
            -
            def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
         
     | 
| 59 | 
         
            -
                assert dim % 2 == 0, "The dimension must be even."
         
     | 
| 60 | 
         
            -
             
     | 
| 61 | 
         
            -
                scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
         
     | 
| 62 | 
         
            -
                omega = 1.0 / (theta**scale)
         
     | 
| 63 | 
         
            -
             
     | 
| 64 | 
         
            -
                batch_size, seq_length = pos.shape
         
     | 
| 65 | 
         
            -
                out = pos.unsqueeze(-1) * omega.unsqueeze(0).unsqueeze(0)
         
     | 
| 66 | 
         
            -
                cos_out = torch.cos(out)
         
     | 
| 67 | 
         
            -
                sin_out = torch.sin(out)
         
     | 
| 68 | 
         
            -
             
     | 
| 69 | 
         
            -
                stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
         
     | 
| 70 | 
         
            -
                out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
         
     | 
| 71 | 
         
            -
                return out.float()
         
     | 
| 72 | 
         
            -
             
     | 
| 73 | 
         
            -
             
     | 
| 74 | 
         
            -
            def calculate_shift(
         
     | 
| 75 | 
         
            -
                image_seq_len,
         
     | 
| 76 | 
         
            -
                base_seq_len: int = 256,
         
     | 
| 77 | 
         
            -
                max_seq_len: int = 4096,
         
     | 
| 78 | 
         
            -
                base_shift: float = 0.5,
         
     | 
| 79 | 
         
            -
                max_shift: float = 1.16,
         
     | 
| 80 | 
         
            -
            ):
         
     | 
| 81 | 
         
            -
                m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
         
     | 
| 82 | 
         
            -
                b = base_shift - m * base_seq_len
         
     | 
| 83 | 
         
            -
                mu = image_seq_len * m + b
         
     | 
| 84 | 
         
            -
                return mu
         
     | 
| 85 | 
         
            -
             
     | 
| 86 | 
         
            -
             
     | 
| 87 | 
         
            -
            # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
         
     | 
| 88 | 
         
            -
            def retrieve_timesteps(
         
     | 
| 89 | 
         
            -
                scheduler,
         
     | 
| 90 | 
         
            -
                num_inference_steps: Optional[int] = None,
         
     | 
| 91 | 
         
            -
                timesteps: Optional[List[int]] = None,
         
     | 
| 92 | 
         
            -
                sigmas: Optional[List[float]] = None,
         
     | 
| 93 | 
         
            -
                **kwargs,
         
     | 
| 94 | 
         
            -
            ):
         
     | 
| 95 | 
         
            -
                """
         
     | 
| 96 | 
         
            -
                Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
         
     | 
| 97 | 
         
            -
                custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
         
     | 
| 98 | 
         
            -
             
     | 
| 99 | 
         
            -
                Args:
         
     | 
| 100 | 
         
            -
                    scheduler (`SchedulerMixin`):
         
     | 
| 101 | 
         
            -
                        The scheduler to get timesteps from.
         
     | 
| 102 | 
         
            -
                    num_inference_steps (`int`):
         
     | 
| 103 | 
         
            -
                        The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
         
     | 
| 104 | 
         
            -
                        must be `None`.
         
     | 
| 105 | 
         
            -
                    device (`str` or `torch.device`, *optional*):
         
     | 
| 106 | 
         
            -
                        The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
         
     | 
| 107 | 
         
            -
                    timesteps (`List[int]`, *optional*):
         
     | 
| 108 | 
         
            -
                        Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
         
     | 
| 109 | 
         
            -
                        `num_inference_steps` and `sigmas` must be `None`.
         
     | 
| 110 | 
         
            -
                    sigmas (`List[float]`, *optional*):
         
     | 
| 111 | 
         
            -
                        Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
         
     | 
| 112 | 
         
            -
                        `num_inference_steps` and `timesteps` must be `None`.
         
     | 
| 113 | 
         
            -
             
     | 
| 114 | 
         
            -
                Returns:
         
     | 
| 115 | 
         
            -
                    `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
         
     | 
| 116 | 
         
            -
                    second element is the number of inference steps.
         
     | 
| 117 | 
         
            -
                """
         
     | 
| 118 | 
         
            -
                if timesteps is not None and sigmas is not None:
         
     | 
| 119 | 
         
            -
                    raise ValueError(
         
     | 
| 120 | 
         
            -
                        "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
         
     | 
| 121 | 
         
            -
                    )
         
     | 
| 122 | 
         
            -
                if timesteps is not None:
         
     | 
| 123 | 
         
            -
                    accepts_timesteps = "timesteps" in set(
         
     | 
| 124 | 
         
            -
                        inspect.signature(scheduler.set_timesteps).parameters.keys()
         
     | 
| 125 | 
         
            -
                    )
         
     | 
| 126 | 
         
            -
                    if not accepts_timesteps:
         
     | 
| 127 | 
         
            -
                        raise ValueError(
         
     | 
| 128 | 
         
            -
                            f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
         
     | 
| 129 | 
         
            -
                            f" timestep schedules. Please check whether you are using the correct scheduler."
         
     | 
| 130 | 
         
            -
                        )
         
     | 
| 131 | 
         
            -
                    scheduler.set_timesteps(timesteps=timesteps, **kwargs)
         
     | 
| 132 | 
         
            -
                    timesteps = scheduler.timesteps
         
     | 
| 133 | 
         
            -
                    num_inference_steps = len(timesteps)
         
     | 
| 134 | 
         
            -
                elif sigmas is not None:
         
     | 
| 135 | 
         
            -
                    accept_sigmas = "sigmas" in set(
         
     | 
| 136 | 
         
            -
                        inspect.signature(scheduler.set_timesteps).parameters.keys()
         
     | 
| 137 | 
         
            -
                    )
         
     | 
| 138 | 
         
            -
                    if not accept_sigmas:
         
     | 
| 139 | 
         
            -
                        raise ValueError(
         
     | 
| 140 | 
         
            -
                            f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
         
     | 
| 141 | 
         
            -
                            f" sigmas schedules. Please check whether you are using the correct scheduler."
         
     | 
| 142 | 
         
            -
                        )
         
     | 
| 143 | 
         
            -
                    scheduler.set_timesteps(sigmas=sigmas, **kwargs)
         
     | 
| 144 | 
         
            -
                    timesteps = scheduler.timesteps
         
     | 
| 145 | 
         
            -
                    num_inference_steps = len(timesteps)
         
     | 
| 146 | 
         
            -
                else:
         
     | 
| 147 | 
         
            -
                    scheduler.set_timesteps(num_inference_steps, **kwargs)
         
     | 
| 148 | 
         
            -
                    timesteps = scheduler.timesteps
         
     | 
| 149 | 
         
            -
                return timesteps, num_inference_steps
         
     | 
| 150 | 
         
            -
             
     | 
| 151 | 
         
            -
             
     | 
| 152 | 
         
            -
            class OVFluxPipeline(DiffusionPipeline):
         
     | 
| 153 | 
         
            -
                def __init__(
         
     | 
| 154 | 
         
            -
                    self,
         
     | 
| 155 | 
         
            -
                    scheduler,
         
     | 
| 156 | 
         
            -
                    transformer,
         
     | 
| 157 | 
         
            -
                    vae,
         
     | 
| 158 | 
         
            -
                    text_encoder,
         
     | 
| 159 | 
         
            -
                    text_encoder_2,
         
     | 
| 160 | 
         
            -
                    tokenizer,
         
     | 
| 161 | 
         
            -
                    tokenizer_2,
         
     | 
| 162 | 
         
            -
                    transformer_config,
         
     | 
| 163 | 
         
            -
                    vae_config,
         
     | 
| 164 | 
         
            -
                ):
         
     | 
| 165 | 
         
            -
                    super().__init__()
         
     | 
| 166 | 
         
            -
             
     | 
| 167 | 
         
            -
                    self.register_modules(
         
     | 
| 168 | 
         
            -
                        vae=vae,
         
     | 
| 169 | 
         
            -
                        text_encoder=text_encoder,
         
     | 
| 170 | 
         
            -
                        text_encoder_2=text_encoder_2,
         
     | 
| 171 | 
         
            -
                        tokenizer=tokenizer,
         
     | 
| 172 | 
         
            -
                        tokenizer_2=tokenizer_2,
         
     | 
| 173 | 
         
            -
                        transformer=transformer,
         
     | 
| 174 | 
         
            -
                        scheduler=scheduler,
         
     | 
| 175 | 
         
            -
                    )
         
     | 
| 176 | 
         
            -
                    self.vae_config = vae_config
         
     | 
| 177 | 
         
            -
                    self.transformer_config = transformer_config
         
     | 
| 178 | 
         
            -
                    self.vae_scale_factor = 2 ** (
         
     | 
| 179 | 
         
            -
                        len(self.vae_config.get("block_out_channels", [0] * 16))
         
     | 
| 180 | 
         
            -
                        if hasattr(self, "vae") and self.vae is not None
         
     | 
| 181 | 
         
            -
                        else 16
         
     | 
| 182 | 
         
            -
                    )
         
     | 
| 183 | 
         
            -
                    self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
         
     | 
| 184 | 
         
            -
                    self.tokenizer_max_length = (
         
     | 
| 185 | 
         
            -
                        self.tokenizer.model_max_length
         
     | 
| 186 | 
         
            -
                        if hasattr(self, "tokenizer") and self.tokenizer is not None
         
     | 
| 187 | 
         
            -
                        else 77
         
     | 
| 188 | 
         
            -
                    )
         
     | 
| 189 | 
         
            -
                    self.default_sample_size = 64
         
     | 
| 190 | 
         
            -
             
     | 
| 191 | 
         
            -
                def _get_t5_prompt_embeds(
         
     | 
| 192 | 
         
            -
                    self,
         
     | 
| 193 | 
         
            -
                    prompt: Union[str, List[str]] = None,
         
     | 
| 194 | 
         
            -
                    num_images_per_prompt: int = 1,
         
     | 
| 195 | 
         
            -
                    max_sequence_length: int = 512,
         
     | 
| 196 | 
         
            -
                ):
         
     | 
| 197 | 
         
            -
                    prompt = [prompt] if isinstance(prompt, str) else prompt
         
     | 
| 198 | 
         
            -
                    batch_size = len(prompt)
         
     | 
| 199 | 
         
            -
             
     | 
| 200 | 
         
            -
                    text_inputs = self.tokenizer_2(
         
     | 
| 201 | 
         
            -
                        prompt,
         
     | 
| 202 | 
         
            -
                        padding="max_length",
         
     | 
| 203 | 
         
            -
                        max_length=max_sequence_length,
         
     | 
| 204 | 
         
            -
                        truncation=True,
         
     | 
| 205 | 
         
            -
                        return_length=False,
         
     | 
| 206 | 
         
            -
                        return_overflowing_tokens=False,
         
     | 
| 207 | 
         
            -
                        return_tensors="pt",
         
     | 
| 208 | 
         
            -
                    )
         
     | 
| 209 | 
         
            -
                    text_input_ids = text_inputs.input_ids
         
     | 
| 210 | 
         
            -
                    prompt_embeds = torch.from_numpy(self.text_encoder_2(text_input_ids)[0])
         
     | 
| 211 | 
         
            -
             
     | 
| 212 | 
         
            -
                    _, seq_len, _ = prompt_embeds.shape
         
     | 
| 213 | 
         
            -
             
     | 
| 214 | 
         
            -
                    # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
         
     | 
| 215 | 
         
            -
                    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
         
     | 
| 216 | 
         
            -
                    prompt_embeds = prompt_embeds.view(
         
     | 
| 217 | 
         
            -
                        batch_size * num_images_per_prompt, seq_len, -1
         
     | 
| 218 | 
         
            -
                    )
         
     | 
| 219 | 
         
            -
             
     | 
| 220 | 
         
            -
                    return prompt_embeds
         
     | 
| 221 | 
         
            -
             
     | 
| 222 | 
         
            -
                def _get_clip_prompt_embeds(
         
     | 
| 223 | 
         
            -
                    self,
         
     | 
| 224 | 
         
            -
                    prompt: Union[str, List[str]],
         
     | 
| 225 | 
         
            -
                    num_images_per_prompt: int = 1,
         
     | 
| 226 | 
         
            -
                ):
         
     | 
| 227 | 
         
            -
             
     | 
| 228 | 
         
            -
                    prompt = [prompt] if isinstance(prompt, str) else prompt
         
     | 
| 229 | 
         
            -
                    batch_size = len(prompt)
         
     | 
| 230 | 
         
            -
             
     | 
| 231 | 
         
            -
                    text_inputs = self.tokenizer(
         
     | 
| 232 | 
         
            -
                        prompt,
         
     | 
| 233 | 
         
            -
                        padding="max_length",
         
     | 
| 234 | 
         
            -
                        max_length=self.tokenizer_max_length,
         
     | 
| 235 | 
         
            -
                        truncation=True,
         
     | 
| 236 | 
         
            -
                        return_overflowing_tokens=False,
         
     | 
| 237 | 
         
            -
                        return_length=False,
         
     | 
| 238 | 
         
            -
                        return_tensors="pt",
         
     | 
| 239 | 
         
            -
                    )
         
     | 
| 240 | 
         
            -
             
     | 
| 241 | 
         
            -
                    text_input_ids = text_inputs.input_ids
         
     | 
| 242 | 
         
            -
                    prompt_embeds = torch.from_numpy(self.text_encoder(text_input_ids)[1])
         
     | 
| 243 | 
         
            -
             
     | 
| 244 | 
         
            -
                    # duplicate text embeddings for each generation per prompt, using mps friendly method
         
     | 
| 245 | 
         
            -
                    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
         
     | 
| 246 | 
         
            -
                    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
         
     | 
| 247 | 
         
            -
             
     | 
| 248 | 
         
            -
                    return prompt_embeds
         
     | 
| 249 | 
         
            -
             
     | 
| 250 | 
         
            -
                def encode_prompt(
         
     | 
| 251 | 
         
            -
                    self,
         
     | 
| 252 | 
         
            -
                    prompt: Union[str, List[str]],
         
     | 
| 253 | 
         
            -
                    prompt_2: Union[str, List[str]],
         
     | 
| 254 | 
         
            -
                    num_images_per_prompt: int = 1,
         
     | 
| 255 | 
         
            -
                    prompt_embeds: Optional[torch.FloatTensor] = None,
         
     | 
| 256 | 
         
            -
                    pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
         
     | 
| 257 | 
         
            -
                    max_sequence_length: int = 512,
         
     | 
| 258 | 
         
            -
                ):
         
     | 
| 259 | 
         
            -
                    r"""
         
     | 
| 260 | 
         
            -
             
     | 
| 261 | 
         
            -
                    Args:
         
     | 
| 262 | 
         
            -
                        prompt (`str` or `List[str]`, *optional*):
         
     | 
| 263 | 
         
            -
                            prompt to be encoded
         
     | 
| 264 | 
         
            -
                        prompt_2 (`str` or `List[str]`, *optional*):
         
     | 
| 265 | 
         
            -
                            The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
         
     | 
| 266 | 
         
            -
                            used in all text-encoders
         
     | 
| 267 | 
         
            -
                        num_images_per_prompt (`int`):
         
     | 
| 268 | 
         
            -
                            number of images that should be generated per prompt
         
     | 
| 269 | 
         
            -
                        prompt_embeds (`torch.FloatTensor`, *optional*):
         
     | 
| 270 | 
         
            -
                            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
         
     | 
| 271 | 
         
            -
                            provided, text embeddings will be generated from `prompt` input argument.
         
     | 
| 272 | 
         
            -
                        pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
         
     | 
| 273 | 
         
            -
                            Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
         
     | 
| 274 | 
         
            -
                            If not provided, pooled text embeddings will be generated from `prompt` input argument.
         
     | 
| 275 | 
         
            -
                        lora_scale (`float`, *optional*):
         
     | 
| 276 | 
         
            -
                            A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
         
     | 
| 277 | 
         
            -
                    """
         
     | 
| 278 | 
         
            -
             
     | 
| 279 | 
         
            -
                    prompt = [prompt] if isinstance(prompt, str) else prompt
         
     | 
| 280 | 
         
            -
                    if prompt is not None:
         
     | 
| 281 | 
         
            -
                        batch_size = len(prompt)
         
     | 
| 282 | 
         
            -
                    else:
         
     | 
| 283 | 
         
            -
                        batch_size = prompt_embeds.shape[0]
         
     | 
| 284 | 
         
            -
             
     | 
| 285 | 
         
            -
                    if prompt_embeds is None:
         
     | 
| 286 | 
         
            -
                        prompt_2 = prompt_2 or prompt
         
     | 
| 287 | 
         
            -
                        prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
         
     | 
| 288 | 
         
            -
             
     | 
| 289 | 
         
            -
                        # We only use the pooled prompt output from the CLIPTextModel
         
     | 
| 290 | 
         
            -
                        pooled_prompt_embeds = self._get_clip_prompt_embeds(
         
     | 
| 291 | 
         
            -
                            prompt=prompt,
         
     | 
| 292 | 
         
            -
                            num_images_per_prompt=num_images_per_prompt,
         
     | 
| 293 | 
         
            -
                        )
         
     | 
| 294 | 
         
            -
                        prompt_embeds = self._get_t5_prompt_embeds(
         
     | 
| 295 | 
         
            -
                            prompt=prompt_2,
         
     | 
| 296 | 
         
            -
                            num_images_per_prompt=num_images_per_prompt,
         
     | 
| 297 | 
         
            -
                            max_sequence_length=max_sequence_length,
         
     | 
| 298 | 
         
            -
                        )
         
     | 
| 299 | 
         
            -
                    text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3)
         
     | 
| 300 | 
         
            -
                    text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
         
     | 
| 301 | 
         
            -
             
     | 
| 302 | 
         
            -
                    return prompt_embeds, pooled_prompt_embeds, text_ids
         
     | 
| 303 | 
         
            -
             
     | 
| 304 | 
         
            -
                def check_inputs(
         
     | 
| 305 | 
         
            -
                    self,
         
     | 
| 306 | 
         
            -
                    prompt,
         
     | 
| 307 | 
         
            -
                    prompt_2,
         
     | 
| 308 | 
         
            -
                    height,
         
     | 
| 309 | 
         
            -
                    width,
         
     | 
| 310 | 
         
            -
                    prompt_embeds=None,
         
     | 
| 311 | 
         
            -
                    pooled_prompt_embeds=None,
         
     | 
| 312 | 
         
            -
                    max_sequence_length=None,
         
     | 
| 313 | 
         
            -
                ):
         
     | 
| 314 | 
         
            -
                    if height % 8 != 0 or width % 8 != 0:
         
     | 
| 315 | 
         
            -
                        raise ValueError(
         
     | 
| 316 | 
         
            -
                            f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
         
     | 
| 317 | 
         
            -
                        )
         
     | 
| 318 | 
         
            -
             
     | 
| 319 | 
         
            -
                    if prompt is not None and prompt_embeds is not None:
         
     | 
| 320 | 
         
            -
                        raise ValueError(
         
     | 
| 321 | 
         
            -
                            f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
         
     | 
| 322 | 
         
            -
                            " only forward one of the two."
         
     | 
| 323 | 
         
            -
                        )
         
     | 
| 324 | 
         
            -
                    elif prompt_2 is not None and prompt_embeds is not None:
         
     | 
| 325 | 
         
            -
                        raise ValueError(
         
     | 
| 326 | 
         
            -
                            f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
         
     | 
| 327 | 
         
            -
                            " only forward one of the two."
         
     | 
| 328 | 
         
            -
                        )
         
     | 
| 329 | 
         
            -
                    elif prompt is None and prompt_embeds is None:
         
     | 
| 330 | 
         
            -
                        raise ValueError(
         
     | 
| 331 | 
         
            -
                            "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
         
     | 
| 332 | 
         
            -
                        )
         
     | 
| 333 | 
         
            -
                    elif prompt is not None and (
         
     | 
| 334 | 
         
            -
                        not isinstance(prompt, str) and not isinstance(prompt, list)
         
     | 
| 335 | 
         
            -
                    ):
         
     | 
| 336 | 
         
            -
                        raise ValueError(
         
     | 
| 337 | 
         
            -
                            f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
         
     | 
| 338 | 
         
            -
                        )
         
     | 
| 339 | 
         
            -
                    elif prompt_2 is not None and (
         
     | 
| 340 | 
         
            -
                        not isinstance(prompt_2, str) and not isinstance(prompt_2, list)
         
     | 
| 341 | 
         
            -
                    ):
         
     | 
| 342 | 
         
            -
                        raise ValueError(
         
     | 
| 343 | 
         
            -
                            f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}"
         
     | 
| 344 | 
         
            -
                        )
         
     | 
| 345 | 
         
            -
             
     | 
| 346 | 
         
            -
                    if prompt_embeds is not None and pooled_prompt_embeds is None:
         
     | 
| 347 | 
         
            -
                        raise ValueError(
         
     | 
| 348 | 
         
            -
                            "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
         
     | 
| 349 | 
         
            -
                        )
         
     | 
| 350 | 
         
            -
             
     | 
| 351 | 
         
            -
                    if max_sequence_length is not None and max_sequence_length > 512:
         
     | 
| 352 | 
         
            -
                        raise ValueError(
         
     | 
| 353 | 
         
            -
                            f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}"
         
     | 
| 354 | 
         
            -
                        )
         
     | 
| 355 | 
         
            -
             
     | 
| 356 | 
         
            -
                @staticmethod
         
     | 
| 357 | 
         
            -
                def _prepare_latent_image_ids(batch_size, height, width):
         
     | 
| 358 | 
         
            -
                    return _prepare_latent_image_ids(batch_size, height, width)
         
     | 
| 359 | 
         
            -
             
     | 
| 360 | 
         
            -
                @staticmethod
         
     | 
| 361 | 
         
            -
                def _pack_latents(latents, batch_size, num_channels_latents, height, width):
         
     | 
| 362 | 
         
            -
                    latents = latents.view(
         
     | 
| 363 | 
         
            -
                        batch_size, num_channels_latents, height // 2, 2, width // 2, 2
         
     | 
| 364 | 
         
            -
                    )
         
     | 
| 365 | 
         
            -
                    latents = latents.permute(0, 2, 4, 1, 3, 5)
         
     | 
| 366 | 
         
            -
                    latents = latents.reshape(
         
     | 
| 367 | 
         
            -
                        batch_size, (height // 2) * (width // 2), num_channels_latents * 4
         
     | 
| 368 | 
         
            -
                    )
         
     | 
| 369 | 
         
            -
             
     | 
| 370 | 
         
            -
                    return latents
         
     | 
| 371 | 
         
            -
             
     | 
| 372 | 
         
            -
                @staticmethod
         
     | 
| 373 | 
         
            -
                def _unpack_latents(latents, height, width, vae_scale_factor):
         
     | 
| 374 | 
         
            -
                    batch_size, num_patches, channels = latents.shape
         
     | 
| 375 | 
         
            -
             
     | 
| 376 | 
         
            -
                    height = height // vae_scale_factor
         
     | 
| 377 | 
         
            -
                    width = width // vae_scale_factor
         
     | 
| 378 | 
         
            -
             
     | 
| 379 | 
         
            -
                    latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
         
     | 
| 380 | 
         
            -
                    latents = latents.permute(0, 3, 1, 4, 2, 5)
         
     | 
| 381 | 
         
            -
             
     | 
| 382 | 
         
            -
                    latents = latents.reshape(
         
     | 
| 383 | 
         
            -
                        batch_size, channels // (2 * 2), height * 2, width * 2
         
     | 
| 384 | 
         
            -
                    )
         
     | 
| 385 | 
         
            -
             
     | 
| 386 | 
         
            -
                    return latents
         
     | 
| 387 | 
         
            -
             
     | 
| 388 | 
         
            -
                def prepare_latents(
         
     | 
| 389 | 
         
            -
                    self,
         
     | 
| 390 | 
         
            -
                    batch_size,
         
     | 
| 391 | 
         
            -
                    num_channels_latents,
         
     | 
| 392 | 
         
            -
                    height,
         
     | 
| 393 | 
         
            -
                    width,
         
     | 
| 394 | 
         
            -
                    generator,
         
     | 
| 395 | 
         
            -
                    latents=None,
         
     | 
| 396 | 
         
            -
                ):
         
     | 
| 397 | 
         
            -
                    height = 2 * (int(height) // self.vae_scale_factor)
         
     | 
| 398 | 
         
            -
                    width = 2 * (int(width) // self.vae_scale_factor)
         
     | 
| 399 | 
         
            -
             
     | 
| 400 | 
         
            -
                    shape = (batch_size, num_channels_latents, height, width)
         
     | 
| 401 | 
         
            -
             
     | 
| 402 | 
         
            -
                    if latents is not None:
         
     | 
| 403 | 
         
            -
                        latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width)
         
     | 
| 404 | 
         
            -
                        return latents, latent_image_ids
         
     | 
| 405 | 
         
            -
             
     | 
| 406 | 
         
            -
                    if isinstance(generator, list) and len(generator) != batch_size:
         
     | 
| 407 | 
         
            -
                        raise ValueError(
         
     | 
| 408 | 
         
            -
                            f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
         
     | 
| 409 | 
         
            -
                            f" size of {batch_size}. Make sure the batch size matches the length of the generators."
         
     | 
| 410 | 
         
            -
                        )
         
     | 
| 411 | 
         
            -
             
     | 
| 412 | 
         
            -
                    latents = randn_tensor(shape, generator=generator)
         
     | 
| 413 | 
         
            -
                    latents = self._pack_latents(
         
     | 
| 414 | 
         
            -
                        latents, batch_size, num_channels_latents, height, width
         
     | 
| 415 | 
         
            -
                    )
         
     | 
| 416 | 
         
            -
             
     | 
| 417 | 
         
            -
                    latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width)
         
     | 
| 418 | 
         
            -
             
     | 
| 419 | 
         
            -
                    return latents, latent_image_ids
         
     | 
| 420 | 
         
            -
             
     | 
| 421 | 
         
            -
                @property
         
     | 
| 422 | 
         
            -
                def guidance_scale(self):
         
     | 
| 423 | 
         
            -
                    return self._guidance_scale
         
     | 
| 424 | 
         
            -
             
     | 
| 425 | 
         
            -
                @property
         
     | 
| 426 | 
         
            -
                def num_timesteps(self):
         
     | 
| 427 | 
         
            -
                    return self._num_timesteps
         
     | 
| 428 | 
         
            -
             
     | 
| 429 | 
         
            -
                @property
         
     | 
| 430 | 
         
            -
                def interrupt(self):
         
     | 
| 431 | 
         
            -
                    return self._interrupt
         
     | 
| 432 | 
         
            -
             
     | 
| 433 | 
         
            -
                def __call__(
         
     | 
| 434 | 
         
            -
                    self,
         
     | 
| 435 | 
         
            -
                    prompt: Union[str, List[str]] = None,
         
     | 
| 436 | 
         
            -
                    prompt_2: Optional[Union[str, List[str]]] = None,
         
     | 
| 437 | 
         
            -
                    height: Optional[int] = None,
         
     | 
| 438 | 
         
            -
                    width: Optional[int] = None,
         
     | 
| 439 | 
         
            -
                    negative_prompt: str = None,
         
     | 
| 440 | 
         
            -
                    num_inference_steps: int = 28,
         
     | 
| 441 | 
         
            -
                    timesteps: List[int] = None,
         
     | 
| 442 | 
         
            -
                    guidance_scale: float = 7.0,
         
     | 
| 443 | 
         
            -
                    num_images_per_prompt: Optional[int] = 1,
         
     | 
| 444 | 
         
            -
                    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
         
     | 
| 445 | 
         
            -
                    latents: Optional[torch.FloatTensor] = None,
         
     | 
| 446 | 
         
            -
                    prompt_embeds: Optional[torch.FloatTensor] = None,
         
     | 
| 447 | 
         
            -
                    pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
         
     | 
| 448 | 
         
            -
                    output_type: Optional[str] = "pil",
         
     | 
| 449 | 
         
            -
                    return_dict: bool = True,
         
     | 
| 450 | 
         
            -
                    max_sequence_length: int = 512,
         
     | 
| 451 | 
         
            -
                ):
         
     | 
| 452 | 
         
            -
                    r"""
         
     | 
| 453 | 
         
            -
                    Function invoked when calling the pipeline for generation.
         
     | 
| 454 | 
         
            -
             
     | 
| 455 | 
         
            -
                    Args:
         
     | 
| 456 | 
         
            -
                        prompt (`str` or `List[str]`, *optional*):
         
     | 
| 457 | 
         
            -
                            The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
         
     | 
| 458 | 
         
            -
                            instead.
         
     | 
| 459 | 
         
            -
                        prompt_2 (`str` or `List[str]`, *optional*):
         
     | 
| 460 | 
         
            -
                            The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
         
     | 
| 461 | 
         
            -
                            will be used instead
         
     | 
| 462 | 
         
            -
                        height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
         
     | 
| 463 | 
         
            -
                            The height in pixels of the generated image. This is set to 1024 by default for the best results.
         
     | 
| 464 | 
         
            -
                        width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
         
     | 
| 465 | 
         
            -
                            The width in pixels of the generated image. This is set to 1024 by default for the best results.
         
     | 
| 466 | 
         
            -
                        num_inference_steps (`int`, *optional*, defaults to 50):
         
     | 
| 467 | 
         
            -
                            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
         
     | 
| 468 | 
         
            -
                            expense of slower inference.
         
     | 
| 469 | 
         
            -
                        timesteps (`List[int]`, *optional*):
         
     | 
| 470 | 
         
            -
                            Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
         
     | 
| 471 | 
         
            -
                            in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
         
     | 
| 472 | 
         
            -
                            passed will be used. Must be in descending order.
         
     | 
| 473 | 
         
            -
                        guidance_scale (`float`, *optional*, defaults to 7.0):
         
     | 
| 474 | 
         
            -
                            Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
         
     | 
| 475 | 
         
            -
                            `guidance_scale` is defined as `w` of equation 2. of [Imagen
         
     | 
| 476 | 
         
            -
                            Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
         
     | 
| 477 | 
         
            -
                            1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
         
     | 
| 478 | 
         
            -
                            usually at the expense of lower image quality.
         
     | 
| 479 | 
         
            -
                        num_images_per_prompt (`int`, *optional*, defaults to 1):
         
     | 
| 480 | 
         
            -
                            The number of images to generate per prompt.
         
     | 
| 481 | 
         
            -
                        generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
         
     | 
| 482 | 
         
            -
                            One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
         
     | 
| 483 | 
         
            -
                            to make generation deterministic.
         
     | 
| 484 | 
         
            -
                        latents (`torch.FloatTensor`, *optional*):
         
     | 
| 485 | 
         
            -
                            Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
         
     | 
| 486 | 
         
            -
                            generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
         
     | 
| 487 | 
         
            -
                            tensor will ge generated by sampling using the supplied random `generator`.
         
     | 
| 488 | 
         
            -
                        prompt_embeds (`torch.FloatTensor`, *optional*):
         
     | 
| 489 | 
         
            -
                            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
         
     | 
| 490 | 
         
            -
                            provided, text embeddings will be generated from `prompt` input argument.
         
     | 
| 491 | 
         
            -
                        pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
         
     | 
| 492 | 
         
            -
                            Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
         
     | 
| 493 | 
         
            -
                            If not provided, pooled text embeddings will be generated from `prompt` input argument.
         
     | 
| 494 | 
         
            -
                        output_type (`str`, *optional*, defaults to `"pil"`):
         
     | 
| 495 | 
         
            -
                            The output format of the generate image. Choose between
         
     | 
| 496 | 
         
            -
                            [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
         
     | 
| 497 | 
         
            -
                        return_dict (`bool`, *optional*, defaults to `True`):
         
     | 
| 498 | 
         
            -
                            Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
         
     | 
| 499 | 
         
            -
                        max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
         
     | 
| 500 | 
         
            -
                    Returns:
         
     | 
| 501 | 
         
            -
                        [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
         
     | 
| 502 | 
         
            -
                        is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
         
     | 
| 503 | 
         
            -
                        images.
         
     | 
| 504 | 
         
            -
                    """
         
     | 
| 505 | 
         
            -
             
     | 
| 506 | 
         
            -
                    height = height or self.default_sample_size * self.vae_scale_factor
         
     | 
| 507 | 
         
            -
                    width = width or self.default_sample_size * self.vae_scale_factor
         
     | 
| 508 | 
         
            -
             
     | 
| 509 | 
         
            -
                    # 1. Check inputs. Raise error if not correct
         
     | 
| 510 | 
         
            -
                    self.check_inputs(
         
     | 
| 511 | 
         
            -
                        prompt,
         
     | 
| 512 | 
         
            -
                        prompt_2,
         
     | 
| 513 | 
         
            -
                        height,
         
     | 
| 514 | 
         
            -
                        width,
         
     | 
| 515 | 
         
            -
                        prompt_embeds=prompt_embeds,
         
     | 
| 516 | 
         
            -
                        pooled_prompt_embeds=pooled_prompt_embeds,
         
     | 
| 517 | 
         
            -
                        max_sequence_length=max_sequence_length,
         
     | 
| 518 | 
         
            -
                    )
         
     | 
| 519 | 
         
            -
             
     | 
| 520 | 
         
            -
                    self._guidance_scale = guidance_scale
         
     | 
| 521 | 
         
            -
                    self._interrupt = False
         
     | 
| 522 | 
         
            -
             
     | 
| 523 | 
         
            -
                    # 2. Define call parameters
         
     | 
| 524 | 
         
            -
                    if prompt is not None and isinstance(prompt, str):
         
     | 
| 525 | 
         
            -
                        batch_size = 1
         
     | 
| 526 | 
         
            -
                    elif prompt is not None and isinstance(prompt, list):
         
     | 
| 527 | 
         
            -
                        batch_size = len(prompt)
         
     | 
| 528 | 
         
            -
                    else:
         
     | 
| 529 | 
         
            -
                        batch_size = prompt_embeds.shape[0]
         
     | 
| 530 | 
         
            -
             
     | 
| 531 | 
         
            -
                    (
         
     | 
| 532 | 
         
            -
                        prompt_embeds,
         
     | 
| 533 | 
         
            -
                        pooled_prompt_embeds,
         
     | 
| 534 | 
         
            -
                        text_ids,
         
     | 
| 535 | 
         
            -
                    ) = self.encode_prompt(
         
     | 
| 536 | 
         
            -
                        prompt=prompt,
         
     | 
| 537 | 
         
            -
                        prompt_2=prompt_2,
         
     | 
| 538 | 
         
            -
                        prompt_embeds=prompt_embeds,
         
     | 
| 539 | 
         
            -
                        pooled_prompt_embeds=pooled_prompt_embeds,
         
     | 
| 540 | 
         
            -
                        num_images_per_prompt=num_images_per_prompt,
         
     | 
| 541 | 
         
            -
                        max_sequence_length=max_sequence_length,
         
     | 
| 542 | 
         
            -
                    )
         
     | 
| 543 | 
         
            -
             
     | 
| 544 | 
         
            -
                    # 4. Prepare latent variables
         
     | 
| 545 | 
         
            -
                    num_channels_latents = self.transformer_config.get("in_channels", 64) // 4
         
     | 
| 546 | 
         
            -
                    latents, latent_image_ids = self.prepare_latents(
         
     | 
| 547 | 
         
            -
                        batch_size * num_images_per_prompt,
         
     | 
| 548 | 
         
            -
                        num_channels_latents,
         
     | 
| 549 | 
         
            -
                        height,
         
     | 
| 550 | 
         
            -
                        width,
         
     | 
| 551 | 
         
            -
                        generator,
         
     | 
| 552 | 
         
            -
                        latents,
         
     | 
| 553 | 
         
            -
                    )
         
     | 
| 554 | 
         
            -
             
     | 
| 555 | 
         
            -
                    # 5. Prepare timesteps
         
     | 
| 556 | 
         
            -
                    sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
         
     | 
| 557 | 
         
            -
                    image_seq_len = latents.shape[1]
         
     | 
| 558 | 
         
            -
                    mu = calculate_shift(
         
     | 
| 559 | 
         
            -
                        image_seq_len,
         
     | 
| 560 | 
         
            -
                        self.scheduler.config.base_image_seq_len,
         
     | 
| 561 | 
         
            -
                        self.scheduler.config.max_image_seq_len,
         
     | 
| 562 | 
         
            -
                        self.scheduler.config.base_shift,
         
     | 
| 563 | 
         
            -
                        self.scheduler.config.max_shift,
         
     | 
| 564 | 
         
            -
                    )
         
     | 
| 565 | 
         
            -
                    timesteps, num_inference_steps = retrieve_timesteps(
         
     | 
| 566 | 
         
            -
                        scheduler=self.scheduler,
         
     | 
| 567 | 
         
            -
                        num_inference_steps=num_inference_steps,
         
     | 
| 568 | 
         
            -
                        timesteps=timesteps,
         
     | 
| 569 | 
         
            -
                        sigmas=sigmas,
         
     | 
| 570 | 
         
            -
                        mu=mu,
         
     | 
| 571 | 
         
            -
                    )
         
     | 
| 572 | 
         
            -
                    num_warmup_steps = max(
         
     | 
| 573 | 
         
            -
                        len(timesteps) - num_inference_steps * self.scheduler.order, 0
         
     | 
| 574 | 
         
            -
                    )
         
     | 
| 575 | 
         
            -
                    self._num_timesteps = len(timesteps)
         
     | 
| 576 | 
         
            -
             
     | 
| 577 | 
         
            -
                    # 6. Denoising loop
         
     | 
| 578 | 
         
            -
                    with self.progress_bar(total=num_inference_steps) as progress_bar:
         
     | 
| 579 | 
         
            -
                        for i, t in enumerate(timesteps):
         
     | 
| 580 | 
         
            -
                            if self.interrupt:
         
     | 
| 581 | 
         
            -
                                continue
         
     | 
| 582 | 
         
            -
             
     | 
| 583 | 
         
            -
                            # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
         
     | 
| 584 | 
         
            -
                            timestep = t.expand(latents.shape[0]).to(latents.dtype)
         
     | 
| 585 | 
         
            -
             
     | 
| 586 | 
         
            -
                            # handle guidance
         
     | 
| 587 | 
         
            -
                            if self.transformer_config.get("guidance_embeds"):
         
     | 
| 588 | 
         
            -
                                guidance = torch.tensor([guidance_scale])
         
     | 
| 589 | 
         
            -
                                guidance = guidance.expand(latents.shape[0])
         
     | 
| 590 | 
         
            -
                            else:
         
     | 
| 591 | 
         
            -
                                guidance = None
         
     | 
| 592 | 
         
            -
             
     | 
| 593 | 
         
            -
                            transformer_input = {
         
     | 
| 594 | 
         
            -
                                "hidden_states": latents,
         
     | 
| 595 | 
         
            -
                                "timestep": timestep / 1000,
         
     | 
| 596 | 
         
            -
                                "pooled_projections": pooled_prompt_embeds,
         
     | 
| 597 | 
         
            -
                                "encoder_hidden_states": prompt_embeds,
         
     | 
| 598 | 
         
            -
                                "txt_ids": text_ids,
         
     | 
| 599 | 
         
            -
                                "img_ids": latent_image_ids,
         
     | 
| 600 | 
         
            -
                            }
         
     | 
| 601 | 
         
            -
                            if guidance is not None:
         
     | 
| 602 | 
         
            -
                                transformer_input["guidance"] = guidance
         
     | 
| 603 | 
         
            -
             
     | 
| 604 | 
         
            -
                            noise_pred = torch.from_numpy(self.transformer(transformer_input)[0])
         
     | 
| 605 | 
         
            -
             
     | 
| 606 | 
         
            -
                            latents = self.scheduler.step(
         
     | 
| 607 | 
         
            -
                                noise_pred, t, latents, return_dict=False
         
     | 
| 608 | 
         
            -
                            )[0]
         
     | 
| 609 | 
         
            -
             
     | 
| 610 | 
         
            -
                            # call the callback, if provided
         
     | 
| 611 | 
         
            -
                            if i == len(timesteps) - 1 or (
         
     | 
| 612 | 
         
            -
                                (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
         
     | 
| 613 | 
         
            -
                            ):
         
     | 
| 614 | 
         
            -
                                progress_bar.update()
         
     | 
| 615 | 
         
            -
             
     | 
| 616 | 
         
            -
                    if output_type == "latent":
         
     | 
| 617 | 
         
            -
                        image = latents
         
     | 
| 618 | 
         
            -
             
     | 
| 619 | 
         
            -
                    else:
         
     | 
| 620 | 
         
            -
                        latents = self._unpack_latents(
         
     | 
| 621 | 
         
            -
                            latents, height, width, self.vae_scale_factor
         
     | 
| 622 | 
         
            -
                        )
         
     | 
| 623 | 
         
            -
                        latents = latents / self.vae_config.get(
         
     | 
| 624 | 
         
            -
                            "scaling_factor"
         
     | 
| 625 | 
         
            -
                        ) + self.vae_config.get("shift_factor")
         
     | 
| 626 | 
         
            -
                        image = self.vae(latents)[0]
         
     | 
| 627 | 
         
            -
                        image = self.image_processor.postprocess(
         
     | 
| 628 | 
         
            -
                            torch.from_numpy(image), output_type=output_type
         
     | 
| 629 | 
         
            -
                        )
         
     | 
| 630 | 
         
            -
             
     | 
| 631 | 
         
            -
                    if not return_dict:
         
     | 
| 632 | 
         
            -
                        return (image,)
         
     | 
| 633 | 
         
            -
             
     | 
| 634 | 
         
            -
                    return FluxPipelineOutput(images=image)
         
     | 
| 635 | 
         
            -
             
     | 
| 636 | 
         
            -
             
     | 
| 637 | 
         
            -
            def init_pipeline(
         
     | 
| 638 | 
         
            -
                model_dir,
         
     | 
| 639 | 
         
            -
                models_dict: Dict[str, Any],
         
     | 
| 640 | 
         
            -
                device: str,
         
     | 
| 641 | 
         
            -
                use_taef1: bool = False,
         
     | 
| 642 | 
         
            -
            ):
         
     | 
| 643 | 
         
            -
                pipeline_args = {}
         
     | 
| 644 | 
         
            -
             
     | 
| 645 | 
         
            -
                print("OpenVINO FLUX Model compilation")
         
     | 
| 646 | 
         
            -
                core = ov.Core()
         
     | 
| 647 | 
         
            -
                for model_name, model_path in models_dict.items():
         
     | 
| 648 | 
         
            -
                    pipeline_args[model_name] = core.compile_model(model_path, device)
         
     | 
| 649 | 
         
            -
                    if model_name == "vae" and use_taef1:
         
     | 
| 650 | 
         
            -
                        print(f"✅ VAE(TAEF1) - Done!")
         
     | 
| 651 | 
         
            -
                    else:
         
     | 
| 652 | 
         
            -
                        print(f"✅ {model_name} - Done!")
         
     | 
| 653 | 
         
            -
             
     | 
| 654 | 
         
            -
                transformer_path = models_dict["transformer"]
         
     | 
| 655 | 
         
            -
                transformer_config_path = transformer_path.parent / "config.json"
         
     | 
| 656 | 
         
            -
                with transformer_config_path.open("r") as f:
         
     | 
| 657 | 
         
            -
                    transformer_config = json.load(f)
         
     | 
| 658 | 
         
            -
                vae_path = models_dict["vae"]
         
     | 
| 659 | 
         
            -
                vae_config_path = vae_path.parent / "config.json"
         
     | 
| 660 | 
         
            -
                with vae_config_path.open("r") as f:
         
     | 
| 661 | 
         
            -
                    vae_config = json.load(f)
         
     | 
| 662 | 
         
            -
             
     | 
| 663 | 
         
            -
                pipeline_args["vae_config"] = vae_config
         
     | 
| 664 | 
         
            -
                pipeline_args["transformer_config"] = transformer_config
         
     | 
| 665 | 
         
            -
             
     | 
| 666 | 
         
            -
                scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_dir / "scheduler")
         
     | 
| 667 | 
         
            -
             
     | 
| 668 | 
         
            -
                tokenizer = AutoTokenizer.from_pretrained(model_dir / "tokenizer")
         
     | 
| 669 | 
         
            -
                tokenizer_2 = AutoTokenizer.from_pretrained(model_dir / "tokenizer_2")
         
     | 
| 670 | 
         
            -
             
     | 
| 671 | 
         
            -
                pipeline_args["scheduler"] = scheduler
         
     | 
| 672 | 
         
            -
                pipeline_args["tokenizer"] = tokenizer
         
     | 
| 673 | 
         
            -
                pipeline_args["tokenizer_2"] = tokenizer_2
         
     | 
| 674 | 
         
            -
                ov_pipe = OVFluxPipeline(**pipeline_args)
         
     | 
| 675 | 
         
            -
                return ov_pipe
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/openvino/pipelines.py
    DELETED
    
    | 
         @@ -1,75 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from constants import DEVICE, LCM_DEFAULT_MODEL_OPENVINO
         
     | 
| 2 | 
         
            -
            from backend.tiny_decoder import get_tiny_decoder_vae_model
         
     | 
| 3 | 
         
            -
            from typing import Any
         
     | 
| 4 | 
         
            -
            from backend.device import is_openvino_device
         
     | 
| 5 | 
         
            -
            from paths import get_base_folder_name
         
     | 
| 6 | 
         
            -
             
     | 
| 7 | 
         
            -
            if is_openvino_device():
         
     | 
| 8 | 
         
            -
                from huggingface_hub import snapshot_download
         
     | 
| 9 | 
         
            -
                from optimum.intel.openvino.modeling_diffusion import OVBaseModel
         
     | 
| 10 | 
         
            -
             
     | 
| 11 | 
         
            -
                from optimum.intel.openvino.modeling_diffusion import (
         
     | 
| 12 | 
         
            -
                    OVStableDiffusionPipeline,
         
     | 
| 13 | 
         
            -
                    OVStableDiffusionImg2ImgPipeline,
         
     | 
| 14 | 
         
            -
                    OVStableDiffusionXLPipeline,
         
     | 
| 15 | 
         
            -
                    OVStableDiffusionXLImg2ImgPipeline,
         
     | 
| 16 | 
         
            -
                )
         
     | 
| 17 | 
         
            -
                from backend.openvino.custom_ov_model_vae_decoder import CustomOVModelVaeDecoder
         
     | 
| 18 | 
         
            -
             
     | 
| 19 | 
         
            -
             
     | 
| 20 | 
         
            -
            def ov_load_taesd(
         
     | 
| 21 | 
         
            -
                pipeline: Any,
         
     | 
| 22 | 
         
            -
                use_local_model: bool = False,
         
     | 
| 23 | 
         
            -
            ):
         
     | 
| 24 | 
         
            -
                taesd_dir = snapshot_download(
         
     | 
| 25 | 
         
            -
                    repo_id=get_tiny_decoder_vae_model(pipeline.__class__.__name__),
         
     | 
| 26 | 
         
            -
                    local_files_only=use_local_model,
         
     | 
| 27 | 
         
            -
                )
         
     | 
| 28 | 
         
            -
                pipeline.vae_decoder = CustomOVModelVaeDecoder(
         
     | 
| 29 | 
         
            -
                    model=OVBaseModel.load_model(f"{taesd_dir}/vae_decoder/openvino_model.xml"),
         
     | 
| 30 | 
         
            -
                    parent_model=pipeline,
         
     | 
| 31 | 
         
            -
                    model_dir=taesd_dir,
         
     | 
| 32 | 
         
            -
                )
         
     | 
| 33 | 
         
            -
             
     | 
| 34 | 
         
            -
             
     | 
| 35 | 
         
            -
            def get_ov_text_to_image_pipeline(
         
     | 
| 36 | 
         
            -
                model_id: str = LCM_DEFAULT_MODEL_OPENVINO,
         
     | 
| 37 | 
         
            -
                use_local_model: bool = False,
         
     | 
| 38 | 
         
            -
            ) -> Any:
         
     | 
| 39 | 
         
            -
                if "xl" in get_base_folder_name(model_id).lower():
         
     | 
| 40 | 
         
            -
                    pipeline = OVStableDiffusionXLPipeline.from_pretrained(
         
     | 
| 41 | 
         
            -
                        model_id,
         
     | 
| 42 | 
         
            -
                        local_files_only=use_local_model,
         
     | 
| 43 | 
         
            -
                        ov_config={"CACHE_DIR": ""},
         
     | 
| 44 | 
         
            -
                        device=DEVICE.upper(),
         
     | 
| 45 | 
         
            -
                    )
         
     | 
| 46 | 
         
            -
                else:
         
     | 
| 47 | 
         
            -
                    pipeline = OVStableDiffusionPipeline.from_pretrained(
         
     | 
| 48 | 
         
            -
                        model_id,
         
     | 
| 49 | 
         
            -
                        local_files_only=use_local_model,
         
     | 
| 50 | 
         
            -
                        ov_config={"CACHE_DIR": ""},
         
     | 
| 51 | 
         
            -
                        device=DEVICE.upper(),
         
     | 
| 52 | 
         
            -
                    )
         
     | 
| 53 | 
         
            -
             
     | 
| 54 | 
         
            -
                return pipeline
         
     | 
| 55 | 
         
            -
             
     | 
| 56 | 
         
            -
             
     | 
| 57 | 
         
            -
            def get_ov_image_to_image_pipeline(
         
     | 
| 58 | 
         
            -
                model_id: str = LCM_DEFAULT_MODEL_OPENVINO,
         
     | 
| 59 | 
         
            -
                use_local_model: bool = False,
         
     | 
| 60 | 
         
            -
            ) -> Any:
         
     | 
| 61 | 
         
            -
                if "xl" in get_base_folder_name(model_id).lower():
         
     | 
| 62 | 
         
            -
                    pipeline = OVStableDiffusionXLImg2ImgPipeline.from_pretrained(
         
     | 
| 63 | 
         
            -
                        model_id,
         
     | 
| 64 | 
         
            -
                        local_files_only=use_local_model,
         
     | 
| 65 | 
         
            -
                        ov_config={"CACHE_DIR": ""},
         
     | 
| 66 | 
         
            -
                        device=DEVICE.upper(),
         
     | 
| 67 | 
         
            -
                    )
         
     | 
| 68 | 
         
            -
                else:
         
     | 
| 69 | 
         
            -
                    pipeline = OVStableDiffusionImg2ImgPipeline.from_pretrained(
         
     | 
| 70 | 
         
            -
                        model_id,
         
     | 
| 71 | 
         
            -
                        local_files_only=use_local_model,
         
     | 
| 72 | 
         
            -
                        ov_config={"CACHE_DIR": ""},
         
     | 
| 73 | 
         
            -
                        device=DEVICE.upper(),
         
     | 
| 74 | 
         
            -
                    )
         
     | 
| 75 | 
         
            -
                return pipeline
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/openvino/stable_diffusion_engine.py
    DELETED
    
    | 
         @@ -1,1817 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            """
         
     | 
| 2 | 
         
            -
            Copyright(C) 2022-2023 Intel Corporation
         
     | 
| 3 | 
         
            -
            SPDX - License - Identifier: Apache - 2.0
         
     | 
| 4 | 
         
            -
             
     | 
| 5 | 
         
            -
            """
         
     | 
| 6 | 
         
            -
            import inspect
         
     | 
| 7 | 
         
            -
            from typing import Union, Optional, Any, List, Dict
         
     | 
| 8 | 
         
            -
            import numpy as np
         
     | 
| 9 | 
         
            -
            # openvino
         
     | 
| 10 | 
         
            -
            from openvino.runtime import Core
         
     | 
| 11 | 
         
            -
            # tokenizer
         
     | 
| 12 | 
         
            -
            from transformers import CLIPTokenizer
         
     | 
| 13 | 
         
            -
            import torch
         
     | 
| 14 | 
         
            -
            import random
         
     | 
| 15 | 
         
            -
             
     | 
| 16 | 
         
            -
            from diffusers import DiffusionPipeline
         
     | 
| 17 | 
         
            -
            from diffusers.schedulers import (DDIMScheduler,
         
     | 
| 18 | 
         
            -
                                              LMSDiscreteScheduler,
         
     | 
| 19 | 
         
            -
                                              PNDMScheduler,
         
     | 
| 20 | 
         
            -
                                              EulerDiscreteScheduler,
         
     | 
| 21 | 
         
            -
                                              EulerAncestralDiscreteScheduler)
         
     | 
| 22 | 
         
            -
             
     | 
| 23 | 
         
            -
             
     | 
| 24 | 
         
            -
            from diffusers.image_processor import VaeImageProcessor
         
     | 
| 25 | 
         
            -
            from diffusers.utils.torch_utils import randn_tensor
         
     | 
| 26 | 
         
            -
            from diffusers.utils import PIL_INTERPOLATION
         
     | 
| 27 | 
         
            -
             
     | 
| 28 | 
         
            -
            import cv2
         
     | 
| 29 | 
         
            -
            import os
         
     | 
| 30 | 
         
            -
            import sys
         
     | 
| 31 | 
         
            -
             
     | 
| 32 | 
         
            -
            # for multithreading 
         
     | 
| 33 | 
         
            -
            import concurrent.futures
         
     | 
| 34 | 
         
            -
             
     | 
| 35 | 
         
            -
            #For GIF
         
     | 
| 36 | 
         
            -
            import PIL
         
     | 
| 37 | 
         
            -
            from PIL import Image
         
     | 
| 38 | 
         
            -
            import glob
         
     | 
| 39 | 
         
            -
            import json
         
     | 
| 40 | 
         
            -
            import time
         
     | 
| 41 | 
         
            -
             
     | 
| 42 | 
         
            -
            def scale_fit_to_window(dst_width:int, dst_height:int, image_width:int, image_height:int):
         
     | 
| 43 | 
         
            -
                """
         
     | 
| 44 | 
         
            -
                Preprocessing helper function for calculating image size for resize with peserving original aspect ratio
         
     | 
| 45 | 
         
            -
                and fitting image to specific window size
         
     | 
| 46 | 
         
            -
             
     | 
| 47 | 
         
            -
                Parameters:
         
     | 
| 48 | 
         
            -
                  dst_width (int): destination window width
         
     | 
| 49 | 
         
            -
                  dst_height (int): destination window height
         
     | 
| 50 | 
         
            -
                  image_width (int): source image width
         
     | 
| 51 | 
         
            -
                  image_height (int): source image height
         
     | 
| 52 | 
         
            -
                Returns:
         
     | 
| 53 | 
         
            -
                  result_width (int): calculated width for resize
         
     | 
| 54 | 
         
            -
                  result_height (int): calculated height for resize
         
     | 
| 55 | 
         
            -
                """
         
     | 
| 56 | 
         
            -
                im_scale = min(dst_height / image_height, dst_width / image_width)
         
     | 
| 57 | 
         
            -
                return int(im_scale * image_width), int(im_scale * image_height)
         
     | 
| 58 | 
         
            -
             
     | 
| 59 | 
         
            -
            def preprocess(image: PIL.Image.Image, ht=512, wt=512):
         
     | 
| 60 | 
         
            -
                """
         
     | 
| 61 | 
         
            -
                Image preprocessing function. Takes image in PIL.Image format, resizes it to keep aspect ration and fits to model input window 512x512,
         
     | 
| 62 | 
         
            -
                then converts it to np.ndarray and adds padding with zeros on right or bottom side of image (depends from aspect ratio), after that
         
     | 
| 63 | 
         
            -
                converts data to float32 data type and change range of values from [0, 255] to [-1, 1], finally, converts data layout from planar NHWC to NCHW.
         
     | 
| 64 | 
         
            -
                The function returns preprocessed input tensor and padding size, which can be used in postprocessing.
         
     | 
| 65 | 
         
            -
             
     | 
| 66 | 
         
            -
                Parameters:
         
     | 
| 67 | 
         
            -
                  image (PIL.Image.Image): input image
         
     | 
| 68 | 
         
            -
                Returns:
         
     | 
| 69 | 
         
            -
                   image (np.ndarray): preprocessed image tensor
         
     | 
| 70 | 
         
            -
                   meta (Dict): dictionary with preprocessing metadata info
         
     | 
| 71 | 
         
            -
                """
         
     | 
| 72 | 
         
            -
             
     | 
| 73 | 
         
            -
                src_width, src_height = image.size
         
     | 
| 74 | 
         
            -
                image = image.convert('RGB')
         
     | 
| 75 | 
         
            -
                dst_width, dst_height = scale_fit_to_window(
         
     | 
| 76 | 
         
            -
                    wt, ht, src_width, src_height)
         
     | 
| 77 | 
         
            -
                image = np.array(image.resize((dst_width, dst_height),
         
     | 
| 78 | 
         
            -
                                 resample=PIL.Image.Resampling.LANCZOS))[None, :]
         
     | 
| 79 | 
         
            -
             
     | 
| 80 | 
         
            -
                pad_width = wt - dst_width
         
     | 
| 81 | 
         
            -
                pad_height = ht - dst_height
         
     | 
| 82 | 
         
            -
                pad = ((0, 0), (0, pad_height), (0, pad_width), (0, 0))
         
     | 
| 83 | 
         
            -
                image = np.pad(image, pad, mode="constant")
         
     | 
| 84 | 
         
            -
                image = image.astype(np.float32) / 255.0
         
     | 
| 85 | 
         
            -
                image = 2.0 * image - 1.0
         
     | 
| 86 | 
         
            -
                image = image.transpose(0, 3, 1, 2)
         
     | 
| 87 | 
         
            -
             
     | 
| 88 | 
         
            -
                return image, {"padding": pad, "src_width": src_width, "src_height": src_height}
         
     | 
| 89 | 
         
            -
             
     | 
| 90 | 
         
            -
            def try_enable_npu_turbo(device, core):
         
     | 
| 91 | 
         
            -
                import platform
         
     | 
| 92 | 
         
            -
                if "windows" in platform.system().lower():
         
     | 
| 93 | 
         
            -
                    if "NPU" in device and "3720" not in core.get_property('NPU', 'DEVICE_ARCHITECTURE'):
         
     | 
| 94 | 
         
            -
                        try:
         
     | 
| 95 | 
         
            -
                            core.set_property(properties={'NPU_TURBO': 'YES'},device_name='NPU')
         
     | 
| 96 | 
         
            -
                        except:
         
     | 
| 97 | 
         
            -
                            print(f"Failed loading NPU_TURBO for device {device}. Skipping... ")
         
     | 
| 98 | 
         
            -
                        else:
         
     | 
| 99 | 
         
            -
                            print_npu_turbo_art()
         
     | 
| 100 | 
         
            -
                    else:
         
     | 
| 101 | 
         
            -
                        print(f"Skipping NPU_TURBO for device {device}")
         
     | 
| 102 | 
         
            -
                elif "linux" in platform.system().lower():
         
     | 
| 103 | 
         
            -
                    if os.path.isfile('/sys/module/intel_vpu/parameters/test_mode'):
         
     | 
| 104 | 
         
            -
                        with open('/sys/module/intel_vpu/version', 'r') as f:
         
     | 
| 105 | 
         
            -
                            version = f.readline().split()[0]
         
     | 
| 106 | 
         
            -
                            if tuple(map(int, version.split('.'))) < tuple(map(int, '1.9.0'.split('.'))):
         
     | 
| 107 | 
         
            -
                                print(f"The driver intel_vpu-1.9.0 (or later) needs to be loaded for NPU Turbo (currently {version}). Skipping...")
         
     | 
| 108 | 
         
            -
                            else:
         
     | 
| 109 | 
         
            -
                                with open('/sys/module/intel_vpu/parameters/test_mode', 'r') as tm_file:
         
     | 
| 110 | 
         
            -
                                    test_mode = int(tm_file.readline().split()[0])
         
     | 
| 111 | 
         
            -
                                    if test_mode == 512:
         
     | 
| 112 | 
         
            -
                                        print_npu_turbo_art()
         
     | 
| 113 | 
         
            -
                                    else:
         
     | 
| 114 | 
         
            -
                                        print("The driver >=intel_vpu-1.9.0 was must be loaded with "
         
     | 
| 115 | 
         
            -
                                              "\"modprobe intel_vpu test_mode=512\" to enable NPU_TURBO "
         
     | 
| 116 | 
         
            -
                                              f"(currently test_mode={test_mode}). Skipping...")
         
     | 
| 117 | 
         
            -
                    else:
         
     | 
| 118 | 
         
            -
                        print(f"The driver >=intel_vpu-1.9.0 must be loaded with  \"modprobe intel_vpu test_mode=512\" to enable NPU_TURBO. Skipping...")
         
     | 
| 119 | 
         
            -
                else:
         
     | 
| 120 | 
         
            -
                    print(f"This platform ({platform.system()}) does not support NPU Turbo")
         
     | 
| 121 | 
         
            -
             
     | 
| 122 | 
         
            -
            def result(var):
         
     | 
| 123 | 
         
            -
                return next(iter(var.values()))
         
     | 
| 124 | 
         
            -
             
     | 
| 125 | 
         
            -
            class StableDiffusionEngineAdvanced(DiffusionPipeline):
         
     | 
| 126 | 
         
            -
                def __init__(self, model="runwayml/stable-diffusion-v1-5", 
         
     | 
| 127 | 
         
            -
                              tokenizer="openai/clip-vit-large-patch14", 
         
     | 
| 128 | 
         
            -
                              device=["CPU", "CPU", "CPU", "CPU"]):
         
     | 
| 129 | 
         
            -
                    try:
         
     | 
| 130 | 
         
            -
                        self.tokenizer = CLIPTokenizer.from_pretrained(model, local_files_only=True)
         
     | 
| 131 | 
         
            -
                    except:
         
     | 
| 132 | 
         
            -
                        self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer)
         
     | 
| 133 | 
         
            -
                        self.tokenizer.save_pretrained(model)
         
     | 
| 134 | 
         
            -
             
     | 
| 135 | 
         
            -
                    self.core = Core()
         
     | 
| 136 | 
         
            -
                    self.core.set_property({'CACHE_DIR': os.path.join(model, 'cache')})
         
     | 
| 137 | 
         
            -
                    try_enable_npu_turbo(device, self.core)
         
     | 
| 138 | 
         
            -
                        
         
     | 
| 139 | 
         
            -
                    print("Loading models... ")
         
     | 
| 140 | 
         
            -
                    
         
     | 
| 141 | 
         
            -
             
     | 
| 142 | 
         
            -
             
     | 
| 143 | 
         
            -
                    with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
         
     | 
| 144 | 
         
            -
                        futures = {
         
     | 
| 145 | 
         
            -
                            "unet_time_proj": executor.submit(self.core.compile_model, os.path.join(model, "unet_time_proj.xml"), device[0]),
         
     | 
| 146 | 
         
            -
                            "text": executor.submit(self.load_model, model, "text_encoder", device[0]),
         
     | 
| 147 | 
         
            -
                            "unet": executor.submit(self.load_model, model, "unet_int8", device[1]),
         
     | 
| 148 | 
         
            -
                            "unet_neg": executor.submit(self.load_model, model, "unet_int8", device[2]) if device[1] != device[2] else None,
         
     | 
| 149 | 
         
            -
                            "vae_decoder": executor.submit(self.load_model, model, "vae_decoder", device[3]),
         
     | 
| 150 | 
         
            -
                            "vae_encoder": executor.submit(self.load_model, model, "vae_encoder", device[3])
         
     | 
| 151 | 
         
            -
                        }
         
     | 
| 152 | 
         
            -
             
     | 
| 153 | 
         
            -
                    self.unet_time_proj = futures["unet_time_proj"].result()
         
     | 
| 154 | 
         
            -
                    self.text_encoder = futures["text"].result()
         
     | 
| 155 | 
         
            -
                    self.unet = futures["unet"].result()
         
     | 
| 156 | 
         
            -
                    self.unet_neg = futures["unet_neg"].result() if futures["unet_neg"] else self.unet
         
     | 
| 157 | 
         
            -
                    self.vae_decoder = futures["vae_decoder"].result()
         
     | 
| 158 | 
         
            -
                    self.vae_encoder = futures["vae_encoder"].result()
         
     | 
| 159 | 
         
            -
                    print("Text Device:", device[0])
         
     | 
| 160 | 
         
            -
                    print("unet Device:", device[1])
         
     | 
| 161 | 
         
            -
                    print("unet-neg Device:", device[2])
         
     | 
| 162 | 
         
            -
                    print("VAE Device:", device[3])
         
     | 
| 163 | 
         
            -
             
     | 
| 164 | 
         
            -
                    self._text_encoder_output = self.text_encoder.output(0)
         
     | 
| 165 | 
         
            -
                    self._vae_d_output = self.vae_decoder.output(0)
         
     | 
| 166 | 
         
            -
                    self._vae_e_output = self.vae_encoder.output(0) if self.vae_encoder else None
         
     | 
| 167 | 
         
            -
             
     | 
| 168 | 
         
            -
                    self.set_dimensions()
         
     | 
| 169 | 
         
            -
                    self.infer_request_neg = self.unet_neg.create_infer_request()
         
     | 
| 170 | 
         
            -
                    self.infer_request = self.unet.create_infer_request()
         
     | 
| 171 | 
         
            -
                    self.infer_request_time_proj = self.unet_time_proj.create_infer_request()
         
     | 
| 172 | 
         
            -
                    self.time_proj_constants = np.load(os.path.join(model, "time_proj_constants.npy"))
         
     | 
| 173 | 
         
            -
                    
         
     | 
| 174 | 
         
            -
                def load_model(self, model, model_name, device):
         
     | 
| 175 | 
         
            -
                    if "NPU" in device:
         
     | 
| 176 | 
         
            -
                        with open(os.path.join(model, f"{model_name}.blob"), "rb") as f:
         
     | 
| 177 | 
         
            -
                            return self.core.import_model(f.read(), device)
         
     | 
| 178 | 
         
            -
                    return self.core.compile_model(os.path.join(model, f"{model_name}.xml"), device)
         
     | 
| 179 | 
         
            -
                
         
     | 
| 180 | 
         
            -
                def set_dimensions(self):
         
     | 
| 181 | 
         
            -
                    latent_shape = self.unet.input("latent_model_input").shape
         
     | 
| 182 | 
         
            -
                    if latent_shape[1] == 4:
         
     | 
| 183 | 
         
            -
                        self.height = latent_shape[2] * 8
         
     | 
| 184 | 
         
            -
                        self.width = latent_shape[3] * 8
         
     | 
| 185 | 
         
            -
                    else:
         
     | 
| 186 | 
         
            -
                        self.height = latent_shape[1] * 8
         
     | 
| 187 | 
         
            -
                        self.width = latent_shape[2] * 8
         
     | 
| 188 | 
         
            -
             
     | 
| 189 | 
         
            -
                def __call__(
         
     | 
| 190 | 
         
            -
                        self,
         
     | 
| 191 | 
         
            -
                        prompt,
         
     | 
| 192 | 
         
            -
                        init_image = None,
         
     | 
| 193 | 
         
            -
                        negative_prompt=None,
         
     | 
| 194 | 
         
            -
                        scheduler=None,
         
     | 
| 195 | 
         
            -
                        strength = 0.5,
         
     | 
| 196 | 
         
            -
                        num_inference_steps = 32,
         
     | 
| 197 | 
         
            -
                        guidance_scale = 7.5,
         
     | 
| 198 | 
         
            -
                        eta = 0.0,
         
     | 
| 199 | 
         
            -
                        create_gif = False,
         
     | 
| 200 | 
         
            -
                        model = None,
         
     | 
| 201 | 
         
            -
                        callback = None,
         
     | 
| 202 | 
         
            -
                        callback_userdata = None
         
     | 
| 203 | 
         
            -
                ):
         
     | 
| 204 | 
         
            -
             
     | 
| 205 | 
         
            -
                    # extract condition
         
     | 
| 206 | 
         
            -
                    text_input = self.tokenizer(
         
     | 
| 207 | 
         
            -
                        prompt,
         
     | 
| 208 | 
         
            -
                        padding="max_length",
         
     | 
| 209 | 
         
            -
                        max_length=self.tokenizer.model_max_length,
         
     | 
| 210 | 
         
            -
                        truncation=True,
         
     | 
| 211 | 
         
            -
                        return_tensors="np",
         
     | 
| 212 | 
         
            -
                    )
         
     | 
| 213 | 
         
            -
                    text_embeddings = self.text_encoder(text_input.input_ids)[self._text_encoder_output]
         
     | 
| 214 | 
         
            -
             
     | 
| 215 | 
         
            -
                    # do classifier free guidance
         
     | 
| 216 | 
         
            -
                    do_classifier_free_guidance = guidance_scale > 1.0
         
     | 
| 217 | 
         
            -
                    if do_classifier_free_guidance:
         
     | 
| 218 | 
         
            -
             
     | 
| 219 | 
         
            -
                        if negative_prompt is None:
         
     | 
| 220 | 
         
            -
                            uncond_tokens = [""]
         
     | 
| 221 | 
         
            -
                        elif isinstance(negative_prompt, str):
         
     | 
| 222 | 
         
            -
                            uncond_tokens = [negative_prompt]
         
     | 
| 223 | 
         
            -
                        else:
         
     | 
| 224 | 
         
            -
                            uncond_tokens = negative_prompt
         
     | 
| 225 | 
         
            -
             
     | 
| 226 | 
         
            -
                        tokens_uncond = self.tokenizer(
         
     | 
| 227 | 
         
            -
                            uncond_tokens,
         
     | 
| 228 | 
         
            -
                            padding="max_length",
         
     | 
| 229 | 
         
            -
                            max_length=self.tokenizer.model_max_length, #truncation=True,
         
     | 
| 230 | 
         
            -
                            return_tensors="np"
         
     | 
| 231 | 
         
            -
                        )
         
     | 
| 232 | 
         
            -
                        uncond_embeddings = self.text_encoder(tokens_uncond.input_ids)[self._text_encoder_output]
         
     | 
| 233 | 
         
            -
                        text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
         
     | 
| 234 | 
         
            -
             
     | 
| 235 | 
         
            -
                    # set timesteps
         
     | 
| 236 | 
         
            -
                    accepts_offset = "offset" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
         
     | 
| 237 | 
         
            -
                    extra_set_kwargs = {}
         
     | 
| 238 | 
         
            -
             
     | 
| 239 | 
         
            -
                    if accepts_offset:
         
     | 
| 240 | 
         
            -
                        extra_set_kwargs["offset"] = 1
         
     | 
| 241 | 
         
            -
             
     | 
| 242 | 
         
            -
                    scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
         
     | 
| 243 | 
         
            -
             
     | 
| 244 | 
         
            -
                    timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, scheduler)
         
     | 
| 245 | 
         
            -
                    latent_timestep = timesteps[:1]
         
     | 
| 246 | 
         
            -
             
     | 
| 247 | 
         
            -
                    # get the initial random noise unless the user supplied it
         
     | 
| 248 | 
         
            -
                    latents, meta = self.prepare_latents(init_image, latent_timestep, scheduler)
         
     | 
| 249 | 
         
            -
             
     | 
| 250 | 
         
            -
             
     | 
| 251 | 
         
            -
                    # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
         
     | 
| 252 | 
         
            -
                    # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
         
     | 
| 253 | 
         
            -
                    # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
         
     | 
| 254 | 
         
            -
                    # and should be between [0, 1]
         
     | 
| 255 | 
         
            -
                    accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys())
         
     | 
| 256 | 
         
            -
                    extra_step_kwargs = {}
         
     | 
| 257 | 
         
            -
                    if accepts_eta:
         
     | 
| 258 | 
         
            -
                        extra_step_kwargs["eta"] = eta
         
     | 
| 259 | 
         
            -
                    if create_gif:
         
     | 
| 260 | 
         
            -
                        frames = []
         
     | 
| 261 | 
         
            -
             
     | 
| 262 | 
         
            -
                    for i, t in enumerate(self.progress_bar(timesteps)):
         
     | 
| 263 | 
         
            -
                        if callback:
         
     | 
| 264 | 
         
            -
                           callback(i, callback_userdata)
         
     | 
| 265 | 
         
            -
             
     | 
| 266 | 
         
            -
                        # expand the latents if we are doing classifier free guidance
         
     | 
| 267 | 
         
            -
                        noise_pred = []
         
     | 
| 268 | 
         
            -
                        latent_model_input = latents
         
     | 
| 269 | 
         
            -
                        latent_model_input = scheduler.scale_model_input(latent_model_input, t)
         
     | 
| 270 | 
         
            -
             
     | 
| 271 | 
         
            -
                        latent_model_input_neg = latent_model_input
         
     | 
| 272 | 
         
            -
                        if self.unet.input("latent_model_input").shape[1] != 4:
         
     | 
| 273 | 
         
            -
                            #print("In transpose")
         
     | 
| 274 | 
         
            -
                            try:
         
     | 
| 275 | 
         
            -
                                latent_model_input = latent_model_input.permute(0,2,3,1)
         
     | 
| 276 | 
         
            -
                            except:
         
     | 
| 277 | 
         
            -
                                latent_model_input = latent_model_input.transpose(0,2,3,1)
         
     | 
| 278 | 
         
            -
             
     | 
| 279 | 
         
            -
                        if self.unet_neg.input("latent_model_input").shape[1] != 4:
         
     | 
| 280 | 
         
            -
                            #print("In transpose")
         
     | 
| 281 | 
         
            -
                            try:
         
     | 
| 282 | 
         
            -
                                latent_model_input_neg = latent_model_input_neg.permute(0,2,3,1)
         
     | 
| 283 | 
         
            -
                            except:
         
     | 
| 284 | 
         
            -
                                latent_model_input_neg = latent_model_input_neg.transpose(0,2,3,1)
         
     | 
| 285 | 
         
            -
             
     | 
| 286 | 
         
            -
             
     | 
| 287 | 
         
            -
                        time_proj_constants_fp16 = np.float16(self.time_proj_constants)
         
     | 
| 288 | 
         
            -
                        t_scaled_fp16 = time_proj_constants_fp16 * np.float16(t)
         
     | 
| 289 | 
         
            -
                        cosine_t_fp16 = np.cos(t_scaled_fp16)
         
     | 
| 290 | 
         
            -
                        sine_t_fp16 = np.sin(t_scaled_fp16)
         
     | 
| 291 | 
         
            -
             
     | 
| 292 | 
         
            -
                        t_scaled = self.time_proj_constants * np.float32(t)
         
     | 
| 293 | 
         
            -
             
     | 
| 294 | 
         
            -
                        cosine_t = np.cos(t_scaled)
         
     | 
| 295 | 
         
            -
                        sine_t = np.sin(t_scaled)
         
     | 
| 296 | 
         
            -
             
     | 
| 297 | 
         
            -
                        time_proj_dict = {"sine_t" : np.float32(sine_t), "cosine_t" : np.float32(cosine_t)}
         
     | 
| 298 | 
         
            -
                        self.infer_request_time_proj.start_async(time_proj_dict)
         
     | 
| 299 | 
         
            -
                        self.infer_request_time_proj.wait()
         
     | 
| 300 | 
         
            -
                        time_proj = self.infer_request_time_proj.get_output_tensor(0).data.astype(np.float32)
         
     | 
| 301 | 
         
            -
             
     | 
| 302 | 
         
            -
                        input_tens_neg_dict = {"time_proj": np.float32(time_proj), "latent_model_input":latent_model_input_neg, "encoder_hidden_states": np.expand_dims(text_embeddings[0], axis=0)}
         
     | 
| 303 | 
         
            -
                        input_tens_dict = {"time_proj": np.float32(time_proj), "latent_model_input":latent_model_input, "encoder_hidden_states": np.expand_dims(text_embeddings[1], axis=0)}
         
     | 
| 304 | 
         
            -
             
     | 
| 305 | 
         
            -
                        self.infer_request_neg.start_async(input_tens_neg_dict)
         
     | 
| 306 | 
         
            -
                        self.infer_request.start_async(input_tens_dict)
         
     | 
| 307 | 
         
            -
                        self.infer_request_neg.wait()
         
     | 
| 308 | 
         
            -
                        self.infer_request.wait()
         
     | 
| 309 | 
         
            -
                        
         
     | 
| 310 | 
         
            -
                        noise_pred_neg = self.infer_request_neg.get_output_tensor(0)
         
     | 
| 311 | 
         
            -
                        noise_pred_pos = self.infer_request.get_output_tensor(0)
         
     | 
| 312 | 
         
            -
             
     | 
| 313 | 
         
            -
                        noise_pred.append(noise_pred_neg.data.astype(np.float32))
         
     | 
| 314 | 
         
            -
                        noise_pred.append(noise_pred_pos.data.astype(np.float32))
         
     | 
| 315 | 
         
            -
             
     | 
| 316 | 
         
            -
                        # perform guidance
         
     | 
| 317 | 
         
            -
                        if do_classifier_free_guidance:
         
     | 
| 318 | 
         
            -
                            noise_pred_uncond, noise_pred_text = noise_pred[0], noise_pred[1]
         
     | 
| 319 | 
         
            -
                            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
         
     | 
| 320 | 
         
            -
             
     | 
| 321 | 
         
            -
                        # compute the previous noisy sample x_t -> x_t-1
         
     | 
| 322 | 
         
            -
                        latents = scheduler.step(torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs)["prev_sample"].numpy()
         
     | 
| 323 | 
         
            -
             
     | 
| 324 | 
         
            -
                        if create_gif:
         
     | 
| 325 | 
         
            -
                            frames.append(latents)
         
     | 
| 326 | 
         
            -
             
     | 
| 327 | 
         
            -
                    if callback:
         
     | 
| 328 | 
         
            -
                        callback(num_inference_steps, callback_userdata)
         
     | 
| 329 | 
         
            -
             
     | 
| 330 | 
         
            -
                    # scale and decode the image latents with vae
         
     | 
| 331 | 
         
            -
                    latents = 1 / 0.18215 * latents
         
     | 
| 332 | 
         
            -
             
     | 
| 333 | 
         
            -
                    start = time.time()
         
     | 
| 334 | 
         
            -
                    image = self.vae_decoder(latents)[self._vae_d_output]
         
     | 
| 335 | 
         
            -
                    print("Decoder ended:",time.time() - start)
         
     | 
| 336 | 
         
            -
             
     | 
| 337 | 
         
            -
                    image = self.postprocess_image(image, meta)
         
     | 
| 338 | 
         
            -
             
     | 
| 339 | 
         
            -
                    if create_gif:
         
     | 
| 340 | 
         
            -
                        gif_folder=os.path.join(model,"../../../gif")
         
     | 
| 341 | 
         
            -
                        print("gif_folder:",gif_folder)
         
     | 
| 342 | 
         
            -
                        if not os.path.exists(gif_folder):
         
     | 
| 343 | 
         
            -
                            os.makedirs(gif_folder)
         
     | 
| 344 | 
         
            -
                        for i in range(0,len(frames)):
         
     | 
| 345 | 
         
            -
                            image = self.vae_decoder(frames[i]*(1/0.18215))[self._vae_d_output]
         
     | 
| 346 | 
         
            -
                            image = self.postprocess_image(image, meta)
         
     | 
| 347 | 
         
            -
                            output = gif_folder + "/" + str(i).zfill(3) +".png"
         
     | 
| 348 | 
         
            -
                            cv2.imwrite(output, image)
         
     | 
| 349 | 
         
            -
                        with open(os.path.join(gif_folder, "prompt.json"), "w") as file:
         
     | 
| 350 | 
         
            -
                            json.dump({"prompt": prompt}, file)
         
     | 
| 351 | 
         
            -
                        frames_image =  [Image.open(image) for image in glob.glob(f"{gif_folder}/*.png")]
         
     | 
| 352 | 
         
            -
                        frame_one = frames_image[0]
         
     | 
| 353 | 
         
            -
                        gif_file=os.path.join(gif_folder,"stable_diffusion.gif")
         
     | 
| 354 | 
         
            -
                        frame_one.save(gif_file, format="GIF", append_images=frames_image, save_all=True, duration=100, loop=0)
         
     | 
| 355 | 
         
            -
             
     | 
| 356 | 
         
            -
                    return image
         
     | 
| 357 | 
         
            -
             
     | 
| 358 | 
         
            -
                def prepare_latents(self, image:PIL.Image.Image = None, latent_timestep:torch.Tensor = None, scheduler = LMSDiscreteScheduler):
         
     | 
| 359 | 
         
            -
                    """
         
     | 
| 360 | 
         
            -
                    Function for getting initial latents for starting generation
         
     | 
| 361 | 
         
            -
             
     | 
| 362 | 
         
            -
                    Parameters:
         
     | 
| 363 | 
         
            -
                        image (PIL.Image.Image, *optional*, None):
         
     | 
| 364 | 
         
            -
                            Input image for generation, if not provided randon noise will be used as starting point
         
     | 
| 365 | 
         
            -
                        latent_timestep (torch.Tensor, *optional*, None):
         
     | 
| 366 | 
         
            -
                            Predicted by scheduler initial step for image generation, required for latent image mixing with nosie
         
     | 
| 367 | 
         
            -
                    Returns:
         
     | 
| 368 | 
         
            -
                        latents (np.ndarray):
         
     | 
| 369 | 
         
            -
                            Image encoded in latent space
         
     | 
| 370 | 
         
            -
                    """
         
     | 
| 371 | 
         
            -
                    latents_shape = (1, 4, self.height // 8, self.width // 8)
         
     | 
| 372 | 
         
            -
               
         
     | 
| 373 | 
         
            -
                    noise = np.random.randn(*latents_shape).astype(np.float32)
         
     | 
| 374 | 
         
            -
                    if image is None:
         
     | 
| 375 | 
         
            -
                        ##print("Image is NONE")
         
     | 
| 376 | 
         
            -
                        # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
         
     | 
| 377 | 
         
            -
                        if isinstance(scheduler, LMSDiscreteScheduler):
         
     | 
| 378 | 
         
            -
             
     | 
| 379 | 
         
            -
                            noise = noise * scheduler.sigmas[0].numpy()
         
     | 
| 380 | 
         
            -
                            return noise, {}
         
     | 
| 381 | 
         
            -
                        elif isinstance(scheduler, EulerDiscreteScheduler) or isinstance(scheduler,EulerAncestralDiscreteScheduler):
         
     | 
| 382 | 
         
            -
             
     | 
| 383 | 
         
            -
                            noise = noise * scheduler.sigmas.max().numpy()
         
     | 
| 384 | 
         
            -
                            return noise, {}
         
     | 
| 385 | 
         
            -
                        else:
         
     | 
| 386 | 
         
            -
                            return noise, {}
         
     | 
| 387 | 
         
            -
                    input_image, meta = preprocess(image,self.height,self.width)
         
     | 
| 388 | 
         
            -
                   
         
     | 
| 389 | 
         
            -
                    moments = self.vae_encoder(input_image)[self._vae_e_output]
         
     | 
| 390 | 
         
            -
                  
         
     | 
| 391 | 
         
            -
                    mean, logvar = np.split(moments, 2, axis=1)
         
     | 
| 392 | 
         
            -
              
         
     | 
| 393 | 
         
            -
                    std = np.exp(logvar * 0.5)
         
     | 
| 394 | 
         
            -
                    latents = (mean + std * np.random.randn(*mean.shape)) * 0.18215
         
     | 
| 395 | 
         
            -
                   
         
     | 
| 396 | 
         
            -
                     
         
     | 
| 397 | 
         
            -
                    latents = scheduler.add_noise(torch.from_numpy(latents), torch.from_numpy(noise), latent_timestep).numpy()
         
     | 
| 398 | 
         
            -
                    return latents, meta
         
     | 
| 399 | 
         
            -
             
     | 
| 400 | 
         
            -
                def postprocess_image(self, image:np.ndarray, meta:Dict):
         
     | 
| 401 | 
         
            -
                    """
         
     | 
| 402 | 
         
            -
                    Postprocessing for decoded image. Takes generated image decoded by VAE decoder, unpad it to initial image size (if required), 
         
     | 
| 403 | 
         
            -
                    normalize and convert to [0, 255] pixels range. Optionally, convertes it from np.ndarray to PIL.Image format
         
     | 
| 404 | 
         
            -
             
     | 
| 405 | 
         
            -
                    Parameters:
         
     | 
| 406 | 
         
            -
                        image (np.ndarray):
         
     | 
| 407 | 
         
            -
                            Generated image
         
     | 
| 408 | 
         
            -
                        meta (Dict):
         
     | 
| 409 | 
         
            -
                            Metadata obtained on latents preparing step, can be empty
         
     | 
| 410 | 
         
            -
                        output_type (str, *optional*, pil):
         
     | 
| 411 | 
         
            -
                            Output format for result, can be pil or numpy
         
     | 
| 412 | 
         
            -
                    Returns:
         
     | 
| 413 | 
         
            -
                        image (List of np.ndarray or PIL.Image.Image):
         
     | 
| 414 | 
         
            -
                            Postprocessed images
         
     | 
| 415 | 
         
            -
             
     | 
| 416 | 
         
            -
                                    if "src_height" in meta:
         
     | 
| 417 | 
         
            -
                        orig_height, orig_width = meta["src_height"], meta["src_width"]
         
     | 
| 418 | 
         
            -
                        image = [cv2.resize(img, (orig_width, orig_height))
         
     | 
| 419 | 
         
            -
                                    for img in image]
         
     | 
| 420 | 
         
            -
             
     | 
| 421 | 
         
            -
                    return image
         
     | 
| 422 | 
         
            -
                    """
         
     | 
| 423 | 
         
            -
                    if "padding" in meta:
         
     | 
| 424 | 
         
            -
                        pad = meta["padding"]
         
     | 
| 425 | 
         
            -
                        (_, end_h), (_, end_w) = pad[1:3]
         
     | 
| 426 | 
         
            -
                        h, w = image.shape[2:]
         
     | 
| 427 | 
         
            -
                        #print("image shape",image.shape[2:])
         
     | 
| 428 | 
         
            -
                        unpad_h = h - end_h
         
     | 
| 429 | 
         
            -
                        unpad_w = w - end_w
         
     | 
| 430 | 
         
            -
                        image = image[:, :, :unpad_h, :unpad_w]
         
     | 
| 431 | 
         
            -
                    image = np.clip(image / 2 + 0.5, 0, 1)
         
     | 
| 432 | 
         
            -
                    image = (image[0].transpose(1, 2, 0)[:, :, ::-1] * 255).astype(np.uint8)
         
     | 
| 433 | 
         
            -
             
     | 
| 434 | 
         
            -
             
     | 
| 435 | 
         
            -
             
     | 
| 436 | 
         
            -
                    if "src_height" in meta:
         
     | 
| 437 | 
         
            -
                        orig_height, orig_width = meta["src_height"], meta["src_width"]
         
     | 
| 438 | 
         
            -
                        image = cv2.resize(image, (orig_width, orig_height))
         
     | 
| 439 | 
         
            -
             
     | 
| 440 | 
         
            -
                    return image
         
     | 
| 441 | 
         
            -
             
     | 
| 442 | 
         
            -
             
     | 
| 443 | 
         
            -
             
     | 
| 444 | 
         
            -
             
     | 
| 445 | 
         
            -
                def get_timesteps(self, num_inference_steps:int, strength:float, scheduler):
         
     | 
| 446 | 
         
            -
                    """
         
     | 
| 447 | 
         
            -
                    Helper function for getting scheduler timesteps for generation
         
     | 
| 448 | 
         
            -
                    In case of image-to-image generation, it updates number of steps according to strength
         
     | 
| 449 | 
         
            -
             
     | 
| 450 | 
         
            -
                    Parameters:
         
     | 
| 451 | 
         
            -
                       num_inference_steps (int):
         
     | 
| 452 | 
         
            -
                          number of inference steps for generation
         
     | 
| 453 | 
         
            -
                       strength (float):
         
     | 
| 454 | 
         
            -
                           value between 0.0 and 1.0, that controls the amount of noise that is added to the input image.
         
     | 
| 455 | 
         
            -
                           Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input.
         
     | 
| 456 | 
         
            -
                    """
         
     | 
| 457 | 
         
            -
                    # get the original timestep using init_timestep
         
     | 
| 458 | 
         
            -
             
     | 
| 459 | 
         
            -
                    init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
         
     | 
| 460 | 
         
            -
             
     | 
| 461 | 
         
            -
                    t_start = max(num_inference_steps - init_timestep, 0)
         
     | 
| 462 | 
         
            -
                    timesteps = scheduler.timesteps[t_start:]
         
     | 
| 463 | 
         
            -
             
     | 
| 464 | 
         
            -
                    return timesteps, num_inference_steps - t_start
         
     | 
| 465 | 
         
            -
             
     | 
| 466 | 
         
            -
            class StableDiffusionEngine(DiffusionPipeline):
         
     | 
| 467 | 
         
            -
                def __init__(
         
     | 
| 468 | 
         
            -
                        self,
         
     | 
| 469 | 
         
            -
                        model="bes-dev/stable-diffusion-v1-4-openvino",
         
     | 
| 470 | 
         
            -
                        tokenizer="openai/clip-vit-large-patch14",
         
     | 
| 471 | 
         
            -
                        device=["CPU","CPU","CPU","CPU"]):
         
     | 
| 472 | 
         
            -
                    
         
     | 
| 473 | 
         
            -
                    self.core = Core()
         
     | 
| 474 | 
         
            -
                    self.core.set_property({'CACHE_DIR': os.path.join(model, 'cache')})
         
     | 
| 475 | 
         
            -
             
     | 
| 476 | 
         
            -
                    self.batch_size = 2 if device[1] == device[2] and device[1] == "GPU" else 1
         
     | 
| 477 | 
         
            -
                    try_enable_npu_turbo(device, self.core)
         
     | 
| 478 | 
         
            -
             
     | 
| 479 | 
         
            -
                    try:
         
     | 
| 480 | 
         
            -
                        self.tokenizer = CLIPTokenizer.from_pretrained(model, local_files_only=True)
         
     | 
| 481 | 
         
            -
                    except Exception as e:
         
     | 
| 482 | 
         
            -
                        print("Local tokenizer not found. Attempting to download...")
         
     | 
| 483 | 
         
            -
                        self.tokenizer = self.download_tokenizer(tokenizer, model)
         
     | 
| 484 | 
         
            -
                
         
     | 
| 485 | 
         
            -
                    print("Loading models... ")
         
     | 
| 486 | 
         
            -
             
     | 
| 487 | 
         
            -
                    with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
         
     | 
| 488 | 
         
            -
                        text_future = executor.submit(self.load_model, model, "text_encoder", device[0])
         
     | 
| 489 | 
         
            -
                        vae_de_future = executor.submit(self.load_model, model, "vae_decoder", device[3])
         
     | 
| 490 | 
         
            -
                        vae_en_future = executor.submit(self.load_model, model, "vae_encoder", device[3])
         
     | 
| 491 | 
         
            -
             
     | 
| 492 | 
         
            -
                        if self.batch_size == 1:
         
     | 
| 493 | 
         
            -
                            if "int8" not in model:
         
     | 
| 494 | 
         
            -
                                unet_future = executor.submit(self.load_model, model, "unet_bs1", device[1])
         
     | 
| 495 | 
         
            -
                                unet_neg_future = executor.submit(self.load_model, model, "unet_bs1", device[2]) if device[1] != device[2] else None
         
     | 
| 496 | 
         
            -
                            else:
         
     | 
| 497 | 
         
            -
                                unet_future = executor.submit(self.load_model, model, "unet_int8a16", device[1])
         
     | 
| 498 | 
         
            -
                                unet_neg_future = executor.submit(self.load_model, model, "unet_int8a16", device[2]) if device[1] != device[2] else None
         
     | 
| 499 | 
         
            -
                        else:
         
     | 
| 500 | 
         
            -
                            unet_future = executor.submit(self.load_model, model, "unet", device[1])
         
     | 
| 501 | 
         
            -
                            unet_neg_future = None
         
     | 
| 502 | 
         
            -
             
     | 
| 503 | 
         
            -
                        self.unet = unet_future.result()
         
     | 
| 504 | 
         
            -
                        self.unet_neg = unet_neg_future.result() if unet_neg_future else self.unet
         
     | 
| 505 | 
         
            -
                        self.text_encoder = text_future.result()
         
     | 
| 506 | 
         
            -
                        self.vae_decoder = vae_de_future.result()
         
     | 
| 507 | 
         
            -
                        self.vae_encoder = vae_en_future.result()
         
     | 
| 508 | 
         
            -
                        print("Text Device:", device[0])
         
     | 
| 509 | 
         
            -
                        print("unet Device:", device[1])
         
     | 
| 510 | 
         
            -
                        print("unet-neg Device:", device[2])
         
     | 
| 511 | 
         
            -
                        print("VAE Device:", device[3])
         
     | 
| 512 | 
         
            -
             
     | 
| 513 | 
         
            -
                        self._text_encoder_output = self.text_encoder.output(0)
         
     | 
| 514 | 
         
            -
                        self._unet_output = self.unet.output(0)
         
     | 
| 515 | 
         
            -
                        self._vae_d_output = self.vae_decoder.output(0)
         
     | 
| 516 | 
         
            -
                        self._vae_e_output = self.vae_encoder.output(0) if self.vae_encoder else None
         
     | 
| 517 | 
         
            -
             
     | 
| 518 | 
         
            -
                        self.unet_input_tensor_name = "sample" if 'sample' in self.unet.input(0).names else "latent_model_input"
         
     | 
| 519 | 
         
            -
             
     | 
| 520 | 
         
            -
                        if self.batch_size == 1:
         
     | 
| 521 | 
         
            -
                            self.infer_request = self.unet.create_infer_request()
         
     | 
| 522 | 
         
            -
                            self.infer_request_neg = self.unet_neg.create_infer_request()
         
     | 
| 523 | 
         
            -
                            self._unet_neg_output = self.unet_neg.output(0)
         
     | 
| 524 | 
         
            -
                        else:
         
     | 
| 525 | 
         
            -
                            self.infer_request = None
         
     | 
| 526 | 
         
            -
                            self.infer_request_neg = None
         
     | 
| 527 | 
         
            -
                            self._unet_neg_output = None
         
     | 
| 528 | 
         
            -
                     
         
     | 
| 529 | 
         
            -
                    self.set_dimensions()
         
     | 
| 530 | 
         
            -
             
     | 
| 531 | 
         
            -
                    
         
     | 
| 532 | 
         
            -
             
     | 
| 533 | 
         
            -
                def load_model(self, model, model_name, device):
         
     | 
| 534 | 
         
            -
                    if "NPU" in device:
         
     | 
| 535 | 
         
            -
                        with open(os.path.join(model, f"{model_name}.blob"), "rb") as f:
         
     | 
| 536 | 
         
            -
                            return self.core.import_model(f.read(), device)
         
     | 
| 537 | 
         
            -
                    return self.core.compile_model(os.path.join(model, f"{model_name}.xml"), device)
         
     | 
| 538 | 
         
            -
                    
         
     | 
| 539 | 
         
            -
                def set_dimensions(self):
         
     | 
| 540 | 
         
            -
                    latent_shape = self.unet.input(self.unet_input_tensor_name).shape
         
     | 
| 541 | 
         
            -
                    if latent_shape[1] == 4:
         
     | 
| 542 | 
         
            -
                        self.height = latent_shape[2] * 8
         
     | 
| 543 | 
         
            -
                        self.width = latent_shape[3] * 8
         
     | 
| 544 | 
         
            -
                    else:
         
     | 
| 545 | 
         
            -
                        self.height = latent_shape[1] * 8
         
     | 
| 546 | 
         
            -
                        self.width = latent_shape[2] * 8
         
     | 
| 547 | 
         
            -
             
     | 
| 548 | 
         
            -
                def __call__(
         
     | 
| 549 | 
         
            -
                        self,
         
     | 
| 550 | 
         
            -
                        prompt,
         
     | 
| 551 | 
         
            -
                        init_image=None,
         
     | 
| 552 | 
         
            -
                        negative_prompt=None,
         
     | 
| 553 | 
         
            -
                        scheduler=None,
         
     | 
| 554 | 
         
            -
                        strength=0.5,
         
     | 
| 555 | 
         
            -
                        num_inference_steps=32,
         
     | 
| 556 | 
         
            -
                        guidance_scale=7.5,
         
     | 
| 557 | 
         
            -
                        eta=0.0,
         
     | 
| 558 | 
         
            -
                        create_gif=False,
         
     | 
| 559 | 
         
            -
                        model=None,
         
     | 
| 560 | 
         
            -
                        callback=None,
         
     | 
| 561 | 
         
            -
                        callback_userdata=None
         
     | 
| 562 | 
         
            -
                ):
         
     | 
| 563 | 
         
            -
                    # extract condition
         
     | 
| 564 | 
         
            -
                    text_input = self.tokenizer(
         
     | 
| 565 | 
         
            -
                        prompt,
         
     | 
| 566 | 
         
            -
                        padding="max_length",
         
     | 
| 567 | 
         
            -
                        max_length=self.tokenizer.model_max_length,
         
     | 
| 568 | 
         
            -
                        truncation=True,
         
     | 
| 569 | 
         
            -
                        return_tensors="np",
         
     | 
| 570 | 
         
            -
                    )
         
     | 
| 571 | 
         
            -
                    text_embeddings = self.text_encoder(text_input.input_ids)[self._text_encoder_output]
         
     | 
| 572 | 
         
            -
                    
         
     | 
| 573 | 
         
            -
             
     | 
| 574 | 
         
            -
                    # do classifier free guidance
         
     | 
| 575 | 
         
            -
                    do_classifier_free_guidance = guidance_scale > 1.0
         
     | 
| 576 | 
         
            -
                    if do_classifier_free_guidance:
         
     | 
| 577 | 
         
            -
                        if negative_prompt is None:
         
     | 
| 578 | 
         
            -
                            uncond_tokens = [""]
         
     | 
| 579 | 
         
            -
                        elif isinstance(negative_prompt, str):
         
     | 
| 580 | 
         
            -
                            uncond_tokens = [negative_prompt]
         
     | 
| 581 | 
         
            -
                        else:
         
     | 
| 582 | 
         
            -
                            uncond_tokens = negative_prompt
         
     | 
| 583 | 
         
            -
             
     | 
| 584 | 
         
            -
                        tokens_uncond = self.tokenizer(
         
     | 
| 585 | 
         
            -
                            uncond_tokens,
         
     | 
| 586 | 
         
            -
                            padding="max_length",
         
     | 
| 587 | 
         
            -
                            max_length=self.tokenizer.model_max_length,  # truncation=True,
         
     | 
| 588 | 
         
            -
                            return_tensors="np"
         
     | 
| 589 | 
         
            -
                        )
         
     | 
| 590 | 
         
            -
                        uncond_embeddings = self.text_encoder(tokens_uncond.input_ids)[self._text_encoder_output]
         
     | 
| 591 | 
         
            -
                        text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
         
     | 
| 592 | 
         
            -
             
     | 
| 593 | 
         
            -
                    # set timesteps
         
     | 
| 594 | 
         
            -
                    accepts_offset = "offset" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
         
     | 
| 595 | 
         
            -
                    extra_set_kwargs = {}
         
     | 
| 596 | 
         
            -
             
     | 
| 597 | 
         
            -
                    if accepts_offset:
         
     | 
| 598 | 
         
            -
                        extra_set_kwargs["offset"] = 1
         
     | 
| 599 | 
         
            -
             
     | 
| 600 | 
         
            -
                    scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
         
     | 
| 601 | 
         
            -
             
     | 
| 602 | 
         
            -
                    timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, scheduler)
         
     | 
| 603 | 
         
            -
                    latent_timestep = timesteps[:1]
         
     | 
| 604 | 
         
            -
             
     | 
| 605 | 
         
            -
                    # get the initial random noise unless the user supplied it
         
     | 
| 606 | 
         
            -
                    latents, meta = self.prepare_latents(init_image, latent_timestep, scheduler,model)
         
     | 
| 607 | 
         
            -
             
     | 
| 608 | 
         
            -
                    # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
         
     | 
| 609 | 
         
            -
                    # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
         
     | 
| 610 | 
         
            -
                    # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
         
     | 
| 611 | 
         
            -
                    # and should be between [0, 1]
         
     | 
| 612 | 
         
            -
                    accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys())
         
     | 
| 613 | 
         
            -
                    extra_step_kwargs = {}
         
     | 
| 614 | 
         
            -
                    if accepts_eta:
         
     | 
| 615 | 
         
            -
                        extra_step_kwargs["eta"] = eta
         
     | 
| 616 | 
         
            -
                    if create_gif:
         
     | 
| 617 | 
         
            -
                        frames = []
         
     | 
| 618 | 
         
            -
             
     | 
| 619 | 
         
            -
                    for i, t in enumerate(self.progress_bar(timesteps)):
         
     | 
| 620 | 
         
            -
                        if callback:
         
     | 
| 621 | 
         
            -
                            callback(i, callback_userdata)
         
     | 
| 622 | 
         
            -
             
     | 
| 623 | 
         
            -
                        if self.batch_size == 1:
         
     | 
| 624 | 
         
            -
                            # expand the latents if we are doing classifier free guidance
         
     | 
| 625 | 
         
            -
                            noise_pred = []
         
     | 
| 626 | 
         
            -
                            latent_model_input = latents 
         
     | 
| 627 | 
         
            -
                               
         
     | 
| 628 | 
         
            -
                            #Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
         
     | 
| 629 | 
         
            -
                            latent_model_input = scheduler.scale_model_input(latent_model_input, t)
         
     | 
| 630 | 
         
            -
                            latent_model_input_pos = latent_model_input 
         
     | 
| 631 | 
         
            -
                            latent_model_input_neg = latent_model_input
         
     | 
| 632 | 
         
            -
             
     | 
| 633 | 
         
            -
                            if self.unet.input(self.unet_input_tensor_name).shape[1] != 4:
         
     | 
| 634 | 
         
            -
                                try:
         
     | 
| 635 | 
         
            -
                                    latent_model_input_pos = latent_model_input_pos.permute(0,2,3,1)
         
     | 
| 636 | 
         
            -
                                except:
         
     | 
| 637 | 
         
            -
                                    latent_model_input_pos = latent_model_input_pos.transpose(0,2,3,1)
         
     | 
| 638 | 
         
            -
                            
         
     | 
| 639 | 
         
            -
                            if self.unet_neg.input(self.unet_input_tensor_name).shape[1] != 4:
         
     | 
| 640 | 
         
            -
                                try:
         
     | 
| 641 | 
         
            -
                                    latent_model_input_neg = latent_model_input_neg.permute(0,2,3,1)
         
     | 
| 642 | 
         
            -
                                except:
         
     | 
| 643 | 
         
            -
                                    latent_model_input_neg = latent_model_input_neg.transpose(0,2,3,1)
         
     | 
| 644 | 
         
            -
                            
         
     | 
| 645 | 
         
            -
                            if "sample" in self.unet_input_tensor_name:                                        
         
     | 
| 646 | 
         
            -
                                input_tens_neg_dict = {"sample" : latent_model_input_neg, "encoder_hidden_states": np.expand_dims(text_embeddings[0], axis=0), "timestep": np.expand_dims(np.float32(t), axis=0)}
         
     | 
| 647 | 
         
            -
                                input_tens_pos_dict = {"sample" : latent_model_input_pos, "encoder_hidden_states": np.expand_dims(text_embeddings[1], axis=0), "timestep": np.expand_dims(np.float32(t), axis=0)}
         
     | 
| 648 | 
         
            -
                            else:
         
     | 
| 649 | 
         
            -
                                input_tens_neg_dict = {"latent_model_input" : latent_model_input_neg, "encoder_hidden_states": np.expand_dims(text_embeddings[0], axis=0), "t": np.expand_dims(np.float32(t), axis=0)}
         
     | 
| 650 | 
         
            -
                                input_tens_pos_dict = {"latent_model_input" : latent_model_input_pos, "encoder_hidden_states": np.expand_dims(text_embeddings[1], axis=0), "t": np.expand_dims(np.float32(t), axis=0)}
         
     | 
| 651 | 
         
            -
                                                                 
         
     | 
| 652 | 
         
            -
                            self.infer_request_neg.start_async(input_tens_neg_dict)
         
     | 
| 653 | 
         
            -
                            self.infer_request.start_async(input_tens_pos_dict)    
         
     | 
| 654 | 
         
            -
                     
         
     | 
| 655 | 
         
            -
                            self.infer_request_neg.wait()
         
     | 
| 656 | 
         
            -
                            self.infer_request.wait()
         
     | 
| 657 | 
         
            -
             
     | 
| 658 | 
         
            -
                            noise_pred_neg = self.infer_request_neg.get_output_tensor(0)
         
     | 
| 659 | 
         
            -
                            noise_pred_pos = self.infer_request.get_output_tensor(0)
         
     | 
| 660 | 
         
            -
                                           
         
     | 
| 661 | 
         
            -
                            noise_pred.append(noise_pred_neg.data.astype(np.float32))
         
     | 
| 662 | 
         
            -
                            noise_pred.append(noise_pred_pos.data.astype(np.float32))
         
     | 
| 663 | 
         
            -
                        else:
         
     | 
| 664 | 
         
            -
                            latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
         
     | 
| 665 | 
         
            -
                            latent_model_input = scheduler.scale_model_input(latent_model_input, t)
         
     | 
| 666 | 
         
            -
                            noise_pred = self.unet([latent_model_input, np.array(t, dtype=np.float32), text_embeddings])[self._unet_output]
         
     | 
| 667 | 
         
            -
                            
         
     | 
| 668 | 
         
            -
                        if do_classifier_free_guidance:
         
     | 
| 669 | 
         
            -
                            noise_pred_uncond, noise_pred_text = noise_pred[0], noise_pred[1]
         
     | 
| 670 | 
         
            -
                            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
         
     | 
| 671 | 
         
            -
             
     | 
| 672 | 
         
            -
                        # compute the previous noisy sample x_t -> x_t-1
         
     | 
| 673 | 
         
            -
                        latents = scheduler.step(torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs)["prev_sample"].numpy()
         
     | 
| 674 | 
         
            -
             
     | 
| 675 | 
         
            -
                        if create_gif:
         
     | 
| 676 | 
         
            -
                            frames.append(latents)
         
     | 
| 677 | 
         
            -
             
     | 
| 678 | 
         
            -
                    if callback:
         
     | 
| 679 | 
         
            -
                        callback(num_inference_steps, callback_userdata)
         
     | 
| 680 | 
         
            -
             
     | 
| 681 | 
         
            -
                    # scale and decode the image latents with vae
         
     | 
| 682 | 
         
            -
                    #if self.height == 512 and self.width == 512:
         
     | 
| 683 | 
         
            -
                    latents = 1 / 0.18215 * latents
         
     | 
| 684 | 
         
            -
                    image = self.vae_decoder(latents)[self._vae_d_output]
         
     | 
| 685 | 
         
            -
                    image = self.postprocess_image(image, meta)
         
     | 
| 686 | 
         
            -
             
     | 
| 687 | 
         
            -
                    return image
         
     | 
| 688 | 
         
            -
             
     | 
| 689 | 
         
            -
                def prepare_latents(self, image: PIL.Image.Image = None, latent_timestep: torch.Tensor = None,
         
     | 
| 690 | 
         
            -
                                    scheduler=LMSDiscreteScheduler,model=None):
         
     | 
| 691 | 
         
            -
                    """
         
     | 
| 692 | 
         
            -
                    Function for getting initial latents for starting generation
         
     | 
| 693 | 
         
            -
             
     | 
| 694 | 
         
            -
                    Parameters:
         
     | 
| 695 | 
         
            -
                        image (PIL.Image.Image, *optional*, None):
         
     | 
| 696 | 
         
            -
                            Input image for generation, if not provided randon noise will be used as starting point
         
     | 
| 697 | 
         
            -
                        latent_timestep (torch.Tensor, *optional*, None):
         
     | 
| 698 | 
         
            -
                            Predicted by scheduler initial step for image generation, required for latent image mixing with nosie
         
     | 
| 699 | 
         
            -
                    Returns:
         
     | 
| 700 | 
         
            -
                        latents (np.ndarray):
         
     | 
| 701 | 
         
            -
                            Image encoded in latent space
         
     | 
| 702 | 
         
            -
                    """
         
     | 
| 703 | 
         
            -
                    latents_shape = (1, 4, self.height // 8, self.width // 8)
         
     | 
| 704 | 
         
            -
             
     | 
| 705 | 
         
            -
                    noise = np.random.randn(*latents_shape).astype(np.float32)
         
     | 
| 706 | 
         
            -
                    if image is None:
         
     | 
| 707 | 
         
            -
                        #print("Image is NONE")
         
     | 
| 708 | 
         
            -
                        # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
         
     | 
| 709 | 
         
            -
                        if isinstance(scheduler, LMSDiscreteScheduler):
         
     | 
| 710 | 
         
            -
             
     | 
| 711 | 
         
            -
                            noise = noise * scheduler.sigmas[0].numpy()
         
     | 
| 712 | 
         
            -
                            return noise, {}
         
     | 
| 713 | 
         
            -
                        elif isinstance(scheduler, EulerDiscreteScheduler):
         
     | 
| 714 | 
         
            -
             
     | 
| 715 | 
         
            -
                            noise = noise * scheduler.sigmas.max().numpy()
         
     | 
| 716 | 
         
            -
                            return noise, {}
         
     | 
| 717 | 
         
            -
                        else:
         
     | 
| 718 | 
         
            -
                            return noise, {}
         
     | 
| 719 | 
         
            -
                    input_image, meta = preprocess(image, self.height, self.width)
         
     | 
| 720 | 
         
            -
             
     | 
| 721 | 
         
            -
                    moments = self.vae_encoder(input_image)[self._vae_e_output]
         
     | 
| 722 | 
         
            -
             
     | 
| 723 | 
         
            -
                    if "sd_2.1" in model:
         
     | 
| 724 | 
         
            -
                        latents = moments * 0.18215
         
     | 
| 725 | 
         
            -
             
     | 
| 726 | 
         
            -
                    else:
         
     | 
| 727 | 
         
            -
             
     | 
| 728 | 
         
            -
                        mean, logvar = np.split(moments, 2, axis=1)
         
     | 
| 729 | 
         
            -
             
     | 
| 730 | 
         
            -
                        std = np.exp(logvar * 0.5)
         
     | 
| 731 | 
         
            -
                        latents = (mean + std * np.random.randn(*mean.shape)) * 0.18215
         
     | 
| 732 | 
         
            -
             
     | 
| 733 | 
         
            -
                    latents = scheduler.add_noise(torch.from_numpy(latents), torch.from_numpy(noise), latent_timestep).numpy()
         
     | 
| 734 | 
         
            -
                    return latents, meta
         
     | 
| 735 | 
         
            -
                    
         
     | 
| 736 | 
         
            -
              
         
     | 
| 737 | 
         
            -
                def postprocess_image(self, image: np.ndarray, meta: Dict):
         
     | 
| 738 | 
         
            -
                    """
         
     | 
| 739 | 
         
            -
                    Postprocessing for decoded image. Takes generated image decoded by VAE decoder, unpad it to initila image size (if required),
         
     | 
| 740 | 
         
            -
                    normalize and convert to [0, 255] pixels range. Optionally, convertes it from np.ndarray to PIL.Image format
         
     | 
| 741 | 
         
            -
             
     | 
| 742 | 
         
            -
                    Parameters:
         
     | 
| 743 | 
         
            -
                        image (np.ndarray):
         
     | 
| 744 | 
         
            -
                            Generated image
         
     | 
| 745 | 
         
            -
                        meta (Dict):
         
     | 
| 746 | 
         
            -
                            Metadata obtained on latents preparing step, can be empty
         
     | 
| 747 | 
         
            -
                        output_type (str, *optional*, pil):
         
     | 
| 748 | 
         
            -
                            Output format for result, can be pil or numpy
         
     | 
| 749 | 
         
            -
                    Returns:
         
     | 
| 750 | 
         
            -
                        image (List of np.ndarray or PIL.Image.Image):
         
     | 
| 751 | 
         
            -
                            Postprocessed images
         
     | 
| 752 | 
         
            -
             
     | 
| 753 | 
         
            -
                                    if "src_height" in meta:
         
     | 
| 754 | 
         
            -
                        orig_height, orig_width = meta["src_height"], meta["src_width"]
         
     | 
| 755 | 
         
            -
                        image = [cv2.resize(img, (orig_width, orig_height))
         
     | 
| 756 | 
         
            -
                                    for img in image]
         
     | 
| 757 | 
         
            -
             
     | 
| 758 | 
         
            -
                    return image
         
     | 
| 759 | 
         
            -
                    """
         
     | 
| 760 | 
         
            -
                    if "padding" in meta:
         
     | 
| 761 | 
         
            -
                        pad = meta["padding"]
         
     | 
| 762 | 
         
            -
                        (_, end_h), (_, end_w) = pad[1:3]
         
     | 
| 763 | 
         
            -
                        h, w = image.shape[2:]
         
     | 
| 764 | 
         
            -
                        # print("image shape",image.shape[2:])
         
     | 
| 765 | 
         
            -
                        unpad_h = h - end_h
         
     | 
| 766 | 
         
            -
                        unpad_w = w - end_w
         
     | 
| 767 | 
         
            -
                        image = image[:, :, :unpad_h, :unpad_w]
         
     | 
| 768 | 
         
            -
                    image = np.clip(image / 2 + 0.5, 0, 1)
         
     | 
| 769 | 
         
            -
                    image = (image[0].transpose(1, 2, 0)[:, :, ::-1] * 255).astype(np.uint8)
         
     | 
| 770 | 
         
            -
             
     | 
| 771 | 
         
            -
                    if "src_height" in meta:
         
     | 
| 772 | 
         
            -
                        orig_height, orig_width = meta["src_height"], meta["src_width"]
         
     | 
| 773 | 
         
            -
                        image = cv2.resize(image, (orig_width, orig_height))
         
     | 
| 774 | 
         
            -
             
     | 
| 775 | 
         
            -
                    return image
         
     | 
| 776 | 
         
            -
             
     | 
| 777 | 
         
            -
                    # image = (image / 2 + 0.5).clip(0, 1)
         
     | 
| 778 | 
         
            -
                    # image = (image[0].transpose(1, 2, 0)[:, :, ::-1] * 255).astype(np.uint8)
         
     | 
| 779 | 
         
            -
             
     | 
| 780 | 
         
            -
                def get_timesteps(self, num_inference_steps: int, strength: float, scheduler):
         
     | 
| 781 | 
         
            -
                    """
         
     | 
| 782 | 
         
            -
                    Helper function for getting scheduler timesteps for generation
         
     | 
| 783 | 
         
            -
                    In case of image-to-image generation, it updates number of steps according to strength
         
     | 
| 784 | 
         
            -
             
     | 
| 785 | 
         
            -
                    Parameters:
         
     | 
| 786 | 
         
            -
                       num_inference_steps (int):
         
     | 
| 787 | 
         
            -
                          number of inference steps for generation
         
     | 
| 788 | 
         
            -
                       strength (float):
         
     | 
| 789 | 
         
            -
                           value between 0.0 and 1.0, that controls the amount of noise that is added to the input image.
         
     | 
| 790 | 
         
            -
                           Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input.
         
     | 
| 791 | 
         
            -
                    """
         
     | 
| 792 | 
         
            -
                    # get the original timestep using init_timestep
         
     | 
| 793 | 
         
            -
             
     | 
| 794 | 
         
            -
                    init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
         
     | 
| 795 | 
         
            -
             
     | 
| 796 | 
         
            -
                    t_start = max(num_inference_steps - init_timestep, 0)
         
     | 
| 797 | 
         
            -
                    timesteps = scheduler.timesteps[t_start:]
         
     | 
| 798 | 
         
            -
             
     | 
| 799 | 
         
            -
                    return timesteps, num_inference_steps - t_start
         
     | 
| 800 | 
         
            -
             
     | 
| 801 | 
         
            -
            class LatentConsistencyEngine(DiffusionPipeline):
         
     | 
| 802 | 
         
            -
                def __init__(
         
     | 
| 803 | 
         
            -
                    self,
         
     | 
| 804 | 
         
            -
                        model="SimianLuo/LCM_Dreamshaper_v7",
         
     | 
| 805 | 
         
            -
                        tokenizer="openai/clip-vit-large-patch14",
         
     | 
| 806 | 
         
            -
                        device=["CPU", "CPU", "CPU"],
         
     | 
| 807 | 
         
            -
                ):
         
     | 
| 808 | 
         
            -
                    super().__init__()
         
     | 
| 809 | 
         
            -
                    try:
         
     | 
| 810 | 
         
            -
                        self.tokenizer = CLIPTokenizer.from_pretrained(model, local_files_only=True)
         
     | 
| 811 | 
         
            -
                    except:
         
     | 
| 812 | 
         
            -
                        self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer)
         
     | 
| 813 | 
         
            -
                        self.tokenizer.save_pretrained(model)
         
     | 
| 814 | 
         
            -
             
     | 
| 815 | 
         
            -
                    self.core = Core()
         
     | 
| 816 | 
         
            -
                    self.core.set_property({'CACHE_DIR': os.path.join(model, 'cache')})  # adding caching to reduce init time
         
     | 
| 817 | 
         
            -
                    try_enable_npu_turbo(device, self.core)
         
     | 
| 818 | 
         
            -
                           
         
     | 
| 819 | 
         
            -
             
     | 
| 820 | 
         
            -
                    with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
         
     | 
| 821 | 
         
            -
                        text_future = executor.submit(self.load_model, model, "text_encoder", device[0])
         
     | 
| 822 | 
         
            -
                        unet_future = executor.submit(self.load_model, model, "unet", device[1])    
         
     | 
| 823 | 
         
            -
                        vae_de_future = executor.submit(self.load_model, model, "vae_decoder", device[2])
         
     | 
| 824 | 
         
            -
                            
         
     | 
| 825 | 
         
            -
                    print("Text Device:", device[0])
         
     | 
| 826 | 
         
            -
                    self.text_encoder = text_future.result()
         
     | 
| 827 | 
         
            -
                    self._text_encoder_output = self.text_encoder.output(0)
         
     | 
| 828 | 
         
            -
             
     | 
| 829 | 
         
            -
                    print("Unet Device:", device[1])
         
     | 
| 830 | 
         
            -
                    self.unet = unet_future.result()
         
     | 
| 831 | 
         
            -
                    self._unet_output = self.unet.output(0)
         
     | 
| 832 | 
         
            -
                    self.infer_request = self.unet.create_infer_request()
         
     | 
| 833 | 
         
            -
             
     | 
| 834 | 
         
            -
                    print(f"VAE Device: {device[2]}")
         
     | 
| 835 | 
         
            -
                    self.vae_decoder = vae_de_future.result()
         
     | 
| 836 | 
         
            -
                    self.infer_request_vae = self.vae_decoder.create_infer_request()
         
     | 
| 837 | 
         
            -
                    self.safety_checker = None #pipe.safety_checker
         
     | 
| 838 | 
         
            -
                    self.feature_extractor = None #pipe.feature_extractor
         
     | 
| 839 | 
         
            -
                    self.vae_scale_factor = 2 ** 3
         
     | 
| 840 | 
         
            -
                    self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
         
     | 
| 841 | 
         
            -
             
     | 
| 842 | 
         
            -
                def load_model(self, model, model_name, device):
         
     | 
| 843 | 
         
            -
                    if "NPU" in device:
         
     | 
| 844 | 
         
            -
                        with open(os.path.join(model, f"{model_name}.blob"), "rb") as f:
         
     | 
| 845 | 
         
            -
                            return self.core.import_model(f.read(), device)
         
     | 
| 846 | 
         
            -
                    return self.core.compile_model(os.path.join(model, f"{model_name}.xml"), device)
         
     | 
| 847 | 
         
            -
             
     | 
| 848 | 
         
            -
                def _encode_prompt(
         
     | 
| 849 | 
         
            -
                    self,
         
     | 
| 850 | 
         
            -
                    prompt,
         
     | 
| 851 | 
         
            -
                    num_images_per_prompt,
         
     | 
| 852 | 
         
            -
                    prompt_embeds: None,
         
     | 
| 853 | 
         
            -
                ):
         
     | 
| 854 | 
         
            -
                    r"""
         
     | 
| 855 | 
         
            -
                    Encodes the prompt into text encoder hidden states.
         
     | 
| 856 | 
         
            -
                    Args:
         
     | 
| 857 | 
         
            -
                        prompt (`str` or `List[str]`, *optional*):
         
     | 
| 858 | 
         
            -
                            prompt to be encoded
         
     | 
| 859 | 
         
            -
                        num_images_per_prompt (`int`):
         
     | 
| 860 | 
         
            -
                            number of images that should be generated per prompt
         
     | 
| 861 | 
         
            -
                        prompt_embeds (`torch.FloatTensor`, *optional*):
         
     | 
| 862 | 
         
            -
                            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
         
     | 
| 863 | 
         
            -
                            provided, text embeddings will be generated from `prompt` input argument.
         
     | 
| 864 | 
         
            -
                    """
         
     | 
| 865 | 
         
            -
             
     | 
| 866 | 
         
            -
                    if prompt_embeds is None:
         
     | 
| 867 | 
         
            -
             
     | 
| 868 | 
         
            -
                        text_inputs = self.tokenizer(
         
     | 
| 869 | 
         
            -
                            prompt,
         
     | 
| 870 | 
         
            -
                            padding="max_length",
         
     | 
| 871 | 
         
            -
                            max_length=self.tokenizer.model_max_length,
         
     | 
| 872 | 
         
            -
                            truncation=True,
         
     | 
| 873 | 
         
            -
                            return_tensors="pt",
         
     | 
| 874 | 
         
            -
                        )
         
     | 
| 875 | 
         
            -
                        text_input_ids = text_inputs.input_ids
         
     | 
| 876 | 
         
            -
                        untruncated_ids = self.tokenizer(
         
     | 
| 877 | 
         
            -
                            prompt, padding="longest", return_tensors="pt"
         
     | 
| 878 | 
         
            -
                        ).input_ids
         
     | 
| 879 | 
         
            -
             
     | 
| 880 | 
         
            -
                        if untruncated_ids.shape[-1] >= text_input_ids.shape[
         
     | 
| 881 | 
         
            -
                            -1
         
     | 
| 882 | 
         
            -
                        ] and not torch.equal(text_input_ids, untruncated_ids):
         
     | 
| 883 | 
         
            -
                            removed_text = self.tokenizer.batch_decode(
         
     | 
| 884 | 
         
            -
                                untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
         
     | 
| 885 | 
         
            -
                            )
         
     | 
| 886 | 
         
            -
                            logger.warning(
         
     | 
| 887 | 
         
            -
                                "The following part of your input was truncated because CLIP can only handle sequences up to"
         
     | 
| 888 | 
         
            -
                                f" {self.tokenizer.model_max_length} tokens: {removed_text}"
         
     | 
| 889 | 
         
            -
                            )
         
     | 
| 890 | 
         
            -
             
     | 
| 891 | 
         
            -
                        prompt_embeds = self.text_encoder(text_input_ids, share_inputs=True, share_outputs=True)
         
     | 
| 892 | 
         
            -
                        prompt_embeds = torch.from_numpy(prompt_embeds[0])
         
     | 
| 893 | 
         
            -
             
     | 
| 894 | 
         
            -
                    bs_embed, seq_len, _ = prompt_embeds.shape
         
     | 
| 895 | 
         
            -
                    # duplicate text embeddings for each generation per prompt
         
     | 
| 896 | 
         
            -
                    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
         
     | 
| 897 | 
         
            -
                    prompt_embeds = prompt_embeds.view(
         
     | 
| 898 | 
         
            -
                        bs_embed * num_images_per_prompt, seq_len, -1
         
     | 
| 899 | 
         
            -
                    )
         
     | 
| 900 | 
         
            -
             
     | 
| 901 | 
         
            -
                    # Don't need to get uncond prompt embedding because of LCM Guided Distillation
         
     | 
| 902 | 
         
            -
                    return prompt_embeds
         
     | 
| 903 | 
         
            -
             
     | 
| 904 | 
         
            -
                def run_safety_checker(self, image, dtype):
         
     | 
| 905 | 
         
            -
                    if self.safety_checker is None:
         
     | 
| 906 | 
         
            -
                        has_nsfw_concept = None
         
     | 
| 907 | 
         
            -
                    else:
         
     | 
| 908 | 
         
            -
                        if torch.is_tensor(image):
         
     | 
| 909 | 
         
            -
                            feature_extractor_input = self.image_processor.postprocess(
         
     | 
| 910 | 
         
            -
                                image, output_type="pil"
         
     | 
| 911 | 
         
            -
                            )
         
     | 
| 912 | 
         
            -
                        else:
         
     | 
| 913 | 
         
            -
                            feature_extractor_input = self.image_processor.numpy_to_pil(image)
         
     | 
| 914 | 
         
            -
                        safety_checker_input = self.feature_extractor(
         
     | 
| 915 | 
         
            -
                            feature_extractor_input, return_tensors="pt"
         
     | 
| 916 | 
         
            -
                        )
         
     | 
| 917 | 
         
            -
                        image, has_nsfw_concept = self.safety_checker(
         
     | 
| 918 | 
         
            -
                            images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
         
     | 
| 919 | 
         
            -
                        )
         
     | 
| 920 | 
         
            -
                    return image, has_nsfw_concept
         
     | 
| 921 | 
         
            -
             
     | 
| 922 | 
         
            -
                def prepare_latents(
         
     | 
| 923 | 
         
            -
                    self, batch_size, num_channels_latents, height, width, dtype, latents=None
         
     | 
| 924 | 
         
            -
                ):
         
     | 
| 925 | 
         
            -
                    shape = (
         
     | 
| 926 | 
         
            -
                        batch_size,
         
     | 
| 927 | 
         
            -
                        num_channels_latents,
         
     | 
| 928 | 
         
            -
                        height // self.vae_scale_factor,
         
     | 
| 929 | 
         
            -
                        width // self.vae_scale_factor,
         
     | 
| 930 | 
         
            -
                    )
         
     | 
| 931 | 
         
            -
                    if latents is None:
         
     | 
| 932 | 
         
            -
                        latents = torch.randn(shape, dtype=dtype)
         
     | 
| 933 | 
         
            -
                    # scale the initial noise by the standard deviation required by the scheduler
         
     | 
| 934 | 
         
            -
                    return latents
         
     | 
| 935 | 
         
            -
             
     | 
| 936 | 
         
            -
                def get_w_embedding(self, w, embedding_dim=512, dtype=torch.float32):
         
     | 
| 937 | 
         
            -
                    """
         
     | 
| 938 | 
         
            -
                    see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
         
     | 
| 939 | 
         
            -
                    Args:
         
     | 
| 940 | 
         
            -
                    timesteps: torch.Tensor: generate embedding vectors at these timesteps
         
     | 
| 941 | 
         
            -
                    embedding_dim: int: dimension of the embeddings to generate
         
     | 
| 942 | 
         
            -
                    dtype: data type of the generated embeddings
         
     | 
| 943 | 
         
            -
                    Returns:
         
     | 
| 944 | 
         
            -
                    embedding vectors with shape `(len(timesteps), embedding_dim)`
         
     | 
| 945 | 
         
            -
                    """
         
     | 
| 946 | 
         
            -
                    assert len(w.shape) == 1
         
     | 
| 947 | 
         
            -
                    w = w * 1000.0
         
     | 
| 948 | 
         
            -
             
     | 
| 949 | 
         
            -
                    half_dim = embedding_dim // 2
         
     | 
| 950 | 
         
            -
                    emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
         
     | 
| 951 | 
         
            -
                    emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
         
     | 
| 952 | 
         
            -
                    emb = w.to(dtype)[:, None] * emb[None, :]
         
     | 
| 953 | 
         
            -
                    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
         
     | 
| 954 | 
         
            -
                    if embedding_dim % 2 == 1:  # zero pad
         
     | 
| 955 | 
         
            -
                        emb = torch.nn.functional.pad(emb, (0, 1))
         
     | 
| 956 | 
         
            -
                    assert emb.shape == (w.shape[0], embedding_dim)
         
     | 
| 957 | 
         
            -
                    return emb
         
     | 
| 958 | 
         
            -
             
     | 
| 959 | 
         
            -
                @torch.no_grad()
         
     | 
| 960 | 
         
            -
                def __call__(
         
     | 
| 961 | 
         
            -
                    self,
         
     | 
| 962 | 
         
            -
                    prompt: Union[str, List[str]] = None,
         
     | 
| 963 | 
         
            -
                    height: Optional[int] = 512,
         
     | 
| 964 | 
         
            -
                    width: Optional[int] = 512,
         
     | 
| 965 | 
         
            -
                    guidance_scale: float = 7.5,
         
     | 
| 966 | 
         
            -
                    scheduler = None,
         
     | 
| 967 | 
         
            -
                    num_images_per_prompt: Optional[int] = 1,
         
     | 
| 968 | 
         
            -
                    latents: Optional[torch.FloatTensor] = None,
         
     | 
| 969 | 
         
            -
                    num_inference_steps: int = 4,
         
     | 
| 970 | 
         
            -
                    lcm_origin_steps: int = 50,
         
     | 
| 971 | 
         
            -
                    prompt_embeds: Optional[torch.FloatTensor] = None,
         
     | 
| 972 | 
         
            -
                    output_type: Optional[str] = "pil",
         
     | 
| 973 | 
         
            -
                    return_dict: bool = True,
         
     | 
| 974 | 
         
            -
                    model: Optional[Dict[str, any]] = None,
         
     | 
| 975 | 
         
            -
                    seed: Optional[int] = 1234567,
         
     | 
| 976 | 
         
            -
                    cross_attention_kwargs: Optional[Dict[str, Any]] = None,
         
     | 
| 977 | 
         
            -
                    callback = None,
         
     | 
| 978 | 
         
            -
                    callback_userdata = None
         
     | 
| 979 | 
         
            -
                ):
         
     | 
| 980 | 
         
            -
             
     | 
| 981 | 
         
            -
                    # 1. Define call parameters
         
     | 
| 982 | 
         
            -
                    if prompt is not None and isinstance(prompt, str):
         
     | 
| 983 | 
         
            -
                        batch_size = 1
         
     | 
| 984 | 
         
            -
                    elif prompt is not None and isinstance(prompt, list):
         
     | 
| 985 | 
         
            -
                        batch_size = len(prompt)
         
     | 
| 986 | 
         
            -
                    else:
         
     | 
| 987 | 
         
            -
                        batch_size = prompt_embeds.shape[0]
         
     | 
| 988 | 
         
            -
             
     | 
| 989 | 
         
            -
                    if seed is not None:
         
     | 
| 990 | 
         
            -
                        torch.manual_seed(seed)
         
     | 
| 991 | 
         
            -
             
     | 
| 992 | 
         
            -
                    #print("After Step 1: batch size is ", batch_size)
         
     | 
| 993 | 
         
            -
                    # do_classifier_free_guidance = guidance_scale > 0.0
         
     | 
| 994 | 
         
            -
                    # In LCM Implementation:  cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond) , (cfg_scale > 0.0 using CFG)
         
     | 
| 995 | 
         
            -
             
     | 
| 996 | 
         
            -
                    # 2. Encode input prompt
         
     | 
| 997 | 
         
            -
                    prompt_embeds = self._encode_prompt(
         
     | 
| 998 | 
         
            -
                        prompt,
         
     | 
| 999 | 
         
            -
                        num_images_per_prompt,
         
     | 
| 1000 | 
         
            -
                        prompt_embeds=prompt_embeds,
         
     | 
| 1001 | 
         
            -
                    )
         
     | 
| 1002 | 
         
            -
                    #print("After Step 2: prompt embeds is ", prompt_embeds)
         
     | 
| 1003 | 
         
            -
                    #print("After Step 2: scheduler is ", scheduler )
         
     | 
| 1004 | 
         
            -
                    # 3. Prepare timesteps
         
     | 
| 1005 | 
         
            -
                    scheduler.set_timesteps(num_inference_steps, original_inference_steps=lcm_origin_steps)
         
     | 
| 1006 | 
         
            -
                    timesteps = scheduler.timesteps
         
     | 
| 1007 | 
         
            -
             
     | 
| 1008 | 
         
            -
                    #print("After Step 3: timesteps is ", timesteps)
         
     | 
| 1009 | 
         
            -
             
     | 
| 1010 | 
         
            -
                    # 4. Prepare latent variable
         
     | 
| 1011 | 
         
            -
                    num_channels_latents = 4
         
     | 
| 1012 | 
         
            -
                    latents = self.prepare_latents(
         
     | 
| 1013 | 
         
            -
                        batch_size * num_images_per_prompt,
         
     | 
| 1014 | 
         
            -
                        num_channels_latents,
         
     | 
| 1015 | 
         
            -
                        height,
         
     | 
| 1016 | 
         
            -
                        width,
         
     | 
| 1017 | 
         
            -
                        prompt_embeds.dtype,
         
     | 
| 1018 | 
         
            -
                        latents,
         
     | 
| 1019 | 
         
            -
                    )
         
     | 
| 1020 | 
         
            -
                    latents = latents * scheduler.init_noise_sigma
         
     | 
| 1021 | 
         
            -
             
     | 
| 1022 | 
         
            -
                    #print("After Step 4: ")
         
     | 
| 1023 | 
         
            -
                    bs = batch_size * num_images_per_prompt
         
     | 
| 1024 | 
         
            -
             
     | 
| 1025 | 
         
            -
                    # 5. Get Guidance Scale Embedding
         
     | 
| 1026 | 
         
            -
                    w = torch.tensor(guidance_scale).repeat(bs)
         
     | 
| 1027 | 
         
            -
                    w_embedding = self.get_w_embedding(w, embedding_dim=256)
         
     | 
| 1028 | 
         
            -
                    #print("After Step 5: ")
         
     | 
| 1029 | 
         
            -
                    # 6. LCM MultiStep Sampling Loop:
         
     | 
| 1030 | 
         
            -
                    with self.progress_bar(total=num_inference_steps) as progress_bar:
         
     | 
| 1031 | 
         
            -
                        for i, t in enumerate(timesteps):
         
     | 
| 1032 | 
         
            -
                            if callback:
         
     | 
| 1033 | 
         
            -
                                callback(i+1, callback_userdata)
         
     | 
| 1034 | 
         
            -
             
     | 
| 1035 | 
         
            -
                            ts = torch.full((bs,), t, dtype=torch.long)
         
     | 
| 1036 | 
         
            -
             
     | 
| 1037 | 
         
            -
                            # model prediction (v-prediction, eps, x)
         
     | 
| 1038 | 
         
            -
                            model_pred = self.unet([latents, ts, prompt_embeds, w_embedding],share_inputs=True, share_outputs=True)[0]
         
     | 
| 1039 | 
         
            -
             
     | 
| 1040 | 
         
            -
                            # compute the previous noisy sample x_t -> x_t-1
         
     | 
| 1041 | 
         
            -
                            latents, denoised = scheduler.step(
         
     | 
| 1042 | 
         
            -
                                torch.from_numpy(model_pred), t, latents, return_dict=False
         
     | 
| 1043 | 
         
            -
                            )
         
     | 
| 1044 | 
         
            -
                            progress_bar.update()
         
     | 
| 1045 | 
         
            -
             
     | 
| 1046 | 
         
            -
                    #print("After Step 6: ")
         
     | 
| 1047 | 
         
            -
             
     | 
| 1048 | 
         
            -
                    vae_start = time.time()
         
     | 
| 1049 | 
         
            -
             
     | 
| 1050 | 
         
            -
                    if not output_type == "latent":
         
     | 
| 1051 | 
         
            -
                        image = torch.from_numpy(self.vae_decoder(denoised / 0.18215, share_inputs=True, share_outputs=True)[0])
         
     | 
| 1052 | 
         
            -
                    else:
         
     | 
| 1053 | 
         
            -
                        image = denoised
         
     | 
| 1054 | 
         
            -
             
     | 
| 1055 | 
         
            -
                    print("Decoder Ended: ", time.time() - vae_start)
         
     | 
| 1056 | 
         
            -
                    #post_start = time.time()
         
     | 
| 1057 | 
         
            -
             
     | 
| 1058 | 
         
            -
                    #if has_nsfw_concept is None:
         
     | 
| 1059 | 
         
            -
                    do_denormalize = [True] * image.shape[0]
         
     | 
| 1060 | 
         
            -
                    #else:
         
     | 
| 1061 | 
         
            -
                    #    do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
         
     | 
| 1062 | 
         
            -
             
     | 
| 1063 | 
         
            -
                    #print ("After do_denormalize: image is ", image)
         
     | 
| 1064 | 
         
            -
             
     | 
| 1065 | 
         
            -
                    image = self.image_processor.postprocess(
         
     | 
| 1066 | 
         
            -
                        image, output_type=output_type, do_denormalize=do_denormalize
         
     | 
| 1067 | 
         
            -
                    )
         
     | 
| 1068 | 
         
            -
             
     | 
| 1069 | 
         
            -
                    return image[0]
         
     | 
| 1070 | 
         
            -
             
     | 
| 1071 | 
         
            -
            class LatentConsistencyEngineAdvanced(DiffusionPipeline):
         
     | 
| 1072 | 
         
            -
                def __init__(
         
     | 
| 1073 | 
         
            -
                    self,
         
     | 
| 1074 | 
         
            -
                        model="SimianLuo/LCM_Dreamshaper_v7",
         
     | 
| 1075 | 
         
            -
                        tokenizer="openai/clip-vit-large-patch14",
         
     | 
| 1076 | 
         
            -
                        device=["CPU", "CPU", "CPU"],
         
     | 
| 1077 | 
         
            -
                ):
         
     | 
| 1078 | 
         
            -
                    super().__init__()
         
     | 
| 1079 | 
         
            -
                    try:
         
     | 
| 1080 | 
         
            -
                        self.tokenizer = CLIPTokenizer.from_pretrained(model, local_files_only=True)
         
     | 
| 1081 | 
         
            -
                    except:
         
     | 
| 1082 | 
         
            -
                        self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer)
         
     | 
| 1083 | 
         
            -
                        self.tokenizer.save_pretrained(model)
         
     | 
| 1084 | 
         
            -
             
     | 
| 1085 | 
         
            -
                    self.core = Core()
         
     | 
| 1086 | 
         
            -
                    self.core.set_property({'CACHE_DIR': os.path.join(model, 'cache')})  # adding caching to reduce init time
         
     | 
| 1087 | 
         
            -
                    #try_enable_npu_turbo(device, self.core)
         
     | 
| 1088 | 
         
            -
                           
         
     | 
| 1089 | 
         
            -
             
     | 
| 1090 | 
         
            -
                    with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
         
     | 
| 1091 | 
         
            -
                        text_future = executor.submit(self.load_model, model, "text_encoder", device[0])
         
     | 
| 1092 | 
         
            -
                        unet_future = executor.submit(self.load_model, model, "unet", device[1])    
         
     | 
| 1093 | 
         
            -
                        vae_de_future = executor.submit(self.load_model, model, "vae_decoder", device[2])
         
     | 
| 1094 | 
         
            -
                        vae_encoder_future = executor.submit(self.load_model, model, "vae_encoder", device[2])
         
     | 
| 1095 | 
         
            -
                            
         
     | 
| 1096 | 
         
            -
                   
         
     | 
| 1097 | 
         
            -
                    print("Text Device:", device[0])
         
     | 
| 1098 | 
         
            -
                    self.text_encoder = text_future.result()
         
     | 
| 1099 | 
         
            -
                    self._text_encoder_output = self.text_encoder.output(0)
         
     | 
| 1100 | 
         
            -
             
     | 
| 1101 | 
         
            -
                    print("Unet Device:", device[1])
         
     | 
| 1102 | 
         
            -
                    self.unet = unet_future.result()
         
     | 
| 1103 | 
         
            -
                    self._unet_output = self.unet.output(0)
         
     | 
| 1104 | 
         
            -
                    self.infer_request = self.unet.create_infer_request()
         
     | 
| 1105 | 
         
            -
             
     | 
| 1106 | 
         
            -
                    print(f"VAE Device: {device[2]}")
         
     | 
| 1107 | 
         
            -
                    self.vae_decoder = vae_de_future.result()
         
     | 
| 1108 | 
         
            -
                    self.vae_encoder = vae_encoder_future.result()
         
     | 
| 1109 | 
         
            -
                    self._vae_e_output = self.vae_encoder.output(0) if self.vae_encoder else None
         
     | 
| 1110 | 
         
            -
             
     | 
| 1111 | 
         
            -
                    self.infer_request_vae = self.vae_decoder.create_infer_request()
         
     | 
| 1112 | 
         
            -
                    self.safety_checker = None #pipe.safety_checker
         
     | 
| 1113 | 
         
            -
                    self.feature_extractor = None #pipe.feature_extractor
         
     | 
| 1114 | 
         
            -
                    self.vae_scale_factor = 2 ** 3
         
     | 
| 1115 | 
         
            -
                    self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
         
     | 
| 1116 | 
         
            -
             
     | 
| 1117 | 
         
            -
                def load_model(self, model, model_name, device):
         
     | 
| 1118 | 
         
            -
                    print(f"Compiling the {model_name} to {device} ...")
         
     | 
| 1119 | 
         
            -
                    return self.core.compile_model(os.path.join(model, f"{model_name}.xml"), device)
         
     | 
| 1120 | 
         
            -
                
         
     | 
| 1121 | 
         
            -
                def get_timesteps(self, num_inference_steps:int, strength:float, scheduler):
         
     | 
| 1122 | 
         
            -
                    """
         
     | 
| 1123 | 
         
            -
                    Helper function for getting scheduler timesteps for generation
         
     | 
| 1124 | 
         
            -
                    In case of image-to-image generation, it updates number of steps according to strength
         
     | 
| 1125 | 
         
            -
                    
         
     | 
| 1126 | 
         
            -
                    Parameters:
         
     | 
| 1127 | 
         
            -
                       num_inference_steps (int):
         
     | 
| 1128 | 
         
            -
                          number of inference steps for generation
         
     | 
| 1129 | 
         
            -
                       strength (float):
         
     | 
| 1130 | 
         
            -
                           value between 0.0 and 1.0, that controls the amount of noise that is added to the input image. 
         
     | 
| 1131 | 
         
            -
                           Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input.
         
     | 
| 1132 | 
         
            -
                    """
         
     | 
| 1133 | 
         
            -
                    # get the original timestep using init_timestep
         
     | 
| 1134 | 
         
            -
               
         
     | 
| 1135 | 
         
            -
                    init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
         
     | 
| 1136 | 
         
            -
                
         
     | 
| 1137 | 
         
            -
                    t_start = max(num_inference_steps - init_timestep, 0)
         
     | 
| 1138 | 
         
            -
                    timesteps = scheduler.timesteps[t_start:]
         
     | 
| 1139 | 
         
            -
             
     | 
| 1140 | 
         
            -
                    return timesteps, num_inference_steps - t_start     
         
     | 
| 1141 | 
         
            -
             
     | 
| 1142 | 
         
            -
                def _encode_prompt(
         
     | 
| 1143 | 
         
            -
                    self,
         
     | 
| 1144 | 
         
            -
                    prompt,
         
     | 
| 1145 | 
         
            -
                    num_images_per_prompt,
         
     | 
| 1146 | 
         
            -
                    prompt_embeds: None,
         
     | 
| 1147 | 
         
            -
                ):
         
     | 
| 1148 | 
         
            -
                    r"""
         
     | 
| 1149 | 
         
            -
                    Encodes the prompt into text encoder hidden states.
         
     | 
| 1150 | 
         
            -
                    Args:
         
     | 
| 1151 | 
         
            -
                        prompt (`str` or `List[str]`, *optional*):
         
     | 
| 1152 | 
         
            -
                            prompt to be encoded
         
     | 
| 1153 | 
         
            -
                        num_images_per_prompt (`int`):
         
     | 
| 1154 | 
         
            -
                            number of images that should be generated per prompt
         
     | 
| 1155 | 
         
            -
                        prompt_embeds (`torch.FloatTensor`, *optional*):
         
     | 
| 1156 | 
         
            -
                            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
         
     | 
| 1157 | 
         
            -
                            provided, text embeddings will be generated from `prompt` input argument.
         
     | 
| 1158 | 
         
            -
                    """
         
     | 
| 1159 | 
         
            -
             
     | 
| 1160 | 
         
            -
                    if prompt_embeds is None:
         
     | 
| 1161 | 
         
            -
             
     | 
| 1162 | 
         
            -
                        text_inputs = self.tokenizer(
         
     | 
| 1163 | 
         
            -
                            prompt,
         
     | 
| 1164 | 
         
            -
                            padding="max_length",
         
     | 
| 1165 | 
         
            -
                            max_length=self.tokenizer.model_max_length,
         
     | 
| 1166 | 
         
            -
                            truncation=True,
         
     | 
| 1167 | 
         
            -
                            return_tensors="pt",
         
     | 
| 1168 | 
         
            -
                        )
         
     | 
| 1169 | 
         
            -
                        text_input_ids = text_inputs.input_ids
         
     | 
| 1170 | 
         
            -
                        untruncated_ids = self.tokenizer(
         
     | 
| 1171 | 
         
            -
                            prompt, padding="longest", return_tensors="pt"
         
     | 
| 1172 | 
         
            -
                        ).input_ids
         
     | 
| 1173 | 
         
            -
             
     | 
| 1174 | 
         
            -
                        if untruncated_ids.shape[-1] >= text_input_ids.shape[
         
     | 
| 1175 | 
         
            -
                            -1
         
     | 
| 1176 | 
         
            -
                        ] and not torch.equal(text_input_ids, untruncated_ids):
         
     | 
| 1177 | 
         
            -
                            removed_text = self.tokenizer.batch_decode(
         
     | 
| 1178 | 
         
            -
                                untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
         
     | 
| 1179 | 
         
            -
                            )
         
     | 
| 1180 | 
         
            -
                            logger.warning(
         
     | 
| 1181 | 
         
            -
                                "The following part of your input was truncated because CLIP can only handle sequences up to"
         
     | 
| 1182 | 
         
            -
                                f" {self.tokenizer.model_max_length} tokens: {removed_text}"
         
     | 
| 1183 | 
         
            -
                            )
         
     | 
| 1184 | 
         
            -
             
     | 
| 1185 | 
         
            -
                        prompt_embeds = self.text_encoder(text_input_ids, share_inputs=True, share_outputs=True)
         
     | 
| 1186 | 
         
            -
                        prompt_embeds = torch.from_numpy(prompt_embeds[0])
         
     | 
| 1187 | 
         
            -
             
     | 
| 1188 | 
         
            -
                    bs_embed, seq_len, _ = prompt_embeds.shape
         
     | 
| 1189 | 
         
            -
                    # duplicate text embeddings for each generation per prompt
         
     | 
| 1190 | 
         
            -
                    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
         
     | 
| 1191 | 
         
            -
                    prompt_embeds = prompt_embeds.view(
         
     | 
| 1192 | 
         
            -
                        bs_embed * num_images_per_prompt, seq_len, -1
         
     | 
| 1193 | 
         
            -
                    )
         
     | 
| 1194 | 
         
            -
             
     | 
| 1195 | 
         
            -
                    # Don't need to get uncond prompt embedding because of LCM Guided Distillation
         
     | 
| 1196 | 
         
            -
                    return prompt_embeds
         
     | 
| 1197 | 
         
            -
             
     | 
| 1198 | 
         
            -
                def run_safety_checker(self, image, dtype):
         
     | 
| 1199 | 
         
            -
                    if self.safety_checker is None:
         
     | 
| 1200 | 
         
            -
                        has_nsfw_concept = None
         
     | 
| 1201 | 
         
            -
                    else:
         
     | 
| 1202 | 
         
            -
                        if torch.is_tensor(image):
         
     | 
| 1203 | 
         
            -
                            feature_extractor_input = self.image_processor.postprocess(
         
     | 
| 1204 | 
         
            -
                                image, output_type="pil"
         
     | 
| 1205 | 
         
            -
                            )
         
     | 
| 1206 | 
         
            -
                        else:
         
     | 
| 1207 | 
         
            -
                            feature_extractor_input = self.image_processor.numpy_to_pil(image)
         
     | 
| 1208 | 
         
            -
                        safety_checker_input = self.feature_extractor(
         
     | 
| 1209 | 
         
            -
                            feature_extractor_input, return_tensors="pt"
         
     | 
| 1210 | 
         
            -
                        )
         
     | 
| 1211 | 
         
            -
                        image, has_nsfw_concept = self.safety_checker(
         
     | 
| 1212 | 
         
            -
                            images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
         
     | 
| 1213 | 
         
            -
                        )
         
     | 
| 1214 | 
         
            -
                    return image, has_nsfw_concep
         
     | 
| 1215 | 
         
            -
             
     | 
| 1216 | 
         
            -
                def prepare_latents(
         
     | 
| 1217 | 
         
            -
                    self,image,timestep,batch_size, num_channels_latents, height, width, dtype, scheduler,latents=None,
         
     | 
| 1218 | 
         
            -
                ):
         
     | 
| 1219 | 
         
            -
                    shape = (
         
     | 
| 1220 | 
         
            -
                        batch_size,
         
     | 
| 1221 | 
         
            -
                        num_channels_latents,
         
     | 
| 1222 | 
         
            -
                        height // self.vae_scale_factor,
         
     | 
| 1223 | 
         
            -
                        width // self.vae_scale_factor,
         
     | 
| 1224 | 
         
            -
                    )
         
     | 
| 1225 | 
         
            -
                    if image:
         
     | 
| 1226 | 
         
            -
                        #latents_shape = (1, 4, 512, 512 // 8)
         
     | 
| 1227 | 
         
            -
                        #input_image, meta = preprocess(image,512,512)
         
     | 
| 1228 | 
         
            -
                        latents_shape = (1, 4, 512 // 8, 512 // 8)
         
     | 
| 1229 | 
         
            -
                        noise = np.random.randn(*latents_shape).astype(np.float32)
         
     | 
| 1230 | 
         
            -
                        input_image,meta = preprocess(image,512,512)
         
     | 
| 1231 | 
         
            -
                        moments = self.vae_encoder(input_image)[self._vae_e_output]
         
     | 
| 1232 | 
         
            -
                        mean, logvar = np.split(moments, 2, axis=1)
         
     | 
| 1233 | 
         
            -
                        std = np.exp(logvar * 0.5)
         
     | 
| 1234 | 
         
            -
                        latents = (mean + std * np.random.randn(*mean.shape)) * 0.18215
         
     | 
| 1235 | 
         
            -
                        noise = torch.randn(shape, dtype=dtype)
         
     | 
| 1236 | 
         
            -
                        #latents = scheduler.add_noise(init_latents, noise, timestep)
         
     | 
| 1237 | 
         
            -
                        latents = scheduler.add_noise(torch.from_numpy(latents), noise, timestep)
         
     | 
| 1238 | 
         
            -
                       
         
     | 
| 1239 | 
         
            -
                    else:
         
     | 
| 1240 | 
         
            -
                        latents = torch.randn(shape, dtype=dtype)
         
     | 
| 1241 | 
         
            -
                    # scale the initial noise by the standard deviation required by the scheduler
         
     | 
| 1242 | 
         
            -
                    return latents
         
     | 
| 1243 | 
         
            -
             
     | 
| 1244 | 
         
            -
                def get_w_embedding(self, w, embedding_dim=512, dtype=torch.float32):
         
     | 
| 1245 | 
         
            -
                    """
         
     | 
| 1246 | 
         
            -
                    see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
         
     | 
| 1247 | 
         
            -
                    Args:
         
     | 
| 1248 | 
         
            -
                    timesteps: torch.Tensor: generate embedding vectors at these timesteps
         
     | 
| 1249 | 
         
            -
                    embedding_dim: int: dimension of the embeddings to generate
         
     | 
| 1250 | 
         
            -
                    dtype: data type of the generated embeddings
         
     | 
| 1251 | 
         
            -
                    Returns:
         
     | 
| 1252 | 
         
            -
                    embedding vectors with shape `(len(timesteps), embedding_dim)`
         
     | 
| 1253 | 
         
            -
                    """
         
     | 
| 1254 | 
         
            -
                    assert len(w.shape) == 1
         
     | 
| 1255 | 
         
            -
                    w = w * 1000.0
         
     | 
| 1256 | 
         
            -
             
     | 
| 1257 | 
         
            -
                    half_dim = embedding_dim // 2
         
     | 
| 1258 | 
         
            -
                    emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
         
     | 
| 1259 | 
         
            -
                    emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
         
     | 
| 1260 | 
         
            -
                    emb = w.to(dtype)[:, None] * emb[None, :]
         
     | 
| 1261 | 
         
            -
                    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
         
     | 
| 1262 | 
         
            -
                    if embedding_dim % 2 == 1:  # zero pad
         
     | 
| 1263 | 
         
            -
                        emb = torch.nn.functional.pad(emb, (0, 1))
         
     | 
| 1264 | 
         
            -
                    assert emb.shape == (w.shape[0], embedding_dim)
         
     | 
| 1265 | 
         
            -
                    return emb
         
     | 
| 1266 | 
         
            -
             
     | 
| 1267 | 
         
            -
                @torch.no_grad()
         
     | 
| 1268 | 
         
            -
                def __call__(
         
     | 
| 1269 | 
         
            -
                    self,
         
     | 
| 1270 | 
         
            -
                    prompt: Union[str, List[str]] = None,
         
     | 
| 1271 | 
         
            -
                    init_image: Optional[PIL.Image.Image] = None,
         
     | 
| 1272 | 
         
            -
                    strength: Optional[float] = 0.8,
         
     | 
| 1273 | 
         
            -
                    height: Optional[int] = 512,
         
     | 
| 1274 | 
         
            -
                    width: Optional[int] = 512,
         
     | 
| 1275 | 
         
            -
                    guidance_scale: float = 7.5,
         
     | 
| 1276 | 
         
            -
                    scheduler = None,
         
     | 
| 1277 | 
         
            -
                    num_images_per_prompt: Optional[int] = 1,
         
     | 
| 1278 | 
         
            -
                    latents: Optional[torch.FloatTensor] = None,
         
     | 
| 1279 | 
         
            -
                    num_inference_steps: int = 4,
         
     | 
| 1280 | 
         
            -
                    lcm_origin_steps: int = 50,
         
     | 
| 1281 | 
         
            -
                    prompt_embeds: Optional[torch.FloatTensor] = None,
         
     | 
| 1282 | 
         
            -
                    output_type: Optional[str] = "pil",
         
     | 
| 1283 | 
         
            -
                    return_dict: bool = True,
         
     | 
| 1284 | 
         
            -
                    model: Optional[Dict[str, any]] = None,
         
     | 
| 1285 | 
         
            -
                    seed: Optional[int] = 1234567,
         
     | 
| 1286 | 
         
            -
                    cross_attention_kwargs: Optional[Dict[str, Any]] = None,
         
     | 
| 1287 | 
         
            -
                    callback = None,
         
     | 
| 1288 | 
         
            -
                    callback_userdata = None
         
     | 
| 1289 | 
         
            -
                ):
         
     | 
| 1290 | 
         
            -
             
     | 
| 1291 | 
         
            -
                    # 1. Define call parameters
         
     | 
| 1292 | 
         
            -
                    if prompt is not None and isinstance(prompt, str):
         
     | 
| 1293 | 
         
            -
                        batch_size = 1
         
     | 
| 1294 | 
         
            -
                    elif prompt is not None and isinstance(prompt, list):
         
     | 
| 1295 | 
         
            -
                        batch_size = len(prompt)
         
     | 
| 1296 | 
         
            -
                    else:
         
     | 
| 1297 | 
         
            -
                        batch_size = prompt_embeds.shape[0]
         
     | 
| 1298 | 
         
            -
             
     | 
| 1299 | 
         
            -
                    if seed is not None:
         
     | 
| 1300 | 
         
            -
                        torch.manual_seed(seed)
         
     | 
| 1301 | 
         
            -
             
     | 
| 1302 | 
         
            -
                    #print("After Step 1: batch size is ", batch_size)
         
     | 
| 1303 | 
         
            -
                    # do_classifier_free_guidance = guidance_scale > 0.0
         
     | 
| 1304 | 
         
            -
                    # In LCM Implementation:  cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond) , (cfg_scale > 0.0 using CFG)
         
     | 
| 1305 | 
         
            -
             
     | 
| 1306 | 
         
            -
                    # 2. Encode input prompt
         
     | 
| 1307 | 
         
            -
                    prompt_embeds = self._encode_prompt(
         
     | 
| 1308 | 
         
            -
                        prompt,
         
     | 
| 1309 | 
         
            -
                        num_images_per_prompt,
         
     | 
| 1310 | 
         
            -
                        prompt_embeds=prompt_embeds,
         
     | 
| 1311 | 
         
            -
                    )
         
     | 
| 1312 | 
         
            -
                    #print("After Step 2: prompt embeds is ", prompt_embeds)
         
     | 
| 1313 | 
         
            -
                    #print("After Step 2: scheduler is ", scheduler )
         
     | 
| 1314 | 
         
            -
                    # 3. Prepare timesteps
         
     | 
| 1315 | 
         
            -
                    #scheduler.set_timesteps(num_inference_steps, original_inference_steps=lcm_origin_steps)
         
     | 
| 1316 | 
         
            -
                    latent_timestep = None
         
     | 
| 1317 | 
         
            -
                    if init_image:
         
     | 
| 1318 | 
         
            -
                        scheduler.set_timesteps(num_inference_steps, original_inference_steps=lcm_origin_steps)
         
     | 
| 1319 | 
         
            -
                        timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, scheduler)
         
     | 
| 1320 | 
         
            -
                        latent_timestep = timesteps[:1]
         
     | 
| 1321 | 
         
            -
                    else:
         
     | 
| 1322 | 
         
            -
                         scheduler.set_timesteps(num_inference_steps, original_inference_steps=lcm_origin_steps)
         
     | 
| 1323 | 
         
            -
                         timesteps = scheduler.timesteps
         
     | 
| 1324 | 
         
            -
                    #timesteps = scheduler.timesteps
         
     | 
| 1325 | 
         
            -
                    #latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
         
     | 
| 1326 | 
         
            -
                    #print("timesteps: ", latent_timestep)
         
     | 
| 1327 | 
         
            -
             
     | 
| 1328 | 
         
            -
                    #print("After Step 3: timesteps is ", timesteps)
         
     | 
| 1329 | 
         
            -
             
     | 
| 1330 | 
         
            -
                    # 4. Prepare latent variable
         
     | 
| 1331 | 
         
            -
                    num_channels_latents = 4
         
     | 
| 1332 | 
         
            -
                    latents = self.prepare_latents(
         
     | 
| 1333 | 
         
            -
                            init_image,
         
     | 
| 1334 | 
         
            -
                            latent_timestep,
         
     | 
| 1335 | 
         
            -
                            batch_size * num_images_per_prompt,
         
     | 
| 1336 | 
         
            -
                            num_channels_latents,
         
     | 
| 1337 | 
         
            -
                            height,
         
     | 
| 1338 | 
         
            -
                            width,
         
     | 
| 1339 | 
         
            -
                            prompt_embeds.dtype,
         
     | 
| 1340 | 
         
            -
                            scheduler,
         
     | 
| 1341 | 
         
            -
                            latents,
         
     | 
| 1342 | 
         
            -
                        )
         
     | 
| 1343 | 
         
            -
                    
         
     | 
| 1344 | 
         
            -
                    latents = latents * scheduler.init_noise_sigma
         
     | 
| 1345 | 
         
            -
             
     | 
| 1346 | 
         
            -
                    #print("After Step 4: ")
         
     | 
| 1347 | 
         
            -
                    bs = batch_size * num_images_per_prompt
         
     | 
| 1348 | 
         
            -
             
     | 
| 1349 | 
         
            -
                    # 5. Get Guidance Scale Embedding
         
     | 
| 1350 | 
         
            -
                    w = torch.tensor(guidance_scale).repeat(bs)
         
     | 
| 1351 | 
         
            -
                    w_embedding = self.get_w_embedding(w, embedding_dim=256)
         
     | 
| 1352 | 
         
            -
                    #print("After Step 5: ")
         
     | 
| 1353 | 
         
            -
                    # 6. LCM MultiStep Sampling Loop:
         
     | 
| 1354 | 
         
            -
                    with self.progress_bar(total=num_inference_steps) as progress_bar:
         
     | 
| 1355 | 
         
            -
                        for i, t in enumerate(timesteps):
         
     | 
| 1356 | 
         
            -
                            if callback:
         
     | 
| 1357 | 
         
            -
                                callback(i+1, callback_userdata)
         
     | 
| 1358 | 
         
            -
             
     | 
| 1359 | 
         
            -
                            ts = torch.full((bs,), t, dtype=torch.long)
         
     | 
| 1360 | 
         
            -
             
     | 
| 1361 | 
         
            -
                            # model prediction (v-prediction, eps, x)
         
     | 
| 1362 | 
         
            -
                            model_pred = self.unet([latents, ts, prompt_embeds, w_embedding],share_inputs=True, share_outputs=True)[0]
         
     | 
| 1363 | 
         
            -
             
     | 
| 1364 | 
         
            -
                            # compute the previous noisy sample x_t -> x_t-1
         
     | 
| 1365 | 
         
            -
                            latents, denoised = scheduler.step(
         
     | 
| 1366 | 
         
            -
                                torch.from_numpy(model_pred), t, latents, return_dict=False
         
     | 
| 1367 | 
         
            -
                            )
         
     | 
| 1368 | 
         
            -
                            progress_bar.update()
         
     | 
| 1369 | 
         
            -
             
     | 
| 1370 | 
         
            -
                    #print("After Step 6: ")
         
     | 
| 1371 | 
         
            -
             
     | 
| 1372 | 
         
            -
                    vae_start = time.time()
         
     | 
| 1373 | 
         
            -
             
     | 
| 1374 | 
         
            -
                    if not output_type == "latent":
         
     | 
| 1375 | 
         
            -
                        image = torch.from_numpy(self.vae_decoder(denoised / 0.18215, share_inputs=True, share_outputs=True)[0])
         
     | 
| 1376 | 
         
            -
                    else:
         
     | 
| 1377 | 
         
            -
                        image = denoised
         
     | 
| 1378 | 
         
            -
             
     | 
| 1379 | 
         
            -
                    print("Decoder Ended: ", time.time() - vae_start)
         
     | 
| 1380 | 
         
            -
                    #post_start = time.time()
         
     | 
| 1381 | 
         
            -
             
     | 
| 1382 | 
         
            -
                    #if has_nsfw_concept is None:
         
     | 
| 1383 | 
         
            -
                    do_denormalize = [True] * image.shape[0]
         
     | 
| 1384 | 
         
            -
                    #else:
         
     | 
| 1385 | 
         
            -
                    #    do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
         
     | 
| 1386 | 
         
            -
             
     | 
| 1387 | 
         
            -
                    #print ("After do_denormalize: image is ", image)
         
     | 
| 1388 | 
         
            -
             
     | 
| 1389 | 
         
            -
                    image = self.image_processor.postprocess(
         
     | 
| 1390 | 
         
            -
                        image, output_type=output_type, do_denormalize=do_denormalize
         
     | 
| 1391 | 
         
            -
                    )
         
     | 
| 1392 | 
         
            -
             
     | 
| 1393 | 
         
            -
                    return image[0]
         
     | 
| 1394 | 
         
            -
                
         
     | 
| 1395 | 
         
            -
            class StableDiffusionEngineReferenceOnly(DiffusionPipeline):
         
     | 
| 1396 | 
         
            -
                def __init__(
         
     | 
| 1397 | 
         
            -
                        self,
         
     | 
| 1398 | 
         
            -
                        #scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
         
     | 
| 1399 | 
         
            -
                        model="bes-dev/stable-diffusion-v1-4-openvino",
         
     | 
| 1400 | 
         
            -
                        tokenizer="openai/clip-vit-large-patch14",
         
     | 
| 1401 | 
         
            -
                        device=["CPU","CPU","CPU"]
         
     | 
| 1402 | 
         
            -
                        ):
         
     | 
| 1403 | 
         
            -
                    #self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer)
         
     | 
| 1404 | 
         
            -
                    try: 
         
     | 
| 1405 | 
         
            -
                        self.tokenizer = CLIPTokenizer.from_pretrained(model,local_files_only=True)
         
     | 
| 1406 | 
         
            -
                    except:
         
     | 
| 1407 | 
         
            -
                        self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer)
         
     | 
| 1408 | 
         
            -
                        self.tokenizer.save_pretrained(model)
         
     | 
| 1409 | 
         
            -
                            
         
     | 
| 1410 | 
         
            -
                    #self.scheduler = scheduler
         
     | 
| 1411 | 
         
            -
                    # models
         
     | 
| 1412 | 
         
            -
                 
         
     | 
| 1413 | 
         
            -
                    self.core = Core()
         
     | 
| 1414 | 
         
            -
                    self.core.set_property({'CACHE_DIR': os.path.join(model, 'cache')}) #adding caching to reduce init time
         
     | 
| 1415 | 
         
            -
                    # text features
         
     | 
| 1416 | 
         
            -
             
     | 
| 1417 | 
         
            -
                    print("Text Device:",device[0])
         
     | 
| 1418 | 
         
            -
                    self.text_encoder = self.core.compile_model(os.path.join(model, "text_encoder.xml"), device[0])
         
     | 
| 1419 | 
         
            -
                    
         
     | 
| 1420 | 
         
            -
                    self._text_encoder_output = self.text_encoder.output(0)
         
     | 
| 1421 | 
         
            -
                   
         
     | 
| 1422 | 
         
            -
                    # diffusion
         
     | 
| 1423 | 
         
            -
                    print("unet_w Device:",device[1])
         
     | 
| 1424 | 
         
            -
                    self.unet_w = self.core.compile_model(os.path.join(model, "unet_reference_write.xml"), device[1]) 
         
     | 
| 1425 | 
         
            -
                    self._unet_w_output = self.unet_w.output(0)
         
     | 
| 1426 | 
         
            -
                    self.latent_shape = tuple(self.unet_w.inputs[0].shape)[1:]
         
     | 
| 1427 | 
         
            -
                    
         
     | 
| 1428 | 
         
            -
                    print("unet_r Device:",device[1])
         
     | 
| 1429 | 
         
            -
                    self.unet_r = self.core.compile_model(os.path.join(model, "unet_reference_read.xml"), device[1]) 
         
     | 
| 1430 | 
         
            -
                    self._unet_r_output = self.unet_r.output(0)
         
     | 
| 1431 | 
         
            -
                    # decoder
         
     | 
| 1432 | 
         
            -
                    print("Vae Device:",device[2])
         
     | 
| 1433 | 
         
            -
                    
         
     | 
| 1434 | 
         
            -
                    self.vae_decoder = self.core.compile_model(os.path.join(model, "vae_decoder.xml"), device[2])
         
     | 
| 1435 | 
         
            -
                        
         
     | 
| 1436 | 
         
            -
                    # encoder
         
     | 
| 1437 | 
         
            -
                        
         
     | 
| 1438 | 
         
            -
                    self.vae_encoder = self.core.compile_model(os.path.join(model, "vae_encoder.xml"), device[2]) 
         
     | 
| 1439 | 
         
            -
                
         
     | 
| 1440 | 
         
            -
                    self.init_image_shape = tuple(self.vae_encoder.inputs[0].shape)[2:]
         
     | 
| 1441 | 
         
            -
             
     | 
| 1442 | 
         
            -
                    self._vae_d_output = self.vae_decoder.output(0)
         
     | 
| 1443 | 
         
            -
                    self._vae_e_output = self.vae_encoder.output(0) if self.vae_encoder is not None else None  
         
     | 
| 1444 | 
         
            -
             
     | 
| 1445 | 
         
            -
                    self.height = self.unet_w.input(0).shape[2] * 8
         
     | 
| 1446 | 
         
            -
                    self.width = self.unet_w.input(0).shape[3] * 8      
         
     | 
| 1447 | 
         
            -
             
     | 
| 1448 | 
         
            -
             
     | 
| 1449 | 
         
            -
             
     | 
| 1450 | 
         
            -
                def __call__(
         
     | 
| 1451 | 
         
            -
                        self,
         
     | 
| 1452 | 
         
            -
                        prompt,
         
     | 
| 1453 | 
         
            -
                        image = None,
         
     | 
| 1454 | 
         
            -
                        negative_prompt=None,
         
     | 
| 1455 | 
         
            -
                        scheduler=None,
         
     | 
| 1456 | 
         
            -
                        strength = 1.0,
         
     | 
| 1457 | 
         
            -
                        num_inference_steps = 32,
         
     | 
| 1458 | 
         
            -
                        guidance_scale = 7.5,
         
     | 
| 1459 | 
         
            -
                        eta = 0.0,
         
     | 
| 1460 | 
         
            -
                        create_gif = False,
         
     | 
| 1461 | 
         
            -
                        model = None,
         
     | 
| 1462 | 
         
            -
                        callback = None,
         
     | 
| 1463 | 
         
            -
                        callback_userdata = None
         
     | 
| 1464 | 
         
            -
                ):
         
     | 
| 1465 | 
         
            -
                    # extract condition
         
     | 
| 1466 | 
         
            -
                    text_input = self.tokenizer(
         
     | 
| 1467 | 
         
            -
                        prompt,
         
     | 
| 1468 | 
         
            -
                        padding="max_length",
         
     | 
| 1469 | 
         
            -
                        max_length=self.tokenizer.model_max_length,
         
     | 
| 1470 | 
         
            -
                        truncation=True,
         
     | 
| 1471 | 
         
            -
                        return_tensors="np",
         
     | 
| 1472 | 
         
            -
                    )
         
     | 
| 1473 | 
         
            -
                    text_embeddings = self.text_encoder(text_input.input_ids)[self._text_encoder_output]
         
     | 
| 1474 | 
         
            -
                
         
     | 
| 1475 | 
         
            -
             
     | 
| 1476 | 
         
            -
                    # do classifier free guidance
         
     | 
| 1477 | 
         
            -
                    do_classifier_free_guidance = guidance_scale > 1.0
         
     | 
| 1478 | 
         
            -
                    if do_classifier_free_guidance:
         
     | 
| 1479 | 
         
            -
                    
         
     | 
| 1480 | 
         
            -
                        if negative_prompt is None:
         
     | 
| 1481 | 
         
            -
                            uncond_tokens = [""]
         
     | 
| 1482 | 
         
            -
                        elif isinstance(negative_prompt, str):
         
     | 
| 1483 | 
         
            -
                            uncond_tokens = [negative_prompt]
         
     | 
| 1484 | 
         
            -
                        else:
         
     | 
| 1485 | 
         
            -
                            uncond_tokens = negative_prompt
         
     | 
| 1486 | 
         
            -
                            
         
     | 
| 1487 | 
         
            -
                        tokens_uncond = self.tokenizer(
         
     | 
| 1488 | 
         
            -
                            uncond_tokens,
         
     | 
| 1489 | 
         
            -
                            padding="max_length",
         
     | 
| 1490 | 
         
            -
                            max_length=self.tokenizer.model_max_length, #truncation=True,  
         
     | 
| 1491 | 
         
            -
                            return_tensors="np"
         
     | 
| 1492 | 
         
            -
                        )
         
     | 
| 1493 | 
         
            -
                        uncond_embeddings = self.text_encoder(tokens_uncond.input_ids)[self._text_encoder_output]
         
     | 
| 1494 | 
         
            -
                        text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
         
     | 
| 1495 | 
         
            -
             
     | 
| 1496 | 
         
            -
                    # set timesteps
         
     | 
| 1497 | 
         
            -
                    accepts_offset = "offset" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
         
     | 
| 1498 | 
         
            -
                    extra_set_kwargs = {}
         
     | 
| 1499 | 
         
            -
                    
         
     | 
| 1500 | 
         
            -
                    if accepts_offset:
         
     | 
| 1501 | 
         
            -
                        extra_set_kwargs["offset"] = 1
         
     | 
| 1502 | 
         
            -
             
     | 
| 1503 | 
         
            -
                    scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
         
     | 
| 1504 | 
         
            -
             
     | 
| 1505 | 
         
            -
                    timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, scheduler)
         
     | 
| 1506 | 
         
            -
                    latent_timestep = timesteps[:1]
         
     | 
| 1507 | 
         
            -
             
     | 
| 1508 | 
         
            -
                    ref_image = self.prepare_image(
         
     | 
| 1509 | 
         
            -
                        image=image,
         
     | 
| 1510 | 
         
            -
                        width=512,
         
     | 
| 1511 | 
         
            -
                        height=512,
         
     | 
| 1512 | 
         
            -
                    )
         
     | 
| 1513 | 
         
            -
                    # get the initial random noise unless the user supplied it
         
     | 
| 1514 | 
         
            -
                    latents, meta = self.prepare_latents(None, latent_timestep, scheduler)
         
     | 
| 1515 | 
         
            -
                    #ref_image_latents, _ = self.prepare_latents(init_image, latent_timestep, scheduler)
         
     | 
| 1516 | 
         
            -
                    ref_image_latents = self.ov_prepare_ref_latents(ref_image)
         
     | 
| 1517 | 
         
            -
             
     | 
| 1518 | 
         
            -
                    # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
         
     | 
| 1519 | 
         
            -
                    # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
         
     | 
| 1520 | 
         
            -
                    # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
         
     | 
| 1521 | 
         
            -
                    # and should be between [0, 1]
         
     | 
| 1522 | 
         
            -
                    accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys())
         
     | 
| 1523 | 
         
            -
                    extra_step_kwargs = {}
         
     | 
| 1524 | 
         
            -
                    if accepts_eta:
         
     | 
| 1525 | 
         
            -
                        extra_step_kwargs["eta"] = eta
         
     | 
| 1526 | 
         
            -
                    if create_gif:
         
     | 
| 1527 | 
         
            -
                        frames = []        
         
     | 
| 1528 | 
         
            -
             
     | 
| 1529 | 
         
            -
                    for i, t in enumerate(self.progress_bar(timesteps)):
         
     | 
| 1530 | 
         
            -
                        if callback:
         
     | 
| 1531 | 
         
            -
                           callback(i, callback_userdata)
         
     | 
| 1532 | 
         
            -
             
     | 
| 1533 | 
         
            -
                        # expand the latents if we are doing classifier free guidance
         
     | 
| 1534 | 
         
            -
                        latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
         
     | 
| 1535 | 
         
            -
                        latent_model_input = scheduler.scale_model_input(latent_model_input, t)
         
     | 
| 1536 | 
         
            -
                        
         
     | 
| 1537 | 
         
            -
                        # ref only part
         
     | 
| 1538 | 
         
            -
                        noise = randn_tensor(
         
     | 
| 1539 | 
         
            -
                            ref_image_latents.shape
         
     | 
| 1540 | 
         
            -
                        )
         
     | 
| 1541 | 
         
            -
                           
         
     | 
| 1542 | 
         
            -
                        ref_xt = scheduler.add_noise(
         
     | 
| 1543 | 
         
            -
                            torch.from_numpy(ref_image_latents),
         
     | 
| 1544 | 
         
            -
                            noise,
         
     | 
| 1545 | 
         
            -
                            t.reshape(
         
     | 
| 1546 | 
         
            -
                                1,
         
     | 
| 1547 | 
         
            -
                            ),    
         
     | 
| 1548 | 
         
            -
                        ).numpy()
         
     | 
| 1549 | 
         
            -
                        ref_xt = np.concatenate([ref_xt] * 2) if do_classifier_free_guidance else ref_xt
         
     | 
| 1550 | 
         
            -
                        ref_xt = scheduler.scale_model_input(ref_xt, t)
         
     | 
| 1551 | 
         
            -
             
     | 
| 1552 | 
         
            -
                        # MODE = "write"
         
     | 
| 1553 | 
         
            -
                        result_w_dict = self.unet_w([
         
     | 
| 1554 | 
         
            -
                            ref_xt,
         
     | 
| 1555 | 
         
            -
                            t,
         
     | 
| 1556 | 
         
            -
                            text_embeddings
         
     | 
| 1557 | 
         
            -
                        ])
         
     | 
| 1558 | 
         
            -
                        down_0_attn0 = result_w_dict["/unet/down_blocks.0/attentions.0/transformer_blocks.0/norm1/LayerNormalization_output_0"]
         
     | 
| 1559 | 
         
            -
                        down_0_attn1 = result_w_dict["/unet/down_blocks.0/attentions.1/transformer_blocks.0/norm1/LayerNormalization_output_0"]
         
     | 
| 1560 | 
         
            -
                        down_1_attn0 = result_w_dict["/unet/down_blocks.1/attentions.0/transformer_blocks.0/norm1/LayerNormalization_output_0"]
         
     | 
| 1561 | 
         
            -
                        down_1_attn1 = result_w_dict["/unet/down_blocks.1/attentions.1/transformer_blocks.0/norm1/LayerNormalization_output_0"]
         
     | 
| 1562 | 
         
            -
                        down_2_attn0 = result_w_dict["/unet/down_blocks.2/attentions.0/transformer_blocks.0/norm1/LayerNormalization_output_0"]
         
     | 
| 1563 | 
         
            -
                        down_2_attn1 = result_w_dict["/unet/down_blocks.2/attentions.1/transformer_blocks.0/norm1/LayerNormalization_output_0"]
         
     | 
| 1564 | 
         
            -
                        mid_attn0    = result_w_dict["/unet/mid_block/attentions.0/transformer_blocks.0/norm1/LayerNormalization_output_0"]
         
     | 
| 1565 | 
         
            -
                        up_1_attn0   = result_w_dict["/unet/up_blocks.1/attentions.0/transformer_blocks.0/norm1/LayerNormalization_output_0"]
         
     | 
| 1566 | 
         
            -
                        up_1_attn1   = result_w_dict["/unet/up_blocks.1/attentions.1/transformer_blocks.0/norm1/LayerNormalization_output_0"]
         
     | 
| 1567 | 
         
            -
                        up_1_attn2   = result_w_dict["/unet/up_blocks.1/attentions.2/transformer_blocks.0/norm1/LayerNormalization_output_0"]
         
     | 
| 1568 | 
         
            -
                        up_2_attn0   = result_w_dict["/unet/up_blocks.2/attentions.0/transformer_blocks.0/norm1/LayerNormalization_output_0"]
         
     | 
| 1569 | 
         
            -
                        up_2_attn1   = result_w_dict["/unet/up_blocks.2/attentions.1/transformer_blocks.0/norm1/LayerNormalization_output_0"]
         
     | 
| 1570 | 
         
            -
                        up_2_attn2   = result_w_dict["/unet/up_blocks.2/attentions.2/transformer_blocks.0/norm1/LayerNormalization_output_0"]
         
     | 
| 1571 | 
         
            -
                        up_3_attn0   = result_w_dict["/unet/up_blocks.3/attentions.0/transformer_blocks.0/norm1/LayerNormalization_output_0"]
         
     | 
| 1572 | 
         
            -
                        up_3_attn1   = result_w_dict["/unet/up_blocks.3/attentions.1/transformer_blocks.0/norm1/LayerNormalization_output_0"]
         
     | 
| 1573 | 
         
            -
                        up_3_attn2   = result_w_dict["/unet/up_blocks.3/attentions.2/transformer_blocks.0/norm1/LayerNormalization_output_0"]
         
     | 
| 1574 | 
         
            -
                            
         
     | 
| 1575 | 
         
            -
                        # MODE = "read"
         
     | 
| 1576 | 
         
            -
                        noise_pred = self.unet_r([
         
     | 
| 1577 | 
         
            -
                            latent_model_input, t, text_embeddings, down_0_attn0, down_0_attn1, down_1_attn0,
         
     | 
| 1578 | 
         
            -
                            down_1_attn1, down_2_attn0, down_2_attn1, mid_attn0, up_1_attn0, up_1_attn1, up_1_attn2, 
         
     | 
| 1579 | 
         
            -
                            up_2_attn0, up_2_attn1, up_2_attn2, up_3_attn0, up_3_attn1, up_3_attn2
         
     | 
| 1580 | 
         
            -
                        ])[0]
         
     | 
| 1581 | 
         
            -
                            
         
     | 
| 1582 | 
         
            -
                        # perform guidance
         
     | 
| 1583 | 
         
            -
                        if do_classifier_free_guidance:
         
     | 
| 1584 | 
         
            -
                            noise_pred_uncond, noise_pred_text = noise_pred[0], noise_pred[1]
         
     | 
| 1585 | 
         
            -
                            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
         
     | 
| 1586 | 
         
            -
             
     | 
| 1587 | 
         
            -
                        # compute the previous noisy sample x_t -> x_t-1
         
     | 
| 1588 | 
         
            -
                        latents = scheduler.step(torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs)["prev_sample"].numpy()
         
     | 
| 1589 | 
         
            -
                 
         
     | 
| 1590 | 
         
            -
                        if create_gif:
         
     | 
| 1591 | 
         
            -
                            frames.append(latents)
         
     | 
| 1592 | 
         
            -
                          
         
     | 
| 1593 | 
         
            -
                    if callback:
         
     | 
| 1594 | 
         
            -
                        callback(num_inference_steps, callback_userdata)
         
     | 
| 1595 | 
         
            -
             
     | 
| 1596 | 
         
            -
                    # scale and decode the image latents with vae
         
     | 
| 1597 | 
         
            -
                    
         
     | 
| 1598 | 
         
            -
                    image = self.vae_decoder(latents)[self._vae_d_output]
         
     | 
| 1599 | 
         
            -
                  
         
     | 
| 1600 | 
         
            -
                    image = self.postprocess_image(image, meta)
         
     | 
| 1601 | 
         
            -
             
     | 
| 1602 | 
         
            -
                    if create_gif:
         
     | 
| 1603 | 
         
            -
                        gif_folder=os.path.join(model,"../../../gif")
         
     | 
| 1604 | 
         
            -
                        if not os.path.exists(gif_folder):
         
     | 
| 1605 | 
         
            -
                            os.makedirs(gif_folder)
         
     | 
| 1606 | 
         
            -
                        for i in range(0,len(frames)):
         
     | 
| 1607 | 
         
            -
                            image = self.vae_decoder(frames[i])[self._vae_d_output]
         
     | 
| 1608 | 
         
            -
                            image = self.postprocess_image(image, meta)
         
     | 
| 1609 | 
         
            -
                            output = gif_folder + "/" + str(i).zfill(3) +".png"
         
     | 
| 1610 | 
         
            -
                            cv2.imwrite(output, image)
         
     | 
| 1611 | 
         
            -
                        with open(os.path.join(gif_folder, "prompt.json"), "w") as file:
         
     | 
| 1612 | 
         
            -
                            json.dump({"prompt": prompt}, file)
         
     | 
| 1613 | 
         
            -
                        frames_image =  [Image.open(image) for image in glob.glob(f"{gif_folder}/*.png")]  
         
     | 
| 1614 | 
         
            -
                        frame_one = frames_image[0]
         
     | 
| 1615 | 
         
            -
                        gif_file=os.path.join(gif_folder,"stable_diffusion.gif")
         
     | 
| 1616 | 
         
            -
                        frame_one.save(gif_file, format="GIF", append_images=frames_image, save_all=True, duration=100, loop=0)
         
     | 
| 1617 | 
         
            -
             
     | 
| 1618 | 
         
            -
                    return image
         
     | 
| 1619 | 
         
            -
                
         
     | 
| 1620 | 
         
            -
                def ov_prepare_ref_latents(self, refimage, vae_scaling_factor=0.18215):
         
     | 
| 1621 | 
         
            -
                    #refimage = refimage.to(device=device, dtype=dtype)
         
     | 
| 1622 | 
         
            -
             
     | 
| 1623 | 
         
            -
                    # encode the mask image into latents space so we can concatenate it to the latents
         
     | 
| 1624 | 
         
            -
                    moments = self.vae_encoder(refimage)[0]
         
     | 
| 1625 | 
         
            -
                    mean, logvar = np.split(moments, 2, axis=1)
         
     | 
| 1626 | 
         
            -
                    std = np.exp(logvar * 0.5)
         
     | 
| 1627 | 
         
            -
                    ref_image_latents = (mean + std * np.random.randn(*mean.shape))
         
     | 
| 1628 | 
         
            -
                    ref_image_latents = vae_scaling_factor * ref_image_latents
         
     | 
| 1629 | 
         
            -
                    #ref_image_latents = scheduler.add_noise(torch.from_numpy(ref_image_latents), torch.from_numpy(noise), latent_timestep).numpy()
         
     | 
| 1630 | 
         
            -
                    
         
     | 
| 1631 | 
         
            -
                    # aligning device to prevent device errors when concating it with the latent model input
         
     | 
| 1632 | 
         
            -
                    #ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
         
     | 
| 1633 | 
         
            -
                    return ref_image_latents
         
     | 
| 1634 | 
         
            -
                
         
     | 
| 1635 | 
         
            -
                def prepare_latents(self, image:PIL.Image.Image = None, latent_timestep:torch.Tensor = None, scheduler = LMSDiscreteScheduler):
         
     | 
| 1636 | 
         
            -
                    """
         
     | 
| 1637 | 
         
            -
                    Function for getting initial latents for starting generation
         
     | 
| 1638 | 
         
            -
                    
         
     | 
| 1639 | 
         
            -
                    Parameters:
         
     | 
| 1640 | 
         
            -
                        image (PIL.Image.Image, *optional*, None):
         
     | 
| 1641 | 
         
            -
                            Input image for generation, if not provided randon noise will be used as starting point
         
     | 
| 1642 | 
         
            -
                        latent_timestep (torch.Tensor, *optional*, None):
         
     | 
| 1643 | 
         
            -
                            Predicted by scheduler initial step for image generation, required for latent image mixing with nosie
         
     | 
| 1644 | 
         
            -
                    Returns:
         
     | 
| 1645 | 
         
            -
                        latents (np.ndarray):
         
     | 
| 1646 | 
         
            -
                            Image encoded in latent space
         
     | 
| 1647 | 
         
            -
                    """
         
     | 
| 1648 | 
         
            -
                    latents_shape = (1, 4, self.height // 8, self.width // 8)
         
     | 
| 1649 | 
         
            -
               
         
     | 
| 1650 | 
         
            -
                    noise = np.random.randn(*latents_shape).astype(np.float32)
         
     | 
| 1651 | 
         
            -
                    if image is None:
         
     | 
| 1652 | 
         
            -
                        #print("Image is NONE")
         
     | 
| 1653 | 
         
            -
                        # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
         
     | 
| 1654 | 
         
            -
                        if isinstance(scheduler, LMSDiscreteScheduler):
         
     | 
| 1655 | 
         
            -
                         
         
     | 
| 1656 | 
         
            -
                            noise = noise * scheduler.sigmas[0].numpy()
         
     | 
| 1657 | 
         
            -
                            return noise, {}
         
     | 
| 1658 | 
         
            -
                        elif isinstance(scheduler, EulerDiscreteScheduler):
         
     | 
| 1659 | 
         
            -
                          
         
     | 
| 1660 | 
         
            -
                            noise = noise * scheduler.sigmas.max().numpy()
         
     | 
| 1661 | 
         
            -
                            return noise, {}
         
     | 
| 1662 | 
         
            -
                        else:
         
     | 
| 1663 | 
         
            -
                            return noise, {}
         
     | 
| 1664 | 
         
            -
                    input_image, meta = preprocess(image,self.height,self.width)
         
     | 
| 1665 | 
         
            -
                   
         
     | 
| 1666 | 
         
            -
                    moments = self.vae_encoder(input_image)[self._vae_e_output]
         
     | 
| 1667 | 
         
            -
                  
         
     | 
| 1668 | 
         
            -
                    mean, logvar = np.split(moments, 2, axis=1)
         
     | 
| 1669 | 
         
            -
              
         
     | 
| 1670 | 
         
            -
                    std = np.exp(logvar * 0.5)
         
     | 
| 1671 | 
         
            -
                    latents = (mean + std * np.random.randn(*mean.shape)) * 0.18215
         
     | 
| 1672 | 
         
            -
                   
         
     | 
| 1673 | 
         
            -
                     
         
     | 
| 1674 | 
         
            -
                    latents = scheduler.add_noise(torch.from_numpy(latents), torch.from_numpy(noise), latent_timestep).numpy()
         
     | 
| 1675 | 
         
            -
                    return latents, meta
         
     | 
| 1676 | 
         
            -
             
     | 
| 1677 | 
         
            -
                def postprocess_image(self, image:np.ndarray, meta:Dict):
         
     | 
| 1678 | 
         
            -
                    """
         
     | 
| 1679 | 
         
            -
                    Postprocessing for decoded image. Takes generated image decoded by VAE decoder, unpad it to initila image size (if required), 
         
     | 
| 1680 | 
         
            -
                    normalize and convert to [0, 255] pixels range. Optionally, convertes it from np.ndarray to PIL.Image format
         
     | 
| 1681 | 
         
            -
                    
         
     | 
| 1682 | 
         
            -
                    Parameters:
         
     | 
| 1683 | 
         
            -
                        image (np.ndarray):
         
     | 
| 1684 | 
         
            -
                            Generated image
         
     | 
| 1685 | 
         
            -
                        meta (Dict):
         
     | 
| 1686 | 
         
            -
                            Metadata obtained on latents preparing step, can be empty
         
     | 
| 1687 | 
         
            -
                        output_type (str, *optional*, pil):
         
     | 
| 1688 | 
         
            -
                            Output format for result, can be pil or numpy
         
     | 
| 1689 | 
         
            -
                    Returns:
         
     | 
| 1690 | 
         
            -
                        image (List of np.ndarray or PIL.Image.Image):
         
     | 
| 1691 | 
         
            -
                            Postprocessed images
         
     | 
| 1692 | 
         
            -
             
     | 
| 1693 | 
         
            -
                                    if "src_height" in meta:
         
     | 
| 1694 | 
         
            -
                        orig_height, orig_width = meta["src_height"], meta["src_width"]
         
     | 
| 1695 | 
         
            -
                        image = [cv2.resize(img, (orig_width, orig_height))
         
     | 
| 1696 | 
         
            -
                                    for img in image]
         
     | 
| 1697 | 
         
            -
                
         
     | 
| 1698 | 
         
            -
                    return image
         
     | 
| 1699 | 
         
            -
                    """
         
     | 
| 1700 | 
         
            -
                    if "padding" in meta:
         
     | 
| 1701 | 
         
            -
                        pad = meta["padding"]
         
     | 
| 1702 | 
         
            -
                        (_, end_h), (_, end_w) = pad[1:3]
         
     | 
| 1703 | 
         
            -
                        h, w = image.shape[2:]
         
     | 
| 1704 | 
         
            -
                        #print("image shape",image.shape[2:])
         
     | 
| 1705 | 
         
            -
                        unpad_h = h - end_h
         
     | 
| 1706 | 
         
            -
                        unpad_w = w - end_w
         
     | 
| 1707 | 
         
            -
                        image = image[:, :, :unpad_h, :unpad_w]
         
     | 
| 1708 | 
         
            -
                    image = np.clip(image / 2 + 0.5, 0, 1)
         
     | 
| 1709 | 
         
            -
                    image = (image[0].transpose(1, 2, 0)[:, :, ::-1] * 255).astype(np.uint8)
         
     | 
| 1710 | 
         
            -
             
     | 
| 1711 | 
         
            -
                       
         
     | 
| 1712 | 
         
            -
             
     | 
| 1713 | 
         
            -
                    if "src_height" in meta:
         
     | 
| 1714 | 
         
            -
                        orig_height, orig_width = meta["src_height"], meta["src_width"]
         
     | 
| 1715 | 
         
            -
                        image = cv2.resize(image, (orig_width, orig_height))
         
     | 
| 1716 | 
         
            -
                                    
         
     | 
| 1717 | 
         
            -
                    return image
         
     | 
| 1718 | 
         
            -
             
     | 
| 1719 | 
         
            -
                    
         
     | 
| 1720 | 
         
            -
                                  #image = (image / 2 + 0.5).clip(0, 1)
         
     | 
| 1721 | 
         
            -
                    #image = (image[0].transpose(1, 2, 0)[:, :, ::-1] * 255).astype(np.uint8)   
         
     | 
| 1722 | 
         
            -
             
     | 
| 1723 | 
         
            -
             
     | 
| 1724 | 
         
            -
                def get_timesteps(self, num_inference_steps:int, strength:float, scheduler):
         
     | 
| 1725 | 
         
            -
                    """
         
     | 
| 1726 | 
         
            -
                    Helper function for getting scheduler timesteps for generation
         
     | 
| 1727 | 
         
            -
                    In case of image-to-image generation, it updates number of steps according to strength
         
     | 
| 1728 | 
         
            -
                    
         
     | 
| 1729 | 
         
            -
                    Parameters:
         
     | 
| 1730 | 
         
            -
                       num_inference_steps (int):
         
     | 
| 1731 | 
         
            -
                          number of inference steps for generation
         
     | 
| 1732 | 
         
            -
                       strength (float):
         
     | 
| 1733 | 
         
            -
                           value between 0.0 and 1.0, that controls the amount of noise that is added to the input image. 
         
     | 
| 1734 | 
         
            -
                           Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input.
         
     | 
| 1735 | 
         
            -
                    """
         
     | 
| 1736 | 
         
            -
                    # get the original timestep using init_timestep
         
     | 
| 1737 | 
         
            -
               
         
     | 
| 1738 | 
         
            -
                    init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
         
     | 
| 1739 | 
         
            -
                
         
     | 
| 1740 | 
         
            -
                    t_start = max(num_inference_steps - init_timestep, 0)
         
     | 
| 1741 | 
         
            -
                    timesteps = scheduler.timesteps[t_start:]
         
     | 
| 1742 | 
         
            -
             
     | 
| 1743 | 
         
            -
                    return timesteps, num_inference_steps - t_start 
         
     | 
| 1744 | 
         
            -
                def prepare_image(
         
     | 
| 1745 | 
         
            -
                    self,
         
     | 
| 1746 | 
         
            -
                    image,
         
     | 
| 1747 | 
         
            -
                    width,
         
     | 
| 1748 | 
         
            -
                    height,
         
     | 
| 1749 | 
         
            -
                    do_classifier_free_guidance=False,
         
     | 
| 1750 | 
         
            -
                    guess_mode=False,
         
     | 
| 1751 | 
         
            -
                ):
         
     | 
| 1752 | 
         
            -
                    if not isinstance(image, np.ndarray):
         
     | 
| 1753 | 
         
            -
                        if isinstance(image, PIL.Image.Image):
         
     | 
| 1754 | 
         
            -
                            image = [image]
         
     | 
| 1755 | 
         
            -
             
     | 
| 1756 | 
         
            -
                        if isinstance(image[0], PIL.Image.Image):
         
     | 
| 1757 | 
         
            -
                            images = []
         
     | 
| 1758 | 
         
            -
             
     | 
| 1759 | 
         
            -
                            for image_ in image:
         
     | 
| 1760 | 
         
            -
                                image_ = image_.convert("RGB")
         
     | 
| 1761 | 
         
            -
                                image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
         
     | 
| 1762 | 
         
            -
                                image_ = np.array(image_)
         
     | 
| 1763 | 
         
            -
                                image_ = image_[None, :]
         
     | 
| 1764 | 
         
            -
                                images.append(image_)
         
     | 
| 1765 | 
         
            -
             
     | 
| 1766 | 
         
            -
                            image = images
         
     | 
| 1767 | 
         
            -
             
     | 
| 1768 | 
         
            -
                            image = np.concatenate(image, axis=0)
         
     | 
| 1769 | 
         
            -
                            image = np.array(image).astype(np.float32) / 255.0
         
     | 
| 1770 | 
         
            -
                            image = (image - 0.5) / 0.5
         
     | 
| 1771 | 
         
            -
                            image = image.transpose(0, 3, 1, 2)
         
     | 
| 1772 | 
         
            -
                        elif isinstance(image[0], np.ndarray):
         
     | 
| 1773 | 
         
            -
                            image = np.concatenate(image, dim=0)
         
     | 
| 1774 | 
         
            -
             
     | 
| 1775 | 
         
            -
                    if do_classifier_free_guidance and not guess_mode:
         
     | 
| 1776 | 
         
            -
                        image = np.concatenate([image] * 2)
         
     | 
| 1777 | 
         
            -
             
     | 
| 1778 | 
         
            -
                    return image
         
     | 
| 1779 | 
         
            -
             
     | 
| 1780 | 
         
            -
            def print_npu_turbo_art():
         
     | 
| 1781 | 
         
            -
                random_number = random.randint(1, 3)
         
     | 
| 1782 | 
         
            -
                
         
     | 
| 1783 | 
         
            -
                if random_number == 1:
         
     | 
| 1784 | 
         
            -
                    print("                                                                                                                      ")
         
     | 
| 1785 | 
         
            -
                    print("      ___           ___         ___                                ___           ___                         ___      ")
         
     | 
| 1786 | 
         
            -
                    print("     /\  \         /\  \       /\  \                              /\  \         /\  \         _____         /\  \     ")
         
     | 
| 1787 | 
         
            -
                    print("     \:\  \       /::\  \      \:\  \                ___          \:\  \       /::\  \       /::\  \       /::\  \    ")
         
     | 
| 1788 | 
         
            -
                    print("      \:\  \     /:/\:\__\      \:\  \              /\__\          \:\  \     /:/\:\__\     /:/\:\  \     /:/\:\  \   ")
         
     | 
| 1789 | 
         
            -
                    print("  _____\:\  \   /:/ /:/  /  ___  \:\  \            /:/  /      ___  \:\  \   /:/ /:/  /    /:/ /::\__\   /:/  \:\  \  ")
         
     | 
| 1790 | 
         
            -
                    print(" /::::::::\__\ /:/_/:/  /  /\  \  \:\__\          /:/__/      /\  \  \:\__\ /:/_/:/__/___ /:/_/:/\:|__| /:/__/ \:\__\ ")
         
     | 
| 1791 | 
         
            -
                    print(" \:\~~\~~\/__/ \:\/:/  /   \:\  \ /:/  /         /::\  \      \:\  \ /:/  / \:\/:::::/  / \:\/:/ /:/  / \:\  \ /:/  / ")
         
     | 
| 1792 | 
         
            -
                    print("  \:\  \        \::/__/     \:\  /:/  /         /:/\:\  \      \:\  /:/  /   \::/~~/~~~~   \::/_/:/  /   \:\  /:/  /  ")
         
     | 
| 1793 | 
         
            -
                    print("   \:\  \        \:\  \      \:\/:/  /          \/__\:\  \      \:\/:/  /     \:\~~\        \:\/:/  /     \:\/:/  /   ")
         
     | 
| 1794 | 
         
            -
                    print("    \:\__\        \:\__\      \::/  /                \:\__\      \::/  /       \:\__\        \::/  /       \::/  /    ")
         
     | 
| 1795 | 
         
            -
                    print("     \/__/         \/__/       \/__/                  \/__/       \/__/         \/__/         \/__/         \/__/     ")
         
     | 
| 1796 | 
         
            -
                    print("                                                                                                                      ")
         
     | 
| 1797 | 
         
            -
                elif random_number == 2:
         
     | 
| 1798 | 
         
            -
                    print(" _   _   ____    _   _     _____   _   _   ____    ____     ___  ")
         
     | 
| 1799 | 
         
            -
                    print("| \ | | |  _ \  | | | |   |_   _| | | | | |  _ \  | __ )   / _ \ ")
         
     | 
| 1800 | 
         
            -
                    print("|  \| | | |_) | | | | |     | |   | | | | | |_) | |  _ \  | | | |")
         
     | 
| 1801 | 
         
            -
                    print("| |\  | |  __/  | |_| |     | |   | |_| | |  _ <  | |_) | | |_| |")
         
     | 
| 1802 | 
         
            -
                    print("|_| \_| |_|      \___/      |_|    \___/  |_| \_\ |____/   \___/ ")
         
     | 
| 1803 | 
         
            -
                    print("                                                                 ")
         
     | 
| 1804 | 
         
            -
                else:
         
     | 
| 1805 | 
         
            -
                    print("")
         
     | 
| 1806 | 
         
            -
                    print("    )   (                                 (                )   ")
         
     | 
| 1807 | 
         
            -
                    print(" ( /(   )\ )              *   )           )\ )     (    ( /(   ")
         
     | 
| 1808 | 
         
            -
                    print(" )\()) (()/(      (     ` )  /(      (   (()/(   ( )\   )\())  ")
         
     | 
| 1809 | 
         
            -
                    print("((_)\   /(_))     )\     ( )(_))     )\   /(_))  )((_) ((_)\   ")
         
     | 
| 1810 | 
         
            -
                    print(" _((_) (_))    _ ((_)   (_(_())   _ ((_) (_))   ((_)_    ((_)  ")
         
     | 
| 1811 | 
         
            -
                    print("| \| | | _ \  | | | |   |_   _|  | | | | | _ \   | _ )  / _ \  ")
         
     | 
| 1812 | 
         
            -
                    print("| .` | |  _/  | |_| |     | |    | |_| | |   /   | _ \ | (_) | ")
         
     | 
| 1813 | 
         
            -
                    print("|_|\_| |_|     \___/      |_|     \___/  |_|_\   |___/  \___/  ")
         
     | 
| 1814 | 
         
            -
                    print("                                                               ")
         
     | 
| 1815 | 
         
            -
             
     | 
| 1816 | 
         
            -
             
     | 
| 1817 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/pipelines/lcm.py
    DELETED
    
    | 
         @@ -1,122 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from constants import LCM_DEFAULT_MODEL
         
     | 
| 2 | 
         
            -
            from diffusers import (
         
     | 
| 3 | 
         
            -
                DiffusionPipeline,
         
     | 
| 4 | 
         
            -
                AutoencoderTiny,
         
     | 
| 5 | 
         
            -
                UNet2DConditionModel,
         
     | 
| 6 | 
         
            -
                LCMScheduler,
         
     | 
| 7 | 
         
            -
                StableDiffusionPipeline,
         
     | 
| 8 | 
         
            -
            )
         
     | 
| 9 | 
         
            -
            import torch
         
     | 
| 10 | 
         
            -
            from backend.tiny_decoder import get_tiny_decoder_vae_model
         
     | 
| 11 | 
         
            -
            from typing import Any
         
     | 
| 12 | 
         
            -
            from diffusers import (
         
     | 
| 13 | 
         
            -
                LCMScheduler,
         
     | 
| 14 | 
         
            -
                StableDiffusionImg2ImgPipeline,
         
     | 
| 15 | 
         
            -
                StableDiffusionXLImg2ImgPipeline,
         
     | 
| 16 | 
         
            -
                AutoPipelineForText2Image,
         
     | 
| 17 | 
         
            -
                AutoPipelineForImage2Image,
         
     | 
| 18 | 
         
            -
                StableDiffusionControlNetPipeline,
         
     | 
| 19 | 
         
            -
            )
         
     | 
| 20 | 
         
            -
            import pathlib
         
     | 
| 21 | 
         
            -
             
     | 
| 22 | 
         
            -
             
     | 
| 23 | 
         
            -
            def _get_lcm_pipeline_from_base_model(
         
     | 
| 24 | 
         
            -
                lcm_model_id: str,
         
     | 
| 25 | 
         
            -
                base_model_id: str,
         
     | 
| 26 | 
         
            -
                use_local_model: bool,
         
     | 
| 27 | 
         
            -
            ):
         
     | 
| 28 | 
         
            -
                pipeline = None
         
     | 
| 29 | 
         
            -
                unet = UNet2DConditionModel.from_pretrained(
         
     | 
| 30 | 
         
            -
                    lcm_model_id,
         
     | 
| 31 | 
         
            -
                    torch_dtype=torch.float32,
         
     | 
| 32 | 
         
            -
                    local_files_only=use_local_model,
         
     | 
| 33 | 
         
            -
                    resume_download=True,
         
     | 
| 34 | 
         
            -
                )
         
     | 
| 35 | 
         
            -
                pipeline = DiffusionPipeline.from_pretrained(
         
     | 
| 36 | 
         
            -
                    base_model_id,
         
     | 
| 37 | 
         
            -
                    unet=unet,
         
     | 
| 38 | 
         
            -
                    torch_dtype=torch.float32,
         
     | 
| 39 | 
         
            -
                    local_files_only=use_local_model,
         
     | 
| 40 | 
         
            -
                    resume_download=True,
         
     | 
| 41 | 
         
            -
                )
         
     | 
| 42 | 
         
            -
                pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
         
     | 
| 43 | 
         
            -
                return pipeline
         
     | 
| 44 | 
         
            -
             
     | 
| 45 | 
         
            -
             
     | 
| 46 | 
         
            -
            def load_taesd(
         
     | 
| 47 | 
         
            -
                pipeline: Any,
         
     | 
| 48 | 
         
            -
                use_local_model: bool = False,
         
     | 
| 49 | 
         
            -
                torch_data_type: torch.dtype = torch.float32,
         
     | 
| 50 | 
         
            -
            ):
         
     | 
| 51 | 
         
            -
                vae_model = get_tiny_decoder_vae_model(pipeline.__class__.__name__)
         
     | 
| 52 | 
         
            -
                pipeline.vae = AutoencoderTiny.from_pretrained(
         
     | 
| 53 | 
         
            -
                    vae_model,
         
     | 
| 54 | 
         
            -
                    torch_dtype=torch_data_type,
         
     | 
| 55 | 
         
            -
                    local_files_only=use_local_model,
         
     | 
| 56 | 
         
            -
                )
         
     | 
| 57 | 
         
            -
             
     | 
| 58 | 
         
            -
             
     | 
| 59 | 
         
            -
            def get_lcm_model_pipeline(
         
     | 
| 60 | 
         
            -
                model_id: str = LCM_DEFAULT_MODEL,
         
     | 
| 61 | 
         
            -
                use_local_model: bool = False,
         
     | 
| 62 | 
         
            -
                pipeline_args={},
         
     | 
| 63 | 
         
            -
            ):
         
     | 
| 64 | 
         
            -
                pipeline = None
         
     | 
| 65 | 
         
            -
                if model_id == "latent-consistency/lcm-sdxl":
         
     | 
| 66 | 
         
            -
                    pipeline = _get_lcm_pipeline_from_base_model(
         
     | 
| 67 | 
         
            -
                        model_id,
         
     | 
| 68 | 
         
            -
                        "stabilityai/stable-diffusion-xl-base-1.0",
         
     | 
| 69 | 
         
            -
                        use_local_model,
         
     | 
| 70 | 
         
            -
                    )
         
     | 
| 71 | 
         
            -
             
     | 
| 72 | 
         
            -
                elif model_id == "latent-consistency/lcm-ssd-1b":
         
     | 
| 73 | 
         
            -
                    pipeline = _get_lcm_pipeline_from_base_model(
         
     | 
| 74 | 
         
            -
                        model_id,
         
     | 
| 75 | 
         
            -
                        "segmind/SSD-1B",
         
     | 
| 76 | 
         
            -
                        use_local_model,
         
     | 
| 77 | 
         
            -
                    )
         
     | 
| 78 | 
         
            -
                elif pathlib.Path(model_id).suffix == ".safetensors":
         
     | 
| 79 | 
         
            -
                    # When loading a .safetensors model, the pipeline has to be created
         
     | 
| 80 | 
         
            -
                    # with StableDiffusionPipeline() since it's the only class that
         
     | 
| 81 | 
         
            -
                    # defines the method from_single_file()
         
     | 
| 82 | 
         
            -
                    dummy_pipeline = StableDiffusionPipeline.from_single_file(
         
     | 
| 83 | 
         
            -
                        model_id,
         
     | 
| 84 | 
         
            -
                        safety_checker=None,
         
     | 
| 85 | 
         
            -
                        run_safety_checker=False,
         
     | 
| 86 | 
         
            -
                        load_safety_checker=False,
         
     | 
| 87 | 
         
            -
                        local_files_only=use_local_model,
         
     | 
| 88 | 
         
            -
                        use_safetensors=True,
         
     | 
| 89 | 
         
            -
                    )
         
     | 
| 90 | 
         
            -
                    if 'lcm' in model_id.lower():
         
     | 
| 91 | 
         
            -
                        dummy_pipeline.scheduler = LCMScheduler.from_config(dummy_pipeline.scheduler.config)
         
     | 
| 92 | 
         
            -
             
     | 
| 93 | 
         
            -
                    pipeline = AutoPipelineForText2Image.from_pipe(
         
     | 
| 94 | 
         
            -
                        dummy_pipeline,
         
     | 
| 95 | 
         
            -
                        **pipeline_args,
         
     | 
| 96 | 
         
            -
                    )
         
     | 
| 97 | 
         
            -
                    del dummy_pipeline
         
     | 
| 98 | 
         
            -
                else:
         
     | 
| 99 | 
         
            -
                    # pipeline = DiffusionPipeline.from_pretrained(
         
     | 
| 100 | 
         
            -
                    pipeline = AutoPipelineForText2Image.from_pretrained(
         
     | 
| 101 | 
         
            -
                        model_id,
         
     | 
| 102 | 
         
            -
                        local_files_only=use_local_model,
         
     | 
| 103 | 
         
            -
                        **pipeline_args,
         
     | 
| 104 | 
         
            -
                    )
         
     | 
| 105 | 
         
            -
             
     | 
| 106 | 
         
            -
                return pipeline
         
     | 
| 107 | 
         
            -
             
     | 
| 108 | 
         
            -
             
     | 
| 109 | 
         
            -
            def get_image_to_image_pipeline(pipeline: Any) -> Any:
         
     | 
| 110 | 
         
            -
                components = pipeline.components
         
     | 
| 111 | 
         
            -
                pipeline_class = pipeline.__class__.__name__
         
     | 
| 112 | 
         
            -
                if (
         
     | 
| 113 | 
         
            -
                    pipeline_class == "LatentConsistencyModelPipeline"
         
     | 
| 114 | 
         
            -
                    or pipeline_class == "StableDiffusionPipeline"
         
     | 
| 115 | 
         
            -
                ):
         
     | 
| 116 | 
         
            -
                    return StableDiffusionImg2ImgPipeline(**components)
         
     | 
| 117 | 
         
            -
                elif pipeline_class == "StableDiffusionControlNetPipeline":
         
     | 
| 118 | 
         
            -
                    return AutoPipelineForImage2Image.from_pipe(pipeline)
         
     | 
| 119 | 
         
            -
                elif pipeline_class == "StableDiffusionXLPipeline":
         
     | 
| 120 | 
         
            -
                    return StableDiffusionXLImg2ImgPipeline(**components)
         
     | 
| 121 | 
         
            -
                else:
         
     | 
| 122 | 
         
            -
                    raise Exception(f"Unknown pipeline {pipeline_class}")
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/pipelines/lcm_lora.py
    DELETED
    
    | 
         @@ -1,81 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            import pathlib
         
     | 
| 2 | 
         
            -
            from os import path
         
     | 
| 3 | 
         
            -
             
     | 
| 4 | 
         
            -
            import torch
         
     | 
| 5 | 
         
            -
            from diffusers import (
         
     | 
| 6 | 
         
            -
                AutoPipelineForText2Image,
         
     | 
| 7 | 
         
            -
                LCMScheduler,
         
     | 
| 8 | 
         
            -
                StableDiffusionPipeline,
         
     | 
| 9 | 
         
            -
            )
         
     | 
| 10 | 
         
            -
             
     | 
| 11 | 
         
            -
             
     | 
| 12 | 
         
            -
            def load_lcm_weights(
         
     | 
| 13 | 
         
            -
                pipeline,
         
     | 
| 14 | 
         
            -
                use_local_model,
         
     | 
| 15 | 
         
            -
                lcm_lora_id,
         
     | 
| 16 | 
         
            -
            ):
         
     | 
| 17 | 
         
            -
                kwargs = {
         
     | 
| 18 | 
         
            -
                    "local_files_only": use_local_model,
         
     | 
| 19 | 
         
            -
                    "weight_name": "pytorch_lora_weights.safetensors",
         
     | 
| 20 | 
         
            -
                }
         
     | 
| 21 | 
         
            -
                pipeline.load_lora_weights(
         
     | 
| 22 | 
         
            -
                    lcm_lora_id,
         
     | 
| 23 | 
         
            -
                    **kwargs,
         
     | 
| 24 | 
         
            -
                    adapter_name="lcm",
         
     | 
| 25 | 
         
            -
                )
         
     | 
| 26 | 
         
            -
             
     | 
| 27 | 
         
            -
             
     | 
| 28 | 
         
            -
            def get_lcm_lora_pipeline(
         
     | 
| 29 | 
         
            -
                base_model_id: str,
         
     | 
| 30 | 
         
            -
                lcm_lora_id: str,
         
     | 
| 31 | 
         
            -
                use_local_model: bool,
         
     | 
| 32 | 
         
            -
                torch_data_type: torch.dtype,
         
     | 
| 33 | 
         
            -
                pipeline_args={},
         
     | 
| 34 | 
         
            -
            ):
         
     | 
| 35 | 
         
            -
                if pathlib.Path(base_model_id).suffix == ".safetensors":
         
     | 
| 36 | 
         
            -
                    # SD 1.5 models only
         
     | 
| 37 | 
         
            -
                    # When loading a .safetensors model, the pipeline has to be created
         
     | 
| 38 | 
         
            -
                    # with StableDiffusionPipeline() since it's the only class that
         
     | 
| 39 | 
         
            -
                    # defines the method from_single_file(); afterwards a new pipeline
         
     | 
| 40 | 
         
            -
                    # is created using AutoPipelineForText2Image() for ControlNet
         
     | 
| 41 | 
         
            -
                    # support, in case ControlNet is enabled
         
     | 
| 42 | 
         
            -
                    if not path.exists(base_model_id):
         
     | 
| 43 | 
         
            -
                        raise FileNotFoundError(
         
     | 
| 44 | 
         
            -
                            f"Model file not found,Please check your model path: {base_model_id}"
         
     | 
| 45 | 
         
            -
                        )
         
     | 
| 46 | 
         
            -
                    print("Using single file Safetensors model (Supported models - SD 1.5 models)")
         
     | 
| 47 | 
         
            -
             
     | 
| 48 | 
         
            -
                    dummy_pipeline = StableDiffusionPipeline.from_single_file(
         
     | 
| 49 | 
         
            -
                        base_model_id,
         
     | 
| 50 | 
         
            -
                        torch_dtype=torch_data_type,
         
     | 
| 51 | 
         
            -
                        safety_checker=None,
         
     | 
| 52 | 
         
            -
                        local_files_only=use_local_model,
         
     | 
| 53 | 
         
            -
                        use_safetensors=True,
         
     | 
| 54 | 
         
            -
                    )
         
     | 
| 55 | 
         
            -
                    pipeline = AutoPipelineForText2Image.from_pipe(
         
     | 
| 56 | 
         
            -
                        dummy_pipeline,
         
     | 
| 57 | 
         
            -
                        **pipeline_args,
         
     | 
| 58 | 
         
            -
                    )
         
     | 
| 59 | 
         
            -
                    del dummy_pipeline
         
     | 
| 60 | 
         
            -
                else:
         
     | 
| 61 | 
         
            -
                    pipeline = AutoPipelineForText2Image.from_pretrained(
         
     | 
| 62 | 
         
            -
                        base_model_id,
         
     | 
| 63 | 
         
            -
                        torch_dtype=torch_data_type,
         
     | 
| 64 | 
         
            -
                        local_files_only=use_local_model,
         
     | 
| 65 | 
         
            -
                        **pipeline_args,
         
     | 
| 66 | 
         
            -
                    )
         
     | 
| 67 | 
         
            -
             
     | 
| 68 | 
         
            -
                load_lcm_weights(
         
     | 
| 69 | 
         
            -
                    pipeline,
         
     | 
| 70 | 
         
            -
                    use_local_model,
         
     | 
| 71 | 
         
            -
                    lcm_lora_id,
         
     | 
| 72 | 
         
            -
                )
         
     | 
| 73 | 
         
            -
                # Always fuse LCM-LoRA
         
     | 
| 74 | 
         
            -
                # pipeline.fuse_lora()
         
     | 
| 75 | 
         
            -
             
     | 
| 76 | 
         
            -
                if "lcm" in lcm_lora_id.lower() or "hypersd" in lcm_lora_id.lower():
         
     | 
| 77 | 
         
            -
                    print("LCM LoRA model detected so using recommended LCMScheduler")
         
     | 
| 78 | 
         
            -
                    pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
         
     | 
| 79 | 
         
            -
             
     | 
| 80 | 
         
            -
                # pipeline.unet.to(memory_format=torch.channels_last)
         
     | 
| 81 | 
         
            -
                return pipeline
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/tiny_decoder.py
    DELETED
    
    | 
         @@ -1,32 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from constants import (
         
     | 
| 2 | 
         
            -
                TAESD_MODEL,
         
     | 
| 3 | 
         
            -
                TAESDXL_MODEL,
         
     | 
| 4 | 
         
            -
                TAESD_MODEL_OPENVINO,
         
     | 
| 5 | 
         
            -
                TAESDXL_MODEL_OPENVINO,
         
     | 
| 6 | 
         
            -
            )
         
     | 
| 7 | 
         
            -
             
     | 
| 8 | 
         
            -
             
     | 
| 9 | 
         
            -
            def get_tiny_decoder_vae_model(pipeline_class) -> str:
         
     | 
| 10 | 
         
            -
                print(f"Pipeline class : {pipeline_class}")
         
     | 
| 11 | 
         
            -
                if (
         
     | 
| 12 | 
         
            -
                    pipeline_class == "LatentConsistencyModelPipeline"
         
     | 
| 13 | 
         
            -
                    or pipeline_class == "StableDiffusionPipeline"
         
     | 
| 14 | 
         
            -
                    or pipeline_class == "StableDiffusionImg2ImgPipeline"
         
     | 
| 15 | 
         
            -
                    or pipeline_class == "StableDiffusionControlNetPipeline"
         
     | 
| 16 | 
         
            -
                    or pipeline_class == "StableDiffusionControlNetImg2ImgPipeline"
         
     | 
| 17 | 
         
            -
                ):
         
     | 
| 18 | 
         
            -
                    return TAESD_MODEL
         
     | 
| 19 | 
         
            -
                elif (
         
     | 
| 20 | 
         
            -
                    pipeline_class == "StableDiffusionXLPipeline"
         
     | 
| 21 | 
         
            -
                    or pipeline_class == "StableDiffusionXLImg2ImgPipeline"
         
     | 
| 22 | 
         
            -
                ):
         
     | 
| 23 | 
         
            -
                    return TAESDXL_MODEL
         
     | 
| 24 | 
         
            -
                elif (
         
     | 
| 25 | 
         
            -
                    pipeline_class == "OVStableDiffusionPipeline"
         
     | 
| 26 | 
         
            -
                    or pipeline_class == "OVStableDiffusionImg2ImgPipeline"
         
     | 
| 27 | 
         
            -
                ):
         
     | 
| 28 | 
         
            -
                    return TAESD_MODEL_OPENVINO
         
     | 
| 29 | 
         
            -
                elif pipeline_class == "OVStableDiffusionXLPipeline":
         
     | 
| 30 | 
         
            -
                    return TAESDXL_MODEL_OPENVINO
         
     | 
| 31 | 
         
            -
                else:
         
     | 
| 32 | 
         
            -
                    raise Exception("No valid pipeline class found!")
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/upscale/aura_sr.py
    DELETED
    
    | 
         @@ -1,1004 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            # AuraSR: GAN-based Super-Resolution for real-world, a reproduction of the GigaGAN* paper. Implementation is
         
     | 
| 2 | 
         
            -
            # based on the unofficial lucidrains/gigagan-pytorch repository. Heavily modified from there.
         
     | 
| 3 | 
         
            -
            #
         
     | 
| 4 | 
         
            -
            # https://mingukkang.github.io/GigaGAN/
         
     | 
| 5 | 
         
            -
            from math import log2, ceil
         
     | 
| 6 | 
         
            -
            from functools import partial
         
     | 
| 7 | 
         
            -
            from typing import Any, Optional, List, Iterable
         
     | 
| 8 | 
         
            -
             
     | 
| 9 | 
         
            -
            import torch
         
     | 
| 10 | 
         
            -
            from torchvision import transforms
         
     | 
| 11 | 
         
            -
            from PIL import Image
         
     | 
| 12 | 
         
            -
            from torch import nn, einsum, Tensor
         
     | 
| 13 | 
         
            -
            import torch.nn.functional as F
         
     | 
| 14 | 
         
            -
             
     | 
| 15 | 
         
            -
            from einops import rearrange, repeat, reduce
         
     | 
| 16 | 
         
            -
            from einops.layers.torch import Rearrange
         
     | 
| 17 | 
         
            -
            from torchvision.utils import save_image
         
     | 
| 18 | 
         
            -
            import math
         
     | 
| 19 | 
         
            -
             
     | 
| 20 | 
         
            -
             
     | 
| 21 | 
         
            -
            def get_same_padding(size, kernel, dilation, stride):
         
     | 
| 22 | 
         
            -
                return ((size - 1) * (stride - 1) + dilation * (kernel - 1)) // 2
         
     | 
| 23 | 
         
            -
             
     | 
| 24 | 
         
            -
             
     | 
| 25 | 
         
            -
            class AdaptiveConv2DMod(nn.Module):
         
     | 
| 26 | 
         
            -
                def __init__(
         
     | 
| 27 | 
         
            -
                    self,
         
     | 
| 28 | 
         
            -
                    dim,
         
     | 
| 29 | 
         
            -
                    dim_out,
         
     | 
| 30 | 
         
            -
                    kernel,
         
     | 
| 31 | 
         
            -
                    *,
         
     | 
| 32 | 
         
            -
                    demod=True,
         
     | 
| 33 | 
         
            -
                    stride=1,
         
     | 
| 34 | 
         
            -
                    dilation=1,
         
     | 
| 35 | 
         
            -
                    eps=1e-8,
         
     | 
| 36 | 
         
            -
                    num_conv_kernels=1,  # set this to be greater than 1 for adaptive
         
     | 
| 37 | 
         
            -
                ):
         
     | 
| 38 | 
         
            -
                    super().__init__()
         
     | 
| 39 | 
         
            -
                    self.eps = eps
         
     | 
| 40 | 
         
            -
             
     | 
| 41 | 
         
            -
                    self.dim_out = dim_out
         
     | 
| 42 | 
         
            -
             
     | 
| 43 | 
         
            -
                    self.kernel = kernel
         
     | 
| 44 | 
         
            -
                    self.stride = stride
         
     | 
| 45 | 
         
            -
                    self.dilation = dilation
         
     | 
| 46 | 
         
            -
                    self.adaptive = num_conv_kernels > 1
         
     | 
| 47 | 
         
            -
             
     | 
| 48 | 
         
            -
                    self.weights = nn.Parameter(
         
     | 
| 49 | 
         
            -
                        torch.randn((num_conv_kernels, dim_out, dim, kernel, kernel))
         
     | 
| 50 | 
         
            -
                    )
         
     | 
| 51 | 
         
            -
             
     | 
| 52 | 
         
            -
                    self.demod = demod
         
     | 
| 53 | 
         
            -
             
     | 
| 54 | 
         
            -
                    nn.init.kaiming_normal_(
         
     | 
| 55 | 
         
            -
                        self.weights, a=0, mode="fan_in", nonlinearity="leaky_relu"
         
     | 
| 56 | 
         
            -
                    )
         
     | 
| 57 | 
         
            -
             
     | 
| 58 | 
         
            -
                def forward(
         
     | 
| 59 | 
         
            -
                    self, fmap, mod: Optional[Tensor] = None, kernel_mod: Optional[Tensor] = None
         
     | 
| 60 | 
         
            -
                ):
         
     | 
| 61 | 
         
            -
                    """
         
     | 
| 62 | 
         
            -
                    notation
         
     | 
| 63 | 
         
            -
             
     | 
| 64 | 
         
            -
                    b - batch
         
     | 
| 65 | 
         
            -
                    n - convs
         
     | 
| 66 | 
         
            -
                    o - output
         
     | 
| 67 | 
         
            -
                    i - input
         
     | 
| 68 | 
         
            -
                    k - kernel
         
     | 
| 69 | 
         
            -
                    """
         
     | 
| 70 | 
         
            -
             
     | 
| 71 | 
         
            -
                    b, h = fmap.shape[0], fmap.shape[-2]
         
     | 
| 72 | 
         
            -
             
     | 
| 73 | 
         
            -
                    # account for feature map that has been expanded by the scale in the first dimension
         
     | 
| 74 | 
         
            -
                    # due to multiscale inputs and outputs
         
     | 
| 75 | 
         
            -
             
     | 
| 76 | 
         
            -
                    if mod.shape[0] != b:
         
     | 
| 77 | 
         
            -
                        mod = repeat(mod, "b ... -> (s b) ...", s=b // mod.shape[0])
         
     | 
| 78 | 
         
            -
             
     | 
| 79 | 
         
            -
                    if exists(kernel_mod):
         
     | 
| 80 | 
         
            -
                        kernel_mod_has_el = kernel_mod.numel() > 0
         
     | 
| 81 | 
         
            -
             
     | 
| 82 | 
         
            -
                        assert self.adaptive or not kernel_mod_has_el
         
     | 
| 83 | 
         
            -
             
     | 
| 84 | 
         
            -
                        if kernel_mod_has_el and kernel_mod.shape[0] != b:
         
     | 
| 85 | 
         
            -
                            kernel_mod = repeat(
         
     | 
| 86 | 
         
            -
                                kernel_mod, "b ... -> (s b) ...", s=b // kernel_mod.shape[0]
         
     | 
| 87 | 
         
            -
                            )
         
     | 
| 88 | 
         
            -
             
     | 
| 89 | 
         
            -
                    # prepare weights for modulation
         
     | 
| 90 | 
         
            -
             
     | 
| 91 | 
         
            -
                    weights = self.weights
         
     | 
| 92 | 
         
            -
             
     | 
| 93 | 
         
            -
                    if self.adaptive:
         
     | 
| 94 | 
         
            -
                        weights = repeat(weights, "... -> b ...", b=b)
         
     | 
| 95 | 
         
            -
             
     | 
| 96 | 
         
            -
                        # determine an adaptive weight and 'select' the kernel to use with softmax
         
     | 
| 97 | 
         
            -
             
     | 
| 98 | 
         
            -
                        assert exists(kernel_mod) and kernel_mod.numel() > 0
         
     | 
| 99 | 
         
            -
             
     | 
| 100 | 
         
            -
                        kernel_attn = kernel_mod.softmax(dim=-1)
         
     | 
| 101 | 
         
            -
                        kernel_attn = rearrange(kernel_attn, "b n -> b n 1 1 1 1")
         
     | 
| 102 | 
         
            -
             
     | 
| 103 | 
         
            -
                        weights = reduce(weights * kernel_attn, "b n ... -> b ...", "sum")
         
     | 
| 104 | 
         
            -
             
     | 
| 105 | 
         
            -
                    # do the modulation, demodulation, as done in stylegan2
         
     | 
| 106 | 
         
            -
             
     | 
| 107 | 
         
            -
                    mod = rearrange(mod, "b i -> b 1 i 1 1")
         
     | 
| 108 | 
         
            -
             
     | 
| 109 | 
         
            -
                    weights = weights * (mod + 1)
         
     | 
| 110 | 
         
            -
             
     | 
| 111 | 
         
            -
                    if self.demod:
         
     | 
| 112 | 
         
            -
                        inv_norm = (
         
     | 
| 113 | 
         
            -
                            reduce(weights**2, "b o i k1 k2 -> b o 1 1 1", "sum")
         
     | 
| 114 | 
         
            -
                            .clamp(min=self.eps)
         
     | 
| 115 | 
         
            -
                            .rsqrt()
         
     | 
| 116 | 
         
            -
                        )
         
     | 
| 117 | 
         
            -
                        weights = weights * inv_norm
         
     | 
| 118 | 
         
            -
             
     | 
| 119 | 
         
            -
                    fmap = rearrange(fmap, "b c h w -> 1 (b c) h w")
         
     | 
| 120 | 
         
            -
             
     | 
| 121 | 
         
            -
                    weights = rearrange(weights, "b o ... -> (b o) ...")
         
     | 
| 122 | 
         
            -
             
     | 
| 123 | 
         
            -
                    padding = get_same_padding(h, self.kernel, self.dilation, self.stride)
         
     | 
| 124 | 
         
            -
                    fmap = F.conv2d(fmap, weights, padding=padding, groups=b)
         
     | 
| 125 | 
         
            -
             
     | 
| 126 | 
         
            -
                    return rearrange(fmap, "1 (b o) ... -> b o ...", b=b)
         
     | 
| 127 | 
         
            -
             
     | 
| 128 | 
         
            -
             
     | 
| 129 | 
         
            -
            class Attend(nn.Module):
         
     | 
| 130 | 
         
            -
                def __init__(self, dropout=0.0, flash=False):
         
     | 
| 131 | 
         
            -
                    super().__init__()
         
     | 
| 132 | 
         
            -
                    self.dropout = dropout
         
     | 
| 133 | 
         
            -
                    self.attn_dropout = nn.Dropout(dropout)
         
     | 
| 134 | 
         
            -
                    self.scale = nn.Parameter(torch.randn(1))
         
     | 
| 135 | 
         
            -
                    self.flash = flash
         
     | 
| 136 | 
         
            -
             
     | 
| 137 | 
         
            -
                def flash_attn(self, q, k, v):
         
     | 
| 138 | 
         
            -
                    q, k, v = map(lambda t: t.contiguous(), (q, k, v))
         
     | 
| 139 | 
         
            -
                    out = F.scaled_dot_product_attention(
         
     | 
| 140 | 
         
            -
                        q, k, v, dropout_p=self.dropout if self.training else 0.0
         
     | 
| 141 | 
         
            -
                    )
         
     | 
| 142 | 
         
            -
                    return out
         
     | 
| 143 | 
         
            -
             
     | 
| 144 | 
         
            -
                def forward(self, q, k, v):
         
     | 
| 145 | 
         
            -
                    if self.flash:
         
     | 
| 146 | 
         
            -
                        return self.flash_attn(q, k, v)
         
     | 
| 147 | 
         
            -
             
     | 
| 148 | 
         
            -
                    scale = q.shape[-1] ** -0.5
         
     | 
| 149 | 
         
            -
             
     | 
| 150 | 
         
            -
                    # similarity
         
     | 
| 151 | 
         
            -
                    sim = einsum("b h i d, b h j d -> b h i j", q, k) * scale
         
     | 
| 152 | 
         
            -
             
     | 
| 153 | 
         
            -
                    # attention
         
     | 
| 154 | 
         
            -
                    attn = sim.softmax(dim=-1)
         
     | 
| 155 | 
         
            -
                    attn = self.attn_dropout(attn)
         
     | 
| 156 | 
         
            -
             
     | 
| 157 | 
         
            -
                    # aggregate values
         
     | 
| 158 | 
         
            -
                    out = einsum("b h i j, b h j d -> b h i d", attn, v)
         
     | 
| 159 | 
         
            -
             
     | 
| 160 | 
         
            -
                    return out
         
     | 
| 161 | 
         
            -
             
     | 
| 162 | 
         
            -
             
     | 
| 163 | 
         
            -
            def exists(x):
         
     | 
| 164 | 
         
            -
                return x is not None
         
     | 
| 165 | 
         
            -
             
     | 
| 166 | 
         
            -
             
     | 
| 167 | 
         
            -
            def default(val, d):
         
     | 
| 168 | 
         
            -
                if exists(val):
         
     | 
| 169 | 
         
            -
                    return val
         
     | 
| 170 | 
         
            -
                return d() if callable(d) else d
         
     | 
| 171 | 
         
            -
             
     | 
| 172 | 
         
            -
             
     | 
| 173 | 
         
            -
            def cast_tuple(t, length=1):
         
     | 
| 174 | 
         
            -
                if isinstance(t, tuple):
         
     | 
| 175 | 
         
            -
                    return t
         
     | 
| 176 | 
         
            -
                return (t,) * length
         
     | 
| 177 | 
         
            -
             
     | 
| 178 | 
         
            -
             
     | 
| 179 | 
         
            -
            def identity(t, *args, **kwargs):
         
     | 
| 180 | 
         
            -
                return t
         
     | 
| 181 | 
         
            -
             
     | 
| 182 | 
         
            -
             
     | 
| 183 | 
         
            -
            def is_power_of_two(n):
         
     | 
| 184 | 
         
            -
                return log2(n).is_integer()
         
     | 
| 185 | 
         
            -
             
     | 
| 186 | 
         
            -
             
     | 
| 187 | 
         
            -
            def null_iterator():
         
     | 
| 188 | 
         
            -
                while True:
         
     | 
| 189 | 
         
            -
                    yield None
         
     | 
| 190 | 
         
            -
             
     | 
| 191 | 
         
            -
             
     | 
| 192 | 
         
            -
            def Downsample(dim, dim_out=None):
         
     | 
| 193 | 
         
            -
                return nn.Sequential(
         
     | 
| 194 | 
         
            -
                    Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
         
     | 
| 195 | 
         
            -
                    nn.Conv2d(dim * 4, default(dim_out, dim), 1),
         
     | 
| 196 | 
         
            -
                )
         
     | 
| 197 | 
         
            -
             
     | 
| 198 | 
         
            -
             
     | 
| 199 | 
         
            -
            class RMSNorm(nn.Module):
         
     | 
| 200 | 
         
            -
                def __init__(self, dim):
         
     | 
| 201 | 
         
            -
                    super().__init__()
         
     | 
| 202 | 
         
            -
                    self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
         
     | 
| 203 | 
         
            -
                    self.eps = 1e-4
         
     | 
| 204 | 
         
            -
             
     | 
| 205 | 
         
            -
                def forward(self, x):
         
     | 
| 206 | 
         
            -
                    return F.normalize(x, dim=1) * self.g * (x.shape[1] ** 0.5)
         
     | 
| 207 | 
         
            -
             
     | 
| 208 | 
         
            -
             
     | 
| 209 | 
         
            -
            # building block modules
         
     | 
| 210 | 
         
            -
             
     | 
| 211 | 
         
            -
             
     | 
| 212 | 
         
            -
            class Block(nn.Module):
         
     | 
| 213 | 
         
            -
                def __init__(self, dim, dim_out, groups=8, num_conv_kernels=0):
         
     | 
| 214 | 
         
            -
                    super().__init__()
         
     | 
| 215 | 
         
            -
                    self.proj = AdaptiveConv2DMod(
         
     | 
| 216 | 
         
            -
                        dim, dim_out, kernel=3, num_conv_kernels=num_conv_kernels
         
     | 
| 217 | 
         
            -
                    )
         
     | 
| 218 | 
         
            -
                    self.kernel = 3
         
     | 
| 219 | 
         
            -
                    self.dilation = 1
         
     | 
| 220 | 
         
            -
                    self.stride = 1
         
     | 
| 221 | 
         
            -
             
     | 
| 222 | 
         
            -
                    self.act = nn.SiLU()
         
     | 
| 223 | 
         
            -
             
     | 
| 224 | 
         
            -
                def forward(self, x, conv_mods_iter: Optional[Iterable] = None):
         
     | 
| 225 | 
         
            -
                    conv_mods_iter = default(conv_mods_iter, null_iterator())
         
     | 
| 226 | 
         
            -
             
     | 
| 227 | 
         
            -
                    x = self.proj(x, mod=next(conv_mods_iter), kernel_mod=next(conv_mods_iter))
         
     | 
| 228 | 
         
            -
             
     | 
| 229 | 
         
            -
                    x = self.act(x)
         
     | 
| 230 | 
         
            -
                    return x
         
     | 
| 231 | 
         
            -
             
     | 
| 232 | 
         
            -
             
     | 
| 233 | 
         
            -
            class ResnetBlock(nn.Module):
         
     | 
| 234 | 
         
            -
                def __init__(
         
     | 
| 235 | 
         
            -
                    self, dim, dim_out, *, groups=8, num_conv_kernels=0, style_dims: List = []
         
     | 
| 236 | 
         
            -
                ):
         
     | 
| 237 | 
         
            -
                    super().__init__()
         
     | 
| 238 | 
         
            -
                    style_dims.extend([dim, num_conv_kernels, dim_out, num_conv_kernels])
         
     | 
| 239 | 
         
            -
             
     | 
| 240 | 
         
            -
                    self.block1 = Block(
         
     | 
| 241 | 
         
            -
                        dim, dim_out, groups=groups, num_conv_kernels=num_conv_kernels
         
     | 
| 242 | 
         
            -
                    )
         
     | 
| 243 | 
         
            -
                    self.block2 = Block(
         
     | 
| 244 | 
         
            -
                        dim_out, dim_out, groups=groups, num_conv_kernels=num_conv_kernels
         
     | 
| 245 | 
         
            -
                    )
         
     | 
| 246 | 
         
            -
                    self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
         
     | 
| 247 | 
         
            -
             
     | 
| 248 | 
         
            -
                def forward(self, x, conv_mods_iter: Optional[Iterable] = None):
         
     | 
| 249 | 
         
            -
                    h = self.block1(x, conv_mods_iter=conv_mods_iter)
         
     | 
| 250 | 
         
            -
                    h = self.block2(h, conv_mods_iter=conv_mods_iter)
         
     | 
| 251 | 
         
            -
             
     | 
| 252 | 
         
            -
                    return h + self.res_conv(x)
         
     | 
| 253 | 
         
            -
             
     | 
| 254 | 
         
            -
             
     | 
| 255 | 
         
            -
            class LinearAttention(nn.Module):
         
     | 
| 256 | 
         
            -
                def __init__(self, dim, heads=4, dim_head=32):
         
     | 
| 257 | 
         
            -
                    super().__init__()
         
     | 
| 258 | 
         
            -
                    self.scale = dim_head**-0.5
         
     | 
| 259 | 
         
            -
                    self.heads = heads
         
     | 
| 260 | 
         
            -
                    hidden_dim = dim_head * heads
         
     | 
| 261 | 
         
            -
             
     | 
| 262 | 
         
            -
                    self.norm = RMSNorm(dim)
         
     | 
| 263 | 
         
            -
                    self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
         
     | 
| 264 | 
         
            -
             
     | 
| 265 | 
         
            -
                    self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), RMSNorm(dim))
         
     | 
| 266 | 
         
            -
             
     | 
| 267 | 
         
            -
                def forward(self, x):
         
     | 
| 268 | 
         
            -
                    b, c, h, w = x.shape
         
     | 
| 269 | 
         
            -
             
     | 
| 270 | 
         
            -
                    x = self.norm(x)
         
     | 
| 271 | 
         
            -
             
     | 
| 272 | 
         
            -
                    qkv = self.to_qkv(x).chunk(3, dim=1)
         
     | 
| 273 | 
         
            -
                    q, k, v = map(
         
     | 
| 274 | 
         
            -
                        lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
         
     | 
| 275 | 
         
            -
                    )
         
     | 
| 276 | 
         
            -
             
     | 
| 277 | 
         
            -
                    q = q.softmax(dim=-2)
         
     | 
| 278 | 
         
            -
                    k = k.softmax(dim=-1)
         
     | 
| 279 | 
         
            -
             
     | 
| 280 | 
         
            -
                    q = q * self.scale
         
     | 
| 281 | 
         
            -
             
     | 
| 282 | 
         
            -
                    context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
         
     | 
| 283 | 
         
            -
             
     | 
| 284 | 
         
            -
                    out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
         
     | 
| 285 | 
         
            -
                    out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
         
     | 
| 286 | 
         
            -
                    return self.to_out(out)
         
     | 
| 287 | 
         
            -
             
     | 
| 288 | 
         
            -
             
     | 
| 289 | 
         
            -
            class Attention(nn.Module):
         
     | 
| 290 | 
         
            -
                def __init__(self, dim, heads=4, dim_head=32, flash=False):
         
     | 
| 291 | 
         
            -
                    super().__init__()
         
     | 
| 292 | 
         
            -
                    self.heads = heads
         
     | 
| 293 | 
         
            -
                    hidden_dim = dim_head * heads
         
     | 
| 294 | 
         
            -
             
     | 
| 295 | 
         
            -
                    self.norm = RMSNorm(dim)
         
     | 
| 296 | 
         
            -
             
     | 
| 297 | 
         
            -
                    self.attend = Attend(flash=flash)
         
     | 
| 298 | 
         
            -
                    self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
         
     | 
| 299 | 
         
            -
                    self.to_out = nn.Conv2d(hidden_dim, dim, 1)
         
     | 
| 300 | 
         
            -
             
     | 
| 301 | 
         
            -
                def forward(self, x):
         
     | 
| 302 | 
         
            -
                    b, c, h, w = x.shape
         
     | 
| 303 | 
         
            -
                    x = self.norm(x)
         
     | 
| 304 | 
         
            -
                    qkv = self.to_qkv(x).chunk(3, dim=1)
         
     | 
| 305 | 
         
            -
             
     | 
| 306 | 
         
            -
                    q, k, v = map(
         
     | 
| 307 | 
         
            -
                        lambda t: rearrange(t, "b (h c) x y -> b h (x y) c", h=self.heads), qkv
         
     | 
| 308 | 
         
            -
                    )
         
     | 
| 309 | 
         
            -
             
     | 
| 310 | 
         
            -
                    out = self.attend(q, k, v)
         
     | 
| 311 | 
         
            -
                    out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
         
     | 
| 312 | 
         
            -
             
     | 
| 313 | 
         
            -
                    return self.to_out(out)
         
     | 
| 314 | 
         
            -
             
     | 
| 315 | 
         
            -
             
     | 
| 316 | 
         
            -
            # feedforward
         
     | 
| 317 | 
         
            -
            def FeedForward(dim, mult=4):
         
     | 
| 318 | 
         
            -
                return nn.Sequential(
         
     | 
| 319 | 
         
            -
                    RMSNorm(dim),
         
     | 
| 320 | 
         
            -
                    nn.Conv2d(dim, dim * mult, 1),
         
     | 
| 321 | 
         
            -
                    nn.GELU(),
         
     | 
| 322 | 
         
            -
                    nn.Conv2d(dim * mult, dim, 1),
         
     | 
| 323 | 
         
            -
                )
         
     | 
| 324 | 
         
            -
             
     | 
| 325 | 
         
            -
             
     | 
| 326 | 
         
            -
            # transformers
         
     | 
| 327 | 
         
            -
            class Transformer(nn.Module):
         
     | 
| 328 | 
         
            -
                def __init__(self, dim, dim_head=64, heads=8, depth=1, flash_attn=True, ff_mult=4):
         
     | 
| 329 | 
         
            -
                    super().__init__()
         
     | 
| 330 | 
         
            -
                    self.layers = nn.ModuleList([])
         
     | 
| 331 | 
         
            -
             
     | 
| 332 | 
         
            -
                    for _ in range(depth):
         
     | 
| 333 | 
         
            -
                        self.layers.append(
         
     | 
| 334 | 
         
            -
                            nn.ModuleList(
         
     | 
| 335 | 
         
            -
                                [
         
     | 
| 336 | 
         
            -
                                    Attention(
         
     | 
| 337 | 
         
            -
                                        dim=dim, dim_head=dim_head, heads=heads, flash=flash_attn
         
     | 
| 338 | 
         
            -
                                    ),
         
     | 
| 339 | 
         
            -
                                    FeedForward(dim=dim, mult=ff_mult),
         
     | 
| 340 | 
         
            -
                                ]
         
     | 
| 341 | 
         
            -
                            )
         
     | 
| 342 | 
         
            -
                        )
         
     | 
| 343 | 
         
            -
             
     | 
| 344 | 
         
            -
                def forward(self, x):
         
     | 
| 345 | 
         
            -
                    for attn, ff in self.layers:
         
     | 
| 346 | 
         
            -
                        x = attn(x) + x
         
     | 
| 347 | 
         
            -
                        x = ff(x) + x
         
     | 
| 348 | 
         
            -
             
     | 
| 349 | 
         
            -
                    return x
         
     | 
| 350 | 
         
            -
             
     | 
| 351 | 
         
            -
             
     | 
| 352 | 
         
            -
            class LinearTransformer(nn.Module):
         
     | 
| 353 | 
         
            -
                def __init__(self, dim, dim_head=64, heads=8, depth=1, ff_mult=4):
         
     | 
| 354 | 
         
            -
                    super().__init__()
         
     | 
| 355 | 
         
            -
                    self.layers = nn.ModuleList([])
         
     | 
| 356 | 
         
            -
             
     | 
| 357 | 
         
            -
                    for _ in range(depth):
         
     | 
| 358 | 
         
            -
                        self.layers.append(
         
     | 
| 359 | 
         
            -
                            nn.ModuleList(
         
     | 
| 360 | 
         
            -
                                [
         
     | 
| 361 | 
         
            -
                                    LinearAttention(dim=dim, dim_head=dim_head, heads=heads),
         
     | 
| 362 | 
         
            -
                                    FeedForward(dim=dim, mult=ff_mult),
         
     | 
| 363 | 
         
            -
                                ]
         
     | 
| 364 | 
         
            -
                            )
         
     | 
| 365 | 
         
            -
                        )
         
     | 
| 366 | 
         
            -
             
     | 
| 367 | 
         
            -
                def forward(self, x):
         
     | 
| 368 | 
         
            -
                    for attn, ff in self.layers:
         
     | 
| 369 | 
         
            -
                        x = attn(x) + x
         
     | 
| 370 | 
         
            -
                        x = ff(x) + x
         
     | 
| 371 | 
         
            -
             
     | 
| 372 | 
         
            -
                    return x
         
     | 
| 373 | 
         
            -
             
     | 
| 374 | 
         
            -
             
     | 
| 375 | 
         
            -
            class NearestNeighborhoodUpsample(nn.Module):
         
     | 
| 376 | 
         
            -
                def __init__(self, dim, dim_out=None):
         
     | 
| 377 | 
         
            -
                    super().__init__()
         
     | 
| 378 | 
         
            -
                    dim_out = default(dim_out, dim)
         
     | 
| 379 | 
         
            -
                    self.conv = nn.Conv2d(dim, dim_out, kernel_size=3, stride=1, padding=1)
         
     | 
| 380 | 
         
            -
             
     | 
| 381 | 
         
            -
                def forward(self, x):
         
     | 
| 382 | 
         
            -
             
     | 
| 383 | 
         
            -
                    if x.shape[0] >= 64:
         
     | 
| 384 | 
         
            -
                        x = x.contiguous()
         
     | 
| 385 | 
         
            -
             
     | 
| 386 | 
         
            -
                    x = F.interpolate(x, scale_factor=2.0, mode="nearest")
         
     | 
| 387 | 
         
            -
                    x = self.conv(x)
         
     | 
| 388 | 
         
            -
             
     | 
| 389 | 
         
            -
                    return x
         
     | 
| 390 | 
         
            -
             
     | 
| 391 | 
         
            -
             
     | 
| 392 | 
         
            -
            class EqualLinear(nn.Module):
         
     | 
| 393 | 
         
            -
                def __init__(self, dim, dim_out, lr_mul=1, bias=True):
         
     | 
| 394 | 
         
            -
                    super().__init__()
         
     | 
| 395 | 
         
            -
                    self.weight = nn.Parameter(torch.randn(dim_out, dim))
         
     | 
| 396 | 
         
            -
                    if bias:
         
     | 
| 397 | 
         
            -
                        self.bias = nn.Parameter(torch.zeros(dim_out))
         
     | 
| 398 | 
         
            -
             
     | 
| 399 | 
         
            -
                    self.lr_mul = lr_mul
         
     | 
| 400 | 
         
            -
             
     | 
| 401 | 
         
            -
                def forward(self, input):
         
     | 
| 402 | 
         
            -
                    return F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul)
         
     | 
| 403 | 
         
            -
             
     | 
| 404 | 
         
            -
             
     | 
| 405 | 
         
            -
            class StyleGanNetwork(nn.Module):
         
     | 
| 406 | 
         
            -
                def __init__(self, dim_in=128, dim_out=512, depth=8, lr_mul=0.1, dim_text_latent=0):
         
     | 
| 407 | 
         
            -
                    super().__init__()
         
     | 
| 408 | 
         
            -
                    self.dim_in = dim_in
         
     | 
| 409 | 
         
            -
                    self.dim_out = dim_out
         
     | 
| 410 | 
         
            -
                    self.dim_text_latent = dim_text_latent
         
     | 
| 411 | 
         
            -
             
     | 
| 412 | 
         
            -
                    layers = []
         
     | 
| 413 | 
         
            -
                    for i in range(depth):
         
     | 
| 414 | 
         
            -
                        is_first = i == 0
         
     | 
| 415 | 
         
            -
             
     | 
| 416 | 
         
            -
                        if is_first:
         
     | 
| 417 | 
         
            -
                            dim_in_layer = dim_in + dim_text_latent
         
     | 
| 418 | 
         
            -
                        else:
         
     | 
| 419 | 
         
            -
                            dim_in_layer = dim_out
         
     | 
| 420 | 
         
            -
             
     | 
| 421 | 
         
            -
                        dim_out_layer = dim_out
         
     | 
| 422 | 
         
            -
             
     | 
| 423 | 
         
            -
                        layers.extend(
         
     | 
| 424 | 
         
            -
                            [EqualLinear(dim_in_layer, dim_out_layer, lr_mul), nn.LeakyReLU(0.2)]
         
     | 
| 425 | 
         
            -
                        )
         
     | 
| 426 | 
         
            -
             
     | 
| 427 | 
         
            -
                    self.net = nn.Sequential(*layers)
         
     | 
| 428 | 
         
            -
             
     | 
| 429 | 
         
            -
                def forward(self, x, text_latent=None):
         
     | 
| 430 | 
         
            -
                    x = F.normalize(x, dim=1)
         
     | 
| 431 | 
         
            -
                    if self.dim_text_latent > 0:
         
     | 
| 432 | 
         
            -
                        assert exists(text_latent)
         
     | 
| 433 | 
         
            -
                        x = torch.cat((x, text_latent), dim=-1)
         
     | 
| 434 | 
         
            -
                    return self.net(x)
         
     | 
| 435 | 
         
            -
             
     | 
| 436 | 
         
            -
             
     | 
| 437 | 
         
            -
            class UnetUpsampler(torch.nn.Module):
         
     | 
| 438 | 
         
            -
             
     | 
| 439 | 
         
            -
                def __init__(
         
     | 
| 440 | 
         
            -
                    self,
         
     | 
| 441 | 
         
            -
                    dim: int,
         
     | 
| 442 | 
         
            -
                    *,
         
     | 
| 443 | 
         
            -
                    image_size: int,
         
     | 
| 444 | 
         
            -
                    input_image_size: int,
         
     | 
| 445 | 
         
            -
                    init_dim: Optional[int] = None,
         
     | 
| 446 | 
         
            -
                    out_dim: Optional[int] = None,
         
     | 
| 447 | 
         
            -
                    style_network: Optional[dict] = None,
         
     | 
| 448 | 
         
            -
                    up_dim_mults: tuple = (1, 2, 4, 8, 16),
         
     | 
| 449 | 
         
            -
                    down_dim_mults: tuple = (4, 8, 16),
         
     | 
| 450 | 
         
            -
                    channels: int = 3,
         
     | 
| 451 | 
         
            -
                    resnet_block_groups: int = 8,
         
     | 
| 452 | 
         
            -
                    full_attn: tuple = (False, False, False, True, True),
         
     | 
| 453 | 
         
            -
                    flash_attn: bool = True,
         
     | 
| 454 | 
         
            -
                    self_attn_dim_head: int = 64,
         
     | 
| 455 | 
         
            -
                    self_attn_heads: int = 8,
         
     | 
| 456 | 
         
            -
                    attn_depths: tuple = (2, 2, 2, 2, 4),
         
     | 
| 457 | 
         
            -
                    mid_attn_depth: int = 4,
         
     | 
| 458 | 
         
            -
                    num_conv_kernels: int = 4,
         
     | 
| 459 | 
         
            -
                    resize_mode: str = "bilinear",
         
     | 
| 460 | 
         
            -
                    unconditional: bool = True,
         
     | 
| 461 | 
         
            -
                    skip_connect_scale: Optional[float] = None,
         
     | 
| 462 | 
         
            -
                ):
         
     | 
| 463 | 
         
            -
                    super().__init__()
         
     | 
| 464 | 
         
            -
                    self.style_network = style_network = StyleGanNetwork(**style_network)
         
     | 
| 465 | 
         
            -
                    self.unconditional = unconditional
         
     | 
| 466 | 
         
            -
                    assert not (
         
     | 
| 467 | 
         
            -
                        unconditional
         
     | 
| 468 | 
         
            -
                        and exists(style_network)
         
     | 
| 469 | 
         
            -
                        and style_network.dim_text_latent > 0
         
     | 
| 470 | 
         
            -
                    )
         
     | 
| 471 | 
         
            -
             
     | 
| 472 | 
         
            -
                    assert is_power_of_two(image_size) and is_power_of_two(
         
     | 
| 473 | 
         
            -
                        input_image_size
         
     | 
| 474 | 
         
            -
                    ), "both output image size and input image size must be power of 2"
         
     | 
| 475 | 
         
            -
                    assert (
         
     | 
| 476 | 
         
            -
                        input_image_size < image_size
         
     | 
| 477 | 
         
            -
                    ), "input image size must be smaller than the output image size, thus upsampling"
         
     | 
| 478 | 
         
            -
             
     | 
| 479 | 
         
            -
                    self.image_size = image_size
         
     | 
| 480 | 
         
            -
                    self.input_image_size = input_image_size
         
     | 
| 481 | 
         
            -
             
     | 
| 482 | 
         
            -
                    style_embed_split_dims = []
         
     | 
| 483 | 
         
            -
             
     | 
| 484 | 
         
            -
                    self.channels = channels
         
     | 
| 485 | 
         
            -
                    input_channels = channels
         
     | 
| 486 | 
         
            -
             
     | 
| 487 | 
         
            -
                    init_dim = default(init_dim, dim)
         
     | 
| 488 | 
         
            -
             
     | 
| 489 | 
         
            -
                    up_dims = [init_dim, *map(lambda m: dim * m, up_dim_mults)]
         
     | 
| 490 | 
         
            -
                    init_down_dim = up_dims[len(up_dim_mults) - len(down_dim_mults)]
         
     | 
| 491 | 
         
            -
                    down_dims = [init_down_dim, *map(lambda m: dim * m, down_dim_mults)]
         
     | 
| 492 | 
         
            -
                    self.init_conv = nn.Conv2d(input_channels, init_down_dim, 7, padding=3)
         
     | 
| 493 | 
         
            -
             
     | 
| 494 | 
         
            -
                    up_in_out = list(zip(up_dims[:-1], up_dims[1:]))
         
     | 
| 495 | 
         
            -
                    down_in_out = list(zip(down_dims[:-1], down_dims[1:]))
         
     | 
| 496 | 
         
            -
             
     | 
| 497 | 
         
            -
                    block_klass = partial(
         
     | 
| 498 | 
         
            -
                        ResnetBlock,
         
     | 
| 499 | 
         
            -
                        groups=resnet_block_groups,
         
     | 
| 500 | 
         
            -
                        num_conv_kernels=num_conv_kernels,
         
     | 
| 501 | 
         
            -
                        style_dims=style_embed_split_dims,
         
     | 
| 502 | 
         
            -
                    )
         
     | 
| 503 | 
         
            -
             
     | 
| 504 | 
         
            -
                    FullAttention = partial(Transformer, flash_attn=flash_attn)
         
     | 
| 505 | 
         
            -
                    *_, mid_dim = up_dims
         
     | 
| 506 | 
         
            -
             
     | 
| 507 | 
         
            -
                    self.skip_connect_scale = default(skip_connect_scale, 2**-0.5)
         
     | 
| 508 | 
         
            -
             
     | 
| 509 | 
         
            -
                    self.downs = nn.ModuleList([])
         
     | 
| 510 | 
         
            -
                    self.ups = nn.ModuleList([])
         
     | 
| 511 | 
         
            -
             
     | 
| 512 | 
         
            -
                    block_count = 6
         
     | 
| 513 | 
         
            -
             
     | 
| 514 | 
         
            -
                    for ind, (
         
     | 
| 515 | 
         
            -
                        (dim_in, dim_out),
         
     | 
| 516 | 
         
            -
                        layer_full_attn,
         
     | 
| 517 | 
         
            -
                        layer_attn_depth,
         
     | 
| 518 | 
         
            -
                    ) in enumerate(zip(down_in_out, full_attn, attn_depths)):
         
     | 
| 519 | 
         
            -
                        attn_klass = FullAttention if layer_full_attn else LinearTransformer
         
     | 
| 520 | 
         
            -
             
     | 
| 521 | 
         
            -
                        blocks = []
         
     | 
| 522 | 
         
            -
                        for i in range(block_count):
         
     | 
| 523 | 
         
            -
                            blocks.append(block_klass(dim_in, dim_in))
         
     | 
| 524 | 
         
            -
             
     | 
| 525 | 
         
            -
                        self.downs.append(
         
     | 
| 526 | 
         
            -
                            nn.ModuleList(
         
     | 
| 527 | 
         
            -
                                [
         
     | 
| 528 | 
         
            -
                                    nn.ModuleList(blocks),
         
     | 
| 529 | 
         
            -
                                    nn.ModuleList(
         
     | 
| 530 | 
         
            -
                                        [
         
     | 
| 531 | 
         
            -
                                            (
         
     | 
| 532 | 
         
            -
                                                attn_klass(
         
     | 
| 533 | 
         
            -
                                                    dim_in,
         
     | 
| 534 | 
         
            -
                                                    dim_head=self_attn_dim_head,
         
     | 
| 535 | 
         
            -
                                                    heads=self_attn_heads,
         
     | 
| 536 | 
         
            -
                                                    depth=layer_attn_depth,
         
     | 
| 537 | 
         
            -
                                                )
         
     | 
| 538 | 
         
            -
                                                if layer_full_attn
         
     | 
| 539 | 
         
            -
                                                else None
         
     | 
| 540 | 
         
            -
                                            ),
         
     | 
| 541 | 
         
            -
                                            nn.Conv2d(
         
     | 
| 542 | 
         
            -
                                                dim_in, dim_out, kernel_size=3, stride=2, padding=1
         
     | 
| 543 | 
         
            -
                                            ),
         
     | 
| 544 | 
         
            -
                                        ]
         
     | 
| 545 | 
         
            -
                                    ),
         
     | 
| 546 | 
         
            -
                                ]
         
     | 
| 547 | 
         
            -
                            )
         
     | 
| 548 | 
         
            -
                        )
         
     | 
| 549 | 
         
            -
             
     | 
| 550 | 
         
            -
                    self.mid_block1 = block_klass(mid_dim, mid_dim)
         
     | 
| 551 | 
         
            -
                    self.mid_attn = FullAttention(
         
     | 
| 552 | 
         
            -
                        mid_dim,
         
     | 
| 553 | 
         
            -
                        dim_head=self_attn_dim_head,
         
     | 
| 554 | 
         
            -
                        heads=self_attn_heads,
         
     | 
| 555 | 
         
            -
                        depth=mid_attn_depth,
         
     | 
| 556 | 
         
            -
                    )
         
     | 
| 557 | 
         
            -
                    self.mid_block2 = block_klass(mid_dim, mid_dim)
         
     | 
| 558 | 
         
            -
             
     | 
| 559 | 
         
            -
                    *_, last_dim = up_dims
         
     | 
| 560 | 
         
            -
             
     | 
| 561 | 
         
            -
                    for ind, (
         
     | 
| 562 | 
         
            -
                        (dim_in, dim_out),
         
     | 
| 563 | 
         
            -
                        layer_full_attn,
         
     | 
| 564 | 
         
            -
                        layer_attn_depth,
         
     | 
| 565 | 
         
            -
                    ) in enumerate(
         
     | 
| 566 | 
         
            -
                        zip(
         
     | 
| 567 | 
         
            -
                            reversed(up_in_out),
         
     | 
| 568 | 
         
            -
                            reversed(full_attn),
         
     | 
| 569 | 
         
            -
                            reversed(attn_depths),
         
     | 
| 570 | 
         
            -
                        )
         
     | 
| 571 | 
         
            -
                    ):
         
     | 
| 572 | 
         
            -
                        attn_klass = FullAttention if layer_full_attn else LinearTransformer
         
     | 
| 573 | 
         
            -
             
     | 
| 574 | 
         
            -
                        blocks = []
         
     | 
| 575 | 
         
            -
                        input_dim = dim_in * 2 if ind < len(down_in_out) else dim_in
         
     | 
| 576 | 
         
            -
                        for i in range(block_count):
         
     | 
| 577 | 
         
            -
                            blocks.append(block_klass(input_dim, dim_in))
         
     | 
| 578 | 
         
            -
             
     | 
| 579 | 
         
            -
                        self.ups.append(
         
     | 
| 580 | 
         
            -
                            nn.ModuleList(
         
     | 
| 581 | 
         
            -
                                [
         
     | 
| 582 | 
         
            -
                                    nn.ModuleList(blocks),
         
     | 
| 583 | 
         
            -
                                    nn.ModuleList(
         
     | 
| 584 | 
         
            -
                                        [
         
     | 
| 585 | 
         
            -
                                            NearestNeighborhoodUpsample(
         
     | 
| 586 | 
         
            -
                                                last_dim if ind == 0 else dim_out,
         
     | 
| 587 | 
         
            -
                                                dim_in,
         
     | 
| 588 | 
         
            -
                                            ),
         
     | 
| 589 | 
         
            -
                                            (
         
     | 
| 590 | 
         
            -
                                                attn_klass(
         
     | 
| 591 | 
         
            -
                                                    dim_in,
         
     | 
| 592 | 
         
            -
                                                    dim_head=self_attn_dim_head,
         
     | 
| 593 | 
         
            -
                                                    heads=self_attn_heads,
         
     | 
| 594 | 
         
            -
                                                    depth=layer_attn_depth,
         
     | 
| 595 | 
         
            -
                                                )
         
     | 
| 596 | 
         
            -
                                                if layer_full_attn
         
     | 
| 597 | 
         
            -
                                                else None
         
     | 
| 598 | 
         
            -
                                            ),
         
     | 
| 599 | 
         
            -
                                        ]
         
     | 
| 600 | 
         
            -
                                    ),
         
     | 
| 601 | 
         
            -
                                ]
         
     | 
| 602 | 
         
            -
                            )
         
     | 
| 603 | 
         
            -
                        )
         
     | 
| 604 | 
         
            -
             
     | 
| 605 | 
         
            -
                    self.out_dim = default(out_dim, channels)
         
     | 
| 606 | 
         
            -
                    self.final_res_block = block_klass(dim, dim)
         
     | 
| 607 | 
         
            -
                    self.final_to_rgb = nn.Conv2d(dim, channels, 1)
         
     | 
| 608 | 
         
            -
                    self.resize_mode = resize_mode
         
     | 
| 609 | 
         
            -
                    self.style_to_conv_modulations = nn.Linear(
         
     | 
| 610 | 
         
            -
                        style_network.dim_out, sum(style_embed_split_dims)
         
     | 
| 611 | 
         
            -
                    )
         
     | 
| 612 | 
         
            -
                    self.style_embed_split_dims = style_embed_split_dims
         
     | 
| 613 | 
         
            -
             
     | 
| 614 | 
         
            -
                @property
         
     | 
| 615 | 
         
            -
                def allowable_rgb_resolutions(self):
         
     | 
| 616 | 
         
            -
                    input_res_base = int(log2(self.input_image_size))
         
     | 
| 617 | 
         
            -
                    output_res_base = int(log2(self.image_size))
         
     | 
| 618 | 
         
            -
                    allowed_rgb_res_base = list(range(input_res_base, output_res_base))
         
     | 
| 619 | 
         
            -
                    return [*map(lambda p: 2**p, allowed_rgb_res_base)]
         
     | 
| 620 | 
         
            -
             
     | 
| 621 | 
         
            -
                @property
         
     | 
| 622 | 
         
            -
                def device(self):
         
     | 
| 623 | 
         
            -
                    return next(self.parameters()).device
         
     | 
| 624 | 
         
            -
             
     | 
| 625 | 
         
            -
                @property
         
     | 
| 626 | 
         
            -
                def total_params(self):
         
     | 
| 627 | 
         
            -
                    return sum([p.numel() for p in self.parameters()])
         
     | 
| 628 | 
         
            -
             
     | 
| 629 | 
         
            -
                def resize_image_to(self, x, size):
         
     | 
| 630 | 
         
            -
                    return F.interpolate(x, (size, size), mode=self.resize_mode)
         
     | 
| 631 | 
         
            -
             
     | 
| 632 | 
         
            -
                def forward(
         
     | 
| 633 | 
         
            -
                    self,
         
     | 
| 634 | 
         
            -
                    lowres_image: torch.Tensor,
         
     | 
| 635 | 
         
            -
                    styles: Optional[torch.Tensor] = None,
         
     | 
| 636 | 
         
            -
                    noise: Optional[torch.Tensor] = None,
         
     | 
| 637 | 
         
            -
                    global_text_tokens: Optional[torch.Tensor] = None,
         
     | 
| 638 | 
         
            -
                    return_all_rgbs: bool = False,
         
     | 
| 639 | 
         
            -
                ):
         
     | 
| 640 | 
         
            -
                    x = lowres_image
         
     | 
| 641 | 
         
            -
             
     | 
| 642 | 
         
            -
                    noise_scale = 0.001  # Adjust the scale of the noise as needed
         
     | 
| 643 | 
         
            -
                    noise_aug = torch.randn_like(x) * noise_scale
         
     | 
| 644 | 
         
            -
                    x = x + noise_aug
         
     | 
| 645 | 
         
            -
                    x = x.clamp(0, 1)
         
     | 
| 646 | 
         
            -
             
     | 
| 647 | 
         
            -
                    shape = x.shape
         
     | 
| 648 | 
         
            -
                    batch_size = shape[0]
         
     | 
| 649 | 
         
            -
             
     | 
| 650 | 
         
            -
                    assert shape[-2:] == ((self.input_image_size,) * 2)
         
     | 
| 651 | 
         
            -
             
     | 
| 652 | 
         
            -
                    # styles
         
     | 
| 653 | 
         
            -
                    if not exists(styles):
         
     | 
| 654 | 
         
            -
                        assert exists(self.style_network)
         
     | 
| 655 | 
         
            -
             
     | 
| 656 | 
         
            -
                        noise = default(
         
     | 
| 657 | 
         
            -
                            noise,
         
     | 
| 658 | 
         
            -
                            torch.randn(
         
     | 
| 659 | 
         
            -
                                (batch_size, self.style_network.dim_in), device=self.device
         
     | 
| 660 | 
         
            -
                            ),
         
     | 
| 661 | 
         
            -
                        )
         
     | 
| 662 | 
         
            -
                        styles = self.style_network(noise, global_text_tokens)
         
     | 
| 663 | 
         
            -
             
     | 
| 664 | 
         
            -
                    # project styles to conv modulations
         
     | 
| 665 | 
         
            -
                    conv_mods = self.style_to_conv_modulations(styles)
         
     | 
| 666 | 
         
            -
                    conv_mods = conv_mods.split(self.style_embed_split_dims, dim=-1)
         
     | 
| 667 | 
         
            -
                    conv_mods = iter(conv_mods)
         
     | 
| 668 | 
         
            -
             
     | 
| 669 | 
         
            -
                    x = self.init_conv(x)
         
     | 
| 670 | 
         
            -
             
     | 
| 671 | 
         
            -
                    h = []
         
     | 
| 672 | 
         
            -
                    for blocks, (attn, downsample) in self.downs:
         
     | 
| 673 | 
         
            -
                        for block in blocks:
         
     | 
| 674 | 
         
            -
                            x = block(x, conv_mods_iter=conv_mods)
         
     | 
| 675 | 
         
            -
                            h.append(x)
         
     | 
| 676 | 
         
            -
             
     | 
| 677 | 
         
            -
                        if attn is not None:
         
     | 
| 678 | 
         
            -
                            x = attn(x)
         
     | 
| 679 | 
         
            -
             
     | 
| 680 | 
         
            -
                        x = downsample(x)
         
     | 
| 681 | 
         
            -
             
     | 
| 682 | 
         
            -
                    x = self.mid_block1(x, conv_mods_iter=conv_mods)
         
     | 
| 683 | 
         
            -
                    x = self.mid_attn(x)
         
     | 
| 684 | 
         
            -
                    x = self.mid_block2(x, conv_mods_iter=conv_mods)
         
     | 
| 685 | 
         
            -
             
     | 
| 686 | 
         
            -
                    for (
         
     | 
| 687 | 
         
            -
                        blocks,
         
     | 
| 688 | 
         
            -
                        (
         
     | 
| 689 | 
         
            -
                            upsample,
         
     | 
| 690 | 
         
            -
                            attn,
         
     | 
| 691 | 
         
            -
                        ),
         
     | 
| 692 | 
         
            -
                    ) in self.ups:
         
     | 
| 693 | 
         
            -
                        x = upsample(x)
         
     | 
| 694 | 
         
            -
                        for block in blocks:
         
     | 
| 695 | 
         
            -
                            if h != []:
         
     | 
| 696 | 
         
            -
                                res = h.pop()
         
     | 
| 697 | 
         
            -
                                res = res * self.skip_connect_scale
         
     | 
| 698 | 
         
            -
                                x = torch.cat((x, res), dim=1)
         
     | 
| 699 | 
         
            -
             
     | 
| 700 | 
         
            -
                            x = block(x, conv_mods_iter=conv_mods)
         
     | 
| 701 | 
         
            -
             
     | 
| 702 | 
         
            -
                        if attn is not None:
         
     | 
| 703 | 
         
            -
                            x = attn(x)
         
     | 
| 704 | 
         
            -
             
     | 
| 705 | 
         
            -
                    x = self.final_res_block(x, conv_mods_iter=conv_mods)
         
     | 
| 706 | 
         
            -
                    rgb = self.final_to_rgb(x)
         
     | 
| 707 | 
         
            -
             
     | 
| 708 | 
         
            -
                    if not return_all_rgbs:
         
     | 
| 709 | 
         
            -
                        return rgb
         
     | 
| 710 | 
         
            -
             
     | 
| 711 | 
         
            -
                    return rgb, []
         
     | 
| 712 | 
         
            -
             
     | 
| 713 | 
         
            -
             
     | 
| 714 | 
         
            -
            def tile_image(image, chunk_size=64):
         
     | 
| 715 | 
         
            -
                c, h, w = image.shape
         
     | 
| 716 | 
         
            -
                h_chunks = ceil(h / chunk_size)
         
     | 
| 717 | 
         
            -
                w_chunks = ceil(w / chunk_size)
         
     | 
| 718 | 
         
            -
                tiles = []
         
     | 
| 719 | 
         
            -
                for i in range(h_chunks):
         
     | 
| 720 | 
         
            -
                    for j in range(w_chunks):
         
     | 
| 721 | 
         
            -
                        tile = image[
         
     | 
| 722 | 
         
            -
                            :,
         
     | 
| 723 | 
         
            -
                            i * chunk_size : (i + 1) * chunk_size,
         
     | 
| 724 | 
         
            -
                            j * chunk_size : (j + 1) * chunk_size,
         
     | 
| 725 | 
         
            -
                        ]
         
     | 
| 726 | 
         
            -
                        tiles.append(tile)
         
     | 
| 727 | 
         
            -
                return tiles, h_chunks, w_chunks
         
     | 
| 728 | 
         
            -
             
     | 
| 729 | 
         
            -
             
     | 
| 730 | 
         
            -
            # This helps create a checkboard pattern with some edge blending
         
     | 
| 731 | 
         
            -
            def create_checkerboard_weights(tile_size):
         
     | 
| 732 | 
         
            -
                x = torch.linspace(-1, 1, tile_size)
         
     | 
| 733 | 
         
            -
                y = torch.linspace(-1, 1, tile_size)
         
     | 
| 734 | 
         
            -
             
     | 
| 735 | 
         
            -
                x, y = torch.meshgrid(x, y, indexing="ij")
         
     | 
| 736 | 
         
            -
                d = torch.sqrt(x * x + y * y)
         
     | 
| 737 | 
         
            -
                sigma, mu = 0.5, 0.0
         
     | 
| 738 | 
         
            -
                weights = torch.exp(-((d - mu) ** 2 / (2.0 * sigma**2)))
         
     | 
| 739 | 
         
            -
             
     | 
| 740 | 
         
            -
                # saturate the values to sure get high weights in the center
         
     | 
| 741 | 
         
            -
                weights = weights**8
         
     | 
| 742 | 
         
            -
             
     | 
| 743 | 
         
            -
                return weights / weights.max()  # Normalize to [0, 1]
         
     | 
| 744 | 
         
            -
             
     | 
| 745 | 
         
            -
             
     | 
| 746 | 
         
            -
            def repeat_weights(weights, image_size):
         
     | 
| 747 | 
         
            -
                tile_size = weights.shape[0]
         
     | 
| 748 | 
         
            -
                repeats = (
         
     | 
| 749 | 
         
            -
                    math.ceil(image_size[0] / tile_size),
         
     | 
| 750 | 
         
            -
                    math.ceil(image_size[1] / tile_size),
         
     | 
| 751 | 
         
            -
                )
         
     | 
| 752 | 
         
            -
                return weights.repeat(repeats)[: image_size[0], : image_size[1]]
         
     | 
| 753 | 
         
            -
             
     | 
| 754 | 
         
            -
             
     | 
| 755 | 
         
            -
            def create_offset_weights(weights, image_size):
         
     | 
| 756 | 
         
            -
                tile_size = weights.shape[0]
         
     | 
| 757 | 
         
            -
                offset = tile_size // 2
         
     | 
| 758 | 
         
            -
                full_weights = repeat_weights(
         
     | 
| 759 | 
         
            -
                    weights, (image_size[0] + offset, image_size[1] + offset)
         
     | 
| 760 | 
         
            -
                )
         
     | 
| 761 | 
         
            -
                return full_weights[offset:, offset:]
         
     | 
| 762 | 
         
            -
             
     | 
| 763 | 
         
            -
             
     | 
| 764 | 
         
            -
            def merge_tiles(tiles, h_chunks, w_chunks, chunk_size=64):
         
     | 
| 765 | 
         
            -
                # Determine the shape of the output tensor
         
     | 
| 766 | 
         
            -
                c = tiles[0].shape[0]
         
     | 
| 767 | 
         
            -
                h = h_chunks * chunk_size
         
     | 
| 768 | 
         
            -
                w = w_chunks * chunk_size
         
     | 
| 769 | 
         
            -
             
     | 
| 770 | 
         
            -
                # Create an empty tensor to hold the merged image
         
     | 
| 771 | 
         
            -
                merged = torch.zeros((c, h, w), dtype=tiles[0].dtype)
         
     | 
| 772 | 
         
            -
             
     | 
| 773 | 
         
            -
                # Iterate over the tiles and place them in the correct position
         
     | 
| 774 | 
         
            -
                for idx, tile in enumerate(tiles):
         
     | 
| 775 | 
         
            -
                    i = idx // w_chunks
         
     | 
| 776 | 
         
            -
                    j = idx % w_chunks
         
     | 
| 777 | 
         
            -
             
     | 
| 778 | 
         
            -
                    h_start = i * chunk_size
         
     | 
| 779 | 
         
            -
                    w_start = j * chunk_size
         
     | 
| 780 | 
         
            -
             
     | 
| 781 | 
         
            -
                    tile_h, tile_w = tile.shape[1:]
         
     | 
| 782 | 
         
            -
                    merged[:, h_start : h_start + tile_h, w_start : w_start + tile_w] = tile
         
     | 
| 783 | 
         
            -
             
     | 
| 784 | 
         
            -
                return merged
         
     | 
| 785 | 
         
            -
             
     | 
| 786 | 
         
            -
             
     | 
| 787 | 
         
            -
            class AuraSR:
         
     | 
| 788 | 
         
            -
                def __init__(self, config: dict[str, Any], device: str = "cuda"):
         
     | 
| 789 | 
         
            -
                    self.upsampler = UnetUpsampler(**config).to(device)
         
     | 
| 790 | 
         
            -
                    self.input_image_size = config["input_image_size"]
         
     | 
| 791 | 
         
            -
             
     | 
| 792 | 
         
            -
                @classmethod
         
     | 
| 793 | 
         
            -
                def from_pretrained(
         
     | 
| 794 | 
         
            -
                    cls,
         
     | 
| 795 | 
         
            -
                    model_id: str = "fal-ai/AuraSR",
         
     | 
| 796 | 
         
            -
                    use_safetensors: bool = True,
         
     | 
| 797 | 
         
            -
                    device: str = "cuda",
         
     | 
| 798 | 
         
            -
                ):
         
     | 
| 799 | 
         
            -
                    import json
         
     | 
| 800 | 
         
            -
                    import torch
         
     | 
| 801 | 
         
            -
                    from pathlib import Path
         
     | 
| 802 | 
         
            -
                    from huggingface_hub import snapshot_download
         
     | 
| 803 | 
         
            -
             
     | 
| 804 | 
         
            -
                    # Check if model_id is a local file
         
     | 
| 805 | 
         
            -
                    if Path(model_id).is_file():
         
     | 
| 806 | 
         
            -
                        local_file = Path(model_id)
         
     | 
| 807 | 
         
            -
                        if local_file.suffix == ".safetensors":
         
     | 
| 808 | 
         
            -
                            use_safetensors = True
         
     | 
| 809 | 
         
            -
                        elif local_file.suffix == ".ckpt":
         
     | 
| 810 | 
         
            -
                            use_safetensors = False
         
     | 
| 811 | 
         
            -
                        else:
         
     | 
| 812 | 
         
            -
                            raise ValueError(
         
     | 
| 813 | 
         
            -
                                f"Unsupported file format: {local_file.suffix}. Please use .safetensors or .ckpt files."
         
     | 
| 814 | 
         
            -
                            )
         
     | 
| 815 | 
         
            -
             
     | 
| 816 | 
         
            -
                        # For local files, we need to provide the config separately
         
     | 
| 817 | 
         
            -
                        config_path = local_file.with_name("config.json")
         
     | 
| 818 | 
         
            -
                        if not config_path.exists():
         
     | 
| 819 | 
         
            -
                            raise FileNotFoundError(
         
     | 
| 820 | 
         
            -
                                f"Config file not found: {config_path}. "
         
     | 
| 821 | 
         
            -
                                f"When loading from a local file, ensure that 'config.json' "
         
     | 
| 822 | 
         
            -
                                f"is present in the same directory as '{local_file.name}'. "
         
     | 
| 823 | 
         
            -
                                f"If you're trying to load a model from Hugging Face, "
         
     | 
| 824 | 
         
            -
                                f"please provide the model ID instead of a file path."
         
     | 
| 825 | 
         
            -
                            )
         
     | 
| 826 | 
         
            -
             
     | 
| 827 | 
         
            -
                        config = json.loads(config_path.read_text())
         
     | 
| 828 | 
         
            -
                        hf_model_path = local_file.parent
         
     | 
| 829 | 
         
            -
                    else:
         
     | 
| 830 | 
         
            -
                        hf_model_path = Path(
         
     | 
| 831 | 
         
            -
                            snapshot_download(model_id, ignore_patterns=["*.ckpt"])
         
     | 
| 832 | 
         
            -
                        )
         
     | 
| 833 | 
         
            -
                        config = json.loads((hf_model_path / "config.json").read_text())
         
     | 
| 834 | 
         
            -
             
     | 
| 835 | 
         
            -
                    model = cls(config, device)
         
     | 
| 836 | 
         
            -
             
     | 
| 837 | 
         
            -
                    if use_safetensors:
         
     | 
| 838 | 
         
            -
                        try:
         
     | 
| 839 | 
         
            -
                            from safetensors.torch import load_file
         
     | 
| 840 | 
         
            -
             
     | 
| 841 | 
         
            -
                            checkpoint = load_file(
         
     | 
| 842 | 
         
            -
                                hf_model_path / "model.safetensors"
         
     | 
| 843 | 
         
            -
                                if not Path(model_id).is_file()
         
     | 
| 844 | 
         
            -
                                else model_id
         
     | 
| 845 | 
         
            -
                            )
         
     | 
| 846 | 
         
            -
                        except ImportError:
         
     | 
| 847 | 
         
            -
                            raise ImportError(
         
     | 
| 848 | 
         
            -
                                "The safetensors library is not installed. "
         
     | 
| 849 | 
         
            -
                                "Please install it with `pip install safetensors` "
         
     | 
| 850 | 
         
            -
                                "or use `use_safetensors=False` to load the model with PyTorch."
         
     | 
| 851 | 
         
            -
                            )
         
     | 
| 852 | 
         
            -
                    else:
         
     | 
| 853 | 
         
            -
                        checkpoint = torch.load(
         
     | 
| 854 | 
         
            -
                            hf_model_path / "model.ckpt"
         
     | 
| 855 | 
         
            -
                            if not Path(model_id).is_file()
         
     | 
| 856 | 
         
            -
                            else model_id
         
     | 
| 857 | 
         
            -
                        )
         
     | 
| 858 | 
         
            -
             
     | 
| 859 | 
         
            -
                    model.upsampler.load_state_dict(checkpoint, strict=True)
         
     | 
| 860 | 
         
            -
                    return model
         
     | 
| 861 | 
         
            -
             
     | 
| 862 | 
         
            -
                @torch.no_grad()
         
     | 
| 863 | 
         
            -
                def upscale_4x(self, image: Image.Image, max_batch_size=8) -> Image.Image:
         
     | 
| 864 | 
         
            -
                    tensor_transform = transforms.ToTensor()
         
     | 
| 865 | 
         
            -
                    device = self.upsampler.device
         
     | 
| 866 | 
         
            -
             
     | 
| 867 | 
         
            -
                    image_tensor = tensor_transform(image).unsqueeze(0)
         
     | 
| 868 | 
         
            -
                    _, _, h, w = image_tensor.shape
         
     | 
| 869 | 
         
            -
                    pad_h = (
         
     | 
| 870 | 
         
            -
                        self.input_image_size - h % self.input_image_size
         
     | 
| 871 | 
         
            -
                    ) % self.input_image_size
         
     | 
| 872 | 
         
            -
                    pad_w = (
         
     | 
| 873 | 
         
            -
                        self.input_image_size - w % self.input_image_size
         
     | 
| 874 | 
         
            -
                    ) % self.input_image_size
         
     | 
| 875 | 
         
            -
             
     | 
| 876 | 
         
            -
                    # Pad the image
         
     | 
| 877 | 
         
            -
                    image_tensor = torch.nn.functional.pad(
         
     | 
| 878 | 
         
            -
                        image_tensor, (0, pad_w, 0, pad_h), mode="reflect"
         
     | 
| 879 | 
         
            -
                    ).squeeze(0)
         
     | 
| 880 | 
         
            -
                    tiles, h_chunks, w_chunks = tile_image(image_tensor, self.input_image_size)
         
     | 
| 881 | 
         
            -
             
     | 
| 882 | 
         
            -
                    # Batch processing of tiles
         
     | 
| 883 | 
         
            -
                    num_tiles = len(tiles)
         
     | 
| 884 | 
         
            -
                    batches = [
         
     | 
| 885 | 
         
            -
                        tiles[i : i + max_batch_size] for i in range(0, num_tiles, max_batch_size)
         
     | 
| 886 | 
         
            -
                    ]
         
     | 
| 887 | 
         
            -
                    reconstructed_tiles = []
         
     | 
| 888 | 
         
            -
             
     | 
| 889 | 
         
            -
                    for batch in batches:
         
     | 
| 890 | 
         
            -
                        model_input = torch.stack(batch).to(device)
         
     | 
| 891 | 
         
            -
                        generator_output = self.upsampler(
         
     | 
| 892 | 
         
            -
                            lowres_image=model_input,
         
     | 
| 893 | 
         
            -
                            noise=torch.randn(model_input.shape[0], 128, device=device),
         
     | 
| 894 | 
         
            -
                        )
         
     | 
| 895 | 
         
            -
                        reconstructed_tiles.extend(
         
     | 
| 896 | 
         
            -
                            list(generator_output.clamp_(0, 1).detach().cpu())
         
     | 
| 897 | 
         
            -
                        )
         
     | 
| 898 | 
         
            -
             
     | 
| 899 | 
         
            -
                    merged_tensor = merge_tiles(
         
     | 
| 900 | 
         
            -
                        reconstructed_tiles, h_chunks, w_chunks, self.input_image_size * 4
         
     | 
| 901 | 
         
            -
                    )
         
     | 
| 902 | 
         
            -
                    unpadded = merged_tensor[:, : h * 4, : w * 4]
         
     | 
| 903 | 
         
            -
             
     | 
| 904 | 
         
            -
                    to_pil = transforms.ToPILImage()
         
     | 
| 905 | 
         
            -
                    return to_pil(unpadded)
         
     | 
| 906 | 
         
            -
             
     | 
| 907 | 
         
            -
                # Tiled 4x upscaling with overlapping tiles to reduce seam artifacts
         
     | 
| 908 | 
         
            -
                # weights options are 'checkboard' and 'constant'
         
     | 
| 909 | 
         
            -
                @torch.no_grad()
         
     | 
| 910 | 
         
            -
                def upscale_4x_overlapped(self, image, max_batch_size=8, weight_type="checkboard"):
         
     | 
| 911 | 
         
            -
                    tensor_transform = transforms.ToTensor()
         
     | 
| 912 | 
         
            -
                    device = self.upsampler.device
         
     | 
| 913 | 
         
            -
             
     | 
| 914 | 
         
            -
                    image_tensor = tensor_transform(image).unsqueeze(0)
         
     | 
| 915 | 
         
            -
                    _, _, h, w = image_tensor.shape
         
     | 
| 916 | 
         
            -
             
     | 
| 917 | 
         
            -
                    # Calculate paddings
         
     | 
| 918 | 
         
            -
                    pad_h = (
         
     | 
| 919 | 
         
            -
                        self.input_image_size - h % self.input_image_size
         
     | 
| 920 | 
         
            -
                    ) % self.input_image_size
         
     | 
| 921 | 
         
            -
                    pad_w = (
         
     | 
| 922 | 
         
            -
                        self.input_image_size - w % self.input_image_size
         
     | 
| 923 | 
         
            -
                    ) % self.input_image_size
         
     | 
| 924 | 
         
            -
             
     | 
| 925 | 
         
            -
                    # Pad the image
         
     | 
| 926 | 
         
            -
                    image_tensor = torch.nn.functional.pad(
         
     | 
| 927 | 
         
            -
                        image_tensor, (0, pad_w, 0, pad_h), mode="reflect"
         
     | 
| 928 | 
         
            -
                    ).squeeze(0)
         
     | 
| 929 | 
         
            -
             
     | 
| 930 | 
         
            -
                    # Function to process tiles
         
     | 
| 931 | 
         
            -
                    def process_tiles(tiles, h_chunks, w_chunks):
         
     | 
| 932 | 
         
            -
                        num_tiles = len(tiles)
         
     | 
| 933 | 
         
            -
                        batches = [
         
     | 
| 934 | 
         
            -
                            tiles[i : i + max_batch_size]
         
     | 
| 935 | 
         
            -
                            for i in range(0, num_tiles, max_batch_size)
         
     | 
| 936 | 
         
            -
                        ]
         
     | 
| 937 | 
         
            -
                        reconstructed_tiles = []
         
     | 
| 938 | 
         
            -
             
     | 
| 939 | 
         
            -
                        for batch in batches:
         
     | 
| 940 | 
         
            -
                            model_input = torch.stack(batch).to(device)
         
     | 
| 941 | 
         
            -
                            generator_output = self.upsampler(
         
     | 
| 942 | 
         
            -
                                lowres_image=model_input,
         
     | 
| 943 | 
         
            -
                                noise=torch.randn(model_input.shape[0], 128, device=device),
         
     | 
| 944 | 
         
            -
                            )
         
     | 
| 945 | 
         
            -
                            reconstructed_tiles.extend(
         
     | 
| 946 | 
         
            -
                                list(generator_output.clamp_(0, 1).detach().cpu())
         
     | 
| 947 | 
         
            -
                            )
         
     | 
| 948 | 
         
            -
             
     | 
| 949 | 
         
            -
                        return merge_tiles(
         
     | 
| 950 | 
         
            -
                            reconstructed_tiles, h_chunks, w_chunks, self.input_image_size * 4
         
     | 
| 951 | 
         
            -
                        )
         
     | 
| 952 | 
         
            -
             
     | 
| 953 | 
         
            -
                    # First pass
         
     | 
| 954 | 
         
            -
                    tiles1, h_chunks1, w_chunks1 = tile_image(image_tensor, self.input_image_size)
         
     | 
| 955 | 
         
            -
                    result1 = process_tiles(tiles1, h_chunks1, w_chunks1)
         
     | 
| 956 | 
         
            -
             
     | 
| 957 | 
         
            -
                    # Second pass with offset
         
     | 
| 958 | 
         
            -
                    offset = self.input_image_size // 2
         
     | 
| 959 | 
         
            -
                    image_tensor_offset = torch.nn.functional.pad(
         
     | 
| 960 | 
         
            -
                        image_tensor, (offset, offset, offset, offset), mode="reflect"
         
     | 
| 961 | 
         
            -
                    ).squeeze(0)
         
     | 
| 962 | 
         
            -
             
     | 
| 963 | 
         
            -
                    tiles2, h_chunks2, w_chunks2 = tile_image(
         
     | 
| 964 | 
         
            -
                        image_tensor_offset, self.input_image_size
         
     | 
| 965 | 
         
            -
                    )
         
     | 
| 966 | 
         
            -
                    result2 = process_tiles(tiles2, h_chunks2, w_chunks2)
         
     | 
| 967 | 
         
            -
             
     | 
| 968 | 
         
            -
                    # unpad
         
     | 
| 969 | 
         
            -
                    offset_4x = offset * 4
         
     | 
| 970 | 
         
            -
                    result2_interior = result2[:, offset_4x:-offset_4x, offset_4x:-offset_4x]
         
     | 
| 971 | 
         
            -
             
     | 
| 972 | 
         
            -
                    if weight_type == "checkboard":
         
     | 
| 973 | 
         
            -
                        weight_tile = create_checkerboard_weights(self.input_image_size * 4)
         
     | 
| 974 | 
         
            -
             
     | 
| 975 | 
         
            -
                        weight_shape = result2_interior.shape[1:]
         
     | 
| 976 | 
         
            -
                        weights_1 = create_offset_weights(weight_tile, weight_shape)
         
     | 
| 977 | 
         
            -
                        weights_2 = repeat_weights(weight_tile, weight_shape)
         
     | 
| 978 | 
         
            -
             
     | 
| 979 | 
         
            -
                        normalizer = weights_1 + weights_2
         
     | 
| 980 | 
         
            -
                        weights_1 = weights_1 / normalizer
         
     | 
| 981 | 
         
            -
                        weights_2 = weights_2 / normalizer
         
     | 
| 982 | 
         
            -
             
     | 
| 983 | 
         
            -
                        weights_1 = weights_1.unsqueeze(0).repeat(3, 1, 1)
         
     | 
| 984 | 
         
            -
                        weights_2 = weights_2.unsqueeze(0).repeat(3, 1, 1)
         
     | 
| 985 | 
         
            -
                    elif weight_type == "constant":
         
     | 
| 986 | 
         
            -
                        weights_1 = torch.ones_like(result2_interior) * 0.5
         
     | 
| 987 | 
         
            -
                        weights_2 = weights_1
         
     | 
| 988 | 
         
            -
                    else:
         
     | 
| 989 | 
         
            -
                        raise ValueError(
         
     | 
| 990 | 
         
            -
                            "weight_type should be either 'gaussian' or 'constant' but got",
         
     | 
| 991 | 
         
            -
                            weight_type,
         
     | 
| 992 | 
         
            -
                        )
         
     | 
| 993 | 
         
            -
             
     | 
| 994 | 
         
            -
                    result1 = result1 * weights_2
         
     | 
| 995 | 
         
            -
                    result2 = result2_interior * weights_1
         
     | 
| 996 | 
         
            -
             
     | 
| 997 | 
         
            -
                    # Average the overlapping region
         
     | 
| 998 | 
         
            -
                    result1 = result1 + result2
         
     | 
| 999 | 
         
            -
             
     | 
| 1000 | 
         
            -
                    # Remove padding
         
     | 
| 1001 | 
         
            -
                    unpadded = result1[:, : h * 4, : w * 4]
         
     | 
| 1002 | 
         
            -
             
     | 
| 1003 | 
         
            -
                    to_pil = transforms.ToPILImage()
         
     | 
| 1004 | 
         
            -
                    return to_pil(unpadded)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/upscale/aura_sr_upscale.py
    DELETED
    
    | 
         @@ -1,9 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from backend.upscale.aura_sr import AuraSR
         
     | 
| 2 | 
         
            -
            from PIL import Image
         
     | 
| 3 | 
         
            -
             
     | 
| 4 | 
         
            -
             
     | 
| 5 | 
         
            -
            def upscale_aura_sr(image_path: str):
         
     | 
| 6 | 
         
            -
             
     | 
| 7 | 
         
            -
                aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2", device="cpu")
         
     | 
| 8 | 
         
            -
                image_in = Image.open(image_path)  # .resize((256, 256))
         
     | 
| 9 | 
         
            -
                return aura_sr.upscale_4x(image_in)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/upscale/edsr_upscale_onnx.py
    DELETED
    
    | 
         @@ -1,37 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            import numpy as np
         
     | 
| 2 | 
         
            -
            import onnxruntime
         
     | 
| 3 | 
         
            -
            from huggingface_hub import hf_hub_download
         
     | 
| 4 | 
         
            -
            from PIL import Image
         
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
            -
             
     | 
| 7 | 
         
            -
            def upscale_edsr_2x(image_path: str):
         
     | 
| 8 | 
         
            -
                input_image = Image.open(image_path).convert("RGB")
         
     | 
| 9 | 
         
            -
                input_image = np.array(input_image).astype("float32")
         
     | 
| 10 | 
         
            -
                input_image = np.transpose(input_image, (2, 0, 1))
         
     | 
| 11 | 
         
            -
                img_arr = np.expand_dims(input_image, axis=0)
         
     | 
| 12 | 
         
            -
             
     | 
| 13 | 
         
            -
                if np.max(img_arr) > 256:  # 16-bit image
         
     | 
| 14 | 
         
            -
                    max_range = 65535
         
     | 
| 15 | 
         
            -
                else:
         
     | 
| 16 | 
         
            -
                    max_range = 255.0
         
     | 
| 17 | 
         
            -
                    img = img_arr / max_range
         
     | 
| 18 | 
         
            -
             
     | 
| 19 | 
         
            -
                model_path = hf_hub_download(
         
     | 
| 20 | 
         
            -
                    repo_id="rupeshs/edsr-onnx",
         
     | 
| 21 | 
         
            -
                    filename="edsr_onnxsim_2x.onnx",
         
     | 
| 22 | 
         
            -
                )
         
     | 
| 23 | 
         
            -
                sess = onnxruntime.InferenceSession(model_path)
         
     | 
| 24 | 
         
            -
             
     | 
| 25 | 
         
            -
                input_name = sess.get_inputs()[0].name
         
     | 
| 26 | 
         
            -
                output_name = sess.get_outputs()[0].name
         
     | 
| 27 | 
         
            -
                output = sess.run(
         
     | 
| 28 | 
         
            -
                    [output_name],
         
     | 
| 29 | 
         
            -
                    {input_name: img},
         
     | 
| 30 | 
         
            -
                )[0]
         
     | 
| 31 | 
         
            -
             
     | 
| 32 | 
         
            -
                result = output.squeeze()
         
     | 
| 33 | 
         
            -
                result = result.clip(0, 1)
         
     | 
| 34 | 
         
            -
                image_array = np.transpose(result, (1, 2, 0))
         
     | 
| 35 | 
         
            -
                image_array = np.uint8(image_array * 255)
         
     | 
| 36 | 
         
            -
                upscaled_image = Image.fromarray(image_array)
         
     | 
| 37 | 
         
            -
                return upscaled_image
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/upscale/tiled_upscale.py
    DELETED
    
    | 
         @@ -1,237 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            import time
         
     | 
| 2 | 
         
            -
            import math
         
     | 
| 3 | 
         
            -
            import logging
         
     | 
| 4 | 
         
            -
            from PIL import Image, ImageDraw, ImageFilter
         
     | 
| 5 | 
         
            -
            from backend.models.lcmdiffusion_setting import DiffusionTask
         
     | 
| 6 | 
         
            -
            from context import Context
         
     | 
| 7 | 
         
            -
            from constants import DEVICE
         
     | 
| 8 | 
         
            -
             
     | 
| 9 | 
         
            -
             
     | 
| 10 | 
         
            -
            def generate_upscaled_image(
         
     | 
| 11 | 
         
            -
                config,
         
     | 
| 12 | 
         
            -
                input_path=None,
         
     | 
| 13 | 
         
            -
                strength=0.3,
         
     | 
| 14 | 
         
            -
                scale_factor=2.0,
         
     | 
| 15 | 
         
            -
                tile_overlap=16,
         
     | 
| 16 | 
         
            -
                upscale_settings=None,
         
     | 
| 17 | 
         
            -
                context: Context = None,
         
     | 
| 18 | 
         
            -
                output_path=None,
         
     | 
| 19 | 
         
            -
                image_format="PNG",
         
     | 
| 20 | 
         
            -
            ):
         
     | 
| 21 | 
         
            -
                if config == None or (
         
     | 
| 22 | 
         
            -
                    input_path == None or input_path == "" and upscale_settings == None
         
     | 
| 23 | 
         
            -
                ):
         
     | 
| 24 | 
         
            -
                    logging.error("Wrong arguments in tiled upscale function call!")
         
     | 
| 25 | 
         
            -
                    return
         
     | 
| 26 | 
         
            -
             
     | 
| 27 | 
         
            -
                # Use the upscale_settings dict if provided; otherwise, build the
         
     | 
| 28 | 
         
            -
                # upscale_settings dict using the function arguments and default values
         
     | 
| 29 | 
         
            -
                if upscale_settings == None:
         
     | 
| 30 | 
         
            -
                    upscale_settings = {
         
     | 
| 31 | 
         
            -
                        "source_file": input_path,
         
     | 
| 32 | 
         
            -
                        "target_file": None,
         
     | 
| 33 | 
         
            -
                        "output_format": image_format,
         
     | 
| 34 | 
         
            -
                        "strength": strength,
         
     | 
| 35 | 
         
            -
                        "scale_factor": scale_factor,
         
     | 
| 36 | 
         
            -
                        "prompt": config.lcm_diffusion_setting.prompt,
         
     | 
| 37 | 
         
            -
                        "tile_overlap": tile_overlap,
         
     | 
| 38 | 
         
            -
                        "tile_size": 256,
         
     | 
| 39 | 
         
            -
                        "tiles": [],
         
     | 
| 40 | 
         
            -
                    }
         
     | 
| 41 | 
         
            -
                    source_image = Image.open(input_path)  # PIL image
         
     | 
| 42 | 
         
            -
                else:
         
     | 
| 43 | 
         
            -
                    source_image = Image.open(upscale_settings["source_file"])
         
     | 
| 44 | 
         
            -
             
     | 
| 45 | 
         
            -
                upscale_settings["source_image"] = source_image
         
     | 
| 46 | 
         
            -
             
     | 
| 47 | 
         
            -
                if upscale_settings["target_file"]:
         
     | 
| 48 | 
         
            -
                    result = Image.open(upscale_settings["target_file"])
         
     | 
| 49 | 
         
            -
                else:
         
     | 
| 50 | 
         
            -
                    result = Image.new(
         
     | 
| 51 | 
         
            -
                        mode="RGBA",
         
     | 
| 52 | 
         
            -
                        size=(
         
     | 
| 53 | 
         
            -
                            source_image.size[0] * int(upscale_settings["scale_factor"]),
         
     | 
| 54 | 
         
            -
                            source_image.size[1] * int(upscale_settings["scale_factor"]),
         
     | 
| 55 | 
         
            -
                        ),
         
     | 
| 56 | 
         
            -
                        color=(0, 0, 0, 0),
         
     | 
| 57 | 
         
            -
                    )
         
     | 
| 58 | 
         
            -
                upscale_settings["target_image"] = result
         
     | 
| 59 | 
         
            -
             
     | 
| 60 | 
         
            -
                # If the custom tile definition array 'tiles' is empty, proceed with the
         
     | 
| 61 | 
         
            -
                # default tiled upscale task by defining all the possible image tiles; note
         
     | 
| 62 | 
         
            -
                # that the actual tile size is 'tile_size' + 'tile_overlap' and the target
         
     | 
| 63 | 
         
            -
                # image width and height are no longer constrained to multiples of 256 but
         
     | 
| 64 | 
         
            -
                # are instead multiples of the actual tile size
         
     | 
| 65 | 
         
            -
                if len(upscale_settings["tiles"]) == 0:
         
     | 
| 66 | 
         
            -
                    tile_size = upscale_settings["tile_size"]
         
     | 
| 67 | 
         
            -
                    scale_factor = upscale_settings["scale_factor"]
         
     | 
| 68 | 
         
            -
                    tile_overlap = upscale_settings["tile_overlap"]
         
     | 
| 69 | 
         
            -
                    total_cols = math.ceil(
         
     | 
| 70 | 
         
            -
                        source_image.size[0] / tile_size
         
     | 
| 71 | 
         
            -
                    )  # Image width / tile size
         
     | 
| 72 | 
         
            -
                    total_rows = math.ceil(
         
     | 
| 73 | 
         
            -
                        source_image.size[1] / tile_size
         
     | 
| 74 | 
         
            -
                    )  # Image height / tile size
         
     | 
| 75 | 
         
            -
                    for y in range(0, total_rows):
         
     | 
| 76 | 
         
            -
                        y_offset = tile_overlap if y > 0 else 0  # Tile mask offset
         
     | 
| 77 | 
         
            -
                        for x in range(0, total_cols):
         
     | 
| 78 | 
         
            -
                            x_offset = tile_overlap if x > 0 else 0  # Tile mask offset
         
     | 
| 79 | 
         
            -
                            x1 = x * tile_size
         
     | 
| 80 | 
         
            -
                            y1 = y * tile_size
         
     | 
| 81 | 
         
            -
                            w = tile_size + (tile_overlap if x < total_cols - 1 else 0)
         
     | 
| 82 | 
         
            -
                            h = tile_size + (tile_overlap if y < total_rows - 1 else 0)
         
     | 
| 83 | 
         
            -
                            mask_box = (  # Default tile mask box definition
         
     | 
| 84 | 
         
            -
                                x_offset,
         
     | 
| 85 | 
         
            -
                                y_offset,
         
     | 
| 86 | 
         
            -
                                int(w * scale_factor),
         
     | 
| 87 | 
         
            -
                                int(h * scale_factor),
         
     | 
| 88 | 
         
            -
                            )
         
     | 
| 89 | 
         
            -
                            upscale_settings["tiles"].append(
         
     | 
| 90 | 
         
            -
                                {
         
     | 
| 91 | 
         
            -
                                    "x": x1,
         
     | 
| 92 | 
         
            -
                                    "y": y1,
         
     | 
| 93 | 
         
            -
                                    "w": w,
         
     | 
| 94 | 
         
            -
                                    "h": h,
         
     | 
| 95 | 
         
            -
                                    "mask_box": mask_box,
         
     | 
| 96 | 
         
            -
                                    "prompt": upscale_settings["prompt"],  # Use top level prompt if available
         
     | 
| 97 | 
         
            -
                                    "scale_factor": scale_factor,
         
     | 
| 98 | 
         
            -
                                }
         
     | 
| 99 | 
         
            -
                            )
         
     | 
| 100 | 
         
            -
             
     | 
| 101 | 
         
            -
                # Generate the output image tiles
         
     | 
| 102 | 
         
            -
                for i in range(0, len(upscale_settings["tiles"])):
         
     | 
| 103 | 
         
            -
                    generate_upscaled_tile(
         
     | 
| 104 | 
         
            -
                        config,
         
     | 
| 105 | 
         
            -
                        i,
         
     | 
| 106 | 
         
            -
                        upscale_settings,
         
     | 
| 107 | 
         
            -
                        context=context,
         
     | 
| 108 | 
         
            -
                    )
         
     | 
| 109 | 
         
            -
             
     | 
| 110 | 
         
            -
                # Save completed upscaled image
         
     | 
| 111 | 
         
            -
                if upscale_settings["output_format"].upper() == "JPEG":
         
     | 
| 112 | 
         
            -
                    result_rgb = result.convert("RGB")
         
     | 
| 113 | 
         
            -
                    result.close()
         
     | 
| 114 | 
         
            -
                    result = result_rgb
         
     | 
| 115 | 
         
            -
                result.save(output_path)
         
     | 
| 116 | 
         
            -
                result.close()
         
     | 
| 117 | 
         
            -
                source_image.close()
         
     | 
| 118 | 
         
            -
                return
         
     | 
| 119 | 
         
            -
             
     | 
| 120 | 
         
            -
             
     | 
| 121 | 
         
            -
            def get_current_tile(
         
     | 
| 122 | 
         
            -
                config,
         
     | 
| 123 | 
         
            -
                context,
         
     | 
| 124 | 
         
            -
                strength,
         
     | 
| 125 | 
         
            -
            ):
         
     | 
| 126 | 
         
            -
                config.lcm_diffusion_setting.strength = strength
         
     | 
| 127 | 
         
            -
                config.lcm_diffusion_setting.diffusion_task = DiffusionTask.image_to_image.value
         
     | 
| 128 | 
         
            -
                if (
         
     | 
| 129 | 
         
            -
                    config.lcm_diffusion_setting.use_tiny_auto_encoder
         
     | 
| 130 | 
         
            -
                    and config.lcm_diffusion_setting.use_openvino
         
     | 
| 131 | 
         
            -
                ):
         
     | 
| 132 | 
         
            -
                    config.lcm_diffusion_setting.use_tiny_auto_encoder = False
         
     | 
| 133 | 
         
            -
                current_tile = context.generate_text_to_image(
         
     | 
| 134 | 
         
            -
                    settings=config,
         
     | 
| 135 | 
         
            -
                    reshape=True,
         
     | 
| 136 | 
         
            -
                    device=DEVICE,
         
     | 
| 137 | 
         
            -
                    save_config=False,
         
     | 
| 138 | 
         
            -
                )[0]
         
     | 
| 139 | 
         
            -
                return current_tile
         
     | 
| 140 | 
         
            -
             
     | 
| 141 | 
         
            -
             
     | 
| 142 | 
         
            -
            # Generates a single tile from the source image as defined in the
         
     | 
| 143 | 
         
            -
            # upscale_settings["tiles"] array with the corresponding index and pastes the
         
     | 
| 144 | 
         
            -
            # generated tile into the target image using the corresponding mask and scale
         
     | 
| 145 | 
         
            -
            # factor; note that scale factor for the target image and the individual tiles
         
     | 
| 146 | 
         
            -
            # can be different, this function will adjust scale factors as needed
         
     | 
| 147 | 
         
            -
            def generate_upscaled_tile(
         
     | 
| 148 | 
         
            -
                config,
         
     | 
| 149 | 
         
            -
                index,
         
     | 
| 150 | 
         
            -
                upscale_settings,
         
     | 
| 151 | 
         
            -
                context: Context = None,
         
     | 
| 152 | 
         
            -
            ):
         
     | 
| 153 | 
         
            -
                if config == None or upscale_settings == None:
         
     | 
| 154 | 
         
            -
                    logging.error("Wrong arguments in tile creation function call!")
         
     | 
| 155 | 
         
            -
                    return
         
     | 
| 156 | 
         
            -
             
     | 
| 157 | 
         
            -
                x = upscale_settings["tiles"][index]["x"]
         
     | 
| 158 | 
         
            -
                y = upscale_settings["tiles"][index]["y"]
         
     | 
| 159 | 
         
            -
                w = upscale_settings["tiles"][index]["w"]
         
     | 
| 160 | 
         
            -
                h = upscale_settings["tiles"][index]["h"]
         
     | 
| 161 | 
         
            -
                tile_prompt = upscale_settings["tiles"][index]["prompt"]
         
     | 
| 162 | 
         
            -
                scale_factor = upscale_settings["scale_factor"]
         
     | 
| 163 | 
         
            -
                tile_scale_factor = upscale_settings["tiles"][index]["scale_factor"]
         
     | 
| 164 | 
         
            -
                target_width = int(w * tile_scale_factor)
         
     | 
| 165 | 
         
            -
                target_height = int(h * tile_scale_factor)
         
     | 
| 166 | 
         
            -
                strength = upscale_settings["strength"]
         
     | 
| 167 | 
         
            -
                source_image = upscale_settings["source_image"]
         
     | 
| 168 | 
         
            -
                target_image = upscale_settings["target_image"]
         
     | 
| 169 | 
         
            -
                mask_image = generate_tile_mask(config, index, upscale_settings)
         
     | 
| 170 | 
         
            -
             
     | 
| 171 | 
         
            -
                config.lcm_diffusion_setting.number_of_images = 1
         
     | 
| 172 | 
         
            -
                config.lcm_diffusion_setting.prompt = tile_prompt
         
     | 
| 173 | 
         
            -
                config.lcm_diffusion_setting.image_width = target_width
         
     | 
| 174 | 
         
            -
                config.lcm_diffusion_setting.image_height = target_height
         
     | 
| 175 | 
         
            -
                config.lcm_diffusion_setting.init_image = source_image.crop((x, y, x + w, y + h))
         
     | 
| 176 | 
         
            -
             
     | 
| 177 | 
         
            -
                current_tile = None
         
     | 
| 178 | 
         
            -
                print(f"[SD Upscale] Generating tile {index + 1}/{len(upscale_settings['tiles'])} ")
         
     | 
| 179 | 
         
            -
                if tile_prompt == None or tile_prompt == "":
         
     | 
| 180 | 
         
            -
                    config.lcm_diffusion_setting.prompt = ""
         
     | 
| 181 | 
         
            -
                    config.lcm_diffusion_setting.negative_prompt = ""
         
     | 
| 182 | 
         
            -
                    current_tile = get_current_tile(config, context, strength)
         
     | 
| 183 | 
         
            -
                else:
         
     | 
| 184 | 
         
            -
                    # Attempt to use img2img with low denoising strength to
         
     | 
| 185 | 
         
            -
                    # generate the tiles with the extra aid of a prompt
         
     | 
| 186 | 
         
            -
                    # context = get_context(InterfaceType.CLI)
         
     | 
| 187 | 
         
            -
                    current_tile = get_current_tile(config, context, strength)
         
     | 
| 188 | 
         
            -
             
     | 
| 189 | 
         
            -
                if math.isclose(scale_factor, tile_scale_factor):
         
     | 
| 190 | 
         
            -
                    target_image.paste(
         
     | 
| 191 | 
         
            -
                        current_tile, (int(x * scale_factor), int(y * scale_factor)), mask_image
         
     | 
| 192 | 
         
            -
                    )
         
     | 
| 193 | 
         
            -
                else:
         
     | 
| 194 | 
         
            -
                    target_image.paste(
         
     | 
| 195 | 
         
            -
                        current_tile.resize((int(w * scale_factor), int(h * scale_factor))),
         
     | 
| 196 | 
         
            -
                        (int(x * scale_factor), int(y * scale_factor)),
         
     | 
| 197 | 
         
            -
                        mask_image.resize((int(w * scale_factor), int(h * scale_factor))),
         
     | 
| 198 | 
         
            -
                    )
         
     | 
| 199 | 
         
            -
                mask_image.close()
         
     | 
| 200 | 
         
            -
                current_tile.close()
         
     | 
| 201 | 
         
            -
                config.lcm_diffusion_setting.init_image.close()
         
     | 
| 202 | 
         
            -
             
     | 
| 203 | 
         
            -
             
     | 
| 204 | 
         
            -
            # Generate tile mask using the box definition in the upscale_settings["tiles"]
         
     | 
| 205 | 
         
            -
            # array with the corresponding index; note that tile masks for the default
         
     | 
| 206 | 
         
            -
            # tiled upscale task can be reused but that would complicate the code, so
         
     | 
| 207 | 
         
            -
            # new tile masks are instead created for each tile
         
     | 
| 208 | 
         
            -
            def generate_tile_mask(
         
     | 
| 209 | 
         
            -
                config,
         
     | 
| 210 | 
         
            -
                index,
         
     | 
| 211 | 
         
            -
                upscale_settings,
         
     | 
| 212 | 
         
            -
            ):
         
     | 
| 213 | 
         
            -
                scale_factor = upscale_settings["scale_factor"]
         
     | 
| 214 | 
         
            -
                tile_overlap = upscale_settings["tile_overlap"]
         
     | 
| 215 | 
         
            -
                tile_scale_factor = upscale_settings["tiles"][index]["scale_factor"]
         
     | 
| 216 | 
         
            -
                w = int(upscale_settings["tiles"][index]["w"] * tile_scale_factor)
         
     | 
| 217 | 
         
            -
                h = int(upscale_settings["tiles"][index]["h"] * tile_scale_factor)
         
     | 
| 218 | 
         
            -
                # The Stable Diffusion pipeline automatically adjusts the output size
         
     | 
| 219 | 
         
            -
                # to multiples of 8 pixels; the mask must be created with the same
         
     | 
| 220 | 
         
            -
                # size as the output tile
         
     | 
| 221 | 
         
            -
                w = w - (w % 8)
         
     | 
| 222 | 
         
            -
                h = h - (h % 8)
         
     | 
| 223 | 
         
            -
                mask_box = upscale_settings["tiles"][index]["mask_box"]
         
     | 
| 224 | 
         
            -
                if mask_box == None:
         
     | 
| 225 | 
         
            -
                    # Build a default solid mask with soft/transparent edges
         
     | 
| 226 | 
         
            -
                    mask_box = (
         
     | 
| 227 | 
         
            -
                        tile_overlap,
         
     | 
| 228 | 
         
            -
                        tile_overlap,
         
     | 
| 229 | 
         
            -
                        w - tile_overlap,
         
     | 
| 230 | 
         
            -
                        h - tile_overlap,
         
     | 
| 231 | 
         
            -
                    )
         
     | 
| 232 | 
         
            -
                mask_image = Image.new(mode="RGBA", size=(w, h), color=(0, 0, 0, 0))
         
     | 
| 233 | 
         
            -
                mask_draw = ImageDraw.Draw(mask_image)
         
     | 
| 234 | 
         
            -
                mask_draw.rectangle(tuple(mask_box), fill=(0, 0, 0))
         
     | 
| 235 | 
         
            -
                mask_blur = mask_image.filter(ImageFilter.BoxBlur(tile_overlap - 1))
         
     | 
| 236 | 
         
            -
                mask_image.close()
         
     | 
| 237 | 
         
            -
                return mask_blur
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/backend/upscale/upscaler.py
    DELETED
    
    | 
         @@ -1,52 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from backend.models.lcmdiffusion_setting import DiffusionTask
         
     | 
| 2 | 
         
            -
            from backend.models.upscale import UpscaleMode
         
     | 
| 3 | 
         
            -
            from backend.upscale.edsr_upscale_onnx import upscale_edsr_2x
         
     | 
| 4 | 
         
            -
            from backend.upscale.aura_sr_upscale import upscale_aura_sr
         
     | 
| 5 | 
         
            -
            from backend.upscale.tiled_upscale import generate_upscaled_image
         
     | 
| 6 | 
         
            -
            from context import Context
         
     | 
| 7 | 
         
            -
            from PIL import Image
         
     | 
| 8 | 
         
            -
            from state import get_settings
         
     | 
| 9 | 
         
            -
             
     | 
| 10 | 
         
            -
             
     | 
| 11 | 
         
            -
            config = get_settings()
         
     | 
| 12 | 
         
            -
             
     | 
| 13 | 
         
            -
             
     | 
| 14 | 
         
            -
            def upscale_image(
         
     | 
| 15 | 
         
            -
                context: Context,
         
     | 
| 16 | 
         
            -
                src_image_path: str,
         
     | 
| 17 | 
         
            -
                dst_image_path: str,
         
     | 
| 18 | 
         
            -
                scale_factor: int = 2,
         
     | 
| 19 | 
         
            -
                upscale_mode: UpscaleMode = UpscaleMode.normal.value,
         
     | 
| 20 | 
         
            -
                strength: float = 0.1,
         
     | 
| 21 | 
         
            -
            ):
         
     | 
| 22 | 
         
            -
                if upscale_mode == UpscaleMode.normal.value:
         
     | 
| 23 | 
         
            -
                    upscaled_img = upscale_edsr_2x(src_image_path)
         
     | 
| 24 | 
         
            -
                    upscaled_img.save(dst_image_path)
         
     | 
| 25 | 
         
            -
                    print(f"Upscaled image saved {dst_image_path}")
         
     | 
| 26 | 
         
            -
                elif upscale_mode == UpscaleMode.aura_sr.value:
         
     | 
| 27 | 
         
            -
                    upscaled_img = upscale_aura_sr(src_image_path)
         
     | 
| 28 | 
         
            -
                    upscaled_img.save(dst_image_path)
         
     | 
| 29 | 
         
            -
                    print(f"Upscaled image saved {dst_image_path}")
         
     | 
| 30 | 
         
            -
                else:
         
     | 
| 31 | 
         
            -
                    config.settings.lcm_diffusion_setting.strength = (
         
     | 
| 32 | 
         
            -
                        0.3 if config.settings.lcm_diffusion_setting.use_openvino else strength
         
     | 
| 33 | 
         
            -
                    )
         
     | 
| 34 | 
         
            -
                    config.settings.lcm_diffusion_setting.diffusion_task = (
         
     | 
| 35 | 
         
            -
                        DiffusionTask.image_to_image.value
         
     | 
| 36 | 
         
            -
                    )
         
     | 
| 37 | 
         
            -
             
     | 
| 38 | 
         
            -
                    generate_upscaled_image(
         
     | 
| 39 | 
         
            -
                        config.settings,
         
     | 
| 40 | 
         
            -
                        src_image_path,
         
     | 
| 41 | 
         
            -
                        config.settings.lcm_diffusion_setting.strength,
         
     | 
| 42 | 
         
            -
                        upscale_settings=None,
         
     | 
| 43 | 
         
            -
                        context=context,
         
     | 
| 44 | 
         
            -
                        tile_overlap=(
         
     | 
| 45 | 
         
            -
                            32 if config.settings.lcm_diffusion_setting.use_openvino else 16
         
     | 
| 46 | 
         
            -
                        ),
         
     | 
| 47 | 
         
            -
                        output_path=dst_image_path,
         
     | 
| 48 | 
         
            -
                        image_format=config.settings.generated_images.format,
         
     | 
| 49 | 
         
            -
                    )
         
     | 
| 50 | 
         
            -
                    print(f"Upscaled image saved {dst_image_path}")
         
     | 
| 51 | 
         
            -
             
     | 
| 52 | 
         
            -
                return [Image.open(dst_image_path)]
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/constants.py
    DELETED
    
    | 
         @@ -1,25 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from os import environ, cpu_count
         
     | 
| 2 | 
         
            -
             
     | 
| 3 | 
         
            -
            cpu_cores = cpu_count()
         
     | 
| 4 | 
         
            -
            cpus = cpu_cores // 2 if cpu_cores else 0
         
     | 
| 5 | 
         
            -
            APP_VERSION = "v1.0.0 beta 200"
         
     | 
| 6 | 
         
            -
            LCM_DEFAULT_MODEL = "stabilityai/sd-turbo"
         
     | 
| 7 | 
         
            -
            LCM_DEFAULT_MODEL_OPENVINO = "rupeshs/sd-turbo-openvino"
         
     | 
| 8 | 
         
            -
            APP_NAME = "FastSD CPU"
         
     | 
| 9 | 
         
            -
            APP_SETTINGS_FILE = "settings.yaml"
         
     | 
| 10 | 
         
            -
            RESULTS_DIRECTORY = "results"
         
     | 
| 11 | 
         
            -
            CONFIG_DIRECTORY = "configs"
         
     | 
| 12 | 
         
            -
            DEVICE = environ.get("DEVICE", "cpu")
         
     | 
| 13 | 
         
            -
            SD_MODELS_FILE = "stable-diffusion-models.txt"
         
     | 
| 14 | 
         
            -
            LCM_LORA_MODELS_FILE = "lcm-lora-models.txt"
         
     | 
| 15 | 
         
            -
            OPENVINO_LCM_MODELS_FILE = "openvino-lcm-models.txt"
         
     | 
| 16 | 
         
            -
            TAESD_MODEL = "madebyollin/taesd"
         
     | 
| 17 | 
         
            -
            TAESDXL_MODEL = "madebyollin/taesdxl"
         
     | 
| 18 | 
         
            -
            TAESD_MODEL_OPENVINO = "deinferno/taesd-openvino"
         
     | 
| 19 | 
         
            -
            LCM_MODELS_FILE = "lcm-models.txt"
         
     | 
| 20 | 
         
            -
            TAESDXL_MODEL_OPENVINO = "rupeshs/taesdxl-openvino"
         
     | 
| 21 | 
         
            -
            LORA_DIRECTORY = "lora_models"
         
     | 
| 22 | 
         
            -
            CONTROLNET_DIRECTORY = "controlnet_models"
         
     | 
| 23 | 
         
            -
            MODELS_DIRECTORY = "models"
         
     | 
| 24 | 
         
            -
            GGUF_THREADS = environ.get("GGUF_THREADS", cpus)
         
     | 
| 25 | 
         
            -
            TAEF1_MODEL_OPENVINO = "rupeshs/taef1-openvino"
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/context.py
    DELETED
    
    | 
         @@ -1,85 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from typing import Any
         
     | 
| 2 | 
         
            -
            from app_settings import Settings
         
     | 
| 3 | 
         
            -
            from models.interface_types import InterfaceType
         
     | 
| 4 | 
         
            -
            from backend.models.lcmdiffusion_setting import DiffusionTask
         
     | 
| 5 | 
         
            -
            from backend.lcm_text_to_image import LCMTextToImage
         
     | 
| 6 | 
         
            -
            from time import perf_counter
         
     | 
| 7 | 
         
            -
            from backend.image_saver import ImageSaver
         
     | 
| 8 | 
         
            -
            from pprint import pprint
         
     | 
| 9 | 
         
            -
             
     | 
| 10 | 
         
            -
             
     | 
| 11 | 
         
            -
            class Context:
         
     | 
| 12 | 
         
            -
                def __init__(
         
     | 
| 13 | 
         
            -
                    self,
         
     | 
| 14 | 
         
            -
                    interface_type: InterfaceType,
         
     | 
| 15 | 
         
            -
                    device="cpu",
         
     | 
| 16 | 
         
            -
                ):
         
     | 
| 17 | 
         
            -
                    self.interface_type = interface_type.value
         
     | 
| 18 | 
         
            -
                    self.lcm_text_to_image = LCMTextToImage(device)
         
     | 
| 19 | 
         
            -
                    self._latency = 0
         
     | 
| 20 | 
         
            -
             
     | 
| 21 | 
         
            -
                @property
         
     | 
| 22 | 
         
            -
                def latency(self):
         
     | 
| 23 | 
         
            -
                    return self._latency
         
     | 
| 24 | 
         
            -
             
     | 
| 25 | 
         
            -
                def generate_text_to_image(
         
     | 
| 26 | 
         
            -
                    self,
         
     | 
| 27 | 
         
            -
                    settings: Settings,
         
     | 
| 28 | 
         
            -
                    reshape: bool = False,
         
     | 
| 29 | 
         
            -
                    device: str = "cpu",
         
     | 
| 30 | 
         
            -
                    save_config=True,
         
     | 
| 31 | 
         
            -
                ) -> Any:
         
     | 
| 32 | 
         
            -
                    if (
         
     | 
| 33 | 
         
            -
                        settings.lcm_diffusion_setting.use_tiny_auto_encoder
         
     | 
| 34 | 
         
            -
                        and settings.lcm_diffusion_setting.use_openvino
         
     | 
| 35 | 
         
            -
                    ):
         
     | 
| 36 | 
         
            -
                        print(
         
     | 
| 37 | 
         
            -
                            "WARNING: Tiny AutoEncoder is not supported in Image to image mode (OpenVINO)"
         
     | 
| 38 | 
         
            -
                        )
         
     | 
| 39 | 
         
            -
                    tick = perf_counter()
         
     | 
| 40 | 
         
            -
                    from state import get_settings
         
     | 
| 41 | 
         
            -
             
     | 
| 42 | 
         
            -
                    if (
         
     | 
| 43 | 
         
            -
                        settings.lcm_diffusion_setting.diffusion_task
         
     | 
| 44 | 
         
            -
                        == DiffusionTask.text_to_image.value
         
     | 
| 45 | 
         
            -
                    ):
         
     | 
| 46 | 
         
            -
                        settings.lcm_diffusion_setting.init_image = None
         
     | 
| 47 | 
         
            -
             
     | 
| 48 | 
         
            -
                    if save_config:
         
     | 
| 49 | 
         
            -
                        get_settings().save()
         
     | 
| 50 | 
         
            -
             
     | 
| 51 | 
         
            -
                    pprint(settings.lcm_diffusion_setting.model_dump())
         
     | 
| 52 | 
         
            -
                    if not settings.lcm_diffusion_setting.lcm_lora:
         
     | 
| 53 | 
         
            -
                        return None
         
     | 
| 54 | 
         
            -
                    self.lcm_text_to_image.init(
         
     | 
| 55 | 
         
            -
                        device,
         
     | 
| 56 | 
         
            -
                        settings.lcm_diffusion_setting,
         
     | 
| 57 | 
         
            -
                    )
         
     | 
| 58 | 
         
            -
                    images = self.lcm_text_to_image.generate(
         
     | 
| 59 | 
         
            -
                        settings.lcm_diffusion_setting,
         
     | 
| 60 | 
         
            -
                        reshape,
         
     | 
| 61 | 
         
            -
                    )
         
     | 
| 62 | 
         
            -
                    elapsed = perf_counter() - tick
         
     | 
| 63 | 
         
            -
                    self._latency = elapsed
         
     | 
| 64 | 
         
            -
                    print(f"Latency : {elapsed:.2f} seconds")
         
     | 
| 65 | 
         
            -
                    if settings.lcm_diffusion_setting.controlnet:
         
     | 
| 66 | 
         
            -
                        if settings.lcm_diffusion_setting.controlnet.enabled:
         
     | 
| 67 | 
         
            -
                            images.append(settings.lcm_diffusion_setting.controlnet._control_image)
         
     | 
| 68 | 
         
            -
                    return images
         
     | 
| 69 | 
         
            -
             
     | 
| 70 | 
         
            -
             
     | 
| 71 | 
         
            -
                def save_images(
         
     | 
| 72 | 
         
            -
                    self,
         
     | 
| 73 | 
         
            -
                    images: Any,
         
     | 
| 74 | 
         
            -
                    settings: Settings,
         
     | 
| 75 | 
         
            -
                ) -> list[str]:
         
     | 
| 76 | 
         
            -
                    saved_images = []
         
     | 
| 77 | 
         
            -
                    if images and settings.generated_images.save_image:
         
     | 
| 78 | 
         
            -
                        saved_images = ImageSaver.save_images(
         
     | 
| 79 | 
         
            -
                            settings.generated_images.path,
         
     | 
| 80 | 
         
            -
                            images=images,
         
     | 
| 81 | 
         
            -
                            lcm_diffusion_setting=settings.lcm_diffusion_setting,
         
     | 
| 82 | 
         
            -
                            format=settings.generated_images.format,
         
     | 
| 83 | 
         
            -
                            jpeg_quality=settings.generated_images.save_image_quality,
         
     | 
| 84 | 
         
            -
                        )
         
     | 
| 85 | 
         
            -
                    return saved_images
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/frontend/cli_interactive.py
    DELETED
    
    | 
         @@ -1,661 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from os import path
         
     | 
| 2 | 
         
            -
            from PIL import Image
         
     | 
| 3 | 
         
            -
            from typing import Any
         
     | 
| 4 | 
         
            -
             
     | 
| 5 | 
         
            -
            from constants import DEVICE
         
     | 
| 6 | 
         
            -
            from paths import FastStableDiffusionPaths
         
     | 
| 7 | 
         
            -
            from backend.upscale.upscaler import upscale_image
         
     | 
| 8 | 
         
            -
            from backend.upscale.tiled_upscale import generate_upscaled_image
         
     | 
| 9 | 
         
            -
            from frontend.webui.image_variations_ui import generate_image_variations
         
     | 
| 10 | 
         
            -
            from backend.lora import (
         
     | 
| 11 | 
         
            -
                get_active_lora_weights,
         
     | 
| 12 | 
         
            -
                update_lora_weights,
         
     | 
| 13 | 
         
            -
                load_lora_weight,
         
     | 
| 14 | 
         
            -
            )
         
     | 
| 15 | 
         
            -
            from backend.models.lcmdiffusion_setting import (
         
     | 
| 16 | 
         
            -
                DiffusionTask,
         
     | 
| 17 | 
         
            -
                ControlNetSetting,
         
     | 
| 18 | 
         
            -
            )
         
     | 
| 19 | 
         
            -
             
     | 
| 20 | 
         
            -
             
     | 
| 21 | 
         
            -
            _batch_count = 1
         
     | 
| 22 | 
         
            -
            _edit_lora_settings = False
         
     | 
| 23 | 
         
            -
             
     | 
| 24 | 
         
            -
             
     | 
| 25 | 
         
            -
            def user_value(
         
     | 
| 26 | 
         
            -
                value_type: type,
         
     | 
| 27 | 
         
            -
                message: str,
         
     | 
| 28 | 
         
            -
                default_value: Any,
         
     | 
| 29 | 
         
            -
            ) -> Any:
         
     | 
| 30 | 
         
            -
                try:
         
     | 
| 31 | 
         
            -
                    value = value_type(input(message))
         
     | 
| 32 | 
         
            -
                except:
         
     | 
| 33 | 
         
            -
                    value = default_value
         
     | 
| 34 | 
         
            -
                return value
         
     | 
| 35 | 
         
            -
             
     | 
| 36 | 
         
            -
             
     | 
| 37 | 
         
            -
            def interactive_mode(
         
     | 
| 38 | 
         
            -
                config,
         
     | 
| 39 | 
         
            -
                context,
         
     | 
| 40 | 
         
            -
            ):
         
     | 
| 41 | 
         
            -
                print("=============================================")
         
     | 
| 42 | 
         
            -
                print("Welcome to FastSD CPU Interactive CLI")
         
     | 
| 43 | 
         
            -
                print("=============================================")
         
     | 
| 44 | 
         
            -
                while True:
         
     | 
| 45 | 
         
            -
                    print("> 1. Text to Image")
         
     | 
| 46 | 
         
            -
                    print("> 2. Image to Image")
         
     | 
| 47 | 
         
            -
                    print("> 3. Image Variations")
         
     | 
| 48 | 
         
            -
                    print("> 4. EDSR Upscale")
         
     | 
| 49 | 
         
            -
                    print("> 5. SD Upscale")
         
     | 
| 50 | 
         
            -
                    print("> 6. Edit default generation settings")
         
     | 
| 51 | 
         
            -
                    print("> 7. Edit LoRA settings")
         
     | 
| 52 | 
         
            -
                    print("> 8. Edit ControlNet settings")
         
     | 
| 53 | 
         
            -
                    print("> 9. Edit negative prompt")
         
     | 
| 54 | 
         
            -
                    print("> 10. Quit")
         
     | 
| 55 | 
         
            -
                    option = user_value(
         
     | 
| 56 | 
         
            -
                        int,
         
     | 
| 57 | 
         
            -
                        "Enter a Diffusion Task number (1): ",
         
     | 
| 58 | 
         
            -
                        1,
         
     | 
| 59 | 
         
            -
                    )
         
     | 
| 60 | 
         
            -
                    if option not in range(1, 11):
         
     | 
| 61 | 
         
            -
                        print("Wrong Diffusion Task number!")
         
     | 
| 62 | 
         
            -
                        exit()
         
     | 
| 63 | 
         
            -
             
     | 
| 64 | 
         
            -
                    if option == 1:
         
     | 
| 65 | 
         
            -
                        interactive_txt2img(
         
     | 
| 66 | 
         
            -
                            config,
         
     | 
| 67 | 
         
            -
                            context,
         
     | 
| 68 | 
         
            -
                        )
         
     | 
| 69 | 
         
            -
                    elif option == 2:
         
     | 
| 70 | 
         
            -
                        interactive_img2img(
         
     | 
| 71 | 
         
            -
                            config,
         
     | 
| 72 | 
         
            -
                            context,
         
     | 
| 73 | 
         
            -
                        )
         
     | 
| 74 | 
         
            -
                    elif option == 3:
         
     | 
| 75 | 
         
            -
                        interactive_variations(
         
     | 
| 76 | 
         
            -
                            config,
         
     | 
| 77 | 
         
            -
                            context,
         
     | 
| 78 | 
         
            -
                        )
         
     | 
| 79 | 
         
            -
                    elif option == 4:
         
     | 
| 80 | 
         
            -
                        interactive_edsr(
         
     | 
| 81 | 
         
            -
                            config,
         
     | 
| 82 | 
         
            -
                            context,
         
     | 
| 83 | 
         
            -
                        )
         
     | 
| 84 | 
         
            -
                    elif option == 5:
         
     | 
| 85 | 
         
            -
                        interactive_sdupscale(
         
     | 
| 86 | 
         
            -
                            config,
         
     | 
| 87 | 
         
            -
                            context,
         
     | 
| 88 | 
         
            -
                        )
         
     | 
| 89 | 
         
            -
                    elif option == 6:
         
     | 
| 90 | 
         
            -
                        interactive_settings(
         
     | 
| 91 | 
         
            -
                            config,
         
     | 
| 92 | 
         
            -
                            context,
         
     | 
| 93 | 
         
            -
                        )
         
     | 
| 94 | 
         
            -
                    elif option == 7:
         
     | 
| 95 | 
         
            -
                        interactive_lora(
         
     | 
| 96 | 
         
            -
                            config,
         
     | 
| 97 | 
         
            -
                            context,
         
     | 
| 98 | 
         
            -
                            True,
         
     | 
| 99 | 
         
            -
                        )
         
     | 
| 100 | 
         
            -
                    elif option == 8:
         
     | 
| 101 | 
         
            -
                        interactive_controlnet(
         
     | 
| 102 | 
         
            -
                            config,
         
     | 
| 103 | 
         
            -
                            context,
         
     | 
| 104 | 
         
            -
                            True,
         
     | 
| 105 | 
         
            -
                        )
         
     | 
| 106 | 
         
            -
                    elif option == 9:
         
     | 
| 107 | 
         
            -
                        interactive_negative(
         
     | 
| 108 | 
         
            -
                            config,
         
     | 
| 109 | 
         
            -
                            context,
         
     | 
| 110 | 
         
            -
                        )
         
     | 
| 111 | 
         
            -
                    elif option == 10:
         
     | 
| 112 | 
         
            -
                        exit()
         
     | 
| 113 | 
         
            -
             
     | 
| 114 | 
         
            -
             
     | 
| 115 | 
         
            -
            def interactive_negative(
         
     | 
| 116 | 
         
            -
                config,
         
     | 
| 117 | 
         
            -
                context,
         
     | 
| 118 | 
         
            -
            ):
         
     | 
| 119 | 
         
            -
                settings = config.lcm_diffusion_setting
         
     | 
| 120 | 
         
            -
                print(f"Current negative prompt: '{settings.negative_prompt}'")
         
     | 
| 121 | 
         
            -
                user_input = input("Write a negative prompt (set guidance > 1.0): ")
         
     | 
| 122 | 
         
            -
                if user_input == "":
         
     | 
| 123 | 
         
            -
                    return
         
     | 
| 124 | 
         
            -
                else:
         
     | 
| 125 | 
         
            -
                    settings.negative_prompt = user_input
         
     | 
| 126 | 
         
            -
             
     | 
| 127 | 
         
            -
             
     | 
| 128 | 
         
            -
            def interactive_controlnet(
         
     | 
| 129 | 
         
            -
                config,
         
     | 
| 130 | 
         
            -
                context,
         
     | 
| 131 | 
         
            -
                menu_flag=False,
         
     | 
| 132 | 
         
            -
            ):
         
     | 
| 133 | 
         
            -
                """
         
     | 
| 134 | 
         
            -
                @param menu_flag: Indicates whether this function was called from the main
         
     | 
| 135 | 
         
            -
                    interactive CLI menu; _True_ if called from the main menu, _False_ otherwise
         
     | 
| 136 | 
         
            -
                """
         
     | 
| 137 | 
         
            -
                settings = config.lcm_diffusion_setting
         
     | 
| 138 | 
         
            -
                if not settings.controlnet:
         
     | 
| 139 | 
         
            -
                    settings.controlnet = ControlNetSetting()
         
     | 
| 140 | 
         
            -
             
     | 
| 141 | 
         
            -
                current_enabled = settings.controlnet.enabled
         
     | 
| 142 | 
         
            -
                current_adapter_path = settings.controlnet.adapter_path
         
     | 
| 143 | 
         
            -
                current_conditioning_scale = settings.controlnet.conditioning_scale
         
     | 
| 144 | 
         
            -
                current_control_image = settings.controlnet._control_image
         
     | 
| 145 | 
         
            -
             
     | 
| 146 | 
         
            -
                option = input("Enable ControlNet? (y/N): ")
         
     | 
| 147 | 
         
            -
                settings.controlnet.enabled = True if option.upper() == "Y" else False
         
     | 
| 148 | 
         
            -
                if settings.controlnet.enabled:
         
     | 
| 149 | 
         
            -
                    option = input(
         
     | 
| 150 | 
         
            -
                        f"Enter ControlNet adapter path ({settings.controlnet.adapter_path}): "
         
     | 
| 151 | 
         
            -
                    )
         
     | 
| 152 | 
         
            -
                    if option != "":
         
     | 
| 153 | 
         
            -
                        settings.controlnet.adapter_path = option
         
     | 
| 154 | 
         
            -
                    settings.controlnet.conditioning_scale = user_value(
         
     | 
| 155 | 
         
            -
                        float,
         
     | 
| 156 | 
         
            -
                        f"Enter ControlNet conditioning scale ({settings.controlnet.conditioning_scale}): ",
         
     | 
| 157 | 
         
            -
                        settings.controlnet.conditioning_scale,
         
     | 
| 158 | 
         
            -
                    )
         
     | 
| 159 | 
         
            -
                    option = input(
         
     | 
| 160 | 
         
            -
                        f"Enter ControlNet control image path (Leave empty to reuse current): "
         
     | 
| 161 | 
         
            -
                    )
         
     | 
| 162 | 
         
            -
                    if option != "":
         
     | 
| 163 | 
         
            -
                        try:
         
     | 
| 164 | 
         
            -
                            new_image = Image.open(option)
         
     | 
| 165 | 
         
            -
                            settings.controlnet._control_image = new_image
         
     | 
| 166 | 
         
            -
                        except (AttributeError, FileNotFoundError) as e:
         
     | 
| 167 | 
         
            -
                            settings.controlnet._control_image = None
         
     | 
| 168 | 
         
            -
                    if (
         
     | 
| 169 | 
         
            -
                        not settings.controlnet.adapter_path
         
     | 
| 170 | 
         
            -
                        or not path.exists(settings.controlnet.adapter_path)
         
     | 
| 171 | 
         
            -
                        or not settings.controlnet._control_image
         
     | 
| 172 | 
         
            -
                    ):
         
     | 
| 173 | 
         
            -
                        print("Invalid ControlNet settings! Disabling ControlNet")
         
     | 
| 174 | 
         
            -
                        settings.controlnet.enabled = False
         
     | 
| 175 | 
         
            -
             
     | 
| 176 | 
         
            -
                if (
         
     | 
| 177 | 
         
            -
                    settings.controlnet.enabled != current_enabled
         
     | 
| 178 | 
         
            -
                    or settings.controlnet.adapter_path != current_adapter_path
         
     | 
| 179 | 
         
            -
                ):
         
     | 
| 180 | 
         
            -
                    settings.rebuild_pipeline = True
         
     | 
| 181 | 
         
            -
             
     | 
| 182 | 
         
            -
             
     | 
| 183 | 
         
            -
            def interactive_lora(
         
     | 
| 184 | 
         
            -
                config,
         
     | 
| 185 | 
         
            -
                context,
         
     | 
| 186 | 
         
            -
                menu_flag=False,
         
     | 
| 187 | 
         
            -
            ):
         
     | 
| 188 | 
         
            -
                """
         
     | 
| 189 | 
         
            -
                @param menu_flag: Indicates whether this function was called from the main
         
     | 
| 190 | 
         
            -
                    interactive CLI menu; _True_ if called from the main menu, _False_ otherwise
         
     | 
| 191 | 
         
            -
                """
         
     | 
| 192 | 
         
            -
                if context == None or context.lcm_text_to_image.pipeline == None:
         
     | 
| 193 | 
         
            -
                    print("Diffusion pipeline not initialized, please run a generation task first!")
         
     | 
| 194 | 
         
            -
                    return
         
     | 
| 195 | 
         
            -
             
     | 
| 196 | 
         
            -
                print("> 1. Change LoRA weights")
         
     | 
| 197 | 
         
            -
                print("> 2. Load new LoRA model")
         
     | 
| 198 | 
         
            -
                option = user_value(
         
     | 
| 199 | 
         
            -
                    int,
         
     | 
| 200 | 
         
            -
                    "Enter a LoRA option (1): ",
         
     | 
| 201 | 
         
            -
                    1,
         
     | 
| 202 | 
         
            -
                )
         
     | 
| 203 | 
         
            -
                if option not in range(1, 3):
         
     | 
| 204 | 
         
            -
                    print("Wrong LoRA option!")
         
     | 
| 205 | 
         
            -
                    return
         
     | 
| 206 | 
         
            -
             
     | 
| 207 | 
         
            -
                if option == 1:
         
     | 
| 208 | 
         
            -
                    update_weights = []
         
     | 
| 209 | 
         
            -
                    active_weights = get_active_lora_weights()
         
     | 
| 210 | 
         
            -
                    for lora in active_weights:
         
     | 
| 211 | 
         
            -
                        weight = user_value(
         
     | 
| 212 | 
         
            -
                            float,
         
     | 
| 213 | 
         
            -
                            f"Enter a new LoRA weight for {lora[0]} ({lora[1]}): ",
         
     | 
| 214 | 
         
            -
                            lora[1],
         
     | 
| 215 | 
         
            -
                        )
         
     | 
| 216 | 
         
            -
                        update_weights.append(
         
     | 
| 217 | 
         
            -
                            (
         
     | 
| 218 | 
         
            -
                                lora[0],
         
     | 
| 219 | 
         
            -
                                weight,
         
     | 
| 220 | 
         
            -
                            )
         
     | 
| 221 | 
         
            -
                        )
         
     | 
| 222 | 
         
            -
                    if len(update_weights) > 0:
         
     | 
| 223 | 
         
            -
                        update_lora_weights(
         
     | 
| 224 | 
         
            -
                            context.lcm_text_to_image.pipeline,
         
     | 
| 225 | 
         
            -
                            config.lcm_diffusion_setting,
         
     | 
| 226 | 
         
            -
                            update_weights,
         
     | 
| 227 | 
         
            -
                        )
         
     | 
| 228 | 
         
            -
                elif option == 2:
         
     | 
| 229 | 
         
            -
                    # Load a new LoRA
         
     | 
| 230 | 
         
            -
                    settings = config.lcm_diffusion_setting
         
     | 
| 231 | 
         
            -
                    settings.lora.fuse = False
         
     | 
| 232 | 
         
            -
                    settings.lora.enabled = False
         
     | 
| 233 | 
         
            -
                    settings.lora.path = input("Enter LoRA model path: ")
         
     | 
| 234 | 
         
            -
                    settings.lora.weight = user_value(
         
     | 
| 235 | 
         
            -
                        float,
         
     | 
| 236 | 
         
            -
                        "Enter a LoRA weight (0.5): ",
         
     | 
| 237 | 
         
            -
                        0.5,
         
     | 
| 238 | 
         
            -
                    )
         
     | 
| 239 | 
         
            -
                    if not path.exists(settings.lora.path):
         
     | 
| 240 | 
         
            -
                        print("Invalid LoRA model path!")
         
     | 
| 241 | 
         
            -
                        return
         
     | 
| 242 | 
         
            -
                    settings.lora.enabled = True
         
     | 
| 243 | 
         
            -
                    load_lora_weight(context.lcm_text_to_image.pipeline, settings)
         
     | 
| 244 | 
         
            -
             
     | 
| 245 | 
         
            -
                if menu_flag:
         
     | 
| 246 | 
         
            -
                    global _edit_lora_settings
         
     | 
| 247 | 
         
            -
                    _edit_lora_settings = False
         
     | 
| 248 | 
         
            -
                    option = input("Edit LoRA settings after every generation? (y/N): ")
         
     | 
| 249 | 
         
            -
                    if option.upper() == "Y":
         
     | 
| 250 | 
         
            -
                        _edit_lora_settings = True
         
     | 
| 251 | 
         
            -
             
     | 
| 252 | 
         
            -
             
     | 
| 253 | 
         
            -
            def interactive_settings(
         
     | 
| 254 | 
         
            -
                config,
         
     | 
| 255 | 
         
            -
                context,
         
     | 
| 256 | 
         
            -
            ):
         
     | 
| 257 | 
         
            -
                global _batch_count
         
     | 
| 258 | 
         
            -
                settings = config.lcm_diffusion_setting
         
     | 
| 259 | 
         
            -
                print("Enter generation settings (leave empty to use current value)")
         
     | 
| 260 | 
         
            -
                print("> 1. Use LCM")
         
     | 
| 261 | 
         
            -
                print("> 2. Use LCM-Lora")
         
     | 
| 262 | 
         
            -
                print("> 3. Use OpenVINO")
         
     | 
| 263 | 
         
            -
                option = user_value(
         
     | 
| 264 | 
         
            -
                    int,
         
     | 
| 265 | 
         
            -
                    "Select inference model option (1): ",
         
     | 
| 266 | 
         
            -
                    1,
         
     | 
| 267 | 
         
            -
                )
         
     | 
| 268 | 
         
            -
                if option not in range(1, 4):
         
     | 
| 269 | 
         
            -
                    print("Wrong inference model option! Falling back to defaults")
         
     | 
| 270 | 
         
            -
                    return
         
     | 
| 271 | 
         
            -
             
     | 
| 272 | 
         
            -
                settings.use_lcm_lora = False
         
     | 
| 273 | 
         
            -
                settings.use_openvino = False
         
     | 
| 274 | 
         
            -
                if option == 1:
         
     | 
| 275 | 
         
            -
                    lcm_model_id = input(f"Enter LCM model ID ({settings.lcm_model_id}): ")
         
     | 
| 276 | 
         
            -
                    if lcm_model_id != "":
         
     | 
| 277 | 
         
            -
                        settings.lcm_model_id = lcm_model_id
         
     | 
| 278 | 
         
            -
                elif option == 2:
         
     | 
| 279 | 
         
            -
                    settings.use_lcm_lora = True
         
     | 
| 280 | 
         
            -
                    lcm_lora_id = input(
         
     | 
| 281 | 
         
            -
                        f"Enter LCM-Lora model ID ({settings.lcm_lora.lcm_lora_id}): "
         
     | 
| 282 | 
         
            -
                    )
         
     | 
| 283 | 
         
            -
                    if lcm_lora_id != "":
         
     | 
| 284 | 
         
            -
                        settings.lcm_lora.lcm_lora_id = lcm_lora_id
         
     | 
| 285 | 
         
            -
                    base_model_id = input(
         
     | 
| 286 | 
         
            -
                        f"Enter Base model ID ({settings.lcm_lora.base_model_id}): "
         
     | 
| 287 | 
         
            -
                    )
         
     | 
| 288 | 
         
            -
                    if base_model_id != "":
         
     | 
| 289 | 
         
            -
                        settings.lcm_lora.base_model_id = base_model_id
         
     | 
| 290 | 
         
            -
                elif option == 3:
         
     | 
| 291 | 
         
            -
                    settings.use_openvino = True
         
     | 
| 292 | 
         
            -
                    openvino_lcm_model_id = input(
         
     | 
| 293 | 
         
            -
                        f"Enter OpenVINO model ID ({settings.openvino_lcm_model_id}): "
         
     | 
| 294 | 
         
            -
                    )
         
     | 
| 295 | 
         
            -
                    if openvino_lcm_model_id != "":
         
     | 
| 296 | 
         
            -
                        settings.openvino_lcm_model_id = openvino_lcm_model_id
         
     | 
| 297 | 
         
            -
             
     | 
| 298 | 
         
            -
                settings.use_offline_model = True
         
     | 
| 299 | 
         
            -
                settings.use_tiny_auto_encoder = True
         
     | 
| 300 | 
         
            -
                option = input("Work offline? (Y/n): ")
         
     | 
| 301 | 
         
            -
                if option.upper() == "N":
         
     | 
| 302 | 
         
            -
                    settings.use_offline_model = False
         
     | 
| 303 | 
         
            -
                option = input("Use Tiny Auto Encoder? (Y/n): ")
         
     | 
| 304 | 
         
            -
                if option.upper() == "N":
         
     | 
| 305 | 
         
            -
                    settings.use_tiny_auto_encoder = False
         
     | 
| 306 | 
         
            -
             
     | 
| 307 | 
         
            -
                settings.image_width = user_value(
         
     | 
| 308 | 
         
            -
                    int,
         
     | 
| 309 | 
         
            -
                    f"Image width ({settings.image_width}): ",
         
     | 
| 310 | 
         
            -
                    settings.image_width,
         
     | 
| 311 | 
         
            -
                )
         
     | 
| 312 | 
         
            -
                settings.image_height = user_value(
         
     | 
| 313 | 
         
            -
                    int,
         
     | 
| 314 | 
         
            -
                    f"Image height ({settings.image_height}): ",
         
     | 
| 315 | 
         
            -
                    settings.image_height,
         
     | 
| 316 | 
         
            -
                )
         
     | 
| 317 | 
         
            -
                settings.inference_steps = user_value(
         
     | 
| 318 | 
         
            -
                    int,
         
     | 
| 319 | 
         
            -
                    f"Inference steps ({settings.inference_steps}): ",
         
     | 
| 320 | 
         
            -
                    settings.inference_steps,
         
     | 
| 321 | 
         
            -
                )
         
     | 
| 322 | 
         
            -
                settings.guidance_scale = user_value(
         
     | 
| 323 | 
         
            -
                    float,
         
     | 
| 324 | 
         
            -
                    f"Guidance scale ({settings.guidance_scale}): ",
         
     | 
| 325 | 
         
            -
                    settings.guidance_scale,
         
     | 
| 326 | 
         
            -
                )
         
     | 
| 327 | 
         
            -
                settings.number_of_images = user_value(
         
     | 
| 328 | 
         
            -
                    int,
         
     | 
| 329 | 
         
            -
                    f"Number of images per batch ({settings.number_of_images}): ",
         
     | 
| 330 | 
         
            -
                    settings.number_of_images,
         
     | 
| 331 | 
         
            -
                )
         
     | 
| 332 | 
         
            -
                _batch_count = user_value(
         
     | 
| 333 | 
         
            -
                    int,
         
     | 
| 334 | 
         
            -
                    f"Batch count ({_batch_count}): ",
         
     | 
| 335 | 
         
            -
                    _batch_count,
         
     | 
| 336 | 
         
            -
                )
         
     | 
| 337 | 
         
            -
                # output_format = user_value(int, f"Output format (PNG)", 1)
         
     | 
| 338 | 
         
            -
                print(config.lcm_diffusion_setting)
         
     | 
| 339 | 
         
            -
             
     | 
| 340 | 
         
            -
             
     | 
| 341 | 
         
            -
            def interactive_txt2img(
         
     | 
| 342 | 
         
            -
                config,
         
     | 
| 343 | 
         
            -
                context,
         
     | 
| 344 | 
         
            -
            ):
         
     | 
| 345 | 
         
            -
                global _batch_count
         
     | 
| 346 | 
         
            -
                config.lcm_diffusion_setting.diffusion_task = DiffusionTask.text_to_image.value
         
     | 
| 347 | 
         
            -
                user_input = input("Write a prompt (write 'exit' to quit): ")
         
     | 
| 348 | 
         
            -
                while True:
         
     | 
| 349 | 
         
            -
                    if user_input == "exit":
         
     | 
| 350 | 
         
            -
                        return
         
     | 
| 351 | 
         
            -
                    elif user_input == "":
         
     | 
| 352 | 
         
            -
                        user_input = config.lcm_diffusion_setting.prompt
         
     | 
| 353 | 
         
            -
                    config.lcm_diffusion_setting.prompt = user_input
         
     | 
| 354 | 
         
            -
                    for _ in range(0, _batch_count):
         
     | 
| 355 | 
         
            -
                        images = context.generate_text_to_image(
         
     | 
| 356 | 
         
            -
                            settings=config,
         
     | 
| 357 | 
         
            -
                            device=DEVICE,
         
     | 
| 358 | 
         
            -
                        )
         
     | 
| 359 | 
         
            -
                        context.save_images(
         
     | 
| 360 | 
         
            -
                            images,
         
     | 
| 361 | 
         
            -
                            config,
         
     | 
| 362 | 
         
            -
                        )
         
     | 
| 363 | 
         
            -
                    if _edit_lora_settings:
         
     | 
| 364 | 
         
            -
                        interactive_lora(
         
     | 
| 365 | 
         
            -
                            config,
         
     | 
| 366 | 
         
            -
                            context,
         
     | 
| 367 | 
         
            -
                        )
         
     | 
| 368 | 
         
            -
                    user_input = input("Write a prompt: ")
         
     | 
| 369 | 
         
            -
             
     | 
| 370 | 
         
            -
             
     | 
| 371 | 
         
            -
            def interactive_img2img(
         
     | 
| 372 | 
         
            -
                config,
         
     | 
| 373 | 
         
            -
                context,
         
     | 
| 374 | 
         
            -
            ):
         
     | 
| 375 | 
         
            -
                global _batch_count
         
     | 
| 376 | 
         
            -
                settings = config.lcm_diffusion_setting
         
     | 
| 377 | 
         
            -
                settings.diffusion_task = DiffusionTask.image_to_image.value
         
     | 
| 378 | 
         
            -
                steps = settings.inference_steps
         
     | 
| 379 | 
         
            -
                source_path = input("Image path: ")
         
     | 
| 380 | 
         
            -
                if source_path == "":
         
     | 
| 381 | 
         
            -
                    print("Error : You need to provide a file in img2img mode")
         
     | 
| 382 | 
         
            -
                    return
         
     | 
| 383 | 
         
            -
                settings.strength = user_value(
         
     | 
| 384 | 
         
            -
                    float,
         
     | 
| 385 | 
         
            -
                    f"img2img strength ({settings.strength}): ",
         
     | 
| 386 | 
         
            -
                    settings.strength,
         
     | 
| 387 | 
         
            -
                )
         
     | 
| 388 | 
         
            -
                settings.inference_steps = int(steps / settings.strength + 1)
         
     | 
| 389 | 
         
            -
                user_input = input("Write a prompt (write 'exit' to quit): ")
         
     | 
| 390 | 
         
            -
                while True:
         
     | 
| 391 | 
         
            -
                    if user_input == "exit":
         
     | 
| 392 | 
         
            -
                        settings.inference_steps = steps
         
     | 
| 393 | 
         
            -
                        return
         
     | 
| 394 | 
         
            -
                    settings.init_image = Image.open(source_path)
         
     | 
| 395 | 
         
            -
                    settings.prompt = user_input
         
     | 
| 396 | 
         
            -
                    for _ in range(0, _batch_count):
         
     | 
| 397 | 
         
            -
                        images = context.generate_text_to_image(
         
     | 
| 398 | 
         
            -
                            settings=config,
         
     | 
| 399 | 
         
            -
                            device=DEVICE,
         
     | 
| 400 | 
         
            -
                        )
         
     | 
| 401 | 
         
            -
                        context.save_images(
         
     | 
| 402 | 
         
            -
                            images,
         
     | 
| 403 | 
         
            -
                            config,
         
     | 
| 404 | 
         
            -
                        )
         
     | 
| 405 | 
         
            -
                    new_path = input(f"Image path ({source_path}): ")
         
     | 
| 406 | 
         
            -
                    if new_path != "":
         
     | 
| 407 | 
         
            -
                        source_path = new_path
         
     | 
| 408 | 
         
            -
                    settings.strength = user_value(
         
     | 
| 409 | 
         
            -
                        float,
         
     | 
| 410 | 
         
            -
                        f"img2img strength ({settings.strength}): ",
         
     | 
| 411 | 
         
            -
                        settings.strength,
         
     | 
| 412 | 
         
            -
                    )
         
     | 
| 413 | 
         
            -
                    if _edit_lora_settings:
         
     | 
| 414 | 
         
            -
                        interactive_lora(
         
     | 
| 415 | 
         
            -
                            config,
         
     | 
| 416 | 
         
            -
                            context,
         
     | 
| 417 | 
         
            -
                        )
         
     | 
| 418 | 
         
            -
                    settings.inference_steps = int(steps / settings.strength + 1)
         
     | 
| 419 | 
         
            -
                    user_input = input("Write a prompt: ")
         
     | 
| 420 | 
         
            -
             
     | 
| 421 | 
         
            -
             
     | 
| 422 | 
         
            -
            def interactive_variations(
         
     | 
| 423 | 
         
            -
                config,
         
     | 
| 424 | 
         
            -
                context,
         
     | 
| 425 | 
         
            -
            ):
         
     | 
| 426 | 
         
            -
                global _batch_count
         
     | 
| 427 | 
         
            -
                settings = config.lcm_diffusion_setting
         
     | 
| 428 | 
         
            -
                settings.diffusion_task = DiffusionTask.image_to_image.value
         
     | 
| 429 | 
         
            -
                steps = settings.inference_steps
         
     | 
| 430 | 
         
            -
                source_path = input("Image path: ")
         
     | 
| 431 | 
         
            -
                if source_path == "":
         
     | 
| 432 | 
         
            -
                    print("Error : You need to provide a file in Image variations mode")
         
     | 
| 433 | 
         
            -
                    return
         
     | 
| 434 | 
         
            -
                settings.strength = user_value(
         
     | 
| 435 | 
         
            -
                    float,
         
     | 
| 436 | 
         
            -
                    f"Image variations strength ({settings.strength}): ",
         
     | 
| 437 | 
         
            -
                    settings.strength,
         
     | 
| 438 | 
         
            -
                )
         
     | 
| 439 | 
         
            -
                settings.inference_steps = int(steps / settings.strength + 1)
         
     | 
| 440 | 
         
            -
                while True:
         
     | 
| 441 | 
         
            -
                    settings.init_image = Image.open(source_path)
         
     | 
| 442 | 
         
            -
                    settings.prompt = ""
         
     | 
| 443 | 
         
            -
                    for i in range(0, _batch_count):
         
     | 
| 444 | 
         
            -
                        generate_image_variations(
         
     | 
| 445 | 
         
            -
                            settings.init_image,
         
     | 
| 446 | 
         
            -
                            settings.strength,
         
     | 
| 447 | 
         
            -
                        )
         
     | 
| 448 | 
         
            -
                    if _edit_lora_settings:
         
     | 
| 449 | 
         
            -
                        interactive_lora(
         
     | 
| 450 | 
         
            -
                            config,
         
     | 
| 451 | 
         
            -
                            context,
         
     | 
| 452 | 
         
            -
                        )
         
     | 
| 453 | 
         
            -
                    user_input = input("Continue in Image variations mode? (Y/n): ")
         
     | 
| 454 | 
         
            -
                    if user_input.upper() == "N":
         
     | 
| 455 | 
         
            -
                        settings.inference_steps = steps
         
     | 
| 456 | 
         
            -
                        return
         
     | 
| 457 | 
         
            -
                    new_path = input(f"Image path ({source_path}): ")
         
     | 
| 458 | 
         
            -
                    if new_path != "":
         
     | 
| 459 | 
         
            -
                        source_path = new_path
         
     | 
| 460 | 
         
            -
                    settings.strength = user_value(
         
     | 
| 461 | 
         
            -
                        float,
         
     | 
| 462 | 
         
            -
                        f"Image variations strength ({settings.strength}): ",
         
     | 
| 463 | 
         
            -
                        settings.strength,
         
     | 
| 464 | 
         
            -
                    )
         
     | 
| 465 | 
         
            -
                    settings.inference_steps = int(steps / settings.strength + 1)
         
     | 
| 466 | 
         
            -
             
     | 
| 467 | 
         
            -
             
     | 
| 468 | 
         
            -
            def interactive_edsr(
         
     | 
| 469 | 
         
            -
                config,
         
     | 
| 470 | 
         
            -
                context,
         
     | 
| 471 | 
         
            -
            ):
         
     | 
| 472 | 
         
            -
                source_path = input("Image path: ")
         
     | 
| 473 | 
         
            -
                if source_path == "":
         
     | 
| 474 | 
         
            -
                    print("Error : You need to provide a file in EDSR mode")
         
     | 
| 475 | 
         
            -
                    return
         
     | 
| 476 | 
         
            -
                while True:
         
     | 
| 477 | 
         
            -
                    output_path = FastStableDiffusionPaths.get_upscale_filepath(
         
     | 
| 478 | 
         
            -
                        source_path,
         
     | 
| 479 | 
         
            -
                        2,
         
     | 
| 480 | 
         
            -
                        config.generated_images.format,
         
     | 
| 481 | 
         
            -
                    )
         
     | 
| 482 | 
         
            -
                    result = upscale_image(
         
     | 
| 483 | 
         
            -
                        context,
         
     | 
| 484 | 
         
            -
                        source_path,
         
     | 
| 485 | 
         
            -
                        output_path,
         
     | 
| 486 | 
         
            -
                        2,
         
     | 
| 487 | 
         
            -
                    )
         
     | 
| 488 | 
         
            -
                    user_input = input("Continue in EDSR upscale mode? (Y/n): ")
         
     | 
| 489 | 
         
            -
                    if user_input.upper() == "N":
         
     | 
| 490 | 
         
            -
                        return
         
     | 
| 491 | 
         
            -
                    new_path = input(f"Image path ({source_path}): ")
         
     | 
| 492 | 
         
            -
                    if new_path != "":
         
     | 
| 493 | 
         
            -
                        source_path = new_path
         
     | 
| 494 | 
         
            -
             
     | 
| 495 | 
         
            -
             
     | 
| 496 | 
         
            -
            def interactive_sdupscale_settings(config):
         
     | 
| 497 | 
         
            -
                steps = config.lcm_diffusion_setting.inference_steps
         
     | 
| 498 | 
         
            -
                custom_settings = {}
         
     | 
| 499 | 
         
            -
                print("> 1. Upscale whole image")
         
     | 
| 500 | 
         
            -
                print("> 2. Define custom tiles (advanced)")
         
     | 
| 501 | 
         
            -
                option = user_value(
         
     | 
| 502 | 
         
            -
                    int,
         
     | 
| 503 | 
         
            -
                    "Select an SD Upscale option (1): ",
         
     | 
| 504 | 
         
            -
                    1,
         
     | 
| 505 | 
         
            -
                )
         
     | 
| 506 | 
         
            -
                if option not in range(1, 3):
         
     | 
| 507 | 
         
            -
                    print("Wrong SD Upscale option!")
         
     | 
| 508 | 
         
            -
                    return
         
     | 
| 509 | 
         
            -
             
     | 
| 510 | 
         
            -
                # custom_settings["source_file"] = args.file
         
     | 
| 511 | 
         
            -
                custom_settings["source_file"] = ""
         
     | 
| 512 | 
         
            -
                new_path = input(f"Input image path ({custom_settings['source_file']}): ")
         
     | 
| 513 | 
         
            -
                if new_path != "":
         
     | 
| 514 | 
         
            -
                    custom_settings["source_file"] = new_path
         
     | 
| 515 | 
         
            -
                if custom_settings["source_file"] == "":
         
     | 
| 516 | 
         
            -
                    print("Error : You need to provide a file in SD Upscale mode")
         
     | 
| 517 | 
         
            -
                    return
         
     | 
| 518 | 
         
            -
                custom_settings["target_file"] = None
         
     | 
| 519 | 
         
            -
                if option == 2:
         
     | 
| 520 | 
         
            -
                    custom_settings["target_file"] = input("Image to patch: ")
         
     | 
| 521 | 
         
            -
                    if custom_settings["target_file"] == "":
         
     | 
| 522 | 
         
            -
                        print("No target file provided, upscaling whole input image instead!")
         
     | 
| 523 | 
         
            -
                        custom_settings["target_file"] = None
         
     | 
| 524 | 
         
            -
                        option = 1
         
     | 
| 525 | 
         
            -
                custom_settings["output_format"] = config.generated_images.format
         
     | 
| 526 | 
         
            -
                custom_settings["strength"] = user_value(
         
     | 
| 527 | 
         
            -
                    float,
         
     | 
| 528 | 
         
            -
                    f"SD Upscale strength ({config.lcm_diffusion_setting.strength}): ",
         
     | 
| 529 | 
         
            -
                    config.lcm_diffusion_setting.strength,
         
     | 
| 530 | 
         
            -
                )
         
     | 
| 531 | 
         
            -
                config.lcm_diffusion_setting.inference_steps = int(
         
     | 
| 532 | 
         
            -
                    steps / custom_settings["strength"] + 1
         
     | 
| 533 | 
         
            -
                )
         
     | 
| 534 | 
         
            -
                if option == 1:
         
     | 
| 535 | 
         
            -
                    custom_settings["scale_factor"] = user_value(
         
     | 
| 536 | 
         
            -
                        float,
         
     | 
| 537 | 
         
            -
                        f"Scale factor (2.0): ",
         
     | 
| 538 | 
         
            -
                        2.0,
         
     | 
| 539 | 
         
            -
                    )
         
     | 
| 540 | 
         
            -
                    custom_settings["tile_size"] = user_value(
         
     | 
| 541 | 
         
            -
                        int,
         
     | 
| 542 | 
         
            -
                        f"Split input image into tiles of the following size, in pixels (256): ",
         
     | 
| 543 | 
         
            -
                        256,
         
     | 
| 544 | 
         
            -
                    )
         
     | 
| 545 | 
         
            -
                    custom_settings["tile_overlap"] = user_value(
         
     | 
| 546 | 
         
            -
                        int,
         
     | 
| 547 | 
         
            -
                        f"Tile overlap, in pixels (16): ",
         
     | 
| 548 | 
         
            -
                        16,
         
     | 
| 549 | 
         
            -
                    )
         
     | 
| 550 | 
         
            -
                elif option == 2:
         
     | 
| 551 | 
         
            -
                    custom_settings["scale_factor"] = user_value(
         
     | 
| 552 | 
         
            -
                        float,
         
     | 
| 553 | 
         
            -
                        "Input image to Image-to-patch scale_factor (2.0): ",
         
     | 
| 554 | 
         
            -
                        2.0,
         
     | 
| 555 | 
         
            -
                    )
         
     | 
| 556 | 
         
            -
                    custom_settings["tile_size"] = 256
         
     | 
| 557 | 
         
            -
                    custom_settings["tile_overlap"] = 16
         
     | 
| 558 | 
         
            -
                custom_settings["prompt"] = input(
         
     | 
| 559 | 
         
            -
                    "Write a prompt describing the input image (optional): "
         
     | 
| 560 | 
         
            -
                )
         
     | 
| 561 | 
         
            -
                custom_settings["tiles"] = []
         
     | 
| 562 | 
         
            -
                if option == 2:
         
     | 
| 563 | 
         
            -
                    add_tile = True
         
     | 
| 564 | 
         
            -
                    while add_tile:
         
     | 
| 565 | 
         
            -
                        print("=== Define custom SD Upscale tile ===")
         
     | 
| 566 | 
         
            -
                        tile_x = user_value(
         
     | 
| 567 | 
         
            -
                            int,
         
     | 
| 568 | 
         
            -
                            "Enter tile's X position: ",
         
     | 
| 569 | 
         
            -
                            0,
         
     | 
| 570 | 
         
            -
                        )
         
     | 
| 571 | 
         
            -
                        tile_y = user_value(
         
     | 
| 572 | 
         
            -
                            int,
         
     | 
| 573 | 
         
            -
                            "Enter tile's Y position: ",
         
     | 
| 574 | 
         
            -
                            0,
         
     | 
| 575 | 
         
            -
                        )
         
     | 
| 576 | 
         
            -
                        tile_w = user_value(
         
     | 
| 577 | 
         
            -
                            int,
         
     | 
| 578 | 
         
            -
                            "Enter tile's width (256): ",
         
     | 
| 579 | 
         
            -
                            256,
         
     | 
| 580 | 
         
            -
                        )
         
     | 
| 581 | 
         
            -
                        tile_h = user_value(
         
     | 
| 582 | 
         
            -
                            int,
         
     | 
| 583 | 
         
            -
                            "Enter tile's height (256): ",
         
     | 
| 584 | 
         
            -
                            256,
         
     | 
| 585 | 
         
            -
                        )
         
     | 
| 586 | 
         
            -
                        tile_scale = user_value(
         
     | 
| 587 | 
         
            -
                            float,
         
     | 
| 588 | 
         
            -
                            "Enter tile's scale factor (2.0): ",
         
     | 
| 589 | 
         
            -
                            2.0,
         
     | 
| 590 | 
         
            -
                        )
         
     | 
| 591 | 
         
            -
                        tile_prompt = input("Enter tile's prompt (optional): ")
         
     | 
| 592 | 
         
            -
                        custom_settings["tiles"].append(
         
     | 
| 593 | 
         
            -
                            {
         
     | 
| 594 | 
         
            -
                                "x": tile_x,
         
     | 
| 595 | 
         
            -
                                "y": tile_y,
         
     | 
| 596 | 
         
            -
                                "w": tile_w,
         
     | 
| 597 | 
         
            -
                                "h": tile_h,
         
     | 
| 598 | 
         
            -
                                "mask_box": None,
         
     | 
| 599 | 
         
            -
                                "prompt": tile_prompt,
         
     | 
| 600 | 
         
            -
                                "scale_factor": tile_scale,
         
     | 
| 601 | 
         
            -
                            }
         
     | 
| 602 | 
         
            -
                        )
         
     | 
| 603 | 
         
            -
                        tile_option = input("Do you want to define another tile? (y/N): ")
         
     | 
| 604 | 
         
            -
                        if tile_option == "" or tile_option.upper() == "N":
         
     | 
| 605 | 
         
            -
                            add_tile = False
         
     | 
| 606 | 
         
            -
             
     | 
| 607 | 
         
            -
                return custom_settings
         
     | 
| 608 | 
         
            -
             
     | 
| 609 | 
         
            -
             
     | 
| 610 | 
         
            -
            def interactive_sdupscale(
         
     | 
| 611 | 
         
            -
                config,
         
     | 
| 612 | 
         
            -
                context,
         
     | 
| 613 | 
         
            -
            ):
         
     | 
| 614 | 
         
            -
                settings = config.lcm_diffusion_setting
         
     | 
| 615 | 
         
            -
                settings.diffusion_task = DiffusionTask.image_to_image.value
         
     | 
| 616 | 
         
            -
                settings.init_image = ""
         
     | 
| 617 | 
         
            -
                source_path = ""
         
     | 
| 618 | 
         
            -
                steps = settings.inference_steps
         
     | 
| 619 | 
         
            -
             
     | 
| 620 | 
         
            -
                while True:
         
     | 
| 621 | 
         
            -
                    custom_upscale_settings = None
         
     | 
| 622 | 
         
            -
                    option = input("Edit custom SD Upscale settings? (y/N): ")
         
     | 
| 623 | 
         
            -
                    if option.upper() == "Y":
         
     | 
| 624 | 
         
            -
                        config.lcm_diffusion_setting.inference_steps = steps
         
     | 
| 625 | 
         
            -
                        custom_upscale_settings = interactive_sdupscale_settings(config)
         
     | 
| 626 | 
         
            -
                        if not custom_upscale_settings:
         
     | 
| 627 | 
         
            -
                            return
         
     | 
| 628 | 
         
            -
                        source_path = custom_upscale_settings["source_file"]
         
     | 
| 629 | 
         
            -
                    else:
         
     | 
| 630 | 
         
            -
                        new_path = input(f"Image path ({source_path}): ")
         
     | 
| 631 | 
         
            -
                        if new_path != "":
         
     | 
| 632 | 
         
            -
                            source_path = new_path
         
     | 
| 633 | 
         
            -
                        if source_path == "":
         
     | 
| 634 | 
         
            -
                            print("Error : You need to provide a file in SD Upscale mode")
         
     | 
| 635 | 
         
            -
                            return
         
     | 
| 636 | 
         
            -
                        settings.strength = user_value(
         
     | 
| 637 | 
         
            -
                            float,
         
     | 
| 638 | 
         
            -
                            f"SD Upscale strength ({settings.strength}): ",
         
     | 
| 639 | 
         
            -
                            settings.strength,
         
     | 
| 640 | 
         
            -
                        )
         
     | 
| 641 | 
         
            -
                        settings.inference_steps = int(steps / settings.strength + 1)
         
     | 
| 642 | 
         
            -
             
     | 
| 643 | 
         
            -
                    output_path = FastStableDiffusionPaths.get_upscale_filepath(
         
     | 
| 644 | 
         
            -
                        source_path,
         
     | 
| 645 | 
         
            -
                        2,
         
     | 
| 646 | 
         
            -
                        config.generated_images.format,
         
     | 
| 647 | 
         
            -
                    )
         
     | 
| 648 | 
         
            -
                    generate_upscaled_image(
         
     | 
| 649 | 
         
            -
                        config,
         
     | 
| 650 | 
         
            -
                        source_path,
         
     | 
| 651 | 
         
            -
                        settings.strength,
         
     | 
| 652 | 
         
            -
                        upscale_settings=custom_upscale_settings,
         
     | 
| 653 | 
         
            -
                        context=context,
         
     | 
| 654 | 
         
            -
                        tile_overlap=32 if settings.use_openvino else 16,
         
     | 
| 655 | 
         
            -
                        output_path=output_path,
         
     | 
| 656 | 
         
            -
                        image_format=config.generated_images.format,
         
     | 
| 657 | 
         
            -
                    )
         
     | 
| 658 | 
         
            -
                    user_input = input("Continue in SD Upscale mode? (Y/n): ")
         
     | 
| 659 | 
         
            -
                    if user_input.upper() == "N":
         
     | 
| 660 | 
         
            -
                        settings.inference_steps = steps
         
     | 
| 661 | 
         
            -
                        return
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/frontend/gui/app_window.py
    DELETED
    
    | 
         @@ -1,595 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from datetime import datetime
         
     | 
| 2 | 
         
            -
             
     | 
| 3 | 
         
            -
            from app_settings import AppSettings
         
     | 
| 4 | 
         
            -
            from backend.models.lcmdiffusion_setting import DiffusionTask
         
     | 
| 5 | 
         
            -
            from constants import (
         
     | 
| 6 | 
         
            -
                APP_NAME,
         
     | 
| 7 | 
         
            -
                APP_VERSION,
         
     | 
| 8 | 
         
            -
                LCM_DEFAULT_MODEL,
         
     | 
| 9 | 
         
            -
                LCM_DEFAULT_MODEL_OPENVINO,
         
     | 
| 10 | 
         
            -
            )
         
     | 
| 11 | 
         
            -
            from context import Context
         
     | 
| 12 | 
         
            -
            from frontend.gui.image_variations_widget import ImageVariationsWidget
         
     | 
| 13 | 
         
            -
            from frontend.gui.upscaler_widget import UpscalerWidget
         
     | 
| 14 | 
         
            -
            from frontend.gui.img2img_widget import Img2ImgWidget
         
     | 
| 15 | 
         
            -
            from frontend.utils import (
         
     | 
| 16 | 
         
            -
                enable_openvino_controls,
         
     | 
| 17 | 
         
            -
                get_valid_model_id,
         
     | 
| 18 | 
         
            -
                is_reshape_required,
         
     | 
| 19 | 
         
            -
            )
         
     | 
| 20 | 
         
            -
            from paths import FastStableDiffusionPaths
         
     | 
| 21 | 
         
            -
            from PyQt5 import QtCore, QtWidgets
         
     | 
| 22 | 
         
            -
            from PyQt5.QtCore import QSize, Qt, QThreadPool, QUrl
         
     | 
| 23 | 
         
            -
            from PyQt5.QtGui import QDesktopServices
         
     | 
| 24 | 
         
            -
            from PyQt5.QtWidgets import (
         
     | 
| 25 | 
         
            -
                QCheckBox,
         
     | 
| 26 | 
         
            -
                QComboBox,
         
     | 
| 27 | 
         
            -
                QFileDialog,
         
     | 
| 28 | 
         
            -
                QHBoxLayout,
         
     | 
| 29 | 
         
            -
                QLabel,
         
     | 
| 30 | 
         
            -
                QLineEdit,
         
     | 
| 31 | 
         
            -
                QMainWindow,
         
     | 
| 32 | 
         
            -
                QPushButton,
         
     | 
| 33 | 
         
            -
                QSizePolicy,
         
     | 
| 34 | 
         
            -
                QSlider,
         
     | 
| 35 | 
         
            -
                QSpacerItem,
         
     | 
| 36 | 
         
            -
                QTabWidget,
         
     | 
| 37 | 
         
            -
                QToolButton,
         
     | 
| 38 | 
         
            -
                QVBoxLayout,
         
     | 
| 39 | 
         
            -
                QWidget,
         
     | 
| 40 | 
         
            -
            )
         
     | 
| 41 | 
         
            -
             
     | 
| 42 | 
         
            -
            from models.interface_types import InterfaceType
         
     | 
| 43 | 
         
            -
            from frontend.gui.base_widget import BaseWidget
         
     | 
| 44 | 
         
            -
             
     | 
| 45 | 
         
            -
            # DPI scale fix
         
     | 
| 46 | 
         
            -
            QtWidgets.QApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling, True)
         
     | 
| 47 | 
         
            -
            QtWidgets.QApplication.setAttribute(QtCore.Qt.AA_UseHighDpiPixmaps, True)
         
     | 
| 48 | 
         
            -
             
     | 
| 49 | 
         
            -
             
     | 
| 50 | 
         
            -
            class MainWindow(QMainWindow):
         
     | 
| 51 | 
         
            -
                settings_changed = QtCore.pyqtSignal()
         
     | 
| 52 | 
         
            -
                """ This signal is used for enabling/disabling the negative prompt field for 
         
     | 
| 53 | 
         
            -
                modes that support it; in particular, negative prompt is supported with OpenVINO models 
         
     | 
| 54 | 
         
            -
                and in LCM-LoRA mode but not in LCM mode
         
     | 
| 55 | 
         
            -
                """
         
     | 
| 56 | 
         
            -
             
     | 
| 57 | 
         
            -
                def __init__(self, config: AppSettings):
         
     | 
| 58 | 
         
            -
                    super().__init__()
         
     | 
| 59 | 
         
            -
                    self.config = config
         
     | 
| 60 | 
         
            -
                    # Prevent saved LoRA and ControlNet settings from being used by
         
     | 
| 61 | 
         
            -
                    # default; in GUI mode, the user must explicitly enable those
         
     | 
| 62 | 
         
            -
                    if self.config.settings.lcm_diffusion_setting.lora:
         
     | 
| 63 | 
         
            -
                        self.config.settings.lcm_diffusion_setting.lora.enabled = False
         
     | 
| 64 | 
         
            -
                    if self.config.settings.lcm_diffusion_setting.controlnet:
         
     | 
| 65 | 
         
            -
                        self.config.settings.lcm_diffusion_setting.controlnet.enabled = False
         
     | 
| 66 | 
         
            -
                    self.setWindowTitle(APP_NAME)
         
     | 
| 67 | 
         
            -
                    self.setFixedSize(QSize(600, 670))
         
     | 
| 68 | 
         
            -
                    self.init_ui()
         
     | 
| 69 | 
         
            -
                    self.pipeline = None
         
     | 
| 70 | 
         
            -
                    self.threadpool = QThreadPool()
         
     | 
| 71 | 
         
            -
                    self.device = "cpu"
         
     | 
| 72 | 
         
            -
                    self.previous_width = 0
         
     | 
| 73 | 
         
            -
                    self.previous_height = 0
         
     | 
| 74 | 
         
            -
                    self.previous_model = ""
         
     | 
| 75 | 
         
            -
                    self.previous_num_of_images = 0
         
     | 
| 76 | 
         
            -
                    self.context = Context(InterfaceType.GUI)
         
     | 
| 77 | 
         
            -
                    self.init_ui_values()
         
     | 
| 78 | 
         
            -
                    self.gen_images = []
         
     | 
| 79 | 
         
            -
                    self.image_index = 0
         
     | 
| 80 | 
         
            -
                    print(f"Output path : {self.config.settings.generated_images.path}")
         
     | 
| 81 | 
         
            -
             
     | 
| 82 | 
         
            -
                def init_ui_values(self):
         
     | 
| 83 | 
         
            -
                    self.lcm_model.setEnabled(
         
     | 
| 84 | 
         
            -
                        not self.config.settings.lcm_diffusion_setting.use_openvino
         
     | 
| 85 | 
         
            -
                    )
         
     | 
| 86 | 
         
            -
                    self.guidance.setValue(
         
     | 
| 87 | 
         
            -
                        int(self.config.settings.lcm_diffusion_setting.guidance_scale * 10)
         
     | 
| 88 | 
         
            -
                    )
         
     | 
| 89 | 
         
            -
                    self.seed_value.setEnabled(self.config.settings.lcm_diffusion_setting.use_seed)
         
     | 
| 90 | 
         
            -
                    self.safety_checker.setChecked(
         
     | 
| 91 | 
         
            -
                        self.config.settings.lcm_diffusion_setting.use_safety_checker
         
     | 
| 92 | 
         
            -
                    )
         
     | 
| 93 | 
         
            -
                    self.use_openvino_check.setChecked(
         
     | 
| 94 | 
         
            -
                        self.config.settings.lcm_diffusion_setting.use_openvino
         
     | 
| 95 | 
         
            -
                    )
         
     | 
| 96 | 
         
            -
                    self.width.setCurrentText(
         
     | 
| 97 | 
         
            -
                        str(self.config.settings.lcm_diffusion_setting.image_width)
         
     | 
| 98 | 
         
            -
                    )
         
     | 
| 99 | 
         
            -
                    self.height.setCurrentText(
         
     | 
| 100 | 
         
            -
                        str(self.config.settings.lcm_diffusion_setting.image_height)
         
     | 
| 101 | 
         
            -
                    )
         
     | 
| 102 | 
         
            -
                    self.inference_steps.setValue(
         
     | 
| 103 | 
         
            -
                        int(self.config.settings.lcm_diffusion_setting.inference_steps)
         
     | 
| 104 | 
         
            -
                    )
         
     | 
| 105 | 
         
            -
                    self.clip_skip.setValue(
         
     | 
| 106 | 
         
            -
                        int(self.config.settings.lcm_diffusion_setting.clip_skip)
         
     | 
| 107 | 
         
            -
                    )
         
     | 
| 108 | 
         
            -
                    self.token_merging.setValue(
         
     | 
| 109 | 
         
            -
                        int(self.config.settings.lcm_diffusion_setting.token_merging * 100)
         
     | 
| 110 | 
         
            -
                    )
         
     | 
| 111 | 
         
            -
                    self.seed_check.setChecked(self.config.settings.lcm_diffusion_setting.use_seed)
         
     | 
| 112 | 
         
            -
                    self.seed_value.setText(str(self.config.settings.lcm_diffusion_setting.seed))
         
     | 
| 113 | 
         
            -
                    self.use_local_model_folder.setChecked(
         
     | 
| 114 | 
         
            -
                        self.config.settings.lcm_diffusion_setting.use_offline_model
         
     | 
| 115 | 
         
            -
                    )
         
     | 
| 116 | 
         
            -
                    self.results_path.setText(self.config.settings.generated_images.path)
         
     | 
| 117 | 
         
            -
                    self.num_images.setValue(
         
     | 
| 118 | 
         
            -
                        self.config.settings.lcm_diffusion_setting.number_of_images
         
     | 
| 119 | 
         
            -
                    )
         
     | 
| 120 | 
         
            -
                    self.use_tae_sd.setChecked(
         
     | 
| 121 | 
         
            -
                        self.config.settings.lcm_diffusion_setting.use_tiny_auto_encoder
         
     | 
| 122 | 
         
            -
                    )
         
     | 
| 123 | 
         
            -
                    self.use_lcm_lora.setChecked(
         
     | 
| 124 | 
         
            -
                        self.config.settings.lcm_diffusion_setting.use_lcm_lora
         
     | 
| 125 | 
         
            -
                    )
         
     | 
| 126 | 
         
            -
                    self.lcm_model.setCurrentText(
         
     | 
| 127 | 
         
            -
                        get_valid_model_id(
         
     | 
| 128 | 
         
            -
                            self.config.lcm_models,
         
     | 
| 129 | 
         
            -
                            self.config.settings.lcm_diffusion_setting.lcm_model_id,
         
     | 
| 130 | 
         
            -
                            LCM_DEFAULT_MODEL,
         
     | 
| 131 | 
         
            -
                        )
         
     | 
| 132 | 
         
            -
                    )
         
     | 
| 133 | 
         
            -
                    self.base_model_id.setCurrentText(
         
     | 
| 134 | 
         
            -
                        get_valid_model_id(
         
     | 
| 135 | 
         
            -
                            self.config.stable_diffsuion_models,
         
     | 
| 136 | 
         
            -
                            self.config.settings.lcm_diffusion_setting.lcm_lora.base_model_id,
         
     | 
| 137 | 
         
            -
                        )
         
     | 
| 138 | 
         
            -
                    )
         
     | 
| 139 | 
         
            -
                    self.lcm_lora_id.setCurrentText(
         
     | 
| 140 | 
         
            -
                        get_valid_model_id(
         
     | 
| 141 | 
         
            -
                            self.config.lcm_lora_models,
         
     | 
| 142 | 
         
            -
                            self.config.settings.lcm_diffusion_setting.lcm_lora.lcm_lora_id,
         
     | 
| 143 | 
         
            -
                        )
         
     | 
| 144 | 
         
            -
                    )
         
     | 
| 145 | 
         
            -
                    self.openvino_lcm_model_id.setCurrentText(
         
     | 
| 146 | 
         
            -
                        get_valid_model_id(
         
     | 
| 147 | 
         
            -
                            self.config.openvino_lcm_models,
         
     | 
| 148 | 
         
            -
                            self.config.settings.lcm_diffusion_setting.openvino_lcm_model_id,
         
     | 
| 149 | 
         
            -
                            LCM_DEFAULT_MODEL_OPENVINO,
         
     | 
| 150 | 
         
            -
                        )
         
     | 
| 151 | 
         
            -
                    )
         
     | 
| 152 | 
         
            -
                    self.openvino_lcm_model_id.setEnabled(
         
     | 
| 153 | 
         
            -
                        self.config.settings.lcm_diffusion_setting.use_openvino
         
     | 
| 154 | 
         
            -
                    )
         
     | 
| 155 | 
         
            -
             
     | 
| 156 | 
         
            -
                def init_ui(self):
         
     | 
| 157 | 
         
            -
                    self.create_main_tab()
         
     | 
| 158 | 
         
            -
                    self.create_settings_tab()
         
     | 
| 159 | 
         
            -
                    self.create_about_tab()
         
     | 
| 160 | 
         
            -
                    self.show()
         
     | 
| 161 | 
         
            -
             
     | 
| 162 | 
         
            -
                def create_main_tab(self):
         
     | 
| 163 | 
         
            -
                    self.tab_widget = QTabWidget(self)
         
     | 
| 164 | 
         
            -
                    self.tab_main = BaseWidget(self.config, self)
         
     | 
| 165 | 
         
            -
                    self.tab_settings = QWidget()
         
     | 
| 166 | 
         
            -
                    self.tab_about = QWidget()
         
     | 
| 167 | 
         
            -
                    self.img2img_tab = Img2ImgWidget(self.config, self)
         
     | 
| 168 | 
         
            -
                    self.variations_tab = ImageVariationsWidget(self.config, self)
         
     | 
| 169 | 
         
            -
                    self.upscaler_tab = UpscalerWidget(self.config, self)
         
     | 
| 170 | 
         
            -
             
     | 
| 171 | 
         
            -
                    # Add main window tabs here
         
     | 
| 172 | 
         
            -
                    self.tab_widget.addTab(self.tab_main, "Text to Image")
         
     | 
| 173 | 
         
            -
                    self.tab_widget.addTab(self.img2img_tab, "Image to Image")
         
     | 
| 174 | 
         
            -
                    self.tab_widget.addTab(self.variations_tab, "Image Variations")
         
     | 
| 175 | 
         
            -
                    self.tab_widget.addTab(self.upscaler_tab, "Upscaler")
         
     | 
| 176 | 
         
            -
                    self.tab_widget.addTab(self.tab_settings, "Settings")
         
     | 
| 177 | 
         
            -
                    self.tab_widget.addTab(self.tab_about, "About")
         
     | 
| 178 | 
         
            -
             
     | 
| 179 | 
         
            -
                    self.setCentralWidget(self.tab_widget)
         
     | 
| 180 | 
         
            -
                    self.use_seed = False
         
     | 
| 181 | 
         
            -
             
     | 
| 182 | 
         
            -
                def create_settings_tab(self):
         
     | 
| 183 | 
         
            -
                    self.lcm_model_label = QLabel("Latent Consistency Model:")
         
     | 
| 184 | 
         
            -
                    # self.lcm_model = QLineEdit(LCM_DEFAULT_MODEL)
         
     | 
| 185 | 
         
            -
                    self.lcm_model = QComboBox(self)
         
     | 
| 186 | 
         
            -
                    self.lcm_model.addItems(self.config.lcm_models)
         
     | 
| 187 | 
         
            -
                    self.lcm_model.currentIndexChanged.connect(self.on_lcm_model_changed)
         
     | 
| 188 | 
         
            -
             
     | 
| 189 | 
         
            -
                    self.use_lcm_lora = QCheckBox("Use LCM LoRA")
         
     | 
| 190 | 
         
            -
                    self.use_lcm_lora.setChecked(False)
         
     | 
| 191 | 
         
            -
                    self.use_lcm_lora.stateChanged.connect(self.use_lcm_lora_changed)
         
     | 
| 192 | 
         
            -
             
     | 
| 193 | 
         
            -
                    self.lora_base_model_id_label = QLabel("Lora base model ID :")
         
     | 
| 194 | 
         
            -
                    self.base_model_id = QComboBox(self)
         
     | 
| 195 | 
         
            -
                    self.base_model_id.addItems(self.config.stable_diffsuion_models)
         
     | 
| 196 | 
         
            -
                    self.base_model_id.currentIndexChanged.connect(self.on_base_model_id_changed)
         
     | 
| 197 | 
         
            -
             
     | 
| 198 | 
         
            -
                    self.lcm_lora_model_id_label = QLabel("LCM LoRA model ID :")
         
     | 
| 199 | 
         
            -
                    self.lcm_lora_id = QComboBox(self)
         
     | 
| 200 | 
         
            -
                    self.lcm_lora_id.addItems(self.config.lcm_lora_models)
         
     | 
| 201 | 
         
            -
                    self.lcm_lora_id.currentIndexChanged.connect(self.on_lcm_lora_id_changed)
         
     | 
| 202 | 
         
            -
             
     | 
| 203 | 
         
            -
                    self.inference_steps_value = QLabel("Number of inference steps: 4")
         
     | 
| 204 | 
         
            -
                    self.inference_steps = QSlider(orientation=Qt.Orientation.Horizontal)
         
     | 
| 205 | 
         
            -
                    self.inference_steps.setMaximum(25)
         
     | 
| 206 | 
         
            -
                    self.inference_steps.setMinimum(1)
         
     | 
| 207 | 
         
            -
                    self.inference_steps.setValue(4)
         
     | 
| 208 | 
         
            -
                    self.inference_steps.valueChanged.connect(self.update_steps_label)
         
     | 
| 209 | 
         
            -
             
     | 
| 210 | 
         
            -
                    self.num_images_value = QLabel("Number of images: 1")
         
     | 
| 211 | 
         
            -
                    self.num_images = QSlider(orientation=Qt.Orientation.Horizontal)
         
     | 
| 212 | 
         
            -
                    self.num_images.setMaximum(100)
         
     | 
| 213 | 
         
            -
                    self.num_images.setMinimum(1)
         
     | 
| 214 | 
         
            -
                    self.num_images.setValue(1)
         
     | 
| 215 | 
         
            -
                    self.num_images.valueChanged.connect(self.update_num_images_label)
         
     | 
| 216 | 
         
            -
             
     | 
| 217 | 
         
            -
                    self.guidance_value = QLabel("Guidance scale: 1")
         
     | 
| 218 | 
         
            -
                    self.guidance = QSlider(orientation=Qt.Orientation.Horizontal)
         
     | 
| 219 | 
         
            -
                    self.guidance.setMaximum(20)
         
     | 
| 220 | 
         
            -
                    self.guidance.setMinimum(10)
         
     | 
| 221 | 
         
            -
                    self.guidance.setValue(10)
         
     | 
| 222 | 
         
            -
                    self.guidance.valueChanged.connect(self.update_guidance_label)
         
     | 
| 223 | 
         
            -
             
     | 
| 224 | 
         
            -
                    self.clip_skip_value = QLabel("CLIP Skip: 1")
         
     | 
| 225 | 
         
            -
                    self.clip_skip = QSlider(orientation=Qt.Orientation.Horizontal)
         
     | 
| 226 | 
         
            -
                    self.clip_skip.setMaximum(12)
         
     | 
| 227 | 
         
            -
                    self.clip_skip.setMinimum(1)
         
     | 
| 228 | 
         
            -
                    self.clip_skip.setValue(1)
         
     | 
| 229 | 
         
            -
                    self.clip_skip.valueChanged.connect(self.update_clip_skip_label)
         
     | 
| 230 | 
         
            -
             
     | 
| 231 | 
         
            -
                    self.token_merging_value = QLabel("Token Merging: 0")
         
     | 
| 232 | 
         
            -
                    self.token_merging = QSlider(orientation=Qt.Orientation.Horizontal)
         
     | 
| 233 | 
         
            -
                    self.token_merging.setMaximum(100)
         
     | 
| 234 | 
         
            -
                    self.token_merging.setMinimum(0)
         
     | 
| 235 | 
         
            -
                    self.token_merging.setValue(0)
         
     | 
| 236 | 
         
            -
                    self.token_merging.valueChanged.connect(self.update_token_merging_label)
         
     | 
| 237 | 
         
            -
             
     | 
| 238 | 
         
            -
                    self.width_value = QLabel("Width :")
         
     | 
| 239 | 
         
            -
                    self.width = QComboBox(self)
         
     | 
| 240 | 
         
            -
                    self.width.addItem("256")
         
     | 
| 241 | 
         
            -
                    self.width.addItem("512")
         
     | 
| 242 | 
         
            -
                    self.width.addItem("768")
         
     | 
| 243 | 
         
            -
                    self.width.addItem("1024")
         
     | 
| 244 | 
         
            -
                    self.width.setCurrentText("512")
         
     | 
| 245 | 
         
            -
                    self.width.currentIndexChanged.connect(self.on_width_changed)
         
     | 
| 246 | 
         
            -
             
     | 
| 247 | 
         
            -
                    self.height_value = QLabel("Height :")
         
     | 
| 248 | 
         
            -
                    self.height = QComboBox(self)
         
     | 
| 249 | 
         
            -
                    self.height.addItem("256")
         
     | 
| 250 | 
         
            -
                    self.height.addItem("512")
         
     | 
| 251 | 
         
            -
                    self.height.addItem("768")
         
     | 
| 252 | 
         
            -
                    self.height.addItem("1024")
         
     | 
| 253 | 
         
            -
                    self.height.setCurrentText("512")
         
     | 
| 254 | 
         
            -
                    self.height.currentIndexChanged.connect(self.on_height_changed)
         
     | 
| 255 | 
         
            -
             
     | 
| 256 | 
         
            -
                    self.seed_check = QCheckBox("Use seed")
         
     | 
| 257 | 
         
            -
                    self.seed_value = QLineEdit()
         
     | 
| 258 | 
         
            -
                    self.seed_value.setInputMask("9999999999")
         
     | 
| 259 | 
         
            -
                    self.seed_value.setText("123123")
         
     | 
| 260 | 
         
            -
                    self.seed_check.stateChanged.connect(self.seed_changed)
         
     | 
| 261 | 
         
            -
             
     | 
| 262 | 
         
            -
                    self.safety_checker = QCheckBox("Use safety checker")
         
     | 
| 263 | 
         
            -
                    self.safety_checker.setChecked(True)
         
     | 
| 264 | 
         
            -
                    self.safety_checker.stateChanged.connect(self.use_safety_checker_changed)
         
     | 
| 265 | 
         
            -
             
     | 
| 266 | 
         
            -
                    self.use_openvino_check = QCheckBox("Use OpenVINO")
         
     | 
| 267 | 
         
            -
                    self.use_openvino_check.setChecked(False)
         
     | 
| 268 | 
         
            -
                    self.openvino_model_label = QLabel("OpenVINO LCM model:")
         
     | 
| 269 | 
         
            -
                    self.use_local_model_folder = QCheckBox(
         
     | 
| 270 | 
         
            -
                        "Use locally cached model or downloaded model folder(offline)"
         
     | 
| 271 | 
         
            -
                    )
         
     | 
| 272 | 
         
            -
                    self.openvino_lcm_model_id = QComboBox(self)
         
     | 
| 273 | 
         
            -
                    self.openvino_lcm_model_id.addItems(self.config.openvino_lcm_models)
         
     | 
| 274 | 
         
            -
                    self.openvino_lcm_model_id.currentIndexChanged.connect(
         
     | 
| 275 | 
         
            -
                        self.on_openvino_lcm_model_id_changed
         
     | 
| 276 | 
         
            -
                    )
         
     | 
| 277 | 
         
            -
             
     | 
| 278 | 
         
            -
                    self.use_openvino_check.setEnabled(enable_openvino_controls())
         
     | 
| 279 | 
         
            -
                    self.use_local_model_folder.setChecked(False)
         
     | 
| 280 | 
         
            -
                    self.use_local_model_folder.stateChanged.connect(self.use_offline_model_changed)
         
     | 
| 281 | 
         
            -
                    self.use_openvino_check.stateChanged.connect(self.use_openvino_changed)
         
     | 
| 282 | 
         
            -
             
     | 
| 283 | 
         
            -
                    self.use_tae_sd = QCheckBox(
         
     | 
| 284 | 
         
            -
                        "Use Tiny Auto Encoder - TAESD (Fast, moderate quality)"
         
     | 
| 285 | 
         
            -
                    )
         
     | 
| 286 | 
         
            -
                    self.use_tae_sd.setChecked(False)
         
     | 
| 287 | 
         
            -
                    self.use_tae_sd.stateChanged.connect(self.use_tae_sd_changed)
         
     | 
| 288 | 
         
            -
             
     | 
| 289 | 
         
            -
                    hlayout = QHBoxLayout()
         
     | 
| 290 | 
         
            -
                    hlayout.addWidget(self.seed_check)
         
     | 
| 291 | 
         
            -
                    hlayout.addWidget(self.seed_value)
         
     | 
| 292 | 
         
            -
                    hspacer = QSpacerItem(20, 10, QSizePolicy.Expanding, QSizePolicy.Minimum)
         
     | 
| 293 | 
         
            -
                    slider_hspacer = QSpacerItem(20, 10, QSizePolicy.Expanding, QSizePolicy.Minimum)
         
     | 
| 294 | 
         
            -
             
     | 
| 295 | 
         
            -
                    self.results_path_label = QLabel("Output path:")
         
     | 
| 296 | 
         
            -
                    self.results_path = QLineEdit()
         
     | 
| 297 | 
         
            -
                    self.results_path.textChanged.connect(self.on_path_changed)
         
     | 
| 298 | 
         
            -
                    self.browse_folder_btn = QToolButton()
         
     | 
| 299 | 
         
            -
                    self.browse_folder_btn.setText("...")
         
     | 
| 300 | 
         
            -
                    self.browse_folder_btn.clicked.connect(self.on_browse_folder)
         
     | 
| 301 | 
         
            -
             
     | 
| 302 | 
         
            -
                    self.reset = QPushButton("Reset All")
         
     | 
| 303 | 
         
            -
                    self.reset.clicked.connect(self.reset_all_settings)
         
     | 
| 304 | 
         
            -
             
     | 
| 305 | 
         
            -
                    vlayout = QVBoxLayout()
         
     | 
| 306 | 
         
            -
                    vspacer = QSpacerItem(20, 20, QSizePolicy.Minimum, QSizePolicy.Expanding)
         
     | 
| 307 | 
         
            -
                    vlayout.addItem(hspacer)
         
     | 
| 308 | 
         
            -
                    vlayout.setSpacing(3)
         
     | 
| 309 | 
         
            -
                    vlayout.addWidget(self.lcm_model_label)
         
     | 
| 310 | 
         
            -
                    vlayout.addWidget(self.lcm_model)
         
     | 
| 311 | 
         
            -
                    vlayout.addWidget(self.use_local_model_folder)
         
     | 
| 312 | 
         
            -
                    vlayout.addWidget(self.use_lcm_lora)
         
     | 
| 313 | 
         
            -
                    vlayout.addWidget(self.lora_base_model_id_label)
         
     | 
| 314 | 
         
            -
                    vlayout.addWidget(self.base_model_id)
         
     | 
| 315 | 
         
            -
                    vlayout.addWidget(self.lcm_lora_model_id_label)
         
     | 
| 316 | 
         
            -
                    vlayout.addWidget(self.lcm_lora_id)
         
     | 
| 317 | 
         
            -
                    vlayout.addWidget(self.use_openvino_check)
         
     | 
| 318 | 
         
            -
                    vlayout.addWidget(self.openvino_model_label)
         
     | 
| 319 | 
         
            -
                    vlayout.addWidget(self.openvino_lcm_model_id)
         
     | 
| 320 | 
         
            -
                    vlayout.addWidget(self.use_tae_sd)
         
     | 
| 321 | 
         
            -
                    vlayout.addItem(slider_hspacer)
         
     | 
| 322 | 
         
            -
                    vlayout.addWidget(self.inference_steps_value)
         
     | 
| 323 | 
         
            -
                    vlayout.addWidget(self.inference_steps)
         
     | 
| 324 | 
         
            -
                    vlayout.addWidget(self.num_images_value)
         
     | 
| 325 | 
         
            -
                    vlayout.addWidget(self.num_images)
         
     | 
| 326 | 
         
            -
                    vlayout.addWidget(self.width_value)
         
     | 
| 327 | 
         
            -
                    vlayout.addWidget(self.width)
         
     | 
| 328 | 
         
            -
                    vlayout.addWidget(self.height_value)
         
     | 
| 329 | 
         
            -
                    vlayout.addWidget(self.height)
         
     | 
| 330 | 
         
            -
                    vlayout.addWidget(self.guidance_value)
         
     | 
| 331 | 
         
            -
                    vlayout.addWidget(self.guidance)
         
     | 
| 332 | 
         
            -
                    vlayout.addWidget(self.clip_skip_value)
         
     | 
| 333 | 
         
            -
                    vlayout.addWidget(self.clip_skip)
         
     | 
| 334 | 
         
            -
                    vlayout.addWidget(self.token_merging_value)
         
     | 
| 335 | 
         
            -
                    vlayout.addWidget(self.token_merging)
         
     | 
| 336 | 
         
            -
                    vlayout.addLayout(hlayout)
         
     | 
| 337 | 
         
            -
                    vlayout.addWidget(self.safety_checker)
         
     | 
| 338 | 
         
            -
             
     | 
| 339 | 
         
            -
                    vlayout.addWidget(self.results_path_label)
         
     | 
| 340 | 
         
            -
                    hlayout_path = QHBoxLayout()
         
     | 
| 341 | 
         
            -
                    hlayout_path.addWidget(self.results_path)
         
     | 
| 342 | 
         
            -
                    hlayout_path.addWidget(self.browse_folder_btn)
         
     | 
| 343 | 
         
            -
                    vlayout.addLayout(hlayout_path)
         
     | 
| 344 | 
         
            -
                    self.tab_settings.setLayout(vlayout)
         
     | 
| 345 | 
         
            -
                    hlayout_reset = QHBoxLayout()
         
     | 
| 346 | 
         
            -
                    hspacer = QSpacerItem(20, 20, QSizePolicy.Expanding, QSizePolicy.Minimum)
         
     | 
| 347 | 
         
            -
                    hlayout_reset.addItem(hspacer)
         
     | 
| 348 | 
         
            -
                    hlayout_reset.addWidget(self.reset)
         
     | 
| 349 | 
         
            -
                    vlayout.addLayout(hlayout_reset)
         
     | 
| 350 | 
         
            -
                    vlayout.addItem(vspacer)
         
     | 
| 351 | 
         
            -
             
     | 
| 352 | 
         
            -
                def create_about_tab(self):
         
     | 
| 353 | 
         
            -
                    self.label = QLabel()
         
     | 
| 354 | 
         
            -
                    self.label.setAlignment(Qt.AlignCenter)
         
     | 
| 355 | 
         
            -
                    current_year = datetime.now().year
         
     | 
| 356 | 
         
            -
                    self.label.setText(
         
     | 
| 357 | 
         
            -
                        f"""<h1>FastSD CPU {APP_VERSION}</h1> 
         
     | 
| 358 | 
         
            -
                           <h3>(c)2023 - {current_year} Rupesh Sreeraman</h3>
         
     | 
| 359 | 
         
            -
                            <h3>Faster stable diffusion on CPU</h3>
         
     | 
| 360 | 
         
            -
                             <h3>Based on Latent Consistency Models</h3>
         
     | 
| 361 | 
         
            -
                            <h3>GitHub : https://github.com/rupeshs/fastsdcpu/</h3>"""
         
     | 
| 362 | 
         
            -
                    )
         
     | 
| 363 | 
         
            -
             
     | 
| 364 | 
         
            -
                    vlayout = QVBoxLayout()
         
     | 
| 365 | 
         
            -
                    vlayout.addWidget(self.label)
         
     | 
| 366 | 
         
            -
                    self.tab_about.setLayout(vlayout)
         
     | 
| 367 | 
         
            -
             
     | 
| 368 | 
         
            -
                def show_image(self, pixmap):
         
     | 
| 369 | 
         
            -
                    image_width = self.config.settings.lcm_diffusion_setting.image_width
         
     | 
| 370 | 
         
            -
                    image_height = self.config.settings.lcm_diffusion_setting.image_height
         
     | 
| 371 | 
         
            -
                    if image_width > 512 or image_height > 512:
         
     | 
| 372 | 
         
            -
                        new_width = 512 if image_width > 512 else image_width
         
     | 
| 373 | 
         
            -
                        new_height = 512 if image_height > 512 else image_height
         
     | 
| 374 | 
         
            -
                        self.img.setPixmap(
         
     | 
| 375 | 
         
            -
                            pixmap.scaled(
         
     | 
| 376 | 
         
            -
                                new_width,
         
     | 
| 377 | 
         
            -
                                new_height,
         
     | 
| 378 | 
         
            -
                                Qt.KeepAspectRatio,
         
     | 
| 379 | 
         
            -
                            )
         
     | 
| 380 | 
         
            -
                        )
         
     | 
| 381 | 
         
            -
                    else:
         
     | 
| 382 | 
         
            -
                        self.img.setPixmap(pixmap)
         
     | 
| 383 | 
         
            -
             
     | 
| 384 | 
         
            -
                def on_show_next_image(self):
         
     | 
| 385 | 
         
            -
                    if self.image_index != len(self.gen_images) - 1 and len(self.gen_images) > 0:
         
     | 
| 386 | 
         
            -
                        self.previous_img_btn.setEnabled(True)
         
     | 
| 387 | 
         
            -
                        self.image_index += 1
         
     | 
| 388 | 
         
            -
                        self.show_image(self.gen_images[self.image_index])
         
     | 
| 389 | 
         
            -
                        if self.image_index == len(self.gen_images) - 1:
         
     | 
| 390 | 
         
            -
                            self.next_img_btn.setEnabled(False)
         
     | 
| 391 | 
         
            -
             
     | 
| 392 | 
         
            -
                def on_open_results_folder(self):
         
     | 
| 393 | 
         
            -
                    QDesktopServices.openUrl(
         
     | 
| 394 | 
         
            -
                        QUrl.fromLocalFile(self.config.settings.generated_images.path)
         
     | 
| 395 | 
         
            -
                    )
         
     | 
| 396 | 
         
            -
             
     | 
| 397 | 
         
            -
                def on_show_previous_image(self):
         
     | 
| 398 | 
         
            -
                    if self.image_index != 0:
         
     | 
| 399 | 
         
            -
                        self.next_img_btn.setEnabled(True)
         
     | 
| 400 | 
         
            -
                        self.image_index -= 1
         
     | 
| 401 | 
         
            -
                        self.show_image(self.gen_images[self.image_index])
         
     | 
| 402 | 
         
            -
                        if self.image_index == 0:
         
     | 
| 403 | 
         
            -
                            self.previous_img_btn.setEnabled(False)
         
     | 
| 404 | 
         
            -
             
     | 
| 405 | 
         
            -
                def on_path_changed(self, text):
         
     | 
| 406 | 
         
            -
                    self.config.settings.generated_images.path = text
         
     | 
| 407 | 
         
            -
             
     | 
| 408 | 
         
            -
                def on_browse_folder(self):
         
     | 
| 409 | 
         
            -
                    options = QFileDialog.Options()
         
     | 
| 410 | 
         
            -
                    options |= QFileDialog.ShowDirsOnly
         
     | 
| 411 | 
         
            -
             
     | 
| 412 | 
         
            -
                    folder_path = QFileDialog.getExistingDirectory(
         
     | 
| 413 | 
         
            -
                        self, "Select a Folder", "", options=options
         
     | 
| 414 | 
         
            -
                    )
         
     | 
| 415 | 
         
            -
             
     | 
| 416 | 
         
            -
                    if folder_path:
         
     | 
| 417 | 
         
            -
                        self.config.settings.generated_images.path = folder_path
         
     | 
| 418 | 
         
            -
                        self.results_path.setText(folder_path)
         
     | 
| 419 | 
         
            -
             
     | 
| 420 | 
         
            -
                def on_width_changed(self, index):
         
     | 
| 421 | 
         
            -
                    width_txt = self.width.itemText(index)
         
     | 
| 422 | 
         
            -
                    self.config.settings.lcm_diffusion_setting.image_width = int(width_txt)
         
     | 
| 423 | 
         
            -
             
     | 
| 424 | 
         
            -
                def on_height_changed(self, index):
         
     | 
| 425 | 
         
            -
                    height_txt = self.height.itemText(index)
         
     | 
| 426 | 
         
            -
                    self.config.settings.lcm_diffusion_setting.image_height = int(height_txt)
         
     | 
| 427 | 
         
            -
             
     | 
| 428 | 
         
            -
                def on_lcm_model_changed(self, index):
         
     | 
| 429 | 
         
            -
                    model_id = self.lcm_model.itemText(index)
         
     | 
| 430 | 
         
            -
                    self.config.settings.lcm_diffusion_setting.lcm_model_id = model_id
         
     | 
| 431 | 
         
            -
             
     | 
| 432 | 
         
            -
                def on_base_model_id_changed(self, index):
         
     | 
| 433 | 
         
            -
                    model_id = self.base_model_id.itemText(index)
         
     | 
| 434 | 
         
            -
                    self.config.settings.lcm_diffusion_setting.lcm_lora.base_model_id = model_id
         
     | 
| 435 | 
         
            -
             
     | 
| 436 | 
         
            -
                def on_lcm_lora_id_changed(self, index):
         
     | 
| 437 | 
         
            -
                    model_id = self.lcm_lora_id.itemText(index)
         
     | 
| 438 | 
         
            -
                    self.config.settings.lcm_diffusion_setting.lcm_lora.lcm_lora_id = model_id
         
     | 
| 439 | 
         
            -
             
     | 
| 440 | 
         
            -
                def on_openvino_lcm_model_id_changed(self, index):
         
     | 
| 441 | 
         
            -
                    model_id = self.openvino_lcm_model_id.itemText(index)
         
     | 
| 442 | 
         
            -
                    self.config.settings.lcm_diffusion_setting.openvino_lcm_model_id = model_id
         
     | 
| 443 | 
         
            -
             
     | 
| 444 | 
         
            -
                def use_openvino_changed(self, state):
         
     | 
| 445 | 
         
            -
                    if state == 2:
         
     | 
| 446 | 
         
            -
                        self.lcm_model.setEnabled(False)
         
     | 
| 447 | 
         
            -
                        self.use_lcm_lora.setEnabled(False)
         
     | 
| 448 | 
         
            -
                        self.lcm_lora_id.setEnabled(False)
         
     | 
| 449 | 
         
            -
                        self.base_model_id.setEnabled(False)
         
     | 
| 450 | 
         
            -
                        self.openvino_lcm_model_id.setEnabled(True)
         
     | 
| 451 | 
         
            -
                        self.config.settings.lcm_diffusion_setting.use_openvino = True
         
     | 
| 452 | 
         
            -
                    else:
         
     | 
| 453 | 
         
            -
                        self.lcm_model.setEnabled(True)
         
     | 
| 454 | 
         
            -
                        self.use_lcm_lora.setEnabled(True)
         
     | 
| 455 | 
         
            -
                        self.lcm_lora_id.setEnabled(True)
         
     | 
| 456 | 
         
            -
                        self.base_model_id.setEnabled(True)
         
     | 
| 457 | 
         
            -
                        self.openvino_lcm_model_id.setEnabled(False)
         
     | 
| 458 | 
         
            -
                        self.config.settings.lcm_diffusion_setting.use_openvino = False
         
     | 
| 459 | 
         
            -
                    self.settings_changed.emit()
         
     | 
| 460 | 
         
            -
             
     | 
| 461 | 
         
            -
                def use_tae_sd_changed(self, state):
         
     | 
| 462 | 
         
            -
                    if state == 2:
         
     | 
| 463 | 
         
            -
                        self.config.settings.lcm_diffusion_setting.use_tiny_auto_encoder = True
         
     | 
| 464 | 
         
            -
                    else:
         
     | 
| 465 | 
         
            -
                        self.config.settings.lcm_diffusion_setting.use_tiny_auto_encoder = False
         
     | 
| 466 | 
         
            -
             
     | 
| 467 | 
         
            -
                def use_offline_model_changed(self, state):
         
     | 
| 468 | 
         
            -
                    if state == 2:
         
     | 
| 469 | 
         
            -
                        self.config.settings.lcm_diffusion_setting.use_offline_model = True
         
     | 
| 470 | 
         
            -
                    else:
         
     | 
| 471 | 
         
            -
                        self.config.settings.lcm_diffusion_setting.use_offline_model = False
         
     | 
| 472 | 
         
            -
             
     | 
| 473 | 
         
            -
                def use_lcm_lora_changed(self, state):
         
     | 
| 474 | 
         
            -
                    if state == 2:
         
     | 
| 475 | 
         
            -
                        self.lcm_model.setEnabled(False)
         
     | 
| 476 | 
         
            -
                        self.lcm_lora_id.setEnabled(True)
         
     | 
| 477 | 
         
            -
                        self.base_model_id.setEnabled(True)
         
     | 
| 478 | 
         
            -
                        self.config.settings.lcm_diffusion_setting.use_lcm_lora = True
         
     | 
| 479 | 
         
            -
                    else:
         
     | 
| 480 | 
         
            -
                        self.lcm_model.setEnabled(True)
         
     | 
| 481 | 
         
            -
                        self.lcm_lora_id.setEnabled(False)
         
     | 
| 482 | 
         
            -
                        self.base_model_id.setEnabled(False)
         
     | 
| 483 | 
         
            -
                        self.config.settings.lcm_diffusion_setting.use_lcm_lora = False
         
     | 
| 484 | 
         
            -
                    self.settings_changed.emit()
         
     | 
| 485 | 
         
            -
             
     | 
| 486 | 
         
            -
                def update_clip_skip_label(self, value):
         
     | 
| 487 | 
         
            -
                    self.clip_skip_value.setText(f"CLIP Skip: {value}")
         
     | 
| 488 | 
         
            -
                    self.config.settings.lcm_diffusion_setting.clip_skip = value
         
     | 
| 489 | 
         
            -
             
     | 
| 490 | 
         
            -
                def update_token_merging_label(self, value):
         
     | 
| 491 | 
         
            -
                    val = round(int(value) / 100, 1)
         
     | 
| 492 | 
         
            -
                    self.token_merging_value.setText(f"Token Merging: {val}")
         
     | 
| 493 | 
         
            -
                    self.config.settings.lcm_diffusion_setting.token_merging = val
         
     | 
| 494 | 
         
            -
             
     | 
| 495 | 
         
            -
                def use_safety_checker_changed(self, state):
         
     | 
| 496 | 
         
            -
                    if state == 2:
         
     | 
| 497 | 
         
            -
                        self.config.settings.lcm_diffusion_setting.use_safety_checker = True
         
     | 
| 498 | 
         
            -
                    else:
         
     | 
| 499 | 
         
            -
                        self.config.settings.lcm_diffusion_setting.use_safety_checker = False
         
     | 
| 500 | 
         
            -
             
     | 
| 501 | 
         
            -
                def update_steps_label(self, value):
         
     | 
| 502 | 
         
            -
                    self.inference_steps_value.setText(f"Number of inference steps: {value}")
         
     | 
| 503 | 
         
            -
                    self.config.settings.lcm_diffusion_setting.inference_steps = value
         
     | 
| 504 | 
         
            -
             
     | 
| 505 | 
         
            -
                def update_num_images_label(self, value):
         
     | 
| 506 | 
         
            -
                    self.num_images_value.setText(f"Number of images: {value}")
         
     | 
| 507 | 
         
            -
                    self.config.settings.lcm_diffusion_setting.number_of_images = value
         
     | 
| 508 | 
         
            -
             
     | 
| 509 | 
         
            -
                def update_guidance_label(self, value):
         
     | 
| 510 | 
         
            -
                    val = round(int(value) / 10, 1)
         
     | 
| 511 | 
         
            -
                    self.guidance_value.setText(f"Guidance scale: {val}")
         
     | 
| 512 | 
         
            -
                    self.config.settings.lcm_diffusion_setting.guidance_scale = val
         
     | 
| 513 | 
         
            -
             
     | 
| 514 | 
         
            -
                def seed_changed(self, state):
         
     | 
| 515 | 
         
            -
                    if state == 2:
         
     | 
| 516 | 
         
            -
                        self.seed_value.setEnabled(True)
         
     | 
| 517 | 
         
            -
                        self.config.settings.lcm_diffusion_setting.use_seed = True
         
     | 
| 518 | 
         
            -
                    else:
         
     | 
| 519 | 
         
            -
                        self.seed_value.setEnabled(False)
         
     | 
| 520 | 
         
            -
                        self.config.settings.lcm_diffusion_setting.use_seed = False
         
     | 
| 521 | 
         
            -
             
     | 
| 522 | 
         
            -
                def get_seed_value(self) -> int:
         
     | 
| 523 | 
         
            -
                    use_seed = self.config.settings.lcm_diffusion_setting.use_seed
         
     | 
| 524 | 
         
            -
                    seed_value = int(self.seed_value.text()) if use_seed else -1
         
     | 
| 525 | 
         
            -
                    return seed_value
         
     | 
| 526 | 
         
            -
             
     | 
| 527 | 
         
            -
                # def text_to_image(self):
         
     | 
| 528 | 
         
            -
                #    self.img.setText("Please wait...")
         
     | 
| 529 | 
         
            -
                #    worker = ImageGeneratorWorker(self.generate_image)
         
     | 
| 530 | 
         
            -
                #    self.threadpool.start(worker)
         
     | 
| 531 | 
         
            -
             
     | 
| 532 | 
         
            -
                def closeEvent(self, event):
         
     | 
| 533 | 
         
            -
                    self.config.settings.lcm_diffusion_setting.seed = self.get_seed_value()
         
     | 
| 534 | 
         
            -
                    print(self.config.settings.lcm_diffusion_setting)
         
     | 
| 535 | 
         
            -
                    print("Saving settings")
         
     | 
| 536 | 
         
            -
                    self.config.save()
         
     | 
| 537 | 
         
            -
             
     | 
| 538 | 
         
            -
                def reset_all_settings(self):
         
     | 
| 539 | 
         
            -
                    self.use_local_model_folder.setChecked(False)
         
     | 
| 540 | 
         
            -
                    self.width.setCurrentText("512")
         
     | 
| 541 | 
         
            -
                    self.height.setCurrentText("512")
         
     | 
| 542 | 
         
            -
                    self.inference_steps.setValue(4)
         
     | 
| 543 | 
         
            -
                    self.guidance.setValue(10)
         
     | 
| 544 | 
         
            -
                    self.clip_skip.setValue(1)
         
     | 
| 545 | 
         
            -
                    self.token_merging.setValue(0)
         
     | 
| 546 | 
         
            -
                    self.use_openvino_check.setChecked(False)
         
     | 
| 547 | 
         
            -
                    self.seed_check.setChecked(False)
         
     | 
| 548 | 
         
            -
                    self.safety_checker.setChecked(False)
         
     | 
| 549 | 
         
            -
                    self.results_path.setText(FastStableDiffusionPaths().get_results_path())
         
     | 
| 550 | 
         
            -
                    self.use_tae_sd.setChecked(False)
         
     | 
| 551 | 
         
            -
                    self.use_lcm_lora.setChecked(False)
         
     | 
| 552 | 
         
            -
             
     | 
| 553 | 
         
            -
                def prepare_generation_settings(self, config):
         
     | 
| 554 | 
         
            -
                    """Populate config settings with the values set by the user in the GUI"""
         
     | 
| 555 | 
         
            -
                    config.settings.lcm_diffusion_setting.seed = self.get_seed_value()
         
     | 
| 556 | 
         
            -
                    config.settings.lcm_diffusion_setting.lcm_lora.lcm_lora_id = (
         
     | 
| 557 | 
         
            -
                        self.lcm_lora_id.currentText()
         
     | 
| 558 | 
         
            -
                    )
         
     | 
| 559 | 
         
            -
                    config.settings.lcm_diffusion_setting.lcm_lora.base_model_id = (
         
     | 
| 560 | 
         
            -
                        self.base_model_id.currentText()
         
     | 
| 561 | 
         
            -
                    )
         
     | 
| 562 | 
         
            -
             
     | 
| 563 | 
         
            -
                    if config.settings.lcm_diffusion_setting.use_openvino:
         
     | 
| 564 | 
         
            -
                        model_id = self.openvino_lcm_model_id.currentText()
         
     | 
| 565 | 
         
            -
                        config.settings.lcm_diffusion_setting.openvino_lcm_model_id = model_id
         
     | 
| 566 | 
         
            -
                    else:
         
     | 
| 567 | 
         
            -
                        model_id = self.lcm_model.currentText()
         
     | 
| 568 | 
         
            -
                        config.settings.lcm_diffusion_setting.lcm_model_id = model_id
         
     | 
| 569 | 
         
            -
             
     | 
| 570 | 
         
            -
                    config.reshape_required = False
         
     | 
| 571 | 
         
            -
                    config.model_id = model_id
         
     | 
| 572 | 
         
            -
                    if config.settings.lcm_diffusion_setting.use_openvino:
         
     | 
| 573 | 
         
            -
                        # Detect dimension change
         
     | 
| 574 | 
         
            -
                        config.reshape_required = is_reshape_required(
         
     | 
| 575 | 
         
            -
                            self.previous_width,
         
     | 
| 576 | 
         
            -
                            config.settings.lcm_diffusion_setting.image_width,
         
     | 
| 577 | 
         
            -
                            self.previous_height,
         
     | 
| 578 | 
         
            -
                            config.settings.lcm_diffusion_setting.image_height,
         
     | 
| 579 | 
         
            -
                            self.previous_model,
         
     | 
| 580 | 
         
            -
                            model_id,
         
     | 
| 581 | 
         
            -
                            self.previous_num_of_images,
         
     | 
| 582 | 
         
            -
                            config.settings.lcm_diffusion_setting.number_of_images,
         
     | 
| 583 | 
         
            -
                        )
         
     | 
| 584 | 
         
            -
                    config.settings.lcm_diffusion_setting.diffusion_task = (
         
     | 
| 585 | 
         
            -
                        DiffusionTask.text_to_image.value
         
     | 
| 586 | 
         
            -
                    )
         
     | 
| 587 | 
         
            -
             
     | 
| 588 | 
         
            -
                def store_dimension_settings(self):
         
     | 
| 589 | 
         
            -
                    """These values are only needed for OpenVINO model reshape"""
         
     | 
| 590 | 
         
            -
                    self.previous_width = self.config.settings.lcm_diffusion_setting.image_width
         
     | 
| 591 | 
         
            -
                    self.previous_height = self.config.settings.lcm_diffusion_setting.image_height
         
     | 
| 592 | 
         
            -
                    self.previous_model = self.config.model_id
         
     | 
| 593 | 
         
            -
                    self.previous_num_of_images = (
         
     | 
| 594 | 
         
            -
                        self.config.settings.lcm_diffusion_setting.number_of_images
         
     | 
| 595 | 
         
            -
                    )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/frontend/gui/base_widget.py
    DELETED
    
    | 
         @@ -1,199 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from PIL.ImageQt import ImageQt
         
     | 
| 2 | 
         
            -
            from PyQt5 import QtCore
         
     | 
| 3 | 
         
            -
            from PyQt5.QtCore import QSize, Qt, QUrl
         
     | 
| 4 | 
         
            -
            from PyQt5.QtGui import (
         
     | 
| 5 | 
         
            -
                QDesktopServices,
         
     | 
| 6 | 
         
            -
                QPixmap,
         
     | 
| 7 | 
         
            -
            )
         
     | 
| 8 | 
         
            -
            from PyQt5.QtWidgets import (
         
     | 
| 9 | 
         
            -
                QApplication,
         
     | 
| 10 | 
         
            -
                QHBoxLayout,
         
     | 
| 11 | 
         
            -
                QLabel,
         
     | 
| 12 | 
         
            -
                QPushButton,
         
     | 
| 13 | 
         
            -
                QSizePolicy,
         
     | 
| 14 | 
         
            -
                QTextEdit,
         
     | 
| 15 | 
         
            -
                QToolButton,
         
     | 
| 16 | 
         
            -
                QVBoxLayout,
         
     | 
| 17 | 
         
            -
                QWidget,
         
     | 
| 18 | 
         
            -
            )
         
     | 
| 19 | 
         
            -
             
     | 
| 20 | 
         
            -
            from app_settings import AppSettings
         
     | 
| 21 | 
         
            -
            from constants import DEVICE
         
     | 
| 22 | 
         
            -
            from frontend.gui.image_generator_worker import ImageGeneratorWorker
         
     | 
| 23 | 
         
            -
             
     | 
| 24 | 
         
            -
             
     | 
| 25 | 
         
            -
            class ImageLabel(QLabel):
         
     | 
| 26 | 
         
            -
                """Defines a simple QLabel widget"""
         
     | 
| 27 | 
         
            -
             
     | 
| 28 | 
         
            -
                changed = QtCore.pyqtSignal()
         
     | 
| 29 | 
         
            -
             
     | 
| 30 | 
         
            -
                def __init__(self, text: str):
         
     | 
| 31 | 
         
            -
                    super().__init__(text)
         
     | 
| 32 | 
         
            -
                    self.setAlignment(Qt.AlignCenter)
         
     | 
| 33 | 
         
            -
                    self.resize(512, 512)
         
     | 
| 34 | 
         
            -
                    self.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.MinimumExpanding)
         
     | 
| 35 | 
         
            -
                    self.sizeHint = QSize(512, 512)
         
     | 
| 36 | 
         
            -
                    self.setAcceptDrops(False)
         
     | 
| 37 | 
         
            -
             
     | 
| 38 | 
         
            -
                def show_image(self, pixmap: QPixmap = None):
         
     | 
| 39 | 
         
            -
                    """Updates the widget pixamp"""
         
     | 
| 40 | 
         
            -
                    if pixmap == None or pixmap.isNull():
         
     | 
| 41 | 
         
            -
                        return
         
     | 
| 42 | 
         
            -
                    self.current_pixmap = pixmap
         
     | 
| 43 | 
         
            -
                    self.changed.emit()
         
     | 
| 44 | 
         
            -
             
     | 
| 45 | 
         
            -
                    # Resize the pixmap to the widget dimensions
         
     | 
| 46 | 
         
            -
                    image_width = self.current_pixmap.width()
         
     | 
| 47 | 
         
            -
                    image_height = self.current_pixmap.height()
         
     | 
| 48 | 
         
            -
                    if image_width > 512 or image_height > 512:
         
     | 
| 49 | 
         
            -
                        new_width = 512 if image_width > 512 else image_width
         
     | 
| 50 | 
         
            -
                        new_height = 512 if image_height > 512 else image_height
         
     | 
| 51 | 
         
            -
                        self.setPixmap(
         
     | 
| 52 | 
         
            -
                            self.current_pixmap.scaled(
         
     | 
| 53 | 
         
            -
                                new_width,
         
     | 
| 54 | 
         
            -
                                new_height,
         
     | 
| 55 | 
         
            -
                                Qt.KeepAspectRatio,
         
     | 
| 56 | 
         
            -
                            )
         
     | 
| 57 | 
         
            -
                        )
         
     | 
| 58 | 
         
            -
                    else:
         
     | 
| 59 | 
         
            -
                        self.setPixmap(self.current_pixmap)
         
     | 
| 60 | 
         
            -
             
     | 
| 61 | 
         
            -
             
     | 
| 62 | 
         
            -
            class BaseWidget(QWidget):
         
     | 
| 63 | 
         
            -
                def __init__(self, config: AppSettings, parent):
         
     | 
| 64 | 
         
            -
                    super().__init__()
         
     | 
| 65 | 
         
            -
                    self.config = config
         
     | 
| 66 | 
         
            -
                    self.gen_images = []
         
     | 
| 67 | 
         
            -
                    self.image_index = 0
         
     | 
| 68 | 
         
            -
                    self.config = config
         
     | 
| 69 | 
         
            -
                    self.parent = parent
         
     | 
| 70 | 
         
            -
             
     | 
| 71 | 
         
            -
                    # Initialize GUI widgets
         
     | 
| 72 | 
         
            -
                    self.prev_btn = QToolButton()
         
     | 
| 73 | 
         
            -
                    self.prev_btn.setText("<")
         
     | 
| 74 | 
         
            -
                    self.prev_btn.clicked.connect(self.on_show_previous_image)
         
     | 
| 75 | 
         
            -
                    self.img = ImageLabel("<<Image>>")
         
     | 
| 76 | 
         
            -
                    self.next_btn = QToolButton()
         
     | 
| 77 | 
         
            -
                    self.next_btn.setText(">")
         
     | 
| 78 | 
         
            -
                    self.next_btn.clicked.connect(self.on_show_next_image)
         
     | 
| 79 | 
         
            -
                    self.prompt = QTextEdit()
         
     | 
| 80 | 
         
            -
                    self.prompt.setPlaceholderText("A fantasy landscape")
         
     | 
| 81 | 
         
            -
                    self.prompt.setAcceptRichText(False)
         
     | 
| 82 | 
         
            -
                    self.prompt.setFixedHeight(40)
         
     | 
| 83 | 
         
            -
                    self.neg_prompt = QTextEdit()
         
     | 
| 84 | 
         
            -
                    self.neg_prompt.setPlaceholderText("")
         
     | 
| 85 | 
         
            -
                    self.neg_prompt.setAcceptRichText(False)
         
     | 
| 86 | 
         
            -
                    self.neg_prompt_label = QLabel("Negative prompt (Set guidance scale > 1.0):")
         
     | 
| 87 | 
         
            -
                    self.neg_prompt.setFixedHeight(35)
         
     | 
| 88 | 
         
            -
                    self.neg_prompt.setEnabled(False)
         
     | 
| 89 | 
         
            -
                    self.generate = QPushButton("Generate")
         
     | 
| 90 | 
         
            -
                    self.generate.clicked.connect(self.generate_click)
         
     | 
| 91 | 
         
            -
                    self.browse_results = QPushButton("...")
         
     | 
| 92 | 
         
            -
                    self.browse_results.setFixedWidth(30)
         
     | 
| 93 | 
         
            -
                    self.browse_results.clicked.connect(self.on_open_results_folder)
         
     | 
| 94 | 
         
            -
                    self.browse_results.setToolTip("Open output folder")
         
     | 
| 95 | 
         
            -
             
     | 
| 96 | 
         
            -
                    # Create the image navigation layout
         
     | 
| 97 | 
         
            -
                    ilayout = QHBoxLayout()
         
     | 
| 98 | 
         
            -
                    ilayout.addWidget(self.prev_btn)
         
     | 
| 99 | 
         
            -
                    ilayout.addWidget(self.img)
         
     | 
| 100 | 
         
            -
                    ilayout.addWidget(self.next_btn)
         
     | 
| 101 | 
         
            -
             
     | 
| 102 | 
         
            -
                    # Create the generate button layout
         
     | 
| 103 | 
         
            -
                    hlayout = QHBoxLayout()
         
     | 
| 104 | 
         
            -
                    hlayout.addWidget(self.neg_prompt)
         
     | 
| 105 | 
         
            -
                    hlayout.addWidget(self.generate)
         
     | 
| 106 | 
         
            -
                    hlayout.addWidget(self.browse_results)
         
     | 
| 107 | 
         
            -
             
     | 
| 108 | 
         
            -
                    # Create the actual widget layout
         
     | 
| 109 | 
         
            -
                    vlayout = QVBoxLayout()
         
     | 
| 110 | 
         
            -
                    vlayout.addLayout(ilayout)
         
     | 
| 111 | 
         
            -
                    # vlayout.addItem(self.vspacer)
         
     | 
| 112 | 
         
            -
                    vlayout.addWidget(self.prompt)
         
     | 
| 113 | 
         
            -
                    vlayout.addWidget(self.neg_prompt_label)
         
     | 
| 114 | 
         
            -
                    vlayout.addLayout(hlayout)
         
     | 
| 115 | 
         
            -
                    self.setLayout(vlayout)
         
     | 
| 116 | 
         
            -
             
     | 
| 117 | 
         
            -
                    self.parent.settings_changed.connect(self.on_settings_changed)
         
     | 
| 118 | 
         
            -
             
     | 
| 119 | 
         
            -
                def generate_image(self):
         
     | 
| 120 | 
         
            -
                    self.parent.prepare_generation_settings(self.config)
         
     | 
| 121 | 
         
            -
                    self.config.settings.lcm_diffusion_setting.prompt = self.prompt.toPlainText()
         
     | 
| 122 | 
         
            -
                    self.config.settings.lcm_diffusion_setting.negative_prompt = (
         
     | 
| 123 | 
         
            -
                        self.neg_prompt.toPlainText()
         
     | 
| 124 | 
         
            -
                    )
         
     | 
| 125 | 
         
            -
                    images = self.parent.context.generate_text_to_image(
         
     | 
| 126 | 
         
            -
                        self.config.settings,
         
     | 
| 127 | 
         
            -
                        self.config.reshape_required,
         
     | 
| 128 | 
         
            -
                        DEVICE,
         
     | 
| 129 | 
         
            -
                    )
         
     | 
| 130 | 
         
            -
                    self.parent.context.save_images(
         
     | 
| 131 | 
         
            -
                        images,
         
     | 
| 132 | 
         
            -
                        self.config.settings,
         
     | 
| 133 | 
         
            -
                    )
         
     | 
| 134 | 
         
            -
                    self.prepare_images(images)
         
     | 
| 135 | 
         
            -
                    self.after_generation()
         
     | 
| 136 | 
         
            -
             
     | 
| 137 | 
         
            -
                def prepare_images(self, images):
         
     | 
| 138 | 
         
            -
                    """Prepares the generated images to be displayed in the Qt widget"""
         
     | 
| 139 | 
         
            -
                    self.image_index = 0
         
     | 
| 140 | 
         
            -
                    self.gen_images = []
         
     | 
| 141 | 
         
            -
                    for img in images:
         
     | 
| 142 | 
         
            -
                        im = ImageQt(img).copy()
         
     | 
| 143 | 
         
            -
                        pixmap = QPixmap.fromImage(im)
         
     | 
| 144 | 
         
            -
                        self.gen_images.append(pixmap)
         
     | 
| 145 | 
         
            -
             
     | 
| 146 | 
         
            -
                    if len(self.gen_images) > 1:
         
     | 
| 147 | 
         
            -
                        self.next_btn.setEnabled(True)
         
     | 
| 148 | 
         
            -
                        self.prev_btn.setEnabled(False)
         
     | 
| 149 | 
         
            -
                    else:
         
     | 
| 150 | 
         
            -
                        self.next_btn.setEnabled(False)
         
     | 
| 151 | 
         
            -
                        self.prev_btn.setEnabled(False)
         
     | 
| 152 | 
         
            -
             
     | 
| 153 | 
         
            -
                    self.img.show_image(pixmap=self.gen_images[0])
         
     | 
| 154 | 
         
            -
             
     | 
| 155 | 
         
            -
                def on_show_next_image(self):
         
     | 
| 156 | 
         
            -
                    if self.image_index != len(self.gen_images) - 1 and len(self.gen_images) > 0:
         
     | 
| 157 | 
         
            -
                        self.prev_btn.setEnabled(True)
         
     | 
| 158 | 
         
            -
                        self.image_index += 1
         
     | 
| 159 | 
         
            -
                        self.img.show_image(pixmap=self.gen_images[self.image_index])
         
     | 
| 160 | 
         
            -
                        if self.image_index == len(self.gen_images) - 1:
         
     | 
| 161 | 
         
            -
                            self.next_btn.setEnabled(False)
         
     | 
| 162 | 
         
            -
             
     | 
| 163 | 
         
            -
                def on_show_previous_image(self):
         
     | 
| 164 | 
         
            -
                    if self.image_index != 0:
         
     | 
| 165 | 
         
            -
                        self.next_btn.setEnabled(True)
         
     | 
| 166 | 
         
            -
                        self.image_index -= 1
         
     | 
| 167 | 
         
            -
                        self.img.show_image(pixmap=self.gen_images[self.image_index])
         
     | 
| 168 | 
         
            -
                        if self.image_index == 0:
         
     | 
| 169 | 
         
            -
                            self.prev_btn.setEnabled(False)
         
     | 
| 170 | 
         
            -
             
     | 
| 171 | 
         
            -
                def on_open_results_folder(self):
         
     | 
| 172 | 
         
            -
                    QDesktopServices.openUrl(
         
     | 
| 173 | 
         
            -
                        QUrl.fromLocalFile(self.config.settings.generated_images.path)
         
     | 
| 174 | 
         
            -
                    )
         
     | 
| 175 | 
         
            -
             
     | 
| 176 | 
         
            -
                def generate_click(self):
         
     | 
| 177 | 
         
            -
                    self.img.setText("Please wait...")
         
     | 
| 178 | 
         
            -
                    self.before_generation()
         
     | 
| 179 | 
         
            -
                    worker = ImageGeneratorWorker(self.generate_image)
         
     | 
| 180 | 
         
            -
                    self.parent.threadpool.start(worker)
         
     | 
| 181 | 
         
            -
             
     | 
| 182 | 
         
            -
                def before_generation(self):
         
     | 
| 183 | 
         
            -
                    """Call this function before running a generation task"""
         
     | 
| 184 | 
         
            -
                    self.img.setEnabled(False)
         
     | 
| 185 | 
         
            -
                    self.generate.setEnabled(False)
         
     | 
| 186 | 
         
            -
                    self.browse_results.setEnabled(False)
         
     | 
| 187 | 
         
            -
             
     | 
| 188 | 
         
            -
                def after_generation(self):
         
     | 
| 189 | 
         
            -
                    """Call this function after running a generation task"""
         
     | 
| 190 | 
         
            -
                    self.img.setEnabled(True)
         
     | 
| 191 | 
         
            -
                    self.generate.setEnabled(True)
         
     | 
| 192 | 
         
            -
                    self.browse_results.setEnabled(True)
         
     | 
| 193 | 
         
            -
                    self.parent.store_dimension_settings()
         
     | 
| 194 | 
         
            -
             
     | 
| 195 | 
         
            -
                def on_settings_changed(self):
         
     | 
| 196 | 
         
            -
                    self.neg_prompt.setEnabled(
         
     | 
| 197 | 
         
            -
                        self.config.settings.lcm_diffusion_setting.use_openvino
         
     | 
| 198 | 
         
            -
                        or self.config.settings.lcm_diffusion_setting.use_lcm_lora
         
     | 
| 199 | 
         
            -
                    )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/frontend/gui/image_generator_worker.py
    DELETED
    
    | 
         @@ -1,37 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from PyQt5.QtCore import (
         
     | 
| 2 | 
         
            -
                pyqtSlot,
         
     | 
| 3 | 
         
            -
                QRunnable,
         
     | 
| 4 | 
         
            -
                pyqtSignal,
         
     | 
| 5 | 
         
            -
                pyqtSlot,
         
     | 
| 6 | 
         
            -
            )
         
     | 
| 7 | 
         
            -
            from PyQt5.QtCore import QObject
         
     | 
| 8 | 
         
            -
            import traceback
         
     | 
| 9 | 
         
            -
            import sys
         
     | 
| 10 | 
         
            -
             
     | 
| 11 | 
         
            -
             
     | 
| 12 | 
         
            -
            class WorkerSignals(QObject):
         
     | 
| 13 | 
         
            -
                finished = pyqtSignal()
         
     | 
| 14 | 
         
            -
                error = pyqtSignal(tuple)
         
     | 
| 15 | 
         
            -
                result = pyqtSignal(object)
         
     | 
| 16 | 
         
            -
             
     | 
| 17 | 
         
            -
             
     | 
| 18 | 
         
            -
            class ImageGeneratorWorker(QRunnable):
         
     | 
| 19 | 
         
            -
                def __init__(self, fn, *args, **kwargs):
         
     | 
| 20 | 
         
            -
                    super(ImageGeneratorWorker, self).__init__()
         
     | 
| 21 | 
         
            -
                    self.fn = fn
         
     | 
| 22 | 
         
            -
                    self.args = args
         
     | 
| 23 | 
         
            -
                    self.kwargs = kwargs
         
     | 
| 24 | 
         
            -
                    self.signals = WorkerSignals()
         
     | 
| 25 | 
         
            -
             
     | 
| 26 | 
         
            -
                @pyqtSlot()
         
     | 
| 27 | 
         
            -
                def run(self):
         
     | 
| 28 | 
         
            -
                    try:
         
     | 
| 29 | 
         
            -
                        result = self.fn(*self.args, **self.kwargs)
         
     | 
| 30 | 
         
            -
                    except:
         
     | 
| 31 | 
         
            -
                        traceback.print_exc()
         
     | 
| 32 | 
         
            -
                        exctype, value = sys.exc_info()[:2]
         
     | 
| 33 | 
         
            -
                        self.signals.error.emit((exctype, value, traceback.format_exc()))
         
     | 
| 34 | 
         
            -
                    else:
         
     | 
| 35 | 
         
            -
                        self.signals.result.emit(result)
         
     | 
| 36 | 
         
            -
                    finally:
         
     | 
| 37 | 
         
            -
                        self.signals.finished.emit()
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        src/frontend/gui/image_variations_widget.py
    DELETED
    
    | 
         @@ -1,35 +0,0 @@ 
     | 
|
| 1 | 
         
            -
            from PIL import Image
         
     | 
| 2 | 
         
            -
            from PyQt5.QtWidgets import QApplication
         
     | 
| 3 | 
         
            -
             
     | 
| 4 | 
         
            -
            from app_settings import AppSettings
         
     | 
| 5 | 
         
            -
            from backend.models.lcmdiffusion_setting import DiffusionTask
         
     | 
| 6 | 
         
            -
            from frontend.gui.img2img_widget import Img2ImgWidget
         
     | 
| 7 | 
         
            -
            from frontend.webui.image_variations_ui import generate_image_variations
         
     | 
| 8 | 
         
            -
             
     | 
| 9 | 
         
            -
             
     | 
| 10 | 
         
            -
            class ImageVariationsWidget(Img2ImgWidget):
         
     | 
| 11 | 
         
            -
                def __init__(self, config: AppSettings, parent):
         
     | 
| 12 | 
         
            -
                    super().__init__(config, parent)
         
     | 
| 13 | 
         
            -
                    # Hide prompt and negative prompt widgets
         
     | 
| 14 | 
         
            -
                    self.prompt.hide()
         
     | 
| 15 | 
         
            -
                    self.neg_prompt_label.hide()
         
     | 
| 16 | 
         
            -
                    self.neg_prompt.setEnabled(False)
         
     | 
| 17 | 
         
            -
             
     | 
| 18 | 
         
            -
                def generate_image(self):
         
     | 
| 19 | 
         
            -
                    self.parent.prepare_generation_settings(self.config)
         
     | 
| 20 | 
         
            -
                    self.config.settings.lcm_diffusion_setting.diffusion_task = (
         
     | 
| 21 | 
         
            -
                        DiffusionTask.image_to_image.value
         
     | 
| 22 | 
         
            -
                    )
         
     | 
| 23 | 
         
            -
                    self.config.settings.lcm_diffusion_setting.prompt = ""
         
     | 
| 24 | 
         
            -
                    self.config.settings.lcm_diffusion_setting.negative_prompt = ""
         
     | 
| 25 | 
         
            -
                    self.config.settings.lcm_diffusion_setting.init_image = Image.open(
         
     | 
| 26 | 
         
            -
                        self.img_path.text()
         
     | 
| 27 | 
         
            -
                    )
         
     | 
| 28 | 
         
            -
                    self.config.settings.lcm_diffusion_setting.strength = self.strength.value() / 10
         
     | 
| 29 | 
         
            -
             
     | 
| 30 | 
         
            -
                    images = generate_image_variations(
         
     | 
| 31 | 
         
            -
                        self.config.settings.lcm_diffusion_setting.init_image,
         
     | 
| 32 | 
         
            -
                        self.config.settings.lcm_diffusion_setting.strength,
         
     | 
| 33 | 
         
            -
                    )
         
     | 
| 34 | 
         
            -
                    self.prepare_images(images)
         
     | 
| 35 | 
         
            -
                    self.after_generation()
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         |