Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| import gradio as gr | |
| import torch | |
| import spaces | |
| torch.jit.script = lambda f: f | |
| import timm | |
| import time | |
| from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download | |
| from safetensors.torch import load_file | |
| from share_btn import community_icon_html, loading_icon_html, share_js | |
| from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler | |
| import lora | |
| import copy | |
| import json | |
| import gc | |
| import random | |
| from urllib.parse import quote | |
| import gdown | |
| import os | |
| import re | |
| import requests | |
| import diffusers | |
| from diffusers.utils import load_image | |
| from diffusers.models import ControlNetModel | |
| from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, UNet2DConditionModel | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from insightface.app import FaceAnalysis | |
| from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline, draw_kps | |
| from controlnet_aux import ZoeDetector | |
| from compel import Compel, ReturnedEmbeddingsType | |
| from gradio_imageslider import ImageSlider | |
| #from gradio_imageslider import ImageSlider | |
| with open("sdxl_loras.json", "r") as file: | |
| data = json.load(file) | |
| sdxl_loras_raw = [ | |
| { | |
| "image": item["image"], | |
| "title": item["title"], | |
| "repo": item["repo"], | |
| "trigger_word": item["trigger_word"], | |
| "weights": item["weights"], | |
| "is_compatible": item["is_compatible"], | |
| "is_pivotal": item.get("is_pivotal", False), | |
| "text_embedding_weights": item.get("text_embedding_weights", None), | |
| "likes": item.get("likes", 0), | |
| "downloads": item.get("downloads", 0), | |
| "is_nc": item.get("is_nc", False), | |
| "new": item.get("new", False), | |
| } | |
| for item in data | |
| ] | |
| with open("defaults_data.json", "r") as file: | |
| lora_defaults = json.load(file) | |
| device = "cuda" | |
| state_dicts = {} | |
| for item in sdxl_loras_raw: | |
| saved_name = hf_hub_download(item["repo"], item["weights"]) | |
| if not saved_name.endswith('.safetensors'): | |
| state_dict = torch.load(saved_name) | |
| else: | |
| state_dict = load_file(saved_name) | |
| state_dicts[item["repo"]] = { | |
| "saved_name": saved_name, | |
| "state_dict": state_dict | |
| } | |
| sdxl_loras_raw = [item for item in sdxl_loras_raw if item.get("new") != True] | |
| # download models | |
| hf_hub_download( | |
| repo_id="InstantX/InstantID", | |
| filename="ControlNetModel/config.json", | |
| local_dir="/data/checkpoints", | |
| ) | |
| hf_hub_download( | |
| repo_id="InstantX/InstantID", | |
| filename="ControlNetModel/diffusion_pytorch_model.safetensors", | |
| local_dir="/data/checkpoints", | |
| ) | |
| hf_hub_download( | |
| repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="/data/checkpoints" | |
| ) | |
| hf_hub_download( | |
| repo_id="latent-consistency/lcm-lora-sdxl", | |
| filename="pytorch_lora_weights.safetensors", | |
| local_dir="/data/checkpoints", | |
| ) | |
| # download antelopev2 | |
| #if not os.path.exists("/data/antelopev2.zip"): | |
| # gdown.download(url="https://drive.google.com/file/d/18wEUfMNohBJ4K3Ly5wpTejPfDzp-8fI8/view?usp=sharing", output="/data/", quiet=False, fuzzy=True) | |
| # os.system("unzip /data/antelopev2.zip -d /data/models/") | |
| antelope_download = snapshot_download(repo_id="DIAMONIK7777/antelopev2", local_dir="/data/models/antelopev2") | |
| print(antelope_download) | |
| app = FaceAnalysis(name='antelopev2', root='/data', providers=['CPUExecutionProvider']) | |
| app.prepare(ctx_id=0, det_size=(640, 640)) | |
| # prepare models under ./checkpoints | |
| face_adapter = f'/data/checkpoints/ip-adapter.bin' | |
| controlnet_path = f'/data/checkpoints/ControlNetModel' | |
| # load IdentityNet | |
| st = time.time() | |
| identitynet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) | |
| zoedepthnet = ControlNetModel.from_pretrained("diffusers/controlnet-zoe-depth-sdxl-1.0",torch_dtype=torch.float16) | |
| et = time.time() | |
| elapsed_time = et - st | |
| print('Loading ControlNet took: ', elapsed_time, 'seconds') | |
| st = time.time() | |
| vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) | |
| et = time.time() | |
| elapsed_time = et - st | |
| print('Loading VAE took: ', elapsed_time, 'seconds') | |
| st = time.time() | |
| pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained("frankjoshua/albedobaseXL_v21", | |
| vae=vae, | |
| controlnet=[identitynet, zoedepthnet], | |
| torch_dtype=torch.float16) | |
| pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True) | |
| pipe.load_ip_adapter_instantid(face_adapter) | |
| pipe.set_ip_adapter_scale(0.8) | |
| et = time.time() | |
| elapsed_time = et - st | |
| print('Loading pipeline took: ', elapsed_time, 'seconds') | |
| st = time.time() | |
| compel = Compel(tokenizer=[pipe.tokenizer, pipe.tokenizer_2] , text_encoder=[pipe.text_encoder, pipe.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True]) | |
| et = time.time() | |
| elapsed_time = et - st | |
| print('Loading Compel took: ', elapsed_time, 'seconds') | |
| st = time.time() | |
| zoe = ZoeDetector.from_pretrained("lllyasviel/Annotators") | |
| et = time.time() | |
| elapsed_time = et - st | |
| print('Loading Zoe took: ', elapsed_time, 'seconds') | |
| zoe.to(device) | |
| pipe.to(device) | |
| last_lora = "" | |
| last_fused = False | |
| js = ''' | |
| var button = document.getElementById('button'); | |
| // Add a click event listener to the button | |
| button.addEventListener('click', function() { | |
| element.classList.add('selected'); | |
| }); | |
| ''' | |
| lora_archive = "/data" | |
| def update_selection(selected_state: gr.SelectData, sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative, is_new=False): | |
| lora_repo = sdxl_loras[selected_state.index]["repo"] | |
| new_placeholder = "Type a prompt to use your selected LoRA" | |
| weight_name = sdxl_loras[selected_state.index]["weights"] | |
| updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨ {'(non-commercial LoRA, `cc-by-nc`)' if sdxl_loras[selected_state.index]['is_nc'] else '' }" | |
| for lora_list in lora_defaults: | |
| if lora_list["model"] == sdxl_loras[selected_state.index]["repo"]: | |
| face_strength = lora_list.get("face_strength", 0.85) | |
| image_strength = lora_list.get("image_strength", 0.15) | |
| weight = lora_list.get("weight", 0.9) | |
| depth_control_scale = lora_list.get("depth_control_scale", 0.8) | |
| negative = lora_list.get("negative", "") | |
| if(is_new): | |
| if(selected_state.index == 0): | |
| selected_state.index = -9999 | |
| else: | |
| selected_state.index *= -1 | |
| return ( | |
| updated_text, | |
| gr.update(placeholder=new_placeholder), | |
| face_strength, | |
| image_strength, | |
| weight, | |
| depth_control_scale, | |
| negative, | |
| selected_state | |
| ) | |
| def center_crop_image_as_square(img): | |
| square_size = min(img.size) | |
| left = (img.width - square_size) / 2 | |
| top = (img.height - square_size) / 2 | |
| right = (img.width + square_size) / 2 | |
| bottom = (img.height + square_size) / 2 | |
| img_cropped = img.crop((left, top, right, bottom)) | |
| return img_cropped | |
| def check_selected(selected_state, custom_lora): | |
| if not selected_state and not custom_lora: | |
| raise gr.Error("You must select a style") | |
| def merge_incompatible_lora(full_path_lora, lora_scale): | |
| for weights_file in [full_path_lora]: | |
| if ";" in weights_file: | |
| weights_file, multiplier = weights_file.split(";") | |
| multiplier = float(multiplier) | |
| else: | |
| multiplier = lora_scale | |
| lora_model, weights_sd = lora.create_network_from_weights( | |
| multiplier, | |
| full_path_lora, | |
| pipe.vae, | |
| pipe.text_encoder, | |
| pipe.unet, | |
| for_inference=True, | |
| ) | |
| lora_model.merge_to( | |
| pipe.text_encoder, pipe.unet, weights_sd, torch.float16, "cuda" | |
| ) | |
| del weights_sd | |
| del lora_model | |
| def generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, repo_name, loaded_state_dict, lora_scale, sdxl_loras, selected_state_index, st): | |
| print(loaded_state_dict) | |
| et = time.time() | |
| elapsed_time = et - st | |
| print('Getting into the decorated function took: ', elapsed_time, 'seconds') | |
| global last_fused, last_lora | |
| print("Last LoRA: ", last_lora) | |
| print("Current LoRA: ", repo_name) | |
| print("Last fused: ", last_fused) | |
| #prepare face zoe | |
| st = time.time() | |
| with torch.no_grad(): | |
| image_zoe = zoe(face_image) | |
| width, height = face_kps.size | |
| images = [face_kps, image_zoe.resize((height, width))] | |
| et = time.time() | |
| elapsed_time = et - st | |
| print('Zoe Depth calculations took: ', elapsed_time, 'seconds') | |
| if last_lora != repo_name: | |
| if(last_fused): | |
| st = time.time() | |
| pipe.unfuse_lora() | |
| pipe.unload_lora_weights() | |
| pipe.unload_textual_inversion() | |
| et = time.time() | |
| elapsed_time = et - st | |
| print('Unfuse and unload LoRA took: ', elapsed_time, 'seconds') | |
| st = time.time() | |
| pipe.load_lora_weights(loaded_state_dict) | |
| pipe.fuse_lora(lora_scale) | |
| et = time.time() | |
| elapsed_time = et - st | |
| print('Fuse and load LoRA took: ', elapsed_time, 'seconds') | |
| last_fused = True | |
| is_pivotal = sdxl_loras[selected_state_index]["is_pivotal"] | |
| if(is_pivotal): | |
| #Add the textual inversion embeddings from pivotal tuning models | |
| text_embedding_name = sdxl_loras[selected_state_index]["text_embedding_weights"] | |
| embedding_path = hf_hub_download(repo_id=repo_name, filename=text_embedding_name, repo_type="model") | |
| state_dict_embedding = load_file(embedding_path) | |
| pipe.load_textual_inversion(state_dict_embedding["clip_l" if "clip_l" in state_dict_embedding else "text_encoders_0"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer) | |
| pipe.load_textual_inversion(state_dict_embedding["clip_g" if "clip_g" in state_dict_embedding else "text_encoders_1"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2) | |
| print("Processing prompt...") | |
| st = time.time() | |
| conditioning, pooled = compel(prompt) | |
| if(negative): | |
| negative_conditioning, negative_pooled = compel(negative) | |
| else: | |
| negative_conditioning, negative_pooled = None, None | |
| et = time.time() | |
| elapsed_time = et - st | |
| print('Prompt processing took: ', elapsed_time, 'seconds') | |
| print("Processing image...") | |
| st = time.time() | |
| image = pipe( | |
| prompt_embeds=conditioning, | |
| pooled_prompt_embeds=pooled, | |
| negative_prompt_embeds=negative_conditioning, | |
| negative_pooled_prompt_embeds=negative_pooled, | |
| width=1024, | |
| height=1024, | |
| image_embeds=face_emb, | |
| image=face_image, | |
| strength=1-image_strength, | |
| control_image=images, | |
| num_inference_steps=20, | |
| guidance_scale = guidance_scale, | |
| controlnet_conditioning_scale=[face_strength, depth_control_scale], | |
| ).images[0] | |
| et = time.time() | |
| elapsed_time = et - st | |
| print('Image processing took: ', elapsed_time, 'seconds') | |
| last_lora = repo_name | |
| return image | |
| def run_lora(face_image, prompt, negative, lora_scale, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, sdxl_loras, custom_lora, progress=gr.Progress(track_tqdm=True)): | |
| print("Custom LoRA: ", custom_lora) | |
| custom_lora_path = custom_lora[0] if custom_lora else None | |
| selected_state_index = selected_state.index if selected_state else -1 | |
| st = time.time() | |
| face_image = center_crop_image_as_square(face_image) | |
| try: | |
| face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR)) | |
| face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face | |
| face_emb = face_info['embedding'] | |
| face_kps = draw_kps(face_image, face_info['kps']) | |
| except: | |
| raise gr.Error("No face found in your image. Only face images work here. Try again") | |
| et = time.time() | |
| elapsed_time = et - st | |
| print('Cropping and calculating face embeds took: ', elapsed_time, 'seconds') | |
| st = time.time() | |
| if(custom_lora_path and custom_lora[1]): | |
| prompt = f"{prompt} {custom_lora[1]}" | |
| else: | |
| for lora_list in lora_defaults: | |
| if lora_list["model"] == sdxl_loras[selected_state_index]["repo"]: | |
| prompt_full = lora_list.get("prompt", None) | |
| if(prompt_full): | |
| prompt = prompt_full.replace("<subject>", prompt) | |
| print("Prompt:", prompt) | |
| if(prompt == ""): | |
| prompt = "a person" | |
| print(f"Executing prompt: {prompt}") | |
| #print("Selected State: ", selected_state_index) | |
| #print(sdxl_loras[selected_state_index]["repo"]) | |
| if negative == "": | |
| negative = None | |
| print("Custom Loaded LoRA: ", custom_lora_path) | |
| if not selected_state and not custom_lora_path: | |
| raise gr.Error("You must select a style") | |
| elif custom_lora_path: | |
| repo_name = custom_lora_path | |
| full_path_lora = custom_lora_path | |
| else: | |
| repo_name = sdxl_loras[selected_state_index]["repo"] | |
| weight_name = sdxl_loras[selected_state_index]["weights"] | |
| full_path_lora = state_dicts[repo_name]["saved_name"] | |
| print("Full path LoRA ", full_path_lora) | |
| #loaded_state_dict = copy.deepcopy(state_dicts[repo_name]["state_dict"]) | |
| cross_attention_kwargs = None | |
| et = time.time() | |
| elapsed_time = et - st | |
| print('Small content processing took: ', elapsed_time, 'seconds') | |
| st = time.time() | |
| image = generate_image(prompt, negative, face_emb, face_image, face_kps, image_strength, guidance_scale, face_strength, depth_control_scale, repo_name, full_path_lora, lora_scale, sdxl_loras, selected_state_index, st) | |
| return (face_image, image), gr.update(visible=True) | |
| run_lora.zerogpu = True | |
| def shuffle_gallery(sdxl_loras): | |
| random.shuffle(sdxl_loras) | |
| return [(item["image"], item["title"]) for item in sdxl_loras], sdxl_loras | |
| def classify_gallery(sdxl_loras): | |
| sorted_gallery = sorted(sdxl_loras, key=lambda x: x.get("likes", 0), reverse=True) | |
| return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery | |
| def swap_gallery(order, sdxl_loras): | |
| if(order == "random"): | |
| return shuffle_gallery(sdxl_loras) | |
| else: | |
| return classify_gallery(sdxl_loras) | |
| def deselect(): | |
| return gr.Gallery(selected_index=None) | |
| def get_huggingface_safetensors(link): | |
| split_link = link.split("/") | |
| if(len(split_link) == 2): | |
| model_card = ModelCard.load(link) | |
| image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None) | |
| trigger_word = model_card.data.get("instance_prompt", "") | |
| image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None | |
| fs = HfFileSystem() | |
| try: | |
| list_of_files = fs.ls(link, detail=False) | |
| for file in list_of_files: | |
| if(file.endswith(".safetensors")): | |
| safetensors_name = file.replace("/", "_") | |
| if(not os.path.exists(f"{lora_archive}/{safetensors_name}")): | |
| fs.get_file(file, lpath=f"{lora_archive}/{safetensors_name}") | |
| if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))): | |
| image_elements = file.split("/") | |
| image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}" | |
| except: | |
| gr.Warning("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA") | |
| raise Exception("You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA") | |
| return split_link[1], f"{lora_archive}/{safetensors_name}", trigger_word, image_url | |
| def get_civitai_safetensors(link): | |
| link_split = link.split("civitai.com/") | |
| pattern = re.compile(r'models\/(\d+)') | |
| regex_match = pattern.search(link_split[1]) | |
| if(regex_match): | |
| civitai_model_id = regex_match.group(1) | |
| else: | |
| gr.Warning("No CivitAI model id found in your URL") | |
| raise Exception("No CivitAI model id found in your URL") | |
| model_request_url = f"https://civitai.com/api/v1/models/{civitai_model_id}?token={os.getenv('CIVITAI_TOKEN')}" | |
| x = requests.get(model_request_url) | |
| if(x.status_code != 200): | |
| raise Exception("Invalid CivitAI URL") | |
| model_data = x.json() | |
| if(model_data["nsfw"] == True or model_data["nsfwLevel"] > 20): | |
| gr.Warning("The model is tagged by CivitAI as adult content and cannot be used in this shared environment.") | |
| raise Exception("The model is tagged by CivitAI as adult content and cannot be used in this shared environment.") | |
| elif(model_data["type"] != "LORA"): | |
| gr.Warning("The model isn't tagged at CivitAI as a LoRA") | |
| raise Exception("The model isn't tagged at CivitAI as a LoRA") | |
| model_link_download = None | |
| image_url = None | |
| trigger_word = "" | |
| for model in model_data["modelVersions"]: | |
| if(model["baseModel"] == "SDXL 1.0"): | |
| model_link_download = f"{model['downloadUrl']}/?token={os.getenv('CIVITAI_TOKEN')}" | |
| safetensors_name = model["files"][0]["name"] | |
| if(not os.path.exists(f"{lora_archive}/{safetensors_name}")): | |
| safetensors_file_request = requests.get(model_link_download) | |
| if(safetensors_file_request.status_code != 200): | |
| raise Exception("Invalid CivitAI download link") | |
| with open(f"{lora_archive}/{safetensors_name}", 'wb') as file: | |
| file.write(safetensors_file_request.content) | |
| trigger_word = model.get("trainedWords", [""])[0] | |
| for image in model["images"]: | |
| if(image["nsfwLevel"] == 1): | |
| image_url = image["url"] | |
| break | |
| break | |
| if(not model_link_download): | |
| gr.Warning("We couldn't find a SDXL LoRA on the model you've sent") | |
| raise Exception("We couldn't find a SDXL LoRA on the model you've sent") | |
| return model_data["name"], f"{lora_archive}/{safetensors_name}", trigger_word, image_url | |
| def check_custom_model(link): | |
| if(link.startswith("https://")): | |
| if(link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co")): | |
| link_split = link.split("huggingface.co/") | |
| return get_huggingface_safetensors(link_split[1]) | |
| elif(link.startswith("https://civitai.com") or link.startswith("https://www.civitai.com")): | |
| return get_civitai_safetensors(link) | |
| else: | |
| return get_huggingface_safetensors(link) | |
| def show_loading_widget(): | |
| return gr.update(visible=True) | |
| def load_custom_lora(link): | |
| if(link): | |
| try: | |
| title, path, trigger_word, image = check_custom_model(link) | |
| card = f''' | |
| <div class="custom_lora_card"> | |
| <span>Loaded custom LoRA:</span> | |
| <div class="card_internal"> | |
| <img src="{image}" /> | |
| <div> | |
| <h3>{title}</h3> | |
| <small>{"Using: <code><b>"+trigger_word+"</code></b> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}<br></small> | |
| </div> | |
| </div> | |
| </div> | |
| ''' | |
| return gr.update(visible=True), card, gr.update(visible=True), [path, trigger_word], gr.Gallery(selected_index=None), f"Custom: {path}" | |
| except Exception as e: | |
| gr.Warning("Invalid LoRA: either you entered an invalid link, a non-SDXL LoRA or a LoRA with mature content") | |
| return gr.update(visible=True), "Invalid LoRA: either you entered an invalid link, a non-SDXL LoRA or a LoRA with mature content", gr.update(visible=False), None, gr.update(visible=True), gr.update(visible=True) | |
| else: | |
| return gr.update(visible=False), "", gr.update(visible=False), None, gr.update(visible=True), gr.update(visible=True) | |
| def remove_custom_lora(): | |
| return "", gr.update(visible=False), gr.update(visible=False), None | |
| with gr.Blocks(css="custom.css") as demo: | |
| gr_sdxl_loras = gr.State(value=sdxl_loras_raw) | |
| title = gr.HTML( | |
| """<h1 style="display: flex; width: 330px; margin: 0 auto"> | |
| <img style='width: 120px;margin-right: 20px' src="https://i.imgur.com/DVoGw04.png"> | |
| <span>Face to All SDXL<br> | |
| <small style="font-size: 13px;display: block;font-weight: normal;opacity: 0.75;"> | |
| 🧨 diffusers InstantID + ControlNet<br> inspired by fofr's | |
| <a href="https://github.com/fofr/cog-face-to-many" target="_blank">face-to-many</a> | |
| </small> | |
| </span> | |
| </h1>""", | |
| elem_id="title", | |
| ) | |
| selected_state = gr.State() | |
| custom_loaded_lora = gr.State() | |
| with gr.Row(elem_id="main_app"): | |
| with gr.Column(scale=4, elem_id="box_column"): | |
| with gr.Group(elem_id="gallery_box"): | |
| photo = gr.Image(label="Upload a picture of yourself", interactive=True, type="pil", height=300) | |
| selected_loras = gr.Gallery(label="Selected LoRAs", height=80, show_share_button=False, visible=False, elem_id="gallery_selected", ) | |
| #order_gallery = gr.Radio(choices=["random", "likes"], value="random", label="Order by", elem_id="order_radio") | |
| #new_gallery = gr.Gallery( | |
| # label="New LoRAs", | |
| # elem_id="gallery_new", | |
| # columns=3, | |
| # value=[(item["image"], item["title"]) for item in sdxl_loras_raw_new], allow_preview=False, show_share_button=False) | |
| gallery = gr.Gallery( | |
| #value=[(item["image"], item["title"]) for item in sdxl_loras], | |
| label="Pick a style from the gallery", | |
| allow_preview=False, | |
| columns=4, | |
| elem_id="gallery", | |
| show_share_button=False, | |
| height=550 | |
| ) | |
| custom_model = gr.Textbox(label="or enter a custom Hugging Face or CivitAI SDXL LoRA", placeholder="Paste Hugging Face or CivitAI model path...") | |
| custom_model_card = gr.HTML(visible=False) | |
| custom_model_button = gr.Button("Remove custom LoRA", visible=False) | |
| with gr.Column(scale=5): | |
| with gr.Row(): | |
| prompt = gr.Textbox(label="Prompt", show_label=False, lines=1, max_lines=1, info="Describe your subject (optional)", value="a person", elem_id="prompt") | |
| button = gr.Button("Run", elem_id="run_button") | |
| result = ImageSlider( | |
| interactive=False, label="Generated Image", elem_id="result-image", position=0.1 | |
| ) | |
| with gr.Group(elem_id="share-btn-container", visible=False) as share_group: | |
| community_icon = gr.HTML(community_icon_html) | |
| loading_icon = gr.HTML(loading_icon_html) | |
| share_button = gr.Button("Share to community", elem_id="share-btn") | |
| with gr.Accordion("Advanced options", open=False): | |
| negative = gr.Textbox(label="Negative Prompt") | |
| weight = gr.Slider(0, 10, value=0.9, step=0.1, label="LoRA weight") | |
| face_strength = gr.Slider(0, 2, value=0.85, step=0.01, label="Face strength", info="Higher values increase the face likeness but reduce the creative liberty of the models") | |
| image_strength = gr.Slider(0, 1, value=0.15, step=0.01, label="Image strength", info="Higher values increase the similarity with the structure/colors of the original photo") | |
| guidance_scale = gr.Slider(0, 50, value=7, step=0.1, label="Guidance Scale") | |
| depth_control_scale = gr.Slider(0, 1, value=0.8, step=0.01, label="Zoe Depth ControlNet strenght") | |
| prompt_title = gr.Markdown( | |
| value="### Click on a LoRA in the gallery to select it", | |
| visible=True, | |
| elem_id="selected_lora", | |
| ) | |
| #order_gallery.change( | |
| # fn=swap_gallery, | |
| # inputs=[order_gallery, gr_sdxl_loras], | |
| # outputs=[gallery, gr_sdxl_loras], | |
| # queue=False | |
| #) | |
| custom_model.input( | |
| fn=load_custom_lora, | |
| inputs=[custom_model], | |
| outputs=[custom_model_card, custom_model_card, custom_model_button, custom_loaded_lora, gallery, prompt_title], | |
| ) | |
| custom_model_button.click( | |
| fn=remove_custom_lora, | |
| outputs=[custom_model, custom_model_button, custom_model_card, custom_loaded_lora] | |
| ) | |
| gallery.select( | |
| fn=update_selection, | |
| inputs=[gr_sdxl_loras, face_strength, image_strength, weight, depth_control_scale, negative], | |
| outputs=[prompt_title, prompt, face_strength, image_strength, weight, depth_control_scale, negative, selected_state], | |
| show_progress=False | |
| ) | |
| #new_gallery.select( | |
| # fn=update_selection, | |
| # inputs=[gr_sdxl_loras_new, gr.State(True)], | |
| # outputs=[prompt_title, prompt, prompt, selected_state, gallery], | |
| # queue=False, | |
| # show_progress=False | |
| #) | |
| prompt.submit( | |
| fn=check_selected, | |
| inputs=[selected_state, custom_loaded_lora], | |
| show_progress=False | |
| ).success( | |
| fn=run_lora, | |
| inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora], | |
| outputs=[result, share_group], | |
| ) | |
| button.click( | |
| fn=check_selected, | |
| inputs=[selected_state, custom_loaded_lora], | |
| show_progress=False | |
| ).success( | |
| fn=run_lora, | |
| inputs=[photo, prompt, negative, weight, selected_state, face_strength, image_strength, guidance_scale, depth_control_scale, gr_sdxl_loras, custom_loaded_lora], | |
| outputs=[result, share_group], | |
| ) | |
| share_button.click(None, [], [], js=share_js) | |
| demo.load(fn=classify_gallery, inputs=[gr_sdxl_loras], outputs=[gallery, gr_sdxl_loras], js=js) | |
| demo.queue(default_concurrency_limit=None) | |
| demo.launch(share=True) | 
