Spaces:
Runtime error
Runtime error
| # Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2 | |
| from torchvision.utils import save_image | |
| from PIL import Image | |
| from pytorch_lightning import seed_everything | |
| import subprocess | |
| from collections import OrderedDict | |
| import re | |
| import cv2 | |
| import einops | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import random | |
| import os | |
| import requests | |
| from io import BytesIO | |
| from annotator.util import resize_image, HWC3, resize_points, get_bounding_box, save_input_to_file | |
| import torch | |
| from safetensors.torch import load_file | |
| from collections import defaultdict | |
| from diffusers import StableDiffusionControlNetPipeline | |
| from diffusers import ControlNetModel, UniPCMultistepScheduler | |
| from utils.stable_diffusion_controlnet import ControlNetModel2 | |
| from utils.stable_diffusion_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline, \ | |
| StableDiffusionControlNetInpaintMixingPipeline, prepare_mask_image | |
| # need the latest transformers | |
| # pip install git+https://github.com/huggingface/transformers.git | |
| from transformers import AutoProcessor, Blip2ForConditionalGeneration | |
| from diffusers import ControlNetModel | |
| import PIL.Image | |
| # Segment-Anything init. | |
| # pip install git+https://github.com/facebookresearch/segment-anything.git | |
| try: | |
| from segment_anything import ( | |
| sam_model_registry, | |
| SamAutomaticMaskGenerator, | |
| SamPredictor, | |
| ) | |
| except ImportError: | |
| print("segment_anything not installed") | |
| result = subprocess.run( | |
| [ | |
| "pip", | |
| "install", | |
| "git+https://github.com/facebookresearch/segment-anything.git", | |
| ], | |
| check=True, | |
| ) | |
| print(f"Install segment_anything {result}") | |
| from segment_anything import ( | |
| sam_model_registry, | |
| SamAutomaticMaskGenerator, | |
| SamPredictor, | |
| ) | |
| if not os.path.exists("./models/sam_vit_h_4b8939.pth"): | |
| result = subprocess.run( | |
| [ | |
| "wget", | |
| "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", | |
| "-P", | |
| "models", | |
| ], | |
| check=True, | |
| ) | |
| print(f"Download sam_vit_h_4b8939.pth {result}") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| config_dict = OrderedDict( | |
| [ | |
| ("LAION Pretrained(v0-4)-SD15", "shgao/edit-anything-v0-4-sd15"), | |
| ("LAION Pretrained(v0-4)-SD21", "shgao/edit-anything-v0-4-sd21"), | |
| ("LAION Pretrained(v0-3)-SD21", "shgao/edit-anything-v0-3"), | |
| ("SAM Pretrained(v0-1)-SD21", "shgao/edit-anything-v0-1-1"), | |
| ] | |
| ) | |
| def init_sam_model(sam_generator=None, mask_predictor=None): | |
| if sam_generator is not None and mask_predictor is not None: | |
| return sam_generator, mask_predictor | |
| sam_checkpoint = "models/sam_vit_h_4b8939.pth" | |
| model_type = "default" | |
| sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) | |
| sam.to(device=device) | |
| sam_generator = ( | |
| SamAutomaticMaskGenerator( | |
| sam) if sam_generator is None else sam_generator | |
| ) | |
| mask_predictor = SamPredictor( | |
| sam) if mask_predictor is None else mask_predictor | |
| return sam_generator, mask_predictor | |
| def init_blip_processor(): | |
| blip_processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") | |
| return blip_processor | |
| def init_blip_model(): | |
| blip_model = Blip2ForConditionalGeneration.from_pretrained( | |
| "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto" | |
| ) | |
| return blip_model | |
| def get_pipeline_embeds(pipeline, prompt, negative_prompt, device): | |
| # https://github.com/huggingface/diffusers/issues/2136 | |
| """Get pipeline embeds for prompts bigger than the maxlength of the pipe | |
| :param pipeline: | |
| :param prompt: | |
| :param negative_prompt: | |
| :param device: | |
| :return: | |
| """ | |
| max_length = pipeline.tokenizer.model_max_length | |
| # # simple way to determine length of tokens | |
| # count_prompt = len(re.split(r",", prompt)) | |
| # count_negative_prompt = len(re.split(r",", negative_prompt)) | |
| # # create the tensor based on which prompt is longer | |
| # if count_prompt >= count_negative_prompt: | |
| # input_ids = pipeline.tokenizer( | |
| # prompt, return_tensors="pt", truncation=False | |
| # ).input_ids.to(device) | |
| # shape_max_length = input_ids.shape[-1] | |
| # negative_ids = pipeline.tokenizer( | |
| # negative_prompt, | |
| # truncation=False, | |
| # padding="max_length", | |
| # max_length=shape_max_length, | |
| # return_tensors="pt", | |
| # ).input_ids.to(device) | |
| # else: | |
| # negative_ids = pipeline.tokenizer( | |
| # negative_prompt, return_tensors="pt", truncation=False | |
| # ).input_ids.to(device) | |
| # shape_max_length = negative_ids.shape[-1] | |
| # input_ids = pipeline.tokenizer( | |
| # prompt, | |
| # return_tensors="pt", | |
| # truncation=False, | |
| # padding="max_length", | |
| # max_length=shape_max_length, | |
| # ).input_ids.to(device) | |
| # concat_embeds = [] | |
| # neg_embeds = [] | |
| # for i in range(0, shape_max_length, max_length): | |
| # concat_embeds.append(pipeline.text_encoder( | |
| # input_ids[:, i: i + max_length])[0]) | |
| # neg_embeds.append(pipeline.text_encoder( | |
| # negative_ids[:, i: i + max_length])[0]) | |
| input_ids = pipeline.tokenizer( | |
| prompt, return_tensors="pt", truncation=False | |
| ).input_ids.to(device) | |
| negative_ids = pipeline.tokenizer( | |
| negative_prompt, return_tensors="pt", truncation=False | |
| ).input_ids.to(device) | |
| shape_max_length = max(input_ids.shape[-1],negative_ids.shape[-1]) | |
| if input_ids.shape[-1]>negative_ids.shape[-1]: | |
| negative_ids = pipeline.tokenizer( | |
| negative_prompt, | |
| truncation=False, | |
| padding="max_length", | |
| max_length=shape_max_length, | |
| return_tensors="pt", | |
| ).input_ids.to(device) | |
| else: | |
| input_ids = pipeline.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=False, | |
| padding="max_length", | |
| max_length=shape_max_length, | |
| ).input_ids.to(device) | |
| concat_embeds = [] | |
| neg_embeds = [] | |
| for i in range(0, shape_max_length, max_length): | |
| concat_embeds.append(pipeline.text_encoder( | |
| input_ids[:, i: i + max_length])[0]) | |
| neg_embeds.append(pipeline.text_encoder( | |
| negative_ids[:, i: i + max_length])[0]) | |
| return torch.cat(concat_embeds, dim=1), torch.cat(neg_embeds, dim=1) | |
| def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype): | |
| LORA_PREFIX_UNET = "lora_unet" | |
| LORA_PREFIX_TEXT_ENCODER = "lora_te" | |
| # load LoRA weight from .safetensors | |
| print('device: {}'.format(device)) | |
| if isinstance(checkpoint_path, str): | |
| state_dict = load_file(checkpoint_path, device=device) | |
| updates = defaultdict(dict) | |
| for key, value in state_dict.items(): | |
| # it is suggested to print out the key, it usually will be something like below | |
| # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" | |
| layer, elem = key.split(".", 1) | |
| updates[layer][elem] = value | |
| # directly update weight in diffusers model | |
| for layer, elems in updates.items(): | |
| if "text" in layer: | |
| layer_infos = layer.split( | |
| LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") | |
| curr_layer = pipeline.text_encoder | |
| else: | |
| layer_infos = layer.split( | |
| LORA_PREFIX_UNET + "_")[-1].split("_") | |
| curr_layer = pipeline.unet | |
| # find the target layer | |
| temp_name = layer_infos.pop(0) | |
| while len(layer_infos) > -1: | |
| try: | |
| curr_layer = curr_layer.__getattr__(temp_name) | |
| if len(layer_infos) > 0: | |
| temp_name = layer_infos.pop(0) | |
| elif len(layer_infos) == 0: | |
| break | |
| except Exception: | |
| if len(temp_name) > 0: | |
| temp_name += "_" + layer_infos.pop(0) | |
| else: | |
| temp_name = layer_infos.pop(0) | |
| # get elements for this layer | |
| weight_up = elems["lora_up.weight"].to(dtype) | |
| weight_down = elems["lora_down.weight"].to(dtype) | |
| alpha = elems["alpha"] | |
| if alpha: | |
| alpha = alpha.item() / weight_up.shape[1] | |
| else: | |
| alpha = 1.0 | |
| # update weight | |
| if len(weight_up.shape) == 4: | |
| curr_layer.weight.data += ( | |
| multiplier | |
| * alpha | |
| * torch.mm( | |
| weight_up.squeeze(3).squeeze(2), | |
| weight_down.squeeze(3).squeeze(2), | |
| ) | |
| .unsqueeze(2) | |
| .unsqueeze(3) | |
| ) | |
| else: | |
| curr_layer.weight.data += ( | |
| multiplier * alpha * torch.mm(weight_up, weight_down) | |
| ) | |
| else: | |
| for ckptpath in checkpoint_path: | |
| state_dict = load_file(ckptpath, device=device) | |
| updates = defaultdict(dict) | |
| for key, value in state_dict.items(): | |
| # it is suggested to print out the key, it usually will be something like below | |
| # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" | |
| layer, elem = key.split(".", 1) | |
| updates[layer][elem] = value | |
| # directly update weight in diffusers model | |
| for layer, elems in updates.items(): | |
| if "text" in layer: | |
| layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split( | |
| "_" | |
| ) | |
| curr_layer = pipeline.text_encoder | |
| else: | |
| layer_infos = layer.split( | |
| LORA_PREFIX_UNET + "_")[-1].split("_") | |
| curr_layer = pipeline.unet | |
| # find the target layer | |
| temp_name = layer_infos.pop(0) | |
| while len(layer_infos) > -1: | |
| try: | |
| curr_layer = curr_layer.__getattr__(temp_name) | |
| if len(layer_infos) > 0: | |
| temp_name = layer_infos.pop(0) | |
| elif len(layer_infos) == 0: | |
| break | |
| except Exception: | |
| if len(temp_name) > 0: | |
| temp_name += "_" + layer_infos.pop(0) | |
| else: | |
| temp_name = layer_infos.pop(0) | |
| # get elements for this layer | |
| weight_up = elems["lora_up.weight"].to(dtype) | |
| weight_down = elems["lora_down.weight"].to(dtype) | |
| alpha = elems["alpha"] | |
| if alpha: | |
| alpha = alpha.item() / weight_up.shape[1] | |
| else: | |
| alpha = 1.0 | |
| # update weight | |
| if len(weight_up.shape) == 4: | |
| curr_layer.weight.data += ( | |
| multiplier | |
| * alpha | |
| * torch.mm( | |
| weight_up.squeeze(3).squeeze(2), | |
| weight_down.squeeze(3).squeeze(2), | |
| ) | |
| .unsqueeze(2) | |
| .unsqueeze(3) | |
| ) | |
| else: | |
| curr_layer.weight.data += ( | |
| multiplier * alpha * torch.mm(weight_up, weight_down) | |
| ) | |
| return pipeline | |
| def make_inpaint_condition(image, image_mask): | |
| image = image / 255.0 | |
| assert ( | |
| image.shape[0:1] == image_mask.shape[0:1] | |
| ), "image and image_mask must have the same image size" | |
| image[image_mask > 128] = -1.0 # set as masked pixel | |
| image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) | |
| image = torch.from_numpy(image) | |
| return image | |
| def obtain_generation_model( | |
| base_model_path, | |
| lora_model_path, | |
| controlnet_path, | |
| generation_only=False, | |
| extra_inpaint=True, | |
| lora_weight=1.0, | |
| ): | |
| controlnet = [] | |
| controlnet.append( | |
| ControlNetModel2.from_pretrained( | |
| controlnet_path, torch_dtype=torch.float16) | |
| ) # sam control | |
| if (not generation_only) and extra_inpaint: # inpainting control | |
| print("Warning: ControlNet based inpainting model only support SD1.5 for now.") | |
| controlnet.append( | |
| ControlNetModel.from_pretrained( | |
| "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16 | |
| ) # inpainting controlnet | |
| ) | |
| if generation_only and extra_inpaint: | |
| pipe = StableDiffusionControlNetPipeline.from_pretrained( | |
| base_model_path, | |
| controlnet=controlnet, | |
| torch_dtype=torch.float16, | |
| safety_checker=None, | |
| ) | |
| else: | |
| pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( | |
| base_model_path, | |
| controlnet=controlnet, | |
| torch_dtype=torch.float16, | |
| safety_checker=None, | |
| ) | |
| if lora_model_path is not None: | |
| pipe = load_lora_weights( | |
| pipe, [lora_model_path], lora_weight, "cpu", torch.float32 | |
| ) | |
| # speed up diffusion process with faster scheduler and memory optimization | |
| pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) | |
| # remove following line if xformers is not installed | |
| pipe.enable_xformers_memory_efficient_attention() | |
| pipe.enable_model_cpu_offload() | |
| return pipe | |
| def obtain_tile_model(base_model_path, lora_model_path, lora_weight=1.0): | |
| controlnet = ControlNetModel2.from_pretrained( | |
| "lllyasviel/control_v11f1e_sd15_tile", torch_dtype=torch.float16 | |
| ) # tile controlnet | |
| if ( | |
| base_model_path == "runwayml/stable-diffusion-v1-5" | |
| or base_model_path == "stabilityai/stable-diffusion-2-inpainting" | |
| ): | |
| print("base_model_path", base_model_path) | |
| pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| controlnet=controlnet, | |
| torch_dtype=torch.float16, | |
| safety_checker=None, | |
| ) | |
| else: | |
| pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( | |
| base_model_path, | |
| controlnet=controlnet, | |
| torch_dtype=torch.float16, | |
| safety_checker=None, | |
| ) | |
| if lora_model_path is not None: | |
| pipe = load_lora_weights( | |
| pipe, [lora_model_path], lora_weight, "cpu", torch.float32 | |
| ) | |
| # speed up diffusion process with faster scheduler and memory optimization | |
| pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) | |
| # remove following line if xformers is not installed | |
| pipe.enable_xformers_memory_efficient_attention() | |
| pipe.enable_model_cpu_offload() | |
| return pipe | |
| def show_anns(anns): | |
| if len(anns) == 0: | |
| return | |
| sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True) | |
| full_img = None | |
| # for ann in sorted_anns: | |
| for i in range(len(sorted_anns)): | |
| ann = anns[i] | |
| m = ann["segmentation"] | |
| if full_img is None: | |
| full_img = np.zeros((m.shape[0], m.shape[1], 3)) | |
| map = np.zeros((m.shape[0], m.shape[1]), dtype=np.uint16) | |
| map[m != 0] = i + 1 | |
| color_mask = np.random.random((1, 3)).tolist()[0] | |
| full_img[m != 0] = color_mask | |
| full_img = full_img * 255 | |
| # anno encoding from https://github.com/LUSSeg/ImageNet-S | |
| res = np.zeros((map.shape[0], map.shape[1], 3)) | |
| res[:, :, 0] = map % 256 | |
| res[:, :, 1] = map // 256 | |
| res.astype(np.float32) | |
| full_img = Image.fromarray(np.uint8(full_img)) | |
| return full_img, res | |
| class EditAnythingLoraModel: | |
| def __init__( | |
| self, | |
| base_model_path="../chilloutmix_NiPrunedFp32Fix", | |
| lora_model_path="../40806/mix4", | |
| use_blip=True, | |
| blip_processor=None, | |
| blip_model=None, | |
| sam_generator=None, | |
| controlmodel_name="LAION Pretrained(v0-4)-SD15", | |
| # used when the base model is not an inpainting model. | |
| extra_inpaint=True, | |
| tile_model=None, | |
| lora_weight=1.0, | |
| alpha_mixing=None, | |
| mask_predictor=None, | |
| ): | |
| self.device = device | |
| self.use_blip = use_blip | |
| # Diffusion init using diffusers. | |
| self.default_controlnet_path = config_dict[controlmodel_name] | |
| self.base_model_path = base_model_path | |
| self.lora_model_path = lora_model_path | |
| self.defalut_enable_all_generate = False | |
| self.extra_inpaint = extra_inpaint | |
| self.last_ref_infer = False | |
| self.pipe = obtain_generation_model( | |
| base_model_path, | |
| lora_model_path, | |
| self.default_controlnet_path, | |
| generation_only=False, | |
| extra_inpaint=extra_inpaint, | |
| lora_weight=lora_weight, | |
| ) | |
| # self.pipe.load_textual_inversion("textual_inversion_cat/learned_embeds.bin") | |
| # Segment-Anything init. | |
| self.sam_generator, self.mask_predictor = init_sam_model( | |
| sam_generator, mask_predictor | |
| ) | |
| # BLIP2 init. | |
| if use_blip: | |
| if blip_processor is not None: | |
| self.blip_processor = blip_processor | |
| else: | |
| self.blip_processor = init_blip_processor() | |
| if blip_model is not None: | |
| self.blip_model = blip_model | |
| else: | |
| self.blip_model = init_blip_model() | |
| # tile model init. | |
| if tile_model is not None: | |
| self.tile_pipe = tile_model | |
| else: | |
| self.tile_pipe = obtain_tile_model( | |
| base_model_path, lora_model_path, lora_weight=lora_weight | |
| ) | |
| def get_blip2_text(self, image): | |
| inputs = self.blip_processor(image, return_tensors="pt").to( | |
| self.device, torch.float16 | |
| ) | |
| generated_ids = self.blip_model.generate(**inputs, max_new_tokens=50) | |
| generated_text = self.blip_processor.batch_decode( | |
| generated_ids, skip_special_tokens=True | |
| )[0].strip() | |
| return generated_text | |
| def get_sam_control(self, image): | |
| masks = self.sam_generator.generate(image) | |
| full_img, res = show_anns(masks) | |
| return full_img, res | |
| def get_click_mask(self, image, clicked_points): | |
| self.mask_predictor.set_image(image) | |
| # Separate the points and labels | |
| points, labels = zip(*[(point[:2], point[2]) | |
| for point in clicked_points]) | |
| # Convert the points and labels to numpy arrays | |
| input_point = np.array(points) | |
| input_label = np.array(labels) | |
| masks, _, _ = self.mask_predictor.predict( | |
| point_coords=input_point, | |
| point_labels=input_label, | |
| multimask_output=False, | |
| ) | |
| return masks | |
| def process_image_click( | |
| self, | |
| original_image: gr.Image, | |
| point_prompt: gr.Radio, | |
| clicked_points: gr.State, | |
| image_resolution, | |
| evt: gr.SelectData, | |
| ): | |
| # Get the clicked coordinates | |
| clicked_coords = evt.index | |
| x, y = clicked_coords | |
| label = point_prompt | |
| lab = 1 if label == "Foreground Point" else 0 | |
| clicked_points.append((x, y, lab)) | |
| input_image = np.array(original_image, dtype=np.uint8) | |
| H, W, C = input_image.shape | |
| input_image = HWC3(input_image) | |
| img = resize_image(input_image, image_resolution) | |
| # Update the clicked_points | |
| resized_points = resize_points( | |
| clicked_points, input_image.shape, image_resolution | |
| ) | |
| mask_click_np = self.get_click_mask(img, resized_points) | |
| # Convert mask_click_np to HWC format | |
| mask_click_np = np.transpose(mask_click_np, (1, 2, 0)) * 255.0 | |
| mask_image = HWC3(mask_click_np.astype(np.uint8)) | |
| mask_image = cv2.resize( | |
| mask_image, (W, H), interpolation=cv2.INTER_LINEAR) | |
| # mask_image = Image.fromarray(mask_image_tmp) | |
| # Draw circles for all clicked points | |
| edited_image = input_image | |
| for x, y, lab in clicked_points: | |
| # Set the circle color based on the label | |
| color = (255, 0, 0) if lab == 1 else (0, 0, 255) | |
| # Draw the circle | |
| edited_image = cv2.circle(edited_image, (x, y), 20, color, -1) | |
| # Set the opacity for the mask_image and edited_image | |
| opacity_mask = 0.75 | |
| opacity_edited = 1.0 | |
| # Combine the edited_image and the mask_image using cv2.addWeighted() | |
| overlay_image = cv2.addWeighted( | |
| edited_image, | |
| opacity_edited, | |
| (mask_image * | |
| np.array([0 / 255, 255 / 255, 0 / 255])).astype(np.uint8), | |
| opacity_mask, | |
| 0, | |
| ) | |
| return ( | |
| Image.fromarray(overlay_image), | |
| clicked_points, | |
| Image.fromarray(mask_image), | |
| ) | |
| # for debug use | |
| def process( | |
| self, | |
| source_image, | |
| enable_all_generate, | |
| mask_image, | |
| control_scale, | |
| enable_auto_prompt, | |
| a_prompt, | |
| n_prompt, | |
| num_samples, | |
| image_resolution, | |
| detect_resolution, | |
| ddim_steps, | |
| guess_mode, | |
| scale, | |
| seed, | |
| eta, | |
| enable_tile=True, | |
| refine_alignment_ratio=None, | |
| refine_image_resolution=None, | |
| alpha_weight=0.5, | |
| use_scale_map=False, | |
| condition_model=None, | |
| ref_image=None, | |
| attention_auto_machine_weight=1.0, | |
| gn_auto_machine_weight=1.0, | |
| style_fidelity=0.5, | |
| reference_attn=True, | |
| reference_adain=True, | |
| ref_prompt=None, | |
| ref_sam_scale=None, | |
| ref_inpaint_scale=None, | |
| ref_auto_prompt=False, | |
| ref_textinv=True, | |
| ref_textinv_path=None, | |
| ref_scale=None, | |
| ): | |
| if condition_model is None or condition_model == "EditAnything": | |
| this_controlnet_path = self.default_controlnet_path | |
| else: | |
| this_controlnet_path = condition_model | |
| input_image = ( | |
| source_image["image"] | |
| if isinstance(source_image, dict) | |
| else np.array(source_image, dtype=np.uint8) | |
| ) | |
| if mask_image is None: | |
| if enable_all_generate != self.defalut_enable_all_generate: | |
| self.pipe = obtain_generation_model( | |
| self.base_model_path, | |
| self.lora_model_path, | |
| this_controlnet_path, | |
| enable_all_generate, | |
| self.extra_inpaint, | |
| ) | |
| self.defalut_enable_all_generate = enable_all_generate | |
| if enable_all_generate: | |
| mask_image = ( | |
| np.ones((input_image.shape[0], | |
| input_image.shape[1], 3)) * 255 | |
| ) | |
| else: | |
| mask_image = source_image["mask"] | |
| else: | |
| mask_image = np.array(mask_image, dtype=np.uint8) | |
| if self.default_controlnet_path != this_controlnet_path: | |
| print( | |
| "To Use:", | |
| this_controlnet_path, | |
| "Current:", | |
| self.default_controlnet_path, | |
| ) | |
| print("Change condition model to:", this_controlnet_path) | |
| self.pipe = obtain_generation_model( | |
| self.base_model_path, | |
| self.lora_model_path, | |
| this_controlnet_path, | |
| enable_all_generate, | |
| self.extra_inpaint, | |
| ) | |
| self.default_controlnet_path = this_controlnet_path | |
| torch.cuda.empty_cache() | |
| if self.last_ref_infer: | |
| print("Redefine the model to overwrite the ref mode") | |
| self.pipe = obtain_generation_model( | |
| self.base_model_path, | |
| self.lora_model_path, | |
| this_controlnet_path, | |
| enable_all_generate, | |
| self.extra_inpaint, | |
| ) | |
| self.last_ref_infer = False | |
| if ref_image is not None: | |
| ref_mask = ref_image["mask"] | |
| ref_image = ref_image["image"] | |
| if ref_auto_prompt or ref_textinv: | |
| bbox = get_bounding_box( | |
| np.array(ref_mask) / 255 | |
| ) # reverse the mask to make 1 the choosen region | |
| cropped_ref_mask = ref_mask.crop( | |
| (bbox[0], bbox[1], bbox[2], bbox[3])) | |
| cropped_ref_image = ref_image.crop( | |
| (bbox[0], bbox[1], bbox[2], bbox[3])) | |
| # cropped_ref_image.save("debug.jpg") | |
| cropped_ref_image = np.array(cropped_ref_image) * ( | |
| np.array(cropped_ref_mask)[:, :, :3] / 255.0 | |
| ) | |
| cropped_ref_image = Image.fromarray( | |
| cropped_ref_image.astype("uint8")) | |
| if ref_auto_prompt: | |
| generated_prompt = self.get_blip2_text(cropped_ref_image) | |
| ref_prompt += generated_prompt | |
| a_prompt += generated_prompt | |
| print("Generated ref text:", ref_prompt) | |
| print("Generated input text:", a_prompt) | |
| self.last_ref_infer = True | |
| # ref_image = cropped_ref_image | |
| # ref_mask = cropped_ref_mask | |
| if ref_textinv: | |
| try: | |
| self.pipe.load_textual_inversion(ref_textinv_path) | |
| print("Load textinv embedding from:", ref_textinv_path) | |
| except: | |
| print("No textinvert embeddings found.") | |
| ref_data_path = "./utils/tmp/textinv/img" | |
| if not os.path.exists(ref_data_path): | |
| os.makedirs(ref_data_path) | |
| cropped_ref_image.save( | |
| os.path.join(ref_data_path, 'ref.png')) | |
| print("Ref image region is save to:", ref_data_path) | |
| print( | |
| "Plese finetune with run_texutal_inversion.sh in utils folder to get the textinvert embeddings.") | |
| else: | |
| ref_mask = None | |
| with torch.no_grad(): | |
| if self.use_blip and enable_auto_prompt: | |
| print("Generating text:") | |
| blip2_prompt = self.get_blip2_text(input_image) | |
| print("Generated text:", blip2_prompt) | |
| if len(a_prompt) > 0: | |
| a_prompt = blip2_prompt + "," + a_prompt | |
| else: | |
| a_prompt = blip2_prompt | |
| input_image = HWC3(input_image) | |
| img = resize_image(input_image, image_resolution) | |
| H, W, C = img.shape | |
| print("Generating SAM seg:") | |
| # the default SAM model is trained with 1024 size. | |
| full_segmask, detected_map = self.get_sam_control( | |
| resize_image(input_image, detect_resolution) | |
| ) | |
| detected_map = HWC3(detected_map.astype(np.uint8)) | |
| detected_map = cv2.resize( | |
| detected_map, (W, H), interpolation=cv2.INTER_LINEAR | |
| ) | |
| control = torch.from_numpy(detected_map.copy()).float().cuda() | |
| control = control.unsqueeze(dim=0) | |
| control = einops.rearrange(control, "b h w c -> b c h w").clone() | |
| mask_imag_ori = HWC3(mask_image.astype(np.uint8)) | |
| mask_image_tmp = cv2.resize( | |
| mask_imag_ori, (W, H), interpolation=cv2.INTER_LINEAR | |
| ) | |
| mask_image = Image.fromarray(mask_image_tmp) | |
| if seed == -1: | |
| seed = random.randint(0, 65535) | |
| seed_everything(seed) | |
| generator = torch.manual_seed(seed) | |
| postive_prompt = a_prompt | |
| negative_prompt = n_prompt | |
| prompt_embeds, negative_prompt_embeds = get_pipeline_embeds( | |
| self.pipe, postive_prompt, negative_prompt, "cuda" | |
| ) | |
| if enable_all_generate and self.extra_inpaint: | |
| if ref_image is not None: | |
| print("Not support yet.") | |
| return | |
| x_samples = self.pipe( | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| num_images_per_prompt=num_samples, | |
| num_inference_steps=ddim_steps, | |
| generator=generator, | |
| height=H, | |
| width=W, | |
| image=[control.type(torch.float16)], | |
| controlnet_conditioning_scale=[float(control_scale)], | |
| guidance_scale=scale, | |
| guess_mode=guess_mode, | |
| ).images | |
| else: | |
| multi_condition_image = [] | |
| multi_condition_scale = [] | |
| multi_condition_image.append(control.type(torch.float16)) | |
| multi_condition_scale.append(float(control_scale)) | |
| ref_multi_condition_scale = [] | |
| if ref_image is not None: | |
| ref_multi_condition_scale.append(float(ref_sam_scale)) | |
| if self.extra_inpaint: | |
| inpaint_image = make_inpaint_condition(img, mask_image_tmp) | |
| multi_condition_image.append( | |
| inpaint_image.type(torch.float16)) | |
| multi_condition_scale.append(1.0) | |
| if ref_image is not None: | |
| ref_multi_condition_scale.append( | |
| float(ref_inpaint_scale)) | |
| if use_scale_map: | |
| scale_map_tmp = source_image["mask"] | |
| tmp = HWC3(scale_map_tmp.astype(np.uint8)) | |
| scale_map_tmp = cv2.resize( | |
| tmp, (W, H), interpolation=cv2.INTER_LINEAR) | |
| scale_map_tmp = Image.fromarray(scale_map_tmp) | |
| controlnet_conditioning_scale_map = 1.0 - \ | |
| prepare_mask_image(scale_map_tmp).float() | |
| print('scale map:', controlnet_conditioning_scale_map.size()) | |
| else: | |
| controlnet_conditioning_scale_map = None | |
| if isinstance(self.pipe, StableDiffusionControlNetInpaintMixingPipeline): | |
| x_samples = self.pipe( | |
| image=img, | |
| mask_image=mask_image, | |
| prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, | |
| num_images_per_prompt=num_samples, | |
| num_inference_steps=ddim_steps, | |
| generator=generator, | |
| controlnet_conditioning_image=multi_condition_image, | |
| height=H, | |
| width=W, | |
| controlnet_conditioning_scale=multi_condition_scale, | |
| guidance_scale=scale, | |
| alpha_weight=alpha_weight, | |
| controlnet_conditioning_scale_map=controlnet_conditioning_scale_map | |
| ).images | |
| else: | |
| x_samples = self.pipe( | |
| image=img, | |
| mask_image=mask_image, | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| num_images_per_prompt=num_samples, | |
| num_inference_steps=ddim_steps, | |
| generator=generator, | |
| controlnet_conditioning_image=multi_condition_image, | |
| height=H, | |
| width=W, | |
| controlnet_conditioning_scale=multi_condition_scale, | |
| guidance_scale=scale, | |
| ref_image=ref_image, | |
| ref_mask=ref_mask, | |
| ref_prompt=ref_prompt, | |
| attention_auto_machine_weight=attention_auto_machine_weight, | |
| gn_auto_machine_weight=gn_auto_machine_weight, | |
| style_fidelity=style_fidelity, | |
| reference_attn=reference_attn, | |
| reference_adain=reference_adain, | |
| ref_controlnet_conditioning_scale=ref_multi_condition_scale, | |
| guess_mode=guess_mode, | |
| ref_scale=ref_scale, | |
| ).images | |
| results = [x_samples[i] for i in range(num_samples)] | |
| results_tile = [] | |
| if enable_tile: | |
| prompt_embeds, negative_prompt_embeds = get_pipeline_embeds( | |
| self.tile_pipe, postive_prompt, negative_prompt, "cuda" | |
| ) | |
| for i in range(num_samples): | |
| img_tile = PIL.Image.fromarray( | |
| resize_image( | |
| np.array(x_samples[i]), refine_image_resolution) | |
| ) | |
| if i == 0: | |
| mask_image_tile = cv2.resize( | |
| mask_imag_ori, | |
| (img_tile.size[0], img_tile.size[1]), | |
| interpolation=cv2.INTER_LINEAR, | |
| ) | |
| mask_image_tile = Image.fromarray(mask_image_tile) | |
| if isinstance(self.pipe, StableDiffusionControlNetInpaintMixingPipeline): | |
| x_samples_tile = self.tile_pipe( | |
| image=img_tile, | |
| mask_image=mask_image_tile, | |
| prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, | |
| num_images_per_prompt=1, | |
| num_inference_steps=ddim_steps, | |
| generator=generator, | |
| controlnet_conditioning_image=img_tile, | |
| height=img_tile.size[1], | |
| width=img_tile.size[0], | |
| controlnet_conditioning_scale=1.0, | |
| alignment_ratio=refine_alignment_ratio, | |
| guidance_scale=scale, | |
| alpha_weight=alpha_weight, | |
| controlnet_conditioning_scale_map=controlnet_conditioning_scale_map | |
| ).images | |
| else: | |
| x_samples_tile = self.tile_pipe( | |
| image=img_tile, | |
| mask_image=mask_image_tile, | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| num_images_per_prompt=1, | |
| num_inference_steps=ddim_steps, | |
| generator=generator, | |
| controlnet_conditioning_image=img_tile, | |
| height=img_tile.size[1], | |
| width=img_tile.size[0], | |
| controlnet_conditioning_scale=1.0, | |
| alignment_ratio=refine_alignment_ratio, | |
| guidance_scale=scale, | |
| guess_mode=guess_mode, | |
| ).images | |
| results_tile += x_samples_tile | |
| return results_tile, results, [full_segmask, mask_image], postive_prompt | |
| def download_image(url): | |
| response = requests.get(url) | |
| return Image.open(BytesIO(response.content)).convert("RGB") | |