Upload 8 files
Browse files- tagger/fl2sd3longcap.py +9 -3
- tagger/tagger.py +11 -4
tagger/fl2sd3longcap.py
CHANGED
|
@@ -8,9 +8,13 @@ import subprocess
|
|
| 8 |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
| 9 |
|
| 10 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 11 |
-
fl_model = AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True).to(device).eval()
|
| 12 |
-
fl_processor = AutoProcessor.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True)
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
def fl_modify_caption(caption: str) -> str:
|
| 16 |
"""
|
|
@@ -41,7 +45,7 @@ def fl_modify_caption(caption: str) -> str:
|
|
| 41 |
return modified_caption if modified_caption != caption else caption
|
| 42 |
|
| 43 |
|
| 44 |
-
@spaces.GPU
|
| 45 |
def fl_run_example(image):
|
| 46 |
task_prompt = "<DESCRIPTION>"
|
| 47 |
prompt = task_prompt + "Describe this image in great detail."
|
|
@@ -50,6 +54,7 @@ def fl_run_example(image):
|
|
| 50 |
if image.mode != "RGB":
|
| 51 |
image = image.convert("RGB")
|
| 52 |
|
|
|
|
| 53 |
inputs = fl_processor(text=prompt, images=image, return_tensors="pt").to(device)
|
| 54 |
generated_ids = fl_model.generate(
|
| 55 |
input_ids=inputs["input_ids"],
|
|
@@ -57,6 +62,7 @@ def fl_run_example(image):
|
|
| 57 |
max_new_tokens=1024,
|
| 58 |
num_beams=3
|
| 59 |
)
|
|
|
|
| 60 |
generated_text = fl_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
| 61 |
parsed_answer = fl_processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
|
| 62 |
return fl_modify_caption(parsed_answer["<DESCRIPTION>"])
|
|
|
|
| 8 |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
| 9 |
|
| 10 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
| 11 |
|
| 12 |
+
try:
|
| 13 |
+
fl_model = AutoModelForCausalLM.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True).to("cpu").eval()
|
| 14 |
+
fl_processor = AutoProcessor.from_pretrained('gokaygokay/Florence-2-SD3-Captioner', trust_remote_code=True)
|
| 15 |
+
except Exception as e:
|
| 16 |
+
print(e)
|
| 17 |
+
fl_model = fl_processor = None
|
| 18 |
|
| 19 |
def fl_modify_caption(caption: str) -> str:
|
| 20 |
"""
|
|
|
|
| 45 |
return modified_caption if modified_caption != caption else caption
|
| 46 |
|
| 47 |
|
| 48 |
+
@spaces.GPU(duration=30)
|
| 49 |
def fl_run_example(image):
|
| 50 |
task_prompt = "<DESCRIPTION>"
|
| 51 |
prompt = task_prompt + "Describe this image in great detail."
|
|
|
|
| 54 |
if image.mode != "RGB":
|
| 55 |
image = image.convert("RGB")
|
| 56 |
|
| 57 |
+
fl_model.to(device)
|
| 58 |
inputs = fl_processor(text=prompt, images=image, return_tensors="pt").to(device)
|
| 59 |
generated_ids = fl_model.generate(
|
| 60 |
input_ids=inputs["input_ids"],
|
|
|
|
| 62 |
max_new_tokens=1024,
|
| 63 |
num_beams=3
|
| 64 |
)
|
| 65 |
+
fl_model.to("cpu")
|
| 66 |
generated_text = fl_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
| 67 |
parsed_answer = fl_processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
|
| 68 |
return fl_modify_caption(parsed_answer["<DESCRIPTION>"])
|
tagger/tagger.py
CHANGED
|
@@ -12,10 +12,15 @@ from pathlib import Path
|
|
| 12 |
WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
|
| 13 |
WD_MODEL_NAME = WD_MODEL_NAMES[0]
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
wd_processor = AutoImageProcessor.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
|
| 21 |
return (
|
|
@@ -506,7 +511,7 @@ def gen_prompt(rating: list[str], character: list[str], general: list[str]):
|
|
| 506 |
return ", ".join(all_tags)
|
| 507 |
|
| 508 |
|
| 509 |
-
@spaces.GPU()
|
| 510 |
def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
|
| 511 |
inputs = wd_processor.preprocess(image, return_tensors="pt")
|
| 512 |
|
|
@@ -514,9 +519,11 @@ def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_t
|
|
| 514 |
logits = torch.sigmoid(outputs.logits[0]) # take the first logits
|
| 515 |
|
| 516 |
# get probabilities
|
|
|
|
| 517 |
results = {
|
| 518 |
wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
|
| 519 |
}
|
|
|
|
| 520 |
# rating, character, general
|
| 521 |
rating, character, general = postprocess_results(
|
| 522 |
results, general_threshold, character_threshold
|
|
|
|
| 12 |
WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
|
| 13 |
WD_MODEL_NAME = WD_MODEL_NAMES[0]
|
| 14 |
|
| 15 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 16 |
+
default_device = device
|
|
|
|
| 17 |
|
| 18 |
+
try:
|
| 19 |
+
wd_model = AutoModelForImageClassification.from_pretrained(WD_MODEL_NAME, trust_remote_code=True).to(default_device).eval()
|
| 20 |
+
wd_processor = AutoImageProcessor.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
|
| 21 |
+
except Exception as e:
|
| 22 |
+
print(e)
|
| 23 |
+
wd_model = wd_processor = None
|
| 24 |
|
| 25 |
def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
|
| 26 |
return (
|
|
|
|
| 511 |
return ", ".join(all_tags)
|
| 512 |
|
| 513 |
|
| 514 |
+
@spaces.GPU(duration=30)
|
| 515 |
def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
|
| 516 |
inputs = wd_processor.preprocess(image, return_tensors="pt")
|
| 517 |
|
|
|
|
| 519 |
logits = torch.sigmoid(outputs.logits[0]) # take the first logits
|
| 520 |
|
| 521 |
# get probabilities
|
| 522 |
+
if device != default_device: wd_model.to(device=device)
|
| 523 |
results = {
|
| 524 |
wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
|
| 525 |
}
|
| 526 |
+
if device != default_device: wd_model.to(device=default_device)
|
| 527 |
# rating, character, general
|
| 528 |
rating, character, general = postprocess_results(
|
| 529 |
results, general_threshold, character_threshold
|