Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| from PIL import Image, ImageDraw, ImageFont, ImageOps | |
| import os | |
| from .constants import KSORT_IMAGE_DIR | |
| from .constants import COLOR1, COLOR2, COLOR3, COLOR4 | |
| from .vote_utils import save_any_image | |
| from .utils import disable_btn, enable_btn, invisible_btn | |
| from .upload import create_remote_directory, upload_ssh_all, upload_ssh_data | |
| import json | |
| def reset_level(Top_btn): | |
| if Top_btn == "Top 1": | |
| level = 0 | |
| elif Top_btn == "Top 2": | |
| level = 1 | |
| elif Top_btn == "Top 3": | |
| level = 2 | |
| elif Top_btn == "Top 4": | |
| level = 3 | |
| return level | |
| def reset_rank(windows, rank, vote_level): | |
| if windows == "Model A": | |
| rank[0] = vote_level | |
| elif windows == "Model B": | |
| rank[1] = vote_level | |
| elif windows == "Model C": | |
| rank[2] = vote_level | |
| elif windows == "Model D": | |
| rank[3] = vote_level | |
| return rank | |
| def reset_btn_rank(windows, rank, btn, vote_level): | |
| if windows == "Model A" and btn == "1": | |
| rank[0] = 0 | |
| elif windows == "Model A" and btn == "2": | |
| rank[0] = 1 | |
| elif windows == "Model A" and btn == "3": | |
| rank[0] = 2 | |
| elif windows == "Model A" and btn == "4": | |
| rank[0] = 3 | |
| elif windows == "Model B" and btn == "1": | |
| rank[1] = 0 | |
| elif windows == "Model B" and btn == "2": | |
| rank[1] = 1 | |
| elif windows == "Model B" and btn == "3": | |
| rank[1] = 2 | |
| elif windows == "Model B" and btn == "4": | |
| rank[1] = 3 | |
| elif windows == "Model C" and btn == "1": | |
| rank[2] = 0 | |
| elif windows == "Model C" and btn == "2": | |
| rank[2] = 1 | |
| elif windows == "Model C" and btn == "3": | |
| rank[2] = 2 | |
| elif windows == "Model C" and btn == "4": | |
| rank[2] = 3 | |
| elif windows == "Model D" and btn == "1": | |
| rank[3] = 0 | |
| elif windows == "Model D" and btn == "2": | |
| rank[3] = 1 | |
| elif windows == "Model D" and btn == "3": | |
| rank[3] = 2 | |
| elif windows == "Model D" and btn == "4": | |
| rank[3] = 3 | |
| if btn == "1": | |
| vote_level = 0 | |
| elif btn == "2": | |
| vote_level = 1 | |
| elif btn == "3": | |
| vote_level = 2 | |
| elif btn == "4": | |
| vote_level = 3 | |
| return (rank, vote_level) | |
| def reset_vote_text(rank): | |
| rank_str = "" | |
| for i in range(len(rank)): | |
| if rank[i] == None: | |
| rank_str = rank_str + str(rank[i]) | |
| else: | |
| rank_str = rank_str + str(rank[i]+1) | |
| rank_str = rank_str + " " | |
| return rank_str | |
| def clear_rank(rank, vote_level): | |
| for i in range(len(rank)): | |
| rank[i] = None | |
| vote_level = 0 | |
| return rank, vote_level | |
| def revote_windows(generate_ig0, generate_ig1, generate_ig2, generate_ig3, rank, vote_level): | |
| for i in range(len(rank)): | |
| rank[i] = None | |
| vote_level = 0 | |
| return generate_ig0, generate_ig1, generate_ig2, generate_ig3, rank, vote_level | |
| def reset_submit(rank): | |
| for i in range(len(rank)): | |
| if rank[i] == None: | |
| return disable_btn | |
| return enable_btn | |
| def reset_mode(mode): | |
| if mode == "Best": | |
| return (gr.update(visible=False, interactive=False),) * 5 + \ | |
| (gr.update(visible=True, interactive=True),) * 16 + \ | |
| (gr.update(visible=True, interactive=True),) * 3 + \ | |
| (gr.Textbox(value="Rank", visible=False, interactive=False),) | |
| elif mode == "Rank": | |
| return (gr.update(visible=True, interactive=True),) * 5 + \ | |
| (gr.update(visible=False, interactive=False),) * 16 + \ | |
| (gr.update(visible=True, interactive=False),) * 2 + \ | |
| (gr.update(visible=True, interactive=True),) + \ | |
| (gr.Textbox(value="Best", visible=False, interactive=False),) | |
| else: | |
| raise ValueError("Undefined mode") | |
| def reset_chatbot(mode, generate_ig0, generate_ig1, generate_ig2, generate_ig3): | |
| return generate_ig0, generate_ig1, generate_ig2, generate_ig3 | |
| def get_json_filename(conv_id): | |
| output_dir = f'{KSORT_IMAGE_DIR}/{conv_id}/json/' | |
| if not os.path.exists(output_dir): | |
| os.makedirs(output_dir) | |
| output_file = os.path.join(output_dir, "information.json") | |
| # name = os.path.join(KSORT_IMAGE_DIR, f"{conv_id}/json/information.json") | |
| print(output_file) | |
| return output_file | |
| def get_img_filename(conv_id, i): | |
| output_dir = f'{KSORT_IMAGE_DIR}/{conv_id}/image/' | |
| if not os.path.exists(output_dir): | |
| os.makedirs(output_dir) | |
| output_file = os.path.join(output_dir, f"{i}.jpg") | |
| print(output_file) | |
| return output_file | |
| def vote_submit(states, textbox, rank, request: gr.Request): | |
| conv_id = states[0].conv_id | |
| for i in range(len(states)): | |
| output_file = get_img_filename(conv_id, i) | |
| save_any_image(states[i].output, output_file) | |
| with open(get_json_filename(conv_id), "a") as fout: | |
| data = { | |
| "models_name": [x.model_name for x in states], | |
| "img_rank": [x for x in rank], | |
| "prompt": [textbox], | |
| } | |
| fout.write(json.dumps(data) + "\n") | |
| def vote_ssh_submit(states, textbox, rank, user_name, user_institution): | |
| conv_id = states[0].conv_id | |
| output_dir = create_remote_directory(conv_id) | |
| # upload_image(states, output_dir) | |
| data = { | |
| "models_name": [x.model_name for x in states], | |
| "img_rank": [x for x in rank], | |
| "prompt": [textbox], | |
| "user_info": {"name": [user_name], "institution": [user_institution]}, | |
| } | |
| output_file = os.path.join(output_dir, "result.json") | |
| # upload_informance(data, output_file) | |
| upload_ssh_all(states, output_dir, data, output_file) | |
| from .update_skill import update_skill | |
| update_skill(rank, [x.model_name for x in states]) | |
| def vote_video_ssh_submit(states, textbox, prompt_path, rank, user_name, user_institution): | |
| conv_id = states[0].conv_id | |
| output_dir = create_remote_directory(conv_id, video=True) | |
| data = { | |
| "models_name": [x.model_name for x in states], | |
| "video_rank": [x for x in rank], | |
| "prompt": [textbox], | |
| "prompt_path": [prompt_path], | |
| "video_path": [x.output for x in states], | |
| "user_info": {"name": [user_name], "institution": [user_institution]}, | |
| } | |
| output_file = os.path.join(output_dir, "result.json") | |
| upload_ssh_data(data, output_file) | |
| from .update_skill_video import update_skill_video | |
| update_skill_video(rank, [x.model_name for x in states]) | |
| def submit_response_igm( | |
| state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, rank, user_name, user_institution, request: gr.Request | |
| ): | |
| # vote_submit([state0, state1, state2, state3], textbox, rank, request) | |
| vote_ssh_submit([state0, state1, state2, state3], textbox, rank, user_name, user_institution) | |
| if model_selector0 == "": | |
| return (disable_btn,) * 6 + ( | |
| gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), | |
| gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True), | |
| gr.Markdown(f"### Model C: {state2.model_name.split('_')[1]}", visible=True), | |
| gr.Markdown(f"### Model D: {state3.model_name.split('_')[1]}", visible=True) | |
| ) + (disable_btn,) | |
| else: | |
| return (disable_btn,) * 6 + ( | |
| gr.Markdown(state0.model_name, visible=True), | |
| gr.Markdown(state1.model_name, visible=True), | |
| gr.Markdown(state2.model_name, visible=True), | |
| gr.Markdown(state3.model_name, visible=True) | |
| ) + (disable_btn,) | |
| def submit_response_vg( | |
| state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, prompt_path, rank, user_name, user_institution, request: gr.Request | |
| ): | |
| vote_video_ssh_submit([state0, state1, state2, state3], textbox, prompt_path, rank, user_name, user_institution) | |
| if model_selector0 == "": | |
| return (disable_btn,) * 6 + ( | |
| gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), | |
| gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True), | |
| gr.Markdown(f"### Model C: {state2.model_name.split('_')[1]}", visible=True), | |
| gr.Markdown(f"### Model D: {state3.model_name.split('_')[1]}", visible=True) | |
| ) + (disable_btn,) | |
| else: | |
| return (disable_btn,) * 6 + ( | |
| gr.Markdown(state0.model_name, visible=True), | |
| gr.Markdown(state1.model_name, visible=True), | |
| gr.Markdown(state2.model_name, visible=True), | |
| gr.Markdown(state3.model_name, visible=True) | |
| ) + (disable_btn,) | |
| def submit_response_rank_igm( | |
| state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, rank, right_vote_text, user_name, user_institution, request: gr.Request | |
| ): | |
| print(rank) | |
| if right_vote_text == "right": | |
| # vote_submit([state0, state1, state2, state3], textbox, rank, request) | |
| vote_ssh_submit([state0, state1, state2, state3], textbox, rank, user_name, user_institution) | |
| if model_selector0 == "": | |
| return (disable_btn,) * 16 + (disable_btn,) * 3 + ("wrong",) + ( | |
| gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), | |
| gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True), | |
| gr.Markdown(f"### Model C: {state2.model_name.split('_')[1]}", visible=True), | |
| gr.Markdown(f"### Model D: {state3.model_name.split('_')[1]}", visible=True) | |
| ) | |
| else: | |
| return (disable_btn,) * 16 + (disable_btn,) * 3 + ("wrong",) + ( | |
| gr.Markdown(state0.model_name, visible=True), | |
| gr.Markdown(state1.model_name, visible=True), | |
| gr.Markdown(state2.model_name, visible=True), | |
| gr.Markdown(state3.model_name, visible=True) | |
| ) | |
| else: | |
| return (enable_btn,) * 16 + (enable_btn,) * 3 + ("wrong",) + (gr.Markdown("", visible=False),) * 4 | |
| def submit_response_rank_vg( | |
| state0, state1, state2, state3, model_selector0, model_selector1, model_selector2, model_selector3, textbox, prompt_path, rank, right_vote_text, user_name, user_institution, request: gr.Request | |
| ): | |
| print(rank) | |
| if right_vote_text == "right": | |
| vote_video_ssh_submit([state0, state1, state2, state3], textbox, prompt_path, rank, user_name, user_institution) | |
| if model_selector0 == "": | |
| return (disable_btn,) * 16 + (disable_btn,) * 3 + ("wrong",) + ( | |
| gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), | |
| gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True), | |
| gr.Markdown(f"### Model C: {state2.model_name.split('_')[1]}", visible=True), | |
| gr.Markdown(f"### Model D: {state3.model_name.split('_')[1]}", visible=True) | |
| ) | |
| else: | |
| return (disable_btn,) * 16 + (disable_btn,) * 3 + ("wrong",) + ( | |
| gr.Markdown(state0.model_name, visible=True), | |
| gr.Markdown(state1.model_name, visible=True), | |
| gr.Markdown(state2.model_name, visible=True), | |
| gr.Markdown(state3.model_name, visible=True) | |
| ) | |
| else: | |
| return (enable_btn,) * 16 + (enable_btn,) * 3 + ("wrong",) + (gr.Markdown("", visible=False),) * 4 | |
| def text_response_rank_igm(generate_ig0, generate_ig1, generate_ig2, generate_ig3, Top1_text, Top2_text, Top3_text, Top4_text, vote_textbox): | |
| rank_list = [char for char in vote_textbox if char.isdigit()] | |
| generate_ig = [generate_ig0, generate_ig1, generate_ig2, generate_ig3] | |
| chatbot = [] | |
| rank = [None, None, None, None] | |
| if len(rank_list) != 4: | |
| return generate_ig + ["error rank"] + ["wrong"] + [rank] | |
| for num in range(len(rank_list)): | |
| if rank_list[num] in ['1', '2', '3', '4']: | |
| base_image = Image.fromarray(generate_ig[num]).convert("RGBA") | |
| base_image = base_image.resize((512, 512), Image.ANTIALIAS) | |
| if rank_list[num] == '1': | |
| border_color = COLOR1 | |
| elif rank_list[num] == '2': | |
| border_color = COLOR2 | |
| elif rank_list[num] == '3': | |
| border_color = COLOR3 | |
| elif rank_list[num] == '4': | |
| border_color = COLOR4 | |
| border_size = 10 # Size of the border | |
| base_image = ImageOps.expand(base_image, border=border_size, fill=border_color) | |
| draw = ImageDraw.Draw(base_image) | |
| font = ImageFont.truetype("./serve/Arial.ttf", 66) | |
| text_position = (180, 25) | |
| if rank_list[num] == '1': | |
| text_color = COLOR1 | |
| draw.text(text_position, Top1_text, font=font, fill=text_color) | |
| elif rank_list[num] == '2': | |
| text_color = COLOR2 | |
| draw.text(text_position, Top2_text, font=font, fill=text_color) | |
| elif rank_list[num] == '3': | |
| text_color = COLOR3 | |
| draw.text(text_position, Top3_text, font=font, fill=text_color) | |
| elif rank_list[num] == '4': | |
| text_color = COLOR4 | |
| draw.text(text_position, Top4_text, font=font, fill=text_color) | |
| base_image = base_image.convert("RGB") | |
| chatbot.append(base_image.copy()) | |
| else: | |
| return generate_ig + ["error rank"] + ["wrong"] + [rank] | |
| rank_str = "" | |
| for str_num in rank_list: | |
| rank_str = rank_str + str_num | |
| rank_str = rank_str + " " | |
| rank = [int(x) for x in rank_list] | |
| return chatbot + [rank_str] + ["right"] + [rank] | |
| def text_response_rank_vg(vote_textbox): | |
| rank_list = [char for char in vote_textbox if char.isdigit()] | |
| rank = [None, None, None, None] | |
| if len(rank_list) != 4: | |
| return ["error rank"] + ["wrong"] + [rank] | |
| for num in range(len(rank_list)): | |
| if rank_list[num] in ['1', '2', '3', '4']: | |
| continue | |
| else: | |
| return ["error rank"] + ["wrong"] + [rank] | |
| rank_str = "" | |
| for str_num in rank_list: | |
| rank_str = rank_str + str_num | |
| rank_str = rank_str + " " | |
| rank = [int(x) for x in rank_list] | |
| return [rank_str] + ["right"] + [rank] | |
| def add_foreground(image, vote_level, Top1_text, Top2_text, Top3_text, Top4_text): | |
| base_image = Image.fromarray(image).convert("RGBA") | |
| base_image = base_image.resize((512, 512), Image.ANTIALIAS) | |
| if vote_level == 0: | |
| border_color = COLOR1 | |
| elif vote_level == 1: | |
| border_color = COLOR2 | |
| elif vote_level == 2: | |
| border_color = COLOR3 | |
| elif vote_level == 3: | |
| border_color = COLOR4 | |
| border_size = 10 # Size of the border | |
| base_image = ImageOps.expand(base_image, border=border_size, fill=border_color) | |
| draw = ImageDraw.Draw(base_image) | |
| font = ImageFont.truetype("./serve/Arial.ttf", 66) | |
| text_position = (180, 25) | |
| if vote_level == 0: | |
| text_color = COLOR1 | |
| draw.text(text_position, Top1_text, font=font, fill=text_color) | |
| elif vote_level == 1: | |
| text_color = COLOR2 | |
| draw.text(text_position, Top2_text, font=font, fill=text_color) | |
| elif vote_level == 2: | |
| text_color = COLOR3 | |
| draw.text(text_position, Top3_text, font=font, fill=text_color) | |
| elif vote_level == 3: | |
| text_color = COLOR4 | |
| draw.text(text_position, Top4_text, font=font, fill=text_color) | |
| base_image = base_image.convert("RGB") | |
| return base_image | |
| def add_green_border(image): | |
| border_color = (0, 255, 0) # RGB for green | |
| border_size = 10 # Size of the border | |
| img_with_border = ImageOps.expand(image, border=border_size, fill=border_color) | |
| return img_with_border | |
| def check_textbox(textbox): | |
| if textbox=="": | |
| return False | |
| else: | |
| return True |