avans06 commited on
Commit
32c8def
·
1 Parent(s): 75c6415

feat(ui): improve categorized tags interaction and merge unclassified outputs

Browse files

Interactive UI: Replaced static JSON display for categorized tags with dynamic gr.Dropdown components (multiselect=True). This allows users to easily remove or add tags within specific categories.
Copy Functionality: Added a synchronized read-only Textbox next to each category to enable copying of individual category strings.
Unified Output: Merged "Unclassified" tags into the main "Categorized" display. Unclassified items are now appended to the list (under the "Unclassified" key) instead of being displayed in a separate component.

Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +102 -15
  3. requirements.txt +4 -6
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 💬
4
  colorFrom: purple
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 5.34.2
8
  app_file: app.py
9
  pinned: true
10
  ---
 
4
  colorFrom: purple
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 5.50.0
8
  app_file: app.py
9
  pinned: true
10
  ---
app.py CHANGED
@@ -531,9 +531,19 @@ class Predictor:
531
 
532
  tag_results[image_path] = { "strings": sorted_general_strings, "classified_tags": classified_tags, "rating": rating, "character_res": character_res, "general_res": general_res, "unclassified_tags": unclassified_tags }
533
 
 
 
 
 
 
 
 
 
 
 
534
  # Store last result for UI display
535
  last_sorted_general_strings = sorted_general_strings
536
- last_classified_tags = classified_tags
537
  last_rating = rating
538
  last_character_res = character_res
539
  last_general_res = general_res
@@ -703,10 +713,32 @@ def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state:
703
  if not selected_state:
704
  return selected_state
705
 
 
706
  tag_result = tag_results.get(selected_state.value["image"]["path"],
707
- {"strings": "", "classified_tags": "{}", "rating": "", "character_res": "", "general_res": "", "unclassified_tags": "{}"})
708
-
709
- return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], tag_result["classified_tags"], tag_result["rating"], tag_result["character_res"], tag_result["general_res"], tag_result["unclassified_tags"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
710
 
711
  def append_gallery(gallery: list, image: str):
712
  if gallery is None:
@@ -765,6 +797,10 @@ def main():
765
  all: initial !important;
766
  background: #a8a8a8 !important;
767
  }
 
 
 
 
768
  """
769
  args = parse_args()
770
 
@@ -838,7 +874,7 @@ def main():
838
  with gr.Row():
839
  upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
840
  remove_button = gr.Button("Remove Selected Image", size="sm")
841
- gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Gallery that displaying a grid of images")
842
 
843
  # Group for text file inputs, initially hidden
844
  with gr.Column(visible=False) as text_inputs_group:
@@ -901,7 +937,7 @@ def main():
901
  with gr.Row():
902
  additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
903
  additional_tags_append = gr.Text(label="Append Additional tags (comma split)")
904
-
905
  # Add the remove tags input box
906
  with gr.Row():
907
  tags_to_remove = gr.Text(label="Remove tags (comma split)")
@@ -930,14 +966,60 @@ def main():
930
  download_file = gr.File(label="Output (Download)")
931
  sorted_general_strings = gr.Textbox(label="Output (string for last processed item)", show_label=True, show_copy_button=True, lines=5)
932
 
933
- with gr.Accordion("Categorized (tags)", open=False):
934
- categorized = gr.JSON(label="Categorized")
935
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
936
  with gr.Accordion("Detailed Output (for last processed item)", open=False):
937
  rating = gr.Label(label="Rating", visible=True)
938
  character_res = gr.Label(label="Output (characters)", visible=True)
939
  general_res = gr.Label(label="Output (tags)", visible=True)
940
- unclassified = gr.JSON(label="Unclassified (tags)", visible=True)
941
 
942
  with gr.Accordion("Tags Statistics (All files)", open=False):
943
  tags_statistics = gr.Text(
@@ -952,7 +1034,7 @@ def main():
952
  [
953
  download_file,
954
  sorted_general_strings,
955
- categorized,
956
  rating,
957
  character_res,
958
  general_res,
@@ -970,7 +1052,11 @@ def main():
970
  # When the upload button is clicked, add the new images to the gallery
971
  upload_button.upload(extend_gallery, inputs=[gallery, upload_button], outputs=gallery)
972
  # Event to update the selected image when an image is clicked in the gallery
973
- gallery.select(get_selection_from_gallery, inputs=[gallery, tag_results], outputs=[selected_image, sorted_general_strings, categorized, rating, character_res, general_res, unclassified])
 
 
 
 
974
  # Event to remove a selected image from the gallery
975
  remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
976
 
@@ -985,7 +1071,8 @@ def main():
985
  general_thresh_row: gr.update(visible=is_image),
986
  character_thresh_row: gr.update(visible=is_image),
987
  characters_merge_enabled: gr.update(visible=is_image),
988
- categorized: gr.update(visible=is_image),
 
989
  rating: gr.update(visible=is_image),
990
  character_res: gr.update(visible=is_image),
991
  general_res: gr.update(visible=is_image),
@@ -999,7 +1086,7 @@ def main():
999
  outputs=[
1000
  image_inputs_group, text_inputs_group, model_repo,
1001
  general_thresh_row, character_thresh_row, characters_merge_enabled,
1002
- categorized, rating, character_res, general_res, unclassified
1003
  ]
1004
  )
1005
 
@@ -1022,7 +1109,7 @@ def main():
1022
  tags_to_remove,
1023
  tag_results,
1024
  ],
1025
- outputs=[download_file, sorted_general_strings, categorized, rating, character_res, general_res, unclassified, tag_results, tags_statistics],
1026
  )
1027
 
1028
  gr.Examples(
 
531
 
532
  tag_results[image_path] = { "strings": sorted_general_strings, "classified_tags": classified_tags, "rating": rating, "character_res": character_res, "general_res": general_res, "unclassified_tags": unclassified_tags }
533
 
534
+ # Merge Unclassified into Classified for frontend display
535
+ display_classified = classified_tags.copy()
536
+ if unclassified_tags:
537
+ # If it is a list (common case), put it into the "Unclassified" category
538
+ if isinstance(unclassified_tags, list):
539
+ display_classified["Unclassified"] = unclassified_tags
540
+ # Just to be safe, if it is a dict, use update
541
+ elif isinstance(unclassified_tags, dict):
542
+ display_classified.update(unclassified_tags)
543
+
544
  # Store last result for UI display
545
  last_sorted_general_strings = sorted_general_strings
546
+ last_classified_tags = display_classified # Use the merged result
547
  last_rating = rating
548
  last_character_res = character_res
549
  last_general_res = general_res
 
713
  if not selected_state:
714
  return selected_state
715
 
716
+ # Default unclassified_tags to list (because classifyTags usually returns a list)
717
  tag_result = tag_results.get(selected_state.value["image"]["path"],
718
+ {"strings": "", "classified_tags": {}, "rating": "", "character_res": "", "general_res": "", "unclassified_tags": []})
719
+
720
+ # Retrieve original data
721
+ c_tags = tag_result["classified_tags"]
722
+ u_tags = tag_result["unclassified_tags"]
723
+
724
+ # Error handling: Ensure correct types
725
+ if isinstance(c_tags, str):
726
+ try: c_tags = ast.literal_eval(c_tags)
727
+ except: c_tags = {}
728
+ if isinstance(u_tags, str):
729
+ try: u_tags = ast.literal_eval(u_tags)
730
+ except: u_tags = []
731
+
732
+ # Merge: Copy Classified, and append Unclassified if it exists
733
+ display_classified = c_tags.copy() if isinstance(c_tags, dict) else {}
734
+
735
+ if u_tags:
736
+ if isinstance(u_tags, list):
737
+ display_classified["Unclassified"] = u_tags
738
+ elif isinstance(u_tags, dict):
739
+ display_classified.update(u_tags)
740
+
741
+ return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], display_classified, tag_result["rating"], tag_result["character_res"], tag_result["general_res"], tag_result["unclassified_tags"]
742
 
743
  def append_gallery(gallery: list, image: str):
744
  if gallery is None:
 
797
  all: initial !important;
798
  background: #a8a8a8 !important;
799
  }
800
+ /* Make the Dropdown options display more compactly */
801
+ .tag-dropdown span.svelte-1f354aw {
802
+ font-family: monospace;
803
+ }
804
  """
805
  args = parse_args()
806
 
 
874
  with gr.Row():
875
  upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
876
  remove_button = gr.Button("Remove Selected Image", size="sm")
877
+ gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height=500, label="Gallery that displaying a grid of images")
878
 
879
  # Group for text file inputs, initially hidden
880
  with gr.Column(visible=False) as text_inputs_group:
 
937
  with gr.Row():
938
  additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
939
  additional_tags_append = gr.Text(label="Append Additional tags (comma split)")
940
+
941
  # Add the remove tags input box
942
  with gr.Row():
943
  tags_to_remove = gr.Text(label="Remove tags (comma split)")
 
966
  download_file = gr.File(label="Output (Download)")
967
  sorted_general_strings = gr.Textbox(label="Output (string for last processed item)", show_label=True, show_copy_button=True, lines=5)
968
 
969
+ # Use State to store categorized data
970
+ categorized_state = gr.State({})
971
+
972
+ # Wrap the dynamically rendered area with Accordion
973
+ with gr.Accordion("Categorized (tags) - Interactive", open=False) as categorized_accordion:
974
+ # Use @gr.render to dynamically generate UI based on the content of categorized_state
975
+ @gr.render(inputs=categorized_state)
976
+ def render_categorized_tags(categories_data):
977
+ if not categories_data:
978
+ gr.Markdown("No categorized tags to display yet.")
979
+ return
980
+
981
+ for category_name, tags_list in categories_data.items():
982
+ # Ensure tags_list is of type list
983
+ current_tags = tags_list if isinstance(tags_list, list) else str(tags_list).split(',')
984
+ current_tags = [t.strip() for t in current_tags if t.strip()]
985
+
986
+ with gr.Group():
987
+ with gr.Row(variant="compact", equal_height=True):
988
+ # 1. Multiselect Dropdown (Main editing area)
989
+ dd = gr.Dropdown(
990
+ choices=current_tags, # Default choices are the current tags
991
+ value=current_tags, # Default value are the current tags
992
+ label=f"{category_name} ({len(current_tags)})",
993
+ multiselect=True, # Enable multiselect (shows X button)
994
+ allow_custom_value=True, # Allow custom values (add new tags)
995
+ interactive=True,
996
+ scale=5,
997
+ elem_classes=["tag-dropdown"]
998
+ )
999
+
1000
+ # 2. Read-only Textbox (Used to provide a copy button)
1001
+ # Since Dropdown cannot directly copy raw strings, we use this Textbox to "sync display" the string
1002
+ txt_copy = gr.Textbox(
1003
+ value=", ".join(current_tags),
1004
+ label="Copy String",
1005
+ show_copy_button=True, # Copy button is here
1006
+ interactive=False, # Disable manual editing, only sync from Dropdown
1007
+ scale=1,
1008
+ min_width=100,
1009
+ max_lines=1
1010
+ )
1011
+
1012
+ # 3. Event binding: Update Textbox when Dropdown changes
1013
+ def sync_tags_to_text(selected_tags):
1014
+ return ", ".join(selected_tags)
1015
+
1016
+ dd.change(fn=sync_tags_to_text, inputs=dd, outputs=txt_copy)
1017
+
1018
  with gr.Accordion("Detailed Output (for last processed item)", open=False):
1019
  rating = gr.Label(label="Rating", visible=True)
1020
  character_res = gr.Label(label="Output (characters)", visible=True)
1021
  general_res = gr.Label(label="Output (tags)", visible=True)
1022
+ unclassified = gr.JSON(label="Unclassified (tags)", visible=False)
1023
 
1024
  with gr.Accordion("Tags Statistics (All files)", open=False):
1025
  tags_statistics = gr.Text(
 
1034
  [
1035
  download_file,
1036
  sorted_general_strings,
1037
+ categorized_state,
1038
  rating,
1039
  character_res,
1040
  general_res,
 
1052
  # When the upload button is clicked, add the new images to the gallery
1053
  upload_button.upload(extend_gallery, inputs=[gallery, upload_button], outputs=gallery)
1054
  # Event to update the selected image when an image is clicked in the gallery
1055
+ gallery.select(
1056
+ get_selection_from_gallery,
1057
+ inputs=[gallery, tag_results],
1058
+ outputs=[selected_image, sorted_general_strings, categorized_state, rating, character_res, general_res, unclassified]
1059
+ )
1060
  # Event to remove a selected image from the gallery
1061
  remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
1062
 
 
1071
  general_thresh_row: gr.update(visible=is_image),
1072
  character_thresh_row: gr.update(visible=is_image),
1073
  characters_merge_enabled: gr.update(visible=is_image),
1074
+ # Update visibility of categorized_accordion
1075
+ categorized_accordion: gr.update(visible=is_image),
1076
  rating: gr.update(visible=is_image),
1077
  character_res: gr.update(visible=is_image),
1078
  general_res: gr.update(visible=is_image),
 
1086
  outputs=[
1087
  image_inputs_group, text_inputs_group, model_repo,
1088
  general_thresh_row, character_thresh_row, characters_merge_enabled,
1089
+ categorized_accordion, rating, character_res, general_res, unclassified
1090
  ]
1091
  )
1092
 
 
1109
  tags_to_remove,
1110
  tag_results,
1111
  ],
1112
+ outputs=[download_file, sorted_general_strings, categorized_state, rating, character_res, general_res, unclassified, tag_results, tags_statistics],
1113
  )
1114
 
1115
  gr.Examples(
requirements.txt CHANGED
@@ -1,16 +1,14 @@
1
- --extra-index-url https://download.pytorch.org/whl/cu124
2
 
3
  pillow>=9.0.0
4
  onnxruntime>=1.12.0
5
  huggingface-hub
6
 
7
- gradio==5.34.2
8
  pandas
9
 
10
  # for reorganize WD Tagger into a readable article by Llama3 model.
11
  transformers>=4.45.2
12
  ctranslate2>=4.4.0
13
- torch==2.5.0+cu124; sys_platform != 'darwin'
14
- torchvision==0.20.0+cu124; sys_platform != 'darwin'
15
- torch==2.5.0; sys_platform == 'darwin'
16
- torchvision==0.20.0; sys_platform == 'darwin'
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu128
2
 
3
  pillow>=9.0.0
4
  onnxruntime>=1.12.0
5
  huggingface-hub
6
 
7
+ gradio==5.50.0
8
  pandas
9
 
10
  # for reorganize WD Tagger into a readable article by Llama3 model.
11
  transformers>=4.45.2
12
  ctranslate2>=4.4.0
13
+ torch
14
+ torchvision