Spaces:
Running
Running
Change the global tag_results variable to use Gradio's State for execution.
Browse files
app.py
CHANGED
|
@@ -69,8 +69,6 @@ kaomojis = [
|
|
| 69 |
"||_||",
|
| 70 |
]
|
| 71 |
|
| 72 |
-
tag_results = {}
|
| 73 |
-
|
| 74 |
|
| 75 |
def parse_args() -> argparse.Namespace:
|
| 76 |
parser = argparse.ArgumentParser()
|
|
@@ -350,7 +348,9 @@ class Predictor:
|
|
| 350 |
llama3_reorganize_model_repo,
|
| 351 |
additional_tags_prepend,
|
| 352 |
additional_tags_append,
|
|
|
|
| 353 |
):
|
|
|
|
| 354 |
self.load_model(model_repo)
|
| 355 |
# Result
|
| 356 |
txt_infos = []
|
|
@@ -363,16 +363,15 @@ class Predictor:
|
|
| 363 |
character_res = None
|
| 364 |
general_res = None
|
| 365 |
|
| 366 |
-
tag_results.clear()
|
| 367 |
-
|
| 368 |
if llama3_reorganize_model_repo:
|
|
|
|
| 369 |
llama3_reorganize = Llama3Reorganize(llama3_reorganize_model_repo, loadModel=True)
|
| 370 |
|
| 371 |
prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
|
| 372 |
append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
|
| 373 |
if prepend_list and append_list:
|
| 374 |
append_list = [item for item in append_list if item not in prepend_list]
|
| 375 |
-
|
| 376 |
for idx, value in enumerate(gallery):
|
| 377 |
try:
|
| 378 |
image_path = value[0]
|
|
@@ -382,6 +381,7 @@ class Predictor:
|
|
| 382 |
|
| 383 |
input_name = self.model.get_inputs()[0].name
|
| 384 |
label_name = self.model.get_outputs()[0].name
|
|
|
|
| 385 |
preds = self.model.run([label_name], {input_name: image})[0]
|
| 386 |
|
| 387 |
labels = list(zip(self.tag_names, preds[0].astype(float)))
|
|
@@ -429,6 +429,7 @@ class Predictor:
|
|
| 429 |
sorted_general_strings = ", ".join((character_list if characters_merge_enabled else []) + prepend_list + sorted_general_list + append_list).replace("(", "\(").replace(")", "\)")
|
| 430 |
|
| 431 |
if llama3_reorganize_model_repo:
|
|
|
|
| 432 |
reorganize_strings = llama3_reorganize.reorganize(sorted_general_strings)
|
| 433 |
reorganize_strings = re.sub(r" *Title: *", "", reorganize_strings)
|
| 434 |
reorganize_strings = re.sub(r"\n+", ",", reorganize_strings)
|
|
@@ -458,9 +459,11 @@ class Predictor:
|
|
| 458 |
llama3_reorganize.release_vram()
|
| 459 |
del llama3_reorganize
|
| 460 |
|
| 461 |
-
|
|
|
|
|
|
|
| 462 |
|
| 463 |
-
def get_selection_from_gallery(gallery: list, selected_state: gr.SelectData):
|
| 464 |
if not selected_state:
|
| 465 |
return selected_state
|
| 466 |
|
|
@@ -627,14 +630,15 @@ def main():
|
|
| 627 |
general_res,
|
| 628 |
]
|
| 629 |
)
|
| 630 |
-
|
|
|
|
| 631 |
# Define the event listener to add the uploaded image to the gallery
|
| 632 |
image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input])
|
| 633 |
# When the upload button is clicked, add the new images to the gallery
|
| 634 |
upload_button.upload(extend_gallery, inputs=[gallery, upload_button], outputs=gallery)
|
| 635 |
# Event to update the selected image when an image is clicked in the gallery
|
| 636 |
selected_image = gr.Textbox(label="Selected Image", visible=False)
|
| 637 |
-
gallery.select(get_selection_from_gallery, inputs=gallery, outputs=[selected_image, sorted_general_strings, rating, character_res, general_res])
|
| 638 |
# Event to remove a selected image from the gallery
|
| 639 |
remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
|
| 640 |
|
|
@@ -651,8 +655,9 @@ def main():
|
|
| 651 |
llama3_reorganize_model_repo,
|
| 652 |
additional_tags_prepend,
|
| 653 |
additional_tags_append,
|
|
|
|
| 654 |
],
|
| 655 |
-
outputs=[download_file, sorted_general_strings, rating, character_res, general_res],
|
| 656 |
)
|
| 657 |
|
| 658 |
gr.Examples(
|
|
@@ -667,7 +672,7 @@ def main():
|
|
| 667 |
],
|
| 668 |
)
|
| 669 |
|
| 670 |
-
demo.queue(max_size=
|
| 671 |
demo.launch(inbrowser=True)
|
| 672 |
|
| 673 |
|
|
|
|
| 69 |
"||_||",
|
| 70 |
]
|
| 71 |
|
|
|
|
|
|
|
| 72 |
|
| 73 |
def parse_args() -> argparse.Namespace:
|
| 74 |
parser = argparse.ArgumentParser()
|
|
|
|
| 348 |
llama3_reorganize_model_repo,
|
| 349 |
additional_tags_prepend,
|
| 350 |
additional_tags_append,
|
| 351 |
+
tag_results,
|
| 352 |
):
|
| 353 |
+
print(f"Predict load model: {model_repo}, gallery length: {len(gallery)}")
|
| 354 |
self.load_model(model_repo)
|
| 355 |
# Result
|
| 356 |
txt_infos = []
|
|
|
|
| 363 |
character_res = None
|
| 364 |
general_res = None
|
| 365 |
|
|
|
|
|
|
|
| 366 |
if llama3_reorganize_model_repo:
|
| 367 |
+
print(f"Llama3 reorganize load model {llama3_reorganize_model_repo}")
|
| 368 |
llama3_reorganize = Llama3Reorganize(llama3_reorganize_model_repo, loadModel=True)
|
| 369 |
|
| 370 |
prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
|
| 371 |
append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
|
| 372 |
if prepend_list and append_list:
|
| 373 |
append_list = [item for item in append_list if item not in prepend_list]
|
| 374 |
+
|
| 375 |
for idx, value in enumerate(gallery):
|
| 376 |
try:
|
| 377 |
image_path = value[0]
|
|
|
|
| 381 |
|
| 382 |
input_name = self.model.get_inputs()[0].name
|
| 383 |
label_name = self.model.get_outputs()[0].name
|
| 384 |
+
print(f"Gallery {idx}: Starting run wd model...")
|
| 385 |
preds = self.model.run([label_name], {input_name: image})[0]
|
| 386 |
|
| 387 |
labels = list(zip(self.tag_names, preds[0].astype(float)))
|
|
|
|
| 429 |
sorted_general_strings = ", ".join((character_list if characters_merge_enabled else []) + prepend_list + sorted_general_list + append_list).replace("(", "\(").replace(")", "\)")
|
| 430 |
|
| 431 |
if llama3_reorganize_model_repo:
|
| 432 |
+
print(f"Starting reorganize with llama3...")
|
| 433 |
reorganize_strings = llama3_reorganize.reorganize(sorted_general_strings)
|
| 434 |
reorganize_strings = re.sub(r" *Title: *", "", reorganize_strings)
|
| 435 |
reorganize_strings = re.sub(r"\n+", ",", reorganize_strings)
|
|
|
|
| 459 |
llama3_reorganize.release_vram()
|
| 460 |
del llama3_reorganize
|
| 461 |
|
| 462 |
+
print("Predict is complete.")
|
| 463 |
+
|
| 464 |
+
return download, sorted_general_strings, rating, character_res, general_res, tag_results
|
| 465 |
|
| 466 |
+
def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData):
|
| 467 |
if not selected_state:
|
| 468 |
return selected_state
|
| 469 |
|
|
|
|
| 630 |
general_res,
|
| 631 |
]
|
| 632 |
)
|
| 633 |
+
|
| 634 |
+
tag_results = gr.State({})
|
| 635 |
# Define the event listener to add the uploaded image to the gallery
|
| 636 |
image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input])
|
| 637 |
# When the upload button is clicked, add the new images to the gallery
|
| 638 |
upload_button.upload(extend_gallery, inputs=[gallery, upload_button], outputs=gallery)
|
| 639 |
# Event to update the selected image when an image is clicked in the gallery
|
| 640 |
selected_image = gr.Textbox(label="Selected Image", visible=False)
|
| 641 |
+
gallery.select(get_selection_from_gallery, inputs=[gallery, tag_results], outputs=[selected_image, sorted_general_strings, rating, character_res, general_res])
|
| 642 |
# Event to remove a selected image from the gallery
|
| 643 |
remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
|
| 644 |
|
|
|
|
| 655 |
llama3_reorganize_model_repo,
|
| 656 |
additional_tags_prepend,
|
| 657 |
additional_tags_append,
|
| 658 |
+
tag_results,
|
| 659 |
],
|
| 660 |
+
outputs=[download_file, sorted_general_strings, rating, character_res, general_res, tag_results,],
|
| 661 |
)
|
| 662 |
|
| 663 |
gr.Examples(
|
|
|
|
| 672 |
],
|
| 673 |
)
|
| 674 |
|
| 675 |
+
demo.queue(max_size=2)
|
| 676 |
demo.launch(inbrowser=True)
|
| 677 |
|
| 678 |
|