Spaces:
Running
Running
Added a "Remove tags" feature to exclude specified tags from the output.
Browse files
app.py
CHANGED
|
@@ -134,7 +134,7 @@ class Timer:
|
|
| 134 |
elapsed = curr_time - prev_time
|
| 135 |
print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
|
| 136 |
prev_time = curr_time
|
| 137 |
-
|
| 138 |
if is_clear_checkpoints:
|
| 139 |
self.checkpoints = [("Start", time.perf_counter())]
|
| 140 |
|
|
@@ -151,7 +151,7 @@ class Timer:
|
|
| 151 |
elapsed = curr_time - prev_time
|
| 152 |
print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
|
| 153 |
prev_time = curr_time
|
| 154 |
-
|
| 155 |
total_time = self.checkpoints[-1][1] - self.start_time
|
| 156 |
print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n")
|
| 157 |
|
|
@@ -252,7 +252,7 @@ class Llama3Reorganize:
|
|
| 252 |
import ctranslate2
|
| 253 |
import transformers
|
| 254 |
try:
|
| 255 |
-
print('\n\nLoading model:
|
| 256 |
kwargsTokenizer = {"pretrained_model_name_or_path": self.modelPath}
|
| 257 |
kwargsModel = {"device": self.device, "model_path": self.modelPath, "compute_type": "auto"}
|
| 258 |
self.roleSystem = {"role": "system", "content": self.system_prompt}
|
|
@@ -270,12 +270,11 @@ class Llama3Reorganize:
|
|
| 270 |
try:
|
| 271 |
import torch
|
| 272 |
if torch.cuda.is_available():
|
| 273 |
-
if
|
| 274 |
self.Model.unload_model()
|
| 275 |
-
|
| 276 |
-
if getattr(self, "Tokenizer", None) is not None:
|
| 277 |
del self.Tokenizer
|
| 278 |
-
if
|
| 279 |
del self.Model
|
| 280 |
import gc
|
| 281 |
gc.collect()
|
|
@@ -283,14 +282,13 @@ class Llama3Reorganize:
|
|
| 283 |
torch.cuda.empty_cache()
|
| 284 |
except Exception as e:
|
| 285 |
print(traceback.format_exc())
|
| 286 |
-
print("\tcuda empty cache, error: "
|
| 287 |
print("release vram end.")
|
| 288 |
except Exception as e:
|
| 289 |
print(traceback.format_exc())
|
| 290 |
-
print("Error release vram: "
|
| 291 |
|
| 292 |
def reorganize(self, text: str, max_length: int = 400):
|
| 293 |
-
output = None
|
| 294 |
result = None
|
| 295 |
try:
|
| 296 |
input_ids = self.Tokenizer.apply_chat_template([self.roleSystem, {"role": "user", "content": text + "\n\nHere's the reorganized English article:"}], tokenize=False, add_generation_prompt=True)
|
|
@@ -298,19 +296,18 @@ class Llama3Reorganize:
|
|
| 298 |
output = self.Model.generate_batch([source], max_length=max_length, max_batch_size=2, no_repeat_ngram_size=3, beam_size=2, sampling_temperature=0.7, sampling_topp=0.9, include_prompt_in_result=False, end_token=self.terminators)
|
| 299 |
target = output[0]
|
| 300 |
result = self.Tokenizer.decode(target.sequences_ids[0])
|
| 301 |
-
|
| 302 |
if len(result) > 2:
|
| 303 |
-
if result[0] == "
|
| 304 |
result = result[1:-1]
|
| 305 |
-
elif result[0] == "'" and result[
|
| 306 |
result = result[1:-1]
|
| 307 |
-
elif result[0] ==
|
| 308 |
result = result[1:-1]
|
| 309 |
-
elif result[0] ==
|
| 310 |
result = result[1:-1]
|
| 311 |
except Exception as e:
|
| 312 |
print(traceback.format_exc())
|
| 313 |
-
print("Error reorganize text: "
|
| 314 |
|
| 315 |
return result
|
| 316 |
|
|
@@ -339,28 +336,19 @@ class Predictor:
|
|
| 339 |
|
| 340 |
tags_df = pd.read_csv(csv_path)
|
| 341 |
sep_tags = load_labels(tags_df)
|
| 342 |
-
|
| 343 |
-
self.tag_names = sep_tags[0]
|
| 344 |
-
self.rating_indexes = sep_tags[1]
|
| 345 |
-
self.general_indexes = sep_tags[2]
|
| 346 |
-
self.character_indexes = sep_tags[3]
|
| 347 |
-
|
| 348 |
model = rt.InferenceSession(model_path)
|
| 349 |
-
_, height,
|
| 350 |
self.model_target_size = height
|
| 351 |
-
|
| 352 |
self.last_loaded_repo = model_repo
|
| 353 |
self.model = model
|
| 354 |
|
| 355 |
def prepare_image(self, path):
|
| 356 |
-
image = Image.open(path)
|
| 357 |
-
image = image.convert("RGBA")
|
| 358 |
-
target_size = self.model_target_size
|
| 359 |
-
|
| 360 |
canvas = Image.new("RGBA", image.size, (255, 255, 255))
|
| 361 |
canvas.alpha_composite(image)
|
| 362 |
image = canvas.convert("RGB")
|
| 363 |
-
|
| 364 |
# Pad image to square
|
| 365 |
image_shape = image.size
|
| 366 |
max_dim = max(image_shape)
|
|
@@ -369,14 +357,14 @@ class Predictor:
|
|
| 369 |
|
| 370 |
padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
|
| 371 |
padded_image.paste(image, (pad_left, pad_top))
|
| 372 |
-
|
| 373 |
# Resize
|
| 374 |
-
if max_dim !=
|
| 375 |
padded_image = padded_image.resize(
|
| 376 |
-
(
|
| 377 |
Image.BICUBIC,
|
| 378 |
)
|
| 379 |
-
|
| 380 |
# Convert to numpy array
|
| 381 |
image_array = np.asarray(padded_image, dtype=np.float32)
|
| 382 |
|
|
@@ -404,6 +392,7 @@ class Predictor:
|
|
| 404 |
llama3_reorganize_model_repo,
|
| 405 |
additional_tags_prepend,
|
| 406 |
additional_tags_append,
|
|
|
|
| 407 |
tag_results,
|
| 408 |
progress=gr.Progress()
|
| 409 |
):
|
|
@@ -413,7 +402,7 @@ class Predictor:
|
|
| 413 |
|
| 414 |
gallery_len = len(gallery)
|
| 415 |
print(f"Predict from images: load model: {model_repo}, gallery length: {gallery_len}")
|
| 416 |
-
|
| 417 |
timer = Timer() # Create a timer
|
| 418 |
progressRatio = 0.5 if llama3_reorganize_model_repo else 1
|
| 419 |
progressTotal = gallery_len + (1 if llama3_reorganize_model_repo else 0) + 1 # +1 for model load
|
|
@@ -423,7 +412,7 @@ class Predictor:
|
|
| 423 |
current_progress += 1 / progressTotal
|
| 424 |
progress(current_progress, desc="Initialize wd model finished")
|
| 425 |
timer.checkpoint(f"Initialize wd model")
|
| 426 |
-
|
| 427 |
# Result
|
| 428 |
txt_infos = []
|
| 429 |
output_dir = tempfile.mkdtemp()
|
|
@@ -439,14 +428,15 @@ class Predictor:
|
|
| 439 |
current_progress += 1 / progressTotal
|
| 440 |
progress(current_progress, desc="Initialize llama3 model finished")
|
| 441 |
timer.checkpoint(f"Initialize llama3 model")
|
| 442 |
-
|
| 443 |
timer.report()
|
| 444 |
|
| 445 |
prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
|
| 446 |
append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
|
|
|
|
| 447 |
if prepend_list and append_list:
|
| 448 |
append_list = [item for item in append_list if item not in prepend_list]
|
| 449 |
-
|
| 450 |
# Dictionary to track counters for each filename
|
| 451 |
name_counters = defaultdict(int)
|
| 452 |
for idx, value in enumerate(gallery):
|
|
@@ -467,11 +457,11 @@ class Predictor:
|
|
| 467 |
preds = self.model.run([label_name], {input_name: image})[0]
|
| 468 |
|
| 469 |
labels = list(zip(self.tag_names, preds[0].astype(float)))
|
| 470 |
-
|
| 471 |
# First 4 labels are actually ratings: pick one with argmax
|
| 472 |
ratings_names = [labels[i] for i in self.rating_indexes]
|
| 473 |
rating = dict(ratings_names)
|
| 474 |
-
|
| 475 |
# Then we have general tags: pick any where prediction confidence > threshold
|
| 476 |
general_names = [labels[i] for i in self.general_indexes]
|
| 477 |
|
|
@@ -479,7 +469,7 @@ class Predictor:
|
|
| 479 |
general_probs = np.array([x[1] for x in general_names])
|
| 480 |
general_thresh = mcut_threshold(general_probs)
|
| 481 |
general_res = dict([x for x in general_names if x[1] > general_thresh])
|
| 482 |
-
|
| 483 |
# Everything else is characters: pick any where prediction confidence > threshold
|
| 484 |
character_names = [labels[i] for i in self.character_indexes]
|
| 485 |
|
|
@@ -503,7 +493,12 @@ class Predictor:
|
|
| 503 |
final_tags_list = prepend_list + sorted_general_list + append_list
|
| 504 |
if characters_merge_enabled:
|
| 505 |
final_tags_list = character_list + final_tags_list
|
| 506 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
sorted_general_strings = ", ".join(final_tags_list).replace("(", "\(").replace(")", "\)")
|
| 508 |
classified_tags, unclassified_tags = classify_tags(final_tags_list)
|
| 509 |
|
|
@@ -553,23 +548,24 @@ class Predictor:
|
|
| 553 |
# Get file name from lookup
|
| 554 |
taggers_zip.write(info["path"], arcname=info["name"])
|
| 555 |
download.append(downloadZipPath)
|
| 556 |
-
|
| 557 |
if llama3_reorganize:
|
| 558 |
llama3_reorganize.release_vram()
|
| 559 |
-
|
| 560 |
progress(1, desc="Image processing completed")
|
| 561 |
timer.report_all()
|
| 562 |
print("Image prediction is complete.")
|
| 563 |
|
| 564 |
return download, last_sorted_general_strings, last_classified_tags, last_rating, last_character_res, last_general_res, last_unclassified_tags, tag_results
|
| 565 |
-
|
| 566 |
-
#
|
| 567 |
def predict_from_text(
|
| 568 |
self,
|
| 569 |
text_files,
|
| 570 |
llama3_reorganize_model_repo,
|
| 571 |
additional_tags_prepend,
|
| 572 |
additional_tags_append,
|
|
|
|
| 573 |
progress=gr.Progress()
|
| 574 |
):
|
| 575 |
if not text_files:
|
|
@@ -583,7 +579,7 @@ class Predictor:
|
|
| 583 |
progressRatio = 0.5 if llama3_reorganize_model_repo else 1.0
|
| 584 |
progressTotal = files_len + (1 if llama3_reorganize_model_repo else 0)
|
| 585 |
current_progress = 0
|
| 586 |
-
|
| 587 |
txt_infos = []
|
| 588 |
output_dir = tempfile.mkdtemp()
|
| 589 |
last_processed_string = ""
|
|
@@ -600,6 +596,7 @@ class Predictor:
|
|
| 600 |
|
| 601 |
prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
|
| 602 |
append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
|
|
|
|
| 603 |
if prepend_list and append_list:
|
| 604 |
append_list = [item for item in append_list if item not in prepend_list]
|
| 605 |
|
|
@@ -608,7 +605,7 @@ class Predictor:
|
|
| 608 |
try:
|
| 609 |
file_path = file_obj.name
|
| 610 |
file_name_base = os.path.splitext(os.path.basename(file_path))[0]
|
| 611 |
-
|
| 612 |
name_counters[file_name_base] += 1
|
| 613 |
if name_counters[file_name_base] > 1:
|
| 614 |
output_file_name = f"{file_name_base}_{name_counters[file_name_base]:02d}.txt"
|
|
@@ -617,16 +614,22 @@ class Predictor:
|
|
| 617 |
|
| 618 |
with open(file_path, 'r', encoding='utf-8') as f:
|
| 619 |
original_content = f.read()
|
| 620 |
-
|
| 621 |
# Process tags
|
| 622 |
tags_list = [tag.strip() for tag in original_content.split(',') if tag.strip()]
|
| 623 |
-
|
| 624 |
if prepend_list:
|
| 625 |
tags_list = [item for item in tags_list if item not in prepend_list]
|
| 626 |
if append_list:
|
| 627 |
tags_list = [item for item in tags_list if item not in append_list]
|
| 628 |
|
| 629 |
final_tags_list = prepend_list + tags_list + append_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 630 |
processed_string = ", ".join(final_tags_list)
|
| 631 |
|
| 632 |
current_progress += progressRatio / progressTotal
|
|
@@ -645,7 +648,7 @@ class Predictor:
|
|
| 645 |
current_progress += progressRatio / progressTotal
|
| 646 |
progress(current_progress, desc=f"File {idx+1}/{files_len}, llama3 reorganize finished")
|
| 647 |
timer.checkpoint(f"File {idx+1}/{files_len}, llama3 reorganize finished")
|
| 648 |
-
|
| 649 |
txt_file_path = self.create_file(processed_string, output_dir, output_file_name)
|
| 650 |
txt_infos.append({"path": txt_file_path, "name": output_file_name})
|
| 651 |
last_processed_string = processed_string
|
|
@@ -671,7 +674,7 @@ class Predictor:
|
|
| 671 |
progress(1, desc="Text processing completed")
|
| 672 |
timer.report_all() # Print all recorded times
|
| 673 |
print("Text processing is complete.")
|
| 674 |
-
|
| 675 |
# Return values in the same structure as the image path, with placeholders for unused outputs
|
| 676 |
return download, last_processed_string, "{}", "", "", "", "{}", {}
|
| 677 |
|
|
@@ -679,9 +682,8 @@ def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state:
|
|
| 679 |
if not selected_state:
|
| 680 |
return selected_state
|
| 681 |
|
| 682 |
-
tag_result =
|
| 683 |
-
|
| 684 |
-
tag_result = tag_results[selected_state.value["image"]["path"]]
|
| 685 |
|
| 686 |
return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], tag_result["classified_tags"], tag_result["rating"], tag_result["character_res"], tag_result["general_res"], tag_result["unclassified_tags"]
|
| 687 |
|
|
@@ -690,7 +692,7 @@ def append_gallery(gallery: list, image: str):
|
|
| 690 |
gallery = []
|
| 691 |
if not image:
|
| 692 |
return gallery, None
|
| 693 |
-
|
| 694 |
gallery.append(image)
|
| 695 |
|
| 696 |
return gallery, None
|
|
@@ -712,14 +714,14 @@ def remove_image_from_gallery(gallery: list, selected_image: str):
|
|
| 712 |
return gallery
|
| 713 |
|
| 714 |
try:
|
| 715 |
-
|
| 716 |
# Remove the selected image from the gallery
|
| 717 |
-
if
|
| 718 |
-
gallery.remove(
|
| 719 |
except (ValueError, SyntaxError):
|
| 720 |
# Handle cases where the string is not a valid literal
|
| 721 |
print(f"Warning: Could not parse selected_image string: {selected_image}")
|
| 722 |
-
|
| 723 |
return gallery
|
| 724 |
|
| 725 |
|
|
@@ -751,32 +753,33 @@ def main():
|
|
| 751 |
SWINV2_MODEL_IS_DSV1_REPO,
|
| 752 |
EVA02_LARGE_MODEL_IS_DSV1_REPO,
|
| 753 |
]
|
| 754 |
-
|
| 755 |
llama_list = [
|
| 756 |
META_LLAMA_3_3B_REPO,
|
| 757 |
META_LLAMA_3_8B_REPO,
|
| 758 |
]
|
| 759 |
-
|
| 760 |
-
#
|
| 761 |
def run_prediction(
|
| 762 |
input_type, gallery, text_files, model_repo, general_thresh,
|
| 763 |
general_mcut_enabled, character_thresh, character_mcut_enabled,
|
| 764 |
characters_merge_enabled, llama3_reorganize_model_repo,
|
| 765 |
-
additional_tags_prepend, additional_tags_append,
|
|
|
|
| 766 |
):
|
| 767 |
if input_type == 'Image':
|
| 768 |
return predictor.predict_from_images(
|
| 769 |
gallery, model_repo, general_thresh, general_mcut_enabled,
|
| 770 |
character_thresh, character_mcut_enabled, characters_merge_enabled,
|
| 771 |
llama3_reorganize_model_repo, additional_tags_prepend,
|
| 772 |
-
additional_tags_append, tag_results, progress
|
| 773 |
)
|
| 774 |
else: # 'Text file (.txt)'
|
| 775 |
# For text files, some parameters are not used, but we must return
|
| 776 |
# a tuple of the same size. `predict_from_text` handles this.
|
| 777 |
return predictor.predict_from_text(
|
| 778 |
text_files, llama3_reorganize_model_repo,
|
| 779 |
-
additional_tags_prepend, additional_tags_append, progress
|
| 780 |
)
|
| 781 |
|
| 782 |
with gr.Blocks(title=TITLE, css=css) as demo:
|
|
@@ -793,7 +796,7 @@ def main():
|
|
| 793 |
value='Image',
|
| 794 |
label="Input Type"
|
| 795 |
)
|
| 796 |
-
|
| 797 |
# Group for image inputs, initially visible
|
| 798 |
with gr.Column(visible=True) as image_inputs_group:
|
| 799 |
with gr.Column(variant="panel"):
|
|
@@ -803,8 +806,8 @@ def main():
|
|
| 803 |
upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
|
| 804 |
remove_button = gr.Button("Remove Selected Image", size="sm")
|
| 805 |
gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Gallery that displaying a grid of images")
|
| 806 |
-
|
| 807 |
-
#
|
| 808 |
with gr.Column(visible=False) as text_inputs_group:
|
| 809 |
text_files_input = gr.Files(
|
| 810 |
label="Upload .txt files",
|
|
@@ -813,24 +816,6 @@ def main():
|
|
| 813 |
height=500
|
| 814 |
)
|
| 815 |
|
| 816 |
-
# NEW: Logic to show/hide input groups based on radio selection
|
| 817 |
-
def change_input_type(input_type):
|
| 818 |
-
is_image = (input_type == 'Image')
|
| 819 |
-
return {
|
| 820 |
-
image_inputs_group: gr.update(visible=is_image),
|
| 821 |
-
text_inputs_group: gr.update(visible=not is_image),
|
| 822 |
-
# Also update visibility of image-specific settings
|
| 823 |
-
model_repo: gr.update(visible=is_image),
|
| 824 |
-
general_thresh_row: gr.update(visible=is_image),
|
| 825 |
-
character_thresh_row: gr.update(visible=is_image),
|
| 826 |
-
characters_merge_enabled: gr.update(visible=is_image),
|
| 827 |
-
categorized: gr.update(visible=is_image),
|
| 828 |
-
rating: gr.update(visible=is_image),
|
| 829 |
-
character_res: gr.update(visible=is_image),
|
| 830 |
-
general_res: gr.update(visible=is_image),
|
| 831 |
-
unclassified: gr.update(visible=is_image),
|
| 832 |
-
}
|
| 833 |
-
|
| 834 |
# Image-specific settings
|
| 835 |
model_repo = gr.Dropdown(
|
| 836 |
dropdown_list,
|
|
@@ -883,6 +868,10 @@ def main():
|
|
| 883 |
with gr.Row():
|
| 884 |
additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
|
| 885 |
additional_tags_append = gr.Text(label="Append Additional tags (comma split)")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 886 |
with gr.Row():
|
| 887 |
clear = gr.ClearButton(
|
| 888 |
components=[
|
|
@@ -897,6 +886,7 @@ def main():
|
|
| 897 |
llama3_reorganize_model_repo,
|
| 898 |
additional_tags_prepend,
|
| 899 |
additional_tags_append,
|
|
|
|
| 900 |
],
|
| 901 |
variant="secondary",
|
| 902 |
size="lg",
|
|
@@ -935,7 +925,25 @@ def main():
|
|
| 935 |
gallery.select(get_selection_from_gallery, inputs=[gallery, tag_results], outputs=[selected_image, sorted_general_strings, categorized, rating, character_res, general_res, unclassified])
|
| 936 |
# Event to remove a selected image from the gallery
|
| 937 |
remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
|
| 938 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 939 |
# Connect the radio button to the visibility function
|
| 940 |
input_type_radio.change(
|
| 941 |
fn=change_input_type,
|
|
@@ -946,7 +954,7 @@ def main():
|
|
| 946 |
categorized, rating, character_res, general_res, unclassified
|
| 947 |
]
|
| 948 |
)
|
| 949 |
-
|
| 950 |
# submit click now calls the wrapper function
|
| 951 |
submit.click(
|
| 952 |
fn=run_prediction,
|
|
@@ -963,6 +971,7 @@ def main():
|
|
| 963 |
llama3_reorganize_model_repo,
|
| 964 |
additional_tags_prepend,
|
| 965 |
additional_tags_append,
|
|
|
|
| 966 |
tag_results,
|
| 967 |
],
|
| 968 |
outputs=[download_file, sorted_general_strings, categorized, rating, character_res, general_res, unclassified, tag_results,],
|
|
|
|
| 134 |
elapsed = curr_time - prev_time
|
| 135 |
print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
|
| 136 |
prev_time = curr_time
|
| 137 |
+
|
| 138 |
if is_clear_checkpoints:
|
| 139 |
self.checkpoints = [("Start", time.perf_counter())]
|
| 140 |
|
|
|
|
| 151 |
elapsed = curr_time - prev_time
|
| 152 |
print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
|
| 153 |
prev_time = curr_time
|
| 154 |
+
|
| 155 |
total_time = self.checkpoints[-1][1] - self.start_time
|
| 156 |
print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n")
|
| 157 |
|
|
|
|
| 252 |
import ctranslate2
|
| 253 |
import transformers
|
| 254 |
try:
|
| 255 |
+
print(f'\n\nLoading model: {self.modelPath}\n\n')
|
| 256 |
kwargsTokenizer = {"pretrained_model_name_or_path": self.modelPath}
|
| 257 |
kwargsModel = {"device": self.device, "model_path": self.modelPath, "compute_type": "auto"}
|
| 258 |
self.roleSystem = {"role": "system", "content": self.system_prompt}
|
|
|
|
| 270 |
try:
|
| 271 |
import torch
|
| 272 |
if torch.cuda.is_available():
|
| 273 |
+
if hasattr(self, "Model") and hasattr(self.Model, "unload_model"):
|
| 274 |
self.Model.unload_model()
|
| 275 |
+
if hasattr(self, "Tokenizer"):
|
|
|
|
| 276 |
del self.Tokenizer
|
| 277 |
+
if hasattr(self, "Model"):
|
| 278 |
del self.Model
|
| 279 |
import gc
|
| 280 |
gc.collect()
|
|
|
|
| 282 |
torch.cuda.empty_cache()
|
| 283 |
except Exception as e:
|
| 284 |
print(traceback.format_exc())
|
| 285 |
+
print(f"\tcuda empty cache, error: {e}")
|
| 286 |
print("release vram end.")
|
| 287 |
except Exception as e:
|
| 288 |
print(traceback.format_exc())
|
| 289 |
+
print(f"Error release vram: {e}")
|
| 290 |
|
| 291 |
def reorganize(self, text: str, max_length: int = 400):
|
|
|
|
| 292 |
result = None
|
| 293 |
try:
|
| 294 |
input_ids = self.Tokenizer.apply_chat_template([self.roleSystem, {"role": "user", "content": text + "\n\nHere's the reorganized English article:"}], tokenize=False, add_generation_prompt=True)
|
|
|
|
| 296 |
output = self.Model.generate_batch([source], max_length=max_length, max_batch_size=2, no_repeat_ngram_size=3, beam_size=2, sampling_temperature=0.7, sampling_topp=0.9, include_prompt_in_result=False, end_token=self.terminators)
|
| 297 |
target = output[0]
|
| 298 |
result = self.Tokenizer.decode(target.sequences_ids[0])
|
|
|
|
| 299 |
if len(result) > 2:
|
| 300 |
+
if result[0] == '"' and result[-1] == '"':
|
| 301 |
result = result[1:-1]
|
| 302 |
+
elif result[0] == "'" and result[-1] == "'":
|
| 303 |
result = result[1:-1]
|
| 304 |
+
elif result[0] == '「' and result[-1] == '」':
|
| 305 |
result = result[1:-1]
|
| 306 |
+
elif result[0] == '『' and result[-1] == '』':
|
| 307 |
result = result[1:-1]
|
| 308 |
except Exception as e:
|
| 309 |
print(traceback.format_exc())
|
| 310 |
+
print(f"Error reorganize text: {e}")
|
| 311 |
|
| 312 |
return result
|
| 313 |
|
|
|
|
| 336 |
|
| 337 |
tags_df = pd.read_csv(csv_path)
|
| 338 |
sep_tags = load_labels(tags_df)
|
| 339 |
+
self.tag_names, self.rating_indexes, self.general_indexes, self.character_indexes = sep_tags
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
model = rt.InferenceSession(model_path)
|
| 341 |
+
_, height, _, _ = model.get_inputs()[0].shape
|
| 342 |
self.model_target_size = height
|
|
|
|
| 343 |
self.last_loaded_repo = model_repo
|
| 344 |
self.model = model
|
| 345 |
|
| 346 |
def prepare_image(self, path):
|
| 347 |
+
image = Image.open(path).convert("RGBA")
|
|
|
|
|
|
|
|
|
|
| 348 |
canvas = Image.new("RGBA", image.size, (255, 255, 255))
|
| 349 |
canvas.alpha_composite(image)
|
| 350 |
image = canvas.convert("RGB")
|
| 351 |
+
|
| 352 |
# Pad image to square
|
| 353 |
image_shape = image.size
|
| 354 |
max_dim = max(image_shape)
|
|
|
|
| 357 |
|
| 358 |
padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
|
| 359 |
padded_image.paste(image, (pad_left, pad_top))
|
| 360 |
+
|
| 361 |
# Resize
|
| 362 |
+
if max_dim != self.model_target_size:
|
| 363 |
padded_image = padded_image.resize(
|
| 364 |
+
(self.model_target_size, self.model_target_size),
|
| 365 |
Image.BICUBIC,
|
| 366 |
)
|
| 367 |
+
|
| 368 |
# Convert to numpy array
|
| 369 |
image_array = np.asarray(padded_image, dtype=np.float32)
|
| 370 |
|
|
|
|
| 392 |
llama3_reorganize_model_repo,
|
| 393 |
additional_tags_prepend,
|
| 394 |
additional_tags_append,
|
| 395 |
+
tags_to_remove,
|
| 396 |
tag_results,
|
| 397 |
progress=gr.Progress()
|
| 398 |
):
|
|
|
|
| 402 |
|
| 403 |
gallery_len = len(gallery)
|
| 404 |
print(f"Predict from images: load model: {model_repo}, gallery length: {gallery_len}")
|
| 405 |
+
|
| 406 |
timer = Timer() # Create a timer
|
| 407 |
progressRatio = 0.5 if llama3_reorganize_model_repo else 1
|
| 408 |
progressTotal = gallery_len + (1 if llama3_reorganize_model_repo else 0) + 1 # +1 for model load
|
|
|
|
| 412 |
current_progress += 1 / progressTotal
|
| 413 |
progress(current_progress, desc="Initialize wd model finished")
|
| 414 |
timer.checkpoint(f"Initialize wd model")
|
| 415 |
+
|
| 416 |
# Result
|
| 417 |
txt_infos = []
|
| 418 |
output_dir = tempfile.mkdtemp()
|
|
|
|
| 428 |
current_progress += 1 / progressTotal
|
| 429 |
progress(current_progress, desc="Initialize llama3 model finished")
|
| 430 |
timer.checkpoint(f"Initialize llama3 model")
|
| 431 |
+
|
| 432 |
timer.report()
|
| 433 |
|
| 434 |
prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
|
| 435 |
append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
|
| 436 |
+
remove_list = [tag.strip() for tag in tags_to_remove.split(",") if tag.strip()] # Parse remove tags
|
| 437 |
if prepend_list and append_list:
|
| 438 |
append_list = [item for item in append_list if item not in prepend_list]
|
| 439 |
+
|
| 440 |
# Dictionary to track counters for each filename
|
| 441 |
name_counters = defaultdict(int)
|
| 442 |
for idx, value in enumerate(gallery):
|
|
|
|
| 457 |
preds = self.model.run([label_name], {input_name: image})[0]
|
| 458 |
|
| 459 |
labels = list(zip(self.tag_names, preds[0].astype(float)))
|
| 460 |
+
|
| 461 |
# First 4 labels are actually ratings: pick one with argmax
|
| 462 |
ratings_names = [labels[i] for i in self.rating_indexes]
|
| 463 |
rating = dict(ratings_names)
|
| 464 |
+
|
| 465 |
# Then we have general tags: pick any where prediction confidence > threshold
|
| 466 |
general_names = [labels[i] for i in self.general_indexes]
|
| 467 |
|
|
|
|
| 469 |
general_probs = np.array([x[1] for x in general_names])
|
| 470 |
general_thresh = mcut_threshold(general_probs)
|
| 471 |
general_res = dict([x for x in general_names if x[1] > general_thresh])
|
| 472 |
+
|
| 473 |
# Everything else is characters: pick any where prediction confidence > threshold
|
| 474 |
character_names = [labels[i] for i in self.character_indexes]
|
| 475 |
|
|
|
|
| 493 |
final_tags_list = prepend_list + sorted_general_list + append_list
|
| 494 |
if characters_merge_enabled:
|
| 495 |
final_tags_list = character_list + final_tags_list
|
| 496 |
+
|
| 497 |
+
# Apply removal logic
|
| 498 |
+
if remove_list:
|
| 499 |
+
remove_set = set(remove_list)
|
| 500 |
+
final_tags_list = [tag for tag in final_tags_list if tag not in remove_set]
|
| 501 |
+
|
| 502 |
sorted_general_strings = ", ".join(final_tags_list).replace("(", "\(").replace(")", "\)")
|
| 503 |
classified_tags, unclassified_tags = classify_tags(final_tags_list)
|
| 504 |
|
|
|
|
| 548 |
# Get file name from lookup
|
| 549 |
taggers_zip.write(info["path"], arcname=info["name"])
|
| 550 |
download.append(downloadZipPath)
|
| 551 |
+
|
| 552 |
if llama3_reorganize:
|
| 553 |
llama3_reorganize.release_vram()
|
| 554 |
+
|
| 555 |
progress(1, desc="Image processing completed")
|
| 556 |
timer.report_all()
|
| 557 |
print("Image prediction is complete.")
|
| 558 |
|
| 559 |
return download, last_sorted_general_strings, last_classified_tags, last_rating, last_character_res, last_general_res, last_unclassified_tags, tag_results
|
| 560 |
+
|
| 561 |
+
# Method to process text files
|
| 562 |
def predict_from_text(
|
| 563 |
self,
|
| 564 |
text_files,
|
| 565 |
llama3_reorganize_model_repo,
|
| 566 |
additional_tags_prepend,
|
| 567 |
additional_tags_append,
|
| 568 |
+
tags_to_remove,
|
| 569 |
progress=gr.Progress()
|
| 570 |
):
|
| 571 |
if not text_files:
|
|
|
|
| 579 |
progressRatio = 0.5 if llama3_reorganize_model_repo else 1.0
|
| 580 |
progressTotal = files_len + (1 if llama3_reorganize_model_repo else 0)
|
| 581 |
current_progress = 0
|
| 582 |
+
|
| 583 |
txt_infos = []
|
| 584 |
output_dir = tempfile.mkdtemp()
|
| 585 |
last_processed_string = ""
|
|
|
|
| 596 |
|
| 597 |
prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
|
| 598 |
append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
|
| 599 |
+
remove_list = [tag.strip() for tag in tags_to_remove.split(",") if tag.strip()] # Parse remove tags
|
| 600 |
if prepend_list and append_list:
|
| 601 |
append_list = [item for item in append_list if item not in prepend_list]
|
| 602 |
|
|
|
|
| 605 |
try:
|
| 606 |
file_path = file_obj.name
|
| 607 |
file_name_base = os.path.splitext(os.path.basename(file_path))[0]
|
| 608 |
+
|
| 609 |
name_counters[file_name_base] += 1
|
| 610 |
if name_counters[file_name_base] > 1:
|
| 611 |
output_file_name = f"{file_name_base}_{name_counters[file_name_base]:02d}.txt"
|
|
|
|
| 614 |
|
| 615 |
with open(file_path, 'r', encoding='utf-8') as f:
|
| 616 |
original_content = f.read()
|
| 617 |
+
|
| 618 |
# Process tags
|
| 619 |
tags_list = [tag.strip() for tag in original_content.split(',') if tag.strip()]
|
| 620 |
+
|
| 621 |
if prepend_list:
|
| 622 |
tags_list = [item for item in tags_list if item not in prepend_list]
|
| 623 |
if append_list:
|
| 624 |
tags_list = [item for item in tags_list if item not in append_list]
|
| 625 |
|
| 626 |
final_tags_list = prepend_list + tags_list + append_list
|
| 627 |
+
|
| 628 |
+
# Apply removal logic
|
| 629 |
+
if remove_list:
|
| 630 |
+
remove_set = set(remove_list)
|
| 631 |
+
final_tags_list = [tag for tag in final_tags_list if tag not in remove_set]
|
| 632 |
+
|
| 633 |
processed_string = ", ".join(final_tags_list)
|
| 634 |
|
| 635 |
current_progress += progressRatio / progressTotal
|
|
|
|
| 648 |
current_progress += progressRatio / progressTotal
|
| 649 |
progress(current_progress, desc=f"File {idx+1}/{files_len}, llama3 reorganize finished")
|
| 650 |
timer.checkpoint(f"File {idx+1}/{files_len}, llama3 reorganize finished")
|
| 651 |
+
|
| 652 |
txt_file_path = self.create_file(processed_string, output_dir, output_file_name)
|
| 653 |
txt_infos.append({"path": txt_file_path, "name": output_file_name})
|
| 654 |
last_processed_string = processed_string
|
|
|
|
| 674 |
progress(1, desc="Text processing completed")
|
| 675 |
timer.report_all() # Print all recorded times
|
| 676 |
print("Text processing is complete.")
|
| 677 |
+
|
| 678 |
# Return values in the same structure as the image path, with placeholders for unused outputs
|
| 679 |
return download, last_processed_string, "{}", "", "", "", "{}", {}
|
| 680 |
|
|
|
|
| 682 |
if not selected_state:
|
| 683 |
return selected_state
|
| 684 |
|
| 685 |
+
tag_result = tag_results.get(selected_state.value["image"]["path"],
|
| 686 |
+
{"strings": "", "classified_tags": "{}", "rating": "", "character_res": "", "general_res": "", "unclassified_tags": "{}"})
|
|
|
|
| 687 |
|
| 688 |
return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], tag_result["classified_tags"], tag_result["rating"], tag_result["character_res"], tag_result["general_res"], tag_result["unclassified_tags"]
|
| 689 |
|
|
|
|
| 692 |
gallery = []
|
| 693 |
if not image:
|
| 694 |
return gallery, None
|
| 695 |
+
|
| 696 |
gallery.append(image)
|
| 697 |
|
| 698 |
return gallery, None
|
|
|
|
| 714 |
return gallery
|
| 715 |
|
| 716 |
try:
|
| 717 |
+
selected_image_tuple = ast.literal_eval(selected_image) #Use ast.literal_eval to parse text into a tuple.
|
| 718 |
# Remove the selected image from the gallery
|
| 719 |
+
if selected_image_tuple in gallery:
|
| 720 |
+
gallery.remove(selected_image_tuple)
|
| 721 |
except (ValueError, SyntaxError):
|
| 722 |
# Handle cases where the string is not a valid literal
|
| 723 |
print(f"Warning: Could not parse selected_image string: {selected_image}")
|
| 724 |
+
|
| 725 |
return gallery
|
| 726 |
|
| 727 |
|
|
|
|
| 753 |
SWINV2_MODEL_IS_DSV1_REPO,
|
| 754 |
EVA02_LARGE_MODEL_IS_DSV1_REPO,
|
| 755 |
]
|
| 756 |
+
|
| 757 |
llama_list = [
|
| 758 |
META_LLAMA_3_3B_REPO,
|
| 759 |
META_LLAMA_3_8B_REPO,
|
| 760 |
]
|
| 761 |
+
|
| 762 |
+
# Wrapper function to decide which prediction method to call
|
| 763 |
def run_prediction(
|
| 764 |
input_type, gallery, text_files, model_repo, general_thresh,
|
| 765 |
general_mcut_enabled, character_thresh, character_mcut_enabled,
|
| 766 |
characters_merge_enabled, llama3_reorganize_model_repo,
|
| 767 |
+
additional_tags_prepend, additional_tags_append, tags_to_remove,
|
| 768 |
+
tag_results, progress=gr.Progress()
|
| 769 |
):
|
| 770 |
if input_type == 'Image':
|
| 771 |
return predictor.predict_from_images(
|
| 772 |
gallery, model_repo, general_thresh, general_mcut_enabled,
|
| 773 |
character_thresh, character_mcut_enabled, characters_merge_enabled,
|
| 774 |
llama3_reorganize_model_repo, additional_tags_prepend,
|
| 775 |
+
additional_tags_append, tags_to_remove, tag_results, progress
|
| 776 |
)
|
| 777 |
else: # 'Text file (.txt)'
|
| 778 |
# For text files, some parameters are not used, but we must return
|
| 779 |
# a tuple of the same size. `predict_from_text` handles this.
|
| 780 |
return predictor.predict_from_text(
|
| 781 |
text_files, llama3_reorganize_model_repo,
|
| 782 |
+
additional_tags_prepend, additional_tags_append, tags_to_remove, progress
|
| 783 |
)
|
| 784 |
|
| 785 |
with gr.Blocks(title=TITLE, css=css) as demo:
|
|
|
|
| 796 |
value='Image',
|
| 797 |
label="Input Type"
|
| 798 |
)
|
| 799 |
+
|
| 800 |
# Group for image inputs, initially visible
|
| 801 |
with gr.Column(visible=True) as image_inputs_group:
|
| 802 |
with gr.Column(variant="panel"):
|
|
|
|
| 806 |
upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
|
| 807 |
remove_button = gr.Button("Remove Selected Image", size="sm")
|
| 808 |
gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Gallery that displaying a grid of images")
|
| 809 |
+
|
| 810 |
+
# Group for text file inputs, initially hidden
|
| 811 |
with gr.Column(visible=False) as text_inputs_group:
|
| 812 |
text_files_input = gr.Files(
|
| 813 |
label="Upload .txt files",
|
|
|
|
| 816 |
height=500
|
| 817 |
)
|
| 818 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 819 |
# Image-specific settings
|
| 820 |
model_repo = gr.Dropdown(
|
| 821 |
dropdown_list,
|
|
|
|
| 868 |
with gr.Row():
|
| 869 |
additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
|
| 870 |
additional_tags_append = gr.Text(label="Append Additional tags (comma split)")
|
| 871 |
+
|
| 872 |
+
# NEW: Add the remove tags input box
|
| 873 |
+
tags_to_remove = gr.Text(label="Remove tags (comma split)")
|
| 874 |
+
|
| 875 |
with gr.Row():
|
| 876 |
clear = gr.ClearButton(
|
| 877 |
components=[
|
|
|
|
| 886 |
llama3_reorganize_model_repo,
|
| 887 |
additional_tags_prepend,
|
| 888 |
additional_tags_append,
|
| 889 |
+
tags_to_remove,
|
| 890 |
],
|
| 891 |
variant="secondary",
|
| 892 |
size="lg",
|
|
|
|
| 925 |
gallery.select(get_selection_from_gallery, inputs=[gallery, tag_results], outputs=[selected_image, sorted_general_strings, categorized, rating, character_res, general_res, unclassified])
|
| 926 |
# Event to remove a selected image from the gallery
|
| 927 |
remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
|
| 928 |
+
|
| 929 |
+
# Logic to show/hide input groups based on radio selection
|
| 930 |
+
def change_input_type(input_type):
|
| 931 |
+
is_image = (input_type == 'Image')
|
| 932 |
+
return {
|
| 933 |
+
image_inputs_group: gr.update(visible=is_image),
|
| 934 |
+
text_inputs_group: gr.update(visible=not is_image),
|
| 935 |
+
# Also update visibility of image-specific settings
|
| 936 |
+
model_repo: gr.update(visible=is_image),
|
| 937 |
+
general_thresh_row: gr.update(visible=is_image),
|
| 938 |
+
character_thresh_row: gr.update(visible=is_image),
|
| 939 |
+
characters_merge_enabled: gr.update(visible=is_image),
|
| 940 |
+
categorized: gr.update(visible=is_image),
|
| 941 |
+
rating: gr.update(visible=is_image),
|
| 942 |
+
character_res: gr.update(visible=is_image),
|
| 943 |
+
general_res: gr.update(visible=is_image),
|
| 944 |
+
unclassified: gr.update(visible=is_image),
|
| 945 |
+
}
|
| 946 |
+
|
| 947 |
# Connect the radio button to the visibility function
|
| 948 |
input_type_radio.change(
|
| 949 |
fn=change_input_type,
|
|
|
|
| 954 |
categorized, rating, character_res, general_res, unclassified
|
| 955 |
]
|
| 956 |
)
|
| 957 |
+
|
| 958 |
# submit click now calls the wrapper function
|
| 959 |
submit.click(
|
| 960 |
fn=run_prediction,
|
|
|
|
| 971 |
llama3_reorganize_model_repo,
|
| 972 |
additional_tags_prepend,
|
| 973 |
additional_tags_append,
|
| 974 |
+
tags_to_remove,
|
| 975 |
tag_results,
|
| 976 |
],
|
| 977 |
outputs=[download_file, sorted_general_strings, categorized, rating, character_res, general_res, unclassified, tag_results,],
|