Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import argparse | |
| import os | |
| import gradio as gr | |
| import huggingface_hub | |
| import numpy as np | |
| import onnxruntime as rt | |
| import pandas as pd | |
| from PIL import Image | |
| # Daftar model dan ControlNet | |
| models = ["Model A", "Model B", "Model C"] | |
| vae = ["VAE A", "VAE B", "VAE C"] | |
| controlnet_types = ["Canny", "Depth", "Normal", "Pose"] | |
| schedulers = ["Euler", "LMS", "DDIM"] | |
| # Fungsi placeholder | |
| def load_model(selected_model): | |
| return f"Model {selected_model} telah dimuat." | |
| def generate_image(prompt, neg_prompt, width, height, scheduler, num_steps, num_images, cfg_scale, seed, model): | |
| # Logika untuk menghasilkan gambar dari teks menggunakan model | |
| return [f"Gambar {i+1} untuk prompt '{prompt}' dengan model '{model}'" for i in range(num_images)], {"prompt": prompt, "neg_prompt": neg_prompt} | |
| def process_image(image, prompt, neg_prompt, model): | |
| # Logika untuk memproses gambar menggunakan model | |
| return f"Proses gambar dengan prompt '{prompt}' dan model '{model}'" | |
| def controlnet_process(image, controlnet_type, model): | |
| # Logika untuk memproses gambar menggunakan ControlNet | |
| return f"Proses gambar dengan ControlNet '{controlnet_type}' dan model '{model}'" | |
| def controlnet_process_func(image, controlnet_type, model): | |
| # Update fungsi sesuai kebutuhan | |
| return controlnet_process(image, controlnet_type, model) | |
| def intpaint_func (image, controlnet_type, model): | |
| # Update fungsi sesuai kebutuhan | |
| return controlnet_process(image, controlnet_type, model) | |
| #wd tagger | |
| # Dataset v3 series of models: | |
| SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3" | |
| CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3" | |
| VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3" | |
| VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3" | |
| EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3" | |
| # Dataset v2 series of models: | |
| MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2" | |
| SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" | |
| CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" | |
| CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" | |
| VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2" | |
| # Files to download from the repos | |
| MODEL_FILENAME = "model.onnx" | |
| LABEL_FILENAME = "selected_tags.csv" | |
| # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368 | |
| kaomojis = [ "0_0", "(o)_(o)", "+_+", "+_-", "._.", "<o>_<o>", "<|>_<|>", "=_=", ">_<", "3_3", "6_9", ">_o", "@_@", "^_^", "o_o", "u_u", "x_x", "|_|", "||_||", ] | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--score-slider-step", type=float, default=0.05) | |
| parser.add_argument("--score-general-threshold", type=float, default=0.35) | |
| parser.add_argument("--score-character-threshold", type=float, default=0.85) | |
| parser.add_argument("--share", action="store_true") | |
| return parser.parse_args() | |
| def load_labels(dataframe) -> list[str]: | |
| name_series = dataframe["name"] | |
| name_series = name_series.map( | |
| lambda x: x.replace("_", " ") if x not in kaomojis else x | |
| ) | |
| tag_names = name_series.tolist() | |
| rating_indexes = list(np.where(dataframe["category"] == 9)[0]) | |
| general_indexes = list(np.where(dataframe["category"] == 0)[0]) | |
| character_indexes = list(np.where(dataframe["category"] == 4)[0]) | |
| return tag_names, rating_indexes, general_indexes, character_indexes | |
| def mcut_threshold(probs): | |
| """ | |
| Maximum Cut Thresholding (MCut) | |
| Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy | |
| for Multi-label Classification. In 11th International Symposium, IDA 2012 | |
| (pp. 172-183). | |
| """ | |
| sorted_probs = probs[probs.argsort()[::-1]] | |
| difs = sorted_probs[:-1] - sorted_probs[1:] | |
| t = difs.argmax() | |
| thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2 | |
| return thresh | |
| class Predictor: | |
| def __init__(self): | |
| self.model_target_size = None | |
| self.last_loaded_repo = None | |
| def download_model(self, model_repo): | |
| csv_path = huggingface_hub.hf_hub_download( | |
| model_repo, | |
| LABEL_FILENAME, | |
| ) | |
| model_path = huggingface_hub.hf_hub_download( | |
| model_repo, | |
| MODEL_FILENAME, | |
| ) | |
| return csv_path, model_path | |
| def load_model(self, model_repo): | |
| if model_repo == self.last_loaded_repo: | |
| return | |
| csv_path, model_path = self.download_model(model_repo) | |
| tags_df = pd.read_csv(csv_path) | |
| sep_tags = load_labels(tags_df) | |
| self.tag_names = sep_tags[0] | |
| self.rating_indexes = sep_tags[1] | |
| self.general_indexes = sep_tags[2] | |
| self.character_indexes = sep_tags[3] | |
| model = rt.InferenceSession(model_path) | |
| _, height, width, _ = model.get_inputs()[0].shape | |
| self.model_target_size = height | |
| self.last_loaded_repo = model_repo | |
| self.model = model | |
| def prepare_image(self, image): | |
| target_size = self.model_target_size | |
| canvas = Image.new("RGBA", image.size, (255, 255, 255)) | |
| canvas.alpha_composite(image) | |
| image = canvas.convert("RGB") | |
| # Pad image to square | |
| image_shape = image.size | |
| max_dim = max(image_shape) | |
| pad_left = (max_dim - image_shape[0]) // 2 | |
| pad_top = (max_dim - image_shape[1]) // 2 | |
| padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) | |
| padded_image.paste(image, (pad_left, pad_top)) | |
| # Resize | |
| if max_dim != target_size: | |
| padded_image = padded_image.resize( | |
| (target_size, target_size), | |
| Image.BICUBIC, | |
| ) | |
| # Convert to numpy array | |
| image_array = np.asarray(padded_image, dtype=np.float32) | |
| # Convert PIL-native RGB to BGR | |
| image_array = image_array[:, :, ::-1] | |
| return np.expand_dims(image_array, axis=0) | |
| def predict( | |
| self, | |
| image, | |
| model_repo, | |
| general_thresh, | |
| general_mcut_enabled, | |
| character_thresh, | |
| character_mcut_enabled, | |
| ): | |
| self.load_model(model_repo) | |
| image = self.prepare_image(image) | |
| input_name = self.model.get_inputs()[0].name | |
| label_name = self.model.get_outputs()[0].name | |
| preds = self.model.run([label_name], {input_name: image})[0] | |
| labels = list(zip(self.tag_names, preds[0].astype(float))) | |
| # First 4 labels are actually ratings: pick one with argmax | |
| ratings_names = [labels[i] for i in self.rating_indexes] | |
| rating = dict(ratings_names) | |
| # Then we have general tags: pick any where prediction confidence > threshold | |
| general_names = [labels[i] for i in self.general_indexes] | |
| if general_mcut_enabled: | |
| general_probs = np.array([x[1] for x in general_names]) | |
| general_thresh = mcut_threshold(general_probs) | |
| general_res = [x for x in general_names if x[1] > general_thresh] | |
| general_res = dict(general_res) | |
| # Everything else is characters: pick any where prediction confidence > threshold | |
| character_names = [labels[i] for i in self.character_indexes] | |
| if character_mcut_enabled: | |
| character_probs = np.array([x[1] for x in character_names]) | |
| character_thresh = mcut_threshold(character_probs) | |
| character_thresh = max(0.15, character_thresh) | |
| character_res = [x for x in character_names if x[1] > character_thresh] | |
| character_res = dict(character_res) | |
| sorted_general_strings = sorted( | |
| general_res.items(), | |
| key=lambda x: x[1], | |
| reverse=True, | |
| ) | |
| sorted_general_strings = [x[0] for x in sorted_general_strings] | |
| sorted_general_strings = ( | |
| ", ".join(sorted_general_strings).replace("(", "\(").replace(")", "\)") | |
| ) | |
| return sorted_general_strings, rating, character_res, general_res | |
| args = parse_args() | |
| predictor = Predictor() | |
| dropdown_list = [ | |
| SWINV2_MODEL_DSV3_REPO, | |
| CONV_MODEL_DSV3_REPO, | |
| VIT_MODEL_DSV3_REPO, | |
| VIT_LARGE_MODEL_DSV3_REPO, | |
| EVA02_LARGE_MODEL_DSV3_REPO, | |
| MOAT_MODEL_DSV2_REPO, | |
| SWIN_MODEL_DSV2_REPO, | |
| CONV_MODEL_DSV2_REPO, | |
| CONV2_MODEL_DSV2_REPO, | |
| VIT_MODEL_DSV2_REPO, | |
| ] | |
| with gr.Blocks(css= "style.css") as app: | |
| # Dropdown untuk memilih model di luar tab dengan lebar kecil | |
| with gr.Row(): | |
| model_dropdown = gr.Dropdown(choices=models, label="Model", value="Model B") | |
| vae_dropdown = gr.Dropdown(choices=vae, label="VAE", value="VAE C") | |
| # Prompt dan Neg Prompt | |
| with gr.Row(): | |
| with gr.Column(scale=1): # Scale 1 ensures full width | |
| prompt_input = gr.Textbox(label="Prompt", placeholder="Masukkan prompt teks", lines=2, elem_id="prompt-input") | |
| neg_prompt_input = gr.Textbox(label="Neg Prompt", placeholder="Masukkan negasi prompt", lines=2, elem_id="neg-prompt-input") | |
| generate_button = gr.Button("Generate", elem_id="generate-button", scale=0.13) | |
| # Tab untuk Text-to-Image | |
| with gr.Tab("Text-to-Image"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Konfigurasi | |
| scheduler_input = gr.Dropdown(choices=schedulers, label="Sampling method", value=schedulers[0]) | |
| num_steps_input = gr.Slider(minimum=1, maximum=100, step=1, label="Sampling steps", value=20) | |
| width_input = gr.Slider(minimum=128, maximum=2048, step=128, label="Width", value=512) | |
| height_input = gr.Slider(minimum=128, maximum=2048, step=128, label="Height", value=512) | |
| cfg_scale_input = gr.Slider(minimum=1, maximum=20, step=1, label="CFG Scale", value=7) | |
| seed_input = gr.Number(label="Seed", value=-1) | |
| batch_size = gr.Slider(minimum=1, maximum=24, step=1, label="Batch size", value=1) | |
| batch_count = gr.Slider(minimum=1, maximum=24, step=1, label="Batch Count", value=1) | |
| with gr.Column(): | |
| # Gallery untuk output gambar | |
| output_gallery = gr.Gallery(label="Hasil Gambar") | |
| # Output teks JSON di bawah gallery | |
| output_text = gr.Textbox(label="Output JSON", placeholder="Hasil dalam format JSON", lines=2) | |
| def update_images(prompt, neg_prompt, width, height, scheduler, num_steps, num_images, cfg_scale, seed, model): | |
| # Update fungsi sesuai kebutuhan | |
| return generate_image(prompt, neg_prompt, width, height, scheduler, num_steps, num_images, cfg_scale, seed, model) | |
| generate_button.click(fn=update_images, inputs=[prompt_input, neg_prompt_input, width_input, height_input, scheduler_input, num_steps_input, batch_size, batch_count, cfg_scale_input, seed_input, model_dropdown, vae_dropdown], outputs=[output_gallery, output_text]) | |
| # Tab untuk Image-to-Image | |
| with gr.Tab("Image-to-Image"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(label="Unggah Gambar") | |
| prompt_input_i2i = gr.Textbox(label="Prompt", placeholder="Masukkan prompt teks", lines=2) | |
| neg_prompt_input_i2i = gr.Textbox(label="Neg Prompt", placeholder="Masukkan negasi prompt", lines=2) | |
| generate_button_i2i = gr.Button("Proses Gambar") | |
| with gr.Column(): | |
| output_image_i2i = gr.Image(label="Hasil Gambar") | |
| def process_image_func(image, prompt, neg_prompt, model): | |
| # Update fungsi sesuai kebutuhan | |
| return process_image(image, prompt, neg_prompt, model) | |
| generate_button_i2i.click(fn=process_image_func, inputs=[image_input, prompt_input_i2i, neg_prompt_input_i2i, model_dropdown, vae_dropdown], outputs=output_image_i2i) | |
| # Tab untuk ControlNet | |
| with gr.Tab("ControlNet"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| controlnet_dropdown = gr.Dropdown(choices=controlnet_types, label="Pilih Tipe ControlNet") | |
| controlnet_image_input = gr.Image(label="Unggah Gambar untuk ControlNet") | |
| controlnet_button = gr.Button("Proses dengan ControlNet") | |
| with gr.Column(): | |
| controlnet_output_image = gr.Image(label="Hasil ControlNet") | |
| controlnet_button.click(fn=controlnet_process_func, inputs=[controlnet_image_input, controlnet_dropdown, model_dropdown, vae_dropdown], outputs=controlnet_output_image) | |
| # Tab untuk Intpainting | |
| with gr.Tab ("Inpainting"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image = gr.ImageMask(sources=["upload"], layers=False, transforms=[], format="png", label="base image", show_label=True) | |
| btn = gr.Button("Inpaint!", elem_id="run_button") | |
| prompt = gr.Textbox(placeholder="Your prompt (what you want in place of what is erased)", show_label=False, elem_id="prompt") | |
| negative_prompt = gr.Textbox(label="negative_prompt", placeholder="Your negative prompt", info="what you don't want to see in the image") | |
| guidance_scale = gr.Number(value=7.5, minimum=1.0, maximum=20.0, step=0.1, label="guidance_scale") | |
| steps = gr.Number(value=20, minimum=10, maximum=30, step=1, label="steps") | |
| strength = gr.Number(value=0.99, minimum=0.01, maximum=1.0, step=0.01, label="strength") | |
| scheduler = gr.Dropdown(label="Schedulers", choices=schedulers, value="EulerDiscreteScheduler") | |
| with gr.Column(): | |
| image_out = gr.Image(label="Output", elem_id="output-img") | |
| btn.click(fn=intpaint_func, inputs=[image, prompt, negative_prompt, guidance_scale, steps, strength, scheduler], outputs=[image_out]) | |
| # Tab untuk Describe | |
| with gr.Tab("Describe"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Components | |
| image = gr.Image(type="pil", image_mode="RGBA", label="Input") | |
| submit_button = gr.Button(value="Submit", variant="primary", size="lg") | |
| model_repo = gr.Dropdown(dropdown_list, value=SWINV2_MODEL_DSV3_REPO, label="Model") | |
| general_thresh = gr.Slider(0, 1, step=args.score_slider_step, value=args.score_general_threshold, label="General Tags Threshold", scale=3) | |
| general_mcut_enabled = gr.Checkbox(value=False, label="Use MCut threshold", scale=1) | |
| character_thresh = gr.Slider(0, 1, step=args.score_slider_step, value=args.score_character_threshold, label="Character Tags Threshold", scale=3) | |
| character_mcut_enabled = gr.Checkbox(value=False, label="Use MCut threshold", scale=1) | |
| clear_button = gr.ClearButton(components=[image, model_repo, general_thresh, general_mcut_enabled, character_thresh, character_mcut_enabled], variant="secondary", size="lg") | |
| with gr.Column(): | |
| sorted_general_strings = gr.Textbox(label="Output (string)") | |
| rating = gr.Label(label="Rating") | |
| character_res = gr.Label(label="Output (characters)") | |
| general_res = gr.Label(label="Output (tags)") | |
| clear_button.add([sorted_general_strings, rating, character_res, general_res]) | |
| submit_button.click(predictor.predict, inputs=[image, model_repo, general_thresh, general_mcut_enabled, character_thresh, character_mcut_enabled], outputs=[sorted_general_strings, rating, character_res, general_res]) | |
| # Jalankan antarmuka | |
| app.launch() |