Spaces:
Running
on
Zero
Running
on
Zero
Upload tagger.py
Browse files
tagger.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
from PIL import Image
|
| 2 |
import torch
|
| 3 |
import gradio as gr
|
| 4 |
-
import spaces
|
| 5 |
-
|
| 6 |
from transformers import (
|
| 7 |
AutoImageProcessor,
|
| 8 |
AutoModelForImageClassification,
|
| 9 |
)
|
| 10 |
|
|
|
|
| 11 |
WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
|
| 12 |
WD_MODEL_NAME = WD_MODEL_NAMES[0]
|
| 13 |
|
|
@@ -49,6 +49,34 @@ DANBOORU_TO_E621_RATING_MAP = {
|
|
| 49 |
}
|
| 50 |
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
def to_list(s):
|
| 53 |
return [x.strip() for x in s.split(",") if not s == ""]
|
| 54 |
|
|
@@ -110,7 +138,7 @@ def select_random_character(series: str, character: str):
|
|
| 110 |
def danbooru_to_e621(dtag, e621_dict):
|
| 111 |
def d_to_e(match, e621_dict):
|
| 112 |
dtag = match.group(0)
|
| 113 |
-
etag = e621_dict.get(dtag
|
| 114 |
if etag:
|
| 115 |
return etag
|
| 116 |
else:
|
|
@@ -134,7 +162,7 @@ def convert_danbooru_to_e621_prompt(input_prompt: str = "", prompt_type: str = "
|
|
| 134 |
|
| 135 |
e621_dict = danbooru_to_e621_dict
|
| 136 |
for tag in tags:
|
| 137 |
-
tag = tag
|
| 138 |
tag = danbooru_to_e621(tag, e621_dict)
|
| 139 |
if tag in PEOPLE_TAGS:
|
| 140 |
people_tags.append(tag)
|
|
@@ -162,6 +190,7 @@ def translate_prompt(prompt: str = ""):
|
|
| 162 |
translated_prompt = translator.translate(prompt, src='auto', dest='en').text
|
| 163 |
return translated_prompt
|
| 164 |
except Exception as e:
|
|
|
|
| 165 |
return prompt
|
| 166 |
|
| 167 |
def is_japanese(s):
|
|
@@ -194,6 +223,7 @@ def translate_prompt_to_ja(prompt: str = ""):
|
|
| 194 |
translated_prompt = translator.translate(prompt, src='en', dest='ja').text
|
| 195 |
return translated_prompt
|
| 196 |
except Exception as e:
|
|
|
|
| 197 |
return prompt
|
| 198 |
|
| 199 |
def is_japanese(s):
|
|
@@ -219,7 +249,7 @@ def translate_prompt_to_ja(prompt: str = ""):
|
|
| 219 |
def tags_to_ja(itag, dict):
|
| 220 |
def t_to_j(match, dict):
|
| 221 |
tag = match.group(0)
|
| 222 |
-
ja = dict.get(tag
|
| 223 |
if ja:
|
| 224 |
return ja
|
| 225 |
else:
|
|
@@ -238,7 +268,7 @@ def convert_tags_to_ja(input_prompt: str = ""):
|
|
| 238 |
tags_to_ja_dict = load_dict_from_csv('all_tags_ja_ext.csv')
|
| 239 |
dict = tags_to_ja_dict
|
| 240 |
for tag in tags:
|
| 241 |
-
tag = tag
|
| 242 |
tag = tags_to_ja(tag, dict)
|
| 243 |
out_tags.append(tag)
|
| 244 |
|
|
@@ -365,7 +395,7 @@ def remove_specific_prompt(input_prompt: str = "", keep_tags: str = "all"):
|
|
| 365 |
|
| 366 |
group_dict = tag_group_dict
|
| 367 |
for tag in tags:
|
| 368 |
-
tag = tag
|
| 369 |
if tag in PEOPLE_TAGS:
|
| 370 |
people_tags.append(tag)
|
| 371 |
elif is_necessary(tag, keep_tags, group_dict):
|
|
@@ -393,7 +423,7 @@ def sort_taglist(tags: list[str]):
|
|
| 393 |
rating_set = set(DANBOORU_TO_E621_RATING_MAP.keys()) | set(DANBOORU_TO_E621_RATING_MAP.values())
|
| 394 |
|
| 395 |
for tag in tags:
|
| 396 |
-
tag = tag
|
| 397 |
if tag in PEOPLE_TAGS:
|
| 398 |
people_tags.append(tag)
|
| 399 |
elif tag in rating_set:
|
|
@@ -494,12 +524,13 @@ def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_t
|
|
| 494 |
output_series_tag = output_series_list[0]
|
| 495 |
else:
|
| 496 |
output_series_tag = ""
|
| 497 |
-
return output_series_tag, ", ".join(character.keys()), prompt, gr.update(interactive=True)
|
| 498 |
|
| 499 |
|
| 500 |
-
def predict_tags_wd(image: Image.Image, input_tags: str, algo: list[str], general_threshold: float = 0.3,
|
|
|
|
| 501 |
if not "Use WD Tagger" in algo and len(algo) != 0:
|
| 502 |
-
return
|
| 503 |
return predict_tags(image, general_threshold, character_threshold)
|
| 504 |
|
| 505 |
|
|
|
|
| 1 |
from PIL import Image
|
| 2 |
import torch
|
| 3 |
import gradio as gr
|
| 4 |
+
import spaces
|
|
|
|
| 5 |
from transformers import (
|
| 6 |
AutoImageProcessor,
|
| 7 |
AutoModelForImageClassification,
|
| 8 |
)
|
| 9 |
|
| 10 |
+
|
| 11 |
WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
|
| 12 |
WD_MODEL_NAME = WD_MODEL_NAMES[0]
|
| 13 |
|
|
|
|
| 49 |
}
|
| 50 |
|
| 51 |
|
| 52 |
+
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
|
| 53 |
+
kaomojis = [
|
| 54 |
+
"0_0",
|
| 55 |
+
"(o)_(o)",
|
| 56 |
+
"+_+",
|
| 57 |
+
"+_-",
|
| 58 |
+
"._.",
|
| 59 |
+
"<o>_<o>",
|
| 60 |
+
"<|>_<|>",
|
| 61 |
+
"=_=",
|
| 62 |
+
">_<",
|
| 63 |
+
"3_3",
|
| 64 |
+
"6_9",
|
| 65 |
+
">_o",
|
| 66 |
+
"@_@",
|
| 67 |
+
"^_^",
|
| 68 |
+
"o_o",
|
| 69 |
+
"u_u",
|
| 70 |
+
"x_x",
|
| 71 |
+
"|_|",
|
| 72 |
+
"||_||",
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def replace_underline(x: str):
|
| 77 |
+
return x.strip().replace("_", " ") if x not in kaomojis else x.strip()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
def to_list(s):
|
| 81 |
return [x.strip() for x in s.split(",") if not s == ""]
|
| 82 |
|
|
|
|
| 138 |
def danbooru_to_e621(dtag, e621_dict):
|
| 139 |
def d_to_e(match, e621_dict):
|
| 140 |
dtag = match.group(0)
|
| 141 |
+
etag = e621_dict.get(replace_underline(dtag), "")
|
| 142 |
if etag:
|
| 143 |
return etag
|
| 144 |
else:
|
|
|
|
| 162 |
|
| 163 |
e621_dict = danbooru_to_e621_dict
|
| 164 |
for tag in tags:
|
| 165 |
+
tag = replace_underline(tag)
|
| 166 |
tag = danbooru_to_e621(tag, e621_dict)
|
| 167 |
if tag in PEOPLE_TAGS:
|
| 168 |
people_tags.append(tag)
|
|
|
|
| 190 |
translated_prompt = translator.translate(prompt, src='auto', dest='en').text
|
| 191 |
return translated_prompt
|
| 192 |
except Exception as e:
|
| 193 |
+
print(e)
|
| 194 |
return prompt
|
| 195 |
|
| 196 |
def is_japanese(s):
|
|
|
|
| 223 |
translated_prompt = translator.translate(prompt, src='en', dest='ja').text
|
| 224 |
return translated_prompt
|
| 225 |
except Exception as e:
|
| 226 |
+
print(e)
|
| 227 |
return prompt
|
| 228 |
|
| 229 |
def is_japanese(s):
|
|
|
|
| 249 |
def tags_to_ja(itag, dict):
|
| 250 |
def t_to_j(match, dict):
|
| 251 |
tag = match.group(0)
|
| 252 |
+
ja = dict.get(replace_underline(tag), "")
|
| 253 |
if ja:
|
| 254 |
return ja
|
| 255 |
else:
|
|
|
|
| 268 |
tags_to_ja_dict = load_dict_from_csv('all_tags_ja_ext.csv')
|
| 269 |
dict = tags_to_ja_dict
|
| 270 |
for tag in tags:
|
| 271 |
+
tag = replace_underline(tag)
|
| 272 |
tag = tags_to_ja(tag, dict)
|
| 273 |
out_tags.append(tag)
|
| 274 |
|
|
|
|
| 395 |
|
| 396 |
group_dict = tag_group_dict
|
| 397 |
for tag in tags:
|
| 398 |
+
tag = replace_underline(tag)
|
| 399 |
if tag in PEOPLE_TAGS:
|
| 400 |
people_tags.append(tag)
|
| 401 |
elif is_necessary(tag, keep_tags, group_dict):
|
|
|
|
| 423 |
rating_set = set(DANBOORU_TO_E621_RATING_MAP.keys()) | set(DANBOORU_TO_E621_RATING_MAP.values())
|
| 424 |
|
| 425 |
for tag in tags:
|
| 426 |
+
tag = replace_underline(tag)
|
| 427 |
if tag in PEOPLE_TAGS:
|
| 428 |
people_tags.append(tag)
|
| 429 |
elif tag in rating_set:
|
|
|
|
| 524 |
output_series_tag = output_series_list[0]
|
| 525 |
else:
|
| 526 |
output_series_tag = ""
|
| 527 |
+
return output_series_tag, ", ".join(character.keys()), prompt, gr.update(interactive=True)
|
| 528 |
|
| 529 |
|
| 530 |
+
def predict_tags_wd(image: Image.Image, input_tags: str, algo: list[str], general_threshold: float = 0.3,
|
| 531 |
+
character_threshold: float = 0.8, input_series: str = "", input_character: str = ""):
|
| 532 |
if not "Use WD Tagger" in algo and len(algo) != 0:
|
| 533 |
+
return input_series, input_character, input_tags, gr.update(interactive=True)
|
| 534 |
return predict_tags(image, general_threshold, character_threshold)
|
| 535 |
|
| 536 |
|